From 18996e774f9637960cfbeb7c0026ade2691ab1e0 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 9 Jul 2024 15:37:15 -0400 Subject: [PATCH] [BACKEND] Fix hopper mma to linear layout constraints (#4283) n = 8 should be a valid option https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape --- .../TritonGPU/IR/LinearLayoutConversions.cpp | 2 +- .../TritonGPUToLLVM/EmitIndicesTest.cpp | 2 ++ .../TritonGPU/LinearLayoutConversionsTest.cpp | 26 ++++++++++++------- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 2836b779bbb9..47d2c3cf3a18 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -387,7 +387,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, int n = mma.getInstrShape()[1]; int k = mma.getInstrShape()[2]; assert(m == 16); - assert(n == 16 || n == 32 || n == 64 || n == 128 || n == 256); + assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256); assert(k == 8 || k == 16 || k == 32); MLIRContext *ctx = mma.getContext(); diff --git a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp index 58e52ffeb7d2..42486fc3c726 100644 --- a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp @@ -1209,6 +1209,8 @@ std::vector makeNvidiaMmaV3TestCases() { // These shapes were captured from grep'ing the TTGIR generated by Triton unit // tests. + addTests({16, 8, 8}, 4, {{16, 16}, {32, 16}, {32, 32}, {64, 64}}); + addTests({16, 8, 16}, 4, {{16, 16}, {32, 16}, {32, 32}, {64, 64}}); addTests({16, 16, 8}, 4, {{16, 16}, {32, 16}, {32, 32}, {64, 64}}); addTests({16, 16, 16}, 4, {{64, 16}, {128, 16}, {128, 128}}); addTests({16, 16, 32}, 4, {{64, 16}, {128, 16}}); diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index ecd16095acae..1bfc5b641555 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -3,6 +3,7 @@ #include "mlir/IR/MLIRContext.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/StrUtil.h" #include "llvm/Support/Signals.h" #include #include @@ -364,16 +365,21 @@ TEST_F(LinearLayoutConversionsTest, MMAv2_Small3D) { } TEST_F(LinearLayoutConversionsTest, MMAv3_64x16) { - EXPECT_EQ(toLinearLayout({64, 16}, mma(3, 0, {16, 16, 8}, {4, 1}, {1, 1}, - {1, 1}, {1, 0})), - LinearLayout( - { - {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, - {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, - {S("warp"), {{16, 0}, {32, 0}}}, - {S("block"), {}}, - }, - {S("dim0"), S("dim1")})); + SmallVector, 4> instrShapes = { + {16, 16, 8}, {16, 16, 8}, {16, 8, 8}}; + for (auto instrShape : instrShapes) { + SCOPED_TRACE(triton::join(instrShape, ",")); + EXPECT_EQ(toLinearLayout({64, 16}, mma(3, 0, instrShape, {4, 1}, {1, 1}, + {1, 1}, {1, 0})), + LinearLayout( + { + {S("register"), {{0, 1}, {8, 0}, {0, 8}}}, + {S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{16, 0}, {32, 0}}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + } } TEST_F(LinearLayoutConversionsTest, MMAv3_128x16) {