From dcfc5632de83cb4481018896a0891518cef01e2f Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 27 Jan 2025 13:06:50 -0500 Subject: [PATCH] Add ops --- .../src/models/deepseekv3/model.rs | 45 +---- .../src/models/deepseekv3/ops.rs | 163 ++++++++++++++++++ 2 files changed, 164 insertions(+), 44 deletions(-) create mode 100644 candle-transformers/src/models/deepseekv3/ops.rs diff --git a/candle-transformers/src/models/deepseekv3/model.rs b/candle-transformers/src/models/deepseekv3/model.rs index 55971cf36f..79ef3406f2 100644 --- a/candle-transformers/src/models/deepseekv3/model.rs +++ b/candle-transformers/src/models/deepseekv3/model.rs @@ -33,50 +33,6 @@ impl NonZero { } } -#[cfg(feature = "cuda")] -fn count_nonzero_cuda(dtype: candle::DType, d_in: *const c_void, n: u32) -> u32 { - unsafe { - match dtype { - candle::DType::U8 => ffi::count_nonzero_u8(d_in, n), - candle::DType::U32 => ffi::count_nonzero_u32(d_in, n), - candle::DType::I64 => ffi::count_nonzero_i64(d_in, n), - candle::DType::I16 => ffi::count_nonzero_i16(d_in, n), - candle::DType::I32 => ffi::count_nonzero_i32(d_in, n), - candle::DType::BF16 => ffi::count_nonzero_bf16(d_in, n), - candle::DType::F16 => ffi::count_nonzero_f16(d_in, n), - candle::DType::F32 => ffi::count_nonzero_f32(d_in, n), - candle::DType::F64 => ffi::count_nonzero_f64(d_in, n), - candle::DType::F8E4M3 => todo!(), - } - } -} - -#[cfg(feature = "cuda")] -fn nonzero_cuda( - dtype: candle::DType, - d_in: *const c_void, - n: u32, - num_nonzero: u32, - dims: *const c_void, - num_dims: u32, - d_out: *mut c_void, -) { - unsafe { - match dtype { - candle::DType::U8 => ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::U32 => ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::I64 => ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::I32 => ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::I16 => ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::BF16 => ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::F16 => ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::F32 => ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::F64 => ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::F8E4M3 => todo!(), - } - } -} - impl CustomOp1 for NonZero { fn name(&self) -> &'static str { "nonzero" @@ -94,6 +50,7 @@ impl CustomOp1 for NonZero { candle::CpuStorage::F16(vs) => self.nonzero(vs, layout), candle::CpuStorage::F32(vs) => self.nonzero(vs, layout), candle::CpuStorage::F64(vs) => self.nonzero(vs, layout), + candle::CpuStorage::F8E4M3(vs) => self.nonzero(vs, layout), }; let index_len = layout.dims().len(); let result_len = result.len() / index_len; diff --git a/candle-transformers/src/models/deepseekv3/ops.rs b/candle-transformers/src/models/deepseekv3/ops.rs new file mode 100644 index 0000000000..44090af2f7 --- /dev/null +++ b/candle-transformers/src/models/deepseekv3/ops.rs @@ -0,0 +1,163 @@ +use candle::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType}; +use float8::F8E4M3; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + +struct Fp8BlockwiseDequantize { + weight_block_size: Vec, + out_ty: DType, +} + +impl Fp8BlockwiseDequantize { + fn dispatch_dequant_blockwise( + &self, + weight: &[F8E4M3], + scale: &[f32], + weight_l: &candle::Layout, + scale_l: &candle::Layout, + ) -> candle::Result> { + let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]); + let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]); + + let res = vec![T::zero(); weight.len()]; + + (0..grid_y).into_par_iter().for_each(|y| { + (0..grid_x).into_par_iter().for_each(|x| { + let res_ptr = res.as_ptr() as *mut T; + + let scale = scale[y * scale_l.stride()[0] + x]; + + let start_y = y * self.weight_block_size[0]; + let end_y = start_y + self.weight_block_size[0]; + + let start_x = x * self.weight_block_size[1]; + let end_x = start_x + self.weight_block_size[1]; + + for weight_y in start_y..end_y { + if weight_y >= weight_l.dims()[0] { + break; + } + + let row_offset = weight_y * weight_l.stride()[0]; + for weight_x in start_x..end_x { + if weight_x >= weight_l.dims()[1] { + break; + } + + let weight_pos = row_offset + weight_x; + + // SAFETY: We know each thread will only update indepedant values! + unsafe { + *res_ptr.wrapping_add(weight_pos) = + T::from_f64((weight[weight_pos].to_f32() * scale) as f64); + } + } + } + }); + }); + + Ok(res) + } +} + +impl CustomOp2 for Fp8BlockwiseDequantize { + fn name(&self) -> &'static str { + "fp8-blockwise-dequantize" + } + + fn cpu_fwd( + &self, + scale_s: &candle::CpuStorage, + scale_l: &candle::Layout, + weight_s: &candle::CpuStorage, + weight_l: &candle::Layout, + ) -> candle::Result<(candle::CpuStorage, candle::Shape)> { + let candle::CpuStorage::F8E4M3(weight) = weight_s else { + candle::bail!("Expected F8E4M3 weight!"); + }; + let candle::CpuStorage::F32(scale) = scale_s else { + candle::bail!("Expected F8E4M3 weight!"); + }; + if weight_l.start_offset() != 0 || !weight_l.is_contiguous() { + candle::bail!("Expected weight to have start offset 0, continuous"); + } + if scale_l.start_offset() != 0 || !scale_l.is_contiguous() { + candle::bail!("Expected scales to have start offset 0, continuous"); + } + if weight_l.dims().len() != 2 { + candle::bail!("Expected weight to be rank 2"); + } + if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 { + candle::bail!("Expected scale to be rank 2"); + } + + match self.out_ty { + DType::F32 => Ok(( + CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?), + weight_l.shape().clone(), + )), + DType::BF16 => Ok(( + CpuStorage::BF16( + self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?, + ), + weight_l.shape().clone(), + )), + DType::F16 => Ok(( + CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?), + weight_l.shape().clone(), + )), + other => candle::bail!("unexpected out type of fp8 blockwise dequant {other:?}"), + } + } +} + +/// FP8 blockwise dequantize. +/// - Expects weight to be fp8 +/// - Expects inv_scales to be f32 +/// - weight * inv_scale = dequantized +/// - Only works on the CPU +pub fn fp8_blockwise_dequantize( + weight: &Tensor, + inv_scales: &Tensor, + weight_block_size: Vec, + out_ty: DType, +) -> Result { + inv_scales.apply_op2_no_bwd( + weight, + &Fp8BlockwiseDequantize { + weight_block_size, + out_ty, + }, + ) +} + +#[cfg(test)] +mod tests { + use candle::{DType, Device, Result, Tensor}; + + use crate::models::deepseekv3::ops; + + #[test] + fn test_fp8_blockwise_dequant() -> Result<()> { + let dev = &Device::Cpu; + let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?; + let weight_block_size = vec![2, 2]; + let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?; + + let dequant = + ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?; + + let res = dequant.to_vec2::()?; + assert_eq!( + res, + vec![ + vec![0., 0., 1., 1., 2.], + vec![0., 0., 1., 1., 2.], + vec![3., 3., 4., 4., 5.], + vec![3., 3., 4., 4., 5.], + vec![6., 6., 7., 7., 8.], + ] + ); + + Ok(()) + } +}