Skip to content

Commit

Permalink
chore(recursion): consolidate initial and finalize memory tables (#656)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevjue authored May 3, 2024
1 parent ca54384 commit 9e0d419
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 90 deletions.
116 changes: 56 additions & 60 deletions recursion/core/src/memory/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@ use std::borrow::{Borrow, BorrowMut};
use tracing::instrument;

use super::columns::MemoryInitCols;
use crate::memory::MemoryChipKind;
use crate::memory::MemoryGlobalChip;
use crate::runtime::{ExecutionRecord, RecursionProgram};

pub(crate) const NUM_MEMORY_INIT_COLS: usize = size_of::<MemoryInitCols<u8>>();

#[allow(dead_code)]
impl MemoryGlobalChip {
pub fn new(kind: MemoryChipKind) -> Self {
pub fn new() -> Self {
Self {
kind,
fixed_log2_rows: None,
}
}
Expand All @@ -32,10 +30,7 @@ impl<F: PrimeField32> MachineAir<F> for MemoryGlobalChip {
type Program = RecursionProgram<F>;

fn name(&self) -> String {
match self.kind {
MemoryChipKind::Init => "MemoryInit".to_string(),
MemoryChipKind::Finalize => "MemoryFinalize".to_string(),
}
"MemoryGlobalChip".to_string()
}

fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
Expand All @@ -48,23 +43,28 @@ impl<F: PrimeField32> MachineAir<F> for MemoryGlobalChip {
input: &ExecutionRecord<F>,
_output: &mut ExecutionRecord<F>,
) -> RowMajorMatrix<F> {
let mut rows = match self.kind {
MemoryChipKind::Init => {
let addresses = &input.first_memory_record;
addresses
.iter()
.map(|(addr, value)| {
let mut row = [F::zero(); NUM_MEMORY_INIT_COLS];
let cols: &mut MemoryInitCols<F> = row.as_mut_slice().borrow_mut();
cols.addr = *addr;
cols.timestamp = F::zero();
cols.value = *value;
cols.is_real = F::one();
row
})
.collect::<Vec<_>>()
}
MemoryChipKind::Finalize => input
let mut rows = Vec::new();

// Fill in the initial memory records.
rows.extend(
input
.first_memory_record
.iter()
.map(|(addr, value)| {
let mut row = [F::zero(); NUM_MEMORY_INIT_COLS];
let cols: &mut MemoryInitCols<F> = row.as_mut_slice().borrow_mut();
cols.addr = *addr;
cols.timestamp = F::zero();
cols.value = *value;
cols.is_initialize = F::one();
row
})
.collect::<Vec<_>>(),
);

// Fill in the finalize memory records.
rows.extend(
input
.last_memory_record
.iter()
.map(|(addr, timestamp, value)| {
Expand All @@ -73,11 +73,11 @@ impl<F: PrimeField32> MachineAir<F> for MemoryGlobalChip {
cols.addr = *addr;
cols.timestamp = *timestamp;
cols.value = *value;
cols.is_real = F::one();
cols.is_finalize = F::one();
row
})
.collect::<Vec<_>>(),
};
);

// Pad the trace to a power of two.
pad_rows_fixed(
Expand All @@ -93,10 +93,7 @@ impl<F: PrimeField32> MachineAir<F> for MemoryGlobalChip {
}

fn included(&self, shard: &Self::Record) -> bool {
match self.kind {
MemoryChipKind::Init => !shard.first_memory_record.is_empty(),
MemoryChipKind::Finalize => !shard.last_memory_record.is_empty(),
}
!shard.first_memory_record.is_empty() || !shard.last_memory_record.is_empty()
}
}

Expand All @@ -115,35 +112,34 @@ where
let local = main.row_slice(0);
let local: &MemoryInitCols<AB::Var> = (*local).borrow();

match self.kind {
MemoryChipKind::Init => {
builder.send(AirInteraction::new(
vec![
local.timestamp.into(),
local.addr.into(),
local.value[0].into(),
local.value[1].into(),
local.value[2].into(),
local.value[3].into(),
],
local.is_real.into(),
InteractionKind::Memory,
));
}
MemoryChipKind::Finalize => {
builder.receive(AirInteraction::new(
vec![
local.timestamp.into(),
local.addr.into(),
local.value[0].into(),
local.value[1].into(),
local.value[2].into(),
local.value[3].into(),
],
local.is_real.into(),
InteractionKind::Memory,
));
}
};
// Verify that is_initilize and is_finalize are bool and that at most one is true.
builder.assert_bool(local.is_initialize);
builder.assert_bool(local.is_finalize);
builder.assert_bool(local.is_initialize + local.is_finalize);

builder.send(AirInteraction::new(
vec![
local.timestamp.into(),
local.addr.into(),
local.value[0].into(),
local.value[1].into(),
local.value[2].into(),
local.value[3].into(),
],
local.is_initialize.into(),
InteractionKind::Memory,
));
builder.receive(AirInteraction::new(
vec![
local.timestamp.into(),
local.addr.into(),
local.value[0].into(),
local.value[1].into(),
local.value[2].into(),
local.value[3].into(),
],
local.is_finalize.into(),
InteractionKind::Memory,
));
}
}
3 changes: 2 additions & 1 deletion recursion/core/src/memory/columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ pub struct MemoryInitCols<T> {
pub addr: T,
pub timestamp: T,
pub value: Block<T>,
pub is_real: T,
pub is_initialize: T,
pub is_finalize: T,
}

/// NOTE: These are very similar to core/src/memory/columns.rs
Expand Down
9 changes: 1 addition & 8 deletions recursion/core/src/memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,7 @@ impl<F: PrimeField32, TValue> MemoryAccessCols<F, TValue> {
}
}

#[allow(dead_code)]
#[derive(PartialEq)]
pub enum MemoryChipKind {
Init,
Finalize,
}

#[derive(Default)]
pub struct MemoryGlobalChip {
pub fixed_log2_rows: Option<usize>,
pub kind: MemoryChipKind,
}
26 changes: 5 additions & 21 deletions recursion/core/src/stark/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,8 @@ pub mod poseidon2;
pub mod utils;

use crate::{
cpu::CpuChip,
fri_fold::FriFoldChip,
memory::{MemoryChipKind, MemoryGlobalChip},
multi::MultiChip,
poseidon2::Poseidon2Chip,
poseidon2_wide::Poseidon2WideChip,
program::ProgramChip,
cpu::CpuChip, fri_fold::FriFoldChip, memory::MemoryGlobalChip, multi::MultiChip,
poseidon2::Poseidon2Chip, poseidon2_wide::Poseidon2WideChip, program::ProgramChip,
range_check::RangeCheckChip,
};
use core::iter::once;
Expand All @@ -31,8 +26,7 @@ pub type RecursionAirSkinnyDeg7<F> = RecursionAir<F, 5>;
pub enum RecursionAir<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> {
Program(ProgramChip),
Cpu(CpuChip<F>),
MemoryInit(MemoryGlobalChip),
MemoryFinalize(MemoryGlobalChip),
MemoryGlobal(MemoryGlobalChip),
Poseidon2Wide(Poseidon2WideChip<DEGREE>),
Poseidon2Skinny(Poseidon2Chip),
FriFold(FriFoldChip),
Expand Down Expand Up @@ -65,12 +59,7 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> RecursionAi
fixed_log2_rows: None,
_phantom: PhantomData,
})))
.chain(once(RecursionAir::MemoryInit(MemoryGlobalChip {
kind: MemoryChipKind::Init,
fixed_log2_rows: None,
})))
.chain(once(RecursionAir::MemoryFinalize(MemoryGlobalChip {
kind: MemoryChipKind::Finalize,
.chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip {
fixed_log2_rows: None,
})))
.chain(once(RecursionAir::Poseidon2Wide(Poseidon2WideChip::<
Expand All @@ -91,12 +80,7 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> RecursionAi
fixed_log2_rows: Some(20),
_phantom: PhantomData,
})))
.chain(once(RecursionAir::MemoryInit(MemoryGlobalChip {
kind: MemoryChipKind::Init,
fixed_log2_rows: Some(18),
})))
.chain(once(RecursionAir::MemoryFinalize(MemoryGlobalChip {
kind: MemoryChipKind::Finalize,
.chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip {
fixed_log2_rows: Some(18),
})))
.chain(once(RecursionAir::Multi(MultiChip {
Expand Down

0 comments on commit 9e0d419

Please sign in to comment.