From 9425a9af49953a8edaad600521bdbf34e365d3b1 Mon Sep 17 00:00:00 2001 From: John Guibas Date: Tue, 21 May 2024 15:53:39 -0700 Subject: [PATCH] test --- prover/src/lib.rs | 159 +++++++++++++++++++++++++--------------------- 1 file changed, 88 insertions(+), 71 deletions(-) diff --git a/prover/src/lib.rs b/prover/src/lib.rs index eabd8270ef..6e5de60dbd 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -390,7 +390,6 @@ impl SP1Prover { // Run the recursion and reduce programs. // Run the recursion programs. - let mut records = Vec::new(); let (core_inputs, deferred_inputs) = self.get_first_layer_inputs( vk, @@ -400,62 +399,71 @@ impl SP1Prover { batch_size, ); - for input in core_inputs { - let mut runtime = RecursionRuntime::, Challenge, _>::new( - &self.recursion_program, - self.compress_machine.config().perm.clone(), - ); - - let mut witness_stream = Vec::new(); - witness_stream.extend(input.write()); - - runtime.witness_stream = witness_stream.into(); - runtime.run(); - runtime.print_stats(); + let mut first_layer_proofs = Vec::new(); + let shard_batch_size = sp1_core::utils::env::shard_batch_size() as usize; + for inputs in core_inputs.chunks(shard_batch_size) { + let proofs = inputs + .into_par_iter() + .map(|input| { + let mut runtime = RecursionRuntime::, Challenge, _>::new( + &self.recursion_program, + self.compress_machine.config().perm.clone(), + ); - records.push((runtime.record, ReduceProgramType::Core)); + let mut witness_stream = Vec::new(); + witness_stream.extend(input.write()); + + runtime.witness_stream = witness_stream.into(); + runtime.run(); + runtime.print_stats(); + + let pk = &self.rec_pk; + let mut recursive_challenger = self.compress_machine.config().challenger(); + ( + self.compress_machine.prove::>( + pk, + runtime.record, + &mut recursive_challenger, + ), + ReduceProgramType::Core, + ) + }) + .collect::>(); + first_layer_proofs.extend(proofs); } // Run the deferred proofs programs. - for input in deferred_inputs { - let mut runtime = RecursionRuntime::, Challenge, _>::new( - &self.deferred_program, - self.compress_machine.config().perm.clone(), - ); - - let mut witness_stream = Vec::new(); - witness_stream.extend(input.write()); - - runtime.witness_stream = witness_stream.into(); - runtime.run(); - runtime.print_stats(); + for inputs in deferred_inputs.chunks(shard_batch_size) { + let proofs = inputs + .into_par_iter() + .map(|input| { + let mut runtime = RecursionRuntime::, Challenge, _>::new( + &self.deferred_program, + self.compress_machine.config().perm.clone(), + ); - records.push((runtime.record, ReduceProgramType::Deferred)); + let mut witness_stream = Vec::new(); + witness_stream.extend(input.write()); + + runtime.witness_stream = witness_stream.into(); + runtime.run(); + runtime.print_stats(); + + let pk = &self.deferred_pk; + let mut recursive_challenger = self.compress_machine.config().challenger(); + ( + self.compress_machine.prove::>( + pk, + runtime.record, + &mut recursive_challenger, + ), + ReduceProgramType::Deferred, + ) + }) + .collect::>(); + first_layer_proofs.extend(proofs); } - // Prove all recursion programs and recursion deferred programs and verify the proofs. - - // Make the recursive proofs for core and deferred proofs. - let first_layer_proofs = records - .into_par_iter() - .map(|(record, kind)| { - let pk = match kind { - ReduceProgramType::Core => &self.rec_pk, - ReduceProgramType::Deferred => &self.deferred_pk, - ReduceProgramType::Reduce => unreachable!(), - }; - let mut recursive_challenger = self.compress_machine.config().challenger(); - ( - self.compress_machine.prove::>( - pk, - record, - &mut recursive_challenger, - ), - kind, - ) - }) - .collect::>(); - // Chain all the individual shard proofs. let mut reduce_proofs = first_layer_proofs .into_iter() @@ -467,28 +475,37 @@ impl SP1Prover { loop { tracing::debug!("Recursive proof layer size: {}", reduce_proofs.len()); is_complete = reduce_proofs.len() <= batch_size; - reduce_proofs = reduce_proofs - .par_chunks(batch_size) - .map(|batch| { - let (shard_proofs, kinds) = - batch.iter().cloned().unzip::<_, _, Vec<_>, Vec<_>>(); - - let input = SP1ReduceMemoryLayout { - compress_vk: &self.compress_vk, - recursive_machine: &self.compress_machine, - shard_proofs, - kinds, - is_complete, - }; - - let proof = self.compress_machine_proof( - input, - &self.compress_program, - &self.compress_pk, - ); - (proof, ReduceProgramType::Reduce) + + let compress_inputs = reduce_proofs.chunks(batch_size).collect::>(); + let batched_compress_inputs = + compress_inputs.chunks(shard_batch_size).collect::>(); + reduce_proofs = batched_compress_inputs + .into_iter() + .flat_map(|batches| { + batches + .par_iter() + .map(|batch| { + let (shard_proofs, kinds) = + batch.iter().cloned().unzip::<_, _, Vec<_>, Vec<_>>(); + + let input = SP1ReduceMemoryLayout { + compress_vk: &self.compress_vk, + recursive_machine: &self.compress_machine, + shard_proofs, + kinds, + is_complete, + }; + + let proof = self.compress_machine_proof( + input, + &self.compress_program, + &self.compress_pk, + ); + (proof, ReduceProgramType::Reduce) + }) + .collect::>() }) - .collect(); + .collect::>(); if reduce_proofs.len() == 1 { break;