Skip to content

Commit

Permalink
add gpu lock in oracle and mt building
Browse files Browse the repository at this point in the history
  • Loading branch information
dloghin committed Apr 16, 2024
1 parent b507df8 commit 8928129
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
10 changes: 10 additions & 0 deletions plonky2/src/fri/oracle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ use cryptography_cuda::{
device::memory::HostOrDeviceSlice, lde_batch, lde_batch_multi_gpu, transpose_rev_batch,
types::*,
};
#[cfg(feature = "cuda")]
use crate::hash::merkle_tree::GPU_LOCK;

use itertools::Itertools;
use plonky2_field::types::Field;
use plonky2_maybe_rayon::*;
Expand Down Expand Up @@ -242,6 +245,10 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
log_n: usize,
_degree: usize,
) -> MerkleTree<F, <C as GenericConfig<D>>::Hasher> {

let mut lock = GPU_LOCK.lock().unwrap();
*lock += 1;

// let salt_size = if blinding { SALT_SIZE } else { 0 };
// println!("salt_size: {:?}", salt_size);
let output_domain_size = log_n + rate_bits;
Expand Down Expand Up @@ -367,6 +374,9 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>

#[cfg(all(feature = "cuda", feature = "batch"))]
if log_n > 10 && polynomials.len() > 0 {
let mut lock = GPU_LOCK.lock().unwrap();
*lock += 1;

println!("log_n: {:?}", log_n);
let start_lde = std::time::Instant::now();

Expand Down
10 changes: 6 additions & 4 deletions plonky2/src/hash/merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use crate::plonk::config::{GenericHashOut, Hasher};
use crate::util::log2_strict;

#[cfg(feature = "cuda")]
static GPU_LOCK: Lazy<Arc<Mutex<i32>>> = Lazy::new(|| Arc::new(Mutex::new(0)));
pub static GPU_LOCK: Lazy<Arc<Mutex<u64>>> = Lazy::new(|| Arc::new(Mutex::new(0)));

#[cfg(feature = "cuda_timing")]
fn print_time(now: Instant, msg: &str) {
Expand Down Expand Up @@ -283,7 +283,8 @@ fn fill_digests_buf_gpu_v1<F: RichField, H: Hasher<F>>(
let cap_height: u64 = cap_height.try_into().unwrap();
let hash_size: u64 = H::HASH_SIZE.try_into().unwrap();

let _lock = GPU_LOCK.lock().unwrap();
let mut lock = GPU_LOCK.lock().unwrap();
*lock += 1;

unsafe {
let now = Instant::now();
Expand Down Expand Up @@ -436,7 +437,8 @@ fn fill_digests_buf_gpu_v2<F: RichField, H: Hasher<F>>(
cap_buf.len() * NUM_HASH_OUT_ELTS
};

let _lock = GPU_LOCK.lock().unwrap();
let mut lock = GPU_LOCK.lock().unwrap();
*lock += 1;

// println!("{} {} {} {} {:?}", leaves_count, leaf_size, digests_count, caps_count, H::HASHER_TYPE);
let mut gpu_leaves_buf: HostOrDeviceSlice<'_, F> =
Expand Down Expand Up @@ -569,7 +571,7 @@ fn fill_digests_buf_gpu_ptr<F: RichField, H: Hasher<F>>(
let cap_height: u64 = cap_height.try_into().unwrap();
let leaf_size: u64 = leaf_len.try_into().unwrap();

let _lock = GPU_LOCK.lock().unwrap();
GPU_LOCK.try_lock().expect_err("GPU_LOCK should be locked!");

let now = Instant::now();
// if digests_buf is empty (size 0), just allocate a few bytes to avoid errors
Expand Down

0 comments on commit 8928129

Please sign in to comment.