diff --git a/core/src/utils/config.rs b/core/src/utils/config.rs index b7939666e4..582ce68c5a 100644 --- a/core/src/utils/config.rs +++ b/core/src/utils/config.rs @@ -15,7 +15,7 @@ use p3_symmetric::Hash; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use serde::Deserialize; use serde::Serialize; -use sp1_primitives::poseidon2_init; +use sp1_primitives::{poseidon2_16_init, poseidon2_24_init}; pub const DIGEST_SIZE: usize = 8; @@ -24,10 +24,12 @@ pub type InnerVal = BabyBear; pub type InnerChallenge = BinomialExtensionField; pub type InnerPerm = Poseidon2; +pub type InnerPerm16 = + Poseidon2; pub type InnerHash = PaddingFreeSponge; pub type InnerDigestHash = Hash; pub type InnerDigest = [InnerVal; DIGEST_SIZE]; -pub type InnerCompress = TruncatedPermutation; +pub type InnerCompress = TruncatedPermutation; pub type InnerValMmcs = FieldMerkleTreeMmcs< ::Packing, ::Packing, @@ -48,14 +50,19 @@ pub type InnerPcsProof = /// The permutation for inner recursion. pub fn inner_perm() -> InnerPerm { - poseidon2_init() + poseidon2_24_init() +} + +pub fn inner_perm16() -> InnerPerm16 { + poseidon2_16_init() } /// The FRI config for sp1 proofs. pub fn sp1_fri_config() -> FriConfig { let perm = inner_perm(); + let perm16 = inner_perm16(); let hash = InnerHash::new(perm.clone()); - let compress = InnerCompress::new(perm.clone()); + let compress = InnerCompress::new(perm16.clone()); let challenge_mmcs = InnerChallengeMmcs::new(InnerValMmcs::new(hash, compress)); let num_queries = match std::env::var("FRI_QUERIES") { Ok(value) => value.parse().unwrap(), @@ -72,8 +79,9 @@ pub fn sp1_fri_config() -> FriConfig { /// The FRI config for inner recursion. pub fn inner_fri_config() -> FriConfig { let perm = inner_perm(); + let perm16 = inner_perm16(); let hash = InnerHash::new(perm.clone()); - let compress = InnerCompress::new(perm.clone()); + let compress = InnerCompress::new(perm16.clone()); let challenge_mmcs = InnerChallengeMmcs::new(InnerValMmcs::new(hash, compress)); let num_queries = match std::env::var("FRI_QUERIES") { Ok(value) => value.parse().unwrap(), @@ -119,8 +127,9 @@ impl From> for BabyBearPoseidon impl BabyBearPoseidon2Inner { pub fn new() -> Self { let perm = inner_perm(); + let perm16 = inner_perm16(); let hash = InnerHash::new(perm.clone()); - let compress = InnerCompress::new(perm.clone()); + let compress = InnerCompress::new(perm16.clone()); let val_mmcs = InnerValMmcs::new(hash, compress); let dft = InnerDft {}; let fri_config = inner_fri_config(); diff --git a/core/src/utils/prove.rs b/core/src/utils/prove.rs index 25852d6c0e..b7716ca029 100644 --- a/core/src/utils/prove.rs +++ b/core/src/utils/prove.rs @@ -403,7 +403,7 @@ pub mod baby_bear_poseidon2 { use p3_poseidon2::Poseidon2ExternalMatrixGeneral; use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; use serde::{Deserialize, Serialize}; - use sp1_primitives::RC_24_29; + use sp1_primitives::{RC_24_29, RC_16_30}; use crate::stark::StarkGenericConfig; @@ -411,8 +411,9 @@ pub mod baby_bear_poseidon2 { pub type Challenge = BinomialExtensionField; pub type Perm = Poseidon2; + pub type Perm16 = Poseidon2; pub type MyHash = PaddingFreeSponge; - pub type MyCompress = TruncatedPermutation; + pub type MyCompress = TruncatedPermutation; pub type ValMmcs = FieldMerkleTreeMmcs< ::Packing, ::Packing, @@ -445,11 +446,34 @@ pub mod baby_bear_poseidon2 { DiffusionMatrixBabyBear, ) } + + pub fn my_perm16() -> Perm16 { + const ROUNDS_F: usize = 8; + const ROUNDS_P: usize = 13; + let mut round_constants = RC_16_30.to_vec(); + let internal_start = ROUNDS_F / 2; + let internal_end = (ROUNDS_F / 2) + ROUNDS_P; + let internal_round_constants = round_constants + .drain(internal_start..internal_end) + .map(|vec| vec[0]) + .collect::>(); + let external_round_constants = round_constants; + Perm16::new( + ROUNDS_F, + external_round_constants, + Poseidon2ExternalMatrixGeneral, + ROUNDS_P, + internal_round_constants, + DiffusionMatrixBabyBear, + ) + } + pub fn default_fri_config() -> FriConfig { let perm = my_perm(); + let perm16 = my_perm16(); let hash = MyHash::new(perm.clone()); - let compress = MyCompress::new(perm.clone()); + let compress = MyCompress::new(perm16.clone()); let challenge_mmcs = ChallengeMmcs::new(ValMmcs::new(hash, compress)); let num_queries = match std::env::var("FRI_QUERIES") { Ok(value) => value.parse().unwrap(), @@ -465,8 +489,9 @@ pub mod baby_bear_poseidon2 { pub fn compressed_fri_config() -> FriConfig { let perm = my_perm(); + let perm16 = my_perm16(); let hash = MyHash::new(perm.clone()); - let compress = MyCompress::new(perm.clone()); + let compress = MyCompress::new(perm16.clone()); let challenge_mmcs = ChallengeMmcs::new(ValMmcs::new(hash, compress)); let num_queries = match std::env::var("FRI_QUERIES") { Ok(value) => value.parse().unwrap(), @@ -496,8 +521,9 @@ pub mod baby_bear_poseidon2 { impl BabyBearPoseidon2 { pub fn new() -> Self { let perm = my_perm(); + let perm16 = my_perm16(); let hash = MyHash::new(perm.clone()); - let compress = MyCompress::new(perm.clone()); + let compress = MyCompress::new(perm16.clone()); let val_mmcs = ValMmcs::new(hash, compress); let dft = Dft {}; let fri_config = default_fri_config(); @@ -511,8 +537,9 @@ pub mod baby_bear_poseidon2 { pub fn compressed() -> Self { let perm = my_perm(); + let perm16 = my_perm16(); let hash = MyHash::new(perm.clone()); - let compress = MyCompress::new(perm.clone()); + let compress = MyCompress::new(perm16.clone()); let val_mmcs = ValMmcs::new(hash, compress); let dft = Dft {}; let fri_config = compressed_fri_config(); diff --git a/primitives/src/lib.rs b/primitives/src/lib.rs index a5e0996218..0d8a027bc2 100644 --- a/primitives/src/lib.rs +++ b/primitives/src/lib.rs @@ -1137,7 +1137,29 @@ lazy_static! { }; } -pub fn poseidon2_init( +pub fn poseidon2_16_init( +) -> Poseidon2 { + const ROUNDS_F: usize = 8; + const ROUNDS_P: usize = 13; + let mut round_constants = RC_16_30.to_vec(); + let internal_start = ROUNDS_F / 2; + let internal_end = (ROUNDS_F / 2) + ROUNDS_P; + let internal_round_constants = round_constants + .drain(internal_start..internal_end) + .map(|vec| vec[0]) + .collect::>(); + let external_round_constants = round_constants; + Poseidon2::new( + ROUNDS_F, + external_round_constants, + Poseidon2ExternalMatrixGeneral, + ROUNDS_P, + internal_round_constants, + DiffusionMatrixBabyBear, + ) +} + +pub fn poseidon2_24_init( ) -> Poseidon2 { const ROUNDS_F: usize = 8; const ROUNDS_P: usize = 21; @@ -1166,7 +1188,7 @@ mod tests { #[test] fn test_24_permutation() { - let h1 = poseidon2_init(); + let h1 = poseidon2_24_init(); type Perm = Poseidon2; let h2 = Perm::new_from_rng_128( @@ -1209,7 +1231,7 @@ pub fn poseidon2_hasher() -> PaddingFreeSponge< 16, 8, > { - let hasher = poseidon2_init(); + let hasher = poseidon2_24_init(); PaddingFreeSponge::< Poseidon2, 24, diff --git a/recursion/program/src/fri/two_adic_pcs.rs b/recursion/program/src/fri/two_adic_pcs.rs index 71b361f8ef..936baa9767 100644 --- a/recursion/program/src/fri/two_adic_pcs.rs +++ b/recursion/program/src/fri/two_adic_pcs.rs @@ -282,6 +282,7 @@ pub mod tests { use rand::rngs::OsRng; use sp1_core::utils::baby_bear_poseidon2::compressed_fri_config; use sp1_core::utils::inner_perm; + use sp1_core::utils::inner_perm16; use sp1_core::utils::InnerChallenge; use sp1_core::utils::InnerChallenger; use sp1_core::utils::InnerCompress; @@ -307,9 +308,10 @@ pub mod tests { let mut rng = &mut OsRng; let log_degrees = &[nb_log2_rows]; let perm = inner_perm(); + let perm16 = inner_perm16(); let fri_config = compressed_fri_config(); let hash = InnerHash::new(perm.clone()); - let compress = InnerCompress::new(perm.clone()); + let compress = InnerCompress::new(perm16.clone()); let val_mmcs = InnerValMmcs::new(hash, compress); let dft = InnerDft {}; let pcs_val: InnerPcs = InnerPcs::new( diff --git a/recursion/program/src/utils.rs b/recursion/program/src/utils.rs index 83b7b2adbe..ae4cd98c1e 100644 --- a/recursion/program/src/utils.rs +++ b/recursion/program/src/utils.rs @@ -27,8 +27,9 @@ type C = AsmConfig; type Val = BabyBear; type Challenge = BinomialExtensionField; type Perm = Poseidon2; +type Perm16 = Poseidon2; type Hash = PaddingFreeSponge; -type Compress = TruncatedPermutation; +type Compress = TruncatedPermutation; type ValMmcs = FieldMerkleTreeMmcs<::Packing, ::Packing, Hash, Compress, 8>; type ChallengeMmcs = ExtensionMmcs;