Skip to content

Commit

Permalink
Async pipeline loads in matmul (#480)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Feb 14, 2025
1 parent f4a420e commit 7f07d39
Show file tree
Hide file tree
Showing 52 changed files with 1,656 additions and 666 deletions.
35 changes: 31 additions & 4 deletions crates/cubecl-core/src/frontend/container/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use super::Line;
/// # Safety
///
/// Since data can't be deallocated during kernel execution, this is safe.
#[derive(Clone)]
#[derive(Clone, Copy)]
pub struct Slice<E> {
_e: PhantomData<E>,
}
Expand All @@ -32,6 +32,7 @@ pub struct Slice<E> {
/// # Safety
///
/// Since data can be accessed by any unit during kernel execution, this can never be safe.
#[derive(Clone, Copy)]
pub struct SliceMut<E> {
_e: PhantomData<E>,
}
Expand All @@ -53,6 +54,17 @@ mod metadata {
{
unexpanded!()
}
/// Try to cast the slice to the given type and panic if the type isn't the same.
///
/// This function should only be used to satify the Rust type system, when two generic
/// types are supposed to be the same.
pub fn try_cast_unchecked<T>(&self) -> Slice<T>
where
E: CubePrimitive,
T: CubePrimitive,
{
unexpanded!()
}
}

impl<E> SliceMut<E> {
Expand All @@ -78,7 +90,7 @@ mod metadata {
elem.__expand_len_method(scope)
}

// Expand method of [len](Slice::to_aligned).
/// Expand method of [len](Slice::to_aligned).
pub fn __expand_to_aligned_method(
self,
_scope: &mut Scope,
Expand All @@ -89,7 +101,22 @@ mod metadata {
self.expand.into()
}

// Expand method of [clone](Clone::clone).
/// Expand method of [try_cast_unchecked](Slice::try_cast_unchecked).
pub fn __expand_try_cast_unchecked_method<T>(
self,
scope: &mut Scope,
) -> ExpandElementTyped<Slice<T>>
where
C: CubePrimitive,
T: CubePrimitive,
{
if T::as_elem(scope) != C::as_elem(scope) {
panic!("Try cast unchecked should only be used to satisfy the rust type system.")
}

self.expand.into()
}

pub fn __expand_clone_method(self, _scope: &mut Scope) -> ExpandElementTyped<Slice<Line<C>>>
where
C: CubePrimitive,
Expand All @@ -105,7 +132,7 @@ mod metadata {
elem.__expand_len_method(scope)
}

// Expand method of [len](SliceMut::into_aligned).
/// Expand method of [len](SliceMut::into_aligned).
pub fn __expand_into_aligned_method(
self,
_scope: &mut Scope,
Expand Down
6 changes: 3 additions & 3 deletions crates/cubecl-core/src/frontend/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl<C: CubePrimitive> Default for Pipeline<C> {

impl<C: CubePrimitive> Pipeline<C> {
/// Create a pipeline instance
pub fn new(_num_stages: u32) -> Self {
pub fn new(_num_stages: u8) -> Self {
Self { _c: PhantomData }
}

Expand Down Expand Up @@ -152,9 +152,9 @@ impl<C: CubePrimitive> Pipeline<C> {
unexpanded!()
}

pub fn __expand_new(scope: &mut Scope, num_stages: u32) -> PipelineExpand<C> {
pub fn __expand_new(scope: &mut Scope, num_stages: u8) -> PipelineExpand<C> {
let elem = C::as_elem(scope);
let variable = scope.create_pipeline(Item::new(elem), num_stages as u8);
let variable = scope.create_pipeline(Item::new(elem), num_stages);
PipelineExpand {
elem: variable,
_c: PhantomData,
Expand Down
6 changes: 3 additions & 3 deletions crates/cubecl-core/src/runtime_tests/memcpy_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use pipeline::Pipeline;
fn one_load<F: Float>(lhs: &Tensor<Line<F>>, output: &mut Tensor<Line<F>>) {
let mut lhs_smem = SharedMemory::<F>::new_lined(4u32, 1u32);

let pipeline = Pipeline::new(1u32);
let pipeline = Pipeline::new(1);

let start = UNIT_POS_X * 2u32;
let end = start + 2u32;
Expand All @@ -32,7 +32,7 @@ fn two_loads<F: Float>(
let mut lhs_smem = SharedMemory::<F>::new_lined(num_data, 1u32);
let mut rhs_smem = SharedMemory::<F>::new_lined(num_data, 1u32);

let pipeline = Pipeline::new(1u32);
let pipeline = Pipeline::new(1);

let start = UNIT_POS_X * num_data / 2;
let end = start + num_data / 2;
Expand Down Expand Up @@ -62,7 +62,7 @@ fn two_independant_loads<F: Float>(
let mut lhs_smem = SharedMemory::<F>::new_lined(num_data, 1u32);
let mut rhs_smem = SharedMemory::<F>::new_lined(num_data, 1u32);

let pipeline = Pipeline::new(2u32);
let pipeline = Pipeline::new(2);

let start = UNIT_POS_X * num_data / 2;
let end = start + num_data / 2;
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-core/src/runtime_tests/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ fn pipelined_sum<F: Float>(
let smem_size = 2 * batch_len;
let num_batches = input.len() / batch_len;
let mut shared_memory = SharedMemory::<F>::new_lined(smem_size, input.line_size());
let pipeline = Pipeline::new(2u32);
let pipeline = Pipeline::new(2);

let mut sum = Line::<F>::empty(input.line_size()).fill(F::new(0.));

Expand Down Expand Up @@ -63,7 +63,7 @@ fn pipelined_sum<F: Float>(

#[cube(launch)]
pub fn async_copy_test<F: Float>(input: &Array<Line<F>>, output: &mut Array<Line<F>>) {
let pipeline = pipeline::Pipeline::<F>::new(2u32);
let pipeline = pipeline::Pipeline::<F>::new(2);
let mut smem = SharedMemory::<F>::new_lined(1u32, 1u32);

if UNIT_POS == 0 {
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ impl<D: Dialect> CppCompiler<D> {

match input {
Variable::Slice { .. } => Instruction::SliceLength { input, out },
Variable::SharedMemory(_id, _item, length) => {
Instruction::ConstLength { length, out }
}
_ => {
let id = match input {
Variable::GlobalInputArray(id, _) => id,
Expand Down
11 changes: 10 additions & 1 deletion crates/cubecl-cpp/src/shared/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ pub enum Instruction<D: Dialect> {
dim: Variable<D>,
out: Variable<D>,
},
ConstLength {
length: u32,
out: Variable<D>,
},
SliceLength {
input: Variable<D>,
out: Variable<D>,
Expand Down Expand Up @@ -216,7 +220,7 @@ impl<D: Dialect> Display for Instruction<D> {
out,
} => {
let item = out.item();
writeln!(f, "const uint {out}_length = {end};")?;
writeln!(f, "const uint {out}_length = {end} - {start};")?;
writeln!(f, "{item} *{out} = {input} + {start};")
}
Instruction::CheckedSlice {
Expand Down Expand Up @@ -454,6 +458,11 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
let out = out.fmt_left();
writeln!(f, "{out} = {input}_length;")
}
Instruction::ConstLength { length, out } => {
let out = out.fmt_left();
writeln!(f, "{out} = {length};")
}

Instruction::Warp(it) => write!(f, "{it}"),
Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out),
Instruction::Wmma(it) => write!(f, "{it}"),
Expand Down
30 changes: 13 additions & 17 deletions crates/cubecl-cpp/src/shared/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,52 +49,48 @@ impl<D: Dialect> Display for PipelineOps<D> {
destination,
} => {
let item = source.item();
let size = item.elem().size() * item.vectorization;
write!(f, "
cuda::memcpy_async(cooperative_groups::this_thread(), {destination}, {source}, {source}_length * {size}, {pipeline});
")
let size = format!("sizeof({item})");
write!(
f,
"
cooperative_groups::memcpy_async({pipeline}_block, {destination}, {source}, {source}_length * {size});
"
)
}
PipelineOps::Init {
pipeline,
num_stages,
} => {
PipelineOps::Init { pipeline, .. } => {
write!(
f,
"
cuda::pipeline_shared_state<cuda::thread_scope::thread_scope_block, {num_stages}> {pipeline}_state;
auto {pipeline} = cuda::make_pipeline(cooperative_groups::this_thread(), &{pipeline}_state);
auto {pipeline}_block = cooperative_groups::this_thread();
"
)
}
PipelineOps::ProducerAcquire { pipeline } => {
PipelineOps::ProducerAcquire { .. } => {
write!(
f,
"
{pipeline}.producer_acquire();
"
)
}
PipelineOps::ProducerCommit { pipeline } => {
PipelineOps::ProducerCommit { .. } => {
write!(
f,
"
{pipeline}.producer_commit();
"
)
}
PipelineOps::ConsumerWait { pipeline } => {
write!(
f,
"
{pipeline}.consumer_wait();
cooperative_groups::wait({pipeline}_block);
"
)
}
PipelineOps::ConsumerRelease { pipeline } => {
PipelineOps::ConsumerRelease { .. } => {
write!(
f,
"
{pipeline}.consumer_release();
"
)
}
Expand Down
33 changes: 20 additions & 13 deletions crates/cubecl-linalg/src/matmul/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,26 @@ use crate::tensor::TensorHandle;
use super::{
components::tile::accelerated::Accelerated,
kernels::{
matmul::{self, PipelinedSelector, SpecializedSelector, StandardSelector},
simple,
matmul::{
self, DoubleBufferingSelector, SimplePipelinedSelector, SimpleSelector,
SpecializedSelector,
},
naive,
tiling2d::{self, Tiling2dConfig},
MatmulLaunchError,
},
};

#[derive(Debug, Clone, Default)]
pub enum Strategy {
Standard,
Pipelined,
Simple,
SimplePipelined,
DoubleBuffering,
Specialized,
#[cfg(any(test, feature = "export_tests"))]
// Very slow, only use for testing.
PlaneMma,
Simple,
Naive,
Tiling2D(Tiling2dConfig),
#[default]
Auto,
Expand Down Expand Up @@ -51,32 +55,35 @@ pub fn launch_ref<R: Runtime, EG: MaybeQuantized>(
out: &TensorHandleRef<R>,
) -> Result<(), MatmulLaunchError> {
match strategy {
Strategy::Standard => {
matmul::launch_ref::<R, EG, StandardSelector<Accelerated>>(client, lhs, rhs, out)
Strategy::Simple => {
matmul::launch_ref::<R, EG, SimpleSelector<Accelerated>>(client, lhs, rhs, out)
}
Strategy::Pipelined => {
matmul::launch_ref::<R, EG, PipelinedSelector<Accelerated>>(client, lhs, rhs, out)
Strategy::SimplePipelined => {
matmul::launch_ref::<R, EG, SimplePipelinedSelector<Accelerated>>(client, lhs, rhs, out)
}
Strategy::DoubleBuffering => {
matmul::launch_ref::<R, EG, DoubleBufferingSelector<Accelerated>>(client, lhs, rhs, out)
}
Strategy::Specialized => {
matmul::launch_ref::<R, EG, SpecializedSelector<Accelerated>>(client, lhs, rhs, out)
}
#[cfg(any(test, feature = "export_tests"))]
Strategy::PlaneMma => {
matmul::launch_ref::<R, EG, StandardSelector<super::components::tile::plane::PlaneMma>>(
matmul::launch_ref::<R, EG, SimpleSelector<super::components::tile::plane::PlaneMma>>(
client, lhs, rhs, out,
)
}
Strategy::Tiling2D(config) => {
tiling2d::launch_ref::<R, EG::Numeric>(client, lhs, rhs, out, config.clone());
Ok(())
}
Strategy::Simple => {
simple::launch_ref::<R, EG::Numeric>(client, lhs, rhs, out)?;
Strategy::Naive => {
naive::launch_ref::<R, EG::Numeric>(client, lhs, rhs, out)?;
Ok(())
}
Strategy::Auto => {
if let Err(err) =
matmul::launch_ref::<R, EG, StandardSelector<Accelerated>>(client, lhs, rhs, out)
matmul::launch_ref::<R, EG, SimpleSelector<Accelerated>>(client, lhs, rhs, out)
{
match err {
super::kernels::MatmulLaunchError::Unavailable(_) => {
Expand Down
Loading

0 comments on commit 7f07d39

Please sign in to comment.