Skip to content

Commit

Permalink
avx256 (#88)
Browse files Browse the repository at this point in the history
* avx256

* fixing m31

* fix

* clear

* fix CI error

* Update m31_ext.rs

Signed-off-by: Tiancheng Xie <[email protected]>

* Update field.rs

Signed-off-by: Tiancheng Xie <[email protected]>

* Delete arith/debugout

Signed-off-by: Tiancheng Xie <[email protected]>

* Update m31x16.rs

Signed-off-by: Tiancheng Xie <[email protected]>

* Update field.rs

Signed-off-by: Tiancheng Xie <[email protected]>

* update bench

* fix CI

* fix CI

---------

Signed-off-by: Tiancheng Xie <[email protected]>
Co-authored-by: Tiancheng Xie <[email protected]>
  • Loading branch information
chonpsk and niconiconi authored Sep 10, 2024
1 parent 7670d3d commit 6fd2d95
Show file tree
Hide file tree
Showing 13 changed files with 2,070 additions and 5 deletions.
4 changes: 4 additions & 0 deletions arith/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ sha2.workspace = true
thiserror.workspace = true
ethnum.workspace = true

raw-cpuid = "11.1.0"
cfg-if = "1.0"

[dev-dependencies]
tynm.workspace = true
criterion.workspace = true
Expand All @@ -25,3 +28,4 @@ name = "ext_field"
harness = false

[features]
avx256 = []
4 changes: 4 additions & 0 deletions arith/benches/ext_field.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(target_arch = "x86_64")]
use arith::GF2_128x8_256;
use arith::{ExtensionField, Field, GF2_128x8, M31Ext3, M31Ext3x16, GF2_128};
use ark_std::test_rng;
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
Expand Down Expand Up @@ -154,6 +156,8 @@ fn ext_by_base_benchmark(c: &mut Criterion) {
bench_field::<M31Ext3x16>(c);
bench_field::<GF2_128>(c);
bench_field::<GF2_128x8>(c);
#[cfg(target_arch = "x86_64")]
bench_field::<GF2_128x8_256>(c);
}

criterion_group!(ext_by_base_benches, ext_by_base_benchmark);
Expand Down
6 changes: 6 additions & 0 deletions arith/benches/field.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// this module benchmarks the performance of different field operations

use arith::{Field, GF2_128x8, GF2x8, M31Ext3, M31Ext3x16, M31x16, GF2, GF2_128, M31};
#[cfg(target_arch = "x86_64")]
use arith::{GF2_128x8_256, M31x16_256};
use ark_std::test_rng;
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use halo2curves::bn256::Fr;
Expand Down Expand Up @@ -174,13 +176,17 @@ pub(crate) fn bench_field<F: Field>(c: &mut Criterion) {
fn criterion_benchmark(c: &mut Criterion) {
bench_field::<M31>(c);
bench_field::<M31x16>(c);
#[cfg(target_arch = "x86_64")]
bench_field::<M31x16_256>(c);
bench_field::<M31Ext3>(c);
bench_field::<M31Ext3x16>(c);
bench_field::<Fr>(c);
bench_field::<GF2>(c);
bench_field::<GF2x8>(c);
bench_field::<GF2_128>(c);
bench_field::<GF2_128x8>(c);
#[cfg(target_arch = "x86_64")]
bench_field::<GF2_128x8_256>(c);
}

criterion_group!(benches, criterion_benchmark);
Expand Down
2 changes: 2 additions & 0 deletions arith/src/extension_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use crate::{Field, FieldSerde};

pub use gf2_128::*;
pub use gf2_128x8::GF2_128x8;
#[cfg(target_arch = "x86_64")]
pub use gf2_128x8::GF2_128x8_256;
pub use m31_ext::M31Ext3;
pub use m31_ext3x16::M31Ext3x16;

Expand Down
318 changes: 318 additions & 0 deletions arith/src/extension_field/gf2_128/avx256.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
use std::iter::{Product, Sum};
use std::{
arch::x86_64::*,
mem::transmute,
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};

use crate::{field_common, ExtensionField, Field, FieldSerde, FieldSerdeResult, GF2};

#[derive(Debug, Clone, Copy)]
pub struct AVX512GF2_128 {
pub v: __m128i,
}

field_common!(AVX512GF2_128);

impl FieldSerde for AVX512GF2_128 {
const SERIALIZED_SIZE: usize = 16;

#[inline(always)]
fn serialize_into<W: std::io::Write>(&self, mut writer: W) -> FieldSerdeResult<()> {
unsafe { writer.write_all(transmute::<__m128i, [u8; 16]>(self.v).as_ref())? };
Ok(())
}

#[inline(always)]
fn deserialize_from<R: std::io::Read>(mut reader: R) -> FieldSerdeResult<Self> {
let mut u = [0u8; Self::SERIALIZED_SIZE];
reader.read_exact(&mut u)?;
unsafe {
Ok(AVX512GF2_128 {
v: transmute::<[u8; Self::SERIALIZED_SIZE], __m128i>(u),
})
}
}

#[inline(always)]
fn try_deserialize_from_ecc_format<R: std::io::Read>(mut reader: R) -> FieldSerdeResult<Self> {
let mut u = [0u8; 32];
reader.read_exact(&mut u)?;
Ok(unsafe {
AVX512GF2_128 {
v: transmute::<[u8; 16], __m128i>(u[..16].try_into().unwrap()),
}
})
}
}

impl Field for AVX512GF2_128 {
const NAME: &'static str = "Galios Field 2^128";

const SIZE: usize = 128 / 8;

const FIELD_SIZE: usize = 128; // in bits

const ZERO: Self = AVX512GF2_128 {
v: unsafe { std::mem::zeroed() },
};

const ONE: Self = AVX512GF2_128 {
v: unsafe { std::mem::transmute::<[i32; 4], __m128i>([1, 0, 0, 0]) },
};

const INV_2: Self = AVX512GF2_128 {
v: unsafe { std::mem::zeroed() },
}; // should not be used

#[inline(always)]
fn zero() -> Self {
AVX512GF2_128 {
v: unsafe { std::mem::zeroed() },
}
}

#[inline(always)]
fn one() -> Self {
AVX512GF2_128 {
// 1 in the first bit
v: unsafe { std::mem::transmute::<[i32; 4], __m128i>([1, 0, 0, 0]) }, // TODO check bit order
}
}

#[inline(always)]
fn random_unsafe(mut rng: impl rand::RngCore) -> Self {
let mut u = [0u8; 16];
rng.fill_bytes(&mut u);
unsafe {
AVX512GF2_128 {
v: *(u.as_ptr() as *const __m128i),
}
}
}

#[inline(always)]
fn random_bool(mut rng: impl rand::RngCore) -> Self {
AVX512GF2_128 {
v: unsafe { std::mem::transmute::<[u32; 4], __m128i>([rng.next_u32() % 2, 0, 0, 0]) },
}
}

#[inline(always)]
fn is_zero(&self) -> bool {
unsafe { std::mem::transmute::<__m128i, [u8; 16]>(self.v) == [0; 16] }
}

#[inline(always)]
fn exp(&self, exponent: u128) -> Self {
let mut e = exponent;
let mut res = Self::one();
let mut t = *self;
while e > 0 {
if e & 1 == 1 {
res *= t;
}
t = t * t;
e >>= 1;
}
res
}

#[inline(always)]
fn inv(&self) -> Option<Self> {
if self.is_zero() {
return None;
}
let p_m2 = !(0u128) - 1;
Some(Self::exp(self, p_m2))
}

#[inline(always)]
fn square(&self) -> Self {
self * self
}

#[inline(always)]
fn as_u32_unchecked(&self) -> u32 {
unimplemented!("u32 for GF128 doesn't make sense")
}

#[inline(always)]
fn from_uniform_bytes(bytes: &[u8; 32]) -> Self {
unsafe {
AVX512GF2_128 {
v: transmute::<[u8; 16], __m128i>(bytes[..16].try_into().unwrap()),
}
}
}
}

impl ExtensionField for AVX512GF2_128 {
const DEGREE: usize = 128;

const W: u32 = 0x87;

const X: Self = AVX512GF2_128 {
v: unsafe { std::mem::transmute::<[i32; 4], __m128i>([2, 0, 0, 0]) },
};

type BaseField = GF2;

#[inline(always)]
fn mul_by_base_field(&self, base: &Self::BaseField) -> Self {
if base.v == 0 {
Self::zero()
} else {
*self
}
}

#[inline(always)]
fn add_by_base_field(&self, base: &Self::BaseField) -> Self {
let mut res = *self;
res.v = unsafe { _mm_xor_si128(res.v, _mm_set_epi64x(0, base.v as i64)) };
res
}

#[inline]
fn mul_by_x(&self) -> Self {
unsafe {
// Shift left by 1 bit
let shifted = _mm_slli_epi64(self.v, 1);

// Get the most significant bit and move it
let msb = _mm_srli_epi64(self.v, 63);
let msb_moved = _mm_slli_si128(msb, 8);

// Combine the shifted value with the moved msb
let shifted_consolidated = _mm_or_si128(shifted, msb_moved);

// Create the reduction value (0x87) and the comparison value (1)
let reduction = {
let multiplier = _mm_set_epi64x(0, 0x87);
let one = _mm_set_epi64x(0, 1);

// Check if the MSB was 1 and create a mask
let mask = _mm_cmpeq_epi64(_mm_srli_si128(msb, 8), one);

_mm_and_si128(mask, multiplier)
};

// Apply the reduction conditionally
let res = _mm_xor_si128(shifted_consolidated, reduction);

Self { v: res }
}
}
}

impl From<GF2> for AVX512GF2_128 {
#[inline(always)]
fn from(v: GF2) -> Self {
AVX512GF2_128 {
v: unsafe { _mm_set_epi64x(0, v.v as i64) },
}
}
}

#[inline]
unsafe fn gfmul(a: __m128i, b: __m128i) -> __m128i {
let xmm_mask = _mm_setr_epi32((0xffffffff_u32) as i32, 0x0, 0x0, 0x0);

// a = a0|a1, b = b0|b1

let mut tmp3 = _mm_clmulepi64_si128(a, b, 0x00); // tmp3 = a0 * b0
let mut tmp6 = _mm_clmulepi64_si128(a, b, 0x11); // tmp6 = a1 * b1

let mut tmp4 = _mm_shuffle_epi32(a, 78); // tmp4 = a1|a0
let mut tmp5 = _mm_shuffle_epi32(b, 78); // tmp5 = b1|b0
tmp4 = _mm_xor_si128(tmp4, a); // tmp4 = (a0 + a1) | (a0 + a1)
tmp5 = _mm_xor_si128(tmp5, b); // tmp5 = (b0 + b1) | (b0 + b1)

tmp4 = _mm_clmulepi64_si128(tmp4, tmp5, 0x00); // tmp4 = (a0 + a1) * (b0 + b1)
tmp4 = _mm_xor_si128(tmp4, tmp3); // tmp4 = (a0 + a1) * (b0 + b1) - a0 * b0
tmp4 = _mm_xor_si128(tmp4, tmp6); // tmp4 = (a0 + a1) * (b0 + b1) - a0 * b0 - a1 * b1 = a0 * b1 + a1 * b0

let tmp5_shifted_left = _mm_slli_si128(tmp4, 8);
tmp4 = _mm_srli_si128(tmp4, 8);
tmp3 = _mm_xor_si128(tmp3, tmp5_shifted_left);
tmp6 = _mm_xor_si128(tmp6, tmp4);

let mut tmp7 = _mm_srli_epi32(tmp6, 31);
let mut tmp8 = _mm_srli_epi32(tmp6, 30);
let tmp9 = _mm_srli_epi32(tmp6, 25);

tmp7 = _mm_xor_si128(tmp7, tmp8);
tmp7 = _mm_xor_si128(tmp7, tmp9);

tmp8 = _mm_shuffle_epi32(tmp7, 147);
tmp7 = _mm_and_si128(xmm_mask, tmp8);
tmp8 = _mm_andnot_si128(xmm_mask, tmp8);

tmp3 = _mm_xor_si128(tmp3, tmp8);
tmp6 = _mm_xor_si128(tmp6, tmp7);

let tmp10 = _mm_slli_epi32(tmp6, 1);
tmp3 = _mm_xor_si128(tmp3, tmp10);

let tmp11 = _mm_slli_epi32(tmp6, 2);
tmp3 = _mm_xor_si128(tmp3, tmp11);

let tmp12 = _mm_slli_epi32(tmp6, 7);
tmp3 = _mm_xor_si128(tmp3, tmp12);

_mm_xor_si128(tmp3, tmp6)
}

impl Default for AVX512GF2_128 {
#[inline(always)]
fn default() -> Self {
Self::zero()
}
}

impl PartialEq for AVX512GF2_128 {
#[inline(always)]
fn eq(&self, other: &Self) -> bool {
unsafe { _mm_test_all_ones(_mm_cmpeq_epi8(self.v, other.v)) == 1 }
}
}

impl Neg for AVX512GF2_128 {
type Output = Self;

#[inline(always)]
fn neg(self) -> Self {
self
}
}

impl From<u32> for AVX512GF2_128 {
#[inline(always)]
fn from(v: u32) -> Self {
AVX512GF2_128 {
v: unsafe { std::mem::transmute::<[u32; 4], __m128i>([v, 0, 0, 0]) },
}
}
}

#[inline(always)]
fn add_internal(a: &AVX512GF2_128, b: &AVX512GF2_128) -> AVX512GF2_128 {
AVX512GF2_128 {
v: unsafe { _mm_xor_si128(a.v, b.v) },
}
}

#[inline(always)]
fn sub_internal(a: &AVX512GF2_128, b: &AVX512GF2_128) -> AVX512GF2_128 {
AVX512GF2_128 {
v: unsafe { _mm_xor_si128(a.v, b.v) },
}
}

#[inline(always)]
fn mul_internal(a: &AVX512GF2_128, b: &AVX512GF2_128) -> AVX512GF2_128 {
AVX512GF2_128 {
v: unsafe { gfmul(a.v, b.v) },
}
}
Loading

0 comments on commit 6fd2d95

Please sign in to comment.