Skip to content

Commit

Permalink
Inlining system calls when possible. (#2087)
Browse files Browse the repository at this point in the history
Solves #2040.
  • Loading branch information
lvella authored Nov 15, 2024
1 parent a3474c3 commit 2049a35
Show file tree
Hide file tree
Showing 18 changed files with 291 additions and 171 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pr-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ jobs:
run: rustup component add rust-src --toolchain nightly-2024-08-01-x86_64-unknown-linux-gnu
- name: Install riscv target
run: rustup target add riscv32imac-unknown-none-elf --toolchain nightly-2024-08-01-x86_64-unknown-linux-gnu
- name: Install test dependencies
run: sudo apt-get install -y binutils-riscv64-unknown-elf lld
- name: Install pilcom
run: git clone https://github.com/0xPolygonHermez/pilcom.git && cd pilcom && npm install
- uses: taiki-e/install-action@nextest
Expand Down
107 changes: 42 additions & 65 deletions riscv-runtime/src/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use core::arch::asm;

use powdr_riscv_syscalls::Syscall;

use crate::ecall;

/// convert a big-endian u8 array to u32 array (arith machine format)
pub(crate) fn be_to_u32(from: &[u8; 32], to: &mut [u32; 8]) {
for (i, chunk) in from.chunks_exact(4).rev().enumerate() {
Expand Down Expand Up @@ -29,11 +31,10 @@ pub fn affine_256_u8_be(mut a: [u8; 32], mut b: [u8; 32], c: [u8; 32]) -> ([u8;
be_to_u32(&c, &mut c1);

unsafe {
asm!("ecall",
in("a0") &mut a1 as *mut [u32; 8],
in("a1") &mut b1 as *mut [u32; 8],
in("a2") &mut c1 as *mut [u32; 8],
in("t0") u32::from(Syscall::Affine256));
ecall!(Syscall::Affine256,
in("a0") a1.as_mut_ptr(),
in("a1") b1.as_mut_ptr(),
in("a2") c1.as_ptr());
}

u32_to_be(&a1, &mut a);
Expand All @@ -45,40 +46,30 @@ pub fn affine_256_u8_be(mut a: [u8; 32], mut b: [u8; 32], c: [u8; 32]) -> ([u8;
/// Returns `(hi, lo)`.
pub fn affine_256_u8_le(mut a: [u8; 32], mut b: [u8; 32], c: [u8; 32]) -> ([u8; 32], [u8; 32]) {
unsafe {
asm!("ecall",
in("a0") a.as_mut_ptr() as *mut [u32; 8],
in("a1") b.as_mut_ptr() as *mut [u32; 8],
in("a2") c.as_ptr() as *const [u32; 8],
in("t0") u32::from(Syscall::Affine256));
ecall!(Syscall::Affine256,
in("a0") a.as_mut_ptr(),
in("a1") b.as_mut_ptr(),
in("a2") c.as_ptr());
}

(a, b)
}

/// Calculate `a*b + c = hi*2**256 + lo` for 256 bit values (as u32 little-endian arrays).
/// Returns `(hi, lo)`.
pub fn affine_256_u32_le(
mut a: [u32; 8],
mut b: [u32; 8],
mut c: [u32; 8],
) -> ([u32; 8], [u32; 8]) {
pub fn affine_256_u32_le(mut a: [u32; 8], mut b: [u32; 8], c: [u32; 8]) -> ([u32; 8], [u32; 8]) {
unsafe {
asm!("ecall",
in("a0") &mut a as *mut [u32; 8],
in("a1") &mut b as *mut [u32; 8],
in("a2") &mut c as *mut [u32; 8],
in("t0") u32::from(Syscall::Affine256));
ecall!(Syscall::Affine256,
in("a0") a.as_mut_ptr(),
in("a1") b.as_mut_ptr(),
in("a2") c.as_ptr());
}
(a, b)
}

/// Calculate `(a*b) % m = r` for 256 bit values (as u8 big-endian arrays).
/// Returns `r`.
pub fn modmul_256_u8_be(
mut a: [u8; 32],
b: [u8; 32],
m: [u8; 32],
) -> [u8; 32] {
pub fn modmul_256_u8_be(mut a: [u8; 32], b: [u8; 32], m: [u8; 32]) -> [u8; 32] {
let mut a1: [u32; 8] = Default::default();
let mut b1: [u32; 8] = Default::default();
let mut m1: [u32; 8] = Default::default();
Expand All @@ -90,17 +81,15 @@ pub fn modmul_256_u8_be(
unsafe {
// First compute the two halves of the result a*b.
// Results are stored in place in a and b.
asm!("ecall",
in("a0") &mut a1 as *mut [u32; 8],
in("a1") &mut b1 as *mut [u32; 8],
in("a2") &mut [0u32; 8] as *mut [u32; 8],
in("t0") u32::from(Syscall::Affine256));
ecall!(Syscall::Affine256,
in("a0") a1.as_mut_ptr(),
in("a1") b1.as_mut_ptr(),
in("a2") [0u32; 8].as_ptr());
// Next compute the remainder, stored in place in a.
asm!("ecall",
in("a0") &mut a1 as *mut [u32; 8],
in("a1") &mut b1 as *mut [u32; 8],
in("a2") &mut m1 as *mut [u32; 8],
in("t0") u32::from(Syscall::Mod256));
ecall!(Syscall::Mod256,
in("a0") a1.as_mut_ptr(),
in("a1") b1.as_ptr(),
in("a2") m1.as_ptr());
}

u32_to_be(&a1, &mut a);
Expand All @@ -109,51 +98,39 @@ pub fn modmul_256_u8_be(

/// Calculate `(a*b) % m = r` for 256 bit values (as u8 little-endian arrays).
/// Returns `r`.
pub fn modmul_256_u8_le(
mut a: [u8; 32],
mut b: [u8; 32],
m: [u8; 32],
) -> [u8; 32] {
pub fn modmul_256_u8_le(mut a: [u8; 32], mut b: [u8; 32], m: [u8; 32]) -> [u8; 32] {
unsafe {
// First compute the two halves of the result a*b.
// Results are stored in place in a and b.
asm!("ecall",
in("a0") a.as_mut_ptr() as *mut [u32; 8],
in("a1") b.as_mut_ptr() as *mut [u32; 8],
in("a2") &mut [0u32; 8] as *mut [u32; 8],
in("t0") u32::from(Syscall::Affine256));
ecall!(Syscall::Affine256,
in("a0") a.as_mut_ptr(),
in("a1") b.as_mut_ptr(),
in("a2") [0u32; 8].as_ptr());
// Next compute the remainder, stored in place in a.
asm!("ecall",
in("a0") a.as_mut_ptr() as *mut [u32; 8],
in("a1") b.as_mut_ptr() as *mut [u32; 8],
in("a2") m.as_ptr() as *const [u32; 8],
in("t0") u32::from(Syscall::Mod256));
ecall!(Syscall::Mod256,
in("a0") a.as_mut_ptr(),
in("a1") b.as_ptr(),
in("a2") m.as_ptr());
}

a
}

/// Calculate `(a*b) % m = r` for 256 bit values (as u32 little-endian arrays).
/// Returns `r`.
pub fn modmul_256_u32_le(
mut a: [u32; 8],
mut b: [u32; 8],
m: [u32; 8],
) -> [u32; 8] {
pub fn modmul_256_u32_le(mut a: [u32; 8], mut b: [u32; 8], m: [u32; 8]) -> [u32; 8] {
unsafe {
// First compute the two halves of the result a*b.
// Results are stored in place in a and b.
asm!("ecall",
in("a0") &mut a as *mut [u32; 8],
in("a1") &mut b as *mut [u32; 8],
in("a2") &[0u32; 8] as *const [u32; 8],
in("t0") u32::from(Syscall::Affine256));
ecall!(Syscall::Affine256,
in("a0") a.as_mut_ptr(),
in("a1") b.as_mut_ptr(),
in("a2") [0u32; 8].as_ptr());
// Next compute the remainder, stored in place in a.
asm!("ecall",
in("a0") &mut a as *mut [u32; 8],
in("a1") &mut b as *mut [u32; 8],
in("a2") &m as *const [u32; 8],
in("t0") u32::from(Syscall::Mod256));
ecall!(Syscall::Mod256,
in("a0") a.as_mut_ptr(),
in("a1") b.as_ptr(),
in("a2") m.as_ptr());
}

a
Expand Down
62 changes: 28 additions & 34 deletions riscv-runtime/src/ec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ pub fn add_u8_be(
be_to_u32(&by, &mut by1);

unsafe {
asm!("ecall",
in("a0") &mut ax1 as *mut [u32; 8],
in("a1") &mut ay1 as *mut [u32; 8],
in("a2") &mut bx1 as *mut [u32; 8],
in("a3") &mut by1 as *mut [u32; 8],
in("t0") u32::from(Syscall::EcAdd));
ecall!(Syscall::EcAdd,
in("a0") ax1.as_mut_ptr(),
in("a1") ay1.as_mut_ptr(),
in("a2") bx1.as_ptr(),
in("a3") by1.as_ptr());
}

u32_to_be(&ax1, &mut ax);
Expand All @@ -39,16 +38,15 @@ pub fn add_u8_be(
pub fn add_u8_le(
mut ax: [u8; 32],
mut ay: [u8; 32],
mut bx: [u8; 32],
mut by: [u8; 32],
bx: [u8; 32],
by: [u8; 32],
) -> ([u8; 32], [u8; 32]) {
unsafe {
asm!("ecall",
in("a0") ax.as_mut_ptr() as *mut [u32; 8],
in("a1") ay.as_mut_ptr() as *mut [u32; 8],
in("a2") bx.as_mut_ptr() as *mut [u32; 8],
in("a3") by.as_mut_ptr() as *mut [u32; 8],
in("t0") u32::from(Syscall::EcAdd));
ecall!(Syscall::EcAdd,
in("a0") ax.as_mut_ptr(),
in("a1") ay.as_mut_ptr(),
in("a2") bx.as_ptr(),
in("a3") by.as_ptr());
}
(ax, ay)
}
Expand All @@ -57,16 +55,15 @@ pub fn add_u8_le(
pub fn add_u32_le(
mut ax: [u32; 8],
mut ay: [u32; 8],
mut bx: [u32; 8],
mut by: [u32; 8],
bx: [u32; 8],
by: [u32; 8],
) -> ([u32; 8], [u32; 8]) {
unsafe {
asm!("ecall",
in("a0") &mut ax as *mut [u32; 8],
in("a1") &mut ay as *mut [u32; 8],
in("a2") &mut bx as *mut [u32; 8],
in("a3") &mut by as *mut [u32; 8],
in("t0") u32::from(Syscall::EcAdd));
ecall!(Syscall::EcAdd,
in("a0") ax.as_mut_ptr(),
in("a1") ay.as_mut_ptr(),
in("a2") bx.as_ptr(),
in("a3") by.as_ptr());
}
(ax, ay)
}
Expand All @@ -80,10 +77,9 @@ pub fn double_u8_be(mut x: [u8; 32], mut y: [u8; 32]) -> ([u8; 32], [u8; 32]) {
be_to_u32(&y, &mut y1);

unsafe {
asm!("ecall",
in("a0") &mut x1 as *mut [u32; 8],
in("a1") &mut y1 as *mut [u32; 8],
in("t0") u32::from(Syscall::EcDouble));
ecall!(Syscall::EcDouble,
in("a0") x1.as_mut_ptr(),
in("a1") y1.as_mut_ptr());
}

u32_to_be(&x1, &mut x);
Expand All @@ -95,10 +91,9 @@ pub fn double_u8_be(mut x: [u8; 32], mut y: [u8; 32]) -> ([u8; 32], [u8; 32]) {
/// Double a k256 ec point. Coordinates are little-endian u8 arrays.
pub fn double_u8_le(mut x: [u8; 32], mut y: [u8; 32]) -> ([u8; 32], [u8; 32]) {
unsafe {
asm!("ecall",
in("a0") x.as_mut_ptr() as *mut [u32; 8],
in("a1") y.as_mut_ptr() as *mut [u32; 8],
in("t0") u32::from(Syscall::EcDouble));
ecall!(Syscall::EcDouble,
in("a0") x.as_mut_ptr(),
in("a1") y.as_mut_ptr());
}

(x, y)
Expand All @@ -107,10 +102,9 @@ pub fn double_u8_le(mut x: [u8; 32], mut y: [u8; 32]) -> ([u8; 32], [u8; 32]) {
/// Double a k256 ec point. Coordinates are little-endian u32 arrays.
pub fn double_u32_le(mut x: [u32; 8], mut y: [u32; 8]) -> ([u32; 8], [u32; 8]) {
unsafe {
asm!("ecall",
in("a0") &mut x as *mut [u32; 8],
in("a1") &mut y as *mut [u32; 8],
in("t0") u32::from(Syscall::EcDouble));
ecall!(Syscall::EcDouble,
in("a0") x.as_mut_ptr(),
in("a1") y.as_mut_ptr());
}
(x, y)
}
17 changes: 17 additions & 0 deletions riscv-runtime/src/ecall.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/// Generates an ecall instruction with the given system call number and arguments.
///
/// Uses the instruction sequence convention for the system call to be inlined.
#[macro_export]
macro_rules! ecall {
($syscall:expr, $($tokens:tt)*) => {
asm!(
"addi t0, x0, {}",
"ecall",
const $syscall as u8,
// No system call we have at this point allocates on stack.
options(nostack),
out("t0") _,
$($tokens)*
);
};
}
2 changes: 1 addition & 1 deletion riscv-runtime/src/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ fn print_prover_char(c: u8) {
let mut value = c as u32;
#[allow(unused_assignments)]
unsafe {
asm!("ecall", lateout("a0") value, in("a0") 1, in("a1") value, in("t0") u32::from(Syscall::Output));
ecall!(Syscall::Output, lateout("a0") value, in("a0") 1, in("a1") value);
}
}
19 changes: 5 additions & 14 deletions riscv-runtime/src/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use powdr_riscv_syscalls::Syscall;

pub fn native_hash(data: &mut [u64; 12]) -> &[u64; 4] {
unsafe {
asm!("ecall", in("a0") data as *mut _, in("t0") u32::from(Syscall::NativeHash));
ecall!(Syscall::NativeHash, in("a0") data);
}
data[..4].try_into().unwrap()
}
Expand All @@ -17,32 +17,23 @@ pub fn native_hash(data: &mut [u64; 12]) -> &[u64; 4] {
/// sub-array is returned.
pub fn poseidon_gl(data: &mut [Goldilocks; 12]) -> &[Goldilocks; 4] {
unsafe {
asm!("ecall", in("a0") data as *mut _, in("t0") u32::from(Syscall::PoseidonGL));
ecall!(Syscall::PoseidonGL, in("a0") data);
}
data[..4].try_into().unwrap()
}

/// Perform one Poseidon2 permutation with 8 Goldilocks field elements in-place.
pub fn poseidon2_gl_inplace(data: &mut [Goldilocks; 8]) {
let ptr = data as *mut _;
unsafe {
asm!("ecall",
in("a0") ptr,
in("a1") ptr,
in("t0") u32::from(Syscall::Poseidon2GL)
);
ecall!(Syscall::Poseidon2GL, in("a0") data, in("a1") data);
}
}

/// Perform one Poseidon2 permutation with 8 Goldilocks field elements.
pub fn poseidon2_gl(data: &[Goldilocks; 8]) -> [Goldilocks; 8] {
unsafe {
let mut output: MaybeUninit<[Goldilocks; 8]> = MaybeUninit::uninit();
asm!("ecall",
in("a0") data as *const _,
in("a1") output.as_mut_ptr(),
in("t0") u32::from(Syscall::Poseidon2GL)
);
ecall!(Syscall::Poseidon2GL, in("a0") data, in("a1") output.as_mut_ptr());
output.assume_init()
}
}
Expand All @@ -52,7 +43,7 @@ pub fn poseidon2_gl(data: &[Goldilocks; 8]) -> [Goldilocks; 8] {
pub fn keccakf(input: &[u64; 25], output: &mut [u64; 25]) {
unsafe {
// Syscall inputs: memory pointer to input array and memory pointer to output array.
asm!("ecall", in("a0") input, in("a1") output, in("t0") u32::from(Syscall::KeccakF));
ecall!(Syscall::KeccakF, in("a0") input, in("a1") output);
}
}

Expand Down
Loading

0 comments on commit 2049a35

Please sign in to comment.