Skip to content

Commit

Permalink
Fixes for bnb and more apis in mistralrs-quant (#972)
Browse files Browse the repository at this point in the history
* Add a forward_autocast method

* Add a to_gguf_quant method for bnb

* Handle blocksizes

* Maybe cast

* Add QuantMethod::dequantize_w

* Debug

* Debug

* Debug

* Fix the bug maybe???

* Fix the bug maybe???

* Clippy
  • Loading branch information
EricLBuehler authored Dec 12, 2024
1 parent 9d1f09f commit 458dc5f
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 9 deletions.
2 changes: 0 additions & 2 deletions mistralrs-quant/kernels/bitsandbytes/dequant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
{
vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max;
vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max;
// vals[j*2] = 10000;
// vals[j*2 + 1] = 10000;
}
break;
}
Expand Down
38 changes: 34 additions & 4 deletions mistralrs-quant/src/bitsandbytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ use std::{
sync::{atomic::AtomicUsize, Arc},
};

use candle_core::{Context, DType, Device, Result, Shape, Tensor};
use candle_core::{
quantized::{GgmlDType, QTensor},
Context, DType, Device, Result, Shape, Tensor, D,
};
use candle_nn::VarBuilder;
use serde::Deserialize;

use crate::{IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde};
use crate::{GgufMatMul, IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde};

#[cfg(feature = "cuda")]
mod ffi;
Expand Down Expand Up @@ -219,14 +222,18 @@ impl QuantMethod for BnbLinear {
}),
}
}

fn dequantize_w(&self) -> Result<Tensor> {
Self::dequantize(&self.weight, &self.params, self.quant_ty)
}

fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let w = Self::dequantize(&self.weight, &self.params, self.quant_ty)?
.t()?
.to_dtype(xs.dtype())?;
// dbg!(&w.mean_all());
let res = xs.broadcast_matmul(&w)?;
if let Some(bias) = &self.bias {
res + bias
res.broadcast_add(bias)
} else {
Ok(res)
}
Expand Down Expand Up @@ -261,6 +268,29 @@ impl QuantMethod for BnbLinear {
fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option<NonZeroUsize> {
None
}

fn maybe_to_gguf_quant(self: Arc<Self>) -> Result<Arc<dyn QuantMethod>> {
let weight = Self::dequantize(&self.weight, &self.params, self.quant_ty)?;
let bias = self.bias.clone();

let last_dim = weight.dim(D::Minus1)?;
let dtype = match self.quant_ty {
BnbQuantType::Fp4 | BnbQuantType::Nf4 if last_dim % 256 == 0 => GgmlDType::Q4K,
BnbQuantType::Fp4 | BnbQuantType::Nf4 if last_dim % 64 == 0 && last_dim % 256 != 0 => {
GgmlDType::Q4_0
}
BnbQuantType::Fp4 | BnbQuantType::Nf4 if last_dim % 64 != 0 && last_dim % 256 != 0 => {
GgmlDType::F32
}
BnbQuantType::Int8 => GgmlDType::Q8_0,
_ => unreachable!(),
};
let qmatmul = QTensor::quantize(&weight, dtype)?;
Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
q_weight: Arc::new(qmatmul),
b: bias,
})?))
}
}

impl QuantizedSerde for BnbLinear {
Expand Down
11 changes: 11 additions & 0 deletions mistralrs-quant/src/dummy/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::sync::Arc;

use candle_core::Result;

use crate::{QuantMethod, QuantizedSerde};

#[derive(Debug)]
Expand All @@ -10,6 +14,9 @@ impl QuantMethod for DummyLayer {
{
Ok(Self)
}
fn dequantize_w(&self) -> Result<candle_core::Tensor> {
candle_core::bail!("DummyLayer cannot be dequantized!")
}
fn add_delta_w(
&self,
_delta: &candle_core::Tensor,
Expand Down Expand Up @@ -46,6 +53,10 @@ impl QuantMethod for DummyLayer {
fn quantized_act_type(&self) -> Option<candle_core::DType> {
None
}

fn maybe_to_gguf_quant(self: Arc<Self>) -> Result<Arc<dyn QuantMethod>> {
Ok(self.clone())
}
}

impl QuantizedSerde for DummyLayer {
Expand Down
7 changes: 7 additions & 0 deletions mistralrs-quant/src/fp8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ impl QuantMethod for FP8Linear {
}
}
}
fn dequantize_w(&self) -> Result<candle_core::Tensor> {
Ok(self.dequantize(DType::F32)?.weight().clone())
}

fn forward(&self, x: &Tensor) -> Result<Tensor> {
// Batch matrix multiplication
Expand Down Expand Up @@ -176,6 +179,10 @@ impl QuantMethod for FP8Linear {
| IsqType::HQQ8 => None,
}
}

fn maybe_to_gguf_quant(self: Arc<Self>) -> Result<Arc<dyn QuantMethod>> {
Ok(self.clone())
}
}

// Serialization structure:
Expand Down
8 changes: 8 additions & 0 deletions mistralrs-quant/src/gguf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ impl QuantMethod for GgufMatMul {
}
}

fn dequantize_w(&self) -> Result<Tensor> {
self.w.dequantize_f16()?.to_dtype(DType::F32)
}

fn forward(&self, a: &Tensor) -> Result<Tensor> {
let x = self.w.forward(a)?;
if let Some(ref b) = self.b {
Expand Down Expand Up @@ -148,6 +152,10 @@ impl QuantMethod for GgufMatMul {
fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option<NonZeroUsize> {
None
}

fn maybe_to_gguf_quant(self: Arc<Self>) -> Result<Arc<dyn QuantMethod>> {
Ok(self.clone())
}
}

// Serialization structure:
Expand Down
8 changes: 8 additions & 0 deletions mistralrs-quant/src/gptq/gptq_cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ impl QuantMethod for GptqLayer {
}
}

fn dequantize_w(&self) -> Result<Tensor> {
todo!()
}

fn forward(&self, _a: &Tensor) -> Result<Tensor> {
todo!()
}
Expand Down Expand Up @@ -60,6 +64,10 @@ impl QuantMethod for GptqLayer {
fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option<NonZeroUsize> {
todo!()
}

fn maybe_to_gguf_quant(self: Arc<Self>) -> Result<Arc<dyn QuantMethod>> {
Ok(self.clone())
}
}

impl QuantizedSerde for GptqLayer {
Expand Down
9 changes: 9 additions & 0 deletions mistralrs-quant/src/gptq/gptq_cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,11 @@ impl QuantMethod for GptqLayer {
}
}

fn dequantize_w(&self) -> Result<Tensor> {
// TODO
candle_core::bail!("GptqLayer cannot be dequantized!");
}

fn forward(&self, a: &Tensor) -> Result<Tensor> {
// https://github.com/vllm-project/vllm/blob/ba991d5c84adbc0685075af88333c688ddb06011/vllm/model_executor/layers/quantization/gptq.py#L200
let out_shape = Shape::from_dims(
Expand Down Expand Up @@ -342,6 +347,10 @@ impl QuantMethod for GptqLayer {
fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option<NonZeroUsize> {
None
}

fn maybe_to_gguf_quant(self: Arc<Self>) -> Result<Arc<dyn QuantMethod>> {
Ok(self.clone())
}
}

impl QuantizedSerde for GptqLayer {
Expand Down
8 changes: 8 additions & 0 deletions mistralrs-quant/src/hqq/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,10 @@ impl QuantMethod for HqqLayer {
}
}

fn dequantize_w(&self) -> Result<Tensor> {
self.dequantize()
}

fn forward(&self, a: &Tensor) -> Result<Tensor> {
/*
if self.cfg.force_dequantize {
Expand Down Expand Up @@ -631,6 +635,10 @@ impl QuantMethod for HqqLayer {
// Use 1 because we quantize on the GPU
Some(1.try_into().unwrap())
}

fn maybe_to_gguf_quant(self: Arc<Self>) -> Result<Arc<dyn QuantMethod>> {
Ok(self.clone())
}
}

// Serialization structure:
Expand Down
19 changes: 18 additions & 1 deletion mistralrs-quant/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,20 @@ pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
where
Self: Sized;

fn dequantize_w(&self) -> Result<Tensor>;

/// Compute matmul of `self` and `a`. `self` should contain the weights.
/// Automatically cast to required quantization actiation type and back
fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
let original_ty = a.dtype();
let a = if let Some(t) = self.quantized_act_type() {
a.to_dtype(t)?
} else {
a.clone()
};
self.forward(&a)?.to_dtype(original_ty)
}

/// Compute matmul of `self` and `a`. `self` should contain the weights.
fn forward(&self, a: &Tensor) -> Result<Tensor>;

Expand Down Expand Up @@ -258,6 +272,9 @@ pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
imatrix_weight: Option<Vec<f32>>,
) -> Result<Arc<dyn QuantMethod>>;

/// Convert to an equivalent gguf quantization, if applicable.
fn maybe_to_gguf_quant(self: Arc<Self>) -> Result<Arc<dyn QuantMethod>>;

/// If the quant is backed by a qmatmul.
fn get_bias_mut(&mut self) -> Option<&mut Tensor>;

Expand Down Expand Up @@ -323,7 +340,7 @@ pub fn linear(
match quant_conf.quant_method {
QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
QuantMethodType::Bitsandbytes => {
Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
}
QuantMethodType::Unreachable => unreachable!(),
}
Expand Down
8 changes: 8 additions & 0 deletions mistralrs-quant/src/unquantized/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ impl QuantMethod for UnquantLinear {
}
}

fn dequantize_w(&self) -> Result<Tensor> {
Ok(self.w.clone())
}

fn forward(&self, a: &Tensor) -> Result<Tensor> {
// Batch matrix multiplication
maybe_init_cublas_lt_wrapper();
Expand Down Expand Up @@ -269,6 +273,10 @@ impl QuantMethod for UnquantLinear {
candle_core::bail!("`{}` does not support tracking stats.", self.name())
}
}

fn maybe_to_gguf_quant(self: Arc<Self>) -> Result<Arc<dyn QuantMethod>> {
Ok(self.clone())
}
}

// Serialization structure:
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/examples/simple_stream/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async fn main() -> Result<()> {
let mut buf = std::io::BufWriter::new(lock);
while let Some(chunk) = stream.next().await {
if let Response::Chunk(chunk) = chunk {
buf.write(chunk.choices[0].delta.content.as_bytes())?;
buf.write_all(chunk.choices[0].delta.content.as_bytes())?;
} else {
// Handle errors
}
Expand Down
2 changes: 1 addition & 1 deletion mistralrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
//! ```no_run
//! use anyhow::Result;
//! use mistralrs::{
//! IsqType, PagedAttentionMetaBuilder, TextMessageRole, TextMessages, TextModelBuilder,
//! IsqType, PagedAttentionMetaBuilder, TextMessageRole, TextMessages, TextModelBuilder, Response
//! };
//!
//! #[tokio::main]
Expand Down

0 comments on commit 458dc5f

Please sign in to comment.