Skip to content

Commit

Permalink
Add the f8 e4m3 dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Jan 27, 2025
1 parent 25bc793 commit e367cd7
Show file tree
Hide file tree
Showing 30 changed files with 901 additions and 21 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
float8 = { version = "0.1.0", features = ["num-traits", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
Expand Down
3 changes: 2 additions & 1 deletion candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
float8 = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
libc = { workspace = true, optional = true }
memmap2 = { workspace = true }
Expand All @@ -42,7 +43,7 @@ criterion = { workspace = true }

[features]
default = []
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda", "float8/cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
Expand Down
6 changes: 6 additions & 0 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Implement conversion traits for tensors
use crate::{DType, Device, Error, Tensor, WithDType};
use float8::F8E4M3;
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::convert::TryFrom;

Expand Down Expand Up @@ -139,6 +140,11 @@ impl Tensor {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;
}
DType::F8E4M3 => {
for v in vs.to_vec1::<F8E4M3>()? {
f.write_u8(v.to_bits())?
}
}
}
Ok(())
}
Expand Down
121 changes: 121 additions & 0 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use float8::F8E4M3;
use half::{bf16, f16};
use rayon::prelude::*;

Expand All @@ -25,6 +26,7 @@ pub enum CpuStorage {
F16(Vec<f16>),
F32(Vec<f32>),
F64(Vec<f64>),
F8E4M3(Vec<F8E4M3>),
}

#[derive(Debug, Clone)]
Expand All @@ -36,6 +38,7 @@ pub enum CpuStorageRef<'a> {
F16(&'a [f16]),
F32(&'a [f32]),
F64(&'a [f64]),
F8E4M3(&'a [F8E4M3]),
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -1623,6 +1626,17 @@ impl CpuStorage {
.concat();
Self::F64(storages)
}
Self::F8E4M3(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F8E4M3(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F8E4M3(storages)
}
};
Ok(s)
}
Expand All @@ -1640,6 +1654,7 @@ impl BackendStorage for CpuStorage {
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
Self::F64(_) => DType::F64,
Self::F8E4M3(_) => DType::F8E4M3,
}
}

Expand Down Expand Up @@ -1674,6 +1689,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, bf16::from_f64);
Ok(Self::BF16(data))
}
(Self::F8E4M3(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
Ok(Self::BF16(data))
}
(Self::U8(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
Expand Down Expand Up @@ -1702,6 +1721,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, f16::from_f64);
Ok(Self::F16(data))
}
(Self::F8E4M3(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
Ok(Self::F16(data))
}
(Self::U8(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
Expand Down Expand Up @@ -1730,6 +1753,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::F8E4M3(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v.to_f32());
Ok(Self::F32(data))
}
(Self::U8(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::U8(data))
Expand Down Expand Up @@ -1758,6 +1785,14 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::F8E4M3(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u8);
Ok(Self::U8(data))
}
(Self::F8E4M3(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u8);
Ok(Self::U8(data))
}
(Self::U8(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
Expand Down Expand Up @@ -1786,6 +1821,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::F8E4M3(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
Ok(Self::U32(data))
}
(Self::U8(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
Expand Down Expand Up @@ -1814,6 +1853,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::F8E4M3(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
Ok(Self::I64(data))
}
(Self::U8(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
Expand Down Expand Up @@ -1842,6 +1885,42 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v);
Ok(Self::F64(data))
}
(Self::F8E4M3(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v.to_f64());
Ok(Self::F64(data))
}
(Self::U8(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::U32(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::I64(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::BF16(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from(v.to_f32()));
Ok(Self::F8E4M3(data))
}
(Self::F16(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
Ok(Self::F8E4M3(data))
}
(Self::F32(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, F8E4M3::from_f32);
Ok(Self::F8E4M3(data))
}
(Self::F64(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, F8E4M3::from_f64);
Ok(Self::F8E4M3(data))
}
(Self::F8E4M3(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::F8E4M3(data))
}
}
}

Expand Down Expand Up @@ -1955,6 +2034,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v.powf(e));
Ok(Self::F64(data))
}
Self::F8E4M3(storage) => {
let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e)));
Ok(Self::F8E4M3(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
Expand All @@ -1980,6 +2063,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| elu(v, alpha));
Ok(Self::F64(data))
}
Self::F8E4M3(storage) => {
let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha)));
Ok(Self::F8E4M3(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
Expand Down Expand Up @@ -2024,6 +2111,15 @@ impl BackendStorage for CpuStorage {
Ok(Self::F64(data))
}
}
Self::F8E4M3(storage) => {
if B::F8E4M3_VEC {
let data = unary_map_vec(storage, layout, B::f8e4m3, B::f8e4m3_vec);
Ok(Self::F8E4M3(data))
} else {
let data = unary_map(storage, layout, B::f8e4m3);
Ok(Self::F8E4M3(data))
}
}
Self::U8(storage) => {
let data = unary_map(storage, layout, B::u8);
Ok(Self::U8(data))
Expand Down Expand Up @@ -2505,6 +2601,15 @@ impl BackendDevice for CpuDevice {
}
Ok(CpuStorage::F16(data))
}
DType::F8E4M3 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
rand::distributions::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max));
for _i in 0..elem_count {
data.push(rng.sample::<F8E4M3, _>(uniform))
}
Ok(CpuStorage::F8E4M3(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
Expand Down Expand Up @@ -2551,6 +2656,15 @@ impl BackendDevice for CpuDevice {
}
Ok(CpuStorage::F16(data))
}
DType::F8E4M3 => {
let mut data = Vec::with_capacity(elem_count);
let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F8E4M3(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let normal =
Expand Down Expand Up @@ -2614,6 +2728,11 @@ impl BackendDevice for CpuDevice {
v.set_len(elem_count);
CpuStorage::F64(v)
}
DType::F8E4M3 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F8E4M3(v)
}
};
Ok(storage)
}
Expand All @@ -2626,6 +2745,7 @@ impl BackendDevice for CpuDevice {
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ONE; elem_count]),
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
};
Expand All @@ -2640,6 +2760,7 @@ impl BackendDevice for CpuDevice {
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]),
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
};
Expand Down
Loading

0 comments on commit e367cd7

Please sign in to comment.