Skip to content

Commit

Permalink
avx512: Avoid continuously reloading the look up table
Browse files Browse the repository at this point in the history
  • Loading branch information
AndersTrier committed Oct 9, 2024
1 parent 945c03e commit f0d9660
Showing 1 changed file with 71 additions and 35 deletions.
106 changes: 71 additions & 35 deletions src/engine/engine_avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,55 +93,87 @@ impl Default for Avx512 {
//
//

#[derive(Copy, Clone)]
struct Lut512 {
t0_t2_lo: __m512i,
t0_t2_hi: __m512i,
t1_t3_lo: __m512i,
t1_t3_hi: __m512i,
}

#[inline(always)]
fn lut512(lut: &Multiply128lutT) -> Lut512 {
let t0_t2_lo: __m512i;
let t0_t2_hi: __m512i;
let t1_t3_lo: __m512i;
let t1_t3_hi: __m512i;

unsafe {
let t0_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[0] as *const u128 as *const __m128i,
));
let t1_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[1] as *const u128 as *const __m128i,
));
let t2_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[2] as *const u128 as *const __m128i,
));
let t3_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[3] as *const u128 as *const __m128i,
));

let t0_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[0] as *const u128 as *const __m128i,
));
let t1_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[1] as *const u128 as *const __m128i,
));
let t2_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[2] as *const u128 as *const __m128i,
));
let t3_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[3] as *const u128 as *const __m128i,
));

t0_t2_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t0_lo), t2_lo, 1);
t0_t2_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t0_hi), t2_hi, 1);
t1_t3_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t1_lo), t3_lo, 1);
t1_t3_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t1_hi), t3_hi, 1);
}

Lut512 {
t0_t2_lo,
t0_t2_hi,
t1_t3_lo,
t1_t3_hi,
}
}

impl Avx512 {
#[target_feature(enable = "avx512f,avx512vl,avx512bw")]
unsafe fn mul_avx512(&self, x: &mut [[u8; 64]], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];

let lut512 = lut512(lut);

for chunk in x.iter_mut() {
let x_ptr = chunk.as_mut_ptr() as *mut i32;
unsafe {
let x = _mm512_loadu_si512(x_ptr);
let prod = Self::mul_512(x, lut);
let prod = Self::mul_512(x, lut512);
_mm512_storeu_si512(x_ptr, prod);
}
}
}

// Impelemntation of LEO_MUL_256
#[inline(always)]
fn mul_512(value: __m512i, lut: &Multiply128lutT) -> __m512i {
fn mul_512(value: __m512i, lut: Lut512) -> __m512i {
unsafe {
let t0_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[0] as *const u128 as *const __m128i,
));
let t1_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[1] as *const u128 as *const __m128i,
));
let t2_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[2] as *const u128 as *const __m128i,
));
let t3_lo = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.lo[3] as *const u128 as *const __m128i,
));

let t0_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[0] as *const u128 as *const __m128i,
));
let t1_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[1] as *const u128 as *const __m128i,
));
let t2_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[2] as *const u128 as *const __m128i,
));
let t3_hi = _mm256_broadcastsi128_si256(_mm_loadu_si128(
&lut.hi[3] as *const u128 as *const __m128i,
));

let t0_t2_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t0_lo), t2_lo, 1);
let t0_t2_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t0_hi), t2_hi, 1);
let t1_t3_lo = _mm512_inserti64x4(_mm512_castsi256_si512(t1_lo), t3_lo, 1);
let t1_t3_hi = _mm512_inserti64x4(_mm512_castsi256_si512(t1_hi), t3_hi, 1);
let t0_t2_lo = lut.t0_t2_lo;
let t0_t2_hi = lut.t0_t2_hi;
let t1_t3_lo = lut.t1_t3_lo;
let t1_t3_hi = lut.t1_t3_hi;

let clr_mask = _mm512_set1_epi8(0x0f);

Expand Down Expand Up @@ -171,7 +203,7 @@ impl Avx512 {
// Implementation of LEO_MULADD_256
#[allow(clippy::too_many_arguments)]
#[inline(always)]
fn muladd_512(x: __m512i, y: __m512i, lut: &Multiply128lutT) -> __m512i {
fn muladd_512(x: __m512i, y: __m512i, lut: Lut512) -> __m512i {
unsafe {
let prod = Self::mul_512(y, lut);
_mm512_xor_si512(x, prod)
Expand All @@ -189,6 +221,8 @@ impl Avx512 {
fn fft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];

let lut512 = lut512(lut);

for (x_chunk, y_chunk) in zip(x.iter_mut(), y.iter_mut()) {
let x_ptr = x_chunk.as_mut_ptr() as *mut i32;
let y_ptr = y_chunk.as_mut_ptr() as *mut i32;
Expand All @@ -197,7 +231,7 @@ impl Avx512 {
let mut x = _mm512_loadu_si512(x_ptr);
let mut y = _mm512_loadu_si512(y_ptr);

x = Self::muladd_512(x, y, lut);
x = Self::muladd_512(x, y, lut512);
y = _mm512_xor_si512(y, x);

_mm512_storeu_si512(x_ptr, x);
Expand Down Expand Up @@ -318,6 +352,8 @@ impl Avx512 {
fn ifft_butterfly_partial(&self, x: &mut [[u8; 64]], y: &mut [[u8; 64]], log_m: GfElement) {
let lut = &self.mul128[log_m as usize];

let lut512 = lut512(lut);

for (x_chunk, y_chunk) in zip(&mut x.iter_mut(), &mut y.iter_mut()) {
let x_ptr = x_chunk.as_mut_ptr() as *mut i32;
let y_ptr = y_chunk.as_mut_ptr() as *mut i32;
Expand All @@ -327,7 +363,7 @@ impl Avx512 {
let mut y = _mm512_loadu_si512(y_ptr);

y = _mm512_xor_si512(y, x);
x = Self::muladd_512(x, y, lut);
x = Self::muladd_512(x, y, lut512);

_mm512_storeu_si512(x_ptr, x);
_mm512_storeu_si512(y_ptr, y);
Expand Down

0 comments on commit f0d9660

Please sign in to comment.