Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute input min/max with a single vectorized pass in DynamicQuantizeLinear #531

Merged
merged 5 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions rten-simd/src/arch/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::arch::aarch64::{
float32x4_t, int32x4_t, uint32x4_t, vabsq_f32, vaddq_f32, vaddq_s32, vaddvq_f32, vandq_u32,
vbslq_f32, vbslq_s32, vceqq_s32, vcgeq_f32, vcgeq_s32, vcgtq_s32, vcleq_f32, vcleq_s32,
vcltq_f32, vcltq_s32, vcvtq_s32_f32, vdivq_f32, vdupq_n_f32, vdupq_n_s32, vfmaq_f32, vld1q_f32,
vld1q_s32, vld1q_u32, vmaxq_f32, vmaxq_s32, vminq_s32, vmulq_f32, vreinterpretq_f32_s32,
vshlq_n_s32, vst1q_f32, vst1q_s32, vst1q_u32, vsubq_f32, vsubq_s32,
vld1q_s32, vld1q_u32, vmaxq_f32, vmaxq_s32, vminq_f32, vminq_s32, vmulq_f32,
vreinterpretq_f32_s32, vshlq_n_s32, vst1q_f32, vst1q_s32, vst1q_u32, vsubq_f32, vsubq_s32,
};

use crate::{Simd, SimdFloat, SimdInt, SimdMask};
Expand Down Expand Up @@ -222,6 +222,11 @@ impl SimdFloat for float32x4_t {
vmaxq_f32(self, rhs)
}

#[inline]
unsafe fn min(self, rhs: Self) -> Self {
vminq_f32(self, rhs)
}

#[inline]
unsafe fn gather_mask(src: *const f32, offsets: Self::Int, mask: Self::Mask) -> Self {
super::simd_gather_mask::<_, _, _, { Self::LEN }>(src, offsets, mask)
Expand Down
5 changes: 5 additions & 0 deletions rten-simd/src/arch/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ impl SimdFloat for f32 {
f32::max(self, rhs)
}

#[inline]
unsafe fn min(self, rhs: Self) -> Self {
f32::min(self, rhs)
}

#[inline]
unsafe fn gather_mask(ptr: *const f32, offset: i32, mask: Self::Mask) -> Self {
if mask {
Expand Down
9 changes: 7 additions & 2 deletions rten-simd/src/arch/wasm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::arch::wasm32::{
f32x4_abs, f32x4_add, f32x4_div, f32x4_extract_lane, f32x4_ge, f32x4_le, f32x4_lt, f32x4_max,
f32x4_mul, f32x4_splat, f32x4_sub, i32x4, i32x4_add, i32x4_eq, i32x4_ge, i32x4_gt, i32x4_le,
i32x4_lt, i32x4_max, i32x4_min, i32x4_shl, i32x4_shuffle, i32x4_splat, i32x4_sub,
f32x4_min, f32x4_mul, f32x4_splat, f32x4_sub, i32x4, i32x4_add, i32x4_eq, i32x4_ge, i32x4_gt,
i32x4_le, i32x4_lt, i32x4_max, i32x4_min, i32x4_shl, i32x4_shuffle, i32x4_splat, i32x4_sub,
i32x4_trunc_sat_f32x4, v128, v128_and, v128_bitselect, v128_load, v128_store,
};

Expand Down Expand Up @@ -230,6 +230,11 @@ impl SimdFloat for v128f {
Self(f32x4_max(self.0, rhs.0))
}

#[inline]
unsafe fn min(self, rhs: Self) -> Self {
Self(f32x4_min(self.0, rhs.0))
}

#[inline]
unsafe fn gather_mask(src: *const f32, offsets: Self::Int, mask: Self::Mask) -> Self {
super::simd_gather_mask::<_, _, _, { Self::LEN }>(src, offsets, mask)
Expand Down
24 changes: 18 additions & 6 deletions rten-simd/src/arch/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use std::arch::x86_64::{
_mm256_blendv_epi8, _mm256_blendv_ps, _mm256_castps256_ps128, _mm256_castsi256_ps,
_mm256_cmp_ps, _mm256_cmpeq_epi32, _mm256_cmpgt_epi32, _mm256_cvttps_epi32, _mm256_div_ps,
_mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_loadu_si256, _mm256_max_epi32,
_mm256_max_ps, _mm256_min_epi32, _mm256_mul_ps, _mm256_or_si256, _mm256_set1_epi32,
_mm256_set1_ps, _mm256_setr_epi32, _mm256_slli_epi32, _mm256_storeu_ps, _mm256_storeu_si256,
_mm256_sub_epi32, _mm256_sub_ps, _mm_add_ps, _mm_cvtss_f32, _mm_movehl_ps, _mm_prefetch,
_mm_shuffle_ps, _CMP_GE_OQ, _CMP_LE_OQ, _CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0,
_mm256_max_ps, _mm256_min_epi32, _mm256_min_ps, _mm256_mul_ps, _mm256_or_si256,
_mm256_set1_epi32, _mm256_set1_ps, _mm256_setr_epi32, _mm256_slli_epi32, _mm256_storeu_ps,
_mm256_storeu_si256, _mm256_sub_epi32, _mm256_sub_ps, _mm_add_ps, _mm_cvtss_f32, _mm_movehl_ps,
_mm_prefetch, _mm_shuffle_ps, _CMP_GE_OQ, _CMP_LE_OQ, _CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0,
};
use std::mem::transmute;

Expand Down Expand Up @@ -285,6 +285,12 @@ impl SimdFloat for __m256 {
_mm256_max_ps(self, rhs)
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn min(self, rhs: Self) -> Self {
_mm256_min_ps(self, rhs)
}

#[inline]
#[target_feature(enable = "avx2")]
unsafe fn gather_mask(src: *const f32, offsets: Self::Int, mask: Self::Mask) -> Self {
Expand Down Expand Up @@ -322,8 +328,8 @@ use std::arch::x86_64::{
_mm512_castsi512_ps, _mm512_cmp_epi32_mask, _mm512_cmp_ps_mask, _mm512_cvttps_epi32,
_mm512_div_ps, _mm512_fmadd_ps, _mm512_loadu_ps, _mm512_loadu_si512, _mm512_mask_blend_epi32,
_mm512_mask_blend_ps, _mm512_mask_i32gather_ps, _mm512_max_epi32, _mm512_max_ps,
_mm512_min_epi32, _mm512_mul_ps, _mm512_reduce_add_ps, _mm512_set1_epi32, _mm512_set1_ps,
_mm512_setzero_si512, _mm512_sllv_epi32, _mm512_storeu_epi32, _mm512_storeu_ps,
_mm512_min_epi32, _mm512_min_ps, _mm512_mul_ps, _mm512_reduce_add_ps, _mm512_set1_epi32,
_mm512_set1_ps, _mm512_setzero_si512, _mm512_sllv_epi32, _mm512_storeu_epi32, _mm512_storeu_ps,
_mm512_sub_epi32, _mm512_sub_ps, _MM_CMPINT_EQ, _MM_CMPINT_LE, _MM_CMPINT_LT,
};

Expand Down Expand Up @@ -596,6 +602,12 @@ impl SimdFloat for __m512 {
_mm512_max_ps(self, rhs)
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn mix(self, rhs: Self) -> Self {
_mm512_min_ps(self, rhs)
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn gather_mask(ptr: *const f32, offsets: Self::Int, mask: Self::Mask) -> Self {
Expand Down
37 changes: 37 additions & 0 deletions rten-simd/src/functional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,40 @@ pub unsafe fn simd_fold<S: Simd, Op: Fn(S, S) -> S>(

accum
}

/// A variant of [`simd_fold`] where the accumulator is an array of values
/// instead of just one.
///
/// # Safety
///
/// The caller must ensure that `S` is a supported SIMD vector type on the
/// current system.
#[inline(always)]
pub unsafe fn simd_fold_array<S: Simd, const N: usize, Op: Fn([S; N], S) -> [S; N]>(
xs: PtrLen<S::Elem>,
mut accum: [S; N],
simd_op: Op,
) -> [S; N] {
let mut n = xs.len();
let mut x_ptr = xs.ptr();

while n >= S::LEN {
let x = S::load(x_ptr);
accum = simd_op(accum, x);
n -= S::LEN;
x_ptr = x_ptr.add(S::LEN);
}

let n_mask = S::Mask::first_n(n);
if n > 0 {
let x = S::load_partial(x_ptr, n);
let prev_accum = accum;
let new_accum = simd_op(accum, x);

for i in 0..N {
accum[i] = prev_accum[i].blend(new_accum[i], n_mask);
}
}

accum
}
8 changes: 7 additions & 1 deletion rten-simd/src/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ pub trait Simd: Copy + Sized {
/// This type should always be `[Self::ELEM; Self::LEN]`. The `to_array`
/// method returns this associated type rather than a concrete array due to
/// const generics limitations.
type Array: Copy + std::fmt::Debug + std::ops::Index<usize, Output = Self::Elem>;
type Array: Copy
+ std::fmt::Debug
+ std::ops::Index<usize, Output = Self::Elem>
+ AsRef<[Self::Elem]>;

/// Combine elements of `self` and `rhs` according to a mask.
///
Expand Down Expand Up @@ -275,6 +278,9 @@ pub trait SimdFloat: Simd<Elem = f32> {
/// Compute a mask containing `self < rhs`.
unsafe fn lt(self, rhs: Self) -> Self::Mask;

/// Compute the minimum of `self` and `rhs`.
unsafe fn min(self, rhs: Self) -> Self;

/// Compute the maximum of `self` and `rhs`.
unsafe fn max(self, rhs: Self) -> Self;

Expand Down
2 changes: 2 additions & 0 deletions rten-vecmath/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@

mod erf;
mod exp;
mod min_max;
mod normalize;
mod softmax;
mod sum;
Expand All @@ -100,6 +101,7 @@ pub use exp::{Exp, Sigmoid, Silu, Swish};
pub use tanh::Tanh;

// Normalization and reduction functions.
pub use min_max::MinMax;
pub use normalize::{Normalize, NormalizeOptions};
pub use softmax::Softmax;
pub use sum::{Sum, SumSquare, SumSquareSub};
63 changes: 63 additions & 0 deletions rten-vecmath/src/min_max.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use rten_simd::dispatch::SimdOp;
use rten_simd::functional::simd_fold_array;
use rten_simd::SimdFloat;

/// Compute the minimum and maximum values in a slice of floats.
pub struct MinMax<'a> {
input: &'a [f32],
}

impl<'a> MinMax<'a> {
pub fn new(input: &'a [f32]) -> Self {
MinMax { input }
}
}

impl SimdOp for MinMax<'_> {
type Output = (f32, f32);

#[inline(always)]
unsafe fn eval<S: SimdFloat>(self) -> Self::Output {
let [vec_min, vec_max] = simd_fold_array(
self.input.into(),
[S::splat(f32::MAX), S::splat(f32::MIN)],
#[inline(always)]
|[min, max], x| [x.min(min), x.max(max)],
);
let min = vec_min
.to_array()
.as_ref()
.iter()
.fold(f32::MAX, |min, x| x.min(min));
let max = vec_max
.to_array()
.as_ref()
.iter()
.fold(f32::MIN, |max, x| x.max(max));
(min, max)
}
}

#[cfg(test)]
mod tests {
use super::MinMax;
use rten_simd::dispatch::SimdOp;

// Chosen to not be a multiple of vector size, so that tail handling is
// exercised.
const LEN: usize = 100;

fn reference_min_max(xs: &[f32]) -> (f32, f32) {
let min = xs.iter().fold(f32::MAX, |min, x| x.min(min));
let max = xs.iter().fold(f32::MIN, |max, x| x.max(max));
(min, max)
}

#[test]
fn test_min_max() {
let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
let expected = reference_min_max(&xs);
let min_max = MinMax::new(&xs).dispatch();
assert_eq!(min_max, expected);
}
}
18 changes: 6 additions & 12 deletions src/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,22 +840,16 @@ fn gemm_impl<LhsT: GemmInT, RhsT: GemmInT, OutT: GemmOutT>(
return Err(GemmError::WrongBiasSize);
}

match a_quant {
Some(quant) => {
if quant.zero_point.len() != a.rows() {
return Err(GemmError::WrongQuantParamSize);
}
if let Some(a_quant) = a_quant {
if a_quant.zero_point.len() != a.rows() {
return Err(GemmError::WrongQuantParamSize);
}
None => {}
}

match b_quant {
Some(quant) => {
if quant.zero_point.len() != b.cols() {
return Err(GemmError::WrongQuantParamSize);
}
if let Some(b_quant) = b_quant {
if b_quant.zero_point.len() != b.cols() {
return Err(GemmError::WrongQuantParamSize);
}
None => {}
}

// Handle case where output is empty.
Expand Down
23 changes: 7 additions & 16 deletions src/ops/quantize.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use rten_simd::dispatch::SimdOp;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, Scalar, Tensor, TensorView};
use rten_vecmath as vecmath;

use crate::ops::{
reduce_max, reduce_min, resolve_axis, DataType, Input, InputList, IntoOpResult, OpError,
Operator, Output, OutputList,
resolve_axis, DataType, Input, InputList, IntoOpResult, OpError, Operator, Output, OutputList,
};
use crate::tensor_pool::{AutoReturn, TensorPool};
use crate::tensor_pool::TensorPool;

/// Convert a quantized tensor element to a higher precision value.
pub trait Dequantize<To> {
Expand Down Expand Up @@ -309,19 +310,9 @@ where
let q_min = 0.;
let q_max = 255.;

// Get the range of the input. This implementation is simple but sub-optimal
// as it makes two passes over the same data to get the min/max.
let x_min = reduce_min(pool, input.view(), None, false /* keep_dims */)?
.auto_return(pool)
.item()
.copied()
.unwrap();
let input = input.to_contiguous_in(pool);
let (x_min, x_max) = vecmath::MinMax::new(input.data().unwrap()).dispatch();
let x_min_adjusted = x_min.min(q_min);
let x_max = reduce_max(pool, input.view(), None, false /* keep_dims */)?
.auto_return(pool)
.item()
.copied()
.unwrap();
let x_max_adjusted = x_max.max(q_min);
let x_range = x_max_adjusted - x_min_adjusted;
let scale = x_range / q_max;
Expand All @@ -335,7 +326,7 @@ where
let zero_point_tensor = Tensor::from(zero_point);
let quantized = quantize_linear(
pool,
input,
input.view(),
scale_tensor.view(),
Some(zero_point_tensor.view()),
1, /* axis */
Expand Down
Loading