Skip to content

Commit

Permalink
Use neptune implementation of Poseidon (#671)
Browse files Browse the repository at this point in the history
* Use neptune for poseidon hashing (works)

* Add assertion to test

* Add constants for n=1,..,16

* Move neptune test

* Add comment

* Parameterize test

* Use neptune for poseidon hashing

* Fix comments

* Keep poseidon instances as static

* Clean up

* Use neptune fork

* Works

* Clean up

* Clean up imports

* Get rid of static mutable objects

* More clean ups

* Even more clean ups

* Keep clippy happy

* Try to fix rebase

* Fix tests

* Use actual neptune repo instead of fork

* Clean up after rebase

* Clean up

* Add proptest

* fmt

* Docs

* Clean up proptest

* Move test

* Comment

* More docs

* Fix proptest

* Fix messy imports

* Use new released version of neptune

* Add some clarifying comments

* fmt
  • Loading branch information
jonas-lj authored Nov 9, 2023
1 parent 3c366db commit 546b7a8
Show file tree
Hide file tree
Showing 8 changed files with 13,504 additions and 147 deletions.
239 changes: 209 additions & 30 deletions Cargo.lock

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion fastcrypto-zkp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ schemars ="0.8.10"
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.93"
once_cell = "1.16"
poseidon-ark = { git = "https://github.com/arnaucube/poseidon-ark.git", rev = "ff7f5e05d55667b4ffba129b837da780c4c5c849" }
bcs = "0.1.4"
im = "15"
reqwest = { version = "0.11.20", default_features = false, features = ["blocking", "json", "rustls-tls"] }
rustls-webpki = "0.101.4"
neptune = "13.0.0"
ff = { version = "0.13.0", features = ["derive"] }
typenum = "1.13.0"
lazy_static = "1.4.0"

[dev-dependencies]
ark-bls12-377 = "0.4.0"
Expand All @@ -54,7 +57,9 @@ blake2 = "0.10.6"
criterion = "0.5.1"
hex = "0.4.3"
proptest = "1.1.0"
poseidon-ark = { git = "https://github.com/arnaucube/poseidon-ark.git", rev = "ff7f5e05d55667b4ffba129b837da780c4c5c849" }
tokio = { version = "1.24.1", features = ["sync", "rt", "macros"] }
lazy_static = "1.4.0"

[features]
e2e = []
13,125 changes: 13,125 additions & 0 deletions fastcrypto-zkp/src/bn254/poseidon/constants.rs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,65 +1,84 @@
// Copyright (c) 2022, Mysten Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use crate::bn254::poseidon::constants::*;
use crate::FrRepr;
use ark_bn254::Fr;
use ark_ff::{BigInteger, PrimeField};
use byte_slice_cast::AsByteSlice;
use fastcrypto::error::FastCryptoError;
use fastcrypto::error::FastCryptoError::{InputTooLong, InvalidInput};
use once_cell::sync::OnceCell;
use poseidon_ark::Poseidon;
use ff::PrimeField as OtherPrimeField;
use neptune::poseidon::HashMode::OptimizedStatic;
use neptune::Poseidon;
use std::cmp::Ordering;
use std::fmt::Debug;
use std::fmt::Formatter;

/// The output of the Poseidon hash function is a field element in BN254 which is 254 bits long, so
/// we need 32 bytes to represent it as an integer.
pub const FIELD_ELEMENT_SIZE_IN_BYTES: usize = 32;
mod constants;

/// Wrapper struct for Poseidon hash instance.
pub struct PoseidonWrapper {
instance: Poseidon,
}

impl Debug for PoseidonWrapper {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PoseidonWrapper").finish()
}
}
macro_rules! define_poseidon_hash {
($inputs:expr, $poseidon_constants:expr) => {{
let mut poseidon = Poseidon::new(&$poseidon_constants);
poseidon.reset();
for input in $inputs.iter() {
poseidon.input(bn254_to_fr(*input)).expect("The number of inputs must be aligned with the constants");
}
poseidon.hash_in_mode(OptimizedStatic);

impl Default for PoseidonWrapper {
fn default() -> Self {
Self::new()
}
// Neptune returns the state element with index 1 but we want the first element to be aligned
// with poseidon-rs and circomlib's implementation which returns the 0'th element.
//
// See:
// * https://github.com/lurk-lab/neptune/blob/b7a9db1fc6ce096aff52b903f7d228eddea6d4e3/src/poseidon.rs#L698
// * https://github.com/arnaucube/poseidon-rs/blob/f4ba1f7c32905cd2ae5a71e7568564bb150a9862/src/lib.rs#L116
// * https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/poseidon.circom#L207
poseidon.elements[0]
}};
}

impl PoseidonWrapper {
/// Initialize a Poseidon hash function.
pub fn new() -> Self {
Self {
instance: Poseidon::new(),
}
/// Poseidon hash function over BN254. The input vector cannot be empty and must contain at most 16
/// elements, otherwise an error is returned.
pub fn hash(inputs: Vec<Fr>) -> Result<Fr, FastCryptoError> {
if inputs.is_empty() || inputs.len() > 16 {
return Err(FastCryptoError::InputLengthWrong(inputs.len()));
}

/// Calculate the hash of the given inputs.
pub fn hash(&self, inputs: Vec<Fr>) -> Result<Fr, FastCryptoError> {
self.instance
.hash(inputs)
.map_err(|_| FastCryptoError::InvalidInput)
}
// Instances of Poseidon and PoseidonConstants from neptune have different types depending on
// the number of inputs, so unfortunately we need to use a macro here.
let result = match inputs.len() {
1 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U1),
2 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U2),
3 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U3),
4 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U4),
5 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U5),
6 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U6),
7 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U7),
8 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U8),
9 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U9),
10 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U10),
11 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U11),
12 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U12),
13 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U13),
14 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U14),
15 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U15),
16 => define_poseidon_hash!(inputs, POSEIDON_CONSTANTS_U16),
_ => return Err(InvalidInput),
};
Ok(fr_to_bn254fr(result))
}

/// Calculate the poseidon hash of the field element inputs. If the input length is <= 16, calculate
/// H(inputs), if it is <= 32, calculate H(H(inputs[0..16]), H(inputs[16..])), otherwise return an
/// error.
pub fn to_poseidon_hash(inputs: Vec<Fr>) -> Result<Fr, FastCryptoError> {
static POSEIDON: OnceCell<PoseidonWrapper> = OnceCell::new();
let poseidon_ref = POSEIDON.get_or_init(PoseidonWrapper::new);
if inputs.len() <= 16 {
poseidon_ref.hash(inputs)
hash(inputs)
} else if inputs.len() <= 32 {
let hash1 = poseidon_ref.hash(inputs[0..16].to_vec())?;
let hash2 = poseidon_ref.hash(inputs[16..].to_vec())?;
poseidon_ref.hash([hash1, hash2].to_vec())
let hash1 = hash(inputs[0..16].to_vec())?;
let hash2 = hash(inputs[16..].to_vec())?;
hash([hash1, hash2].to_vec())
} else {
Err(FastCryptoError::GeneralError(format!(
"Yet to implement: Unable to hash a vector of length {}",
Expand Down Expand Up @@ -121,12 +140,31 @@ pub fn hash_to_bytes(
.expect("Leading zeros are added in to_bytes_be"))
}

/// Convert an ff field element to an arkworks-ff field element.
fn fr_to_bn254fr(fr: crate::Fr) -> Fr {
// We use big-endian as in the definition of the BN254 prime field (see fastcrypto-zkp/src/lib.rs).
Fr::from_be_bytes_mod_order(fr.to_repr().as_byte_slice())
}

/// Convert an arkworks-ff field element to an ff field element.
fn bn254_to_fr(fr: Fr) -> crate::Fr {
let mut bytes = [0u8; 32];
// We use big-endian as in the definition of the BN254 prime field (see fastcrypto-zkp/src/lib.rs).
bytes.clone_from_slice(&fr.into_bigint().to_bytes_be());
crate::Fr::from_repr_vartime(FrRepr(bytes))
.expect("The bytes of fr are guaranteed to be canonical here")
}

#[cfg(test)]
mod test {
use super::PoseidonWrapper;
use crate::bn254::poseidon::hash;
use crate::bn254::poseidon::hash_to_bytes;
use crate::bn254::{poseidon::to_poseidon_hash, zk_login::Bn254Fr};
use ark_bn254::Fr;
use ark_ff::{BigInteger, PrimeField};
use lazy_static::lazy_static;
use proptest::arbitrary::Arbitrary;
use proptest::collection;
use std::str::FromStr;

fn to_bigint_arr(vals: Vec<u8>) -> Vec<Bn254Fr> {
Expand All @@ -135,15 +173,14 @@ mod test {

#[test]
fn poseidon_test() {
let poseidon = PoseidonWrapper::new();
let input1 = Fr::from_str("134696963602902907403122104327765350261").unwrap();
let input2 = Fr::from_str("17932473587154777519561053972421347139").unwrap();
let input3 = Fr::from_str("10000").unwrap();
let input4 = Fr::from_str(
"50683480294434968413708503290439057629605340925620961559740848568164438166",
)
.unwrap();
let hash = poseidon.hash(vec![input1, input2, input3, input4]).unwrap();
let hash = hash(vec![input1, input2, input3, input4]).unwrap();
assert_eq!(
hash,
Fr::from_str(
Expand Down Expand Up @@ -199,52 +236,6 @@ mod test {
.is_err());
}

#[test]
fn test_all_inputs_hash() {
let poseidon = PoseidonWrapper::new();
let jwt_sha2_hash_0 = Fr::from_str("248987002057371616691124650904415756047").unwrap();
let jwt_sha2_hash_1 = Fr::from_str("113498781424543581252500776698433499823").unwrap();
let masked_content_hash = Fr::from_str(
"14900420995580824499222150327925943524564997104405553289134597516335134742309",
)
.unwrap();
let payload_start_index = Fr::from_str("103").unwrap();
let payload_len = Fr::from_str("564").unwrap();
let eph_public_key_0 = Fr::from_str("17932473587154777519561053972421347139").unwrap();
let eph_public_key_1 = Fr::from_str("134696963602902907403122104327765350261").unwrap();
let max_epoch = Fr::from_str("10000").unwrap();
let num_sha2_blocks = Fr::from_str("11").unwrap();
let key_claim_name_f = Fr::from_str(
"18523124550523841778801820019979000409432455608728354507022210389496924497355",
)
.unwrap();
let addr_seed = Fr::from_str(
"15604334753912523265015800787270404628529489918817818174033741053550755333691",
)
.unwrap();

let hash = poseidon
.hash(vec![
jwt_sha2_hash_0,
jwt_sha2_hash_1,
masked_content_hash,
payload_start_index,
payload_len,
eph_public_key_0,
eph_public_key_1,
max_epoch,
num_sha2_blocks,
key_claim_name_f,
addr_seed,
])
.unwrap();
assert_eq!(
hash.to_string(),
"2487117669597822357956926047501254969190518860900347921480370492048882803688"
.to_string()
);
}

#[test]
fn test_hash_to_bytes() {
let inputs: Vec<Vec<u8>> = vec![vec![1u8]];
Expand All @@ -271,4 +262,21 @@ mod test {
let inputs = vec![vec![255; 31]];
assert!(hash_to_bytes(&inputs).is_ok());
}

#[cfg(test)]
lazy_static! {
static ref POSEIDON_ARK: poseidon_ark::Poseidon = poseidon_ark::Poseidon::new();
}

proptest::proptest! {
#[test]
fn test_against_poseidon_ark(r in collection::vec(<[u8; 32]>::arbitrary(), 1..16)) {

let inputs = r.into_iter().map(|ri| ark_bn254::Fr::from_le_bytes_mod_order(&ri)).collect::<Vec<_>>();
let expected = POSEIDON_ARK.hash(inputs.clone()).unwrap().into_bigint().to_bytes_le();

let actual = hash_to_bytes(&inputs.iter().map(|i| i.into_bigint().to_bytes_le().to_vec()).collect::<Vec<_>>()).unwrap();
assert_eq!(&actual, expected.as_slice());
}
}
}
52 changes: 46 additions & 6 deletions fastcrypto-zkp/src/bn254/unit_tests/zk_login_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use std::str::FromStr;

use crate::bn254::poseidon::PoseidonWrapper;
use crate::bn254::poseidon::hash;
use crate::bn254::utils::{
gen_address_seed, gen_address_seed_with_salt_hash, get_nonce, get_zk_login_address,
};
Expand All @@ -18,6 +18,7 @@ use crate::bn254::{
zk_login::{ZkLoginInputs, JWK},
zk_login_api::verify_zk_login,
};
use ark_bn254::Fr;
use ark_std::rand::rngs::StdRng;
use ark_std::rand::SeedableRng;
use fastcrypto::ed25519::Ed25519KeyPair;
Expand Down Expand Up @@ -483,11 +484,7 @@ fn test_verify_zk_login() {
let aud = "575519204237-msop9ep45u2uo98hapqmngv8d84qdc8k.apps.googleusercontent.com";
let salt = "6588741469050502421550140105345050859";
let iss = "https://accounts.google.com";
let poseidon = PoseidonWrapper::new();
let salt_hash = poseidon
.hash(vec![to_field(salt).unwrap()])
.unwrap()
.to_string();
let salt_hash = hash(vec![to_field(salt).unwrap()]).unwrap().to_string();
assert!(verify_zk_login_id(&address, name, value, aud, iss, &salt_hash).is_ok());

let address_seed = gen_address_seed_with_salt_hash(&salt_hash, name, value, aud).unwrap();
Expand Down Expand Up @@ -529,3 +526,46 @@ fn test_verify_zk_login() {
Err(FastCryptoError::GeneralError("in_arr too long".to_string()))
);
}

#[test]
fn test_all_inputs_hash() {
let jwt_sha2_hash_0 = Fr::from_str("248987002057371616691124650904415756047").unwrap();
let jwt_sha2_hash_1 = Fr::from_str("113498781424543581252500776698433499823").unwrap();
let masked_content_hash = Fr::from_str(
"14900420995580824499222150327925943524564997104405553289134597516335134742309",
)
.unwrap();
let payload_start_index = Fr::from_str("103").unwrap();
let payload_len = Fr::from_str("564").unwrap();
let eph_public_key_0 = Fr::from_str("17932473587154777519561053972421347139").unwrap();
let eph_public_key_1 = Fr::from_str("134696963602902907403122104327765350261").unwrap();
let max_epoch = Fr::from_str("10000").unwrap();
let num_sha2_blocks = Fr::from_str("11").unwrap();
let key_claim_name_f = Fr::from_str(
"18523124550523841778801820019979000409432455608728354507022210389496924497355",
)
.unwrap();
let addr_seed = Fr::from_str(
"15604334753912523265015800787270404628529489918817818174033741053550755333691",
)
.unwrap();

let hash = hash(vec![
jwt_sha2_hash_0,
jwt_sha2_hash_1,
masked_content_hash,
payload_start_index,
payload_len,
eph_public_key_0,
eph_public_key_1,
max_epoch,
num_sha2_blocks,
key_claim_name_f,
addr_seed,
])
.unwrap();
assert_eq!(
hash.to_string(),
"2487117669597822357956926047501254969190518860900347921480370492048882803688".to_string()
);
}
Loading

0 comments on commit 546b7a8

Please sign in to comment.