Skip to content

Commit

Permalink
Downsize VMs when they reach a loop (#1676)
Browse files Browse the repository at this point in the history
Builds on #1666

With this PR, Witgen choses the minimum possible size of VMs (rounding
to the next power of two after first detecting a loop):

```
$ MAX_DEGREE_LOG=10 cargo run -r pil test_data/asm/vm_to_block_different_length.asm -o output -f --field bn254 --prove-with halo2-composite
...
== Proving machine: main (size 128), stage 0
Starting proof generation...
Generating PK for snark...
Generating proof...
Time taken: 59.970083ms
Proof generation done.
==> Proof stage computed in 81.618417ms
== Proving machine: main__rom (size 8), stage 0
Starting proof generation...
Generating PK for snark...
Generating proof...
Time taken: 12.319334ms
Proof generation done.
==> Proof stage computed in 29.754417ms
== Proving machine: main_arith (size 32), stage 0
Starting proof generation...
Generating PK for snark...
Generating proof...
Time taken: 20.056834ms
Proof generation done.
==> Proof stage computed in 32.017167ms
Proof generation took 0.14349297s
Proof size: 8296 bytes
```
  • Loading branch information
georgwiese authored Aug 13, 2024
1 parent 2b9984c commit 88785a8
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 13 deletions.
24 changes: 21 additions & 3 deletions executor/src/witgen/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ impl<'a, T: FieldElement> Generator<'a, T> {
}

fn process<'b, Q: QueryCallback<T>>(
&self,
&mut self,
first_row: Row<T>,
row_offset: DegreeType,
mutable_state: &mut MutableState<'a, 'b, T, Q>,
Expand All @@ -213,6 +213,7 @@ impl<'a, T: FieldElement> Generator<'a, T> {
);

let mut processor = VmProcessor::new(
self.name().to_string(),
RowIndex::from_degree(row_offset, self.degree),
self.fixed_data,
&self.identities,
Expand All @@ -224,7 +225,11 @@ impl<'a, T: FieldElement> Generator<'a, T> {
processor = processor.with_outer_query(outer_query);
}
let eval_value = processor.run(is_main_run);
let block = processor.finish();
let (block, degree) = processor.finish();

// The processor might have detected a loop, in which case the degree has changed
self.degree = degree;

ProcessResult { eval_value, block }
}

Expand All @@ -234,6 +239,19 @@ impl<'a, T: FieldElement> Generator<'a, T> {
assert_eq!(self.data.len() as DegreeType, self.degree + 1);

let last_row = self.data.pop().unwrap();
self.data[0].merge_with(&last_row).unwrap();
if self.data[0].merge_with(&last_row).is_err() {
log::error!(
"{}",
self.data[0].render("First row", false, &self.witnesses, self.fixed_data)
);
log::error!(
"{}",
last_row.render("Last row", false, &self.witnesses, self.fixed_data)
);
panic!(
"Failed to merge the first and last row of the VM '{}'",
self.name()
);
}
}
}
4 changes: 4 additions & 0 deletions executor/src/witgen/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> Processor<'a, 'b, 'c, T,
}
}

pub fn set_size(&mut self, size: DegreeType) {
self.size = size;
}

pub fn finished_outer_query(&self) -> bool {
self.outer_query
.as_ref()
Expand Down
40 changes: 33 additions & 7 deletions executor/src/witgen/vm_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::cmp::max;
use std::collections::HashSet;
use std::time::Instant;

use crate::constant_evaluator::MIN_DEGREE_LOG;
use crate::witgen::identity_processor::{self};
use crate::witgen::IncompleteCause;
use crate::Identity;
Expand Down Expand Up @@ -43,6 +44,8 @@ impl<'a, T: FieldElement> CompletableIdentities<'a, T> {
}

pub struct VmProcessor<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> {
/// The name of the machine being run
machine_name: String,
/// The common degree of all referenced columns
degree: DegreeType,
/// The global index of the first row of [VmProcessor::data].
Expand All @@ -64,6 +67,7 @@ pub struct VmProcessor<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> {

impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> VmProcessor<'a, 'b, 'c, T, Q> {
pub fn new(
machine_name: String,
row_offset: RowIndex,
fixed_data: &'a FixedData<'a, T>,
identities: &[&'a Identity<T>],
Expand Down Expand Up @@ -94,6 +98,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> VmProcessor<'a, 'b, 'c, T
);

VmProcessor {
machine_name,
degree,
row_offset: row_offset.into(),
witnesses: witnesses.clone(),
Expand All @@ -112,8 +117,8 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> VmProcessor<'a, 'b, 'c, T
Self { processor, ..self }
}

pub fn finish(self) -> FinalizableData<T> {
self.processor.finish()
pub fn finish(self) -> (FinalizableData<T>, DegreeType) {
(self.processor.finish(), self.degree)
}

/// Starting out with a single row (at a given offset), iteratively append rows
Expand All @@ -137,9 +142,16 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> VmProcessor<'a, 'b, 'c, T
} else {
log::Level::Debug
};
let rows_left = self.degree - self.row_offset + 1;
let mut finalize_start = 1;
for row_index in 0..rows_left {

for row_index in 0.. {
// The total number of rows to run for. Note that `self.degree` might change during
// the computation, so we need to recompute this value in each iteration.
let rows_to_run = self.degree - self.row_offset + 1;
if row_index >= rows_to_run {
break;
}

if is_main_run {
self.maybe_log_performance(row_index);
}
Expand All @@ -152,7 +164,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> VmProcessor<'a, 'b, 'c, T
finalize_start = finalize_end;
}

if row_index >= rows_left - 2 {
if row_index >= rows_to_run - 2 {
// On th last few rows, it is quite normal for the constraints to be different,
// so we reduce the log level for loop detection.
loop_detection_log_level = log::Level::Debug;
Expand All @@ -166,14 +178,28 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> VmProcessor<'a, 'b, 'c, T
loop_detection_log_level,
"Found loop with period {p} starting at row {row_index}"
);

if self.fixed_data.is_variable_size(&self.witnesses) {
let new_degree = self.processor.len().next_power_of_two() as DegreeType;
let new_degree = new_degree.max(1 << MIN_DEGREE_LOG);
log::info!(
"Resizing variable length machine '{}': {} -> {} (rounded up from {})",
self.machine_name,
self.degree,
new_degree,
self.processor.len()
);
self.degree = new_degree;
self.processor.set_size(new_degree);
}
}
}
if let Some(period) = looping_period {
let proposed_row = self.processor.row(row_index as usize - period).clone();
if !self.try_proposed_row(row_index, proposed_row) {
log::log!(
loop_detection_log_level,
"Looping failed. Trying to generate regularly again. (Use RUST_LOG=debug to see whether this happens more often.) {row_index} {rows_left}"
"Looping failed. Trying to generate regularly again. (Use RUST_LOG=debug to see whether this happens more often.) {row_index} / {rows_to_run}"
);
looping_period = None;
// For some programs, loop detection will often find loops and then fail.
Expand All @@ -184,7 +210,7 @@ impl<'a, 'b, 'c, T: FieldElement, Q: QueryCallback<T>> VmProcessor<'a, 'b, 'c, T
// Note that we exit one iteration early in the non-loop case,
// because ensure_has_next_row() + compute_row() will already
// add and compute some values for the next row as well.
if looping_period.is_none() && row_index != rows_left - 1 {
if looping_period.is_none() && row_index != rows_to_run - 1 {
self.ensure_has_next_row(row_index);
outer_assignments.extend(self.compute_row(row_index).into_iter());

Expand Down
5 changes: 2 additions & 3 deletions test_data/asm/vm_to_block_different_length.asm
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
machine Main with degree: 16 {
machine Main {
Arith arith;

reg pc[@pc];
Expand All @@ -21,8 +21,7 @@ machine Main with degree: 16 {

machine Arith with
latch: latch,
operation_id: operation_id,
degree: 16
operation_id: operation_id
{

operation add<0> x[0], x[1] -> y;
Expand Down

0 comments on commit 88785a8

Please sign in to comment.