Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/jafioti/luminal
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed May 2, 2024
2 parents b7e40c1 + dc01280 commit bb97aab
Show file tree
Hide file tree
Showing 39 changed files with 2,001 additions and 380 deletions.
7 changes: 1 addition & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ description = "Deep learning at the speed of light."
license = "MIT OR Apache-2.0"

[dependencies]
luminal_symbolic = {path="./crates/luminal_symbolic"}
itertools = "0.11.0"
num-traits = "0.2.16"
petgraph = "0.6.4"
Expand All @@ -32,9 +31,5 @@ members = [
"crates/luminal_cpu",
"crates/luminal_nn",
"crates/luminal_training",
"crates/luminal_symbolic",
]
exclude = [
"crates/luminal_metal",
"crates/luminal_cuda",
]
exclude = ["crates/luminal_metal", "crates/luminal_cuda"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# [luminal](https://luminalai.com)
![image](https://github.com/jafioti/luminal/blob/main/docs/dag.jpeg)
![image](https://github.com/jafioti/luminal/blob/main/docs/images/dag.jpeg)
[![Website](https://img.shields.io/badge/Docs-Website-blue?style=for-the-badge&color=0D9373)](https://luminalai.com)
[![CI Status](https://img.shields.io/github/actions/workflow/status/jafioti/luminal/test.yml?style=for-the-badge&logo=github-actions&logoColor=white&branch=main)](https://github.com/Sidekick-AI/dataflow/actions)
[![Current Crates.io Version](https://img.shields.io/crates/v/luminal.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/luminal)
Expand Down
13 changes: 6 additions & 7 deletions crates/luminal_metal/src/quantized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,13 +613,12 @@ mod tests {
})
.collect::<Vec<_>>();
let dev = Device::system_default().unwrap();
cx.tensors
.insert((weights.id, 0), quantized_buffer(&blocks, &dev));

cx.compile(
MetalQuantizedCompiler::<f32>::new(vec![weights.id]),
&mut out,
);
cx.tensors
.insert((weights.id, 0), quantized_buffer(&blocks, &dev));
cx.execute();

let mut cx1 = Graph::new();
Expand Down Expand Up @@ -659,13 +658,13 @@ mod tests {
})
.collect::<Vec<_>>();
let dev = Device::system_default().unwrap();
cx.tensors
.insert((weights.id, 0), quantized_buffer(&blocks, &dev));

cx.compile(
MetalQuantizedCompiler::<f32>::new(vec![weights.id]),
&mut out,
);
cx.tensors
.insert((weights.id, 0), quantized_buffer(&blocks, &dev));
cx.execute();

let cpu = dfdx::tensor::Cpu::default();
Expand Down Expand Up @@ -706,13 +705,13 @@ mod tests {
})
.collect::<Vec<_>>();
let dev = Device::system_default().unwrap();
cx.tensors
.insert((weights.id, 0), quantized_buffer(&blocks, &dev));

cx.compile(
MetalQuantizedCompiler::<f16>::new(vec![weights.id]),
&mut out,
);
cx.tensors
.insert((weights.id, 0), quantized_buffer(&blocks, &dev));
cx.execute();

let cpu = dfdx::tensor::Cpu::default();
Expand Down
87 changes: 86 additions & 1 deletion crates/luminal_metal/src/tests/fp16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ fn test_pad_contig() {
.set_dyn(a_data, &[m, k])
.retrieve();
let mut b: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = a
.pad(&[(0, 0.into()), (0, Expression::from(16) - 'K')])
.pad(&[(0, 0.into()), (0, Expression::from(24) - 'K')])
.contiguous()
.retrieve();
let mut c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> =
Expand Down Expand Up @@ -791,3 +791,88 @@ fn test_movement() {

assert_exact(&c.data(), &d_c.as_vec());
}

#[test]
fn test_conv2d() {
let mut cx = Graph::new();

const CH_IN: usize = 5;
const CH_OUT: usize = 2;
const KERNELX: usize = 2;
const KERNELY: usize = 2;
const STRIDEX: usize = KERNELX;
const STRIDEY: usize = KERNELY;
const DILATIONX: usize = 0;
const DILATIONY: usize = 0;
const DIMX_IN: usize = 16;
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
const DIMY_IN: usize = 9;
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;

let inp1 = cx.tensor::<R3<CH_IN, DIMX_IN, DIMY_IN>>().set(vec![
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8., 8.,
5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9., 4., 7.,
1., 0., 7., 7., 4., 9., 5., 0., 4., 7., 4., 7., 8., 8., 4., 8., 4., 7., 9., 3., 7., 9., 5.,
8., 5., 9., 0., 9., 5., 6., 8., 9., 5., 4., 1., 9., 7., 2., 2., 7., 9., 3., 1., 2., 8., 4.,
0., 8., 0., 5., 6., 7., 7., 4., 3., 4., 6., 8., 3., 7., 8., 8., 7., 1., 5., 1., 8., 0., 1.,
1., 7., 3., 2., 1., 0., 4., 5., 4., 3., 2., 5., 4., 2., 4., 1., 9., 4., 1., 9., 7., 7., 1.,
2., 6., 3., 4., 1., 1., 6., 6., 8., 2., 7., 7., 9., 0., 9., 0., 1., 4., 2., 4., 9., 6., 8.,
6., 1., 6., 3., 8., 3., 4., 5., 0., 2., 1., 8., 2., 2., 8., 7., 0., 7., 7., 3., 4., 5., 0.,
7., 2., 1., 1., 4., 2., 9., 9., 6., 1., 5., 4., 6., 9., 5., 4., 1., 9., 1., 5., 5., 5., 8.,
8., 0., 1., 3., 0., 8., 8., 5., 1., 6., 1., 5., 6., 4., 4., 4., 0., 1., 1., 5., 1., 7., 2.,
3., 5., 5., 4., 9., 1., 3., 7., 6., 7., 1., 5., 3., 8., 6., 6., 6., 7., 3., 2., 2., 8., 1.,
3., 0., 2., 7., 6., 5., 7., 5., 7., 8., 1., 2., 2., 5., 0., 2., 9., 1., 5., 3., 8., 7., 9.,
7., 2., 8., 8., 8., 6., 3., 2., 7., 7., 0., 3., 7., 8., 3., 7., 2., 3., 2., 7., 5., 5., 6.,
0., 9., 0., 9., 9., 1., 8., 7., 9., 6., 8., 7., 5., 4., 9., 5., 6., 3., 2., 8., 3., 0., 6.,
3., 8., 3., 1., 8., 7., 2., 0., 7., 7., 7., 7., 8., 0., 4., 9., 8., 2., 0., 4., 4., 3., 5.,
5., 3., 0., 3., 6., 3., 1., 2., 9., 9., 6., 8., 1., 2., 6., 8., 6., 0., 0., 2., 8., 8., 5.,
0., 5., 9., 0., 8., 1., 1., 3., 5., 9., 3., 5., 8., 6., 3., 2., 9., 4., 8., 3., 9., 5., 2.,
9., 0., 1., 6., 8., 0., 3., 0., 1., 2., 1., 0., 1., 4., 1., 1., 0., 6., 9., 2., 7., 2., 6.,
0., 4., 8., 2., 6., 7., 2., 2., 7., 4., 5., 8., 1., 4., 7., 5., 9., 7., 2., 5., 9., 1., 6.,
1., 7., 9., 5., 6., 9., 3., 5., 1., 6., 1., 3., 3., 9., 3., 9., 0., 1., 8., 1., 9., 8., 5.,
3., 4., 4., 1., 5., 5., 4., 4., 5., 8., 7., 1., 1., 7., 3., 9., 0., 1., 3., 4., 8., 4., 0.,
5., 6., 2., 0., 7., 8., 2., 6., 2., 9., 6., 2., 0., 3., 7., 5., 7., 1., 8., 5., 5., 9., 1.,
0., 3., 5., 7., 5., 3., 2., 8., 6., 3., 0., 5., 8., 5., 7., 8., 8., 2., 9., 0., 1., 8., 6.,
0., 3., 2., 5., 2., 9., 8., 9., 6., 2., 0., 3., 2., 5., 9., 1., 3., 6., 5., 2., 8., 2., 2.,
1., 8., 6., 4., 1., 6., 0., 7., 3., 0., 9., 6., 5., 5., 5., 2., 4., 2., 8., 3., 0., 6., 3.,
8., 8., 4., 9., 4., 7., 0., 3., 5., 1., 4., 6., 0., 0., 5., 9., 7., 8., 6., 7., 0., 6., 7.,
0., 5., 8., 8., 6., 4., 6., 0., 2., 3., 2., 8., 7., 5., 9., 6., 6., 2., 0., 4., 4., 4., 4.,
2., 7., 5., 3., 2., 6., 3., 7., 0., 7., 2., 5., 1., 4., 4., 5., 1., 6., 7., 5., 7., 0., 7.,
8., 4., 7., 3., 9., 1., 7., 5., 6., 1., 0., 2., 0., 0., 5., 5., 8., 8., 7., 3., 7., 2., 9.,
3., 8., 4., 5., 3., 8., 5., 2., 0., 2., 0., 5., 9., 0., 3., 8., 0., 4., 1., 8., 4., 8., 9.,
1., 1., 4., 5., 0., 2., 0., 9., 4., 2., 3., 9., 0., 7., 3., 1., 5., 9., 1., 6., 5., 4., 2.,
1., 2., 1., 1., 4., 7., 2.,
]);

let model = luminal_nn::Conv2D::<CH_IN, CH_OUT, KERNELX, KERNELY>::initialize(&mut cx);
model.weight.set(vec![
0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300,
0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200, 0.1500,
0.0700, -0.0800, 0.1700, 0.1000, -0.0700, 0.1600, -0.1600, -0.1900, -0.0500, -0.2100,
0.0100, -0.2000, 0.2100, -0.0400, -0.1400, 0.1500, 0.0500, -0.1700, 0.1400,
]);

let mut out1 = model
.forward::<DIMX_IN, DIMY_IN, DIMX_OUT, DIMY_OUT>(inp1)
.retrieve();

cx.compile(
<(GenericCompiler, MetalCompiler<f16>)>::default(),
&mut out1,
);
cx.execute();

assert_close_precision(
&out1.data(),
&[
3.9600, -0.3300, -1.7800, 4.0400, 1.5300, 0.2900, 2.8700, 3.0000, 0.9600, -1.8700,
4.5900, 3.9700, 1.2800, 1.1800, 3.7800, 2.8500, 0.5500, 0.5600, 3.9800, 1.3200,
-0.7100, -0.6500, 4.3900, 0.4000, 1.0300, 0.9800, 3.1200, 2.7400, 2.5100, 0.1200,
1.8500, 2.0000, -0.7900, 1.0700, -0.3900, -0.8100, -2.5100, -2.9700, 0.2100, 1.8400,
-0.7700, -0.3900, 1.2200, 0.1900, 4.1700, -4.3600, -1.8600, 0.4800, -2.4400, 2.6300,
1.5000, -1.9700, 1.2800, -2.8200, -2.3200, 0.2200, -0.3800, 2.1800, -0.8200, -1.5700,
1.2000, -3.4200, -1.6700, 0.9000,
],
1e-2,
);
}
84 changes: 84 additions & 0 deletions crates/luminal_metal/src/tests/fp32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,87 @@ fn test_transformer_encoder_block() {

assert_close(&b.data(), &d_b.as_vec());
}

#[test]
fn test_conv2d() {
let mut cx = Graph::new();

const CH_IN: usize = 5;
const CH_OUT: usize = 2;
const KERNELX: usize = 2;
const KERNELY: usize = 2;
const STRIDEX: usize = KERNELX;
const STRIDEY: usize = KERNELY;
const DILATIONX: usize = 0;
const DILATIONY: usize = 0;
const DIMX_IN: usize = 16;
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
const DIMY_IN: usize = 9;
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;

let inp1 = cx.tensor::<R3<CH_IN, DIMX_IN, DIMY_IN>>().set(vec![
8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8., 8.,
5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9., 4., 7.,
1., 0., 7., 7., 4., 9., 5., 0., 4., 7., 4., 7., 8., 8., 4., 8., 4., 7., 9., 3., 7., 9., 5.,
8., 5., 9., 0., 9., 5., 6., 8., 9., 5., 4., 1., 9., 7., 2., 2., 7., 9., 3., 1., 2., 8., 4.,
0., 8., 0., 5., 6., 7., 7., 4., 3., 4., 6., 8., 3., 7., 8., 8., 7., 1., 5., 1., 8., 0., 1.,
1., 7., 3., 2., 1., 0., 4., 5., 4., 3., 2., 5., 4., 2., 4., 1., 9., 4., 1., 9., 7., 7., 1.,
2., 6., 3., 4., 1., 1., 6., 6., 8., 2., 7., 7., 9., 0., 9., 0., 1., 4., 2., 4., 9., 6., 8.,
6., 1., 6., 3., 8., 3., 4., 5., 0., 2., 1., 8., 2., 2., 8., 7., 0., 7., 7., 3., 4., 5., 0.,
7., 2., 1., 1., 4., 2., 9., 9., 6., 1., 5., 4., 6., 9., 5., 4., 1., 9., 1., 5., 5., 5., 8.,
8., 0., 1., 3., 0., 8., 8., 5., 1., 6., 1., 5., 6., 4., 4., 4., 0., 1., 1., 5., 1., 7., 2.,
3., 5., 5., 4., 9., 1., 3., 7., 6., 7., 1., 5., 3., 8., 6., 6., 6., 7., 3., 2., 2., 8., 1.,
3., 0., 2., 7., 6., 5., 7., 5., 7., 8., 1., 2., 2., 5., 0., 2., 9., 1., 5., 3., 8., 7., 9.,
7., 2., 8., 8., 8., 6., 3., 2., 7., 7., 0., 3., 7., 8., 3., 7., 2., 3., 2., 7., 5., 5., 6.,
0., 9., 0., 9., 9., 1., 8., 7., 9., 6., 8., 7., 5., 4., 9., 5., 6., 3., 2., 8., 3., 0., 6.,
3., 8., 3., 1., 8., 7., 2., 0., 7., 7., 7., 7., 8., 0., 4., 9., 8., 2., 0., 4., 4., 3., 5.,
5., 3., 0., 3., 6., 3., 1., 2., 9., 9., 6., 8., 1., 2., 6., 8., 6., 0., 0., 2., 8., 8., 5.,
0., 5., 9., 0., 8., 1., 1., 3., 5., 9., 3., 5., 8., 6., 3., 2., 9., 4., 8., 3., 9., 5., 2.,
9., 0., 1., 6., 8., 0., 3., 0., 1., 2., 1., 0., 1., 4., 1., 1., 0., 6., 9., 2., 7., 2., 6.,
0., 4., 8., 2., 6., 7., 2., 2., 7., 4., 5., 8., 1., 4., 7., 5., 9., 7., 2., 5., 9., 1., 6.,
1., 7., 9., 5., 6., 9., 3., 5., 1., 6., 1., 3., 3., 9., 3., 9., 0., 1., 8., 1., 9., 8., 5.,
3., 4., 4., 1., 5., 5., 4., 4., 5., 8., 7., 1., 1., 7., 3., 9., 0., 1., 3., 4., 8., 4., 0.,
5., 6., 2., 0., 7., 8., 2., 6., 2., 9., 6., 2., 0., 3., 7., 5., 7., 1., 8., 5., 5., 9., 1.,
0., 3., 5., 7., 5., 3., 2., 8., 6., 3., 0., 5., 8., 5., 7., 8., 8., 2., 9., 0., 1., 8., 6.,
0., 3., 2., 5., 2., 9., 8., 9., 6., 2., 0., 3., 2., 5., 9., 1., 3., 6., 5., 2., 8., 2., 2.,
1., 8., 6., 4., 1., 6., 0., 7., 3., 0., 9., 6., 5., 5., 5., 2., 4., 2., 8., 3., 0., 6., 3.,
8., 8., 4., 9., 4., 7., 0., 3., 5., 1., 4., 6., 0., 0., 5., 9., 7., 8., 6., 7., 0., 6., 7.,
0., 5., 8., 8., 6., 4., 6., 0., 2., 3., 2., 8., 7., 5., 9., 6., 6., 2., 0., 4., 4., 4., 4.,
2., 7., 5., 3., 2., 6., 3., 7., 0., 7., 2., 5., 1., 4., 4., 5., 1., 6., 7., 5., 7., 0., 7.,
8., 4., 7., 3., 9., 1., 7., 5., 6., 1., 0., 2., 0., 0., 5., 5., 8., 8., 7., 3., 7., 2., 9.,
3., 8., 4., 5., 3., 8., 5., 2., 0., 2., 0., 5., 9., 0., 3., 8., 0., 4., 1., 8., 4., 8., 9.,
1., 1., 4., 5., 0., 2., 0., 9., 4., 2., 3., 9., 0., 7., 3., 1., 5., 9., 1., 6., 5., 4., 2.,
1., 2., 1., 1., 4., 7., 2.,
]);

let model = luminal_nn::Conv2D::<CH_IN, CH_OUT, KERNELX, KERNELY>::initialize(&mut cx);
model.weight.set(vec![
0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300,
0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200, 0.1500,
0.0700, -0.0800, 0.1700, 0.1000, -0.0700, 0.1600, -0.1600, -0.1900, -0.0500, -0.2100,
0.0100, -0.2000, 0.2100, -0.0400, -0.1400, 0.1500, 0.0500, -0.1700, 0.1400,
]);

let mut out1 = model
.forward::<DIMX_IN, DIMY_IN, DIMX_OUT, DIMY_OUT>(inp1)
.retrieve();

cx.compile(
<(GenericCompiler, MetalCompiler<f32>)>::default(),
&mut out1,
);
cx.execute();

assert_close_precision(
&out1.data(),
&[
3.9600, -0.3300, -1.7800, 4.0400, 1.5300, 0.2900, 2.8700, 3.0000, 0.9600, -1.8700,
4.5900, 3.9700, 1.2800, 1.1800, 3.7800, 2.8500, 0.5500, 0.5600, 3.9800, 1.3200,
-0.7100, -0.6500, 4.3900, 0.4000, 1.0300, 0.9800, 3.1200, 2.7400, 2.5100, 0.1200,
1.8500, 2.0000, -0.7900, 1.0700, -0.3900, -0.8100, -2.5100, -2.9700, 0.2100, 1.8400,
-0.7700, -0.3900, 1.2200, 0.1900, 4.1700, -4.3600, -1.8600, 0.4800, -2.4400, 2.6300,
1.5000, -1.9700, 1.2800, -2.8200, -2.3200, 0.2200, -0.3800, 2.1800, -0.8200, -1.5700,
1.2000, -3.4200, -1.6700, 0.9000,
],
);
}
Loading

0 comments on commit bb97aab

Please sign in to comment.