From b030964652b18e3ce74df0bf599cdd4e13f7a080 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 10 Jan 2025 22:58:49 +0000 Subject: [PATCH 1/5] Take clippy's advice about using `if let` instead of match --- src/gemm.rs | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/gemm.rs b/src/gemm.rs index fb462ba6..1e1db2d8 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -840,22 +840,16 @@ fn gemm_impl( return Err(GemmError::WrongBiasSize); } - match a_quant { - Some(quant) => { - if quant.zero_point.len() != a.rows() { - return Err(GemmError::WrongQuantParamSize); - } + if let Some(a_quant) = a_quant { + if a_quant.zero_point.len() != a.rows() { + return Err(GemmError::WrongQuantParamSize); } - None => {} } - match b_quant { - Some(quant) => { - if quant.zero_point.len() != b.cols() { - return Err(GemmError::WrongQuantParamSize); - } + if let Some(b_quant) = b_quant { + if b_quant.zero_point.len() != b.cols() { + return Err(GemmError::WrongQuantParamSize); } - None => {} } // Handle case where output is empty. From b6eebb9c685019e221474d3fafec8006f0ee5b37 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 10 Jan 2025 22:50:11 +0000 Subject: [PATCH 2/5] Add `SimdFloat:min` method There was already a `max` method, so this fills in a gap. --- rten-simd/src/arch/aarch64.rs | 9 +++++++-- rten-simd/src/arch/scalar.rs | 5 +++++ rten-simd/src/arch/wasm.rs | 9 +++++++-- rten-simd/src/arch/x86_64.rs | 24 ++++++++++++++++++------ rten-simd/src/vec.rs | 8 +++++++- 5 files changed, 44 insertions(+), 11 deletions(-) diff --git a/rten-simd/src/arch/aarch64.rs b/rten-simd/src/arch/aarch64.rs index 8bfde543..f35d3d65 100644 --- a/rten-simd/src/arch/aarch64.rs +++ b/rten-simd/src/arch/aarch64.rs @@ -2,8 +2,8 @@ use std::arch::aarch64::{ float32x4_t, int32x4_t, uint32x4_t, vabsq_f32, vaddq_f32, vaddq_s32, vaddvq_f32, vandq_u32, vbslq_f32, vbslq_s32, vceqq_s32, vcgeq_f32, vcgeq_s32, vcgtq_s32, vcleq_f32, vcleq_s32, vcltq_f32, vcltq_s32, vcvtq_s32_f32, vdivq_f32, vdupq_n_f32, vdupq_n_s32, vfmaq_f32, vld1q_f32, - vld1q_s32, vld1q_u32, vmaxq_f32, vmaxq_s32, vminq_s32, vmulq_f32, vreinterpretq_f32_s32, - vshlq_n_s32, vst1q_f32, vst1q_s32, vst1q_u32, vsubq_f32, vsubq_s32, + vld1q_s32, vld1q_u32, vmaxq_f32, vmaxq_s32, vminq_f32, vminq_s32, vmulq_f32, + vreinterpretq_f32_s32, vshlq_n_s32, vst1q_f32, vst1q_s32, vst1q_u32, vsubq_f32, vsubq_s32, }; use crate::{Simd, SimdFloat, SimdInt, SimdMask}; @@ -222,6 +222,11 @@ impl SimdFloat for float32x4_t { vmaxq_f32(self, rhs) } + #[inline] + unsafe fn min(self, rhs: Self) -> Self { + vminq_f32(self, rhs) + } + #[inline] unsafe fn gather_mask(src: *const f32, offsets: Self::Int, mask: Self::Mask) -> Self { super::simd_gather_mask::<_, _, _, { Self::LEN }>(src, offsets, mask) diff --git a/rten-simd/src/arch/scalar.rs b/rten-simd/src/arch/scalar.rs index 20b06fac..4ee84efb 100644 --- a/rten-simd/src/arch/scalar.rs +++ b/rten-simd/src/arch/scalar.rs @@ -187,6 +187,11 @@ impl SimdFloat for f32 { f32::max(self, rhs) } + #[inline] + unsafe fn min(self, rhs: Self) -> Self { + f32::min(self, rhs) + } + #[inline] unsafe fn gather_mask(ptr: *const f32, offset: i32, mask: Self::Mask) -> Self { if mask { diff --git a/rten-simd/src/arch/wasm.rs b/rten-simd/src/arch/wasm.rs index c4955d16..93dc3c85 100644 --- a/rten-simd/src/arch/wasm.rs +++ b/rten-simd/src/arch/wasm.rs @@ -1,7 +1,7 @@ use std::arch::wasm32::{ f32x4_abs, f32x4_add, f32x4_div, f32x4_extract_lane, f32x4_ge, f32x4_le, f32x4_lt, f32x4_max, - f32x4_mul, f32x4_splat, f32x4_sub, i32x4, i32x4_add, i32x4_eq, i32x4_ge, i32x4_gt, i32x4_le, - i32x4_lt, i32x4_max, i32x4_min, i32x4_shl, i32x4_shuffle, i32x4_splat, i32x4_sub, + f32x4_min, f32x4_mul, f32x4_splat, f32x4_sub, i32x4, i32x4_add, i32x4_eq, i32x4_ge, i32x4_gt, + i32x4_le, i32x4_lt, i32x4_max, i32x4_min, i32x4_shl, i32x4_shuffle, i32x4_splat, i32x4_sub, i32x4_trunc_sat_f32x4, v128, v128_and, v128_bitselect, v128_load, v128_store, }; @@ -230,6 +230,11 @@ impl SimdFloat for v128f { Self(f32x4_max(self.0, rhs.0)) } + #[inline] + unsafe fn min(self, rhs: Self) -> Self { + Self(f32x4_min(self.0, rhs.0)) + } + #[inline] unsafe fn gather_mask(src: *const f32, offsets: Self::Int, mask: Self::Mask) -> Self { super::simd_gather_mask::<_, _, _, { Self::LEN }>(src, offsets, mask) diff --git a/rten-simd/src/arch/x86_64.rs b/rten-simd/src/arch/x86_64.rs index 78470fce..af3bccc8 100644 --- a/rten-simd/src/arch/x86_64.rs +++ b/rten-simd/src/arch/x86_64.rs @@ -3,10 +3,10 @@ use std::arch::x86_64::{ _mm256_blendv_epi8, _mm256_blendv_ps, _mm256_castps256_ps128, _mm256_castsi256_ps, _mm256_cmp_ps, _mm256_cmpeq_epi32, _mm256_cmpgt_epi32, _mm256_cvttps_epi32, _mm256_div_ps, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_loadu_ps, _mm256_loadu_si256, _mm256_max_epi32, - _mm256_max_ps, _mm256_min_epi32, _mm256_mul_ps, _mm256_or_si256, _mm256_set1_epi32, - _mm256_set1_ps, _mm256_setr_epi32, _mm256_slli_epi32, _mm256_storeu_ps, _mm256_storeu_si256, - _mm256_sub_epi32, _mm256_sub_ps, _mm_add_ps, _mm_cvtss_f32, _mm_movehl_ps, _mm_prefetch, - _mm_shuffle_ps, _CMP_GE_OQ, _CMP_LE_OQ, _CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0, + _mm256_max_ps, _mm256_min_epi32, _mm256_min_ps, _mm256_mul_ps, _mm256_or_si256, + _mm256_set1_epi32, _mm256_set1_ps, _mm256_setr_epi32, _mm256_slli_epi32, _mm256_storeu_ps, + _mm256_storeu_si256, _mm256_sub_epi32, _mm256_sub_ps, _mm_add_ps, _mm_cvtss_f32, _mm_movehl_ps, + _mm_prefetch, _mm_shuffle_ps, _CMP_GE_OQ, _CMP_LE_OQ, _CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0, }; use std::mem::transmute; @@ -285,6 +285,12 @@ impl SimdFloat for __m256 { _mm256_max_ps(self, rhs) } + #[inline] + #[target_feature(enable = "avx2")] + unsafe fn min(self, rhs: Self) -> Self { + _mm256_min_ps(self, rhs) + } + #[inline] #[target_feature(enable = "avx2")] unsafe fn gather_mask(src: *const f32, offsets: Self::Int, mask: Self::Mask) -> Self { @@ -322,8 +328,8 @@ use std::arch::x86_64::{ _mm512_castsi512_ps, _mm512_cmp_epi32_mask, _mm512_cmp_ps_mask, _mm512_cvttps_epi32, _mm512_div_ps, _mm512_fmadd_ps, _mm512_loadu_ps, _mm512_loadu_si512, _mm512_mask_blend_epi32, _mm512_mask_blend_ps, _mm512_mask_i32gather_ps, _mm512_max_epi32, _mm512_max_ps, - _mm512_min_epi32, _mm512_mul_ps, _mm512_reduce_add_ps, _mm512_set1_epi32, _mm512_set1_ps, - _mm512_setzero_si512, _mm512_sllv_epi32, _mm512_storeu_epi32, _mm512_storeu_ps, + _mm512_min_epi32, _mm512_min_ps, _mm512_mul_ps, _mm512_reduce_add_ps, _mm512_set1_epi32, + _mm512_set1_ps, _mm512_setzero_si512, _mm512_sllv_epi32, _mm512_storeu_epi32, _mm512_storeu_ps, _mm512_sub_epi32, _mm512_sub_ps, _MM_CMPINT_EQ, _MM_CMPINT_LE, _MM_CMPINT_LT, }; @@ -596,6 +602,12 @@ impl SimdFloat for __m512 { _mm512_max_ps(self, rhs) } + #[inline] + #[target_feature(enable = "avx512f")] + unsafe fn mix(self, rhs: Self) -> Self { + _mm512_min_ps(self, rhs) + } + #[inline] #[target_feature(enable = "avx512f")] unsafe fn gather_mask(ptr: *const f32, offsets: Self::Int, mask: Self::Mask) -> Self { diff --git a/rten-simd/src/vec.rs b/rten-simd/src/vec.rs index cacef630..89aad964 100644 --- a/rten-simd/src/vec.rs +++ b/rten-simd/src/vec.rs @@ -68,7 +68,10 @@ pub trait Simd: Copy + Sized { /// This type should always be `[Self::ELEM; Self::LEN]`. The `to_array` /// method returns this associated type rather than a concrete array due to /// const generics limitations. - type Array: Copy + std::fmt::Debug + std::ops::Index; + type Array: Copy + + std::fmt::Debug + + std::ops::Index + + AsRef<[Self::Elem]>; /// Combine elements of `self` and `rhs` according to a mask. /// @@ -275,6 +278,9 @@ pub trait SimdFloat: Simd { /// Compute a mask containing `self < rhs`. unsafe fn lt(self, rhs: Self) -> Self::Mask; + /// Compute the minimum of `self` and `rhs`. + unsafe fn min(self, rhs: Self) -> Self; + /// Compute the maximum of `self` and `rhs`. unsafe fn max(self, rhs: Self) -> Self; From c5461469e4997e92ff12c87a08cdd010b278e195 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 10 Jan 2025 22:50:59 +0000 Subject: [PATCH 3/5] Add `simd_fold_array` helper for vectorized ops This is useful for reductions which need to compute multiple values in one pass over the data. --- rten-simd/src/functional.rs | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/rten-simd/src/functional.rs b/rten-simd/src/functional.rs index cd1503ae..54933d66 100644 --- a/rten-simd/src/functional.rs +++ b/rten-simd/src/functional.rs @@ -82,3 +82,40 @@ pub unsafe fn simd_fold S>( accum } + +/// A variant of [`simd_fold`] where the accumulator is an array of values +/// instead of just one. +/// +/// # Safety +/// +/// The caller must ensure that `S` is a supported SIMD vector type on the +/// current system. +#[inline(always)] +pub unsafe fn simd_fold_array [S; N]>( + xs: PtrLen, + mut accum: [S; N], + simd_op: Op, +) -> [S; N] { + let mut n = xs.len(); + let mut x_ptr = xs.ptr(); + + while n >= S::LEN { + let x = S::load(x_ptr); + accum = simd_op(accum, x); + n -= S::LEN; + x_ptr = x_ptr.add(S::LEN); + } + + let n_mask = S::Mask::first_n(n); + if n > 0 { + let x = S::load_partial(x_ptr, n); + let prev_accum = accum; + let new_accum = simd_op(accum, x); + + for i in 0..N { + accum[i] = prev_accum[i].blend(new_accum[i], n_mask); + } + } + + accum +} From 395d57c040e7bd3f9aae8f3d874174486dd0b53d Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 10 Jan 2025 22:51:45 +0000 Subject: [PATCH 4/5] Add `MinMax` vectorized operation This allows for computing the minimum and maximum values in a slice of floats with one pass over the slice. --- rten-vecmath/src/lib.rs | 2 ++ rten-vecmath/src/min_max.rs | 63 +++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 rten-vecmath/src/min_max.rs diff --git a/rten-vecmath/src/lib.rs b/rten-vecmath/src/lib.rs index 0b0b6e20..48b12097 100644 --- a/rten-vecmath/src/lib.rs +++ b/rten-vecmath/src/lib.rs @@ -83,6 +83,7 @@ mod erf; mod exp; +mod min_max; mod normalize; mod softmax; mod sum; @@ -100,6 +101,7 @@ pub use exp::{Exp, Sigmoid, Silu, Swish}; pub use tanh::Tanh; // Normalization and reduction functions. +pub use min_max::MinMax; pub use normalize::{Normalize, NormalizeOptions}; pub use softmax::Softmax; pub use sum::{Sum, SumSquare, SumSquareSub}; diff --git a/rten-vecmath/src/min_max.rs b/rten-vecmath/src/min_max.rs new file mode 100644 index 00000000..97721473 --- /dev/null +++ b/rten-vecmath/src/min_max.rs @@ -0,0 +1,63 @@ +use rten_simd::dispatch::SimdOp; +use rten_simd::functional::simd_fold_array; +use rten_simd::SimdFloat; + +/// Compute the minimum and maximum values in a slice of floats. +pub struct MinMax<'a> { + input: &'a [f32], +} + +impl<'a> MinMax<'a> { + pub fn new(input: &'a [f32]) -> Self { + MinMax { input } + } +} + +impl SimdOp for MinMax<'_> { + type Output = (f32, f32); + + #[inline(always)] + unsafe fn eval(self) -> Self::Output { + let [vec_min, vec_max] = simd_fold_array( + self.input.into(), + [S::splat(f32::MAX), S::splat(f32::MIN)], + #[inline(always)] + |[min, max], x| [x.min(min), x.max(max)], + ); + let min = vec_min + .to_array() + .as_ref() + .iter() + .fold(f32::MAX, |min, x| x.min(min)); + let max = vec_max + .to_array() + .as_ref() + .iter() + .fold(f32::MIN, |max, x| x.max(max)); + (min, max) + } +} + +#[cfg(test)] +mod tests { + use super::MinMax; + use rten_simd::dispatch::SimdOp; + + // Chosen to not be a multiple of vector size, so that tail handling is + // exercised. + const LEN: usize = 100; + + fn reference_min_max(xs: &[f32]) -> (f32, f32) { + let min = xs.iter().fold(f32::MAX, |min, x| x.min(min)); + let max = xs.iter().fold(f32::MIN, |max, x| x.max(max)); + (min, max) + } + + #[test] + fn test_min_max() { + let xs: Vec = (0..LEN).map(|i| i as f32 * 0.1).collect(); + let expected = reference_min_max(&xs); + let min_max = MinMax::new(&xs).dispatch(); + assert_eq!(min_max, expected); + } +} From b1a9613abd96cf41da359e9f2500a9a3229c222f Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 10 Jan 2025 22:52:26 +0000 Subject: [PATCH 5/5] Optimize computation of input min/max in DynamicQuantizeLinear Previously two separate passes over the data were used to compute the min/max values. Use the `MinMax` op from rten-vecmath to compute this in one vectorized pass. In a benchmark with a quantized ModernBERT model this made DynamicQuantizeLinear 2.5-3x faster. --- src/ops/quantize.rs | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/src/ops/quantize.rs b/src/ops/quantize.rs index 4453ef35..c679dd8a 100644 --- a/src/ops/quantize.rs +++ b/src/ops/quantize.rs @@ -1,11 +1,12 @@ +use rten_simd::dispatch::SimdOp; use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView, Scalar, Tensor, TensorView}; +use rten_vecmath as vecmath; use crate::ops::{ - reduce_max, reduce_min, resolve_axis, DataType, Input, InputList, IntoOpResult, OpError, - Operator, Output, OutputList, + resolve_axis, DataType, Input, InputList, IntoOpResult, OpError, Operator, Output, OutputList, }; -use crate::tensor_pool::{AutoReturn, TensorPool}; +use crate::tensor_pool::TensorPool; /// Convert a quantized tensor element to a higher precision value. pub trait Dequantize { @@ -309,19 +310,9 @@ where let q_min = 0.; let q_max = 255.; - // Get the range of the input. This implementation is simple but sub-optimal - // as it makes two passes over the same data to get the min/max. - let x_min = reduce_min(pool, input.view(), None, false /* keep_dims */)? - .auto_return(pool) - .item() - .copied() - .unwrap(); + let input = input.to_contiguous_in(pool); + let (x_min, x_max) = vecmath::MinMax::new(input.data().unwrap()).dispatch(); let x_min_adjusted = x_min.min(q_min); - let x_max = reduce_max(pool, input.view(), None, false /* keep_dims */)? - .auto_return(pool) - .item() - .copied() - .unwrap(); let x_max_adjusted = x_max.max(q_min); let x_range = x_max_adjusted - x_min_adjusted; let scale = x_range / q_max; @@ -335,7 +326,7 @@ where let zero_point_tensor = Tensor::from(zero_point); let quantized = quantize_linear( pool, - input, + input.view(), scale_tensor.view(), Some(zero_point_tensor.view()), 1, /* axis */