diff --git a/reed-solomon-novelpoly/src/field/inc_afft.rs b/reed-solomon-novelpoly/src/field/inc_afft.rs index ad877a7..6193048 100644 --- a/reed-solomon-novelpoly/src/field/inc_afft.rs +++ b/reed-solomon-novelpoly/src/field/inc_afft.rs @@ -88,27 +88,25 @@ fn b_is_one() { // We're hunting for the differences and trying to undersrtand the algorithm. /// Inverse additive FFT in the "novel polynomial basis" -/// -/// # Safety -/// See safety section of `AdditiveFFT::inverse_afft`. -pub unsafe fn inverse_afft(data: &mut [Additive], size: usize, index: usize) { +#[inline(always)] +pub fn inverse_afft(data: &mut [Additive], size: usize, index: usize) { unsafe { &AFFT }.inverse_afft(data, size, index) } #[cfg(all(target_feature = "avx", feature = "avx"))] +#[inline(always)] pub fn inverse_afft_faster8(data: &mut [Additive], size: usize, index: usize) { unsafe { &AFFT }.inverse_afft_faster8(data, size, index) } /// Additive FFT in the "novel polynomial basis" -/// -/// # Safety -/// See safety section of `AdditiveFFT::afft`. -pub unsafe fn afft(data: &mut [Additive], size: usize, index: usize) { +#[inline(always)] +pub fn afft(data: &mut [Additive], size: usize, index: usize) { unsafe { &AFFT }.afft(data, size, index) } #[cfg(all(target_feature = "avx", feature = "avx"))] +#[inline(always)] /// Additive FFT in the "novel polynomial basis" pub fn afft_faster8(data: &mut [Additive], size: usize, index: usize) { unsafe { &AFFT }.afft_faster8(data, size, index) @@ -138,12 +136,7 @@ impl AdditiveFFT { } /// Inverse additive FFT in the "novel polynomial basis" - /// - /// # Safety - /// - /// - caller must ensure than `size` is a power of two and that the length of the `data` slice is at least equal to `size`. - /// - caller must ensure that `index + size - 2` is less than or equal to 65534. - pub unsafe fn inverse_afft(&self, data: &mut [Additive], size: usize, index: usize) { + pub fn inverse_afft(&self, data: &mut [Additive], size: usize, index: usize) { // All line references to Algorithm 2 page 6288 of // https://www.citi.sinica.edu.tw/papers/whc/5524-F.pdf @@ -154,6 +147,8 @@ impl AdditiveFFT { // After this, we start at depth (i of Algorithm 2) = (k of Algorithm 2) - 1 // and progress through FIELD_BITS-1 steps, obtaining \Psi_\beta(0,0). let mut depart_no = 1_usize; + assert!(data.len() >= size); + while depart_no < size { // if depart_no >= 8 { // println!("\n\n\nplain/Round depart_no={depart_no}"); @@ -180,42 +175,16 @@ impl AdditiveFFT { // if depart_no >= 8 && false{ // data[i + depart_no] ^= dbg!(data[dbg!(i)]); // } else { - #[cfg(debug)] - { - data[i + depart_no] ^= data[i]; - } - #[cfg(not(debug))] - { - // SAFETY - // - // j is smaller than size. depart_no is smaller than size. - // depart_no is always doubled, so it's always a power of two smaller than size. - // this means that depart_no is at most half of size, assuming size is a power of two. - // - // i is at most j - 1. j is greater than depart_no but is incremented by double of depart_no. - // for the max depart_no value of size/2, j will only have the one value of size/2, - // so the index will be size/2 - 1 + size/2, which is equal to size - 1, which is safe. - // i will always be smaller than i + depart_no, since they're positive integers. qed. - let local = unsafe { *data.get_unchecked(i) }; - unsafe { *data.get_unchecked_mut(i + depart_no) ^= local }; - } - // } + // TODO: Optimising bounds checks on this line will yield a great performance improvement. + data[i + depart_no] ^= data[i]; } // Algorithm 2 indexs the skew factor in line 5 page 6288 // by i and \omega_{j 2^{i+1}}, but not by r explicitly. // We further explore this confusion below. (TODO) - #[cfg(debug)] let skew = self.skews[j + index - 1]; - #[cfg(not(debug))] - // SAFETY: - // - // Safe because caller ensured that index + size - 2 is less than or equal to 65534 (the skew vector len). - // Since, j is at most size - 1, this is safe. - let skew = unsafe { *self.skews.get_unchecked(j + index - 1) }; - // It's reasonale to skip the loop if skew is zero, but doing so with // all bits set requires justification. (TODO) if skew.0 != ONEMASK { @@ -226,18 +195,9 @@ impl AdditiveFFT { // if depart_no >= 8 && false{ // data[i] ^= dbg!(dbg!(data[dbg!(i + depart_no)]).mul(skew)); // } else { - #[cfg(debug)] - { - data[i] ^= data[i + depart_no].mul(skew); - } - - #[cfg(not(debug))] - // Same safety princicples as the first `for i in (j - depart_no)..j` loop. - { - let local = unsafe { *data.get_unchecked(i + depart_no) }; - unsafe { *data.get_unchecked_mut(i) ^= local.mul(skew) }; - } - // } + + // TODO: Optimising bounds checks on this line will yield a great performance improvement. + data[i] ^= data[i + depart_no].mul(skew); } } @@ -304,12 +264,7 @@ impl AdditiveFFT { } /// Additive FFT in the "novel polynomial basis" - /// - /// # Safety - /// - /// - caller must ensure than `size` is a power of two and that the length of the `data` slice is at least equal to `size`. - /// - caller must ensure that `index + size - 2` is less than or equal to 65534. - pub unsafe fn afft(&self, data: &mut [Additive], size: usize, index: usize) { + pub fn afft(&self, data: &mut [Additive], size: usize, index: usize) { // All line references to Algorithm 1 page 6287 of // https://www.citi.sinica.edu.tw/papers/whc/5524-F.pdf @@ -320,6 +275,8 @@ impl AdditiveFFT { // After this, we start at depth (i of Algorithm 1) = (k of Algorithm 1) - 1 // and progress through FIELD_BITS-1 steps, obtaining \Psi_\beta(0,0). let mut depart_no = size >> 1_usize; + assert!(data.len() >= size); + while depart_no > 0 { // Agrees with for loop (j of Algorithm 1) in (0..2^{k-i-1}) from line 5, // except we've j in (depart_no..size).step_by(2*depart_no), meaning @@ -342,16 +299,8 @@ impl AdditiveFFT { // like in (19) in the proof of Lemma 4. (TODO) // We should understand the rest of this basis story, like (8) too. (TODO) - #[cfg(debug)] let skew = self.skews[j + index - 1]; - #[cfg(not(debug))] - // SAFETY: - // - // Safe because caller ensured that index + size - 2 is less than or equal to 65534 (the skew vector len). - // Since, j is at most size - 1, this is safe. - let skew = unsafe { *self.skews.get_unchecked(j + index - 1) }; - // It's reasonale to skip the loop if skew is zero, but doing so with // all bits set requires justification. (TODO) if skew.0 != ONEMASK { @@ -360,25 +309,8 @@ impl AdditiveFFT { // Line 6, explained by (28) page 6287, but // adding depart_no acts like the r+2^i superscript. - #[cfg(debug)] - { - data[i] ^= data[i + depart_no].mul(skew); - } - - #[cfg(not(debug))] - { - // SAFETY - // - // j is smaller than size. depart_no is smaller than size/2 and it's always halved. - // this means that depart_no is at most half of size, assuming size is a power of two. - // - // i is at most j - 1. j is greater than depart_no but is incremented by double of depart_no. - // for the max depart_no value of size/2, j will only have the one value of size/2, - // so the index will be size/2 - 1 + size/2, which is equal to size - 1, which is safe. - // i will always be smaller than i + depart_no, since they're positive integers. qed. - let local = unsafe { *data.get_unchecked(i + depart_no) }; - unsafe { *data.get_unchecked_mut(i) ^= local.mul(skew) }; - } + // TODO: Optimising bounds checks on this line will yield a great performance improvement. + data[i] ^= data[i + depart_no].mul(skew); } } @@ -386,17 +318,9 @@ impl AdditiveFFT { for i in (j - depart_no)..j { // Line 7, explained by (31) page 6287, but // adding depart_no acts like the r+2^i superscript. - #[cfg(debug)] - { - data[i + depart_no] ^= data[i]; - } - #[cfg(not(debug))] - { - // Same safety princicples as the first `for i in (j - depart_no)..j` loop. - let local = unsafe { *data.get_unchecked(i) }; - unsafe { *data.get_unchecked_mut(i + depart_no) ^= local }; - } + // TODO: Optimising bounds checks on this line will yield a great performance improvement. + data[i + depart_no] ^= data[i]; } // Increment by double depart_no in agreement with diff --git a/reed-solomon-novelpoly/src/field/inc_encode.rs b/reed-solomon-novelpoly/src/field/inc_encode.rs index 1bdb472..3746fd2 100644 --- a/reed-solomon-novelpoly/src/field/inc_encode.rs +++ b/reed-solomon-novelpoly/src/field/inc_encode.rs @@ -29,36 +29,18 @@ pub fn encode_low_plain(data: &[Additive], k: usize, codeword: &mut [Additive], // split after the first k let (codeword_first_k, codeword_skip_first_k) = codeword.split_at_mut(k); - // - safe because codeword_first_k is exactly k elements and k is a power of two. - // - safe because `index + size - 2` is `k - 2`. k is at most n/2 and n is at most 65536. Therefore, - // k is at most 65536/2-2 = 32766 (smaller than 65535). qed. - unsafe { inverse_afft(codeword_first_k, k, 0) }; + inverse_afft(codeword_first_k, k, 0); // dbg!(&codeword_first_k); // the first codeword is now the basis for the remaining transforms // denoted `M_topdash` for shift in (k..n).step_by(k) { - #[cfg(debug)] let codeword_at_shift = &mut codeword_skip_first_k[(shift - k)..shift]; - #[cfg(not(debug))] - // SAFETY - // - // n is i*k, with i at least 2. shift is at most (i-1)*k. - // (i-1) * k will always be smaller than i*k for all i greater than 2. - // Similarly, shift - k will always be smaller than shift, since they're positive integers - // and shift is at least equal to k. - let codeword_at_shift = unsafe { codeword_skip_first_k.get_unchecked_mut((shift - k)..shift) }; - // copy `M_topdash` to the position we are currently at, the n transform codeword_at_shift.copy_from_slice(codeword_first_k); - // SAFETY - // - // - safe because codeword_first_k is exactly k elements and k is a power of two. - // - k is at most n/2 (32768). `index + size - 2` is therefore equal to 2*k - 2 = 65534 which - // is less than or equal to 65534. qed. - unsafe { afft(codeword_at_shift, k, shift) }; + afft(codeword_at_shift, k, shift); } // restore `M` from the derived ones @@ -118,68 +100,65 @@ pub fn encode_low_faster8(data: &[Additive], k: usize, codeword: &mut [Additive] //data: message array. parity: parity array. mem: buffer(size>= n-k) //Encoding alg for k/n>0.5: parity is a power of two. -// Function is not exposed/tested. Consider the safety guidelines of the *afft functions before using. -// #[inline(always)] -// pub fn encode_high(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { -// #[cfg(all(target_feature = "avx", feature = "avx"))] -// if (n - k) % Additive8x::LANE == 0 && n % Additive8x::LANE == 0 && k % Additive8x::LANE == 0 { -// encode_high_faster8(data, k, parity, mem, n); -// } else { -// encode_high_plain(data, k, parity, mem, n); -// } -// #[cfg(not(target_feature = "avx"))] -// encode_high_plain(data, k, parity, mem, n); -// } +#[inline(always)] +pub fn encode_high(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { + #[cfg(all(target_feature = "avx", feature = "avx"))] + if (n - k) % Additive8x::LANE == 0 && n % Additive8x::LANE == 0 && k % Additive8x::LANE == 0 { + encode_high_faster8(data, k, parity, mem, n); + } else { + encode_high_plain(data, k, parity, mem, n); + } + #[cfg(not(target_feature = "avx"))] + encode_high_plain(data, k, parity, mem, n); +} //data: message array. parity: parity array. mem: buffer(size>= n-k) //Encoding alg for k/n>0.5: parity is a power of two. -// Function is not exposed/tested. Consider the safety guidelines of the *afft functions before using. -// pub fn encode_high_plain(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { -// assert!(is_power_of_2(n)); - -// let t: usize = n - k; - -// // mem_zero(&mut parity[0..t]); -// for i in 0..t { -// parity[i] = Additive(0); -// } - -// let mut i = t; -// while i < n { -// mem[..t].copy_from_slice(&data[(i - t)..t]); - -// unsafe { inverse_afft(mem, t, i) }; -// for j in 0..t { -// parity[j] ^= mem[j]; -// } -// i += t; -// } -// unsafe { afft(parity, t, 0) }; -// } - -// #[cfg(all(target_feature = "avx", feature = "avx"))] -// Function is not exposed/tested. Consider the safety guidelines of the *afft functions before using. -// pub fn encode_high_faster8(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { -// let t: usize = n - k; -// assert!(t >= 8); -// assert_eq!(t % 8, 0); - -// for i in 0..t { -// parity[i] = Additive::zero(); -// } - -// let mut i = t; -// while i < n { -// mem[..t].copy_from_slice(&data[(i - t)..t]); - -// inverse_afft_faster8(mem, t, i); -// for j in 0..t { -// parity[j] ^= mem[j]; -// } -// i += t; -// } -// afft_faster8(parity, t, 0); -// } +pub fn encode_high_plain(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { + assert!(is_power_of_2(n)); + + let t: usize = n - k; + + // mem_zero(&mut parity[0..t]); + for i in 0..t { + parity[i] = Additive(0); + } + + let mut i = t; + while i < n { + mem[..t].copy_from_slice(&data[(i - t)..t]); + + inverse_afft(mem, t, i); + for j in 0..t { + parity[j] ^= mem[j]; + } + i += t; + } + afft(parity, t, 0); +} + +#[cfg(all(target_feature = "avx", feature = "avx"))] +pub fn encode_high_faster8(data: &[Additive], k: usize, parity: &mut [Additive], mem: &mut [Additive], n: usize) { + let t: usize = n - k; + assert!(t >= 8); + assert_eq!(t % 8, 0); + + for i in 0..t { + parity[i] = Additive::zero(); + } + + let mut i = t; + while i < n { + mem[..t].copy_from_slice(&data[(i - t)..t]); + + inverse_afft_faster8(mem, t, i); + for j in 0..t { + parity[j] ^= mem[j]; + } + i += t; + } + afft_faster8(parity, t, 0); +} pub fn encode_sub(bytes: &[u8], n: usize, k: usize) -> Result> { #[cfg(all(target_feature = "avx", feature = "avx"))] @@ -221,32 +200,14 @@ pub fn encode_sub_plain(bytes: &[u8], n: usize, k: usize) -> Result= size); + let mask = ONEMASK as Wide; let mut depart_no = 1_usize; while depart_no < size { diff --git a/reed-solomon-novelpoly/src/field/inc_reconstruct.rs b/reed-solomon-novelpoly/src/field/inc_reconstruct.rs index 87dfacb..cc9b140 100644 --- a/reed-solomon-novelpoly/src/field/inc_reconstruct.rs +++ b/reed-solomon-novelpoly/src/field/inc_reconstruct.rs @@ -73,23 +73,11 @@ pub(crate) fn decode_main( codeword[i] = if erasure[i] { Additive(0) } else { codeword[i].mul(log_walsh2[i]) }; } - // SAFETY - // - // - safe because we check in `reconstruct_sub` that n is a power of two and we also check that - // codeword.len() is equal to n. - // - n is at most 65536. `index + size - 2` is therefore equal to 65536 - 2 = 65534 which - // is less than or equal to 65534. qed. - unsafe { inverse_afft(codeword, n, 0) }; + inverse_afft(codeword, n, 0); tweaked_formal_derivative(codeword); - // SAFETY - // - // - safe because we check in `reconstruct_sub` that n is a power of two and we also check that - // codeword.len() is equal to n. - // - n is at most 65536. `index + size - 2` is therefore equal to 65536 - 2 = 65534 which - // is less than or equal to 65534. qed. - unsafe { afft(codeword, n, 0) }; + afft(codeword, n, 0); for i in 0..recover_up_to { codeword[i] = if erasure[i] { codeword[i].mul(log_walsh2[i]) } else { Additive(0) }; @@ -101,6 +89,10 @@ pub(crate) fn decode_main( // since this has only to be called once per reconstruction pub fn eval_error_polynomial(erasure: &[bool], log_walsh2: &mut [Multiplier], n: usize) { let z = std::cmp::min(n, erasure.len()); + assert!(z <= erasure.len()); + assert!(n <= log_walsh2.len()); + assert!(z <= log_walsh2.len()); + for i in 0..z { log_walsh2[i] = Multiplier(erasure[i] as Elt); } diff --git a/reed-solomon-novelpoly/src/novel_poly_basis/tests.rs b/reed-solomon-novelpoly/src/novel_poly_basis/tests.rs index 5c90d9c..d8b6728 100644 --- a/reed-solomon-novelpoly/src/novel_poly_basis/tests.rs +++ b/reed-solomon-novelpoly/src/novel_poly_basis/tests.rs @@ -70,12 +70,12 @@ fn flt_back_and_forth() { let mut data = (0..N).map(|_x| rand_gf_element()).collect::>(); let expected = data.clone(); - unsafe { afft(&mut data, N, N / 4) }; + afft(&mut data, N, N / 4); // make sure something is done assert!(data.iter().zip(expected.iter()).filter(|(a, b)| { a != b }).count() > 0); - unsafe { inverse_afft(&mut data, N, N / 4) }; + inverse_afft(&mut data, N, N / 4); itertools::assert_equal(data, expected); } @@ -314,7 +314,7 @@ fn flt_roundtrip_small() { let mut data = EXPECTED; - unsafe { f2e16::afft(&mut data, N, N / 4) }; + f2e16::afft(&mut data, N, N / 4); println!("novel basis(rust):"); data.iter().for_each(|sym| { @@ -322,7 +322,7 @@ fn flt_roundtrip_small() { }); println!(); - unsafe { f2e16::inverse_afft(&mut data, N, N / 4) }; + f2e16::inverse_afft(&mut data, N, N / 4); itertools::assert_equal(data.iter(), EXPECTED.iter()); }