From 623608054cf507e036c939fff65e2ededd4c0e26 Mon Sep 17 00:00:00 2001 From: Leo Alt Date: Mon, 18 Nov 2024 16:10:18 +0100 Subject: [PATCH 1/8] keccak32 --- std/machines/hash/keccakf32_memory.asm | 787 +++++++++++++++++++++++++ 1 file changed, 787 insertions(+) create mode 100644 std/machines/hash/keccakf32_memory.asm diff --git a/std/machines/hash/keccakf32_memory.asm b/std/machines/hash/keccakf32_memory.asm new file mode 100644 index 0000000000..d8b16f51bb --- /dev/null +++ b/std/machines/hash/keccakf32_memory.asm @@ -0,0 +1,787 @@ +use std::array; +use std::utils; +use std::utils::unchanged_until; +use std::utils::force_bool; +use std::convert::expr; +use std::convert::int; +use std::convert::fe; +use std::prelude::set_hint; +use std::prelude::Query; +use std::prover::eval; +use std::prover::provide_value; +use std::machines::large_field::memory::Memory; + +machine Keccakf32Memory(mem: Memory) with + latch: final_step, + operation_id: operation_id, + call_selectors: sel, +{ + /* + ------------- Begin memory read / write --------------- + Additional columns compared to the non-memory version: + - 1 column for user input address (of first byte of input). + - 1 column for user output address (of first byte of output). + - 49 columns for computed input/output address of all bytes. + Overall, given that there are 2,600+ columns in the non-memory version, this isn't a huge cost + Methodology description: + 1. The latch with the input and output addresses + time step is in the last row of each block. + 2. User input address is copied to the first row. + 3. Input addresses for all bytes are calculated from user input address in the first row. + 4. Load all input bytes from memory to the preimage columns. + 5. Keccak is computed from top to bottom. + 6. Output addresses for all bytes are calculated from user output address in the last row. + 7. Store all output bytes from keccak computation columns to memory. + Note that this methodology reuses the same 49 columns of addr to calculate input and output addresses of all bytes. + However, these 49 columns are only used in the first and last rows of each block. + Essentially, we conduct all memory reads in the first row and all memory writes in the last row. + This might seem wasteful, but it's still cheaper than spreading memory reads/writes over different rows while using as few columns as possible, + which requires 100 columns to make outputs available in all rows in additional to the memory columns. + This alternative methodology (memory reads/writes over different rows) also doesn't work well with our auto witgen infrastructure, + because one would need to do memory read -> keccak computation -> memory write as three sequential passes during witgen. + On the contrary, our current methodology performs all memory reads at once in the first row, then immediately does the keccak computation, + and finally performs all memory writes at once in the last row, and thus only requires one pass with auto witgen. + Though note that input address need to be first copied from the last row to the first row. + */ + + operation keccakf32_memory<0> input_addr, output_addr, time_step ->; + + // Get an intermediate column that indicates that we're in an + // actual block, not a default block. Its value is constant + // within the block. + let used = array::sum(sel); + array::map(sel, |s| unchanged_until(s, final_step + is_last)); + std::utils::force_bool(used); + let first_step_used: expr = used * first_step; + let final_step_used: expr = used * final_step; + + // Repeat the time step and input address in the whole block. + col witness time_step; + unchanged_until(time_step, final_step + is_last); + + // Input address for the first byte of input array from the user. + // Copied from user input in the last row to the first row. + col witness input_addr; + unchanged_until(input_addr, final_step + is_last); + + // Output address for the first byte of output array from the user. + // Used in the last row directly from user input. + col witness output_addr; + + // Load memory while converting to little endian format for keccak computation. + // Specifically, this keccakf32 machine accepts big endian inputs in memory. + // However the keccak computation constraints are written for little endian iputs. + // Therefore memory load converts big endian inputs to little endian for the preimage. + link if first_step_used ~> preimage[1] = mem.mload(input_addr, time_step); + link if first_step_used ~> preimage[0] = mem.mload(input_addr + 4, time_step); + link if first_step_used ~> preimage[3] = mem.mload(input_addr + 12, time_step); + link if first_step_used ~> preimage[2] = mem.mload(input_addr + 8, time_step); + link if first_step_used ~> preimage[5] = mem.mload(input_addr + 20, time_step); + link if first_step_used ~> preimage[4] = mem.mload(input_addr + 16, time_step); + link if first_step_used ~> preimage[7] = mem.mload(input_addr + 28, time_step); + link if first_step_used ~> preimage[6] = mem.mload(input_addr + 24, time_step); + link if first_step_used ~> preimage[9] = mem.mload(input_addr + 36, time_step); + link if first_step_used ~> preimage[8] = mem.mload(input_addr + 32, time_step); + link if first_step_used ~> preimage[11] = mem.mload(input_addr + 44, time_step); + link if first_step_used ~> preimage[10] = mem.mload(input_addr + 40, time_step); + link if first_step_used ~> preimage[13] = mem.mload(input_addr + 52, time_step); + link if first_step_used ~> preimage[12] = mem.mload(input_addr + 48, time_step); + link if first_step_used ~> preimage[15] = mem.mload(input_addr + 60, time_step); + link if first_step_used ~> preimage[14] = mem.mload(input_addr + 56, time_step); + link if first_step_used ~> preimage[17] = mem.mload(input_addr + 68, time_step); + link if first_step_used ~> preimage[16] = mem.mload(input_addr + 64, time_step); + link if first_step_used ~> preimage[19] = mem.mload(input_addr + 76, time_step); + link if first_step_used ~> preimage[18] = mem.mload(input_addr + 72, time_step); + link if first_step_used ~> preimage[21] = mem.mload(input_addr + 84, time_step); + link if first_step_used ~> preimage[20] = mem.mload(input_addr + 80, time_step); + link if first_step_used ~> preimage[23] = mem.mload(input_addr + 92, time_step); + link if first_step_used ~> preimage[22] = mem.mload(input_addr + 88, time_step); + link if first_step_used ~> preimage[25] = mem.mload(input_addr + 100, time_step); + link if first_step_used ~> preimage[24] = mem.mload(input_addr + 96, time_step); + link if first_step_used ~> preimage[27] = mem.mload(input_addr + 108, time_step); + link if first_step_used ~> preimage[26] = mem.mload(input_addr + 104, time_step); + link if first_step_used ~> preimage[29] = mem.mload(input_addr + 116, time_step); + link if first_step_used ~> preimage[28] = mem.mload(input_addr + 112, time_step); + link if first_step_used ~> preimage[31] = mem.mload(input_addr + 124, time_step); + link if first_step_used ~> preimage[30] = mem.mload(input_addr + 120, time_step); + link if first_step_used ~> preimage[33] = mem.mload(input_addr + 132, time_step); + link if first_step_used ~> preimage[32] = mem.mload(input_addr + 128, time_step); + link if first_step_used ~> preimage[35] = mem.mload(input_addr + 140, time_step); + link if first_step_used ~> preimage[34] = mem.mload(input_addr + 136, time_step); + link if first_step_used ~> preimage[37] = mem.mload(input_addr + 148, time_step); + link if first_step_used ~> preimage[36] = mem.mload(input_addr + 144, time_step); + link if first_step_used ~> preimage[39] = mem.mload(input_addr + 156, time_step); + link if first_step_used ~> preimage[38] = mem.mload(input_addr + 152, time_step); + link if first_step_used ~> preimage[41] = mem.mload(input_addr + 164, time_step); + link if first_step_used ~> preimage[40] = mem.mload(input_addr + 160, time_step); + link if first_step_used ~> preimage[43] = mem.mload(input_addr + 172, time_step); + link if first_step_used ~> preimage[42] = mem.mload(input_addr + 168, time_step); + link if first_step_used ~> preimage[45] = mem.mload(input_addr + 180, time_step); + link if first_step_used ~> preimage[44] = mem.mload(input_addr + 176, time_step); + link if first_step_used ~> preimage[47] = mem.mload(input_addr + 188, time_step); + link if first_step_used ~> preimage[46] = mem.mload(input_addr + 184, time_step); + link if first_step_used ~> preimage[49] = mem.mload(input_addr + 196, time_step); + link if first_step_used ~> preimage[48] = mem.mload(input_addr + 192, time_step); + + // Expects input of 25 64-bit numbers decomposed to 25 chunks of 2 32-bit little endian limbs. + // The output is a_prime_prime_prime_0_0_limbs for the first 2 and a_prime_prime for the rest. + + // Write memory while converting output to big endian format. + // Specifically, output obtained from the keccak computation are little endian. + // However, this keccakf16_memory machine produces big endian outputs in memory. + // Therefore, memory write converts little endian from keccak computation to big endian for the output in memory. + link if final_step_used ~> mem.mstore(output_addr, time_step, a_prime_prime_prime_0_0_limbs[1]); + link if final_step_used ~> mem.mstore(output_addr + 4, time_step, a_prime_prime_prime_0_0_limbs[0]); + link if final_step_used ~> mem.mstore(output_addr + 8, time_step, a_prime_prime_prime_0_0_limbs[3]); + link if final_step_used ~> mem.mstore(output_addr + 12, time_step, a_prime_prime_prime_0_0_limbs[2]); + link if final_step_used ~> mem.mstore(output_addr + 16, time_step, a_prime_prime_prime_0_0_limbs[5]); + link if final_step_used ~> mem.mstore(output_addr + 20, time_step, a_prime_prime_prime_0_0_limbs[4]); + link if final_step_used ~> mem.mstore(output_addr + 24, time_step, a_prime_prime_prime_0_0_limbs[7]); + link if final_step_used ~> mem.mstore(output_addr + 28, time_step, a_prime_prime_prime_0_0_limbs[6]); + link if final_step_used ~> mem.mstore(output_addr + 32, time_step, a_prime_prime_prime_0_0_limbs[9]); + link if final_step_used ~> mem.mstore(output_addr + 36, time_step, a_prime_prime_prime_0_0_limbs[8]); + link if final_step_used ~> mem.mstore(output_addr + 40, time_step, a_prime_prime_prime_0_0_limbs[11]); + link if final_step_used ~> mem.mstore(output_addr + 44, time_step, a_prime_prime_prime_0_0_limbs[10]); + link if final_step_used ~> mem.mstore(output_addr + 48, time_step, a_prime_prime_prime_0_0_limbs[13]); + link if final_step_used ~> mem.mstore(output_addr + 52, time_step, a_prime_prime_prime_0_0_limbs[12]); + link if final_step_used ~> mem.mstore(output_addr + 56, time_step, a_prime_prime_prime_0_0_limbs[15]); + link if final_step_used ~> mem.mstore(output_addr + 60, time_step, a_prime_prime_prime_0_0_limbs[14]); + link if final_step_used ~> mem.mstore(output_addr + 64, time_step, a_prime_prime_prime_0_0_limbs[17]); + link if final_step_used ~> mem.mstore(output_addr + 68, time_step, a_prime_prime_prime_0_0_limbs[16]); + link if final_step_used ~> mem.mstore(output_addr + 72, time_step, a_prime_prime_prime_0_0_limbs[19]); + link if final_step_used ~> mem.mstore(output_addr + 76, time_step, a_prime_prime_prime_0_0_limbs[18]); + link if final_step_used ~> mem.mstore(output_addr + 80, time_step, a_prime_prime_prime_0_0_limbs[21]); + link if final_step_used ~> mem.mstore(output_addr + 84, time_step, a_prime_prime_prime_0_0_limbs[20]); + link if final_step_used ~> mem.mstore(output_addr + 88, time_step, a_prime_prime_prime_0_0_limbs[23]); + link if final_step_used ~> mem.mstore(output_addr + 92, time_step, a_prime_prime_prime_0_0_limbs[22]); + link if final_step_used ~> mem.mstore(output_addr + 96, time_step, a_prime_prime_prime_0_0_limbs[25]); + link if final_step_used ~> mem.mstore(output_addr + 100, time_step, a_prime_prime_prime_0_0_limbs[24]); + link if final_step_used ~> mem.mstore(output_addr + 104, time_step, a_prime_prime_prime_0_0_limbs[27]); + link if final_step_used ~> mem.mstore(output_addr + 108, time_step, a_prime_prime_prime_0_0_limbs[26]); + link if final_step_used ~> mem.mstore(output_addr + 112, time_step, a_prime_prime_prime_0_0_limbs[29]); + link if final_step_used ~> mem.mstore(output_addr + 116, time_step, a_prime_prime_prime_0_0_limbs[28]); + link if final_step_used ~> mem.mstore(output_addr + 120, time_step, a_prime_prime_prime_0_0_limbs[31]); + link if final_step_used ~> mem.mstore(output_addr + 124, time_step, a_prime_prime_prime_0_0_limbs[30]); + link if final_step_used ~> mem.mstore(output_addr + 128, time_step, a_prime_prime_prime_0_0_limbs[33]); + link if final_step_used ~> mem.mstore(output_addr + 132, time_step, a_prime_prime_prime_0_0_limbs[32]); + link if final_step_used ~> mem.mstore(output_addr + 136, time_step, a_prime_prime_prime_0_0_limbs[35]); + link if final_step_used ~> mem.mstore(output_addr + 140, time_step, a_prime_prime_prime_0_0_limbs[34]); + link if final_step_used ~> mem.mstore(output_addr + 144, time_step, a_prime_prime_prime_0_0_limbs[37]); + link if final_step_used ~> mem.mstore(output_addr + 148, time_step, a_prime_prime_prime_0_0_limbs[36]); + link if final_step_used ~> mem.mstore(output_addr + 152, time_step, a_prime_prime_prime_0_0_limbs[39]); + link if final_step_used ~> mem.mstore(output_addr + 156, time_step, a_prime_prime_prime_0_0_limbs[38]); + link if final_step_used ~> mem.mstore(output_addr + 160, time_step, a_prime_prime_prime_0_0_limbs[41]); + link if final_step_used ~> mem.mstore(output_addr + 164, time_step, a_prime_prime_prime_0_0_limbs[40]); + link if final_step_used ~> mem.mstore(output_addr + 168, time_step, a_prime_prime_prime_0_0_limbs[43]); + link if final_step_used ~> mem.mstore(output_addr + 172, time_step, a_prime_prime_prime_0_0_limbs[42]); + link if final_step_used ~> mem.mstore(output_addr + 176, time_step, a_prime_prime_prime_0_0_limbs[45]); + link if final_step_used ~> mem.mstore(output_addr + 180, time_step, a_prime_prime_prime_0_0_limbs[44]); + link if final_step_used ~> mem.mstore(output_addr + 184, time_step, a_prime_prime_prime_0_0_limbs[47]); + link if final_step_used ~> mem.mstore(output_addr + 188, time_step, a_prime_prime_prime_0_0_limbs[46]); + link if final_step_used ~> mem.mstore(output_addr + 192, time_step, a_prime_prime_prime_0_0_limbs[49]); + link if final_step_used ~> mem.mstore(output_addr + 196, time_step, a_prime_prime_prime_0_0_limbs[48]); + // ------------- End memory read / write --------------- + + // Adapted from Plonky3 implementation of Keccak: https://github.com/Plonky3/Plonky3/tree/main/keccak-air/src + + std::check::require_field_bits(32, || "The field modulus should be at least 2^32 - 1 to work in the keccakf32 machine."); + + col witness operation_id; + + let NUM_ROUNDS: int = 24; + + // pub struct KeccakCols { + // /// The `i`th value is set to 1 if we are in the `i`th round, otherwise 0. + // pub step_flags: [T; NUM_ROUNDS], + + // /// A register which indicates if a row should be exported, i.e. included in a multiset equality + // /// argument. Should be 1 only for certain rows which are final steps, i.e. with + // /// `step_flags[23] = 1`. + // pub export: T, + + // /// Permutation inputs, stored in y-major order. + // pub preimage: [[[T; U64_LIMBS]; 5]; 5], + + // pub a: [[[T; U64_LIMBS]; 5]; 5], + + // /// ```ignore + // /// C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]) + // /// ``` + // pub c: [[T; 64]; 5], + + // /// ```ignore + // /// C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]) + // /// ``` + // pub c_prime: [[T; 64]; 5], + + // // Note: D is inlined, not stored in the witness. + // /// ```ignore + // /// A'[x, y] = xor(A[x, y], D[x]) + // /// = xor(A[x, y], C[x - 1], ROT(C[x + 1], 1)) + // /// ``` + // pub a_prime: [[[T; 64]; 5]; 5], + + // /// ```ignore + // /// A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). + // /// ``` + // pub a_prime_prime: [[[T; U64_LIMBS]; 5]; 5], + + // /// The bits of `A''[0, 0]`. + // pub a_prime_prime_0_0_bits: [T; 64], + + // /// ```ignore + // /// A'''[0, 0, z] = A''[0, 0, z] ^ RC[k, z] + // /// ``` + // pub a_prime_prime_prime_0_0_limbs: [T; U64_LIMBS], + // } + + // @Steve are these constants correct for 32? + // It would be good to replace all these constants by named constants + col witness preimage[5 * 5 * 2]; + col witness a[5 * 5 * 2]; + col witness c[5 * 64]; + array::map(c, |i| force_bool(i)); + col witness c_prime[5 * 64]; + col witness a_prime[5 * 5 * 64]; + array::map(a_prime, |i| force_bool(i)); + col witness a_prime_prime[5 * 5 * 2]; + col witness a_prime_prime_0_0_bits[64]; + array::map(a_prime_prime_0_0_bits, |i| force_bool(i)); + col witness a_prime_prime_prime_0_0_limbs[2]; + + // Initially, the first step flag should be 1 while the others should be 0. + // builder.when_first_row().assert_one(local.step_flags[0]); + // for i in 1..NUM_ROUNDS { + // builder.when_first_row().assert_zero(local.step_flags[i]); + // } + // for i in 0..NUM_ROUNDS { + // let current_round_flag = local.step_flags[i]; + // let next_round_flag = next.step_flags[(i + 1) % NUM_ROUNDS]; + // builder + // .when_transition() + // .assert_eq(next_round_flag, current_round_flag); + // } + + let step_flags: col[NUM_ROUNDS] = array::new(NUM_ROUNDS, |i| |row| if row % NUM_ROUNDS == i { 1 } else { 0 } ); + + // let main = builder.main(); + // let (local, next) = (main.row_slice(0), main.row_slice(1)); + // let local: &KeccakCols = (*local).borrow(); + // let next: &KeccakCols = (*next).borrow(); + + // let first_step = local.step_flags[0]; + // let final_step = local.step_flags[NUM_ROUNDS - 1]; + // let not_final_step = AB::Expr::one() - final_step; + + let first_step: expr = step_flags[0]; // Aliasing instead of defining a new fixed column. + let final_step: expr = step_flags[NUM_ROUNDS - 1]; + col fixed is_last = [0]* + [1]; + + // // If this is the first step, the input A must match the preimage. + // for y in 0..5 { + // for x in 0..5 { + // for limb in 0..U64_LIMBS { + // builder + // .when(first_step) + // .assert_eq(local.preimage[y][x][limb], local.a[y][x][limb]); + // } + // } + // } + + array::zip(preimage, a, |p_i, a_i| first_step * (p_i - a_i) = 0); + + // // The export flag must be 0 or 1. + // builder.assert_bool(local.export); + + // force_bool(export); + + // // If this is not the final step, the export flag must be off. + // builder + // .when(not_final_step.clone()) + // .assert_zero(local.export); + + // not_final_step * export = 0; + + // // If this is not the final step, the local and next preimages must match. + // for y in 0..5 { + // for x in 0..5 { + // for limb in 0..U64_LIMBS { + // builder + // .when(not_final_step.clone()) + // .when_transition() + // .assert_eq(local.preimage[y][x][limb], next.preimage[y][x][limb]); + // } + // } + // } + + array::map(preimage, |p| unchanged_until(p, final_step + is_last)); + + // for x in 0..5 { + // for z in 0..64 { + // builder.assert_bool(local.c[x][z]); + // let xor = xor3_gen::( + // local.c[x][z].into(), + // local.c[(x + 4) % 5][z].into(), + // local.c[(x + 1) % 5][(z + 63) % 64].into(), + // ); + // let c_prime = local.c_prime[x][z]; + // builder.assert_eq(c_prime, xor); + // } + // } + + let andn: expr, expr -> expr = |a, b| (1 - a) * b; + let xor: expr, expr -> expr = |a, b| a + b - 2*a*b; + let xor3: expr, expr, expr -> expr = |a, b, c| xor(xor(a, b), c); + // a b c xor3 + // 0 0 0 0 + // 0 0 1 1 + // 0 1 0 1 + // 0 1 1 0 + // 1 0 0 1 + // 1 0 1 0 + // 1 1 0 0 + // 1 1 1 1 + + // @Steve are these constants correct for 32? + array::new(320, |i| { + let x = i / 64; + let z = i % 64; + c_prime[i] = xor3( + c[i], + c[((x + 4) % 5) * 64 + z], + c[((x + 1) % 5) * 64 + ((z + 63) % 64)] + ) + }); + + // // Check that the input limbs are consistent with A' and D. + // // A[x, y, z] = xor(A'[x, y, z], D[x, y, z]) + // // = xor(A'[x, y, z], C[x - 1, z], C[x + 1, z - 1]) + // // = xor(A'[x, y, z], C[x, z], C'[x, z]). + // // The last step is valid based on the identity we checked above. + // // It isn't required, but makes this check a bit cleaner. + // for y in 0..5 { + // for x in 0..5 { + // let get_bit = |z| { + // let a_prime: AB::Var = local.a_prime[y][x][z]; + // let c: AB::Var = local.c[x][z]; + // let c_prime: AB::Var = local.c_prime[x][z]; + // xor3_gen::(a_prime.into(), c.into(), c_prime.into()) + // }; + + // for limb in 0..U64_LIMBS { + // let a_limb = local.a[y][x][limb]; + // let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB) // bigger address correspond to more significant bit + // .rev() + // .fold(AB::Expr::zero(), |acc, z| { + // builder.assert_bool(local.a_prime[y][x][z]); + // acc.double() + get_bit(z) + // }); + // builder.assert_eq(computed_limb, a_limb); + // } + // } + // } + + let bits_to_value_be: expr[] -> expr = |bits_be| array::fold(bits_be, 0, |acc, e| (acc * 2 + e)); + + // @Steve are these constants correct for 32? + array::new(100, |i| { + let y = i / 20; + let x = (i / 4) % 5; + let limb = i % 4; + let get_bit: int -> expr = |z| xor3(a_prime[y * 320 + x * 64 + z], c[x * 64 + z], c_prime[x * 64 + z]); + + let limb_bits_be: expr[] = array::reverse(array::new(16, |z| get_bit(limb * 16 + z))); + a[i] = bits_to_value_be(limb_bits_be) + }); + + // // xor_{i=0}^4 A'[x, i, z] = C'[x, z], so for each x, z, + // // diff * (diff - 2) * (diff - 4) = 0, where + // // diff = sum_{i=0}^4 A'[x, i, z] - C'[x, z] + // for x in 0..5 { + // for z in 0..64 { + // let sum: AB::Expr = (0..5).map(|y| local.a_prime[y][x][z].into()).sum(); + // let diff = sum - local.c_prime[x][z]; + // let four = AB::Expr::from_canonical_u8(4); + // builder + // .assert_zero(diff.clone() * (diff.clone() - AB::Expr::two()) * (diff - four)); + // } + // } + + // @Steve are these constants correct for 32? + array::new(320, |i| { + let x = i / 64; + let z = i % 64; + let sum = utils::sum(5, |y| a_prime[y * 320 + i]); + let diff = sum - c_prime[i]; + diff * (diff - 2) * (diff - 4) = 0 + }); + + // // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). + // for y in 0..5 { + // for x in 0..5 { + // let get_bit = |z| { + // let andn = andn_gen::( + // local.b((x + 1) % 5, y, z).into(), + // local.b((x + 2) % 5, y, z).into(), + // ); + // xor_gen::(local.b(x, y, z).into(), andn) + // }; + + // for limb in 0..U64_LIMBS { + // let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB) + // .rev() + // .fold(AB::Expr::zero(), |acc, z| acc.double() + get_bit(z)); + // builder.assert_eq(computed_limb, local.a_prime_prime[y][x][limb]); + // } + // } + // } + + // @Steve are these constants correct for 32? + array::new(100, |i| { + let y = i / 20; + let x = (i / 4) % 5; + let limb = i % 4; + + let get_bit: int -> expr = |z| { + xor(b(x, y, z), andn(b((x + 1) % 5, y, z), b((x + 2) % 5, y, z))) + }; + let limb_bits_be: expr[] = array::reverse(array::new(16, |z| get_bit(limb * 16 + z))); + a_prime_prime[i] = bits_to_value_be(limb_bits_be) + }); + + // pub fn b(&self, x: usize, y: usize, z: usize) -> T { + // debug_assert!(x < 5); + // debug_assert!(y < 5); + // debug_assert!(z < 64); + + // // B is just a rotation of A', so these are aliases for A' registers. + // // From the spec, + // // B[y, (2x + 3y) % 5] = ROT(A'[x, y], r[x, y]) + // // So, + // // B[x, y] = f((x + 3y) % 5, x) + // // where f(a, b) = ROT(A'[a, b], r[a, b]) + // let a = (x + 3 * y) % 5; + // let b = x; + // let rot = R[a][b] as usize; + // self.a_prime[b][a][(z + 64 - rot) % 64] + // } + + // @Steve are these constants correct for 32? + let b: int, int, int -> expr = |x, y, z| { + let a: int = (x + 3 * y) % 5; + let rot: int = R[a * 5 + x]; // b = x + a_prime[x * 320 + a * 64 + (z + 64 - rot) % 64] + }; + + // // A'''[0, 0] = A''[0, 0] XOR RC + // for limb in 0..U64_LIMBS { + // let computed_a_prime_prime_0_0_limb = (limb * BITS_PER_LIMB + // ..(limb + 1) * BITS_PER_LIMB) + // .rev() + // .fold(AB::Expr::zero(), |acc, z| { + // builder.assert_bool(local.a_prime_prime_0_0_bits[z]); + // acc.double() + local.a_prime_prime_0_0_bits[z] + // }); + // let a_prime_prime_0_0_limb = local.a_prime_prime[0][0][limb]; + // builder.assert_eq(computed_a_prime_prime_0_0_limb, a_prime_prime_0_0_limb); + // } + + // @Steve are these constants correct for 32? + array::new(4, |limb| { + let limb_bits_be: expr[] = array::reverse(array::new(16, |z| a_prime_prime_0_0_bits[limb * 16 + z])); + a_prime_prime[limb] = bits_to_value_be(limb_bits_be) + }); + + // let get_xored_bit = |i| { + // let mut rc_bit_i = AB::Expr::zero(); + // for r in 0..NUM_ROUNDS { + // let this_round = local.step_flags[r]; + // let this_round_constant = AB::Expr::from_canonical_u8(rc_value_bit(r, i)); + // rc_bit_i += this_round * this_round_constant; + // } + + // xor_gen::(local.a_prime_prime_0_0_bits[i].into(), rc_bit_i) + // }; + + let get_xored_bit: int -> expr = |i| xor(a_prime_prime_0_0_bits[i], utils::sum(NUM_ROUNDS, |r| expr(RC_BITS[r * 64 + i]) * step_flags[r] )); + + // for limb in 0..U64_LIMBS { + // let a_prime_prime_prime_0_0_limb = local.a_prime_prime_prime_0_0_limbs[limb]; + // let computed_a_prime_prime_prime_0_0_limb = (limb * BITS_PER_LIMB + // ..(limb + 1) * BITS_PER_LIMB) + // .rev() + // .fold(AB::Expr::zero(), |acc, z| acc.double() + get_xored_bit(z)); + // builder.assert_eq( + // computed_a_prime_prime_prime_0_0_limb, + // a_prime_prime_prime_0_0_limb, + // ); + // } + + array::new(4, |limb| { + let limb_bits_be: expr[] = array::reverse(array::new(16, |z| get_xored_bit(limb * 16 + z))); + a_prime_prime_prime_0_0_limbs[limb] = bits_to_value_be(limb_bits_be) + }); + + // // Enforce that this round's output equals the next round's input. + // for x in 0..5 { + // for y in 0..5 { + // for limb in 0..U64_LIMBS { + // let output = local.a_prime_prime_prime(y, x, limb); + // let input = next.a[y][x][limb]; + // builder + // .when_transition() + // .when(not_final_step.clone()) + // .assert_eq(output, input); + // } + // } + // } + + // final_step and is_last should never be 1 at the same time, because final_step is 1 at multiples of 24 and can never be 1 at power of 2. + // (1 - final_step - is_last) is used to deactivate constraints that reference the next row, whenever we are at the latch row or the last row of the trace (so that we don't incorrectly cycle to the first row). + array::new(100, |i| { + let y = i / 20; + let x = (i / 4) % 5; + let limb = i % 4; + (1 - final_step - is_last) * (a_prime_prime_prime(y, x, limb) - a[i]') = 0 + }); + + // pub fn a_prime_prime_prime(&self, y: usize, x: usize, limb: usize) -> T { + // debug_assert!(y < 5); + // debug_assert!(x < 5); + // debug_assert!(limb < U64_LIMBS); + + // if y == 0 && x == 0 { + // self.a_prime_prime_prime_0_0_limbs[limb] + // } else { + // self.a_prime_prime[y][x][limb] + // } + // } + + let a_prime_prime_prime: int, int, int -> expr = |y, x, limb| if y == 0 && x == 0 { a_prime_prime_prime_0_0_limbs[limb] } else { a_prime_prime[y * 20 + x * 4 + limb] }; + + // @Steve are these constants correct for 32? + let R: int[] = [ + 0, 36, 3, 41, 18, + 1, 44, 10, 45, 2, + 62, 6, 43, 15, 61, + 28, 55, 25, 21, 56, + 27, 20, 39, 8, 14 + ]; + + // @Steve are these constants correct for 32? + let RC: int[] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808A, + 0x8000000080008000, + 0x000000000000808B, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008A, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000A, + 0x000000008000808B, + 0x800000000000008B, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800A, + 0x800000008000000A, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008 + ]; + + // @Steve are these constants correct for 32? + let RC_BITS: int[] = array::new(24 * 64, |i| { + let rc_idx = i / 64; + let bit = i % 64; + RC[rc_idx] >> bit & 0x1 + }); + + // Prover function section (for witness generation). + + // // Populate C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]). + // for x in 0..5 { + // for z in 0..64 { + // let limb = z / BITS_PER_LIMB; + // let bit_in_limb = z % BITS_PER_LIMB; + // let a = (0..5).map(|y| { + // let a_limb = row.a[y][x][limb].as_canonical_u64() as u16; + // ((a_limb >> bit_in_limb) & 1) != 0 + // }); + // row.c[x][z] = F::from_bool(a.fold(false, |acc, x| acc ^ x)); + // } + // } + + let query_c: int, int, int -> int = query |x, limb, bit_in_limb| + utils::fold( + 5, + |y| (int(eval(a[y * 20 + x * 4 + limb])) >> bit_in_limb) & 0x1, + 0, + |acc, e| acc ^ e + ); + + // @Steve are these constants correct for 32? + query |row| { + let _ = array::map_enumerated(c, |i, c_i| { + let x = i / 64; + let z = i % 64; + let limb = z / 32; + let bit_in_limb = z % 32; + + provide_value(c_i, row, fe(query_c(x, limb, bit_in_limb))); + }); + }; + + // // Populate C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]). + // for x in 0..5 { + // for z in 0..64 { + // row.c_prime[x][z] = xor([ + // row.c[x][z], + // row.c[(x + 4) % 5][z], + // row.c[(x + 1) % 5][(z + 63) % 64], + // ]); + // } + // } + + let query_c_prime: int, int -> int = query |x, z| + int(eval(c[x * 64 + z])) ^ + int(eval(c[((x + 4) % 5) * 64 + z])) ^ + int(eval(c[((x + 1) % 5) * 64 + (z + 63) % 64])); + + query |row| { + let _ = array::map_enumerated(c_prime, |i, c_i| { + let x = i / 64; + let z = i % 64; + + provide_value(c_i, row, fe(query_c_prime(x, z))); + }); + }; + + // // Populate A'. To avoid shifting indices, we rewrite + // // A'[x, y, z] = xor(A[x, y, z], C[x - 1, z], C[x + 1, z - 1]) + // // as + // // A'[x, y, z] = xor(A[x, y, z], C[x, z], C'[x, z]). + // for x in 0..5 { + // for y in 0..5 { + // for z in 0..64 { + // let limb = z / BITS_PER_LIMB; + // let bit_in_limb = z % BITS_PER_LIMB; + // let a_limb = row.a[y][x][limb].as_canonical_u64() as u16; + // let a_bit = F::from_bool(((a_limb >> bit_in_limb) & 1) != 0); + // row.a_prime[y][x][z] = xor([a_bit, row.c[x][z], row.c_prime[x][z]]); + // } + // } + // } + + let query_a_prime: int, int, int, int, int -> int = query |x, y, z, limb, bit_in_limb| + ((int(eval(a[y * 20 + x * 4 + limb])) >> bit_in_limb) & 0x1) ^ + int(eval(c[x * 64 + z])) ^ + int(eval(c_prime[x * 64 + z])); + + query |row| { + let _ = array::map_enumerated(a_prime, |i, a_i| { + let y = i / 320; + let x = (i / 64) % 5; + let z = i % 64; + let limb = z / 16; + let bit_in_limb = z % 16; + + provide_value(a_i, row, fe(query_a_prime(x, y, z, limb, bit_in_limb))); + }); + }; + + // // Populate A''.P + // // A''[x, y] = xor(B[x, y], andn(B[x + 1, y], B[x + 2, y])). + // for y in 0..5 { + // for x in 0..5 { + // for limb in 0..U64_LIMBS { + // row.a_prime_prime[y][x][limb] = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB) + // .rev() + // .fold(F::zero(), |acc, z| { + // let bit = xor([ + // row.b(x, y, z), + // andn(row.b((x + 1) % 5, y, z), row.b((x + 2) % 5, y, z)), + // ]); + // acc.double() + bit + // }); + // } + // } + // } + + let query_a_prime_prime: int, int, int -> int = query |x, y, limb| + utils::fold( + 16, + |z| + int(eval(b(x, y, (limb + 1) * 16 - 1 - z))) ^ + int(eval(andn(b((x + 1) % 5, y, (limb + 1) * 16 - 1 - z), + b((x + 2) % 5, y, (limb + 1) * 16 - 1 - z)))), + 0, + |acc, e| acc * 2 + e + ); + + query |row| { + let _ = array::map_enumerated(a_prime_prime, |i, a_i| { + let y = i / 20; + let x = (i / 4) % 5; + let limb = i % 4; + + provide_value(a_i, row, fe(query_a_prime_prime(x, y, limb))); + }); + }; + + // // For the XOR, we split A''[0, 0] to bits. + // let mut val = 0; // smaller address correspond to less significant limb + // for limb in 0..U64_LIMBS { + // let val_limb = row.a_prime_prime[0][0][limb].as_canonical_u64(); + // val |= val_limb << (limb * BITS_PER_LIMB); + // } + // let val_bits: Vec = (0..64) // smaller address correspond to less significant bit + // .scan(val, |acc, _| { + // let bit = (*acc & 1) != 0; + // *acc >>= 1; + // Some(bit) + // }) + // .collect(); + // for (i, bit) in row.a_prime_prime_0_0_bits.iter_mut().enumerate() { + // *bit = F::from_bool(val_bits[i]); + // } + + query |row| { + let _ = array::map_enumerated(a_prime_prime_0_0_bits, |i, a_i| { + let limb = i / 16; + let bit_in_limb = i % 16; + + provide_value( + a_i, + row, + fe((int(eval(a_prime_prime[limb])) >> bit_in_limb) & 0x1) + ); + }); + }; + + // // A''[0, 0] is additionally xor'd with RC. + // for limb in 0..U64_LIMBS { + // let rc_lo = rc_value_limb(round, limb); + // row.a_prime_prime_prime_0_0_limbs[limb] = + // F::from_canonical_u16(row.a_prime_prime[0][0][limb].as_canonical_u64() as u16 ^ rc_lo); + // } + + let query_a_prime_prime_prime_0_0_limbs: int, int -> int = query |round, limb| + int(eval(a_prime_prime[limb])) ^ + ((RC[round] >> (limb * 32)) & 0xffff); + + query |row| { + let _ = array::new(4, |limb| { + provide_value( + a_prime_prime_prime_0_0_limbs[limb], + row, + fe(query_a_prime_prime_prime_0_0_limbs(row % NUM_ROUNDS, limb) + )); + }); + }; +} From f3720a36197f250808f1a802c1a770c902a06c7c Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Mon, 25 Nov 2024 15:42:10 -0500 Subject: [PATCH 2/8] fixes for 32 bit --- std/machines/hash/keccakf32_memory.asm | 164 ++++++++++++------------- 1 file changed, 82 insertions(+), 82 deletions(-) diff --git a/std/machines/hash/keccakf32_memory.asm b/std/machines/hash/keccakf32_memory.asm index d8b16f51bb..e512fa069f 100644 --- a/std/machines/hash/keccakf32_memory.asm +++ b/std/machines/hash/keccakf32_memory.asm @@ -131,54 +131,54 @@ machine Keccakf32Memory(mem: Memory) with // Therefore, memory write converts little endian from keccak computation to big endian for the output in memory. link if final_step_used ~> mem.mstore(output_addr, time_step, a_prime_prime_prime_0_0_limbs[1]); link if final_step_used ~> mem.mstore(output_addr + 4, time_step, a_prime_prime_prime_0_0_limbs[0]); - link if final_step_used ~> mem.mstore(output_addr + 8, time_step, a_prime_prime_prime_0_0_limbs[3]); - link if final_step_used ~> mem.mstore(output_addr + 12, time_step, a_prime_prime_prime_0_0_limbs[2]); - link if final_step_used ~> mem.mstore(output_addr + 16, time_step, a_prime_prime_prime_0_0_limbs[5]); - link if final_step_used ~> mem.mstore(output_addr + 20, time_step, a_prime_prime_prime_0_0_limbs[4]); - link if final_step_used ~> mem.mstore(output_addr + 24, time_step, a_prime_prime_prime_0_0_limbs[7]); - link if final_step_used ~> mem.mstore(output_addr + 28, time_step, a_prime_prime_prime_0_0_limbs[6]); - link if final_step_used ~> mem.mstore(output_addr + 32, time_step, a_prime_prime_prime_0_0_limbs[9]); - link if final_step_used ~> mem.mstore(output_addr + 36, time_step, a_prime_prime_prime_0_0_limbs[8]); - link if final_step_used ~> mem.mstore(output_addr + 40, time_step, a_prime_prime_prime_0_0_limbs[11]); - link if final_step_used ~> mem.mstore(output_addr + 44, time_step, a_prime_prime_prime_0_0_limbs[10]); - link if final_step_used ~> mem.mstore(output_addr + 48, time_step, a_prime_prime_prime_0_0_limbs[13]); - link if final_step_used ~> mem.mstore(output_addr + 52, time_step, a_prime_prime_prime_0_0_limbs[12]); - link if final_step_used ~> mem.mstore(output_addr + 56, time_step, a_prime_prime_prime_0_0_limbs[15]); - link if final_step_used ~> mem.mstore(output_addr + 60, time_step, a_prime_prime_prime_0_0_limbs[14]); - link if final_step_used ~> mem.mstore(output_addr + 64, time_step, a_prime_prime_prime_0_0_limbs[17]); - link if final_step_used ~> mem.mstore(output_addr + 68, time_step, a_prime_prime_prime_0_0_limbs[16]); - link if final_step_used ~> mem.mstore(output_addr + 72, time_step, a_prime_prime_prime_0_0_limbs[19]); - link if final_step_used ~> mem.mstore(output_addr + 76, time_step, a_prime_prime_prime_0_0_limbs[18]); - link if final_step_used ~> mem.mstore(output_addr + 80, time_step, a_prime_prime_prime_0_0_limbs[21]); - link if final_step_used ~> mem.mstore(output_addr + 84, time_step, a_prime_prime_prime_0_0_limbs[20]); - link if final_step_used ~> mem.mstore(output_addr + 88, time_step, a_prime_prime_prime_0_0_limbs[23]); - link if final_step_used ~> mem.mstore(output_addr + 92, time_step, a_prime_prime_prime_0_0_limbs[22]); - link if final_step_used ~> mem.mstore(output_addr + 96, time_step, a_prime_prime_prime_0_0_limbs[25]); - link if final_step_used ~> mem.mstore(output_addr + 100, time_step, a_prime_prime_prime_0_0_limbs[24]); - link if final_step_used ~> mem.mstore(output_addr + 104, time_step, a_prime_prime_prime_0_0_limbs[27]); - link if final_step_used ~> mem.mstore(output_addr + 108, time_step, a_prime_prime_prime_0_0_limbs[26]); - link if final_step_used ~> mem.mstore(output_addr + 112, time_step, a_prime_prime_prime_0_0_limbs[29]); - link if final_step_used ~> mem.mstore(output_addr + 116, time_step, a_prime_prime_prime_0_0_limbs[28]); - link if final_step_used ~> mem.mstore(output_addr + 120, time_step, a_prime_prime_prime_0_0_limbs[31]); - link if final_step_used ~> mem.mstore(output_addr + 124, time_step, a_prime_prime_prime_0_0_limbs[30]); - link if final_step_used ~> mem.mstore(output_addr + 128, time_step, a_prime_prime_prime_0_0_limbs[33]); - link if final_step_used ~> mem.mstore(output_addr + 132, time_step, a_prime_prime_prime_0_0_limbs[32]); - link if final_step_used ~> mem.mstore(output_addr + 136, time_step, a_prime_prime_prime_0_0_limbs[35]); - link if final_step_used ~> mem.mstore(output_addr + 140, time_step, a_prime_prime_prime_0_0_limbs[34]); - link if final_step_used ~> mem.mstore(output_addr + 144, time_step, a_prime_prime_prime_0_0_limbs[37]); - link if final_step_used ~> mem.mstore(output_addr + 148, time_step, a_prime_prime_prime_0_0_limbs[36]); - link if final_step_used ~> mem.mstore(output_addr + 152, time_step, a_prime_prime_prime_0_0_limbs[39]); - link if final_step_used ~> mem.mstore(output_addr + 156, time_step, a_prime_prime_prime_0_0_limbs[38]); - link if final_step_used ~> mem.mstore(output_addr + 160, time_step, a_prime_prime_prime_0_0_limbs[41]); - link if final_step_used ~> mem.mstore(output_addr + 164, time_step, a_prime_prime_prime_0_0_limbs[40]); - link if final_step_used ~> mem.mstore(output_addr + 168, time_step, a_prime_prime_prime_0_0_limbs[43]); - link if final_step_used ~> mem.mstore(output_addr + 172, time_step, a_prime_prime_prime_0_0_limbs[42]); - link if final_step_used ~> mem.mstore(output_addr + 176, time_step, a_prime_prime_prime_0_0_limbs[45]); - link if final_step_used ~> mem.mstore(output_addr + 180, time_step, a_prime_prime_prime_0_0_limbs[44]); - link if final_step_used ~> mem.mstore(output_addr + 184, time_step, a_prime_prime_prime_0_0_limbs[47]); - link if final_step_used ~> mem.mstore(output_addr + 188, time_step, a_prime_prime_prime_0_0_limbs[46]); - link if final_step_used ~> mem.mstore(output_addr + 192, time_step, a_prime_prime_prime_0_0_limbs[49]); - link if final_step_used ~> mem.mstore(output_addr + 196, time_step, a_prime_prime_prime_0_0_limbs[48]); + link if final_step_used ~> mem.mstore(output_addr + 8, time_step, a_prime_prime[3]); + link if final_step_used ~> mem.mstore(output_addr + 12, time_step, a_prime_prime[2]); + link if final_step_used ~> mem.mstore(output_addr + 16, time_step, a_prime_prime[5]); + link if final_step_used ~> mem.mstore(output_addr + 20, time_step, a_prime_prime[4]); + link if final_step_used ~> mem.mstore(output_addr + 24, time_step, a_prime_prime[7]); + link if final_step_used ~> mem.mstore(output_addr + 28, time_step, a_prime_prime[6]); + link if final_step_used ~> mem.mstore(output_addr + 32, time_step, a_prime_prime[9]); + link if final_step_used ~> mem.mstore(output_addr + 36, time_step, a_prime_prime[8]); + link if final_step_used ~> mem.mstore(output_addr + 40, time_step, a_prime_prime[11]); + link if final_step_used ~> mem.mstore(output_addr + 44, time_step, a_prime_prime[10]); + link if final_step_used ~> mem.mstore(output_addr + 48, time_step, a_prime_prime[13]); + link if final_step_used ~> mem.mstore(output_addr + 52, time_step, a_prime_prime[12]); + link if final_step_used ~> mem.mstore(output_addr + 56, time_step, a_prime_prime[15]); + link if final_step_used ~> mem.mstore(output_addr + 60, time_step, a_prime_prime[14]); + link if final_step_used ~> mem.mstore(output_addr + 64, time_step, a_prime_prime[17]); + link if final_step_used ~> mem.mstore(output_addr + 68, time_step, a_prime_prime[16]); + link if final_step_used ~> mem.mstore(output_addr + 72, time_step, a_prime_prime[19]); + link if final_step_used ~> mem.mstore(output_addr + 76, time_step, a_prime_prime[18]); + link if final_step_used ~> mem.mstore(output_addr + 80, time_step, a_prime_prime[21]); + link if final_step_used ~> mem.mstore(output_addr + 84, time_step, a_prime_prime[20]); + link if final_step_used ~> mem.mstore(output_addr + 88, time_step, a_prime_prime[23]); + link if final_step_used ~> mem.mstore(output_addr + 92, time_step, a_prime_prime[22]); + link if final_step_used ~> mem.mstore(output_addr + 96, time_step, a_prime_prime[25]); + link if final_step_used ~> mem.mstore(output_addr + 100, time_step, a_prime_prime[24]); + link if final_step_used ~> mem.mstore(output_addr + 104, time_step, a_prime_prime[27]); + link if final_step_used ~> mem.mstore(output_addr + 108, time_step, a_prime_prime[26]); + link if final_step_used ~> mem.mstore(output_addr + 112, time_step, a_prime_prime[29]); + link if final_step_used ~> mem.mstore(output_addr + 116, time_step, a_prime_prime[28]); + link if final_step_used ~> mem.mstore(output_addr + 120, time_step, a_prime_prime[31]); + link if final_step_used ~> mem.mstore(output_addr + 124, time_step, a_prime_prime[30]); + link if final_step_used ~> mem.mstore(output_addr + 128, time_step, a_prime_prime[33]); + link if final_step_used ~> mem.mstore(output_addr + 132, time_step, a_prime_prime[32]); + link if final_step_used ~> mem.mstore(output_addr + 136, time_step, a_prime_prime[35]); + link if final_step_used ~> mem.mstore(output_addr + 140, time_step, a_prime_prime[34]); + link if final_step_used ~> mem.mstore(output_addr + 144, time_step, a_prime_prime[37]); + link if final_step_used ~> mem.mstore(output_addr + 148, time_step, a_prime_prime[36]); + link if final_step_used ~> mem.mstore(output_addr + 152, time_step, a_prime_prime[39]); + link if final_step_used ~> mem.mstore(output_addr + 156, time_step, a_prime_prime[38]); + link if final_step_used ~> mem.mstore(output_addr + 160, time_step, a_prime_prime[41]); + link if final_step_used ~> mem.mstore(output_addr + 164, time_step, a_prime_prime[40]); + link if final_step_used ~> mem.mstore(output_addr + 168, time_step, a_prime_prime[43]); + link if final_step_used ~> mem.mstore(output_addr + 172, time_step, a_prime_prime[42]); + link if final_step_used ~> mem.mstore(output_addr + 176, time_step, a_prime_prime[45]); + link if final_step_used ~> mem.mstore(output_addr + 180, time_step, a_prime_prime[44]); + link if final_step_used ~> mem.mstore(output_addr + 184, time_step, a_prime_prime[47]); + link if final_step_used ~> mem.mstore(output_addr + 188, time_step, a_prime_prime[46]); + link if final_step_used ~> mem.mstore(output_addr + 192, time_step, a_prime_prime[49]); + link if final_step_used ~> mem.mstore(output_addr + 196, time_step, a_prime_prime[48]); // ------------- End memory read / write --------------- // Adapted from Plonky3 implementation of Keccak: https://github.com/Plonky3/Plonky3/tree/main/keccak-air/src @@ -383,13 +383,13 @@ machine Keccakf32Memory(mem: Memory) with let bits_to_value_be: expr[] -> expr = |bits_be| array::fold(bits_be, 0, |acc, e| (acc * 2 + e)); // @Steve are these constants correct for 32? - array::new(100, |i| { - let y = i / 20; - let x = (i / 4) % 5; - let limb = i % 4; + array::new(50, |i| { + let y = i / 10; + let x = (i / 2) % 5; + let limb = i % 2; let get_bit: int -> expr = |z| xor3(a_prime[y * 320 + x * 64 + z], c[x * 64 + z], c_prime[x * 64 + z]); - let limb_bits_be: expr[] = array::reverse(array::new(16, |z| get_bit(limb * 16 + z))); + let limb_bits_be: expr[] = array::reverse(array::new(32, |z| get_bit(limb * 32 + z))); a[i] = bits_to_value_be(limb_bits_be) }); @@ -436,15 +436,15 @@ machine Keccakf32Memory(mem: Memory) with // } // @Steve are these constants correct for 32? - array::new(100, |i| { - let y = i / 20; - let x = (i / 4) % 5; - let limb = i % 4; + array::new(50, |i| { + let y = i / 10; + let x = (i / 2) % 5; + let limb = i % 2; let get_bit: int -> expr = |z| { xor(b(x, y, z), andn(b((x + 1) % 5, y, z), b((x + 2) % 5, y, z))) }; - let limb_bits_be: expr[] = array::reverse(array::new(16, |z| get_bit(limb * 16 + z))); + let limb_bits_be: expr[] = array::reverse(array::new(32, |z| get_bit(limb * 32 + z))); a_prime_prime[i] = bits_to_value_be(limb_bits_be) }); @@ -486,8 +486,8 @@ machine Keccakf32Memory(mem: Memory) with // } // @Steve are these constants correct for 32? - array::new(4, |limb| { - let limb_bits_be: expr[] = array::reverse(array::new(16, |z| a_prime_prime_0_0_bits[limb * 16 + z])); + array::new(2, |limb| { + let limb_bits_be: expr[] = array::reverse(array::new(32, |z| a_prime_prime_0_0_bits[limb * 32 + z])); a_prime_prime[limb] = bits_to_value_be(limb_bits_be) }); @@ -516,8 +516,8 @@ machine Keccakf32Memory(mem: Memory) with // ); // } - array::new(4, |limb| { - let limb_bits_be: expr[] = array::reverse(array::new(16, |z| get_xored_bit(limb * 16 + z))); + array::new(2, |limb| { + let limb_bits_be: expr[] = array::reverse(array::new(32, |z| get_xored_bit(limb * 32 + z))); a_prime_prime_prime_0_0_limbs[limb] = bits_to_value_be(limb_bits_be) }); @@ -537,10 +537,10 @@ machine Keccakf32Memory(mem: Memory) with // final_step and is_last should never be 1 at the same time, because final_step is 1 at multiples of 24 and can never be 1 at power of 2. // (1 - final_step - is_last) is used to deactivate constraints that reference the next row, whenever we are at the latch row or the last row of the trace (so that we don't incorrectly cycle to the first row). - array::new(100, |i| { - let y = i / 20; - let x = (i / 4) % 5; - let limb = i % 4; + array::new(50, |i| { + let y = i / 10; + let x = (i / 2) % 5; + let limb = i % 2; (1 - final_step - is_last) * (a_prime_prime_prime(y, x, limb) - a[i]') = 0 }); @@ -556,7 +556,7 @@ machine Keccakf32Memory(mem: Memory) with // } // } - let a_prime_prime_prime: int, int, int -> expr = |y, x, limb| if y == 0 && x == 0 { a_prime_prime_prime_0_0_limbs[limb] } else { a_prime_prime[y * 20 + x * 4 + limb] }; + let a_prime_prime_prime: int, int, int -> expr = |y, x, limb| if y == 0 && x == 0 { a_prime_prime_prime_0_0_limbs[limb] } else { a_prime_prime[y * 10 + x * 2 + limb] }; // @Steve are these constants correct for 32? let R: int[] = [ @@ -620,7 +620,7 @@ machine Keccakf32Memory(mem: Memory) with let query_c: int, int, int -> int = query |x, limb, bit_in_limb| utils::fold( 5, - |y| (int(eval(a[y * 20 + x * 4 + limb])) >> bit_in_limb) & 0x1, + |y| (int(eval(a[y * 10 + x * 2 + limb])) >> bit_in_limb) & 0x1, 0, |acc, e| acc ^ e ); @@ -679,7 +679,7 @@ machine Keccakf32Memory(mem: Memory) with // } let query_a_prime: int, int, int, int, int -> int = query |x, y, z, limb, bit_in_limb| - ((int(eval(a[y * 20 + x * 4 + limb])) >> bit_in_limb) & 0x1) ^ + ((int(eval(a[y * 10 + x * 2 + limb])) >> bit_in_limb) & 0x1) ^ int(eval(c[x * 64 + z])) ^ int(eval(c_prime[x * 64 + z])); @@ -688,8 +688,8 @@ machine Keccakf32Memory(mem: Memory) with let y = i / 320; let x = (i / 64) % 5; let z = i % 64; - let limb = z / 16; - let bit_in_limb = z % 16; + let limb = z / 32; + let bit_in_limb = z % 32; provide_value(a_i, row, fe(query_a_prime(x, y, z, limb, bit_in_limb))); }); @@ -715,20 +715,20 @@ machine Keccakf32Memory(mem: Memory) with let query_a_prime_prime: int, int, int -> int = query |x, y, limb| utils::fold( - 16, + 32, |z| - int(eval(b(x, y, (limb + 1) * 16 - 1 - z))) ^ - int(eval(andn(b((x + 1) % 5, y, (limb + 1) * 16 - 1 - z), - b((x + 2) % 5, y, (limb + 1) * 16 - 1 - z)))), + int(eval(b(x, y, (limb + 1) * 32 - 1 - z))) ^ + int(eval(andn(b((x + 1) % 5, y, (limb + 1) * 32 - 1 - z), + b((x + 2) % 5, y, (limb + 1) * 32 - 1 - z)))), 0, |acc, e| acc * 2 + e ); query |row| { let _ = array::map_enumerated(a_prime_prime, |i, a_i| { - let y = i / 20; - let x = (i / 4) % 5; - let limb = i % 4; + let y = i / 10; + let x = (i / 2) % 5; + let limb = i % 2; provide_value(a_i, row, fe(query_a_prime_prime(x, y, limb))); }); @@ -753,8 +753,8 @@ machine Keccakf32Memory(mem: Memory) with query |row| { let _ = array::map_enumerated(a_prime_prime_0_0_bits, |i, a_i| { - let limb = i / 16; - let bit_in_limb = i % 16; + let limb = i / 32; + let bit_in_limb = i % 32; provide_value( a_i, @@ -773,10 +773,10 @@ machine Keccakf32Memory(mem: Memory) with let query_a_prime_prime_prime_0_0_limbs: int, int -> int = query |round, limb| int(eval(a_prime_prime[limb])) ^ - ((RC[round] >> (limb * 32)) & 0xffff); + ((RC[round] >> (limb * 32)) & 0xffffffff); query |row| { - let _ = array::new(4, |limb| { + let _ = array::new(2, |limb| { provide_value( a_prime_prime_prime_0_0_limbs[limb], row, From 9ba2f3a3b9a5d76b8494722e6813f45e2a4058df Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Mon, 25 Nov 2024 15:51:56 -0500 Subject: [PATCH 3/8] first pass test case --- test_data/std/keccakf32_memory_test.asm | 120 ++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 test_data/std/keccakf32_memory_test.asm diff --git a/test_data/std/keccakf32_memory_test.asm b/test_data/std/keccakf32_memory_test.asm new file mode 100644 index 0000000000..425c61f9ce --- /dev/null +++ b/test_data/std/keccakf32_memory_test.asm @@ -0,0 +1,120 @@ +use std::machines::hash::keccakf32_memory::Keccakf32Memory; +use std::machines::large_field::memory::Memory; +use std::machines::range::Byte2; + +machine Main with degree: 65536 { + reg pc[@pc]; + + reg X[<=]; + + reg Y[<=]; + + Byte2 byte2; + Memory memory(byte2); + + Keccakf32Memory keccakf32_memory(memory); + + col fixed STEP(i) { i }; + + // Big endian. + // Usage: mstore addr, val; + instr mstore X, Y -> link ~> memory.mstore(X, STEP, Y); + // Usage: keccakf32_memory input_addr, output_addr; + instr keccakf32_memory X, Y -> link ~> keccakf32_memory.keccakf32_memory(X, Y, STEP); + + col witness val; + // Usage: assert_eq addr, val; + instr assert_eq X, Y -> + link ~> val = memory.mload(X, STEP) + { + val = Y + } + + function main { + // Test 1: Input/output address computations have no carry. + // 0 for all 25 64-bit inputs except setting the second 64-bit input to 1. All 64-bit inputs in chunks of 4 16-bit big endian limbs. + mstore 0, 0; + mstore 4, 0; + mstore 8, 0; + mstore 12, 1; + mstore 16, 0; + mstore 20, 0; + mstore 24, 0; + mstore 28, 0; + mstore 32, 0; + mstore 36, 0; + mstore 40, 0; + mstore 44, 0; + mstore 48, 0; + mstore 52, 0; + mstore 56, 0; + mstore 60, 0; + mstore 64, 0; + mstore 68, 0; + mstore 72, 0; + mstore 76, 0; + mstore 80, 0; + mstore 84, 0; + mstore 88, 0; + mstore 92, 0; + mstore 96, 0; + // Input address 0. Output address 200. + keccakf16_memory 0, 200; + // Selectively checking a few registers only. + // Test vector generated from Tiny Keccak. + assert_eq 200, 0xfdbbbbdf; + assert_eq 204, 0x9001405f; + assert_eq 392, 0xeac9f006; + assert_eq 396, 0x664deb35; +// + // // Test 2: Same as Test 1 but sets input and output addresses to be the same. + // // No need to rerun the mstores because input values from Test 1 should be intact. + // keccakf16_memory 0, 0, 0, 0; + // // Selectively checking a few registers only. + // // Test vector generated from Tiny Keccak. + // assert_eq 0, 0, 0xfdbb, 0xbbdf; + // assert_eq 0, 4, 0x9001, 0x405f; + // assert_eq 0, 192, 0xeac9, 0xf006; + // assert_eq 0, 196, 0x664d, 0xeb35; +// + // // Test 3: Input/output address computations have carry. + // // 0 for all 25 64-bit inputs except setting the second 64-bit input to 1. All 64-bit inputs in chunks of 4 16-bit big endian limbs. + // mstore 100, 65520, 0, 0; + // mstore 100, 65524, 0, 0; + // mstore 100, 65528, 0, 0; + // mstore 100, 65532, 0, 1; + // mstore 101, 0, 0, 0; + // mstore 101, 4, 0, 0; + // mstore 101, 8, 0, 0; + // mstore 101, 12, 0, 0; + // mstore 101, 16, 0, 0; + // mstore 101, 20, 0, 0; + // mstore 101, 24, 0, 0; + // mstore 101, 28, 0, 0; + // mstore 101, 32, 0, 0; + // mstore 101, 36, 0, 0; + // mstore 101, 40, 0, 0; + // mstore 101, 44, 0, 0; + // mstore 101, 48, 0, 0; + // mstore 101, 52, 0, 0; + // mstore 101, 56, 0, 0; + // mstore 101, 60, 0, 0; + // mstore 101, 64, 0, 0; + // mstore 101, 68, 0, 0; + // mstore 101, 72, 0, 0; + // mstore 101, 76, 0, 0; + // mstore 101, 80, 0, 0; +// + // // Input address (100 * 65536 + 65520). Output address (50000 * 65536 + 65528). + // keccakf16_memory 100, 65520, 50000, 65528; +// + // // Selectively checking a few registers only. + // // Test vector generated from Tiny Keccak. + // assert_eq 50000, 65528, 0xfdbb, 0xbbdf; + // assert_eq 50000, 65532, 0x9001, 0x405f; + // assert_eq 50001, 184, 0xeac9, 0xf006; + // assert_eq 50001, 188, 0x664d, 0xeb35; +// + return; + } +} From c38b548ee47546ad03e8ee145d6386c87452e21e Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Mon, 25 Nov 2024 16:25:45 -0500 Subject: [PATCH 4/8] first test case passed --- std/machines/hash/keccakf32_memory.asm | 96 ++++++++++++------------- std/machines/hash/mod.asm | 3 +- test_data/std/keccakf32_memory_test.asm | 2 +- 3 files changed, 51 insertions(+), 50 deletions(-) diff --git a/std/machines/hash/keccakf32_memory.asm b/std/machines/hash/keccakf32_memory.asm index e512fa069f..7eb04b16ea 100644 --- a/std/machines/hash/keccakf32_memory.asm +++ b/std/machines/hash/keccakf32_memory.asm @@ -73,54 +73,54 @@ machine Keccakf32Memory(mem: Memory) with // Therefore memory load converts big endian inputs to little endian for the preimage. link if first_step_used ~> preimage[1] = mem.mload(input_addr, time_step); link if first_step_used ~> preimage[0] = mem.mload(input_addr + 4, time_step); - link if first_step_used ~> preimage[3] = mem.mload(input_addr + 12, time_step); - link if first_step_used ~> preimage[2] = mem.mload(input_addr + 8, time_step); - link if first_step_used ~> preimage[5] = mem.mload(input_addr + 20, time_step); - link if first_step_used ~> preimage[4] = mem.mload(input_addr + 16, time_step); - link if first_step_used ~> preimage[7] = mem.mload(input_addr + 28, time_step); - link if first_step_used ~> preimage[6] = mem.mload(input_addr + 24, time_step); - link if first_step_used ~> preimage[9] = mem.mload(input_addr + 36, time_step); - link if first_step_used ~> preimage[8] = mem.mload(input_addr + 32, time_step); - link if first_step_used ~> preimage[11] = mem.mload(input_addr + 44, time_step); - link if first_step_used ~> preimage[10] = mem.mload(input_addr + 40, time_step); - link if first_step_used ~> preimage[13] = mem.mload(input_addr + 52, time_step); - link if first_step_used ~> preimage[12] = mem.mload(input_addr + 48, time_step); - link if first_step_used ~> preimage[15] = mem.mload(input_addr + 60, time_step); - link if first_step_used ~> preimage[14] = mem.mload(input_addr + 56, time_step); - link if first_step_used ~> preimage[17] = mem.mload(input_addr + 68, time_step); - link if first_step_used ~> preimage[16] = mem.mload(input_addr + 64, time_step); - link if first_step_used ~> preimage[19] = mem.mload(input_addr + 76, time_step); - link if first_step_used ~> preimage[18] = mem.mload(input_addr + 72, time_step); - link if first_step_used ~> preimage[21] = mem.mload(input_addr + 84, time_step); - link if first_step_used ~> preimage[20] = mem.mload(input_addr + 80, time_step); - link if first_step_used ~> preimage[23] = mem.mload(input_addr + 92, time_step); - link if first_step_used ~> preimage[22] = mem.mload(input_addr + 88, time_step); - link if first_step_used ~> preimage[25] = mem.mload(input_addr + 100, time_step); - link if first_step_used ~> preimage[24] = mem.mload(input_addr + 96, time_step); - link if first_step_used ~> preimage[27] = mem.mload(input_addr + 108, time_step); - link if first_step_used ~> preimage[26] = mem.mload(input_addr + 104, time_step); - link if first_step_used ~> preimage[29] = mem.mload(input_addr + 116, time_step); - link if first_step_used ~> preimage[28] = mem.mload(input_addr + 112, time_step); - link if first_step_used ~> preimage[31] = mem.mload(input_addr + 124, time_step); - link if first_step_used ~> preimage[30] = mem.mload(input_addr + 120, time_step); - link if first_step_used ~> preimage[33] = mem.mload(input_addr + 132, time_step); - link if first_step_used ~> preimage[32] = mem.mload(input_addr + 128, time_step); - link if first_step_used ~> preimage[35] = mem.mload(input_addr + 140, time_step); - link if first_step_used ~> preimage[34] = mem.mload(input_addr + 136, time_step); - link if first_step_used ~> preimage[37] = mem.mload(input_addr + 148, time_step); - link if first_step_used ~> preimage[36] = mem.mload(input_addr + 144, time_step); - link if first_step_used ~> preimage[39] = mem.mload(input_addr + 156, time_step); - link if first_step_used ~> preimage[38] = mem.mload(input_addr + 152, time_step); - link if first_step_used ~> preimage[41] = mem.mload(input_addr + 164, time_step); - link if first_step_used ~> preimage[40] = mem.mload(input_addr + 160, time_step); - link if first_step_used ~> preimage[43] = mem.mload(input_addr + 172, time_step); - link if first_step_used ~> preimage[42] = mem.mload(input_addr + 168, time_step); - link if first_step_used ~> preimage[45] = mem.mload(input_addr + 180, time_step); - link if first_step_used ~> preimage[44] = mem.mload(input_addr + 176, time_step); - link if first_step_used ~> preimage[47] = mem.mload(input_addr + 188, time_step); - link if first_step_used ~> preimage[46] = mem.mload(input_addr + 184, time_step); - link if first_step_used ~> preimage[49] = mem.mload(input_addr + 196, time_step); - link if first_step_used ~> preimage[48] = mem.mload(input_addr + 192, time_step); + link if first_step_used ~> preimage[3] = mem.mload(input_addr + 8, time_step); + link if first_step_used ~> preimage[2] = mem.mload(input_addr + 12, time_step); + link if first_step_used ~> preimage[5] = mem.mload(input_addr + 16, time_step); + link if first_step_used ~> preimage[4] = mem.mload(input_addr + 20, time_step); + link if first_step_used ~> preimage[7] = mem.mload(input_addr + 24, time_step); + link if first_step_used ~> preimage[6] = mem.mload(input_addr + 28, time_step); + link if first_step_used ~> preimage[9] = mem.mload(input_addr + 32, time_step); + link if first_step_used ~> preimage[8] = mem.mload(input_addr + 36, time_step); + link if first_step_used ~> preimage[11] = mem.mload(input_addr + 40, time_step); + link if first_step_used ~> preimage[10] = mem.mload(input_addr + 44, time_step); + link if first_step_used ~> preimage[13] = mem.mload(input_addr + 48, time_step); + link if first_step_used ~> preimage[12] = mem.mload(input_addr + 52, time_step); + link if first_step_used ~> preimage[15] = mem.mload(input_addr + 56, time_step); + link if first_step_used ~> preimage[14] = mem.mload(input_addr + 60, time_step); + link if first_step_used ~> preimage[17] = mem.mload(input_addr + 64, time_step); + link if first_step_used ~> preimage[16] = mem.mload(input_addr + 68, time_step); + link if first_step_used ~> preimage[19] = mem.mload(input_addr + 72, time_step); + link if first_step_used ~> preimage[18] = mem.mload(input_addr + 76, time_step); + link if first_step_used ~> preimage[21] = mem.mload(input_addr + 80, time_step); + link if first_step_used ~> preimage[20] = mem.mload(input_addr + 84, time_step); + link if first_step_used ~> preimage[23] = mem.mload(input_addr + 88, time_step); + link if first_step_used ~> preimage[22] = mem.mload(input_addr + 92, time_step); + link if first_step_used ~> preimage[25] = mem.mload(input_addr + 96, time_step); + link if first_step_used ~> preimage[24] = mem.mload(input_addr + 100, time_step); + link if first_step_used ~> preimage[27] = mem.mload(input_addr + 104, time_step); + link if first_step_used ~> preimage[26] = mem.mload(input_addr + 108, time_step); + link if first_step_used ~> preimage[29] = mem.mload(input_addr + 112, time_step); + link if first_step_used ~> preimage[28] = mem.mload(input_addr + 116, time_step); + link if first_step_used ~> preimage[31] = mem.mload(input_addr + 120, time_step); + link if first_step_used ~> preimage[30] = mem.mload(input_addr + 124, time_step); + link if first_step_used ~> preimage[33] = mem.mload(input_addr + 128, time_step); + link if first_step_used ~> preimage[32] = mem.mload(input_addr + 132, time_step); + link if first_step_used ~> preimage[35] = mem.mload(input_addr + 136, time_step); + link if first_step_used ~> preimage[34] = mem.mload(input_addr + 140, time_step); + link if first_step_used ~> preimage[37] = mem.mload(input_addr + 144, time_step); + link if first_step_used ~> preimage[36] = mem.mload(input_addr + 148, time_step); + link if first_step_used ~> preimage[39] = mem.mload(input_addr + 152, time_step); + link if first_step_used ~> preimage[38] = mem.mload(input_addr + 156, time_step); + link if first_step_used ~> preimage[41] = mem.mload(input_addr + 160, time_step); + link if first_step_used ~> preimage[40] = mem.mload(input_addr + 164, time_step); + link if first_step_used ~> preimage[43] = mem.mload(input_addr + 168, time_step); + link if first_step_used ~> preimage[42] = mem.mload(input_addr + 172, time_step); + link if first_step_used ~> preimage[45] = mem.mload(input_addr + 176, time_step); + link if first_step_used ~> preimage[44] = mem.mload(input_addr + 180, time_step); + link if first_step_used ~> preimage[47] = mem.mload(input_addr + 184, time_step); + link if first_step_used ~> preimage[46] = mem.mload(input_addr + 188, time_step); + link if first_step_used ~> preimage[49] = mem.mload(input_addr + 192, time_step); + link if first_step_used ~> preimage[48] = mem.mload(input_addr + 196, time_step); // Expects input of 25 64-bit numbers decomposed to 25 chunks of 2 32-bit little endian limbs. // The output is a_prime_prime_prime_0_0_limbs for the first 2 and a_prime_prime for the rest. diff --git a/std/machines/hash/mod.asm b/std/machines/hash/mod.asm index 46f3c9e610..3b9a215595 100644 --- a/std/machines/hash/mod.asm +++ b/std/machines/hash/mod.asm @@ -6,4 +6,5 @@ mod poseidon2_common; mod poseidon2_bb; mod poseidon2_gl; mod keccakf16; -mod keccakf16_memory; \ No newline at end of file +mod keccakf16_memory; +mod keccakf32_memory; diff --git a/test_data/std/keccakf32_memory_test.asm b/test_data/std/keccakf32_memory_test.asm index 425c61f9ce..f15e9c426b 100644 --- a/test_data/std/keccakf32_memory_test.asm +++ b/test_data/std/keccakf32_memory_test.asm @@ -59,7 +59,7 @@ machine Main with degree: 65536 { mstore 92, 0; mstore 96, 0; // Input address 0. Output address 200. - keccakf16_memory 0, 200; + keccakf32_memory 0, 200; // Selectively checking a few registers only. // Test vector generated from Tiny Keccak. assert_eq 200, 0xfdbbbbdf; From 924464f58e73d26dad608a2500228278d4150cfc Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Mon, 25 Nov 2024 16:35:39 -0500 Subject: [PATCH 5/8] both test cases pass; updated comments as well --- std/machines/hash/keccakf32_memory.asm | 24 ++-------- test_data/std/keccakf32_memory_test.asm | 63 +++++-------------------- 2 files changed, 15 insertions(+), 72 deletions(-) diff --git a/std/machines/hash/keccakf32_memory.asm b/std/machines/hash/keccakf32_memory.asm index 7eb04b16ea..dbb1d6604f 100644 --- a/std/machines/hash/keccakf32_memory.asm +++ b/std/machines/hash/keccakf32_memory.asm @@ -21,7 +21,7 @@ machine Keccakf32Memory(mem: Memory) with Additional columns compared to the non-memory version: - 1 column for user input address (of first byte of input). - 1 column for user output address (of first byte of output). - - 49 columns for computed input/output address of all bytes. + - 1 column for time step. Overall, given that there are 2,600+ columns in the non-memory version, this isn't a huge cost Methodology description: 1. The latch with the input and output addresses + time step is in the last row of each block. @@ -31,14 +31,8 @@ machine Keccakf32Memory(mem: Memory) with 5. Keccak is computed from top to bottom. 6. Output addresses for all bytes are calculated from user output address in the last row. 7. Store all output bytes from keccak computation columns to memory. - Note that this methodology reuses the same 49 columns of addr to calculate input and output addresses of all bytes. - However, these 49 columns are only used in the first and last rows of each block. Essentially, we conduct all memory reads in the first row and all memory writes in the last row. - This might seem wasteful, but it's still cheaper than spreading memory reads/writes over different rows while using as few columns as possible, - which requires 100 columns to make outputs available in all rows in additional to the memory columns. - This alternative methodology (memory reads/writes over different rows) also doesn't work well with our auto witgen infrastructure, - because one would need to do memory read -> keccak computation -> memory write as three sequential passes during witgen. - On the contrary, our current methodology performs all memory reads at once in the first row, then immediately does the keccak computation, + Our current methodology performs all memory reads at once in the first row, then immediately does the keccak computation, and finally performs all memory writes at once in the last row, and thus only requires one pass with auto witgen. Though note that input address need to be first copied from the last row to the first row. */ @@ -127,7 +121,7 @@ machine Keccakf32Memory(mem: Memory) with // Write memory while converting output to big endian format. // Specifically, output obtained from the keccak computation are little endian. - // However, this keccakf16_memory machine produces big endian outputs in memory. + // However, this keccakf32_memory machine produces big endian outputs in memory. // Therefore, memory write converts little endian from keccak computation to big endian for the output in memory. link if final_step_used ~> mem.mstore(output_addr, time_step, a_prime_prime_prime_0_0_limbs[1]); link if final_step_used ~> mem.mstore(output_addr + 4, time_step, a_prime_prime_prime_0_0_limbs[0]); @@ -234,8 +228,6 @@ machine Keccakf32Memory(mem: Memory) with // pub a_prime_prime_prime_0_0_limbs: [T; U64_LIMBS], // } - // @Steve are these constants correct for 32? - // It would be good to replace all these constants by named constants col witness preimage[5 * 5 * 2]; col witness a[5 * 5 * 2]; col witness c[5 * 64]; @@ -341,7 +333,6 @@ machine Keccakf32Memory(mem: Memory) with // 1 1 0 0 // 1 1 1 1 - // @Steve are these constants correct for 32? array::new(320, |i| { let x = i / 64; let z = i % 64; @@ -382,7 +373,6 @@ machine Keccakf32Memory(mem: Memory) with let bits_to_value_be: expr[] -> expr = |bits_be| array::fold(bits_be, 0, |acc, e| (acc * 2 + e)); - // @Steve are these constants correct for 32? array::new(50, |i| { let y = i / 10; let x = (i / 2) % 5; @@ -406,7 +396,6 @@ machine Keccakf32Memory(mem: Memory) with // } // } - // @Steve are these constants correct for 32? array::new(320, |i| { let x = i / 64; let z = i % 64; @@ -435,7 +424,6 @@ machine Keccakf32Memory(mem: Memory) with // } // } - // @Steve are these constants correct for 32? array::new(50, |i| { let y = i / 10; let x = (i / 2) % 5; @@ -465,7 +453,6 @@ machine Keccakf32Memory(mem: Memory) with // self.a_prime[b][a][(z + 64 - rot) % 64] // } - // @Steve are these constants correct for 32? let b: int, int, int -> expr = |x, y, z| { let a: int = (x + 3 * y) % 5; let rot: int = R[a * 5 + x]; // b = x @@ -485,7 +472,6 @@ machine Keccakf32Memory(mem: Memory) with // builder.assert_eq(computed_a_prime_prime_0_0_limb, a_prime_prime_0_0_limb); // } - // @Steve are these constants correct for 32? array::new(2, |limb| { let limb_bits_be: expr[] = array::reverse(array::new(32, |z| a_prime_prime_0_0_bits[limb * 32 + z])); a_prime_prime[limb] = bits_to_value_be(limb_bits_be) @@ -558,7 +544,6 @@ machine Keccakf32Memory(mem: Memory) with let a_prime_prime_prime: int, int, int -> expr = |y, x, limb| if y == 0 && x == 0 { a_prime_prime_prime_0_0_limbs[limb] } else { a_prime_prime[y * 10 + x * 2 + limb] }; - // @Steve are these constants correct for 32? let R: int[] = [ 0, 36, 3, 41, 18, 1, 44, 10, 45, 2, @@ -567,7 +552,6 @@ machine Keccakf32Memory(mem: Memory) with 27, 20, 39, 8, 14 ]; - // @Steve are these constants correct for 32? let RC: int[] = [ 0x0000000000000001, 0x0000000000008082, @@ -595,7 +579,6 @@ machine Keccakf32Memory(mem: Memory) with 0x8000000080008008 ]; - // @Steve are these constants correct for 32? let RC_BITS: int[] = array::new(24 * 64, |i| { let rc_idx = i / 64; let bit = i % 64; @@ -625,7 +608,6 @@ machine Keccakf32Memory(mem: Memory) with |acc, e| acc ^ e ); - // @Steve are these constants correct for 32? query |row| { let _ = array::map_enumerated(c, |i, c_i| { let x = i / 64; diff --git a/test_data/std/keccakf32_memory_test.asm b/test_data/std/keccakf32_memory_test.asm index f15e9c426b..c4efc9acf2 100644 --- a/test_data/std/keccakf32_memory_test.asm +++ b/test_data/std/keccakf32_memory_test.asm @@ -31,8 +31,7 @@ machine Main with degree: 65536 { } function main { - // Test 1: Input/output address computations have no carry. - // 0 for all 25 64-bit inputs except setting the second 64-bit input to 1. All 64-bit inputs in chunks of 4 16-bit big endian limbs. + // Test 1: 0 for all 25 64-bit inputs except setting the second 64-bit input to 1. All 64-bit inputs in chunks of 2 32-bit big endian limbs. mstore 0, 0; mstore 4, 0; mstore 8, 0; @@ -66,55 +65,17 @@ machine Main with degree: 65536 { assert_eq 204, 0x9001405f; assert_eq 392, 0xeac9f006; assert_eq 396, 0x664deb35; -// - // // Test 2: Same as Test 1 but sets input and output addresses to be the same. - // // No need to rerun the mstores because input values from Test 1 should be intact. - // keccakf16_memory 0, 0, 0, 0; - // // Selectively checking a few registers only. - // // Test vector generated from Tiny Keccak. - // assert_eq 0, 0, 0xfdbb, 0xbbdf; - // assert_eq 0, 4, 0x9001, 0x405f; - // assert_eq 0, 192, 0xeac9, 0xf006; - // assert_eq 0, 196, 0x664d, 0xeb35; -// - // // Test 3: Input/output address computations have carry. - // // 0 for all 25 64-bit inputs except setting the second 64-bit input to 1. All 64-bit inputs in chunks of 4 16-bit big endian limbs. - // mstore 100, 65520, 0, 0; - // mstore 100, 65524, 0, 0; - // mstore 100, 65528, 0, 0; - // mstore 100, 65532, 0, 1; - // mstore 101, 0, 0, 0; - // mstore 101, 4, 0, 0; - // mstore 101, 8, 0, 0; - // mstore 101, 12, 0, 0; - // mstore 101, 16, 0, 0; - // mstore 101, 20, 0, 0; - // mstore 101, 24, 0, 0; - // mstore 101, 28, 0, 0; - // mstore 101, 32, 0, 0; - // mstore 101, 36, 0, 0; - // mstore 101, 40, 0, 0; - // mstore 101, 44, 0, 0; - // mstore 101, 48, 0, 0; - // mstore 101, 52, 0, 0; - // mstore 101, 56, 0, 0; - // mstore 101, 60, 0, 0; - // mstore 101, 64, 0, 0; - // mstore 101, 68, 0, 0; - // mstore 101, 72, 0, 0; - // mstore 101, 76, 0, 0; - // mstore 101, 80, 0, 0; -// - // // Input address (100 * 65536 + 65520). Output address (50000 * 65536 + 65528). - // keccakf16_memory 100, 65520, 50000, 65528; -// - // // Selectively checking a few registers only. - // // Test vector generated from Tiny Keccak. - // assert_eq 50000, 65528, 0xfdbb, 0xbbdf; - // assert_eq 50000, 65532, 0x9001, 0x405f; - // assert_eq 50001, 184, 0xeac9, 0xf006; - // assert_eq 50001, 188, 0x664d, 0xeb35; -// + + // Test 2: Same as Test 1 but sets input and output addresses to be the same. + // No need to rerun the mstores because input values from Test 1 should be intact. + keccakf32_memory 0, 0; + // Selectively checking a few registers only. + // Test vector generated from Tiny Keccak. + assert_eq 0, 0xfdbbbbdf; + assert_eq 4, 0x9001405f; + assert_eq 192, 0xeac9f006; + assert_eq 196, 0x664deb35; + return; } } From fe823e83e429d4f24f4b1803f72d0675bc69b947 Mon Sep 17 00:00:00 2001 From: Steve Wang Date: Mon, 25 Nov 2024 16:37:49 -0500 Subject: [PATCH 6/8] added pipeline --- pipeline/tests/powdr_std.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index 046c057c4f..0ab45fe264 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -58,6 +58,15 @@ fn keccakf16_memory_test() { test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Monolithic); } +#[test] +#[ignore = "Too slow"] +fn keccakf32_memory_test() { + let f = "std/keccakf32_memory_test.asm"; + let pipeline = make_simple_prepared_pipeline(f); + test_pilcom(pipeline.clone()); + gen_estark_proof(pipeline); +} + #[test] #[ignore = "Too slow"] fn poseidon_bb_test() { From 9c6feae2ac4049092b604e3ee2c97fbee57be910 Mon Sep 17 00:00:00 2001 From: Leo Alt Date: Tue, 26 Nov 2024 21:31:21 +0100 Subject: [PATCH 7/8] fix machine and test --- std/machines/hash/keccakf32_memory.asm | 100 ++++++++++++------------ test_data/std/keccakf32_memory_test.asm | 2 +- 2 files changed, 51 insertions(+), 51 deletions(-) diff --git a/std/machines/hash/keccakf32_memory.asm b/std/machines/hash/keccakf32_memory.asm index dbb1d6604f..8b166b9777 100644 --- a/std/machines/hash/keccakf32_memory.asm +++ b/std/machines/hash/keccakf32_memory.asm @@ -123,56 +123,56 @@ machine Keccakf32Memory(mem: Memory) with // Specifically, output obtained from the keccak computation are little endian. // However, this keccakf32_memory machine produces big endian outputs in memory. // Therefore, memory write converts little endian from keccak computation to big endian for the output in memory. - link if final_step_used ~> mem.mstore(output_addr, time_step, a_prime_prime_prime_0_0_limbs[1]); - link if final_step_used ~> mem.mstore(output_addr + 4, time_step, a_prime_prime_prime_0_0_limbs[0]); - link if final_step_used ~> mem.mstore(output_addr + 8, time_step, a_prime_prime[3]); - link if final_step_used ~> mem.mstore(output_addr + 12, time_step, a_prime_prime[2]); - link if final_step_used ~> mem.mstore(output_addr + 16, time_step, a_prime_prime[5]); - link if final_step_used ~> mem.mstore(output_addr + 20, time_step, a_prime_prime[4]); - link if final_step_used ~> mem.mstore(output_addr + 24, time_step, a_prime_prime[7]); - link if final_step_used ~> mem.mstore(output_addr + 28, time_step, a_prime_prime[6]); - link if final_step_used ~> mem.mstore(output_addr + 32, time_step, a_prime_prime[9]); - link if final_step_used ~> mem.mstore(output_addr + 36, time_step, a_prime_prime[8]); - link if final_step_used ~> mem.mstore(output_addr + 40, time_step, a_prime_prime[11]); - link if final_step_used ~> mem.mstore(output_addr + 44, time_step, a_prime_prime[10]); - link if final_step_used ~> mem.mstore(output_addr + 48, time_step, a_prime_prime[13]); - link if final_step_used ~> mem.mstore(output_addr + 52, time_step, a_prime_prime[12]); - link if final_step_used ~> mem.mstore(output_addr + 56, time_step, a_prime_prime[15]); - link if final_step_used ~> mem.mstore(output_addr + 60, time_step, a_prime_prime[14]); - link if final_step_used ~> mem.mstore(output_addr + 64, time_step, a_prime_prime[17]); - link if final_step_used ~> mem.mstore(output_addr + 68, time_step, a_prime_prime[16]); - link if final_step_used ~> mem.mstore(output_addr + 72, time_step, a_prime_prime[19]); - link if final_step_used ~> mem.mstore(output_addr + 76, time_step, a_prime_prime[18]); - link if final_step_used ~> mem.mstore(output_addr + 80, time_step, a_prime_prime[21]); - link if final_step_used ~> mem.mstore(output_addr + 84, time_step, a_prime_prime[20]); - link if final_step_used ~> mem.mstore(output_addr + 88, time_step, a_prime_prime[23]); - link if final_step_used ~> mem.mstore(output_addr + 92, time_step, a_prime_prime[22]); - link if final_step_used ~> mem.mstore(output_addr + 96, time_step, a_prime_prime[25]); - link if final_step_used ~> mem.mstore(output_addr + 100, time_step, a_prime_prime[24]); - link if final_step_used ~> mem.mstore(output_addr + 104, time_step, a_prime_prime[27]); - link if final_step_used ~> mem.mstore(output_addr + 108, time_step, a_prime_prime[26]); - link if final_step_used ~> mem.mstore(output_addr + 112, time_step, a_prime_prime[29]); - link if final_step_used ~> mem.mstore(output_addr + 116, time_step, a_prime_prime[28]); - link if final_step_used ~> mem.mstore(output_addr + 120, time_step, a_prime_prime[31]); - link if final_step_used ~> mem.mstore(output_addr + 124, time_step, a_prime_prime[30]); - link if final_step_used ~> mem.mstore(output_addr + 128, time_step, a_prime_prime[33]); - link if final_step_used ~> mem.mstore(output_addr + 132, time_step, a_prime_prime[32]); - link if final_step_used ~> mem.mstore(output_addr + 136, time_step, a_prime_prime[35]); - link if final_step_used ~> mem.mstore(output_addr + 140, time_step, a_prime_prime[34]); - link if final_step_used ~> mem.mstore(output_addr + 144, time_step, a_prime_prime[37]); - link if final_step_used ~> mem.mstore(output_addr + 148, time_step, a_prime_prime[36]); - link if final_step_used ~> mem.mstore(output_addr + 152, time_step, a_prime_prime[39]); - link if final_step_used ~> mem.mstore(output_addr + 156, time_step, a_prime_prime[38]); - link if final_step_used ~> mem.mstore(output_addr + 160, time_step, a_prime_prime[41]); - link if final_step_used ~> mem.mstore(output_addr + 164, time_step, a_prime_prime[40]); - link if final_step_used ~> mem.mstore(output_addr + 168, time_step, a_prime_prime[43]); - link if final_step_used ~> mem.mstore(output_addr + 172, time_step, a_prime_prime[42]); - link if final_step_used ~> mem.mstore(output_addr + 176, time_step, a_prime_prime[45]); - link if final_step_used ~> mem.mstore(output_addr + 180, time_step, a_prime_prime[44]); - link if final_step_used ~> mem.mstore(output_addr + 184, time_step, a_prime_prime[47]); - link if final_step_used ~> mem.mstore(output_addr + 188, time_step, a_prime_prime[46]); - link if final_step_used ~> mem.mstore(output_addr + 192, time_step, a_prime_prime[49]); - link if final_step_used ~> mem.mstore(output_addr + 196, time_step, a_prime_prime[48]); + link if final_step_used ~> mem.mstore(output_addr, time_step + 1, a_prime_prime_prime_0_0_limbs[1]); + link if final_step_used ~> mem.mstore(output_addr + 4, time_step + 1, a_prime_prime_prime_0_0_limbs[0]); + link if final_step_used ~> mem.mstore(output_addr + 8, time_step + 1, a_prime_prime[3]); + link if final_step_used ~> mem.mstore(output_addr + 12, time_step + 1, a_prime_prime[2]); + link if final_step_used ~> mem.mstore(output_addr + 16, time_step + 1, a_prime_prime[5]); + link if final_step_used ~> mem.mstore(output_addr + 20, time_step + 1, a_prime_prime[4]); + link if final_step_used ~> mem.mstore(output_addr + 24, time_step + 1, a_prime_prime[7]); + link if final_step_used ~> mem.mstore(output_addr + 28, time_step + 1, a_prime_prime[6]); + link if final_step_used ~> mem.mstore(output_addr + 32, time_step + 1, a_prime_prime[9]); + link if final_step_used ~> mem.mstore(output_addr + 36, time_step + 1, a_prime_prime[8]); + link if final_step_used ~> mem.mstore(output_addr + 40, time_step + 1, a_prime_prime[11]); + link if final_step_used ~> mem.mstore(output_addr + 44, time_step + 1, a_prime_prime[10]); + link if final_step_used ~> mem.mstore(output_addr + 48, time_step + 1, a_prime_prime[13]); + link if final_step_used ~> mem.mstore(output_addr + 52, time_step + 1, a_prime_prime[12]); + link if final_step_used ~> mem.mstore(output_addr + 56, time_step + 1, a_prime_prime[15]); + link if final_step_used ~> mem.mstore(output_addr + 60, time_step + 1, a_prime_prime[14]); + link if final_step_used ~> mem.mstore(output_addr + 64, time_step + 1, a_prime_prime[17]); + link if final_step_used ~> mem.mstore(output_addr + 68, time_step + 1, a_prime_prime[16]); + link if final_step_used ~> mem.mstore(output_addr + 72, time_step + 1, a_prime_prime[19]); + link if final_step_used ~> mem.mstore(output_addr + 76, time_step + 1, a_prime_prime[18]); + link if final_step_used ~> mem.mstore(output_addr + 80, time_step + 1, a_prime_prime[21]); + link if final_step_used ~> mem.mstore(output_addr + 84, time_step + 1, a_prime_prime[20]); + link if final_step_used ~> mem.mstore(output_addr + 88, time_step + 1, a_prime_prime[23]); + link if final_step_used ~> mem.mstore(output_addr + 92, time_step + 1, a_prime_prime[22]); + link if final_step_used ~> mem.mstore(output_addr + 96, time_step + 1, a_prime_prime[25]); + link if final_step_used ~> mem.mstore(output_addr + 100, time_step + 1, a_prime_prime[24]); + link if final_step_used ~> mem.mstore(output_addr + 104, time_step + 1, a_prime_prime[27]); + link if final_step_used ~> mem.mstore(output_addr + 108, time_step + 1, a_prime_prime[26]); + link if final_step_used ~> mem.mstore(output_addr + 112, time_step + 1, a_prime_prime[29]); + link if final_step_used ~> mem.mstore(output_addr + 116, time_step + 1, a_prime_prime[28]); + link if final_step_used ~> mem.mstore(output_addr + 120, time_step + 1, a_prime_prime[31]); + link if final_step_used ~> mem.mstore(output_addr + 124, time_step + 1, a_prime_prime[30]); + link if final_step_used ~> mem.mstore(output_addr + 128, time_step + 1, a_prime_prime[33]); + link if final_step_used ~> mem.mstore(output_addr + 132, time_step + 1, a_prime_prime[32]); + link if final_step_used ~> mem.mstore(output_addr + 136, time_step + 1, a_prime_prime[35]); + link if final_step_used ~> mem.mstore(output_addr + 140, time_step + 1, a_prime_prime[34]); + link if final_step_used ~> mem.mstore(output_addr + 144, time_step + 1, a_prime_prime[37]); + link if final_step_used ~> mem.mstore(output_addr + 148, time_step + 1, a_prime_prime[36]); + link if final_step_used ~> mem.mstore(output_addr + 152, time_step + 1, a_prime_prime[39]); + link if final_step_used ~> mem.mstore(output_addr + 156, time_step + 1, a_prime_prime[38]); + link if final_step_used ~> mem.mstore(output_addr + 160, time_step + 1, a_prime_prime[41]); + link if final_step_used ~> mem.mstore(output_addr + 164, time_step + 1, a_prime_prime[40]); + link if final_step_used ~> mem.mstore(output_addr + 168, time_step + 1, a_prime_prime[43]); + link if final_step_used ~> mem.mstore(output_addr + 172, time_step + 1, a_prime_prime[42]); + link if final_step_used ~> mem.mstore(output_addr + 176, time_step + 1, a_prime_prime[45]); + link if final_step_used ~> mem.mstore(output_addr + 180, time_step + 1, a_prime_prime[44]); + link if final_step_used ~> mem.mstore(output_addr + 184, time_step + 1, a_prime_prime[47]); + link if final_step_used ~> mem.mstore(output_addr + 188, time_step + 1, a_prime_prime[46]); + link if final_step_used ~> mem.mstore(output_addr + 192, time_step + 1, a_prime_prime[49]); + link if final_step_used ~> mem.mstore(output_addr + 196, time_step + 1, a_prime_prime[48]); // ------------- End memory read / write --------------- // Adapted from Plonky3 implementation of Keccak: https://github.com/Plonky3/Plonky3/tree/main/keccak-air/src diff --git a/test_data/std/keccakf32_memory_test.asm b/test_data/std/keccakf32_memory_test.asm index c4efc9acf2..299326f281 100644 --- a/test_data/std/keccakf32_memory_test.asm +++ b/test_data/std/keccakf32_memory_test.asm @@ -14,7 +14,7 @@ machine Main with degree: 65536 { Keccakf32Memory keccakf32_memory(memory); - col fixed STEP(i) { i }; + col fixed STEP(i) { i * 2 }; // Big endian. // Usage: mstore addr, val; From 306ab73425ddebae651a8f4f0588c10ff029c08f Mon Sep 17 00:00:00 2001 From: Leo Alt Date: Tue, 26 Nov 2024 21:46:47 +0100 Subject: [PATCH 8/8] test --- pipeline/tests/powdr_std.rs | 5 ++--- test_data/std/keccakf32_memory_test.asm | 8 +++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pipeline/tests/powdr_std.rs b/pipeline/tests/powdr_std.rs index ab822bdcbc..7726f547eb 100644 --- a/pipeline/tests/powdr_std.rs +++ b/pipeline/tests/powdr_std.rs @@ -65,9 +65,8 @@ fn keccakf16_memory_test() { #[ignore = "Too slow"] fn keccakf32_memory_test() { let f = "std/keccakf32_memory_test.asm"; - let pipeline = make_simple_prepared_pipeline(f); - test_pilcom(pipeline.clone()); - gen_estark_proof(pipeline); + test_mock_backend(make_simple_prepared_pipeline::(f)); + test_plonky3_with_backend_variant::(f, vec![], BackendVariant::Monolithic); } #[test] diff --git a/test_data/std/keccakf32_memory_test.asm b/test_data/std/keccakf32_memory_test.asm index 299326f281..fbea832b3e 100644 --- a/test_data/std/keccakf32_memory_test.asm +++ b/test_data/std/keccakf32_memory_test.asm @@ -2,7 +2,9 @@ use std::machines::hash::keccakf32_memory::Keccakf32Memory; use std::machines::large_field::memory::Memory; use std::machines::range::Byte2; -machine Main with degree: 65536 { +let MIN: int = 2**5; +let MAX: int = 2**8; +machine Main with min_degree: MIN, max_degree: MAX { reg pc[@pc]; reg X[<=]; @@ -10,9 +12,9 @@ machine Main with degree: 65536 { reg Y[<=]; Byte2 byte2; - Memory memory(byte2); + Memory memory(byte2, MIN, MAX); - Keccakf32Memory keccakf32_memory(memory); + Keccakf32Memory keccakf32_memory(memory, MIN, MAX); col fixed STEP(i) { i * 2 };