From 751ec864da08346b93ea657d2be7261aa8f54c14 Mon Sep 17 00:00:00 2001 From: samkim-crypto Date: Wed, 21 Sep 2022 09:50:47 +0900 Subject: [PATCH] Add syscall curve group ops tests (#27937) * resolve rebase conflict * fix logic when group ops fail * update bpf loader id --- programs/bpf_loader/src/syscalls/mod.rs | 434 +++++++++++++++++++++++- 1 file changed, 428 insertions(+), 6 deletions(-) diff --git a/programs/bpf_loader/src/syscalls/mod.rs b/programs/bpf_loader/src/syscalls/mod.rs index 15e3e8504e668a..3d8bc275e9ac90 100644 --- a/programs/bpf_loader/src/syscalls/mod.rs +++ b/programs/bpf_loader/src/syscalls/mod.rs @@ -1251,7 +1251,8 @@ declare_syscall!( ), result ) = result_point; - *result = Ok(0); + } else { + *result = Ok(1); } } SUB => { @@ -1286,7 +1287,8 @@ declare_syscall!( ), result ) = result_point; - *result = Ok(0); + } else { + *result = Ok(1); } } MUL => { @@ -1321,7 +1323,8 @@ declare_syscall!( ), result ) = result_point; - *result = Ok(0); + } else { + *result = Ok(1); } } _ => { @@ -1362,7 +1365,8 @@ declare_syscall!( ), result ) = result_point; - *result = Ok(0); + } else { + *result = Ok(1); } } SUB => { @@ -1399,7 +1403,8 @@ declare_syscall!( ), result ) = result_point; - *result = Ok(0); + } else { + *result = Ok(1); } } MUL => { @@ -1434,7 +1439,8 @@ declare_syscall!( ), result ) = result_point; - *result = Ok(0); + } else { + *result = Ok(1); } } _ => { @@ -3121,6 +3127,422 @@ mod tests { ); } + #[test] + fn test_syscall_edwards_curve_group_ops() { + use solana_zk_token_sdk::curve25519::curve_syscall_traits::{ + ADD, CURVE25519_EDWARDS, MUL, SUB, + }; + + let config = Config::default(); + prepare_mockup!( + invoke_context, + transaction_context, + program_id, + bpf_loader::id(), + ); + + let left_point: [u8; 32] = [ + 33, 124, 71, 170, 117, 69, 151, 247, 59, 12, 95, 125, 133, 166, 64, 5, 2, 27, 90, 27, + 200, 167, 59, 164, 52, 54, 52, 200, 29, 13, 34, 213, + ]; + let left_point_va = 0x100000000; + let right_point: [u8; 32] = [ + 70, 222, 137, 221, 253, 204, 71, 51, 78, 8, 124, 1, 67, 200, 102, 225, 122, 228, 111, + 183, 129, 14, 131, 210, 212, 95, 109, 246, 55, 10, 159, 91, + ]; + let right_point_va = 0x200000000; + let scalar: [u8; 32] = [ + 254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250, + 78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6, + ]; + let scalar_va = 0x300000000; + let invalid_point: [u8; 32] = [ + 120, 140, 152, 233, 41, 227, 203, 27, 87, 115, 25, 251, 219, 5, 84, 148, 117, 38, 84, + 60, 87, 144, 161, 146, 42, 34, 91, 155, 158, 189, 121, 79, + ]; + let invalid_point_va = 0x400000000; + let result_point: [u8; 32] = [0; 32]; + let result_point_va = 0x500000000; + + let mut memory_mapping = MemoryMapping::new::( + vec![ + MemoryRegion::default(), + MemoryRegion { + host_addr: left_point.as_ptr() as *const _ as u64, + vm_addr: left_point_va, + len: 32, + vm_gap_shift: 63, + is_writable: false, + }, + MemoryRegion { + host_addr: right_point.as_ptr() as *const _ as u64, + vm_addr: right_point_va, + len: 32, + vm_gap_shift: 63, + is_writable: false, + }, + MemoryRegion { + host_addr: scalar.as_ptr() as *const _ as u64, + vm_addr: scalar_va, + len: 32, + vm_gap_shift: 63, + is_writable: false, + }, + MemoryRegion { + host_addr: invalid_point.as_ptr() as *const _ as u64, + vm_addr: invalid_point_va, + len: 32, + vm_gap_shift: 63, + is_writable: false, + }, + MemoryRegion { + host_addr: result_point.as_ptr() as *const _ as u64, + vm_addr: result_point_va, + len: 32, + vm_gap_shift: 63, + is_writable: true, + }, + ], + &config, + ) + .unwrap(); + + invoke_context + .get_compute_meter() + .borrow_mut() + .mock_set_remaining( + (invoke_context + .get_compute_budget() + .curve25519_edwards_add_cost + + invoke_context + .get_compute_budget() + .curve25519_edwards_subtract_cost + + invoke_context + .get_compute_budget() + .curve25519_edwards_multiply_cost) + * 2, + ); + let mut syscall = SyscallCurveGroupOps { + invoke_context: Rc::new(RefCell::new(&mut invoke_context)), + }; + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_EDWARDS, + ADD, + left_point_va, + right_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + + assert_eq!(0, result.unwrap()); + let expected_sum = [ + 7, 251, 187, 86, 186, 232, 57, 242, 193, 236, 49, 200, 90, 29, 254, 82, 46, 80, 83, 70, + 244, 153, 23, 156, 2, 138, 207, 51, 165, 38, 200, 85, + ]; + assert_eq!(expected_sum, result_point); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_EDWARDS, + ADD, + invalid_point_va, + right_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + assert_eq!(1, result.unwrap()); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_EDWARDS, + SUB, + left_point_va, + right_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + + assert_eq!(0, result.unwrap()); + let expected_difference = [ + 60, 87, 90, 68, 232, 25, 7, 172, 247, 120, 158, 104, 52, 127, 94, 244, 5, 79, 253, 15, + 48, 69, 82, 134, 155, 70, 188, 81, 108, 95, 212, 9, + ]; + assert_eq!(expected_difference, result_point); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_EDWARDS, + SUB, + invalid_point_va, + right_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + assert_eq!(1, result.unwrap()); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_EDWARDS, + MUL, + scalar_va, + right_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + + result.unwrap(); + let expected_product = [ + 64, 150, 40, 55, 80, 49, 217, 209, 105, 229, 181, 65, 241, 68, 2, 106, 220, 234, 211, + 71, 159, 76, 156, 114, 242, 68, 147, 31, 243, 211, 191, 124, + ]; + assert_eq!(expected_product, result_point); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_EDWARDS, + MUL, + scalar_va, + invalid_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + assert_eq!(1, result.unwrap()); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_EDWARDS, + MUL, + scalar_va, + invalid_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + assert_eq!( + Err(EbpfError::UserError(BpfError::SyscallError( + SyscallError::InstructionError(InstructionError::ComputationalBudgetExceeded) + ))), + result + ); + } + + #[test] + fn test_syscall_ristretto_curve_group_ops() { + use solana_zk_token_sdk::curve25519::curve_syscall_traits::{ + ADD, CURVE25519_RISTRETTO, MUL, SUB, + }; + + let config = Config::default(); + prepare_mockup!( + invoke_context, + transaction_context, + program_id, + bpf_loader::id(), + ); + + let left_point: [u8; 32] = [ + 208, 165, 125, 204, 2, 100, 218, 17, 170, 194, 23, 9, 102, 156, 134, 136, 217, 190, 98, + 34, 183, 194, 228, 153, 92, 11, 108, 103, 28, 57, 88, 15, + ]; + let left_point_va = 0x100000000; + let right_point: [u8; 32] = [ + 208, 241, 72, 163, 73, 53, 32, 174, 54, 194, 71, 8, 70, 181, 244, 199, 93, 147, 99, + 231, 162, 127, 25, 40, 39, 19, 140, 132, 112, 212, 145, 108, + ]; + let right_point_va = 0x200000000; + let scalar: [u8; 32] = [ + 254, 198, 23, 138, 67, 243, 184, 110, 236, 115, 236, 205, 205, 215, 79, 114, 45, 250, + 78, 137, 3, 107, 136, 237, 49, 126, 117, 223, 37, 191, 88, 6, + ]; + let scalar_va = 0x300000000; + let invalid_point: [u8; 32] = [ + 120, 140, 152, 233, 41, 227, 203, 27, 87, 115, 25, 251, 219, 5, 84, 148, 117, 38, 84, + 60, 87, 144, 161, 146, 42, 34, 91, 155, 158, 189, 121, 79, + ]; + let invalid_point_va = 0x400000000; + let result_point: [u8; 32] = [0; 32]; + let result_point_va = 0x500000000; + + let mut memory_mapping = MemoryMapping::new::( + vec![ + MemoryRegion::default(), + MemoryRegion { + host_addr: left_point.as_ptr() as *const _ as u64, + vm_addr: left_point_va, + len: 32, + vm_gap_shift: 63, + is_writable: false, + }, + MemoryRegion { + host_addr: right_point.as_ptr() as *const _ as u64, + vm_addr: right_point_va, + len: 32, + vm_gap_shift: 63, + is_writable: false, + }, + MemoryRegion { + host_addr: scalar.as_ptr() as *const _ as u64, + vm_addr: scalar_va, + len: 32, + vm_gap_shift: 63, + is_writable: false, + }, + MemoryRegion { + host_addr: invalid_point.as_ptr() as *const _ as u64, + vm_addr: invalid_point_va, + len: 32, + vm_gap_shift: 63, + is_writable: false, + }, + MemoryRegion { + host_addr: result_point.as_ptr() as *const _ as u64, + vm_addr: result_point_va, + len: 32, + vm_gap_shift: 63, + is_writable: true, + }, + ], + &config, + ) + .unwrap(); + + invoke_context + .get_compute_meter() + .borrow_mut() + .mock_set_remaining( + (invoke_context + .get_compute_budget() + .curve25519_ristretto_add_cost + + invoke_context + .get_compute_budget() + .curve25519_ristretto_subtract_cost + + invoke_context + .get_compute_budget() + .curve25519_ristretto_multiply_cost) + * 2, + ); + let mut syscall = SyscallCurveGroupOps { + invoke_context: Rc::new(RefCell::new(&mut invoke_context)), + }; + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_RISTRETTO, + ADD, + left_point_va, + right_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + + assert_eq!(0, result.unwrap()); + let expected_sum = [ + 78, 173, 9, 241, 180, 224, 31, 107, 176, 210, 144, 240, 118, 73, 70, 191, 128, 119, + 141, 113, 125, 215, 161, 71, 49, 176, 87, 38, 180, 177, 39, 78, + ]; + assert_eq!(expected_sum, result_point); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_RISTRETTO, + ADD, + invalid_point_va, + right_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + assert_eq!(1, result.unwrap()); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_RISTRETTO, + SUB, + left_point_va, + right_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + + assert_eq!(0, result.unwrap()); + let expected_difference = [ + 150, 72, 222, 61, 148, 79, 96, 130, 151, 176, 29, 217, 231, 211, 0, 215, 76, 86, 212, + 146, 110, 128, 24, 151, 187, 144, 108, 233, 221, 208, 157, 52, + ]; + assert_eq!(expected_difference, result_point); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_RISTRETTO, + SUB, + invalid_point_va, + right_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + + assert_eq!(1, result.unwrap()); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_RISTRETTO, + MUL, + scalar_va, + right_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + + result.unwrap(); + let expected_product = [ + 4, 16, 46, 2, 53, 151, 201, 133, 117, 149, 232, 164, 119, 109, 136, 20, 153, 24, 124, + 21, 101, 124, 80, 19, 119, 100, 77, 108, 65, 187, 228, 5, + ]; + assert_eq!(expected_product, result_point); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_RISTRETTO, + MUL, + scalar_va, + invalid_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + + assert_eq!(1, result.unwrap()); + + let mut result: Result> = Ok(0); + syscall.call( + CURVE25519_RISTRETTO, + MUL, + scalar_va, + invalid_point_va, + result_point_va, + &mut memory_mapping, + &mut result, + ); + assert_eq!( + Err(EbpfError::UserError(BpfError::SyscallError( + SyscallError::InstructionError(InstructionError::ComputationalBudgetExceeded) + ))), + result + ); + } + fn create_filled_type(zero_init: bool) -> T { let mut val = T::default(); let p = &mut val as *mut _ as *mut u8;