Skip to content

Commit

Permalink
ThresholdBls: accept Iterator directly where possible (#709)
Browse files Browse the repository at this point in the history
Instead of requiring a slice that we immediately and only call
`iter()` on, accept the Iterator. This can enable clients to avoid
extra copies.
  • Loading branch information
aschran authored Dec 13, 2023
1 parent 3ae45d6 commit 4d8041f
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 89 deletions.
6 changes: 3 additions & 3 deletions fastcrypto-tbls/benches/tbls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mod tbls_benches {
.collect::<Vec<_>>();

create.bench_function(format!("w={}", w).as_str(), |b| {
b.iter(|| ThresholdBls12381MinSig::partial_sign_batch(&shares, msg))
b.iter(|| ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg))
});
}
}
Expand All @@ -39,10 +39,10 @@ mod tbls_benches {
.map(|i| private_poly.eval(NonZeroU32::new(i as u32).unwrap()))
.collect::<Vec<_>>();

let sigs = ThresholdBls12381MinSig::partial_sign_batch(&shares, msg);
let sigs = ThresholdBls12381MinSig::partial_sign_batch(shares.iter(), msg);

create.bench_function(format!("w={}", w).as_str(), |b| {
b.iter(|| ThresholdBls12381MinSig::aggregate(w as u32, &sigs).unwrap())
b.iter(|| ThresholdBls12381MinSig::aggregate(w as u32, sigs.iter()).unwrap())
});
}
}
Expand Down
5 changes: 2 additions & 3 deletions fastcrypto-tbls/src/nidkg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,8 @@ where
.map(|(i, pk)| Eval {
index: NonZeroU32::new((i + 1) as u32).expect("non zero"),
value: *pk,
})
.collect::<Vec<Eval<G>>>();
let pk = Poly::<G>::recover_c0(self.t, &evals).expect("enough shares");
});
let pk = Poly::<G>::recover_c0(self.t, evals).expect("enough shares");

(pk, partial_pks)
}
Expand Down
50 changes: 30 additions & 20 deletions fastcrypto-tbls/src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use fastcrypto::error::{FastCryptoError, FastCryptoResult};
use fastcrypto::groups::{GroupElement, MultiScalarMul, Scalar};
use fastcrypto::traits::AllowedRng;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashSet;

/// Types
Expand Down Expand Up @@ -81,22 +82,25 @@ impl<C: GroupElement> Poly<C> {
// Expects exactly t unique shares.
fn get_lagrange_coefficients_for_c0(
t: u32,
shares: &[Eval<C>],
mut shares: impl Iterator<Item = impl Borrow<Eval<C>>>,
) -> FastCryptoResult<Vec<C::ScalarType>> {
if shares.len() != t as usize {
return Err(FastCryptoError::InvalidInput);
}
// Check for duplicates.
let mut ids_set = HashSet::new();
if !shares.iter().map(|s| &s.index).all(|id| ids_set.insert(id)) {
return Err(FastCryptoError::InvalidInput); // expected unique ids
let (shares_size_lower, shares_size_upper) = shares.size_hint();
let indices = shares.try_fold(
Vec::with_capacity(shares_size_upper.unwrap_or(shares_size_lower)),
|mut vec, s| {
// Check for duplicates.
if !ids_set.insert(s.borrow().index) {
return Err(FastCryptoError::InvalidInput); // expected unique ids
}
vec.push(C::ScalarType::from(s.borrow().index.get() as u64));
Ok(vec)
},
)?;
if indices.len() != t as usize {
return Err(FastCryptoError::InvalidInput);
}

let indices = shares
.iter()
.map(|s| C::ScalarType::from(s.index.get() as u64))
.collect::<Vec<_>>();

let full_numerator = indices
.iter()
.fold(C::ScalarType::generator(), |acc, i| acc * i);
Expand All @@ -113,13 +117,16 @@ impl<C: GroupElement> Poly<C> {
}

/// Given exactly `t` polynomial evaluations, it will recover the polynomial's constant term.
pub fn recover_c0(t: u32, shares: &[Eval<C>]) -> Result<C, FastCryptoError> {
let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares)?;
let plain_shares = shares.iter().map(|s| s.value).collect::<Vec<_>>();
pub fn recover_c0(
t: u32,
shares: impl Iterator<Item = impl Borrow<Eval<C>>> + Clone,
) -> Result<C, FastCryptoError> {
let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares.clone())?;
let plain_shares = shares.map(|s| s.borrow().value);
let res = coeffs
.iter()
.zip(plain_shares.iter())
.fold(C::zero(), |acc, (c, s)| acc + (*s * *c));
.zip(plain_shares)
.fold(C::zero(), |acc, (c, s)| acc + (s * *c));
Ok(res)
}

Expand Down Expand Up @@ -172,9 +179,12 @@ impl<C: Scalar> Poly<C> {
impl<C: GroupElement + MultiScalarMul> Poly<C> {
/// Given exactly `t` polynomial evaluations, it will recover the polynomial's
/// constant term.
pub fn recover_c0_msm(t: u32, shares: &[Eval<C>]) -> Result<C, FastCryptoError> {
let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares)?;
let plain_shares = shares.iter().map(|s| s.value).collect::<Vec<_>>();
pub fn recover_c0_msm(
t: u32,
shares: impl Iterator<Item = impl Borrow<Eval<C>>> + Clone,
) -> Result<C, FastCryptoError> {
let coeffs = Self::get_lagrange_coefficients_for_c0(t, shares.clone())?;
let plain_shares = shares.map(|s| s.borrow().value).collect::<Vec<_>>();
let res = C::multi_scalar_mul(&coeffs, &plain_shares).expect("sizes match");
Ok(res)
}
Expand Down
51 changes: 25 additions & 26 deletions fastcrypto-tbls/src/tbls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// Some of the code below is based on code from https://github.com/celo-org/celo-threshold-bls-rs,
// modified for our needs.

use std::borrow::Borrow;

use crate::dl_verification::{batch_coefficients, get_random_scalars};
use crate::polynomial::Poly;
use crate::types::IndexedValue;
Expand All @@ -29,20 +31,22 @@ pub trait ThresholdBls {

/// Sign a message using the private share/partial key.
fn partial_sign(share: &Share<Self::Private>, msg: &[u8]) -> PartialSignature<Self::Signature> {
Self::partial_sign_batch(&[share.clone()], msg)[0].clone()
Self::partial_sign_batch(std::iter::once(share), msg)[0].clone()
}

/// Sign a message using one of more private share/partial keys.
fn partial_sign_batch(
shares: &[Share<Self::Private>],
shares: impl Iterator<Item = impl Borrow<Share<Self::Private>>>,
msg: &[u8],
) -> Vec<PartialSignature<Self::Signature>> {
let h = Self::Signature::hash_to_group_element(msg);
shares
.iter()
.map(|share| PartialSignature {
index: share.index,
value: h * share.value,
.map(|share| {
let share = share.borrow();
PartialSignature {
index: share.index,
value: h * share.value,
}
})
.collect()
}
Expand All @@ -63,46 +67,41 @@ pub trait ThresholdBls {
fn partial_verify_batch<R: AllowedRng>(
vss_pk: &Poly<Self::Public>,
msg: &[u8],
partial_sigs: &[PartialSignature<Self::Signature>],
partial_sigs: impl Iterator<Item = impl Borrow<PartialSignature<Self::Signature>>>,
rng: &mut R,
) -> FastCryptoResult<()> {
assert!(vss_pk.degree() > 0 || !msg.is_empty());
if partial_sigs.is_empty() {
let (evals_as_scalars, points): (Vec<_>, Vec<_>) = partial_sigs
.map(|sig| {
let sig = sig.borrow();
(Self::Private::from(sig.index.get().into()), sig.value)
})
.unzip();
if points.is_empty() {
return Ok(());
}
let rs = get_random_scalars::<Self::Private, R>(partial_sigs.len() as u32, rng);
let evals_as_scalars = partial_sigs
.iter()
.map(|e| Self::Private::from(e.index.get().into()))
.collect::<Vec<_>>();
let rs = get_random_scalars::<Self::Private, R>(points.len() as u32, rng);
// TODO: should we cache it instead? that would replace t-wide msm with w-wide msm.
let coeffs = batch_coefficients(&rs, &evals_as_scalars, vss_pk.degree());
let pk = Self::Public::multi_scalar_mul(&coeffs, vss_pk.as_vec()).expect("sizes match");
let aggregated_sig = Self::Signature::multi_scalar_mul(
&rs,
&partial_sigs.iter().map(|s| s.value).collect::<Vec<_>>(),
)
.expect("sizes match");
let aggregated_sig = Self::Signature::multi_scalar_mul(&rs, &points).expect("sizes match");

Self::verify(&pk, msg, &aggregated_sig)
}

/// Interpolate partial signatures to recover the full signature.
fn aggregate(
threshold: u32,
partials: &[PartialSignature<Self::Signature>],
partials: impl Iterator<Item = impl Borrow<PartialSignature<Self::Signature>>> + Clone,
) -> FastCryptoResult<Self::Signature> {
let unique_partials = partials
.iter()
.unique_by(|p| p.index)
.take(threshold as usize)
.cloned()
.collect::<Vec<_>>();
if unique_partials.len() != threshold as usize {
.unique_by(|p| p.borrow().index)
.take(threshold as usize);
if unique_partials.clone().count() != threshold as usize {
return Err(FastCryptoError::NotEnoughInputs);
}
// No conversion is required since PartialSignature<S> and Eval<S> are different aliases to
// IndexedValue<S>.
Poly::<Self::Signature>::recover_c0_msm(threshold, &unique_partials)
Poly::<Self::Signature>::recover_c0_msm(threshold, unique_partials)
}
}
2 changes: 1 addition & 1 deletion fastcrypto-tbls/src/tests/dkg_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ fn test_dkg_e2e_5_parties_min_weight_2_threshold_4() {
S::partial_verify(&o3.vss_pk, &MSG, &sig31).unwrap();

let sigs = vec![sig00, sig30, sig31];
let sig = S::aggregate(d0.t(), &sigs).unwrap();
let sig = S::aggregate(d0.t(), sigs.iter()).unwrap();
S::verify(o0.vss_pk.c0(), &MSG, &sig).unwrap();
}

Expand Down
37 changes: 17 additions & 20 deletions fastcrypto-tbls/src/tests/polynomial_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,13 @@ mod scalar_tests {
let threshold = degree + 1;
let poly = Poly::<S>::rand(4, &mut thread_rng());
// insufficient shares gathered
let shares = (1..threshold)
.map(|i| poly.eval(ShareIndex::new(i).unwrap()))
.collect::<Vec<_>>();
Poly::<S>::recover_c0(threshold, &shares).unwrap_err();
let shares = (1..threshold).map(|i| poly.eval(ShareIndex::new(i).unwrap()));
Poly::<S>::recover_c0(threshold, shares).unwrap_err();
// duplications
let mut shares = (1..=threshold)
let shares = (1..=threshold)
.map(|i| poly.eval(ShareIndex::new(i).unwrap()))
.collect::<Vec<_>>();
shares.push(shares[0].clone());
Poly::<S>::recover_c0(threshold, &shares).unwrap_err();
.chain(std::iter::once(poly.eval(ShareIndex::new(1).unwrap()))); // duplicate value 1
Poly::<S>::recover_c0(threshold, shares).unwrap_err();
}

#[test]
Expand All @@ -75,7 +72,7 @@ mod scalar_tests {
let c0 = poly.c0();
for _ in 0..10 {
shares.shuffle(&mut thread_rng());
let used_shares = &shares[..124];
let used_shares = shares.iter().take(124);
assert_eq!(c0, &Poly::<S>::recover_c0(124, used_shares).unwrap());
}
}
Expand Down Expand Up @@ -112,7 +109,10 @@ mod points_tests {
let s2 = p.eval(NonZeroU32::new(20).unwrap());
let s3 = p.eval(NonZeroU32::new(30).unwrap());
let shares = vec![s1, s2, s3];
assert_eq!(Poly::<G::ScalarType>::recover_c0(3, &shares).unwrap(), one);
assert_eq!(
Poly::<G::ScalarType>::recover_c0(3, shares.iter()).unwrap(),
one
);
}

#[test]
Expand All @@ -122,16 +122,13 @@ mod points_tests {
let poly = Poly::<G::ScalarType>::rand(4, &mut thread_rng());
let poly_g = poly.commit();
// insufficient shares gathered
let shares = (1..threshold)
.map(|i| poly_g.eval(ShareIndex::new(i).unwrap()))
.collect::<Vec<_>>();
Poly::<G>::recover_c0_msm(threshold, &shares).unwrap_err();
let shares = (1..threshold).map(|i| poly_g.eval(ShareIndex::new(i).unwrap()));
Poly::<G>::recover_c0_msm(threshold, shares).unwrap_err();
// duplications
let mut shares = (1..threshold)
let shares = (1..threshold)
.map(|i| poly_g.eval(ShareIndex::new(i).unwrap()))
.collect::<Vec<_>>();
shares.push(shares[0].clone());
Poly::<G>::recover_c0_msm(threshold, &shares).unwrap_err();
.chain(std::iter::once(poly_g.eval(ShareIndex::new(1).unwrap()))); // duplicate value 1
Poly::<G>::recover_c0_msm(threshold, shares).unwrap_err();
}

#[test]
Expand All @@ -144,7 +141,7 @@ mod points_tests {
let s2 = p.eval(NonZeroU32::new(20).unwrap());
let s3 = p.eval(NonZeroU32::new(30).unwrap());
let shares = vec![s1, s2, s3];
assert_eq!(Poly::<G>::recover_c0_msm(3, &shares).unwrap(), one);
assert_eq!(Poly::<G>::recover_c0_msm(3, shares.iter()).unwrap(), one);

// and random tests
let poly = Poly::<G::ScalarType>::rand(123, &mut thread_rng());
Expand All @@ -156,7 +153,7 @@ mod points_tests {
let c0 = poly_g.c0();
for _ in 0..10 {
shares.shuffle(&mut thread_rng());
let used_shares = &shares[..124];
let used_shares = shares.iter().take(124);
assert_eq!(c0, &Poly::<G>::recover_c0_msm(124, used_shares).unwrap());
}
}
Expand Down
Loading

0 comments on commit 4d8041f

Please sign in to comment.