Skip to content

Commit

Permalink
[BACKEND] LL for ldmatrix part3 - ldmatrix.x2/x1 for small tiles (tri…
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Jan 26, 2025
1 parent b1301d6 commit 53e6e9e
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 104 deletions.
37 changes: 25 additions & 12 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,7 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
auto rank = shape.size();
auto opIdx = dot.getOpIdx();
int kDim = (opIdx == 0) ? rank - 1 : rank - 2;
int nonKDim = (opIdx == 0) ? rank - 2 : rank - 1;

StringAttr kReg = S("register");
StringAttr kLane = S("lane");
Expand All @@ -1117,8 +1118,11 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
auto reg = 1 << logReg;
basesReg.push_back({0, reg});
}
std::vector<std::vector<int>> basesLane = {{1, 0}, {2, 0}, {4, 0}};
int numTileCols;
std::vector<std::vector<int>> basesLane = {
{1, 0}, {2, 0}, {4, 0}, {0, 0}, {0, 0}};
bool kX2 = shape[kDim] > 8 * 16 / elemBitWidth;
bool kX4 = shape[kDim] > 16 * 16 / elemBitWidth;
bool nonKX2 = shape[nonKDim] > 8;
// Construct a tile consisting of 4 8x8x16bits sub-tiles to use ldmatrix
// efficiently. opIdx=0 and opIdx=1 are handled differently.
if (opIdx == 0) {
Expand All @@ -1131,13 +1135,16 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
if (needTrans) {
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
"supported in the transposed mode");
basesLane.push_back({0, 8});
basesLane.push_back({8, 0});
if (nonKX2)
basesLane[3] = {0, 8};
if (kX2)
basesLane[4] = {8 * 16 / elemBitWidth, 0};
} else {
basesLane.push_back({8, 0});
basesLane.push_back({0, 8 * 16 / elemBitWidth});
if (nonKX2)
basesLane[3] = {8, 0};
if (kX2)
basesLane[4] = {0, 8 * 16 / elemBitWidth};
}
numTileCols = 16 * 16 / elemBitWidth;
} else {
// The matrix elements of thread 0 are distributed in the following pattern
// (fp16):
Expand All @@ -1147,14 +1154,20 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
if (needTrans) {
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
"supported in the transposed mode");
basesLane.push_back({8, 0});
basesLane.push_back({16, 0});
if (kX2)
basesLane[3] = {8, 0};
if (kX4)
basesLane[4] = {16, 0};
} else {
basesLane.push_back({0, 8 * 16 / elemBitWidth});
basesLane.push_back({0, 16 * 16 / elemBitWidth});
if (kX2)
basesLane[3] = {0, 8 * 16 / elemBitWidth};
if (kX4)
basesLane[4] = {0, 16 * 16 / elemBitWidth};
}
numTileCols = 32 * 16 / elemBitWidth;
}
int numTileCols =
(8 * 16 / elemBitWidth)
<< (static_cast<int>(kX2) + static_cast<int>(kX4 && opIdx == 1));
// Expand the `register` dimension so the size of columns matches `K`.
auto layout =
LinearLayout({{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}},
Expand Down
15 changes: 10 additions & 5 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -862,8 +862,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
tt.func @convert_dot_ldmatrix(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
// CHECK: nvgpu.ldmatrix
// CHECK: nvgpu.ldmatrix
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK-NOT: nvgpu.ldmatrix
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
Expand Down Expand Up @@ -892,8 +893,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
tt.func @convert_dot_ldmatrix_swizzle(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) {
%AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
%BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem>
// CHECK: nvgpu.ldmatrix
// CHECK: nvgpu.ldmatrix
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK-NOT: nvgpu.ldmatrix
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a>
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b>
Expand Down Expand Up @@ -974,7 +976,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) {
%AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
%BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem>
// CHECK: nvgpu.ldmatrix
// CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
// CHECK-NOT: nvgpu.ldmatrix
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a>
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b>
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0>
Expand Down
108 changes: 21 additions & 87 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,98 +52,28 @@ struct LocalLoadOpConversion
auto kOrder = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2;
auto nonKOrder = dotEnc.getOpIdx() == 0 ? rank - 2 : rank - 1;
auto needTrans = kOrder != sharedEnc.getOrder()[0];
// Limitation 1: Cannot use ldmatrix if we need to transpose a non-fp16
// matrix
// Limitation 2: If kWidth is greater than the vector width of the dot
// operands of MMA, we don't use ldmatrix
// Limitation 3 [TODO: remove]: Shared memory with leading offset is not
// supported yet
auto canUseLdmatrixLegacy =
// Limitation 1 [TODO: remove]: Check LL bases to verify register and
// address alignment
auto canUseLdmatrix =
(kWidth == vecWidth) && (!sharedEnc.getHasLeadingOffset());
if (mmaEnc.isHopper()) {
// Limitation 4 [TODO: remove]:
// I think we should be able to remove this condition, but it's here
// as the legacy ldmatrix path does not support it
canUseLdmatrixLegacy &= srcTy.getElementTypeBitWidth() * kWidth == 32 &&
dotEnc.getOpIdx() == 0;
}
// Limitation 5: If we perform swizzling, it must be done within a single
// ldmatrix tile
auto maxPhase = sharedEnc.getMaxPhase();
auto perPhase = sharedEnc.getPerPhase();
auto vecSize = sharedEnc.getVec();
canUseLdmatrixLegacy &=
(maxPhase == 1) ||
((maxPhase / perPhase <= 8) && (vecSize * bitwidth >= 8 * 16));
canUseLdmatrix &= (sharedEnc.getMaxPhase() == 1) ||
(sharedEnc.getVec() * bitwidth >= 8 * 16);
auto shape = srcTy.getShape();
auto allocShape = srcTy.getAllocShape();
// Limitation 6 [TODO: remove]: Only support 2d matrices now but we should
// Limitation 2 [TODO: remove]: Only support 2d matrices now but we should
// be able to support 3D minor changes
auto canUseLdmatrixLL = (bitwidth <= 16 || (!needTrans)) &&
shape.size() <= 2 && canUseLdmatrixLegacy;
canUseLdmatrixLegacy &=
(bitwidth == 16 || (!needTrans)) && shape.size() <= 2;
if (dotEnc.getOpIdx() == 0) {
canUseLdmatrixLL &=
shape[kOrder] >= (16 * 16 / bitwidth) && shape[nonKOrder] >= 16;
} else {
// Limitation 8 [TODO: remove]: Due to the use of ldmatrix.x4, we need
// to read 4 tiles. For opIdx=1, a single warp load four consecutive
// tiles along the K dimension, so the minimum K size is 4 * 8 = 32.
// The legacy path doesn't have this limitation because it reads
// duplicated elements from shared memory and throw them away.
// It might be better to use ldmatrix.x2 in such a case instead of
// abandoning elements.
canUseLdmatrixLL &=
shape[kOrder] >= (32 * 16 / bitwidth) && shape[nonKOrder] >= 16;
}
// Limitation 9 [TODO: remove]:
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
// though. Remove this constraint after all other limitations have been
// resolved
canUseLdmatrixLegacy &=
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
if (canUseLdmatrixLL) {
canUseLdmatrix &= (bitwidth <= 16 || !needTrans) && shape.size() <= 2;
// Limitation 3: Minimum tile size (8)x(8x16bits)
canUseLdmatrix &=
shape[kOrder] >= (8 * 16 / bitwidth) && shape[nonKOrder] >= 8;
if (canUseLdmatrix) {
return lowerSharedToDotOperandLL(op, adaptor, getTypeConverter(),
rewriter);
} else if (canUseLdmatrixLegacy) {
return lowerSharedToDotOperandLegacy(op, adaptor, getTypeConverter(),
rewriter);
}
}
return failure();
}

private:
LogicalResult
lowerSharedToDotOperandLegacy(triton::gpu::LocalLoadOp op,
triton::gpu::LocalLoadOpAdaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto src = op.getSrc();
auto dstLayout = cast<DotOperandEncodingAttr>(op.getType().getEncoding());
auto mmaLayout = cast<NvidiaMmaEncodingAttr>(dstLayout.getParent());
auto llvmElemTy =
typeConverter->convertType(src.getType().getElementType());
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
llvmElemTy, rewriter);
Value res;
if (mmaLayout.isHopper() || mmaLayout.isAmpere()) { // tensor core v2 or v3
if (mmaLayout.isHopper())
assert(dstLayout.getOpIdx() == 0 &&
"Operand $b in MMAv3 can only be in shared memory");

res = SharedToDotOperandMMAv2OrV3::convertLayout(
dstLayout.getOpIdx(), rewriter, loc, src, dstLayout, smemObj,
typeConverter, getThreadId(rewriter, loc));
} else {
llvm_unreachable("Unsupported mma layout found");
}
rewriter.replaceOp(op, res);
return success();
}

LogicalResult
lowerSharedToDotOperandLL(triton::gpu::LocalLoadOp op,
triton::gpu::LocalLoadOpAdaptor adaptor,
Expand All @@ -158,6 +88,7 @@ struct LocalLoadOpConversion
auto shape = dstTy.getShape();
auto rank = dstTy.getRank();
auto kOrder = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2;
auto nonKOrder = dotEnc.getOpIdx() == 0 ? rank - 2 : rank - 1;
auto needTrans = kOrder != sharedEnc.getOrder()[0];

auto llvmElemTy = typeConverter->convertType(dstTy.getElementType());
Expand All @@ -169,22 +100,25 @@ struct LocalLoadOpConversion

// Emit ldmatrix load operations for values packed in i32s
SmallVector<Value> elemsI32;
// Typically we load 32x8 to use ldmatrix.x4, but the minimum tile size for
// opIdx=1 is 16x8. Therefore, we use ldmatrix.x2 instead of
// ldmatrix.x4 in this case.
auto shift = dotEnc.getOpIdx() == 1 && shape[kOrder] < (32 * 16 / bitwidth);
auto maxVecElems = 8 * 16 / bitwidth;
bool valid = emitTransferBetweenRegistersAndShared(
ldmatrixLayout, srcTy, llvmElemTy,
/*maxVecElems=*/maxVecElems, smemObj, loc, rewriter, targetInfo,
[&](VectorType vecTy, Value vecAddr) {
auto numElems = vecTy.getNumElements();
auto numElemsI32 = numElems * bitwidth / 32;
auto numElemsI32 = (numElems * bitwidth / 32) >> shift;
auto matTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(numElemsI32, i32_ty));
auto ldMatrixOp = rewriter.create<nvgpu::LoadMatrixOp>(
loc, matTy, vecAddr, /*needTrans=*/needTrans);
auto resV4 = ldMatrixOp.getResult();
elemsI32.push_back(extract_val(i32_ty, resV4, 0));
elemsI32.push_back(extract_val(i32_ty, resV4, 1));
elemsI32.push_back(extract_val(i32_ty, resV4, 2));
elemsI32.push_back(extract_val(i32_ty, resV4, 3));
auto res = ldMatrixOp.getResult();
for (auto i = 0; i < numElemsI32; ++i) {
elemsI32.push_back(extract_val(i32_ty, res, i));
}
});
assert(valid && "Failed to emit ldmatrix load operations");

Expand Down

0 comments on commit 53e6e9e

Please sign in to comment.