Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ args = [
"--workspace",
]
command = "cargo"
env = { RUST_MIN_STACK = "33554432" }
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lookup_keccak involving some deep recursion when build the circuit. I think it's a existing issue, just the new change hit the threshold unluckily. With that, we temporarily increase RUST_MIN_STACK to make tests pass. It doesn't break functional correctness.

workspace = false

[tasks.tests_v2]
Expand All @@ -27,6 +28,7 @@ args = [
"u16limb_circuit",
]
command = "cargo"
env = { RUST_MIN_STACK = "33554432" }
workspace = false


Expand Down
24 changes: 16 additions & 8 deletions ceno_zkvm/src/precompiles/lookup_keccakf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ pub struct KeccakFixedCols<T> {
pub struct KeccakWitCols<T> {
pub input8: [T; 200],
pub c_aux: [T; 200],
pub c_temp: [T; 30],
pub c_temp: [T; 40],
pub c_rot: [T; 40],
pub d: [T; 40],
pub theta_output: [T; 200],
pub rotation_witness: [T; 146],
pub rotation_witness: [T; 196],
pub rhopi_output: [T; 200],
pub nonlinear: [T; 200],
pub chi_output: [T; 8],
Expand Down Expand Up @@ -312,7 +312,7 @@ impl<E: ExtensionField> ProtocolBuilder<E> for KeccakLayout<E> {
// documentation of `constrain_left_rotation64`. Here c_temp is the split
// witness for a 1-rotation.

let c_temp: ArrayView<WitIn, Ix2> = ArrayView::from_shape((5, 6), c_temp).unwrap();
let c_temp: ArrayView<WitIn, Ix2> = ArrayView::from_shape((5, 8), c_temp).unwrap();
let c_rot: ArrayView<WitIn, Ix2> = ArrayView::from_shape((5, 8), c_rot).unwrap();

let (sizes, _) = rotation_split(1);
Expand Down Expand Up @@ -405,6 +405,7 @@ impl<E: ExtensionField> ProtocolBuilder<E> for KeccakLayout<E> {
)?;
}
}
assert!(rotation_witness.next().is_none());

let mut chi_output = chi_output.to_vec();
chi_output.extend(iota_output[8..].to_vec());
Expand Down Expand Up @@ -785,12 +786,12 @@ where
c8[x] = conv64to8(c64[x]);
}

let mut c_temp = [[0u64; 6]; 5];
let mut c_temp = [[0u64; 8]; 5];
for i in 0..5 {
let rep = MaskRepresentation::new(vec![(64, c64[i]).into()])
.convert(vec![16, 15, 1, 16, 15, 1])
.convert(vec![15, 1, 15, 1, 15, 1, 15, 1])
.values();
for (j, size) in [16, 15, 1, 16, 15, 1].iter().enumerate() {
for (j, size) in [15, 1, 15, 1, 15, 1, 15, 1].iter().enumerate() {
lk_multiplicity.assert_ux_in_u16(*size, rep[j]);
}
c_temp[i] = rep.try_into().unwrap();
Expand Down Expand Up @@ -834,13 +835,20 @@ where
.convert(sizes.clone())
.values();
for (j, size) in sizes.iter().enumerate() {
if *size != 32 {
lk_multiplicity.assert_ux_in_u16(*size, rep[j]);
match *size {
32 | 1 => (),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not checking it when size is 1?

Copy link
Collaborator Author

@hero78119 hero78119 Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because size 1 means bit check, and we compiled it into (1-expr)*expr with zero check already. So size 1 we can skip multiplicity += 1

18 => lk_multiplicity.assert_ux::<18>(rep[j]),
16 => lk_multiplicity.assert_ux::<16>(rep[j]),
14 => lk_multiplicity.assert_ux::<14>(rep[j]),
8 => lk_multiplicity.assert_ux::<8>(rep[j]),
5 => lk_multiplicity.assert_ux::<5>(rep[j]),
_ => lk_multiplicity.assert_ux_in_u16(*size, rep[j]),
}
}
rotation_witness.extend(rep);
}
}
assert_eq!(rotation_witness.len(), rotation_witness_witin.len());

// Rho and Pi steps
let mut rhopi_output64 = [[0u64; 5]; 5];
Expand Down
119 changes: 86 additions & 33 deletions gkr_iop/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use multilinear_extensions::{
Expression, Fixed, Instance, StructuralWitIn, ToExpr, WitIn, WitnessId, rlc_chip_record,
};
use serde::de::DeserializeOwned;
use std::{cmp::Ordering, collections::HashMap, iter::once, marker::PhantomData};
use std::{collections::HashMap, iter::once, marker::PhantomData};

use ff_ext::ExtensionField;

Expand Down Expand Up @@ -782,6 +782,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
14 => self.assert_u14(name_fn, expr),
8 => self.assert_byte(name_fn, expr),
5 => self.assert_u5(name_fn, expr),
1 => self.assert_bit(name_fn, expr),
c => panic!("Unsupported bit range {c}"),
}
}
Expand Down Expand Up @@ -1058,8 +1059,29 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {

// Lookup ranges
for (i, (size, elem)) in split_rep.iter().enumerate() {
if *size != 32 {
self.assert_ux_in_u16(|| format!("{}_{}", name().into(), i), *size, elem.clone())?;
match *size {
32 => (),
18 => {
self.assert_ux::<_, _, 18>(|| format!("{}_{}", name().into(), i), elem.clone())?
}
16 => {
self.assert_ux::<_, _, 16>(|| format!("{}_{}", name().into(), i), elem.clone())?
}
14 => {
self.assert_ux::<_, _, 14>(|| format!("{}_{}", name().into(), i), elem.clone())?
}
8 => {
self.assert_ux::<_, _, 8>(|| format!("{}_{}", name().into(), i), elem.clone())?
}
5 => {
self.assert_ux::<_, _, 5>(|| format!("{}_{}", name().into(), i), elem.clone())?
}
1 => self.assert_bit(|| format!("{}_{}", name().into(), i), elem.clone())?,
_ => self.assert_ux_in_u16(
|| format!("{}_{}", name().into(), i),
*size,
elem.clone(),
)?,
Comment on lines +1062 to +1084
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we put this in a helper function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes will do it in next PR #1036 to fix precompile e2e sound

}
}

Expand All @@ -1073,25 +1095,26 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
let mut rep_x = rep_x.to_owned();
rep_x.rotate_right(chunks_rotation);

for i in 0..2 {
// The respective 4 elements in the byte representation
let lhs = rep8[4 * i..4 * (i + 1)]
// 64 bits represent in 4 limb, each with 16 bits
let num_limbs = 4;
let mut rep_x_iter = rep_x.iter().cloned();
for limb_i in 0..num_limbs {
let lhs = rep8[2 * limb_i..2 * (limb_i + 1)]
.iter()
.map(|wit| (8, wit.expr()))
.collect_vec();
let cnt = rep_x.len() / 2;
let rhs = &rep_x[cnt * i..cnt * (i + 1)];
let rhs_limbs = take_til_threshold(&mut rep_x_iter, 16, &|limb| limb.0).unwrap();

assert_eq!(rhs.iter().map(|e| e.0).sum::<usize>(), 32);
assert_eq!(rhs_limbs.iter().map(|e| e.0).sum::<usize>(), 16);

self.require_reps_equal::<32, _, _>(
self.require_reps_equal::<16, _, _>(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key highlights: constrain on 16 bits limb individually instead of 32 bits (limb_16 + (1 << 16( * limb_16) because later are larger than some finite field characteristic, e.g. babybear

||format!(
"rotation internal {}, round {i}, rot: {chunks_rotation}, delta: {delta}, {:?}",
"rotation internal {}, round {limb_i}, rot: {chunks_rotation}, delta: {delta}, {:?}",
name().into(),
sizes
),
&lhs,
rhs,
&rhs_limbs,
)?;
}
Ok(())
Expand All @@ -1113,34 +1136,64 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
}
}

/// take items from an iterator until the accumulated "weight" (measured by `f`)
/// reaches exactly `threshold`.
///
/// - `iter`: a mutable iterator to consume items from
/// - `threshold`: the sum of weights at which to stop and return the group
/// - `f`: closure that extracts the "weight" (e.g., bit length) from each item
///
/// returns:
/// - `Some(Vec<T>)` containing the next group of items whose weights sum to `threshold`
/// - `None` if the iterator is exhausted and no items remain
///
/// panics if the sum of weights ever exceeds `threshold`.
pub fn take_til_threshold<I, T, F>(iter: &mut I, threshold: usize, f: &F) -> Option<Vec<T>>
where
I: Iterator<Item = T>,
F: Fn(&T) -> usize,
{
let mut group = Vec::new();
let mut sum = 0;

for x in iter.by_ref() {
sum += f(&x);
group.push(x);

if sum == threshold {
return Some(group);
} else if sum > threshold {
panic!("sum exceeded threshold!");
}
}

if group.is_empty() {
None
} else {
Some(group) // leftover if input not perfectly divisible
}
}

/// Compute an adequate split of 64-bits into chunks for performing a rotation
/// by `delta`. The first element of the return value is the vec of chunk sizes.
/// The second one is the length of its suffix that needs to be rotated
pub fn rotation_split(delta: usize) -> (Vec<usize>, usize) {
let delta = delta % 64;

if delta == 0 {
return (vec![32, 32], 0);
}

// This split meets all requirements except for <= 16 sizes
let split32 = match delta.cmp(&32) {
Ordering::Less => vec![32 - delta, delta, 32 - delta, delta],
Ordering::Equal => vec![32, 32],
Ordering::Greater => vec![32 - (delta - 32), delta - 32, 32 - (delta - 32), delta - 32],
};

// Split off large chunks
let split16 = split32
.into_iter()
.flat_map(|size| {
assert!(size < 32);
if size <= 16 {
vec![size]
return (vec![16, 16, 16, 16], 0);
}

let remainder = delta % 16;
let split16 = std::iter::repeat_with(|| [16 - remainder, remainder])
.flatten()
.scan(0, |sum, x| {
if *sum >= 64 {
None
} else {
vec![16, size - 16]
*sum += x;
Some(x)
}
})
.filter(|v| *v > 0)
.collect_vec();

let mut sum = 0;
Expand All @@ -1151,7 +1204,7 @@ pub fn rotation_split(delta: usize) -> (Vec<usize>, usize) {
}
}

panic!();
panic!("delta {:?} split16 {:?}", remainder, split16);
}

pub fn expansion_expr<E: ExtensionField, const SIZE: usize>(
Expand Down