diff --git a/crypto/Cargo.toml b/crypto/Cargo.toml index 8261ee196..adb02143a 100644 --- a/crypto/Cargo.toml +++ b/crypto/Cargo.toml @@ -13,6 +13,7 @@ sha3 = "0.10" sha2 = "0.10" thiserror = "1.0.38" serde = { version = "1.0", features = ["derive"] } +rayon = { version = "1.8.0", optional = true } [dev-dependencies] criterion = "0.4" @@ -21,6 +22,7 @@ rand = "0.8.5" [features] test_fiat_shamir = [] +parallel = ["dep:rayon"] [[bench]] name = "criterion_merkle" diff --git a/crypto/src/merkle_tree/traits.rs b/crypto/src/merkle_tree/traits.rs index 6f2f7097f..dce3e253b 100644 --- a/crypto/src/merkle_tree/traits.rs +++ b/crypto/src/merkle_tree/traits.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "parallel")] +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; + /// A backend for Merkle trees. This defines raw `Data` from which the Merkle /// tree is built from. It also defines the `Node` type and the hash function /// used to build parent nodes from children nodes. @@ -11,10 +14,12 @@ pub trait IsMerkleTreeBackend: Default { /// This function takes the list of data from which the Merkle /// tree will be built from and converts it to a list of leaf nodes. fn hash_leaves(unhashed_leaves: &[Self::Data]) -> Vec { - unhashed_leaves - .iter() - .map(|leaf| Self::hash_data(leaf)) - .collect() + #[cfg(feature = "parallel")] + let iter = unhashed_leaves.par_iter(); + #[cfg(not(feature = "parallel"))] + let iter = unhashed_leaves.iter(); + + iter.map(|leaf| Self::hash_data(leaf)).collect() } /// This function takes to children nodes and builds a new parent node. diff --git a/provers/stark/Cargo.toml b/provers/stark/Cargo.toml index 26f6e25c9..aa1f5c683 100644 --- a/provers/stark/Cargo.toml +++ b/provers/stark/Cargo.toml @@ -45,7 +45,7 @@ wasm-bindgen-test = "0.3.0" test_fiat_shamir = [] instruments = [] # This enables timing prints in prover and verifier metal = ["lambdaworks-math/metal"] -parallel = ["dep:rayon"] +parallel = ["dep:rayon", "lambdaworks-crypto/parallel"] wasm = ["dep:wasm-bindgen", "dep:serde-wasm-bindgen", "dep:web-sys"] [target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies]