From e367cd7894ff6d268b924a09e636e4cc728d028b Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Thu, 17 Oct 2024 07:28:05 -0400 Subject: [PATCH] Add the f8 e4m3 dtype --- Cargo.toml | 1 + candle-core/Cargo.toml | 3 +- candle-core/src/convert.rs | 6 + candle-core/src/cpu_backend/mod.rs | 121 ++++++++++++++++ candle-core/src/cpu_backend/utils.rs | 82 +++++++++++ candle-core/src/cuda_backend/device.rs | 50 ++++++- candle-core/src/cuda_backend/mod.rs | 193 +++++++++++++++++++++++++ candle-core/src/cuda_backend/utils.rs | 8 + candle-core/src/display.rs | 9 ++ candle-core/src/dtype.rs | 23 ++- candle-core/src/metal_backend/mod.rs | 4 + candle-core/src/npy.rs | 10 ++ candle-core/src/op.rs | 67 +++++++++ candle-core/src/safetensors.rs | 5 + candle-core/src/sort.rs | 28 ++++ candle-kernels/src/affine.cu | 26 ++-- candle-kernels/src/binary.cu | 15 ++ candle-kernels/src/cast.cu | 86 +++++++++++ candle-kernels/src/compatibility.cuh | 1 + candle-kernels/src/conv.cu | 12 ++ candle-kernels/src/cuda_utils.cuh | 23 +++ candle-kernels/src/fill.cu | 5 + candle-kernels/src/indexing.cu | 95 ++++++++++++ candle-kernels/src/reduce.cu | 8 + candle-kernels/src/sort.cu | 3 + candle-kernels/src/ternary.cu | 6 + candle-kernels/src/unary.cu | 27 ++++ candle-pyo3/Cargo.toml | 1 + candle-pyo3/src/lib.rs | 3 + candle-transformers/src/models/mod.rs | 1 + 30 files changed, 901 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e8d1f76988..d5a527b105 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.4.1" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +float8 = { version = "0.1.0", features = ["num-traits", "rand_distr"] } hound = "3.5.1" image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] } imageproc = { version = "0.24.0", default-features = false } diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 4ffc869ff8..799c40308b 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -18,6 +18,7 @@ metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } half = { workspace = true } +float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } libc = { workspace = true, optional = true } memmap2 = { workspace = true } @@ -42,7 +43,7 @@ criterion = { workspace = true } [features] default = [] -cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"] +cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda", "float8/cuda"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index 5ea5612a7c..db7bf6a4a8 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -1,5 +1,6 @@ //! Implement conversion traits for tensors use crate::{DType, Device, Error, Tensor, WithDType}; +use float8::F8E4M3; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::convert::TryFrom; @@ -139,6 +140,11 @@ impl Tensor { let vs = vs.to_vec1::()?; f.write_all(&vs)?; } + DType::F8E4M3 => { + for v in vs.to_vec1::()? { + f.write_u8(v.to_bits())? + } + } } Ok(()) } diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 11ff1a406f..6d1185dbef 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -2,6 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; +use float8::F8E4M3; use half::{bf16, f16}; use rayon::prelude::*; @@ -25,6 +26,7 @@ pub enum CpuStorage { F16(Vec), F32(Vec), F64(Vec), + F8E4M3(Vec), } #[derive(Debug, Clone)] @@ -36,6 +38,7 @@ pub enum CpuStorageRef<'a> { F16(&'a [f16]), F32(&'a [f32]), F64(&'a [f64]), + F8E4M3(&'a [F8E4M3]), } #[derive(Debug, Clone)] @@ -1623,6 +1626,17 @@ impl CpuStorage { .concat(); Self::F64(storages) } + Self::F8E4M3(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F8E4M3(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F8E4M3(storages) + } }; Ok(s) } @@ -1640,6 +1654,7 @@ impl BackendStorage for CpuStorage { Self::F16(_) => DType::F16, Self::F32(_) => DType::F32, Self::F64(_) => DType::F64, + Self::F8E4M3(_) => DType::F8E4M3, } } @@ -1674,6 +1689,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, bf16::from_f64); Ok(Self::BF16(data)) } + (Self::F8E4M3(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); + Ok(Self::BF16(data)) + } (Self::U8(storage), DType::F16) => { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) @@ -1702,6 +1721,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, f16::from_f64); Ok(Self::F16(data)) } + (Self::F8E4M3(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); + Ok(Self::F16(data)) + } (Self::U8(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) @@ -1730,6 +1753,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } + (Self::F8E4M3(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v.to_f32()); + Ok(Self::F32(data)) + } (Self::U8(storage), DType::U8) => { let data = unary_map(storage, layout, |v| v); Ok(Self::U8(data)) @@ -1758,6 +1785,14 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } + (Self::F8E4M3(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } + (Self::F8E4M3(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } (Self::U8(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) @@ -1786,6 +1821,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } + (Self::F8E4M3(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u32); + Ok(Self::U32(data)) + } (Self::U8(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -1814,6 +1853,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } + (Self::F8E4M3(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i64); + Ok(Self::I64(data)) + } (Self::U8(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -1842,6 +1885,42 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::F64(data)) } + (Self::F8E4M3(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v.to_f64()); + Ok(Self::F64(data)) + } + (Self::U8(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::U32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::BF16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, F8E4M3::from_f32); + Ok(Self::F8E4M3(data)) + } + (Self::F64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, F8E4M3::from_f64); + Ok(Self::F8E4M3(data)) + } + (Self::F8E4M3(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::F8E4M3(data)) + } } } @@ -1955,6 +2034,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v.powf(e)); Ok(Self::F64(data)) } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e))); + Ok(Self::F8E4M3(data)) + } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), @@ -1980,6 +2063,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| elu(v, alpha)); Ok(Self::F64(data)) } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha))); + Ok(Self::F8E4M3(data)) + } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), @@ -2024,6 +2111,15 @@ impl BackendStorage for CpuStorage { Ok(Self::F64(data)) } } + Self::F8E4M3(storage) => { + if B::F8E4M3_VEC { + let data = unary_map_vec(storage, layout, B::f8e4m3, B::f8e4m3_vec); + Ok(Self::F8E4M3(data)) + } else { + let data = unary_map(storage, layout, B::f8e4m3); + Ok(Self::F8E4M3(data)) + } + } Self::U8(storage) => { let data = unary_map(storage, layout, B::u8); Ok(Self::U8(data)) @@ -2505,6 +2601,15 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::F8E4M3(data)) + } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min as f32, max as f32); @@ -2551,6 +2656,15 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std)) + .map_err(Error::wrap)?; + for _i in 0..elem_count { + data.push(normal.sample(&mut rng)) + } + Ok(CpuStorage::F8E4M3(data)) + } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let normal = @@ -2614,6 +2728,11 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::F64(v) } + DType::F8E4M3 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F8E4M3(v) + } }; Ok(storage) } @@ -2626,6 +2745,7 @@ impl BackendDevice for CpuDevice { DType::I64 => CpuStorage::I64(vec![1i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ONE; elem_count]), DType::F32 => CpuStorage::F32(vec![1f32; elem_count]), DType::F64 => CpuStorage::F64(vec![1f64; elem_count]), }; @@ -2640,6 +2760,7 @@ impl BackendDevice for CpuDevice { DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]), DType::F32 => CpuStorage::F32(vec![0f32; elem_count]), DType::F64 => CpuStorage::F64(vec![0f64; elem_count]), }; diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 3e0c69b4f7..f61005a9b0 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -15,6 +15,7 @@ pub trait Map1 { C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)), C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)), + C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)), } } } @@ -31,6 +32,7 @@ pub trait Map1Any { C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), C::F32(vs) => Ok(self.f(vs, layout, C::F32)?), C::F64(vs) => Ok(self.f(vs, layout, C::F64)?), + C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?), } } } @@ -48,6 +50,85 @@ pub trait Map2 { (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map3 { + const OP: &'static str; + #[allow(clippy::too_many_arguments)] + fn f( + &self, + v1: &[T], + l1: &Layout, + v2: &[T], + l2: &Layout, + v3: &mut [T], + l3: &Layout, + s: Option, + ) -> Result<()>; + + #[allow(clippy::too_many_arguments)] + fn map( + &self, + v1: &C, + l1: &Layout, + v2: &C, + l2: &Layout, + v3: &mut C, + l3: &Layout, + s: Option, + ) -> Result<()> { + let v3d = v3.dtype(); + match (v1, v2, v3) { + (C::U8(v1), C::U8(v2), C::U8(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::U32(v1), C::U32(v2), C::U32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::I64(v1), C::I64(v2), C::I64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::BF16(v1), C::BF16(v2), C::BF16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F16(v1), C::F16(v2), C::F16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F32(v1), C::F32(v2), C::F32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F64(v1), C::F64(v2), C::F64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F8E4M3(v1), C::F8E4M3(v2), C::F8E4M3(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + _ => Err(Error::DTypeMismatchBinaryOp3 { + lhs: v1.dtype(), + rhs: v2.dtype(), + c: v3d, + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map2Alpha { + const OP: &'static str; + #[allow(clippy::too_many_arguments)] + fn f( + &self, + v1: &[T], + l1: &Layout, + v2: &[T], + l2: &Layout, + s: Option, + ) -> Result>; + + #[allow(clippy::too_many_arguments)] + fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout, s: Option) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2, s)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2, s)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2, s)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2, s)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2, s)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2, s)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2, s)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), @@ -71,6 +152,7 @@ pub trait Map2U8 { (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index d3bd29030e..cc597cc702 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -3,6 +3,7 @@ use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; +use float8::F8E4M3; use half::{bf16, f16}; use std::sync::{Arc, Mutex}; @@ -136,6 +137,14 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f8_e4m3", kernels::FILL)?; + let params = (&data, v, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -243,6 +252,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -256,7 +269,8 @@ impl BackendDevice for CudaDevice { let slice = match dtype { // TODO: Add support for F16 and BF16 though this is likely to require some upstream // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 + | DType::F8E4M3 => { Err(CudaError::UnsupportedDtype { dtype, op: "rand_uniform", @@ -300,13 +314,17 @@ impl BackendDevice for CudaDevice { elem_count }; let slice = match dtype { - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F16 + | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; curand @@ -362,6 +380,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -399,6 +421,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorageRef::F8E4M3(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -436,6 +462,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorage::F8E4M3(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -473,6 +503,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorage::F8E4M3(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 2cd97c182e..8b90fc3e35 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -9,6 +9,7 @@ use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, }; +use float8::F8E4M3; use half::{bf16, f16}; #[cfg(feature = "cudnn")] @@ -54,6 +55,7 @@ pub enum CudaStorageSlice { F16(CudaSlice), F32(CudaSlice), F64(CudaSlice), + F8E4M3(CudaSlice), } struct Clone; @@ -1033,6 +1035,7 @@ cuda_dtype!(f16, F16); cuda_dtype!(bf16, BF16); cuda_dtype!(f32, F32); cuda_dtype!(f64, F64); +cuda_dtype!(F8E4M3, F8E4M3); impl CudaStorage { pub fn wrap_cuda_slice(slice: CudaSlice, device: CudaDevice) -> CudaStorage { @@ -1155,6 +1158,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::F16(_) => DType::F16, CudaStorageSlice::F32(_) => DType::F32, CudaStorageSlice::F64(_) => DType::F64, + CudaStorageSlice::F8E4M3(_) => DType::F8E4M3, } } @@ -1181,6 +1185,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F8E4M3(inp) => *inp.slice(start_o..).device_ptr(), }; let inp = &inp; @@ -1229,6 +1234,12 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(out) } + DType::F8E4M3 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F8E4M3(out) + } }; Ok(Self { slice, @@ -1320,6 +1331,11 @@ impl BackendStorage for CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F64(cpu_storage)) } + CudaStorageSlice::F8E4M3(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::F8E4M3(cpu_storage)) + } } } @@ -1772,6 +1788,11 @@ impl BackendStorage for CudaStorage { *d.slice(dst_o..).device_ptr(), "copy2d_f64", ), + (S::F8E4M3(s), S::F8E4M3(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f8_e4m3", + ), _ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?, }; let func = dev.get_or_load_func(kname, kernels::FILL)?; @@ -1829,6 +1850,18 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()? } } + (CudaStorageSlice::F8E4M3(src), CudaStorageSlice::F8E4M3(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_f8_e4m3", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } (CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { @@ -2084,3 +2117,163 @@ unsafe fn gemm_strided_batched_bf16( sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, ) } + +pub struct KVConcat { + pub concat_dim: usize, +} +impl crate::CustomOp2 for KVConcat { + fn name(&self) -> &'static str { + "kvconcat" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + crate::bail!("no cpu support for kvconcat") + } + + fn cuda_fwd( + &self, + ltensor: &CudaStorage, + ltensor_l: &Layout, + rtensor: &CudaStorage, + rtensor_l: &Layout, + ) -> Result<(CudaStorage, Shape)> { + assert!(self.concat_dim == 2 || self.concat_dim == 0); //must be in the dim of sequence len + let dev = <ensor.device; + let elem_count = ltensor_l.shape().elem_count() + rtensor_l.shape().elem_count(); + let dims_l = ltensor_l.shape().dims(); + let dims_r = rtensor_l.shape().dims(); + let dim_size = dims_l.len(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + + let chunk_l = if dim_size > 3 { + dims_l[0] * dims_l[1] + } else { + dims_l[0] + }; + let chunk_r = if dim_size > 3 { + dims_r[0] * dims_r[1] + } else { + dims_r[0] + }; + let lstride = if dim_size > 3 { + dims_l[2] * dims_l[3] + } else { + dims_l[1] * dims_l[2] + }; + let rstride = if dim_size > 3 { + dims_r[2] * dims_r[3] + } else { + dims_r[1] * dims_r[2] + }; + + let slice = match (<ensor.slice, &rtensor.slice) { + (CudaStorageSlice::BF16(left_), CudaStorageSlice::BF16(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_bf16", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::BF16(out) + } + (CudaStorageSlice::F32(left_), CudaStorageSlice::F32(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f32", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F32(out) + } + (CudaStorageSlice::F16(left_), CudaStorageSlice::F16(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f16", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F16(out) + } + (CudaStorageSlice::F64(left_), CudaStorageSlice::F64(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f64", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F64(out) + } + (CudaStorageSlice::U8(left_), CudaStorageSlice::U8(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_u8", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U8(out) + } + _ => Err(CudaError::InternalError("dtype mismatch in kvconcat op"))?, + }; + + let mut lshape: Vec = ltensor_l.shape().dims().to_vec(); + if self.concat_dim == 0 { + lshape[0] += rtensor_l.shape().dims()[0]; + } else { + if dim_size > 3 { + lshape[2] += rtensor_l.shape().dims()[2]; + } else { + lshape[1] += rtensor_l.shape().dims()[1]; + } + } + + let device = dev.clone(); + Ok(( + CudaStorage { + slice: slice, + device, + }, + lshape.into(), + )) + } +} diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index c1210727ad..e6bb92fe13 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -24,6 +24,7 @@ pub trait Map1 { S::F16(s) => S::F16(self.f(s, d, l)?), S::F32(s) => S::F32(self.f(s, d, l)?), S::F64(s) => S::F64(self.f(s, d, l)?), + S::F8E4M3(s) => S::F8E4M3(self.f(s, d, l)?), }; Ok(out) } @@ -48,6 +49,7 @@ pub trait Map2 { (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), + (S::F8E4M3(s1), S::F8E4M3(s2)) => S::F8E4M3(self.f(s1, l1, s2, l2, d)?), _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; Ok(out) @@ -86,6 +88,9 @@ pub trait Map3 { (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F8E4M3(s1), S::F8E4M3(s2), S::F8E4M3(s3)) => { + S::F8E4M3(self.f(s1, l1, s2, l2, s3, l3, d)?) + } _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?, }; Ok(out) @@ -118,6 +123,7 @@ pub trait Map2InPlace { (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), (S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F8E4M3(dst), S::F8E4M3(src)) => self.f(dst, dst_s, src, src_l, d), _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, } } @@ -141,6 +147,7 @@ pub trait Map1Any { S::F16(s) => self.f(s, d, l, S::F16)?, S::F32(s) => self.f(s, d, l, S::F32)?, S::F64(s) => self.f(s, d, l, S::F64)?, + S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, }; Ok(out) } @@ -165,6 +172,7 @@ pub trait Map2Any { (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, }; Ok(out) diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 76d39010a9..cdc930615d 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -3,6 +3,7 @@ //! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py). //! use crate::{DType, Result, Tensor, WithDType}; +use float8::F8E4M3; use half::{bf16, f16}; impl Tensor { @@ -61,6 +62,7 @@ impl std::fmt::Debug for Tensor { DType::F16 => self.fmt_dt::(f), DType::F32 => self.fmt_dt::(f), DType::F64 => self.fmt_dt::(f), + DType::F8E4M3 => self.fmt_dt::(f), } } } @@ -498,6 +500,13 @@ impl std::fmt::Display for Tensor { writeln!(f)?; } } + DType::F8E4M3 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } }; let device_str = match self.device().location() { diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index de6cddc3a3..94dc8c1062 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -1,11 +1,14 @@ //! Types for elements that can be stored and manipulated using tensors. #![allow(clippy::redundant_closure_call)] use crate::backend::BackendStorage; +use crate::cpu::kernels::VecOps; use crate::{CpuStorage, CpuStorageRef, Error, Result}; /// The different types of elements allowed in tensors. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { + // Floating-point 8 bits integer (4-bit exponent, 3-bit mantissa). + F8E4M3, // Unsigned 8 bits integer. U8, // Unsigned 32 bits integer. @@ -44,6 +47,7 @@ impl std::str::FromStr for DType { "f16" => Ok(Self::F16), "f32" => Ok(Self::F32), "f64" => Ok(Self::F64), + "f8_e4m3" => Ok(Self::F8E4M3), _ => Err(DTypeParseError(s.to_string())), } } @@ -60,6 +64,7 @@ impl DType { Self::F16 => "f16", Self::F32 => "f32", Self::F64 => "f64", + Self::F8E4M3 => "f8_e4m3", } } @@ -67,6 +72,7 @@ impl DType { pub fn size_in_bytes(&self) -> usize { match self { Self::U8 => 1, + Self::F8E4M3 => 1, Self::U32 => 4, Self::I64 => 8, Self::BF16 => 2, @@ -79,14 +85,14 @@ impl DType { pub fn is_int(&self) -> bool { match self { Self::U8 | Self::U32 | Self::I64 => true, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => false, } } pub fn is_float(&self) -> bool { match self { Self::U8 | Self::U32 | Self::I64 => false, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => true, } } } @@ -165,6 +171,7 @@ macro_rules! with_dtype { } }; } +use float8::F8E4M3; use half::{bf16, f16}; with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); @@ -174,6 +181,17 @@ with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); with_dtype!(f64, F64, |v: f64| v, |v: f64| v); +with_dtype!(F8E4M3, F8E4M3, |v: f64| F8E4M3::from_f64(v), |v: F8E4M3| v + .to_f64()); + +impl VecOps for F8E4M3 { + fn max(self, rhs: Self) -> Self { + F8E4M3::max(self, rhs) + } + fn min(self, rhs: Self) -> Self { + F8E4M3::min(self, rhs) + } +} pub trait IntDType: WithDType { fn is_true(&self) -> bool; @@ -213,3 +231,4 @@ impl FloatDType for f16 {} impl FloatDType for bf16 {} impl FloatDType for f32 {} impl FloatDType for f64 {} +impl FloatDType for F8E4M3 {} diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 70a512bc8e..27592ad9a4 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -103,6 +103,7 @@ impl BackendStorage for MetalStorage { DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)), DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)), + DType::F8E4M3 => Ok(CpuStorage::F64(self.to_cpu()?)), } } @@ -1913,6 +1914,7 @@ impl BackendDevice for MetalDevice { DType::F16 => "fill_f16", DType::BF16 => "fill_bf16", DType::F32 => "fill_f32", + DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), DType::F64 => { let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; return self.storage_from_cpu_storage(&cpu_storage); @@ -1948,6 +1950,7 @@ impl BackendDevice for MetalDevice { CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F8E4M3(_) => crate::bail!("Metal device does not yet support F8E4M3."), }; Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE)) } @@ -1961,6 +1964,7 @@ impl BackendDevice for MetalDevice { CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F8E4M3(_) => crate::bail!("Metal device does not yet support F8E4M3."), }; Ok(Self::Storage::new( buffer?, diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 83e4f6527f..c2d06d4d19 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -27,11 +27,13 @@ //! ``` use crate::{DType, Device, Error, Result, Shape, Tensor}; use byteorder::{LittleEndian, ReadBytesExt}; +use float8::F8E4M3; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Write}; use std::path::Path; +use std::slice; const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY"; const NPY_SUFFIX: &str = ".npy"; @@ -88,6 +90,7 @@ impl Header { DType::I64 => "i8", DType::U32 => "u4", DType::U8 => "u1", + DType::F8E4M3 => Err(Error::Npy("f8e4m3 is not supported".into()))?, }; if !shape.is_empty() { shape.push(',') @@ -239,6 +242,13 @@ impl Tensor { reader.read_i64_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::F8E4M3 => { + let mut data_t = vec![F8E4M3::ZERO; elem_count]; + let ptr = data_t.as_mut_ptr().cast::(); + let len = data_t.len(); + reader.read_i8_into(unsafe { slice::from_raw_parts_mut(ptr, len) })?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } } } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index c5fc3fc475..2170341a69 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -2,6 +2,7 @@ //! #![allow(clippy::redundant_closure_call)] use crate::Tensor; +use float8::F8E4M3; use half::{bf16, f16}; use num_traits::float::Float; @@ -189,6 +190,7 @@ pub trait UnaryOpT { fn f16(v1: f16) -> f16; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; + fn f8e4m3(v1: F8E4M3) -> F8E4M3; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; fn i64(v1: i64) -> i64; @@ -199,6 +201,8 @@ pub trait UnaryOpT { fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {} const F16_VEC: bool = false; fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {} + const F8E4M3_VEC: bool = false; + fn f8e4m3_vec(_xs: &[F8E4M3], _ys: &mut [F8E4M3]) {} const F32_VEC: bool = false; fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; @@ -213,6 +217,7 @@ pub trait BinaryOpT { fn f16(v1: f16, v2: f16) -> f16; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; + fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; fn i64(v1: i64, v2: i64) -> i64; @@ -225,6 +230,8 @@ pub trait BinaryOpT { fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {} + const F8E4M3_VEC: bool = false; + fn f8e4m3_vec(_xs1: &[F8E4M3], __xs2: &[F8E4M3], _ys: &mut [F8E4M3]) {} const U8_VEC: bool = false; fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {} const U32_VEC: bool = false; @@ -282,6 +289,10 @@ macro_rules! bin_op { $e(v1, v2) } #[inline(always)] + fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3 { + $e(v1, v2) + } + #[inline(always)] fn u8(v1: u8, v2: u8) -> u8 { $e(v1, v2) } @@ -362,6 +373,10 @@ macro_rules! unary_op { $e } #[inline(always)] + fn f8e4m3($a: F8E4M3) -> F8E4M3 { + $e + } + #[inline(always)] fn f32($a: f32) -> f32 { $e } @@ -406,6 +421,10 @@ macro_rules! unary_op { $e } #[inline(always)] + fn f8e4m3($a: F8E4M3) -> F8E4M3 { + $e + } + #[inline(always)] fn u8(_: u8) -> u8 { todo!("no unary function for u8") } @@ -497,6 +516,17 @@ impl UnaryOpT for Gelu { )) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f32(0.5) + * v + * (F8E4M3::ONE + + F8E4M3::tanh( + F8E4M3::from_f32(SQRT_TWO_OVER_PI_F32) + * v + * (F8E4M3::ONE + F8E4M3::from_f32(0.044715) * v * v), + )) + } + #[inline(always)] fn f32(v: f32) -> f32 { 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v))) } @@ -570,6 +600,10 @@ impl UnaryOpT for Erf { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] fn f32(v: f32) -> f32 { Self::f64(v as f64) as f32 } @@ -604,6 +638,10 @@ impl UnaryOpT for Silu { v / (f16::ONE + (-v).exp()) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v / (F8E4M3::ONE + (-v).exp()) + } + #[inline(always)] fn f32(v: f32) -> f32 { v / (1.0 + (-v).exp()) } @@ -675,6 +713,10 @@ impl UnaryOpT for Abs { v.abs() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.abs() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.abs() } @@ -709,6 +751,10 @@ impl UnaryOpT for Ceil { v.ceil() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.ceil() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.ceil() } @@ -743,6 +789,10 @@ impl UnaryOpT for Floor { v.floor() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.floor() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.floor() } @@ -777,6 +827,10 @@ impl UnaryOpT for Round { v.round() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.round() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.round() } @@ -811,6 +865,10 @@ impl UnaryOpT for GeluErf { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] fn f32(v: f32) -> f32 { Self::f64(v as f64) as f32 } @@ -845,6 +903,10 @@ impl UnaryOpT for Relu { v.max(f16::ZERO) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.max(F8E4M3::ZERO) + } + #[inline(always)] fn f32(v: f32) -> f32 { v.max(0f32) } @@ -943,6 +1005,11 @@ impl UnaryOpT for Sign { f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from((v > F8E4M3::ZERO) as i8 as f32) + - F8E4M3::from((v < F8E4M3::ZERO) as i8 as f32) + } + #[inline(always)] fn f32(v: f32) -> f32 { f32::from(v > 0.) - f32::from(v < 0.) } diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index d402d6b8e0..67ca079155 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -10,6 +10,7 @@ //! `Tensor::save_safetensors` method. //! use crate::{DType, Device, Error, Result, Tensor, WithDType}; +use float8::F8E4M3; use safetensors::tensor as st; use safetensors::tensor::SafeTensors; use std::borrow::Cow; @@ -26,6 +27,7 @@ impl From for st::Dtype { DType::F16 => st::Dtype::F16, DType::F32 => st::Dtype::F32, DType::F64 => st::Dtype::F64, + DType::F8E4M3 => st::Dtype::F8_E4M3, } } } @@ -41,6 +43,7 @@ impl TryFrom for DType { st::Dtype::F16 => Ok(DType::F16), st::Dtype::F32 => Ok(DType::F32), st::Dtype::F64 => Ok(DType::F64), + st::Dtype::F8_E4M3 => Ok(DType::F8E4M3), dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } @@ -203,6 +206,7 @@ impl Tensor { DType::F16 => convert_slice::(data, shape, device), DType::F32 => convert_slice::(data, shape, device), DType::F64 => convert_slice::(data, shape, device), + DType::F8E4M3 => convert_slice::(data, shape, device), } } } @@ -239,6 +243,7 @@ fn convert_back(tensor: &Tensor) -> Result> { DType::BF16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F64 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F8E4M3 => Ok(convert_back_::(tensor.to_vec1()?)), } } diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 0ebb18357d..9f741da84f 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -52,6 +52,32 @@ impl ArgSort { } } +impl crate::CustomOp1 for ArgSort { + fn name(&self) -> &'static str { + "argsort" + } + + fn cpu_fwd( + &self, + storage: &crate::CpuStorage, + layout: &crate::Layout, + ) -> Result<(crate::CpuStorage, crate::Shape)> { + let sort_indexes = match storage { + crate::CpuStorage::U8(vs) => self.asort(vs, layout), + crate::CpuStorage::U32(vs) => self.asort(vs, layout), + crate::CpuStorage::I16(vs) => self.asort(vs, layout), + crate::CpuStorage::I32(vs) => self.asort(vs, layout), + crate::CpuStorage::I64(vs) => self.asort(vs, layout), + crate::CpuStorage::BF16(vs) => self.asort(vs, layout), + crate::CpuStorage::F16(vs) => self.asort(vs, layout), + crate::CpuStorage::F32(vs) => self.asort(vs, layout), + crate::CpuStorage::F64(vs) => self.asort(vs, layout), + crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout), + }; + let sort_indexes = crate::CpuStorage::U32(sort_indexes); + Ok((sort_indexes, layout.shape().into())) + } + #[cfg(feature = "cuda")] mod cuda { use super::*; @@ -154,6 +180,7 @@ impl crate::CustomOp1 for ArgSort { DType::U8 => "asort_asc_u8", DType::U32 => "asort_asc_u32", DType::I64 => "asort_asc_i64", + DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), } } else { match storage.dtype() { @@ -164,6 +191,7 @@ impl crate::CustomOp1 for ArgSort { DType::U8 => "asort_desc_u8", DType::U32 => "asort_desc_u32", DType::I64 => "asort_desc_i64", + DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), } } }; diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 540d0819f5..ef75dffd36 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -1,7 +1,7 @@ #include "cuda_utils.cuh" #include -#define AFFINE_OP(TYPENAME, FN_NAME) \ +#define AFFINE_OP(TYPENAME, FN_NAME, AFFINE) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ const size_t num_dims, \ @@ -16,28 +16,34 @@ extern "C" __global__ void FN_NAME( \ if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = inp ? inp[i] : out[i]; \ - out[i] = x * mul + add; \ + out[i] = AFFINE; \ } \ } \ else { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ TYPENAME x = inp ? inp[strided_i] : out[i]; \ - out[i] = x * mul + add; \ + out[i] = AFFINE; \ } \ } \ } \ #if __CUDA_ARCH__ >= 800 -AFFINE_OP(__nv_bfloat16, affine_bf16) +AFFINE_OP(__nv_bfloat16, affine_bf16, x * mul + add) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add))) #endif #if __CUDA_ARCH__ >= 530 -AFFINE_OP(__half, affine_f16) +AFFINE_OP(__half, affine_f16, x * mul + add) #endif -AFFINE_OP(float, affine_f32) -AFFINE_OP(double, affine_f64) -AFFINE_OP(uint8_t, affine_u8) -AFFINE_OP(uint32_t, affine_u32) -AFFINE_OP(int64_t, affine_i64) +AFFINE_OP(float, affine_f32, x * mul + add) +AFFINE_OP(double, affine_f64, x * mul + add) +AFFINE_OP(uint8_t, affine_u8, x * mul + add) +AFFINE_OP(uint32_t, affine_u32, x * mul + add) +AFFINE_OP(int16_t, affine_i16, x * mul + add) +AFFINE_OP(int32_t, affine_i32, x * mul + add) +AFFINE_OP(int64_t, affine_i64, x * mul + add) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index d44e3b20ee..971a2c433c 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -14,6 +14,21 @@ BINARY_OP_OUT(__nv_bfloat16, uint8_t, lt_bf16, x < y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, le_bf16, x <= y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, gt_bf16, x > y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +BINARY_OP(__nv_fp8_e4m3, badd_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) + F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bdiv_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) / F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bmul_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bsub_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) - F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bmaximum_f8_e4m3, maxg(x, y)) +BINARY_OP(__nv_fp8_e4m3, bminimum_f8_e4m3, ming(x, y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, eq_f8_e4m3, F8E4M3_TO_FLOAT(x) == F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ne_f8_e4m3, F8E4M3_TO_FLOAT(x) != F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, lt_f8_e4m3, F8E4M3_TO_FLOAT(x) < F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, le_f8_e4m3, F8E4M3_TO_FLOAT(x) <= F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, gt_f8_e4m3, F8E4M3_TO_FLOAT(x) > F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ge_f8_e4m3, F8E4M3_TO_FLOAT(x) >= F8E4M3_TO_FLOAT(y)) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 90f5e7ba48..1b38f58e1c 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -24,6 +24,53 @@ __device__ void cast_( } } +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +template +__device__ void cast_fp8_( + const size_t numel, + const size_t num_dims, + const size_t *info, + const __nv_fp8_e4m3 *inp, + T *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = F8E4M3_TO_FLOAT(inp[i]); + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = F8E4M3_TO_FLOAT(inp[strided_i]); + } + } +} +template +__device__ void cast_fp8_into_( + const size_t numel, + const size_t num_dims, + const size_t *info, + const S *inp, + __nv_fp8_e4m3 *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = __nv_fp8_e4m3((float)inp[i]); + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = __nv_fp8_e4m3((float)inp[strided_i]); + } + } +} + template __device__ void cast_through( const size_t numel, @@ -59,6 +106,30 @@ extern "C" __global__ void FN_NAME( \ cast_(numel, num_dims, info, inp, out); \ } \ + +#define CAST_OP_FP8(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const SRC_TYPENAME *inp, \ + DST_TYPENAME *out \ +) { \ + cast_fp8_(numel, num_dims, info, inp, out); \ +} \ + + +#define CAST_OP_FP8_INTO(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const SRC_TYPENAME *inp, \ + DST_TYPENAME *out \ +) { \ + cast_fp8_into_(numel, num_dims, info, inp, out); \ +} \ + #define CAST_THROUGH_OP(SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ @@ -72,6 +143,7 @@ extern "C" __global__ void FN_NAME( \ #if __CUDA_ARCH__ >= 800 CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16) +CAST_OP(__nv_fp8_e4m3, __nv_fp8_e4m3, cast_f8_e4m3_f8_e4m3) CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32) CAST_OP(__nv_bfloat16, float, cast_bf16_f32) @@ -83,6 +155,19 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16) CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) + +CAST_OP_FP8(__nv_fp8_e4m3, float, cast_f8_e4m3_f32) +CAST_OP_FP8_INTO(float, __nv_fp8_e4m3, cast_f32_f8_e4m3) +CAST_OP_FP8(__nv_fp8_e4m3, uint8_t, cast_f8_e4m3_u8) +CAST_OP_FP8(__nv_fp8_e4m3, __half, cast_f8_e4m3_f16) +CAST_OP_FP8(__nv_fp8_e4m3, double, cast_f8_e4m3_f64) +CAST_OP_FP8_INTO(__half, __nv_fp8_e4m3, cast_f16_f8_e4m3) +CAST_OP_FP8_INTO(double, __nv_fp8_e4m3, cast_f64_f8_e4m3) +CAST_OP_FP8_INTO(uint8_t, __nv_fp8_e4m3, cast_u8_f8_e4m3) +CAST_OP_FP8_INTO(int32_t, __nv_fp8_e4m3, cast_i32_f8_e4m3) +CAST_OP_FP8(__nv_fp8_e4m3, int32_t, cast_f8_e4m3_i32) +CAST_OP_FP8(__nv_fp8_e4m3, __nv_bfloat16, cast_f8_e4m3_bf16) +CAST_OP_FP8_INTO(__nv_bfloat16, __nv_fp8_e4m3, cast_bf16_f8_e4m3) #else #include #if CUDA_VERSION >= 11000 @@ -94,6 +179,7 @@ CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16) CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16) +CAST_THROUGH_OP(__nv_bfloat16, __nv_fp8_e4m3, float, cast_bf16_f8_e4m3) #endif #endif diff --git a/candle-kernels/src/compatibility.cuh b/candle-kernels/src/compatibility.cuh index d0791749bb..1e4cf215c1 100644 --- a/candle-kernels/src/compatibility.cuh +++ b/candle-kernels/src/compatibility.cuh @@ -1,5 +1,6 @@ #include "cuda_fp16.h" #include "cuda_bf16.h" +#include "cuda_fp8.h" // Table showing which features are supported on which compute capability // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index fa834faa3a..6ca6fd7c2b 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -702,6 +702,18 @@ UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) IM2COL_OP(__nv_bfloat16, im2col_bf16) IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16) COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16) + +// NOTE: No conv ops for f8 +// CONV1D_OP(__nv_bfloat16, float, conv1d_f8_e5m) +// CONV2D_OP(__nv_fp8_e4m3, float, conv2d_f8_e5m) +// CONVT1D_OP(__nv_fp8_e4m3, float, conv_transpose1d_f8_e5m) +// CONVT2D_OP(__nv_fp8_e4m3, float, conv_transpose2d_f8_e5m) +// AVG_POOL2D_OP(__nv_fp8_e4m3, float, avg_pool2d_f8_e5m) +// MAX_POOL2D_OP(__nv_fp8_e4m3, max_pool2d_f8_e5m) +// UPSAMPLE_NEAREST2D_OP(__nv_fp8_e4m3, upsample_nearest2d_f8_e5m) +// IM2COL_OP(__nv_fp8_e4m3, im2col_f8_e5m) +// IM2COL1D_OP(__nv_fp8_e4m3, im2col1d_f8_e5m) +// COL2IM1D_OP(__nv_fp8_e4m3, col2im1d_f8_e5m) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 2673b8aaf1..eb1400b4da 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -198,4 +198,27 @@ __device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); __device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); } __device__ __forceinline__ __nv_bfloat16 absg(__nv_bfloat16 a) { return __habs(a); } __device__ __forceinline__ __nv_bfloat16 copysigng(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(copysignf(__bfloat162float(a), __bfloat162float(b))); } + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +__device__ __forceinline__ __nv_fp8_e4m3 powg(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(powf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ bool isnang(__nv_fp8_e4m3 a) { return isnan(F8E4M3_TO_FLOAT(a)); } +__device__ __forceinline__ __nv_fp8_e4m3 sqrtg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(sqrtf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 cosg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(cosf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 sing(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(sinf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 recipg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(1. / F8E4M3_TO_FLOAT(a)); } +__device__ __forceinline__ __nv_fp8_e4m3 maxg(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(fmaxf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ __nv_fp8_e4m3 tanhg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(tanhf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 erfg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(erff(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 ceilg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(ceilf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 floorg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(floorf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 roundg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(roundf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 normcdfg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(normcdff(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 ming(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(fminf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ __nv_fp8_e4m3 logg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(logf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 expg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(expf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 absg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(fabsf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 copysigng(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(copysignf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } + + #endif diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index ca448d989f..1b72a901f2 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -43,6 +43,11 @@ COPY2D_OP(__half, copy2d_f16) #if __CUDA_ARCH__ >= 800 #include +#include + extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__nv_bfloat16, copy2d_bf16) + +extern "C" __global__ void fill_f8_e4m3(__nv_fp8_e4m3 *buf, __nv_fp8_e4m3 value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__nv_fp8_e4m3, copy2d_f8_e4m3) #endif diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 8af2954d13..32cc4e9ad1 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -99,6 +99,57 @@ __device__ void index_add( } } +#if __CUDA_ARCH__ >= 800 +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +template +__device__ void scatter_add_f8( + const I *ids, + const __nv_fp8_e4m3 *inp, + __nv_fp8_e4m3 *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (pre * src_dim_size + j) * right_size + post; + const size_t idx = ids[src_i]; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = __nv_fp8_e4m3(F8E4M3_TO_FLOAT(out[dst_i]) + F8E4M3_TO_FLOAT(inp[src_i])); + } + } +} + +template +__device__ void index_add_f8( + const I *ids, + const size_t ids_dim_size, + const __nv_fp8_e4m3 *inp, + __nv_fp8_e4m3 *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const size_t idx = ids[j]; + const size_t src_i = (pre * ids_dim_size + j) * right_size + post; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = __nv_fp8_e4m3(F8E4M3_TO_FLOAT(out[dst_i]) + F8E4M3_TO_FLOAT(inp[src_i])); + } + } +} +#endif + #define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const INDEX_TYPENAME *ids, \ @@ -111,6 +162,18 @@ extern "C" __global__ void FN_NAME( \ const size_t right_size \ ) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ +#define IA_OP_F8(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const size_t ids_dim_size, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { index_add_f8(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + template __device__ void scatter_add( const I *ids, @@ -145,6 +208,17 @@ extern "C" __global__ void FN_NAME( \ const size_t right_size \ ) { scatter_add(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ +#define SA_OP_F8(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { scatter_add_f8(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + #if __CUDA_ARCH__ >= 800 IS_OP(__nv_bfloat16, int64_t, is_i64_bf16) @@ -159,6 +233,27 @@ IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) + +IS_OP(__nv_fp8_e4m3, int16_t, is_i16_f8_e4m3) +IS_OP(__nv_fp8_e4m3, int32_t, is_i32_f8_e4m3) +IS_OP(__nv_fp8_e4m3, int64_t, is_i64_f8_e4m3) +IS_OP(__nv_fp8_e4m3, uint32_t, is_u32_f8_e4m3) +IS_OP(__nv_fp8_e4m3, uint8_t, is_u8_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int16_t, gather_i16_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int32_t, gather_i32_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int64_t, gather_i64_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, uint32_t, gather_u32_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, uint8_t, gather_u8_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int16_t, ia_i16_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int32_t, ia_i32_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int64_t, ia_i64_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, uint32_t, ia_u32_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, uint8_t, ia_u8_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int16_t, sa_i16_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int32_t, sa_i32_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int64_t, sa_i64_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, uint32_t, sa_u32_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, uint8_t, sa_u8_f8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 079c370873..2738d8254e 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -578,6 +578,14 @@ LAYERNORM_OP(__nv_bfloat16, layernorm_bf16) ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) + +// NOTE: No reduce ops for f8 +// SUM_OP(__nv_fp8_e4m3, sum_fp8_e4m3) +// SOFTMAX_OP(__nv_fp8_e4m3, float, softmax_fp8_e4m3) +// RMSNORM_OP(__nv_fp8_e4m3, rmsnorm_fp8_e4m3) +// LAYERNORM_OP(__nv_fp8_e4m3, layernorm_fp8_e4m3) +// ROPE_OP(__nv_fp8_e4m3, rope_fp8_e4m3, rope_i_fp8_e4m3, rope_thd_fp8_e4m3) +// FAST_OP(__nv_fp8_e4m3, fast_min_fp8_e4m3, fast_max_fp8_e4m3, fast_argmin_fp8_e4m3, fast_argmax_fp8_e4m3, fast_sum_fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu index 08f1f9fc29..a7ad4f79c4 100644 --- a/candle-kernels/src/sort.cu +++ b/candle-kernels/src/sort.cu @@ -75,6 +75,9 @@ extern "C" __global__ void asort_desc_##RUST_NAME( \ #if __CUDA_ARCH__ >= 800 ASORT_OP(__nv_bfloat16, bf16) + +// NOTE: No sort ops for f8 +// ASORT_OP(__nv_fp8_e4m3, fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index aaa8a881fb..ef4009e3e0 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -36,6 +36,12 @@ extern "C" __global__ void FN_NAME( \ WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) + +WHERE_OP(__nv_fp8_e4m3, int16_t, where_i16_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, int32_t, where_i32_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, int64_t, where_i64_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, uint32_t, where_u32_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, uint8_t, where_u8_fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index c82a88375d..5fcb5e2b1a 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -122,6 +122,33 @@ UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x)) UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x)) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +UNARY_OP(__nv_fp8_e4m3, ucopy_f8_e4m3, x) +UNARY_OP(__nv_fp8_e4m3, uneg_fp8_e4m3, __nv_fp8_e4m3(-F8E4M3_TO_FLOAT(x))) +UNARY_OP(__nv_fp8_e4m3, urecip_fp8_e4m3, recipg(x)) +UNARY_OP(__nv_fp8_e4m3, uexp_fp8_e4m3, expg(x)) +UNARY_OP(__nv_fp8_e4m3, ulog_fp8_e4m3, logg(x)) +UNARY_OP(__nv_fp8_e4m3, usin_fp8_e4m3, sing(x)) +UNARY_OP(__nv_fp8_e4m3, ucos_fp8_e4m3, cosg(x)) +UNARY_OP(__nv_fp8_e4m3, utanh_fp8_e4m3, tanhg(x)) +UNARY_OP(__nv_fp8_e4m3, uerf_fp8_e4m3, erfg(x)) +UNARY_OP(__nv_fp8_e4m3, uceil_fp8_e4m3, ceilg(x)) +UNARY_OP(__nv_fp8_e4m3, ufloor_fp8_e4m3, floorg(x)) +UNARY_OP(__nv_fp8_e4m3, uround_fp8_e4m3, roundg(x)) +UNARY_OP(__nv_fp8_e4m3, unormcdf_fp8_e4m3, normcdfg(x)) +UNARY_OP(__nv_fp8_e4m3, uabs_fp8_e4m3, absg(x)) +UNARY_OP(__nv_fp8_e4m3, usqr_fp8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x)*F8E4M3_TO_FLOAT(x))) +UNARY_OP(__nv_fp8_e4m3, usqrt_fp8_e4m3, sqrtg(x)) +UNARY_OP(__nv_fp8_e4m3, ugelu_fp8_e4m3, __nv_fp8_e4m3(gelu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, ugelu_erf_fp8_e4m3, __nv_fp8_e4m3(gelu_erf_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, urelu_fp8_e4m3, __nv_fp8_e4m3(relu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP1(__nv_fp8_e4m3, uelu_fp8_e4m3, __nv_fp8_e4m3(elu_fwd(F8E4M3_TO_FLOAT(x), F8E4M3_TO_FLOAT(param)))) +UNARY_OP(__nv_fp8_e4m3, usilu_fp8_e4m3, __nv_fp8_e4m3(silu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP1(__nv_fp8_e4m3, upowf_fp8_e4m3, powg(x, param)) +UNARY_OP(__nv_fp8_e4m3, usign_fp8_e4m3, __nv_fp8_e4m3(sign_(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, usigmoid_fp8_e4m3, __nv_fp8_e4m3(sigmoid_fwd(F8E4M3_TO_FLOAT(x)))) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index d91619fbb3..e884381c26 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -19,6 +19,7 @@ candle = { workspace = true } candle-nn = { workspace = true } candle-onnx = { workspace = true, optional = true } half = { workspace = true } +float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } pyo3 = { version = "0.22.0", features = ["extension-module", "abi3-py311"] } diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index b8695cc8a0..85bf7c4710 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,4 +1,5 @@ #![allow(clippy::redundant_closure_call)] +use float8::F8E4M3; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::pyclass::CompareOp; @@ -157,6 +158,7 @@ pydtype!(f16, f32::from); pydtype!(bf16, f32::from); pydtype!(f32, |v| v); pydtype!(f64, |v| v); +pydtype!(F8E4M3, f32::from); fn actual_index(t: &Tensor, dim: usize, index: i64) -> ::candle::Result { let dim = t.dim(dim)?; @@ -204,6 +206,7 @@ trait MapDType { DType::F16 => self.f::(t), DType::F32 => self.f::(t), DType::F64 => self.f::(t), + DType::F8E4M3 => self.f::(t), } } } diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index df1de0b276..269c2a02da 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -28,6 +28,7 @@ pub mod colpali; pub mod convmixer; pub mod convnext; pub mod dac; +pub mod deepseekv3; pub mod depth_anything_v2; pub mod dinov2; pub mod dinov2reg4;