Skip to content

Commit

Permalink
[ReductionOp][MFMA] fix reduction of mfma64x4 layout
Browse files Browse the repository at this point in the history
This PR fixes reduction of mfma 64x4 layout and enables related tests.
  • Loading branch information
binarman committed Mar 18, 2024
1 parent b3c9d8d commit d65b774
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 42 deletions.
21 changes: 21 additions & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
auto srcLayout = getSrcLayout();

// TODO fix mfma order
#ifdef USE_ROCM
if (auto mfmaLayout =
srcLayout.dyn_cast<mlir::triton::gpu::MfmaEncodingAttr>()) {
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout);
std::vector<int> order = {1, 0};
if (mfmaLayout.getIsTransposed())
std::swap(order[0], order[1]);

unsigned threadOffset = 1;
for (unsigned i = 0; i < order.size(); i++) {
if (order[i] == axis)
break;
threadOffset *= threadsPerWarp[order[i]];
}
return threadOffset;
}
#endif

// If the reduction axis is the fast axis of the parent layout
if (isReductionOnLayoutFastAxis()) {
return 1;
Expand Down Expand Up @@ -543,6 +562,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto dstLayout = dstTy.getEncoding();
auto mfmaLayout = srcLayout.cast<triton::gpu::MfmaEncodingAttr>();
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
auto dstParentLayout =
dotOperandLayout.getParent().cast<triton::gpu::MfmaEncodingAttr>();
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
// improved. In addition, we can enable this shortcut for regular MFMA
// layout when opIdx == 1.
Expand Down
18 changes: 1 addition & 17 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,24 +286,8 @@ struct ReduceOpConversion
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(acc.size());
unsigned shuffleIdx = N;
#ifdef USE_ROCM
auto srcTys = op.getInputTypes();
auto inputTy = srcTys[0].cast<RankedTensorType>();
auto inMfma =
inputTy.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
if (inMfma && inMfma.getIsTransposed()) {
assert(numLaneToReduce == 2 || numLaneToReduce == 4);
// for mfma 32x32 adjacent threads in y dimension in transposed MFMA
// layout are 32 apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33
// ...] ...]. for mfma 16x16 adjacent threads in y dimension in
// transposed MFMA layout are 16 apart: [[0 0 0 0 16 16 16 16 32 32 32
// 32 ...] [1 1 1 1 33 33 33 33 ...] ...].
const int warpSize = 64;
shuffleIdx = warpSize / N / 2;
}
#endif
for (unsigned i = 0; i < acc.size(); ++i) {
shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx * interleave);
shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx * interleave);
}
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
}
Expand Down
13 changes: 8 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -1209,10 +1209,13 @@ class ConvertTritonGPUOpToLLVMPatternBase {
Value laneId = urem(threadId, effectiveWarpSize);

Value warpId = udiv(threadId, warpSize);
Value warpId0 =
urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / mDim));
Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]),
i32_val(shape[1] / nDim));
Value limitWarpId0 =
i32_val(std::max(static_cast<int64_t>(1), shape[0] / mDim));
Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), limitWarpId0);
Value limitWarpId1 =
i32_val(std::max(static_cast<int64_t>(1), shape[1] / nDim));
Value warpId1 =
urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]), limitWarpId1);

Value offWarp0 = mul(warpId0, i32_val(mDim));
Value offWarp1 = mul(warpId1, i32_val(nDim));
Expand All @@ -1221,7 +1224,7 @@ class ConvertTritonGPUOpToLLVMPatternBase {
if (mfmaLayout.getIsTransposed()) {
multiDimBase[1] =
add(mul(i32_val(4), udiv(laneId, i32_val(mDim))), offWarp1);
multiDimBase[0] = add(urem(laneId, i32_val(nDim)), offWarp0);
multiDimBase[0] = add(urem(laneId, i32_val(mDim)), offWarp0);
} else {
multiDimBase[0] =
add(mul(i32_val(4), udiv(laneId, i32_val(nDim))), offWarp0);
Expand Down
36 changes: 24 additions & 12 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
return {8, 4};
}
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
// cols counts how many threads along axis 0 (i.e. along one column)
// rows counts how many threads along axis 1 (i.e. along one row)
unsigned rows = -1, cols = -1;
unsigned mDim = mfmaLayout.getMDim();
unsigned nDim = mfmaLayout.getNDim();
Expand All @@ -117,21 +119,28 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
cols = 4;
rows = 16;
}
if (mfmaLayout.getIsTransposed())
std::swap(cols, rows);
} else {
if (mDim == 64 && nDim == 4) {
cols = 16;
rows = 4;
if (!mfmaLayout.getIsTransposed()) {
cols = 16;
rows = 4;
} else {
cols = 64;
rows = 1;
}
} else if (mDim == 4 && nDim == 64) {
cols = 4;
rows = 16;
if (mfmaLayout.getIsTransposed()) {
cols = 4;
rows = 16;
} else {
cols = 1;
rows = 64;
}
}
}

if (mfmaLayout.getIsTransposed()) {
return {rows, cols};
} else {
return {cols, rows};
}
return {cols, rows};
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
Expand Down Expand Up @@ -321,8 +330,11 @@ SmallVector<unsigned> getContigPerThread(Attribute layout) {
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.isVolta() || mmaLayout.isAmpere() || mmaLayout.isHopper());
return {1, 2};
} else if (layout.isa<MfmaEncodingAttr>()) {
return {1, 1};
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
if (mfmaLayout.getIsTransposed())
return {1, 4};
else
return {4, 1};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
return getContigPerThread(parentLayout);
Expand Down
23 changes: 15 additions & 8 deletions python/test/unit/language/test_core_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3040,6 +3040,14 @@ def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile
layouts = [
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[32,32], is_transposed=True),
MfmaLayout(version=(2,0), warps_per_cta=[2, 2], instr_shape=[32,32], is_transposed=False),
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[16,16], is_transposed=True),
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[16,16], is_transposed=False),
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[4,4], is_transposed=True),
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[4,4], is_transposed=False),
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[4,64], is_transposed=True),
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[4,64], is_transposed=False),
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[64,4], is_transposed=True),
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[64,4], is_transposed=False),
]
shapes = [[128, 32], [128, 128], [32, 128], [64, 64]]
else:
Expand All @@ -3058,16 +3066,15 @@ def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("axis", [0, 1])
def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
if is_hip():
pytest.skip("Skiping test_reduce_layouts for now.")

if torch.version.hip is not None and _get_warp_size() == 64:
if src_layout.is_transposed and axis == 0:
pytest.skip("Reduce along axis 0 is not supported in transposed mfma layout")
if torch.version.hip is not None:
if src_layout.warps_per_cta != "[1, 4]" and axis == 0:
pytest.skip("Reduce between warps is not supported")
if src_layout.warps_per_cta != "[4, 1]" and axis == 1:
pytest.skip("Reduce between warps is not supported")
rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1"
rdims_1d = f"{N}" if axis == 0 else f"{M}"
store_range = "%7" if axis == 0 else "%1"
blocked = BlockedLayout([1, 1], [32, 1], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1])
blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP//32], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1])
ir = f"""
#blocked = {blocked}
#src = {src_layout}
Expand Down Expand Up @@ -3109,7 +3116,7 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
kernel = triton.compile(f.name)

rs = RandomState(17)
x = rs.randint(0, 4, (M, N)).astype('float32')
x = rs.randint(0, 1024, (M, N)).astype('float32')
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')

if axis == 0:
Expand Down

0 comments on commit d65b774

Please sign in to comment.