diff --git a/mistralrs-quant/kernels/bitsandbytes/dequant.cu b/mistralrs-quant/kernels/bitsandbytes/dequant.cu index 6f6d4bfe1..7dce07d8d 100644 --- a/mistralrs-quant/kernels/bitsandbytes/dequant.cu +++ b/mistralrs-quant/kernels/bitsandbytes/dequant.cu @@ -150,8 +150,6 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs { vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; - // vals[j*2] = 10000; - // vals[j*2 + 1] = 10000; } break; } diff --git a/mistralrs-quant/src/bitsandbytes/mod.rs b/mistralrs-quant/src/bitsandbytes/mod.rs index 3a886b004..f38c82c7b 100644 --- a/mistralrs-quant/src/bitsandbytes/mod.rs +++ b/mistralrs-quant/src/bitsandbytes/mod.rs @@ -4,11 +4,14 @@ use std::{ sync::{atomic::AtomicUsize, Arc}, }; -use candle_core::{Context, DType, Device, Result, Shape, Tensor}; +use candle_core::{ + quantized::{GgmlDType, QTensor}, + Context, DType, Device, Result, Shape, Tensor, D, +}; use candle_nn::VarBuilder; use serde::Deserialize; -use crate::{IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde}; +use crate::{GgufMatMul, IsqType, QuantMethod, QuantMethodConfig, QuantizedSerde}; #[cfg(feature = "cuda")] mod ffi; @@ -219,14 +222,18 @@ impl QuantMethod for BnbLinear { }), } } + + fn dequantize_w(&self) -> Result { + Self::dequantize(&self.weight, &self.params, self.quant_ty) + } + fn forward(&self, xs: &Tensor) -> Result { let w = Self::dequantize(&self.weight, &self.params, self.quant_ty)? .t()? .to_dtype(xs.dtype())?; - // dbg!(&w.mean_all()); let res = xs.broadcast_matmul(&w)?; if let Some(bias) = &self.bias { - res + bias + res.broadcast_add(bias) } else { Ok(res) } @@ -261,6 +268,29 @@ impl QuantMethod for BnbLinear { fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option { None } + + fn maybe_to_gguf_quant(self: Arc) -> Result> { + let weight = Self::dequantize(&self.weight, &self.params, self.quant_ty)?; + let bias = self.bias.clone(); + + let last_dim = weight.dim(D::Minus1)?; + let dtype = match self.quant_ty { + BnbQuantType::Fp4 | BnbQuantType::Nf4 if last_dim % 256 == 0 => GgmlDType::Q4K, + BnbQuantType::Fp4 | BnbQuantType::Nf4 if last_dim % 64 == 0 && last_dim % 256 != 0 => { + GgmlDType::Q4_0 + } + BnbQuantType::Fp4 | BnbQuantType::Nf4 if last_dim % 64 != 0 && last_dim % 256 != 0 => { + GgmlDType::F32 + } + BnbQuantType::Int8 => GgmlDType::Q8_0, + _ => unreachable!(), + }; + let qmatmul = QTensor::quantize(&weight, dtype)?; + Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf { + q_weight: Arc::new(qmatmul), + b: bias, + })?)) + } } impl QuantizedSerde for BnbLinear { diff --git a/mistralrs-quant/src/dummy/mod.rs b/mistralrs-quant/src/dummy/mod.rs index a9af0e9c5..8eaf27fb1 100644 --- a/mistralrs-quant/src/dummy/mod.rs +++ b/mistralrs-quant/src/dummy/mod.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; + +use candle_core::Result; + use crate::{QuantMethod, QuantizedSerde}; #[derive(Debug)] @@ -10,6 +14,9 @@ impl QuantMethod for DummyLayer { { Ok(Self) } + fn dequantize_w(&self) -> Result { + candle_core::bail!("DummyLayer cannot be dequantized!") + } fn add_delta_w( &self, _delta: &candle_core::Tensor, @@ -46,6 +53,10 @@ impl QuantMethod for DummyLayer { fn quantized_act_type(&self) -> Option { None } + + fn maybe_to_gguf_quant(self: Arc) -> Result> { + Ok(self.clone()) + } } impl QuantizedSerde for DummyLayer { diff --git a/mistralrs-quant/src/fp8/mod.rs b/mistralrs-quant/src/fp8/mod.rs index 15ffb0d48..661f0eb44 100644 --- a/mistralrs-quant/src/fp8/mod.rs +++ b/mistralrs-quant/src/fp8/mod.rs @@ -59,6 +59,9 @@ impl QuantMethod for FP8Linear { } } } + fn dequantize_w(&self) -> Result { + Ok(self.dequantize(DType::F32)?.weight().clone()) + } fn forward(&self, x: &Tensor) -> Result { // Batch matrix multiplication @@ -176,6 +179,10 @@ impl QuantMethod for FP8Linear { | IsqType::HQQ8 => None, } } + + fn maybe_to_gguf_quant(self: Arc) -> Result> { + Ok(self.clone()) + } } // Serialization structure: diff --git a/mistralrs-quant/src/gguf/mod.rs b/mistralrs-quant/src/gguf/mod.rs index 2e8f3d738..b7f43f80d 100644 --- a/mistralrs-quant/src/gguf/mod.rs +++ b/mistralrs-quant/src/gguf/mod.rs @@ -43,6 +43,10 @@ impl QuantMethod for GgufMatMul { } } + fn dequantize_w(&self) -> Result { + self.w.dequantize_f16()?.to_dtype(DType::F32) + } + fn forward(&self, a: &Tensor) -> Result { let x = self.w.forward(a)?; if let Some(ref b) = self.b { @@ -148,6 +152,10 @@ impl QuantMethod for GgufMatMul { fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option { None } + + fn maybe_to_gguf_quant(self: Arc) -> Result> { + Ok(self.clone()) + } } // Serialization structure: diff --git a/mistralrs-quant/src/gptq/gptq_cpu.rs b/mistralrs-quant/src/gptq/gptq_cpu.rs index 04a822b49..a14a7c321 100644 --- a/mistralrs-quant/src/gptq/gptq_cpu.rs +++ b/mistralrs-quant/src/gptq/gptq_cpu.rs @@ -27,6 +27,10 @@ impl QuantMethod for GptqLayer { } } + fn dequantize_w(&self) -> Result { + todo!() + } + fn forward(&self, _a: &Tensor) -> Result { todo!() } @@ -60,6 +64,10 @@ impl QuantMethod for GptqLayer { fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option { todo!() } + + fn maybe_to_gguf_quant(self: Arc) -> Result> { + Ok(self.clone()) + } } impl QuantizedSerde for GptqLayer { diff --git a/mistralrs-quant/src/gptq/gptq_cuda.rs b/mistralrs-quant/src/gptq/gptq_cuda.rs index dd2d07fe5..7c95d5ea4 100644 --- a/mistralrs-quant/src/gptq/gptq_cuda.rs +++ b/mistralrs-quant/src/gptq/gptq_cuda.rs @@ -268,6 +268,11 @@ impl QuantMethod for GptqLayer { } } + fn dequantize_w(&self) -> Result { + // TODO + candle_core::bail!("GptqLayer cannot be dequantized!"); + } + fn forward(&self, a: &Tensor) -> Result { // https://github.com/vllm-project/vllm/blob/ba991d5c84adbc0685075af88333c688ddb06011/vllm/model_executor/layers/quantization/gptq.py#L200 let out_shape = Shape::from_dims( @@ -342,6 +347,10 @@ impl QuantMethod for GptqLayer { fn get_max_isq_cpu_threads(&self, _dtype: IsqType) -> Option { None } + + fn maybe_to_gguf_quant(self: Arc) -> Result> { + Ok(self.clone()) + } } impl QuantizedSerde for GptqLayer { diff --git a/mistralrs-quant/src/hqq/mod.rs b/mistralrs-quant/src/hqq/mod.rs index 82b609ba3..6a9de09ec 100644 --- a/mistralrs-quant/src/hqq/mod.rs +++ b/mistralrs-quant/src/hqq/mod.rs @@ -560,6 +560,10 @@ impl QuantMethod for HqqLayer { } } + fn dequantize_w(&self) -> Result { + self.dequantize() + } + fn forward(&self, a: &Tensor) -> Result { /* if self.cfg.force_dequantize { @@ -631,6 +635,10 @@ impl QuantMethod for HqqLayer { // Use 1 because we quantize on the GPU Some(1.try_into().unwrap()) } + + fn maybe_to_gguf_quant(self: Arc) -> Result> { + Ok(self.clone()) + } } // Serialization structure: diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs index a9e912701..6b8f5bdef 100644 --- a/mistralrs-quant/src/lib.rs +++ b/mistralrs-quant/src/lib.rs @@ -231,6 +231,20 @@ pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde { where Self: Sized; + fn dequantize_w(&self) -> Result; + + /// Compute matmul of `self` and `a`. `self` should contain the weights. + /// Automatically cast to required quantization actiation type and back + fn forward_autocast(&self, a: &Tensor) -> Result { + let original_ty = a.dtype(); + let a = if let Some(t) = self.quantized_act_type() { + a.to_dtype(t)? + } else { + a.clone() + }; + self.forward(&a)?.to_dtype(original_ty) + } + /// Compute matmul of `self` and `a`. `self` should contain the weights. fn forward(&self, a: &Tensor) -> Result; @@ -258,6 +272,9 @@ pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde { imatrix_weight: Option>, ) -> Result>; + /// Convert to an equivalent gguf quantization, if applicable. + fn maybe_to_gguf_quant(self: Arc) -> Result>; + /// If the quant is backed by a qmatmul. fn get_bias_mut(&mut self) -> Option<&mut Tensor>; @@ -323,7 +340,7 @@ pub fn linear( match quant_conf.quant_method { QuantMethodType::Gptq => gptq_linear(in_dim, out_dim, quant_conf, vb)?, QuantMethodType::Bitsandbytes => { - Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_> + Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_> } QuantMethodType::Unreachable => unreachable!(), } diff --git a/mistralrs-quant/src/unquantized/mod.rs b/mistralrs-quant/src/unquantized/mod.rs index 2b0320014..1da356782 100644 --- a/mistralrs-quant/src/unquantized/mod.rs +++ b/mistralrs-quant/src/unquantized/mod.rs @@ -45,6 +45,10 @@ impl QuantMethod for UnquantLinear { } } + fn dequantize_w(&self) -> Result { + Ok(self.w.clone()) + } + fn forward(&self, a: &Tensor) -> Result { // Batch matrix multiplication maybe_init_cublas_lt_wrapper(); @@ -269,6 +273,10 @@ impl QuantMethod for UnquantLinear { candle_core::bail!("`{}` does not support tracking stats.", self.name()) } } + + fn maybe_to_gguf_quant(self: Arc) -> Result> { + Ok(self.clone()) + } } // Serialization structure: diff --git a/mistralrs/examples/simple_stream/main.rs b/mistralrs/examples/simple_stream/main.rs index 55771c564..58136385c 100644 --- a/mistralrs/examples/simple_stream/main.rs +++ b/mistralrs/examples/simple_stream/main.rs @@ -45,7 +45,7 @@ async fn main() -> Result<()> { let mut buf = std::io::BufWriter::new(lock); while let Some(chunk) = stream.next().await { if let Response::Chunk(chunk) = chunk { - buf.write(chunk.choices[0].delta.content.as_bytes())?; + buf.write_all(chunk.choices[0].delta.content.as_bytes())?; } else { // Handle errors } diff --git a/mistralrs/src/lib.rs b/mistralrs/src/lib.rs index f197926d4..75e2fa1e9 100644 --- a/mistralrs/src/lib.rs +++ b/mistralrs/src/lib.rs @@ -54,7 +54,7 @@ //! ```no_run //! use anyhow::Result; //! use mistralrs::{ -//! IsqType, PagedAttentionMetaBuilder, TextMessageRole, TextMessages, TextModelBuilder, +//! IsqType, PagedAttentionMetaBuilder, TextMessageRole, TextMessages, TextModelBuilder, Response //! }; //! //! #[tokio::main]