diff --git a/Cargo.lock b/Cargo.lock index ab5874283bc759..69ae47c2ad2abf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6191,6 +6191,7 @@ dependencies = [ "ark-serialize", "array-bytes", "bytemuck", + "criterion", "serde", "serde_derive", "serde_json", diff --git a/curves/bn254/Cargo.toml b/curves/bn254/Cargo.toml index 02d5b8c18300e8..112afd0d44da13 100644 --- a/curves/bn254/Cargo.toml +++ b/curves/bn254/Cargo.toml @@ -22,9 +22,14 @@ ark-serialize = { workspace = true } [dev-dependencies] array-bytes = { workspace = true } +criterion = { workspace = true } serde = { workspace = true } serde_derive = { workspace = true } serde_json = { workspace = true } +[[bench]] +name = "bn254" +harness = false + [lints] workspace = true diff --git a/curves/bn254/benches/bn254.rs b/curves/bn254/benches/bn254.rs new file mode 100644 index 00000000000000..f94c4bcc7c9f8c --- /dev/null +++ b/curves/bn254/benches/bn254.rs @@ -0,0 +1,71 @@ +use { + criterion::{criterion_group, criterion_main, Criterion}, + solana_bn254::prelude::{alt_bn128_addition, alt_bn128_multiplication, alt_bn128_pairing}, +}; + +fn bench_addition(c: &mut Criterion) { + let p_bytes = [ + 24, 177, 138, 207, 180, 194, 195, 2, 118, 219, 84, 17, 54, 142, 113, 133, 179, 17, 221, 18, + 70, 145, 97, 12, 93, 59, 116, 3, 78, 9, 61, 201, 6, 60, 144, 156, 71, 32, 132, 12, 181, 19, + 76, 185, 245, 159, 167, 73, 117, 87, 150, 129, 150, 88, 211, 46, 252, 13, 40, 129, 152, + 243, 114, 102, + ]; + let q_bytes = [ + 7, 194, 183, 245, 138, 132, 189, 97, 69, 240, 12, 156, 43, 192, 187, 26, 24, 127, 32, 255, + 44, 146, 150, 58, 136, 1, 158, 124, 106, 1, 78, 237, 6, 97, 78, 32, 193, 71, 233, 64, 242, + 215, 13, 163, 247, 76, 154, 23, 223, 54, 23, 6, 164, 72, 92, 116, 43, 214, 120, 132, 120, + 250, 23, 215, + ]; + + let input_bytes = [&p_bytes[..], &q_bytes[..]].concat(); + + c.bench_function("bn128 addition", |b| { + b.iter(|| alt_bn128_addition(&input_bytes)) + }); +} + +fn bench_multiplication(c: &mut Criterion) { + let point_bytes = [ + 43, 211, 230, 208, 243, 177, 66, 146, 79, 92, 167, 180, 156, 229, 185, 213, 76, 71, 3, 215, + 174, 86, 72, 230, 29, 2, 38, 139, 26, 10, 159, 183, 33, 97, 28, 224, 166, 175, 133, 145, + 94, 47, 29, 112, 48, 9, 9, 206, 46, 73, 223, 173, 74, 70, 25, 200, 57, 12, 174, 102, 206, + 253, 178, 4, + ]; + let scalar_bytes = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 19, 140, 231, + 80, 250, 21, 194, + ]; + + let input_bytes = [&point_bytes[..], &scalar_bytes[..]].concat(); + + c.bench_function("bn128 multiplication", |b| { + b.iter(|| alt_bn128_multiplication(&input_bytes)) + }); +} + +fn bench_pairing(c: &mut Criterion) { + let p_bytes = [ + 28, 118, 71, 111, 77, 239, 75, 185, 69, 65, 213, 126, 187, 161, 25, 51, 129, 255, 167, 170, + 118, 173, 166, 100, 221, 49, 193, 96, 36, 196, 63, 89, 48, 52, 221, 41, 32, 246, 115, 226, + 4, 254, 226, 129, 28, 103, 135, 69, 252, 129, 155, 85, 211, 233, 210, 148, 228, 92, 155, 3, + 167, 106, 239, 65, + ]; + let q_bytes = [ + 32, 157, 209, 94, 191, 245, 212, 108, 75, 216, 136, 229, 26, 147, 207, 153, 167, 50, 150, + 54, 198, 53, 20, 57, 107, 74, 69, 32, 3, 163, 91, 247, 4, 191, 17, 202, 1, 72, 59, 250, + 139, 52, 180, 53, 97, 132, 141, 40, 144, 89, 96, 17, 76, 138, 192, 64, 73, 175, 75, 99, 21, + 164, 22, 120, 43, 184, 50, 74, 246, 207, 201, 53, 55, 162, 173, 26, 68, 92, 253, 12, 162, + 167, 26, 205, 122, 196, 31, 173, 191, 147, 60, 42, 81, 190, 52, 77, 18, 10, 42, 76, 243, + 12, 27, 249, 132, 95, 32, 198, 254, 57, 224, 126, 162, 204, 230, 31, 12, 155, 176, 72, 22, + 95, 229, 228, 222, 135, 117, 80, + ]; + + let input_bytes = [&p_bytes[..], &q_bytes[..]].concat(); + + c.bench_function("bn128 pairing", |b| { + b.iter(|| alt_bn128_pairing(&input_bytes)) + }); +} + +criterion_group!(benches, bench_addition, bench_multiplication, bench_pairing,); +criterion_main!(benches); diff --git a/curves/bn254/src/lib.rs b/curves/bn254/src/lib.rs index ff582abab32818..4052d2eefeb3b2 100644 --- a/curves/bn254/src/lib.rs +++ b/curves/bn254/src/lib.rs @@ -86,13 +86,46 @@ impl From for u64 { } } +use consts::{ALT_BN128_FIELD_SIZE as FIELD_SIZE, ALT_BN128_POINT_SIZE as G1_POINT_SIZE}; + +/// The BN254 (BN128) group element in G1 as a POD type. +/// +/// A group element in G1 consists of two field elements `(x, y)`. A `PodG1` +/// type expects a group element to be encoded as `[le(x), le(y)]` where +/// `le(..)` is the little-endian encoding of the input field element as used +/// in the `ark-bn254` crate. Note that this differs from the EIP-197 standard, +/// which specifies that the field elements are encoded as big-endian. +/// +/// The Solana syscalls still expect the inputs to be encoded in big-endian as +/// specified in EIP-197. The type `PodG1` is an intermediate type that +/// facilitates the translation between the EIP-197 encoding and the arkworks +/// implementation encoding. #[derive(Clone, Copy, Debug, PartialEq, Eq, Pod, Zeroable)] #[repr(transparent)] -pub struct PodG1(pub [u8; 64]); - +pub struct PodG1(pub [u8; G1_POINT_SIZE]); + +const G2_POINT_SIZE: usize = FIELD_SIZE * 4; + +/// The BN254 (BN128) group element in G2 as a POD type. +/// +/// Elements in G2 is represented by 2 field-extension elements `(x, y)`. Each +/// field-extension element itself is a degree 1 polynomial `x = x0 + x1*X`, +/// `y = y0 + y1*X`. The EIP-197 standard encodes a G2 element as +/// `[be(x1), be(x0), be(y1), be(y0)]` where `be(..)` is the big-endian +/// encoding of the input field element. The `ark-bn254` crate encodes a G2 +/// element as `[le(x0), le(x1), le(y0), le(y1)]` where `le(..)` is the +/// little-endian encoding of the input field element. Notably, in addition to +/// the differences in the big-endian vs. little-endian encodings of field +/// elements, the order of the polynomial field coefficients `x0`, `x1`, `y0`, +/// and `y1` are different. +/// +/// THe Solana syscalls still expect the inputs to be encoded as specified in +/// EIP-197. The type `PodG2` is an intermediate type that facilitates the +/// translation between the `EIP-197 encoding and the encoding used in the +/// arkworks implementation. #[derive(Clone, Copy, Debug, PartialEq, Eq, Pod, Zeroable)] #[repr(transparent)] -pub struct PodG2(pub [u8; 128]); +pub struct PodG2(pub [u8; G2_POINT_SIZE]); #[cfg(not(target_os = "solana"))] mod target_arch { @@ -107,6 +140,60 @@ mod target_arch { type G1 = ark_bn254::g1::G1Affine; type G2 = ark_bn254::g2::G2Affine; + impl PodG1 { + /// Takes in an EIP-197 (big-endian) byte encoding of a group element in G1 and constructs a + /// `PodG1` struct that encodes the same bytes in little-endian. + fn from_be_bytes(be_bytes: &[u8]) -> Result { + if be_bytes.len() != G1_POINT_SIZE { + return Err(AltBn128Error::SliceOutOfBounds); + } + let mut pod_bytes = [0u8; G1_POINT_SIZE]; + reverse_copy(&be_bytes[..FIELD_SIZE], &mut pod_bytes[..FIELD_SIZE])?; + reverse_copy(&be_bytes[FIELD_SIZE..], &mut pod_bytes[FIELD_SIZE..])?; + Ok(Self(pod_bytes)) + } + } + + impl PodG2 { + /// Takes in an EIP-197 (big-endian) byte encoding of a group element in G2 + /// and constructs a `PodG2` struct that encodes the same bytes in + /// little-endian. + fn from_be_bytes(be_bytes: &[u8]) -> Result { + if be_bytes.len() != G2_POINT_SIZE { + return Err(AltBn128Error::SliceOutOfBounds); + } + // note the cross order + const SOURCE_X1_INDEX: usize = 0; + const SOURCE_X0_INDEX: usize = SOURCE_X1_INDEX.saturating_add(FIELD_SIZE); + const SOURCE_Y1_INDEX: usize = SOURCE_X0_INDEX.saturating_add(FIELD_SIZE); + const SOURCE_Y0_INDEX: usize = SOURCE_Y1_INDEX.saturating_add(FIELD_SIZE); + + const TARGET_X0_INDEX: usize = 0; + const TARGET_X1_INDEX: usize = TARGET_X0_INDEX.saturating_add(FIELD_SIZE); + const TARGET_Y0_INDEX: usize = TARGET_X1_INDEX.saturating_add(FIELD_SIZE); + const TARGET_Y1_INDEX: usize = TARGET_Y0_INDEX.saturating_add(FIELD_SIZE); + + let mut pod_bytes = [0u8; G2_POINT_SIZE]; + reverse_copy( + &be_bytes[SOURCE_X1_INDEX..SOURCE_X1_INDEX.saturating_add(FIELD_SIZE)], + &mut pod_bytes[TARGET_X1_INDEX..TARGET_X1_INDEX.saturating_add(FIELD_SIZE)], + )?; + reverse_copy( + &be_bytes[SOURCE_X0_INDEX..SOURCE_X0_INDEX.saturating_add(FIELD_SIZE)], + &mut pod_bytes[TARGET_X0_INDEX..TARGET_X0_INDEX.saturating_add(FIELD_SIZE)], + )?; + reverse_copy( + &be_bytes[SOURCE_Y1_INDEX..SOURCE_Y1_INDEX.saturating_add(FIELD_SIZE)], + &mut pod_bytes[TARGET_Y1_INDEX..TARGET_Y1_INDEX.saturating_add(FIELD_SIZE)], + )?; + reverse_copy( + &be_bytes[SOURCE_Y0_INDEX..SOURCE_Y0_INDEX.saturating_add(FIELD_SIZE)], + &mut pod_bytes[TARGET_Y0_INDEX..TARGET_Y0_INDEX.saturating_add(FIELD_SIZE)], + )?; + Ok(Self(pod_bytes)) + } + } + impl TryFrom for G1 { type Error = AltBn128Error; @@ -167,18 +254,8 @@ mod target_arch { let mut input = input.to_vec(); input.resize(ALT_BN128_ADDITION_INPUT_LEN, 0); - let p: G1 = PodG1( - convert_endianness_64(&input[..64]) - .try_into() - .map_err(AltBn128Error::TryIntoVecError)?, - ) - .try_into()?; - let q: G1 = PodG1( - convert_endianness_64(&input[64..ALT_BN128_ADDITION_INPUT_LEN]) - .try_into() - .map_err(AltBn128Error::TryIntoVecError)?, - ) - .try_into()?; + let p: G1 = PodG1::from_be_bytes(&input[..64])?.try_into()?; + let q: G1 = PodG1::from_be_bytes(&input[64..ALT_BN128_ADDITION_INPUT_LEN])?.try_into()?; #[allow(clippy::arithmetic_side_effects)] let result_point = p + q; @@ -194,7 +271,7 @@ mod target_arch { .serialize_with_mode(&mut result_point_data[32..], Compress::No) .map_err(|_| AltBn128Error::InvalidInputData)?; - Ok(convert_endianness_64(&result_point_data[..]).to_vec()) + Ok(convert_endianness_64(&result_point_data[..])) } pub fn alt_bn128_multiplication(input: &[u8]) -> Result, AltBn128Error> { @@ -205,16 +282,11 @@ mod target_arch { let mut input = input.to_vec(); input.resize(ALT_BN128_MULTIPLICATION_INPUT_LEN, 0); - let p: G1 = PodG1( - convert_endianness_64(&input[..64]) - .try_into() - .map_err(AltBn128Error::TryIntoVecError)?, - ) - .try_into()?; - let fr = BigInteger256::deserialize_uncompressed_unchecked( - &convert_endianness_64(&input[64..96])[..], - ) - .map_err(|_| AltBn128Error::InvalidInputData)?; + let p: G1 = PodG1::from_be_bytes(&input[..64])?.try_into()?; + let mut fr_bytes = [0u8; 32]; + reverse_copy(&input[64..96], &mut fr_bytes)?; + let fr = BigInteger256::deserialize_uncompressed_unchecked(fr_bytes.as_slice()) + .map_err(|_| AltBn128Error::InvalidInputData)?; let result_point: G1 = p.mul_bigint(fr).into(); @@ -229,10 +301,9 @@ mod target_arch { .serialize_with_mode(&mut result_point_data[32..], Compress::No) .map_err(|_| AltBn128Error::InvalidInputData)?; - Ok( - convert_endianness_64(&result_point_data[..ALT_BN128_MULTIPLICATION_OUTPUT_LEN]) - .to_vec(), - ) + Ok(convert_endianness_64( + &result_point_data[..ALT_BN128_MULTIPLICATION_OUTPUT_LEN], + )) } pub fn alt_bn128_pairing(input: &[u8]) -> Result, AltBn128Error> { @@ -246,32 +317,14 @@ mod target_arch { let ele_len = input.len().saturating_div(ALT_BN128_PAIRING_ELEMENT_LEN); - let mut vec_pairs: Vec<(G1, G2)> = Vec::new(); - for i in 0..ele_len { - vec_pairs.push(( - PodG1( - convert_endianness_64( - &input[i.saturating_mul(ALT_BN128_PAIRING_ELEMENT_LEN) - ..i.saturating_mul(ALT_BN128_PAIRING_ELEMENT_LEN) - .saturating_add(ALT_BN128_POINT_SIZE)], - ) - .try_into() - .map_err(AltBn128Error::TryIntoVecError)?, - ) - .try_into()?, - PodG2( - convert_endianness_128( - &input[i - .saturating_mul(ALT_BN128_PAIRING_ELEMENT_LEN) - .saturating_add(ALT_BN128_POINT_SIZE) - ..i.saturating_mul(ALT_BN128_PAIRING_ELEMENT_LEN) - .saturating_add(ALT_BN128_PAIRING_ELEMENT_LEN)], - ) - .try_into() - .map_err(AltBn128Error::TryIntoVecError)?, - ) - .try_into()?, - )); + let mut vec_pairs: Vec<(G1, G2)> = Vec::with_capacity(ele_len); + for chunk in input.chunks(ALT_BN128_PAIRING_ELEMENT_LEN) { + let (p_bytes, q_bytes) = chunk.split_at(G1_POINT_SIZE); + + let g1 = PodG1::from_be_bytes(p_bytes)?.try_into()?; + let g2 = PodG2::from_be_bytes(q_bytes)?.try_into()?; + + vec_pairs.push((g1, g2)); } let mut result = BigInteger256::from(0u64); @@ -295,11 +348,15 @@ mod target_arch { .collect::>() } - fn convert_endianness_128(bytes: &[u8]) -> Vec { - bytes - .chunks(64) - .flat_map(|b| b.iter().copied().rev().collect::>()) - .collect::>() + /// Copies a `source` byte slice into a `destination` byte slice in reverse order. + fn reverse_copy(source: &[u8], destination: &mut [u8]) -> Result<(), AltBn128Error> { + if source.len() != destination.len() { + return Err(AltBn128Error::SliceOutOfBounds); + } + for (source_index, destination_index) in source.iter().rev().zip(destination.iter_mut()) { + *destination_index = *source_index; + } + Ok(()) } }