-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e367cd7
commit dcfc563
Showing
2 changed files
with
164 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
use candle::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType}; | ||
use float8::F8E4M3; | ||
use rayon::iter::{IntoParallelIterator, ParallelIterator}; | ||
|
||
struct Fp8BlockwiseDequantize { | ||
weight_block_size: Vec<usize>, | ||
out_ty: DType, | ||
} | ||
|
||
impl Fp8BlockwiseDequantize { | ||
fn dispatch_dequant_blockwise<T: WithDType>( | ||
&self, | ||
weight: &[F8E4M3], | ||
scale: &[f32], | ||
weight_l: &candle::Layout, | ||
scale_l: &candle::Layout, | ||
) -> candle::Result<Vec<T>> { | ||
let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]); | ||
let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]); | ||
|
||
let res = vec![T::zero(); weight.len()]; | ||
|
||
(0..grid_y).into_par_iter().for_each(|y| { | ||
(0..grid_x).into_par_iter().for_each(|x| { | ||
let res_ptr = res.as_ptr() as *mut T; | ||
|
||
let scale = scale[y * scale_l.stride()[0] + x]; | ||
|
||
let start_y = y * self.weight_block_size[0]; | ||
let end_y = start_y + self.weight_block_size[0]; | ||
|
||
let start_x = x * self.weight_block_size[1]; | ||
let end_x = start_x + self.weight_block_size[1]; | ||
|
||
for weight_y in start_y..end_y { | ||
if weight_y >= weight_l.dims()[0] { | ||
break; | ||
} | ||
|
||
let row_offset = weight_y * weight_l.stride()[0]; | ||
for weight_x in start_x..end_x { | ||
if weight_x >= weight_l.dims()[1] { | ||
break; | ||
} | ||
|
||
let weight_pos = row_offset + weight_x; | ||
|
||
// SAFETY: We know each thread will only update indepedant values! | ||
unsafe { | ||
*res_ptr.wrapping_add(weight_pos) = | ||
T::from_f64((weight[weight_pos].to_f32() * scale) as f64); | ||
} | ||
} | ||
} | ||
}); | ||
}); | ||
|
||
Ok(res) | ||
} | ||
} | ||
|
||
impl CustomOp2 for Fp8BlockwiseDequantize { | ||
fn name(&self) -> &'static str { | ||
"fp8-blockwise-dequantize" | ||
} | ||
|
||
fn cpu_fwd( | ||
&self, | ||
scale_s: &candle::CpuStorage, | ||
scale_l: &candle::Layout, | ||
weight_s: &candle::CpuStorage, | ||
weight_l: &candle::Layout, | ||
) -> candle::Result<(candle::CpuStorage, candle::Shape)> { | ||
let candle::CpuStorage::F8E4M3(weight) = weight_s else { | ||
candle::bail!("Expected F8E4M3 weight!"); | ||
}; | ||
let candle::CpuStorage::F32(scale) = scale_s else { | ||
candle::bail!("Expected F8E4M3 weight!"); | ||
}; | ||
if weight_l.start_offset() != 0 || !weight_l.is_contiguous() { | ||
candle::bail!("Expected weight to have start offset 0, continuous"); | ||
} | ||
if scale_l.start_offset() != 0 || !scale_l.is_contiguous() { | ||
candle::bail!("Expected scales to have start offset 0, continuous"); | ||
} | ||
if weight_l.dims().len() != 2 { | ||
candle::bail!("Expected weight to be rank 2"); | ||
} | ||
if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 { | ||
candle::bail!("Expected scale to be rank 2"); | ||
} | ||
|
||
match self.out_ty { | ||
DType::F32 => Ok(( | ||
CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?), | ||
weight_l.shape().clone(), | ||
)), | ||
DType::BF16 => Ok(( | ||
CpuStorage::BF16( | ||
self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?, | ||
), | ||
weight_l.shape().clone(), | ||
)), | ||
DType::F16 => Ok(( | ||
CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?), | ||
weight_l.shape().clone(), | ||
)), | ||
other => candle::bail!("unexpected out type of fp8 blockwise dequant {other:?}"), | ||
} | ||
} | ||
} | ||
|
||
/// FP8 blockwise dequantize. | ||
/// - Expects weight to be fp8 | ||
/// - Expects inv_scales to be f32 | ||
/// - weight * inv_scale = dequantized | ||
/// - Only works on the CPU | ||
pub fn fp8_blockwise_dequantize( | ||
weight: &Tensor, | ||
inv_scales: &Tensor, | ||
weight_block_size: Vec<usize>, | ||
out_ty: DType, | ||
) -> Result<Tensor> { | ||
inv_scales.apply_op2_no_bwd( | ||
weight, | ||
&Fp8BlockwiseDequantize { | ||
weight_block_size, | ||
out_ty, | ||
}, | ||
) | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use candle::{DType, Device, Result, Tensor}; | ||
|
||
use crate::models::deepseekv3::ops; | ||
|
||
#[test] | ||
fn test_fp8_blockwise_dequant() -> Result<()> { | ||
let dev = &Device::Cpu; | ||
let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?; | ||
let weight_block_size = vec![2, 2]; | ||
let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?; | ||
|
||
let dequant = | ||
ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?; | ||
|
||
let res = dequant.to_vec2::<f32>()?; | ||
assert_eq!( | ||
res, | ||
vec![ | ||
vec![0., 0., 1., 1., 2.], | ||
vec![0., 0., 1., 1., 2.], | ||
vec![3., 3., 4., 4., 5.], | ||
vec![3., 3., 4., 4., 5.], | ||
vec![6., 6., 7., 7., 8.], | ||
] | ||
); | ||
|
||
Ok(()) | ||
} | ||
} |