-
Notifications
You must be signed in to change notification settings - Fork 38
fix lookup keccak rotation to use max 16 limb #1034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,11 +130,11 @@ pub struct KeccakFixedCols<T> { | |
pub struct KeccakWitCols<T> { | ||
pub input8: [T; 200], | ||
pub c_aux: [T; 200], | ||
pub c_temp: [T; 30], | ||
pub c_temp: [T; 40], | ||
pub c_rot: [T; 40], | ||
pub d: [T; 40], | ||
pub theta_output: [T; 200], | ||
pub rotation_witness: [T; 146], | ||
pub rotation_witness: [T; 196], | ||
pub rhopi_output: [T; 200], | ||
pub nonlinear: [T; 200], | ||
pub chi_output: [T; 8], | ||
|
@@ -312,7 +312,7 @@ impl<E: ExtensionField> ProtocolBuilder<E> for KeccakLayout<E> { | |
// documentation of `constrain_left_rotation64`. Here c_temp is the split | ||
// witness for a 1-rotation. | ||
|
||
let c_temp: ArrayView<WitIn, Ix2> = ArrayView::from_shape((5, 6), c_temp).unwrap(); | ||
let c_temp: ArrayView<WitIn, Ix2> = ArrayView::from_shape((5, 8), c_temp).unwrap(); | ||
let c_rot: ArrayView<WitIn, Ix2> = ArrayView::from_shape((5, 8), c_rot).unwrap(); | ||
|
||
let (sizes, _) = rotation_split(1); | ||
|
@@ -405,6 +405,7 @@ impl<E: ExtensionField> ProtocolBuilder<E> for KeccakLayout<E> { | |
)?; | ||
} | ||
} | ||
assert!(rotation_witness.next().is_none()); | ||
|
||
let mut chi_output = chi_output.to_vec(); | ||
chi_output.extend(iota_output[8..].to_vec()); | ||
|
@@ -785,12 +786,12 @@ where | |
c8[x] = conv64to8(c64[x]); | ||
} | ||
|
||
let mut c_temp = [[0u64; 6]; 5]; | ||
let mut c_temp = [[0u64; 8]; 5]; | ||
for i in 0..5 { | ||
let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) | ||
.convert(vec![16, 15, 1, 16, 15, 1]) | ||
.convert(vec![15, 1, 15, 1, 15, 1, 15, 1]) | ||
.values(); | ||
for (j, size) in [16, 15, 1, 16, 15, 1].iter().enumerate() { | ||
for (j, size) in [15, 1, 15, 1, 15, 1, 15, 1].iter().enumerate() { | ||
lk_multiplicity.assert_ux_in_u16(*size, rep[j]); | ||
} | ||
c_temp[i] = rep.try_into().unwrap(); | ||
|
@@ -834,13 +835,20 @@ where | |
.convert(sizes.clone()) | ||
.values(); | ||
for (j, size) in sizes.iter().enumerate() { | ||
if *size != 32 { | ||
lk_multiplicity.assert_ux_in_u16(*size, rep[j]); | ||
match *size { | ||
32 | 1 => (), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not checking it when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because size 1 means bit check, and we compiled it into |
||
18 => lk_multiplicity.assert_ux::<18>(rep[j]), | ||
16 => lk_multiplicity.assert_ux::<16>(rep[j]), | ||
14 => lk_multiplicity.assert_ux::<14>(rep[j]), | ||
8 => lk_multiplicity.assert_ux::<8>(rep[j]), | ||
5 => lk_multiplicity.assert_ux::<5>(rep[j]), | ||
_ => lk_multiplicity.assert_ux_in_u16(*size, rep[j]), | ||
} | ||
} | ||
rotation_witness.extend(rep); | ||
} | ||
} | ||
assert_eq!(rotation_witness.len(), rotation_witness_witin.len()); | ||
|
||
// Rho and Pi steps | ||
let mut rhopi_output64 = [[0u64; 5]; 5]; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ use multilinear_extensions::{ | |
Expression, Fixed, Instance, StructuralWitIn, ToExpr, WitIn, WitnessId, rlc_chip_record, | ||
}; | ||
use serde::de::DeserializeOwned; | ||
use std::{cmp::Ordering, collections::HashMap, iter::once, marker::PhantomData}; | ||
use std::{collections::HashMap, iter::once, marker::PhantomData}; | ||
|
||
use ff_ext::ExtensionField; | ||
|
||
|
@@ -782,6 +782,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { | |
14 => self.assert_u14(name_fn, expr), | ||
8 => self.assert_byte(name_fn, expr), | ||
5 => self.assert_u5(name_fn, expr), | ||
1 => self.assert_bit(name_fn, expr), | ||
c => panic!("Unsupported bit range {c}"), | ||
} | ||
} | ||
|
@@ -1058,8 +1059,29 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { | |
|
||
// Lookup ranges | ||
for (i, (size, elem)) in split_rep.iter().enumerate() { | ||
if *size != 32 { | ||
self.assert_ux_in_u16(|| format!("{}_{}", name().into(), i), *size, elem.clone())?; | ||
match *size { | ||
32 => (), | ||
18 => { | ||
self.assert_ux::<_, _, 18>(|| format!("{}_{}", name().into(), i), elem.clone())? | ||
} | ||
16 => { | ||
self.assert_ux::<_, _, 16>(|| format!("{}_{}", name().into(), i), elem.clone())? | ||
} | ||
14 => { | ||
self.assert_ux::<_, _, 14>(|| format!("{}_{}", name().into(), i), elem.clone())? | ||
} | ||
8 => { | ||
self.assert_ux::<_, _, 8>(|| format!("{}_{}", name().into(), i), elem.clone())? | ||
} | ||
5 => { | ||
self.assert_ux::<_, _, 5>(|| format!("{}_{}", name().into(), i), elem.clone())? | ||
} | ||
1 => self.assert_bit(|| format!("{}_{}", name().into(), i), elem.clone())?, | ||
_ => self.assert_ux_in_u16( | ||
|| format!("{}_{}", name().into(), i), | ||
*size, | ||
elem.clone(), | ||
)?, | ||
Comment on lines
+1062
to
+1084
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we put this in a helper function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes will do it in next PR #1036 to fix precompile e2e sound |
||
} | ||
} | ||
|
||
|
@@ -1073,25 +1095,26 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { | |
let mut rep_x = rep_x.to_owned(); | ||
rep_x.rotate_right(chunks_rotation); | ||
|
||
for i in 0..2 { | ||
// The respective 4 elements in the byte representation | ||
let lhs = rep8[4 * i..4 * (i + 1)] | ||
// 64 bits represent in 4 limb, each with 16 bits | ||
let num_limbs = 4; | ||
let mut rep_x_iter = rep_x.iter().cloned(); | ||
for limb_i in 0..num_limbs { | ||
let lhs = rep8[2 * limb_i..2 * (limb_i + 1)] | ||
.iter() | ||
.map(|wit| (8, wit.expr())) | ||
.collect_vec(); | ||
let cnt = rep_x.len() / 2; | ||
let rhs = &rep_x[cnt * i..cnt * (i + 1)]; | ||
let rhs_limbs = take_til_threshold(&mut rep_x_iter, 16, &|limb| limb.0).unwrap(); | ||
|
||
assert_eq!(rhs.iter().map(|e| e.0).sum::<usize>(), 32); | ||
assert_eq!(rhs_limbs.iter().map(|e| e.0).sum::<usize>(), 16); | ||
|
||
self.require_reps_equal::<32, _, _>( | ||
self.require_reps_equal::<16, _, _>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. key highlights: constrain on 16 bits limb individually instead of 32 bits (limb_16 + (1 << 16( * limb_16) because later are larger than some finite field characteristic, e.g. babybear |
||
||format!( | ||
"rotation internal {}, round {i}, rot: {chunks_rotation}, delta: {delta}, {:?}", | ||
"rotation internal {}, round {limb_i}, rot: {chunks_rotation}, delta: {delta}, {:?}", | ||
name().into(), | ||
sizes | ||
), | ||
&lhs, | ||
rhs, | ||
&rhs_limbs, | ||
)?; | ||
} | ||
Ok(()) | ||
|
@@ -1113,34 +1136,64 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { | |
} | ||
} | ||
|
||
/// take items from an iterator until the accumulated "weight" (measured by `f`) | ||
/// reaches exactly `threshold`. | ||
/// | ||
/// - `iter`: a mutable iterator to consume items from | ||
/// - `threshold`: the sum of weights at which to stop and return the group | ||
/// - `f`: closure that extracts the "weight" (e.g., bit length) from each item | ||
/// | ||
/// returns: | ||
/// - `Some(Vec<T>)` containing the next group of items whose weights sum to `threshold` | ||
/// - `None` if the iterator is exhausted and no items remain | ||
/// | ||
/// panics if the sum of weights ever exceeds `threshold`. | ||
pub fn take_til_threshold<I, T, F>(iter: &mut I, threshold: usize, f: &F) -> Option<Vec<T>> | ||
where | ||
I: Iterator<Item = T>, | ||
F: Fn(&T) -> usize, | ||
{ | ||
let mut group = Vec::new(); | ||
let mut sum = 0; | ||
|
||
for x in iter.by_ref() { | ||
sum += f(&x); | ||
group.push(x); | ||
|
||
if sum == threshold { | ||
return Some(group); | ||
} else if sum > threshold { | ||
panic!("sum exceeded threshold!"); | ||
} | ||
} | ||
|
||
if group.is_empty() { | ||
None | ||
} else { | ||
Some(group) // leftover if input not perfectly divisible | ||
} | ||
} | ||
|
||
/// Compute an adequate split of 64-bits into chunks for performing a rotation | ||
/// by `delta`. The first element of the return value is the vec of chunk sizes. | ||
/// The second one is the length of its suffix that needs to be rotated | ||
pub fn rotation_split(delta: usize) -> (Vec<usize>, usize) { | ||
let delta = delta % 64; | ||
|
||
if delta == 0 { | ||
return (vec![32, 32], 0); | ||
} | ||
|
||
// This split meets all requirements except for <= 16 sizes | ||
let split32 = match delta.cmp(&32) { | ||
Ordering::Less => vec![32 - delta, delta, 32 - delta, delta], | ||
Ordering::Equal => vec![32, 32], | ||
Ordering::Greater => vec![32 - (delta - 32), delta - 32, 32 - (delta - 32), delta - 32], | ||
}; | ||
|
||
// Split off large chunks | ||
let split16 = split32 | ||
.into_iter() | ||
.flat_map(|size| { | ||
assert!(size < 32); | ||
if size <= 16 { | ||
vec![size] | ||
return (vec![16, 16, 16, 16], 0); | ||
} | ||
|
||
let remainder = delta % 16; | ||
let split16 = std::iter::repeat_with(|| [16 - remainder, remainder]) | ||
.flatten() | ||
.scan(0, |sum, x| { | ||
if *sum >= 64 { | ||
None | ||
} else { | ||
vec![16, size - 16] | ||
*sum += x; | ||
Some(x) | ||
} | ||
}) | ||
.filter(|v| *v > 0) | ||
.collect_vec(); | ||
|
||
let mut sum = 0; | ||
|
@@ -1151,7 +1204,7 @@ pub fn rotation_split(delta: usize) -> (Vec<usize>, usize) { | |
} | ||
} | ||
|
||
panic!(); | ||
panic!("delta {:?} split16 {:?}", remainder, split16); | ||
} | ||
|
||
pub fn expansion_expr<E: ExtensionField, const SIZE: usize>( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lookup_keccak
involving some deep recursion when build the circuit. I think it's a existing issue, just the new change hit the threshold unluckily. With that, we temporarily increase RUST_MIN_STACK to make tests pass. It doesn't break functional correctness.