From d65b774db16fa3730589126efd2adbb20c360217 Mon Sep 17 00:00:00 2001 From: Alexander Efimov Date: Mon, 22 Jan 2024 16:39:51 +0000 Subject: [PATCH] [ReductionOp][MFMA] fix reduction of mfma64x4 layout This PR fixes reduction of mfma 64x4 layout and enables related tests. --- lib/Analysis/Utility.cpp | 21 +++++++++++ .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 18 +--------- .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 13 ++++--- lib/Dialect/TritonGPU/IR/Dialect.cpp | 36 ++++++++++++------- python/test/unit/language/test_core_amd.py | 23 +++++++----- 5 files changed, 69 insertions(+), 42 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index a8236dda77b2..6dbc10b943a6 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -54,6 +54,25 @@ SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { auto srcLayout = getSrcLayout(); +// TODO fix mfma order +#ifdef USE_ROCM + if (auto mfmaLayout = + srcLayout.dyn_cast()) { + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); + std::vector order = {1, 0}; + if (mfmaLayout.getIsTransposed()) + std::swap(order[0], order[1]); + + unsigned threadOffset = 1; + for (unsigned i = 0; i < order.size(); i++) { + if (order[i] == axis) + break; + threadOffset *= threadsPerWarp[order[i]]; + } + return threadOffset; + } +#endif + // If the reduction axis is the fast axis of the parent layout if (isReductionOnLayoutFastAxis()) { return 1; @@ -543,6 +562,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { auto dstLayout = dstTy.getEncoding(); auto mfmaLayout = srcLayout.cast(); auto dotOperandLayout = dstLayout.cast(); + auto dstParentLayout = + dotOperandLayout.getParent().cast(); // TODO: Remove the restriction on the warpsPerCTA once chain dot testing is // improved. In addition, we can enable this shortcut for regular MFMA // layout when opIdx == 1. diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index fb6f8135b366..1fe623efab7e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -286,24 +286,8 @@ struct ReduceOpConversion for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { SmallVector shfl(acc.size()); unsigned shuffleIdx = N; -#ifdef USE_ROCM - auto srcTys = op.getInputTypes(); - auto inputTy = srcTys[0].cast(); - auto inMfma = - inputTy.getEncoding().dyn_cast(); - if (inMfma && inMfma.getIsTransposed()) { - assert(numLaneToReduce == 2 || numLaneToReduce == 4); - // for mfma 32x32 adjacent threads in y dimension in transposed MFMA - // layout are 32 apart: [[0 0 0 0 32 32 32 32 ...] [1 1 1 1 33 33 33 33 - // ...] ...]. for mfma 16x16 adjacent threads in y dimension in - // transposed MFMA layout are 16 apart: [[0 0 0 0 16 16 16 16 32 32 32 - // 32 ...] [1 1 1 1 33 33 33 33 ...] ...]. - const int warpSize = 64; - shuffleIdx = warpSize / N / 2; - } -#endif for (unsigned i = 0; i < acc.size(); ++i) { - shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx * interleave); + shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx * interleave); } accumulate(rewriter, op.getCombineOp(), acc, shfl, false); } diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 5d2dbae1718e..b7c4811cbd2d 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -1209,10 +1209,13 @@ class ConvertTritonGPUOpToLLVMPatternBase { Value laneId = urem(threadId, effectiveWarpSize); Value warpId = udiv(threadId, warpSize); - Value warpId0 = - urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / mDim)); - Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]), - i32_val(shape[1] / nDim)); + Value limitWarpId0 = + i32_val(std::max(static_cast(1), shape[0] / mDim)); + Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), limitWarpId0); + Value limitWarpId1 = + i32_val(std::max(static_cast(1), shape[1] / nDim)); + Value warpId1 = + urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]), limitWarpId1); Value offWarp0 = mul(warpId0, i32_val(mDim)); Value offWarp1 = mul(warpId1, i32_val(nDim)); @@ -1221,7 +1224,7 @@ class ConvertTritonGPUOpToLLVMPatternBase { if (mfmaLayout.getIsTransposed()) { multiDimBase[1] = add(mul(i32_val(4), udiv(laneId, i32_val(mDim))), offWarp1); - multiDimBase[0] = add(urem(laneId, i32_val(nDim)), offWarp0); + multiDimBase[0] = add(urem(laneId, i32_val(mDim)), offWarp0); } else { multiDimBase[0] = add(mul(i32_val(4), udiv(laneId, i32_val(nDim))), offWarp0); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 6e06a043ed4f..87e8bb218bc9 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -103,6 +103,8 @@ SmallVector getThreadsPerWarp(Attribute layout) { return {8, 4}; } if (auto mfmaLayout = layout.dyn_cast()) { + // cols counts how many threads along axis 0 (i.e. along one column) + // rows counts how many threads along axis 1 (i.e. along one row) unsigned rows = -1, cols = -1; unsigned mDim = mfmaLayout.getMDim(); unsigned nDim = mfmaLayout.getNDim(); @@ -117,21 +119,28 @@ SmallVector getThreadsPerWarp(Attribute layout) { cols = 4; rows = 16; } + if (mfmaLayout.getIsTransposed()) + std::swap(cols, rows); } else { if (mDim == 64 && nDim == 4) { - cols = 16; - rows = 4; + if (!mfmaLayout.getIsTransposed()) { + cols = 16; + rows = 4; + } else { + cols = 64; + rows = 1; + } } else if (mDim == 4 && nDim == 64) { - cols = 4; - rows = 16; + if (mfmaLayout.getIsTransposed()) { + cols = 4; + rows = 16; + } else { + cols = 1; + rows = 64; + } } } - - if (mfmaLayout.getIsTransposed()) { - return {rows, cols}; - } else { - return {cols, rows}; - } + return {cols, rows}; } if (auto sliceLayout = layout.dyn_cast()) { auto parent = sliceLayout.getParent(); @@ -321,8 +330,11 @@ SmallVector getContigPerThread(Attribute layout) { if (auto mmaLayout = layout.dyn_cast()) { assert(mmaLayout.isVolta() || mmaLayout.isAmpere() || mmaLayout.isHopper()); return {1, 2}; - } else if (layout.isa()) { - return {1, 1}; + } else if (auto mfmaLayout = layout.dyn_cast()) { + if (mfmaLayout.getIsTransposed()) + return {1, 4}; + else + return {4, 1}; } else if (auto sliceLayout = layout.dyn_cast()) { auto parentLayout = sliceLayout.getParent(); return getContigPerThread(parentLayout); diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 235a263484e5..2bf5c63dd613 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -3040,6 +3040,14 @@ def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile layouts = [ MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[32,32], is_transposed=True), MfmaLayout(version=(2,0), warps_per_cta=[2, 2], instr_shape=[32,32], is_transposed=False), + MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[16,16], is_transposed=True), + MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[16,16], is_transposed=False), + MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[4,4], is_transposed=True), + MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[4,4], is_transposed=False), + MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[4,64], is_transposed=True), + MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[4,64], is_transposed=False), + MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[64,4], is_transposed=True), + MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[64,4], is_transposed=False), ] shapes = [[128, 32], [128, 128], [32, 128], [64, 64]] else: @@ -3058,16 +3066,15 @@ def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("axis", [0, 1]) def test_reduce_layouts(M, N, src_layout, axis, device='cuda'): - if is_hip(): - pytest.skip("Skiping test_reduce_layouts for now.") - - if torch.version.hip is not None and _get_warp_size() == 64: - if src_layout.is_transposed and axis == 0: - pytest.skip("Reduce along axis 0 is not supported in transposed mfma layout") + if torch.version.hip is not None: + if src_layout.warps_per_cta != "[1, 4]" and axis == 0: + pytest.skip("Reduce between warps is not supported") + if src_layout.warps_per_cta != "[4, 1]" and axis == 1: + pytest.skip("Reduce between warps is not supported") rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" rdims_1d = f"{N}" if axis == 0 else f"{M}" store_range = "%7" if axis == 0 else "%1" - blocked = BlockedLayout([1, 1], [32, 1], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]) + blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP//32], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]) ir = f""" #blocked = {blocked} #src = {src_layout} @@ -3109,7 +3116,7 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'): kernel = triton.compile(f.name) rs = RandomState(17) - x = rs.randint(0, 4, (M, N)).astype('float32') + x = rs.randint(0, 1024, (M, N)).astype('float32') x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32') if axis == 0: