From a2fee3afa7a5551764bb962c26f8f32134811ff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Wed, 17 Jan 2024 10:16:00 -0500 Subject: [PATCH] refactor: Improve Spartan SNARK polynomial computations and evaluations This gathers the changes to the pre-processing SNARK excerpted from the Supernova implementation in PR #283. The main changes extracted here are the introduction of the masked eq polynomial (and its use fixing [this issue](https://hackmd.io/@adr1anh/Sy08YaVBa)), additional sum-check tooling, removal of two calls to evaluation argument. Main reference Arecibo PRs: - https://github.com/lurk-lab/arecibo/pull/106 - https://github.com/lurk-lab/arecibo/pull/131 - https://github.com/lurk-lab/arecibo/pull/174 - https://github.com/lurk-lab/arecibo/pull/182 - Enhancement of polynomial related code in Spartan protocol including new polynomial types, modified existing types, better evaluation methods, and improved polynomial operations. - Introduction of `squares` function and change in the generation of `t_pow` in Spartan. - Addition of a new polynomial type through `MaskedEqPolynomial` with methods for its creation and evaluation. - Enhancements in `UniPoly` struct by addition of `PartialEq`, and `Eq` traits. - Improvements in `snark.rs` for proving and verifying batch evaluation claims, leveraging `gamma` and `rho` for random linear combinations, and optimizing various variable computations. - Updates in `multilinear.rs` with refactoring, optimization, error handling, and supplementing unit tests. - Refactor in `spartan/mod.rs` with import updates, function overhauls, struct visibility changes, and asynchronous operations for efficient calculations. - Additions and amendments in `sumcheck.rs` for batch verification, handling of vectors of claims, handling of cubic bounds with additive terms, visibility adjustments, and typo fixes. - Modifications in `eq.rs` including a debug derive for `EqPolynomial`, enhanced visibility of `r` vector, provision of `evals_from_points` for enhanced evaluation, and addition of `FromIterator` implementation. Co-authored-by: porcuquine Co-authored-by: Matej Penciak <96667244+mpenciak@users.noreply.github.com> Co-authored-by: Adrian Hamelink --- src/spartan/mod.rs | 182 +++++++++++--- src/spartan/polys/eq.rs | 22 +- src/spartan/polys/masked_eq.rs | 150 ++++++++++++ src/spartan/polys/mod.rs | 1 + src/spartan/polys/multilinear.rs | 37 +-- src/spartan/polys/power.rs | 22 +- src/spartan/polys/univariate.rs | 2 +- src/spartan/ppsnark.rs | 405 ++++++++++++++++++++----------- src/spartan/snark.rs | 197 +++++++-------- src/spartan/sumcheck.rs | 146 +++++++++-- 10 files changed, 819 insertions(+), 345 deletions(-) create mode 100644 src/spartan/polys/masked_eq.rs diff --git a/src/spartan/mod.rs b/src/spartan/mod.rs index 18479d19d..1189e531e 100644 --- a/src/spartan/mod.rs +++ b/src/spartan/mod.rs @@ -5,24 +5,29 @@ //! We also provide direct.rs that allows proving a step circuit directly with either of the two SNARKs. //! //! In polynomial.rs we also provide foundational types and functions for manipulating multilinear polynomials. +pub mod direct; #[macro_use] mod macros; -pub mod direct; pub(crate) mod math; pub mod polys; pub mod ppsnark; pub mod snark; mod sumcheck; -use crate::{traits::Engine, Commitment}; +use crate::{ + r1cs::{R1CSShape, SparseMatrix}, + traits::Engine, + Commitment, +}; use ff::Field; use itertools::Itertools as _; use polys::multilinear::SparsePolynomial; use rayon::{iter::IntoParallelRefIterator, prelude::*}; +// Creates a vector of the first `n` powers of `s`. fn powers(s: &E::Scalar, n: usize) -> Vec { assert!(n >= 1); - let mut powers = Vec::new(); + let mut powers = Vec::with_capacity(n); powers.push(E::Scalar::ONE); for i in 1..n { powers.push(powers[i - 1] * s); @@ -31,35 +36,60 @@ fn powers(s: &E::Scalar, n: usize) -> Vec { } /// A type that holds a witness to a polynomial evaluation instance -pub struct PolyEvalWitness { +struct PolyEvalWitness { p: Vec, // polynomial } impl PolyEvalWitness { - fn pad(mut W: Vec>) -> Vec> { - // determine the maximum size - if let Some(n) = W.iter().map(|w| w.p.len()).max() { - W.iter_mut().for_each(|w| { - w.p.resize(n, E::Scalar::ZERO); - }); - W - } else { - Vec::new() - } - } + /// Given [Pᵢ] and s, compute P = ∑ᵢ sⁱ⋅Pᵢ + /// + /// # Details + /// + /// We allow the input polynomials to have different sizes, and interpret smaller ones as + /// being padded with 0 to the maximum size of all polynomials. + fn batch_diff_size(W: Vec>, s: E::Scalar) -> PolyEvalWitness { + let powers = powers::(&s, W.len()); + + let size_max = W.iter().map(|w| w.p.len()).max().unwrap(); + // Scale the input polynomials by the power of s + let p = W + .into_par_iter() + .zip_eq(powers.par_iter()) + .map(|(mut w, s)| { + if *s != E::Scalar::ONE { + w.p.par_iter_mut().for_each(|e| *e *= s); + } + w.p + }) + .reduce( + || vec![E::Scalar::ZERO; size_max], + |left, right| { + // Sum into the largest polynomial + let (mut big, small) = if left.len() > right.len() { + (left, right) + } else { + (right, left) + }; + + big + .par_iter_mut() + .zip(small.par_iter()) + .for_each(|(b, s)| *b += s); + + big + }, + ); - fn weighted_sum(W: &[PolyEvalWitness], s: &[E::Scalar]) -> PolyEvalWitness { - assert_eq!(W.len(), s.len()); - let mut p = vec![E::Scalar::ZERO; W[0].p.len()]; - for i in 0..W.len() { - for j in 0..W[i].p.len() { - p[j] += W[i].p[j] * s[i] - } - } PolyEvalWitness { p } } - // This method panics unless all vectors in p_vec are of the same length + /// Given a set of polynomials \[Pᵢ\] and a scalar `s`, this method computes the weighted sum + /// of the polynomials, where each polynomial Pᵢ is scaled by sⁱ. The method handles polynomials + /// of different sizes by padding smaller ones with zeroes up to the size of the largest polynomial. + /// + /// # Panics + /// + /// This method panics if the polynomials in `p_vec` are not all of the same length. fn batch(p_vec: &[&Vec], s: &E::Scalar) -> PolyEvalWitness { p_vec .iter() @@ -69,7 +99,7 @@ impl PolyEvalWitness { let p = zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| { // compute the weighted sum for each vector - v.iter().map(|&x| x * weight).collect::>() + v.iter().map(|&x| x * *weight).collect::>() }) .reduce( || vec![E::Scalar::ZERO; p_vec[0].len()], @@ -84,25 +114,54 @@ impl PolyEvalWitness { } /// A type that holds a polynomial evaluation instance -pub struct PolyEvalInstance { +struct PolyEvalInstance { c: Commitment, // commitment to the polynomial x: Vec, // evaluation point e: E::Scalar, // claimed evaluation } impl PolyEvalInstance { - fn pad(U: Vec>) -> Vec> { - // determine the maximum size - if let Some(ell) = U.iter().map(|u| u.x.len()).max() { - U.into_iter() - .map(|mut u| { - let mut x = vec![E::Scalar::ZERO; ell - u.x.len()]; - x.append(&mut u.x); - PolyEvalInstance { x, ..u } - }) - .collect() - } else { - Vec::new() + fn batch_diff_size( + c_vec: &[Commitment], + e_vec: &[E::Scalar], + num_vars: &[usize], + x: Vec, + s: E::Scalar, + ) -> PolyEvalInstance { + let num_instances = num_vars.len(); + assert_eq!(c_vec.len(), num_instances); + assert_eq!(e_vec.len(), num_instances); + + let num_vars_max = x.len(); + let powers: Vec = powers::(&s, num_instances); + // Rescale evaluations by the first Lagrange polynomial, + // so that we can check its evaluation against x + let evals_scaled = zip_with!(iter, (e_vec, num_vars), |eval, num_rounds| { + // x_lo = [ x[0] , ..., x[n-nᵢ-1] ] + // x_hi = [ x[n-nᵢ], ..., x[n] ] + let (r_lo, _r_hi) = x.split_at(num_vars_max - num_rounds); + // Compute L₀(x_lo) + let lagrange_eval = r_lo + .iter() + .map(|r| E::Scalar::ONE - r) + .product::(); + + // vᵢ = L₀(x_lo)⋅Pᵢ(x_hi) + lagrange_eval * eval + }) + .collect::>(); + + // C = ∑ᵢ γⁱ⋅Cᵢ + let comm_joint = zip_with!(iter, (c_vec, powers), |c, g_i| *c * *g_i) + .fold(Commitment::::default(), |acc, item| acc + item); + + // v = ∑ᵢ γⁱ⋅vᵢ + let eval_joint = zip_with!((evals_scaled.into_iter(), powers.iter()), |e, g_i| e * g_i).sum(); + + PolyEvalInstance { + c: comm_joint, + x, + e: eval_joint, } } @@ -112,8 +171,13 @@ impl PolyEvalInstance { e_vec: &[E::Scalar], s: &E::Scalar, ) -> PolyEvalInstance { - let powers_of_s = powers::(s, c_vec.len()); + let num_instances = c_vec.len(); + assert_eq!(e_vec.len(), num_instances); + + let powers_of_s = powers::(s, num_instances); + // Weighted sum of evaluations let e = zip_with!(par_iter, (e_vec, powers_of_s), |e, p| *e * p).sum(); + // Weighted sum of commitments let c = zip_with!(par_iter, (c_vec, powers_of_s), |c, p| *c * *p) .reduce(Commitment::::default, |acc, item| acc + item); @@ -124,3 +188,43 @@ impl PolyEvalInstance { } } } + +/// Bounds "row" variables of (A, B, C) matrices viewed as 2d multilinear polynomials +fn compute_eval_table_sparse( + S: &R1CSShape, + rx: &[E::Scalar], +) -> (Vec, Vec, Vec) { + assert_eq!(rx.len(), S.num_cons); + + let inner = |M: &SparseMatrix, M_evals: &mut Vec| { + for (row_idx, ptrs) in M.indptr.windows(2).enumerate() { + for (val, col_idx) in M.get_row_unchecked(ptrs.try_into().unwrap()) { + M_evals[*col_idx] += rx[row_idx] * val; + } + } + }; + + let (A_evals, (B_evals, C_evals)) = rayon::join( + || { + let mut A_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; + inner(&S.A, &mut A_evals); + A_evals + }, + || { + rayon::join( + || { + let mut B_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; + inner(&S.B, &mut B_evals); + B_evals + }, + || { + let mut C_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; + inner(&S.C, &mut C_evals); + C_evals + }, + ) + }, + ); + + (A_evals, B_evals, C_evals) +} diff --git a/src/spartan/polys/eq.rs b/src/spartan/polys/eq.rs index 5e2dcd9e2..aacab8f24 100644 --- a/src/spartan/polys/eq.rs +++ b/src/spartan/polys/eq.rs @@ -14,8 +14,9 @@ use rayon::prelude::{IndexedParallelIterator, IntoParallelRefMutIterator, Parall /// This polynomial evaluates to 1 if every component $x_i$ equals its corresponding $e_i$, and 0 otherwise. /// /// For instance, for e = 6 (with a binary representation of 0b110), the vector r would be [1, 1, 0]. +#[derive(Debug)] pub struct EqPolynomial { - r: Vec, + pub(in crate::spartan::polys) r: Vec, } impl EqPolynomial { @@ -43,12 +44,20 @@ impl EqPolynomial { /// /// Returns a vector of Scalars, each corresponding to the polynomial evaluation at a specific point. pub fn evals(&self) -> Vec { - let ell = self.r.len(); + Self::evals_from_points(&self.r) + } + + /// Evaluates the `EqPolynomial` from the `2^|r|` points in its domain, without creating an intermediate polynomial + /// representation. + /// + /// Returns a vector of Scalars, each corresponding to the polynomial evaluation at a specific point. + pub fn evals_from_points(r: &[Scalar]) -> Vec { + let ell = r.len(); let mut evals: Vec = vec![Scalar::ZERO; (2_usize).pow(ell as u32)]; let mut size = 1; evals[0] = Scalar::ONE; - for r in self.r.iter().rev() { + for r in r.iter().rev() { let (evals_left, evals_right) = evals.split_at_mut(size); let (evals_right, _) = evals_right.split_at_mut(size); @@ -64,6 +73,13 @@ impl EqPolynomial { } } +impl FromIterator for EqPolynomial { + fn from_iter>(iter: I) -> Self { + let r: Vec<_> = iter.into_iter().collect(); + EqPolynomial { r } + } +} + #[cfg(test)] mod tests { use crate::provider; diff --git a/src/spartan/polys/masked_eq.rs b/src/spartan/polys/masked_eq.rs new file mode 100644 index 000000000..5fa5fe224 --- /dev/null +++ b/src/spartan/polys/masked_eq.rs @@ -0,0 +1,150 @@ +//! `MaskedEqPolynomial`: Represents the `eq` polynomial over n variables, where the first 2^m entries are 0. + +use crate::spartan::polys::eq::EqPolynomial; +use ff::PrimeField; +use itertools::zip_eq; + +/// Represents the multilinear extension polynomial (MLE) of the equality polynomial $eqₘ(x,r)$ +/// over n variables, where the first 2^m evaluations are 0. +/// +/// The polynomial is defined by the formula: +/// eqₘ(x,r) = eq(x,r) - ( ∏_{0 ≤ i < n-m} (1−rᵢ)(1−xᵢ) )⋅( ∏_{n-m ≤ i < n} (1−rᵢ)(1−xᵢ) + rᵢ⋅xᵢ ) +#[derive(Debug)] +pub struct MaskedEqPolynomial<'a, Scalar: PrimeField> { + eq: &'a EqPolynomial, + num_masked_vars: usize, +} + +impl<'a, Scalar: PrimeField> MaskedEqPolynomial<'a, Scalar> { + /// Creates a new `MaskedEqPolynomial` from a vector of Scalars `r` of size n, with the number of + /// masked variables m = `num_masked_vars`. + pub const fn new(eq: &'a EqPolynomial, num_masked_vars: usize) -> Self { + MaskedEqPolynomial { + eq, + num_masked_vars, + } + } + + /// Evaluates the `MaskedEqPolynomial` at a given point `rx`. + /// + /// This function computes the value of the polynomial at the point specified by `rx`. + /// It expects `rx` to have the same length as the internal vector `r`. + /// + /// Panics if `rx` and `r` have different lengths. + pub fn evaluate(&self, rx: &[Scalar]) -> Scalar { + let r = &self.eq.r; + assert_eq!(r.len(), rx.len()); + let split_idx = r.len() - self.num_masked_vars; + + let (r_lo, r_hi) = r.split_at(split_idx); + let (rx_lo, rx_hi) = rx.split_at(split_idx); + let eq_lo = zip_eq(r_lo, rx_lo) + .map(|(r, rx)| *r * rx + (Scalar::ONE - r) * (Scalar::ONE - rx)) + .product::(); + let eq_hi = zip_eq(r_hi, rx_hi) + .map(|(r, rx)| *r * rx + (Scalar::ONE - r) * (Scalar::ONE - rx)) + .product::(); + let mask_lo = zip_eq(r_lo, rx_lo) + .map(|(r, rx)| (Scalar::ONE - r) * (Scalar::ONE - rx)) + .product::(); + + (eq_lo - mask_lo) * eq_hi + } + + /// Evaluates the `MaskedEqPolynomial` at all the `2^|r|` points in its domain. + /// + /// Returns a vector of Scalars, each corresponding to the polynomial evaluation at a specific point. + pub fn evals(&self) -> Vec { + Self::evals_from_points(&self.eq.r, self.num_masked_vars) + } + + /// Evaluates the `MaskedEqPolynomial` from the `2^|r|` points in its domain, without creating an intermediate polynomial + /// representation. + /// + /// Returns a vector of Scalars, each corresponding to the polynomial evaluation at a specific point. + fn evals_from_points(r: &[Scalar], num_masked_vars: usize) -> Vec { + let mut evals = EqPolynomial::evals_from_points(r); + + // replace the first 2^m evaluations with 0 + let num_masked_evals = 1 << num_masked_vars; + evals[..num_masked_evals] + .iter_mut() + .for_each(|e| *e = Scalar::ZERO); + + evals + } +} + +#[cfg(test)] +mod tests { + use crate::provider; + + use super::*; + use crate::spartan::polys::eq::EqPolynomial; + use pasta_curves::Fp; + use rand_chacha::ChaCha20Rng; + use rand_core::{CryptoRng, RngCore, SeedableRng}; + + fn test_masked_eq_polynomial_with( + num_vars: usize, + num_masked_vars: usize, + mut rng: &mut R, + ) { + let num_masked_evals = 1 << num_masked_vars; + + // random point + let r = std::iter::from_fn(|| Some(F::random(&mut rng))) + .take(num_vars) + .collect::>(); + // evaluation point + let rx = std::iter::from_fn(|| Some(F::random(&mut rng))) + .take(num_vars) + .collect::>(); + + let poly_eq = EqPolynomial::new(r); + let poly_eq_evals = poly_eq.evals(); + + let masked_eq_poly = MaskedEqPolynomial::new(&poly_eq, num_masked_vars); + let masked_eq_poly_evals = masked_eq_poly.evals(); + + // ensure the first 2^m entries are 0 + assert_eq!( + masked_eq_poly_evals[..num_masked_evals], + vec![F::ZERO; num_masked_evals] + ); + // ensure the remaining evaluations match eq(r) + assert_eq!( + masked_eq_poly_evals[num_masked_evals..], + poly_eq_evals[num_masked_evals..] + ); + + // compute the evaluation at rx succinctly + let masked_eq_eval = masked_eq_poly.evaluate(&rx); + + // compute the evaluation as a MLE + let rx_evals = EqPolynomial::evals_from_points(&rx); + let expected_masked_eq_eval = zip_eq(rx_evals, masked_eq_poly_evals) + .map(|(rx, r)| rx * r) + .sum(); + + assert_eq!(masked_eq_eval, expected_masked_eq_eval); + } + + #[test] + fn test_masked_eq_polynomial() { + let mut rng = ChaCha20Rng::from_seed([0u8; 32]); + let num_vars = 5; + let num_masked_vars = 2; + test_masked_eq_polynomial_with::(num_vars, num_masked_vars, &mut rng); + test_masked_eq_polynomial_with::( + num_vars, + num_masked_vars, + &mut rng, + ); + test_masked_eq_polynomial_with::( + num_vars, + num_masked_vars, + &mut rng, + ); + } +} diff --git a/src/spartan/polys/mod.rs b/src/spartan/polys/mod.rs index d19d56f77..a1a192ef8 100644 --- a/src/spartan/polys/mod.rs +++ b/src/spartan/polys/mod.rs @@ -1,6 +1,7 @@ //! This module contains the definitions of polynomial types used in the Spartan SNARK. pub(crate) mod eq; pub(crate) mod identity; +pub(crate) mod masked_eq; pub(crate) mod multilinear; pub(crate) mod power; pub(crate) mod univariate; diff --git a/src/spartan/polys/multilinear.rs b/src/spartan/polys/multilinear.rs index 6610fc2f1..43cbe63ce 100644 --- a/src/spartan/polys/multilinear.rs +++ b/src/spartan/polys/multilinear.rs @@ -39,13 +39,12 @@ pub struct MultilinearPolynomial { impl MultilinearPolynomial { /// Creates a new `MultilinearPolynomial` from the given evaluations. /// + /// # Panics /// The number of evaluations must be a power of two. pub fn new(Z: Vec) -> Self { - assert_eq!(Z.len(), (2_usize).pow((Z.len() as f64).log2() as u32)); - MultilinearPolynomial { - num_vars: usize::try_from(Z.len().ilog2()).unwrap(), - Z, - } + let num_vars = Z.len().log_2(); + assert_eq!(Z.len(), 1 << num_vars); + MultilinearPolynomial { num_vars, Z } } /// Returns the number of variables in the multilinear polynomial @@ -58,10 +57,12 @@ impl MultilinearPolynomial { self.Z.len() } - /// Bounds the polynomial's top variable using the given scalar. + /// Binds the polynomial's top variable using the given scalar. /// /// This operation modifies the polynomial in-place. pub fn bind_poly_var_top(&mut self, r: &Scalar) { + assert!(self.num_vars > 0); + let n = self.len() / 2; let (left, right) = self.Z.split_at_mut(n); @@ -81,20 +82,22 @@ impl MultilinearPolynomial { pub fn evaluate(&self, r: &[Scalar]) -> Scalar { // r must have a value for each variable assert_eq!(r.len(), self.get_num_vars()); - let chis = EqPolynomial::new(r.to_vec()).evals(); - assert_eq!(chis.len(), self.Z.len()); + let chis = EqPolynomial::evals_from_points(r); - (0..chis.len()) - .into_par_iter() - .map(|i| chis[i] * self.Z[i]) - .sum() + zip_with!( + (chis.into_par_iter(), self.Z.par_iter()), + |chi_i, Z_i| chi_i * Z_i + ) + .sum() } /// Evaluates the polynomial with the given evaluations and point. pub fn evaluate_with(Z: &[Scalar], r: &[Scalar]) -> Scalar { zip_with!( - into_par_iter, - (EqPolynomial::new(r.to_vec()).evals(), Z), + ( + EqPolynomial::evals_from_points(r).into_par_iter(), + Z.par_iter() + ), |a, b| a * b ) .sum() @@ -141,7 +144,7 @@ impl SparsePolynomial { chi_i } - // Takes O(n log n) + // Takes O(m log n) where m is the number of non-zero evaluations and n is the number of variables. pub fn evaluate(&self, r: &[Scalar]) -> Scalar { assert_eq!(self.num_vars, r.len()); @@ -165,7 +168,7 @@ impl Add for MultilinearPolynomial { return Err("The two polynomials must have the same number of variables"); } - let sum: Vec = zip_with!(iter, (self.Z, other.Z), |a, b| *a + *b).collect(); + let sum: Vec = zip_with!(into_iter, (self.Z, other.Z), |a, b| a + b).collect(); Ok(MultilinearPolynomial::new(sum)) } @@ -257,7 +260,7 @@ mod tests { let num_evals = 4; let mut evals: Vec = Vec::with_capacity(num_evals); for _ in 0..num_evals { - evals.push(F::from_u128(8)); + evals.push(F::from(8)); } let dense_poly: MultilinearPolynomial = MultilinearPolynomial::new(evals.clone()); diff --git a/src/spartan/polys/power.rs b/src/spartan/polys/power.rs index 06721a23c..fd6ef50a7 100644 --- a/src/spartan/polys/power.rs +++ b/src/spartan/polys/power.rs @@ -2,6 +2,7 @@ use crate::spartan::polys::eq::EqPolynomial; use ff::PrimeField; +use std::iter::successors; /// Represents the multilinear extension polynomial (MLE) of the equality polynomial $pow(x,t)$, denoted as $\tilde{pow}(x, t)$. /// @@ -10,7 +11,6 @@ use ff::PrimeField; /// \tilde{power}(x, t) = \prod_{i=1}^m(1 + (t^{2^i} - 1) * x_i) /// $$ pub struct PowPolynomial { - t_pow: Vec, eq: EqPolynomial, } @@ -18,30 +18,34 @@ impl PowPolynomial { /// Creates a new `PowPolynomial` from a Scalars `t`. pub fn new(t: &Scalar, ell: usize) -> Self { // t_pow = [t^{2^0}, t^{2^1}, ..., t^{2^{ell-1}}] - let mut t_pow = vec![Scalar::ONE; ell]; - t_pow[0] = *t; - for i in 1..ell { - t_pow[i] = t_pow[i - 1].square(); - } + let t_pow = Self::squares(t, ell); PowPolynomial { - t_pow: t_pow.clone(), eq: EqPolynomial::new(t_pow), } } + /// Create powers the following powers of `t`: + /// [t^{2^0}, t^{2^1}, ..., t^{2^{ell-1}}] + pub(in crate::spartan) fn squares(t: &Scalar, ell: usize) -> Vec { + successors(Some(*t), |p: &Scalar| Some(p.square())) + .take(ell) + .collect::>() + } + /// Evaluates the `PowPolynomial` at a given point `rx`. /// /// This function computes the value of the polynomial at the point specified by `rx`. /// It expects `rx` to have the same length as the internal vector `t_pow`. /// /// Panics if `rx` and `t_pow` have different lengths. + #[allow(dead_code)] pub fn evaluate(&self, rx: &[Scalar]) -> Scalar { self.eq.evaluate(rx) } - pub fn coordinates(&self) -> Vec { - self.t_pow.clone() + pub fn coordinates(self) -> Vec { + self.eq.r } /// Evaluates the `PowPolynomial` at all the `2^|t_pow|` points in its domain. diff --git a/src/spartan/polys/univariate.rs b/src/spartan/polys/univariate.rs index bfe983e5b..4bb96c5a7 100644 --- a/src/spartan/polys/univariate.rs +++ b/src/spartan/polys/univariate.rs @@ -9,7 +9,7 @@ use crate::traits::{Group, TranscriptReprTrait}; // ax^2 + bx + c stored as vec![c, b, a] // ax^3 + bx^2 + cx + d stored as vec![d, c, b, a] -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub struct UniPoly { coeffs: Vec, } diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index 1fbe22cf4..c4a152ad3 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -13,6 +13,7 @@ use crate::{ polys::{ eq::EqPolynomial, identity::IdentityPolynomial, + masked_eq::MaskedEqPolynomial, multilinear::MultilinearPolynomial, power::PowPolynomial, univariate::{CompressedUniPoly, UniPoly}, @@ -173,7 +174,7 @@ impl R1CSShapeSparkRepr { } } - fn commit(&self, ck: &CommitmentKey) -> R1CSShapeSparkCommitment { + pub(in crate::spartan) fn commit(&self, ck: &CommitmentKey) -> R1CSShapeSparkCommitment { let comm_vec: Vec> = [ &self.row, &self.col, @@ -257,7 +258,85 @@ pub trait SumcheckEngine: Send + Sync { fn final_claims(&self) -> Vec>; } -struct MemorySumcheckInstance { +/// The [WitnessBoundSumcheck] ensures that the witness polynomial W defined over n = log(N) variables, +/// is zero outside of the first `num_vars = 2^m` entries. +/// +/// # Details +/// +/// The `W` polynomial is padded with zeros to size N = 2^n. +/// The `masked_eq` polynomials is defined as with regards to a random challenge `tau` as +/// the eq(tau) polynomial, where the first 2^m evaluations to 0. +/// +/// The instance is given by +/// `0 = ∑_{0≤i<2^n} masked_eq[i] * W[i]`. +/// It is equivalent to the expression +/// `0 = ∑_{2^m≤i<2^n} eq[i] * W[i]` +/// Since `eq` is random, the instance is only satisfied if `W[2^{m}..] = 0`. +pub(in crate::spartan) struct WitnessBoundSumcheck { + poly_W: MultilinearPolynomial, + poly_masked_eq: MultilinearPolynomial, +} + +impl WitnessBoundSumcheck { + pub fn new(tau: E::Scalar, poly_W_padded: Vec, num_vars: usize) -> Self { + let num_vars_log = num_vars.log_2(); + // When num_vars = num_rounds, we shouldn't have to prove anything + // but we still want this instance to compute the evaluation of W + let num_rounds = poly_W_padded.len().log_2(); + assert!(num_vars_log < num_rounds); + + let tau_coords = PowPolynomial::new(&tau, num_rounds).coordinates(); + let poly_masked_eq_evals = + MaskedEqPolynomial::new(&EqPolynomial::new(tau_coords), num_vars_log).evals(); + + Self { + poly_W: MultilinearPolynomial::new(poly_W_padded), + poly_masked_eq: MultilinearPolynomial::new(poly_masked_eq_evals), + } + } +} +impl SumcheckEngine for WitnessBoundSumcheck { + fn initial_claims(&self) -> Vec { + vec![E::Scalar::ZERO] + } + + fn degree(&self) -> usize { + 3 + } + + fn size(&self) -> usize { + assert_eq!(self.poly_W.len(), self.poly_masked_eq.len()); + self.poly_W.len() + } + + fn evaluation_points(&self) -> Vec> { + let comb_func = |poly_A_comp: &E::Scalar, + poly_B_comp: &E::Scalar, + _: &E::Scalar| + -> E::Scalar { *poly_A_comp * *poly_B_comp }; + + let (eval_point_0, eval_point_2, eval_point_3) = SumcheckProof::::compute_eval_points_cubic( + &self.poly_masked_eq, + &self.poly_W, + &self.poly_W, // unused + &comb_func, + ); + + vec![vec![eval_point_0, eval_point_2, eval_point_3]] + } + + fn bound(&mut self, r: &E::Scalar) { + [&mut self.poly_W, &mut self.poly_masked_eq] + .par_iter_mut() + .for_each(|poly| poly.bind_poly_var_top(r)); + } + + fn final_claims(&self) -> Vec> { + vec![vec![self.poly_W[0], self.poly_masked_eq[0]]] + } +} + +pub(in crate::spartan) struct MemorySumcheckInstance { // row w_plus_r_row: MultilinearPolynomial, t_plus_r_row: MultilinearPolynomial, @@ -280,17 +359,65 @@ struct MemorySumcheckInstance { } impl MemorySumcheckInstance { - pub fn new( + /// Computes witnesses for MemoryInstanceSumcheck + /// + /// # Description + /// We use the logUp protocol to prove that + /// ∑ TS[i]/(T[i] + r) - 1/(W[i] + r) = 0 + /// where + /// T_row[i] = mem_row[i] * gamma + i + /// = eq(tau)[i] * gamma + i + /// W_row[i] = L_row[i] * gamma + addr_row[i] + /// = eq(tau)[row[i]] * gamma + addr_row[i] + /// T_col[i] = mem_col[i] * gamma + i + /// = z[i] * gamma + i + /// W_col[i] = addr_col[i] * gamma + addr_col[i] + /// = z[col[i]] * gamma + addr_col[i] + /// and + /// TS_row, TS_col are integer-valued vectors representing the number of reads + /// to each memory cell of L_row, L_col + /// + /// The function returns oracles for the polynomials TS[i]/(T[i] + r), 1/(W[i] + r), + /// as well as auxiliary polynomials T[i] + r, W[i] + r + pub fn compute_oracles( ck: &CommitmentKey, r: &E::Scalar, - T_row: &[E::Scalar], - W_row: &[E::Scalar], - ts_row: Vec, - T_col: &[E::Scalar], - W_col: &[E::Scalar], - ts_col: Vec, - transcript: &mut E::TE, - ) -> Result<(Self, [Commitment; 4], [Vec; 4]), NovaError> { + gamma: &E::Scalar, + mem_row: &[E::Scalar], + addr_row: &[E::Scalar], + L_row: &[E::Scalar], + ts_row: &[E::Scalar], + mem_col: &[E::Scalar], + addr_col: &[E::Scalar], + L_col: &[E::Scalar], + ts_col: &[E::Scalar], + ) -> Result<([Commitment; 4], [Vec; 4], [Vec; 4]), NovaError> { + // hash the tuples of (addr,val) memory contents and read responses into a single field element using `hash_func` + let hash_func_vec = |mem: &[E::Scalar], + addr: &[E::Scalar], + lookups: &[E::Scalar]| + -> (Vec, Vec) { + let hash_func = |addr: &E::Scalar, val: &E::Scalar| -> E::Scalar { *val * gamma + *addr }; + assert_eq!(addr.len(), lookups.len()); + rayon::join( + || { + (0..mem.len()) + .map(|i| hash_func(&E::Scalar::from(i as u64), &mem[i])) + .collect::>() + }, + || { + (0..addr.len()) + .map(|i| hash_func(&addr[i], &lookups[i])) + .collect::>() + }, + ) + }; + + let ((T_row, W_row), (T_col, W_col)) = rayon::join( + || hash_func_vec(mem_row, addr_row, L_row), + || hash_func_vec(mem_col, addr_col, L_col), + ); + let batch_invert = |v: &[E::Scalar]| -> Result, NovaError> { let mut products = vec![E::Scalar::ZERO; v.len()]; let mut acc = E::Scalar::ONE; @@ -340,7 +467,10 @@ impl MemorySumcheckInstance { let inv = batch_invert(&T.par_iter().map(|e| *e + *r).collect::>())?; // compute inv[i] * TS[i] in parallel - Ok(zip_with!(par_iter, (inv, TS), |e1, e2| *e1 * e2).collect::>()) + Ok( + zip_with!((inv.into_par_iter(), TS.par_iter()), |e1, e2| e1 * *e2) + .collect::>(), + ) }, || batch_invert(&W.par_iter().map(|e| *e + *r).collect::>()), ) @@ -358,8 +488,8 @@ impl MemorySumcheckInstance { ((t_plus_r_inv_row, w_plus_r_inv_row), (t_plus_r_row, w_plus_r_row)), ((t_plus_r_inv_col, w_plus_r_inv_col), (t_plus_r_col, w_plus_r_col)), ) = rayon::join( - || helper(T_row, W_row, &ts_row, r), - || helper(T_col, W_col, &ts_col, r), + || helper(&T_row, &W_row, ts_row, r), + || helper(&T_col, &W_col, ts_col, r), ); let t_plus_r_inv_row = t_plus_r_inv_row?; @@ -385,21 +515,6 @@ impl MemorySumcheckInstance { }, ); - // absorb the commitments - transcript.absorb( - b"l", - &[ - comm_t_plus_r_inv_row, - comm_w_plus_r_inv_row, - comm_t_plus_r_inv_col, - comm_w_plus_r_inv_col, - ] - .as_slice(), - ); - - let rho = transcript.squeeze(b"r")?; - let poly_eq = MultilinearPolynomial::new(PowPolynomial::new(&rho, T_row.len().log_2()).evals()); - let comm_vec = [ comm_t_plus_r_inv_row, comm_w_plus_r_inv_row, @@ -408,32 +523,43 @@ impl MemorySumcheckInstance { ]; let poly_vec = [ - t_plus_r_inv_row.clone(), - w_plus_r_inv_row.clone(), - t_plus_r_inv_col.clone(), - w_plus_r_inv_col.clone(), + t_plus_r_inv_row, + w_plus_r_inv_row, + t_plus_r_inv_col, + w_plus_r_inv_col, ]; - let zero = vec![E::Scalar::ZERO; t_plus_r_inv_row.len()]; + let aux_poly_vec = [t_plus_r_row?, w_plus_r_row?, t_plus_r_col?, w_plus_r_col?]; - Ok(( - Self { - w_plus_r_row: MultilinearPolynomial::new(w_plus_r_row?), - t_plus_r_row: MultilinearPolynomial::new(t_plus_r_row?), - t_plus_r_inv_row: MultilinearPolynomial::new(t_plus_r_inv_row), - w_plus_r_inv_row: MultilinearPolynomial::new(w_plus_r_inv_row), - ts_row: MultilinearPolynomial::new(ts_row), - w_plus_r_col: MultilinearPolynomial::new(w_plus_r_col?), - t_plus_r_col: MultilinearPolynomial::new(t_plus_r_col?), - t_plus_r_inv_col: MultilinearPolynomial::new(t_plus_r_inv_col), - w_plus_r_inv_col: MultilinearPolynomial::new(w_plus_r_inv_col), - ts_col: MultilinearPolynomial::new(ts_col), - poly_eq, - poly_zero: MultilinearPolynomial::new(zero), - }, - comm_vec, - poly_vec, - )) + Ok((comm_vec, poly_vec, aux_poly_vec)) + } + + pub fn new( + polys_oracle: [Vec; 4], + polys_aux: [Vec; 4], + poly_eq: Vec, + ts_row: Vec, + ts_col: Vec, + ) -> Self { + let [t_plus_r_inv_row, w_plus_r_inv_row, t_plus_r_inv_col, w_plus_r_inv_col] = polys_oracle; + let [t_plus_r_row, w_plus_r_row, t_plus_r_col, w_plus_r_col] = polys_aux; + + let zero = vec![E::Scalar::ZERO; poly_eq.len()]; + + Self { + w_plus_r_row: MultilinearPolynomial::new(w_plus_r_row), + t_plus_r_row: MultilinearPolynomial::new(t_plus_r_row), + t_plus_r_inv_row: MultilinearPolynomial::new(t_plus_r_inv_row), + w_plus_r_inv_row: MultilinearPolynomial::new(w_plus_r_inv_row), + ts_row: MultilinearPolynomial::new(ts_row), + w_plus_r_col: MultilinearPolynomial::new(w_plus_r_col), + t_plus_r_col: MultilinearPolynomial::new(t_plus_r_col), + t_plus_r_inv_col: MultilinearPolynomial::new(t_plus_r_inv_col), + w_plus_r_inv_col: MultilinearPolynomial::new(w_plus_r_inv_col), + ts_col: MultilinearPolynomial::new(ts_col), + poly_eq: MultilinearPolynomial::new(poly_eq), + poly_zero: MultilinearPolynomial::new(zero), + } } } @@ -478,6 +604,7 @@ impl SumcheckEngine for MemorySumcheckInstance { -> E::Scalar { *poly_A_comp * (*poly_B_comp * *poly_C_comp - *poly_D_comp) }; // inv related evaluation points + // 0 = ∑ TS[i]/(T[i] + r) - 1/(W[i] + r) let (eval_inv_0_row, eval_inv_2_row, eval_inv_3_row) = SumcheckProof::::compute_eval_points_cubic( &self.t_plus_r_inv_row, @@ -495,6 +622,7 @@ impl SumcheckEngine for MemorySumcheckInstance { ); // row related evaluation points + // 0 = ∑ eq[i] * (inv_T[i] * (T[i] + r) - TS[i])) let (eval_T_0_row, eval_T_2_row, eval_T_3_row) = SumcheckProof::::compute_eval_points_cubic_with_additive_term( &self.poly_eq, @@ -503,6 +631,7 @@ impl SumcheckEngine for MemorySumcheckInstance { &self.ts_row, &comb_func3, ); + // 0 = ∑ eq[i] * (inv_W[i] * (T[i] + r) - 1)) let (eval_W_0_row, eval_W_2_row, eval_W_3_row) = SumcheckProof::::compute_eval_points_cubic_with_additive_term( &self.poly_eq, @@ -756,7 +885,7 @@ impl> SimpleDigestible for VerifierKey> { // commitment to oracles: the first three are for Az, Bz, Cz, @@ -805,15 +934,15 @@ pub struct RelaxedR1CSSNARK> { eval_ts_col: E::Scalar, // a PCS evaluation argument - eval_arg_W: EE::EvaluationArgument, - eval_arg_batch: EE::EvaluationArgument, + eval_arg: EE::EvaluationArgument, } impl> RelaxedR1CSSNARK { - fn prove_helper( + fn prove_helper( mem: &mut T1, outer: &mut T2, inner: &mut T3, + witness: &mut T4, transcript: &mut E::TE, ) -> Result< ( @@ -822,6 +951,7 @@ impl> RelaxedR1CSSNARK { Vec>, Vec>, Vec>, + Vec>, ), NovaError, > @@ -829,12 +959,15 @@ impl> RelaxedR1CSSNARK { T1: SumcheckEngine, T2: SumcheckEngine, T3: SumcheckEngine, + T4: SumcheckEngine, { // sanity checks assert_eq!(mem.size(), outer.size()); assert_eq!(mem.size(), inner.size()); + assert_eq!(mem.size(), witness.size()); assert_eq!(mem.degree(), outer.degree()); assert_eq!(mem.degree(), inner.degree()); + assert_eq!(mem.degree(), witness.degree()); // these claims are already added to the transcript, so we do not need to add let claims = mem @@ -842,28 +975,30 @@ impl> RelaxedR1CSSNARK { .into_iter() .chain(outer.initial_claims()) .chain(inner.initial_claims()) + .chain(witness.initial_claims()) .collect::>(); let s = transcript.squeeze(b"r")?; let coeffs = powers::(&s, claims.len()); // compute the joint claim - let claim = zip_with!(iter, (claims, coeffs), |c_1, c_2| *c_1 * c_2).sum(); + let claim = zip_with!((claims.iter(), coeffs.iter()), |c_1, c_2| *c_1 * c_2).sum(); let mut e = claim; let mut r: Vec = Vec::new(); let mut cubic_polys: Vec> = Vec::new(); let num_rounds = mem.size().log_2(); for _ in 0..num_rounds { - let (evals_mem, (evals_outer, evals_inner)) = rayon::join( - || mem.evaluation_points(), - || rayon::join(|| outer.evaluation_points(), || inner.evaluation_points()), + let ((evals_mem, evals_outer), (evals_inner, evals_witness)) = rayon::join( + || rayon::join(|| mem.evaluation_points(), || outer.evaluation_points()), + || rayon::join(|| inner.evaluation_points(), || witness.evaluation_points()), ); let evals: Vec> = evals_mem .into_iter() .chain(evals_outer.into_iter()) .chain(evals_inner.into_iter()) + .chain(evals_witness.into_iter()) .collect::>>(); assert_eq!(evals.len(), claims.len()); @@ -887,8 +1022,8 @@ impl> RelaxedR1CSSNARK { r.push(r_i); let _ = rayon::join( - || mem.bound(&r_i), - || rayon::join(|| outer.bound(&r_i), || inner.bound(&r_i)), + || rayon::join(|| mem.bound(&r_i), || outer.bound(&r_i)), + || rayon::join(|| inner.bound(&r_i), || witness.bound(&r_i)), ); e = poly.evaluate(&r_i); @@ -898,6 +1033,7 @@ impl> RelaxedR1CSSNARK { let mem_claims = mem.final_claims(); let outer_claims = outer.final_claims(); let inner_claims = inner.final_claims(); + let witness_claims = witness.final_claims(); Ok(( SumcheckProof::new(cubic_polys), @@ -905,6 +1041,7 @@ impl> RelaxedR1CSSNARK { mem_claims, outer_claims, inner_claims, + witness_claims, )) } } @@ -1018,13 +1155,14 @@ impl> RelaxedR1CSSNARKTrait for Relax let tau_coords = PowPolynomial::new(&tau, num_rounds_sc).coordinates(); // (1) send commitments to Az, Bz, and Cz along with their evaluations at tau - let (Az, Bz, Cz, E) = { + let (Az, Bz, Cz, W, E) = { Az.resize(pk.S_repr.N, E::Scalar::ZERO); Bz.resize(pk.S_repr.N, E::Scalar::ZERO); Cz.resize(pk.S_repr.N, E::Scalar::ZERO); let E = padded::(&W.E, pk.S_repr.N, &E::Scalar::ZERO); + let W = padded::(&W.W, pk.S_repr.N, &E::Scalar::ZERO); - (Az, Bz, Cz, E) + (Az, Bz, Cz, W, E) }; let (eval_Az_at_tau, eval_Bz_at_tau, eval_Cz_at_tau) = { let evals_at_tau = [&Az, &Bz, &Cz] @@ -1097,51 +1235,50 @@ impl> RelaxedR1CSSNARKTrait for Relax // we now need to prove that L_row and L_col are well-formed // hash the tuples of (addr,val) memory contents and read responses into a single field element using `hash_func` - let hash_func_vec = |mem: &[E::Scalar], - addr: &[E::Scalar], - lookups: &[E::Scalar]| - -> (Vec, Vec) { - let hash_func = |addr: &E::Scalar, val: &E::Scalar| -> E::Scalar { *val * gamma + *addr }; - assert_eq!(addr.len(), lookups.len()); - rayon::join( - || { - (0..mem.len()) - .map(|i| hash_func(&E::Scalar::from(i as u64), &mem[i])) - .collect::>() - }, - || { - (0..addr.len()) - .map(|i| hash_func(&addr[i], &lookups[i])) - .collect::>() - }, - ) - }; - - let ((T_row, W_row), (T_col, W_col)) = rayon::join( - || hash_func_vec(&mem_row, &pk.S_repr.row, &L_row), - || hash_func_vec(&mem_col, &pk.S_repr.col, &L_col), - ); - MemorySumcheckInstance::new( - ck, - &r, - &T_row, - &W_row, - pk.S_repr.ts_row.clone(), - &T_col, - &W_col, - pk.S_repr.ts_col.clone(), - &mut transcript, - ) + let (comm_mem_oracles, mem_oracles, mem_aux) = + MemorySumcheckInstance::::compute_oracles( + ck, + &r, + &gamma, + &mem_row, + &pk.S_repr.row, + &L_row, + &pk.S_repr.ts_row, + &mem_col, + &pk.S_repr.col, + &L_col, + &pk.S_repr.ts_col, + )?; + // absorb the commitments + transcript.absorb(b"l", &comm_mem_oracles.as_slice()); + + let rho = transcript.squeeze(b"r")?; + let poly_eq = MultilinearPolynomial::new(PowPolynomial::new(&rho, num_rounds_sc).evals()); + + Ok::<_, NovaError>(( + MemorySumcheckInstance::new( + mem_oracles.clone(), + mem_aux, + poly_eq.Z, + pk.S_repr.ts_row.clone(), + pk.S_repr.ts_col.clone(), + ), + comm_mem_oracles, + mem_oracles, + )) }, ); let (mut mem_sc_inst, comm_mem_oracles, mem_oracles) = mem_res?; - let (sc, rand_sc, claims_mem, claims_outer, claims_inner) = Self::prove_helper( + let mut witness_sc_inst = WitnessBoundSumcheck::new(tau, W.clone(), S.num_vars); + + let (sc, rand_sc, claims_mem, claims_outer, claims_inner, claims_witness) = Self::prove_helper( &mut mem_sc_inst, &mut outer_sc_inst, &mut inner_sc_inst, + &mut witness_sc_inst, &mut transcript, )?; @@ -1159,6 +1296,7 @@ impl> RelaxedR1CSSNARKTrait for Relax let eval_t_plus_r_inv_col = claims_mem[1][0]; let eval_w_plus_r_inv_col = claims_mem[1][1]; let eval_ts_col = claims_mem[1][2]; + let eval_W = claims_witness[0][0]; // compute the remaining claims that did not come for free from the sum-check prover let (eval_Cz, eval_E, eval_val_A, eval_val_B, eval_val_C, eval_row, eval_col) = { @@ -1177,8 +1315,9 @@ impl> RelaxedR1CSSNARKTrait for Relax (e[0], e[1], e[2], e[3], e[4], e[5], e[6]) }; - // all the following evaluations are at rand_sc, we can fold them into one claim + // all the evaluations are at rand_sc, we can fold them into one claim let eval_vec = vec![ + eval_W, eval_Az, eval_Bz, eval_Cz, @@ -1201,6 +1340,7 @@ impl> RelaxedR1CSSNARKTrait for Relax .collect::>(); let comm_vec = [ + U.comm_W, comm_Az, comm_Bz, comm_Cz, @@ -1220,6 +1360,7 @@ impl> RelaxedR1CSSNARKTrait for Relax pk.S_comm.comm_ts_col, ]; let poly_vec = [ + &W, &Az, &Bz, &Cz, @@ -1243,21 +1384,7 @@ impl> RelaxedR1CSSNARKTrait for Relax let w: PolyEvalWitness = PolyEvalWitness::batch(&poly_vec, &c); let u: PolyEvalInstance = PolyEvalInstance::batch(&comm_vec, &rand_sc, &eval_vec, &c); - let eval_arg_batch = EE::prove(ck, &pk.pk_ee, &mut transcript, &u.c, &w.p, &rand_sc, &u.e)?; - - // prove eval_W at the shortened vector - let l = pk.S_comm.N.log_2() - (2 * S.num_vars).log_2(); - let rand_sc_unpad = rand_sc[l..].to_vec(); - let eval_W = MultilinearPolynomial::evaluate_with(&W.W, &rand_sc_unpad[1..]); - let eval_arg_W = EE::prove( - ck, - &pk.pk_ee, - &mut transcript, - &U.comm_W, - &W.W, - &rand_sc_unpad[1..], - &eval_W, - )?; + let eval_arg = EE::prove(ck, &pk.pk_ee, &mut transcript, &u.c, &w.p, &rand_sc, &u.e)?; Ok(RelaxedR1CSSNARK { comm_Az: comm_Az.compress(), @@ -1299,8 +1426,7 @@ impl> RelaxedR1CSSNARKTrait for Relax eval_w_plus_r_inv_col, eval_ts_col, - eval_arg_batch, - eval_arg_W, + eval_arg, }) } @@ -1362,7 +1488,7 @@ impl> RelaxedR1CSSNARKTrait for Relax let rho = transcript.squeeze(b"r")?; - let num_claims = 9; + let num_claims = 10; let s = transcript.squeeze(b"r")?; let coeffs = powers::(&s, num_claims); let claim = (coeffs[7] + coeffs[8]) * claim; // rest are zeros @@ -1376,7 +1502,12 @@ impl> RelaxedR1CSSNARKTrait for Relax let poly_eq_coords = PowPolynomial::new(&rho, num_rounds_sc).coordinates(); EqPolynomial::new(poly_eq_coords).evaluate(&rand_sc) }; - let taus_bound_rand_sc = PowPolynomial::new(&tau, num_rounds_sc).evaluate(&rand_sc); + let taus_coords = PowPolynomial::new(&tau, num_rounds_sc).coordinates(); + let eq_tau = EqPolynomial::new(taus_coords); + + let taus_bound_rand_sc = eq_tau.evaluate(&rand_sc); + let taus_masked_bound_rand_sc = + MaskedEqPolynomial::new(&eq_tau, vk.num_vars.log_2()).evaluate(&rand_sc); let eval_t_plus_r_row = { let eval_addr_row = IdentityPolynomial::new(num_rounds_sc).evaluate(&rand_sc); @@ -1406,10 +1537,7 @@ impl> RelaxedR1CSSNARKTrait for Relax factor *= E::Scalar::ONE - r_p } - let rand_sc_unpad = { - let l = vk.S_comm.N.log_2() - (2 * vk.num_vars).log_2(); - rand_sc[l..].to_vec() - }; + let rand_sc_unpad = rand_sc[l..].to_vec(); (factor, rand_sc_unpad) }; @@ -1426,7 +1554,7 @@ impl> RelaxedR1CSSNARKTrait for Relax SparsePolynomial::new(vk.num_vars.log_2(), poly_X).evaluate(&rand_sc_unpad[1..]) }; - factor * ((E::Scalar::ONE - rand_sc_unpad[0]) * self.eval_W + rand_sc_unpad[0] * eval_X) + self.eval_W + factor * rand_sc_unpad[0] * eval_X }; let eval_t = eval_addr_col + gamma * eval_val_col; eval_t + r @@ -1464,7 +1592,12 @@ impl> RelaxedR1CSSNARKTrait for Relax * self.eval_L_col * (self.eval_val_A + c * self.eval_val_B + c * c * self.eval_val_C); - claim_mem_final_expected + claim_outer_final_expected + claim_inner_final_expected + let claim_witness_final_expected = coeffs[9] * taus_masked_bound_rand_sc * self.eval_W; + + claim_mem_final_expected + + claim_outer_final_expected + + claim_inner_final_expected + + claim_witness_final_expected }; if claim_sc_final_expected != claim_sc_final { @@ -1472,6 +1605,7 @@ impl> RelaxedR1CSSNARKTrait for Relax } let eval_vec = vec![ + self.eval_W, self.eval_Az, self.eval_Bz, self.eval_Cz, @@ -1493,6 +1627,7 @@ impl> RelaxedR1CSSNARKTrait for Relax .into_iter() .collect::>(); let comm_vec = [ + U.comm_W, comm_Az, comm_Bz, comm_Cz, @@ -1515,28 +1650,14 @@ impl> RelaxedR1CSSNARKTrait for Relax let c = transcript.squeeze(b"c")?; let u: PolyEvalInstance = PolyEvalInstance::batch(&comm_vec, &rand_sc, &eval_vec, &c); - // verify eval_arg_batch + // verify EE::verify( &vk.vk_ee, &mut transcript, &u.c, &rand_sc, &u.e, - &self.eval_arg_batch, - )?; - - // verify eval_arg_W - let rand_sc_unpad = { - let l = vk.S_comm.N.log_2() - (2 * vk.num_vars).log_2(); - rand_sc[l..].to_vec() - }; - EE::verify( - &vk.vk_ee, - &mut transcript, - &U.comm_W, - &rand_sc_unpad[1..], - &self.eval_W, - &self.eval_arg_W, + &self.eval_arg, )?; Ok(()) diff --git a/src/spartan/snark.rs b/src/spartan/snark.rs index 5e0e54d2e..9755f9fe4 100644 --- a/src/spartan/snark.rs +++ b/src/spartan/snark.rs @@ -9,6 +9,7 @@ use crate::{ errors::NovaError, r1cs::{R1CSShape, RelaxedR1CSInstance, RelaxedR1CSWitness, SparseMatrix}, spartan::{ + compute_eval_table_sparse, polys::{eq::EqPolynomial, multilinear::MultilinearPolynomial, multilinear::SparsePolynomial}, powers, sumcheck::SumcheckProof, @@ -19,8 +20,9 @@ use crate::{ snark::{DigestHelperTrait, RelaxedR1CSSNARKTrait}, Engine, TranscriptEngineTrait, }, - Commitment, CommitmentKey, + CommitmentKey, }; + use ff::Field; use itertools::Itertools as _; use once_cell::sync::OnceCell; @@ -75,7 +77,7 @@ impl> DigestHelperTrait for VerifierK /// A succinct proof of knowledge of a witness to a relaxed R1CS instance /// The proof is produced using Spartan's combination of the sum-check and /// the commitment to a vector viewed as a polynomial commitment -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(bound = "")] pub struct RelaxedR1CSSNARK> { sc_proof_outer: SumcheckProof, @@ -141,9 +143,9 @@ impl> RelaxedR1CSSNARKTrait for Relax // outer sum-check let tau = (0..num_rounds_x) .map(|_i| transcript.squeeze(b"t")) - .collect::, NovaError>>()?; + .collect::, NovaError>>()?; - let mut poly_tau = MultilinearPolynomial::new(EqPolynomial::new(tau).evals()); + let mut poly_tau = MultilinearPolynomial::new(tau.evals()); let (mut poly_Az, mut poly_Bz, poly_Cz, mut poly_uCz_E) = { let (poly_Az, poly_Bz, poly_Cz) = S.multiply_vec(&z)?; let poly_uCz_E = (0..S.num_cons) @@ -189,45 +191,7 @@ impl> RelaxedR1CSSNARKTrait for Relax let poly_ABC = { // compute the initial evaluation table for R(\tau, x) - let evals_rx = EqPolynomial::new(r_x.clone()).evals(); - - // Bounds "row" variables of (A, B, C) matrices viewed as 2d multilinear polynomials - let compute_eval_table_sparse = - |S: &R1CSShape, rx: &[E::Scalar]| -> (Vec, Vec, Vec) { - assert_eq!(rx.len(), S.num_cons); - - let inner = |M: &SparseMatrix, M_evals: &mut Vec| { - for (row_idx, ptrs) in M.indptr.windows(2).enumerate() { - for (val, col_idx) in M.get_row_unchecked(ptrs.try_into().unwrap()) { - M_evals[*col_idx] += rx[row_idx] * val; - } - } - }; - - let (A_evals, (B_evals, C_evals)) = rayon::join( - || { - let mut A_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; - inner(&S.A, &mut A_evals); - A_evals - }, - || { - rayon::join( - || { - let mut B_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; - inner(&S.B, &mut B_evals); - B_evals - }, - || { - let mut C_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; - inner(&S.C, &mut C_evals); - C_evals - }, - ) - }, - ); - - (A_evals, B_evals, C_evals) - }; + let evals_rx = EqPolynomial::evals_from_points(&r_x.clone()); let (evals_A, evals_B, evals_C) = compute_eval_table_sparse(&S, &evals_rx); @@ -321,7 +285,7 @@ impl> RelaxedR1CSSNARKTrait for Relax // outer sum-check let tau = (0..num_rounds_x) .map(|_i| transcript.squeeze(b"t")) - .collect::, NovaError>>()?; + .collect::, NovaError>>()?; let (claim_outer_final, r_x) = self @@ -330,7 +294,7 @@ impl> RelaxedR1CSSNARKTrait for Relax // verify claim_outer_final let (claim_Az, claim_Bz, claim_Cz) = self.claims_outer; - let taus_bound_rx = EqPolynomial::new(tau).evaluate(&r_x); + let taus_bound_rx = tau.evaluate(&r_x); let claim_outer_final_expected = taus_bound_rx * (claim_Az * claim_Bz - U.u * claim_Cz - self.eval_E); if claim_outer_final != claim_outer_final_expected { @@ -394,8 +358,8 @@ impl> RelaxedR1CSSNARKTrait for Relax }; let (T_x, T_y) = rayon::join( - || EqPolynomial::new(r_x.to_vec()).evals(), - || EqPolynomial::new(r_y.to_vec()).evals(), + || EqPolynomial::evals_from_points(r_x), + || EqPolynomial::evals_from_points(r_y), ); (0..M_vec.len()) @@ -448,7 +412,19 @@ impl> RelaxedR1CSSNARKTrait for Relax /// Proves a batch of polynomial evaluation claims using Sumcheck /// reducing them to a single claim at the same point. -fn batch_eval_prove( +/// +/// # Details +/// +/// We are given as input a list of instance/witness pairs +/// u = [(Cᵢ, xᵢ, eᵢ)], w = [Pᵢ], such that +/// - nᵢ = |xᵢ| +/// - Cᵢ = Commit(Pᵢ) +/// - eᵢ = Pᵢ(xᵢ) +/// - |Pᵢ| = 2^nᵢ +/// +/// We allow the polynomial Pᵢ to have different sizes, by appropriately scaling +/// the claims and resulting evaluations from Sumcheck. +pub(in crate::spartan) fn batch_eval_prove( u_vec: Vec>, w_vec: Vec>, transcript: &mut E::TE, @@ -461,34 +437,44 @@ fn batch_eval_prove( ), NovaError, > { - assert_eq!(u_vec.len(), w_vec.len()); + let num_claims = u_vec.len(); + assert_eq!(w_vec.len(), num_claims); - let w_vec_padded = PolyEvalWitness::pad(w_vec); // pad the polynomials to be of the same size - let u_vec_padded = PolyEvalInstance::pad(u_vec); // pad the evaluation points + // Compute nᵢ and n = maxᵢ{nᵢ} + let num_rounds = u_vec.iter().map(|u| u.x.len()).collect::>(); - // generate a challenge + // Check polynomials match number of variables, i.e. |Pᵢ| = 2^nᵢ + w_vec + .iter() + .zip_eq(num_rounds.iter()) + .for_each(|(w, num_vars)| assert_eq!(w.p.len(), 1 << num_vars)); + + // generate a challenge, and powers of it for random linear combination let rho = transcript.squeeze(b"r")?; - let num_claims = w_vec_padded.len(); let powers_of_rho = powers::(&rho, num_claims); - let claim_batch_joint = zip_with!(iter, (u_vec_padded, powers_of_rho), |u, p| u.e * p).sum(); - let mut polys_left: Vec> = w_vec_padded + let (claims, u_xs, comms): (Vec<_>, Vec<_>, Vec<_>) = + u_vec.into_iter().map(|u| (u.e, u.x, u.c)).multiunzip(); + + // Create clones of polynomials to be given to Sumcheck + // Pᵢ(X) + let polys_P: Vec> = w_vec .iter() .map(|w| MultilinearPolynomial::new(w.p.clone())) .collect(); - let mut polys_right: Vec> = u_vec_padded - .iter() - .map(|u| MultilinearPolynomial::new(EqPolynomial::new(u.x.clone()).evals())) + // eq(xᵢ, X) + let polys_eq: Vec> = u_xs + .into_iter() + .map(|ux| MultilinearPolynomial::new(EqPolynomial::evals_from_points(&ux))) .collect(); - let num_rounds_z = u_vec_padded[0].x.len(); - let comb_func = - |poly_A_comp: &E::Scalar, poly_B_comp: &E::Scalar| -> E::Scalar { *poly_A_comp * *poly_B_comp }; - let (sc_proof_batch, r_z, claims_batch) = SumcheckProof::prove_quad_batch( - &claim_batch_joint, - num_rounds_z, - &mut polys_left, - &mut polys_right, + // For each i, check eᵢ = ∑ₓ Pᵢ(x)eq(xᵢ,x), where x ∈ {0,1}^nᵢ + let comb_func = |poly_P: &E::Scalar, poly_eq: &E::Scalar| -> E::Scalar { *poly_P * *poly_eq }; + let (sc_proof_batch, r, claims_batch) = SumcheckProof::prove_quad_batch( + &claims, + &num_rounds, + polys_P, + polys_eq, &powers_of_rho, comb_func, transcript, @@ -498,62 +484,51 @@ fn batch_eval_prove( transcript.absorb(b"l", &claims_batch_left.as_slice()); - // we now combine evaluation claims at the same point rz into one + // we now combine evaluation claims at the same point r into one let gamma = transcript.squeeze(b"g")?; - let powers_of_gamma: Vec = powers::(&gamma, num_claims); - let comm_joint = zip_with!(iter, (u_vec_padded, powers_of_gamma), |u, g_i| u.c * *g_i) - .fold(Commitment::::default(), |acc, item| acc + item); - let poly_joint = PolyEvalWitness::weighted_sum(&w_vec_padded, &powers_of_gamma); - let eval_joint = zip_with!(iter, (claims_batch_left, powers_of_gamma), |e, g_i| *e - * *g_i) - .sum(); - - Ok(( - PolyEvalInstance:: { - c: comm_joint, - x: r_z, - e: eval_joint, - }, - poly_joint, - sc_proof_batch, - claims_batch_left, - )) + + let u_joint = + PolyEvalInstance::batch_diff_size(&comms, &claims_batch_left, &num_rounds, r, gamma); + + // P = ∑ᵢ γⁱ⋅Pᵢ + let w_joint = PolyEvalWitness::batch_diff_size(w_vec, gamma); + + Ok((u_joint, w_joint, sc_proof_batch, claims_batch_left)) } /// Verifies a batch of polynomial evaluation claims using Sumcheck /// reducing them to a single claim at the same point. -fn batch_eval_verify( +pub(in crate::spartan) fn batch_eval_verify( u_vec: Vec>, transcript: &mut E::TE, sc_proof_batch: &SumcheckProof, evals_batch: &[E::Scalar], ) -> Result, NovaError> { - assert_eq!(evals_batch.len(), evals_batch.len()); - - let u_vec_padded = PolyEvalInstance::pad(u_vec); // pad the evaluation points + let num_claims = u_vec.len(); + assert_eq!(evals_batch.len(), num_claims); // generate a challenge let rho = transcript.squeeze(b"r")?; - let num_claims: usize = u_vec_padded.len(); let powers_of_rho = powers::(&rho, num_claims); - let claim_batch_joint = zip_with!(iter, (u_vec_padded, powers_of_rho), |u, p| u.e * p).sum(); - let num_rounds_z = u_vec_padded[0].x.len(); + // Compute nᵢ and n = maxᵢ{nᵢ} + let num_rounds = u_vec.iter().map(|u| u.x.len()).collect::>(); + let num_rounds_max = *num_rounds.iter().max().unwrap(); - let (claim_batch_final, r_z) = - sc_proof_batch.verify(claim_batch_joint, num_rounds_z, 2, transcript)?; + let claims = u_vec.iter().map(|u| u.e).collect::>(); + + let (claim_batch_final, r) = + sc_proof_batch.verify_batch(&claims, &num_rounds, &powers_of_rho, 2, transcript)?; let claim_batch_final_expected = { - let poly_rz = EqPolynomial::new(r_z.clone()); - let evals = u_vec_padded - .iter() - .map(|u| poly_rz.evaluate(&u.x)) - .collect::>(); + let evals_r = u_vec.iter().map(|u| { + let (_, r_hi) = r.split_at(num_rounds_max - u.x.len()); + EqPolynomial::new(r_hi.to_vec()).evaluate(&u.x) + }); zip_with!( - iter, - (evals, evals_batch, powers_of_rho), - |e_i, p_i, rho_i| *e_i * *p_i * rho_i + (evals_r, evals_batch.iter(), powers_of_rho.iter()), + |e_i, p_i, rho_i| e_i * *p_i * rho_i ) .sum() }; @@ -564,16 +539,12 @@ fn batch_eval_verify( transcript.absorb(b"l", &evals_batch); - // we now combine evaluation claims at the same point rz into one + // we now combine evaluation claims at the same point r into one let gamma = transcript.squeeze(b"g")?; - let powers_of_gamma: Vec = powers::(&gamma, num_claims); - let comm_joint = zip_with!(iter, (u_vec_padded, powers_of_gamma), |u, g_i| u.c * *g_i) - .fold(Commitment::::default(), |acc, item| acc + item); - let eval_joint = zip_with!(iter, (evals_batch, powers_of_gamma), |e, g_i| *e * *g_i).sum(); - - Ok(PolyEvalInstance:: { - c: comm_joint, - x: r_z, - e: eval_joint, - }) + + let comms = u_vec.into_iter().map(|u| u.c).collect::>(); + + let u_joint = PolyEvalInstance::batch_diff_size(&comms, evals_batch, &num_rounds, r, gamma); + + Ok(u_joint) } diff --git a/src/spartan/sumcheck.rs b/src/spartan/sumcheck.rs index f9a6ecb23..3e4277838 100644 --- a/src/spartan/sumcheck.rs +++ b/src/spartan/sumcheck.rs @@ -5,6 +5,7 @@ use crate::spartan::polys::{ }; use crate::traits::{Engine, TranscriptEngineTrait}; use ff::Field; +use itertools::Itertools as _; use rayon::prelude::*; use serde::{Deserialize, Serialize}; @@ -61,8 +62,42 @@ impl SumcheckProof { Ok((e, r)) } + pub fn verify_batch( + &self, + claims: &[E::Scalar], + num_rounds: &[usize], + coeffs: &[E::Scalar], + degree_bound: usize, + transcript: &mut E::TE, + ) -> Result<(E::Scalar, Vec), NovaError> { + let num_instances = claims.len(); + assert_eq!(num_rounds.len(), num_instances); + assert_eq!(coeffs.len(), num_instances); + + // n = maxᵢ{nᵢ} + let num_rounds_max = *num_rounds.iter().max().unwrap(); + + // Random linear combination of claims, + // where each claim is scaled by 2^{n-nᵢ} to account for the padding. + // + // claim = ∑ᵢ coeffᵢ⋅2^{n-nᵢ}⋅cᵢ + let claim = zip_with!( + ( + zip_with!(iter, (claims, num_rounds), |claim, num_rounds| { + let scaling_factor = 1 << (num_rounds_max - num_rounds); + E::Scalar::from(scaling_factor as u64) * claim + }), + coeffs.iter() + ), + |scaled_claim, coeff| scaled_claim * coeff + ) + .sum(); + + self.verify(claim, num_rounds_max, degree_bound, transcript) + } + #[inline] - pub(in crate::spartan) fn compute_eval_points_quad( + fn compute_eval_points_quad( poly_A: &MultilinearPolynomial, poly_B: &MultilinearPolynomial, comb_func: &F, @@ -123,7 +158,7 @@ impl SumcheckProof { // Set up next round claim_per_round = poly.evaluate(&r_i); - // bound all tables to the verifier's challenege + // bind all tables to the verifier's challenge rayon::join( || poly_A.bind_poly_var_top(&r_i), || poly_B.bind_poly_var_top(&r_i), @@ -140,10 +175,10 @@ impl SumcheckProof { } pub fn prove_quad_batch( - claim: &E::Scalar, - num_rounds: usize, - poly_A_vec: &mut Vec>, - poly_B_vec: &mut Vec>, + claims: &[E::Scalar], + num_rounds: &[usize], + mut poly_A_vec: Vec>, + mut poly_B_vec: Vec>, coeffs: &[E::Scalar], comb_func: F, transcript: &mut E::TE, @@ -151,16 +186,58 @@ impl SumcheckProof { where F: Fn(&E::Scalar, &E::Scalar) -> E::Scalar + Sync, { - let mut e = *claim; + let num_claims = claims.len(); + + assert_eq!(num_rounds.len(), num_claims); + assert_eq!(poly_A_vec.len(), num_claims); + assert_eq!(poly_B_vec.len(), num_claims); + assert_eq!(coeffs.len(), num_claims); + + for (i, &num_rounds) in num_rounds.iter().enumerate() { + let expected_size = 1 << num_rounds; + + // Direct indexing with the assumption that the index will always be in bounds + let a = &poly_A_vec[i]; + let b = &poly_B_vec[i]; + + for (l, polyname) in [(a.len(), "poly_A_vec"), (b.len(), "poly_B_vec")].iter() { + assert_eq!( + *l, expected_size, + "Mismatch in size for {} at index {}", + polyname, i + ); + } + } + + let num_rounds_max = *num_rounds.iter().max().unwrap(); + let mut e = zip_with!( + iter, + (claims, num_rounds, coeffs), + |claim, num_rounds, coeff| { + let scaled_claim = E::Scalar::from((1 << (num_rounds_max - num_rounds)) as u64) * claim; + scaled_claim * coeff + } + ) + .sum(); let mut r: Vec = Vec::new(); let mut quad_polys: Vec> = Vec::new(); - for _ in 0..num_rounds { - let evals: Vec<(E::Scalar, E::Scalar)> = - zip_with!(par_iter, (poly_A_vec, poly_B_vec), |poly_A, poly_B| { - Self::compute_eval_points_quad(poly_A, poly_B, &comb_func) - }) - .collect(); + for current_round in 0..num_rounds_max { + let remaining_rounds = num_rounds_max - current_round; + let evals: Vec<(E::Scalar, E::Scalar)> = zip_with!( + par_iter, + (num_rounds, claims, poly_A_vec, poly_B_vec), + |num_rounds, claim, poly_A, poly_B| { + if remaining_rounds <= *num_rounds { + Self::compute_eval_points_quad(poly_A, poly_B, &comb_func) + } else { + let remaining_variables = remaining_rounds - num_rounds - 1; + let scaled_claim = E::Scalar::from((1 << remaining_variables) as u64) * claim; + (scaled_claim, scaled_claim) + } + } + ) + .collect(); let evals_combined_0 = (0..evals.len()).map(|i| evals[i].0 * coeffs[i]).sum(); let evals_combined_2 = (0..evals.len()).map(|i| evals[i].1 * coeffs[i]).sum(); @@ -176,19 +253,45 @@ impl SumcheckProof { r.push(r_i); // bound all tables to the verifier's challenge - zip_with_for_each!(par_iter_mut, (poly_A_vec, poly_B_vec), |poly_A, poly_B| { - let _ = rayon::join( - || poly_A.bind_poly_var_top(&r_i), - || poly_B.bind_poly_var_top(&r_i), - ); - }); + zip_with_for_each!( + ( + num_rounds.par_iter(), + poly_A_vec.par_iter_mut(), + poly_B_vec.par_iter_mut() + ), + |num_rounds, poly_A, poly_B| { + if remaining_rounds <= *num_rounds { + let _ = rayon::join( + || poly_A.bind_poly_var_top(&r_i), + || poly_B.bind_poly_var_top(&r_i), + ); + } + } + ); e = poly.evaluate(&r_i); quad_polys.push(poly.compress()); } + poly_A_vec.iter().for_each(|p| assert_eq!(p.len(), 1)); + poly_B_vec.iter().for_each(|p| assert_eq!(p.len(), 1)); + + let poly_A_final = poly_A_vec + .into_iter() + .map(|poly| poly[0]) + .collect::>(); + let poly_B_final = poly_B_vec + .into_iter() + .map(|poly| poly[0]) + .collect::>(); + + let eval_expected = zip_with!( + iter, + (poly_A_final, poly_B_final, coeffs), + |eA, eB, coeff| comb_func(eA, eB) * coeff + ) + .sum::(); + assert_eq!(e, eval_expected); - let poly_A_final = (0..poly_A_vec.len()).map(|i| poly_A_vec[i][0]).collect(); - let poly_B_final = (0..poly_B_vec.len()).map(|i| poly_B_vec[i][0]).collect(); let claims_prod = (poly_A_final, poly_B_final); Ok((SumcheckProof::new(quad_polys), r, claims_prod)) @@ -357,4 +460,5 @@ impl SumcheckProof { vec![poly_A[0], poly_B[0], poly_C[0], poly_D[0]], )) } + }