Skip to content

Commit

Permalink
make vtpq work on top of available kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jan 29, 2025
1 parent 044f72b commit 40d2e63
Showing 1 changed file with 62 additions and 34 deletions.
96 changes: 62 additions & 34 deletions core/src/ops/vptq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,21 @@ impl VPTQGemm {
pre_shift_pack_tensor_shape.push(1);

let mut out = shift_right_zero_and_1(
pack_tensor.clone().into_shape(&pre_shift_pack_tensor_shape)?.into(),
pack_tensor
.clone()
.into_shape(&pre_shift_pack_tensor_shape)?
.into(),
wf.into(),
)?;

let mut post_shift_pack_tensor_shape = pack_tensor_shape.clone();
let pval = post_shift_pack_tensor_shape.pop().unwrap();
post_shift_pack_tensor_shape.push(32 * pval);
out = out.into_tensor().clone().into_shape(&post_shift_pack_tensor_shape)?.into_tvalue();
out = out
.into_tensor()
.clone()
.into_shape(&post_shift_pack_tensor_shape)?
.into_tvalue();

let pad_size = (pack_tensor_shape.last().unwrap_or(&0) * 32) % (index_bits * num_elements);
if pad_size > 0 {
Expand All @@ -78,10 +85,15 @@ impl VPTQGemm {
let auto = out.shape().last().unwrap() / index_bits;
post_pad_pack_tensor_shape.push(auto);
post_pad_pack_tensor_shape.push(index_bits);
out = out.into_tensor().into_shape(&post_pad_pack_tensor_shape)?.into();
out = out
.into_tensor()
.into_shape(&post_pad_pack_tensor_shape)?
.into();

let wf1 = Tensor::from(
Array1::from_iter(0..(index_bits as i32)).to_shape([1, 1, 1, index_bits])?.into_owned(),
Array1::from_iter(0..(index_bits as i32))
.to_shape([1, 1, 1, index_bits])?
.into_owned(),
);

out = shift_left().eval(tvec!(out, wf1.into()))?.pop().unwrap();
Expand Down Expand Up @@ -146,7 +158,7 @@ impl VPTQGemm {
.into_shape(&[num_codebooks, remain, group_size, vector_len])?
.permute_axes(&[0, 1, 3, 2])? // NOTE: costly in tract (applied in memory)
.into_shape(&[num_codebooks, remain * vector_len, group_size])?
.permute_axes(&[1, 0, 2])?// NOTE: costly in tract (applied in memory)
.permute_axes(&[1, 0, 2])? // NOTE: costly in tract (applied in memory)
.into_shape(&[vector_len * remain, num_codebooks * group_size])?;

let dim0 = qweight.shape()[0];
Expand Down Expand Up @@ -210,14 +222,27 @@ impl EvalOp for VPTQGemm {
assert_eq!(outlier_centroids.rank(), 3);
assert!(outlier_centroids.datum_type().is_float());
}
let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()];
let _fdtypes = [
input.datum_type(),
centroids.datum_type(),
outlier_centroids.datum_type(),
];
let fdtypes = HashSet::from(_fdtypes);
if fdtypes.len() != 1 {
log::warn!("force cast centroids to be same type as input: {:?}", input.datum_type());
log::warn!(
"force cast centroids to be same type as input: {:?}",
input.datum_type()
);
centroids = centroids.cast_to_dt(input.datum_type())?.into_owned();
outlier_centroids = outlier_centroids.cast_to_dt(input.datum_type())?.into_owned();
outlier_centroids = outlier_centroids
.cast_to_dt(input.datum_type())?
.into_owned();
}
let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()];
let _fdtypes = [
input.datum_type(),
centroids.datum_type(),
outlier_centroids.datum_type(),
];
let fdtypes = HashSet::from(_fdtypes);
assert!(fdtypes.len() == 1, "mixed dtypes: {_fdtypes:?}");

Expand Down Expand Up @@ -245,16 +270,23 @@ impl EvalOp for VPTQGemm {
if enable_perm {
let axis = 0;
let dim = perm.shape()[0];
let top_k = Topk { axis, largest: false, fallback_k: dim.into() };
let invert_perm =
top_k.eval(tvec!(perm.into_tvalue(), tensor0(dim as u16).into()))?.remove(0);
let top_k = Topk {
axis,
largest: false,
fallback_k: dim.into(),
};
let invert_perm = top_k
.eval(tvec!(perm.into_tvalue(), tensor0(dim as u16).into()))?
.remove(0);
// TODO: manage case with quant dim == 'in' ?
// if self.vector_quant_dim == "in":
// assert True, "Not implemented"
// qweight = qweight[invert_perm, :]

let perm_gather_axis = 1;
let gather_perm = Gather { axis: perm_gather_axis };
let gather_perm = Gather {
axis: perm_gather_axis,
};
qweight = gather_perm
.eval(tvec!(qweight.into(), invert_perm))?
.pop()
Expand All @@ -280,19 +312,9 @@ impl EvalOp for VPTQGemm {

let &n = qweight.shape().last().unwrap();

let (&[m, k], out_shape) = match ishape.len() {
2 => {
let &[m, k] = ishape else {
bail!("unexpected rank: {:?}", ishape.len());
};
(&[m, k], vec![m, n])
}
3 => {
let &[b, m, k] = ishape else {
bail!("unexpected rank: {:?}", ishape.len());
};
(&[m, k], vec![b, m, n])
}
let (m, k, out_shape) = match ishape {
&[m, k] => (m, k, vec![m, n]),
&[b, m, k] => (m, k, vec![b, m, n]),
_ => {
bail!("unexpected rank {:?}", ishape.len())
}
Expand All @@ -301,15 +323,25 @@ impl EvalOp for VPTQGemm {
let input_offset = input.rank() - 2;
let weight_offset = qweight.rank() - 2;

/* this would be better for Intel where there is no f16 support, but the kernel selection
APIs are not up to the task (yet)
let acc_type = if tract_linalg::has_fp16() {
f16::datum_type()
} else {
f32::datum_type()
};
*/
let mmm = op.mmm(data_type, Some(m), Some(k), Some(n)).unwrap();
let (pack_a, pack_b) = &mmm.packings()[0];

let cstore = unsafe { mmm.c_view(input_offset, 1 + input_offset) };

let a = pack_a.prepare_tensor(&input, 1 + input_offset, input_offset)?;
let b = pack_b.prepare_tensor(&qweight, weight_offset, 1 + weight_offset)?;
let last = unsafe {
let out = Tensor::uninitialized::<f32>(out_shape.iter().as_slice())?;
unsafe {
let out = Tensor::uninitialized_dt(data_type, &out_shape)?;
let non_linear = &[
FusedSpec::AddMatMul {
a: tract_linalg::mmm::AsInputValue::Owned(a),
Expand All @@ -319,12 +351,8 @@ impl EvalOp for VPTQGemm {
FusedSpec::Store(cstore.wrap(&out.view())),
];
mmm.run(m, n, non_linear)?;

out
};
// force down cast for now
let last_cdt = last.cast_to_dt(input.datum_type())?.into_owned().into_tvalue();
Ok(tvec!(last_cdt))
Ok(tvec!(out.into_tvalue()))
}
}
}

Expand Down

0 comments on commit 40d2e63

Please sign in to comment.