diff --git a/arith/src/extension_field/gf2_128x8.rs b/arith/src/extension_field/gf2_128x8.rs index e745affd..0236dc60 100644 --- a/arith/src/extension_field/gf2_128x8.rs +++ b/arith/src/extension_field/gf2_128x8.rs @@ -13,4 +13,3 @@ cfg_if::cfg_if! { pub type GF2_128x8 = avx::AVX512GF2_128x8; } } - diff --git a/arith/src/extension_field/gf2_128x8/avx256.rs b/arith/src/extension_field/gf2_128x8/avx256.rs index e765a126..1a23f7f4 100644 --- a/arith/src/extension_field/gf2_128x8/avx256.rs +++ b/arith/src/extension_field/gf2_128x8/avx256.rs @@ -21,10 +21,11 @@ field_common!(AVX256GF2_128x8); impl AVX256GF2_128x8 { #[inline(always)] pub(crate) fn pack_full(data: __m128i) -> [__m256i; 4] { - [unsafe { _mm256_broadcast_i32x4(data) }, - unsafe { _mm256_broadcast_i32x4(data) }, - unsafe { _mm256_broadcast_i32x4(data) }, - unsafe { _mm256_broadcast_i32x4(data) }, + [ + unsafe { _mm256_broadcast_i32x4(data) }, + unsafe { _mm256_broadcast_i32x4(data) }, + unsafe { _mm256_broadcast_i32x4(data) }, + unsafe { _mm256_broadcast_i32x4(data) }, ] } @@ -86,14 +87,7 @@ const PACKED_0: [__m256i; 4] = [ unsafe { transmute::<[i32; 8], std::arch::x86_64::__m256i>([0; 8]) }, unsafe { transmute::<[i32; 8], std::arch::x86_64::__m256i>([0; 8]) }, ]; -const _M256_INV_2: __m256i = unsafe { - transmute([ - 67_u64, - (1_u64) << 63, - 67_u64, - (1_u64) << 63, - ]) -}; +const _M256_INV_2: __m256i = unsafe { transmute([67_u64, (1_u64) << 63, 67_u64, (1_u64) << 63]) }; const PACKED_INV_2: [__m256i; 4] = [_M256_INV_2, _M256_INV_2, _M256_INV_2, _M256_INV_2]; // Should not be used? // p(x) = x^128 + x^7 + x^2 + x + 1 @@ -124,7 +118,9 @@ impl Field for AVX256GF2_128x8 { fn zero() -> Self { unsafe { let zero = _mm256_setzero_si256(); - Self { data: [zero, zero, zero, zero] } + Self { + data: [zero, zero, zero, zero], + } } } @@ -132,8 +128,10 @@ impl Field for AVX256GF2_128x8 { fn is_zero(&self) -> bool { unsafe { let zero = _mm256_setzero_si256(); - let cmp_0 = _mm256_cmpeq_epi64_mask(self.data[0], zero) & _mm256_cmpeq_epi64_mask(self.data[1], zero); - let cmp_1 = _mm256_cmpeq_epi64_mask(self.data[2], zero) & _mm256_cmpeq_epi64_mask(self.data[3], zero); + let cmp_0 = _mm256_cmpeq_epi64_mask(self.data[0], zero) + & _mm256_cmpeq_epi64_mask(self.data[1], zero); + let cmp_1 = _mm256_cmpeq_epi64_mask(self.data[2], zero) + & _mm256_cmpeq_epi64_mask(self.data[3], zero); (cmp_0 & cmp_1) == 0xF // All 16 64-bit integers are equal (zero) } } @@ -142,7 +140,9 @@ impl Field for AVX256GF2_128x8 { fn one() -> Self { unsafe { let one = _mm256_set_epi64x(0, 1, 0, 1); - Self { data: [one, one, one, one] } + Self { + data: [one, one, one, one], + } } } @@ -404,7 +404,9 @@ impl From for AVX256GF2_128x8 { fn from(v: u32) -> AVX256GF2_128x8 { assert!(v < 2); // only 0 and 1 are allowed let data = unsafe { _mm256_set_epi64x(0, v as i64, 0, v as i64) }; - AVX256GF2_128x8 { data: [data, data, data, data] } + AVX256GF2_128x8 { + data: [data, data, data, data], + } } } @@ -422,9 +424,9 @@ impl Debug for AVX256GF2_128x8 { let mut data = [0u8; 128]; unsafe { _mm256_storeu_si256(data.as_mut_ptr() as *mut __m256i, self.data[0]); - _mm256_storeu_si256((data.as_mut_ptr() as *mut __m256i).offset(8), self.data[1]); - _mm256_storeu_si256((data.as_mut_ptr() as *mut __m256i).offset(16), self.data[2]); - _mm256_storeu_si256((data.as_mut_ptr() as *mut __m256i).offset(24), self.data[3]); + _mm256_storeu_si256((data.as_mut_ptr() as *mut __m256i).offset(1), self.data[1]); + _mm256_storeu_si256((data.as_mut_ptr() as *mut __m256i).offset(2), self.data[2]); + _mm256_storeu_si256((data.as_mut_ptr() as *mut __m256i).offset(3), self.data[3]); } f.debug_struct("AVX256GF2_128x8") .field("data", &data) @@ -436,8 +438,10 @@ impl PartialEq for AVX256GF2_128x8 { #[inline(always)] fn eq(&self, other: &Self) -> bool { unsafe { - let cmp_0 = _mm256_cmpeq_epi64_mask(self.data[0], other.data[0]) & _mm256_cmpeq_epi64_mask(self.data[1], other.data[1]); - let cmp_1 = _mm256_cmpeq_epi64_mask(self.data[2], other.data[2]) & _mm256_cmpeq_epi64_mask(self.data[3], other.data[3]); + let cmp_0 = _mm256_cmpeq_epi64_mask(self.data[0], other.data[0]) + & _mm256_cmpeq_epi64_mask(self.data[1], other.data[1]); + let cmp_1 = _mm256_cmpeq_epi64_mask(self.data[2], other.data[2]) + & _mm256_cmpeq_epi64_mask(self.data[3], other.data[3]); (cmp_0 & cmp_1) == 0xF // All 16 64-bit integers are equal } } @@ -504,16 +508,8 @@ fn sub_internal(a: &AVX256GF2_128x8, b: &AVX256GF2_128x8) -> AVX256GF2_128x8 { #[inline] fn _m256_mul_internal(a: __m256i, b: __m256i) -> __m256i { unsafe { - let xmmmask = _mm256_set_epi32( - 0, - 0, - 0, - 0xffffffffu32 as i32, - 0, - 0, - 0, - 0xffffffffu32 as i32, - ); + let xmmmask = + _mm256_set_epi32(0, 0, 0, 0xffffffffu32 as i32, 0, 0, 0, 0xffffffffu32 as i32); let mut tmp3 = _mm256_clmulepi64_epi128(a, b, 0x00); let mut tmp6 = _mm256_clmulepi64_epi128(a, b, 0x11); @@ -632,14 +628,10 @@ impl ExtensionField for AVX256GF2_128x8 { let v7 = (base.v & 1u8) as i64; let mut res = *self; - res.data[0] = - unsafe { _mm256_xor_si256(res.data[0], _mm256_set_epi64x(0, v0, 0, v2)) }; - res.data[1] = - unsafe { _mm256_xor_si256(res.data[1], _mm256_set_epi64x(0, v4, 0, v6)) }; - res.data[2] = - unsafe { _mm256_xor_si256(res.data[2], _mm256_set_epi64x(0, v1, 0, v3,)) }; - res.data[3] = - unsafe { _mm256_xor_si256(res.data[3], _mm256_set_epi64x(0, v5, 0, v7)) }; + res.data[0] = unsafe { _mm256_xor_si256(res.data[0], _mm256_set_epi64x(0, v0, 0, v2)) }; + res.data[1] = unsafe { _mm256_xor_si256(res.data[1], _mm256_set_epi64x(0, v4, 0, v6)) }; + res.data[2] = unsafe { _mm256_xor_si256(res.data[2], _mm256_set_epi64x(0, v1, 0, v3)) }; + res.data[3] = unsafe { _mm256_xor_si256(res.data[3], _mm256_set_epi64x(0, v5, 0, v7)) }; res } diff --git a/arith/src/field/m31/m31_avx256.rs b/arith/src/field/m31/m31_avx256.rs index e6210fbc..637de711 100644 --- a/arith/src/field/m31/m31_avx256.rs +++ b/arith/src/field/m31/m31_avx256.rs @@ -11,10 +11,10 @@ use rand::{Rng, RngCore}; use crate::{field_common, Field, FieldSerde, FieldSerdeResult, SimdField, M31, M31_MOD}; -const M31_PACK_SIZE: usize = 16; -const PACKED_MOD: __m256i = unsafe { transmute([M31_MOD; M31_PACK_SIZE / 2]) }; -const PACKED_0: __m256i = unsafe { transmute([0; M31_PACK_SIZE / 2]) }; -const PACKED_INV_2: __m256i = unsafe { transmute([1 << 30; M31_PACK_SIZE/ 2]) }; +const M31_PACK_SIZE: usize = 16; +const PACKED_MOD: __m256i = unsafe { transmute([M31_MOD; M31_PACK_SIZE / 2]) }; +const PACKED_0: __m256i = unsafe { transmute([0; M31_PACK_SIZE / 2]) }; +const PACKED_INV_2: __m256i = unsafe { transmute([1 << 30; M31_PACK_SIZE / 2]) }; #[inline(always)] unsafe fn mod_reduce_epi32(x: __m256i) -> __m256i { @@ -85,20 +85,24 @@ impl Field for AVXM31 { // size in bytes const SIZE: usize = 512 / 8; - const ZERO: Self = Self { v: [PACKED_0, PACKED_0] }; + const ZERO: Self = Self { + v: [PACKED_0, PACKED_0], + }; const ONE: Self = Self { - v: unsafe { transmute::<[u32; 16], [__m256i; 2]>([1; M31_PACK_SIZE]) }, + v: unsafe { transmute::<[u32; 16], [__m256i; 2]>([1; M31_PACK_SIZE]) }, }; - const INV_2: Self = Self { v: [PACKED_INV_2, PACKED_INV_2] }; + const INV_2: Self = Self { + v: [PACKED_INV_2, PACKED_INV_2], + }; const FIELD_SIZE: usize = 32; #[inline(always)] fn zero() -> Self { AVXM31 { - v: unsafe { [_mm256_set1_epi32(0), _mm256_set1_epi32(0)] }, + v: unsafe { [_mm256_set1_epi32(0), _mm256_set1_epi32(0)] }, } } @@ -106,8 +110,10 @@ impl Field for AVXM31 { fn is_zero(&self) -> bool { // value is either zero or 0x7FFFFFFF unsafe { - let pcmp = _mm256_cmpeq_epi32_mask(self.v[0], PACKED_0) & _mm256_cmpeq_epi32_mask(self.v[1], PACKED_0); - let pcmp2 = _mm256_cmpeq_epi32_mask(self.v[0], PACKED_MOD) & _mm256_cmpeq_epi32_mask(self.v[1], PACKED_MOD); + let pcmp = _mm256_cmpeq_epi32_mask(self.v[0], PACKED_0) + & _mm256_cmpeq_epi32_mask(self.v[1], PACKED_0); + let pcmp2 = _mm256_cmpeq_epi32_mask(self.v[0], PACKED_MOD) + & _mm256_cmpeq_epi32_mask(self.v[1], PACKED_MOD); (pcmp | pcmp2) == 0xFF } } @@ -115,7 +121,7 @@ impl Field for AVXM31 { #[inline(always)] fn one() -> Self { AVXM31 { - v: unsafe { [_mm256_set1_epi32(1), _mm256_set1_epi32(1)] }, + v: unsafe { [_mm256_set1_epi32(1), _mm256_set1_epi32(1)] }, } } @@ -123,7 +129,7 @@ impl Field for AVXM31 { // this function is for internal testing only. it is not // a source for uniformly random field elements and // should not be used in production. - fn random_unsafe(mut rng: impl RngCore) -> Self { + fn random_unsafe(mut rng: impl RngCore) -> Self { // Caution: this may not produce uniformly random elements unsafe { let mut v = [ @@ -146,7 +152,7 @@ impl Field for AVXM31 { rng.gen::(), rng.gen::(), rng.gen::(), - ) + ), ]; v = mod_reduce_epi32_2(v); v = mod_reduce_epi32_2(v); @@ -161,26 +167,26 @@ impl Field for AVXM31 { AVXM31 { v: unsafe { [ - _mm256_setr_epi32( - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - ), - _mm256_setr_epi32( - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - ) + _mm256_setr_epi32( + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + ), + _mm256_setr_epi32( + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + rng.gen::() as i32, + ), ] }, } @@ -198,9 +204,24 @@ impl Field for AVXM31 { #[inline(always)] // modified fn mul_by_5(&self) -> AVXM31 { - let double = unsafe { mod_reduce_epi32_2([_mm256_slli_epi32::<1>(self.v[0]), _mm256_slli_epi32::<1>(self.v[1])]) }; - let quad = unsafe { mod_reduce_epi32_2([_mm256_slli_epi32::<1>(double[0]), _mm256_slli_epi32::<1>(double[1])]) }; - let res = unsafe { mod_reduce_epi32_2([_mm256_add_epi32(self.v[0], quad[0]), _mm256_add_epi32(self.v[1], quad[1])]) }; + let double = unsafe { + mod_reduce_epi32_2([ + _mm256_slli_epi32::<1>(self.v[0]), + _mm256_slli_epi32::<1>(self.v[1]), + ]) + }; + let quad = unsafe { + mod_reduce_epi32_2([ + _mm256_slli_epi32::<1>(double[0]), + _mm256_slli_epi32::<1>(double[1]), + ]) + }; + let res = unsafe { + mod_reduce_epi32_2([ + _mm256_add_epi32(self.v[0], quad[0]), + _mm256_add_epi32(self.v[1], quad[1]), + ]) + }; Self { v: res } } @@ -227,14 +248,24 @@ impl Field for AVXM31 { fn from_uniform_bytes(bytes: &[u8; 32]) -> Self { let m = M31::from_uniform_bytes(bytes); Self { - v: unsafe {[ _mm256_set1_epi32(m.v as i32), _mm256_set1_epi32(m.v as i32)] }, + v: unsafe { [_mm256_set1_epi32(m.v as i32), _mm256_set1_epi32(m.v as i32)] }, } } #[inline(always)] fn mul_by_3(&self) -> AVXM31 { - let double = unsafe { mod_reduce_epi32_2([_mm256_slli_epi32::<1>(self.v[0]), _mm256_slli_epi32::<1>(self.v[1])]) }; - let res = unsafe { mod_reduce_epi32_2([_mm256_add_epi32(self.v[0], double[0]), _mm256_add_epi32(self.v[1], double[1])]) }; + let double = unsafe { + mod_reduce_epi32_2([ + _mm256_slli_epi32::<1>(self.v[0]), + _mm256_slli_epi32::<1>(self.v[1]), + ]) + }; + let res = unsafe { + mod_reduce_epi32_2([ + _mm256_add_epi32(self.v[0], double[0]), + _mm256_add_epi32(self.v[1], double[1]), + ]) + }; Self { v: res } } } @@ -292,7 +323,7 @@ impl Default for AVXM31 { impl PartialEq for AVXM31 { #[inline(always)] - fn eq(&self, other: &Self) -> bool { + fn eq(&self, other: &Self) -> bool { unsafe { let cmp0 = _mm256_cmpeq_epi32_mask(self.v[0], other.v[0]); let cmp1 = _mm256_cmpeq_epi32_mask(self.v[1], other.v[1]); @@ -303,7 +334,7 @@ impl PartialEq for AVXM31 { #[inline] #[must_use] -fn mask_movehdup_epi32(src: __m256i, k: __mmask8, a: __m256i) -> __m256i { +fn mask_movehdup_epi32(src: __m256i, k: __mmask8, a: __m256i) -> __m256i { // The instruction is only available in the floating-point flavor; this distinction is only for // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. unsafe { @@ -315,7 +346,7 @@ fn mask_movehdup_epi32(src: __m256i, k: __mmask8, a: __m256i) -> __m256i { #[inline] #[must_use] -fn mask_moveldup_epi32(src: __m256i, k: __mmask8, a: __m256i) -> __m256i { +fn mask_moveldup_epi32(src: __m256i, k: __mmask8, a: __m256i) -> __m256i { // The instruction is only available in the floating-point flavor; this distinction is only for // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. unsafe { @@ -342,11 +373,12 @@ impl Mul<&M31> for AVXM31 { type Output = AVXM31; #[inline(always)] - fn mul(self, rhs: &M31) -> Self::Output { + fn mul(self, rhs: &M31) -> Self::Output { let rhsv = AVXM31::pack_full(*rhs); unsafe { let mut res: [__m256i; 2] = [_mm256_setzero_si256(); 2]; - for i in 0..2 { + #[allow(clippy::needless_range_loop)] + for i in 0..res.len() { let rhs_evn = rhsv.v[i]; let lhs_odd_dbl = _mm256_srli_epi64(self.v[i], 31); let lhs_evn_dbl = _mm256_add_epi32(self.v[i], self.v[i]); @@ -395,39 +427,62 @@ impl From for AVXM31 { impl Neg for AVXM31 { type Output = AVXM31; #[inline(always)] - fn neg(self) -> Self::Output { + fn neg(self) -> Self::Output { AVXM31 { - v: unsafe { [_mm256_xor_epi32(self.v[0], PACKED_MOD), _mm256_xor_epi32(self.v[1], PACKED_MOD)] }, + v: unsafe { + [ + _mm256_xor_epi32(self.v[0], PACKED_MOD), + _mm256_xor_epi32(self.v[1], PACKED_MOD), + ] + }, } } } #[inline] #[must_use] -fn movehdup_epi32(x: __m256i) -> __m256i { +fn movehdup_epi32(x: __m256i) -> __m256i { // The instruction is only available in the floating-point flavor; this distinction is only for // historical reasons and no longer matters. We cast to floats, duplicate, and cast back. unsafe { _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))) } } #[inline(always)] -fn add_internal(a: &AVXM31, b: &AVXM31) -> AVXM31 { +fn add_internal(a: &AVXM31, b: &AVXM31) -> AVXM31 { unsafe { - let mut result = [_mm256_add_epi32(a.v[0], b.v[0]), _mm256_add_epi32(a.v[1], b.v[1])]; - let subx = [_mm256_sub_epi32(result[0], PACKED_MOD), _mm256_sub_epi32(result[1], PACKED_MOD)]; - result = [_mm256_min_epu32(result[0], subx[0]), _mm256_min_epu32(result[1], subx[1])]; + let mut result = [ + _mm256_add_epi32(a.v[0], b.v[0]), + _mm256_add_epi32(a.v[1], b.v[1]), + ]; + let subx = [ + _mm256_sub_epi32(result[0], PACKED_MOD), + _mm256_sub_epi32(result[1], PACKED_MOD), + ]; + result = [ + _mm256_min_epu32(result[0], subx[0]), + _mm256_min_epu32(result[1], subx[1]), + ]; AVXM31 { v: result } } } #[inline(always)] -fn sub_internal(a: &AVXM31, b: &AVXM31) -> AVXM31 { +fn sub_internal(a: &AVXM31, b: &AVXM31) -> AVXM31 { AVXM31 { v: unsafe { - let t = [_mm256_sub_epi32(a.v[0], b.v[0]), _mm256_sub_epi32(a.v[1], b.v[1])]; - let subx = [_mm256_add_epi32(t[0], PACKED_MOD), _mm256_add_epi32(t[1], PACKED_MOD)]; - [_mm256_min_epu32(t[0], subx[0]), _mm256_min_epu32(t[1], subx[1])] + let t = [ + _mm256_sub_epi32(a.v[0], b.v[0]), + _mm256_sub_epi32(a.v[1], b.v[1]), + ]; + let subx = [ + _mm256_add_epi32(t[0], PACKED_MOD), + _mm256_add_epi32(t[1], PACKED_MOD), + ]; + [ + _mm256_min_epu32(t[0], subx[0]), + _mm256_min_epu32(t[1], subx[1]), + ] }, } } @@ -437,7 +492,8 @@ fn mul_internal(a: &AVXM31, b: &AVXM31) -> AVXM31 { // credit: https://github.com/Plonky3/Plonky3/blob/eeb4e37b20127c4daa871b2bad0df30a7c7380db/mersenne-31/src/x86_64_avx2/packing.rs#L154 unsafe { let mut res: [__m256i; 2] = [_mm256_setzero_si256(); 2]; - for i in 0..2{ + #[allow(clippy::needless_range_loop)] + for i in 0..res.len() { let rhs_evn = b.v[i]; let lhs_odd_dbl = _mm256_srli_epi64(a.v[i], 31); let lhs_evn_dbl = _mm256_add_epi32(a.v[i], a.v[i]); diff --git a/arith/src/field/m31/m31_avx256_err.rs b/arith/src/field/m31/m31_avx256_err.rs deleted file mode 100644 index 33484fff..00000000 --- a/arith/src/field/m31/m31_avx256_err.rs +++ /dev/null @@ -1,494 +0,0 @@ -use std::{ - arch::x86_64::*, - fmt::Debug, - io::{Read, Write}, - iter::{Product, Sum}, - mem::transmute, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, -}; - -use rand::{Rng, RngCore}; - -use crate::{field_common, Field, FieldSerde, FieldSerdeResult, SimdField, M31, M31_MOD}; - -const M31_PACK_SIZE: usize = 16; -const PACKED_MOD: __m256i = unsafe { transmute([M31_MOD; M31_PACK_SIZE / 2]) }; -const PACKED_0: __m256i = unsafe { transmute([0; M31_PACK_SIZE / 2]) }; -const PACKED_INV_2: __m256i = unsafe { transmute([1 << 30; M31_PACK_SIZE/ 2]) }; - -#[inline(always)] -unsafe fn mod_reduce_epi32(x: __m256i) -> __m256i { - _mm256_add_epi32(_mm256_and_si256(x, PACKED_MOD), _mm256_srli_epi32(x, 31)) -} - -#[inline(always)] -unsafe fn mod_reduce_epi32_2(x: [__m256i; 2]) -> [__m256i; 2] { - [mod_reduce_epi32(x[0]), mod_reduce_epi32(x[1])] -} - -#[derive(Clone, Copy)] -pub struct AVXM31 { - pub v: [__m256i; 2], -} - -impl AVXM31 { - #[inline(always)] - pub(crate) fn pack_full(x: M31) -> AVXM31 { - AVXM31 { - v: unsafe { [_mm256_set1_epi32(x.v as i32), _mm256_set1_epi32(x.v as i32)] }, - } - } -} - -field_common!(AVXM31); - -impl FieldSerde for AVXM31 { - const SERIALIZED_SIZE: usize = 512 / 8; - - #[inline(always)] - /// serialize self into bytes - fn serialize_into(&self, mut writer: W) -> FieldSerdeResult<()> { - let data = unsafe { transmute::<[__m256i; 2], [u8; 64]>(self.v) }; - writer.write_all(&data)?; - Ok(()) - } - - /// deserialize bytes into field - #[inline(always)] - fn deserialize_from(mut reader: R) -> FieldSerdeResult { - let mut data = [0; Self::SERIALIZED_SIZE]; - reader.read_exact(&mut data)?; - unsafe { - let mut value = transmute::<[u8; Self::SERIALIZED_SIZE], [__m256i; 2]>(data); - value = mod_reduce_epi32_2(value); - Ok(AVXM31 { v: value }) - } - } - - #[inline(always)] - fn try_deserialize_from_ecc_format(mut reader: R) -> FieldSerdeResult { - let mut buf = [0u8; 32]; - reader.read_exact(&mut buf)?; - assert!( - buf.iter().skip(4).all(|&x| x == 0), - "non-zero byte found in witness byte" - ); - Ok(Self::pack_full( - u32::from_le_bytes(buf[..4].try_into().unwrap()).into(), - )) - } -} - -impl Field for AVXM31 { - const NAME: &'static str = "AVX Packed Mersenne 31"; - - // size in bytes - const SIZE: usize = 512 / 8; - - const ZERO: Self = Self { v: [PACKED_0, PACKED_0] }; - - const ONE: Self = Self { - v: unsafe { transmute::<[u32; 16], [__m256i; 2]>([1; M31_PACK_SIZE]) }, - }; - - const INV_2: Self = Self { v: [PACKED_INV_2, PACKED_INV_2] }; - - const FIELD_SIZE: usize = 32; - - #[inline(always)] - fn zero() -> Self { - AVXM31 { - v: unsafe { [_mm256_set1_epi32(0), _mm256_set1_epi32(0)] }, - } - } - - #[inline(always)] - fn is_zero(&self) -> bool { - // value is either zero or 0x7FFFFFFF - unsafe { - let pcmp = _mm256_cmpeq_epi32_mask(self.v[0], PACKED_0) & _mm256_cmpeq_epi32_mask(self.v[1], PACKED_0); - let pcmp2 = _mm256_cmpeq_epi32_mask(self.v[0], PACKED_MOD) & _mm256_cmpeq_epi32_mask(self.v[1], PACKED_MOD); - (pcmp | pcmp2) == 0xFF - } - } - - #[inline(always)] - fn one() -> Self { - AVXM31 { - v: unsafe { [_mm256_set1_epi32(1), _mm256_set1_epi32(1)] }, - } - } - - #[inline(always)] - // this function is for internal testing only. it is not - // a source for uniformly random field elements and - // should not be used in production. - fn random_unsafe(mut rng: impl RngCore) -> Self { - // Caution: this may not produce uniformly random elements - unsafe { - let mut v = [ - _mm256_setr_epi32( - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - ), - _mm256_setr_epi32( - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - ) - ]; - v = mod_reduce_epi32_2(v); - v = mod_reduce_epi32_2(v); - AVXM31 { v } - } - } - - #[inline(always)] - // modified - fn random_bool(mut rng: impl RngCore) -> Self { - // TODO: optimize this code - AVXM31 { - v: unsafe { - [ - _mm256_setr_epi32( - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - ), - _mm256_setr_epi32( - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - ) - ] - }, - } - } - - fn exp(&self, _exponent: u128) -> Self { - unimplemented!("exp not implemented for AVXM31") - } - - #[inline(always)] - fn double(&self) -> Self { - self.mul_by_2() - } - - #[inline(always)] - // modified - fn mul_by_5(&self) -> AVXM31 { - let double = unsafe { mod_reduce_epi32_2([_mm256_slli_epi32::<1>(self.v[0]), _mm256_slli_epi32::<1>(self.v[1])]) }; - let quad = unsafe { mod_reduce_epi32_2([_mm256_slli_epi32::<1>(double[0]), _mm256_slli_epi32::<1>(double[1])]) }; - let res = unsafe { mod_reduce_epi32_2([_mm256_add_epi32(self.v[0], quad[0]), _mm256_add_epi32(self.v[1], quad[1])]) }; - Self { v: res } - } - - #[inline(always)] - fn inv(&self) -> Option { - // slow, should not be used in production - let mut m31_vec = unsafe { transmute::<[__m256i; 2], [M31; 16]>(self.v) }; - let is_non_zero = m31_vec.iter().all(|x| !x.is_zero()); - if !is_non_zero { - return None; - } - - m31_vec.iter_mut().for_each(|x| *x = x.inv().unwrap()); // safe unwrap - Some(Self { - v: unsafe { transmute::<[M31; 16], [__m256i; 2]>(m31_vec) }, - }) - } - - fn as_u32_unchecked(&self) -> u32 { - unimplemented!("self is a vector, cannot convert to u32") - } - - #[inline] - fn from_uniform_bytes(bytes: &[u8; 32]) -> Self { - let m = M31::from_uniform_bytes(bytes); - Self { - v: unsafe {[ _mm256_set1_epi32(m.v as i32), _mm256_set1_epi32(m.v as i32)] }, - } - } - - #[inline(always)] - // modified - fn mul_by_3(&self) -> AVXM31 { - let double = unsafe { mod_reduce_epi32_2([_mm256_slli_epi32::<1>(self.v[0]), _mm256_slli_epi32::<1>(self.v[1])]) }; - let res = unsafe { mod_reduce_epi32_2([_mm256_add_epi32(self.v[0], double[0]), _mm256_add_epi32(self.v[1], double[1])]) }; - Self { v: res } - } -} - -impl SimdField for AVXM31 { - type Scalar = M31; - - #[inline] - fn scale(&self, challenge: &Self::Scalar) -> Self { - *self * *challenge - } - - #[inline(always)] - fn pack_size() -> usize { - M31_PACK_SIZE - } -} - -impl From for AVXM31 { - #[inline(always)] - fn from(x: M31) -> Self { - AVXM31::pack_full(x) - } -} - -impl Debug for AVXM31 { - // modified - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut data = [0; M31_PACK_SIZE]; - unsafe { - _mm256_storeu_si256(data.as_mut_ptr() as *mut __m256i, self.v[0]); - _mm256_storeu_si256(data.as_mut_ptr().add(8) as *mut __m256i, self.v[1]); - } - // if all data is the same, print only one - if data.iter().all(|&x| x == data[0]) { - write!( - f, - "mm256i<8 x {}>", - if M31_MOD - data[0] > 1024 { - format!("{}", data[0]) - } else { - format!("-{}", M31_MOD - data[0]) - } - ) - } else { - write!(f, "mm256i<{:?}>", data) - } - } -} - -impl Default for AVXM31 { - fn default() -> Self { - AVXM31::zero() - } -} - -impl PartialEq for AVXM31 { - #[inline(always)] - fn eq(&self, other: &Self) -> bool { - unsafe { - let cmp0 = _mm256_cmpeq_epi32_mask(self.v[0], other.v[0]); - let cmp1 = _mm256_cmpeq_epi32_mask(self.v[1], other.v[1]); -if (cmp0 & cmp1) != 0xFF { - print_m256i_bits(self.v[0]); - print_m256i_bits(self.v[1]); - print_m256i_bits(other.v[0]); - print_m256i_bits(other.v[1]); - println!("======= {} {}", cmp0, cmp1); -} - (cmp0 & cmp1) == 0xFF - } - } -} - -#[inline] -#[must_use] -fn mask_movehdup_epi32(src: __m256i, k: __mmask8, a: __m256i) -> __m256i { - // The instruction is only available in the floating-point flavor; this distinction is only for - // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. - unsafe { - let src = _mm256_castsi256_ps(src); - let a = _mm256_castsi256_ps(a); - _mm256_castps_si256(_mm256_mask_movehdup_ps(src, k, a)) - } -} - -#[inline] -#[must_use] -fn mask_moveldup_epi32(src: __m256i, k: __mmask8, a: __m256i) -> __m256i { - // The instruction is only available in the floating-point flavor; this distinction is only for - // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. - unsafe { - let src = _mm256_castsi256_ps(src); - let a = _mm256_castsi256_ps(a); - _mm256_castps_si256(_mm256_mask_moveldup_ps(src, k, a)) - } -} - -use std::arch::x86_64::*; // 确保在x86_64架构上使用 - -fn print_m256i_bits(value: __m256i) { - println!("__m256i: {:?}", m256i2arr(value)); -} - -fn m256i2arr(value: __m256i) -> [u32; 8]{ - let mut arr = [0u32; 8]; - unsafe { - _mm256_storeu_si256(arr.as_mut_ptr() as *mut __m256i, value); - } - arr -} - -#[inline] -#[must_use] -fn add(lhs: __m256i, rhs: __m256i) -> __m256i { -// print_m256i_bits(lhs); -// print_m256i_bits(rhs); - unsafe { - let t = _mm256_add_epi32(lhs, rhs); - let u = _mm256_sub_epi32(t, PACKED_MOD); -//println!("{:?} + {:?} == {:?} ? {:?}", m256i2arr(lhs), m256i2arr(rhs), m256i2arr(t), m256i2arr(u)); -let t_ = m256i2arr(t); -let u_ = m256i2arr(u); -if t_[4] == 2609101268u32 || u_[4] == 2609101268u32 || t_[4] == 1062346703u32 || u_[4] == 1062346703u32{ - print_m256i_bits(lhs); - print_m256i_bits(rhs); - print_m256i_bits(t); - print_m256i_bits(u); - println!("======="); -} - _mm256_min_epu32(t, u) - } -} - -const EVENS: __mmask8 = 0b01010101; -const ODDS: __mmask8 = 0b10101010; - -impl Mul<&M31> for AVXM31 { - type Output = AVXM31; - - #[inline(always)] - fn mul(self, rhs: &M31) -> Self::Output { - let rhsv = AVXM31::pack_full(*rhs); - unsafe { - let mut res: [__m256i; 2] = [_mm256_setzero_si256(); 2]; - for i in 0..2 { - let rhs_evn = rhsv.v[i]; - let lhs_odd_dbl = _mm256_srli_epi64(self.v[i], 31); - let lhs_evn_dbl = _mm256_add_epi32(self.v[i], self.v[i]); - let rhs_odd = movehdup_epi32(rhsv.v[i]); - - let prod_odd_dbl = _mm256_mul_epu32(lhs_odd_dbl, rhs_odd); - let prod_evn_dbl = _mm256_mul_epu32(lhs_evn_dbl, rhs_evn); - - let prod_lo_dbl = mask_moveldup_epi32(prod_evn_dbl, ODDS, prod_odd_dbl); - let prod_hi = mask_movehdup_epi32(prod_odd_dbl, EVENS, prod_evn_dbl); - // Right shift to undo the doubling. - let prod_lo = _mm256_srli_epi32::<1>(prod_lo_dbl); - - // Standard addition of two 31-bit values. - res[i] = add(prod_lo, prod_hi); - } - AVXM31 { v: res } - } - } -} - -impl Mul for AVXM31 { - type Output = AVXM31; - #[inline(always)] - fn mul(self, rhs: M31) -> Self::Output { - self * &rhs - } -} - -impl Add for AVXM31 { - type Output = AVXM31; - #[inline(always)] - #[allow(clippy::op_ref)] - fn add(self, rhs: M31) -> Self::Output { - self + AVXM31::pack_full(rhs) - } -} - -impl From for AVXM31 { - #[inline(always)] - fn from(x: u32) -> Self { - AVXM31::pack_full(M31::from(x)) - } -} - -impl Neg for AVXM31 { - type Output = AVXM31; - #[inline(always)] - fn neg(self) -> Self::Output { - AVXM31 { - v: unsafe { [_mm256_xor_epi32(self.v[0], PACKED_MOD), _mm256_xor_epi32(self.v[1], PACKED_MOD)] }, - } - } -} - -#[inline] -#[must_use] -fn movehdup_epi32(x: __m256i) -> __m256i { - // The instruction is only available in the floating-point flavor; this distinction is only for - // historical reasons and no longer matters. We cast to floats, duplicate, and cast back. - unsafe { _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(x))) } -} - -#[inline(always)] -fn add_internal(a: &AVXM31, b: &AVXM31) -> AVXM31 { - unsafe { - let mut result = [_mm256_add_epi32(a.v[0], b.v[0]), _mm256_add_epi32(a.v[1], b.v[1])]; - let subx = [_mm256_sub_epi32(result[0], PACKED_MOD), _mm256_sub_epi32(result[1], PACKED_MOD)]; - result = [_mm256_min_epu32(result[0], subx[0]), _mm256_min_epu32(result[1], subx[1])]; - - AVXM31 { v: result } - } -} - -#[inline(always)] -fn sub_internal(a: &AVXM31, b: &AVXM31) -> AVXM31 { - AVXM31 { - v: unsafe { - let t = [_mm256_sub_epi32(a.v[0], b.v[0]), _mm256_sub_epi32(a.v[1], b.v[1])]; - let subx = [_mm256_add_epi32(t[0], PACKED_MOD), _mm256_add_epi32(t[1], PACKED_MOD)]; - [_mm256_min_epu32(t[0], subx[0]), _mm256_min_epu32(t[1], subx[1])] - }, - } -} - -#[inline] -fn mul_internal(a: &AVXM31, b: &AVXM31) -> AVXM31 { - // credit: https://github.com/Plonky3/Plonky3/blob/eeb4e37b20127c4daa871b2bad0df30a7c7380db/mersenne-31/src/x86_64_avx2/packing.rs#L154 - unsafe { - let mut res: [__m256i; 2] = [_mm256_setzero_si256(); 2]; - for i in 0..2{ - let rhs_evn = b.v[i]; - let lhs_odd_dbl = _mm256_srli_epi64(a.v[i], 31); - let lhs_evn_dbl = _mm256_add_epi32(a.v[i], a.v[i]); - let rhs_odd = movehdup_epi32(b.v[i]); - - let prod_odd_dbl = _mm256_mul_epu32(lhs_odd_dbl, rhs_odd); - let prod_evn_dbl = _mm256_mul_epu32(lhs_evn_dbl, rhs_evn); - - let prod_lo_dbl = mask_moveldup_epi32(prod_evn_dbl, ODDS, prod_odd_dbl); - let prod_hi = mask_movehdup_epi32(prod_odd_dbl, EVENS, prod_evn_dbl); - // Right shift to undo the doubling. - let prod_lo = _mm256_srli_epi32::<1>(prod_lo_dbl); - - // Standard addition of two 31-bit values. - res[i] = add(prod_lo, prod_hi); - } - AVXM31 { v: res } - } -} diff --git a/arith/src/field/m31/m31_avx256_redraft.rs b/arith/src/field/m31/m31_avx256_redraft.rs deleted file mode 100644 index b3ac9bf5..00000000 --- a/arith/src/field/m31/m31_avx256_redraft.rs +++ /dev/null @@ -1,448 +0,0 @@ -use std::{ - arch::x86_64::*, - fmt::Debug, - io::{Read, Write}, - iter::{Product, Sum}, - mem::transmute, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, -}; - -use rand::{Rng, RngCore}; - -use crate::{field_common, Field, FieldSerde, FieldSerdeResult, SimdField, M31, M31_MOD}; - -const M31_PACK_SIZE: usize = 16; -const PACKED_MOD: __m256i = unsafe { transmute([M31_MOD; M31_PACK_SIZE / 2]) }; -const PACKED_0: __m256i = unsafe { transmute([0; M31_PACK_SIZE / 2]) }; -const PACKED_INV_2: __m256i = unsafe { transmute([1 << 30; M31_PACK_SIZE / 2]) }; - -#[inline(always)] -unsafe fn mod_reduce_epi32(x: __m256i) -> __m256i { - _mm256_add_epi32(_mm256_and_si256(x, PACKED_MOD), _mm256_srli_epi32(x, 31)) -} - -#[inline(always)] -unsafe fn mod_reduce_epi32_2(x: [__m256i; 2]) -> [__m256i; 2] { - [mod_reduce_epi32(x[0]), mod_reduce_epi32(x[1])] -} - -#[derive(Clone, Copy)] -pub struct AVXM31 { - pub v: [__m256i; 2], -} - -impl AVXM31 { - #[inline(always)] - pub(crate) fn pack_full(x: M31) -> AVXM31 { - AVXM31 { - v: unsafe { _mm256_set1_epi32(x.v as i32) }, - } - } -} - -field_common!(AVXM31); - -impl FieldSerde for AVXM31 { - const SERIALIZED_SIZE: usize = 512 / 8; - - #[inline(always)] - /// serialize self into bytes - fn serialize_into(&self, mut writer: W) -> FieldSerdeResult<()> { - let data = unsafe { transmute::<[__m256i; 2], [u8; 64]>(self.v) }; - writer.write_all(&data)?; - Ok(()) - } - - /// deserialize bytes into field - #[inline(always)] - fn deserialize_from(mut reader: R) -> FieldSerdeResult { - let mut data = [0; Self::SERIALIZED_SIZE]; - reader.read_exact(&mut data)?; - unsafe { - let mut value = transmute::<[u8; Self::SERIALIZED_SIZE], __m256i>(data); - value = mod_reduce_epi32(value); - Ok(AVXM31 { v: value }) - } - } - - #[inline(always)] - fn try_deserialize_from_ecc_format(mut reader: R) -> FieldSerdeResult { - let mut buf = [0u8; 32]; - reader.read_exact(&mut buf)?; - assert!( - buf.iter().skip(4).all(|&x| x == 0), - "non-zero byte found in witness byte" - ); - Ok(Self::pack_full( - u32::from_le_bytes(buf[..4].try_into().unwrap()).into(), - )) - } -} - -impl Field for AVXM31 { - const NAME: &'static str = "AVX Packed Mersenne 31"; - - // size in bytes - const SIZE: usize = 512 / 8; - - const ZERO: Self = Self { v: [PACKED_0, PACKED_0] }; - - const ONE: Self = Self { - v: unsafe { transmute::<[u32; 16], [__m256i; 2]>([1; M31_PACK_SIZE]) }, - }; - - const INV_2: Self = Self { v: [PACKED_INV_2, PACKED_INV_2] }; - - const FIELD_SIZE: usize = 32; - - #[inline(always)] - fn zero() -> Self { - AVXM31 { - v: unsafe { [_mm256_set1_epi32(0), _mm256_set1_epi32(0)] }, - } - } - - #[inline(always)] - fn is_zero(&self) -> bool { - // value is either zero or 0x7FFFFFFF - unsafe { - let pcmp = _mm256_cmpeq_epi32_mask(self.v[0], PACKED_0) & _mm256_cmpeq_epi32_mask(self.v[1], PACKED_0); - let pcmp2 = _mm256_cmpeq_epi32_mask(self.v[0], PACKED_MOD) & _mm256_cmpeq_epi32_mask(self.v[1], PACKED_MOD); - (pcmp | pcmp2) == 0xFF - } - } - - #[inline(always)] - fn one() -> Self { - AVXM31 { - v: unsafe { [_mm256_set1_epi32(1), _mm256_set1_epi32(1)] }, - } - } - - #[inline(always)] - // this function is for internal testing only. it is not - // a source for uniformly random field elements and - // should not be used in production. - fn random_unsafe(mut rng: impl RngCore) -> Self { - // Caution: this may not produce uniformly random elements - unsafe { - let mut v = [ - _mm256_setr_epi32( - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - }, - _mm256_setr_epi32( - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - ), - ]; - v = mod_reduce_epi32_2(v); - v = mod_reduce_epi32_2(v); - AVXM31 { v } - } - } - - #[inline(always)] - fn random_bool(mut rng: impl RngCore) -> Self { - // TODO: optimize this code - AVXM31 { - v: unsafe {[ - _mm256_setr_epi32( - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - ), - _mm256_setr_epi32( - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - rng.gen::() as i32, - ), - ]}, - } - } - - fn exp(&self, _exponent: u128) -> Self { - unimplemented!("exp not implemented for AVXM31") - } - - #[inline(always)] - fn double(&self) -> Self { - self.mul_by_2() - } - - #[inline(always)] - fn mul_by_5(&self) -> AVXM31 { - let double = unsafe { [mod_reduce_epi32(_mm256_slli_epi32::<1>(self.v[0])), mod_reduce_epi32(_mm256_slli_epi32::<1>(self.v[1]))] }; - let quad = unsafe { [mod_reduce_epi32(_mm256_slli_epi32::<1>(double[0])), mod_reduce_epi32(_mm256_slli_epi32::<1>(double[1]))] }; - let res = unsafe { [mod_reduce_epi32(_mm256_add_epi32(self.v, quad[0])), mod_reduce_epi32(_mm256_add_epi32(self.v, quad[1]))] }; - Self { v: res } - } - - #[inline(always)] - fn inv(&self) -> Option { - // slow, should not be used in production - let mut m31_vec = unsafe { transmute::<[__m256i; 2], [M31; 16]>(self.v) }; - let is_non_zero = m31_vec.iter().all(|x| !x.is_zero()); - if !is_non_zero { - return None; - } - - m31_vec.iter_mut().for_each(|x| *x = x.inv().unwrap()); // safe unwrap - Some(Self { - v: unsafe { transmute::<[M31; 16], [__m256i; 2]>(m31_vec) }, - }) - } - - fn as_u32_unchecked(&self) -> u32 { - unimplemented!("self is a vector, cannot convert to u32") - } - - #[inline] - fn from_uniform_bytes(bytes: &[u8; 32]) -> Self { - let m = M31::from_uniform_bytes(bytes); - Self { - v: unsafe { [_mm256_set1_epi32(m.v[0] as i32), _mm256_set1_epi32(m.v[1] as i32)] }, - } - } - - #[inline(always)] - fn mul_by_3(&self) -> AVXM31 { - let double = unsafe { [mod_reduce_epi32(_mm256_slli_epi32::<1>(self.v[0])), mod_reduce_epi32(_mm256_slli_epi32::<1>(self.v[1]))] }; - let res = unsafe { [mod_reduce_epi32(_mm256_add_epi32(self.v[0], double)), mod_reduce_epi32(_mm256_add_epi32(self.v[1], double))] }; - Self { v: res } - } -} - -impl SimdField for AVXM31 { - type Scalar = M31; - - #[inline] - fn scale(&self, challenge: &Self::Scalar) -> Self { - *self * *challenge - } - - #[inline(always)] - fn pack_size() -> usize { - M31_PACK_SIZE - } -} - -impl From for AVXM31 { - #[inline(always)] - fn from(x: M31) -> Self { - AVXM31::pack_full(x) - } -} - -impl Debug for AVXM31 { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut data = [0; M31_PACK_SIZE]; - unsafe { - _mm256_storeu_si256(data.as_mut_ptr() as *mut __m256i, self.v[0]); - _mm256_storeu_si256(data.as_mut_ptr().offset(8) as *mut __m256i, self.v[1]); - } - // if all data is the same, print only one - if data.iter().all(|&x| x == data[0]) { - write!( - f, - "mm256i<8 x {}>", - if M31_MOD - data[0] > 1024 { - format!("{}", data[0]) - } else { - format!("-{}", M31_MOD - data[0]) - } - ) - } else { - write!(f, "mm256i<{:?}>", data) - } - } -} - -impl Default for AVXM31 { - fn default() -> Self { - AVXM31::zero() - } -} - -impl PartialEq for AVXM31 { - #[inline(always)] - fn eq(&self, other: &Self) -> bool { - unsafe { - let pcmp = _mm512_cmpeq_epi32_mask(self.v, other.v); - pcmp == 0xFFFF - } - } -} - -#[inline] -#[must_use] -fn mask_movehdup_epi32(src: __m512i, k: __mmask16, a: __m512i) -> __m512i { - // The instruction is only available in the floating-point flavor; this distinction is only for - // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. - unsafe { - let src = _mm512_castsi512_ps(src); - let a = _mm512_castsi512_ps(a); - _mm512_castps_si512(_mm512_mask_movehdup_ps(src, k, a)) - } -} - -#[inline] -#[must_use] -fn mask_moveldup_epi32(src: __m512i, k: __mmask16, a: __m512i) -> __m512i { - // The instruction is only available in the floating-point flavor; this distinction is only for - // historical reasons and no longer matters. We cast to floats, do the thing, and cast back. - unsafe { - let src = _mm512_castsi512_ps(src); - let a = _mm512_castsi512_ps(a); - _mm512_castps_si512(_mm512_mask_moveldup_ps(src, k, a)) - } -} - -#[inline] -#[must_use] -fn add(lhs: __m512i, rhs: __m512i) -> __m512i { - unsafe { - let t = _mm512_add_epi32(lhs, rhs); - let u = _mm512_sub_epi32(t, PACKED_MOD); - _mm512_min_epu32(t, u) - } -} - -const EVENS: __mmask16 = 0b0101010101010101; -const ODDS: __mmask16 = 0b1010101010101010; - -impl Mul<&M31> for AVXM31 { - type Output = AVXM31; - - #[inline(always)] - fn mul(self, rhs: &M31) -> Self::Output { - let rhsv = AVXM31::pack_full(*rhs); - unsafe { - let rhs_evn = rhsv.v; - let lhs_odd_dbl = _mm512_srli_epi64(self.v, 31); - let lhs_evn_dbl = _mm512_add_epi32(self.v, self.v); - let rhs_odd = movehdup_epi32(rhsv.v); - - let prod_odd_dbl = _mm512_mul_epu32(lhs_odd_dbl, rhs_odd); - let prod_evn_dbl = _mm512_mul_epu32(lhs_evn_dbl, rhs_evn); - - let prod_lo_dbl = mask_moveldup_epi32(prod_evn_dbl, ODDS, prod_odd_dbl); - let prod_hi = mask_movehdup_epi32(prod_odd_dbl, EVENS, prod_evn_dbl); - // Right shift to undo the doubling. - let prod_lo = _mm512_srli_epi32::<1>(prod_lo_dbl); - - // Standard addition of two 31-bit values. - let res = add(prod_lo, prod_hi); - AVXM31 { v: res } - } - } -} - -impl Mul for AVXM31 { - type Output = AVXM31; - #[inline(always)] - fn mul(self, rhs: M31) -> Self::Output { - self * &rhs - } -} - -impl Add for AVXM31 { - type Output = AVXM31; - #[inline(always)] - #[allow(clippy::op_ref)] - fn add(self, rhs: M31) -> Self::Output { - self + AVXM31::pack_full(rhs) - } -} - -impl From for AVXM31 { - #[inline(always)] - fn from(x: u32) -> Self { - AVXM31::pack_full(M31::from(x)) - } -} - -impl Neg for AVXM31 { - type Output = AVXM31; - #[inline(always)] - fn neg(self) -> Self::Output { - AVXM31 { - v: unsafe { _mm512_xor_epi32(self.v, PACKED_MOD) }, - } - } -} - -#[inline] -#[must_use] -fn movehdup_epi32(x: __m512i) -> __m512i { - // The instruction is only available in the floating-point flavor; this distinction is only for - // historical reasons and no longer matters. We cast to floats, duplicate, and cast back. - unsafe { _mm512_castps_si512(_mm512_movehdup_ps(_mm512_castsi512_ps(x))) } -} - -#[inline(always)] -fn add_internal(a: &AVXM31, b: &AVXM31) -> AVXM31 { - unsafe { - let mut result = _mm512_add_epi32(a.v, b.v); - let subx = _mm512_sub_epi32(result, PACKED_MOD); - result = _mm512_min_epu32(result, subx); - - AVXM31 { v: result } - } -} - -#[inline(always)] -fn sub_internal(a: &AVXM31, b: &AVXM31) -> AVXM31 { - AVXM31 { - v: unsafe { - let t = _mm512_sub_epi32(a.v, b.v); - let subx = _mm512_add_epi32(t, PACKED_MOD); - _mm512_min_epu32(t, subx) - }, - } -} - -#[inline] -fn mul_internal(a: &AVXM31, b: &AVXM31) -> AVXM31 { - // credit: https://github.com/Plonky3/Plonky3/blob/eeb4e37b20127c4daa871b2bad0df30a7c7380db/mersenne-31/src/x86_64_avx2/packing.rs#L154 - unsafe { - let rhs_evn = b.v; - let lhs_odd_dbl = _mm512_srli_epi64(a.v, 31); - let lhs_evn_dbl = _mm512_add_epi32(a.v, a.v); - let rhs_odd = movehdup_epi32(b.v); - - let prod_odd_dbl = _mm512_mul_epu32(lhs_odd_dbl, rhs_odd); - let prod_evn_dbl = _mm512_mul_epu32(lhs_evn_dbl, rhs_evn); - - let prod_lo_dbl = mask_moveldup_epi32(prod_evn_dbl, ODDS, prod_odd_dbl); - let prod_hi = mask_movehdup_epi32(prod_odd_dbl, EVENS, prod_evn_dbl); - // Right shift to undo the doubling. - let prod_lo = _mm512_srli_epi32::<1>(prod_lo_dbl); - - // Standard addition of two 31-bit values. - let res = add(prod_lo, prod_hi); - AVXM31 { v: res } - } -} diff --git a/arith/src/field/m31/m31x16.rs b/arith/src/field/m31/m31x16.rs index f6709b3e..7b51943b 100644 --- a/arith/src/field/m31/m31x16.rs +++ b/arith/src/field/m31/m31x16.rs @@ -1,6 +1,6 @@ -/// A M31x16 stores 512 bits of data. -/// With AVX it stores a single __m512i element. -/// With NEON it stores four uint32x4_t elements. +// A M31x16 stores 512 bits of data. +// With AVX it stores a single __m512i element. +// With NEON it stores four uint32x4_t elements. #[cfg(target_arch = "x86_64")] cfg_if::cfg_if! { if #[cfg(feature = "avx256")] { @@ -10,7 +10,6 @@ cfg_if::cfg_if! { } } - #[cfg(target_arch = "aarch64")] pub type M31x16 = super::m31_neon::NeonM31; @@ -44,4 +43,3 @@ fn has_avx512() { } } */ - diff --git a/arith/src/tests/field.rs b/arith/src/tests/field.rs index 18957139..470efacb 100644 --- a/arith/src/tests/field.rs +++ b/arith/src/tests/field.rs @@ -8,17 +8,30 @@ use crate::{Field, FieldSerde}; pub fn random_field_tests(type_name: String) { let mut rng = test_rng(); + println!("{}", std::any::type_name::()); + random_multiplication_tests::(&mut rng, type_name.clone()); + println!("{} random_multiplication_tests done", type_name); random_addition_tests::(&mut rng, type_name.clone()); + println!("{} random_addition_tests done", type_name); random_subtraction_tests::(&mut rng, type_name.clone()); + println!("{} random_subtraction_tests done", type_name); random_negation_tests::(&mut rng, type_name.clone()); + println!("{} random_negation_tests done", type_name); random_doubling_tests::(&mut rng, type_name.clone()); + println!("{} random_doubling_tests done", type_name); random_squaring_tests::(&mut rng, type_name.clone()); + println!("{} random_squaring_tests done", type_name); random_expansion_tests::(&mut rng, type_name.clone()); // also serve as distributivity tests + println!("{} random_expansion_tests done", type_name); random_serde_tests::(&mut rng, type_name.clone()); + println!("{} random_serde_tests done", type_name); associativity_tests::(&mut rng, type_name.clone()); + println!("{} associativity_tests done", type_name); commutativity_tests::(&mut rng, type_name.clone()); + println!("{} commutativity_tests done", type_name); identity_tests::(&mut rng, type_name.clone()); + println!("{} identity_tests done", type_name); //inverse_tests::(&mut rng, type_name.clone()); assert_eq!(F::zero().is_zero(), true); diff --git a/arith/src/tests/m31_ext.rs b/arith/src/tests/m31_ext.rs index d08345e4..ea8ad03d 100644 --- a/arith/src/tests/m31_ext.rs +++ b/arith/src/tests/m31_ext.rs @@ -9,8 +9,11 @@ use super::{ fn test_field() { random_field_tests::("M31 Ext3".to_string()); random_extension_field_tests::("M31 Ext3".to_string()); - + println!("M31Ext3x16 Starting"); random_field_tests::("Simd M31 Ext3".to_string()); + println!("M31Ext3x16 Starting 2"); random_extension_field_tests::("Simd M31 Ext3".to_string()); + println!("M31Ext3x16 Starting 3"); random_simd_field_tests::("Simd M31 Ext3".to_string()); + println!("M31Ext3x16 Done"); }