Skip to content

Commit

Permalink
[BACKEND] Fix hopper mma to linear layout constraints (#4283)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Jul 9, 2024
1 parent 7a78c04 commit 18996e7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef<int64_t> shape,
int n = mma.getInstrShape()[1];
int k = mma.getInstrShape()[2];
assert(m == 16);
assert(n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256);
assert(k == 8 || k == 16 || k == 32);

MLIRContext *ctx = mma.getContext();
Expand Down
2 changes: 2 additions & 0 deletions unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,8 @@ std::vector<NvidiaMmaLLTestParams> makeNvidiaMmaV3TestCases() {

// These shapes were captured from grep'ing the TTGIR generated by Triton unit
// tests.
addTests({16, 8, 8}, 4, {{16, 16}, {32, 16}, {32, 32}, {64, 64}});
addTests({16, 8, 16}, 4, {{16, 16}, {32, 16}, {32, 32}, {64, 64}});
addTests({16, 16, 8}, 4, {{16, 16}, {32, 16}, {32, 32}, {64, 64}});
addTests({16, 16, 16}, 4, {{64, 16}, {128, 16}, {128, 128}});
addTests({16, 16, 32}, 4, {{64, 16}, {128, 16}});
Expand Down
26 changes: 16 additions & 10 deletions unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "mlir/IR/MLIRContext.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/Support/Signals.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
Expand Down Expand Up @@ -364,16 +365,21 @@ TEST_F(LinearLayoutConversionsTest, MMAv2_Small3D) {
}

TEST_F(LinearLayoutConversionsTest, MMAv3_64x16) {
EXPECT_EQ(toLinearLayout({64, 16}, mma(3, 0, {16, 16, 8}, {4, 1}, {1, 1},
{1, 1}, {1, 0})),
LinearLayout(
{
{S("register"), {{0, 1}, {8, 0}, {0, 8}}},
{S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {{16, 0}, {32, 0}}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
SmallVector<SmallVector<unsigned>, 4> instrShapes = {
{16, 16, 8}, {16, 16, 8}, {16, 8, 8}};
for (auto instrShape : instrShapes) {
SCOPED_TRACE(triton::join(instrShape, ","));
EXPECT_EQ(toLinearLayout({64, 16}, mma(3, 0, instrShape, {4, 1}, {1, 1},
{1, 1}, {1, 0})),
LinearLayout(
{
{S("register"), {{0, 1}, {8, 0}, {0, 8}}},
{S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {{16, 0}, {32, 0}}},
{S("block"), {}},
},
{S("dim0"), S("dim1")}));
}
}

TEST_F(LinearLayoutConversionsTest, MMAv3_128x16) {
Expand Down

0 comments on commit 18996e7

Please sign in to comment.