Skip to content

Commit

Permalink
feat: multi-threaded tracing (#1124)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtguibas authored Jul 18, 2024
1 parent 6e92b1b commit b17f86e
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 119 deletions.
12 changes: 7 additions & 5 deletions core/src/stark/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,14 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> StarkMachine<SC, A> {
let chips = self.chips();
records.iter_mut().for_each(|record| {
chips.iter().for_each(|chip| {
let mut output = A::Record::default();
chip.generate_dependencies(record, &mut output);
record.append(&mut output);
tracing::debug_span!("chip dependencies", chip = chip.name()).in_scope(|| {
let mut output = A::Record::default();
chip.generate_dependencies(record, &mut output);
record.append(&mut output);
});
});
record.register_nonces(opts);
});
tracing::debug_span!("register nonces").in_scope(|| record.register_nonces(opts));
})
}

pub const fn config(&self) -> &SC {
Expand Down
253 changes: 139 additions & 114 deletions core/src/utils/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ where
let (pk, vk) = prover.setup(runtime.program.as_ref());

// Execute the program, saving checkpoints at the start of every `shard_batch_size` cycle range.
let make_checkpoint_span = tracing::debug_span!("Execute and save checkpoints").entered();
let create_checkpoints_span = tracing::debug_span!("create checkpoints").entered();
let mut checkpoints = Vec::new();
let (public_values_stream, public_values) = loop {
// Execute the runtime until we reach a checkpoint.
Expand All @@ -174,94 +174,109 @@ where
);
}
};
make_checkpoint_span.exit();
create_checkpoints_span.exit();

// Commit to the shards.
#[cfg(debug_assertions)]
let mut debug_records: Vec<ExecutionRecord> = Vec::new();

let commit_span = tracing::debug_span!("Commit to shards").entered();
let mut deferred = ExecutionRecord::new(program.clone().into());
let mut state = public_values.reset();
let nb_checkpoints = checkpoints.len();
let mut challenger = prover.config().challenger();
vk.observe_into(&mut challenger);

let scope_span = tracing::Span::current().clone();
std::thread::scope(move |s| {
let _span = scope_span.enter();

// Spawn a thread for commiting to the shards.
let span = tracing::Span::current().clone();
let (records_tx, records_rx) =
sync_channel::<Vec<ExecutionRecord>>(opts.commit_stream_capacity);
let challenger_handle = s.spawn(move || {
for records in records_rx.iter() {
let commitments = records
.par_iter()
.map(|record| prover.commit(record))
.collect::<Vec<_>>();
for (commit, record) in commitments.into_iter().zip(records) {
prover.update(
&mut challenger,
commit,
&record.public_values::<SC::Val>()[0..prover.machine().num_pv_elts()],
);
let _span = span.enter();
tracing::debug_span!("phase 1 commiter").in_scope(|| {
for records in records_rx.iter() {
let commitments = tracing::debug_span!("batch").in_scope(|| {
let span = tracing::Span::current().clone();
records
.par_iter()
.map(|record| {
let _span = span.enter();
prover.commit(record)
})
.collect::<Vec<_>>()
});
for (commit, record) in commitments.into_iter().zip(records) {
prover.update(
&mut challenger,
commit,
&record.public_values::<SC::Val>()[0..prover.machine().num_pv_elts()],
);
}
}
}
});

challenger
});

for (checkpoint_idx, checkpoint_file) in checkpoints.iter_mut().enumerate() {
// Trace the checkpoint and reconstruct the execution records.
let (mut records, _) = trace_checkpoint(program.clone(), checkpoint_file, opts);
reset_seek(&mut *checkpoint_file);

// Update the public values & prover state for the shards which contain "cpu events".
for record in records.iter_mut() {
state.shard += 1;
state.execution_shard = record.public_values.execution_shard;
state.start_pc = record.public_values.start_pc;
state.next_pc = record.public_values.next_pc;
record.public_values = state;
}
tracing::debug_span!("phase 1 record generator").in_scope(|| {
for (checkpoint_idx, checkpoint_file) in checkpoints.iter_mut().enumerate() {
// Trace the checkpoint and reconstruct the execution records.
let (mut records, _) = tracing::debug_span!("trace checkpoint")
.in_scope(|| trace_checkpoint(program.clone(), checkpoint_file, opts));
reset_seek(&mut *checkpoint_file);

// Update the public values & prover state for the shards which contain "cpu events".
for record in records.iter_mut() {
state.shard += 1;
state.execution_shard = record.public_values.execution_shard;
state.start_pc = record.public_values.start_pc;
state.next_pc = record.public_values.next_pc;
record.public_values = state;
}

// Generate the dependencies.
tracing::debug_span!("Generate dependencies", checkpoint_idx = checkpoint_idx)
.in_scope(|| prover.machine().generate_dependencies(&mut records, &opts));
// Generate the dependencies.
tracing::debug_span!("generate dependencies")
.in_scope(|| prover.machine().generate_dependencies(&mut records, &opts));

// Defer events that are too expensive to include in every shard.
for record in records.iter_mut() {
deferred.append(&mut record.defer());
}
// Defer events that are too expensive to include in every shard.
for record in records.iter_mut() {
deferred.append(&mut record.defer());
}

// See if any deferred shards are ready to be commited to.
let is_last_checkpoint = checkpoint_idx == nb_checkpoints - 1;
let mut deferred = deferred.split(is_last_checkpoint, opts.split_opts);
// See if any deferred shards are ready to be commited to.
let is_last_checkpoint = checkpoint_idx == nb_checkpoints - 1;
let mut deferred = deferred.split(is_last_checkpoint, opts.split_opts);

// Update the public values & prover state for the shards which do not contain "cpu events"
// before committing to them.
if !is_last_checkpoint {
state.execution_shard += 1;
}
for record in deferred.iter_mut() {
state.shard += 1;
state.previous_init_addr_bits = record.public_values.previous_init_addr_bits;
state.last_init_addr_bits = record.public_values.last_init_addr_bits;
state.previous_finalize_addr_bits =
record.public_values.previous_finalize_addr_bits;
state.last_finalize_addr_bits = record.public_values.last_finalize_addr_bits;
state.start_pc = state.next_pc;
record.public_values = state;
}
records.append(&mut deferred);
// Update the public values & prover state for the shards which do not contain "cpu events"
// before committing to them.
if !is_last_checkpoint {
state.execution_shard += 1;
}
for record in deferred.iter_mut() {
state.shard += 1;
state.previous_init_addr_bits = record.public_values.previous_init_addr_bits;
state.last_init_addr_bits = record.public_values.last_init_addr_bits;
state.previous_finalize_addr_bits =
record.public_values.previous_finalize_addr_bits;
state.last_finalize_addr_bits = record.public_values.last_finalize_addr_bits;
state.start_pc = state.next_pc;
record.public_values = state;
}
records.append(&mut deferred);

#[cfg(debug_assertions)]
{
debug_records.extend(records.clone());
}
#[cfg(debug_assertions)]
{
debug_records.extend(records.clone());
}

records_tx.send(records).unwrap();
}
records_tx.send(records).unwrap();
}
});
drop(records_tx);
let challenger = challenger_handle.join().unwrap();
commit_span.exit();

// Debug the constraints if debug assertions are enabled.
#[cfg(debug_assertions)]
Expand All @@ -279,65 +294,76 @@ where
let (records_tx, records_rx) =
sync_channel::<Vec<ExecutionRecord>>(opts.prove_stream_capacity);

let commit_and_open = tracing::Span::current().clone();
let shard_proofs = s.spawn(move || {
let _span = commit_and_open.enter();
let mut shard_proofs = Vec::new();
for records in records_rx.iter() {
shard_proofs.par_extend(records.into_par_iter().map(|record| {
prover
.commit_and_open(&pk, record, &mut challenger.clone())
.unwrap()
}));
}
tracing::debug_span!("phase 2 prover").in_scope(|| {
for records in records_rx.iter() {
tracing::debug_span!("batch").in_scope(|| {
let span = tracing::Span::current().clone();
shard_proofs.par_extend(records.into_par_iter().map(|record| {
let _span = span.enter();
prover
.commit_and_open(&pk, record, &mut challenger.clone())
.unwrap()
}));
});
}
});
shard_proofs
});

// let mut shard_proofs = Vec::new();
for (checkpoint_idx, mut checkpoint_file) in checkpoints.into_iter().enumerate() {
// Trace the checkpoint and reconstruct the execution records.
let (mut records, report) = trace_checkpoint(program.clone(), &checkpoint_file, opts);
report_aggregate += report;
reset_seek(&mut checkpoint_file);

// Update the public values & prover state for the shards which contain "cpu events".
for record in records.iter_mut() {
state.shard += 1;
state.execution_shard = record.public_values.execution_shard;
state.start_pc = record.public_values.start_pc;
state.next_pc = record.public_values.next_pc;
record.public_values = state;
}
tracing::debug_span!("phase 2 record generator").in_scope(|| {
for (checkpoint_idx, mut checkpoint_file) in checkpoints.into_iter().enumerate() {
// Trace the checkpoint and reconstruct the execution records.
let (mut records, report) = tracing::debug_span!("trace checkpoint")
.in_scope(|| trace_checkpoint(program.clone(), &checkpoint_file, opts));
report_aggregate += report;
reset_seek(&mut checkpoint_file);

// Update the public values & prover state for the shards which contain "cpu events".
for record in records.iter_mut() {
state.shard += 1;
state.execution_shard = record.public_values.execution_shard;
state.start_pc = record.public_values.start_pc;
state.next_pc = record.public_values.next_pc;
record.public_values = state;
}

// Generate the dependencies.
prover.machine().generate_dependencies(&mut records, &opts);
// Generate the dependencies.
tracing::debug_span!("generate dependencies")
.in_scope(|| prover.machine().generate_dependencies(&mut records, &opts));

// Defer events that are too expensive to include in every shard.
for record in records.iter_mut() {
deferred.append(&mut record.defer());
}
// Defer events that are too expensive to include in every shard.
for record in records.iter_mut() {
deferred.append(&mut record.defer());
}

// See if any deferred shards are ready to be commited to.
let is_last_checkpoint = checkpoint_idx == nb_checkpoints - 1;
let mut deferred = deferred.split(is_last_checkpoint, opts.split_opts);
// See if any deferred shards are ready to be commited to.
let is_last_checkpoint = checkpoint_idx == nb_checkpoints - 1;
let mut deferred = deferred.split(is_last_checkpoint, opts.split_opts);

// Update the public values & prover state for the shards which do not contain "cpu events"
// before committing to them.
if !is_last_checkpoint {
state.execution_shard += 1;
}
for record in deferred.iter_mut() {
state.shard += 1;
state.previous_init_addr_bits = record.public_values.previous_init_addr_bits;
state.last_init_addr_bits = record.public_values.last_init_addr_bits;
state.previous_finalize_addr_bits =
record.public_values.previous_finalize_addr_bits;
state.last_finalize_addr_bits = record.public_values.last_finalize_addr_bits;
state.start_pc = state.next_pc;
record.public_values = state;
}
records.append(&mut deferred);
// Update the public values & prover state for the shards which do not contain "cpu events"
// before committing to them.
if !is_last_checkpoint {
state.execution_shard += 1;
}
for record in deferred.iter_mut() {
state.shard += 1;
state.previous_init_addr_bits = record.public_values.previous_init_addr_bits;
state.last_init_addr_bits = record.public_values.last_init_addr_bits;
state.previous_finalize_addr_bits =
record.public_values.previous_finalize_addr_bits;
state.last_finalize_addr_bits = record.public_values.last_finalize_addr_bits;
state.start_pc = state.next_pc;
record.public_values = state;
}
records.append(&mut deferred);

records_tx.send(records).unwrap();
}
records_tx.send(records).unwrap();
}
});
drop(records_tx);
let shard_proofs = shard_proofs.join().unwrap();

Expand Down Expand Up @@ -381,7 +407,7 @@ pub fn run_test_io<P: MachineProver<BabyBearPoseidon2, RiscvAir<BabyBear>>>(
program: Program,
inputs: SP1Stdin,
) -> Result<SP1PublicValues, crate::stark::MachineVerificationError<BabyBearPoseidon2>> {
let runtime = tracing::info_span!("runtime.run(...)").in_scope(|| {
let runtime = tracing::debug_span!("runtime.run(...)").in_scope(|| {
let mut runtime = Runtime::new(program, SP1CoreOpts::default());
runtime.write_vecs(&inputs.buffer);
runtime.run().unwrap();
Expand All @@ -398,7 +424,7 @@ pub fn run_test<P: MachineProver<BabyBearPoseidon2, RiscvAir<BabyBear>>>(
crate::stark::MachineProof<BabyBearPoseidon2>,
crate::stark::MachineVerificationError<BabyBearPoseidon2>,
> {
let runtime = tracing::info_span!("runtime.run(...)").in_scope(|| {
let runtime = tracing::debug_span!("runtime.run(...)").in_scope(|| {
let mut runtime = Runtime::new(program, SP1CoreOpts::default());
runtime.run().unwrap();
runtime
Expand Down Expand Up @@ -483,8 +509,7 @@ fn trace_checkpoint(
// We already passed the deferred proof verifier when creating checkpoints, so the proofs were
// already verified. So here we use a noop verifier to not print any warnings.
runtime.subproof_verifier = Arc::new(NoOpSubproofVerifier);
let (events, _) =
tracing::debug_span!("runtime.trace").in_scope(|| runtime.execute_record().unwrap());
let (events, _) = runtime.execute_record().unwrap();
(events, runtime.report)
}

Expand Down

0 comments on commit b17f86e

Please sign in to comment.