Skip to content

Commit

Permalink
Merge pull request #528 from robertknight/gemm-kernel-quant
Browse files Browse the repository at this point in the history
Implement minimal u8 x i8 -> i32 quantized GEMM support
  • Loading branch information
robertknight authored Jan 10, 2025
2 parents 1f3213c + e05f298 commit b6889c5
Show file tree
Hide file tree
Showing 11 changed files with 739 additions and 183 deletions.
315 changes: 263 additions & 52 deletions src/gemm.rs

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions src/gemm/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ pub enum GemmError {
KSizeMismatch,
/// Bias vector length does not match the corresponding output matrix size.
WrongBiasSize,
/// Quantization parameter size does not match corresponding input size.
WrongQuantParamSize,
/// The buffer provided for the output is too short.
OutputNotLargeEnough,
/// The data was packed with a kernel that uses a different layout than
Expand All @@ -26,6 +28,9 @@ impl Display for GemmError {
write!(fmt, "columns of matrix `a` must match rows of matrix `b`")
}
Self::WrongBiasSize => write!(fmt, "bias vector length is incorrect"),
Self::WrongQuantParamSize => {
write!(fmt, "quantization parameter size does not match input")
}
Self::OutputNotLargeEnough => write!(fmt, "output buffer is too small"),
Self::PackedDataKernelMismatch => {
write!(fmt, "matrix was packed with a different kernel")
Expand Down
38 changes: 36 additions & 2 deletions src/gemm/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@ impl PackedLayout {
}
}

/// Parameters required to perform matrix multiplication on quantized tensors.
#[derive(Debug)]
pub struct QuantParams<'a, T> {
/// Values that correspond to zero in each row (for LHS inputs) or column
/// (for RHS inputs).
pub zero_point: &'a [T],
}

// Make QuantParams Copy/Clone regardless of `T`.
impl<T> Clone for QuantParams<'_, T> {
fn clone(&self) -> Self {
*self
}
}

impl<T> Copy for QuantParams<'_, T> {}

/// Kernel that computes a small tile of a general matrix multiplication (GEMM)
/// or general matrix-vector multiplication (GEMV).
///
Expand Down Expand Up @@ -132,7 +149,13 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
fn name(&self) -> &'static str;

/// Return the layout of a packing buffer required to pack an A / LHS input.
fn packed_a_layout(&self, a: Matrix<LhsT>, rows: usize, cols: usize) -> PackedLayout;
fn packed_a_layout(
&self,
a: Matrix<LhsT>,
rows: usize,
cols: usize,
quant: Option<QuantParams<LhsT>>,
) -> PackedLayout;

/// Pack a block of the LHS / "A" input for use by this kernel.
fn pack_a_block(
Expand All @@ -141,6 +164,7 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
a: Matrix<LhsT>,
rows: Range<usize>,
cols: Range<usize>,
quant: Option<QuantParams<LhsT>>,
);

/// Return the layout of a packing buffer required to pack a block of a "B"
Expand All @@ -149,7 +173,12 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
/// Unlike `packed_a_layout` this doesn't take the matrix as an argument.
/// `packed_a_layout` may use this to indicate that the A input does not
/// need to be packed. For the B input it is assumed this is always packed.
fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout;
fn packed_b_layout(
&self,
rows: usize,
cols: usize,
quant: Option<QuantParams<RhsT>>,
) -> PackedLayout;

/// Pack a block of the RHS / "B" input for use by this kernel.
fn pack_b_block(
Expand All @@ -158,6 +187,7 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
b: Matrix<RhsT>,
rows: Range<usize>,
cols: Range<usize>,
quant: Option<QuantParams<RhsT>>,
);

/// Pack a block of an image as the B input for use by this kernel, using
Expand Down Expand Up @@ -200,6 +230,8 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
depth: usize,
alpha: f32,
beta: OutT,
a_quant: Option<QuantParams<LhsT>>,
b_quant: Option<QuantParams<RhsT>>,
);

/// Compute an output block of a vector-matrix product ("gemv").
Expand All @@ -226,6 +258,8 @@ pub unsafe trait Kernel<LhsT, RhsT, OutT>: Sync {
b: Matrix<RhsT>,
alpha: f32,
beta: OutT,
a_quant: Option<QuantParams<LhsT>>,
b_quant: Option<QuantParams<RhsT>>,
);
}

Expand Down
23 changes: 20 additions & 3 deletions src/gemm/kernels/aarch64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use rten_simd::vec_count;
use rten_tensor::{Matrix, MatrixLayout};

use super::simd_generic::{simd_gemv, GemmDispatch};
use super::{Kernel, Lhs, PackedLayout, TempTile};
use super::{Kernel, Lhs, PackedLayout, QuantParams, TempTile};
use crate::gemm::packing::{pack_a_block, pack_b_block, packed_a_layout, packed_b_layout};
use crate::gemm::Im2Col;
use crate::number::{cast_pod_mut_slice, cast_pod_slice};
Expand Down Expand Up @@ -39,7 +39,13 @@ unsafe impl Kernel<f32, f32, f32> for ArmNeonKernel {
Self::NR
}

fn packed_a_layout(&self, a: Matrix, rows: usize, cols: usize) -> PackedLayout {
fn packed_a_layout(
&self,
a: Matrix,
rows: usize,
cols: usize,
_quant: Option<QuantParams<f32>>,
) -> PackedLayout {
let mut info = packed_a_layout::<f32, { Self::MR }>(rows, cols);
info.must_pack = a.col_stride() != 1;
info
Expand All @@ -51,12 +57,18 @@ unsafe impl Kernel<f32, f32, f32> for ArmNeonKernel {
a: Matrix,
rows: Range<usize>,
cols: Range<usize>,
_quant: Option<QuantParams<f32>>,
) {
let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer");
pack_a_block::<f32, { Self::MR }>(out, a, rows, cols);
}

fn packed_b_layout(&self, rows: usize, cols: usize) -> PackedLayout {
fn packed_b_layout(
&self,
rows: usize,
cols: usize,
_quant: Option<QuantParams<f32>>,
) -> PackedLayout {
packed_b_layout::<f32, { Self::NR }>(rows, cols)
}

Expand All @@ -66,6 +78,7 @@ unsafe impl Kernel<f32, f32, f32> for ArmNeonKernel {
b: Matrix,
rows: Range<usize>,
cols: Range<usize>,
_quant: Option<QuantParams<f32>>,
) {
let out = cast_pod_mut_slice(out).expect("incorrect alignment for packing buffer");
pack_b_block::<f32, { Self::NR }>(out, b, rows, cols);
Expand Down Expand Up @@ -98,6 +111,8 @@ unsafe impl Kernel<f32, f32, f32> for ArmNeonKernel {
depth: usize,
alpha: f32,
beta: f32,
_a_quant: Option<QuantParams<f32>>,
_b_quant: Option<QuantParams<f32>>,
) {
const MR: usize = ArmNeonKernel::MR;
const NR: usize = ArmNeonKernel::NR;
Expand Down Expand Up @@ -152,6 +167,8 @@ unsafe impl Kernel<f32, f32, f32> for ArmNeonKernel {
b: Matrix,
alpha: f32,
beta: f32,
_a_quant: Option<QuantParams<f32>>,
_b_quant: Option<QuantParams<f32>>,
) {
// Safety - float32x4_t is supported if this kernel was constructed.
unsafe {
Expand Down
Loading

0 comments on commit b6889c5

Please sign in to comment.