Skip to content

Commit

Permalink
fix: optimize compute_inner_product with parallel computation (#37)
Browse files Browse the repository at this point in the history
* Update arithmetic.rs

* Update arithmetic.rs

* Update arithmetic.rs

* Update arithmetic.rs
  • Loading branch information
crStiv authored Feb 3, 2025
1 parent a4140d7 commit 916c0bf
Showing 1 changed file with 44 additions and 6 deletions.
50 changes: 44 additions & 6 deletions halo2_proofs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use group::{
prime::PrimeCurveAffine,
Curve, GroupOpsOwned, ScalarMulOwned,
};
use rayon::prelude::*;

Check failure on line 11 in halo2_proofs/src/arithmetic.rs

View workflow job for this annotation

GitHub Actions / Build target wasm32-wasi

failed to resolve: use of undeclared crate or module `rayon`

use halo2curves::msm::msm_best;
pub use halo2curves::{CurveAffine, CurveExt};
Expand Down Expand Up @@ -110,16 +111,26 @@ pub fn eval_polynomial<F: Field>(poly: &[F], point: F) -> F {
/// This computes the inner product of two vectors `a` and `b`.
///
/// This function will panic if the two vectors are not the same size.
/// For vectors smaller than 32 elements, it uses sequential computation for better performance.
/// For larger vectors, it switches to parallel computation.
pub fn compute_inner_product<F: Field>(a: &[F], b: &[F]) -> F {
// TODO: parallelize?
assert_eq!(a.len(), b.len());

let mut acc = F::ZERO;
for (a, b) in a.iter().zip(b.iter()) {
acc += (*a) * (*b);

if a.len() < 32 {
// Use sequential computation for small vectors
let mut acc = F::ZERO;
for (a, b) in a.iter().zip(b.iter()) {
acc += (*a) * (*b);
}
return acc;
}

acc
// Use parallel computation
let mut result = F::ZERO;
parallelize(&mut [result], |results, _| {
results[0] = a.iter().zip(b.iter()).fold(F::ZERO, |acc, (a, b)| acc + (*a) * (*b));
});
result
}

/// Divides polynomial `a` in `X` by `X - b` with
Expand Down Expand Up @@ -328,3 +339,30 @@ fn test_lagrange_interpolate() {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use rand_core::OsRng;

#[test]
fn test_compute_inner_product() {
let rng = OsRng;

// Test small vectors (sequential)
let a_small: Vec<Fp> = (0..16).map(|_| Fp::random(rng)).collect();
let b_small: Vec<Fp> = (0..16).map(|_| Fp::random(rng)).collect();
let result_small = compute_inner_product(&a_small, &b_small);
let expected_small = a_small.iter().zip(b_small.iter())
.fold(Fp::ZERO, |acc, (a, b)| acc + (*a) * (*b));
assert_eq!(result_small, expected_small);

// Test large vectors (parallel)
let a_large: Vec<Fp> = (0..64).map(|_| Fp::random(rng)).collect();
let b_large: Vec<Fp> = (0..64).map(|_| Fp::random(rng)).collect();
let result_large = compute_inner_product(&a_large, &b_large);
let expected_large = a_large.iter().zip(b_large.iter())
.fold(Fp::ZERO, |acc, (a, b)| acc + (*a) * (*b));
assert_eq!(result_large, expected_large);
}
}

0 comments on commit 916c0bf

Please sign in to comment.