diff --git a/fastcrypto-vdf/src/math/parameterized_group.rs b/fastcrypto-vdf/src/math/parameterized_group.rs index 7b0543d7b..7ff868aff 100644 --- a/fastcrypto-vdf/src/math/parameterized_group.rs +++ b/fastcrypto-vdf/src/math/parameterized_group.rs @@ -5,6 +5,7 @@ use fastcrypto::error::FastCryptoError::InvalidInput; use fastcrypto::error::FastCryptoResult; use fastcrypto::groups::Doubling; use num_bigint::BigUint; +use num_traits::Zero; use std::ops::{Add, Neg}; /// This trait is implemented by types which can be used as parameters for a parameterized group. @@ -28,8 +29,8 @@ pub trait ParameterizedGroupElement: /// Returns true if this is an element of the group defined by `parameter`. fn is_in_group(&self, parameter: &Self::ParameterType) -> bool; - /// Compute self * scalar using a "Double-and-Add" algorithm. Returns an `InvalidInput` error if - /// the input is not in the group defined by `parameter`. + /// Compute self * scalar using a "Double-and-Add" algorithm for a positive scalar. Returns an + /// `InvalidInput` error if the scalar is zero. fn multiply( &self, scalar: &BigUint, @@ -38,16 +39,20 @@ pub trait ParameterizedGroupElement: if !self.is_in_group(parameter) { return Err(InvalidInput); } - let result = (0..scalar.bits()).rev().map(|i| scalar.bit(i)).fold( - Self::zero(parameter), - |acc, bit| { + if scalar.is_zero() { + return Ok(Self::zero(parameter)); + } + let result = (0..scalar.bits()) + .rev() + .map(|i| scalar.bit(i)) + .skip(1) // The most significant bit is always 1. + .fold(self.clone(), |acc, bit| { let mut res = acc.double(); if bit { res = res + self; } res - }, - ); + }); Ok(result) } } @@ -64,11 +69,10 @@ mod tests { fn test_scalar_multiplication() { let discriminant = Discriminant::from_seed(b"test", 256).unwrap(); let input = QuadraticForm::generator(&discriminant); - let zero = QuadraticForm::zero(&discriminant); // Edge cases assert_eq!( - zero, + QuadraticForm::zero(&discriminant), input.multiply(&BigUint::zero(), &discriminant).unwrap() ); assert_eq!( diff --git a/fastcrypto-vdf/src/vdf/pietrzak/mod.rs b/fastcrypto-vdf/src/vdf/pietrzak/mod.rs index 654caae5a..4fc3d9d51 100644 --- a/fastcrypto-vdf/src/vdf/pietrzak/mod.rs +++ b/fastcrypto-vdf/src/vdf/pietrzak/mod.rs @@ -1,10 +1,9 @@ // Copyright (c) 2022, Mysten Labs, Inc. // SPDX-License-Identifier: Apache-2.0 -use std::ops::{AddAssign, ShrAssign}; - use num_integer::Integer; use serde::Serialize; +use std::ops::{AddAssign, ShrAssign}; use crate::math::parameterized_group::ParameterizedGroupElement; use crate::vdf::pietrzak::fiat_shamir::{DefaultFiatShamir, FiatShamir}; @@ -34,16 +33,6 @@ impl PietrzaksVDF { } } -/// Replace t with (t+1) >> 1 and return true iff the input was odd. -fn check_parity_and_iterate(t: &mut u64) -> bool { - let parity = t.is_odd(); - if parity { - t.add_assign(1); - } - t.shr_assign(1); - parity -} - impl VDF for PietrzaksVDF { type InputType = G; type OutputType = G; @@ -57,28 +46,29 @@ impl VDF for PietrzaksVDF { // Compute output = 2^iterations * input let output = input.repeated_doubling(self.iterations); - let mut x_i = input.clone(); - let mut y_i = output.clone(); - let mut t_i = self.iterations; + let mut x = input.clone(); + let mut y = output.clone(); + let mut t = self.iterations; // This is ceil(log_2(iterations)). See also https://oeis.org/A029837. let iterations = 64 - (self.iterations - 1).leading_zeros(); let mut proof = Vec::with_capacity(iterations as usize); - // Compute the full proof. This loop may stop at any time which will give a shorter proof that is computationally harder to verify. - while t_i != 1 { - if check_parity_and_iterate(&mut t_i) { - y_i = y_i.double(); + // Compute the full proof. This loop may stop at any time which will give a shorter proof + // that is computationally harder to verify. + while t != 1 { + if check_parity_and_iterate(&mut t) { + y = y.double(); } // TODO: Precompute some of the mu's - let mu_i = x_i.repeated_doubling(t_i); + let mu = x.repeated_doubling(t); - let r = DefaultFiatShamir::compute_challenge(&x_i, &y_i, self.iterations, &mu_i); - x_i = x_i.multiply(&r, &self.group_parameter)? + &mu_i; - y_i = mu_i.multiply(&r, &self.group_parameter)? + &y_i; + let r = DefaultFiatShamir::compute_challenge(&x, &y, self.iterations, &mu); + x = x.multiply(&r, &self.group_parameter)? + μ + y = mu.multiply(&r, &self.group_parameter)? + &y; - proof.push(mu_i); + proof.push(mu); } Ok((output, proof)) @@ -93,28 +83,39 @@ impl VDF for PietrzaksVDF { return Err(InvalidInput); } - let mut x_i = input.clone(); - let mut y_i = output.clone(); - let mut t_i = self.iterations; + let mut x = input.clone(); + let mut y = output.clone(); + let mut t = self.iterations; - for mu_i in proof { - if check_parity_and_iterate(&mut t_i) { - y_i = y_i.double(); + for mu in proof { + if check_parity_and_iterate(&mut t) { + y = y.double(); } - let r = DefaultFiatShamir::compute_challenge(&x_i, &y_i, self.iterations, mu_i); - x_i = x_i.multiply(&r, &self.group_parameter)? + mu_i; - y_i = y_i + mu_i.multiply(&r, &self.group_parameter)?; + let r = DefaultFiatShamir::compute_challenge(&x, &y, self.iterations, mu); + x = x.multiply(&r, &self.group_parameter)? + mu; + y = mu.multiply(&r, &self.group_parameter)? + y; } - let expected = x_i.repeated_doubling(t_i); - if y_i != expected { + // In case the proof is shorter than the full proof, we need to compute the remaining powers. + let expected = x.repeated_doubling(t); + if y != expected { return Err(InvalidProof); } Ok(()) } } +/// Replace t with (t+1) >> 1 and return true iff the input was odd. +fn check_parity_and_iterate(t: &mut u64) -> bool { + let parity = t.is_odd(); + if parity { + t.add_assign(1); + } + t.shr_assign(1); + parity +} + #[cfg(test)] mod tests { use crate::class_group::discriminant::Discriminant;