From 2049a358a91b64f2169811b11bb5230f8eca756a Mon Sep 17 00:00:00 2001 From: Lucas Clemente Vella Date: Fri, 15 Nov 2024 22:03:58 +0000 Subject: [PATCH] Inlining system calls when possible. (#2087) Solves #2040. --- .github/workflows/pr-tests.yml | 2 + riscv-runtime/src/arith.rs | 107 ++++++++++++------------------ riscv-runtime/src/ec.rs | 62 ++++++++--------- riscv-runtime/src/ecall.rs | 17 +++++ riscv-runtime/src/fmt.rs | 2 +- riscv-runtime/src/hash.rs | 19 ++---- riscv-runtime/src/io.rs | 8 +-- riscv-runtime/src/lib.rs | 16 +++-- riscv-runtime/src/std_support.rs | 2 +- riscv-syscalls/src/lib.rs | 18 +++-- riscv/Cargo.toml | 4 ++ riscv/src/elf/mod.rs | 25 +++++++ riscv/src/large_field/code_gen.rs | 16 +++-- riscv/src/large_field/runtime.rs | 32 +++++---- riscv/src/runtime.rs | 6 +- riscv/src/small_field/code_gen.rs | 16 +++-- riscv/src/small_field/runtime.rs | 39 ++++++----- riscv/tests/riscv.rs | 71 ++++++++++++++++++-- 18 files changed, 291 insertions(+), 171 deletions(-) create mode 100644 riscv-runtime/src/ecall.rs diff --git a/.github/workflows/pr-tests.yml b/.github/workflows/pr-tests.yml index 76a6e98175..a3f5760b34 100644 --- a/.github/workflows/pr-tests.yml +++ b/.github/workflows/pr-tests.yml @@ -106,6 +106,8 @@ jobs: run: rustup component add rust-src --toolchain nightly-2024-08-01-x86_64-unknown-linux-gnu - name: Install riscv target run: rustup target add riscv32imac-unknown-none-elf --toolchain nightly-2024-08-01-x86_64-unknown-linux-gnu + - name: Install test dependencies + run: sudo apt-get install -y binutils-riscv64-unknown-elf lld - name: Install pilcom run: git clone https://github.com/0xPolygonHermez/pilcom.git && cd pilcom && npm install - uses: taiki-e/install-action@nextest diff --git a/riscv-runtime/src/arith.rs b/riscv-runtime/src/arith.rs index 012b7bd53e..b3f5d0ca38 100644 --- a/riscv-runtime/src/arith.rs +++ b/riscv-runtime/src/arith.rs @@ -2,6 +2,8 @@ use core::arch::asm; use powdr_riscv_syscalls::Syscall; +use crate::ecall; + /// convert a big-endian u8 array to u32 array (arith machine format) pub(crate) fn be_to_u32(from: &[u8; 32], to: &mut [u32; 8]) { for (i, chunk) in from.chunks_exact(4).rev().enumerate() { @@ -29,11 +31,10 @@ pub fn affine_256_u8_be(mut a: [u8; 32], mut b: [u8; 32], c: [u8; 32]) -> ([u8; be_to_u32(&c, &mut c1); unsafe { - asm!("ecall", - in("a0") &mut a1 as *mut [u32; 8], - in("a1") &mut b1 as *mut [u32; 8], - in("a2") &mut c1 as *mut [u32; 8], - in("t0") u32::from(Syscall::Affine256)); + ecall!(Syscall::Affine256, + in("a0") a1.as_mut_ptr(), + in("a1") b1.as_mut_ptr(), + in("a2") c1.as_ptr()); } u32_to_be(&a1, &mut a); @@ -45,11 +46,10 @@ pub fn affine_256_u8_be(mut a: [u8; 32], mut b: [u8; 32], c: [u8; 32]) -> ([u8; /// Returns `(hi, lo)`. pub fn affine_256_u8_le(mut a: [u8; 32], mut b: [u8; 32], c: [u8; 32]) -> ([u8; 32], [u8; 32]) { unsafe { - asm!("ecall", - in("a0") a.as_mut_ptr() as *mut [u32; 8], - in("a1") b.as_mut_ptr() as *mut [u32; 8], - in("a2") c.as_ptr() as *const [u32; 8], - in("t0") u32::from(Syscall::Affine256)); + ecall!(Syscall::Affine256, + in("a0") a.as_mut_ptr(), + in("a1") b.as_mut_ptr(), + in("a2") c.as_ptr()); } (a, b) @@ -57,28 +57,19 @@ pub fn affine_256_u8_le(mut a: [u8; 32], mut b: [u8; 32], c: [u8; 32]) -> ([u8; /// Calculate `a*b + c = hi*2**256 + lo` for 256 bit values (as u32 little-endian arrays). /// Returns `(hi, lo)`. -pub fn affine_256_u32_le( - mut a: [u32; 8], - mut b: [u32; 8], - mut c: [u32; 8], -) -> ([u32; 8], [u32; 8]) { +pub fn affine_256_u32_le(mut a: [u32; 8], mut b: [u32; 8], c: [u32; 8]) -> ([u32; 8], [u32; 8]) { unsafe { - asm!("ecall", - in("a0") &mut a as *mut [u32; 8], - in("a1") &mut b as *mut [u32; 8], - in("a2") &mut c as *mut [u32; 8], - in("t0") u32::from(Syscall::Affine256)); + ecall!(Syscall::Affine256, + in("a0") a.as_mut_ptr(), + in("a1") b.as_mut_ptr(), + in("a2") c.as_ptr()); } (a, b) } /// Calculate `(a*b) % m = r` for 256 bit values (as u8 big-endian arrays). /// Returns `r`. -pub fn modmul_256_u8_be( - mut a: [u8; 32], - b: [u8; 32], - m: [u8; 32], -) -> [u8; 32] { +pub fn modmul_256_u8_be(mut a: [u8; 32], b: [u8; 32], m: [u8; 32]) -> [u8; 32] { let mut a1: [u32; 8] = Default::default(); let mut b1: [u32; 8] = Default::default(); let mut m1: [u32; 8] = Default::default(); @@ -90,17 +81,15 @@ pub fn modmul_256_u8_be( unsafe { // First compute the two halves of the result a*b. // Results are stored in place in a and b. - asm!("ecall", - in("a0") &mut a1 as *mut [u32; 8], - in("a1") &mut b1 as *mut [u32; 8], - in("a2") &mut [0u32; 8] as *mut [u32; 8], - in("t0") u32::from(Syscall::Affine256)); + ecall!(Syscall::Affine256, + in("a0") a1.as_mut_ptr(), + in("a1") b1.as_mut_ptr(), + in("a2") [0u32; 8].as_ptr()); // Next compute the remainder, stored in place in a. - asm!("ecall", - in("a0") &mut a1 as *mut [u32; 8], - in("a1") &mut b1 as *mut [u32; 8], - in("a2") &mut m1 as *mut [u32; 8], - in("t0") u32::from(Syscall::Mod256)); + ecall!(Syscall::Mod256, + in("a0") a1.as_mut_ptr(), + in("a1") b1.as_ptr(), + in("a2") m1.as_ptr()); } u32_to_be(&a1, &mut a); @@ -109,25 +98,19 @@ pub fn modmul_256_u8_be( /// Calculate `(a*b) % m = r` for 256 bit values (as u8 little-endian arrays). /// Returns `r`. -pub fn modmul_256_u8_le( - mut a: [u8; 32], - mut b: [u8; 32], - m: [u8; 32], -) -> [u8; 32] { +pub fn modmul_256_u8_le(mut a: [u8; 32], mut b: [u8; 32], m: [u8; 32]) -> [u8; 32] { unsafe { // First compute the two halves of the result a*b. // Results are stored in place in a and b. - asm!("ecall", - in("a0") a.as_mut_ptr() as *mut [u32; 8], - in("a1") b.as_mut_ptr() as *mut [u32; 8], - in("a2") &mut [0u32; 8] as *mut [u32; 8], - in("t0") u32::from(Syscall::Affine256)); + ecall!(Syscall::Affine256, + in("a0") a.as_mut_ptr(), + in("a1") b.as_mut_ptr(), + in("a2") [0u32; 8].as_ptr()); // Next compute the remainder, stored in place in a. - asm!("ecall", - in("a0") a.as_mut_ptr() as *mut [u32; 8], - in("a1") b.as_mut_ptr() as *mut [u32; 8], - in("a2") m.as_ptr() as *const [u32; 8], - in("t0") u32::from(Syscall::Mod256)); + ecall!(Syscall::Mod256, + in("a0") a.as_mut_ptr(), + in("a1") b.as_ptr(), + in("a2") m.as_ptr()); } a @@ -135,25 +118,19 @@ pub fn modmul_256_u8_le( /// Calculate `(a*b) % m = r` for 256 bit values (as u32 little-endian arrays). /// Returns `r`. -pub fn modmul_256_u32_le( - mut a: [u32; 8], - mut b: [u32; 8], - m: [u32; 8], -) -> [u32; 8] { +pub fn modmul_256_u32_le(mut a: [u32; 8], mut b: [u32; 8], m: [u32; 8]) -> [u32; 8] { unsafe { // First compute the two halves of the result a*b. // Results are stored in place in a and b. - asm!("ecall", - in("a0") &mut a as *mut [u32; 8], - in("a1") &mut b as *mut [u32; 8], - in("a2") &[0u32; 8] as *const [u32; 8], - in("t0") u32::from(Syscall::Affine256)); + ecall!(Syscall::Affine256, + in("a0") a.as_mut_ptr(), + in("a1") b.as_mut_ptr(), + in("a2") [0u32; 8].as_ptr()); // Next compute the remainder, stored in place in a. - asm!("ecall", - in("a0") &mut a as *mut [u32; 8], - in("a1") &mut b as *mut [u32; 8], - in("a2") &m as *const [u32; 8], - in("t0") u32::from(Syscall::Mod256)); + ecall!(Syscall::Mod256, + in("a0") a.as_mut_ptr(), + in("a1") b.as_ptr(), + in("a2") m.as_ptr()); } a diff --git a/riscv-runtime/src/ec.rs b/riscv-runtime/src/ec.rs index 14594417f6..13a8faa141 100644 --- a/riscv-runtime/src/ec.rs +++ b/riscv-runtime/src/ec.rs @@ -21,12 +21,11 @@ pub fn add_u8_be( be_to_u32(&by, &mut by1); unsafe { - asm!("ecall", - in("a0") &mut ax1 as *mut [u32; 8], - in("a1") &mut ay1 as *mut [u32; 8], - in("a2") &mut bx1 as *mut [u32; 8], - in("a3") &mut by1 as *mut [u32; 8], - in("t0") u32::from(Syscall::EcAdd)); + ecall!(Syscall::EcAdd, + in("a0") ax1.as_mut_ptr(), + in("a1") ay1.as_mut_ptr(), + in("a2") bx1.as_ptr(), + in("a3") by1.as_ptr()); } u32_to_be(&ax1, &mut ax); @@ -39,16 +38,15 @@ pub fn add_u8_be( pub fn add_u8_le( mut ax: [u8; 32], mut ay: [u8; 32], - mut bx: [u8; 32], - mut by: [u8; 32], + bx: [u8; 32], + by: [u8; 32], ) -> ([u8; 32], [u8; 32]) { unsafe { - asm!("ecall", - in("a0") ax.as_mut_ptr() as *mut [u32; 8], - in("a1") ay.as_mut_ptr() as *mut [u32; 8], - in("a2") bx.as_mut_ptr() as *mut [u32; 8], - in("a3") by.as_mut_ptr() as *mut [u32; 8], - in("t0") u32::from(Syscall::EcAdd)); + ecall!(Syscall::EcAdd, + in("a0") ax.as_mut_ptr(), + in("a1") ay.as_mut_ptr(), + in("a2") bx.as_ptr(), + in("a3") by.as_ptr()); } (ax, ay) } @@ -57,16 +55,15 @@ pub fn add_u8_le( pub fn add_u32_le( mut ax: [u32; 8], mut ay: [u32; 8], - mut bx: [u32; 8], - mut by: [u32; 8], + bx: [u32; 8], + by: [u32; 8], ) -> ([u32; 8], [u32; 8]) { unsafe { - asm!("ecall", - in("a0") &mut ax as *mut [u32; 8], - in("a1") &mut ay as *mut [u32; 8], - in("a2") &mut bx as *mut [u32; 8], - in("a3") &mut by as *mut [u32; 8], - in("t0") u32::from(Syscall::EcAdd)); + ecall!(Syscall::EcAdd, + in("a0") ax.as_mut_ptr(), + in("a1") ay.as_mut_ptr(), + in("a2") bx.as_ptr(), + in("a3") by.as_ptr()); } (ax, ay) } @@ -80,10 +77,9 @@ pub fn double_u8_be(mut x: [u8; 32], mut y: [u8; 32]) -> ([u8; 32], [u8; 32]) { be_to_u32(&y, &mut y1); unsafe { - asm!("ecall", - in("a0") &mut x1 as *mut [u32; 8], - in("a1") &mut y1 as *mut [u32; 8], - in("t0") u32::from(Syscall::EcDouble)); + ecall!(Syscall::EcDouble, + in("a0") x1.as_mut_ptr(), + in("a1") y1.as_mut_ptr()); } u32_to_be(&x1, &mut x); @@ -95,10 +91,9 @@ pub fn double_u8_be(mut x: [u8; 32], mut y: [u8; 32]) -> ([u8; 32], [u8; 32]) { /// Double a k256 ec point. Coordinates are little-endian u8 arrays. pub fn double_u8_le(mut x: [u8; 32], mut y: [u8; 32]) -> ([u8; 32], [u8; 32]) { unsafe { - asm!("ecall", - in("a0") x.as_mut_ptr() as *mut [u32; 8], - in("a1") y.as_mut_ptr() as *mut [u32; 8], - in("t0") u32::from(Syscall::EcDouble)); + ecall!(Syscall::EcDouble, + in("a0") x.as_mut_ptr(), + in("a1") y.as_mut_ptr()); } (x, y) @@ -107,10 +102,9 @@ pub fn double_u8_le(mut x: [u8; 32], mut y: [u8; 32]) -> ([u8; 32], [u8; 32]) { /// Double a k256 ec point. Coordinates are little-endian u32 arrays. pub fn double_u32_le(mut x: [u32; 8], mut y: [u32; 8]) -> ([u32; 8], [u32; 8]) { unsafe { - asm!("ecall", - in("a0") &mut x as *mut [u32; 8], - in("a1") &mut y as *mut [u32; 8], - in("t0") u32::from(Syscall::EcDouble)); + ecall!(Syscall::EcDouble, + in("a0") x.as_mut_ptr(), + in("a1") y.as_mut_ptr()); } (x, y) } diff --git a/riscv-runtime/src/ecall.rs b/riscv-runtime/src/ecall.rs new file mode 100644 index 0000000000..7d19eb356b --- /dev/null +++ b/riscv-runtime/src/ecall.rs @@ -0,0 +1,17 @@ +/// Generates an ecall instruction with the given system call number and arguments. +/// +/// Uses the instruction sequence convention for the system call to be inlined. +#[macro_export] +macro_rules! ecall { + ($syscall:expr, $($tokens:tt)*) => { + asm!( + "addi t0, x0, {}", + "ecall", + const $syscall as u8, + // No system call we have at this point allocates on stack. + options(nostack), + out("t0") _, + $($tokens)* + ); + }; +} diff --git a/riscv-runtime/src/fmt.rs b/riscv-runtime/src/fmt.rs index 2297bd2b03..495d36417b 100644 --- a/riscv-runtime/src/fmt.rs +++ b/riscv-runtime/src/fmt.rs @@ -37,6 +37,6 @@ fn print_prover_char(c: u8) { let mut value = c as u32; #[allow(unused_assignments)] unsafe { - asm!("ecall", lateout("a0") value, in("a0") 1, in("a1") value, in("t0") u32::from(Syscall::Output)); + ecall!(Syscall::Output, lateout("a0") value, in("a0") 1, in("a1") value); } } diff --git a/riscv-runtime/src/hash.rs b/riscv-runtime/src/hash.rs index f74af1cbff..caa4ee7310 100644 --- a/riscv-runtime/src/hash.rs +++ b/riscv-runtime/src/hash.rs @@ -7,7 +7,7 @@ use powdr_riscv_syscalls::Syscall; pub fn native_hash(data: &mut [u64; 12]) -> &[u64; 4] { unsafe { - asm!("ecall", in("a0") data as *mut _, in("t0") u32::from(Syscall::NativeHash)); + ecall!(Syscall::NativeHash, in("a0") data); } data[..4].try_into().unwrap() } @@ -17,20 +17,15 @@ pub fn native_hash(data: &mut [u64; 12]) -> &[u64; 4] { /// sub-array is returned. pub fn poseidon_gl(data: &mut [Goldilocks; 12]) -> &[Goldilocks; 4] { unsafe { - asm!("ecall", in("a0") data as *mut _, in("t0") u32::from(Syscall::PoseidonGL)); + ecall!(Syscall::PoseidonGL, in("a0") data); } data[..4].try_into().unwrap() } /// Perform one Poseidon2 permutation with 8 Goldilocks field elements in-place. pub fn poseidon2_gl_inplace(data: &mut [Goldilocks; 8]) { - let ptr = data as *mut _; unsafe { - asm!("ecall", - in("a0") ptr, - in("a1") ptr, - in("t0") u32::from(Syscall::Poseidon2GL) - ); + ecall!(Syscall::Poseidon2GL, in("a0") data, in("a1") data); } } @@ -38,11 +33,7 @@ pub fn poseidon2_gl_inplace(data: &mut [Goldilocks; 8]) { pub fn poseidon2_gl(data: &[Goldilocks; 8]) -> [Goldilocks; 8] { unsafe { let mut output: MaybeUninit<[Goldilocks; 8]> = MaybeUninit::uninit(); - asm!("ecall", - in("a0") data as *const _, - in("a1") output.as_mut_ptr(), - in("t0") u32::from(Syscall::Poseidon2GL) - ); + ecall!(Syscall::Poseidon2GL, in("a0") data, in("a1") output.as_mut_ptr()); output.assume_init() } } @@ -52,7 +43,7 @@ pub fn poseidon2_gl(data: &[Goldilocks; 8]) -> [Goldilocks; 8] { pub fn keccakf(input: &[u64; 25], output: &mut [u64; 25]) { unsafe { // Syscall inputs: memory pointer to input array and memory pointer to output array. - asm!("ecall", in("a0") input, in("a1") output, in("t0") u32::from(Syscall::KeccakF)); + ecall!(Syscall::KeccakF, in("a0") input, in("a1") output); } } diff --git a/riscv-runtime/src/io.rs b/riscv-runtime/src/io.rs index 198ca01f98..1345270455 100644 --- a/riscv-runtime/src/io.rs +++ b/riscv-runtime/src/io.rs @@ -11,7 +11,7 @@ use alloc::vec::Vec; pub fn read_u32(idx: u32) -> u32 { let mut value: u32; unsafe { - asm!("ecall", lateout("a0") value, in("a0") 0, in("a1") idx + 1, in("t0") u32::from(Syscall::Input)); + ecall!(Syscall::Input, lateout("a0") value, in("a0") 0, in("a1") idx + 1); } value } @@ -20,7 +20,7 @@ pub fn read_u32(idx: u32) -> u32 { pub fn read_slice(fd: u32, data: &mut [u32]) { for (i, d) in data.iter_mut().enumerate() { unsafe { - asm!("ecall", lateout("a0") *d, in("a0") fd, in("a1") (i+1) as u32, in("t0") u32::from(Syscall::Input)) + ecall!(Syscall::Input, lateout("a0") *d, in("a0") fd, in("a1") (i+1) as u32); }; } } @@ -29,7 +29,7 @@ pub fn read_slice(fd: u32, data: &mut [u32]) { pub fn read_data_len(fd: u32) -> usize { let mut out: u32; unsafe { - asm!("ecall", lateout("a0") out, in("a0") fd, in("a1") 0, in("t0") u32::from(Syscall::Input)) + ecall!(Syscall::Input, lateout("a0") out, in("a0") fd, in("a1") 0); }; out as usize } @@ -37,7 +37,7 @@ pub fn read_data_len(fd: u32) -> usize { /// Writes a single u8 to the file descriptor fd. pub fn write_u8(fd: u32, byte: u8) { unsafe { - asm!("ecall", in("a0") fd, in("a1") byte, in("t0") u32::from(Syscall::Output)); + ecall!(Syscall::Output, in("a0") fd, in("a1") byte); } } diff --git a/riscv-runtime/src/lib.rs b/riscv-runtime/src/lib.rs index 9aafdc06c0..db1e06e295 100644 --- a/riscv-runtime/src/lib.rs +++ b/riscv-runtime/src/lib.rs @@ -3,11 +3,12 @@ start, alloc_error_handler, maybe_uninit_write_slice, - round_char_boundary + round_char_boundary, + asm_const )] -use core::arch::{asm, global_asm}; -use powdr_riscv_syscalls::Syscall; +#[macro_use] +mod ecall; mod allocator; pub mod arith; @@ -26,11 +27,14 @@ mod no_std_support; #[cfg(feature = "std")] mod std_support; +use core::arch::{asm, global_asm}; +use powdr_riscv_syscalls::Syscall; + #[no_mangle] pub fn halt() -> ! { finalize(); unsafe { - asm!("ecall", in("t0") u32::from(Syscall::Halt)); + ecall!(Syscall::Halt,); } #[allow(clippy::empty_loop)] loop {} @@ -43,8 +47,8 @@ pub fn finalize() { let low = *limb as u32; let high = (*limb >> 32) as u32; // TODO this is not going to work properly for BB for now. - asm!("ecall", in("t0") u32::from(Syscall::CommitPublic), in("a0") i * 2, in("a1") low); - asm!("ecall", in("t0") u32::from(Syscall::CommitPublic), in("a0") i * 2 + 1, in("a1") high); + ecall!(Syscall::CommitPublic, in("a0") i * 2, in("a1") low); + ecall!(Syscall::CommitPublic, in("a0") i * 2 + 1, in("a1") high); } } } diff --git a/riscv-runtime/src/std_support.rs b/riscv-runtime/src/std_support.rs index 2727784469..cdeecdbc92 100644 --- a/riscv-runtime/src/std_support.rs +++ b/riscv-runtime/src/std_support.rs @@ -25,7 +25,7 @@ extern "C" fn sys_rand(buf: *mut u32, words: usize) { #[no_mangle] extern "C" fn sys_panic(msg_ptr: *const u8, len: usize) -> ! { - let out = u32::from(Syscall::Output); + let out: u32 = u8::from(Syscall::Output).into(); unsafe { write_slice(out, "Panic: ".as_bytes()); write_slice(out, slice::from_raw_parts(msg_ptr, len)); diff --git a/riscv-syscalls/src/lib.rs b/riscv-syscalls/src/lib.rs index edc437918e..b4a47725a8 100644 --- a/riscv-syscalls/src/lib.rs +++ b/riscv-syscalls/src/lib.rs @@ -2,12 +2,22 @@ macro_rules! syscalls { ($(($num:expr, $identifier:ident, $name:expr)),* $(,)?) => { + /// We use repr(u8) to make sure the enum discriminant will fit into the + /// 12 bits of the immediate field of the `addi` instruction, #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)] - #[repr(u32)] + #[repr(u8)] pub enum Syscall { $($identifier = $num),* } + impl Syscall { + pub fn name(&self) -> &'static str { + match self { + $(Syscall::$identifier => $name),* + } + } + } + impl core::fmt::Display for Syscall { fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { write!(f, "{}", match self { @@ -26,15 +36,15 @@ macro_rules! syscalls { } } - impl From for u32 { + impl From for u8 { fn from(syscall: Syscall) -> Self { syscall as Self } } - impl core::convert::TryFrom for Syscall { + impl core::convert::TryFrom for Syscall { type Error = (); - fn try_from(value: u32) -> Result { + fn try_from(value: u8) -> Result { match value { $($num => Ok(Syscall::$identifier)),*, _ => Err(()), diff --git a/riscv/Cargo.toml b/riscv/Cargo.toml index b8fa1ab6ea..0d1c6f9310 100644 --- a/riscv/Cargo.toml +++ b/riscv/Cargo.toml @@ -10,8 +10,12 @@ repository = { workspace = true } [features] default = [] # complex-tests is disabled by default + +# features below only affects tests +# not sure this is the best approach complex-tests = [] estark-polygon = ["powdr-pipeline/estark-polygon"] +plonky3 = ["powdr-pipeline/plonky3"] [dependencies] powdr-ast.workspace = true diff --git a/riscv/src/elf/mod.rs b/riscv/src/elf/mod.rs index f9dd5055da..08f5bedfc4 100644 --- a/riscv/src/elf/mod.rs +++ b/riscv/src/elf/mod.rs @@ -14,6 +14,7 @@ use goblin::elf::{ }; use itertools::{Either, Itertools}; use powdr_isa_utils::SingleDataValue; +use powdr_riscv_syscalls::Syscall; use raki::{ decode::Decode, instruction::{Extensions, Instruction as Ins, OpcodeKind as Op}, @@ -783,6 +784,30 @@ impl TwoOrOneMapper for InstructionLifter<'_> { HighLevelInsn { op, args, loc } } + ( + // inline-able system call: + // addi t0, x0, immediate + // ecall + Ins { + opc: Op::ADDI, + rd: Some(5), + rs1: Some(0), + imm: Some(opcode), + .. + }, + Ins { opc: Op::ECALL, .. }, + ) => { + // If this is not a know system call, we just let the executor deal with the problem. + let syscall = u8::try_from(*opcode) + .ok() + .and_then(|opcode| Syscall::try_from(opcode).ok())?; + + HighLevelInsn { + loc, + op: syscall.name(), + args: Default::default(), + } + } ( // All other double instructions we can lift start with auipc. Ins { diff --git a/riscv/src/large_field/code_gen.rs b/riscv/src/large_field/code_gen.rs index 6e3ee32ad9..a32dd8bd17 100644 --- a/riscv/src/large_field/code_gen.rs +++ b/riscv/src/large_field/code_gen.rs @@ -154,7 +154,7 @@ fn translate_program_impl( } Statement::Label(l) => statements.push(format!("{}:", escape_label(l.as_ref()))), Statement::Instruction { op, args } => { - let processed_instr = match process_instruction(op, args) { + let processed_instr = match process_instruction(op, args, runtime) { Ok(s) => s, Err(e) => panic!("Failed to process instruction '{op}'. {e}"), }; @@ -703,7 +703,11 @@ pub fn pop_register(name: &str) -> Vec { ] } -fn process_instruction(instr: &str, args: A) -> Result, A::Error> { +fn process_instruction( + instr: &str, + args: A, + runtime: &Runtime, +) -> Result, A::Error> { let tmp1 = Register::from("tmp1"); let tmp2 = Register::from("tmp2"); let tmp3 = Register::from("tmp3"); @@ -1514,8 +1518,12 @@ fn process_instruction(instr: &str, args: A) -> Result { - panic!("Unknown instruction: {instr}"); + // possibly inlined system calls + insn => { + let Some(syscall_impl) = runtime.get_syscall_impl(insn) else { + panic!("Unknown instruction: {instr}"); + }; + syscall_impl.statements.clone() } }; for s in &statements { diff --git a/riscv/src/large_field/runtime.rs b/riscv/src/large_field/runtime.rs index cf598b3e08..5e4c95c0a1 100644 --- a/riscv/src/large_field/runtime.rs +++ b/riscv/src/large_field/runtime.rs @@ -6,10 +6,7 @@ use itertools::Itertools; use crate::code_gen::Register; -use crate::runtime::{ - parse_function_statement, parse_instruction_declaration, SubMachine, SyscallImpl, - EXTRA_REG_PREFIX, -}; +use crate::runtime::{parse_instruction_declaration, SubMachine, SyscallImpl, EXTRA_REG_PREFIX}; use crate::RuntimeLibs; /// RISCV powdr assembly runtime. @@ -17,7 +14,7 @@ use crate::RuntimeLibs; #[derive(Clone)] pub struct Runtime { submachines: BTreeMap, - syscalls: BTreeMap, + syscalls: BTreeMap<&'static str, SyscallImpl>, } impl Runtime { @@ -234,14 +231,19 @@ impl Runtime { syscall: Syscall, implementation: I, ) { - let implementation = SyscallImpl( - implementation + let implementation = SyscallImpl { + syscall, + statements: implementation .into_iter() - .map(|s| parse_function_statement(s.as_ref())) + .map(|s| s.as_ref().to_string()) .collect(), - ); + }; - if self.syscalls.insert(syscall, implementation).is_some() { + if self + .syscalls + .insert(syscall.name(), implementation) + .is_some() + { panic!("duplicate syscall {syscall}"); } } @@ -459,10 +461,10 @@ impl Runtime { ] .into_iter(); - let jump_table = self.syscalls.keys().map(|s| { + let jump_table = self.syscalls.values().map(|s| { format!( "branch_if_diff_equal 5, 0, {}, __ecall_handler_{};", - *s as u32, s + s.syscall as u8, s.syscall ) }); @@ -470,7 +472,7 @@ impl Runtime { let handlers = self.syscalls.iter().flat_map(|(syscall, implementation)| { std::iter::once(format!("__ecall_handler_{syscall}:")) - .chain(implementation.0.iter().map(|i| i.to_string())) + .chain(implementation.statements.iter().cloned()) .chain([format!("jump_dyn 1, {};", Register::from("tmp1").addr())]) }); @@ -481,6 +483,10 @@ impl Runtime { .chain(std::iter::once("// end of ecall handler".to_string())) .collect() } + + pub fn get_syscall_impl(&self, syscall_name: &str) -> Option<&SyscallImpl> { + self.syscalls.get(syscall_name) + } } /// Helper function for register names used in instruction params diff --git a/riscv/src/runtime.rs b/riscv/src/runtime.rs index e963bbf76b..46068dc5e9 100644 --- a/riscv/src/runtime.rs +++ b/riscv/src/runtime.rs @@ -1,6 +1,7 @@ use powdr_ast::parsed::asm::{FunctionStatement, MachineStatement, SymbolPath}; use powdr_parser::ParserContext; +use powdr_riscv_syscalls::Syscall; pub static EXTRA_REG_PREFIX: &str = "xtra"; @@ -67,4 +68,7 @@ impl SubMachine { /// Any of the registers used as input/output to the syscall should be usable without issue. /// Other registers should be saved/restored from memory, as LLVM doesn't know about their usage here. #[derive(Clone)] -pub struct SyscallImpl(pub Vec); +pub struct SyscallImpl { + pub syscall: Syscall, + pub statements: Vec, +} diff --git a/riscv/src/small_field/code_gen.rs b/riscv/src/small_field/code_gen.rs index 8ef0385f1f..399263b4cd 100644 --- a/riscv/src/small_field/code_gen.rs +++ b/riscv/src/small_field/code_gen.rs @@ -165,7 +165,7 @@ fn translate_program_impl( } Statement::Label(l) => statements.push(format!("{}:", escape_label(l.as_ref()))), Statement::Instruction { op, args } => { - let processed_instr = match process_instruction(op, args) { + let processed_instr = match process_instruction(op, args, runtime) { Ok(s) => s, Err(e) => panic!("Failed to process instruction '{op}'. {e}"), }; @@ -777,7 +777,11 @@ fn i32_low(x: i32) -> u16 { (x & 0xffff) as u16 } -fn process_instruction(instr: &str, args: A) -> Result, A::Error> { +fn process_instruction( + instr: &str, + args: A, + runtime: &Runtime, +) -> Result, A::Error> { let tmp1 = Register::from("tmp1"); let tmp2 = Register::from("tmp2"); let tmp3 = Register::from("tmp3"); @@ -1806,8 +1810,12 @@ fn process_instruction(instr: &str, args: A) -> Result { - panic!("Unknown instruction: {instr}"); + // possibly inlined system calls + insn => { + let Some(syscall_impl) = runtime.get_syscall_impl(insn) else { + panic!("Unknown instruction: {instr}"); + }; + syscall_impl.statements.clone() } }; for s in &statements { diff --git a/riscv/src/small_field/runtime.rs b/riscv/src/small_field/runtime.rs index 52dbe7e69d..a67c3ecd83 100644 --- a/riscv/src/small_field/runtime.rs +++ b/riscv/src/small_field/runtime.rs @@ -5,12 +5,8 @@ use powdr_riscv_syscalls::Syscall; use itertools::Itertools; use crate::code_gen::Register; -use crate::small_field::code_gen::{u32_high, u32_low}; -use crate::runtime::{ - parse_function_statement, parse_instruction_declaration, SubMachine, SyscallImpl, - EXTRA_REG_PREFIX, -}; +use crate::runtime::{parse_instruction_declaration, SubMachine, SyscallImpl, EXTRA_REG_PREFIX}; use crate::RuntimeLibs; /// RISCV powdr assembly runtime. @@ -18,7 +14,7 @@ use crate::RuntimeLibs; #[derive(Clone)] pub struct Runtime { submachines: BTreeMap, - syscalls: BTreeMap, + syscalls: BTreeMap<&'static str, SyscallImpl>, } impl Runtime { @@ -204,14 +200,19 @@ impl Runtime { syscall: Syscall, implementation: I, ) { - let implementation = SyscallImpl( - implementation + let implementation = SyscallImpl { + syscall, + statements: implementation .into_iter() - .map(|s| parse_function_statement(s.as_ref())) + .map(|s| s.as_ref().to_string()) .collect(), - ); + }; - if self.syscalls.insert(syscall, implementation).is_some() { + if self + .syscalls + .insert(syscall.name(), implementation) + .is_some() + { panic!("duplicate syscall {syscall}"); } } @@ -278,17 +279,19 @@ impl Runtime { ] .into_iter(); - let jump_table = self.syscalls.keys().map(|s| { - let s32_h = u32_high(*s as u32); - let s32_l = u32_low(*s as u32); - format!("branch_if_diff_equal 5, 0, {s32_h}, {s32_l}, __ecall_handler_{s};",) + let jump_table = self.syscalls.values().map(|s| { + let opcode = s.syscall as u8; + format!( + "branch_if_diff_equal 5, 0, 0, {opcode}, __ecall_handler_{};", + s.syscall + ) }); let invalid_handler = ["__invalid_syscall:".to_string(), "fail;".to_string()].into_iter(); let handlers = self.syscalls.iter().flat_map(|(syscall, implementation)| { std::iter::once(format!("__ecall_handler_{syscall}:")) - .chain(implementation.0.iter().map(|i| i.to_string())) + .chain(implementation.statements.iter().cloned()) .chain([format!("jump_dyn 1, {};", Register::from("tmp1").addr())]) }); @@ -299,4 +302,8 @@ impl Runtime { .chain(std::iter::once("// end of ecall handler".to_string())) .collect() } + + pub fn get_syscall_impl(&self, syscall_name: &str) -> Option<&SyscallImpl> { + self.syscalls.get(syscall_name) + } } diff --git a/riscv/tests/riscv.rs b/riscv/tests/riscv.rs index 8972078b6c..3ba1a0d1d0 100644 --- a/riscv/tests/riscv.rs +++ b/riscv/tests/riscv.rs @@ -1,10 +1,8 @@ mod common; -use common::{verify_riscv_asm_file, verify_riscv_asm_string}; +use common::{compile_riscv_asm_file, verify_riscv_asm_file, verify_riscv_asm_string}; use mktemp::Temp; -use powdr_number::{ - read_polys_csv_file, BabyBearField, CsvRenderMode, FieldElement, GoldilocksField, KnownField, -}; +use powdr_number::{BabyBearField, FieldElement, GoldilocksField, KnownField}; use powdr_pipeline::{ test_util::{run_pilcom_with_backend_variant, BackendVariant}, Pipeline, @@ -287,6 +285,68 @@ fn read_slice() { read_slice_with_options::(CompilerOptions::new_gl()); } +/// Tests that the syscalls are inlined when the following pattern is used: +/// addi t0, x0, opcode +/// ecall +#[test] +fn syscalls_inlined_when_possible() { + // The following program should have two inlined syscalls, + // and two calls to the dispatcher that could not be inlined. + let asm = r#" + .section .text + .globl _start + _start: + # inlined commit_public + addi t0, x0, 12 + ecall + + # non-inlined halt + ori t0, x0, 9 + ecall + + # inlined input + addi t0, x0, 1 # input opcode + ecall + + # non-inlined output + ori t0, x0, 2 # output opcode + ecall + "#; + let tmp_dir = Temp::new_dir().unwrap(); + let asm_file = tmp_dir.join("test.s"); + std::fs::write(&asm_file, asm).unwrap(); + + let compiled = compile_riscv_asm_file(&asm_file, CompilerOptions::new_gl(), true); + + // The resulting compiled program should contain the following strings in + // between the first automatic "return;" and the "__data_init:" definition, + // in the provided order: + let expected_strings = [ + // initial marker + "return;", + // from the inlined commit_public + "commit_public 10, 11;", + // from the non-inlined halt call + "or 0, 0, 9, 5;", + "jump __ecall_handler, 1;", + // from the inlined input call + "std::prelude::Query::Input", + // from the non-inlined output call + "or 0, 0, 2, 5;", + "jump __ecall_handler, 1;", + // final marker + "__data_init:", + ]; + + let mut remaining = compiled.as_str(); + for expected in &expected_strings { + let pos = remaining + .find(expected) + .unwrap_or_else(|| panic!("Expected string not found in generated code: {expected}")); + remaining = &remaining[pos + expected.len()..]; + } +} + fn read_slice_with_options(options: CompilerOptions) { let case = "read_slice"; let temp_dir = Temp::new_dir().unwrap(); @@ -628,10 +688,13 @@ fn profiler_sanity_check() { assert!(!callgrind.unwrap().is_empty()); } +#[cfg(feature = "plonky3")] #[test] #[ignore = "Too slow"] /// check that exported witness CSV can be loaded back in fn exported_csv_as_external_witness() { + use powdr_number::{read_polys_csv_file, CsvRenderMode}; + let case = "keccak"; let temp_dir = Temp::new_dir().unwrap();