From 98a23d30c88e3ced1ad68152ab6a50196bc685c6 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 25 Aug 2025 14:50:09 +0800 Subject: [PATCH 1/2] fix lookup keccak rotation to use max 16 limb --- Makefile.toml | 2 + ceno_zkvm/src/precompiles/lookup_keccakf.rs | 22 +++- gkr_iop/src/circuit_builder.rs | 116 ++++++++++++++------ 3 files changed, 101 insertions(+), 39 deletions(-) diff --git a/Makefile.toml b/Makefile.toml index 03e6d4ee4..4776fa231 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -12,6 +12,7 @@ args = [ "--workspace", ] command = "cargo" +env = { RUST_MIN_STACK = "33554432" } workspace = false [tasks.tests_v2] @@ -27,6 +28,7 @@ args = [ "u16limb_circuit", ] command = "cargo" +env = { RUST_MIN_STACK = "33554432" } workspace = false diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 7a7d62fda..c8b672cc0 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -130,11 +130,11 @@ pub struct KeccakFixedCols { pub struct KeccakWitCols { pub input8: [T; 200], pub c_aux: [T; 200], - pub c_temp: [T; 30], + pub c_temp: [T; 40], pub c_rot: [T; 40], pub d: [T; 40], pub theta_output: [T; 200], - pub rotation_witness: [T; 146], + pub rotation_witness: [T; 196], pub rhopi_output: [T; 200], pub nonlinear: [T; 200], pub chi_output: [T; 8], @@ -312,7 +312,7 @@ impl ProtocolBuilder for KeccakLayout { // documentation of `constrain_left_rotation64`. Here c_temp is the split // witness for a 1-rotation. - let c_temp: ArrayView = ArrayView::from_shape((5, 6), c_temp).unwrap(); + let c_temp: ArrayView = ArrayView::from_shape((5, 8), c_temp).unwrap(); let c_rot: ArrayView = ArrayView::from_shape((5, 8), c_rot).unwrap(); let (sizes, _) = rotation_split(1); @@ -405,6 +405,7 @@ impl ProtocolBuilder for KeccakLayout { )?; } } + assert!(rotation_witness.next().is_none()); let mut chi_output = chi_output.to_vec(); chi_output.extend(iota_output[8..].to_vec()); @@ -785,12 +786,12 @@ where c8[x] = conv64to8(c64[x]); } - let mut c_temp = [[0u64; 6]; 5]; + let mut c_temp = [[0u64; 8]; 5]; for i in 0..5 { let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) - .convert(vec![16, 15, 1, 16, 15, 1]) + .convert(vec![15, 1, 15, 1, 15, 1, 15, 1]) .values(); - for (j, size) in [16, 15, 1, 16, 15, 1].iter().enumerate() { + for (j, size) in [1, 15, 1, 15, 1, 15, 1, 15].iter().enumerate() { lk_multiplicity.assert_ux_in_u16(*size, rep[j]); } c_temp[i] = rep.try_into().unwrap(); @@ -837,10 +838,19 @@ where if *size != 32 { lk_multiplicity.assert_ux_in_u16(*size, rep[j]); } + match *size { + 32 | 1 => (), + 16 => lk_multiplicity.assert_ux::<16>(rep[j]), + 14 => lk_multiplicity.assert_ux::<14>(rep[j]), + 8 => lk_multiplicity.assert_ux::<8>(rep[j]), + 5 => lk_multiplicity.assert_ux::<5>(rep[j]), + _ => lk_multiplicity.assert_ux_in_u16(*size, rep[j]), + } } rotation_witness.extend(rep); } } + assert_eq!(rotation_witness.len(), rotation_witness_witin.len()); // Rho and Pi steps let mut rhopi_output64 = [[0u64; 5]; 5]; diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 7fd7cf1b9..c0d284348 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -3,7 +3,7 @@ use multilinear_extensions::{ Expression, Fixed, Instance, StructuralWitIn, ToExpr, WitIn, WitnessId, rlc_chip_record, }; use serde::de::DeserializeOwned; -use std::{cmp::Ordering, collections::HashMap, iter::once, marker::PhantomData}; +use std::{collections::HashMap, iter::once, marker::PhantomData}; use ff_ext::ExtensionField; @@ -782,6 +782,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { 14 => self.assert_u14(name_fn, expr), 8 => self.assert_byte(name_fn, expr), 5 => self.assert_u5(name_fn, expr), + 1 => self.assert_bit(name_fn, expr), c => panic!("Unsupported bit range {c}"), } } @@ -1058,8 +1059,26 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { // Lookup ranges for (i, (size, elem)) in split_rep.iter().enumerate() { - if *size != 32 { - self.assert_ux_in_u16(|| format!("{}_{}", name().into(), i), *size, elem.clone())?; + match *size { + 32 => (), + 16 => { + self.assert_ux::<_, _, 16>(|| format!("{}_{}", name().into(), i), elem.clone())? + } + 14 => { + self.assert_ux::<_, _, 14>(|| format!("{}_{}", name().into(), i), elem.clone())? + } + 8 => { + self.assert_ux::<_, _, 8>(|| format!("{}_{}", name().into(), i), elem.clone())? + } + 5 => { + self.assert_ux::<_, _, 5>(|| format!("{}_{}", name().into(), i), elem.clone())? + } + 1 => self.assert_bit(|| format!("{}_{}", name().into(), i), elem.clone())?, + _ => self.assert_ux_in_u16( + || format!("{}_{}", name().into(), i), + *size, + elem.clone(), + )?, } } @@ -1073,25 +1092,26 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { let mut rep_x = rep_x.to_owned(); rep_x.rotate_right(chunks_rotation); - for i in 0..2 { - // The respective 4 elements in the byte representation - let lhs = rep8[4 * i..4 * (i + 1)] + // 64 bits represent in 4 limb, each with 16 bits + let num_limbs = 4; + let mut rep_x_iter = rep_x.iter().cloned(); + for limb_i in 0..num_limbs { + let lhs = rep8[2 * limb_i..2 * (limb_i + 1)] .iter() .map(|wit| (8, wit.expr())) .collect_vec(); - let cnt = rep_x.len() / 2; - let rhs = &rep_x[cnt * i..cnt * (i + 1)]; + let rhs_limbs = take_til_threshold(&mut rep_x_iter, 16, &|limb| limb.0).unwrap(); - assert_eq!(rhs.iter().map(|e| e.0).sum::(), 32); + assert_eq!(rhs_limbs.iter().map(|e| e.0).sum::(), 16); - self.require_reps_equal::<32, _, _>( + self.require_reps_equal::<16, _, _>( ||format!( - "rotation internal {}, round {i}, rot: {chunks_rotation}, delta: {delta}, {:?}", + "rotation internal {}, round {limb_i}, rot: {chunks_rotation}, delta: {delta}, {:?}", name().into(), sizes ), &lhs, - rhs, + &rhs_limbs, )?; } Ok(()) @@ -1113,34 +1133,64 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { } } +/// take items from an iterator until the accumulated "weight" (measured by `f`) +/// reaches exactly `threshold`. +/// +/// - `iter`: a mutable iterator to consume items from +/// - `threshold`: the sum of weights at which to stop and return the group +/// - `f`: closure that extracts the "weight" (e.g., bit length) from each item +/// +/// returns: +/// - `Some(Vec)` containing the next group of items whose weights sum to `threshold` +/// - `None` if the iterator is exhausted and no items remain +/// +/// panics if the sum of weights ever exceeds `threshold`. +pub fn take_til_threshold(iter: &mut I, threshold: usize, f: &F) -> Option> +where + I: Iterator, + F: Fn(&T) -> usize, +{ + let mut group = Vec::new(); + let mut sum = 0; + + for x in iter.by_ref() { + sum += f(&x); + group.push(x); + + if sum == threshold { + return Some(group); + } else if sum > threshold { + panic!("sum exceeded threshold!"); + } + } + + if group.is_empty() { + None + } else { + Some(group) // leftover if input not perfectly divisible + } +} + /// Compute an adequate split of 64-bits into chunks for performing a rotation /// by `delta`. The first element of the return value is the vec of chunk sizes. /// The second one is the length of its suffix that needs to be rotated pub fn rotation_split(delta: usize) -> (Vec, usize) { - let delta = delta % 64; - if delta == 0 { - return (vec![32, 32], 0); - } - - // This split meets all requirements except for <= 16 sizes - let split32 = match delta.cmp(&32) { - Ordering::Less => vec![32 - delta, delta, 32 - delta, delta], - Ordering::Equal => vec![32, 32], - Ordering::Greater => vec![32 - (delta - 32), delta - 32, 32 - (delta - 32), delta - 32], - }; - - // Split off large chunks - let split16 = split32 - .into_iter() - .flat_map(|size| { - assert!(size < 32); - if size <= 16 { - vec![size] + return (vec![16, 16, 16, 16], 0); + } + + let remainder = delta % 16; + let split16 = std::iter::repeat_with(|| [16 - remainder, remainder]) + .flatten() + .scan(0, |sum, x| { + if *sum >= 64 { + None } else { - vec![16, size - 16] + *sum += x; + Some(x) } }) + .filter(|v| *v > 0) .collect_vec(); let mut sum = 0; @@ -1151,7 +1201,7 @@ pub fn rotation_split(delta: usize) -> (Vec, usize) { } } - panic!(); + panic!("delta {:?} split16 {:?}", remainder, split16); } pub fn expansion_expr( From 8f0a7e395c3fc804958dd1e1ecb3dfe65a41b6b2 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 25 Aug 2025 17:21:17 +0800 Subject: [PATCH 2/2] bug fix --- ceno_zkvm/src/precompiles/lookup_keccakf.rs | 6 ++---- gkr_iop/src/circuit_builder.rs | 3 +++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index c8b672cc0..79efe3b34 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -791,7 +791,7 @@ where let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) .convert(vec![15, 1, 15, 1, 15, 1, 15, 1]) .values(); - for (j, size) in [1, 15, 1, 15, 1, 15, 1, 15].iter().enumerate() { + for (j, size) in [15, 1, 15, 1, 15, 1, 15, 1].iter().enumerate() { lk_multiplicity.assert_ux_in_u16(*size, rep[j]); } c_temp[i] = rep.try_into().unwrap(); @@ -835,11 +835,9 @@ where .convert(sizes.clone()) .values(); for (j, size) in sizes.iter().enumerate() { - if *size != 32 { - lk_multiplicity.assert_ux_in_u16(*size, rep[j]); - } match *size { 32 | 1 => (), + 18 => lk_multiplicity.assert_ux::<18>(rep[j]), 16 => lk_multiplicity.assert_ux::<16>(rep[j]), 14 => lk_multiplicity.assert_ux::<14>(rep[j]), 8 => lk_multiplicity.assert_ux::<8>(rep[j]), diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index c0d284348..361b75d11 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -1061,6 +1061,9 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { for (i, (size, elem)) in split_rep.iter().enumerate() { match *size { 32 => (), + 18 => { + self.assert_ux::<_, _, 18>(|| format!("{}_{}", name().into(), i), elem.clone())? + } 16 => { self.assert_ux::<_, _, 16>(|| format!("{}_{}", name().into(), i), elem.clone())? }