From 2f502fd8570fe4e9cff36eea5bbd6fef22002898 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Lindstr=C3=B8m?= Date: Wed, 2 Oct 2024 09:37:36 +0200 Subject: [PATCH] BLS fast sums (#840) * Group elements to and from uncompressed form * Add ToFromUncompressedBytes impl for G2Elements * docs + tests * add sum function * clippy * benchmarks * add fast sum function * very fast version works * clean up * keep only 'safe' version * clean up * more test cases * name * add to_bytes method * Clean up * benchmarks + docs * regression test * review comments #1 * test * tests * test * test * use vector of ptrs * remove redundant comment --- fastcrypto/benches/groups.rs | 57 +++++++++- fastcrypto/src/groups/bls12381.rs | 112 ++++++++++++++++++- fastcrypto/src/tests/bls12381_group_tests.rs | 108 +++++++++++++++++- 3 files changed, 266 insertions(+), 11 deletions(-) diff --git a/fastcrypto/benches/groups.rs b/fastcrypto/benches/groups.rs index 16626686bb..55bb5c9ba7 100644 --- a/fastcrypto/benches/groups.rs +++ b/fastcrypto/benches/groups.rs @@ -7,16 +7,16 @@ mod group_benches { use criterion::measurement::Measurement; use criterion::{measurement, BenchmarkGroup, BenchmarkId, Criterion}; use fastcrypto::groups::bls12381::{ - G1Element, G2Element, GTElement, Scalar as BlsScalar, G1_ELEMENT_BYTE_LENGTH, - G2_ELEMENT_BYTE_LENGTH, GT_ELEMENT_BYTE_LENGTH, SCALAR_LENGTH, + G1Element, G1ElementUncompressed, G2Element, GTElement, Scalar as BlsScalar, + G1_ELEMENT_BYTE_LENGTH, G2_ELEMENT_BYTE_LENGTH, GT_ELEMENT_BYTE_LENGTH, SCALAR_LENGTH, }; use fastcrypto::groups::multiplier::windowed::WindowedScalarMultiplier; use fastcrypto::groups::multiplier::ScalarMultiplier; use fastcrypto::groups::ristretto255::RistrettoPoint; use fastcrypto::groups::secp256r1::ProjectivePoint; use fastcrypto::groups::{ - secp256r1, FromTrustedByteArray, GroupElement, HashToGroupElement, MultiScalarMul, Pairing, - Scalar, + bls12381, secp256r1, FromTrustedByteArray, GroupElement, HashToGroupElement, + MultiScalarMul, Pairing, Scalar, }; use fastcrypto::serde_helpers::ToFromByteArray; use rand::thread_rng; @@ -211,6 +211,54 @@ mod group_benches { pairing_single::("BLS12381-G1", &mut group); } + fn sum(c: &mut Criterion) { + static NUMBER_OF_TERMS: [usize; 4] = [10, 100, 500, 1000]; + + for n in NUMBER_OF_TERMS { + let terms: Vec = (0..n) + .map(|_| G1Element::generator() * bls12381::Scalar::rand(&mut thread_rng())) + .collect(); + + let terms_uncompressed = terms + .iter() + .map(G1ElementUncompressed::from) + .map(G1ElementUncompressed::into_byte_array) + .collect::>(); + + let terms_compressed = terms + .iter() + .map(G1Element::to_byte_array) + .collect::>(); + + c.bench_function(&format!("Sum/BLS12381-G1/{} uncompressed", n), move |b| { + b.iter_batched( + || terms_uncompressed.clone(), + |t| { + let terms_deserialized = t + .into_iter() + .map(G1ElementUncompressed::from_trusted_byte_array) + .collect::>(); + G1ElementUncompressed::sum(terms_deserialized.as_slice()) + }, + criterion::BatchSize::SmallInput, + ) + }); + + c.bench_function(&format!("Sum/BLS12381-G1/{} compressed", n), move |b| { + b.iter_batched( + || terms_compressed.clone(), + |t| { + t.iter() + .map(G1Element::from_trusted_byte_array) + .map(Result::unwrap) + .reduce(|a, b| a + b) + }, + criterion::BatchSize::SmallInput, + ) + }); + } + } + /// Implementation of a `Multiplier` where scalar multiplication is done without any pre-computation by /// simply calling the GroupElement implementation. Only used for benchmarking. struct DefaultMultiplier(G); @@ -294,6 +342,7 @@ mod group_benches { pairing, double_scale, blst_msm, + sum, } } diff --git a/fastcrypto/src/groups/bls12381.rs b/fastcrypto/src/groups/bls12381.rs index 3d07a1cb8d..e3bf5113c1 100644 --- a/fastcrypto/src/groups/bls12381.rs +++ b/fastcrypto/src/groups/bls12381.rs @@ -4,7 +4,7 @@ use crate::bls12381::min_pk::DST_G2; use crate::bls12381::min_sig::DST_G1; use crate::encoding::{Encoding, Hex}; -use crate::error::{FastCryptoError, FastCryptoResult}; +use crate::error::{FastCryptoError, FastCryptoError::InvalidInput, FastCryptoResult}; use crate::groups::{ FiatShamirChallenge, FromTrustedByteArray, GroupElement, HashToGroupElement, MultiScalarMul, Pairing, Scalar as ScalarType, @@ -20,11 +20,12 @@ use blst::{ blst_fr_from_scalar, blst_fr_from_uint64, blst_fr_inverse, blst_fr_mul, blst_fr_rshift, blst_fr_sub, blst_hash_to_g1, blst_hash_to_g2, blst_lendian_from_scalar, blst_miller_loop, blst_p1, blst_p1_add_or_double, blst_p1_affine, blst_p1_cneg, blst_p1_compress, - blst_p1_from_affine, blst_p1_in_g1, blst_p1_mult, blst_p1_to_affine, blst_p1_uncompress, - blst_p2, blst_p2_add_or_double, blst_p2_affine, blst_p2_cneg, blst_p2_compress, - blst_p2_from_affine, blst_p2_in_g2, blst_p2_mult, blst_p2_to_affine, blst_p2_uncompress, - blst_scalar, blst_scalar_fr_check, blst_scalar_from_be_bytes, blst_scalar_from_bendian, - blst_scalar_from_fr, p1_affines, p2_affines, BLS12_381_G1, BLS12_381_G2, BLST_ERROR, + blst_p1_deserialize, blst_p1_from_affine, blst_p1_in_g1, blst_p1_mult, blst_p1_serialize, + blst_p1_to_affine, blst_p1_uncompress, blst_p1s_add, blst_p2, blst_p2_add_or_double, + blst_p2_affine, blst_p2_cneg, blst_p2_compress, blst_p2_from_affine, blst_p2_in_g2, + blst_p2_mult, blst_p2_to_affine, blst_p2_uncompress, blst_scalar, blst_scalar_fr_check, + blst_scalar_from_be_bytes, blst_scalar_from_bendian, blst_scalar_from_fr, p1_affines, + p2_affines, BLS12_381_G1, BLS12_381_G2, BLST_ERROR, }; use fastcrypto_derive::GroupOpsExtend; use hex_literal::hex; @@ -333,6 +334,105 @@ impl Debug for G1Element { serialize_deserialize_with_to_from_byte_array!(G1Element); generate_bytes_representation!(G1Element, G1_ELEMENT_BYTE_LENGTH, G1ElementAsBytes); +/// An uncompressed serialization of a G1 element. This format is two times longer than the compressed +/// format used by `G1Element::serialize`, but is much faster to deserialize. +/// +/// The intended use of this struct is to deserialize and sum a large number of G1 elements without +/// having to decompress them first. +#[derive(Clone, Debug)] +#[repr(transparent)] +pub struct G1ElementUncompressed(pub(crate) [u8; 2 * G1_ELEMENT_BYTE_LENGTH]); + +impl From<&G1Element> for G1ElementUncompressed { + fn from(element: &G1Element) -> Self { + let mut bytes = [0u8; 2 * G1_ELEMENT_BYTE_LENGTH]; + unsafe { + blst_p1_serialize(bytes.as_mut_ptr(), &element.0); + } + G1ElementUncompressed(bytes) + } +} + +impl TryFrom<&G1ElementUncompressed> for G1Element { + type Error = FastCryptoError; + + fn try_from(value: &G1ElementUncompressed) -> Result { + // See https://github.com/supranational/blst for details on the serialization format. + + // Note that `blst_p1_deserialize` accepts both compressed and uncompressed serializations, + // so we check that the compressed bit flag (the 1st) is not set. The third is used for + // compressed points to indicate sign of the y-coordinate and should also not be set. + if value.0[0] & 0x20 != 0 || value.0[0] & 0x80 != 0 { + return Err(InvalidInput); + } + + let mut ret = blst_p1::default(); + unsafe { + let mut affine = blst_p1_affine::default(); + if blst_p1_deserialize(&mut affine, value.0.as_ptr()) != BLST_ERROR::BLST_SUCCESS { + return Err(InvalidInput); + } + blst_p1_from_affine(&mut ret, &affine); + + if !blst_p1_in_g1(&ret) { + return Err(InvalidInput); + } + } + Ok(G1Element(ret)) + } +} + +impl G1ElementUncompressed { + /// Create a new `G1ElementUncompressed` from a byte array. + /// The input is not validated so it should come from a trusted source. + /// + /// See [the blst docs](https://github.com/supranational/blst/tree/master?tab=readme-ov-file#serialization-format) for details about the uncompressed serialization format. + pub fn from_trusted_byte_array(bytes: [u8; 2 * G1_ELEMENT_BYTE_LENGTH]) -> Self { + Self(bytes) + } + + /// Get the byte array representation of this element. + pub fn into_byte_array(self) -> [u8; 2 * G1_ELEMENT_BYTE_LENGTH] { + self.0 + } + + /// This will never fail if the input is a valid G1 element. + fn to_blst_p1_affine(&self) -> FastCryptoResult { + let mut affine = blst_p1_affine::default(); + unsafe { + // This fails if the point is not on the curve or if it is (0, ±2) which is on the curve + // but not in the G1 subgroup. See https://github.com/supranational/blst/blob/6f3136ffb636974166a93f2f25436854fe8d10ff/src/e1.c#L296-L326. + // A subgroup check is not performed here. + if blst_p1_deserialize(&mut affine, self.0.as_ptr()) != BLST_ERROR::BLST_SUCCESS { + return Err(InvalidInput); + } + } + Ok(affine) + } + + /// Compute the sum of a slice of uncompressed G1 elements. + /// + /// This function will never fail if the inputs are valid G1 element. + pub fn sum(terms: &[G1ElementUncompressed]) -> FastCryptoResult { + if terms.is_empty() { + return Ok(G1Element::zero()); + } + + let affine_points: Vec = terms + .iter() + .map(G1ElementUncompressed::to_blst_p1_affine) + .collect::>>()?; + + let mut ret = blst_p1::default(); + let p = affine_points + .iter() + .map(|p| p as *const _) + .collect::>(); + unsafe { blst_p1s_add(&mut ret, p.as_ptr(), p.len()) }; + Ok(G1Element(ret)) + } +} + impl Add for G2Element { type Output = Self; diff --git a/fastcrypto/src/tests/bls12381_group_tests.rs b/fastcrypto/src/tests/bls12381_group_tests.rs index e4f494c220..09275a8dbd 100644 --- a/fastcrypto/src/tests/bls12381_group_tests.rs +++ b/fastcrypto/src/tests/bls12381_group_tests.rs @@ -2,7 +2,10 @@ // SPDX-License-Identifier: Apache-2.0 use crate::bls12381::min_pk::{BLS12381KeyPair, BLS12381Signature}; -use crate::groups::bls12381::{reduce_mod_uniform_buffer, G1Element, G2Element, GTElement, Scalar}; +use crate::groups::bls12381::{ + reduce_mod_uniform_buffer, G1Element, G1ElementUncompressed, G2Element, GTElement, Scalar, + G1_ELEMENT_BYTE_LENGTH, +}; use crate::groups::{ FromTrustedByteArray, GroupElement, HashToGroupElement, MultiScalarMul, Pairing, Scalar as ScalarTrait, @@ -653,3 +656,106 @@ fn test_serialization_gt() { assert!(GTElement::from_trusted_byte_array(&bytes).is_ok()); assert!(GTElement::from_byte_array(&bytes).is_err()); } + +#[test] +fn test_g1_to_uncompressed() { + let a = G1Element::generator() * Scalar::from(7u128); + + let uncompressed_bytes = G1ElementUncompressed::from(&a); + + // Compressed bit flags (1 and 3) should not be set. + assert_eq!(uncompressed_bytes.0[0] & 0xA0, 0); + + // Infinity bit flag (2) should not be set. + assert_eq!(uncompressed_bytes.0[0] & 0x40, 0); + + // Regression test + assert_eq!(&uncompressed_bytes.0, hex::decode("1928f3beb93519eecf0145da903b40a4c97dca00b21f12ac0df3be9116ef2ef27b2ae6bcd4c5bc2d54ef5a70627efcb7108dadbaa4b636445639d5ae3089b3c43a8a1d47818edd1839d7383959a41c10fdc66849cfa1b08c5a11ec7e28981a1c").unwrap().as_slice()); + + // Check round-trip + let b = G1Element::try_from(&uncompressed_bytes).unwrap(); + assert_eq!(a, b); + + // Simply padding a compressed serialization with 0's will fail + let mut padded = b.to_byte_array().to_vec(); + padded.extend_from_slice(&[0u8; G1_ELEMENT_BYTE_LENGTH]); + assert_eq!(padded.len(), 2 * G1_ELEMENT_BYTE_LENGTH); + let uncompressed = G1ElementUncompressed::from_trusted_byte_array(padded.try_into().unwrap()); + assert!(G1Element::try_from(&uncompressed).is_err()); + + // A point not on the curve fails + let mut bytes = uncompressed_bytes.into_byte_array(); + bytes[1] += 1; + let uncompressed_bytes = G1ElementUncompressed::from_trusted_byte_array(bytes); + assert!(G1Element::try_from(&uncompressed_bytes).is_err()); + + // Serialize the point-at-infinity + let a = G1Element::zero(); + let uncompressed_bytes = G1ElementUncompressed::from(&a); + + // Only the point at infinity flag should be set. + assert_eq!(uncompressed_bytes.0[0], 0x40); + + // The remaining bytes should all be zero + assert_eq!( + uncompressed_bytes.0[1..], + [0u8; G1_ELEMENT_BYTE_LENGTH * 2 - 1] + ); + + // All zeros + let uncompressed = + G1ElementUncompressed::from_trusted_byte_array([0u8; 2 * G1_ELEMENT_BYTE_LENGTH]); + assert!(G1Element::try_from(&uncompressed).is_err()); +} + +#[test] +fn test_g1_sum() { + // Empty sum + assert_eq!(G1ElementUncompressed::sum(&[]).unwrap(), G1Element::zero()); + + // Non-trivial sum + let a = G1Element::generator(); + let b = G1Element::generator() * Scalar::from(2u128); + let c = G1Element::generator() * Scalar::from(3u128); + let mut bytes: Vec = vec![(&a).into(), (&b).into(), (&c).into()]; + let sum = G1ElementUncompressed::sum(&bytes).unwrap(); + assert_eq!(sum, G1Element::generator() * Scalar::from(6u128)); + + // Adding zeros doesn't change anything + bytes.push(G1ElementUncompressed::from(&G1Element::zero())); + let sum = G1ElementUncompressed::sum(&bytes).unwrap(); + assert_eq!(sum, G1Element::generator() * Scalar::from(6u128)); + + // Equal elements in sum + let bytes = vec![(&b).into(), (&b).into()]; + let sum = G1ElementUncompressed::sum(&bytes).unwrap(); + assert_eq!(sum, G1Element::generator() * Scalar::from(4u128)); + + // Singleton sum + let bytes = [(&b).into()]; + let sum = G1ElementUncompressed::sum(&bytes).unwrap(); + assert_eq!(sum, b); + + // Adding zero's + let mut bytes = vec![G1ElementUncompressed::from(&G1Element::zero())]; + let sum = G1ElementUncompressed::sum(&bytes).unwrap(); + assert_eq!(sum, G1Element::zero()); + bytes.push(G1ElementUncompressed::from(&G1Element::zero())); + let sum = G1ElementUncompressed::sum(&bytes).unwrap(); + assert_eq!(sum, G1Element::zero()); +} + +#[test] +fn test_g1_large_sum() { + let mut rng = thread_rng(); + let n: usize = 100; + let points: Vec = (0..n) + .map(|_| G1Element::generator() * Scalar::rand(&mut rng)) + .collect(); + let expected = points.iter().fold(G1Element::zero(), |acc, p| acc + p); + + let as_uncompressed: Vec = + points.iter().map(G1ElementUncompressed::from).collect(); + let sum = G1ElementUncompressed::sum(as_uncompressed.as_slice()).unwrap(); + assert_eq!(expected, sum); +}