Skip to content

Commit

Permalink
perf: poseidon2 parallel tracegen (#1118)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevjue authored Jul 17, 2024
1 parent 1a8509f commit 22f51bb
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 73 deletions.
17 changes: 11 additions & 6 deletions core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,23 +194,28 @@ pub fn log2_strict_usize(n: usize) -> usize {
res as usize
}

pub fn par_for_each_row<P, F>(vec: &mut [F], num_cols: usize, processor: P)
pub fn par_for_each_row<P, F>(vec: &mut [F], num_elements_per_event: usize, processor: P)
where
F: Send,
P: Fn(usize, &mut [F]) + Send + Sync,
{
// Split the vector into `num_cpus` chunks, but at least `num_cpus` rows per chunk.
let len = vec.len();
assert!(vec.len() % num_elements_per_event == 0);
let len = vec.len() / num_elements_per_event;
let cpus = num_cpus::get();
let ceil_div = (len + cpus - 1) / cpus;
let chunk_size = std::cmp::max(ceil_div, cpus);

vec.chunks_mut(chunk_size * num_cols)
vec.chunks_mut(chunk_size * num_elements_per_event)
.enumerate()
.par_bridge()
.for_each(|(i, chunk)| {
chunk.chunks_mut(num_cols).enumerate().for_each(|(j, row)| {
processor(i * chunk_size + j, row);
});
chunk
.chunks_mut(num_elements_per_event)
.enumerate()
.for_each(|(j, row)| {
assert!(row.len() == num_elements_per_event);
processor(i * chunk_size + j, row);
});
});
}
6 changes: 3 additions & 3 deletions recursion/core/src/poseidon2_wide/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ impl<'a, const DEGREE: usize> Poseidon2WideChip<DEGREE> {
/// Transmute a row it to a mutable Poseidon2 instance.
pub(crate) fn convert_mut<'b: 'a, F: PrimeField32>(
&self,
row: &'b mut Vec<F>,
row: &'b mut [F],
) -> Box<dyn Poseidon2Mut<'a, F> + 'a> {
if DEGREE == 3 {
let convert: &mut Poseidon2Degree3<F> = row.as_mut_slice().borrow_mut();
let convert: &mut Poseidon2Degree3<F> = row.borrow_mut();
Box::new(convert)
} else if DEGREE == 9 || DEGREE == 17 {
let convert: &mut Poseidon2Degree9<F> = row.as_mut_slice().borrow_mut();
let convert: &mut Poseidon2Degree9<F> = row.borrow_mut();
Box::new(convert)
} else {
panic!("Unsupported degree");
Expand Down
144 changes: 80 additions & 64 deletions recursion/core/src/poseidon2_wide/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ use std::borrow::Borrow;
use p3_air::BaseAir;
use p3_field::PrimeField32;
use p3_matrix::dense::RowMajorMatrix;
use sp1_core::{air::MachineAir, utils::pad_rows_fixed};
use p3_maybe_rayon::prelude::IndexedParallelIterator;
use p3_maybe_rayon::prelude::ParallelIterator;
use p3_maybe_rayon::prelude::ParallelSliceMut;
use sp1_core::air::MachineAir;
use sp1_core::utils::next_power_of_two;
use sp1_core::utils::par_for_each_row;
use sp1_primitives::RC_16_30_U32;
use tracing::instrument;

Expand Down Expand Up @@ -34,43 +39,69 @@ impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for Poseidon2WideChip<D
input: &ExecutionRecord<F>,
output: &mut ExecutionRecord<F>,
) -> RowMajorMatrix<F> {
let mut rows = Vec::new();
// Calculate the number of rows in the trace.
let mut nb_rows = 0;
for event in input.poseidon2_hash_events.iter() {
match event {
Poseidon2HashEvent::Absorb(absorb_event) => {
nb_rows += absorb_event.iterations.len();
}
Poseidon2HashEvent::Finalize(_) => {
nb_rows += 1;
}
}
}
nb_rows += input.poseidon2_compress_events.len() * 2;

let nb_padded_rows = if self.pad {
next_power_of_two(nb_rows, self.fixed_log2_rows)
} else {
nb_rows
};

let num_columns = <Poseidon2WideChip<DEGREE> as BaseAir<F>>::width(self);
let mut rows = vec![F::zero(); nb_padded_rows * num_columns];

// Populate the hash events.
// Populate the hash events. We do this serially, since each absorb event could populate a different
// number of rows. Also, most of the rows are populated by the compress events.
let mut row_cursor = 0;
for event in &input.poseidon2_hash_events {
match event {
Poseidon2HashEvent::Absorb(absorb_event) => {
rows.extend(self.populate_absorb_event(absorb_event, num_columns, output));
let num_absorb_elements = absorb_event.iterations.len() * num_columns;
let absorb_rows = &mut rows[row_cursor..row_cursor + num_absorb_elements];
self.populate_absorb_event(absorb_rows, absorb_event, num_columns, output);
row_cursor += num_absorb_elements;
}

Poseidon2HashEvent::Finalize(finalize_event) => {
rows.push(self.populate_finalize_event(finalize_event, num_columns));
let finalize_row = &mut rows[row_cursor..row_cursor + num_columns];
self.populate_finalize_event(finalize_row, finalize_event);
row_cursor += num_columns;
}
}
}

// Populate the compress events.
for event in &input.poseidon2_compress_events {
rows.extend(self.populate_compress_event(event, num_columns));
}
let compress_rows = &mut rows[row_cursor..nb_rows * num_columns];
par_for_each_row(compress_rows, num_columns * 2, |i, rows| {
self.populate_compress_event(rows, &input.poseidon2_compress_events[i], num_columns);
});

// Convert the trace to a row major matrix.
let mut trace = RowMajorMatrix::new(rows, num_columns);

let padded_rows = trace.values.par_chunks_mut(num_columns).skip(nb_rows);

if self.pad {
// Pad the trace to a power of two.
pad_rows_fixed(
&mut rows,
|| {
let mut padded_row = vec![F::zero(); num_columns];
self.populate_permutation([F::zero(); WIDTH], None, &mut padded_row);
padded_row
},
self.fixed_log2_rows,
);
let mut dummy_row = vec![F::zero(); num_columns];
self.populate_permutation([F::zero(); WIDTH], None, &mut dummy_row);
padded_rows.for_each(|padded_row| {
padded_row.copy_from_slice(&dummy_row);
});
}

// Convert the trace to a row major matrix.
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), num_columns)
trace
}

fn included(&self, record: &Self::Record) -> bool {
Expand All @@ -81,15 +112,14 @@ impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for Poseidon2WideChip<D
impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
pub fn populate_compress_event<F: PrimeField32>(
&self,
rows: &mut [F],
compress_event: &Poseidon2CompressEvent<F>,
num_columns: usize,
) -> Vec<Vec<F>> {
let mut compress_rows = Vec::new();

let mut input_row = vec![F::zero(); num_columns];
) {
let input_row = &mut rows[0..num_columns];
// Populate the control flow fields.
{
let mut cols = self.convert_mut(&mut input_row);
let mut cols = self.convert_mut(input_row);
let control_flow = cols.control_flow_mut();

control_flow.is_compress = F::one();
Expand All @@ -98,7 +128,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {

// Populate the syscall params fields.
{
let mut cols = self.convert_mut(&mut input_row);
let mut cols = self.convert_mut(input_row);
let syscall_params = cols.syscall_params_mut().compress_mut();

syscall_params.clk = compress_event.clk;
Expand All @@ -109,7 +139,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {

// Populate the memory fields.
{
let mut cols = self.convert_mut(&mut input_row);
let mut cols = self.convert_mut(input_row);
let memory = cols.memory_mut();

memory.start_addr = compress_event.left;
Expand All @@ -122,7 +152,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {

// Populate the opcode workspace fields.
{
let mut cols = self.convert_mut(&mut input_row);
let mut cols = self.convert_mut(input_row);
let compress_cols = cols.opcode_workspace_mut().compress_mut();
compress_cols.start_addr = compress_event.right;

Expand All @@ -137,22 +167,20 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
self.populate_permutation(
compress_event.input,
Some(compress_event.result_array),
&mut input_row,
input_row,
);

compress_rows.push(input_row);

let mut output_row = vec![F::zero(); num_columns];
let output_row = &mut rows[num_columns..];
{
let mut cols = self.convert_mut(&mut output_row);
let mut cols = self.convert_mut(output_row);
let control_flow = cols.control_flow_mut();

control_flow.is_compress = F::one();
control_flow.is_compress_output = F::one();
}

{
let mut cols = self.convert_mut(&mut output_row);
let mut cols = self.convert_mut(output_row);
let syscall_cols = cols.syscall_params_mut().compress_mut();

syscall_cols.clk = compress_event.clk;
Expand All @@ -162,7 +190,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
}

{
let mut cols = self.convert_mut(&mut output_row);
let mut cols = self.convert_mut(output_row);
let memory = cols.memory_mut();

memory.start_addr = compress_event.dst;
Expand All @@ -174,7 +202,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
}

{
let mut cols = self.convert_mut(&mut output_row);
let mut cols = self.convert_mut(output_row);
let compress_cols = cols.opcode_workspace_mut().compress_mut();

compress_cols.start_addr = compress_event.dst + F::from_canonical_usize(WIDTH / 2);
Expand All @@ -184,34 +212,30 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
}
}

self.populate_permutation(compress_event.result_array, None, &mut output_row);

compress_rows.push(output_row);
compress_rows
self.populate_permutation(compress_event.result_array, None, output_row);
}

pub fn populate_absorb_event<F: PrimeField32>(
&self,
rows: &mut [F],
absorb_event: &Poseidon2AbsorbEvent<F>,
num_columns: usize,
output: &mut ExecutionRecord<F>,
) -> Vec<Vec<F>> {
let mut absorb_rows = Vec::new();

) {
// We currently don't support an input_len of 0, since it will need special logic in the AIR.
assert!(absorb_event.input_len > F::zero());

let mut last_row_ending_cursor = 0;
let num_absorb_rows = absorb_event.iterations.len();

for (iter_num, absorb_iter) in absorb_event.iterations.iter().enumerate() {
let mut absorb_row = vec![F::zero(); num_columns];
let absorb_row = &mut rows[iter_num * num_columns..(iter_num + 1) * num_columns];
let is_syscall_row = iter_num == 0;
let is_last_row = iter_num == num_absorb_rows - 1;

// Populate the control flow fields.
{
let mut cols = self.convert_mut(&mut absorb_row);
let mut cols = self.convert_mut(absorb_row);
let control_flow = cols.control_flow_mut();

control_flow.is_absorb = F::one();
Expand All @@ -223,7 +247,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {

// Populate the syscall params fields.
{
let mut cols = self.convert_mut(&mut absorb_row);
let mut cols = self.convert_mut(absorb_row);
let syscall_params = cols.syscall_params_mut().absorb_mut();

syscall_params.clk = absorb_event.clk;
Expand All @@ -239,7 +263,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {

// Populate the memory fields.
{
let mut cols = self.convert_mut(&mut absorb_row);
let mut cols = self.convert_mut(absorb_row);
let memory = cols.memory_mut();

memory.start_addr = absorb_iter.start_addr;
Expand All @@ -251,7 +275,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {

// Populate the opcode workspace fields.
{
let mut cols = self.convert_mut(&mut absorb_row);
let mut cols = self.convert_mut(absorb_row);
let absorb_workspace = cols.opcode_workspace_mut().absorb_mut();

absorb_workspace.hash_num = absorb_event.hash_num;
Expand Down Expand Up @@ -337,33 +361,27 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
} else {
None
},
&mut absorb_row,
absorb_row,
);

absorb_rows.push(absorb_row);
}

absorb_rows
}

pub fn populate_finalize_event<F: PrimeField32>(
&self,
row: &mut [F],
finalize_event: &Poseidon2FinalizeEvent<F>,
num_columns: usize,
) -> Vec<F> {
let mut finalize_row = vec![F::zero(); num_columns];

) {
// Populate the control flow fields.
{
let mut cols = self.convert_mut(&mut finalize_row);
let mut cols = self.convert_mut(row);
let control_flow = cols.control_flow_mut();
control_flow.is_finalize = F::one();
control_flow.is_syscall_row = F::one();
}

// Populate the syscall params fields.
{
let mut cols = self.convert_mut(&mut finalize_row);
let mut cols = self.convert_mut(row);

let syscall_params = cols.syscall_params_mut().finalize_mut();
syscall_params.clk = finalize_event.clk;
Expand All @@ -373,7 +391,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {

// Populate the memory fields.
{
let mut cols = self.convert_mut(&mut finalize_row);
let mut cols = self.convert_mut(row);
let memory = cols.memory_mut();

memory.start_addr = finalize_event.output_ptr;
Expand All @@ -385,7 +403,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {

// Populate the opcode workspace fields.
{
let mut cols = self.convert_mut(&mut finalize_row);
let mut cols = self.convert_mut(row);
let finalize_workspace = cols.opcode_workspace_mut().finalize_mut();

finalize_workspace.previous_state = finalize_event.previous_state;
Expand All @@ -404,10 +422,8 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
} else {
None
},
&mut finalize_row,
row,
);

finalize_row
}

pub fn populate_permutation<F: PrimeField32>(
Expand Down

0 comments on commit 22f51bb

Please sign in to comment.