Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BLS fast sums #840

Merged
merged 24 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions fastcrypto/benches/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ mod group_benches {
use criterion::measurement::Measurement;
use criterion::{measurement, BenchmarkGroup, BenchmarkId, Criterion};
use fastcrypto::groups::bls12381::{
G1Element, G2Element, GTElement, Scalar as BlsScalar, G1_ELEMENT_BYTE_LENGTH,
G2_ELEMENT_BYTE_LENGTH, GT_ELEMENT_BYTE_LENGTH, SCALAR_LENGTH,
G1Element, G1ElementUncompressed, G2Element, GTElement, Scalar as BlsScalar,
G1_ELEMENT_BYTE_LENGTH, G2_ELEMENT_BYTE_LENGTH, GT_ELEMENT_BYTE_LENGTH, SCALAR_LENGTH,
};
use fastcrypto::groups::multiplier::windowed::WindowedScalarMultiplier;
use fastcrypto::groups::multiplier::ScalarMultiplier;
use fastcrypto::groups::ristretto255::RistrettoPoint;
use fastcrypto::groups::secp256r1::ProjectivePoint;
use fastcrypto::groups::{
secp256r1, FromTrustedByteArray, GroupElement, HashToGroupElement, MultiScalarMul, Pairing,
Scalar,
bls12381, secp256r1, FromTrustedByteArray, GroupElement, HashToGroupElement,
MultiScalarMul, Pairing, Scalar,
};
use fastcrypto::serde_helpers::ToFromByteArray;
use rand::thread_rng;
Expand Down Expand Up @@ -211,6 +211,54 @@ mod group_benches {
pairing_single::<G1Element, _>("BLS12381-G1", &mut group);
}

fn sum(c: &mut Criterion) {
static NUMBER_OF_TERMS: [usize; 4] = [10, 100, 500, 1000];

for n in NUMBER_OF_TERMS {
let terms: Vec<G1Element> = (0..n)
.map(|_| G1Element::generator() * bls12381::Scalar::rand(&mut thread_rng()))
.collect();

let terms_uncompressed = terms
.iter()
.map(G1ElementUncompressed::from)
.map(G1ElementUncompressed::into_byte_array)
.collect::<Vec<_>>();

let terms_compressed = terms
.iter()
.map(G1Element::to_byte_array)
.collect::<Vec<_>>();

c.bench_function(&format!("Sum/BLS12381-G1/{} uncompressed", n), move |b| {
b.iter_batched(
|| terms_uncompressed.clone(),
|t| {
let terms_deserialized = t
.into_iter()
.map(G1ElementUncompressed::from_trusted_byte_array)
.collect::<Vec<_>>();
G1ElementUncompressed::sum(terms_deserialized.as_slice())
},
criterion::BatchSize::SmallInput,
)
});

c.bench_function(&format!("Sum/BLS12381-G1/{} compressed", n), move |b| {
b.iter_batched(
|| terms_compressed.clone(),
|t| {
t.iter()
.map(G1Element::from_trusted_byte_array)
.map(Result::unwrap)
.reduce(|a, b| a + b)
},
criterion::BatchSize::SmallInput,
)
});
}
}

/// Implementation of a `Multiplier` where scalar multiplication is done without any pre-computation by
/// simply calling the GroupElement implementation. Only used for benchmarking.
struct DefaultMultiplier<G: GroupElement>(G);
Expand Down Expand Up @@ -294,6 +342,7 @@ mod group_benches {
pairing,
double_scale,
blst_msm,
sum,
}
}

Expand Down
112 changes: 106 additions & 6 deletions fastcrypto/src/groups/bls12381.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use crate::bls12381::min_pk::DST_G2;
use crate::bls12381::min_sig::DST_G1;
use crate::encoding::{Encoding, Hex};
use crate::error::{FastCryptoError, FastCryptoResult};
use crate::error::{FastCryptoError, FastCryptoError::InvalidInput, FastCryptoResult};
use crate::groups::{
FiatShamirChallenge, FromTrustedByteArray, GroupElement, HashToGroupElement, MultiScalarMul,
Pairing, Scalar as ScalarType,
Expand All @@ -20,11 +20,12 @@ use blst::{
blst_fr_from_scalar, blst_fr_from_uint64, blst_fr_inverse, blst_fr_mul, blst_fr_rshift,
blst_fr_sub, blst_hash_to_g1, blst_hash_to_g2, blst_lendian_from_scalar, blst_miller_loop,
blst_p1, blst_p1_add_or_double, blst_p1_affine, blst_p1_cneg, blst_p1_compress,
blst_p1_from_affine, blst_p1_in_g1, blst_p1_mult, blst_p1_to_affine, blst_p1_uncompress,
blst_p2, blst_p2_add_or_double, blst_p2_affine, blst_p2_cneg, blst_p2_compress,
blst_p2_from_affine, blst_p2_in_g2, blst_p2_mult, blst_p2_to_affine, blst_p2_uncompress,
blst_scalar, blst_scalar_fr_check, blst_scalar_from_be_bytes, blst_scalar_from_bendian,
blst_scalar_from_fr, p1_affines, p2_affines, BLS12_381_G1, BLS12_381_G2, BLST_ERROR,
blst_p1_deserialize, blst_p1_from_affine, blst_p1_in_g1, blst_p1_mult, blst_p1_serialize,
blst_p1_to_affine, blst_p1_uncompress, blst_p1s_add, blst_p2, blst_p2_add_or_double,
blst_p2_affine, blst_p2_cneg, blst_p2_compress, blst_p2_from_affine, blst_p2_in_g2,
blst_p2_mult, blst_p2_to_affine, blst_p2_uncompress, blst_scalar, blst_scalar_fr_check,
blst_scalar_from_be_bytes, blst_scalar_from_bendian, blst_scalar_from_fr, p1_affines,
p2_affines, BLS12_381_G1, BLS12_381_G2, BLST_ERROR,
};
use fastcrypto_derive::GroupOpsExtend;
use hex_literal::hex;
Expand Down Expand Up @@ -333,6 +334,105 @@ impl Debug for G1Element {
serialize_deserialize_with_to_from_byte_array!(G1Element);
generate_bytes_representation!(G1Element, G1_ELEMENT_BYTE_LENGTH, G1ElementAsBytes);

/// An uncompressed serialization of a G1 element. This format is two times longer than the compressed
/// format used by `G1Element::serialize`, but is much faster to deserialize.
///
/// The intended use of this struct is to deserialize and sum a large number of G1 elements without
/// having to decompress them first.
#[derive(Clone, Debug)]
#[repr(transparent)]
pub struct G1ElementUncompressed(pub(crate) [u8; 2 * G1_ELEMENT_BYTE_LENGTH]);

impl From<&G1Element> for G1ElementUncompressed {
jonas-lj marked this conversation as resolved.
Show resolved Hide resolved
fn from(element: &G1Element) -> Self {
let mut bytes = [0u8; 2 * G1_ELEMENT_BYTE_LENGTH];
unsafe {
blst_p1_serialize(bytes.as_mut_ptr(), &element.0);
}
G1ElementUncompressed(bytes)
}
}

impl TryFrom<&G1ElementUncompressed> for G1Element {
type Error = FastCryptoError;

fn try_from(value: &G1ElementUncompressed) -> Result<Self, Self::Error> {
// See https://github.com/supranational/blst for details on the serialization format.

// Note that `blst_p1_deserialize` accepts both compressed and uncompressed serializations,
// so we check that the compressed bit flag (the 1st) is not set. The third is used for
// compressed points to indicate sign of the y-coordinate and should also not be set.
if value.0[0] & 0x20 != 0 || value.0[0] & 0x80 != 0 {
return Err(InvalidInput);
}

let mut ret = blst_p1::default();
unsafe {
let mut affine = blst_p1_affine::default();
if blst_p1_deserialize(&mut affine, value.0.as_ptr()) != BLST_ERROR::BLST_SUCCESS {
return Err(InvalidInput);
}
blst_p1_from_affine(&mut ret, &affine);

if !blst_p1_in_g1(&ret) {
return Err(InvalidInput);
}
}
Ok(G1Element(ret))
}
}

impl G1ElementUncompressed {
/// Create a new `G1ElementUncompressed` from a byte array.
/// The input is not validated so it should come from a trusted source.
///
/// See [the blst docs](https://github.com/supranational/blst/tree/master?tab=readme-ov-file#serialization-format) for details about the uncompressed serialization format.
pub fn from_trusted_byte_array(bytes: [u8; 2 * G1_ELEMENT_BYTE_LENGTH]) -> Self {
Self(bytes)
}

/// Get the byte array representation of this element.
pub fn into_byte_array(self) -> [u8; 2 * G1_ELEMENT_BYTE_LENGTH] {
self.0
}

jonas-lj marked this conversation as resolved.
Show resolved Hide resolved
/// This will never fail if the input is a valid G1 element.
fn to_blst_p1_affine(&self) -> FastCryptoResult<blst_p1_affine> {
let mut affine = blst_p1_affine::default();
unsafe {
// This fails if the point is not on the curve or if it is (0, ±2) which is on the curve
// but not in the G1 subgroup. See https://github.com/supranational/blst/blob/6f3136ffb636974166a93f2f25436854fe8d10ff/src/e1.c#L296-L326.
// A subgroup check is not performed here.
if blst_p1_deserialize(&mut affine, self.0.as_ptr()) != BLST_ERROR::BLST_SUCCESS {
return Err(InvalidInput);
}
}
Ok(affine)
}

/// Compute the sum of a slice of uncompressed G1 elements.
///
/// This function will never fail if the inputs are valid G1 element.
pub fn sum(terms: &[G1ElementUncompressed]) -> FastCryptoResult<G1Element> {
if terms.is_empty() {
return Ok(G1Element::zero());
}

let affine_points: Vec<blst_p1_affine> = terms
.iter()
.map(G1ElementUncompressed::to_blst_p1_affine)
.collect::<FastCryptoResult<Vec<_>>>()?;

let mut ret = blst_p1::default();
let p = affine_points
.iter()
.map(|p| p as *const _)
.collect::<Vec<_>>();
unsafe { blst_p1s_add(&mut ret, p.as_ptr(), p.len()) };
Ok(G1Element(ret))
}
}

impl Add for G2Element {
type Output = Self;

Expand Down
108 changes: 107 additions & 1 deletion fastcrypto/src/tests/bls12381_group_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// SPDX-License-Identifier: Apache-2.0

use crate::bls12381::min_pk::{BLS12381KeyPair, BLS12381Signature};
use crate::groups::bls12381::{reduce_mod_uniform_buffer, G1Element, G2Element, GTElement, Scalar};
use crate::groups::bls12381::{
reduce_mod_uniform_buffer, G1Element, G1ElementUncompressed, G2Element, GTElement, Scalar,
G1_ELEMENT_BYTE_LENGTH,
};
use crate::groups::{
FromTrustedByteArray, GroupElement, HashToGroupElement, MultiScalarMul, Pairing,
Scalar as ScalarTrait,
Expand Down Expand Up @@ -653,3 +656,106 @@ fn test_serialization_gt() {
assert!(GTElement::from_trusted_byte_array(&bytes).is_ok());
assert!(GTElement::from_byte_array(&bytes).is_err());
}

#[test]
fn test_g1_to_uncompressed() {
let a = G1Element::generator() * Scalar::from(7u128);

let uncompressed_bytes = G1ElementUncompressed::from(&a);

// Compressed bit flags (1 and 3) should not be set.
assert_eq!(uncompressed_bytes.0[0] & 0xA0, 0);

// Infinity bit flag (2) should not be set.
assert_eq!(uncompressed_bytes.0[0] & 0x40, 0);

// Regression test
assert_eq!(&uncompressed_bytes.0, hex::decode("1928f3beb93519eecf0145da903b40a4c97dca00b21f12ac0df3be9116ef2ef27b2ae6bcd4c5bc2d54ef5a70627efcb7108dadbaa4b636445639d5ae3089b3c43a8a1d47818edd1839d7383959a41c10fdc66849cfa1b08c5a11ec7e28981a1c").unwrap().as_slice());

// Check round-trip
let b = G1Element::try_from(&uncompressed_bytes).unwrap();
assert_eq!(a, b);

// Simply padding a compressed serialization with 0's will fail
let mut padded = b.to_byte_array().to_vec();
padded.extend_from_slice(&[0u8; G1_ELEMENT_BYTE_LENGTH]);
assert_eq!(padded.len(), 2 * G1_ELEMENT_BYTE_LENGTH);
let uncompressed = G1ElementUncompressed::from_trusted_byte_array(padded.try_into().unwrap());
assert!(G1Element::try_from(&uncompressed).is_err());

// A point not on the curve fails
let mut bytes = uncompressed_bytes.into_byte_array();
bytes[1] += 1;
let uncompressed_bytes = G1ElementUncompressed::from_trusted_byte_array(bytes);
assert!(G1Element::try_from(&uncompressed_bytes).is_err());

// Serialize the point-at-infinity
let a = G1Element::zero();
let uncompressed_bytes = G1ElementUncompressed::from(&a);

// Only the point at infinity flag should be set.
assert_eq!(uncompressed_bytes.0[0], 0x40);

// The remaining bytes should all be zero
assert_eq!(
uncompressed_bytes.0[1..],
[0u8; G1_ELEMENT_BYTE_LENGTH * 2 - 1]
);

// All zeros
let uncompressed =
G1ElementUncompressed::from_trusted_byte_array([0u8; 2 * G1_ELEMENT_BYTE_LENGTH]);
assert!(G1Element::try_from(&uncompressed).is_err());
}

#[test]
fn test_g1_sum() {
// Empty sum
assert_eq!(G1ElementUncompressed::sum(&[]).unwrap(), G1Element::zero());

// Non-trivial sum
let a = G1Element::generator();
let b = G1Element::generator() * Scalar::from(2u128);
let c = G1Element::generator() * Scalar::from(3u128);
let mut bytes: Vec<G1ElementUncompressed> = vec![(&a).into(), (&b).into(), (&c).into()];
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, G1Element::generator() * Scalar::from(6u128));

// Adding zeros doesn't change anything
bytes.push(G1ElementUncompressed::from(&G1Element::zero()));
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, G1Element::generator() * Scalar::from(6u128));

// Equal elements in sum
let bytes = vec![(&b).into(), (&b).into()];
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, G1Element::generator() * Scalar::from(4u128));

// Singleton sum
let bytes = [(&b).into()];
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, b);

// Adding zero's
let mut bytes = vec![G1ElementUncompressed::from(&G1Element::zero())];
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, G1Element::zero());
bytes.push(G1ElementUncompressed::from(&G1Element::zero()));
let sum = G1ElementUncompressed::sum(&bytes).unwrap();
assert_eq!(sum, G1Element::zero());
}
jonas-lj marked this conversation as resolved.
Show resolved Hide resolved

#[test]
fn test_g1_large_sum() {
let mut rng = thread_rng();
let n: usize = 100;
let points: Vec<G1Element> = (0..n)
.map(|_| G1Element::generator() * Scalar::rand(&mut rng))
.collect();
let expected = points.iter().fold(G1Element::zero(), |acc, p| acc + p);

let as_uncompressed: Vec<G1ElementUncompressed> =
points.iter().map(G1ElementUncompressed::from).collect();
let sum = G1ElementUncompressed::sum(as_uncompressed.as_slice()).unwrap();
assert_eq!(expected, sum);
}
Loading