From cd1f6d09af05529ae62e0656ae0fd45a17d405eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20Lindstr=C3=B8m?= Date: Thu, 9 Jan 2025 11:21:48 +0100 Subject: [PATCH] Fix poseidon api --- fastcrypto-zkp/src/bn254/poseidon/mod.rs | 53 +++++++++++------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/fastcrypto-zkp/src/bn254/poseidon/mod.rs b/fastcrypto-zkp/src/bn254/poseidon/mod.rs index 88f3c8c8f..0dba6a111 100644 --- a/fastcrypto-zkp/src/bn254/poseidon/mod.rs +++ b/fastcrypto-zkp/src/bn254/poseidon/mod.rs @@ -7,12 +7,11 @@ use crate::FrRepr; use ark_bn254::Fr; use ark_ff::{BigInteger, PrimeField}; use byte_slice_cast::AsByteSlice; -use fastcrypto::error::FastCryptoError::{InputTooLong, InvalidInput}; +use fastcrypto::error::FastCryptoError::InvalidInput; use fastcrypto::error::{FastCryptoError, FastCryptoResult}; use ff::PrimeField as OtherPrimeField; use neptune::poseidon::HashMode::OptimizedStatic; use neptune::Poseidon; -use std::cmp::Ordering; /// The output of the Poseidon hash function is a field element in BN254 which is 254 bits long, so /// we need 32 bytes to represent it as an integer. @@ -90,12 +89,12 @@ pub fn poseidon_merkle_tree(inputs: &[FieldElement]) -> FastCryptoResult]) -> FastCryptoResult<[u8; FIELD_ELEMENT_SIZE_IN_BYTES]> { @@ -111,25 +110,21 @@ pub fn poseidon_bytes(inputs: &[Vec]) -> FastCryptoResult<[u8; FIELD_ELEMENT /// Given a binary representation of a BN254 field element as an integer in little-endian encoding, /// this function returns the corresponding field element. If the field element is not canonical (is -/// larger than the field size as an integer), an `FastCryptoError::InvalidInput` is returned. +/// larger than the field size as an integer), an [InvalidInput] is returned. /// -/// If more than 32 bytes is given, an `FastCryptoError::InputTooLong` is returned. +/// If the input is not exactly 32 bytes long, an [InvalidInput] is returned. fn canonical_le_bytes_to_field_element(bytes: &[u8]) -> FastCryptoResult { - match bytes.len().cmp(&FIELD_ELEMENT_SIZE_IN_BYTES) { - Ordering::Less => Ok(Fr::from_le_bytes_mod_order(bytes)), - Ordering::Equal => { - let field_element = Fr::from_le_bytes_mod_order(bytes); - // Unfortunately, there doesn't seem to be a nice way to check if a modular reduction - // happened without doing the extra work of serializing the field element again. - let reduced_bytes = field_element.into_bigint().to_bytes_le(); - if reduced_bytes != bytes { - return Err(InvalidInput); - } - Ok(field_element) - } - Ordering::Greater => Err(InputTooLong(FIELD_ELEMENT_SIZE_IN_BYTES)), + if bytes.len() != FIELD_ELEMENT_SIZE_IN_BYTES { + return Err(InvalidInput); + } + let field_element = Fr::from_le_bytes_mod_order(bytes); + // Unfortunately, there doesn't seem to be a nice way to check if a modular reduction + // happened without doing the extra work of serializing the field element again. + let reduced_bytes = field_element.into_bigint().to_bytes_le(); + if reduced_bytes != bytes { + return Err(InvalidInput); } - .map(FieldElement) + Ok(FieldElement(field_element)) } /// Convert a BN254 field element to a byte array as the little-endian representation of the @@ -247,7 +242,10 @@ mod test { #[test] fn test_hash_to_bytes() { - let inputs: Vec> = vec![vec![1u8]]; + let mut one = vec![0; 32]; + one[0] = 1; + + let inputs: Vec> = vec![one.clone()]; let hash = poseidon_bytes(&inputs).unwrap(); // 18586133768512220936620570745912940619677854269274689475585506675881198879027 in decimal let expected = @@ -255,8 +253,11 @@ mod test { .unwrap(); assert_eq!(hash.as_slice(), &expected); + let mut two = vec![0; 32]; + two[0] = 2; + // 7853200120776062878684798364095072458815029376092732009249414926327459813530 in decimal - let inputs: Vec> = vec![vec![1u8], vec![2u8]]; + let inputs: Vec> = vec![one.clone(), two.clone()]; let hash = poseidon_bytes(&inputs).unwrap(); let expected = hex::decode("9a1817447a60199e51453274f217362acfe962966b4cf63d4190d6e7f5c05c11") @@ -266,10 +267,6 @@ mod test { // Input larger than the modulus let inputs = vec![vec![255; 32]]; assert!(poseidon_bytes(&inputs).is_err()); - - // Input smaller than the modulus - let inputs = vec![vec![255; 31]]; - assert!(poseidon_bytes(&inputs).is_ok()); } #[cfg(test)]