From 14d62bb3b5814403c7b5fe30eb78bb74a96e718c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Lindstr=C3=B8m?= Date: Mon, 11 Dec 2023 15:31:24 +0100 Subject: [PATCH] Optimze VDF (#706) * Update reduction * mutable reduction * Get rid of closure * Update numbigint * Use multi-scalar multiplication for VDF verification * fmt * clippy * Clean up * Clean up benchmark * Test * Optimise fs * Update tests * Remove unused function * constant * comment * Clean up * Trait bounds * align names * Review * cleanup * Refactor hashprime * Refactor: get rid of scalar_size constant * Remove complaint function * Rename test * fix bench --- Cargo.lock | 29 +-- fastcrypto-cli/src/vdf.rs | 4 +- fastcrypto-vdf/Cargo.toml | 4 +- fastcrypto-vdf/benches/class_group.rs | 1 + fastcrypto-vdf/benches/vdf.rs | 59 +++++- fastcrypto-vdf/src/class_group/mod.rs | 90 ++++---- fastcrypto-vdf/src/extended_gcd.rs | 12 +- fastcrypto-vdf/src/hash_prime.rs | 192 ++++++++++++++---- fastcrypto-vdf/src/lib.rs | 16 +- fastcrypto-vdf/src/vdf/wesolowski.rs | 136 ++++++++++--- fastcrypto/Cargo.toml | 1 + fastcrypto/benches/groups.rs | 49 +++-- fastcrypto/src/groups/mod.rs | 11 +- fastcrypto/src/groups/multiplier/bgmw.rs | 22 +- .../src/groups/multiplier/integer_utils.rs | 36 ++-- fastcrypto/src/groups/multiplier/mod.rs | 23 +-- fastcrypto/src/groups/multiplier/windowed.rs | 161 ++++++++------- fastcrypto/src/groups/ristretto255.rs | 10 +- fastcrypto/src/groups/secp256r1.rs | 11 +- fastcrypto/src/secp256r1/mod.rs | 9 +- fastcrypto/src/tests/secp256r1_group_tests.rs | 2 +- 21 files changed, 597 insertions(+), 281 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6575d7b412..8cad9e7e2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -184,7 +184,7 @@ dependencies = [ "derivative", "digest 0.10.6", "itertools", - "num-bigint 0.4.3", + "num-bigint 0.4.4", "num-traits", "paste", "rustc_version", @@ -207,7 +207,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" dependencies = [ - "num-bigint 0.4.3", + "num-bigint 0.4.4", "num-traits", "proc-macro2", "quote", @@ -253,7 +253,7 @@ dependencies = [ "ark-relations", "ark-std", "derivative", - "num-bigint 0.4.3", + "num-bigint 0.4.4", "num-integer", "num-traits", "tracing", @@ -291,7 +291,7 @@ dependencies = [ "ark-serialize-derive", "ark-std", "digest 0.10.6", - "num-bigint 0.4.3", + "num-bigint 0.4.4", ] [[package]] @@ -1414,6 +1414,7 @@ dependencies = [ "k256", "lazy_static", "merlin", + "num-bigint 0.4.4", "once_cell", "p256", "proptest", @@ -1500,7 +1501,7 @@ dependencies = [ "criterion 0.5.1", "fastcrypto", "hex", - "num-bigint 0.4.3", + "num-bigint 0.4.4", "num-integer", "num-modular 0.6.1", "num-prime", @@ -1536,7 +1537,7 @@ dependencies = [ "im", "lazy_static", "neptune", - "num-bigint 0.4.3", + "num-bigint 0.4.4", "once_cell", "poseidon-ark", "proptest", @@ -2246,7 +2247,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43db66d1170d347f9a065114077f7dccb00c1b9478c89384490a3425279a4606" dependencies = [ - "num-bigint 0.4.3", + "num-bigint 0.4.4", "num-complex", "num-integer", "num-iter", @@ -2267,9 +2268,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" dependencies = [ "autocfg", "num-integer", @@ -2330,7 +2331,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64a5fe11d4135c3bcdf3a95b18b194afa9608a5f6ff034f5d857bc9a27fb0119" dependencies = [ - "num-bigint 0.4.3", + "num-bigint 0.4.4", "num-integer", "num-traits", ] @@ -2350,7 +2351,7 @@ dependencies = [ "bitvec", "either", "lru", - "num-bigint 0.4.3", + "num-bigint 0.4.4", "num-integer", "num-modular 0.5.1", "num-traits", @@ -2364,16 +2365,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" dependencies = [ "autocfg", - "num-bigint 0.4.3", + "num-bigint 0.4.4", "num-integer", "num-traits", ] [[package]] name = "num-traits" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", "libm", diff --git a/fastcrypto-cli/src/vdf.rs b/fastcrypto-cli/src/vdf.rs index 52209a48e0..4663452455 100644 --- a/fastcrypto-cli/src/vdf.rs +++ b/fastcrypto-cli/src/vdf.rs @@ -174,7 +174,7 @@ mod tests { iterations, })) .unwrap(); - let expected = "Output: 010027d513249bf8d6ad8cc854052080111a420b2771fab2ac566e63cb6a389cfe42c7920b90871fd1ea0b85e80d157d48e6759546cdcfef4a25b3f013b982c2970dfaa8d67e5f87564a91698ffd1407c505372fc52b0313f444937991c63b6b00040401\nProof: 0000aadd0fceb7cab33ad9991aaddfb234473d2c4dc987225cba6f1c6a259e01e893fecede62b459db56474f840e0da0e4de3d0b2da709083620dccfed9451dc3c1b4f911167c85f887dacaa6cac52db94682f9ddc73c18613d4ecf6513580ec2f270302"; + let expected = "Output: 010027d513249bf8d6ad8cc854052080111a420b2771fab2ac566e63cb6a389cfe42c7920b90871fd1ea0b85e80d157d48e6759546cdcfef4a25b3f013b982c2970dfaa8d67e5f87564a91698ffd1407c505372fc52b0313f444937991c63b6b00040401\nProof: 0200a79fea1d00b2d1bf7863098980146ad080d400141ff2333652cbcee96b524f273461f8e2e65d8b713663f7083954ef6246ea08d09e6909a047f34065bcfe1e2013c8e523a8a59a01fafa008c637240097d082486c8cc52803d5cad3d4e2aa9130402"; assert_eq!(expected, result); let invalid_discriminant = "abcx".to_string(); @@ -190,7 +190,7 @@ mod tests { let discriminant = "ff6cb04c161319209d438b6f016a9c3703b69fef3bb701550eb556a7b2dfec8676677282f2dd06c5688c51439c59e5e1f9efe8305df1957d6b7bf3433493668680e8b8bb05262cbdf4d020dafa8d5a3433199b8b53f6d487b3f37a4ab59493f050d1e2b535b7e9be19c0201055c0d7a07db3aaa67fe0eed63b63d86558668a27".to_string(); let iterations = 1000u64; let output = "010027d513249bf8d6ad8cc854052080111a420b2771fab2ac566e63cb6a389cfe42c7920b90871fd1ea0b85e80d157d48e6759546cdcfef4a25b3f013b982c2970dfaa8d67e5f87564a91698ffd1407c505372fc52b0313f444937991c63b6b00040401".to_string(); - let proof = "0000aadd0fceb7cab33ad9991aaddfb234473d2c4dc987225cba6f1c6a259e01e893fecede62b459db56474f840e0da0e4de3d0b2da709083620dccfed9451dc3c1b4f911167c85f887dacaa6cac52db94682f9ddc73c18613d4ecf6513580ec2f270302".to_string(); + let proof = "0200a79fea1d00b2d1bf7863098980146ad080d400141ff2333652cbcee96b524f273461f8e2e65d8b713663f7083954ef6246ea08d09e6909a047f34065bcfe1e2013c8e523a8a59a01fafa008c637240097d082486c8cc52803d5cad3d4e2aa9130402".to_string(); let result = execute(Command::Verify(VerifyArguments { discriminant, iterations, diff --git a/fastcrypto-vdf/Cargo.toml b/fastcrypto-vdf/Cargo.toml index 3a9117babe..98c969024a 100644 --- a/fastcrypto-vdf/Cargo.toml +++ b/fastcrypto-vdf/Cargo.toml @@ -11,8 +11,8 @@ repository = "https://github.com/MystenLabs/fastcrypto" [dependencies] fastcrypto = { path = "../fastcrypto" } -num-bigint = "0.4.3" -num-traits = "0.2.16" +num-bigint = "0.4.4" +num-traits = "0.2.17" num-integer = "0.1.45" num-modular = "0.6.1" num-prime = { version = "0.4.3", features = ["big-int"] } diff --git a/fastcrypto-vdf/benches/class_group.rs b/fastcrypto-vdf/benches/class_group.rs index e1a29e273b..4699d665f0 100644 --- a/fastcrypto-vdf/benches/class_group.rs +++ b/fastcrypto-vdf/benches/class_group.rs @@ -3,6 +3,7 @@ use criterion::measurement::Measurement; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkGroup, Criterion}; +use fastcrypto::groups::Doubling; use fastcrypto_vdf::class_group::{Discriminant, QuadraticForm}; use fastcrypto_vdf::ParameterizedGroupElement; use num_bigint::BigInt; diff --git a/fastcrypto-vdf/benches/vdf.rs b/fastcrypto-vdf/benches/vdf.rs index 5b8ce7d401..570d61edb1 100644 --- a/fastcrypto-vdf/benches/vdf.rs +++ b/fastcrypto-vdf/benches/vdf.rs @@ -6,8 +6,11 @@ extern crate criterion; use criterion::measurement::Measurement; use criterion::{BenchmarkGroup, BenchmarkId, Criterion}; +use fastcrypto::groups::multiplier::windowed::WindowedScalarMultiplier; use fastcrypto_vdf::class_group::{Discriminant, QuadraticForm}; -use fastcrypto_vdf::vdf::wesolowski::StrongVDF; +use fastcrypto_vdf::hash_prime::{hash_prime_with_index, verify_prime, DefaultPrimalityCheck}; +use fastcrypto_vdf::vdf::wesolowski::CHALLENGE_SIZE; +use fastcrypto_vdf::vdf::wesolowski::{FastVerifier, StrongFiatShamir, StrongVDF}; use fastcrypto_vdf::vdf::VDF; use fastcrypto_vdf::Parameter; use num_bigint::BigInt; @@ -29,16 +32,29 @@ fn verify_single(parameters: VerificationInputs, c: &mut Benchma let result_bytes = hex::decode(parameters.result).unwrap(); let result = QuadraticForm::from_bytes(&result_bytes, &discriminant).unwrap(); + let result_copy = result.clone(); let proof_bytes = hex::decode(parameters.proof).unwrap(); let proof = QuadraticForm::from_bytes(&proof_bytes, &discriminant).unwrap(); + let proof_copy = proof.clone(); let input = QuadraticForm::generator(&discriminant); + let input_copy = input.clone(); - let vdf = StrongVDF::new(discriminant, parameters.iterations); + let vdf = StrongVDF::new(discriminant.clone(), parameters.iterations); c.bench_function(discriminant_size.to_string(), move |b| { b.iter(|| vdf.verify(&input, &result, &proof)) }); + + let vdf = StrongVDF::new(discriminant.clone(), parameters.iterations); + let fast_verify: FastVerifier< + QuadraticForm, + StrongFiatShamir, + WindowedScalarMultiplier, + > = FastVerifier::new(vdf, input_copy); + c.bench_function(format!("{} fast", discriminant_size), move |b| { + b.iter(|| fast_verify.verify(&result_copy, &proof_copy)) + }); } fn verify(c: &mut Criterion) { @@ -62,16 +78,16 @@ fn verify(c: &mut Criterion) { } fn sample_discriminant(c: &mut Criterion) { - let bit_lengths = [128, 256, 512, 1024, 2048]; + let byte_lengths = [16, 32, 64, 128, 256]; let mut seed = [0u8; 32]; let mut rng = thread_rng(); - for bit_length in bit_lengths { + for byte_length in byte_lengths { c.bench_with_input( - BenchmarkId::new("Sample class group discriminant".to_string(), bit_length), - &bit_length, + BenchmarkId::new("Sample class group discriminant".to_string(), byte_length), + &byte_length, |b, n| { b.iter(|| { rng.try_fill_bytes(&mut seed).unwrap(); @@ -82,10 +98,39 @@ fn sample_discriminant(c: &mut Criterion) { } } +fn verify_discriminant(c: &mut Criterion) { + let byte_lengths = [16, 32, 64, 128, 256]; + let seed = [0u8; 32]; + + for byte_length in byte_lengths { + let (i, _) = hash_prime_with_index::( + &seed, + byte_length, + &[0, 1, 8 * byte_length - 1], + ); + + c.bench_with_input( + BenchmarkId::new("Verify discriminant".to_string(), byte_length), + &byte_length, + |b, n| { + b.iter(|| { + verify_prime::( + &seed, + *n, + &[0, 1, 8 * byte_length - 1], + i, + ) + .unwrap() + }) + }, + ); + } +} + criterion_group! { name = vdf_benchmarks; config = Criterion::default().sample_size(100); - targets = verify, sample_discriminant + targets = verify, sample_discriminant, verify_discriminant } criterion_main!(vdf_benchmarks); diff --git a/fastcrypto-vdf/src/class_group/mod.rs b/fastcrypto-vdf/src/class_group/mod.rs index 5a4faa5fae..14ba5d7746 100644 --- a/fastcrypto-vdf/src/class_group/mod.rs +++ b/fastcrypto-vdf/src/class_group/mod.rs @@ -9,12 +9,14 @@ use crate::extended_gcd::{extended_euclidean_algorithm, EuclideanAlgorithmOutput use crate::{ParameterizedGroupElement, ToBytes, UnknownOrderGroupElement}; use fastcrypto::error::FastCryptoError::InvalidInput; use fastcrypto::error::{FastCryptoError, FastCryptoResult}; +use fastcrypto::groups::Doubling; use num_bigint::BigInt; use num_integer::Integer; use num_traits::{One, Signed, Zero}; +use std::borrow::Borrow; use std::cmp::Ordering; use std::mem::swap; -use std::ops::{Add, Neg}; +use std::ops::{Add, AddAssign, Mul, Neg, Shl, Shr}; mod compressed; @@ -66,21 +68,15 @@ impl QuadraticForm { } /// Return a normalized form equivalent to this quadratic form. See [`QuadraticForm::is_normal`]. - fn normalize(self) -> Self { + fn normalize(&mut self) { // See section 5 in https://github.com/Chia-Network/chiavdf/blob/main/classgroups.pdf. if self.is_normal() { - return self; - } - let r = (&self.a - &self.b).div_floor(&(&self.a * 2)); - let ra = &r * &self.a; - let c = self.c + (&ra + &self.b) * &r; - let b = self.b + &ra * 2; - Self { - a: self.a, - b, - c, - partial_gcd_limit: self.partial_gcd_limit, + return; } + let r = (&self.a - &self.b).div_floor(&self.a).shr(1); + let ra: BigInt = &r * &self.a; + self.c.add_assign((&ra + &self.b) * &r); + self.b.add_assign(&ra.shl(1)); } /// Return true if this form is reduced: A form is reduced if it is normal (see @@ -94,18 +90,16 @@ impl QuadraticForm { } /// Return a reduced form (see [`QuadraticForm::is_reduced`]) equivalent to this quadratic form. - fn reduce(self) -> Self { + fn reduce(&mut self) { // See section 5 in https://github.com/Chia-Network/chiavdf/blob/main/classgroups.pdf. - let mut form = self.normalize(); - while !form.is_reduced() { - let s = (&form.b + &form.c).div_floor(&(&form.c * 2)); - let cs = &form.c * &s; - let old_c = form.c.clone(); - form.c = (&cs - &form.b) * &s + &form.a; - form.a = old_c; - form.b = &cs * 2 - &form.b; + self.normalize(); + while !self.is_reduced() { + let s = (&self.b + &self.c).div_floor(&self.c).shr(1); + let cs: BigInt = &self.c * &s; + swap(&mut self.a, &mut self.c); + self.c += (&cs - &self.b) * &s; + self.b = cs.shl(1) - &self.b; } - form } /// Compute the composition of this quadratic form with another quadratic form. @@ -216,27 +210,19 @@ impl QuadraticForm { v3 = &g * (&q3 + &q4) - &q1 - &q2; } - QuadraticForm { + let mut form = QuadraticForm { a: u3, b: v3, c: w3, partial_gcd_limit: self.partial_gcd_limit.clone(), - } - .reduce() + }; + form.reduce(); + form } } -impl ParameterizedGroupElement for QuadraticForm { - /// The discriminant of a quadratic form defines the class group. - type ParameterType = Discriminant; - - type ScalarType = BigInt; - - fn zero(discriminant: &Self::ParameterType) -> Self { - Self::from_a_b_discriminant(BigInt::one(), BigInt::one(), discriminant) - } - - fn double(self) -> Self { +impl Doubling for QuadraticForm { + fn double(&self) -> Self { // Slightly optimised version of Algorithm 2 from Jacobson, Jr, Michael & Poorten, Alfred // (2002). "Computational aspects of NUCOMP", Lecture Notes in Computer Science. // (https://www.researchgate.net/publication/221451638_Computational_aspects_of_NUCOMP) @@ -300,13 +286,33 @@ impl ParameterizedGroupElement for QuadraticForm { w3 = &w3 - &g * &x * &dx; } - QuadraticForm { + let mut form = QuadraticForm { a: u3, b: v3, c: w3, partial_gcd_limit: self.partial_gcd_limit.clone(), - } - .reduce() + }; + form.reduce(); + form + } +} + +impl<'a> Mul<&'a BigInt> for QuadraticForm { + type Output = Self; + + fn mul(self, rhs: &'a BigInt) -> Self::Output { + self.borrow().mul(rhs) + } +} + +impl ParameterizedGroupElement for QuadraticForm { + /// The discriminant of a quadratic form defines the class group. + type ParameterType = Discriminant; + + type ScalarType = BigInt; + + fn zero(discriminant: &Self::ParameterType) -> Self { + Self::from_a_b_discriminant(BigInt::one(), BigInt::one(), discriminant) } fn mul(&self, scale: &BigInt) -> Self { @@ -436,14 +442,14 @@ mod tests { QuadraticForm::from_a_b_discriminant(BigInt::from(11), BigInt::from(49), &discriminant); assert_eq!(quadratic_form.c, BigInt::from(55)); - quadratic_form = quadratic_form.normalize(); + quadratic_form.normalize(); // Test vector from https://github.com/Chia-Network/vdf-competition/blob/main/classgroups.pdf assert_eq!(quadratic_form.a, BigInt::from(11)); assert_eq!(quadratic_form.b, BigInt::from(5)); assert_eq!(quadratic_form.c, BigInt::from(1)); - quadratic_form = quadratic_form.reduce(); + quadratic_form.reduce(); // Test vector from https://github.com/Chia-Network/vdf-competition/blob/main/classgroups.pdf assert_eq!(quadratic_form.a, BigInt::from(1)); diff --git a/fastcrypto-vdf/src/extended_gcd.rs b/fastcrypto-vdf/src/extended_gcd.rs index 9ea3480177..d10a2f545f 100644 --- a/fastcrypto-vdf/src/extended_gcd.rs +++ b/fastcrypto-vdf/src/extended_gcd.rs @@ -50,13 +50,11 @@ pub fn extended_euclidean_algorithm(a: &BigInt, b: &BigInt) -> EuclideanAlgorith r.1 = r.0; r.0 = r_prime; - let f = |mut x: (BigInt, BigInt)| { - mem::swap(&mut x.0, &mut x.1); - x.0 -= &q * &x.1; - x - }; - s = f(s); - t = f(t); + mem::swap(&mut s.0, &mut s.1); + s.0 -= &q * &s.1; + + mem::swap(&mut t.0, &mut t.1); + t.0 -= &q * &t.1; } // The last coefficients are equal to +/- a / gcd(a,b) and b / gcd(a,b) respectively. diff --git a/fastcrypto-vdf/src/hash_prime.rs b/fastcrypto-vdf/src/hash_prime.rs index ee69140658..4994dfb069 100644 --- a/fastcrypto-vdf/src/hash_prime.rs +++ b/fastcrypto-vdf/src/hash_prime.rs @@ -4,12 +4,41 @@ //! This module contains an implementation of a hash-to-prime function identical to the HashPrime //! function from [chiavdf](https://github.com/Chia-Network/chiavdf/blob/bcc36af3a8de4d2fcafa571602040a4ebd4bdd56/src/proof_common.h#L14-L43). -use fastcrypto::error::FastCryptoError::InvalidInput; -use fastcrypto::error::FastCryptoResult; +use fastcrypto::error::{FastCryptoError, FastCryptoResult}; use fastcrypto::hash::{HashFunction, Sha256}; -use num_bigint::{BigInt, BigUint}; +use num_bigint::BigUint; +use num_prime::nt_funcs::is_prime; use std::cmp::min; +struct HashPrimeIterator { + seed: Vec, + length_in_bytes: usize, + bitmask: Vec, +} + +impl Iterator for HashPrimeIterator { + type Item = BigUint; + + fn next(&mut self) -> Option { + let mut blob = vec![]; + while blob.len() < self.length_in_bytes { + for i in (0..self.seed.len()).rev() { + self.seed[i] = self.seed[i].wrapping_add(1); + if self.seed[i] != 0 { + break; + } + } + let hash = Sha256::digest(&self.seed).digest; + blob.extend_from_slice(&hash[..min(hash.len(), self.length_in_bytes - blob.len())]); + } + let mut x = BigUint::from_bytes_be(&blob); + for b in &self.bitmask { + x.set_bit(*b as u64, true); + } + Some(x) + } +} + /// Implementation of a probabilistic primality test. pub trait PrimalityCheck { /// Return true if `x` is probably a prime. If `false` is returned, `x` is guaranteed to be composite. @@ -22,52 +51,59 @@ pub trait PrimalityCheck { /// (for b in bitmask) { x |= (1 << b) }. /// Then return x if it is a pseudo-prime, otherwise repeat. /// -/// The length must be a multiple of 8, otherwise `FastCryptoError::InvalidInput` is returned. +/// See also [hash_prime_with_index]. pub fn hash_prime( seed: &[u8], - length: usize, + length_in_bytes: usize, bitmask: &[usize], -) -> FastCryptoResult { - if length % 8 != 0 { - return Err(InvalidInput); - } - - let mut sprout: Vec = vec![]; - sprout.extend_from_slice(seed); +) -> BigUint { + hash_prime_with_index::

(seed, length_in_bytes, bitmask).1 +} - loop { - let mut blob = vec![]; - while blob.len() * 8 < length { - for i in (0..sprout.len()).rev() { - sprout[i] = sprout[i].wrapping_add(1); - if sprout[i] != 0 { - break; - } - } - let hash = Sha256::digest(&sprout).digest; - blob.extend_from_slice(&hash[..min(hash.len(), length / 8 - blob.len())]); - } - let mut x = BigUint::from_bytes_be(&blob); - for b in bitmask { - x.set_bit(*b as u64, true); - } +/// Generates a random pseudo-prime using the hash and check method: +/// Randomly chooses x with bit-length `length`, then applies a mask +/// (for b in bitmask) { x |= (1 << b) }. +/// Then return x if it is a pseudo-prime, otherwise repeat. +/// +/// This method returns both the prime and the index of the prime in the +/// iterator. +/// +/// See also [hash_prime]. +pub fn hash_prime_with_index( + seed: &[u8], + length_in_bytes: usize, + bitmask: &[usize], +) -> (usize, BigUint) { + let iterator = HashPrimeIterator { + seed: seed.to_vec(), + length_in_bytes, + bitmask: bitmask.to_vec(), + }; + iterator.enumerate().find(|(_, x)| P::is_prime(x)).unwrap() +} - // The implementations of the primality test used below might be slightly different from the - // one used by chiavdf, but since the risk of a false positive is very small (4^{-100}) this - // is not an issue. - if P::is_prime(&x) { - return Ok(x.into()); - } +/// Verify that the given prime is a prime and has the given index in the hash prime iterator. +pub fn verify_prime( + seed: &[u8], + length_in_bytes: usize, + bitmask: &[usize], + index: usize, +) -> FastCryptoResult<()> { + let mut iterator = HashPrimeIterator { + seed: seed.to_vec(), + length_in_bytes, + bitmask: bitmask.to_vec(), + }; + // Check that the index points to a prime + if P::is_prime(&iterator.nth(index).expect("Iterator is infinite")) { + return Ok(()); } + Err(FastCryptoError::InvalidProof) } /// Implementation of [hash_prime] using the primality test from `num_prime::nt_funcs::is_prime`. -pub fn hash_prime_default( - seed: &[u8], - length: usize, - bitmask: &[usize], -) -> FastCryptoResult { - hash_prime::(seed, length, bitmask) +pub fn hash_prime_default(seed: &[u8], length_in_bytes: usize, bitmask: &[usize]) -> BigUint { + hash_prime::(seed, length_in_bytes, bitmask) } /// Implementation of the [PrimalityCheck] trait using the primality test from `num_prime::nt_funcs::is_prime`. @@ -75,6 +111,80 @@ pub struct DefaultPrimalityCheck {} impl PrimalityCheck for DefaultPrimalityCheck { fn is_prime(x: &BigUint) -> bool { - num_prime::nt_funcs::is_prime(x, None).probably() + is_prime(x, None).probably() + } +} + +#[cfg(test)] +mod tests { + use crate::hash_prime::{ + hash_prime_default, verify_prime, DefaultPrimalityCheck, HashPrimeIterator, PrimalityCheck, + }; + use num_bigint::BigUint; + use num_integer::Integer; + use num_prime::PrimalityTestConfig; + use std::str::FromStr; + + #[test] + fn test_hash_prime() { + let seed = [0u8; 32]; + let length = 64; + let bitmask: [usize; 3] = [0, 1, 8 * length - 1]; + + let prime = hash_prime_default(&seed, length, &bitmask); + + // Prime has right length + assert_eq!((length * 8) as u64, prime.bits()); + + // The last two bits are set (see bitmask) + assert_eq!(BigUint::from(3u64), prime.mod_floor(&BigUint::from(4u64))); + + // The result is a prime, even when checking with a stricter test + assert!( + num_prime::nt_funcs::is_prime(&prime, Some(PrimalityTestConfig::strict())).probably() + ); + + // Regression test + assert_eq!(prime, BigUint::from_str("7904272817142338150419757415334055106926417574777773392214522399425467199262039794276651240832053626391864792937889238336287002167559810128294881253078163").unwrap()); + } + + #[test] + fn test_verify_hash_prime() { + let seed = [0u8; 32]; + let length_in_bytes = 64; + let bitmask: [usize; 3] = [0, 1, 8 * length_in_bytes - 1]; + + let iterator = HashPrimeIterator { + seed: seed.to_vec(), + length_in_bytes, + bitmask: bitmask.to_vec(), + } + .enumerate(); + + let mut candidates = vec![]; + + for (i, x) in iterator.take(1000) { + if DefaultPrimalityCheck::is_prime(&x) { + candidates.push((i, x)); + } + } + + for candidate in &candidates { + assert!(verify_prime::( + &seed, + length_in_bytes, + &bitmask, + candidate.0 + ) + .is_ok()); + } + + assert!(verify_prime::( + &seed, + length_in_bytes, + &bitmask, + &candidates[0].0 + 1, + ) + .is_err()); } } diff --git a/fastcrypto-vdf/src/lib.rs b/fastcrypto-vdf/src/lib.rs index e542ea49b6..cf61d0ab94 100644 --- a/fastcrypto-vdf/src/lib.rs +++ b/fastcrypto-vdf/src/lib.rs @@ -2,7 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use fastcrypto::error::FastCryptoResult; -use std::ops::{Add, Neg}; +use fastcrypto::groups::Doubling; +use std::ops::{Add, Mul, Neg}; #[cfg(any(test, feature = "experimental"))] pub mod class_group; @@ -25,7 +26,15 @@ pub trait Parameter: Eq + Sized + ToBytes { /// Trait implemented by elements of an additive group where the group is parameterized, for example /// by the modulus in case of the group being Z mod N or the discriminant in case of class groups. pub trait ParameterizedGroupElement: - Sized + Clone + for<'a> Add<&'a Self, Output = Self> + Add + Neg + Eq + ToBytes + Sized + + Clone + + for<'a> Add<&'a Self, Output = Self> + + Add + + for<'a> Mul<&'a Self::ScalarType, Output = Self> + + Neg + + Eq + + ToBytes + + Doubling { /// The type of the parameter which uniquely defines this group. type ParameterType: Parameter; @@ -36,9 +45,6 @@ pub trait ParameterizedGroupElement: /// Return an instance of the identity element in this group. fn zero(parameters: &Self::ParameterType) -> Self; - /// Compute 2 * Self. - fn double(self) -> Self; - /// Compute scale * self. fn mul(&self, scale: &Self::ScalarType) -> Self; diff --git a/fastcrypto-vdf/src/vdf/wesolowski.rs b/fastcrypto-vdf/src/vdf/wesolowski.rs index 8f3cbc42e5..e89fae60f2 100644 --- a/fastcrypto-vdf/src/vdf/wesolowski.rs +++ b/fastcrypto-vdf/src/vdf/wesolowski.rs @@ -7,20 +7,25 @@ use crate::vdf::VDF; use crate::{hash_prime, Parameter, ParameterizedGroupElement, ToBytes, UnknownOrderGroupElement}; use fastcrypto::error::FastCryptoError::{InvalidInput, InvalidProof}; use fastcrypto::error::FastCryptoResult; -use num_bigint::BigInt; +use fastcrypto::groups::multiplier::windowed::WindowedScalarMultiplier; +use fastcrypto::groups::multiplier::ScalarMultiplier; +use num_bigint::{BigInt, ToBigInt}; use num_integer::Integer; use std::marker::PhantomData; use std::ops::Neg; -/// An implementation of the Wesolowski VDF construction (https://eprint.iacr.org/2018/623) over a +/// An implementation of Wesolowski's VDF construction (https://eprint.iacr.org/2018/623) over a /// group of unknown order. -pub struct WesolowskiVDF { +pub struct WesolowskisVDF> +{ group_parameter: G::ParameterType, iterations: u64, _fiat_shamir: PhantomData, } -impl WesolowskiVDF { +impl> + WesolowskisVDF +{ /// Create a new VDF using the group defined by the given group parameter. Evaluating this VDF /// will require computing `2^iterations * input` which requires `iterations` group operations. pub fn new(group_parameter: G::ParameterType, iterations: u64) -> Self { @@ -35,7 +40,7 @@ impl WesolowskiVDF + UnknownOrderGroupElement, F: FiatShamir, - > VDF for WesolowskiVDF + > VDF for WesolowskisVDF { type InputType = G; type OutputType = G; @@ -72,7 +77,6 @@ impl< } let challenge = F::compute_challenge(self, input, output); - let f1 = proof.mul(&challenge); let r = BigInt::modpow(&BigInt::from(2), &BigInt::from(self.iterations), &challenge); @@ -85,16 +89,79 @@ impl< } } +/// A faster method of verification which uses fast multi-scalar multiplication. The scalar size +/// for the scalar multiplier `M` must be larger enough to hold the challenge in the Fiat-Shamir +/// construction `F`. +pub struct FastVerifier< + G: ParameterizedGroupElement + UnknownOrderGroupElement, + F: FiatShamir, + M: ScalarMultiplier, +> { + vdf: WesolowskisVDF, + input: G, + multiplier: M, +} + +impl< + G: ParameterizedGroupElement + UnknownOrderGroupElement, + F: FiatShamir, + M: ScalarMultiplier, + > FastVerifier +{ + /// Create a new FastVerifier for the given VDF instance. + pub fn new(vdf: WesolowskisVDF, input: G) -> Self { + let multiplier = M::new(input.clone(), G::zero(&vdf.group_parameter)); + Self { + vdf, + input, + multiplier, + } + } + + /// Verify the output and proof from a VDF using the input given in [new]. + pub fn verify(&self, output: &G, proof: &G) -> FastCryptoResult<()> { + if !self.input.same_group(output) || !self.input.same_group(proof) { + return Err(InvalidInput); + } + + let challenge = F::compute_challenge(&self.vdf, &self.input, output); + let r = BigInt::modpow( + &BigInt::from(2), + &BigInt::from(self.vdf.iterations), + &challenge, + ); + let actual = self.multiplier.two_scalar_mul(&r, proof, &challenge); + if actual != *output { + return Err(InvalidProof); + } + Ok(()) + } +} + /// Implementation of Wesolowski's VDF construction over a group of unknown order using a strong /// Fiat-Shamir implementation. pub type StrongVDF = - WesolowskiVDF>; + WesolowskisVDF>; + +pub type StrongVDFVerifier = FastVerifier< + G, + StrongFiatShamir, + WindowedScalarMultiplier, +>; /// Implementation of Wesolowski's VDF construction over a group of unknown order using the Fiat-Shamir /// construction from chiavdf (https://github.com/Chia-Network/chiavdf). -pub type WeakVDF = WesolowskiVDF>; +pub type WeakVDF = WesolowskisVDF>; + +pub type WeakVDFVerifier = FastVerifier< + G, + WeakFiatShamir, + WindowedScalarMultiplier, +>; -impl WesolowskiVDF { +impl> + WesolowskisVDF +{ /// Create a new VDF over an group of unknown where the discriminant has a given size and /// is generated based on a seed. The `iterations` parameters specifies the number of group /// operations the evaluation function requires. @@ -106,14 +173,14 @@ impl WesolowskiVDF { +pub trait FiatShamir: Sized { /// Compute the prime modulus used in proving and verification. This is a Fiat-Shamir construction /// to make the Wesolowski VDF non-interactive. - fn compute_challenge(vdf: &WesolowskiVDF, input: &G, output: &G) -> G::ScalarType; + fn compute_challenge(vdf: &WesolowskisVDF, input: &G, output: &G) -> G::ScalarType; } -/// Size of the challenge used in proving and verification. -pub const CHALLENGE_SIZE: usize = 264; +/// Default size in bytes of the Fiat-Shamir challenge used in proving and verification (same as chiavdf). +pub const CHALLENGE_SIZE: usize = 33; /// Implementation of the Fiat-Shamir challenge generation compatible with chiavdf. /// Note that this implementation is weak, meaning that not all public parameters are used in the @@ -130,12 +197,11 @@ impl< P: PrimalityCheck, > FiatShamir for WeakFiatShamir { - fn compute_challenge(_vdf: &WesolowskiVDF, input: &G, output: &G) -> BigInt { + fn compute_challenge(_vdf: &WesolowskisVDF, input: &G, output: &G) -> BigInt { let mut seed = vec![]; seed.extend_from_slice(&input.to_bytes()); seed.extend_from_slice(&output.to_bytes()); - hash_prime::hash_prime::

(&seed, CHALLENGE_SIZE, &[CHALLENGE_SIZE - 1]) - .expect("The length should be a multiple of 8") + hash_prime::hash_prime::

(&seed, CHALLENGE_SIZE, &[8 * CHALLENGE_SIZE - 1]).into() } } @@ -153,7 +219,7 @@ impl< P: PrimalityCheck, > FiatShamir for StrongFiatShamir { - fn compute_challenge(vdf: &WesolowskiVDF, input: &G, output: &G) -> BigInt { + fn compute_challenge(vdf: &WesolowskisVDF, input: &G, output: &G) -> BigInt { let mut seed = vec![]; let input_bytes = input.to_bytes(); @@ -167,16 +233,22 @@ impl< seed.extend_from_slice(&(vdf.iterations).to_be_bytes()); seed.extend_from_slice(&vdf.group_parameter.to_bytes()); - hash_prime::hash_prime::

(&seed, CHALLENGE_SIZE, &[CHALLENGE_SIZE - 1]) - .expect("The length should be a multiple of 8") + hash_prime::hash_prime::

(&seed, CHALLENGE_SIZE, &[0, 8 * CHALLENGE_SIZE - 1]).into() } } impl Parameter for Discriminant { /// Compute a valid discriminant (aka a negative prime equal to 3 mod 4) based on the given seed. + /// The size_in_bits must be divisible by 8. fn from_seed(seed: &[u8], size_in_bits: usize) -> FastCryptoResult { + if size_in_bits % 8 != 0 { + return Err(InvalidInput); + } Self::try_from( - hash_prime::hash_prime_default(seed, size_in_bits, &[0, 1, 2, size_in_bits - 1])?.neg(), + hash_prime::hash_prime_default(seed, size_in_bits / 8, &[0, 1, 2, size_in_bits - 1]) + .to_bigint() + .expect("Never fails") + .neg(), ) } } @@ -184,7 +256,7 @@ impl Parameter for Discriminant { #[cfg(test)] mod tests { use crate::class_group::{Discriminant, QuadraticForm}; - use crate::vdf::wesolowski::{StrongVDF, WeakVDF}; + use crate::vdf::wesolowski::{StrongVDF, StrongVDFVerifier, WeakVDF, WeakVDFVerifier}; use crate::vdf::VDF; use crate::{Parameter, ParameterizedGroupElement}; use num_bigint::BigInt; @@ -195,17 +267,20 @@ mod tests { let iterations = 1000u64; let discriminant = Discriminant::from_seed(&challenge, 1024).unwrap(); - let g = QuadraticForm::generator(&discriminant); + let input = QuadraticForm::generator(&discriminant); let vdf = StrongVDF::::new(discriminant, iterations); - let (output, proof) = vdf.evaluate(&g).unwrap(); - assert!(vdf.verify(&g, &output, &proof).is_ok()); + let (output, proof) = vdf.evaluate(&input).unwrap(); + assert!(vdf.verify(&input, &output, &proof).is_ok()); // A modified output or proof fails to verify let modified_output = output.mul(&BigInt::from(2)); let modified_proof = proof.mul(&BigInt::from(2)); - assert!(vdf.verify(&g, &modified_output, &proof).is_err()); - assert!(vdf.verify(&g, &output, &modified_proof).is_err()); + assert!(vdf.verify(&input, &modified_output, &proof).is_err()); + assert!(vdf.verify(&input, &output, &modified_proof).is_err()); + + let fast_verifier = StrongVDFVerifier::new(vdf, input); + assert!(fast_verifier.verify(&output, &proof).is_ok()); } #[test] @@ -229,6 +304,9 @@ mod tests { let vdf = WeakVDF::::from_seed(&challenge, 1024, iterations).unwrap(); assert!(vdf.verify(&input, &result, &proof).is_ok()); + + let fast_verifier = WeakVDFVerifier::new(vdf, input); + assert!(fast_verifier.verify(&result, &proof).is_ok()); } #[test] @@ -236,7 +314,7 @@ mod tests { let challenge = hex::decode("99c9e5e3a4449a4b4e15").unwrap(); let iterations = 1000u64; let result_hex = "030039c78c39cff6c29052bfc1453616ec7a47251509b9dbc33d1036bebd4d12e6711a51deb327120310f96be04c90fd4c3b1dab9617c3133132b827abe7bb2348707da8164b964e1b95cd6a8eaf36ffb80bab1f750410e793daec8228b222bd00370100"; - let proof_hex = "0100eb1e1b0b58bca2ceca30344321d77e6c6f995e7c9db878d63aa71348db3577634e309be81ed71cba185a0d3c6bba2945c7002cc757c29a612afec8bf95581c008d2fe1c77e5171f8b85706e5f823cd233d847117f25b53d45cb30fb036b5b0030100"; + let proof_hex = "03001a967ce490545d31cbd13ef87edbeac3c90f61f3fa0efc1dd6aab5a62007593a8bfba9aee82966a8f056c38334eeeafaa1afbb5c98eb9c97bf42bfcce90b0b5c9d32f743b735e1179f2ef169301906fe4acaa95755171b223af2c04038a983630100"; let discriminant = Discriminant::from_seed(&challenge, 1024).unwrap(); @@ -249,7 +327,9 @@ mod tests { let input = QuadraticForm::generator(&discriminant); let vdf = StrongVDF::::new(discriminant, iterations); - assert!(vdf.verify(&input, &result, &proof).is_ok()); + + let fast_verifier = StrongVDFVerifier::new(vdf, input); + assert!(fast_verifier.verify(&result, &proof).is_ok()); } } diff --git a/fastcrypto/Cargo.toml b/fastcrypto/Cargo.toml index b1fd81b190..9db6cfbd13 100644 --- a/fastcrypto/Cargo.toml +++ b/fastcrypto/Cargo.toml @@ -61,6 +61,7 @@ ark-serialize = "0.4.1" lazy_static = "1.4.0" fastcrypto-derive = { path = "../fastcrypto-derive", version = "0.1.3" } serde_json = "1.0.93" +num-bigint = "0.4.4" [[bench]] name = "crypto" diff --git a/fastcrypto/benches/groups.rs b/fastcrypto/benches/groups.rs index dc244d654d..d95682ec0e 100644 --- a/fastcrypto/benches/groups.rs +++ b/fastcrypto/benches/groups.rs @@ -40,14 +40,18 @@ mod group_benches { c.bench_function(&(name.to_string()), move |b| b.iter(|| x * y)); } - fn scale_single_precomputed, M: Measurement>( + fn scale_single_precomputed< + G: GroupElement, + Mul: ScalarMultiplier, + M: Measurement, + >( name: &str, c: &mut BenchmarkGroup, ) { let x = G::generator() * G::ScalarType::rand(&mut thread_rng()); let y = G::ScalarType::rand(&mut thread_rng()); - let multiplier = Mul::new(x); + let multiplier = Mul::new(x, G::zero()); c.bench_function(&(name.to_string()), move |b| b.iter(|| multiplier.mul(&y))); } @@ -61,32 +65,36 @@ mod group_benches { scale_single_precomputed::< ProjectivePoint, - WindowedScalarMultiplier, + WindowedScalarMultiplier, _, >("Secp256r1 Fixed window (16)", &mut group); scale_single_precomputed::< ProjectivePoint, - WindowedScalarMultiplier, + WindowedScalarMultiplier, _, >("Secp256r1 Fixed window (32)", &mut group); scale_single_precomputed::< ProjectivePoint, - WindowedScalarMultiplier, + WindowedScalarMultiplier, _, >("Secp256r1 Fixed window (64)", &mut group); scale_single_precomputed::< ProjectivePoint, - WindowedScalarMultiplier, + WindowedScalarMultiplier, _, >("Secp256r1 Fixed window (128)", &mut group); scale_single_precomputed::< ProjectivePoint, - WindowedScalarMultiplier, + WindowedScalarMultiplier, _, >("Secp256r1 Fixed window (256)", &mut group); } - fn double_scale_single, M: Measurement>( + fn double_scale_single< + G: GroupElement, + Mul: ScalarMultiplier, + M: Measurement, + >( name: &str, c: &mut BenchmarkGroup, ) { @@ -95,7 +103,7 @@ mod group_benches { let g2 = G::generator() * G::ScalarType::rand(&mut thread_rng()); let s2 = G::ScalarType::rand(&mut thread_rng()); - let multiplier = Mul::new(g1); + let multiplier = Mul::new(g1, G::zero()); c.bench_function(&(name.to_string()), move |b| { b.iter(|| multiplier.two_scalar_mul(&s1, &g2, &s2)) }); @@ -106,27 +114,27 @@ mod group_benches { double_scale_single::< ProjectivePoint, - WindowedScalarMultiplier, + WindowedScalarMultiplier, _, >("Secp256r1 Straus (16)", &mut group); double_scale_single::< ProjectivePoint, - WindowedScalarMultiplier, + WindowedScalarMultiplier, _, >("Secp256r1 Straus (32)", &mut group); double_scale_single::< ProjectivePoint, - WindowedScalarMultiplier, + WindowedScalarMultiplier, _, >("Secp256r1 Straus (64)", &mut group); double_scale_single::< ProjectivePoint, - WindowedScalarMultiplier, + WindowedScalarMultiplier, _, >("Secp256r1 Straus (128)", &mut group); double_scale_single::< ProjectivePoint, - WindowedScalarMultiplier, + WindowedScalarMultiplier, _, >("Secp256r1 Straus (256)", &mut group); double_scale_single::, _>( @@ -171,14 +179,23 @@ mod group_benches { /// simply calling the GroupElement implementation. Only used for benchmarking. struct DefaultMultiplier(G); - impl ScalarMultiplier for DefaultMultiplier { - fn new(base_element: G) -> Self { + impl ScalarMultiplier for DefaultMultiplier { + fn new(base_element: G, _zero: G) -> Self { Self(base_element) } fn mul(&self, scalar: &G::ScalarType) -> G { self.0 * scalar } + + fn two_scalar_mul( + &self, + base_scalar: &G::ScalarType, + other_element: &G, + other_scalar: &G::ScalarType, + ) -> G { + self.0 * base_scalar + *other_element * other_scalar + } } criterion_group! { diff --git a/fastcrypto/src/groups/mod.rs b/fastcrypto/src/groups/mod.rs index bd254f955f..b5446059e2 100644 --- a/fastcrypto/src/groups/mod.rs +++ b/fastcrypto/src/groups/mod.rs @@ -41,11 +41,6 @@ pub trait GroupElement: /// Return an instance of the generator for this group. fn generator() -> Self; - - /// Compute 2 * Self. May be overwritten by implementations that have a fast doubling operation. - fn double(&self) -> Self { - *self + self - } } // TODO: Move Serialize + DeserializeOwned to GroupElement. @@ -58,6 +53,12 @@ pub trait Scalar: fn inverse(&self) -> FastCryptoResult; } +/// Trait for group elements that has a fast doubling operation. +pub trait Doubling { + /// Compute 2 * Self = Self + Self. + fn double(&self) -> Self; +} + pub trait Pairing: GroupElement { type Other: GroupElement; type Output; diff --git a/fastcrypto/src/groups/multiplier/bgmw.rs b/fastcrypto/src/groups/multiplier/bgmw.rs index 2229401744..7e1a672b99 100644 --- a/fastcrypto/src/groups/multiplier/bgmw.rs +++ b/fastcrypto/src/groups/multiplier/bgmw.rs @@ -3,7 +3,7 @@ use crate::groups::multiplier::integer_utils::{compute_base_2w_expansion, div_ceil}; use crate::groups::multiplier::ScalarMultiplier; -use crate::groups::GroupElement; +use crate::groups::{Doubling, GroupElement}; use crate::serde_helpers::ToFromByteArray; /// Performs scalar multiplication using a windowed method with a larger pre-computation table than @@ -45,14 +45,14 @@ impl< } impl< - G: GroupElement, + G: GroupElement + Doubling, S: GroupElement + ToFromByteArray, const WIDTH: usize, const HEIGHT: usize, const SCALAR_SIZE: usize, - > ScalarMultiplier for BGMWScalarMultiplier + > ScalarMultiplier for BGMWScalarMultiplier { - fn new(base_element: G) -> Self { + fn new(base_element: G, zero: G) -> Self { // Verify parameters let lower_limit = div_ceil(SCALAR_SIZE * 8, Self::WINDOW_WIDTH); if HEIGHT < lower_limit { @@ -60,7 +60,7 @@ impl< } // Store cache[i][j] = 2^{i w} * j * base_element - let mut cache = [[G::zero(); WIDTH]; HEIGHT]; + let mut cache = [[zero; WIDTH]; HEIGHT]; // Compute cache[0][j] = j * base_element. for j in 1..WIDTH { @@ -83,8 +83,7 @@ impl< // Scalar as bytes in little-endian representation. let scalar_bytes = scalar.to_byte_array(); - let base_2w_expansion = - compute_base_2w_expansion::(&scalar_bytes, Self::WINDOW_WIDTH); + let base_2w_expansion = compute_base_2w_expansion(&scalar_bytes, Self::WINDOW_WIDTH); let mut result = self.get_precomputed_multiple(0, base_2w_expansion[0]); for (i, digit) in base_2w_expansion.iter().enumerate().skip(1) { @@ -92,6 +91,10 @@ impl< } result } + + fn two_scalar_mul(&self, base_scalar: &S, other_element: &G, other_scalar: &S) -> G { + self.cache[0][1] * base_scalar + *other_element * *other_scalar + } } #[cfg(test)] @@ -106,6 +109,7 @@ mod tests { fn test_scalar_multiplication_ristretto() { let multiplier = BGMWScalarMultiplier::::new( RistrettoPoint::generator(), + RistrettoPoint::zero(), ); let scalars = [ @@ -148,18 +152,21 @@ mod tests { let multiplier = BGMWScalarMultiplier::::new( ProjectivePoint::generator(), + ProjectivePoint::zero(), ); let actual = multiplier.mul(&scalar); assert_eq!(expected, actual); let multiplier = BGMWScalarMultiplier::::new( ProjectivePoint::generator(), + ProjectivePoint::zero(), ); let actual = multiplier.mul(&scalar); assert_eq!(expected, actual); let multiplier = BGMWScalarMultiplier::::new( ProjectivePoint::generator(), + ProjectivePoint::zero(), ); let actual = multiplier.mul(&scalar); assert_eq!(expected, actual); @@ -169,6 +176,7 @@ mod tests { assert!(std::panic::catch_unwind(|| { BGMWScalarMultiplier::::new( ProjectivePoint::generator(), + ProjectivePoint::zero(), ) }) .is_err()); diff --git a/fastcrypto/src/groups/multiplier/integer_utils.rs b/fastcrypto/src/groups/multiplier/integer_utils.rs index 71504d93f0..5e0c95aadf 100644 --- a/fastcrypto/src/groups/multiplier/integer_utils.rs +++ b/fastcrypto/src/groups/multiplier/integer_utils.rs @@ -1,20 +1,22 @@ // Copyright (c) 2022, Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 +use crate::groups::multiplier::ToLittleEndianBytes; +use num_bigint::BigInt; + /// Given a binary representation of a number in little-endian format, return the digits of its base /// `2^bits_per_digit` expansion. -pub fn compute_base_2w_expansion( - bytes: &[u8; N], - bits_per_digit: usize, -) -> Vec { +pub fn compute_base_2w_expansion(bytes: &[u8], bits_per_digit: usize) -> Vec { assert!(0 < bits_per_digit && bits_per_digit <= usize::BITS as usize); // The base 2^window_size expansions digits in little-endian representation. let mut digits = Vec::new(); + let n = bytes.len(); + // Compute the number of digits needed to represent the numbed in base 2^w. This is equal to // ceil(8*N / window_size), and we compute like this because div_ceil is unstable as of rustc 1.69.0. - let digits_count = div_ceil(8 * N, bits_per_digit); + let digits_count = div_ceil(8 * n, bits_per_digit); for i in 0..digits_count { digits.push(get_bits_from_bytes( @@ -51,10 +53,10 @@ pub(crate) fn div_ceil(numerator: usize, denominator: usize) -> usize { /// Get the integer represented by a given range of bits of a an integer represented by a little-endian /// byte array from start to end (exclusive). The `end` argument may be arbitrarily large, but if it -/// is larger than 8*N, the remaining bits of the byte array will be assumed to be zero. +/// is larger than 8*bytes.len(), the remaining bits of the byte array will be assumed to be zero. #[inline] -pub fn get_bits_from_bytes(bytes: &[u8; N], start: usize, end: usize) -> usize { - assert!(start <= end && start < 8 * N); +pub fn get_bits_from_bytes(bytes: &[u8], start: usize, end: usize) -> usize { + assert!(start <= end && start < 8 * bytes.len()); let mut result: usize = 0; let mut bits_added = 0; @@ -62,7 +64,7 @@ pub fn get_bits_from_bytes(bytes: &[u8; N], start: usize, end: u let mut current_bit = start % 8; let mut current_byte = start / 8; - while bits_added < end - start && current_byte < N { + while bits_added < end - start && current_byte < bytes.len() { let remaining_bits = end - start - bits_added; let (bits_to_read, next_byte, next_bit) = if remaining_bits < 8 - current_bit { // There are enough bits left in the current byte @@ -90,8 +92,8 @@ pub fn get_bits_from_bytes(bytes: &[u8; N], start: usize, end: u /// Return true iff the bit at the given index is set. #[inline] -pub fn test_bit(bytes: &[u8; N], index: usize) -> bool { - assert!(index < 8 * N); +pub fn test_bit(bytes: &[u8], index: usize) -> bool { + assert!(index < 8 * bytes.len()); let byte = index >> 3; let shifted = bytes[byte] >> (index & 7); shifted & 1 != 0 @@ -148,7 +150,7 @@ mod tests { let bytes = value.to_le_bytes(); // Is w = 8, the base 2^w expansion should be equal to the le bytes. - let expansion = compute_base_2w_expansion::<16>(&bytes, 8); + let expansion = compute_base_2w_expansion(&bytes, 8); assert_eq!( bytes.to_vec(), expansion.iter().map(|x| *x as u8).collect::>() @@ -156,7 +158,7 @@ mod tests { // Verify that the expansion is correct for w = 1, ..., 64 for window_size in 1..=64 { - let expansion = compute_base_2w_expansion::<16>(&bytes, window_size); + let expansion = compute_base_2w_expansion(&bytes, window_size); let mut sum = 0u128; for (i, value) in expansion.iter().enumerate() { sum += (1 << (window_size * i)) * *value as u128; @@ -189,3 +191,11 @@ mod tests { assert!(is_power_of_2(4096)); } } + +// We implementation `ToLittleEndianByteArray` for BigInt in case it needs to be used as scalar for +// multi-scalar multiplication. +impl ToLittleEndianBytes for BigInt { + fn to_le_bytes(&self) -> Vec { + self.to_bytes_le().1 + } +} diff --git a/fastcrypto/src/groups/multiplier/mod.rs b/fastcrypto/src/groups/multiplier/mod.rs index 7a2295c824..be4afcfedd 100644 --- a/fastcrypto/src/groups/multiplier/mod.rs +++ b/fastcrypto/src/groups/multiplier/mod.rs @@ -4,29 +4,24 @@ //! This module contains implementations of optimised scalar multiplication algorithms where the //! group element is fixed and certain multiples of this may be pre-computed. -use crate::groups::GroupElement; - #[cfg(feature = "experimental")] pub mod bgmw; mod integer_utils; pub mod windowed; /// Trait for scalar multiplication for a fixed group element, e.g. by using precomputed values. -pub trait ScalarMultiplier { +pub trait ScalarMultiplier { /// Create a new scalar multiplier with the given base element. - fn new(base_element: G) -> Self; + fn new(base_element: G, zero: G) -> Self; /// Compute `self.base_element * scalar`. - fn mul(&self, scalar: &G::ScalarType) -> G; + fn mul(&self, scalar: &S) -> G; /// Compute `self.base_element * base_scalar + other_element * other_scalar`. - fn two_scalar_mul( - &self, - base_scalar: &G::ScalarType, - other_element: &G, - other_scalar: &G::ScalarType, - ) -> G { - // The default implementation. May be overwritten by implementations that allow optimised double multiplication. - self.mul(base_scalar) + *other_element * other_scalar - } + fn two_scalar_mul(&self, base_scalar: &S, other_element: &G, other_scalar: &S) -> G; +} + +pub trait ToLittleEndianBytes { + /// Serialize scalar into a byte vector in little-endian format. + fn to_le_bytes(&self) -> Vec; } diff --git a/fastcrypto/src/groups/multiplier/windowed.rs b/fastcrypto/src/groups/multiplier/windowed.rs index 006f577b71..69285f8a60 100644 --- a/fastcrypto/src/groups/multiplier/windowed.rs +++ b/fastcrypto/src/groups/multiplier/windowed.rs @@ -2,101 +2,93 @@ // SPDX-License-Identifier: Apache-2.0 use std::collections::HashMap; +use std::fmt::Debug; use std::iter::successors; +use std::marker::PhantomData; +use std::ops::{Add, Mul}; use crate::groups::multiplier::integer_utils::{get_bits_from_bytes, is_power_of_2, test_bit}; -use crate::groups::multiplier::{integer_utils, ScalarMultiplier}; -use crate::groups::GroupElement; -use crate::serde_helpers::ToFromByteArray; +use crate::groups::multiplier::{integer_utils, ScalarMultiplier, ToLittleEndianBytes}; +use crate::groups::Doubling; /// This scalar multiplier uses pre-computation with the windowed method. This multiplier is particularly /// fast for double multiplications, where a sliding window method is used, but this implies that the /// `double_mul`, is NOT constant time. However, the single multiplication method `mul` is constant /// time if the group operations for `G` are constant time. /// -/// The `CACHE_SIZE` should be a power of two > 1. The `SCALAR_SIZE` is the number of bytes in the byte -/// representation of the scalar type `S`, and we assume that the `S::to_byte_array` method returns -/// the scalar in little-endian format. +/// The `CACHE_SIZE` should be a power of two > 1. /// /// The `SLIDING_WINDOW_WIDTH` is the number of bits in the sliding window of the elements not already -/// with precomputed multiples. This should be approximately log2(sqrt(SCALAR_SIZE_IN_BITS)) + 1 for +/// with precomputed multiples. This should be approximately log2(sqrt(scalar size in bits)) + 1 for /// optimal performance. pub struct WindowedScalarMultiplier< - G: GroupElement, - S: GroupElement + ToFromByteArray, + G, + S, const CACHE_SIZE: usize, - const SCALAR_SIZE: usize, const SLIDING_WINDOW_WIDTH: usize, > { /// Precomputed multiples of the base element from 0 up to CACHE_SIZE - 1 = 2^WINDOW_WIDTH - 1. cache: [G; CACHE_SIZE], + _scalar: PhantomData, } -impl< - G: GroupElement, - S: GroupElement + ToFromByteArray, - const CACHE_SIZE: usize, - const SCALAR_SIZE: usize, - const SLIDING_WINDOW_WIDTH: usize, - > WindowedScalarMultiplier +impl + WindowedScalarMultiplier { /// The number of bits in the window. This is equal to the floor of the log2 of the cache size. const WINDOW_WIDTH: usize = integer_utils::log2(CACHE_SIZE); } impl< - G: GroupElement, - S: GroupElement + ToFromByteArray, + G: for<'a> Add<&'a G, Output = G> + for<'a> Mul<&'a S, Output = G> + Doubling + Clone + Debug, + S: ToLittleEndianBytes + Clone + Debug, const CACHE_SIZE: usize, - const SCALAR_SIZE: usize, const SLIDING_WINDOW_WIDTH: usize, - > ScalarMultiplier - for WindowedScalarMultiplier + > ScalarMultiplier for WindowedScalarMultiplier { - fn new(base_element: G) -> Self { + fn new(base_element: G, zero: G) -> Self { if !is_power_of_2(CACHE_SIZE) || CACHE_SIZE <= 1 { panic!("CACHE_SIZE must be a power of two greater than 1"); } - let mut cache = [G::zero(); CACHE_SIZE]; - cache[1] = base_element; + let mut cache = vec![]; + cache.push(zero); + cache.push(base_element.clone()); for i in 2..CACHE_SIZE { - cache[i] = cache[i - 1] + base_element; + cache.push(cache[i - 1].clone() + &base_element); + } + let cache: [G; CACHE_SIZE] = cache.try_into().unwrap(); + Self { + cache, + _scalar: PhantomData, } - Self { cache } } fn mul(&self, scalar: &S) -> G { // Scalar as bytes in little-endian representation. - let scalar_bytes = scalar.to_byte_array(); + let scalar_bytes = scalar.to_le_bytes(); - let base_2w_expansion = integer_utils::compute_base_2w_expansion::( - &scalar_bytes, - Self::WINDOW_WIDTH, - ); + let base_2w_expansion = + integer_utils::compute_base_2w_expansion(&scalar_bytes, Self::WINDOW_WIDTH); // Computer multiplication using the fixed-window method to ensure that it's constant time. - let mut result: G = self.cache[base_2w_expansion[base_2w_expansion.len() - 1]]; + let mut result: G = self.cache[base_2w_expansion[base_2w_expansion.len() - 1]].clone(); for digit in base_2w_expansion.iter().rev().skip(1) { for _ in 1..=Self::WINDOW_WIDTH { result = result.double(); } - result += self.cache[*digit]; + result = result + &self.cache[*digit]; } result } - fn two_scalar_mul( - &self, - base_scalar: &G::ScalarType, - other_element: &G, - other_scalar: &G::ScalarType, - ) -> G { + fn two_scalar_mul(&self, base_scalar: &S, other_element: &G, other_scalar: &S) -> G { // Compute the sum of the two multiples using Straus' algorithm combined with a sliding window algorithm. multi_scalar_mul( - &[*base_scalar, *other_scalar], - &[self.cache[1], *other_element], + &[base_scalar.clone(), other_scalar.clone()], + &[self.cache[1].clone(), other_element.clone()], &HashMap::from([(0, self.cache[CACHE_SIZE / 2..CACHE_SIZE].to_vec())]), SLIDING_WINDOW_WIDTH, + self.cache[0].clone(), ) } } @@ -114,16 +106,20 @@ impl< /// table and may be set to any value >= 1. As rule-of-thumb, this should be set to approximately /// the bit length of the square root of the scalar size for optimal performance. pub fn multi_scalar_mul< - G: GroupElement, - S: GroupElement + ToFromByteArray, - const SCALAR_SIZE: usize, + G: Doubling + for<'a> Add<&'a G, Output = G> + for<'a> Mul<&'a S, Output = G> + Clone + Debug, + S: ToLittleEndianBytes + Clone + Debug, const N: usize, >( - scalars: &[G::ScalarType; N], + scalars: &[S; N], elements: &[G; N], precomputed_multiples: &HashMap>, default_window_width: usize, + zero: G, ) -> G { + if N == 0 { + return zero; + } + let mut window_sizes = [0usize; N]; // Compute missing precomputation tables. @@ -152,8 +148,14 @@ pub fn multi_scalar_mul< // Compute little-endian byte representations of scalars. let scalar_bytes = scalars .iter() - .map(|s| s.to_byte_array()) - .collect::>(); + .map(|s| s.to_le_bytes()) + .collect::>>(); + + let scalar_size = scalar_bytes + .iter() + .map(|b| b.len()) + .max() + .expect("No scalars given."); // We iterate from the top bit and down for all scalars until we reach a set bit. This marks the // beginning of a window, and we continue the iteration. When the iterations exists the window, @@ -165,10 +167,10 @@ pub fn multi_scalar_mul< // We may skip doubling until result is non-zero. let mut is_zero = true; - let mut result = G::zero(); + let mut result = zero; // Iterate through all bits of the scalars from the top. - for bit in (0..SCALAR_SIZE * 8).rev() { + for bit in (0..scalar_size * 8).rev() { if !is_zero { result = result.double(); } @@ -180,9 +182,9 @@ pub fn multi_scalar_mul< // This window is finished. Add the right precomputed value and indicate that we are ready for a new window. result = if is_zero { is_zero = false; - all_precomputed_multiples[i][precomputed_multiple_index[i]] + all_precomputed_multiples[i][precomputed_multiple_index[i]].clone() } else { - result + all_precomputed_multiples[i][precomputed_multiple_index[i]] + result + &all_precomputed_multiples[i][precomputed_multiple_index[i]] }; is_in_window[i] = false; } @@ -201,9 +203,9 @@ pub fn multi_scalar_mul< // There is not enough room left for a window. Continue with regular double-and-add. result = if is_zero { is_zero = false; - elements[i] + elements[i].clone() } else { - result + elements[i] + result + &elements[i] }; } } @@ -213,13 +215,19 @@ pub fn multi_scalar_mul< } /// Compute multiples 2w-1 base_element, (2w-1 + 1) base_element, ..., (2w - 1) base_element. -fn compute_multiples(base_element: &G, window_size: usize) -> Vec { +fn compute_multiples< + S, + G: Doubling + for<'a> Add<&'a G, Output = G> + for<'a> Mul<&'a S, Output = G> + Clone + Debug, +>( + base_element: &G, + window_size: usize, +) -> Vec { assert!(window_size > 0, "Window size must be strictly positive."); - let mut smallest_multiple = *base_element; + let mut smallest_multiple = base_element.clone(); for _ in 1..window_size { smallest_multiple = smallest_multiple.double(); } - successors(Some(smallest_multiple), |g| Some(*g + base_element)) + successors(Some(smallest_multiple), |g| Some(g.clone() + base_element)) .take(1 << (window_size - 1)) .collect::>() } @@ -232,16 +240,24 @@ mod tests { use crate::groups::ristretto255::{RistrettoPoint, RistrettoScalar}; use crate::groups::secp256r1::{ProjectivePoint, Scalar}; + use crate::groups::GroupElement; use crate::groups::Scalar as ScalarTrait; + use crate::serde_helpers::ToFromByteArray; use super::*; + impl ToLittleEndianBytes for RistrettoScalar { + fn to_le_bytes(&self) -> Vec { + self.to_byte_array().to_vec() + } + } + #[test] fn test_scalar_multiplication_ristretto() { - let multiplier = - WindowedScalarMultiplier::::new( - RistrettoPoint::generator(), - ); + let multiplier = WindowedScalarMultiplier::::new( + RistrettoPoint::generator(), + RistrettoPoint::zero(), + ); let scalars = [ RistrettoScalar::from(0), @@ -281,32 +297,37 @@ mod tests { for scalar in scalars { let expected = ProjectivePoint::generator() * scalar; - let multiplier = WindowedScalarMultiplier::::new( + let multiplier = WindowedScalarMultiplier::::new( ProjectivePoint::generator(), + ProjectivePoint::zero(), ); let actual = multiplier.mul(&scalar); assert_eq!(expected, actual); - let multiplier = WindowedScalarMultiplier::::new( + let multiplier = WindowedScalarMultiplier::::new( ProjectivePoint::generator(), + ProjectivePoint::zero(), ); let actual = multiplier.mul(&scalar); assert_eq!(expected, actual); - let multiplier = WindowedScalarMultiplier::::new( + let multiplier = WindowedScalarMultiplier::::new( ProjectivePoint::generator(), + ProjectivePoint::zero(), ); let actual = multiplier.mul(&scalar); assert_eq!(expected, actual); - let multiplier = WindowedScalarMultiplier::::new( + let multiplier = WindowedScalarMultiplier::::new( ProjectivePoint::generator(), + ProjectivePoint::zero(), ); let actual = multiplier.mul(&scalar); assert_eq!(expected, actual); - let multiplier = WindowedScalarMultiplier::::new( + let multiplier = WindowedScalarMultiplier::::new( ProjectivePoint::generator(), + ProjectivePoint::zero(), ); let actual = multiplier.mul(&scalar); assert_eq!(expected, actual); @@ -315,10 +336,10 @@ mod tests { #[test] fn test_double_mul_ristretto() { - let multiplier = - WindowedScalarMultiplier::::new( - RistrettoPoint::generator(), - ); + let multiplier = WindowedScalarMultiplier::::new( + RistrettoPoint::generator(), + RistrettoPoint::zero(), + ); let other_point = RistrettoPoint::generator() * RistrettoScalar::from(3); diff --git a/fastcrypto/src/groups/ristretto255.rs b/fastcrypto/src/groups/ristretto255.rs index 8fb36e6ae3..f2c6776924 100644 --- a/fastcrypto/src/groups/ristretto255.rs +++ b/fastcrypto/src/groups/ristretto255.rs @@ -6,7 +6,7 @@ use crate::error::FastCryptoResult; use crate::groups::{ - FiatShamirChallenge, GroupElement, HashToGroupElement, MultiScalarMul, Scalar, + Doubling, FiatShamirChallenge, GroupElement, HashToGroupElement, MultiScalarMul, Scalar, }; use crate::hash::Sha512; use crate::serde_helpers::ToFromByteArray; @@ -23,7 +23,7 @@ use curve25519_dalek_ng::traits::{Identity, VartimeMultiscalarMul}; use derive_more::{Add, Div, From, Neg, Sub}; use fastcrypto_derive::GroupOpsExtend; use serde::{de, Deserialize}; -use std::ops::{Div, Mul}; +use std::ops::{Add, Div, Mul}; use zeroize::Zeroize; const RISTRETTO_POINT_BYTE_LENGTH: usize = 32; @@ -57,6 +57,12 @@ impl RistrettoPoint { } } +impl Doubling for RistrettoPoint { + fn double(&self) -> Self { + Self(self.0.add(&self.0)) + } +} + impl MultiScalarMul for RistrettoPoint { fn multi_scalar_mul(scalars: &[Self::ScalarType], points: &[Self]) -> FastCryptoResult { if scalars.len() != points.len() { diff --git a/fastcrypto/src/groups/secp256r1.rs b/fastcrypto/src/groups/secp256r1.rs index 87d6742d14..0e0ff43a23 100644 --- a/fastcrypto/src/groups/secp256r1.rs +++ b/fastcrypto/src/groups/secp256r1.rs @@ -5,7 +5,8 @@ //! See "SEC 2: Recommended Elliptic Curve Domain Parameters" for details." use crate::error::{FastCryptoError, FastCryptoResult}; -use crate::groups::{GroupElement, Scalar as ScalarTrait}; +use crate::groups::multiplier::ToLittleEndianBytes; +use crate::groups::{Doubling, GroupElement, Scalar as ScalarTrait}; use crate::serde_helpers::ToFromByteArray; use crate::serialize_deserialize_with_to_from_byte_array; use crate::traits::AllowedRng; @@ -34,7 +35,9 @@ impl GroupElement for ProjectivePoint { fn generator() -> Self { Self(Projective::generator()) } +} +impl Doubling for ProjectivePoint { fn double(&self) -> Self { ProjectivePoint::from(self.0.double()) } @@ -125,4 +128,10 @@ impl ToFromByteArray for Scalar { } } +impl ToLittleEndianBytes for Scalar { + fn to_le_bytes(&self) -> Vec { + self.to_byte_array().to_vec() + } +} + serialize_deserialize_with_to_from_byte_array!(Scalar); diff --git a/fastcrypto/src/secp256r1/mod.rs b/fastcrypto/src/secp256r1/mod.rs index e4fcfcbd07..96909e9ce3 100644 --- a/fastcrypto/src/secp256r1/mod.rs +++ b/fastcrypto/src/secp256r1/mod.rs @@ -46,7 +46,7 @@ use fastcrypto_derive::{SilentDebug, SilentDisplay}; use crate::groups::multiplier::windowed::WindowedScalarMultiplier; use crate::groups::multiplier::ScalarMultiplier; use crate::groups::secp256r1; -use crate::groups::secp256r1::{ProjectivePoint, SCALAR_SIZE_IN_BYTES}; +use crate::groups::secp256r1::ProjectivePoint; use crate::hash::{HashFunction, Sha256}; use crate::secp256r1::conversion::{ affine_pt_p256_to_projective_arkworks, arkworks_fq_to_fr, fr_arkworks_to_p256, @@ -150,15 +150,16 @@ lazy_static! { ProjectivePoint, ::ScalarType, PRECOMPUTED_POINTS, - SCALAR_SIZE_IN_BYTES, SLIDING_WINDOW_WIDTH, > = WindowedScalarMultiplier::< ProjectivePoint, ::ScalarType, PRECOMPUTED_POINTS, - SCALAR_SIZE_IN_BYTES, SLIDING_WINDOW_WIDTH, - >::new(secp256r1::ProjectivePoint::generator()); + >::new( + secp256r1::ProjectivePoint::generator(), + secp256r1::ProjectivePoint::zero() + ); } serialize_deserialize_with_to_from_bytes!(Secp256r1PublicKey, SECP256R1_PUBLIC_KEY_LENGTH); diff --git a/fastcrypto/src/tests/secp256r1_group_tests.rs b/fastcrypto/src/tests/secp256r1_group_tests.rs index 17694e9a0c..8c8a42482a 100644 --- a/fastcrypto/src/tests/secp256r1_group_tests.rs +++ b/fastcrypto/src/tests/secp256r1_group_tests.rs @@ -1,7 +1,7 @@ // Copyright (c) 2022, Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -use crate::groups::GroupElement; +use crate::groups::{Doubling, GroupElement}; use crate::groups::secp256r1::{ProjectivePoint, Scalar}; use crate::groups::{secp256r1, Scalar as ScalarTrait};