From 5c1000d5b30866d9cc378a71ed310e22fcfcd4a8 Mon Sep 17 00:00:00 2001 From: "senww.eth" Date: Mon, 25 Sep 2023 23:33:50 -0500 Subject: [PATCH] add ipa implementation --- halo2-ecc/src/ipa/inner_product_argument.rs | 135 +++++++++++ halo2-ecc/src/ipa/mod.rs | 3 + halo2-ecc/src/ipa/tests/mod.rs | 244 ++++++++++++++++++++ halo2-ecc/src/lib.rs | 1 + 4 files changed, 383 insertions(+) create mode 100644 halo2-ecc/src/ipa/inner_product_argument.rs create mode 100644 halo2-ecc/src/ipa/mod.rs create mode 100644 halo2-ecc/src/ipa/tests/mod.rs diff --git a/halo2-ecc/src/ipa/inner_product_argument.rs b/halo2-ecc/src/ipa/inner_product_argument.rs new file mode 100644 index 00000000..17645d2d --- /dev/null +++ b/halo2-ecc/src/ipa/inner_product_argument.rs @@ -0,0 +1,135 @@ +#![allow(non_snake_case)] +#![allow(dead_code)] +use crate::bigint::{CRTInteger, ProperCrtUint}; +use crate::fields::fp::Reduced; +use crate::fields::{fp::FpChip, FieldChip}; +use halo2_base::utils::BigPrimeField; +use halo2_base::{utils::CurveAffineExt, AssignedValue, Context}; + +use crate::ecc::{multi_scalar_multiply, EcPoint, EccChip}; + +/// Computes three vectors of verification scalars \\([u\_{i}^{2}]\\), \\([u\_{i}^{-2}]\\) and \\([s\_{i}]\\) for combined multiscalar multiplication +/// u_{i} is provided as input, assume is checked to be in [0, n - 1] +/// returns (u_{i}^{2}, u_{i}^{-2}, s_{i}) with size k, k, 2^k +pub fn verification_scalars( + ctx: &mut Context, + u: Vec, SF>>, // size = k, u_i < n, randomness generated by fiat-shamir transform + scalar_chip: &FpChip, +) -> (Vec>, Vec>, Vec>) { + let lg_n = u.len(); + let u_pow_two: Vec<_> = u + .iter() + .map(|u_i| { + scalar_chip.mul( + ctx, + CRTInteger::from(ProperCrtUint::from((*u_i).clone())), + CRTInteger::from(ProperCrtUint::from((*u_i).clone())), + ) + }) + .collect(); + + let sf_one = scalar_chip.load_constant(ctx, SF::ONE); + let u_invert: Vec<_> = u + .iter() + .map(|u_i| scalar_chip.divide(ctx, sf_one.clone(), ProperCrtUint::from((*u_i).clone()))) + .collect(); + let u_inv_pow_two: Vec<_> = u_invert + .iter() + .map(|u_i| { + scalar_chip.mul(ctx, CRTInteger::from((*u_i).clone()), CRTInteger::from((*u_i).clone())) + }) + .collect(); + + // Compute 1/(u_k...u_1) + let allinv = + u_invert.iter().fold(sf_one.clone(), |acc, x| scalar_chip.mul(ctx, acc, (*x).clone())); + // compute s + let mut s = Vec::with_capacity(2_usize.pow(lg_n as u32)); + println!("s capacity: {}", s.capacity()); + s.push(allinv); + for i in 1..2_usize.pow(lg_n as u32) { + let lg_i = (32 - 1 - (i as u32).leading_zeros()) as usize; + let k = 1 << lg_i; + let u_lg_i_sq = u_pow_two[(lg_n - 1) - lg_i].clone(); + s.push(scalar_chip.mul(ctx, s[i - k].clone(), u_lg_i_sq)); + } + + (u_pow_two, u_inv_pow_two, s) +} + +// CF is the coordinate field of GA +// SF is the scalar field of GA +// p = base field modulus +// n = scalar field modulus +/// follow verification equation in https://doc-internal.dalek.rs/bulletproofs/inner_product_proof/index.html +pub fn inner_product_argument( + chip: &EccChip>, + ctx: &mut Context, + P: EcPoint as FieldChip>::FieldPoint>, // commitment with form P = + + Q + G: Vec as FieldChip>::FieldPoint>>, // size n = 2^k + H: Vec as FieldChip>::FieldPoint>>, // size n = 2^k + L: Vec as FieldChip>::FieldPoint>>, // size k, + R: Vec as FieldChip>::FieldPoint>>, // size k, + Q: EcPoint as FieldChip>::FieldPoint>, + u: Vec>, // size = k, u_i < n, randomness generated by fiat-shamir transform + a: ProperCrtUint, // a < n + b: ProperCrtUint, // b < n + var_window_bits: usize, +) -> AssignedValue +where + GA: CurveAffineExt, +{ + // get FpChip for SF + let base_chip = chip.field_chip; + let scalar_chip = + FpChip::::new(base_chip.range, base_chip.limb_bits, base_chip.num_limbs); + + // validate vector length + let k = G.len().ilog2(); + assert_eq!(G.len(), H.len()); + assert_eq!(G.len(), 1 << k); + assert_eq!(u.len() as u32, k); + assert_eq!(L.len() as u32, k); + assert_eq!(R.len() as u32, k); + + // validate u, a, b < n + let a_valid = scalar_chip.enforce_less_than(ctx, a); + let b_valid = scalar_chip.enforce_less_than(ctx, b); + // todo: should enforce u_{i} != 0 + let u_valid: Vec, SF>> = + u.iter().map(|u_i| scalar_chip.enforce_less_than(ctx, (*u_i).clone())).collect(); + + let (u_pow_two, u_inv_pow_two, s) = verification_scalars(ctx, u_valid, &scalar_chip); + + let a_s: Vec> = + s.iter().map(|s_i| scalar_chip.mul(ctx, a_valid.0.clone(), (*s_i).clone())).collect(); + let b_invert_s: Vec> = + s.iter().rev().map(|s_i| scalar_chip.mul(ctx, b_valid.0.clone(), (*s_i).clone())).collect(); + let neg_u_pow_two: Vec> = + u_pow_two.iter().map(|u_i| scalar_chip.negate(ctx, (*u_i).clone())).collect(); + let neg_u_inv_pow_two: Vec> = + u_inv_pow_two.iter().map(|u_i| scalar_chip.negate(ctx, (*u_i).clone())).collect(); + let a_b = scalar_chip.mul(ctx, a_valid.0, b_valid.0); + let p_prime = multi_scalar_multiply::<_, _, GA>( + base_chip, + ctx, + &(G.iter() + .chain(H.iter()) + .chain(std::iter::once(&Q)) + .chain(L.iter()) + .chain(R.iter()) + .cloned() + .collect::>()), + a_s.iter() + .chain(b_invert_s.iter()) + .chain(std::iter::once(&a_b)) + .chain(neg_u_pow_two.iter()) + .chain(neg_u_inv_pow_two.iter()) + .map(|x| x.limbs().to_vec()) + .collect(), + base_chip.limb_bits, + var_window_bits, + ); + + chip.is_equal(ctx, p_prime, P) +} diff --git a/halo2-ecc/src/ipa/mod.rs b/halo2-ecc/src/ipa/mod.rs new file mode 100644 index 00000000..14ef2931 --- /dev/null +++ b/halo2-ecc/src/ipa/mod.rs @@ -0,0 +1,3 @@ +mod inner_product_argument; +#[cfg(test)] +mod tests; diff --git a/halo2-ecc/src/ipa/tests/mod.rs b/halo2-ecc/src/ipa/tests/mod.rs new file mode 100644 index 00000000..c1025acb --- /dev/null +++ b/halo2-ecc/src/ipa/tests/mod.rs @@ -0,0 +1,244 @@ +#![allow(non_snake_case)] +use crate::fields::FpStrategy; +use crate::halo2_proofs::{ + arithmetic::CurveAffine, + halo2curves::bn256::Fr, + halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, +}; +use crate::ipa::inner_product_argument::inner_product_argument; +use crate::secp256k1::{FpChip, FqChip}; +use crate::{ecc::EccChip, fields::FieldChip}; +use halo2_base::gates::RangeChip; +use halo2_base::halo2_proofs::arithmetic::Field; +use halo2_base::halo2_proofs::halo2curves::group::prime::PrimeCurveAffine; +use halo2_base::utils::testing::base_test; +use halo2_base::utils::BigPrimeField; +use halo2_base::Context; +use rand::rngs::StdRng; +use rand::SeedableRng; +use serde::{Deserialize, Serialize}; +use std::fs::File; + +fn inner_product(a: Vec, b: Vec) -> Fq { + assert_eq!(a.len(), b.len()); + let a_b = a + .iter() + .zip(b.iter()) + .map(|(a, b)| a * b) + .fold(::ScalarExt::zero(), |acc, x| acc + x); + a_b +} + +fn multi_scalar_multiply( + a: impl IntoIterator, + G: Vec, +) -> Secp256k1Affine { + let a_G = G + .iter() + .zip(a.into_iter()) + .map(|(g, a)| Secp256k1Affine::from(g * a)) + .fold(Secp256k1Affine::identity(), |acc, x| Secp256k1Affine::from(acc + x)); + a_G +} + +fn random_inputs_ipa_prover_output(k: usize, rng: &mut StdRng) -> IPAInput { + // generate random inputs + let mut G = (0..2_usize.pow(k as u32)) + .map(|_| { + Secp256k1Affine::from( + Secp256k1Affine::generator() + * ::ScalarExt::random(rng.clone()), + ) + }) + .collect::>(); + let G_origin = G.clone(); + + let mut H = (0..2_usize.pow(k as u32)) + .map(|_| { + Secp256k1Affine::from( + Secp256k1Affine::generator() + * ::ScalarExt::random(rng.clone()), + ) + }) + .collect::>(); + let H_origin = H.clone(); + + let Q = Secp256k1Affine::from( + Secp256k1Affine::generator() + * ::ScalarExt::random(rng.clone()), + ); + + let mut a = (0..2_usize.pow(k as u32)) + .map(|_| ::ScalarExt::random(rng.clone())) + .collect::>(); + + let mut b = (0..2_usize.pow(k as u32)) + .map(|_| ::ScalarExt::random(rng.clone())) + .collect::>(); + + let u = (0..k) + .map(|_| ::ScalarExt::random(rng.clone())) + .collect::>(); + + // P = + + Q + let a_G = multi_scalar_multiply(a.clone(), G.clone()); + + let b_H = multi_scalar_multiply(b.clone(), H.clone()); + + let a_b = inner_product(a.clone(), b.clone()); + + let P = Secp256k1Affine::from(a_G + b_H + (Q * a_b)); + + // IPA proof generation + let mut L_vec = Vec::with_capacity(k); + let mut R_vec = Vec::with_capacity(k); + + for j in 0..k { + let n = a.len() / 2; + let (a_L, a_R) = a.split_at_mut(n); + let (b_L, b_R) = b.split_at_mut(n); + let (G_L, G_R) = G.split_at_mut(n); + let (H_L, H_R) = H.split_at_mut(n); + let L_j = Secp256k1Affine::from( + multi_scalar_multiply(a_L.to_owned(), G_R.to_owned()) + + multi_scalar_multiply(b_R.to_owned(), H_L.to_owned()) + + (Q * inner_product(a_L.to_owned(), b_R.to_owned())), + ); + let R_j = Secp256k1Affine::from( + multi_scalar_multiply(a_R.to_owned(), G_L.to_owned()) + + multi_scalar_multiply(b_L.to_owned(), H_R.to_owned()) + + (Q * inner_product(a_R.to_owned(), b_L.to_owned())), + ); + L_vec.push(L_j); + R_vec.push(R_j); + + // a = a_L * u_j + u_j^-1 * a_R + a = a_L + .iter_mut() + .zip(a_R.iter_mut()) + .map(|(a_L, a_R)| { + let a_L = *a_L; + let a_R = *a_R; + let u_j = u.get(j).unwrap(); + let u_j_inv = u_j.invert().unwrap(); + a_L * u_j + a_R * u_j_inv + }) + .collect::>(); + + // b = b_L * u_j^-1 + u_j * b_R + b = b_L + .iter_mut() + .zip(b_R.iter_mut()) + .map(|(b_L, b_R)| { + let b_L = *b_L; + let b_R = *b_R; + let u_j = u.get(j).unwrap(); + let u_j_inv = u_j.invert().unwrap(); + b_L * u_j_inv + b_R * u_j + }) + .collect::>(); + + // G = G_L * u_j^-1 + u_j * G_R + G = G_L + .iter_mut() + .zip(G_R.iter_mut()) + .map(|(G_L, G_R)| { + let G_L = *G_L; + let G_R = *G_R; + let u_j = u.get(j).unwrap(); + let u_j_inv = u_j.invert().unwrap(); + Secp256k1Affine::from(G_L * u_j_inv + G_R * u_j) + }) + .collect::>(); + // H = H_L * u_j + u_j^-1 * H_R + H = H_L + .iter_mut() + .zip(H_R.iter_mut()) + .map(|(H_L, H_R)| { + let H_L = *H_L; + let H_R = *H_R; + let u_j = u.get(j).unwrap(); + let u_j_inv = u_j.invert().unwrap(); + Secp256k1Affine::from(H_L * u_j + H_R * u_j_inv) + }) + .collect::>(); + } + + assert!(a.len() == 1); + assert!(b.len() == 1); + let a = a.get(0).unwrap(); + let b = b.get(0).unwrap(); + IPAInput { a: *a, b: *b, G_origin, H_origin, Q, P, L_vec, R_vec, u } +} + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct CircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, +} + +#[derive(Clone, Debug)] +pub struct IPAInput { + pub a: Fq, + pub b: Fq, + pub G_origin: Vec, + pub H_origin: Vec, + pub Q: Secp256k1Affine, + pub P: Secp256k1Affine, + pub L_vec: Vec, + pub R_vec: Vec, + pub u: Vec, +} + +pub fn ipa_test( + ctx: &mut Context, + range: &RangeChip, + params: CircuitParams, + input: IPAInput, +) -> F { + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(range, params.limb_bits, params.num_limbs); + + let ecc_chip = EccChip::>::new(&fp_chip); + let a = fq_chip.load_private(ctx, input.a); + let b = fq_chip.load_private(ctx, input.b); + let G = input.G_origin.iter().map(|g| ecc_chip.assign_point(ctx, *g)).collect::>(); + let H = input.H_origin.iter().map(|h| ecc_chip.assign_point(ctx, *h)).collect::>(); + let Q = ecc_chip.assign_point(ctx, input.Q); + let P = ecc_chip.assign_point(ctx, input.P); + let L = input.L_vec.iter().map(|l| ecc_chip.assign_point(ctx, *l)).collect::>(); + let R = input.R_vec.iter().map(|r| ecc_chip.assign_point(ctx, *r)).collect::>(); + let u = input.u.iter().map(|u| fq_chip.load_private(ctx, *u)).collect::>(); + // test inner product argument proof verification + let is_valid_ipa_proof = inner_product_argument::( + &ecc_chip, ctx, P, G, H, L, R, Q, u, a, b, 4, + ); + *is_valid_ipa_proof.value() +} + +pub fn run_test(input: IPAInput) { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let res = base_test() + .k(params.degree) + .lookup_bits(params.lookup_bits) + .run(|ctx, range| ipa_test(ctx, range, params, input)); + assert_eq!(res, Fr::ONE); +} + +#[test] +fn test_secp256k1_ipa() { + let mut rng = StdRng::seed_from_u64(0); + let input = random_inputs_ipa_prover_output(2, &mut rng); + run_test(input); +} diff --git a/halo2-ecc/src/lib.rs b/halo2-ecc/src/lib.rs index 5b3f191a..edda454c 100644 --- a/halo2-ecc/src/lib.rs +++ b/halo2-ecc/src/lib.rs @@ -9,6 +9,7 @@ pub mod fields; pub mod bn254; pub mod grumpkin; +pub mod ipa; pub mod secp256k1; pub use halo2_base;