diff --git a/core/src/alu/mul/mod.rs b/core/src/alu/mul/mod.rs index c30a59c4f4..b1ef70b93f 100644 --- a/core/src/alu/mul/mod.rs +++ b/core/src/alu/mul/mod.rs @@ -412,14 +412,6 @@ where .when(local.c_sign_extend) .assert_eq(local.c_msb, one.clone()); - // If the opcode doesn't allow sign extension for an operand, we must not extend their sign. - builder - .when(local.is_mul + local.is_mulhu) - .assert_zero(local.b_sign_extend + local.c_sign_extend); - builder - .when(local.is_mul + local.is_mulhsu + local.is_mulhsu) - .assert_zero(local.c_sign_extend); - // Calculate the opcode. let opcode = { // Exactly one of the op codes must be on. diff --git a/core/src/alu/sr/mod.rs b/core/src/alu/sr/mod.rs index 8f9ea721e5..58b92086b8 100644 --- a/core/src/alu/sr/mod.rs +++ b/core/src/alu/sr/mod.rs @@ -348,6 +348,16 @@ where .assert_eq(num_bits_to_shift.clone(), AB::F::from_canonical_usize(i)); } + // Bool check c_least_sig_bytes elements. + for bit in local.c_least_sig_byte.iter() { + builder.assert_bool(*bit); + } + + // Bool check shift_by_n_bits elements. + for shift in local.shift_by_n_bits.iter() { + builder.assert_bool(*shift); + } + // Exactly one of the shift_by_n_bits must be 1. builder.assert_eq( local @@ -485,6 +495,8 @@ where builder.assert_bool(local.is_sra); builder.assert_bool(local.is_real); + builder.assert_eq(local.is_srl + local.is_sra, local.is_real); + // Receive the arguments. builder.receive_alu( local.is_srl * AB::F::from_canonical_u32(Opcode::SRL as u32) diff --git a/core/src/cpu/air/branch.rs b/core/src/cpu/air/branch.rs index fad654de35..0300420663 100644 --- a/core/src/cpu/air/branch.rs +++ b/core/src/cpu/air/branch.rs @@ -3,6 +3,7 @@ use p3_field::AbstractField; use crate::air::{BaseAirBuilder, SP1AirBuilder, Word, WordAirBuilder}; use crate::cpu::columns::{CpuCols, OpcodeSelectorCols}; +use crate::operations::BabyBearWord; use crate::{cpu::CpuChip, runtime::Opcode}; impl CpuChip { @@ -42,26 +43,38 @@ impl CpuChip { // When we are branching, assert local.pc <==> branch_cols.pc as Word. builder .when(local.branching) - .assert_eq(branch_cols.pc.reduce::(), local.pc); + .assert_eq(branch_cols.pc.value.reduce::(), local.pc); // When we are branching, assert that next.pc <==> branch_columns.next_pc as Word. builder .when_transition() .when(next.is_real) .when(local.branching) - .assert_eq(branch_cols.next_pc.reduce::(), next.pc); + .assert_eq(branch_cols.next_pc.value.reduce::(), next.pc); // When the current row is real and local.branching, assert that local.next_pc <==> branch_columns.next_pc as Word. builder .when(local.is_real) .when(local.branching) - .assert_eq(branch_cols.next_pc.reduce::(), local.next_pc); + .assert_eq(branch_cols.next_pc.value.reduce::(), local.next_pc); + + // Range check branch_cols.pc and branch_cols.next_pc + BabyBearWord::::range_check( + builder, + branch_cols.pc, + is_branch_instruction.clone(), + ); + BabyBearWord::::range_check( + builder, + branch_cols.next_pc, + is_branch_instruction.clone(), + ); // When we are branching, calculate branch_cols.next_pc <==> branch_cols.pc + c. builder.send_alu( Opcode::ADD.as_field::(), - branch_cols.next_pc, - branch_cols.pc, + branch_cols.next_pc.value, + branch_cols.pc.value, local.op_c_val(), local.shard, local.channel, diff --git a/core/src/cpu/air/ecall.rs b/core/src/cpu/air/ecall.rs index 506b2c7b75..6fb631cd4f 100644 --- a/core/src/cpu/air/ecall.rs +++ b/core/src/cpu/air/ecall.rs @@ -35,10 +35,14 @@ impl CpuChip { let syscall_id = syscall_code[0]; let send_to_table = syscall_code[1]; - // When is_ecall_instruction == true AND sent_to_table == true, ecall_mul_send_to_table should be true. + // When is_ecall_instruction == true AND send_to_table == true, ecall_mul_send_to_table should be true. + builder.assert_bool(local.ecall_mul_send_to_table); builder .when(is_ecall_instruction.clone()) - .assert_eq(send_to_table, local.ecall_mul_send_to_table); + .assert_bool(send_to_table); + builder + .when(local.ecall_mul_send_to_table) + .assert_one(send_to_table * is_ecall_instruction.clone()); builder.send_syscall( local.shard, local.channel, diff --git a/core/src/cpu/air/memory.rs b/core/src/cpu/air/memory.rs index 6ac1a07c11..5b9dae515e 100644 --- a/core/src/cpu/air/memory.rs +++ b/core/src/cpu/air/memory.rs @@ -88,6 +88,35 @@ impl CpuChip { memory_columns.addr_word.reduce::(), ); + // Verify that the least significant byte of addr_word - addr_offset is divisible by 4. + let offset = [ + memory_columns.offset_is_one, + memory_columns.offset_is_two, + memory_columns.offset_is_three, + ] + .iter() + .enumerate() + .fold(AB::Expr::zero(), |acc, (index, &value)| { + acc + AB::Expr::from_canonical_usize(index + 1) * value + }); + let mut recomposed_byte = AB::Expr::zero(); + memory_columns + .aa_least_sig_byte_decomp + .iter() + .enumerate() + .for_each(|(i, value)| { + builder + .when(is_memory_instruction.clone()) + .assert_bool(*value); + + recomposed_byte = + recomposed_byte.clone() + AB::Expr::from_canonical_usize(1 << (i + 2)) * *value; + }); + + builder + .when(is_memory_instruction.clone()) + .assert_eq(memory_columns.addr_word[0] - offset, recomposed_byte); + // For operations that require reading from memory (not registers), we need to read the // value into the memory columns. builder.eval_memory_access( @@ -98,6 +127,14 @@ impl CpuChip { &memory_columns.memory_access, is_memory_instruction.clone(), ); + + // On memory load instructions, make sure that the memory value is not changed. + builder + .when(self.is_load_instruction::(&local.selectors)) + .assert_word_eq( + *memory_columns.memory_access.value(), + *memory_columns.memory_access.prev_value(), + ); } /// Evaluates constraints related to loading from memory. @@ -146,7 +183,7 @@ impl CpuChip { local.mem_value_is_neg, ); - // When the memory value is not negaitve, assert that op_a value is equal to the unsigned + // When the memory value is not negative, assert that op_a value is equal to the unsigned // memory value. builder .when(is_load) @@ -195,6 +232,11 @@ impl CpuChip { .when(local.selectors.is_sh) .assert_zero(memory_columns.offset_is_one + memory_columns.offset_is_three); + // When the instruction is SW, ensure that the offset is 0. + builder + .when(local.selectors.is_sw) + .assert_one(offset_is_zero.clone()); + // Compute the expected stored value for a SH instruction. let a_is_lower_half = offset_is_zero; let a_is_upper_half = memory_columns.offset_is_two; @@ -247,6 +289,12 @@ impl CpuChip { builder .when(local.selectors.is_lh + local.selectors.is_lhu) .assert_zero(memory_columns.offset_is_one + memory_columns.offset_is_three); + + // When the instruction is LW, ensure that the offset is zero. + builder + .when(local.selectors.is_lw) + .assert_one(offset_is_zero.clone()); + let use_lower_half = offset_is_zero; let use_upper_half = memory_columns.offset_is_two; let half_value = Word([ diff --git a/core/src/cpu/air/mod.rs b/core/src/cpu/air/mod.rs index 11a985bb5e..b2e7fb8a35 100644 --- a/core/src/cpu/air/mod.rs +++ b/core/src/cpu/air/mod.rs @@ -22,6 +22,7 @@ use crate::bytes::ByteOpcode; use crate::cpu::columns::OpcodeSelectorCols; use crate::cpu::columns::{CpuCols, NUM_CPU_COLS}; use crate::cpu::CpuChip; +use crate::operations::BabyBearWord; use crate::runtime::Opcode; use super::columns::eval_channel_selectors; @@ -160,26 +161,34 @@ impl CpuChip { // Verify that the word form of local.pc is correct for JAL instructions. builder .when(local.selectors.is_jal) - .assert_eq(jump_columns.pc.reduce::(), local.pc); + .assert_eq(jump_columns.pc.value.reduce::(), local.pc); // Verify that the word form of next.pc is correct for both jump instructions. builder .when_transition() .when(next.is_real) .when(is_jump_instruction.clone()) - .assert_eq(jump_columns.next_pc.reduce::(), next.pc); + .assert_eq(jump_columns.next_pc.value.reduce::(), next.pc); // When the last row is real and it's a jump instruction, assert that local.next_pc <==> jump_column.next_pc builder .when(local.is_real) .when(is_jump_instruction.clone()) - .assert_eq(jump_columns.next_pc.reduce::(), local.next_pc); + .assert_eq(jump_columns.next_pc.value.reduce::(), local.next_pc); + + // Range check pc and next_pc + BabyBearWord::::range_check(builder, jump_columns.pc, local.selectors.is_jal.into()); + BabyBearWord::::range_check( + builder, + jump_columns.next_pc, + is_jump_instruction.clone(), + ); // Verify that the new pc is calculated correctly for JAL instructions. builder.send_alu( AB::Expr::from_canonical_u32(Opcode::ADD as u32), - jump_columns.next_pc, - jump_columns.pc, + jump_columns.next_pc.value, + jump_columns.pc.value, local.op_b_val(), local.shard, local.channel, @@ -189,7 +198,7 @@ impl CpuChip { // Verify that the new pc is calculated correctly for JALR instructions. builder.send_alu( AB::Expr::from_canonical_u32(Opcode::ADD as u32), - jump_columns.next_pc, + jump_columns.next_pc.value, local.op_b_val(), local.op_c_val(), local.shard, @@ -206,13 +215,20 @@ impl CpuChip { // Verify that the word form of local.pc is correct. builder .when(local.selectors.is_auipc) - .assert_eq(auipc_columns.pc.reduce::(), local.pc); + .assert_eq(auipc_columns.pc.value.reduce::(), local.pc); + + // Range check the pc. + BabyBearWord::::range_check( + builder, + auipc_columns.pc, + local.selectors.is_auipc.into(), + ); // Verify that op_a == pc + op_b. builder.send_alu( AB::Expr::from_canonical_u32(Opcode::ADD as u32), local.op_a_val(), - auipc_columns.pc, + auipc_columns.pc.value, local.op_b_val(), local.shard, local.channel, @@ -288,17 +304,15 @@ impl CpuChip { next: &CpuCols, is_branch_instruction: AB::Expr, ) { - // Verify that if is_sequential_instr is true, assert that local.is_real is true. - // This is needed for the following constraint, which is already degree 3. - builder - .when(local.is_sequential_instr) - .assert_one(local.is_real); - - // When is_sequential_instr is true, assert that instruction is not branch, jump, or halt. - // Note that the condition `when(local_is_real)` is implied from the previous constraint. + // Verify the is_sequential_instr flag. let is_halt = self.get_is_halt_syscall::(builder, local); - builder.when(local.is_sequential_instr).assert_zero( - is_branch_instruction + local.selectors.is_jal + local.selectors.is_jalr + is_halt, + builder.when(local.is_real).assert_eq( + local.is_sequential_instr, + AB::Expr::one() + - (is_branch_instruction + + local.selectors.is_jal + + local.selectors.is_jalr + + is_halt), ); // Verify that the pc increments by 4 for all instructions except branch, jump and halt instructions. diff --git a/core/src/cpu/columns/auipc.rs b/core/src/cpu/columns/auipc.rs index a6eb410e7c..e1ba4f9745 100644 --- a/core/src/cpu/columns/auipc.rs +++ b/core/src/cpu/columns/auipc.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::air::Word; +use crate::operations::BabyBearWord; pub const NUM_AUIPC_COLS: usize = size_of::>(); @@ -9,5 +9,5 @@ pub const NUM_AUIPC_COLS: usize = size_of::>(); #[repr(C)] pub struct AuipcCols { /// The current program counter. - pub pc: Word, + pub pc: BabyBearWord, } diff --git a/core/src/cpu/columns/branch.rs b/core/src/cpu/columns/branch.rs index 06a77ad306..39adb51259 100644 --- a/core/src/cpu/columns/branch.rs +++ b/core/src/cpu/columns/branch.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::air::Word; +use crate::operations::BabyBearWord; pub const NUM_BRANCH_COLS: usize = size_of::>(); @@ -10,10 +10,10 @@ pub const NUM_BRANCH_COLS: usize = size_of::>(); #[repr(C)] pub struct BranchCols { /// The current program counter. - pub pc: Word, + pub pc: BabyBearWord, /// The next program counter. - pub next_pc: Word, + pub next_pc: BabyBearWord, /// Whether a equals b. pub a_eq_b: T, diff --git a/core/src/cpu/columns/jump.rs b/core/src/cpu/columns/jump.rs index ca94f3ecac..e339e99a93 100644 --- a/core/src/cpu/columns/jump.rs +++ b/core/src/cpu/columns/jump.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::air::Word; +use crate::operations::BabyBearWord; pub const NUM_JUMP_COLS: usize = size_of::>(); @@ -9,8 +9,8 @@ pub const NUM_JUMP_COLS: usize = size_of::>(); #[repr(C)] pub struct JumpCols { /// The current program counter. - pub pc: Word, + pub pc: BabyBearWord, /// THe next program counter. - pub next_pc: Word, + pub next_pc: BabyBearWord, } diff --git a/core/src/cpu/columns/memory.rs b/core/src/cpu/columns/memory.rs index fc54de34c4..f1ac91a582 100644 --- a/core/src/cpu/columns/memory.rs +++ b/core/src/cpu/columns/memory.rs @@ -18,6 +18,8 @@ pub struct MemoryColumns { // Note that this all needs to be verified in the AIR pub addr_word: Word, pub addr_aligned: T, + /// The LE bit decomp of the least significant byte of address aligned. + pub aa_least_sig_byte_decomp: [T; 6], pub addr_offset: T, pub memory_access: MemoryReadWriteCols, diff --git a/core/src/cpu/trace.rs b/core/src/cpu/trace.rs index aa71f130f5..a130513cd7 100644 --- a/core/src/cpu/trace.rs +++ b/core/src/cpu/trace.rs @@ -1,3 +1,4 @@ +use std::array; use std::borrow::BorrowMut; use std::collections::HashMap; @@ -177,13 +178,11 @@ impl CpuChip { self.populate_auipc(cols, event, &mut new_alu_events); let is_halt = self.populate_ecall(cols, event); - if !event.instruction.is_branch_instruction() - && !event.instruction.is_jump_instruction() - && !event.instruction.is_ecall_instruction() - && !is_halt - { - cols.is_sequential_instr = F::one(); - } + cols.is_sequential_instr = F::from_bool( + !event.instruction.is_branch_instruction() + && !event.instruction.is_jump_instruction() + && !is_halt, + ); // Assert that the instruction is not a no-op. cols.is_real = F::one(); @@ -261,9 +260,15 @@ impl CpuChip { // Populate addr_word and addr_aligned columns. let memory_columns = cols.opcode_specific_columns.memory_mut(); let memory_addr = event.b.wrapping_add(event.c); + let aligned_addr = memory_addr - memory_addr % WORD_SIZE as u32; memory_columns.addr_word = memory_addr.into(); - memory_columns.addr_aligned = - F::from_canonical_u32(memory_addr - memory_addr % WORD_SIZE as u32); + memory_columns.addr_aligned = F::from_canonical_u32(aligned_addr); + + // Populate the aa_least_sig_byte_decomp columns. + assert!(aligned_addr % 4 == 0); + let aligned_addr_ls_byte = (aligned_addr & 0x000000FF) as u8; + let bits: [bool; 8] = array::from_fn(|i| aligned_addr_ls_byte & (1 << i) != 0); + memory_columns.aa_least_sig_byte_decomp = array::from_fn(|i| F::from_bool(bits[i + 2])); // Add event to ALU check to check that addr == b + c let add_event = AluEvent { @@ -442,8 +447,8 @@ impl CpuChip { let next_pc = event.pc.wrapping_add(event.c); cols.branching = F::one(); - branch_columns.pc = event.pc.into(); - branch_columns.next_pc = next_pc.into(); + branch_columns.pc.populate(event.pc); + branch_columns.next_pc.populate(next_pc); let add_event = AluEvent { shard: event.shard, @@ -478,8 +483,8 @@ impl CpuChip { match event.instruction.opcode { Opcode::JAL => { let next_pc = event.pc.wrapping_add(event.b); - jump_columns.pc = event.pc.into(); - jump_columns.next_pc = next_pc.into(); + jump_columns.pc.populate(event.pc); + jump_columns.next_pc.populate(next_pc); let add_event = AluEvent { shard: event.shard, @@ -498,7 +503,7 @@ impl CpuChip { } Opcode::JALR => { let next_pc = event.b.wrapping_add(event.c); - jump_columns.next_pc = next_pc.into(); + jump_columns.next_pc.populate(next_pc); let add_event = AluEvent { shard: event.shard, @@ -530,7 +535,7 @@ impl CpuChip { if matches!(event.instruction.opcode, Opcode::AUIPC) { let auipc_columns = cols.opcode_specific_columns.auipc_mut(); - auipc_columns.pc = event.pc.into(); + auipc_columns.pc.populate(event.pc); let add_event = AluEvent { shard: event.shard, diff --git a/core/src/memory/global.rs b/core/src/memory/global.rs index 60a27b0277..7749361c1d 100644 --- a/core/src/memory/global.rs +++ b/core/src/memory/global.rs @@ -131,6 +131,8 @@ where let local = main.row_slice(0); let local: &MemoryInitCols = (*local).borrow(); + builder.assert_bool(local.is_real); + if self.kind == MemoryChipType::Initialize { let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), local.addr.into()]; values.extend(local.value.map(Into::into)); @@ -158,9 +160,7 @@ where // and Finalize global memory chip is for register %x0 (i.e. addr = 0x0), and that those rows // have a value of 0. Additionally, in the CPU air, we ensure that whenever op_a is set to // %x0, its value is 0. - // - // TODO: Add a similar check for MemoryChipType::Initialize. - if self.kind == MemoryChipType::Finalize { + if self.kind == MemoryChipType::Initialize || self.kind == MemoryChipType::Finalize { builder.when_first_row().assert_zero(local.addr); builder.when_first_row().assert_word_zero(local.value); } diff --git a/core/src/operations/baby_bear_word.rs b/core/src/operations/baby_bear_word.rs new file mode 100644 index 0000000000..34754a957c --- /dev/null +++ b/core/src/operations/baby_bear_word.rs @@ -0,0 +1,72 @@ +use std::array; + +use p3_air::AirBuilder; +use p3_field::{AbstractField, Field}; +use sp1_derive::AlignedBorrow; + +use crate::{air::Word, stark::SP1AirBuilder}; + +/// A set of columns needed to compute the add of two words. +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct BabyBearWord { + /// The babybear element in word format. + pub value: Word, + + /// Most sig byte LE bit decomposition. + pub most_sig_byte_decomp: [T; 8], +} + +impl BabyBearWord { + pub fn populate(&mut self, value: u32) { + self.value = value.into(); + self.most_sig_byte_decomp = array::from_fn(|i| F::from_bool(value & (1 << (i + 24)) != 0)); + } + + pub fn range_check( + builder: &mut AB, + cols: BabyBearWord, + is_real: AB::Expr, + ) { + let mut recomposed_byte = AB::Expr::zero(); + cols.most_sig_byte_decomp + .iter() + .enumerate() + .for_each(|(i, value)| { + builder.when(is_real.clone()).assert_bool(*value); + + recomposed_byte = + recomposed_byte.clone() + AB::Expr::from_canonical_usize(1 << i) * *value; + }); + + builder + .when(is_real.clone()) + .assert_eq(recomposed_byte, cols.value[3]); + + // Range check that value is less than baby bear modulus. To do this, it is sufficient + // to just do comparisons for the most significant byte. BabyBear's modulus is (in big endian binary) + // 1111000_00000000_00000000_00000001. So we need to check the following conditions: + // 1) if most_sig_byte > 1111000, then fail. + // 2) if most_sig_byte == 11110000, then value's lower sig bytes must all be 0. + // 3) if most_sig_byte < 11110000, then pass. + + // Flag to see if the four most significant bits of value's most significant byte is set. + let most_sig_bits: AB::Expr = cols.most_sig_byte_decomp[4..8] + .iter() + .map(|bit| (*bit).into()) + .sum(); + // Flag to see if the four least significant bits of value's most significant byte is set. + let least_sig_bits: AB::Expr = cols.most_sig_byte_decomp[0..4] + .iter() + .map(|bit| (*bit).into()) + .sum(); + builder + .when(is_real.clone()) + .when(most_sig_bits.clone()) + .assert_zero(least_sig_bits); + builder + .when(is_real) + .when(most_sig_bits) + .assert_zero(cols.value[0] + cols.value[1] + cols.value[2]); + } +} diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index 242c9100b1..006f3508e7 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -8,6 +8,7 @@ mod add; mod add4; mod add5; mod and; +mod baby_bear_word; pub mod field; mod fixed_rotate_right; mod fixed_shift_right; @@ -22,6 +23,7 @@ pub use add::*; pub use add4::*; pub use add5::*; pub use and::*; +pub use baby_bear_word::*; pub use fixed_rotate_right::*; pub use fixed_shift_right::*; pub use is_equal_word::*; diff --git a/recursion/core/src/air/builder.rs b/recursion/core/src/air/builder.rs index 6bcf20d408..5dfa895676 100644 --- a/recursion/core/src/air/builder.rs +++ b/recursion/core/src/air/builder.rs @@ -33,6 +33,8 @@ pub trait RecursionMemoryAirBuilder: RecursionInteractionAirBuilder { let timestamp: Self::Expr = timestamp.into(); let mem_access = memory_access.access(); + self.assert_bool(is_real.clone()); + self.eval_memory_access_timestamp(timestamp.clone(), mem_access, is_real.clone()); let addr = addr.into(); @@ -69,6 +71,8 @@ pub trait RecursionMemoryAirBuilder: RecursionInteractionAirBuilder { let timestamp: Self::Expr = timestamp.into(); let mem_access = memory_access.access(); + self.assert_bool(is_real.clone()); + self.eval_memory_access_timestamp(timestamp.clone(), mem_access, is_real.clone()); let addr = addr.into(); diff --git a/recursion/core/src/cpu/air/branch.rs b/recursion/core/src/cpu/air/branch.rs index 91bebfa65e..105dd771da 100644 --- a/recursion/core/src/cpu/air/branch.rs +++ b/recursion/core/src/cpu/air/branch.rs @@ -3,7 +3,9 @@ use p3_field::{AbstractField, Field}; use sp1_core::air::{BinomialExtension, ExtensionAirBuilder}; use crate::{ - air::{BinomialExtensionUtils, IsExtZeroOperation, SP1RecursionAirBuilder}, + air::{ + BinomialExtensionUtils, Block, BlockBuilder, IsExtZeroOperation, SP1RecursionAirBuilder, + }, cpu::{CpuChip, CpuCols}, memory::MemoryCols, }; @@ -22,18 +24,24 @@ impl CpuChip { let is_branch_instruction = self.is_branch_instruction::(local); let one = AB::Expr::one(); - // If the instruction is a BNEINC, verify that the a value is incremented by one. - builder - .when(local.is_real) - .when(local.selectors.is_bneinc) - .assert_eq(local.a.value()[0], local.a.prev_value()[0] + one.clone()); - // Convert operand values from Block to BinomialExtension. Note that it gets the // previous value of the `a` and `b` operands, since BNENIC will modify `a`. + let a_prev_ext: BinomialExtension = + BinomialExtensionUtils::from_block(local.a.prev_value().map(|x| x.into())); let a_ext: BinomialExtension = BinomialExtensionUtils::from_block(local.a.value().map(|x| x.into())); let b_ext: BinomialExtension = BinomialExtensionUtils::from_block(local.b.value().map(|x| x.into())); + let one_ext: BinomialExtension = + BinomialExtensionUtils::from_block(Block::from(one.clone())); + + let expected_a_ext = a_prev_ext + one_ext; + + // If the instruction is a BNEINC, verify that the a value is incremented by one. + builder + .when(local.is_real) + .when(local.selectors.is_bneinc) + .assert_block_eq(a_ext.as_block(), expected_a_ext.as_block()); let comparison_diff = a_ext - b_ext; diff --git a/recursion/core/src/cpu/air/jump.rs b/recursion/core/src/cpu/air/jump.rs index bf86a70cce..dd5e9b8bba 100644 --- a/recursion/core/src/cpu/air/jump.rs +++ b/recursion/core/src/cpu/air/jump.rs @@ -2,7 +2,7 @@ use p3_air::AirBuilder; use p3_field::{AbstractField, Field}; use crate::{ - air::SP1RecursionAirBuilder, + air::{Block, BlockBuilder, SP1RecursionAirBuilder}, cpu::{CpuChip, CpuCols}, memory::MemoryCols, runtime::STACK_SIZE, @@ -21,19 +21,29 @@ impl CpuChip { ) where AB: SP1RecursionAirBuilder, { + let is_jump_instr = self.is_jump_instruction::(local); + // Verify the next row's fp. builder .when_first_row() .assert_eq(local.fp, F::from_canonical_usize(STACK_SIZE)); - let not_jump_instruction = AB::Expr::one() - self.is_jump_instruction::(local); + let not_jump_instruction = AB::Expr::one() - is_jump_instr.clone(); let expected_next_fp = local.selectors.is_jal * (local.fp + local.c.value()[0]) - + local.selectors.is_jalr * local.a.value()[0] + + local.selectors.is_jalr * local.c.value()[0] + not_jump_instruction * local.fp; builder .when_transition() .when(next.is_real) .assert_eq(next.fp, expected_next_fp); + // Verify the a operand values. + let expected_a_val = local.selectors.is_jal * local.pc + + local.selectors.is_jalr * (local.pc + AB::Expr::one()); + let expected_a_val_block = Block::from(expected_a_val); + builder + .when(is_jump_instr) + .assert_block_eq(*local.a.value(), expected_a_val_block); + // Add to the `next_pc` expression. *next_pc += local.selectors.is_jal * (local.pc + local.b.value()[0]); *next_pc += local.selectors.is_jalr * local.b.value()[0]; diff --git a/recursion/core/src/cpu/air/memory.rs b/recursion/core/src/cpu/air/memory.rs index c0a3a2b639..d1b024130f 100644 --- a/recursion/core/src/cpu/air/memory.rs +++ b/recursion/core/src/cpu/air/memory.rs @@ -30,7 +30,7 @@ impl CpuChip { local.clk + AB::F::from_canonical_u32(MemoryAccessPosition::Memory as u32), memory_cols.memory_addr, &memory_cols.memory, - is_memory_instr, + is_memory_instr.clone(), ); // Constraints on the memory column depending on load or store. @@ -41,7 +41,7 @@ impl CpuChip { ); // When there is a store, we ensure that we are writing the value of the a operand to the memory. builder - .when(local.selectors.is_store) + .when(is_memory_instr) .assert_block_eq(*local.a.value(), *memory_cols.memory.value()); } } diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index 23173f7de9..10a9c3db0d 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -190,8 +190,7 @@ where local.poseidon2_receive_table, ); sub_builder.assert_eq( - local.is_poseidon2 - * Poseidon2Chip::do_memory_access::(poseidon2_columns), + local.is_poseidon2 * Poseidon2Chip::do_memory_access::(poseidon2_columns), local.poseidon2_memory_access, ); @@ -201,7 +200,7 @@ where local.poseidon2(), next.poseidon2(), local.poseidon2_receive_table, - local.poseidon2_memory_access.into(), + local.poseidon2_memory_access, ); } } diff --git a/recursion/core/src/poseidon2/columns.rs b/recursion/core/src/poseidon2/columns.rs index 12fa730477..9194cb1add 100644 --- a/recursion/core/src/poseidon2/columns.rs +++ b/recursion/core/src/poseidon2/columns.rs @@ -11,7 +11,10 @@ pub struct Poseidon2Cols { pub left_input: T, pub right_input: T, pub rounds: [T; 24], // 1 round for memory input; 1 round for initialize; 8 rounds for external; 13 rounds for internal; 1 round for memory output + pub do_receive: T, + pub do_memory: T, pub round_specific_cols: RoundSpecificCols, + pub is_real: T, } #[derive(AlignedBorrow, Clone, Copy)] diff --git a/recursion/core/src/poseidon2/external.rs b/recursion/core/src/poseidon2/external.rs index c871bd873d..96a8d77449 100644 --- a/recursion/core/src/poseidon2/external.rs +++ b/recursion/core/src/poseidon2/external.rs @@ -6,7 +6,6 @@ use p3_field::AbstractField; use p3_matrix::Matrix; use sp1_core::air::{BaseAirBuilder, ExtensionAirBuilder, SP1AirBuilder}; use sp1_primitives::RC_16_30_U32; -use std::ops::Add; use crate::air::{RecursionInteractionAirBuilder, RecursionMemoryAirBuilder}; use crate::memory::MemoryCols; @@ -40,7 +39,7 @@ impl Poseidon2Chip { local: &Poseidon2Cols, next: &Poseidon2Cols, receive_table: AB::Var, - memory_access: AB::Expr, + memory_access: AB::Var, ) { const NUM_ROUNDS_F: usize = 8; const NUM_ROUNDS_P: usize = 13; @@ -66,6 +65,10 @@ impl Poseidon2Chip { .sum::(); let is_memory_write = local.rounds[local.rounds.len() - 1]; + self.eval_control_flow_and_inputs(builder, local, next); + + self.eval_syscall(builder, local, receive_table); + self.eval_mem( builder, local, @@ -84,16 +87,71 @@ impl Poseidon2Chip { is_internal_layer.clone(), NUM_ROUNDS_F + NUM_ROUNDS_P + 1, ); + } - self.eval_syscall(builder, local, receive_table); - - // Range check all flags. - for i in 0..local.rounds.len() { + fn eval_control_flow_and_inputs( + &self, + builder: &mut AB, + local: &Poseidon2Cols, + next: &Poseidon2Cols, + ) { + let num_total_rounds = local.rounds.len(); + for i in 0..num_total_rounds { + // Verify that the round flags are boolean. builder.assert_bool(local.rounds[i]); + + if i != num_total_rounds - 1 { + // Verify that the round flags cycle. + builder + .when_transition() + .assert_eq(local.rounds[i], next.rounds[i + 1]); + + // Verify that the clk, dst_input, left_input, and right_input values are the same + // within a permutation. + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.clk, next.clk); + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.dst_input, next.dst_input); + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.left_input, next.left_input); + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.right_input, next.right_input); + } } - builder.assert_bool( - is_memory_read + is_initial + is_external_layer + is_internal_layer + is_memory_write, - ); + // Ensure that at most one of the round flags is set. + let round_acc = local + .rounds + .iter() + .fold(AB::Expr::zero(), |acc, round_flag| acc + *round_flag); + builder.assert_bool(round_acc); + + // Verify the do_memory flag. + builder + .when(local.is_real) + .assert_eq(local.do_memory, local.rounds[0] + local.rounds[23]); + // Verify the do_receive flag. + builder + .when(local.is_real) + .assert_eq(local.do_receive, local.rounds[0]); + // Verify the first row starts at round 0. + builder.when_first_row().assert_one(local.rounds[0]); + // The round count is not a power of 2, so the last row should not be real. + builder.when_last_row().assert_zero(local.is_real); + + // Verify that all is_real flags within a round are equal. + let is_last_round = local.rounds[23]; + builder + .when_transition() + .when_not(is_last_round) + .assert_eq(local.is_real, next.is_real); } fn eval_mem( @@ -103,20 +161,23 @@ impl Poseidon2Chip { next: &Poseidon2Cols, is_memory_read: AB::Var, is_memory_write: AB::Var, - memory_access: AB::Expr, + memory_access: AB::Var, ) { let memory_access_cols = local.round_specific_cols.memory_access(); builder + .when(local.is_real) .when(is_memory_read) .assert_eq(local.left_input, memory_access_cols.addr_first_half); builder + .when(local.is_real) .when(is_memory_read) .assert_eq(local.right_input, memory_access_cols.addr_second_half); builder + .when(local.is_real) .when(is_memory_write) .assert_eq(local.dst_input, memory_access_cols.addr_first_half); - builder.when(is_memory_write).assert_eq( + builder.when(local.is_real).when(is_memory_write).assert_eq( local.dst_input + AB::F::from_canonical_usize(WIDTH / 2), memory_access_cols.addr_second_half, ); @@ -131,7 +192,11 @@ impl Poseidon2Chip { local.clk + AB::Expr::one() * is_memory_write, addr, &memory_access_cols.mem_access[i], - memory_access.clone(), + memory_access, + ); + builder.when(local.is_real).when(is_memory_read).assert_eq( + *memory_access_cols.mem_access[i].value(), + *memory_access_cols.mem_access[i].prev_value(), ); } @@ -139,10 +204,14 @@ impl Poseidon2Chip { // computation round. let next_computation_col = next.round_specific_cols.computation(); for i in 0..WIDTH { - builder.when_transition().when(is_memory_read).assert_eq( - *memory_access_cols.mem_access[i].value(), - next_computation_col.input[i], - ); + builder + .when_transition() + .when(local.is_real) + .when(is_memory_read) + .assert_eq( + *memory_access_cols.mem_access[i].value(), + next_computation_col.input[i], + ); } } @@ -184,6 +253,7 @@ impl Poseidon2Chip { } } builder + .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(result, computation_cols.add_rc[i]); } @@ -198,6 +268,7 @@ impl Poseidon2Chip { * computation_cols.add_rc[i]; let sbox_deg_7 = sbox_deg_3.clone() * sbox_deg_3.clone() * computation_cols.add_rc[i]; builder + .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(sbox_deg_7, computation_cols.sbox_deg_7[i]); } @@ -253,6 +324,7 @@ impl Poseidon2Chip { for i in 0..WIDTH { state[i] += sums[i % 4].clone(); builder + .when(local.is_real) .when(is_external_layer.clone() + is_initial.clone()) .assert_eq(state[i].clone(), computation_cols.output[i]); } @@ -264,6 +336,7 @@ impl Poseidon2Chip { let mut state: [AB::Expr; WIDTH] = sbox_result.clone(); internal_linear_layer(&mut state); builder + .when(local.is_real) .when(is_internal_layer.clone()) .assert_all_eq(state.clone(), computation_cols.output); } @@ -281,6 +354,7 @@ impl Poseidon2Chip { builder .when_transition() + .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(computation_cols.output[i], next_round_value); } @@ -307,13 +381,11 @@ impl Poseidon2Chip { } pub const fn do_receive_table(local: &Poseidon2Cols) -> T { - local.rounds[0] + local.do_receive } - pub fn do_memory_access, Output>( - local: &Poseidon2Cols, - ) -> Output { - local.rounds[0] + local.rounds[23] + pub fn do_memory_access(local: &Poseidon2Cols) -> T { + local.do_memory } } @@ -333,7 +405,7 @@ where local, next, Self::do_receive_table::(local), - Self::do_memory_access::(local), + Self::do_memory_access::(local), ); } } diff --git a/recursion/core/src/poseidon2/trace.rs b/recursion/core/src/poseidon2/trace.rs index cc6a41d94f..f86ad71163 100644 --- a/recursion/core/src/poseidon2/trace.rs +++ b/recursion/core/src/poseidon2/trace.rs @@ -49,7 +49,9 @@ impl MachineAir for Poseidon2Chip { for r in 0..rounds { let mut row = [F::zero(); NUM_POSEIDON2_COLS]; let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); + cols.is_real = F::one(); + let is_receive = r == 0; let is_memory_read = r == 0; let is_initial_layer = r == 1; let is_external_layer = @@ -78,6 +80,10 @@ impl MachineAir for Poseidon2Chip { cols.right_input = poseidon2_event.right; cols.rounds[r] = F::one(); + if is_receive { + cols.do_receive = F::one(); + } + if is_memory_read || is_memory_write { let memory_access_cols = cols.round_specific_cols.memory_access_mut(); @@ -97,6 +103,7 @@ impl MachineAir for Poseidon2Chip { .populate(&poseidon2_event.result_records[i]); } } + cols.do_memory = F::one(); } else { let computation_cols = cols.round_specific_cols.computation_mut(); @@ -163,6 +170,8 @@ impl MachineAir for Poseidon2Chip { } } + let num_real_rows = rows.len(); + // Pad the trace to a power of two. pad_rows_fixed( &mut rows, @@ -170,6 +179,14 @@ impl MachineAir for Poseidon2Chip { self.fixed_log2_rows, ); + let mut round_num = 0; + for row in rows[num_real_rows..].iter_mut() { + let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); + cols.rounds[round_num] = F::one(); + + round_num = (round_num + 1) % rounds; + } + // Convert the trace to a row major matrix. RowMajorMatrix::new( rows.into_iter().flatten().collect::>(),