Skip to content

Commit

Permalink
feat: regularize proof shape (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtguibas authored May 1, 2024
1 parent 4b1d6a4 commit f09a33e
Show file tree
Hide file tree
Showing 22 changed files with 249 additions and 101 deletions.
8 changes: 4 additions & 4 deletions core/src/stark/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,13 +499,13 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> StarkMachine<SC, A> {
let permutation_width = permutation_traces[i].width();
let total_width = trace_width + permutation_width;
tracing::debug!(
"{:<11} | Cols = {:<5} | Rows = {:<5} | Cells = {:<10} | Main Cols = {:.2}% | Perm Cols = {:.2}%",
"{:<11} | Main Cols = {:<5} | Perm Cols = {:<5} | Rows = {:<10} | Cells = {:<10}",
chips[i].name(),
total_width,
trace_width,
permutation_width,
traces[i].0.height(),
total_width * traces[i].0.height(),
(100f32 * trace_width as f32) / total_width as f32,
(100f32 * permutation_width as f32) / total_width as f32);
);
}

tracing::info_span!("debug constraints").in_scope(|| {
Expand Down
7 changes: 3 additions & 4 deletions core/src/stark/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,12 @@ where
let permutation_width = permutation_traces[i].width();
let total_width = trace_width + permutation_width;
tracing::debug!(
"{:<15} | Cols = {:<5} | Rows = {:<5} | Cells = {:<10} | Main Cols = {:.2}% | Perm Cols = {:.2}%",
"{:<15} | Main Cols = {:<5} | Perm Cols = {:<5} | Rows = {:<5} | Cells = {:<10}",
chips[i].name(),
total_width,
trace_width,
permutation_width,
traces[i].height(),
total_width * traces[i].height(),
(100f32 * trace_width as f32) / total_width as f32,
(100f32 * permutation_width as f32) / total_width as f32,
);
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/utils/logger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub fn setup_logger() {
// if the RUST_LOGGER environment variable is set, use it to determine which logger to configure
// (tracing_forest or tracing_subscriber)
// otherwise, default to 'forest'
let logger_type = std::env::var("RUST_LOGGER").unwrap_or_else(|_| "forest".to_string());
let logger_type = std::env::var("RUST_LOGGER").unwrap_or_else(|_| "flat".to_string());
match logger_type.as_str() {
"forest" => {
Registry::default()
Expand Down
38 changes: 38 additions & 0 deletions core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,44 @@ pub fn pad_rows<T: Clone, const N: usize>(rows: &mut Vec<[T; N]>, row_fn: impl F
rows.resize(padded_nb_rows, dummy_row);
}

pub fn pad_rows_fixed<T: Clone, const N: usize>(
rows: &mut Vec<[T; N]>,
row_fn: impl Fn() -> [T; N],
size_log2: Option<usize>,
) {
let nb_rows = rows.len();
let dummy_row = row_fn();
match size_log2 {
Some(size_log2) => {
let padded_nb_rows = 1 << size_log2;
if nb_rows * 2 < padded_nb_rows {
tracing::warn!(
"fixed log2 rows can be potentially reduced: got {}, expected {}",
nb_rows,
padded_nb_rows
);
}
if nb_rows > padded_nb_rows {
panic!(
"fixed log2 rows is too small: got {}, expected {}",
nb_rows, padded_nb_rows
);
}
rows.resize(padded_nb_rows, dummy_row);
}
None => {
let mut padded_nb_rows = nb_rows.next_power_of_two();
if padded_nb_rows == 2 || padded_nb_rows == 1 {
padded_nb_rows = 4;
}
if padded_nb_rows == nb_rows {
return;
}
rows.resize(padded_nb_rows, dummy_row);
}
}
}

/// Converts a slice of words to a slice of bytes in little endian.
pub fn words_to_bytes_le<const B: usize>(words: &[u32]) -> [u8; B] {
debug_assert_eq!(words.len() * 4, B);
Expand Down
12 changes: 8 additions & 4 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ impl SP1Prover {
&self,
config: SC,
pk: &StarkProvingKey<SC>,
vk: &SP1VerifyingKey,
core_vk: &SP1VerifyingKey,
core_challenger: Challenger<CoreSC>,
reconstruct_challenger: Challenger<CoreSC>,
state: ReduceState,
Expand Down Expand Up @@ -484,7 +484,7 @@ impl SP1Prover {
})
.collect();
let (prep_sorted_indices, prep_domains): (Vec<usize>, Vec<Domain<CoreSC>>) =
get_preprocessed_data(&self.core_machine, &vk.vk);
get_preprocessed_data(&self.core_machine, &core_vk.vk);
let (reduce_prep_sorted_indices, reduce_prep_domains): (Vec<usize>, Vec<Domain<InnerSC>>) =
get_preprocessed_data(&self.reduce_machine, &self.reduce_vk);
let (compress_prep_sorted_indices, compress_prep_domains): (
Expand Down Expand Up @@ -517,7 +517,7 @@ impl SP1Prover {
witness_stream.extend(Hintable::write(&reduce_prep_domains));
witness_stream.extend(compress_prep_sorted_indices.write());
witness_stream.extend(Hintable::write(&compress_prep_domains));
witness_stream.extend(vk.vk.write());
witness_stream.extend(core_vk.vk.write());
witness_stream.extend(self.reduce_vk.write());
witness_stream.extend(self.compress_vk.write());
witness_stream.extend(state.committed_values_digest.write());
Expand Down Expand Up @@ -577,7 +577,11 @@ impl SP1Prover {

// Generate proof.
let start = Instant::now();
let proof = if proving_with_skinny {
let proof = if proving_with_skinny && verifying_compressed_proof {
let machine = RecursionAirSkinnyDeg7::wrap_machine(config);
let mut challenger = machine.config().challenger();
machine.prove::<LocalProver<_, _>>(pk, runtime.record.clone(), &mut challenger)
} else if proving_with_skinny {
let machine = RecursionAirSkinnyDeg7::machine(config);
let mut challenger = machine.config().challenger();
machine.prove::<LocalProver<_, _>>(pk, runtime.record.clone(), &mut challenger)
Expand Down
2 changes: 2 additions & 0 deletions recursion/circuit/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pub fn verify_two_adic_pcs<C: Config>(
rounds: Vec<TwoAdicPcsRoundVariable<C>>,
) {
let alpha = challenger.sample_ext(builder);

let fri_challenges =
verify_shape_and_sample_challenges(builder, config, &proof.fri_proof, challenger);

Expand Down Expand Up @@ -161,6 +162,7 @@ pub fn verify_challenges<C: Config>(
ro,
log_max_height,
);

builder.assert_ext_eq(folded_eval, proof.final_poly);
}
}
Expand Down
24 changes: 12 additions & 12 deletions recursion/circuit/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use sp1_core::{
};
use sp1_recursion_compiler::config::OuterConfig;
use sp1_recursion_compiler::constraints::{Constraint, ConstraintCompiler};
use sp1_recursion_compiler::ir::{Builder, Config, Felt};
use sp1_recursion_compiler::ir::{Builder, Config};
use sp1_recursion_compiler::ir::{Usize, Witness};
use sp1_recursion_compiler::prelude::SymbolicVar;
use sp1_recursion_core::stark::config::{outer_fri_config, BabyBearPoseidon2Outer};
Expand Down Expand Up @@ -235,38 +235,38 @@ type OuterF = <BabyBearPoseidon2Outer as StarkGenericConfig>::Val;
type OuterC = OuterConfig;

pub fn build_wrap_circuit(
vk: &StarkVerifyingKey<OuterSC>,
dummy_proof: ShardProof<OuterSC>,
template_vk: &StarkVerifyingKey<OuterSC>,
template_proof: ShardProof<OuterSC>,
) -> Vec<Constraint> {
let dev_mode = std::env::var("SP1_DEV_WRAPPER")
.unwrap_or("true".to_string())
.to_lowercase()
.eq("true");

let outer_config = OuterSC::new();
let outer_machine = RecursionAirSkinnyDeg7::<OuterF>::machine(outer_config);
let outer_machine = RecursionAirSkinnyDeg7::<OuterF>::wrap_machine(outer_config);

let mut builder = Builder::<OuterConfig>::default();
let mut challenger = MultiField32ChallengerVariable::new(&mut builder);

let preprocessed_commit_val: [Bn254Fr; 1] = vk.commit.into();
let preprocessed_commit_val: [Bn254Fr; 1] = template_vk.commit.into();
let preprocessed_commit: OuterDigestVariable<OuterC> =
[builder.eval(preprocessed_commit_val[0])];
challenger.observe_commitment(&mut builder, preprocessed_commit);
let pc_start: Felt<_> = builder.eval(vk.pc_start);
let pc_start = builder.eval(template_vk.pc_start);
challenger.observe(&mut builder, pc_start);

if !dev_mode {
let chips = outer_machine
.shard_chips_ordered(&dummy_proof.chip_ordering)
.shard_chips_ordered(&template_proof.chip_ordering)
.map(|chip| chip.name())
.collect::<Vec<_>>();

let sorted_indices = outer_machine
.chips()
.iter()
.map(|chip| {
dummy_proof
template_proof
.chip_ordering
.get(&chip.name())
.copied()
Expand All @@ -275,7 +275,7 @@ pub fn build_wrap_circuit(
.collect::<Vec<_>>();

let chip_quotient_data = outer_machine
.shard_chips_ordered(&dummy_proof.chip_ordering)
.shard_chips_ordered(&template_proof.chip_ordering)
.map(|chip| {
let log_quotient_degree = chip.log_quotient_degree();
QuotientDataValues {
Expand All @@ -286,8 +286,8 @@ pub fn build_wrap_circuit(
.collect();

let mut witness = Witness::default();
dummy_proof.write(&mut witness);
let proof = dummy_proof.read(&mut builder);
template_proof.write(&mut witness);
let proof = template_proof.read(&mut builder);

let ShardCommitment { main_commit, .. } = &proof.commitment;
challenger.observe_commitment(&mut builder, *main_commit);
Expand All @@ -300,7 +300,7 @@ pub fn build_wrap_circuit(

StarkVerifierCircuit::<OuterC, OuterSC>::verify_shard(
&mut builder,
vk,
template_vk,
&outer_machine,
&mut challenger.clone(),
&proof,
Expand Down
4 changes: 4 additions & 0 deletions recursion/compiler/src/ir/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ pub struct Witness<C: Config> {
}

impl<C: Config> Witness<C> {
pub fn size(&self) -> usize {
self.vars.len() + self.felts.len() + self.exts.len() + 2
}

pub fn set_vkey_hash(&mut self, vkey_hash: C::N) {
self.vkey_hash = vkey_hash;
}
Expand Down
3 changes: 2 additions & 1 deletion recursion/core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ pub struct CpuEvent<F> {

#[derive(Default)]
pub struct CpuChip<F> {
_phantom: std::marker::PhantomData<F>,
pub fixed_log2_rows: Option<usize>,
pub _phantom: std::marker::PhantomData<F>,
}
20 changes: 12 additions & 8 deletions recursion/core/src/cpu/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use p3_field::{extension::BinomiallyExtendable, PrimeField32};
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use sp1_core::{
air::{BinomialExtension, MachineAir},
utils::pad_rows,
utils::pad_rows_fixed,
};
use tracing::instrument;

Expand All @@ -28,7 +28,7 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>> MachineAir<F> for CpuChip<F> {
// There are no dependencies, since we do it all in the runtime. This is just a placeholder.
}

#[instrument(name = "generate cpu trace", level = "debug", skip_all)]
#[instrument(name = "generate cpu trace", level = "debug", skip_all, fields(rows = input.cpu_events.len()))]
fn generate_trace(
&self,
input: &ExecutionRecord<F>,
Expand Down Expand Up @@ -103,12 +103,16 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>> MachineAir<F> for CpuChip<F> {
})
.collect::<Vec<_>>();

pad_rows(&mut rows, || {
let mut row = [F::zero(); NUM_CPU_COLS];
let cols: &mut CpuCols<F> = row.as_mut_slice().borrow_mut();
cols.selectors.is_noop = F::one();
row
});
pad_rows_fixed(
&mut rows,
|| {
let mut row = [F::zero(); NUM_CPU_COLS];
let cols: &mut CpuCols<F> = row.as_mut_slice().borrow_mut();
cols.selectors.is_noop = F::one();
row
},
self.fixed_log2_rows,
);

let mut trace =
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_CPU_COLS);
Expand Down
26 changes: 16 additions & 10 deletions recursion/core/src/fri_fold/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use p3_field::PrimeField32;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use sp1_core::air::{BinomialExtension, MachineAir};
use sp1_core::utils::pad_to_power_of_two;
use sp1_core::utils::pad_rows_fixed;
use sp1_derive::AlignedBorrow;
use std::borrow::BorrowMut;
use tracing::instrument;
Expand All @@ -21,7 +21,9 @@ use crate::runtime::{ExecutionRecord, RecursionProgram};
pub const NUM_FRI_FOLD_COLS: usize = core::mem::size_of::<FriFoldCols<u8>>();

#[derive(Default)]
pub struct FriFoldChip;
pub struct FriFoldChip {
pub fixed_log2_rows: Option<usize>,
}

#[derive(Debug, Clone)]
pub struct FriFoldEvent<F> {
Expand Down Expand Up @@ -94,16 +96,16 @@ impl<F: PrimeField32> MachineAir<F> for FriFoldChip {
// This is a no-op.
}

#[instrument(name = "generate fri fold trace", level = "debug", skip_all)]
#[instrument(name = "generate fri fold trace", level = "debug", skip_all, fields(rows = input.fri_fold_events.len()))]
fn generate_trace(
&self,
input: &ExecutionRecord<F>,
_: &mut ExecutionRecord<F>,
) -> RowMajorMatrix<F> {
let trace_values = input
let mut rows = input
.fri_fold_events
.iter()
.flat_map(|event| {
.map(|event| {
let mut row = [F::zero(); NUM_FRI_FOLD_COLS];

let cols: &mut FriFoldCols<F> = row.as_mut_slice().borrow_mut();
Expand All @@ -129,15 +131,19 @@ impl<F: PrimeField32> MachineAir<F> for FriFoldChip {
.populate(&event.alpha_pow_at_log_height);
cols.ro_at_log_height.populate(&event.ro_at_log_height);

row.into_iter()
row
})
.collect_vec();

// Convert the trace to a row major matrix.
let mut trace = RowMajorMatrix::new(trace_values, NUM_FRI_FOLD_COLS);

// Pad the trace to a power of two.
pad_to_power_of_two::<NUM_FRI_FOLD_COLS, F>(&mut trace.values);
pad_rows_fixed(
&mut rows,
|| [F::zero(); NUM_FRI_FOLD_COLS],
self.fixed_log2_rows,
);

// Convert the trace to a row major matrix.
let trace = RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_FRI_FOLD_COLS);

#[cfg(debug_assertions)]
println!(
Expand Down
Loading

0 comments on commit f09a33e

Please sign in to comment.