From f0d96608174944e0e64b568ebc8a9f1c60cfe134 Mon Sep 17 00:00:00 2001 From: Anders Trier Olesen Date: Wed, 9 Oct 2024 16:49:43 +0200 Subject: [PATCH] avx512: Avoid continuously reloading the look up table --- src/engine/engine_avx512.rs | 106 ++++++++++++++++++++++++------------ 1 file changed, 71 insertions(+), 35 deletions(-) diff --git a/src/engine/engine_avx512.rs b/src/engine/engine_avx512.rs index 5f5a5ef..7c2a38d 100644 --- a/src/engine/engine_avx512.rs +++ b/src/engine/engine_avx512.rs @@ -93,16 +93,74 @@ impl Default for Avx512 { // // +#[derive(Copy, Clone)] +struct Lut512 { + t0_t2_lo: __m512i, + t0_t2_hi: __m512i, + t1_t3_lo: __m512i, + t1_t3_hi: __m512i, +} + +#[inline(always)] +fn lut512(lut: &Multiply128lutT) -> Lut512 { + let t0_t2_lo: __m512i; + let t0_t2_hi: __m512i; + let t1_t3_lo: __m512i; + let t1_t3_hi: __m512i; + + unsafe { + let t0_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.lo[0] as *const u128 as *const __m128i, + )); + let t1_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.lo[1] as *const u128 as *const __m128i, + )); + let t2_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.lo[2] as *const u128 as *const __m128i, + )); + let t3_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.lo[3] as *const u128 as *const __m128i, + )); + + let t0_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.hi[0] as *const u128 as *const __m128i, + )); + let t1_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.hi[1] as *const u128 as *const __m128i, + )); + let t2_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.hi[2] as *const u128 as *const __m128i, + )); + let t3_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( + &lut.hi[3] as *const u128 as *const __m128i, + )); + + t0_t2_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t0_lo), t2_lo, 1); + t0_t2_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t0_hi), t2_hi, 1); + t1_t3_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t1_lo), t3_lo, 1); + t1_t3_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t1_hi), t3_hi, 1); + } + + Lut512 { + t0_t2_lo, + t0_t2_hi, + t1_t3_lo, + t1_t3_hi, + } +} + impl Avx512 { #[target_feature(enable = "avx512f,avx512vl,avx512bw")] unsafe fn mul_avx512(&self, x: &mut [[u8; 64]], log_m: GfElement) { let lut = &self.mul128[log_m as usize]; + let lut512 = lut512(lut); + for chunk in x.iter_mut() { let x_ptr = chunk.as_mut_ptr() as *mut i32; unsafe { let x = _mm512_loadu_si512(x_ptr); - let prod = Self::mul_512(x, lut); + let prod = Self::mul_512(x, lut512); _mm512_storeu_si512(x_ptr, prod); } } @@ -110,38 +168,12 @@ impl Avx512 { // Impelemntation of LEO_MUL_256 #[inline(always)] - fn mul_512(value: __m512i, lut: &Multiply128lutT) -> __m512i { + fn mul_512(value: __m512i, lut: Lut512) -> __m512i { unsafe { - let t0_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.lo[0] as *const u128 as *const __m128i, - )); - let t1_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.lo[1] as *const u128 as *const __m128i, - )); - let t2_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.lo[2] as *const u128 as *const __m128i, - )); - let t3_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.lo[3] as *const u128 as *const __m128i, - )); - - let t0_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.hi[0] as *const u128 as *const __m128i, - )); - let t1_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.hi[1] as *const u128 as *const __m128i, - )); - let t2_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.hi[2] as *const u128 as *const __m128i, - )); - let t3_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128( - &lut.hi[3] as *const u128 as *const __m128i, - )); - - let t0_t2_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t0_lo), t2_lo, 1); - let t0_t2_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t0_hi), t2_hi, 1); - let t1_t3_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t1_lo), t3_lo, 1); - let t1_t3_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t1_hi), t3_hi, 1); + let t0_t2_lo = lut.t0_t2_lo; + let t0_t2_hi = lut.t0_t2_hi; + let t1_t3_lo = lut.t1_t3_lo; + let t1_t3_hi = lut.t1_t3_hi; let clr_mask = _mm512_set1_epi8(0x0f); @@ -171,7 +203,7 @@ impl Avx512 { // Implementation of LEO_MULADD_256 #[allow(clippy::too_many_arguments)] #[inline(always)] - fn muladd_512(x: __m512i, y: __m512i, lut: &Multiply128lutT) -> __m512i { + fn muladd_512(x: __m512i, y: __m512i, lut: Lut512) -> __m512i { unsafe { let prod = Self::mul_512(y, lut); _mm512_xor_si512(x, prod) @@ -189,6 +221,8 @@ impl Avx512 { fn fft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) { let lut = &self.mul128[log_m as usize]; + let lut512 = lut512(lut); + for (x_chunk, y_chunk) in zip(x.iter_mut(), y.iter_mut()) { let x_ptr = x_chunk.as_mut_ptr() as *mut i32; let y_ptr = y_chunk.as_mut_ptr() as *mut i32; @@ -197,7 +231,7 @@ impl Avx512 { let mut x = _mm512_loadu_si512(x_ptr); let mut y = _mm512_loadu_si512(y_ptr); - x = Self::muladd_512(x, y, lut); + x = Self::muladd_512(x, y, lut512); y = _mm512_xor_si512(y, x); _mm512_storeu_si512(x_ptr, x); @@ -318,6 +352,8 @@ impl Avx512 { fn ifft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) { let lut = &self.mul128[log_m as usize]; + let lut512 = lut512(lut); + for (x_chunk, y_chunk) in zip(&mut x.iter_mut(), &mut y.iter_mut()) { let x_ptr = x_chunk.as_mut_ptr() as *mut i32; let y_ptr = y_chunk.as_mut_ptr() as *mut i32; @@ -327,7 +363,7 @@ impl Avx512 { let mut y = _mm512_loadu_si512(y_ptr); y = _mm512_xor_si512(y, x); - x = Self::muladd_512(x, y, lut); + x = Self::muladd_512(x, y, lut512); _mm512_storeu_si512(x_ptr, x); _mm512_storeu_si512(y_ptr, y);