diff --git a/src/msm.rs b/src/msm.rs index 25af9711..a2d3fae9 100644 --- a/src/msm.rs +++ b/src/msm.rs @@ -14,8 +14,7 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { // Booth encoding: // * step by `window` size // * slice by size of `window + 1`` - // * each window overlap by 1 bit - // * append a zero bit to the least significant end + // * each window overlap by 1 bit * append a zero bit to the least significant end // Indexing rule for example window size 3 where we slice by 4 bits: // `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]`` // So we can reduce the bucket size without preprocessing scalars @@ -54,7 +53,9 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 { } } -fn batch_add( +// Batch addition without edge case handling: +// Will panic if a point is the identity or if two points share the x coordinate. +fn batch_add_nonexceptional( size: usize, buckets: &mut [BucketAffine], points: &[SchedulePoint], @@ -85,7 +86,9 @@ fn batch_add( acc *= *z; } - acc = acc.invert().unwrap(); + acc = acc + .invert() + .expect("Attempted to invert 0 at batch_add_nmonexceptional"); for ( ( @@ -112,6 +115,94 @@ fn batch_add( } } +/// Batch addition with edge case handling. +fn batch_add_exceptional( + size: usize, + buckets: &mut [BucketAffine], + points: &[SchedulePoint], + bases: &[Affine], +) { + let mut t = vec![C::Base::ZERO; size]; // Stores x2 - x1 + let mut z = vec![C::Base::ZERO; size]; // Stores y2 - y1 + let mut acc = C::Base::ONE; + + for ( + ( + SchedulePoint { + base_idx, + buck_idx, + sign, + }, + t, + ), + z, + ) in points.iter().zip(t.iter_mut()).zip(z.iter_mut()) + { + if buckets[*buck_idx].is_inf() { + // We assume bases[*base_idx] != infinity always. + continue; + } + + if buckets[*buck_idx].x() == bases[*base_idx].x { + // y-coordinate matches: + // 1. y1 == y2 and sign = false or + // 2. y1 != y2 and sign = true + // => ( y1 == y2) xor !sign + // (This uses the fact that x1 == x2 and both points satisfy the curve eq.) + if (buckets[*buck_idx].y() == bases[*base_idx].y) ^ !*sign { + // Doubling + let x_squared = bases[*base_idx].x.square(); + *z = buckets[*buck_idx].y() + buckets[*buck_idx].y(); // 2y + *t = acc * (x_squared + x_squared + x_squared); // acc * 3x^2 + acc *= *z; + continue; + } + // P + (-P) + buckets[*buck_idx].set_inf(); + continue; + } + // Addition + *z = buckets[*buck_idx].x() - bases[*base_idx].x; // x2 - x1 + if *sign { + *t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y); + } else { + *t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y); + } // y2 - y1 + acc *= *z; + } + + acc = acc + .invert() + .expect("Some edge case has not been handled properly"); + + for ( + ( + SchedulePoint { + base_idx, + buck_idx, + sign, + }, + t, + ), + z, + ) in points.iter().zip(t.iter()).zip(z.iter()).rev() + { + if buckets[*buck_idx].is_inf() { + // We assume bases[*base_idx] != infinity always. + continue; + } + let lambda = acc * t; + acc *= z; // update acc + let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); // x_result + if *sign { + buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y)); + } else { + buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y)); + } // y_result = lambda * (x1 - x_result) - y1 + buckets[*buck_idx].set_x(&x); + } +} + #[derive(Debug, Clone, Copy)] struct Affine { x: C::Base, @@ -207,6 +298,13 @@ impl BucketAffine { } } + fn is_inf(&self) -> bool { + match self { + Self::None => true, + Self::Point(_) => false, + } + } + fn set_x(&mut self, x: &C::Base) { match self { Self::None => panic!("::set_x None"), @@ -220,6 +318,13 @@ impl BucketAffine { Self::Point(ref mut a) => a.y = *y, } } + + fn set_inf(&mut self) { + match self { + Self::None => {} + Self::Point(_) => *self = Self::None, + } + } } struct Schedule { @@ -266,7 +371,7 @@ impl Schedule { fn execute(&mut self, bases: &[Affine]) { if self.ptr != 0 { - batch_add(self.ptr, &mut self.buckets, &self.set, bases); + batch_add_nonexceptional(self.ptr, &mut self.buckets, &self.set, bases); self.ptr = 0; self.set .iter_mut() @@ -473,7 +578,6 @@ pub fn best_multiexp_independent_points( #[cfg(test)] mod test { - use std::ops::Neg; use crate::bn256::{Fr, G1Affine, G1}; @@ -529,6 +633,7 @@ mod test { } } + #[cfg(test)] fn run_msm_cross(min_k: usize, max_k: usize) { let points = (0..1 << max_k) .map(|_| C::Curve::random(OsRng)) @@ -545,12 +650,12 @@ mod test { let points = &points[..1 << k]; let scalars = &scalars[..1 << k]; - let t0 = start_timer!(|| format!("cyclone k={}", k)); - let e0 = super::best_multiexp_independent_points(scalars, points); + let t0 = start_timer!(|| format!("cyclone indep k={}", k)); + let e0 = super::best_multiexp_independent_points(&scalars, &points); end_timer!(t0); let t1 = start_timer!(|| format!("older k={}", k)); - let e1 = super::best_multiexp(scalars, points); + let e1 = super::best_multiexp(&scalars, &points); end_timer!(t1); assert_eq!(e0, e1); }