Skip to content

Commit

Permalink
fix: fri fold mem access (#660)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtguibas authored May 6, 2024
1 parent c01d575 commit 02c6ea6
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 26 deletions.
1 change: 1 addition & 0 deletions recursion/circuit/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ pub fn build_wrap_circuit(
.unwrap();
let pv_committed_values_digest: Var<_> =
babybear_bytes_to_bn254(&mut builder, &pv_committed_values_digest_bytes);

// Committed values digest must match the witnessed one that we are committing to.
builder.assert_var_eq(pv_committed_values_digest, commited_values_digest);

Expand Down
48 changes: 25 additions & 23 deletions recursion/core/src/fri_fold/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,30 +200,32 @@ impl FriFoldChip {
.when(next_is_real.clone())
.assert_zero(next.m);

// TODO: FIX
//
// Ensure that all rows for a FRI FOLD invocation have the same input_ptr, clk, and sequential m values.
builder
.when_transition()
.when_not(local.is_last_iteration)
.when(next_is_real.clone())
.assert_eq(next.m, local.m + AB::Expr::one());
builder
.when_transition()
.when_not(local.is_last_iteration)
.when(next_is_real.clone())
.assert_eq(local.input_ptr, next.input_ptr);
builder
.when_transition()
.when_not(local.is_last_iteration)
.when(next_is_real)
.assert_eq(local.clk + AB::Expr::one(), next.clk);

// Constrain read for `z` at `input_ptr`
builder.recursion_eval_memory_access(
local.clk,
local.input_ptr + AB::Expr::zero(),
&local.z,
local.is_real,
);
// builder
// .when_transition()
// .when_not(local.is_last_iteration)
// .when(next_is_real.clone())
// .assert_eq(next.m, local.m + AB::Expr::one());
// builder
// .when_transition()
// .when_not(local.is_last_iteration)
// .when(next_is_real.clone())
// .assert_eq(local.input_ptr, next.input_ptr);
// builder
// .when_transition()
// .when_not(local.is_last_iteration)
// .when(next_is_real)
// .assert_eq(local.clk + AB::Expr::one(), next.clk);

// // Constrain read for `z` at `input_ptr`
// builder.recursion_eval_memory_access(
// local.clk,
// local.input_ptr + AB::Expr::zero(),
// &local.z,
// local.is_real,
// );

// Constrain read for `alpha`
builder.recursion_eval_memory_access(
Expand Down
85 changes: 83 additions & 2 deletions recursion/core/src/multi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::borrow::{Borrow, BorrowMut};

use itertools::Itertools;
use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::{AbstractField, PrimeField32};
use p3_field::PrimeField32;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use sp1_core::air::{BaseAirBuilder, MachineAir};
Expand Down Expand Up @@ -142,7 +142,7 @@ where
&mut sub_builder,
local.fri_fold(),
next.fri_fold(),
AB::Expr::one() - next.is_fri_fold,
next.is_fri_fold.into(),
);

let poseidon2_chip = Poseidon2Chip::default();
Expand All @@ -160,3 +160,84 @@ impl<T: Copy> MultiCols<T> {
unsafe { &self.instruction.poseidon2 }
}
}

#[cfg(test)]
mod tests {
use itertools::Itertools;
use std::time::Instant;

use p3_baby_bear::BabyBear;
use p3_baby_bear::DiffusionMatrixBabyBear;
use p3_field::AbstractField;
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use p3_poseidon2::Poseidon2;
use p3_poseidon2::Poseidon2ExternalMatrixGeneral;
use sp1_core::stark::StarkGenericConfig;
use sp1_core::utils::inner_perm;
use sp1_core::{
air::MachineAir,
utils::{uni_stark_prove, uni_stark_verify, BabyBearPoseidon2},
};

use crate::multi::MultiChip;
use crate::{poseidon2::Poseidon2Event, runtime::ExecutionRecord};
use p3_symmetric::Permutation;

#[test]
fn prove_babybear() {
let config = BabyBearPoseidon2::compressed();
let mut challenger = config.challenger();

let chip = MultiChip {
fixed_log2_rows: None,
};

let test_inputs = (0..16)
.map(|i| [BabyBear::from_canonical_u32(i); 16])
.collect_vec();

let gt: Poseidon2<
BabyBear,
Poseidon2ExternalMatrixGeneral,
DiffusionMatrixBabyBear,
16,
7,
> = inner_perm();

let expected_outputs = test_inputs
.iter()
.map(|input| gt.permute(*input))
.collect::<Vec<_>>();

let mut input_exec = ExecutionRecord::<BabyBear>::default();
for (input, output) in test_inputs.into_iter().zip_eq(expected_outputs) {
input_exec
.poseidon2_events
.push(Poseidon2Event::dummy_from_input(input, output));
}
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&input_exec, &mut ExecutionRecord::<BabyBear>::default());
println!(
"trace dims is width: {:?}, height: {:?}",
trace.width(),
trace.height()
);

let start = Instant::now();
let proof = uni_stark_prove(&config, &chip, &mut challenger, trace);
let duration = start.elapsed().as_secs_f64();
println!("proof duration = {:?}", duration);

let mut challenger: p3_challenger::DuplexChallenger<
BabyBear,
Poseidon2<BabyBear, Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, 16, 7>,
16,
> = config.challenger();
let start = Instant::now();
uni_stark_verify(&config, &chip, &mut challenger, &proof)
.expect("expected proof to be valid");

let duration = start.elapsed().as_secs_f64();
println!("verify duration = {:?}", duration);
}
}
3 changes: 2 additions & 1 deletion recursion/core/src/poseidon2/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl Poseidon2Chip {
.sum::<AB::Expr>();
let is_memory_write = local.rounds[local.rounds.len() - 1];

self.eval_mem(builder, local, is_memory_read, is_memory_write);
// self.eval_mem(builder, local, is_memory_read, is_memory_write);

self.eval_computation(
builder,
Expand All @@ -81,6 +81,7 @@ impl Poseidon2Chip {
);
}

#[allow(unused)]
fn eval_mem<AB: BaseAirBuilder + ExtensionAirBuilder>(
&self,
builder: &mut AB,
Expand Down

0 comments on commit 02c6ea6

Please sign in to comment.