Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto witgen #2071

Draft
wants to merge 18 commits into
base: call_jit_from_block
Choose a base branch
from
1 change: 1 addition & 0 deletions executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ bit-vec = "0.6.3"
num-traits = "0.2.15"
derive_more = "0.99.17"
lazy_static = "1.4.0"
libloading = "0.8"
indicatif = "0.17.7"
serde = { version = "1.0", default-features = false, features = ["alloc", "derive", "rc"] }

Expand Down
147 changes: 136 additions & 11 deletions executor/src/witgen/data_structures/finalizable_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,98 @@ pub struct CompactData<T: FieldElement> {
/// The cell values, stored in row-major order.
data: Vec<T>,
/// Bit vector of known cells, stored in row-major order.
known_cells: BitVec,
known_cells: PaddedBitVec,
}

#[derive(Clone)]
pub struct PaddedBitVec {
data: Vec<u32>,
// TODO check if we really need all of those.
bits_per_row: usize,
words_per_row: usize,
rows: usize,
bits_in_last_row: usize,
}

impl PaddedBitVec {
fn new(bits_per_row: usize) -> Self {
let words_per_row = (bits_per_row + 31) / 32;
Self {
data: Vec::new(),
bits_per_row,
words_per_row,
rows: 0,
bits_in_last_row: 0,
}
}

fn truncate_to_rows(&mut self, len: usize) {
assert!(len <= self.rows);
self.data.truncate(len * self.words_per_row);
self.rows = len;
self.bits_in_last_row = 0;
}

fn clear(&mut self) {
self.data.clear();
self.rows = 0;
self.bits_in_last_row = 0;
}

fn reserve_rows(&mut self, count: usize) {
self.data.reserve(count * self.words_per_row);
}

fn push(&mut self, value: bool) {
if self.bits_in_last_row == 0 {
self.data.push(value as u32);
self.rows += 1;
} else {
let last_word = self.data.last_mut().unwrap();
if value {
*last_word |= 1 << (self.bits_in_last_row - 1);
}
}
self.bits_in_last_row = (self.bits_in_last_row + 1) % self.bits_per_row;
}

fn append_empty_rows(&mut self, count: usize) {
assert!(self.bits_in_last_row == 0);
self.data
.resize(self.data.len() + count * self.words_per_row, 0);
self.rows += count;
}

fn get(&self, row: usize, col: u64) -> bool {
let word = &self.data[row * self.words_per_row + (col / 32) as usize];
(word & (1 << (col % 32))) != 0
}

fn set(&mut self, row: usize, col: u64, value: bool) {
let word = &mut self.data[row * self.words_per_row + (col / 32) as usize];
if value {
*word |= 1 << (col % 32);
} else {
*word &= !(1 << (col % 32));
}
}

fn mut_slice(&mut self) -> &mut [u32] {
self.data.as_mut_slice()
}
}

impl<T: FieldElement> CompactData<T> {
/// Creates a new empty compact data storage.
pub fn new(column_ids: &[PolyID]) -> Self {
let col_id_range = column_ids.iter().map(|id| id.id).minmax();
let (first_column_id, last_column_id) = col_id_range.into_option().unwrap();
let column_count = (last_column_id - first_column_id + 1) as usize;
Self {
first_column_id,
column_count: (last_column_id - first_column_id + 1) as usize,
column_count,
data: Vec::new(),
known_cells: BitVec::new(),
known_cells: PaddedBitVec::new(column_count),
}
}

Expand All @@ -49,7 +128,7 @@ impl<T: FieldElement> CompactData<T> {
/// Truncates the data to `len` rows.
pub fn truncate(&mut self, len: usize) {
self.data.truncate(len * self.column_count);
self.known_cells.truncate(len * self.column_count);
self.known_cells.truncate_to_rows(len);
}

pub fn clear(&mut self) {
Expand All @@ -60,7 +139,7 @@ impl<T: FieldElement> CompactData<T> {
/// Appends a non-finalized row to the data, turning it into a finalized row.
pub fn push(&mut self, row: Row<T>) {
self.data.reserve(self.column_count);
self.known_cells.reserve(self.column_count);
self.known_cells.reserve_rows(1);
for col_id in self.first_column_id..(self.first_column_id + self.column_count as u64) {
if let Some(v) = row.value(&PolyID {
id: col_id,
Expand All @@ -78,8 +157,7 @@ impl<T: FieldElement> CompactData<T> {
pub fn append_new_rows(&mut self, count: usize) {
self.data
.resize(self.data.len() + count * self.column_count, T::zero());
self.known_cells
.grow(self.known_cells.len() + count * self.column_count, false);
self.known_cells.append_empty_rows(count);
}

fn index(&self, row: usize, col: u64) -> usize {
Expand All @@ -89,25 +167,36 @@ impl<T: FieldElement> CompactData<T> {

pub fn get(&self, row: usize, col: u64) -> (T, bool) {
let idx = self.index(row, col);
(self.data[idx], self.known_cells[idx])
(
self.data[idx],
self.known_cells.get(row, col - self.first_column_id),
)
}

pub fn set(&mut self, row: usize, col: u64, value: T) {
let idx = self.index(row, col);
assert!(!self.known_cells[idx] || self.data[idx] == value);
assert!(!self.known_cells.get(row, col - self.first_column_id) || self.data[idx] == value);
self.data[idx] = value;
self.known_cells.set(idx, true);
self.known_cells.set(row, col - self.first_column_id, true);
}

pub fn set_known(&mut self, row: usize, col: u64) {
self.known_cells.set(row, col - self.first_column_id, true);
}

pub fn known_values_in_row(&self, row: usize) -> impl Iterator<Item = (u64, &T)> {
(0..self.column_count).filter_map(move |i| {
let idx = row * self.column_count + i;
self.known_cells[idx].then(|| {
self.known_cells.get(row, i as u64).then(|| {
let col_id = self.first_column_id + i as u64;
(col_id, &self.data[idx])
})
})
}

pub fn data_slice(&mut self) -> (&mut [T], &mut [u32]) {
(self.data.as_mut_slice(), self.known_cells.mut_slice())
}
}

/// A mutable reference into CompactData that is meant to be used
Expand All @@ -132,12 +221,27 @@ impl<'a, T: FieldElement> CompactDataRef<'a, T> {
}

pub fn set(&mut self, row: i32, col: u32, value: T) {
//println!("outer row {row} is inner row {}", self.inner_row(row));
self.data.set(self.inner_row(row), col as u64, value);
}

pub fn set_known(&mut self, row: i32, col: u32) {
self.data.set_known(self.inner_row(row), col as u64);
}

fn inner_row(&self, row: i32) -> usize {
(row + self.row_offset as i32) as usize
}

pub fn direct_slice(&mut self) -> (&mut [T], &mut [u32], usize) {
// println!(
// "Extracting slice at row offset {}, total length: {}",
// self.row_offset,
// self.data.len()
// );
let (data, known) = self.data.data_slice();
(data, known, self.row_offset)
}
}

/// A data structure that stores witness data.
Expand Down Expand Up @@ -333,6 +437,25 @@ impl<T: FieldElement> FinalizableData<T> {
(current, next)
}

pub fn set_row(&mut self, i: usize, row: Row<T>) {
match self.location_of_row(i) {
Location::PreFinalized(local) => {
self.pre_finalized_data[local] = row;
}
Location::Finalized(local) => {
for poly_id in &self.column_ids {
// TODO this ignores values that have been unset in `row`.
if let Some(v) = row.value(poly_id) {
self.finalized_data.set(local, poly_id.id, v);
}
}
}
Location::PostFinalized(local) => {
self.post_finalized_data[local] = row;
}
}
}

pub fn finalize_range(&mut self, range: std::ops::Range<usize>) -> usize {
if range.is_empty() {
return 0;
Expand Down Expand Up @@ -442,6 +565,8 @@ impl<T: FieldElement> FinalizableData<T> {
}
}

// TODO are these here stil used?

impl<T: FieldElement> Index<usize> for FinalizableData<T> {
type Output = Row<T>;

Expand Down
26 changes: 21 additions & 5 deletions executor/src/witgen/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::witgen::EvalValue;
use super::affine_expression::AlgebraicVariable;
use super::block_processor::BlockProcessor;
use super::data_structures::multiplicity_counter::MultiplicityCounter;
use super::jit::jit_processor::{self, JitProcessor};
use super::machines::{Machine, MachineParts};
use super::processor::SolverState;
use super::rows::{Row, RowIndex, RowPair};
Expand All @@ -31,6 +32,9 @@ pub struct Generator<'a, T: FieldElement> {
name: String,
degree: DegreeType,
multiplicity_counter: MultiplicityCounter,
/// The JIT processor for this machine, i.e. the component that tries to generate
/// witgen code based on which elements of the connection are known.
jit_processor: JitProcessor<'a, T>,
}

impl<'a, T: FieldElement> Machine<'a, T> for Generator<'a, T> {
Expand Down Expand Up @@ -119,6 +123,12 @@ impl<'a, T: FieldElement> Generator<'a, T> {
let data = FinalizableData::new(&parts.witnesses);
let multiplicity_counter = MultiplicityCounter::new(&parts.connections);

let jit_processor = JitProcessor::new(
fixed_data,
parts.clone(),
1, // block size
0, // latch row
);
Self {
degree: parts.common_degree_range().max,
name,
Expand All @@ -128,18 +138,24 @@ impl<'a, T: FieldElement> Generator<'a, T> {
publics: Default::default(),
latch,
multiplicity_counter,
jit_processor,
}
}

/// Runs the machine without any arguments from the first row.
pub fn run<'b, Q: QueryCallback<T>>(&mut self, mutable_state: &mut MutableState<'a, 'b, T, Q>) {
record_start(self.name());
assert!(self.data.is_empty());
let first_row = self.compute_partial_first_row(mutable_state);
self.data = self
.process(first_row, 0, mutable_state, None, true)
.updated_data
.block;
if self.jit_processor.can_run() {
//self.jit_processor.run();
unimplemented!()
} else {
let first_row = self.compute_partial_first_row(mutable_state);
self.data = self
.process(first_row, 0, mutable_state, None, true)
.updated_data
.block;
}
record_end(self.name());
}

Expand Down
31 changes: 29 additions & 2 deletions executor/src/witgen/identity_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ use crate::{
};

use super::{
affine_expression::AlgebraicVariable, machines::KnownMachine, processor::OuterQuery,
rows::RowPair, EvalResult, EvalValue, IncompleteCause, MutableState, QueryCallback,
affine_expression::AlgebraicVariable,
machines::{KnownMachine, LookupCell},
processor::OuterQuery,
rows::RowPair,
EvalResult, EvalValue, IncompleteCause, MutableState, QueryCallback,
};

/// A list of mutable references to machines.
Expand Down Expand Up @@ -88,6 +91,30 @@ impl<'a, 'b, T: FieldElement> Machines<'a, 'b, T> {
current.process_plookup_timed(&mut mutable_state, identity_id, caller_rows)
}

pub fn call_direct<Q: QueryCallback<T>>(
&mut self,
identity_id: u64,
values: Vec<LookupCell<'_, T>>,
query_callback: &mut Q,
) -> Result<bool, EvalError<T>> {
let machine_index = *self
.identity_to_machine_index
.get(&identity_id)
.unwrap_or_else(|| panic!("No executor machine matched identity ID: {identity_id}"));

// TOOD this has horrible performance, avoid this.
// It's probably much better to use runtime borrow checks.
// This will fail if we have circular calls.
// At some point, we can probably turn this into an async message passing interface.
let (current, others) = self.split(machine_index);

let mut mutable_state = MutableState {
machines: others,
query_callback,
};
current.process_lookup_direct(&mut mutable_state, identity_id, values)
}

pub fn take_witness_col_values<Q: QueryCallback<T>>(
&mut self,
query_callback: &mut Q,
Expand Down
48 changes: 48 additions & 0 deletions executor/src/witgen/jit/cell.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use std::hash::{Hash, Hasher};

use powdr_ast::analyzed::AlgebraicReference;

/// The identifier of a cell in the trace table, relative to a "zero row".
#[derive(Debug, Clone, Eq)]
pub struct Cell {
pub column_name: String,
pub id: u64,
pub row_offset: i32,
}

impl Hash for Cell {
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state);
self.row_offset.hash(state);
}
}

impl PartialEq for Cell {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.row_offset == other.row_offset
}
}

impl Ord for Cell {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(self.id, self.row_offset).cmp(&(other.id, other.row_offset))
}
}

impl PartialOrd for Cell {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Cell {
pub fn from_reference(r: &AlgebraicReference, offset: i32) -> Self {
assert!(r.is_witness());
let row_offset = r.next as i32 + offset;
Self {
column_name: r.name.clone(),
id: r.poly_id.id,
row_offset,
}
}
}
Loading