diff --git a/src/gemm.rs b/src/gemm.rs index 9961eb98..a91a3ae5 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -27,6 +27,7 @@ pub use errors::GemmError; pub use im2col::{ColOffsets, Im2Col, RowOffsets}; use kernels::generic::GenericKernel; use kernels::Kernel; +pub use kernels::QuantParams; use packing::{PackElem, PackingBuffer}; pub type GemmResult = Result<(), GemmError>; @@ -174,6 +175,8 @@ impl GemmInputA<'_, T> { /// Trait implemented by GEMM input types. pub trait GemmInT: Copy + Default + Send + Sync + Identities + Pod {} +impl GemmInT for i8 {} +impl GemmInT for u8 {} impl GemmInT for f32 {} /// Trait implemented by GEMM output types. @@ -188,6 +191,7 @@ pub trait GemmOutT: + Pod { } +impl GemmOutT for i32 {} impl GemmOutT for f32 {} /// Right-hand or "B" input for a GEMM operation. @@ -260,6 +264,8 @@ where alpha, beta, None, // bias + None, // a_quant + None, // b_quant ) } @@ -327,11 +333,11 @@ impl GemmExecutor(&self, alloc: A, a: Matrix) -> PackedAMatrix { let depth_block = depth_block_size(a.cols()); - let layout = self.kernel.packed_a_layout(a, a.rows(), depth_block); + let layout = self.kernel.packed_a_layout(a, a.rows(), depth_block, None); let tail_layout = if a.cols() % depth_block != 0 { Some( self.kernel - .packed_a_layout(a, a.rows(), a.cols() % depth_block), + .packed_a_layout(a, a.rows(), a.cols() % depth_block, None), ) } else { None @@ -353,7 +359,7 @@ impl GemmExecutor GemmExecutor(&self, alloc: A, b: Matrix) -> PackedBMatrix { let depth_block = depth_block_size(b.rows()); - let layout = self.kernel.packed_b_layout(depth_block, b.cols()); + let layout = self.kernel.packed_b_layout(depth_block, b.cols(), None); let tail_layout = if b.rows() % depth_block != 0 { Some( self.kernel - .packed_b_layout(b.rows() % depth_block, b.cols()), + .packed_b_layout(b.rows() % depth_block, b.cols(), None), ) } else { None @@ -422,7 +428,7 @@ impl GemmExecutor GemmExecutor>, + a_quant: Option>, + b_quant: Option>, ) -> GemmResult { gemm_impl( &*self.kernel, @@ -476,6 +484,8 @@ impl GemmExecutor GemmExecutor, alpha: f32, bias: Option>, + a_quant: Option>, + b_quant: Option>, ) -> GemmResult { gemm_impl( &*self.kernel, @@ -501,8 +513,17 @@ impl GemmExecutor + 'static>(kernel_type: KernelType) -> Option { + K::new().map(|kernel| GemmExecutor { + kernel: Box::new(kernel), + kernel_type, + }) + } } impl GemmExecutor { @@ -533,37 +554,24 @@ impl GemmExecutor { /// kernel is not supported. #[allow(dead_code)] // Currently only used in tests pub fn with_kernel(hint: KernelType) -> Option { - fn make_kernel + 'static>( - kernel_type: KernelType, - ) -> Option { - K::new().map(|kernel| GemmExecutor { - kernel: Box::new(kernel), - kernel_type, - }) - } - match hint { #[cfg(feature = "avx512")] #[cfg(target_arch = "x86_64")] - KernelType::Avx512 => make_kernel::(hint), + KernelType::Avx512 => Self::from_kernel::(hint), #[cfg(target_arch = "x86_64")] - KernelType::Fma => make_kernel::(hint), + KernelType::Fma => Self::from_kernel::(hint), #[cfg(target_arch = "aarch64")] - KernelType::ArmNeon => make_kernel::(hint), + KernelType::ArmNeon => Self::from_kernel::(hint), #[cfg(target_arch = "wasm32")] #[cfg(target_feature = "simd128")] - KernelType::Wasm => make_kernel::(hint), + KernelType::Wasm => Self::from_kernel::(hint), KernelType::Generic => Some(Self::with_generic_kernel()), } } /// Construct a GemmExecutor that uses the generic kernel. fn with_generic_kernel() -> Self { - let kernel = GenericKernel::new().unwrap(); - GemmExecutor { - kernel: Box::new(kernel), - kernel_type: KernelType::Generic, - } + Self::from_kernel::(KernelType::Generic).unwrap() } } @@ -573,6 +581,12 @@ impl Default for GemmExecutor { } } +impl Default for GemmExecutor { + fn default() -> Self { + Self::from_kernel::(KernelType::Generic).unwrap() + } +} + /// Return the block size for the K / depth dimension of a GEMM operation. /// /// This is chosen such that a `depth_block_size * nr` panel of B fits in the L1 @@ -704,6 +718,8 @@ fn gemv( alpha: f32, beta: OutT, bias: Option>, + a_quant: Option>, + b_quant: Option>, ) { assert!(output_mat.is_contiguous()); @@ -736,7 +752,15 @@ fn gemv( range_chunks(0..a_cols, k_block_size).zip(a_data.chunks(k_block_size)) { let b_block = b.slice((k_block, col_block.clone())); - kernel.gemv_kernel(out_chunk, a_block, b_block, alpha, effective_beta); + kernel.gemv_kernel( + out_chunk, + a_block, + b_block, + alpha, + effective_beta, + a_quant, + b_quant, + ); // Reset `beta` so that subsequent updates for each column // accumulate into the first update. @@ -800,6 +824,8 @@ fn gemm_impl( alpha: f32, beta: OutT, bias: Option>, + a_quant: Option>, + b_quant: Option>, ) -> GemmResult { if a.cols() != b.rows() { return Err(GemmError::KSizeMismatch); @@ -814,6 +840,24 @@ fn gemm_impl( return Err(GemmError::WrongBiasSize); } + match a_quant { + Some(quant) => { + if quant.zero_point.len() != a.rows() { + return Err(GemmError::WrongQuantParamSize); + } + } + None => {} + } + + match b_quant { + Some(quant) => { + if quant.zero_point.len() != b.cols() { + return Err(GemmError::WrongQuantParamSize); + } + } + None => {} + } + // Handle case where output is empty. if a.rows() == 0 || b.cols() == 0 { return Ok(()); @@ -873,6 +917,8 @@ fn gemm_impl( // nb. We checked above that, if present, the bias length matches // `a.rows()` or `b.cols()` as appropriate. bias, + a_quant, + b_quant, ); return Ok(()); } @@ -939,7 +985,8 @@ fn gemm_impl( GemmInputB::Unpacked(_) | GemmInputB::Im2Col(_) => PACKED_B.with(|cell| { let mut packed_b = cell.take(); - let layout = kernel.packed_b_layout(depth_range.len(), col_end - col_start); + let layout = + kernel.packed_b_layout(depth_range.len(), col_end - col_start, b_quant); let packed_uninit = packed_b.alloc(layout.size(), layout.align()); match b { @@ -948,6 +995,7 @@ fn gemm_impl( b, depth_range.clone(), col_start..col_end, + b_quant, ), GemmInputB::Im2Col(im) => kernel.pack_im2col( packed_uninit, @@ -998,6 +1046,7 @@ fn gemm_impl( a, row_end - row_start, depth_range.len(), + a_quant, ); if !layout.must_pack { return LhsBlock::Unpacked(a); @@ -1011,6 +1060,7 @@ fn gemm_impl( a, row_start..row_end, depth_range.clone(), + a_quant, ); // Safety: We initialized `layout.size` bytes. @@ -1037,6 +1087,8 @@ fn gemm_impl( alpha, effective_beta, bias, + a_quant, + b_quant, ); if let Some(packed_a) = thread_local_packed_a { @@ -1097,6 +1149,8 @@ fn gemm_block( alpha: f32, beta: OutT, bias: Option>, + a_quant: Option>, + b_quant: Option>, ) { let (mr, nr) = (kernel.mr(), kernel.nr()); @@ -1114,6 +1168,12 @@ fn gemm_block( .for_each(|(block_col_tile, col_tile)| { let b_panel_offset = block_col_tile * b.panel_stride; let b_panel = &b.data[b_panel_offset..b_panel_offset + b.panel_stride]; + let b_quant_tile = b_quant.map(|bq| { + let col_range = col_tile * nr..(col_tile * nr + nr).min(bq.zero_point.len()); + QuantParams { + zero_point: &bq.zero_point[col_range], + } + }); // Loop over row tiles. for (block_row_tile, row_tile) in row_tiles.clone().enumerate() { @@ -1122,6 +1182,13 @@ fn gemm_block( // every output tile is processed by one thread at a time. let out_tile = unsafe { output.tile(row_tile, col_tile) }; + let a_quant_tile = a_quant.map(|aq| { + let row_range = row_tile * mr..(row_tile * mr + mr).min(aq.zero_point.len()); + QuantParams { + zero_point: &aq.zero_point[row_range], + } + }); + let kernel_lhs = match a { LhsBlock::Packed { data, @@ -1160,6 +1227,8 @@ fn gemm_block( depth_range.len(), alpha, beta, + a_quant_tile, + b_quant_tile, ); } @@ -1217,7 +1286,7 @@ mod tests { use super::{ BiasVector, ColOffsets, GemmError, GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, - Im2Col, KernelType, RowOffsets, + Im2Col, KernelType, QuantParams, RowOffsets, }; /// Scale a possibly non-float value by a float. @@ -1242,7 +1311,14 @@ mod tests { /// Type that can be used as the output for the reference GEMM /// implementation. trait RefGemmOutT: - Default + GemmOutT + From + From + MulFloat + ApproxEq + std::fmt::Debug + Default + + GemmOutT + + From + + From + + MulFloat + + ApproxEq + + std::fmt::Debug + + std::ops::Sub { } @@ -1253,19 +1329,30 @@ mod tests { { } + impl RefGemmOutT for i32 + where + i32: From, + i32: From, + { + } + #[derive(Clone)] - struct GemmOpts<'a, OutT> { + struct GemmOpts<'a, LhsT, RhsT, OutT> { alpha: f32, beta: OutT, bias: Option>, + a_quant: Option>, + b_quant: Option>, } - impl Default for GemmOpts<'_, OutT> { + impl Default for GemmOpts<'_, LhsT, RhsT, OutT> { fn default() -> Self { GemmOpts { alpha: 1., beta: OutT::zero(), bias: None, + a_quant: None, + b_quant: None, } } } @@ -1278,19 +1365,36 @@ mod tests { mut output: MatrixMut, a: Matrix, b: Matrix, - opts: Option>, + opts: Option>, ) where LhsT: GemmInT, RhsT: GemmInT, OutT: RefGemmOutT, { - let GemmOpts { alpha, beta, bias } = opts.unwrap_or_default(); + let GemmOpts { + alpha, + beta, + bias, + a_quant, + b_quant, + } = opts.unwrap_or_default(); for r in 0..a.rows() { + let a_zero = a_quant + .as_ref() + .map(|aq| OutT::from(aq.zero_point[r])) + .unwrap_or(OutT::zero()); for c in 0..b.cols() { + let b_zero = b_quant + .as_ref() + .map(|bq| OutT::from(bq.zero_point[c])) + .unwrap_or(OutT::zero()); + let mut accum = OutT::zero(); for k in 0..a.cols() { - accum = accum + OutT::from(a[[r, k]]) * OutT::from(b[[k, c]]); + let a_el = OutT::from(a[[r, k]]) - a_zero; + let b_el = OutT::from(b[[k, c]]) - b_zero; + accum = accum + a_el * b_el; } let bias = match bias { Some(BiasVector::Row(b)) => b[c], @@ -1305,7 +1409,7 @@ mod tests { fn reference_matmul( a: Matrix, b: Matrix, - opts: Option>, + opts: Option>, ) -> NdTensor where LhsT: GemmInT, @@ -1340,15 +1444,22 @@ mod tests { mut output: MatrixMut, a: Matrix, b: Matrix, - opts: Option>, + opts: Option>, gemm: Option<&GemmExecutor>, - ) where + ) -> super::GemmResult + where GemmExecutor: Default, { let out_row_stride = output.stride(0); let default_gemm = GemmExecutor::default(); let gemm = gemm.unwrap_or(&default_gemm); - let GemmOpts { alpha, beta, bias } = opts.unwrap_or_default(); + let GemmOpts { + alpha, + beta, + bias, + a_quant, + b_quant, + } = opts.unwrap_or_default(); gemm.gemm( output.data_mut().expect("expected contiguous input"), @@ -1358,22 +1469,23 @@ mod tests { alpha, beta, bias, + a_quant, + b_quant, ) - .unwrap(); } fn run_matmul( a: Matrix, b: Matrix, - opts: Option>, + opts: Option>, gemm: Option<&GemmExecutor>, - ) -> NdTensor + ) -> Result, GemmError> where GemmExecutor: Default, { let mut output = NdTensor::zeros([a.rows(), b.cols()]); - run_gemm(output.view_mut(), a, b, opts, gemm); - output + run_gemm(output.view_mut(), a, b, opts, gemm)?; + Ok(output) } /// Run a matmul with the reference and real implementations and verify @@ -1381,19 +1493,19 @@ mod tests { fn run_compare_matmul>( a: Matrix, b: Matrix, - opts: Option>, + opts: Option>, gemm: Option<&GemmExecutor>, ) where GemmExecutor: Default, { - let result = run_matmul(a.view(), b.view(), opts.clone(), gemm); + let result = run_matmul(a.view(), b.view(), opts.clone(), gemm).unwrap(); let expected = reference_matmul(a.view(), b.view(), opts); expect_equal(&result, &expected).unwrap(); } // Simplest possible test case for easy debugging. #[test] - fn test_simple_gemm() -> Result<(), Box> { + fn test_simple_gemm_f32() -> Result<(), Box> { let a = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]); let b = NdTensor::from_data([2, 2], vec![5., 6., 7., 8.]); run_compare_matmul(a.view(), b.view(), None, None); @@ -1406,6 +1518,14 @@ mod tests { Ok(()) } + #[test] + fn test_simple_gemm_u8i8_i32() -> Result<(), Box> { + let a = NdTensor::from_data([2, 2], vec![1, 2, 3, 4]); + let b = NdTensor::from_data([2, 2], vec![5, 6, 7, 8]); + run_compare_matmul::(a.view(), b.view(), None, None); + Ok(()) + } + #[test] fn test_gemm_input_errors() { struct Case { @@ -1452,6 +1572,8 @@ mod tests { 1., // alpha 0., // beta None, // bias + None, // a_quant + None, // b_quant ); assert_eq!(result, Err(expected)); } @@ -1507,7 +1629,7 @@ mod tests { let a = NdTensor::::rand(lhs_size, &mut rng); let b = NdTensor::::rand(rhs_size, &mut rng); - let result = run_matmul(a.view(), b.view(), None, gemm); + let result = run_matmul(a.view(), b.view(), None, gemm).unwrap(); let expected = reference_matmul(a.view(), b.view(), None); if let Err(err) = expect_equal(&result, &expected) { @@ -1558,10 +1680,93 @@ mod tests { test_gemm_various_input_sizes(Some(&gemm)) } + #[test] + fn test_gemm_u8i8_i32() -> Result<(), Box> { + let gemm = GemmExecutor::::default(); + test_gemm_various_input_sizes(Some(&gemm)) + } + + #[test] + fn test_gemm_u8i8_i32_zero_point() { + let mut rng = XorShiftRng::new(1234); + + struct Case { + m: usize, + n: usize, + k: usize, + } + + let cases = [ + // Matrix-matrix + Case { m: 5, n: 7, k: 10 }, + // Vector-matrix + Case { m: 1, n: 5, k: 10 }, + ]; + + for Case { m, n, k } in cases { + let a = NdTensor::::rand([m, k], &mut rng); + let b = NdTensor::::rand([k, n], &mut rng); + let a_zero_point: Vec<_> = (0..a.rows() as u8).collect(); + let b_zero_point: Vec<_> = (0..b.cols() as i8).collect(); + let opts = Some(GemmOpts { + a_quant: Some(QuantParams { + zero_point: &a_zero_point, + }), + b_quant: Some(QuantParams { + zero_point: &b_zero_point, + }), + ..Default::default() + }); + run_compare_matmul(a.view(), b.view(), opts, None); + } + } + + #[test] + fn test_gemm_u8i8_i32_invalid_zero_point() { + let mut rng = XorShiftRng::new(1234); + let a = NdTensor::::rand([5, 10], &mut rng); + let b = NdTensor::::rand([10, 3], &mut rng); + + fn gemm_opts<'a>( + a_zero_point: &'a [u8], + b_zero_point: &'a [i8], + ) -> GemmOpts<'a, u8, i8, i32> { + GemmOpts { + a_quant: Some(QuantParams { + zero_point: a_zero_point, + }), + b_quant: Some(QuantParams { + zero_point: b_zero_point, + }), + ..Default::default() + } + } + let a_zero_point: Vec<_> = (0..a.rows()).map(|row| row as u8).collect(); + let b_zero_point: Vec<_> = (0..b.cols()).map(|col| col as i8).collect(); + + // LHS zero point does not match LHS rows. + let result = run_matmul( + a.view(), + b.view(), + Some(gemm_opts(&[1, 2, 3], &b_zero_point)), + None, + ); + assert_eq!(result, Err(GemmError::WrongQuantParamSize)); + + // RHS zero point does not match RHS columns. + let result = run_matmul( + a.view(), + b.view(), + Some(gemm_opts(&a_zero_point, &[1, 2, 3, 4])), + None, + ); + assert_eq!(result, Err(GemmError::WrongQuantParamSize)); + } + #[test] fn test_gemm_transposed() -> Result<(), Box> { let mut rng = XorShiftRng::new(1234); - let mut a = NdTensor::rand([20, 30], &mut rng); + let mut a = NdTensor::::rand([20, 30], &mut rng); let mut b = NdTensor::rand([10, 20], &mut rng); // Transpose the input matrices. This will alter their row and column @@ -1632,7 +1837,7 @@ mod tests { b.view(), opts.clone(), Some(&gemm), - ); + )?; reference_gemm(expected.view_mut(), a.view(), b.view(), opts); expect_equal(&result, &expected)?; @@ -1680,7 +1885,7 @@ mod tests { alpha, ..Default::default() }); - run_gemm(result.view_mut(), a.view(), b.view(), opts.clone(), None); + run_gemm(result.view_mut(), a.view(), b.view(), opts.clone(), None)?; let expected = reference_matmul(a.view(), b.view(), opts); expect_equal(&result, &expected)?; } @@ -1785,6 +1990,8 @@ mod tests { 1., // alpha 1., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); @@ -1802,6 +2009,8 @@ mod tests { 1., // alpha 1., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); @@ -1885,6 +2094,8 @@ mod tests { 1., // alpha 0., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); @@ -2087,7 +2298,7 @@ mod tests { ..Default::default() }); - run_gemm(result.view_mut(), a.view(), b.view(), opts.clone(), None); + run_gemm(result.view_mut(), a.view(), b.view(), opts.clone(), None).unwrap(); let mut expected = NdTensor::zeros([1, b.size(1)]); reference_gemm(expected.view_mut(), a.view(), b.view(), opts); @@ -2141,7 +2352,7 @@ mod tests { let mut rng = XorShiftRng::new(1234); let mut result = NdTensor::zeros([m, n]); - let a = NdTensor::rand([m, k], &mut rng); + let a = NdTensor::::rand([m, k], &mut rng); let b = if transpose_b { let mut b = NdTensor::rand([n, k], &mut rng); b.transpose(); @@ -2152,7 +2363,7 @@ mod tests { let start = Instant::now(); for _i in 0..iters { - run_gemm(result.view_mut(), a.view(), b.view(), None, None); + run_gemm(result.view_mut(), a.view(), b.view(), None, None).unwrap(); } let duration = start.elapsed(); diff --git a/src/gemm/errors.rs b/src/gemm/errors.rs index 0a919ec9..0ba895cd 100644 --- a/src/gemm/errors.rs +++ b/src/gemm/errors.rs @@ -9,6 +9,8 @@ pub enum GemmError { KSizeMismatch, /// Bias vector length does not match the corresponding output matrix size. WrongBiasSize, + /// Quantization parameter size does not match corresponding input size. + WrongQuantParamSize, /// The buffer provided for the output is too short. OutputNotLargeEnough, /// The data was packed with a kernel that uses a different layout than @@ -26,6 +28,9 @@ impl Display for GemmError { write!(fmt, "columns of matrix `a` must match rows of matrix `b`") } Self::WrongBiasSize => write!(fmt, "bias vector length is incorrect"), + Self::WrongQuantParamSize => { + write!(fmt, "quantization parameter size does not match input") + } Self::OutputNotLargeEnough => write!(fmt, "output buffer is too small"), Self::PackedDataKernelMismatch => { write!(fmt, "matrix was packed with a different kernel") diff --git a/src/gemm/kernels.rs b/src/gemm/kernels.rs index 79d4cb0c..43e81d1b 100644 --- a/src/gemm/kernels.rs +++ b/src/gemm/kernels.rs @@ -87,6 +87,23 @@ impl PackedLayout { } } +/// Parameters required to perform matrix multiplication on quantized tensors. +#[derive(Debug)] +pub struct QuantParams<'a, T> { + /// Values that correspond to zero in each row (for LHS inputs) or column + /// (for RHS inputs). + pub zero_point: &'a [T], +} + +// Make QuantParams Copy/Clone regardless of `T`. +impl Clone for QuantParams<'_, T> { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for QuantParams<'_, T> {} + /// Kernel that computes a small tile of a general matrix multiplication (GEMM) /// or general matrix-vector multiplication (GEMV). /// @@ -132,7 +149,13 @@ pub unsafe trait Kernel: Sync { fn name(&self) -> &'static str; /// Return the layout of a packing buffer required to pack an A / LHS input. - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout; + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + quant: Option>, + ) -> PackedLayout; /// Pack a block of the LHS / "A" input for use by this kernel. fn pack_a_block( @@ -141,6 +164,7 @@ pub unsafe trait Kernel: Sync { a: Matrix, rows: Range, cols: Range, + quant: Option>, ); /// Return the layout of a packing buffer required to pack a block of a "B" @@ -149,7 +173,12 @@ pub unsafe trait Kernel: Sync { /// Unlike `packed_a_layout` this doesn't take the matrix as an argument. /// `packed_a_layout` may use this to indicate that the A input does not /// need to be packed. For the B input it is assumed this is always packed. - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout; + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + quant: Option>, + ) -> PackedLayout; /// Pack a block of the RHS / "B" input for use by this kernel. fn pack_b_block( @@ -158,6 +187,7 @@ pub unsafe trait Kernel: Sync { b: Matrix, rows: Range, cols: Range, + quant: Option>, ); /// Pack a block of an image as the B input for use by this kernel, using @@ -200,6 +230,8 @@ pub unsafe trait Kernel: Sync { depth: usize, alpha: f32, beta: OutT, + a_quant: Option>, + b_quant: Option>, ); /// Compute an output block of a vector-matrix product ("gemv"). @@ -226,6 +258,8 @@ pub unsafe trait Kernel: Sync { b: Matrix, alpha: f32, beta: OutT, + a_quant: Option>, + b_quant: Option>, ); } diff --git a/src/gemm/kernels/aarch64.rs b/src/gemm/kernels/aarch64.rs index a9184dec..6facf613 100644 --- a/src/gemm/kernels/aarch64.rs +++ b/src/gemm/kernels/aarch64.rs @@ -6,7 +6,7 @@ use rten_simd::vec_count; use rten_tensor::{Matrix, MatrixLayout}; use super::simd_generic::{simd_gemv, GemmDispatch}; -use super::{Kernel, Lhs, PackedLayout, TempTile}; +use super::{Kernel, Lhs, PackedLayout, QuantParams, TempTile}; use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_layout, packed_b_layout}; use crate::gemm::Im2Col; use crate::number::{cast_pod_mut_slice, cast_pod_slice}; @@ -39,7 +39,13 @@ unsafe impl Kernel for ArmNeonKernel { Self::NR } - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { let mut info = packed_a_layout::(rows, cols); info.must_pack = a.col_stride() != 1; info @@ -51,12 +57,18 @@ unsafe impl Kernel for ArmNeonKernel { a: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); pack_a_block::(out, a, rows, cols); } - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { packed_b_layout::(rows, cols) } @@ -66,6 +78,7 @@ unsafe impl Kernel for ArmNeonKernel { b: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); pack_b_block::(out, b, rows, cols); @@ -98,6 +111,8 @@ unsafe impl Kernel for ArmNeonKernel { depth: usize, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { const MR: usize = ArmNeonKernel::MR; const NR: usize = ArmNeonKernel::NR; @@ -152,6 +167,8 @@ unsafe impl Kernel for ArmNeonKernel { b: Matrix, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { // Safety - float32x4_t is supported if this kernel was constructed. unsafe { diff --git a/src/gemm/kernels/generic.rs b/src/gemm/kernels/generic.rs index d11fc477..a32c5dbf 100644 --- a/src/gemm/kernels/generic.rs +++ b/src/gemm/kernels/generic.rs @@ -5,7 +5,7 @@ use rten_simd::vec_count; use rten_tensor::{Matrix, MatrixLayout}; use super::simd_generic::{simd_gemv, GemmDispatch}; -use super::{Kernel, Lhs, PackedLayout, TempTile}; +use super::{Kernel, Lhs, PackedLayout, QuantParams, TempTile}; use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_layout, packed_b_layout}; use crate::gemm::Im2Col; use crate::number::{cast_pod_mut_slice, cast_pod_slice}; @@ -44,7 +44,13 @@ unsafe impl Kernel for GenericKernel { "base" } - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { let mut info = packed_a_layout::(rows, cols); info.must_pack = a.col_stride() != 1; info @@ -56,12 +62,18 @@ unsafe impl Kernel for GenericKernel { a: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).unwrap(); pack_a_block::(out, a, rows, cols); } - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { packed_b_layout::(rows, cols) } @@ -71,6 +83,7 @@ unsafe impl Kernel for GenericKernel { b: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).unwrap(); pack_b_block::(out, b, rows, cols); @@ -103,6 +116,8 @@ unsafe impl Kernel for GenericKernel { depth: usize, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { const MR: usize = GenericKernel::MR; const NR: usize = GenericKernel::NR; @@ -156,6 +171,8 @@ unsafe impl Kernel for GenericKernel { b: Matrix, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { // Safety - f32 "SIMD" type is always supported unsafe { @@ -163,3 +180,207 @@ unsafe impl Kernel for GenericKernel { } } } + +unsafe impl Kernel for GenericKernel { + fn new() -> Option { + Some(GenericKernel { _private: () }) + } + + fn mr(&self) -> usize { + Self::MR + } + + fn nr(&self) -> usize { + Self::NR + } + + fn name(&self) -> &'static str { + "generic-i8" + } + + fn packed_a_layout( + &self, + _a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { + let mut info = packed_a_layout::(rows, cols); + info.must_pack = true; + info + } + + fn pack_a_block( + &self, + out: &mut [MaybeUninit], + a: Matrix, + rows: Range, + cols: Range, + _quant: Option>, + ) { + let out = cast_pod_mut_slice(out).unwrap(); + pack_a_block::(out, a, rows, cols); + } + + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { + packed_b_layout::(rows, cols) + } + + fn pack_b_block( + &self, + out: &mut [MaybeUninit], + b: Matrix, + rows: Range, + cols: Range, + _quant: Option>, + ) { + let out = cast_pod_mut_slice(out).unwrap(); + pack_b_block::(out, b, rows, cols); + } + + fn pack_im2col( + &self, + _out: &mut [MaybeUninit], + _image: &Im2Col, + _rows: Range, + _cols: Range, + ) { + unimplemented!("im2col packing not implemented"); + } + + unsafe fn kernel( + &self, + tile_ptr: *mut i32, + tile_row_stride: usize, + a: Lhs, + b: &[u8], + used_rows: usize, + used_cols: usize, + depth: usize, + alpha: f32, + beta: i32, + a_quant: Option>, + b_quant: Option>, + ) { + assert_eq!(alpha, 1.); + assert!(beta == 0 || beta == 1, "unsupported beta value"); + assert!(used_rows <= MR); + assert!(used_cols <= NR); + + const MR: usize = GenericKernel::MR; + const NR: usize = GenericKernel::NR; + + let a_data = match a { + Lhs::Packed(packed) => packed, + Lhs::Unpacked { .. } => panic!("inputs must be packed"), + }; + let a_row_stride = depth; + + let mut a_zero_point = [0u8; MR]; + if let Some(a_quant) = a_quant { + #[allow(clippy::manual_memcpy)] + for row in 0..used_rows { + a_zero_point[row] = a_quant.zero_point[row]; + } + } + let mut b_zero_point = [0i8; NR]; + if let Some(b_quant) = b_quant { + #[allow(clippy::manual_memcpy)] + for col in 0..used_cols { + b_zero_point[col] = b_quant.zero_point[col]; + } + } + + let b: &[i8] = cast_pod_slice(b).unwrap(); + let use_tmp_tile = used_cols < NR || used_rows < MR; + + let mut tmp_tile = TempTile::::new(); + let (dest_ptr, dest_row_stride, dest_beta) = if !use_tmp_tile { + (tile_ptr, tile_row_stride, beta) + } else { + (tmp_tile.as_mut_ptr() as *mut i32, NR, 0) + }; + + let mut tmp = [[0i32; NR]; MR]; + for k in 0..depth { + for row in 0..MR { + let a_i32 = unsafe { *a_data.get_unchecked(row * a_row_stride + k) } as i32 + - a_zero_point[row] as i32; + for col in 0..NR { + let b_i32 = + unsafe { *b.get_unchecked(k * NR + col) } as i32 - b_zero_point[col] as i32; + tmp[row][col] += a_i32 * b_i32; + } + } + } + + if dest_beta == 0 { + for row in 0..used_rows { + for col in 0..used_cols { + dest_ptr + .add(row * dest_row_stride + col) + .write(tmp[row][col]); + } + } + } else { + // nb. We require that beta is 0 or 1, so here it is 1. + for row in 0..used_rows { + for col in 0..used_cols { + *dest_ptr.add(row * dest_row_stride + col) += tmp[row][col]; + } + } + } + + if use_tmp_tile { + tmp_tile.accumulate_into( + tile_ptr as *mut MaybeUninit, + used_rows, + used_cols, + tile_row_stride, + beta, + ); + } + } + + fn gemv_kernel( + &self, + out: &mut [MaybeUninit], + a: &[u8], + b: Matrix, + alpha: f32, + beta: i32, + a_quant: Option>, + b_quant: Option>, + ) { + assert!(beta == 0 || beta == 1); + assert_eq!(alpha, 1.); + assert_eq!(b.rows(), a.len()); + assert_eq!(out.len(), b.cols()); + + let a_zero = a_quant.map(|aq| aq.zero_point[0] as i32).unwrap_or(0); + let depth = a.len(); + + for (out, col) in out.iter_mut().zip(0..b.cols()) { + let b_zero = b_quant.map(|bq| bq.zero_point[col] as i32).unwrap_or(0); + let mut acc = 0; + for k in 0..depth { + let a_el = unsafe { *a.get_unchecked(k) } as i32 - a_zero; + let b_el = unsafe { *b.get_unchecked([k, col]) } as i32 - b_zero; + acc += a_el * b_el; + } + if beta == 0 { + out.write(acc); + } else { + // Safety: Output is initialized when beta is non-zero + unsafe { + out.write(out.assume_init() + acc); + } + } + } + } +} diff --git a/src/gemm/kernels/wasm.rs b/src/gemm/kernels/wasm.rs index cf9d2919..50613a63 100644 --- a/src/gemm/kernels/wasm.rs +++ b/src/gemm/kernels/wasm.rs @@ -6,7 +6,7 @@ use rten_simd::vec_count; use rten_tensor::{Matrix, MatrixLayout}; use super::simd_generic::{simd_gemv, GemmDispatch}; -use super::{Kernel, Lhs, PackedLayout, TempTile}; +use super::{Kernel, Lhs, PackedLayout, QuantParams, TempTile}; use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_layout, packed_b_layout}; use crate::gemm::Im2Col; use crate::number::{cast_pod_mut_slice, cast_pod_slice}; @@ -43,7 +43,13 @@ unsafe impl Kernel for WasmKernel { Self::NR } - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { let mut info = packed_a_layout::(rows, cols); info.must_pack = a.col_stride() != 1; info @@ -55,12 +61,18 @@ unsafe impl Kernel for WasmKernel { a: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).unwrap(); pack_a_block::(out, a, rows, cols); } - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { packed_b_layout::(rows, cols) } @@ -70,6 +82,7 @@ unsafe impl Kernel for WasmKernel { b: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).unwrap(); pack_b_block::(out, b, rows, cols); @@ -102,6 +115,8 @@ unsafe impl Kernel for WasmKernel { depth: usize, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { const MR: usize = WasmKernel::MR; const NR: usize = WasmKernel::NR; @@ -155,6 +170,8 @@ unsafe impl Kernel for WasmKernel { b: Matrix, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { // Safety - WASM SIMD types are supported if this kernel was constructed. unsafe { diff --git a/src/gemm/kernels/x86_64.rs b/src/gemm/kernels/x86_64.rs index c5045aaf..d0b31617 100644 --- a/src/gemm/kernels/x86_64.rs +++ b/src/gemm/kernels/x86_64.rs @@ -12,7 +12,7 @@ use rten_tensor::{Matrix, MatrixLayout}; use rten_simd::isa_detection::is_avx512_supported; use super::simd_generic::{simd_gemv, GemmDispatch}; -use super::{Kernel, Lhs, PackedLayout, TempTile}; +use super::{Kernel, Lhs, PackedLayout, QuantParams, TempTile}; use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_layout, packed_b_layout}; use crate::gemm::Im2Col; use crate::number::{cast_pod_mut_slice, cast_pod_slice}; @@ -84,7 +84,13 @@ unsafe impl Kernel for FmaKernel { Self::NR } - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { let mut info = packed_a_layout::(rows, cols); info.must_pack = a.col_stride() != 1; info @@ -96,6 +102,7 @@ unsafe impl Kernel for FmaKernel { a: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); @@ -105,7 +112,12 @@ unsafe impl Kernel for FmaKernel { } } - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { packed_b_layout::(rows, cols) } @@ -115,6 +127,7 @@ unsafe impl Kernel for FmaKernel { b: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).unwrap(); @@ -153,6 +166,8 @@ unsafe impl Kernel for FmaKernel { depth: usize, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { const MR: usize = FmaKernel::MR; const NR: usize = FmaKernel::NR; @@ -206,6 +221,8 @@ unsafe impl Kernel for FmaKernel { b: Matrix, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { #[target_feature(enable = "avx2")] #[target_feature(enable = "fma")] @@ -275,7 +292,13 @@ unsafe impl Kernel for Avx512Kernel { Self::NR } - fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout { + fn packed_a_layout( + &self, + a: Matrix, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { let mut info = packed_a_layout::(rows, cols); info.must_pack = a.col_stride() != 1; info @@ -287,6 +310,7 @@ unsafe impl Kernel for Avx512Kernel { a: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); @@ -296,7 +320,12 @@ unsafe impl Kernel for Avx512Kernel { } } - fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout { + fn packed_b_layout( + &self, + rows: usize, + cols: usize, + _quant: Option>, + ) -> PackedLayout { packed_b_layout::(rows, cols) } @@ -306,6 +335,7 @@ unsafe impl Kernel for Avx512Kernel { b: Matrix, rows: Range, cols: Range, + _quant: Option>, ) { let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer"); @@ -344,6 +374,8 @@ unsafe impl Kernel for Avx512Kernel { depth: usize, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { const MR: usize = Avx512Kernel::MR; const NR: usize = Avx512Kernel::NR; @@ -397,6 +429,8 @@ unsafe impl Kernel for Avx512Kernel { b: Matrix, alpha: f32, beta: f32, + _a_quant: Option>, + _b_quant: Option>, ) { #[target_feature(enable = "avx512f")] #[target_feature(enable = "avx512vl")] diff --git a/src/number.rs b/src/number.rs index 5cace137..5e9fe731 100644 --- a/src/number.rs +++ b/src/number.rs @@ -43,25 +43,39 @@ pub trait Identities { fn zero() -> Self; } -impl Identities for f32 { - fn one() -> f32 { - 1. - } +macro_rules! impl_float_identities { + ($type:ty) => { + impl Identities for $type { + fn one() -> Self { + 1. + } - fn zero() -> f32 { - 0. - } + fn zero() -> Self { + 0. + } + } + }; } -impl Identities for i32 { - fn one() -> i32 { - 1 - } - fn zero() -> i32 { - 0 - } +macro_rules! impl_int_identities { + ($type:ty) => { + impl Identities for $type { + fn one() -> Self { + 1 + } + + fn zero() -> Self { + 0 + } + } + }; } +impl_float_identities!(f32); +impl_int_identities!(i32); +impl_int_identities!(i8); +impl_int_identities!(u8); + /// Test if a number is a float NaN ("Not a number") value. pub trait IsNaN { /// Return true if the current value is a NaN. See [`f32::is_nan`]. diff --git a/src/ops/conv.rs b/src/ops/conv.rs index 00772cdd..2620c4ed 100644 --- a/src/ops/conv.rs +++ b/src/ops/conv.rs @@ -57,6 +57,8 @@ where GemmInputB::Unpacked(in_mat.view()), 1., // alpha bias_vec, + None, // a_quant + None, // b_quant ) .unwrap(); n_init += out_item.len(); @@ -287,6 +289,8 @@ where GemmInputB::Im2Col(&im2col), 1., // alpha bias_vec, + None, // a_quant + None, // b_quant ) .unwrap(); n_init.fetch_add(out_mat.len(), Ordering::SeqCst); @@ -319,7 +323,7 @@ impl Operator for Conv { let input = inputs.require_as(0)?; let weight = inputs.require_as(1)?; let bias = inputs.get_as(2)?; - conv( + conv::( pool, input, weight, @@ -564,6 +568,8 @@ pub fn conv_transpose( GemmInputB::Unpacked(input_mat.view()), 1., // alpha None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); @@ -1623,8 +1629,8 @@ mod tests { // This has a small spatial shape, so it measures overhead around the // inner loop. A larger spatial shape would be more affected by the // efficiency of the innermost loops. - let input = Tensor::rand(&[1, 576, 14, 14], &mut rng); - let kernel = Tensor::rand(&[576, 1, 3, 3], &mut rng); + let input = Tensor::::rand(&[1, 576, 14, 14], &mut rng); + let kernel = Tensor::::rand(&[576, 1, 3, 3], &mut rng); let n_groups = input.size(1); let padding = Padding::Fixed([1, 1, 1, 1].into()); diff --git a/src/ops/matmul.rs b/src/ops/matmul.rs index df204657..a94e1f77 100644 --- a/src/ops/matmul.rs +++ b/src/ops/matmul.rs @@ -5,9 +5,8 @@ use rten_tensor::{Matrix, NdTensorView, Tensor, TensorView}; use smallvec::SmallVec; use crate::gemm::{ - BiasVector, GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, PackedBMatrix, + BiasVector, GemmExecutor, GemmInT, GemmInputA, GemmInputB, GemmOutT, PackedBMatrix, QuantParams, }; -use crate::iter_util::range_chunks; use crate::ops::binary_elementwise::broadcast_shapes; use crate::ops::layout::expand_to; use crate::ops::{ @@ -60,6 +59,8 @@ where alpha, beta, None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); output @@ -74,6 +75,8 @@ where GemmInputB::Unpacked(b.nd_view()), alpha, None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); // Safety: `gemm_uninit` initialized all elements @@ -144,7 +147,17 @@ pub fn matmul( where GemmExecutor: Default, { - matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, None, None) + matmul_impl( + pool, + a, + b, + packed_b, + MatmulStrategy::Auto, + None, + None, + None, /* a_quant */ + None, /* b_quant */ + ) } fn matmul_impl( @@ -155,6 +168,8 @@ fn matmul_impl( strategy: MatmulStrategy, bias: Option>, alpha: Option, + a_quant: Option>, + b_quant: Option>, ) -> Result, OpError> where GemmExecutor: Default, @@ -215,6 +230,8 @@ where strategy, bias, alpha, + a_quant, + b_quant, )?; output.reshape(out_shape); return Ok(output); @@ -288,6 +305,8 @@ where b_input, alpha.unwrap_or(1.), bias, + a_quant, + b_quant, ) .unwrap(); }); @@ -347,7 +366,17 @@ pub fn matmul_fused( where GemmExecutor: Default, { - matmul_impl(pool, a, b, packed_b, MatmulStrategy::Auto, bias, alpha) + matmul_impl( + pool, + a, + b, + packed_b, + MatmulStrategy::Auto, + bias, + alpha, + None, /* a_quant */ + None, /* b_quant */ + ) } /// MatMul with fused addition of bias and scaling of result. @@ -400,42 +429,6 @@ pub fn matmul_integer( a_zero_point: Option>, b_zero_point: Option>, ) -> Result, OpError> { - if a.ndim() < 2 || b.ndim() < 2 { - return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions")); - } - - let a_rows = a.size(a.ndim() - 2); - let a_cols = a.size(a.ndim() - 1); - - let b_rows = b.size(b.ndim() - 2); - let b_cols = b.size(b.ndim() - 1); - - if a_cols != b_rows { - return Err(OpError::IncompatibleInputShapes( - "Columns of first matrix does not match rows of second matrix", - )); - } - - let a_prefix = &a.shape()[..a.ndim() - 2]; - let b_prefix = &b.shape()[..b.ndim() - 2]; - - let out_prefix = broadcast_shapes(a_prefix, b_prefix) - .ok_or(OpError::IncompatibleInputShapes("Cannot broadcast shapes"))?; - let out_shape = &[out_prefix.as_slice(), &[a_rows, b_cols]].concat(); - - let mut output = Tensor::::uninit_in(pool, out_shape); - if output.is_empty() { - // nb. We don't need to alloc from the pool here, since the buffer - // is already empty. - return Ok(Tensor::zeros(out_shape)); - } - - let a_broadcast_shape = [out_prefix.as_slice(), &[a_rows, a_cols]].concat(); - let b_broadcast_shape = [out_prefix.as_slice(), &[b_rows, b_cols]].concat(); - - let a_broadcast = a.broadcast(a_broadcast_shape.as_slice()); - let b_broadcast = b.broadcast(b_broadcast_shape.as_slice()); - // Convert the zero point to a vector. // // The spec allows for the zero point to be a scalar, vector or a batch of @@ -460,58 +453,33 @@ pub fn matmul_integer( } } - let a_zero = zero_point_to_vec(a_zero_point, a_rows)?; - let b_zero = zero_point_to_vec(b_zero_point, b_cols)?; + if a.ndim() < 2 || b.ndim() < 2 { + return Err(OpError::InvalidValue("Inputs must have >= 2 dimensions")); + } + let a_rows = a.size(a.ndim() - 2); + let b_cols = b.size(b.ndim() - 1); - a_broadcast - .inner_iter::<2>() - .zip(b_broadcast.inner_iter::<2>()) - .zip(output.inner_iter_mut::<2>()) - .par_bridge() - .for_each(|((a_mat, b_mat), mut out_mat)| { - let [m, k] = a_mat.shape(); - let [bk, n] = b_mat.shape(); - assert_eq!(k, bk); - assert_eq!(out_mat.shape(), [m, n]); - - // Do some extremely rudimentary cache blocking. - for col_block in range_chunks(0..n, 32) { - for depth_block in range_chunks(0..k, 32) { - for row_block in range_chunks(0..m, 32) { - for j in col_block.clone() { - let b_zero = b_zero.as_ref().map(|zp| zp[j]).unwrap_or(0) as i32; - - for i in row_block.clone() { - let a_zero = a_zero.as_ref().map(|zp| zp[i]).unwrap_or(0) as i32; - - let mut out = 0i32; - for k in depth_block.clone() { - // Safety: `[i, k]` is in-bounds for `a_mat`. - let a = unsafe { *a_mat.get_unchecked([i, k]) } as i32 - a_zero; - // Safety: `[k, j]` is in-bounds for `b_mat`. - let b = unsafe { *b_mat.get_unchecked([k, j]) } as i32 - b_zero; - out += a * b; - } - unsafe { - // Safety: `[i, j]` is in-bounds for `b_mat`. - let el = out_mat.get_unchecked_mut([i, j]); - if depth_block.start == 0 { - el.write(out); - } else { - el.write(el.assume_init() + out); - } - } - } - } - } - } - } - }); + let a_zero = zero_point_to_vec(a_zero_point, a_rows)?.map(|zp| zp.to_contiguous()); + let a_quant = a_zero.as_ref().map(|zp| QuantParams { + zero_point: zp.data().unwrap(), + }); - // Safety: Loop above initialized all output elements. - let output = unsafe { output.assume_init() }; + let b_zero = zero_point_to_vec(b_zero_point, b_cols)?.map(|zp| zp.to_contiguous()); + let b_quant = b_zero.as_ref().map(|zp| QuantParams { + zero_point: zp.data().unwrap(), + }); - Ok(output) + matmul_impl( + pool, + a, + b, + None, + MatmulStrategy::Auto, + None, + None, + a_quant, + b_quant, + ) } #[derive(Debug)] @@ -621,6 +589,8 @@ mod tests { alpha.unwrap_or(1.), 0., /* beta */ bias, + None, // a_quant + None, // b_quant ) .unwrap() }); @@ -1018,8 +988,8 @@ mod tests { } in cases { let mut rng = XorShiftRng::new(1234); - let a = Tensor::rand(a_shape, &mut rng); - let b = Tensor::rand(b_shape, &mut rng); + let a = Tensor::::rand(a_shape, &mut rng); + let b = Tensor::::rand(b_shape, &mut rng); let result = matmul(&pool, a.view(), b.view(), None); assert_eq!(result, Err(error)); @@ -1045,8 +1015,8 @@ mod tests { let pool = new_pool(); for Case { m, n, k } in cases { let mut rng = XorShiftRng::new(1234); - let a = Tensor::rand(&[m, k], &mut rng); - let b = Tensor::rand(&[k, n], &mut rng); + let a = Tensor::::rand(&[m, k], &mut rng); + let b = Tensor::::rand(&[k, n], &mut rng); let result = matmul(&pool, a.view(), b.view(), None).unwrap(); assert_eq!(result.shape(), &[m, n]); @@ -1091,6 +1061,15 @@ mod tests { b_zero_point: Some(Tensor::from([3, 4])), expected_err: None, }, + // A input which is a row vector + Case { + a: Tensor::from([[1, 2, 3, 4]]), + b: Tensor::from([[5, 6], [7, 8], [9, 10], [11, 12]]), + a_zero_point: Some(Tensor::from([1])), + b_zero_point: Some(Tensor::from([3, 4])), + expected_err: None, + }, + // Incorrect zero point size Case { a: Tensor::from([[1, 2], [3, 4]]), b: Tensor::from([[5, 6], [7, 8]]), @@ -1228,8 +1207,8 @@ mod tests { } in cases { let mut rng = XorShiftRng::new(1234); - let a = Tensor::rand(&[a_batch, a_rows, a_cols], &mut rng); - let b = Tensor::rand(&[a_cols, b_cols], &mut rng); + let a = Tensor::::rand(&[a_batch, a_rows, a_cols], &mut rng); + let b = Tensor::::rand(&[a_cols, b_cols], &mut rng); let run_trial = |strategy| { let trials = 10; @@ -1238,9 +1217,19 @@ mod tests { ); let pool = new_pool(); run_bench(trials, Some(&desc), || { - matmul_impl(&pool, a.view(), b.view(), None, strategy, None, None) - .unwrap() - .auto_return(&pool); + matmul_impl( + &pool, + a.view(), + b.view(), + None, + strategy, + None, + None, + None, + None, + ) + .unwrap() + .auto_return(&pool); }); }; diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index 3c9e7107..736950b9 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -241,6 +241,8 @@ pub fn gru( 1., // alpha 0., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); if let Some(input_bias) = input_bias { @@ -257,6 +259,8 @@ pub fn gru( 1., // alpha 0., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); if let Some(hidden_bias) = hidden_bias { @@ -498,6 +502,8 @@ pub fn lstm( 1., // alpha 0., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); if let Some(input_bias) = input_bias { @@ -512,6 +518,8 @@ pub fn lstm( 1., // alpha 1., // beta None, // bias + None, // a_quant + None, // b_quant ) .unwrap(); if let Some(hidden_bias) = hidden_bias {