Skip to content

Commit

Permalink
Add ops
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jan 27, 2025
1 parent e367cd7 commit dcfc563
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 44 deletions.
45 changes: 1 addition & 44 deletions candle-transformers/src/models/deepseekv3/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,50 +33,6 @@ impl NonZero {
}
}

#[cfg(feature = "cuda")]
fn count_nonzero_cuda(dtype: candle::DType, d_in: *const c_void, n: u32) -> u32 {
unsafe {
match dtype {
candle::DType::U8 => ffi::count_nonzero_u8(d_in, n),
candle::DType::U32 => ffi::count_nonzero_u32(d_in, n),
candle::DType::I64 => ffi::count_nonzero_i64(d_in, n),
candle::DType::I16 => ffi::count_nonzero_i16(d_in, n),
candle::DType::I32 => ffi::count_nonzero_i32(d_in, n),
candle::DType::BF16 => ffi::count_nonzero_bf16(d_in, n),
candle::DType::F16 => ffi::count_nonzero_f16(d_in, n),
candle::DType::F32 => ffi::count_nonzero_f32(d_in, n),
candle::DType::F64 => ffi::count_nonzero_f64(d_in, n),
candle::DType::F8E4M3 => todo!(),
}
}
}

#[cfg(feature = "cuda")]
fn nonzero_cuda(
dtype: candle::DType,
d_in: *const c_void,
n: u32,
num_nonzero: u32,
dims: *const c_void,
num_dims: u32,
d_out: *mut c_void,
) {
unsafe {
match dtype {
candle::DType::U8 => ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out),
candle::DType::U32 => ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out),
candle::DType::I64 => ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out),
candle::DType::I32 => ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out),
candle::DType::I16 => ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out),
candle::DType::BF16 => ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out),
candle::DType::F16 => ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out),
candle::DType::F32 => ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out),
candle::DType::F64 => ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out),
candle::DType::F8E4M3 => todo!(),
}
}
}

impl CustomOp1 for NonZero {
fn name(&self) -> &'static str {
"nonzero"
Expand All @@ -94,6 +50,7 @@ impl CustomOp1 for NonZero {
candle::CpuStorage::F16(vs) => self.nonzero(vs, layout),
candle::CpuStorage::F32(vs) => self.nonzero(vs, layout),
candle::CpuStorage::F64(vs) => self.nonzero(vs, layout),
candle::CpuStorage::F8E4M3(vs) => self.nonzero(vs, layout),
};
let index_len = layout.dims().len();
let result_len = result.len() / index_len;
Expand Down
163 changes: 163 additions & 0 deletions candle-transformers/src/models/deepseekv3/ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
use candle::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType};
use float8::F8E4M3;
use rayon::iter::{IntoParallelIterator, ParallelIterator};

struct Fp8BlockwiseDequantize {
weight_block_size: Vec<usize>,
out_ty: DType,
}

impl Fp8BlockwiseDequantize {
fn dispatch_dequant_blockwise<T: WithDType>(
&self,
weight: &[F8E4M3],
scale: &[f32],
weight_l: &candle::Layout,
scale_l: &candle::Layout,
) -> candle::Result<Vec<T>> {
let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]);
let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]);

let res = vec![T::zero(); weight.len()];

(0..grid_y).into_par_iter().for_each(|y| {
(0..grid_x).into_par_iter().for_each(|x| {
let res_ptr = res.as_ptr() as *mut T;

let scale = scale[y * scale_l.stride()[0] + x];

let start_y = y * self.weight_block_size[0];
let end_y = start_y + self.weight_block_size[0];

let start_x = x * self.weight_block_size[1];
let end_x = start_x + self.weight_block_size[1];

for weight_y in start_y..end_y {
if weight_y >= weight_l.dims()[0] {
break;
}

let row_offset = weight_y * weight_l.stride()[0];
for weight_x in start_x..end_x {
if weight_x >= weight_l.dims()[1] {
break;
}

let weight_pos = row_offset + weight_x;

// SAFETY: We know each thread will only update indepedant values!
unsafe {
*res_ptr.wrapping_add(weight_pos) =
T::from_f64((weight[weight_pos].to_f32() * scale) as f64);
}
}
}
});
});

Ok(res)
}
}

impl CustomOp2 for Fp8BlockwiseDequantize {
fn name(&self) -> &'static str {
"fp8-blockwise-dequantize"
}

fn cpu_fwd(
&self,
scale_s: &candle::CpuStorage,
scale_l: &candle::Layout,
weight_s: &candle::CpuStorage,
weight_l: &candle::Layout,
) -> candle::Result<(candle::CpuStorage, candle::Shape)> {
let candle::CpuStorage::F8E4M3(weight) = weight_s else {
candle::bail!("Expected F8E4M3 weight!");
};
let candle::CpuStorage::F32(scale) = scale_s else {
candle::bail!("Expected F8E4M3 weight!");
};
if weight_l.start_offset() != 0 || !weight_l.is_contiguous() {
candle::bail!("Expected weight to have start offset 0, continuous");
}
if scale_l.start_offset() != 0 || !scale_l.is_contiguous() {
candle::bail!("Expected scales to have start offset 0, continuous");
}
if weight_l.dims().len() != 2 {
candle::bail!("Expected weight to be rank 2");
}
if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 {
candle::bail!("Expected scale to be rank 2");
}

match self.out_ty {
DType::F32 => Ok((
CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
weight_l.shape().clone(),
)),
DType::BF16 => Ok((
CpuStorage::BF16(
self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?,
),
weight_l.shape().clone(),
)),
DType::F16 => Ok((
CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?),
weight_l.shape().clone(),
)),
other => candle::bail!("unexpected out type of fp8 blockwise dequant {other:?}"),
}
}
}

/// FP8 blockwise dequantize.
/// - Expects weight to be fp8
/// - Expects inv_scales to be f32
/// - weight * inv_scale = dequantized
/// - Only works on the CPU
pub fn fp8_blockwise_dequantize(
weight: &Tensor,
inv_scales: &Tensor,
weight_block_size: Vec<usize>,
out_ty: DType,
) -> Result<Tensor> {
inv_scales.apply_op2_no_bwd(
weight,
&Fp8BlockwiseDequantize {
weight_block_size,
out_ty,
},
)
}

#[cfg(test)]
mod tests {
use candle::{DType, Device, Result, Tensor};

use crate::models::deepseekv3::ops;

#[test]
fn test_fp8_blockwise_dequant() -> Result<()> {
let dev = &Device::Cpu;
let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?;
let weight_block_size = vec![2, 2];
let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?;

let dequant =
ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?;

let res = dequant.to_vec2::<f32>()?;
assert_eq!(
res,
vec![
vec![0., 0., 1., 1., 2.],
vec![0., 0., 1., 1., 2.],
vec![3., 3., 4., 4., 5.],
vec![3., 3., 4., 4., 5.],
vec![6., 6., 7., 7., 8.],
]
);

Ok(())
}
}

0 comments on commit dcfc563

Please sign in to comment.