Skip to content

Commit

Permalink
Refactor of tiling layout (#487)
Browse files Browse the repository at this point in the history
* refactor tiling order into enum

* encapsulate tiling order
  • Loading branch information
louisfd authored Feb 17, 2025
1 parent c305e93 commit 97f658e
Show file tree
Hide file tree
Showing 23 changed files with 141 additions and 156 deletions.
4 changes: 2 additions & 2 deletions crates/cubecl-linalg/src/matmul/components/global/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use cubecl_core as cubecl;
use cubecl_core::prelude::*;
use pipeline::Pipeline;

use crate::matmul::components::stage::{self, StageWriter, TilingOrderConfig};
use crate::matmul::components::stage::{self, StageWriter, TilingLayout};
use crate::matmul::components::{config::MatmulConfig, tile};
use crate::matmul::components::{Ident, MatrixLayout};
use crate::matmul::components::{InvalidConfigError, MatmulConfigFactory};
Expand Down Expand Up @@ -177,7 +177,7 @@ pub trait GlobalConfig: MatmulConfig {
fn plane_dim(&self) -> u32;

/// Returns the order in which tiles should be loaded to the stage
fn tiling_order(&self, ident: Ident) -> TilingOrderConfig;
fn tiling_layout(&self, ident: Ident) -> TilingLayout;

/// Whether to check if accessing a row would exceed bounds.
fn check_row_bounds(&self, ident: Ident) -> bool;
Expand Down
6 changes: 3 additions & 3 deletions crates/cubecl-linalg/src/matmul/components/global/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::matmul::components::{
stage::{self, TilingOrderConfig},
stage::{self, TilingLayout},
Ident, MatmulConfig, MatrixLayout, StageTiling,
};

Expand Down Expand Up @@ -66,8 +66,8 @@ impl<S: stage::StageConfig> super::GlobalConfig for CommonGlobalConfig<S> {
self.smm_config.plane_dim()
}

fn tiling_order(&self, ident: Ident) -> TilingOrderConfig {
self.smm_config.tiling_order(ident)
fn tiling_layout(&self, ident: Ident) -> TilingLayout {
self.smm_config.tiling_layout(ident)
}

fn check_row_bounds(&self, ident: Ident) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::matmul::components::global::multi_stage::buffer_loading::BufferLoadin
use crate::matmul::components::global::tensor_view::TensorReader;
use crate::matmul::components::global::{CommonGlobalConfig, InputLoader};
use crate::matmul::components::stage::single_buffer::{LhsBufferReader, RhsBufferReader};
use crate::matmul::components::stage::TilingOrderConfig;
use crate::matmul::components::stage::{self, Stage};
use crate::matmul::components::stage::{TilingLayout, TilingOrder};
use crate::matmul::components::{global, Ident, InvalidConfigError};
use crate::tensor::VirtualTensor;
use cubecl_core as cubecl;
Expand Down Expand Up @@ -182,14 +182,14 @@ pub fn check_buffers_contiguous<G: global::GlobalConfig>(
) -> Result<(), InvalidConfigError> {
match ident.as_input() {
InputIdent::Lhs => {
if let TilingOrderConfig::RowMajor = config.tiling_order(ident) {
if let TilingLayout::Contiguous(TilingOrder::RowMajor) = config.tiling_layout(ident) {
return Err(Box::new(
"Lhs must have ColMajor tiling order in pipelined setting",
));
}
}
InputIdent::Rhs => {
if let TilingOrderConfig::ColMajor = config.tiling_order(ident) {
if let TilingLayout::Contiguous(TilingOrder::ColMajor) = config.tiling_layout(ident) {
return Err(Box::new(
"Rhs must have RowMajor tiling order in pipelined setting",
));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::matmul::components::{
global::{self, GlobalConfig, LoadMode},
stage::{self, TilingOrderConfig},
stage::{self, TilingLayout},
Ident, MatmulConfig, MatrixLayout, StageTiling,
};

Expand Down Expand Up @@ -59,8 +59,8 @@ impl<S: stage::StageConfig> global::GlobalConfig for Config<S> {
self.smm_config.plane_dim()
}

fn tiling_order(&self, ident: Ident) -> TilingOrderConfig {
self.smm_config.tiling_order(ident)
fn tiling_layout(&self, ident: Ident) -> TilingLayout {
self.smm_config.tiling_layout(ident)
}

fn check_row_bounds(&self, ident: Ident) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::matmul::components::global::multi_stage::buffer_loading::BufferLoadin
use crate::matmul::components::global::tensor_view::TensorReader;
use crate::matmul::components::global::InputLoader;
use crate::matmul::components::stage::single_buffer::{LhsBufferReader, RhsBufferReader};
use crate::matmul::components::stage::TilingOrderConfig;
use crate::matmul::components::stage::{self, Stage};
use crate::matmul::components::stage::{TilingLayout, TilingOrder};
use crate::matmul::components::{global, Ident};
use crate::tensor::VirtualTensor;
use cubecl_core as cubecl;
Expand Down Expand Up @@ -193,12 +193,12 @@ fn load_buffer<EG: Numeric, ES: Numeric, S: stage::StageConfig>(
fn check_buffers_contiguous<G: global::GlobalConfig>(ident: Ident, config: G) {
match ident.as_input() {
InputIdent::Lhs => {
if let TilingOrderConfig::RowMajor = config.tiling_order(ident) {
if let TilingLayout::Contiguous(TilingOrder::RowMajor) = config.tiling_layout(ident) {
panic!("Lhs must have ColMajor tiling order in producer consumer setting")
}
}
InputIdent::Rhs => {
if let TilingOrderConfig::ColMajor = config.tiling_order(ident) {
if let TilingLayout::Contiguous(TilingOrder::ColMajor) = config.tiling_layout(ident) {
panic!("Rhs must have RowMajor tiling order in producer consumer setting")
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::matmul::components::{
global::{self, LoadMode},
stage::{self, TilingOrderConfig},
stage::{self, TilingLayout},
Ident, MatmulConfig, MatrixLayout, StageTiling,
};

Expand Down Expand Up @@ -59,8 +59,8 @@ impl<S: stage::StageConfig> global::GlobalConfig for Config<S> {
self.smm_config.plane_dim()
}

fn tiling_order(&self, ident: Ident) -> TilingOrderConfig {
self.smm_config.tiling_order(ident)
fn tiling_layout(&self, ident: Ident) -> TilingLayout {
self.smm_config.tiling_layout(ident)
}

fn check_row_bounds(&self, ident: Ident) -> bool {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use crate::matmul::components::global::tensor_view::TensorReader;
use crate::matmul::components::global::{GlobalConfig, LoadMode, LoadingValidation};
use crate::matmul::components::stage::{
ColMajorTiling, RowMajorTiling, TilingOrder, TilingOrderConfig,
};
use crate::matmul::components::stage::TilingLayout;
use crate::matmul::components::{Ident, InvalidConfigError, MatrixLayout};
use cubecl_core as cubecl;
use cubecl_core::prelude::*;
Expand Down Expand Up @@ -88,18 +86,12 @@ impl LoadingStrategy for CyclicLoading {
let slice_index = unit_id + total_units * i;

let nth_tile = slice_index / num_slices_per_tile;
let (tile_x, tile_y) = match config.tiling_order(ident) {
TilingOrderConfig::RowMajor => RowMajorTiling::to_x_y(
nth_tile,
stage_dim.tile_count_row(),
stage_dim.tile_count_col(),
),
TilingOrderConfig::ColMajor => ColMajorTiling::to_x_y(
nth_tile,
stage_dim.tile_count_row(),
stage_dim.tile_count_col(),
),
};
let (tile_x, tile_y) = TilingLayout::to_x_y(
config.tiling_layout(ident),
nth_tile,
stage_dim.tile_count_row(),
stage_dim.tile_count_col(),
);
let nth_slice = slice_index % num_slices_per_tile;

// TODO make branching comptime conditional
Expand Down Expand Up @@ -147,18 +139,12 @@ impl LoadingStrategy for CyclicLoading {
let nth_tile = unit_position / tile_num_elements;
let pos_within_tile = unit_position % tile_num_elements;

let (tile_x, tile_y) = match config.tiling_order(ident) {
TilingOrderConfig::RowMajor => RowMajorTiling::to_x_y(
nth_tile,
tiling.tile_count_row(),
tiling.tile_count_col(),
),
TilingOrderConfig::ColMajor => ColMajorTiling::to_x_y(
nth_tile,
tiling.tile_count_row(),
tiling.tile_count_col(),
),
};
let (tile_x, tile_y) = TilingLayout::to_x_y(
config.tiling_layout(ident),
nth_tile,
tiling.tile_count_row(),
tiling.tile_count_col(),
);

let line_read =
read_view.load_coalesced::<G>(tile_x, tile_y, pos_within_tile, ident, config);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use crate::matmul::components::global::tensor_view::TensorReader;
use crate::matmul::components::global::{GlobalConfig, LoadMode, LoadingValidation};
use crate::matmul::components::stage::{
ColMajorTiling, RowMajorTiling, TilingOrder, TilingOrderConfig,
};
use crate::matmul::components::stage::TilingLayout;
use crate::matmul::components::{FormattedConfigError, Ident, InvalidConfigError};
use cubecl_core as cubecl;
use cubecl_core::prelude::*;
Expand Down Expand Up @@ -82,14 +80,12 @@ impl LoadingStrategy for TilewiseLoading {

let num_loads_per_unit = num_lines_per_tile / config.plane_dim();

let (tile_x, tile_y) = match config.tiling_order(ident) {
TilingOrderConfig::RowMajor => {
RowMajorTiling::to_x_y(nth_tile, tiling.tile_count_row(), tiling.tile_count_col())
}
TilingOrderConfig::ColMajor => {
ColMajorTiling::to_x_y(nth_tile, tiling.tile_count_row(), tiling.tile_count_col())
}
};
let (tile_x, tile_y) = TilingLayout::to_x_y(
config.tiling_layout(ident),
nth_tile,
tiling.tile_count_row(),
tiling.tile_count_col(),
);

for i in 0..num_loads_per_unit {
let pos_within_tile = i * config.plane_dim() + UNIT_POS_X;
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-linalg/src/matmul/components/stage/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::matmul::components::{global, MatmulConfigFactory};
use crate::matmul::components::{Ident, MatrixLayout};
use crate::matmul::components::{MatmulSize, StageTiling};

use super::tiling_order::TilingOrderConfig;
use super::TilingLayout;

pub trait ReaderFamily {
type Reader<I: Numeric>: CubeType;
Expand Down Expand Up @@ -151,7 +151,7 @@ pub trait StageConfig: MatmulConfig {
fn plane_dim(&self) -> u32;

/// Returns the order in which tiles should be loaded to the stage
fn tiling_order(&self, ident: Ident) -> TilingOrderConfig;
fn tiling_layout(&self, ident: Ident) -> TilingLayout;

fn tile_count(&self) -> &MatmulSize;
}
54 changes: 54 additions & 0 deletions crates/cubecl-linalg/src/matmul/components/stage/layout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use cubecl_core::prelude::*;
use cubecl_core::{self as cubecl};

#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
/// How the tiles are stored in shared memory
pub enum TilingLayout {
/// Each tile is stored contiguously in memory.
/// Tiles are placed sequentially in memory according to the specified `TilingOrder`.
Contiguous(TilingOrder),

/// Tiles follow the memory layout of the underlying global memory,
/// meaning elements within a tile may be interleaved with elements from other tiles.
Strided,
}

#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
/// Layout in which to store tiles within the stage
pub enum TilingOrder {
/// Tiles are conceptually stored in row-major order, regardless of the actual data layout.
RowMajor,
/// Tiles are conceptually stored in column-major order, regardless of the actual data layout.
ColMajor,
}

#[cube]
impl TilingLayout {
/// Converts a tile index in the stage to its (x,y) position
pub fn to_x_y(#[comptime] this: TilingLayout, nth: u32, num_x: u32, num_y: u32) -> (u32, u32) {
match comptime!(this) {
TilingLayout::Contiguous(tiling_order) => match comptime!(tiling_order) {
TilingOrder::RowMajor => (nth / num_y, nth % num_y),
TilingOrder::ColMajor => (nth % num_x, nth / num_x),
},
TilingLayout::Strided => todo!(),
}
}

/// Converts an (x,y) position to its tile index in the stage
pub fn to_nth_tile(
#[comptime] this: TilingLayout,
x: u32,
y: u32,
num_x: u32,
num_y: u32,
) -> u32 {
match comptime!(this) {
TilingLayout::Contiguous(tiling_order) => match comptime!(tiling_order) {
TilingOrder::RowMajor => x * num_y + y,
TilingOrder::ColMajor => y * num_x + x,
},
TilingLayout::Strided => todo!(),
}
}
}
4 changes: 2 additions & 2 deletions crates/cubecl-linalg/src/matmul/components/stage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ pub mod multi_buffer;
pub mod single_buffer;

mod base;
mod layout;
pub(super) mod shared;
mod staging;
mod tiling_order;

pub use base::*;
pub use layout::*;
pub use staging::Stage;
pub use tiling_order::*;
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ impl<TMM: TileMatmulFamily> MatmulConfigFactory for MultiBufferMatmulFamily<TMM>
tmm_config,
tiling,
cube_dim.y,
advanced_config.lhs_tiling_order,
advanced_config.rhs_tiling_order,
advanced_config.lhs_tiling_layout,
advanced_config.rhs_tiling_layout,
quantized,
)
}
Expand Down
20 changes: 10 additions & 10 deletions crates/cubecl-linalg/src/matmul/components/stage/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ use crate::matmul::components::{
MatrixLayout, StageTiling,
};

use super::{StageConfig, TilingOrderConfig};
use super::{StageConfig, TilingLayout};

#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
/// Configuration for the single buffer matmul
pub struct CommonStageConfig<T: TileConfig> {
pub tmm_config: T,
pub tiling: CompleteStageTiling,
pub num_planes: u32,
pub lhs_tiling_order: TilingOrderConfig,
pub rhs_tiling_order: TilingOrderConfig,
pub lhs_tiling_layout: TilingLayout,
pub rhs_tiling_layout: TilingLayout,
pub quantized: bool,
}

Expand Down Expand Up @@ -43,10 +43,10 @@ impl<T: TileConfig> StageConfig for CommonStageConfig<T> {
self.tmm_config.plane_dim()
}

fn tiling_order(&self, ident: Ident) -> TilingOrderConfig {
fn tiling_layout(&self, ident: Ident) -> TilingLayout {
match ident.as_input() {
InputIdent::Lhs => self.lhs_tiling_order,
InputIdent::Rhs => self.rhs_tiling_order,
InputIdent::Lhs => self.lhs_tiling_layout,
InputIdent::Rhs => self.rhs_tiling_layout,
}
}

Expand All @@ -63,16 +63,16 @@ impl<T: TileConfig> CommonStageConfig<T> {
tmm_config: T,
tiling: CompleteStageTiling,
num_planes: u32,
lhs_tiling_order: TilingOrderConfig,
rhs_tiling_order: TilingOrderConfig,
lhs_tiling_layout: TilingLayout,
rhs_tiling_layout: TilingLayout,
quantized: bool,
) -> Self {
Self {
tmm_config,
tiling,
num_planes,
lhs_tiling_order,
rhs_tiling_order,
lhs_tiling_layout,
rhs_tiling_layout,
quantized,
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ where
tmm_config,
tiling,
tile_count.m,
advanced_config.lhs_tiling_order,
advanced_config.rhs_tiling_order,
advanced_config.lhs_tiling_layout,
advanced_config.rhs_tiling_layout,
quantized,
)
}
Expand Down
Loading

0 comments on commit 97f658e

Please sign in to comment.