From 53e6e9e5d018d8d4a5a385095e08d5e71cfef65a Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Sun, 26 Jan 2025 11:06:33 -0500 Subject: [PATCH] [BACKEND] LL for ldmatrix part3 - ldmatrix.x2/x1 for small tiles (#5703) --- .../TritonGPU/IR/LinearLayoutConversions.cpp | 37 ++++-- test/Conversion/tritongpu_to_llvm.mlir | 15 ++- .../TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 108 ++++-------------- 3 files changed, 56 insertions(+), 104 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 1ee8d227cd0e..3abd7a114351 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1101,6 +1101,7 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot, auto rank = shape.size(); auto opIdx = dot.getOpIdx(); int kDim = (opIdx == 0) ? rank - 1 : rank - 2; + int nonKDim = (opIdx == 0) ? rank - 2 : rank - 1; StringAttr kReg = S("register"); StringAttr kLane = S("lane"); @@ -1117,8 +1118,11 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot, auto reg = 1 << logReg; basesReg.push_back({0, reg}); } - std::vector> basesLane = {{1, 0}, {2, 0}, {4, 0}}; - int numTileCols; + std::vector> basesLane = { + {1, 0}, {2, 0}, {4, 0}, {0, 0}, {0, 0}}; + bool kX2 = shape[kDim] > 8 * 16 / elemBitWidth; + bool kX4 = shape[kDim] > 16 * 16 / elemBitWidth; + bool nonKX2 = shape[nonKDim] > 8; // Construct a tile consisting of 4 8x8x16bits sub-tiles to use ldmatrix // efficiently. opIdx=0 and opIdx=1 are handled differently. if (opIdx == 0) { @@ -1131,13 +1135,16 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot, if (needTrans) { assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are " "supported in the transposed mode"); - basesLane.push_back({0, 8}); - basesLane.push_back({8, 0}); + if (nonKX2) + basesLane[3] = {0, 8}; + if (kX2) + basesLane[4] = {8 * 16 / elemBitWidth, 0}; } else { - basesLane.push_back({8, 0}); - basesLane.push_back({0, 8 * 16 / elemBitWidth}); + if (nonKX2) + basesLane[3] = {8, 0}; + if (kX2) + basesLane[4] = {0, 8 * 16 / elemBitWidth}; } - numTileCols = 16 * 16 / elemBitWidth; } else { // The matrix elements of thread 0 are distributed in the following pattern // (fp16): @@ -1147,14 +1154,20 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot, if (needTrans) { assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are " "supported in the transposed mode"); - basesLane.push_back({8, 0}); - basesLane.push_back({16, 0}); + if (kX2) + basesLane[3] = {8, 0}; + if (kX4) + basesLane[4] = {16, 0}; } else { - basesLane.push_back({0, 8 * 16 / elemBitWidth}); - basesLane.push_back({0, 16 * 16 / elemBitWidth}); + if (kX2) + basesLane[3] = {0, 8 * 16 / elemBitWidth}; + if (kX4) + basesLane[4] = {0, 16 * 16 / elemBitWidth}; } - numTileCols = 32 * 16 / elemBitWidth; } + int numTileCols = + (8 * 16 / elemBitWidth) + << (static_cast(kX2) + static_cast(kX4 && opIdx == 1)); // Expand the `register` dimension so the size of columns matches `K`. auto layout = LinearLayout({{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}}, diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 2b5ae9597626..690420b832ee 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -862,8 +862,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @convert_dot_ldmatrix(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> - // CHECK: nvgpu.ldmatrix - // CHECK: nvgpu.ldmatrix + // CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK-NOT: nvgpu.ldmatrix %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b> @@ -892,8 +893,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @convert_dot_ldmatrix_swizzle(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { %AA = ttg.local_alloc %A : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> %BB = ttg.local_alloc %B : (tensor<16x16xf16, #blocked0>) -> !ttg.memdesc<16x16xf16, #shared0, #smem> - // CHECK: nvgpu.ldmatrix - // CHECK: nvgpu.ldmatrix + // CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK-NOT: nvgpu.ldmatrix %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_a> %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf16, #shared0, #smem> -> tensor<16x16xf16, #dot_operand_b> @@ -974,7 +976,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { tt.func @convert_dot_fp8(%A: tensor<16x16xf8E5M2, #blocked0>, %B: tensor<16x16xf8E5M2, #blocked0>) { %AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> %BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> - // CHECK: nvgpu.ldmatrix + // CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> + // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> + // CHECK-NOT: nvgpu.ldmatrix %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a> %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index b0f5d398d987..8fa340cb14d5 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -52,98 +52,28 @@ struct LocalLoadOpConversion auto kOrder = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2; auto nonKOrder = dotEnc.getOpIdx() == 0 ? rank - 2 : rank - 1; auto needTrans = kOrder != sharedEnc.getOrder()[0]; - // Limitation 1: Cannot use ldmatrix if we need to transpose a non-fp16 - // matrix - // Limitation 2: If kWidth is greater than the vector width of the dot - // operands of MMA, we don't use ldmatrix - // Limitation 3 [TODO: remove]: Shared memory with leading offset is not - // supported yet - auto canUseLdmatrixLegacy = + // Limitation 1 [TODO: remove]: Check LL bases to verify register and + // address alignment + auto canUseLdmatrix = (kWidth == vecWidth) && (!sharedEnc.getHasLeadingOffset()); - if (mmaEnc.isHopper()) { - // Limitation 4 [TODO: remove]: - // I think we should be able to remove this condition, but it's here - // as the legacy ldmatrix path does not support it - canUseLdmatrixLegacy &= srcTy.getElementTypeBitWidth() * kWidth == 32 && - dotEnc.getOpIdx() == 0; - } - // Limitation 5: If we perform swizzling, it must be done within a single - // ldmatrix tile - auto maxPhase = sharedEnc.getMaxPhase(); - auto perPhase = sharedEnc.getPerPhase(); - auto vecSize = sharedEnc.getVec(); - canUseLdmatrixLegacy &= - (maxPhase == 1) || - ((maxPhase / perPhase <= 8) && (vecSize * bitwidth >= 8 * 16)); + canUseLdmatrix &= (sharedEnc.getMaxPhase() == 1) || + (sharedEnc.getVec() * bitwidth >= 8 * 16); auto shape = srcTy.getShape(); - auto allocShape = srcTy.getAllocShape(); - // Limitation 6 [TODO: remove]: Only support 2d matrices now but we should + // Limitation 2 [TODO: remove]: Only support 2d matrices now but we should // be able to support 3D minor changes - auto canUseLdmatrixLL = (bitwidth <= 16 || (!needTrans)) && - shape.size() <= 2 && canUseLdmatrixLegacy; - canUseLdmatrixLegacy &= - (bitwidth == 16 || (!needTrans)) && shape.size() <= 2; - if (dotEnc.getOpIdx() == 0) { - canUseLdmatrixLL &= - shape[kOrder] >= (16 * 16 / bitwidth) && shape[nonKOrder] >= 16; - } else { - // Limitation 8 [TODO: remove]: Due to the use of ldmatrix.x4, we need - // to read 4 tiles. For opIdx=1, a single warp load four consecutive - // tiles along the K dimension, so the minimum K size is 4 * 8 = 32. - // The legacy path doesn't have this limitation because it reads - // duplicated elements from shared memory and throw them away. - // It might be better to use ldmatrix.x2 in such a case instead of - // abandoning elements. - canUseLdmatrixLL &= - shape[kOrder] >= (32 * 16 / bitwidth) && shape[nonKOrder] >= 16; - } - // Limitation 9 [TODO: remove]: - // If we remove this one, ldmatrix will IMA. It can probably be relaxed - // though. Remove this constraint after all other limitations have been - // resolved - canUseLdmatrixLegacy &= - srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth; - if (canUseLdmatrixLL) { + canUseLdmatrix &= (bitwidth <= 16 || !needTrans) && shape.size() <= 2; + // Limitation 3: Minimum tile size (8)x(8x16bits) + canUseLdmatrix &= + shape[kOrder] >= (8 * 16 / bitwidth) && shape[nonKOrder] >= 8; + if (canUseLdmatrix) { return lowerSharedToDotOperandLL(op, adaptor, getTypeConverter(), rewriter); - } else if (canUseLdmatrixLegacy) { - return lowerSharedToDotOperandLegacy(op, adaptor, getTypeConverter(), - rewriter); } } return failure(); } private: - LogicalResult - lowerSharedToDotOperandLegacy(triton::gpu::LocalLoadOp op, - triton::gpu::LocalLoadOpAdaptor adaptor, - const LLVMTypeConverter *typeConverter, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - auto src = op.getSrc(); - auto dstLayout = cast(op.getType().getEncoding()); - auto mmaLayout = cast(dstLayout.getParent()); - auto llvmElemTy = - typeConverter->convertType(src.getType().getElementType()); - auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), - llvmElemTy, rewriter); - Value res; - if (mmaLayout.isHopper() || mmaLayout.isAmpere()) { // tensor core v2 or v3 - if (mmaLayout.isHopper()) - assert(dstLayout.getOpIdx() == 0 && - "Operand $b in MMAv3 can only be in shared memory"); - - res = SharedToDotOperandMMAv2OrV3::convertLayout( - dstLayout.getOpIdx(), rewriter, loc, src, dstLayout, smemObj, - typeConverter, getThreadId(rewriter, loc)); - } else { - llvm_unreachable("Unsupported mma layout found"); - } - rewriter.replaceOp(op, res); - return success(); - } - LogicalResult lowerSharedToDotOperandLL(triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, @@ -158,6 +88,7 @@ struct LocalLoadOpConversion auto shape = dstTy.getShape(); auto rank = dstTy.getRank(); auto kOrder = dotEnc.getOpIdx() == 0 ? rank - 1 : rank - 2; + auto nonKOrder = dotEnc.getOpIdx() == 0 ? rank - 2 : rank - 1; auto needTrans = kOrder != sharedEnc.getOrder()[0]; auto llvmElemTy = typeConverter->convertType(dstTy.getElementType()); @@ -169,22 +100,25 @@ struct LocalLoadOpConversion // Emit ldmatrix load operations for values packed in i32s SmallVector elemsI32; + // Typically we load 32x8 to use ldmatrix.x4, but the minimum tile size for + // opIdx=1 is 16x8. Therefore, we use ldmatrix.x2 instead of + // ldmatrix.x4 in this case. + auto shift = dotEnc.getOpIdx() == 1 && shape[kOrder] < (32 * 16 / bitwidth); auto maxVecElems = 8 * 16 / bitwidth; bool valid = emitTransferBetweenRegistersAndShared( ldmatrixLayout, srcTy, llvmElemTy, /*maxVecElems=*/maxVecElems, smemObj, loc, rewriter, targetInfo, [&](VectorType vecTy, Value vecAddr) { auto numElems = vecTy.getNumElements(); - auto numElemsI32 = numElems * bitwidth / 32; + auto numElemsI32 = (numElems * bitwidth / 32) >> shift; auto matTy = LLVM::LLVMStructType::getLiteral( ctx, SmallVector(numElemsI32, i32_ty)); auto ldMatrixOp = rewriter.create( loc, matTy, vecAddr, /*needTrans=*/needTrans); - auto resV4 = ldMatrixOp.getResult(); - elemsI32.push_back(extract_val(i32_ty, resV4, 0)); - elemsI32.push_back(extract_val(i32_ty, resV4, 1)); - elemsI32.push_back(extract_val(i32_ty, resV4, 2)); - elemsI32.push_back(extract_val(i32_ty, resV4, 3)); + auto res = ldMatrixOp.getResult(); + for (auto i = 0; i < numElemsI32; ++i) { + elemsI32.push_back(extract_val(i32_ty, res, i)); + } }); assert(valid && "Failed to emit ldmatrix load operations");