Skip to content

Commit

Permalink
Cuda test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed May 28, 2024
1 parent ad2d73f commit aebbd9c
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 23 deletions.
21 changes: 7 additions & 14 deletions crates/luminal_cuda/src/elementwise_fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ extern \"C\" __global__ void kernel({} {type_name}* out, const int n_elements{re
.join("\n "),
op.subexpressions.last().unwrap().0
);
println!("{kernel}");
op.kernel = Some(compile_and_load_kernel(kernel, &device));
op.dyn_chars = dyn_chars;
}
Expand Down Expand Up @@ -903,9 +904,9 @@ mod tests {

pub struct TransformerBlock {
pub attention: SelfAttention,
pub attention_norm: RMSNorm<HIDDEN_DIM>,
pub attention_norm: LayerNorm<HIDDEN_DIM>,
pub feed_forward: Mlp<MLP_DIM, HIDDEN_DIM>,
pub feed_forward_norm: RMSNorm<HIDDEN_DIM>,
pub feed_forward_norm: LayerNorm<HIDDEN_DIM>,
}

impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Expand Down Expand Up @@ -948,17 +949,9 @@ mod tests {
fn initialize(cx: &mut Graph) -> Self {
Self {
attention: InitModule::initialize(cx),
attention_norm: {
let mut norm = RMSNorm::initialize(cx);
norm.epsilon = 1e-5;
norm
},
attention_norm: LayerNorm::init(true, false, false, 1e-5, cx),
feed_forward: InitModule::initialize(cx),
feed_forward_norm: {
let mut norm = RMSNorm::initialize(cx);
norm.epsilon = 1e-5;
norm
},
feed_forward_norm: LayerNorm::init(true, false, false, 1e-5, cx),
}
}
}
Expand All @@ -967,7 +960,7 @@ mod tests {
// Transformer layers
pub layers: Vec<TransformerBlock>,
// Final Norm layer
pub norm: RMSNorm<HIDDEN_DIM>,
pub norm: LayerNorm<HIDDEN_DIM>,
}

impl<Batch: Dimension, CurSeq: Dimension, PrevSeq: Dimension, TotSeq: Dimension>
Expand Down Expand Up @@ -1007,7 +1000,7 @@ mod tests {
impl InitModule for MistralLM {
fn initialize(cx: &mut Graph) -> Self {
Self {
norm: RMSNorm::initialize(cx),
norm: LayerNorm::init(true, false, false, 1e-5, cx),
layers: (0..NUM_LAYERS)
.map(|_| InitModule::initialize(cx))
.collect(),
Expand Down
4 changes: 2 additions & 2 deletions crates/luminal_cuda/src/tests/fp16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,8 @@ fn test_rms_norm() {
let mut cx = Graph::new();
let a = cx.tensor::<R2<15, 32>>().set(inp_data.clone());

let model = luminal_nn::RMSNorm::<32>::initialize(&mut cx);
model.weight.set(weight_data.clone());
let model = luminal_nn::LayerNorm::<32>::new(true, false, false, 1e-5, &mut cx);
model.weight.unwrap().set(weight_data.clone());
let mut b = model.forward(a).retrieve();

cx.compile(<(GenericCompiler, CudaCompiler<f16>)>::default(), &mut b);
Expand Down
4 changes: 2 additions & 2 deletions crates/luminal_cuda/src/tests/fp32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ fn test_rms_norm() {
let mut cx = Graph::new();
let a = cx.tensor::<R2<15, 32>>().set(inp_data.clone());

let model = luminal_nn::RMSNorm::<32>::initialize(&mut cx);
model.weight.set(weight_data.clone());
let model = luminal_nn::LayerNorm::<32>::new(true, false, false, 1e-5, &mut cx);
model.weight.unwrap().set(weight_data.clone());
let mut b = model.forward(a).retrieve();

cx.compile(<(GenericCompiler, CudaCompiler<f32>)>::default(), &mut b);
Expand Down
27 changes: 26 additions & 1 deletion crates/luminal_nn/src/norm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use luminal::prelude::*;
use luminal::{prelude::*, tests::random_vec_rng};
use rand::thread_rng;

/// A simple layer norm with an optional weight and bias
#[derive(Default)]
Expand Down Expand Up @@ -26,6 +27,30 @@ impl<const DIM: usize> LayerNorm<DIM> {
epsilon,
}
}
pub fn init(weight: bool, bias: bool, mean_norm: bool, epsilon: f32, cx: &mut Graph) -> Self {
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
Self {
weight: if weight {
Some(
cx.named_tensor("LayerNorm Weight")
.set(random_vec_rng(DIM, &mut rng)),
)
} else {
None
},
bias: if bias {
Some(
cx.named_tensor("LayerNorm Bias")
.set(random_vec_rng(DIM, &mut rng)),
)
} else {
None
},
mean_norm,
epsilon,
}
}
}

impl<const DIM: usize, S: Shape> Module<GraphTensor<S>> for LayerNorm<DIM>
Expand Down
8 changes: 4 additions & 4 deletions examples/whisper/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ fn main() {
#[cfg(feature = "metal")]
luminal_metal::MetalCompiler::<f16>::default(),
#[cfg(feature = "cuda")]
luminal_cuda::CudaCompiler::<f16>::default(),
luminal_cuda::CudaCompiler::<f32>::default(),
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
luminal_cpu::CPUCompiler::default(),
),
Expand All @@ -77,7 +77,7 @@ fn main() {
#[cfg(feature = "metal")]
luminal_metal::MetalCompiler::<f16>::default(),
#[cfg(feature = "cuda")]
luminal_cuda::CudaCompiler::<f16>::default(),
luminal_cuda::CudaCompiler::<f32>::default(),
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
luminal_cpu::CPUCompiler::default(),
),
Expand All @@ -102,7 +102,7 @@ fn main() {
let now = std::time::Instant::now();
audio_input.set_dyn(vec![0.; 160], &[1, 80, 2]);
enc_cx.set_dyn_dim('d', 1);
enc_cx.execute();
enc_cx.execute_debug();
delete_inputs(downstream(encoder_params, &enc_cx), &mut enc_cx);
text_input.set_dyn(vec![0.], &[1, 1]);
dec_cx.set_dyn_dim('e', 1);
Expand Down Expand Up @@ -131,7 +131,7 @@ fn main() {

audio_input.set_dyn(mel, &[1, 80, mel_len / 80]);
enc_cx.set_dyn_dim('d', (mel_len / 80) / 2);
enc_cx.execute();
enc_cx.execute_debug();
transfer_data(encoded, &mut enc_cx, encoder_output, &mut dec_cx);
println!("\t\t - {}ms", start_encoding.elapsed().as_millis());

Expand Down

0 comments on commit aebbd9c

Please sign in to comment.