Skip to content

Commit

Permalink
Refactor: get rid of scalar_size constant
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-lj committed Dec 6, 2023
1 parent 98795c8 commit 999edf9
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 103 deletions.
2 changes: 1 addition & 1 deletion fastcrypto-vdf/benches/vdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fn verify_single<M: Measurement>(parameters: VerificationInputs, c: &mut Benchma
let fast_verify: FastVerifier<
QuadraticForm,
StrongFiatShamir<QuadraticForm, CHALLENGE_SIZE, DefaultPrimalityCheck>,
WindowedScalarMultiplier<QuadraticForm, BigInt, 256, CHALLENGE_SIZE, 5>,
WindowedScalarMultiplier<QuadraticForm, BigInt, 256, 5>,
> = FastVerifier::new(vdf, input_copy);
c.bench_function(format!("{} fast", discriminant_size), move |b| {
b.iter(|| fast_verify.verify(&result_copy, &proof_copy))
Expand Down
14 changes: 4 additions & 10 deletions fastcrypto-vdf/src/hash_prime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,13 @@ pub fn verify_prime<P: PrimalityCheck>(
bitmask: &[usize],
prime: &(usize, BigUint),
) -> FastCryptoResult<()> {
let iterator = HashPrimeIterator {
let mut iterator = HashPrimeIterator {
seed: seed.to_vec(),
length_in_bytes,
bitmask: bitmask.to_vec(),
};
// Check that the original index points to a prime
let original_prime = iterator
.skip(prime.0)
.next()
.expect("Iterator failed to give next");
let original_prime = iterator.nth(prime.0).expect("Iterator is infinite");
if P::is_prime(&original_prime) && original_prime == prime.1 {
return Ok(());
}
Expand All @@ -105,7 +102,7 @@ pub fn verify_complaint<P: PrimalityCheck>(
prime: &(usize, BigUint),
complaint: &usize,
) -> FastCryptoResult<()> {
let iterator = HashPrimeIterator {
let mut iterator = HashPrimeIterator {
seed: seed.to_vec(),
length_in_bytes,
bitmask: bitmask.to_vec(),
Expand All @@ -116,10 +113,7 @@ pub fn verify_complaint<P: PrimalityCheck>(
}

// Check that the complaint index points to a prime
let complaint_prime = iterator
.skip(*complaint)
.next()
.expect("Iterator failed to give next");
let complaint_prime = iterator.nth(*complaint).expect("Iterator is infinite");
if !P::is_prime(&complaint_prime) {
return Err(FastCryptoError::InvalidProof);
}
Expand Down
4 changes: 2 additions & 2 deletions fastcrypto-vdf/src/vdf/wesolowski.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ pub type StrongVDF<G> =
pub type StrongVDFVerifier<G> = FastVerifier<
G,
StrongFiatShamir<G, CHALLENGE_SIZE, DefaultPrimalityCheck>,
WindowedScalarMultiplier<G, BigInt, 256, CHALLENGE_SIZE, 5>,
WindowedScalarMultiplier<G, BigInt, 256, 5>,
>;

/// Implementation of Wesolowski's VDF construction over a group of unknown order using the Fiat-Shamir
Expand All @@ -156,7 +156,7 @@ pub type WeakVDF<G> = WesolowskisVDF<G, WeakFiatShamir<G, CHALLENGE_SIZE, Defaul
pub type WeakVDFVerifier<G> = FastVerifier<
G,
WeakFiatShamir<G, CHALLENGE_SIZE, DefaultPrimalityCheck>,
WindowedScalarMultiplier<G, BigInt, 256, CHALLENGE_SIZE, 5>,
WindowedScalarMultiplier<G, BigInt, 256, 5>,
>;

impl<G: ParameterizedGroupElement + UnknownOrderGroupElement, F: FiatShamir<G>>
Expand Down
20 changes: 10 additions & 10 deletions fastcrypto/benches/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,27 @@ mod group_benches {

scale_single_precomputed::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 16, 32, 5>,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 16, 5>,
_,
>("Secp256r1 Fixed window (16)", &mut group);
scale_single_precomputed::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 32, 32, 5>,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 32, 5>,
_,
>("Secp256r1 Fixed window (32)", &mut group);
scale_single_precomputed::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 64, 32, 5>,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 64, 5>,
_,
>("Secp256r1 Fixed window (64)", &mut group);
scale_single_precomputed::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 128, 32, 5>,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 128, 5>,
_,
>("Secp256r1 Fixed window (128)", &mut group);
scale_single_precomputed::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 256, 32, 5>,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 256, 5>,
_,
>("Secp256r1 Fixed window (256)", &mut group);
}
Expand Down Expand Up @@ -114,27 +114,27 @@ mod group_benches {

double_scale_single::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 16, 32, 5>,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 16, 5>,
_,
>("Secp256r1 Straus (16)", &mut group);
double_scale_single::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 32, 32, 5>,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 32, 5>,
_,
>("Secp256r1 Straus (32)", &mut group);
double_scale_single::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 64, 32, 5>,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 64, 5>,
_,
>("Secp256r1 Straus (64)", &mut group);
double_scale_single::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 128, 32, 5>,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 128, 5>,
_,
>("Secp256r1 Straus (128)", &mut group);
double_scale_single::<
ProjectivePoint,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 256, 32, 5>,
WindowedScalarMultiplier<ProjectivePoint, secp256r1::Scalar, 256, 5>,
_,
>("Secp256r1 Straus (256)", &mut group);
double_scale_single::<ProjectivePoint, DefaultMultiplier<ProjectivePoint>, _>(
Expand Down
3 changes: 1 addition & 2 deletions fastcrypto/src/groups/multiplier/bgmw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ impl<
// Scalar as bytes in little-endian representation.
let scalar_bytes = scalar.to_byte_array();

let base_2w_expansion =
compute_base_2w_expansion::<SCALAR_SIZE>(&scalar_bytes, Self::WINDOW_WIDTH);
let base_2w_expansion = compute_base_2w_expansion(&scalar_bytes, Self::WINDOW_WIDTH);

let mut result = self.get_precomputed_multiple(0, base_2w_expansion[0]);
for (i, digit) in base_2w_expansion.iter().enumerate().skip(1) {
Expand Down
40 changes: 16 additions & 24 deletions fastcrypto/src/groups/multiplier/integer_utils.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
// Copyright (c) 2022, Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use crate::groups::multiplier::ToLittleEndianByteArray;
use crate::groups::multiplier::ToLittleEndianBytes;
use num_bigint::BigInt;
use std::cmp::min;

/// Given a binary representation of a number in little-endian format, return the digits of its base
/// `2^bits_per_digit` expansion.
pub fn compute_base_2w_expansion<const N: usize>(
bytes: &[u8; N],
bits_per_digit: usize,
) -> Vec<usize> {
pub fn compute_base_2w_expansion(bytes: &[u8], bits_per_digit: usize) -> Vec<usize> {
assert!(0 < bits_per_digit && bits_per_digit <= usize::BITS as usize);

// The base 2^window_size expansions digits in little-endian representation.
let mut digits = Vec::new();

let n = bytes.len();

// Compute the number of digits needed to represent the numbed in base 2^w. This is equal to
// ceil(8*N / window_size), and we compute like this because div_ceil is unstable as of rustc 1.69.0.
let digits_count = div_ceil(8 * N, bits_per_digit);
let digits_count = div_ceil(8 * n, bits_per_digit);

for i in 0..digits_count {
digits.push(get_bits_from_bytes(
Expand Down Expand Up @@ -55,18 +53,18 @@ pub(crate) fn div_ceil(numerator: usize, denominator: usize) -> usize {

/// Get the integer represented by a given range of bits of a an integer represented by a little-endian
/// byte array from start to end (exclusive). The `end` argument may be arbitrarily large, but if it
/// is larger than 8*N, the remaining bits of the byte array will be assumed to be zero.
/// is larger than 8*bytes.len(), the remaining bits of the byte array will be assumed to be zero.
#[inline]
pub fn get_bits_from_bytes<const N: usize>(bytes: &[u8; N], start: usize, end: usize) -> usize {
assert!(start <= end && start < 8 * N);
pub fn get_bits_from_bytes(bytes: &[u8], start: usize, end: usize) -> usize {
assert!(start <= end && start < 8 * bytes.len());

let mut result: usize = 0;
let mut bits_added = 0;

let mut current_bit = start % 8;
let mut current_byte = start / 8;

while bits_added < end - start && current_byte < N {
while bits_added < end - start && current_byte < bytes.len() {
let remaining_bits = end - start - bits_added;
let (bits_to_read, next_byte, next_bit) = if remaining_bits < 8 - current_bit {
// There are enough bits left in the current byte
Expand Down Expand Up @@ -94,8 +92,8 @@ pub fn get_bits_from_bytes<const N: usize>(bytes: &[u8; N], start: usize, end: u

/// Return true iff the bit at the given index is set.
#[inline]
pub fn test_bit<const N: usize>(bytes: &[u8; N], index: usize) -> bool {
assert!(index < 8 * N);
pub fn test_bit(bytes: &[u8], index: usize) -> bool {
assert!(index < 8 * bytes.len());
let byte = index >> 3;
let shifted = bytes[byte] >> (index & 7);
shifted & 1 != 0
Expand Down Expand Up @@ -152,15 +150,15 @@ mod tests {
let bytes = value.to_le_bytes();

// Is w = 8, the base 2^w expansion should be equal to the le bytes.
let expansion = compute_base_2w_expansion::<16>(&bytes, 8);
let expansion = compute_base_2w_expansion(&bytes, 8);
assert_eq!(
bytes.to_vec(),
expansion.iter().map(|x| *x as u8).collect::<Vec<u8>>()
);

// Verify that the expansion is correct for w = 1, ..., 64
for window_size in 1..=64 {
let expansion = compute_base_2w_expansion::<16>(&bytes, window_size);
let expansion = compute_base_2w_expansion(&bytes, window_size);
let mut sum = 0u128;
for (i, value) in expansion.iter().enumerate() {
sum += (1 << (window_size * i)) * *value as u128;
Expand Down Expand Up @@ -196,14 +194,8 @@ mod tests {

// We implementation `ToLittleEndianByteArray` for BigInt in case it needs to be used as scalar for
// multi-scalar multiplication.
impl<const N: usize> ToLittleEndianByteArray<N> for BigInt {
fn to_le_byte_array(&self) -> [u8; N] {
let mut output = [0u8; N];
let bytes = self.to_bytes_le().1;

// It's up to the caller to ensure that the BigInt is not too large to fit in N bytes.
// Otherwise, we just truncate the output.
output[0..min(bytes.len(), N)].clone_from_slice(&bytes);
output
impl ToLittleEndianBytes for BigInt {
fn to_le_bytes(&self) -> Vec<u8> {
self.to_bytes_le().1
}
}
4 changes: 2 additions & 2 deletions fastcrypto/src/groups/multiplier/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub trait ScalarMultiplier<G, S> {
fn two_scalar_mul(&self, base_scalar: &S, other_element: &G, other_scalar: &S) -> G;
}

pub trait ToLittleEndianByteArray<const SCALAR_SIZE: usize> {
pub trait ToLittleEndianBytes {
/// Serialize scalar into a byte vector in little-endian format.
fn to_le_byte_array(&self) -> [u8; SCALAR_SIZE];
fn to_le_bytes(&self) -> Vec<u8>;
}
Loading

0 comments on commit 999edf9

Please sign in to comment.