From 886be0eda35e5f28bfae46d20b68ba1e8eed73e2 Mon Sep 17 00:00:00 2001 From: Alin Dima Date: Thu, 11 Jan 2024 12:36:39 +0200 Subject: [PATCH] optimise some of the bounds checks (#34) * optimise bounds checks This brings a performance improvement of 40-100%, making this implementation as fast as the C++ alternative in kagome. Where possible, compiler is aided to optimise away the bounds checks without any unsafe code. However, a fair amount of unsafe code was needed, but it doesn't lower the security posture as the needed assertions were already being made. Signed-off-by: alindima * fix clippy * switch to using safe optimisations * revert some changes --------- Signed-off-by: alindima --- reed-solomon-novelpoly/src/field/inc_afft.rs | 48 ++++++++++++------- .../src/field/inc_encode.rs | 23 ++++----- .../src/field/inc_log_mul.rs | 8 +++- .../src/field/inc_reconstruct.rs | 8 +++- .../src/novel_poly_basis/mod.rs | 2 +- 5 files changed, 53 insertions(+), 36 deletions(-) diff --git a/reed-solomon-novelpoly/src/field/inc_afft.rs b/reed-solomon-novelpoly/src/field/inc_afft.rs index 5032cd2..1fd0d52 100644 --- a/reed-solomon-novelpoly/src/field/inc_afft.rs +++ b/reed-solomon-novelpoly/src/field/inc_afft.rs @@ -14,16 +14,16 @@ pub struct AdditiveFFT { } /// Formal derivative of polynomial in the new?? basis -pub fn formal_derivative(cos: &mut [Additive], size: usize) { - for i in 1..size { +pub fn formal_derivative(cos: &mut [Additive]) { + for i in 1..cos.len() { let length = ((i ^ (i - 1)) + 1) >> 1; for j in (i - length)..i { cos[j] ^= cos.get(j + length).copied().unwrap_or(Additive::ZERO); } } - let mut i = size; + let mut i = cos.len(); while i < FIELD_SIZE && i < cos.len() { - for j in 0..size { + for j in 0..cos.len() { cos[j] ^= cos.get(j + i).copied().unwrap_or(Additive::ZERO); } i <<= 1; @@ -32,9 +32,11 @@ pub fn formal_derivative(cos: &mut [Additive], size: usize) { /// Formal derivative of polynomial in tweaked?? basis #[allow(non_snake_case)] -pub fn tweaked_formal_derivative(codeword: &mut [Additive], n: usize) { +pub fn tweaked_formal_derivative(codeword: &mut [Additive]) { #[cfg(b_is_not_one)] let B = unsafe { &AFFT.B }; + #[cfg(b_is_not_one)] + let n = codeword.len(); // We change nothing when multiplying by b from B. #[cfg(b_is_not_one)] @@ -44,7 +46,7 @@ pub fn tweaked_formal_derivative(codeword: &mut [Additive], n: usize) { codeword[i + 1] = codeword[i + 1].mul(b); } - formal_derivative(codeword, n); + formal_derivative(codeword); // Again changes nothing by multiplying by b although b differs here. #[cfg(b_is_not_one)] @@ -86,21 +88,25 @@ fn b_is_one() { // We're hunting for the differences and trying to undersrtand the algorithm. /// Inverse additive FFT in the "novel polynomial basis" +#[inline(always)] pub fn inverse_afft(data: &mut [Additive], size: usize, index: usize) { unsafe { &AFFT }.inverse_afft(data, size, index) } #[cfg(all(target_feature = "avx", feature = "avx"))] +#[inline(always)] pub fn inverse_afft_faster8(data: &mut [Additive], size: usize, index: usize) { unsafe { &AFFT }.inverse_afft_faster8(data, size, index) } /// Additive FFT in the "novel polynomial basis" +#[inline(always)] pub fn afft(data: &mut [Additive], size: usize, index: usize) { unsafe { &AFFT }.afft(data, size, index) } #[cfg(all(target_feature = "avx", feature = "avx"))] +#[inline(always)] /// Additive FFT in the "novel polynomial basis" pub fn afft_faster8(data: &mut [Additive], size: usize, index: usize) { unsafe { &AFFT }.afft_faster8(data, size, index) @@ -141,6 +147,8 @@ impl AdditiveFFT { // After this, we start at depth (i of Algorithm 2) = (k of Algorithm 2) - 1 // and progress through FIELD_BITS-1 steps, obtaining \Psi_\beta(0,0). let mut depart_no = 1_usize; + assert!(data.len() >= size); + while depart_no < size { // if depart_no >= 8 { // println!("\n\n\nplain/Round depart_no={depart_no}"); @@ -167,20 +175,16 @@ impl AdditiveFFT { // if depart_no >= 8 && false{ // data[i + depart_no] ^= dbg!(data[dbg!(i)]); // } else { + + // TODO: Optimising bounds checks on this line will yield a great performance improvement. data[i + depart_no] ^= data[i]; - // } } // Algorithm 2 indexs the skew factor in line 5 page 6288 // by i and \omega_{j 2^{i+1}}, but not by r explicitly. // We further explore this confusion below. (TODO) - let skew = - // if depart_no >= 8 && false { - // dbg!(self.skews[j + index - 1]) - // } else { - self.skews[j + index - 1] - // } - ; + let skew = self.skews[j + index - 1]; + // It's reasonale to skip the loop if skew is zero, but doing so with // all bits set requires justification. (TODO) if skew.0 != ONEMASK { @@ -191,8 +195,9 @@ impl AdditiveFFT { // if depart_no >= 8 && false{ // data[i] ^= dbg!(dbg!(data[dbg!(i + depart_no)]).mul(skew)); // } else { + + // TODO: Optimising bounds checks on this line will yield a great performance improvement. data[i] ^= data[i + depart_no].mul(skew); - // } } } @@ -270,6 +275,8 @@ impl AdditiveFFT { // After this, we start at depth (i of Algorithm 1) = (k of Algorithm 1) - 1 // and progress through FIELD_BITS-1 steps, obtaining \Psi_\beta(0,0). let mut depart_no = size >> 1_usize; + assert!(data.len() >= size); + while depart_no > 0 { // Agrees with for loop (j of Algorithm 1) in (0..2^{k-i-1}) from line 5, // except we've j in (depart_no..size).step_by(2*depart_no), meaning @@ -291,6 +298,7 @@ impl AdditiveFFT { // we think r actually appears but the skew factor repeats itself // like in (19) in the proof of Lemma 4. (TODO) // We should understand the rest of this basis story, like (8) too. (TODO) + let skew = self.skews[j + index - 1]; // It's reasonale to skip the loop if skew is zero, but doing so with @@ -300,6 +308,8 @@ impl AdditiveFFT { for i in (j - depart_no)..j { // Line 6, explained by (28) page 6287, but // adding depart_no acts like the r+2^i superscript. + + // TODO: Optimising bounds checks on this line will yield a great performance improvement. data[i] ^= data[i + depart_no].mul(skew); } } @@ -308,6 +318,8 @@ impl AdditiveFFT { for i in (j - depart_no)..j { // Line 7, explained by (31) page 6287, but // adding depart_no acts like the r+2^i superscript. + + // TODO: Optimising bounds checks on this line will yield a great performance improvement. data[i + depart_no] ^= data[i]; } @@ -484,7 +496,7 @@ pub mod test_utils { let data = gen_plain::(size); gen_faster8_from_plain(data) } - + #[cfg(all(target_feature = "avx", feature = "avx"))] pub fn assert_plain_eq_faster8(plain: impl AsRef<[Additive]>, faster8: impl AsRef<[Additive]>) { let plain = plain.as_ref(); @@ -502,7 +514,7 @@ mod afft_tests { use super::super::*; use super::super::test_utils::*; use rand::rngs::SmallRng; - + #[cfg(all(target_feature = "avx", feature = "avx"))] #[test] fn afft_output_plain_eq_faster8_size_16() { @@ -544,7 +556,7 @@ mod afft_tests { println!(">>>>"); assert_plain_eq_faster8(data_plain, data_faster8); } - + #[cfg(all(target_feature = "avx", feature = "avx"))] #[test] fn afft_output_plain_eq_faster8_impulse_data() { diff --git a/reed-solomon-novelpoly/src/field/inc_encode.rs b/reed-solomon-novelpoly/src/field/inc_encode.rs index 5a0aeb2..103497e 100644 --- a/reed-solomon-novelpoly/src/field/inc_encode.rs +++ b/reed-solomon-novelpoly/src/field/inc_encode.rs @@ -7,7 +7,7 @@ pub fn encode_low(data: &[Additive], k: usize, codeword: &mut [Additive], n: usi encode_low_plain(data, k, codeword, n); } - #[cfg(not(target_feature = "avx"))] + #[cfg(not(all(target_feature = "avx", feature = "avx")))] encode_low_plain(data, k, codeword, n); } @@ -37,12 +37,10 @@ pub fn encode_low_plain(data: &[Additive], k: usize, codeword: &mut [Additive], for shift in (k..n).step_by(k) { let codeword_at_shift = &mut codeword_skip_first_k[(shift - k)..shift]; + // copy `M_topdash` to the position we are currently at, the n transform codeword_at_shift.copy_from_slice(codeword_first_k); - // dbg!(&codeword_at_shift); afft(codeword_at_shift, k, shift); - // let post = &codeword_at_shift; - // dbg!(post); } // restore `M` from the derived ones @@ -79,11 +77,10 @@ pub fn encode_low_faster8(data: &[Additive], k: usize, codeword: &mut [Additive] for shift in (k..n).step_by(k) { let codeword_at_shift = &mut codeword_skip_first_k[(shift - k)..shift]; + // copy `M_topdash` to the position we are currently at, the n transform codeword_at_shift.copy_from_slice(codeword_first_k); - afft_faster8(codeword_at_shift, k, shift); - // let post = &codeword8x_at_shift; } // restore `M` from the derived ones @@ -108,6 +105,8 @@ pub fn encode_high(data: &[Additive], k: usize, parity: &mut [Additive], mem: &m //data: message array. parity: parity array. mem: buffer(size>= n-k) //Encoding alg for k/n>0.5: parity is a power of two. pub fn encode_high_plain(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { + assert!(is_power_of_2(n)); + let t: usize = n - k; // mem_zero(&mut parity[0..t]); @@ -158,7 +157,7 @@ pub fn encode_sub(bytes: &[u8], n: usize, k: usize) -> Result> { } else { encode_sub_plain(bytes, n, k) } - #[cfg(not(target_feature = "avx"))] + #[cfg(not(all(target_feature = "avx", feature = "avx")))] encode_sub_plain(bytes, n, k) } @@ -194,13 +193,11 @@ pub fn encode_sub_plain(bytes: &[u8], n: usize, k: usize) -> Result Result= size); + let mask = ONEMASK as Wide; let mut depart_no = 1_usize; while depart_no < size { diff --git a/reed-solomon-novelpoly/src/field/inc_reconstruct.rs b/reed-solomon-novelpoly/src/field/inc_reconstruct.rs index 591def0..cc9b140 100644 --- a/reed-solomon-novelpoly/src/field/inc_reconstruct.rs +++ b/reed-solomon-novelpoly/src/field/inc_reconstruct.rs @@ -69,13 +69,13 @@ pub(crate) fn decode_main( assert!(n >= recover_up_to); assert_eq!(erasure.len(), n); - for i in 0..n { + for i in 0..codeword.len() { codeword[i] = if erasure[i] { Additive(0) } else { codeword[i].mul(log_walsh2[i]) }; } inverse_afft(codeword, n, 0); - tweaked_formal_derivative(codeword, n); + tweaked_formal_derivative(codeword); afft(codeword, n, 0); @@ -89,6 +89,10 @@ pub(crate) fn decode_main( // since this has only to be called once per reconstruction pub fn eval_error_polynomial(erasure: &[bool], log_walsh2: &mut [Multiplier], n: usize) { let z = std::cmp::min(n, erasure.len()); + assert!(z <= erasure.len()); + assert!(n <= log_walsh2.len()); + assert!(z <= log_walsh2.len()); + for i in 0..z { log_walsh2[i] = Multiplier(erasure[i] as Elt); } diff --git a/reed-solomon-novelpoly/src/novel_poly_basis/mod.rs b/reed-solomon-novelpoly/src/novel_poly_basis/mod.rs index 3a80ff2..d65f688 100644 --- a/reed-solomon-novelpoly/src/novel_poly_basis/mod.rs +++ b/reed-solomon-novelpoly/src/novel_poly_basis/mod.rs @@ -66,7 +66,7 @@ impl CodeParams { { self.k >= (Additive8x::LANE << 1) && self.n % Additive8x::LANE == 0 } - #[cfg(not(target_feature = "avx"))] + #[cfg(not(all(target_feature = "avx", feature = "avx")))] false }