From 0f2edd8ac4346eba7fdf35d03971039ae9820ec9 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 30 May 2025 16:14:52 +0800 Subject: [PATCH] [Experiment] new crates subprotocols --- Cargo.lock | 28 ++ Cargo.toml | 1 + subprotocols/Cargo.toml | 30 ++ subprotocols/benches/expr_based_logup.rs | 144 +++++++ subprotocols/examples/zerocheck_logup.rs | 93 +++++ subprotocols/src/error.rs | 9 + subprotocols/src/expression.rs | 167 ++++++++ subprotocols/src/expression/evaluate.rs | 459 +++++++++++++++++++++ subprotocols/src/expression/macros.rs | 100 +++++ subprotocols/src/expression/op.rs | 81 ++++ subprotocols/src/lib.rs | 9 + subprotocols/src/points.rs | 75 ++++ subprotocols/src/sumcheck.rs | 454 +++++++++++++++++++++ subprotocols/src/test_utils.rs | 46 +++ subprotocols/src/utils.rs | 235 +++++++++++ subprotocols/src/zerocheck.rs | 495 +++++++++++++++++++++++ 16 files changed, 2426 insertions(+) create mode 100644 subprotocols/Cargo.toml create mode 100644 subprotocols/benches/expr_based_logup.rs create mode 100644 subprotocols/examples/zerocheck_logup.rs create mode 100644 subprotocols/src/error.rs create mode 100644 subprotocols/src/expression.rs create mode 100644 subprotocols/src/expression/evaluate.rs create mode 100644 subprotocols/src/expression/macros.rs create mode 100644 subprotocols/src/expression/op.rs create mode 100644 subprotocols/src/lib.rs create mode 100644 subprotocols/src/points.rs create mode 100644 subprotocols/src/sumcheck.rs create mode 100644 subprotocols/src/test_utils.rs create mode 100644 subprotocols/src/utils.rs create mode 100644 subprotocols/src/zerocheck.rs diff --git a/Cargo.lock b/Cargo.lock index 2f23ff4d6..112d8db29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -115,6 +115,16 @@ dependencies = [ "backtrace", ] +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "arrayref" version = "0.3.9" @@ -2752,6 +2762,24 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "subprotocols" +version = "0.1.0" +dependencies = [ + "ark-std", + "criterion", + "ff_ext", + "itertools 0.13.0", + "multilinear_extensions", + "p3-field", + "p3-goldilocks", + "rand", + "rayon", + "serde", + "thiserror 1.0.69", + "transcript", +] + [[package]] name = "substrate-bn" version = "0.6.0" diff --git a/Cargo.toml b/Cargo.toml index d90f59573..1e4bcc417 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "multilinear_extensions", "sumcheck_macro", "poseidon", + "subprotocols", "sumcheck", "transcript", "whir", diff --git a/subprotocols/Cargo.toml b/subprotocols/Cargo.toml new file mode 100644 index 000000000..b425bf380 --- /dev/null +++ b/subprotocols/Cargo.toml @@ -0,0 +1,30 @@ +[package] +categories.workspace = true +description = "Subprotocols" +edition.workspace = true +keywords.workspace = true +license.workspace = true +name = "subprotocols" +readme.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +ark-std = { version = "0.5" } +ff_ext = { path = "../ff_ext" } +itertools.workspace = true +multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } +p3-field.workspace = true +rand.workspace = true +rayon.workspace = true +serde.workspace = true +thiserror = "1" +transcript = { path = "../transcript" } + +[dev-dependencies] +criterion.workspace = true +p3-goldilocks.workspace = true + +[[bench]] +harness = false +name = "expr_based_logup" diff --git a/subprotocols/benches/expr_based_logup.rs b/subprotocols/benches/expr_based_logup.rs new file mode 100644 index 000000000..7c1bbaef7 --- /dev/null +++ b/subprotocols/benches/expr_based_logup.rs @@ -0,0 +1,144 @@ +use std::{array, time::Duration}; + +use ark_std::test_rng; +use criterion::*; +use ff_ext::FromUniformBytes; +use itertools::Itertools; +use p3_field::extension::BinomialExtensionField; +use p3_goldilocks::Goldilocks; +use subprotocols::{ + expression::{Constant, Expression, Witness}, + sumcheck::SumcheckProverState, + test_utils::{random_point, random_poly}, + zerocheck::ZerocheckProverState, +}; +use transcript::BasicTranscript as Transcript; + +criterion_group!(benches, zerocheck_fn, sumcheck_fn); +criterion_main!(benches); + +const NUM_SAMPLES: usize = 10; +const NV: [usize; 2] = [25, 26]; + +fn sumcheck_fn(c: &mut Criterion) { + type E = BinomialExtensionField; + + for nv in NV { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("logup_sumcheck_nv_{}", nv)); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function( + BenchmarkId::new("prove_sumcheck", format!("sumcheck_nv_{}", nv)), + |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + let mut rng = test_rng(); + // Initialize logup expression. + let eq = Expression::Wit(Witness::EqPoly(0)); + let beta = Expression::Const(Constant::Challenge(0)); + let [d0, d1, n0, n1] = + array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); + let expr = eq * (d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0)); + + // Randomly generate point and witness. + let point = random_point(&mut rng, nv); + + let d0 = random_poly(&mut rng, nv); + let d1 = random_poly(&mut rng, nv); + let n0 = random_poly(&mut rng, nv); + let n1 = random_poly(&mut rng, nv); + let mut ext_mles = [d0.clone(), d1.clone(), n0.clone(), n1.clone()]; + + let challenges = vec![E::random(&mut rng)]; + + let ext_mle_refs = + ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); + + let mut prover_transcript = Transcript::new(b"test"); + let prover = SumcheckProverState::new( + expr, + &[&point], + ext_mle_refs, + vec![], + &challenges, + &mut prover_transcript, + ); + + let instant = std::time::Instant::now(); + let _ = black_box(prover.prove()); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }, + ); + + group.finish(); + } +} + +fn zerocheck_fn(c: &mut Criterion) { + type E = BinomialExtensionField; + + for nv in NV { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("logup_sumcheck_nv_{}", nv)); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function( + BenchmarkId::new("prove_sumcheck", format!("sumcheck_nv_{}", nv)), + |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + // Initialize logup expression. + let mut rng = test_rng(); + let beta = Expression::Const(Constant::Challenge(0)); + let [d0, d1, n0, n1] = + array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); + let expr = d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0); + + // Randomly generate point and witness. + let point = random_point(&mut rng, nv); + + let d0 = random_poly(&mut rng, nv); + let d1 = random_poly(&mut rng, nv); + let n0 = random_poly(&mut rng, nv); + let n1 = random_poly(&mut rng, nv); + let mut ext_mles = [d0.clone(), d1.clone(), n0.clone(), n1.clone()]; + + let challenges = vec![E::random(&mut rng)]; + + let ext_mle_refs = + ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); + + let mut prover_transcript = Transcript::new(b"test"); + let prover = ZerocheckProverState::new( + vec![expr], + &[&point], + ext_mle_refs, + vec![], + &challenges, + &mut prover_transcript, + ); + + let instant = std::time::Instant::now(); + let _ = black_box(prover.prove()); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }, + ); + + group.finish(); + } +} diff --git a/subprotocols/examples/zerocheck_logup.rs b/subprotocols/examples/zerocheck_logup.rs new file mode 100644 index 000000000..36c227b7e --- /dev/null +++ b/subprotocols/examples/zerocheck_logup.rs @@ -0,0 +1,93 @@ +use std::array; + +use ff_ext::{ExtensionField, FromUniformBytes}; +use itertools::{Itertools, izip}; +use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; +use p3_goldilocks::Goldilocks as F; +use rand::thread_rng; +use subprotocols::{ + expression::{Constant, Expression, Witness}, + sumcheck::{SumcheckProof, SumcheckProverOutput}, + test_utils::{random_point, random_poly}, + utils::eq_vecs, + zerocheck::{ZerocheckProverState, ZerocheckVerifierState}, +}; +use transcript::BasicTranscript; + +type E = BinomialExtensionField; + +fn run_prover( + point: &[E], + ext_mles: &mut [Vec], + expr: Expression, + challenges: Vec, +) -> SumcheckProof { + let timer = std::time::Instant::now(); + let ext_mle_refs = ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); + + let mut prover_transcript = BasicTranscript::new(b"test"); + let prover = ZerocheckProverState::new( + vec![expr], + &[point], + ext_mle_refs, + vec![], + &challenges, + &mut prover_transcript, + ); + + let SumcheckProverOutput { proof, .. } = prover.prove(); + println!("Proving time: {:?}", timer.elapsed()); + proof +} + +fn run_verifier( + proof: SumcheckProof, + ans: &E, + point: &[E], + expr: Expression, + challenges: Vec, +) { + let mut verifier_transcript = BasicTranscript::new(b"test"); + let verifier = ZerocheckVerifierState::new( + vec![*ans], + vec![expr], + vec![], + vec![point], + proof, + &challenges, + &mut verifier_transcript, + ); + + verifier.verify().expect("verification failed"); +} + +fn main() { + let num_vars = 20; + let mut rng = thread_rng(); + + // Initialize logup expression. + let beta = Expression::Const(Constant::Challenge(0)); + let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); + let expr = d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0); + + // Randomly generate point and witness. + let point = random_point(&mut rng, num_vars); + + let d0 = random_poly(&mut rng, num_vars); + let d1 = random_poly(&mut rng, num_vars); + let n0 = random_poly(&mut rng, num_vars); + let n1 = random_poly(&mut rng, num_vars); + let mut ext_mles = [d0.clone(), d1.clone(), n0.clone(), n1.clone()]; + + let challenges = vec![E::random(&mut rng)]; + + let proof = run_prover(&point, &mut ext_mles, expr.clone(), challenges.clone()); + + let eqs = eq_vecs([point.as_slice()].into_iter(), &[E::ONE]); + + let ans: E = izip!(&eqs[0], &d0, &d1, &n0, &n1) + .map(|(eq, d0, d1, n0, n1)| *eq * (*d0 * *d1 + challenges[0] * (*d0 * *n1 + *d1 * *n0))) + .sum(); + + run_verifier(proof, &ans, &point, expr, challenges); +} diff --git a/subprotocols/src/error.rs b/subprotocols/src/error.rs new file mode 100644 index 000000000..ca8eddefe --- /dev/null +++ b/subprotocols/src/error.rs @@ -0,0 +1,9 @@ +use thiserror::Error; + +use crate::expression::Expression; + +#[derive(Clone, Debug, Error)] +pub enum VerifierError { + #[error("Claim not match: expr: {0:?}\n (expr name: {3:?})\n expect: {1:?}, got: {2:?}")] + ClaimNotMatch(Expression, E, E, String), +} diff --git a/subprotocols/src/expression.rs b/subprotocols/src/expression.rs new file mode 100644 index 000000000..60fd882cb --- /dev/null +++ b/subprotocols/src/expression.rs @@ -0,0 +1,167 @@ +use std::sync::Arc; + +use ff_ext::ExtensionField; +use serde::{Deserialize, Serialize}; + +mod evaluate; +mod op; + +mod macros; + +pub type Point = Arc>; + +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub enum Constant { + /// Base field + Base(i64), + /// Challenge + Challenge(usize), + /// Sum + Sum(Box, Box), + /// Product + Product(Box, Box), + /// Neg + Neg(Box), + /// Pow + Pow(Box, usize), +} + +impl Default for Constant { + fn default() -> Self { + Constant::Base(0) + } +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub enum Witness { + /// Base field polynomial (index). + BasePoly(usize), + /// Extension field polynomial (index). + ExtPoly(usize), + /// Eq polynomial + EqPoly(usize), +} + +impl Default for Witness { + fn default() -> Self { + Witness::BasePoly(0) + } +} + +#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub enum Expression { + /// Constant + Const(Constant), + /// Witness. + Wit(Witness), + /// This is the sum of two expressions, with `degree`. + Sum(Box, Box, usize), + /// This is the product of two expressions, with `degree`. + Product(Box, Box, usize), + /// Neg, with `degree`. + Neg(Box, usize), + /// Pow, with `D` and `degree`. + Pow(Box, usize, usize), +} + +impl std::fmt::Debug for Expression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Expression::Const(c) => write!(f, "{:?}", c), + Expression::Wit(w) => write!(f, "{:?}", w), + Expression::Sum(a, b, _) => write!(f, "({:?} + {:?})", a, b), + Expression::Product(a, b, _) => write!(f, "({:?} * {:?})", a, b), + Expression::Neg(a, _) => write!(f, "(-{:?})", a), + Expression::Pow(a, n, _) => write!(f, "({:?})^({})", a, n), + } + } +} + +impl std::fmt::Debug for Witness { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Witness::BasePoly(i) => write!(f, "BP[{}]", i), + Witness::ExtPoly(i) => write!(f, "EP[{}]", i), + Witness::EqPoly(i) => write!(f, "EQ[{}]", i), + } + } +} + +/// Vector of univariate polys. +#[derive(Clone, Debug)] +enum UniPolyVectorType { + Base(Vec>), + Ext(Vec>), +} + +/// Vector of field type. +#[derive(Clone, PartialEq, Eq)] +pub enum VectorType { + Base(Vec), + Ext(Vec), +} + +impl std::fmt::Debug for VectorType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VectorType::Base(v) => { + let mut v = v.iter(); + write!(f, "[")?; + if let Some(e) = v.next() { + write!(f, "{:?}", e)?; + } + for _ in 0..2 { + if let Some(e) = v.next() { + write!(f, ", {:?}", e)?; + } else { + break; + } + } + if v.next().is_some() { + write!(f, ", ...]")?; + } else { + write!(f, "]")?; + }; + Ok(()) + } + VectorType::Ext(v) => { + let mut v = v.iter(); + write!(f, "[")?; + if let Some(e) = v.next() { + write!(f, "{:?}", e)?; + } + for _ in 0..2 { + if let Some(e) = v.next() { + write!(f, ", {:?}", e)?; + } else { + break; + } + } + if v.next().is_some() { + write!(f, ", ...]")?; + } else { + write!(f, "]")?; + }; + Ok(()) + } + } + } +} + +#[derive(Clone, Debug)] +enum ScalarType { + Base(E::BaseField), + Ext(E), +} + +impl From for Expression { + fn from(w: Witness) -> Self { + Expression::Wit(w) + } +} + +impl From for Expression { + fn from(c: Constant) -> Self { + Expression::Const(c) + } +} diff --git a/subprotocols/src/expression/evaluate.rs b/subprotocols/src/expression/evaluate.rs new file mode 100644 index 000000000..cb17c9e34 --- /dev/null +++ b/subprotocols/src/expression/evaluate.rs @@ -0,0 +1,459 @@ +use ff_ext::ExtensionField; +use itertools::{Itertools, zip_eq}; +use multilinear_extensions::virtual_poly::eq_eval; +use p3_field::{Field, PrimeCharacteristicRing}; + +use crate::{op_by_type, utils::i64_to_field}; + +use super::{Constant, Expression, ScalarType, UniPolyVectorType, VectorType, Witness}; + +impl Expression { + pub fn degree(&self) -> usize { + match self { + Expression::Const(_) => 0, + Expression::Wit(_) => 1, + Expression::Sum(_, _, degree) => *degree, + Expression::Product(_, _, degree) => *degree, + Expression::Neg(_, degree) => *degree, + Expression::Pow(_, _, degree) => *degree, + } + } + + pub fn is_ext(&self) -> bool { + match self { + Expression::Const(c) => c.is_ext(), + Expression::Wit(w) => w.is_ext(), + Expression::Sum(e0, e1, _) | Expression::Product(e0, e1, _) => { + e0.is_ext() || e1.is_ext() + } + Expression::Neg(e, _) => e.is_ext(), + Expression::Pow(e, d, _) => { + if *d > 0 { + e.is_ext() + } else { + false + } + } + } + } + + pub fn evaluate( + &self, + ext_mle_evals: &[E], + base_mle_evals: &[E], + out_points: &[&[E]], + in_point: &[E], + challenges: &[E], + ) -> E { + match self { + Expression::Const(constant) => constant.evaluate(challenges), + Expression::Wit(w) => w.evaluate(base_mle_evals, ext_mle_evals, out_points, in_point), + Expression::Sum(e0, e1, _) => { + e0.evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ) + e1.evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ) + } + Expression::Product(e0, e1, _) => { + e0.evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ) * e1.evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ) + } + Expression::Neg(e, _) => -e.evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ), + Expression::Pow(e, d, _) => e + .evaluate( + ext_mle_evals, + base_mle_evals, + out_points, + in_point, + challenges, + ) + .exp_u64(*d as u64), + } + } + + pub fn calc( + &self, + ext: &[Vec], + base: &[Vec], + eqs: &[Vec], + challenges: &[E], + ) -> VectorType { + assert!(!(ext.is_empty() && base.is_empty())); + let size = if !ext.is_empty() { + ext[0].len() + } else { + base[0].len() + }; + match self { + Expression::Const(constant) => { + VectorType::Ext(vec![constant.evaluate(challenges); size]) + } + Expression::Wit(w) => match w { + Witness::BasePoly(index) => VectorType::Base(base[*index].clone()), + Witness::ExtPoly(index) => VectorType::Ext(ext[*index].clone()), + Witness::EqPoly(index) => VectorType::Ext(eqs[*index].clone()), + }, + Expression::Sum(e0, e1, _) => { + e0.calc(ext, base, eqs, challenges) + e1.calc(ext, base, eqs, challenges) + } + Expression::Product(e0, e1, _) => { + e0.calc(ext, base, eqs, challenges) * e1.calc(ext, base, eqs, challenges) + } + Expression::Neg(e, _) => -e.calc(ext, base, eqs, challenges), + Expression::Pow(e, d, _) => { + let poly = e.calc(ext, base, eqs, challenges); + op_by_type!( + VectorType, + poly, + |poly| { poly.into_iter().map(|x| x.exp_u64(*d as u64)).collect_vec() }, + |ext| VectorType::Ext(ext), + |base| VectorType::Base(base) + ) + } + } + } + + #[allow(clippy::too_many_arguments)] + pub fn sumcheck_uni_poly( + &self, + ext_mles: &[&mut [E]], + base_after_mles: &[Vec], + base_mles: &[&[E::BaseField]], + eqs: &[Vec], + challenges: &[E], + size: usize, + degree: usize, + ) -> Vec { + let poly = self.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + op_by_type!(UniPolyVectorType, poly, |poly| { + poly.into_iter().fold(vec![E::ZERO; degree], |acc, x| { + zip_eq(acc, x).map(|(a, b)| a + b).collect_vec() + }) + }) + } + + /// Compute \sum_x (eq(0, x) + eq(1, x)) * expr_0(X, x) + #[allow(clippy::too_many_arguments)] + pub fn zerocheck_uni_poly<'a, E: ExtensionField>( + &self, + ext_mles: &[&mut [E]], + base_after_mles: &[Vec], + base_mles: &[&[E::BaseField]], + challenges: &[E], + coeffs: impl Iterator, + size: usize, + ) -> Vec { + let degree = self.degree(); + let poly = self.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + &[], + challenges, + size, + degree, + ); + + op_by_type!(UniPolyVectorType, poly, |poly| { + zip_eq(coeffs, poly).fold(vec![E::ZERO; degree], |mut acc, (c, poly)| { + zip_eq(&mut acc, poly).for_each(|(a, x)| *a += *c * x); + acc + }) + }) + } + + /// Compute the extension field univariate polynomial evaluated on 1..degree + 1. + #[allow(clippy::too_many_arguments)] + fn uni_poly_inner( + &self, + ext_mles: &[&mut [E]], + base_after_mles: &[Vec], + base_mles: &[&[E::BaseField]], + eqs: &[Vec], + challenges: &[E], + size: usize, + degree: usize, + ) -> UniPolyVectorType { + match self { + Expression::Const(constant) => { + let value = constant.evaluate(challenges); + UniPolyVectorType::Ext(vec![vec![value; degree]; size >> 1]) + } + Expression::Wit(w) => match w { + Witness::BasePoly(index) => { + if !base_mles.is_empty() { + UniPolyVectorType::Base(uni_poly_helper(base_mles[*index], size, degree)) + } else { + UniPolyVectorType::Ext(uni_poly_helper( + &base_after_mles[*index], + size, + degree, + )) + } + } + Witness::ExtPoly(index) => { + UniPolyVectorType::Ext(uni_poly_helper(ext_mles[*index], size, degree)) + } + Witness::EqPoly(index) => { + UniPolyVectorType::Ext(uni_poly_helper(&eqs[*index], size, degree)) + } + }, + Expression::Sum(expr0, expr1, _) => { + let poly0 = expr0.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + let poly1 = expr1.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + poly0 + poly1 + } + Expression::Product(expr0, expr1, _) => { + let poly0 = expr0.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + let poly1 = expr1.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + poly0 * poly1 + } + Expression::Neg(expr, _) => { + let poly = expr.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + -poly + } + Expression::Pow(expr, d, _) => { + let poly = expr.uni_poly_inner( + ext_mles, + base_after_mles, + base_mles, + eqs, + challenges, + size, + degree, + ); + op_by_type!( + UniPolyVectorType, + poly, + |poly| { + poly.into_iter() + .map(|x| x.iter().map(|x| x.exp_u64(*d as u64)).collect_vec()) + .collect_vec() + }, + |ext| UniPolyVectorType::Ext(ext), + |base| UniPolyVectorType::Base(base) + ) + } + } + } +} + +impl Constant { + pub fn is_ext(&self) -> bool { + match self { + Constant::Base(_) => false, + Constant::Challenge(_) => true, + Constant::Sum(c0, c1) | Constant::Product(c0, c1) => c0.is_ext() || c1.is_ext(), + Constant::Neg(c) => c.is_ext(), + Constant::Pow(c, _) => c.is_ext(), + } + } + + pub fn evaluate(&self, challenges: &[E]) -> E { + let res = self.evaluate_inner(challenges); + op_by_type!(ScalarType, res, |b| b, |e| e, |bf| E::from(bf)) + } + + fn evaluate_inner(&self, challenges: &[E]) -> ScalarType { + match self { + Constant::Base(value) => ScalarType::Base(i64_to_field(*value)), + Constant::Challenge(index) => ScalarType::Ext(challenges[*index]), + Constant::Sum(c0, c1) => c0.evaluate_inner(challenges) + c1.evaluate_inner(challenges), + Constant::Product(c0, c1) => { + c0.evaluate_inner(challenges) * c1.evaluate_inner(challenges) + } + Constant::Neg(c) => -c.evaluate_inner(challenges), + Constant::Pow(c, degree) => { + let value = c.evaluate_inner(challenges); + op_by_type!( + ScalarType, + value, + |value| { value.exp_u64(*degree as u64) }, + |ext| ScalarType::Ext(ext), + |base| ScalarType::Base(base) + ) + } + } + } + + pub fn entry(&self, challenges: &[E]) -> E { + match self { + Constant::Challenge(index) => challenges[*index], + _ => unreachable!(), + } + } + + pub fn entry_mut<'a, E: ExtensionField>(&self, challenges: &'a mut [E]) -> &'a mut E { + match self { + Constant::Challenge(index) => &mut challenges[*index], + _ => unreachable!(), + } + } +} + +impl Witness { + pub fn is_ext(&self) -> bool { + match self { + Witness::BasePoly(_) => false, + Witness::ExtPoly(_) => true, + Witness::EqPoly(_) => true, + } + } + + pub fn evaluate( + &self, + base_mle_evals: &[E], + ext_mle_evals: &[E], + out_point: &[&[E]], + in_point: &[E], + ) -> E { + match self { + Witness::BasePoly(index) => base_mle_evals[*index], + Witness::ExtPoly(index) => ext_mle_evals[*index], + Witness::EqPoly(index) => eq_eval(out_point[*index], in_point), + } + } + + pub fn base<'a, T>(&self, base_mle_evals: &'a [T]) -> &'a T { + match self { + Witness::BasePoly(index) => &base_mle_evals[*index], + _ => unreachable!(), + } + } + + pub fn base_mut<'a, T>(&self, base_mle_evals: &'a mut [T]) -> &'a mut T { + match self { + Witness::BasePoly(index) => &mut base_mle_evals[*index], + _ => unreachable!(), + } + } + + pub fn ext<'a, T>(&self, ext_mle_evals: &'a [T]) -> &'a T { + match self { + Witness::ExtPoly(index) => &ext_mle_evals[*index], + _ => unreachable!(), + } + } + + pub fn ext_mut<'a, T>(&self, ext_mle_evals: &'a mut [T]) -> &'a mut T { + match self { + Witness::ExtPoly(index) => &mut ext_mle_evals[*index], + _ => unreachable!(), + } + } +} + +/// Compute the univariate polynomial evaluated on 1..degree. +#[inline] +fn uni_poly_helper(mle: &[F], size: usize, degree: usize) -> Vec> { + mle.chunks(2) + .take(size >> 1) + .map(|p| { + let start = p[0]; + let step = p[1] - start; + (0..degree) + .scan(start, |state, _| { + *state += step; + Some(*state) + }) + .collect_vec() + }) + .collect_vec() +} + +#[cfg(test)] +mod test { + use crate::field_vec; + use p3_field::PrimeCharacteristicRing; + use p3_goldilocks::Goldilocks as F; + + #[test] + fn test_uni_poly_helper() { + // (x + 2), (3x + 4), (5x + 6), (7x + 8) + let mle = field_vec![F, 2, 3, 4, 7, 6, 11, 8, 15, 11, 13, 17, 19, 23, 29, 31, 37]; + let size = 8; + let degree = 3; + let expected = vec![ + field_vec![F, 3, 4, 5], + field_vec![F, 7, 10, 13], + field_vec![F, 11, 16, 21], + field_vec![F, 15, 22, 29], + ]; + let result = super::uni_poly_helper(&mle, size, degree); + assert_eq!(result, expected); + } +} diff --git a/subprotocols/src/expression/macros.rs b/subprotocols/src/expression/macros.rs new file mode 100644 index 000000000..8930be70f --- /dev/null +++ b/subprotocols/src/expression/macros.rs @@ -0,0 +1,100 @@ +#[macro_export] +macro_rules! op_by_type { + ($ele_type:ident, $ele:ident, |$x:ident| $op:expr, |$y_ext:ident| $convert_ext:expr, |$y_base:ident| $convert_base:expr) => { + match $ele { + $ele_type::Base($x) => { + let $y_base = $op; + $convert_base + } + $ele_type::Ext($x) => { + let $y_ext = $op; + $convert_ext + } + } + }; + + ($ele_type:ident, $ele:ident, |$x:ident| $op:expr, |$y_base:ident| $convert_base:expr) => { + match $ele { + $ele_type::Base($x) => { + let $y_base = $op; + $convert_base + } + $ele_type::Ext($x) => $op, + } + }; + + ($ele_type:ident, $ele:ident, |$x:ident| $op:expr) => { + match $ele { + $ele_type::Base($x) => $op, + $ele_type::Ext($x) => $op, + } + }; +} + +#[macro_export] +macro_rules! define_commutative_op_mle2 { + ($ele_type:ident, $trait_type:ident, $func_type:ident, |$x:ident, $y:ident| $op:expr) => { + impl $trait_type for $ele_type { + type Output = Self; + + fn $func_type(self, other: Self) -> Self::Output { + #[allow(unused)] + match (self, other) { + ($ele_type::Base(mut $x), $ele_type::Base($y)) => $ele_type::Base($op), + ($ele_type::Ext(mut $x), $ele_type::Base($y)) + | ($ele_type::Base($y), $ele_type::Ext(mut $x)) => $ele_type::Ext($op), + ($ele_type::Ext(mut $x), $ele_type::Ext($y)) => $ele_type::Ext($op), + } + } + } + + // impl<'a, E: ExtensionField> $trait_type<&'a Self> for $ele_type { + // type Output = Self; + + // fn $func_type(self, other: &'a Self) -> Self::Output { + // #[allow(unused)] + // match (self, other) { + // ($ele_type::Base(mut $x), $ele_type::Base($y)) => $ele_type::Base($op), + // ($ele_type::Ext(mut $x), $ele_type::Base($y)) => $ele_type::Ext($op), + // ($ele_type::Base($y), $ele_type::Ext($x)) => { + // let mut $x = $x.clone(); + // $ele_type::Ext($op) + // } + // ($ele_type::Ext(mut $x), $ele_type::Ext($y)) => $ele_type::Ext($op), + // } + // } + // } + }; +} + +#[macro_export] +macro_rules! define_op_mle2 { + ($ele_type:ident, $trait_type:ident, $func_type:ident, |$x:ident, $y:ident| $op:expr) => { + impl $trait_type for $ele_type { + type Output = Self; + + fn $func_type(self, other: Self) -> Self::Output { + let $x = self; + let $y = other; + $op + } + } + }; +} + +#[macro_export] +macro_rules! define_op_mle { + ($ele_type:ident, $trait_type:ident, $func_type:ident, |$x:ident| $op:expr) => { + impl $trait_type for $ele_type { + type Output = Self; + + fn $func_type(self) -> Self::Output { + #[allow(unused)] + match (self) { + $ele_type::Base(mut $x) => $ele_type::Base($op), + $ele_type::Ext(mut $x) => $ele_type::Ext($op), + } + } + } + }; +} diff --git a/subprotocols/src/expression/op.rs b/subprotocols/src/expression/op.rs new file mode 100644 index 000000000..1690c24e4 --- /dev/null +++ b/subprotocols/src/expression/op.rs @@ -0,0 +1,81 @@ +use std::{ + cmp::max, + ops::{Add, Mul, Neg, Sub}, +}; + +use ff_ext::ExtensionField; +use itertools::zip_eq; + +use crate::{define_commutative_op_mle2, define_op_mle, define_op_mle2}; + +use super::{Expression, ScalarType, UniPolyVectorType, VectorType}; + +impl Add for Expression { + type Output = Self; + + fn add(self, other: Self) -> Self { + let degree = max(self.degree(), other.degree()); + Expression::Sum(Box::new(self), Box::new(other), degree) + } +} + +impl Mul for Expression { + type Output = Self; + + fn mul(self, other: Self) -> Self { + #[allow(clippy::suspicious_arithmetic_impl)] + let degree = self.degree() + other.degree(); + Expression::Product(Box::new(self), Box::new(other), degree) + } +} + +impl Neg for Expression { + type Output = Self; + + fn neg(self) -> Self { + let deg = self.degree(); + Expression::Neg(Box::new(self), deg) + } +} + +impl Sub for Expression { + type Output = Self; + + fn sub(self, other: Self) -> Self { + self + (-other) + } +} + +define_commutative_op_mle2!(UniPolyVectorType, Add, add, |x, y| { + zip_eq(&mut x, y).for_each(|(x, y)| zip_eq(x, y).for_each(|(x, y)| *x += y)); + x +}); +define_commutative_op_mle2!(UniPolyVectorType, Mul, mul, |x, y| { + zip_eq(&mut x, y).for_each(|(x, y)| zip_eq(x, y).for_each(|(x, y)| *x *= y)); + x +}); +define_op_mle2!(UniPolyVectorType, Sub, sub, |x, y| x + (-y)); +define_op_mle!(UniPolyVectorType, Neg, neg, |x| { + x.iter_mut() + .for_each(|x| x.iter_mut().for_each(|x| *x = -(*x))); + x +}); + +define_commutative_op_mle2!(VectorType, Add, add, |x, y| { + zip_eq(&mut x, y).for_each(|(x, y)| *x += y); + x +}); +define_commutative_op_mle2!(VectorType, Mul, mul, |x, y| { + zip_eq(&mut x, y).for_each(|(x, y)| *x *= y); + x +}); +define_op_mle2!(VectorType, Sub, sub, |x, y| x + (-y)); +define_op_mle!(VectorType, Neg, neg, |x| { + x.iter_mut().for_each(|x| *x = -(*x)); + x +}); + +define_commutative_op_mle2!(ScalarType, Add, add, |x, y| x + y); +define_commutative_op_mle2!(ScalarType, Mul, mul, |x, y| x * y); +define_op_mle2!(ScalarType, Sub, sub, |x, y| x + (-y)); +define_op_mle!(ScalarType, Neg, neg, |x| -x); diff --git a/subprotocols/src/lib.rs b/subprotocols/src/lib.rs new file mode 100644 index 000000000..a86f12c8f --- /dev/null +++ b/subprotocols/src/lib.rs @@ -0,0 +1,9 @@ +pub mod error; +pub mod expression; +pub mod points; +pub mod sumcheck; +pub mod utils; +pub mod zerocheck; + +#[macro_use] +pub mod test_utils; diff --git a/subprotocols/src/points.rs b/subprotocols/src/points.rs new file mode 100644 index 000000000..9d128e0ac --- /dev/null +++ b/subprotocols/src/points.rs @@ -0,0 +1,75 @@ +use std::sync::Arc; + +use ff_ext::ExtensionField; +use itertools::izip; + +use crate::expression::Point; + +pub trait PointBeforeMerge { + fn point_before_merge(&self, pos: &[usize]) -> Point; +} + +pub trait PointBeforePartition { + fn point_before_partition( + &self, + pos_and_var_ids: &[(usize, usize)], + challenges: &[E], + ) -> Point; +} + +/// Suppose we have several vectors v_0, ..., v_{N-1}, and want to merge it through n = log(N) variables, +/// x_0, ..., x_{n-1}, at the positions i_0, ..., i_{n - 1}. Suppose the output point is P, then the point +/// before it is P_0, ..., P_{i_0 - 1}, P_{i_0 + 1}, ..., P_{i_1 - 1}, ..., P_{i_{n - 1} + 1}, ..., P_{N - 1}. +impl PointBeforeMerge for Point { + fn point_before_merge(&self, pos: &[usize]) -> Point { + if pos.is_empty() { + return self.clone(); + } + + assert!(izip!(pos.iter(), pos.iter().skip(1)).all(|(i, j)| i < j)); + + let mut new_point = Vec::with_capacity(self.len() - pos.len()); + let mut i = 0usize; + for (j, p) in self.iter().enumerate() { + if j != pos[i] { + new_point.push(*p); + } else { + i += 1; + } + } + + Arc::new(new_point) + } +} + +/// Suppose we have a vector v, and want to partition it through n = log(N) variables +/// x_0, ..., x_{n-1}, at the positions i_0, ..., i_{n - 1}. Suppose the output point +/// is P, then the point before it is P after calling P.insert(i_0, x_0), ... +impl PointBeforePartition for Point { + fn point_before_partition( + &self, + pos_and_var_ids: &[(usize, usize)], + challenges: &[E], + ) -> Point { + if pos_and_var_ids.is_empty() { + return self.clone(); + } + + assert!( + izip!(pos_and_var_ids.iter(), pos_and_var_ids.iter().skip(1)).all(|(i, j)| i.0 < j.0) + ); + + let mut new_point = Vec::with_capacity(self.len() + pos_and_var_ids.len()); + let mut i = 0usize; + for (j, p) in self.iter().enumerate() { + if i + j != pos_and_var_ids[i].0 { + new_point.push(*p); + } else { + new_point.push(challenges[pos_and_var_ids[i].1]); + i += 1; + } + } + + Arc::new(new_point) + } +} diff --git a/subprotocols/src/sumcheck.rs b/subprotocols/src/sumcheck.rs new file mode 100644 index 000000000..73c331f57 --- /dev/null +++ b/subprotocols/src/sumcheck.rs @@ -0,0 +1,454 @@ +use std::{iter, mem, sync::Arc, vec}; + +use ark_std::log2; +use ff_ext::ExtensionField; +use itertools::chain; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use transcript::Transcript; + +use crate::{ + error::VerifierError, + expression::{Expression, Point}, + utils::eq_vecs, +}; + +use super::utils::{fix_variables_ext, fix_variables_inplace, interpolate_uni_poly}; + +/// This is an randomly combined sumcheck protocol for the following equation: +/// \sigma = \sum_x expr(x) +pub struct SumcheckProverState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + /// Expression. + expr: Expression, + + /// Extension field mles. + ext_mles: Vec<&'a mut [E]>, + /// Base field mles after the first round. + base_mles_after: Vec>, + /// Base field mles. + base_mles: Vec<&'a [E::BaseField]>, + /// Eq polys + eqs: Vec>, + /// Challenges occurred in expressions + challenges: &'a [E], + + transcript: &'a mut Trans, + + degree: usize, + num_vars: usize, +} + +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct SumcheckProof { + /// Messages for each round. + pub univariate_polys: Vec>>, + pub ext_mle_evals: Vec, + pub base_mle_evals: Vec, +} + +pub struct SumcheckProverOutput { + pub proof: SumcheckProof, + pub point: Point, +} + +impl<'a, E, Trans> SumcheckProverState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + expr: Expression, + points: &[&[E]], + ext_mles: Vec<&'a mut [E]>, + base_mles: Vec<&'a [E::BaseField]>, + challenges: &'a [E], + transcript: &'a mut Trans, + ) -> Self { + assert!(!(ext_mles.is_empty() && base_mles.is_empty())); + + let num_vars = if !ext_mles.is_empty() { + log2(ext_mles[0].len()) as usize + } else { + log2(base_mles[0].len()) as usize + }; + + // The length of all mles should be 2^{num_vars}. + assert!(ext_mles.iter().all(|mle| mle.len() == 1 << num_vars)); + assert!(base_mles.iter().all(|mle| mle.len() == 1 << num_vars)); + + let degree = expr.degree(); + + let eqs = eq_vecs(points.iter().copied(), &vec![E::ONE; points.len()]); + + Self { + expr, + ext_mles, + base_mles_after: vec![], + base_mles, + eqs, + challenges, + transcript, + num_vars, + degree, + } + } + + pub fn prove(mut self) -> SumcheckProverOutput { + let (univariate_polys, point) = (0..self.num_vars) + .map(|round| { + let round_msg = self.compute_univariate_poly(round); + self.transcript.append_field_element_exts(&round_msg); + + let r = self + .transcript + .sample_and_append_challenge(b"sumcheck round") + .elements; + self.update_mles(&r, round); + (vec![round_msg], r) + }) + .unzip(); + let point = Arc::new(point); + + // Send the final evaluations + let SumcheckProverState { + ext_mles, + base_mles_after, + base_mles, + .. + } = self; + let ext_mle_evaluations = ext_mles.into_iter().map(|mle| mle[0]).collect(); + let base_mle_evaluations = if !base_mles.is_empty() { + base_mles.into_iter().map(|mle| E::from(mle[0])).collect() + } else { + base_mles_after.into_iter().map(|mle| mle[0]).collect() + }; + + SumcheckProverOutput { + proof: SumcheckProof { + univariate_polys, + ext_mle_evals: ext_mle_evaluations, + base_mle_evals: base_mle_evaluations, + }, + point, + } + } + + /// Compute f(X) = r^0 \sum_x expr_0(X || x) + r^1 \sum_x expr_1(X || x) + ... + fn compute_univariate_poly(&self, round: usize) -> Vec { + self.expr.sumcheck_uni_poly( + &self.ext_mles, + &self.base_mles_after, + &self.base_mles, + &self.eqs, + self.challenges, + 1 << (self.num_vars - round), + self.degree, + ) + } + + fn update_mles(&mut self, r: &E, round: usize) { + // fix variables of eq polynomials + self.eqs.iter_mut().for_each(|eq| { + fix_variables_inplace(eq, r); + }); + // fix variables of ext field polynomials. + self.ext_mles.iter_mut().for_each(|mle| { + fix_variables_inplace(mle, r); + }); + // fix variables of base field polynomials. + if round == 0 { + self.base_mles_after = mem::take(&mut self.base_mles) + .into_iter() + .map(|mle| fix_variables_ext(mle, r)) + .collect(); + } else { + self.base_mles_after + .iter_mut() + .for_each(|mle| fix_variables_inplace(mle, r)); + } + } +} + +pub struct SumcheckVerifierState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + sigma: E, + expr: Expression, + proof: SumcheckProof, + expr_names: Vec, + challenges: &'a [E], + transcript: &'a mut Trans, + out_points: Vec<&'a [E]>, +} + +pub struct SumcheckClaims { + pub in_point: Point, + pub base_mle_evals: Vec, + pub ext_mle_evals: Vec, +} + +impl<'a, E, Trans> SumcheckVerifierState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + pub fn new( + sigma: E, + expr: Expression, + out_points: Vec<&'a [E]>, + proof: SumcheckProof, + challenges: &'a [E], + transcript: &'a mut Trans, + expr_names: Vec, + ) -> Self { + // Fill in missing debug data + let mut expr_names = expr_names; + expr_names.resize(1, "nothing".to_owned()); + Self { + sigma, + expr, + proof, + challenges, + transcript, + out_points, + expr_names, + } + } + + pub fn verify(self) -> Result, VerifierError> { + let SumcheckVerifierState { + sigma, + expr, + proof, + challenges, + transcript, + out_points, + expr_names, + } = self; + let SumcheckProof { + univariate_polys, + ext_mle_evals, + base_mle_evals, + } = proof; + + let (in_point, expected_claim) = univariate_polys.into_iter().fold( + (vec![], sigma), + |(mut last_point, last_sigma), msg| { + let msg = msg.into_iter().next().unwrap(); + transcript.append_field_element_exts(&msg); + + let len = msg.len() + 1; + let eval_at_0 = last_sigma - msg[0]; + + // Evaluations on degree, degree - 1, ..., 1, 0. + let evals_iter_rev = chain![msg.into_iter().rev(), iter::once(eval_at_0)]; + + let r = transcript + .sample_and_append_challenge(b"sumcheck round") + .elements; + let sigma = interpolate_uni_poly(evals_iter_rev, len, r); + last_point.push(r); + (last_point, sigma) + }, + ); + + // Check the final evaluations. + let got_claim = expr.evaluate( + &ext_mle_evals, + &base_mle_evals, + &out_points, + &in_point, + challenges, + ); + + if expected_claim != got_claim { + return Err(VerifierError::ClaimNotMatch( + expr, + expected_claim, + got_claim, + expr_names[0].clone(), + )); + } + + let in_point = Arc::new(in_point); + Ok(SumcheckClaims { + in_point, + base_mle_evals, + ext_mle_evals, + }) + } +} + +#[cfg(test)] +mod test { + use std::array; + + use ff_ext::ExtensionField; + use itertools::{Itertools, izip}; + use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; + use p3_goldilocks::Goldilocks as F; + use transcript::BasicTranscript; + + type E = BinomialExtensionField; + + use crate::{ + expression::{Constant, Expression, Witness}, + field_vec, + utils::eq_vecs, + }; + + use super::{SumcheckProverOutput, SumcheckProverState, SumcheckVerifierState}; + + #[allow(clippy::too_many_arguments)] + fn run( + points: Vec<&[E]>, + expr: Expression, + ext_mle_refs: Vec<&mut [E]>, + base_mle_refs: Vec<&[E::BaseField]>, + challenges: Vec, + + sigma: E, + ) { + let mut prover_transcript = BasicTranscript::new(b"test"); + let prover = SumcheckProverState::new( + expr.clone(), + &points, + ext_mle_refs, + base_mle_refs, + &challenges, + &mut prover_transcript, + ); + + let SumcheckProverOutput { proof, .. } = prover.prove(); + + let mut verifier_transcript = BasicTranscript::new(b"test"); + let verifier = SumcheckVerifierState::new( + sigma, + expr, + points, + proof, + &challenges, + &mut verifier_transcript, + vec![], + ); + + verifier.verify().expect("verification failed"); + } + + #[test] + fn test_sumcheck_trivial() { + let f = field_vec![F, 2]; + let g = field_vec![F, 3]; + let out_point = vec![]; + + let base_mle_refs = vec![f.as_slice(), g.as_slice()]; + let f = Expression::Wit(Witness::BasePoly(0)); + let g = Expression::Wit(Witness::BasePoly(1)); + let expr = f * g; + + run( + vec![out_point.as_slice()], + expr, + vec![], + base_mle_refs, + vec![], + E::from_u64(6), + ); + } + + #[test] + fn test_sumcheck_simple() { + let f = field_vec![F, 1, 2, 3, 4]; + let ans = E::from(f.iter().fold(F::ZERO, |acc, x| acc + *x)); + let base_mle_refs = vec![f.as_slice()]; + let expr = Expression::Wit(Witness::BasePoly(0)); + + run(vec![], expr, vec![], base_mle_refs, vec![], ans); + } + + #[test] + fn test_sumcheck_logup() { + let point = field_vec![E, 2, 3]; + + let eqs = eq_vecs([point.as_slice()].into_iter(), &[E::ONE]); + + let d0 = field_vec![E, 1, 2, 3, 4]; + let d1 = field_vec![E, 5, 6, 7, 8]; + let n0 = field_vec![E, 9, 10, 11, 12]; + let n1 = field_vec![E, 13, 14, 15, 16]; + + let challenges = vec![E::from_u64(7)]; + let ans = izip!(&eqs[0], &d0, &d1, &n0, &n1) + .map(|(eq, d0, d1, n0, n1)| *eq * (*d0 * *d1 + challenges[0] * (*d0 * *n1 + *d1 * *n0))) + .sum(); + + let mut ext_mles = [d0, d1, n0, n1]; + let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); + let eq = Expression::Wit(Witness::EqPoly(0)); + let beta = Expression::Const(Constant::Challenge(0)); + + let expr = eq * (d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0)); + + let ext_mle_refs = ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); + run( + vec![point.as_slice()], + expr, + ext_mle_refs, + vec![], + challenges, + ans, + ); + } + + #[test] + fn test_sumcheck_multi_points() { + let challenges = vec![E::from_u64(2)]; + + let points = [ + field_vec![E, 2, 3], + field_vec![E, 5, 7], + field_vec![E, 2, 5], + ]; + let point_refs = points.iter().map(|v| v.as_slice()).collect_vec(); + + let eqs = eq_vecs(point_refs.clone().into_iter(), &vec![E::ONE; points.len()]); + + let d0 = field_vec![F, 1, 2, 3, 4]; + let d1 = field_vec![F, 5, 6, 7, 8]; + let n0 = field_vec![F, 9, 10, 11, 12]; + let n1 = field_vec![F, 13, 14, 15, 16]; + + let ans_0 = izip!(&eqs[0], &d0, &d1) + .map(|(eq0, d0, d1)| *eq0 * *d0 * *d1) + .sum::(); + let ans_1 = izip!(&eqs[1], &d0, &n1) + .map(|(eq1, d0, n1)| *eq1 * *d0 * *n1) + .sum::(); + let ans_2 = izip!(&eqs[2], &d1, &n0) + .map(|(eq2, d1, n0)| *eq2 * *d1 * *n0) + .sum::(); + let ans = (ans_0 * challenges[0] + ans_1) * challenges[0] + ans_2; + + let base_mles = [d0, d1, n0, n1]; + let [eq0, eq1, eq2] = array::from_fn(|i| Expression::Wit(Witness::EqPoly(i))); + let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::BasePoly(i))); + let rlc_challenge = Expression::Const(Constant::Challenge(0)); + + let expr = (eq0 * d0.clone() * d1.clone() * rlc_challenge.clone() + eq1 * d0 * n1) + * rlc_challenge + + eq2 * d1 * n0; + + let base_mle_refs = base_mles.iter().map(|v| v.as_slice()).collect_vec(); + run(point_refs, expr, vec![], base_mle_refs, challenges, ans); + } +} diff --git a/subprotocols/src/test_utils.rs b/subprotocols/src/test_utils.rs new file mode 100644 index 000000000..cb0812e1a --- /dev/null +++ b/subprotocols/src/test_utils.rs @@ -0,0 +1,46 @@ +use ff_ext::{ExtensionField, FromUniformBytes}; +use itertools::Itertools; +use p3_field::Field; +use rand::RngCore; + +pub fn random_point(mut rng: impl RngCore, num_vars: usize) -> Vec { + (0..num_vars).map(|_| E::random(&mut rng)).collect_vec() +} + +pub fn random_vec(mut rng: impl RngCore, len: usize) -> Vec { + (0..len).map(|_| E::random(&mut rng)).collect_vec() +} + +pub fn random_poly(mut rng: impl RngCore, num_vars: usize) -> Vec { + (0..1 << num_vars) + .map(|_| E::random(&mut rng)) + .collect_vec() +} + +#[macro_export] +macro_rules! field_vec { + () => ( + $crate::vec::Vec::new() + ); + ($field_type:ident; $elem:expr; $n:expr) => ( + $crate::vec::from_elem({ + if $x < 0 { + -$field_type::from((-$x) as u64) + } else { + $field_type::from($x as u64) + } + }, $n) + ); + ($field_type:ident, $($x:expr),+ $(,)?) => ( + <[_]>::into_vec( + std::boxed::Box::new([$({ + let x = $x as i64; + if $x < 0 { + -$field_type::from_u64((-x) as u64) + } else { + $field_type::from_u64(x as u64) + } + }),+]) + ) + ); +} diff --git a/subprotocols/src/utils.rs b/subprotocols/src/utils.rs new file mode 100644 index 000000000..025fdf3c3 --- /dev/null +++ b/subprotocols/src/utils.rs @@ -0,0 +1,235 @@ +use std::{iter, ops::Mul}; + +use ff_ext::ExtensionField; +use itertools::{Itertools, chain, izip}; +use multilinear_extensions::virtual_poly::build_eq_x_r_vec_with_scalar; +use p3_field::Field; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; + +pub fn i64_to_field(i: i64) -> F { + if i < 0 { + -F::from_u64(i.unsigned_abs()) + } else { + F::from_u64(i as u64) + } +} + +pub fn power_list(ele: &F, size: usize) -> Vec { + (0..size) + .scan(F::ONE, |state, _| { + let last = *state; + *state *= *ele; + Some(last) + }) + .collect() +} + +/// Grand product of ele, start from 1, with length ele.len() + 1. +pub fn grand_product(ele: &[F]) -> Vec { + let one = F::ONE; + chain![iter::once(&one), ele.iter()] + .scan(F::ONE, |state, e| { + *state *= *e; + Some(*state) + }) + .collect() +} + +pub fn eq_vecs<'a, E: ExtensionField>( + points: impl Iterator, + scalars: &[E], +) -> Vec> { + izip!(points, scalars) + .map(|(point, scalar)| build_eq_x_r_vec_with_scalar(point, *scalar)) + .collect_vec() +} + +#[inline(always)] +pub fn eq(x: &F, y: &F) -> F { + // x * y + (1 - x) * (1 - y) + let xy = *x * *y; + xy + xy - *x - *y + F::ONE +} + +pub fn fix_variables_ext(base_mle: &[E::BaseField], r: &E) -> Vec { + base_mle + .par_iter() + .chunks(2) + .with_min_len(64) + .map(|buf| *r * (*buf[1] - *buf[0]) + *buf[0]) + .collect() +} + +pub fn fix_variables_inplace(ext_mle: &mut [E], r: &E) { + ext_mle + .par_iter_mut() + .chunks(2) + .with_min_len(64) + .for_each(|mut buf| *buf[0] = *buf[0] + (*buf[1] - *buf[0]) * *r); + // sequentially update buf[b1, b2,..bt] = buf[b1, b2,..bt, 0] + let half_len = ext_mle.len() >> 1; + for index in 0..half_len { + ext_mle[index] = ext_mle[index << 1]; + } +} + +pub fn evaluate_mle_inplace(mle: &mut [E], point: &[E]) -> E { + for r in point { + fix_variables_inplace(mle, r); + } + mle[0] +} + +pub fn evaluate_mle_ext(mle: &[E::BaseField], point: &[E]) -> E { + let mut ext_mle = fix_variables_ext(mle, &point[0]); + evaluate_mle_inplace(&mut ext_mle, &point[1..]) +} + +/// Interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this +/// polynomial at `eval_at`: +/// +/// \sum_{i=0}^len p_i * (\prod_{j!=i} (eval_at - j)/(i-j) ) +/// +/// This implementation is linear in number of inputs in terms of field +/// operations. It also has a quadratic term in primitive operations which is +/// negligible compared to field operations. +/// TODO: The quadratic term can be removed by precomputing the lagrange +/// coefficients. +pub(crate) fn interpolate_uni_poly>( + p_iter_rev: impl Iterator, + len: usize, + eval_at: E, +) -> E { + let mut evals = vec![eval_at]; + let mut prod = eval_at; + + // `prod = \prod_{j} (eval_at - j)` + for j in 1..len { + let tmp = eval_at - E::from_u64(j as u64); + evals.push(tmp); + prod *= tmp; + } + let mut res = E::ZERO; + // we want to compute \prod (j!=i) (i-j) for a given i + // + // we start from the last step, which is + // denom[len-1] = (len-1) * (len-2) *... * 2 * 1 + // the step before that is + // denom[len-2] = (len-2) * (len-3) * ... * 2 * 1 * -1 + // and the step before that is + // denom[len-3] = (len-3) * (len-4) * ... * 2 * 1 * -1 * -2 + // + // i.e., for any i, the one before this will be derived from + // denom[i-1] = denom[i] * (len-i) / i + // + // that is, we only need to store + // - the last denom for i = len-1, and + // - the ratio between current step and fhe last step, which is the product of (len-i) / i from + // all previous steps and we store this product as a fraction number to reduce field + // divisions. + + let mut denom_up = field_factorial::(len - 1); + let mut denom_down = F::ONE; + + for (j, p_i) in p_iter_rev.enumerate() { + let i = len - j - 1; + res += prod * p_i * denom_down * (evals[i] * denom_up).inverse(); + + // compute denom for the next step is current_denom * (len-i)/i + if i != 0 { + denom_up *= -F::from_u64((j + 1) as u64); + denom_down *= F::from_u64(i as u64); + } + } + res +} + +/// compute the factorial(a) = 1 * 2 * ... * a +#[inline] +fn field_factorial(a: usize) -> F { + let mut res = F::ONE; + for i in 2..=a { + res *= F::from_u64(i as u64); + } + res +} + +#[cfg(test)] +mod test { + use itertools::Itertools; + use multilinear_extensions::virtual_poly::eq_eval; + use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; + use p3_goldilocks::Goldilocks as F; + + use crate::field_vec; + + use super::*; + + type E = BinomialExtensionField; + + #[test] + fn test_power_list() { + let ele = F::from_u64(3u64); + let list = power_list(&ele, 4); + assert_eq!(list, field_vec![F, 1, 3, 9, 27]); + } + + #[test] + fn test_grand_product() { + let ele = field_vec![F, 2, 3, 4, 5]; + let expected = field_vec![F, 1, 2, 6, 24, 120]; + assert_eq!(grand_product(&ele), expected); + } + + #[test] + fn test_eq_vecs() { + let points = [field_vec![E, 2, 3, 5], field_vec![E, 7, 11, 13]]; + let point_refs = points.iter().map(|p| p.as_slice()).collect_vec(); + + let scalars = field_vec![E, 3, 5]; + + let eq_evals = eq_vecs(point_refs.into_iter(), &scalars); + + let expected = vec![ + field_vec![E, -24, 48, 36, -72, 30, -60, -45, 90], + field_vec![E, -3600, 4200, 3960, -4620, 3900, -4550, -4290, 5005], + ]; + assert_eq!(eq_evals, expected); + } + + #[test] + fn test_eq_eval() { + let xs = field_vec![E, 2, 3, 5]; + let ys = field_vec![E, 7, 11, 13]; + let expected = E::from_u64(119780); + assert_eq!(eq_eval(&xs, &ys), expected); + } + + #[test] + fn test_fix_variables_ext() { + let base_mle = field_vec![F, 1, 2, 3, 4, 5, 6]; + let r = E::from_u64(3u64); + let expected = field_vec![E, 4, 6, 8]; + assert_eq!(fix_variables_ext(&base_mle, &r), expected); + } + + #[test] + fn test_fix_variables_inplace() { + let mut ext_mle = field_vec![E, 1, 2, 3, 4, 5, 6]; + let r = E::from_u64(3u64); + fix_variables_inplace(&mut ext_mle, &r); + let expected = field_vec![E, 4, 6, 8]; + assert_eq!(ext_mle[..3], expected); + } + + #[test] + fn test_interpolate_uni_poly() { + // p(x) = x^3 + 2x^2 + 3x + 4 + let p_iter = field_vec![F, 4, 10, 26, 58].into_iter().rev(); + let eval_at = E::from_u64(11); + let expected = E::from_u64(1610); + assert_eq!(interpolate_uni_poly(p_iter, 4, eval_at), expected); + } +} diff --git a/subprotocols/src/zerocheck.rs b/subprotocols/src/zerocheck.rs new file mode 100644 index 000000000..00f2fd296 --- /dev/null +++ b/subprotocols/src/zerocheck.rs @@ -0,0 +1,495 @@ +use std::{iter, mem, sync::Arc, vec}; + +use ark_std::log2; +use ff_ext::ExtensionField; +use itertools::{Itertools, chain, izip, zip_eq}; +use p3_field::batch_multiplicative_inverse; +use transcript::Transcript; + +use crate::{ + error::VerifierError, + expression::Expression, + sumcheck::{SumcheckProof, SumcheckProverOutput}, +}; + +use super::{ + sumcheck::SumcheckClaims, + utils::{ + eq_vecs, fix_variables_ext, fix_variables_inplace, grand_product, interpolate_uni_poly, + }, +}; + +/// This is an randomly combined zerocheck protocol for the following equation: +/// \sigma = \sum_x (r^0 eq_0(X) \cdot expr_0(x) + r^1 eq_1(X) \cdot expr_1(x) + ...) +pub struct ZerocheckProverState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + /// Expressions and corresponding half eq reference. + exprs: Vec<(Expression, Vec)>, + + /// Extension field mles. + ext_mles: Vec<&'a mut [E]>, + /// Base field mles after the first round. + base_mles_after: Vec>, + /// Base field mles. + base_mles: Vec<&'a [E::BaseField]>, + /// Challenges occurred in expressions + challenges: &'a [E], + /// For each point in points, the inverse of prod_{j < i}(1 - point[i]) for 0 <= i < point.len(). + grand_prod_of_not_inv: Vec>, + + transcript: &'a mut Trans, + + num_vars: usize, +} + +impl<'a, E, Trans> ZerocheckProverState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + #[allow(clippy::too_many_arguments)] + pub fn new( + exprs: Vec, + points: &[&[E]], + ext_mles: Vec<&'a mut [E]>, + base_mles: Vec<&'a [E::BaseField]>, + challenges: &'a [E], + transcript: &'a mut Trans, + ) -> Self { + assert!(!(ext_mles.is_empty() && base_mles.is_empty())); + + let num_vars = if !ext_mles.is_empty() { + log2(ext_mles[0].len()) as usize + } else { + log2(base_mles[0].len()) as usize + }; + + // For each point, compute eq(point[1..], b) for b in [0, 2^{num_vars - 1}). + let (exprs, grand_prod_of_not_inv) = if num_vars > 0 { + let half_eq_evals = eq_vecs( + points.iter().map(|point| &point[1..]), + &vec![E::ONE; exprs.len()], + ); + let exprs = zip_eq(exprs, half_eq_evals).collect_vec(); + let grand_prod_of_not_inv = points + .iter() + .flat_map(|point| point[1..].iter().map(|p| E::ONE - *p).collect_vec()) + .collect_vec(); + let grand_prod_of_not_inv = batch_multiplicative_inverse(&grand_prod_of_not_inv); + let (_, grand_prod_of_not_inv) = + points + .iter() + .fold((0usize, vec![]), |(start, mut last_vec), point| { + let end = start + point.len() - 1; + last_vec.push(grand_product(&grand_prod_of_not_inv[start..end])); + (end, last_vec) + }); + (exprs, grand_prod_of_not_inv) + } else { + let expr = exprs.into_iter().map(|expr| (expr, vec![])).collect_vec(); + (expr, vec![]) + }; + + // The length of all mles should be 2^{num_vars}. + assert!(ext_mles.iter().all(|mle| mle.len() == 1 << num_vars)); + assert!(base_mles.iter().all(|mle| mle.len() == 1 << num_vars)); + + Self { + exprs, + ext_mles, + base_mles_after: vec![], + base_mles, + challenges, + grand_prod_of_not_inv, + transcript, + num_vars, + } + } + + pub fn prove(mut self) -> SumcheckProverOutput { + let (univariate_polys, point) = (0..self.num_vars) + .map(|round| { + let round_msg = self.compute_univariate_poly(round); + round_msg + .iter() + .for_each(|poly| self.transcript.append_field_element_exts(poly)); + + let r = self + .transcript + .sample_and_append_challenge(b"sumcheck round") + .elements; + self.update_mles(&r, round); + (round_msg, r) + }) + .unzip(); + let point = Arc::new(point); + + // Send the final evaluations + let ZerocheckProverState { + ext_mles, + base_mles_after, + base_mles, + .. + } = self; + let ext_mle_evaluations = ext_mles.into_iter().map(|mle| mle[0]).collect(); + let base_mle_evaluations = if !base_mles.is_empty() { + base_mles.into_iter().map(|mle| E::from(mle[0])).collect() + } else { + base_mles_after.into_iter().map(|mle| mle[0]).collect() + }; + + SumcheckProverOutput { + proof: SumcheckProof { + univariate_polys, + ext_mle_evals: ext_mle_evaluations, + base_mle_evals: base_mle_evaluations, + }, + point, + } + } + + /// Compute f_i(X) = \sum_x eq_i(x) expr_i(X || x) + fn compute_univariate_poly(&self, round: usize) -> Vec> { + izip!(&self.exprs, &self.grand_prod_of_not_inv) + .map(|((expr, half_eq_mle), coeff)| { + let mut uni_poly = expr.zerocheck_uni_poly( + &self.ext_mles, + &self.base_mles_after, + &self.base_mles, + self.challenges, + half_eq_mle.iter().step_by(1 << round), + 1 << (self.num_vars - round), + ); + uni_poly.iter_mut().for_each(|x| *x *= coeff[round]); + uni_poly + }) + .collect_vec() + } + + fn update_mles(&mut self, r: &E, round: usize) { + // fix variables of base field polynomials. + self.ext_mles.iter_mut().for_each(|mle| { + fix_variables_inplace(mle, r); + }); + if round == 0 { + self.base_mles_after = mem::take(&mut self.base_mles) + .into_iter() + .map(|mle| fix_variables_ext(mle, r)) + .collect(); + } else { + self.base_mles_after + .iter_mut() + .for_each(|mle| fix_variables_inplace(mle, r)); + } + } +} + +pub struct ZerocheckVerifierState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + sigmas: Vec, + inv_of_one_minus_points: Vec>, + exprs: Vec<(Expression, &'a [E])>, + proof: SumcheckProof, + expr_names: Vec, + challenges: &'a [E], + transcript: &'a mut Trans, +} + +impl<'a, E, Trans> ZerocheckVerifierState<'a, E, Trans> +where + E: ExtensionField, + Trans: Transcript, +{ + pub fn new( + sigmas: Vec, + exprs: Vec, + expr_names: Vec, + points: Vec<&'a [E]>, + proof: SumcheckProof, + challenges: &'a [E], + transcript: &'a mut Trans, + ) -> Self { + // Fill in missing debug data + let mut expr_names = expr_names; + expr_names.resize(exprs.len(), "nothing".to_owned()); + + let inv_of_one_minus_points = points + .iter() + .flat_map(|point| point.iter().map(|p| E::ONE - *p).collect_vec()) + .collect_vec(); + let inv_of_one_minus_points = batch_multiplicative_inverse(&inv_of_one_minus_points); + let (_, inv_of_one_minus_points) = + points + .iter() + .fold((0usize, vec![]), |(start, mut last_vec), point| { + let end = start + point.len(); + last_vec.push(inv_of_one_minus_points[start..start + point.len()].to_vec()); + (end, last_vec) + }); + + let exprs = zip_eq(exprs, points).collect_vec(); + Self { + sigmas, + inv_of_one_minus_points, + exprs, + proof, + challenges, + transcript, + expr_names, + } + } + + pub fn verify(self) -> Result, VerifierError> { + let ZerocheckVerifierState { + sigmas, + inv_of_one_minus_points, + exprs, + proof, + challenges, + transcript, + expr_names, + .. + } = self; + let SumcheckProof { + univariate_polys, + ext_mle_evals, + base_mle_evals, + } = proof; + + let (in_point, expected_claims) = univariate_polys.into_iter().enumerate().fold( + (vec![], sigmas), + |(mut last_point, last_sigmas), (round, round_msg)| { + round_msg + .iter() + .for_each(|poly| transcript.append_field_element_exts(poly)); + let r = transcript + .sample_and_append_challenge(b"sumcheck round") + .elements; + last_point.push(r); + + let sigmas = izip!(&exprs, &inv_of_one_minus_points, round_msg, last_sigmas) + .map(|((_, point), inv_of_one_minus_point, poly, last_sigma)| { + let len = poly.len() + 1; + // last_sigma = (1 - point[round]) * eval_at_0 + point[round] * eval_at_1 + // eval_at_0 = (last_sigma - point[round] * eval_at_1) * inv(1 - point[round]) + let eval_at_0 = if !poly.is_empty() { + (last_sigma - point[round] * poly[0]) * inv_of_one_minus_point[round] + } else { + last_sigma + }; + + // Evaluations on degree, degree - 1, ..., 1, 0. + let evals_iter_rev = chain![poly.into_iter().rev(), iter::once(eval_at_0)]; + + interpolate_uni_poly(evals_iter_rev, len, r) + }) + .collect_vec(); + + (last_point, sigmas) + }, + ); + + // Check the final evaluations. + assert_eq!(expr_names.len(), exprs.len()); + // assert_eq!(expected_claims.len(), expr_names.len()); + + for (expected_claim, (expr, _), expr_name) in izip!(expected_claims, exprs, expr_names) { + let got_claim = expr.evaluate(&ext_mle_evals, &base_mle_evals, &[], &[], challenges); + + if expected_claim != got_claim { + return Err(VerifierError::ClaimNotMatch( + expr, + expected_claim, + got_claim, + expr_name.clone(), + )); + } + } + + let in_point = Arc::new(in_point); + Ok(SumcheckClaims { + in_point, + ext_mle_evals, + base_mle_evals, + }) + } +} + +#[cfg(test)] +mod test { + use std::array; + + use ff_ext::ExtensionField; + use itertools::{Itertools, izip}; + use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; + use p3_goldilocks::Goldilocks as F; + use transcript::BasicTranscript; + + use crate::{ + expression::{Constant, Expression, Witness}, + field_vec, + sumcheck::SumcheckProverOutput, + }; + + use super::{ZerocheckProverState, ZerocheckVerifierState}; + + type E = BinomialExtensionField; + + #[allow(clippy::too_many_arguments)] + fn run<'a, E: ExtensionField>( + points: Vec<&[E]>, + exprs: Vec, + ext_mle_refs: Vec<&'a mut [E]>, + base_mle_refs: Vec<&'a [E::BaseField]>, + challenges: Vec, + + sigmas: Vec, + ) { + let mut prover_transcript = BasicTranscript::new(b"test"); + let prover = ZerocheckProverState::new( + exprs.clone(), + &points, + ext_mle_refs, + base_mle_refs, + &challenges, + &mut prover_transcript, + ); + + let SumcheckProverOutput { proof, .. } = prover.prove(); + + let mut verifier_transcript = BasicTranscript::new(b"test"); + let verifier = ZerocheckVerifierState::new( + sigmas, + exprs, + vec![], + points, + proof, + &challenges, + &mut verifier_transcript, + ); + + verifier.verify().expect("verification failed"); + } + + #[test] + fn test_zerocheck_trivial() { + let f = field_vec![F, 2]; + let g = field_vec![F, 3]; + let out_point = vec![]; + + let base_mle_refs = vec![f.as_slice(), g.as_slice()]; + let f = Expression::Wit(Witness::BasePoly(0)); + let g = Expression::Wit(Witness::BasePoly(1)); + let expr = f * g; + + run( + vec![out_point.as_slice()], + vec![expr], + vec![], + base_mle_refs, + vec![], + vec![E::from_u64(6)], + ); + } + + #[test] + fn test_zerocheck_simple() { + let f = field_vec![F, 1, 2, 3, 4, 5, 6, 7, 8]; + let out_point = field_vec![E, 2, 3, 5]; + let out_eq = field_vec![E, -8, 16, 12, -24, 10, -20, -15, 30]; + let ans = izip!(&out_eq, &f).fold(E::ZERO, |acc, (c, x)| acc + *c * *x); + + let base_mle_refs = vec![f.as_slice()]; + let expr = Expression::Wit(Witness::BasePoly(0)); + run( + vec![out_point.as_slice()], + vec![expr.clone()], + vec![], + base_mle_refs, + vec![], + vec![ans], + ); + } + + #[test] + fn test_zerocheck_logup() { + let out_point = field_vec![E, 2, 3, 5]; + let out_eq = field_vec![E, -8, 16, 12, -24, 10, -20, -15, 30]; + + let d0 = field_vec![E, 1, 2, 3, 4, 5, 6, 7, 8]; + let d1 = field_vec![E, 9, 10, 11, 12, 13, 14, 15, 16]; + let n0 = field_vec![E, 17, 18, 19, 20, 21, 22, 23, 24]; + let n1 = field_vec![E, 25, 26, 27, 28, 29, 30, 31, 32]; + + let challenges = vec![E::from_u64(7)]; + let ans = izip!(&out_eq, &d0, &d1, &n0, &n1) + .map(|(eq, d0, d1, n0, n1)| *eq * (*d0 * *d1 + challenges[0] * (*d0 * *n1 + *d1 * *n0))) + .sum(); + + let mut ext_mles = [d0, d1, n0, n1]; + let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); + let beta = Expression::Const(Constant::Challenge(0)); + let expr = d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0); + + let ext_mles_refs = ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); + run( + vec![out_point.as_slice()], + vec![expr.clone()], + ext_mles_refs, + vec![], + challenges, + vec![ans], + ); + } + + #[test] + fn test_zerocheck_multi_points() { + let points = [ + field_vec![E, 2, 3, 5], + field_vec![E, 7, 11, 13], + field_vec![E, 17, 19, 23], + ]; + let out_eqs = [ + field_vec![E, -8, 16, 12, -24, 10, -20, -15, 30], + field_vec![E, -720, 840, 792, -924, 780, -910, -858, 1001], + field_vec![E, -6336, 6732, 6688, -7106, 6624, -7038, -6992, 7429], + ]; + let point_refs = points.iter().map(|v| v.as_slice()).collect_vec(); + + let d0 = field_vec![F, 1, 2, 3, 4, 5, 6, 7, 8]; + let d1 = field_vec![F, 9, 10, 11, 12, 13, 14, 15, 16]; + let n0 = field_vec![F, 17, 18, 19, 20, 21, 22, 23, 24]; + let n1 = field_vec![F, 25, 26, 27, 28, 29, 30, 31, 32]; + + let ans_0 = izip!(&out_eqs[0], &d0, &d1) + .map(|(eq0, d0, d1)| *eq0 * *d0 * *d1) + .sum(); + let ans_1 = izip!(&out_eqs[1], &d0, &n1) + .map(|(eq1, d0, n1)| *eq1 * *d0 * *n1) + .sum(); + let ans_2 = izip!(&out_eqs[2], &d1, &n0) + .map(|(eq2, d1, n0)| *eq2 * *d1 * *n0) + .sum(); + + let base_mles = [d0, d1, n0, n1]; + let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::BasePoly(i))); + + let exprs = vec![d0.clone() * d1.clone(), d0 * n1, d1 * n0]; + + let base_mle_refs = base_mles.iter().map(|v| v.as_slice()).collect_vec(); + run( + point_refs, + exprs, + vec![], + base_mle_refs, + vec![], + vec![ans_0, ans_1, ans_2], + ); + } +}