From 947f24225d1d98c76baf28960f6c1c748d0d6aa3 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Thu, 9 Jan 2025 08:25:26 +0000 Subject: [PATCH 1/7] Impl `Identities` for i8, u8 types --- src/number.rs | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/src/number.rs b/src/number.rs index 5cace137..5e9fe731 100644 --- a/src/number.rs +++ b/src/number.rs @@ -43,25 +43,39 @@ pub trait Identities { fn zero() -> Self; } -impl Identities for f32 { - fn one() -> f32 { - 1. - } +macro_rules! impl_float_identities { + ($type:ty) => { + impl Identities for $type { + fn one() -> Self { + 1. + } - fn zero() -> f32 { - 0. - } + fn zero() -> Self { + 0. + } + } + }; } -impl Identities for i32 { - fn one() -> i32 { - 1 - } - fn zero() -> i32 { - 0 - } +macro_rules! impl_int_identities { + ($type:ty) => { + impl Identities for $type { + fn one() -> Self { + 1 + } + + fn zero() -> Self { + 0 + } + } + }; } +impl_float_identities!(f32); +impl_int_identities!(i32); +impl_int_identities!(i8); +impl_int_identities!(u8); + /// Test if a number is a float NaN ("Not a number") value. pub trait IsNaN { /// Return true if the current value is a NaN. See [`f32::is_nan`]. From d69d9167b6193d48cd866528137592fa0cd6db4a Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Thu, 9 Jan 2025 08:26:11 +0000 Subject: [PATCH 2/7] Impl `GemmInT` for i8, u8 and `GemmOutT` for i32 --- src/gemm.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gemm.rs b/src/gemm.rs index 9961eb98..d3d1c4f6 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -174,6 +174,8 @@ impl GemmInputA<'_, T> { /// Trait implemented by GEMM input types. pub trait GemmInT: Copy + Default + Send + Sync + Identities + Pod {} +impl GemmInT for i8 {} +impl GemmInT for u8 {} impl GemmInT for f32 {} /// Trait implemented by GEMM output types. @@ -188,6 +190,7 @@ pub trait GemmOutT: + Pod { } +impl GemmOutT for i32 {} impl GemmOutT for f32 {} /// Right-hand or "B" input for a GEMM operation. From 3c68f6ff390b1c57f01bcd887c693348888c6e12 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 10 Jan 2025 08:13:02 +0000 Subject: [PATCH 3/7] Add type annotations to disambiguate which GEMM kernel to use These tests will run into type errors when new `GemmExecutor` impls are added. --- src/gemm.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gemm.rs b/src/gemm.rs index d3d1c4f6..b433c17d 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -1564,7 +1564,7 @@ mod tests { #[test] fn test_gemm_transposed() -> Result<(), Box> { let mut rng = XorShiftRng::new(1234); - let mut a = NdTensor::rand([20, 30], &mut rng); + let mut a = NdTensor::::rand([20, 30], &mut rng); let mut b = NdTensor::rand([10, 20], &mut rng); // Transpose the input matrices. This will alter their row and column @@ -2144,7 +2144,7 @@ mod tests { let mut rng = XorShiftRng::new(1234); let mut result = NdTensor::zeros([m, n]); - let a = NdTensor::rand([m, k], &mut rng); + let a = NdTensor::::rand([m, k], &mut rng); let b = if transpose_b { let mut b = NdTensor::rand([n, k], &mut rng); b.transpose(); From 0a6b95707b77e495f1d03bf47bd44ce5cecdbbf2 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Thu, 9 Jan 2025 08:26:52 +0000 Subject: [PATCH 4/7] Refactor `GemmExecutor` construction to make supporting new types easier --- src/gemm.rs | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/gemm.rs b/src/gemm.rs index b433c17d..6f95d2e3 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -506,6 +506,13 @@ impl GemmExecutor + 'static>(kernel_type: KernelType) -> Option { + K::new().map(|kernel| GemmExecutor { + kernel: Box::new(kernel), + kernel_type, + }) + } } impl GemmExecutor { @@ -536,37 +543,24 @@ impl GemmExecutor { /// kernel is not supported. #[allow(dead_code)] // Currently only used in tests pub fn with_kernel(hint: KernelType) -> Option { - fn make_kernel + 'static>( - kernel_type: KernelType, - ) -> Option { - K::new().map(|kernel| GemmExecutor { - kernel: Box::new(kernel), - kernel_type, - }) - } - match hint { #[cfg(feature = "avx512")] #[cfg(target_arch = "x86_64")] - KernelType::Avx512 => make_kernel::(hint), + KernelType::Avx512 => Self::from_kernel::(hint), #[cfg(target_arch = "x86_64")] - KernelType::Fma => make_kernel::(hint), + KernelType::Fma => Self::from_kernel::(hint), #[cfg(target_arch = "aarch64")] - KernelType::ArmNeon => make_kernel::(hint), + KernelType::ArmNeon => Self::from_kernel::(hint), #[cfg(target_arch = "wasm32")] #[cfg(target_feature = "simd128")] - KernelType::Wasm => make_kernel::(hint), + KernelType::Wasm => Self::from_kernel::(hint), KernelType::Generic => Some(Self::with_generic_kernel()), } } /// Construct a GemmExecutor that uses the generic kernel. fn with_generic_kernel() -> Self { - let kernel = GenericKernel::new().unwrap(); - GemmExecutor { - kernel: Box::new(kernel), - kernel_type: KernelType::Generic, - } + Self::from_kernel::(KernelType::Generic).unwrap() } } From d1cb76dc4508a7cea9ffc8cfc9193c57d4666dbc Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Wed, 8 Jan 2025 07:52:50 +0000 Subject: [PATCH 5/7] Support quantization params in GEMM kernels, add basic i8 x u8 -> i32 kernel Add arguments to `Kernel` trait methods to support passing zero point to packing and kernel methods. Use this to implement a generic (and slow) i8 x u8 -> i32 kernel. Convert the MatMulInteger operator to use `GemmExecutor` for performing matmuls using this new kernel. Several other parts of the code needed to have generic arguments added to disambiguate which kind of `GemmExecutor` to instantiate. --- src/gemm.rs | 80 +++++++++++-- src/gemm/kernels.rs | 38 +++++- src/gemm/kernels/aarch64.rs | 23 +++- src/gemm/kernels/generic.rs | 227 +++++++++++++++++++++++++++++++++++- src/gemm/kernels/wasm.rs | 23 +++- src/gemm/kernels/x86_64.rs | 44 ++++++- src/ops/conv.rs | 12 +- src/ops/matmul.rs | 185 ++++++++++++++--------------- src/ops/rnn.rs | 8 ++ 9 files changed, 515 insertions(+), 125 deletions(-) diff --git a/src/gemm.rs b/src/gemm.rs index 6f95d2e3..d507343c 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -27,6 +27,7 @@ pub use errors::GemmError; pub use im2col::{ColOffsets, Im2Col, RowOffsets}; use kernels::generic::GenericKernel; use kernels::Kernel; +pub use kernels::QuantParams; use packing::{PackElem, PackingBuffer}; pub type GemmResult = Result<(), GemmError>; @@ -263,6 +264,8 @@ where alpha, beta, None, // bias + None, // a_quant + None, // b_quant ) } @@ -330,11 +333,11 @@ impl GemmExecutor(&self, alloc: A, a: Matrix) -> PackedAMatrix { let depth_block = depth_block_size(a.cols()); - let layout = self.kernel.packed_a_layout(a, a.rows(), depth_block); + let layout = self.kernel.packed_a_layout(a, a.rows(), depth_block, None); let tail_layout = if a.cols() % depth_block != 0 { Some( self.kernel - .packed_a_layout(a, a.rows(), a.cols() % depth_block), + .packed_a_layout(a, a.rows(), a.cols() % depth_block, None), ) } else { None @@ -356,7 +359,7 @@ impl GemmExecutor GemmExecutor(&self, alloc: A, b: Matrix) -> PackedBMatrix { let depth_block = depth_block_size(b.rows()); - let layout = self.kernel.packed_b_layout(depth_block, b.cols()); + let layout = self.kernel.packed_b_layout(depth_block, b.cols(), None); let tail_layout = if b.rows() % depth_block != 0 { Some( self.kernel - .packed_b_layout(b.rows() % depth_block, b.cols()), + .packed_b_layout(b.rows() % depth_block, b.cols(), None), ) } else { None @@ -425,7 +428,7 @@ impl GemmExecutor GemmExecutor>, + a_quant: Option>, + b_quant: Option>, ) -> GemmResult { gemm_impl( &*self.kernel, @@ -479,6 +484,8 @@ impl GemmExecutor GemmExecutor, alpha: f32, bias: Option>, + a_quant: Option>, + b_quant: Option>, ) -> GemmResult { gemm_impl( &*self.kernel, @@ -504,6 +513,8 @@ impl GemmExecutor { } } +impl Default for GemmExecutor { + fn default() -> Self { + Self::from_kernel::(KernelType::Generic).unwrap() + } +} + /// Return the block size for the K / depth dimension of a GEMM operation. /// /// This is chosen such that a `depth_block_size * nr` panel of B fits in the L1 @@ -701,6 +718,8 @@ fn gemv( alpha: f32, beta: OutT, bias: Option>, + a_quant: Option>, + b_quant: Option>, ) { assert!(output_mat.is_contiguous()); @@ -733,7 +752,15 @@ fn gemv( range_chunks(0..a_cols, k_block_size).zip(a_data.chunks(k_block_size)) { let b_block = b.slice((k_block, col_block.clone())); - kernel.gemv_kernel(out_chunk, a_block, b_block, alpha, effective_beta); + kernel.gemv_kernel( + out_chunk, + a_block, + b_block, + alpha, + effective_beta, + a_quant, + b_quant, + ); // Reset `beta` so that subsequent updates for each column // accumulate into the first update. @@ -797,6 +824,8 @@ fn gemm_impl( alpha: f32, beta: OutT, bias: Option>, + a_quant: Option>, + b_quant: Option>, ) -> GemmResult { if a.cols() != b.rows() { return Err(GemmError::KSizeMismatch); @@ -870,6 +899,8 @@ fn gemm_impl( // nb. We checked above that, if present, the bias length matches // `a.rows()` or `b.cols()` as appropriate. bias, + a_quant, + b_quant, ); return Ok(()); } @@ -936,7 +967,8 @@ fn gemm_impl( GemmInputB::Unpacked(_) | GemmInputB::Im2Col(_) => PACKED_B.with(|cell| { let mut packed_b = cell.take(); - let layout = kernel.packed_b_layout(depth_range.len(), col_end - col_start); + let layout = + kernel.packed_b_layout(depth_range.len(), col_end - col_start, b_quant); let packed_uninit = packed_b.alloc(layout.size(), layout.align()); match b { @@ -945,6 +977,7 @@ fn gemm_impl( b, depth_range.clone(), col_start..col_end, + b_quant, ), GemmInputB::Im2Col(im) => kernel.pack_im2col( packed_uninit, @@ -995,6 +1028,7 @@ fn gemm_impl( a, row_end - row_start, depth_range.len(), + a_quant, ); if !layout.must_pack { return LhsBlock::Unpacked(a); @@ -1008,6 +1042,7 @@ fn gemm_impl( a, row_start..row_end, depth_range.clone(), + a_quant, ); // Safety: We initialized `layout.size` bytes. @@ -1034,6 +1069,8 @@ fn gemm_impl( alpha, effective_beta, bias, + a_quant, + b_quant, ); if let Some(packed_a) = thread_local_packed_a { @@ -1094,6 +1131,8 @@ fn gemm_block( alpha: f32, beta: OutT, bias: Option>, + a_quant: Option>, + b_quant: Option>, ) { let (mr, nr) = (kernel.mr(), kernel.nr()); @@ -1111,6 +1150,12 @@ fn gemm_block( .for_each(|(block_col_tile, col_tile)| { let b_panel_offset = block_col_tile * b.panel_stride; let b_panel = &b.data[b_panel_offset..b_panel_offset + b.panel_stride]; + let b_quant_tile = b_quant.map(|bq| { + let col_range = col_tile * nr..(col_tile * nr + nr).min(bq.zero_point.len()); + QuantParams { + zero_point: &bq.zero_point[col_range], + } + }); // Loop over row tiles. for (block_row_tile, row_tile) in row_tiles.clone().enumerate() { @@ -1119,6 +1164,13 @@ fn gemm_block( // every output tile is processed by one thread at a time. let out_tile = unsafe { output.tile(row_tile, col_tile) }; + let a_quant_tile = a_quant.map(|aq| { + let row_range = row_tile * mr..(row_tile * mr + mr).min(aq.zero_point.len()); + QuantParams { + zero_point: &aq.zero_point[row_range], + } + }); + let kernel_lhs = match a { LhsBlock::Packed { data, @@ -1157,6 +1209,8 @@ fn gemm_block( depth_range.len(), alpha, beta, + a_quant_tile, + b_quant_tile, ); } @@ -1355,6 +1409,8 @@ mod tests { alpha, beta, bias, + None, // a_quant + None, // b_quant ) .unwrap(); } @@ -1449,6 +1505,8 @@ mod tests { 1., // alpha 0., // beta None, // bias + None, // a_quant + None, // b_quant ); assert_eq!(result, Err(expected)); } @@ -1782,6 +1840,8 @@ mod tests { 1., // alpha 1., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); @@ -1799,6 +1859,8 @@ mod tests { 1., // alpha 1., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); @@ -1882,6 +1944,8 @@ mod tests { 1., // alpha 0., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); diff --git a/src/gemm/kernels.rs b/src/gemm/kernels.rs index 79d4cb0c..43e81d1b 100644 --- a/src/gemm/kernels.rs +++ b/src/gemm/kernels.rs @@ -87,6 +87,23 @@ impl PackedLayout { } } +/// Parameters required to perform matrix multiplication on quantized tensors. +#[derive(Debug)] +pub struct QuantParams<'a, T> { + /// Values that correspond to zero in each row (for LHS inputs) or column + /// (for RHS inputs). + pub zero_point: &'a [T], +} + +// Make QuantParams Copy/Clone regardless of `T`. +impl Clone for QuantParams<'_, T> { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for QuantParams<'_, T> {} + /// Kernel that computes a small tile of a general matrix multiplication (GEMM) /// or general matrix-vector multiplication (GEMV). /// @@ -132,7 +149,13 @@ pub unsafe trait Kernel: Sync { fn name(&self) -> &'static str; /// Return the layout of a packing buffer required to pack an A / LHS input. - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout; + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + quant: Option>, + ) -> PackedLayout; /// Pack a block of the LHS / "A" input for use by this kernel. fn pack_a_block( @@ -141,6 +164,7 @@ pub unsafe trait Kernel: Sync { a: Matrix, rows: Range, cols: Range, + quant: Option>, ); /// Return the layout of a packing buffer required to pack a block of a "B" @@ -149,7 +173,12 @@ pub unsafe trait Kernel: Sync { /// Unlike `packed_a_layout` this doesn't take the matrix as an argument. /// `packed_a_layout` may use this to indicate that the A input does not /// need to be packed. For the B input it is assumed this is always packed. - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout; + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + quant: Option>, + ) -> PackedLayout; /// Pack a block of the RHS / "B" input for use by this kernel. fn pack_b_block( @@ -158,6 +187,7 @@ pub unsafe trait Kernel: Sync { b: Matrix, rows: Range, cols: Range, + quant: Option>, ); /// Pack a block of an image as the B input for use by this kernel, using @@ -200,6 +230,8 @@ pub unsafe trait Kernel: Sync { depth: usize, alpha: f32, beta: OutT, + a_quant: Option>, + b_quant: Option>, ); /// Compute an output block of a vector-matrix product ("gemv"). @@ -226,6 +258,8 @@ pub unsafe trait Kernel: Sync { b: Matrix, alpha: f32, beta: OutT, + a_quant: Option>, + b_quant: Option>, ); } diff --git a/src/gemm/kernels/aarch64.rs b/src/gemm/kernels/aarch64.rs index a9184dec..6facf613 100644 --- a/src/gemm/kernels/aarch64.rs +++ b/src/gemm/kernels/aarch64.rs @@ -6,7 +6,7 @@ use rten_simd::vec_count; use rten_tensor::{Matrix, MatrixLayout}; use super::simd_generic::{simd_gemv, GemmDispatch}; -use super::{Kernel, Lhs, PackedLayout, TempTile}; +use super::{Kernel, Lhs, PackedLayout, QuantParams, TempTile}; use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_layout, packed_b_layout}; use crate::gemm::Im2Col; use crate::number::{cast_pod_mut_slice, cast_pod_slice}; @@ -39,7 +39,13 @@ unsafe impl Kernel for ArmNeonKernel { Self::NR } - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { let mut info = packed_a_layout::(rows, cols); info.must_pack = a.col_stride() != 1; info @@ -51,12 +57,18 @@ unsafe impl Kernel for ArmNeonKernel { a: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); pack_a_block::(out, a, rows, cols); } - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { packed_b_layout::(rows, cols) } @@ -66,6 +78,7 @@ unsafe impl Kernel for ArmNeonKernel { b: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); pack_b_block::(out, b, rows, cols); @@ -98,6 +111,8 @@ unsafe impl Kernel for ArmNeonKernel { depth: usize, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { const MR: usize = ArmNeonKernel::MR; const NR: usize = ArmNeonKernel::NR; @@ -152,6 +167,8 @@ unsafe impl Kernel for ArmNeonKernel { b: Matrix, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { // Safety - float32x4_t is supported if this kernel was constructed. unsafe { diff --git a/src/gemm/kernels/generic.rs b/src/gemm/kernels/generic.rs index d11fc477..a32c5dbf 100644 --- a/src/gemm/kernels/generic.rs +++ b/src/gemm/kernels/generic.rs @@ -5,7 +5,7 @@ use rten_simd::vec_count; use rten_tensor::{Matrix, MatrixLayout}; use super::simd_generic::{simd_gemv, GemmDispatch}; -use super::{Kernel, Lhs, PackedLayout, TempTile}; +use super::{Kernel, Lhs, PackedLayout, QuantParams, TempTile}; use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_layout, packed_b_layout}; use crate::gemm::Im2Col; use crate::number::{cast_pod_mut_slice, cast_pod_slice}; @@ -44,7 +44,13 @@ unsafe impl Kernel for GenericKernel { "base" } - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { let mut info = packed_a_layout::(rows, cols); info.must_pack = a.col_stride() != 1; info @@ -56,12 +62,18 @@ unsafe impl Kernel for GenericKernel { a: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).unwrap(); pack_a_block::(out, a, rows, cols); } - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { packed_b_layout::(rows, cols) } @@ -71,6 +83,7 @@ unsafe impl Kernel for GenericKernel { b: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).unwrap(); pack_b_block::(out, b, rows, cols); @@ -103,6 +116,8 @@ unsafe impl Kernel for GenericKernel { depth: usize, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { const MR: usize = GenericKernel::MR; const NR: usize = GenericKernel::NR; @@ -156,6 +171,8 @@ unsafe impl Kernel for GenericKernel { b: Matrix, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { // Safety - f32 "SIMD" type is always supported unsafe { @@ -163,3 +180,207 @@ unsafe impl Kernel for GenericKernel { } } } + +unsafe impl Kernel for GenericKernel { + fn new() -> Option { + Some(GenericKernel { _private: () }) + } + + fn mr(&self) -> usize { + Self::MR + } + + fn nr(&self) -> usize { + Self::NR + } + + fn name(&self) -> &'static str { + "generic-i8" + } + + fn packed_a_layout( + &self, + _a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { + let mut info = packed_a_layout::(rows, cols); + info.must_pack = true; + info + } + + fn pack_a_block( + &self, + out: &mut [MaybeUninit], + a: Matrix, + rows: Range, + cols: Range, + _quant: Option>, + ) { + let out = cast_pod_mut_slice(out).unwrap(); + pack_a_block::(out, a, rows, cols); + } + + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { + packed_b_layout::(rows, cols) + } + + fn pack_b_block( + &self, + out: &mut [MaybeUninit], + b: Matrix, + rows: Range, + cols: Range, + _quant: Option>, + ) { + let out = cast_pod_mut_slice(out).unwrap(); + pack_b_block::(out, b, rows, cols); + } + + fn pack_im2col( + &self, + _out: &mut [MaybeUninit], + _image: &Im2Col, + _rows: Range, + _cols: Range, + ) { + unimplemented!("im2col packing not implemented"); + } + + unsafe fn kernel( + &self, + tile_ptr: *mut i32, + tile_row_stride: usize, + a: Lhs, + b: &[u8], + used_rows: usize, + used_cols: usize, + depth: usize, + alpha: f32, + beta: i32, + a_quant: Option>, + b_quant: Option>, + ) { + assert_eq!(alpha, 1.); + assert!(beta == 0 || beta == 1, "unsupported beta value"); + assert!(used_rows <= MR); + assert!(used_cols <= NR); + + const MR: usize = GenericKernel::MR; + const NR: usize = GenericKernel::NR; + + let a_data = match a { + Lhs::Packed(packed) => packed, + Lhs::Unpacked { .. } => panic!("inputs must be packed"), + }; + let a_row_stride = depth; + + let mut a_zero_point = [0u8; MR]; + if let Some(a_quant) = a_quant { + #[allow(clippy::manual_memcpy)] + for row in 0..used_rows { + a_zero_point[row] = a_quant.zero_point[row]; + } + } + let mut b_zero_point = [0i8; NR]; + if let Some(b_quant) = b_quant { + #[allow(clippy::manual_memcpy)] + for col in 0..used_cols { + b_zero_point[col] = b_quant.zero_point[col]; + } + } + + let b: &[i8] = cast_pod_slice(b).unwrap(); + let use_tmp_tile = used_cols < NR || used_rows < MR; + + let mut tmp_tile = TempTile::::new(); + let (dest_ptr, dest_row_stride, dest_beta) = if !use_tmp_tile { + (tile_ptr, tile_row_stride, beta) + } else { + (tmp_tile.as_mut_ptr() as *mut i32, NR, 0) + }; + + let mut tmp = [[0i32; NR]; MR]; + for k in 0..depth { + for row in 0..MR { + let a_i32 = unsafe { *a_data.get_unchecked(row * a_row_stride + k) } as i32 + - a_zero_point[row] as i32; + for col in 0..NR { + let b_i32 = + unsafe { *b.get_unchecked(k * NR + col) } as i32 - b_zero_point[col] as i32; + tmp[row][col] += a_i32 * b_i32; + } + } + } + + if dest_beta == 0 { + for row in 0..used_rows { + for col in 0..used_cols { + dest_ptr + .add(row * dest_row_stride + col) + .write(tmp[row][col]); + } + } + } else { + // nb. We require that beta is 0 or 1, so here it is 1. + for row in 0..used_rows { + for col in 0..used_cols { + *dest_ptr.add(row * dest_row_stride + col) += tmp[row][col]; + } + } + } + + if use_tmp_tile { + tmp_tile.accumulate_into( + tile_ptr as *mut MaybeUninit, + used_rows, + used_cols, + tile_row_stride, + beta, + ); + } + } + + fn gemv_kernel( + &self, + out: &mut [MaybeUninit], + a: &[u8], + b: Matrix, + alpha: f32, + beta: i32, + a_quant: Option>, + b_quant: Option>, + ) { + assert!(beta == 0 || beta == 1); + assert_eq!(alpha, 1.); + assert_eq!(b.rows(), a.len()); + assert_eq!(out.len(), b.cols()); + + let a_zero = a_quant.map(|aq| aq.zero_point[0] as i32).unwrap_or(0); + let depth = a.len(); + + for (out, col) in out.iter_mut().zip(0..b.cols()) { + let b_zero = b_quant.map(|bq| bq.zero_point[col] as i32).unwrap_or(0); + let mut acc = 0; + for k in 0..depth { + let a_el = unsafe { *a.get_unchecked(k) } as i32 - a_zero; + let b_el = unsafe { *b.get_unchecked([k, col]) } as i32 - b_zero; + acc += a_el * b_el; + } + if beta == 0 { + out.write(acc); + } else { + // Safety: Output is initialized when beta is non-zero + unsafe { + out.write(out.assume_init() + acc); + } + } + } + } +} diff --git a/src/gemm/kernels/wasm.rs b/src/gemm/kernels/wasm.rs index cf9d2919..50613a63 100644 --- a/src/gemm/kernels/wasm.rs +++ b/src/gemm/kernels/wasm.rs @@ -6,7 +6,7 @@ use rten_simd::vec_count; use rten_tensor::{Matrix, MatrixLayout}; use super::simd_generic::{simd_gemv, GemmDispatch}; -use super::{Kernel, Lhs, PackedLayout, TempTile}; +use super::{Kernel, Lhs, PackedLayout, QuantParams, TempTile}; use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_layout, packed_b_layout}; use crate::gemm::Im2Col; use crate::number::{cast_pod_mut_slice, cast_pod_slice}; @@ -43,7 +43,13 @@ unsafe impl Kernel for WasmKernel { Self::NR } - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { let mut info = packed_a_layout::(rows, cols); info.must_pack = a.col_stride() != 1; info @@ -55,12 +61,18 @@ unsafe impl Kernel for WasmKernel { a: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).unwrap(); pack_a_block::(out, a, rows, cols); } - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { packed_b_layout::(rows, cols) } @@ -70,6 +82,7 @@ unsafe impl Kernel for WasmKernel { b: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).unwrap(); pack_b_block::(out, b, rows, cols); @@ -102,6 +115,8 @@ unsafe impl Kernel for WasmKernel { depth: usize, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { const MR: usize = WasmKernel::MR; const NR: usize = WasmKernel::NR; @@ -155,6 +170,8 @@ unsafe impl Kernel for WasmKernel { b: Matrix, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { // Safety - WASM SIMD types are supported if this kernel was constructed. unsafe { diff --git a/src/gemm/kernels/x86_64.rs b/src/gemm/kernels/x86_64.rs index c5045aaf..d0b31617 100644 --- a/src/gemm/kernels/x86_64.rs +++ b/src/gemm/kernels/x86_64.rs @@ -12,7 +12,7 @@ use rten_tensor::{Matrix, MatrixLayout}; use rten_simd::isa_detection::is_avx512_supported; use super::simd_generic::{simd_gemv, GemmDispatch}; -use super::{Kernel, Lhs, PackedLayout, TempTile}; +use super::{Kernel, Lhs, PackedLayout, QuantParams, TempTile}; use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_layout, packed_b_layout}; use crate::gemm::Im2Col; use crate::number::{cast_pod_mut_slice, cast_pod_slice}; @@ -84,7 +84,13 @@ unsafe impl Kernel for FmaKernel { Self::NR } - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { let mut info = packed_a_layout::(rows, cols); info.must_pack = a.col_stride() != 1; info @@ -96,6 +102,7 @@ unsafe impl Kernel for FmaKernel { a: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); @@ -105,7 +112,12 @@ unsafe impl Kernel for FmaKernel { } } - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { packed_b_layout::(rows, cols) } @@ -115,6 +127,7 @@ unsafe impl Kernel for FmaKernel { b: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).unwrap(); @@ -153,6 +166,8 @@ unsafe impl Kernel for FmaKernel { depth: usize, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { const MR: usize = FmaKernel::MR; const NR: usize = FmaKernel::NR; @@ -206,6 +221,8 @@ unsafe impl Kernel for FmaKernel { b: Matrix, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { #[target_feature(enable = "avx2")] #[target_feature(enable = "fma")] @@ -275,7 +292,13 @@ unsafe impl Kernel for Avx512Kernel { Self::NR } - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { let mut info = packed_a_layout::(rows, cols); info.must_pack = a.col_stride() != 1; info @@ -287,6 +310,7 @@ unsafe impl Kernel for Avx512Kernel { a: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); @@ -296,7 +320,12 @@ unsafe impl Kernel for Avx512Kernel { } } - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { packed_b_layout::(rows, cols) } @@ -306,6 +335,7 @@ unsafe impl Kernel for Avx512Kernel { b: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); @@ -344,6 +374,8 @@ unsafe impl Kernel for Avx512Kernel { depth: usize, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { const MR: usize = Avx512Kernel::MR; const NR: usize = Avx512Kernel::NR; @@ -397,6 +429,8 @@ unsafe impl Kernel for Avx512Kernel { b: Matrix, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { #[target_feature(enable = "avx512f")] #[target_feature(enable = "avx512vl")] diff --git a/src/ops/conv.rs b/src/ops/conv.rs index 00772cdd..2620c4ed 100644 --- a/src/ops/conv.rs +++ b/src/ops/conv.rs @@ -57,6 +57,8 @@ where GemmInputB::Unpacked(in_mat.view()), 1., // alpha bias_vec, + None, // a_quant + None, // b_quant ) .unwrap(); n_init += out_item.len(); @@ -287,6 +289,8 @@ where GemmInputB::Im2Col(&im2col), 1., // alpha bias_vec, + None, // a_quant + None, // b_quant ) .unwrap(); n_init.fetch_add(out_mat.len(), Ordering::SeqCst); @@ -319,7 +323,7 @@ impl Operator for Conv { let input = inputs.require_as(0)?; let weight = inputs.require_as(1)?; let bias = inputs.get_as(2)?; - conv( + conv::( pool, input, weight, @@ -564,6 +568,8 @@ pub fn conv_transpose( GemmInputB::Unpacked(input_mat.view()), 1., // alpha None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); @@ -1623,8 +1629,8 @@ mod tests { // This has a small spatial shape, so it measures overhead around the // inner loop. A larger spatial shape would be more affected by the // efficiency of the innermost loops. - let input = Tensor::rand(&[1, 576, 14, 14], &mut rng); - let kernel = Tensor::rand(&[576, 1, 3, 3], &mut rng); + let input = Tensor::::rand(&[1, 576, 14, 14], &mut rng); + let kernel = Tensor::::rand(&[576, 1, 3, 3], &mut rng); let n_groups = input.size(1); let padding = Padding::Fixed([1, 1, 1, 1].into()); diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index df204657..a94e1f77 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -5,9 +5,8 @@ use rten_tensor::{Matrix, NdTensorView, Tensor, TensorView}; use smallvec::SmallVec; use crate::gemm::{ - BiasVector, GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, PackedBMatrix, + BiasVector, GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, PackedBMatrix, QuantParams, }; -use crate::iter_util::range_chunks; use crate::ops::binary_elementwise::broadcast_shapes; use crate::ops::layout::expand_to; use crate::ops::{ @@ -60,6 +59,8 @@ where alpha, beta, None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); output @@ -74,6 +75,8 @@ where GemmInputB::Unpacked(b.nd_view()), alpha, None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); // Safety: `gemm_uninit` initialized all elements @@ -144,7 +147,17 @@ pub fn matmul( where GemmExecutor: Default, { - matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, None, None) + matmul_impl( + pool, + a, + b, + packed_b, + MatmulStrategy::Auto, + None, + None, + None, /* a_quant */ + None, /* b_quant */ + ) } fn matmul_impl( @@ -155,6 +168,8 @@ fn matmul_impl( strategy: MatmulStrategy, bias: Option>, alpha: Option, + a_quant: Option>, + b_quant: Option>, ) -> Result, OpError> where GemmExecutor: Default, @@ -215,6 +230,8 @@ where strategy, bias, alpha, + a_quant, + b_quant, )?; output.reshape(out_shape); return Ok(output); @@ -288,6 +305,8 @@ where b_input, alpha.unwrap_or(1.), bias, + a_quant, + b_quant, ) .unwrap(); }); @@ -347,7 +366,17 @@ pub fn matmul_fused( where GemmExecutor: Default, { - matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, bias, alpha) + matmul_impl( + pool, + a, + b, + packed_b, + MatmulStrategy::Auto, + bias, + alpha, + None, /* a_quant */ + None, /* b_quant */ + ) } /// MatMul with fused addition of bias and scaling of result. @@ -400,42 +429,6 @@ pub fn matmul_integer( a_zero_point: Option>, b_zero_point: Option>, ) -> Result, OpError> { - if a.ndim() < 2 || b.ndim() < 2 { - return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions")); - } - - let a_rows = a.size(a.ndim() - 2); - let a_cols = a.size(a.ndim() - 1); - - let b_rows = b.size(b.ndim() - 2); - let b_cols = b.size(b.ndim() - 1); - - if a_cols != b_rows { - return Err(OpError::IncompatibleInputShapes( - "Columns of first matrix does not match rows of second matrix", - )); - } - - let a_prefix = &a.shape()[..a.ndim() - 2]; - let b_prefix = &b.shape()[..b.ndim() - 2]; - - let out_prefix = broadcast_shapes(a_prefix, b_prefix) - .ok_or(OpError::IncompatibleInputShapes("Cannot broadcast shapes"))?; - let out_shape = &[out_prefix.as_slice(), &[a_rows, b_cols]].concat(); - - let mut output = Tensor::::uninit_in(pool, out_shape); - if output.is_empty() { - // nb. We don't need to alloc from the pool here, since the buffer - // is already empty. - return Ok(Tensor::zeros(out_shape)); - } - - let a_broadcast_shape = [out_prefix.as_slice(), &[a_rows, a_cols]].concat(); - let b_broadcast_shape = [out_prefix.as_slice(), &[b_rows, b_cols]].concat(); - - let a_broadcast = a.broadcast(a_broadcast_shape.as_slice()); - let b_broadcast = b.broadcast(b_broadcast_shape.as_slice()); - // Convert the zero point to a vector. // // The spec allows for the zero point to be a scalar, vector or a batch of @@ -460,58 +453,33 @@ pub fn matmul_integer( } } - let a_zero = zero_point_to_vec(a_zero_point, a_rows)?; - let b_zero = zero_point_to_vec(b_zero_point, b_cols)?; + if a.ndim() < 2 || b.ndim() < 2 { + return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions")); + } + let a_rows = a.size(a.ndim() - 2); + let b_cols = b.size(b.ndim() - 1); - a_broadcast - .inner_iter::<2>() - .zip(b_broadcast.inner_iter::<2>()) - .zip(output.inner_iter_mut::<2>()) - .par_bridge() - .for_each(|((a_mat, b_mat), mut out_mat)| { - let [m, k] = a_mat.shape(); - let [bk, n] = b_mat.shape(); - assert_eq!(k, bk); - assert_eq!(out_mat.shape(), [m, n]); - - // Do some extremely rudimentary cache blocking. - for col_block in range_chunks(0..n, 32) { - for depth_block in range_chunks(0..k, 32) { - for row_block in range_chunks(0..m, 32) { - for j in col_block.clone() { - let b_zero = b_zero.as_ref().map(|zp| zp[j]).unwrap_or(0) as i32; - - for i in row_block.clone() { - let a_zero = a_zero.as_ref().map(|zp| zp[i]).unwrap_or(0) as i32; - - let mut out = 0i32; - for k in depth_block.clone() { - // Safety: `[i, k]` is in-bounds for `a_mat`. - let a = unsafe { *a_mat.get_unchecked([i, k]) } as i32 - a_zero; - // Safety: `[k, j]` is in-bounds for `b_mat`. - let b = unsafe { *b_mat.get_unchecked([k, j]) } as i32 - b_zero; - out += a * b; - } - unsafe { - // Safety: `[i, j]` is in-bounds for `b_mat`. - let el = out_mat.get_unchecked_mut([i, j]); - if depth_block.start == 0 { - el.write(out); - } else { - el.write(el.assume_init() + out); - } - } - } - } - } - } - } - }); + let a_zero = zero_point_to_vec(a_zero_point, a_rows)?.map(|zp| zp.to_contiguous()); + let a_quant = a_zero.as_ref().map(|zp| QuantParams { + zero_point: zp.data().unwrap(), + }); - // Safety: Loop above initialized all output elements. - let output = unsafe { output.assume_init() }; + let b_zero = zero_point_to_vec(b_zero_point, b_cols)?.map(|zp| zp.to_contiguous()); + let b_quant = b_zero.as_ref().map(|zp| QuantParams { + zero_point: zp.data().unwrap(), + }); - Ok(output) + matmul_impl( + pool, + a, + b, + None, + MatmulStrategy::Auto, + None, + None, + a_quant, + b_quant, + ) } #[derive(Debug)] @@ -621,6 +589,8 @@ mod tests { alpha.unwrap_or(1.), 0., /* beta */ bias, + None, // a_quant + None, // b_quant ) .unwrap() }); @@ -1018,8 +988,8 @@ mod tests { } in cases { let mut rng = XorShiftRng::new(1234); - let a = Tensor::rand(a_shape, &mut rng); - let b = Tensor::rand(b_shape, &mut rng); + let a = Tensor::::rand(a_shape, &mut rng); + let b = Tensor::::rand(b_shape, &mut rng); let result = matmul(&pool, a.view(), b.view(), None); assert_eq!(result, Err(error)); @@ -1045,8 +1015,8 @@ mod tests { let pool = new_pool(); for Case { m, n, k } in cases { let mut rng = XorShiftRng::new(1234); - let a = Tensor::rand(&[m, k], &mut rng); - let b = Tensor::rand(&[k, n], &mut rng); + let a = Tensor::::rand(&[m, k], &mut rng); + let b = Tensor::::rand(&[k, n], &mut rng); let result = matmul(&pool, a.view(), b.view(), None).unwrap(); assert_eq!(result.shape(), &[m, n]); @@ -1091,6 +1061,15 @@ mod tests { b_zero_point: Some(Tensor::from([3, 4])), expected_err: None, }, + // A input which is a row vector + Case { + a: Tensor::from([[1, 2, 3, 4]]), + b: Tensor::from([[5, 6], [7, 8], [9, 10], [11, 12]]), + a_zero_point: Some(Tensor::from([1])), + b_zero_point: Some(Tensor::from([3, 4])), + expected_err: None, + }, + // Incorrect zero point size Case { a: Tensor::from([[1, 2], [3, 4]]), b: Tensor::from([[5, 6], [7, 8]]), @@ -1228,8 +1207,8 @@ mod tests { } in cases { let mut rng = XorShiftRng::new(1234); - let a = Tensor::rand(&[a_batch, a_rows, a_cols], &mut rng); - let b = Tensor::rand(&[a_cols, b_cols], &mut rng); + let a = Tensor::::rand(&[a_batch, a_rows, a_cols], &mut rng); + let b = Tensor::::rand(&[a_cols, b_cols], &mut rng); let run_trial = |strategy| { let trials = 10; @@ -1238,9 +1217,19 @@ mod tests { ); let pool = new_pool(); run_bench(trials, Some(&desc), || { - matmul_impl(&pool, a.view(), b.view(), None, strategy, None, None) - .unwrap() - .auto_return(&pool); + matmul_impl( + &pool, + a.view(), + b.view(), + None, + strategy, + None, + None, + None, + None, + ) + .unwrap() + .auto_return(&pool); }); }; diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index 3c9e7107..736950b9 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -241,6 +241,8 @@ pub fn gru( 1., // alpha 0., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); if let Some(input_bias) = input_bias { @@ -257,6 +259,8 @@ pub fn gru( 1., // alpha 0., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); if let Some(hidden_bias) = hidden_bias { @@ -498,6 +502,8 @@ pub fn lstm( 1., // alpha 0., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); if let Some(input_bias) = input_bias { @@ -512,6 +518,8 @@ pub fn lstm( 1., // alpha 1., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); if let Some(hidden_bias) = hidden_bias { From 882615ce618626e557ee286ad3a670622ec874d2 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 10 Jan 2025 08:18:02 +0000 Subject: [PATCH 6/7] Add u8 x i8 -> i32 GEMM tests --- src/gemm.rs | 120 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 105 insertions(+), 15 deletions(-) diff --git a/src/gemm.rs b/src/gemm.rs index d507343c..f2be5395 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -1268,7 +1268,7 @@ mod tests { use super::{ BiasVector, ColOffsets, GemmError, GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, - Im2Col, KernelType, RowOffsets, + Im2Col, KernelType, QuantParams, RowOffsets, }; /// Scale a possibly non-float value by a float. @@ -1293,7 +1293,14 @@ mod tests { /// Type that can be used as the output for the reference GEMM /// implementation. trait RefGemmOutT: - Default + GemmOutT + From + From + MulFloat + ApproxEq + std::fmt::Debug + Default + + GemmOutT + + From + + From + + MulFloat + + ApproxEq + + std::fmt::Debug + + std::ops::Sub { } @@ -1304,19 +1311,30 @@ mod tests { { } + impl RefGemmOutT for i32 + where + i32: From, + i32: From, + { + } + #[derive(Clone)] - struct GemmOpts<'a, OutT> { + struct GemmOpts<'a, LhsT, RhsT, OutT> { alpha: f32, beta: OutT, bias: Option>, + a_quant: Option>, + b_quant: Option>, } - impl Default for GemmOpts<'_, OutT> { + impl Default for GemmOpts<'_, LhsT, RhsT, OutT> { fn default() -> Self { GemmOpts { alpha: 1., beta: OutT::zero(), bias: None, + a_quant: None, + b_quant: None, } } } @@ -1329,19 +1347,36 @@ mod tests { mut output: MatrixMut, a: Matrix, b: Matrix, - opts: Option>, + opts: Option>, ) where LhsT: GemmInT, RhsT: GemmInT, OutT: RefGemmOutT, { - let GemmOpts { alpha, beta, bias } = opts.unwrap_or_default(); + let GemmOpts { + alpha, + beta, + bias, + a_quant, + b_quant, + } = opts.unwrap_or_default(); for r in 0..a.rows() { + let a_zero = a_quant + .as_ref() + .map(|aq| OutT::from(aq.zero_point[r])) + .unwrap_or(OutT::zero()); for c in 0..b.cols() { + let b_zero = b_quant + .as_ref() + .map(|bq| OutT::from(bq.zero_point[c])) + .unwrap_or(OutT::zero()); + let mut accum = OutT::zero(); for k in 0..a.cols() { - accum = accum + OutT::from(a[[r, k]]) * OutT::from(b[[k, c]]); + let a_el = OutT::from(a[[r, k]]) - a_zero; + let b_el = OutT::from(b[[k, c]]) - b_zero; + accum = accum + a_el * b_el; } let bias = match bias { Some(BiasVector::Row(b)) => b[c], @@ -1356,7 +1391,7 @@ mod tests { fn reference_matmul( a: Matrix, b: Matrix, - opts: Option>, + opts: Option>, ) -> NdTensor where LhsT: GemmInT, @@ -1391,7 +1426,7 @@ mod tests { mut output: MatrixMut, a: Matrix, b: Matrix, - opts: Option>, + opts: Option>, gemm: Option<&GemmExecutor>, ) where GemmExecutor: Default, @@ -1399,7 +1434,13 @@ mod tests { let out_row_stride = output.stride(0); let default_gemm = GemmExecutor::default(); let gemm = gemm.unwrap_or(&default_gemm); - let GemmOpts { alpha, beta, bias } = opts.unwrap_or_default(); + let GemmOpts { + alpha, + beta, + bias, + a_quant, + b_quant, + } = opts.unwrap_or_default(); gemm.gemm( output.data_mut().expect("expected contiguous input"), @@ -1409,8 +1450,8 @@ mod tests { alpha, beta, bias, - None, // a_quant - None, // b_quant + a_quant, + b_quant, ) .unwrap(); } @@ -1418,7 +1459,7 @@ mod tests { fn run_matmul( a: Matrix, b: Matrix, - opts: Option>, + opts: Option>, gemm: Option<&GemmExecutor>, ) -> NdTensor where @@ -1434,7 +1475,7 @@ mod tests { fn run_compare_matmul>( a: Matrix, b: Matrix, - opts: Option>, + opts: Option>, gemm: Option<&GemmExecutor>, ) where GemmExecutor: Default, @@ -1446,7 +1487,7 @@ mod tests { // Simplest possible test case for easy debugging. #[test] - fn test_simple_gemm() -> Result<(), Box> { + fn test_simple_gemm_f32() -> Result<(), Box> { let a = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]); let b = NdTensor::from_data([2, 2], vec![5., 6., 7., 8.]); run_compare_matmul(a.view(), b.view(), None, None); @@ -1459,6 +1500,14 @@ mod tests { Ok(()) } + #[test] + fn test_simple_gemm_u8i8_i32() -> Result<(), Box> { + let a = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); + let b = NdTensor::from_data([2, 2], vec![5, 6, 7, 8]); + run_compare_matmul::(a.view(), b.view(), None, None); + Ok(()) + } + #[test] fn test_gemm_input_errors() { struct Case { @@ -1613,6 +1662,47 @@ mod tests { test_gemm_various_input_sizes(Some(&gemm)) } + #[test] + fn test_gemm_u8i8_i32() -> Result<(), Box> { + let gemm = GemmExecutor::::default(); + test_gemm_various_input_sizes(Some(&gemm)) + } + + #[test] + fn test_gemm_u8i8_i32_zero_point() { + let mut rng = XorShiftRng::new(1234); + + struct Case { + m: usize, + n: usize, + k: usize, + } + + let cases = [ + // Matrix-matrix + Case { m: 5, n: 7, k: 10 }, + // Vector-matrix + Case { m: 1, n: 5, k: 10 }, + ]; + + for Case { m, n, k } in cases { + let a = NdTensor::::rand([m, k], &mut rng); + let b = NdTensor::::rand([k, n], &mut rng); + let a_zero_point: Vec<_> = (0..a.rows() as u8).collect(); + let b_zero_point: Vec<_> = (0..b.cols() as i8).collect(); + let opts = Some(GemmOpts { + a_quant: Some(QuantParams { + zero_point: &a_zero_point, + }), + b_quant: Some(QuantParams { + zero_point: &b_zero_point, + }), + ..Default::default() + }); + run_compare_matmul(a.view(), b.view(), opts, None); + } + } + #[test] fn test_gemm_transposed() -> Result<(), Box> { let mut rng = XorShiftRng::new(1234); From e05f2984a0d7b2858892352d7e6cf228771c550f Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Fri, 10 Jan 2025 09:18:52 +0000 Subject: [PATCH 7/7] Validate zero point slice length matches corresponding GEMM input --- src/gemm.rs | 82 +++++++++++++++++++++++++++++++++++++++------- src/gemm/errors.rs | 5 +++ 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/src/gemm.rs b/src/gemm.rs index f2be5395..a91a3ae5 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -840,6 +840,24 @@ fn gemm_impl( return Err(GemmError::WrongBiasSize); } + match a_quant { + Some(quant) => { + if quant.zero_point.len() != a.rows() { + return Err(GemmError::WrongQuantParamSize); + } + } + None => {} + } + + match b_quant { + Some(quant) => { + if quant.zero_point.len() != b.cols() { + return Err(GemmError::WrongQuantParamSize); + } + } + None => {} + } + // Handle case where output is empty. if a.rows() == 0 || b.cols() == 0 { return Ok(()); @@ -1428,7 +1446,8 @@ mod tests { b: Matrix, opts: Option>, gemm: Option<&GemmExecutor>, - ) where + ) -> super::GemmResult + where GemmExecutor: Default, { let out_row_stride = output.stride(0); @@ -1453,7 +1472,6 @@ mod tests { a_quant, b_quant, ) - .unwrap(); } fn run_matmul( @@ -1461,13 +1479,13 @@ mod tests { b: Matrix, opts: Option>, gemm: Option<&GemmExecutor>, - ) -> NdTensor + ) -> Result, GemmError> where GemmExecutor: Default, { let mut output = NdTensor::zeros([a.rows(), b.cols()]); - run_gemm(output.view_mut(), a, b, opts, gemm); - output + run_gemm(output.view_mut(), a, b, opts, gemm)?; + Ok(output) } /// Run a matmul with the reference and real implementations and verify @@ -1480,7 +1498,7 @@ mod tests { ) where GemmExecutor: Default, { - let result = run_matmul(a.view(), b.view(), opts.clone(), gemm); + let result = run_matmul(a.view(), b.view(), opts.clone(), gemm).unwrap(); let expected = reference_matmul(a.view(), b.view(), opts); expect_equal(&result, &expected).unwrap(); } @@ -1611,7 +1629,7 @@ mod tests { let a = NdTensor::::rand(lhs_size, &mut rng); let b = NdTensor::::rand(rhs_size, &mut rng); - let result = run_matmul(a.view(), b.view(), None, gemm); + let result = run_matmul(a.view(), b.view(), None, gemm).unwrap(); let expected = reference_matmul(a.view(), b.view(), None); if let Err(err) = expect_equal(&result, &expected) { @@ -1703,6 +1721,48 @@ mod tests { } } + #[test] + fn test_gemm_u8i8_i32_invalid_zero_point() { + let mut rng = XorShiftRng::new(1234); + let a = NdTensor::::rand([5, 10], &mut rng); + let b = NdTensor::::rand([10, 3], &mut rng); + + fn gemm_opts<'a>( + a_zero_point: &'a [u8], + b_zero_point: &'a [i8], + ) -> GemmOpts<'a, u8, i8, i32> { + GemmOpts { + a_quant: Some(QuantParams { + zero_point: a_zero_point, + }), + b_quant: Some(QuantParams { + zero_point: b_zero_point, + }), + ..Default::default() + } + } + let a_zero_point: Vec<_> = (0..a.rows()).map(|row| row as u8).collect(); + let b_zero_point: Vec<_> = (0..b.cols()).map(|col| col as i8).collect(); + + // LHS zero point does not match LHS rows. + let result = run_matmul( + a.view(), + b.view(), + Some(gemm_opts(&[1, 2, 3], &b_zero_point)), + None, + ); + assert_eq!(result, Err(GemmError::WrongQuantParamSize)); + + // RHS zero point does not match RHS columns. + let result = run_matmul( + a.view(), + b.view(), + Some(gemm_opts(&a_zero_point, &[1, 2, 3, 4])), + None, + ); + assert_eq!(result, Err(GemmError::WrongQuantParamSize)); + } + #[test] fn test_gemm_transposed() -> Result<(), Box> { let mut rng = XorShiftRng::new(1234); @@ -1777,7 +1837,7 @@ mod tests { b.view(), opts.clone(), Some(&gemm), - ); + )?; reference_gemm(expected.view_mut(), a.view(), b.view(), opts); expect_equal(&result, &expected)?; @@ -1825,7 +1885,7 @@ mod tests { alpha, ..Default::default() }); - run_gemm(result.view_mut(), a.view(), b.view(), opts.clone(), None); + run_gemm(result.view_mut(), a.view(), b.view(), opts.clone(), None)?; let expected = reference_matmul(a.view(), b.view(), opts); expect_equal(&result, &expected)?; } @@ -2238,7 +2298,7 @@ mod tests { ..Default::default() }); - run_gemm(result.view_mut(), a.view(), b.view(), opts.clone(), None); + run_gemm(result.view_mut(), a.view(), b.view(), opts.clone(), None).unwrap(); let mut expected = NdTensor::zeros([1, b.size(1)]); reference_gemm(expected.view_mut(), a.view(), b.view(), opts); @@ -2303,7 +2363,7 @@ mod tests { let start = Instant::now(); for _i in 0..iters { - run_gemm(result.view_mut(), a.view(), b.view(), None, None); + run_gemm(result.view_mut(), a.view(), b.view(), None, None).unwrap(); } let duration = start.elapsed(); diff --git a/src/gemm/errors.rs b/src/gemm/errors.rs index 0a919ec9..0ba895cd 100644 --- a/src/gemm/errors.rs +++ b/src/gemm/errors.rs @@ -9,6 +9,8 @@ pub enum GemmError { KSizeMismatch, /// Bias vector length does not match the corresponding output matrix size. WrongBiasSize, + /// Quantization parameter size does not match corresponding input size. + WrongQuantParamSize, /// The buffer provided for the output is too short. OutputNotLargeEnough, /// The data was packed with a kernel that uses a different layout than @@ -26,6 +28,9 @@ impl Display for GemmError { write!(fmt, "columns of matrix `a` must match rows of matrix `b`") } Self::WrongBiasSize => write!(fmt, "bias vector length is incorrect"), + Self::WrongQuantParamSize => { + write!(fmt, "quantization parameter size does not match input") + } Self::OutputNotLargeEnough => write!(fmt, "output buffer is too small"), Self::PackedDataKernelMismatch => { write!(fmt, "matrix was packed with a different kernel")