diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 9cb4de97d744..5afa922665ef 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -175,7 +175,8 @@ struct MfmaInsnAttr { unsigned n; unsigned k; // k_base refers to the number of elements per thread - unsigned k_base; + unsigned k_base_a; + unsigned k_base_b; llvm::StringRef insn; }; @@ -223,7 +224,8 @@ class MfmaInsn { unsigned getMDim(); unsigned getNDim(); StringRef getInsnName(); - unsigned getKBase(); + unsigned getKBaseA(); + unsigned getKBaseB(); }; } // namespace mlir diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 6dbc10b943a6..51c7ed03b1a9 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -571,7 +571,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { dotOperandLayout.getOpIdx() == 0 && dotOperandLayout.getKWidth() == 4 && dotOperandLayout.getParent() == mfmaLayout && - (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) && + (mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16 || + (mfmaLayout.getMDim() == 4 && mfmaLayout.getNDim() == 64)) && mfmaLayout.getIsTransposed() && (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16()); } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 86a4153603b2..77d5f6ca5160 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -158,14 +158,12 @@ llvm::SmallVector> computeTensorElemMappingInBlock( if (iNonKDim == 32) laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0); else { - // In this configuration wave contains 16 copies of same data - if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) { + // shortcut for 64x64 tile size. + // In this case warp do not wrap, so no need to introduce this offset + if (iNonKDim == 64) laneHOffset = i32_val(0); - } else { - assert(iKDim * iNonKDim / numOfElems == 64 && - "seems no all threads in wave contain unique elements"); + else laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems)); - } } for (int loadId = 0; loadId < loadsPerThread; ++loadId) { @@ -346,7 +344,7 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, // 32 33 34 35 ... 63 // 32 33 34 35 ... 63 Value halfOffset; - if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) + if (iNonKDim == 64) halfOffset = i32_val(0); else halfOffset = @@ -456,6 +454,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, int numSubBlocks = 1; if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4) numSubBlocks = 16; + assert(numSubBlocks == 1 && + "after reworking layout, there should be no redundency"); int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWaveSize; assert(numOfElems >= 1); diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp index 10bec3614969..15ed28f593be 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp @@ -60,16 +60,140 @@ struct DotOpMFMAConversionHelper { return rewriter.create(loc, i32_ty, tid); } + /** + * @param mfmaInsnName + * @param valA + * @param valB + * @param valC + * @param cbsz Control Broadcast Size modifier + * @param abid A-matrix Broadcast Identifier + * @param blgp B-matrix Lane Group Pattern modifier + */ Value generateMFMAOp(StringRef mfmaInsnName, Value valA, Value valB, - Value valC) const { + Value valC, int cbsz = 0, int abid = 0, + int blgp = 0) const { + assert(cbsz >= 0 && cbsz <= 4); + assert(abid >= 0 && abid <= 15); + assert(blgp >= 0 && blgp <= 7); auto resType = valC.getType(); - Value zeroFlag = i32_val(0); + Value zeroVal = i32_val(0); + Value cbszFlag = cbsz != 0 ? i32_val(cbsz) : zeroVal; + Value abidFlag = abid != 0 ? i32_val(abid) : zeroVal; + Value blgpFlag = blgp != 0 ? i32_val(blgp) : zeroVal; OperationState loweredOp(loc, mfmaInsnName); loweredOp.addTypes(resType); - loweredOp.addOperands({valA, valB, valC, zeroFlag, zeroFlag, zeroFlag}); + loweredOp.addOperands({valA, valB, valC, cbszFlag, abidFlag, blgpFlag}); return rewriter.create(loweredOp)->getResult(0); } + Value getSubVector(Value vec, int numSubVectors, int subVecId) const { + auto groupVecType = vec.getType().cast(); + auto elemType = groupVecType.getElementType(); + auto totalElems = groupVecType.getNumElements(); + auto elemsPerRep = totalElems / numSubVectors; + VectorType repVecType = vec_ty(elemType, elemsPerRep); + Value repVec = undef(repVecType); + for (int i = 0; i < elemsPerRep; i++) { + Value elem = + extract_element(elemType, vec, i32_val(subVecId * elemsPerRep + i)); + repVec = insert_element(repVecType, repVec, elem, i32_val(i)); + } + return repVec; + } + + Value getRepetitionValue(Value vec, int repId) const { + auto groupVecType = vec.getType().cast(); + auto elemType = groupVecType.getElementType(); + if (elemType.getIntOrFloatBitWidth() == 16) { + Value elem = getSubVector(vec, 16, repId); + return elem; + } + auto totalElems = groupVecType.getNumElements(); + assert(repId < totalElems); + Value elem = extract_element(elemType, vec, i32_val(repId)); + return elem; + } + + Value broadcastGroup(Value val, int groupId, int numGroups) const { + constexpr int waveSize = 64; + const int groupSize = waveSize / numGroups; + + Value lane = getThreadId(); + // Multiply by 4, because permute requires offset in bytes + Value laneOffset = mul(urem(lane, i32_val(groupSize)), i32_val(4)); + Value permuteAddr = add(laneOffset, i32_val(groupId * groupSize * 4)); + Type valType = val.getType(); + Value broadcasted; + if (valType.isInteger(32)) + broadcasted = rewriter.create(loc, val.getType(), + permuteAddr, val); + if (valType.isF32()) { + val = bitcast(val, i32_ty); + broadcasted = rewriter.create(loc, val.getType(), + permuteAddr, val); + broadcasted = bitcast(broadcasted, f32_ty); + } + if (valType.isa()) { + auto vecTy = valType.dyn_cast(); + auto vecBitSize = vecTy.getElementType().getIntOrFloatBitWidth() * + vecTy.getNumElements(); + const int int32VecSize = vecBitSize / 32; + + Type int32VecTy = vec_ty(i32_ty, int32VecSize); + Value int32Val = bitcast(val, int32VecTy); + Value int32Broadcasted = undef(int32VecTy); + for (int i = 0; i < int32VecSize; ++i) { + Value int32Chunk = extract_element(i32_ty, int32Val, i32_val(i)); + Value broadcastedChunk = rewriter.create( + loc, i32_ty, permuteAddr, int32Chunk); + int32Broadcasted = insert_element(int32VecTy, int32Broadcasted, + broadcastedChunk, i32_val(i)); + } + broadcasted = bitcast(int32Broadcasted, valType); + } + assert(broadcasted); + return broadcasted; + } + + Value generateMFMATile(StringRef mfmaInsnName, Value valA, Value valB, + Value valC, int mDim, int nDim, bool transpose) const { + + Value acc; + if (mDim == nDim) { + acc = transpose ? generateMFMAOp(mfmaInsnName, valB, valA, valC) + : generateMFMAOp(mfmaInsnName, valA, valB, valC); + } + if (mDim == 4 && nDim == 64 || mDim == 64 && nDim == 4) { + // broadcast selected kRep A operand matrix to all A matrices(2^4=16) + constexpr int broadcastCtrl = 4; + constexpr int numRepeats = 16; + acc = valC; + for (int kRep = 0; kRep < numRepeats; kRep++) { + if (mDim == 4 && !transpose) { + Value repVec = getRepetitionValue(valB, kRep); + acc = generateMFMAOp(mfmaInsnName, valA, repVec, acc, broadcastCtrl, + kRep); + } + if (mDim == 4 && transpose) { + Value repValB = getRepetitionValue(valB, kRep); + Value broadcastValA = broadcastGroup(valA, kRep, numRepeats); + acc = generateMFMAOp(mfmaInsnName, repValB, broadcastValA, acc); + } + if (nDim == 4 && !transpose) { + Value repValA = getRepetitionValue(valA, kRep); + Value broadcastValB = broadcastGroup(valB, kRep, numRepeats); + acc = generateMFMAOp(mfmaInsnName, repValA, broadcastValB, acc); + } + if (nDim == 4 && transpose) { + Value repVec = getRepetitionValue(valA, kRep); + acc = generateMFMAOp(mfmaInsnName, valB, repVec, acc, broadcastCtrl, + kRep); + } + } + } + return acc; + } + int getNumSubmatrices(Type elementType, int mDim, int nDim) const { if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) return 1; @@ -187,13 +311,14 @@ struct DotOpMFMAConversionHelper { llvm::report_fatal_error("No match found in MFMA database\n"); mfmaInsnName = (*maybeMfmaInsn).getInsnName(); - unsigned k_base = (*maybeMfmaInsn).getKBase(); + unsigned kBaseA = (*maybeMfmaInsn).getKBaseA(); + unsigned kBaseB = (*maybeMfmaInsn).getKBaseB(); auto aEncoding = aTensorTy.getEncoding().cast(); auto bEncoding = bTensorTy.getEncoding().cast(); - auto kWidth = aEncoding.getKWidth(); - assert(kWidth == bEncoding.getKWidth()); + auto kWidthA = aEncoding.getKWidth(); + auto kWidthB = bEncoding.getKWidth(); auto repA = aEncoding.getMFMARep(aTensorTy.getShape()); auto repB = bEncoding.getMFMARep(bTensorTy.getShape()); @@ -209,9 +334,9 @@ struct DotOpMFMAConversionHelper { auto numRepK = repA[1]; auto operandA = getValuesFromDotOperandLayoutStruct( - loadedA, numRepM, numRepK, kWidth, k_base, aTensorTy.getElementType()); + loadedA, numRepM, numRepK, kWidthA, kBaseA, aTensorTy.getElementType()); auto operandB = getValuesFromDotOperandLayoutStruct( - loadedB, numRepN, numRepK, kWidth, k_base, aTensorTy.getElementType()); + loadedB, numRepN, numRepK, kWidthB, kBaseB, aTensorTy.getElementType()); auto dstElemTy = dTensorTy.getElementType(); auto fc = @@ -236,12 +361,10 @@ struct DotOpMFMAConversionHelper { acc = zeroAuxiliarBlocks(subBlocks, acc); for (size_t k = 0; k < numRepK; k++) - for (int kpack = 0; kpack < kWidth / k_base; ++kpack) - acc = mfmaLayout.getIsTransposed() - ? generateMFMAOp(mfmaInsnName, operandB[kpack][{n, k}], - operandA[kpack][{m, k}], acc) - : generateMFMAOp(mfmaInsnName, operandA[kpack][{m, k}], - operandB[kpack][{n, k}], acc); + for (int kpack = 0; kpack < kWidthA / kBaseA; ++kpack) + acc = generateMFMATile(mfmaInsnName, operandA[kpack][{m, k}], + operandB[kpack][{n, k}], acc, mDim, nDim, + mfmaLayout.getIsTransposed()); acc = reduceSubBlocks(subBlocks, acc); for (unsigned v = 0; v < elemsPerVec; ++v) { fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] = @@ -276,12 +399,29 @@ struct DotOpMFMAConversionHelper { extract_element(type, rawElems, i32_val(elemId + k * k_base)); vec = insert_element(vecTy, vec, val, i32_val(elemId)); } + // if (64 == k_base) { + // constexpr int numRepeats = 16; + // const int oneOpKWidth = k_base / numRepeats; + // assert(oneOpKWidth == 4); + // auto repVecTy = vec_ty(type, oneOpKWidth); + // auto operandVecTy = vec_ty(repVecTy, numRepeats); + // results.push_back(bitcast(vec, operandVecTy)); + // } if (type.getIntOrFloatBitWidth() == 8) { if (4 == k_base) // This is for int8 on pre- MI300 GPUs results.push_back(bitcast(vec, i32_ty)); if (8 == k_base) results.push_back(bitcast(vec, i64_ty)); + // In this case one tile is processed by sevelar instructions + // repack flat vector into vector of vectors + if (64 == k_base) { + constexpr int numRepeats = 16; + assert(k_base / numRepeats == 4); + auto repVecTy = i32_ty; + auto operandVecTy = vec_ty(repVecTy, numRepeats); + results.push_back(bitcast(vec, operandVecTy)); + } } else results.push_back(vec); } @@ -305,8 +445,14 @@ struct DotOpMFMAConversionHelper { auto rawElems = elems[n1 * i + j]; if (type.isF32()) { - for (int k = 0; k < kpack; ++k) { - dotOpVals[k][{i, j}] = extract_element(type, rawElems, i32_val(k)); + if (k_base == 16) { + for (int k = 0; k < kpack; ++k) + dotOpVals[k][{i, j}] = getSubVector(rawElems, kpack, k); + } else { + for (int k = 0; k < kpack; ++k) { + dotOpVals[k][{i, j}] = + extract_element(type, rawElems, i32_val(k)); + } } } else { SmallVector vals; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 87e8bb218bc9..48171dc43822 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -304,12 +304,17 @@ SmallVector getSizePerThread(Attribute layout) { llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); return {}; } - } else if (parentLayout.isa()) { + } else if (auto mfmaLayout = parentLayout.dyn_cast()) { auto opIdx = dotLayout.getOpIdx(); + auto kWidth = dotLayout.getKWidth(); if (opIdx == 0) { - return {4, 1}; + int repeats = + (mfmaLayout.getMDim() == 64 && mfmaLayout.getNDim() == 4) ? 16 : 1; + return {1, kWidth * repeats}; } else if (opIdx == 1) { - return {1, 4}; + int repeats = + (mfmaLayout.getMDim() == 4 && mfmaLayout.getNDim() == 64) ? 16 : 1; + return {kWidth * repeats, 1}; } else { assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1"); return {}; @@ -458,6 +463,8 @@ SmallVector getShapePerCTATile(Attribute layout, auto parentShapePerCTA = getShapePerCTATile(parentLayout, tensorShape); auto opIdx = dotLayout.getOpIdx(); + assert(parentMfmaLayout.getMDim() == 32); + if (opIdx == 0) { return {parentShapePerCTA[0], 32}; } else if (opIdx == 1) { @@ -1102,16 +1109,13 @@ DotOperandEncodingAttr::getMFMAElemsPerInstr() const { (mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64)); int64_t kWidth = getKWidth(); constexpr int waveSize = 64; // MFMA is used on wave64 architectures only - int kGroups = -1; - if (mDim == nDim) - kGroups = waveSize / mDim; - if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64) - kGroups = 1; + auto nonKDim = getOpIdx() == 0 ? mDim : nDim; + int kGroups = waveSize / nonKDim; int64_t kDim = kWidth * kGroups; if (getOpIdx() == 0) - return {mDim, kDim}; + return {nonKDim, kDim}; else - return {kDim, nDim}; + return {kDim, nonKDim}; } SmallVector @@ -1902,6 +1906,18 @@ struct TritonGPUInferLayoutInterface // Verify that the encodings are valid. if (!aEncoding || !bEncoding) return op->emitError("mismatching encoding between A and B operands"); +#ifdef USE_ROCM + auto aParentEncoding = + aEncoding.getParent().dyn_cast_or_null(); + auto bParentEncoding = + bEncoding.getParent().dyn_cast_or_null(); + if (aParentEncoding != bParentEncoding) + return op->emitError( + "mismatching parent encoding between A and B operands"); + if (aParentEncoding != nullptr && + aParentEncoding.getMDim() != aParentEncoding.getNDim()) + return success(); +#endif // USE_ROCM if (aEncoding.getKWidth() != bEncoding.getKWidth()) return op->emitError("mismatching kWidth between A and B operands"); return success(); diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp index 3f39248597bd..32303ca748bc 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -158,9 +158,8 @@ class BlockedToMFMA : public mlir::RewritePattern { /// @brief Choose MFMA instruction parameters /// @param dot target dot operation - /// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments - std::tuple - chooseMfmaDimensions(tt::DotOp dot) const { + /// @return selected mfma instruction + MfmaInsn chooseMfmaDimensions(tt::DotOp dot) const { // number of matrix elements along k dim per one MFMA intruction unsigned kDim = 0; auto opType = dot.getA().getType().cast(); @@ -200,6 +199,8 @@ class BlockedToMFMA : public mlir::RewritePattern { nDim = 16; } if (minSize < 16) { + assert(opType.getShape()[1] >= 64 && + "k should be at least 64 to use this layout"); if (resShape[0] < 16 && resShape[1] >= 64) { mDim = 4; nDim = 64; @@ -207,8 +208,6 @@ class BlockedToMFMA : public mlir::RewritePattern { mDim = 64; nDim = 4; } else { - assert(opType.getShape()[1] >= 64 && - "k should be at least 64 to use this layout"); mDim = 4; nDim = 4; } @@ -227,7 +226,7 @@ class BlockedToMFMA : public mlir::RewritePattern { assert(mDim != 0 && nDim != 0); assert(resShape[0] % mDim == 0 && resShape[1] % nDim == 0); assert(opType.getShape()[1] % kDim == 0); - return {mDim, nDim, kDim}; + return maybeMfmaInsn.value(); } mlir::LogicalResult @@ -259,7 +258,10 @@ class BlockedToMFMA : public mlir::RewritePattern { ttg::MfmaEncodingAttr mfmaEnc; - auto [mDim, nDim, kDim] = chooseMfmaDimensions(dotOp); + auto instr = chooseMfmaDimensions(dotOp); + auto mDim = instr.getMDim(); + auto nDim = instr.getNDim(); + auto kDim = instr.getKDim(); auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); @@ -290,33 +292,24 @@ class BlockedToMFMA : public mlir::RewritePattern { // kWidth is initialized as k_base, which is the number of elements hold by // one thread per mfma instruction - auto kWidth = -1; - // in mfma 32x32 case argument matrix groups elements in 2 groups - // in mfma 16x16 case argument matrix groups elements in 4 groups - // in mfma 4x4 case argument matrix groups in 16 groups - if (mDim == 32 && nDim == 32) - kWidth = kDim / 2; - if (mDim == 16 && nDim == 16) - kWidth = kDim / 4; - if (mDim == 4 && nDim == 4) - kWidth = kDim / 16; - if (mDim == 4 && nDim == 64 || mDim == 64 && nDim == 4) - kWidth = kDim; - assert(kWidth != -1); + auto kWidthA = instr.getKBaseA(); + auto kWidthB = instr.getKBaseB(); // We want to extend kWidth by kpack (kpack=1 means no extension) // to increase ds_read vector size // However, in FA, the second dot can only use kWidth = k_bse since it's // limited by the result of the first dot, which is of mfmaLayout. - if (!isSecondDot(dotOp)) - kWidth *= kpack; + if (!isSecondDot(dotOp)) { + kWidthA *= kpack; + kWidthB *= kpack; + } auto newAType = RankedTensorType::get( oldAType.getShape(), oldAType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth)); + ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidthA)); auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), - ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidth)); + ttg::DotOperandEncodingAttr::get(ctx, 1, mfmaEnc, kWidthB)); a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 23f1befd2617..5a5046b1f8f0 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -669,173 +669,183 @@ using MfmaInsnGroupMap = llvm::DenseMap const MfmaInsnGroupMap & { static MfmaInsnGroupMap MfmaInsnMap{ + // MFMA tile description: + // M N K k_base_a k_base_b instr_name // f32 // mfma_f32_32x32x2f32 {{32, 32, MfmaTypeId::Fp32TyId, 1}, - {32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, + {32, 32, 2, 1, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, {{32, 32, MfmaTypeId::Fp32TyId, 2}, - {32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, + {32, 32, 2, 1, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, {{32, 32, MfmaTypeId::Fp32TyId, 3}, - {32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, + {32, 32, 2, 1, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}}, // mfma_f32_16x16x4f32 {{16, 16, MfmaTypeId::Fp32TyId, 1}, - {16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, + {16, 16, 4, 1, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, {{16, 16, MfmaTypeId::Fp32TyId, 2}, - {16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, + {16, 16, 4, 1, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, {{16, 16, MfmaTypeId::Fp32TyId, 3}, - {16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, + {16, 16, 4, 1, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}}, // mfma_f32_4x4x1f32 {{4, 4, MfmaTypeId::Fp32TyId, 1}, - {4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 4, 16, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 4, MfmaTypeId::Fp32TyId, 2}, - {4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 4, 16, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 64, MfmaTypeId::Fp32TyId, 1}, - {4, 64, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 64, 16, 1, 16, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 64, MfmaTypeId::Fp32TyId, 2}, - {4, 64, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 64, 16, 1, 16, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{64, 4, MfmaTypeId::Fp32TyId, 1}, - {64, 4, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {64, 4, 16, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{64, 4, MfmaTypeId::Fp32TyId, 2}, - {64, 4, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {64, 4, 16, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, // mfma_f32_4x4x1_16B_f32 {{4, 4, MfmaTypeId::Fp32TyId, 3}, - {4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 4, 16, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{4, 64, MfmaTypeId::Fp32TyId, 3}, - {4, 64, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {4, 64, 16, 1, 16, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, {{64, 4, MfmaTypeId::Fp32TyId, 3}, - {64, 4, 1, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, + {64, 4, 16, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}}, // f16 // mfma_f32_32x32x8f16 {{32, 32, MfmaTypeId::Fp16TyId, 1}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, {{32, 32, MfmaTypeId::Fp16TyId, 2}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, {{32, 32, MfmaTypeId::Fp16TyId, 3}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}}, // mfma_f32_16x16x16xf16 {{16, 16, MfmaTypeId::Fp16TyId, 1}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, {{16, 16, MfmaTypeId::Fp16TyId, 2}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, {{16, 16, MfmaTypeId::Fp16TyId, 3}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}}, // mfma_f32_4x4x4f16 {{4, 4, MfmaTypeId::Fp16TyId, 1}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 4, MfmaTypeId::Fp16TyId, 2}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 4, MfmaTypeId::Fp16TyId, 3}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 64, MfmaTypeId::Fp16TyId, 1}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 64, MfmaTypeId::Fp16TyId, 2}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{4, 64, MfmaTypeId::Fp16TyId, 3}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{64, 4, MfmaTypeId::Fp16TyId, 1}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{64, 4, MfmaTypeId::Fp16TyId, 2}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, {{64, 4, MfmaTypeId::Fp16TyId, 3}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}}, // bf16 // mfma_f32_32x32x4_bf16 {{32, 32, MfmaTypeId::Bf16TyId, 1}, - {32, 32, 4, 2, ROCDL::mfma_f32_32x32x4bf16::getOperationName()}}, + {32, 32, 4, 2, 2, ROCDL::mfma_f32_32x32x4bf16::getOperationName()}}, // mfma_f32_32x32x8_bf16_1K {{32, 32, MfmaTypeId::Bf16TyId, 2}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, {{32, 32, MfmaTypeId::Bf16TyId, 3}, - {32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}}, // mfma_f32_16x16x8_bf16 {{16, 16, MfmaTypeId::Bf16TyId, 1}, - {16, 16, 8, 2, ROCDL::mfma_f32_16x16x8bf16::getOperationName()}}, + {16, 16, 8, 2, 2, ROCDL::mfma_f32_16x16x8bf16::getOperationName()}}, // mfma_f32_16x16x16_bf16_1K {{16, 16, MfmaTypeId::Bf16TyId, 2}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, {{16, 16, MfmaTypeId::Bf16TyId, 3}, - {16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}}, // mfma_f32_4x4x2_bf16 {{4, 4, MfmaTypeId::Bf16TyId, 1}, - {4, 4, 32, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, + {4, 4, 32, 2, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, {{4, 64, MfmaTypeId::Bf16TyId, 1}, - {4, 64, 2, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, + {4, 64, 32, 2, 32, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, {{64, 4, MfmaTypeId::Bf16TyId, 1}, - {64, 4, 2, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, + {64, 4, 32, 32, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}}, // mfma_f32_4x4x4_bf16_1K {{4, 4, MfmaTypeId::Bf16TyId, 2}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{4, 4, MfmaTypeId::Bf16TyId, 3}, - {4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{4, 64, MfmaTypeId::Bf16TyId, 2}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{4, 64, MfmaTypeId::Bf16TyId, 3}, - {4, 64, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{64, 4, MfmaTypeId::Bf16TyId, 2}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, {{64, 4, MfmaTypeId::Bf16TyId, 3}, - {64, 4, 4, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}}, // int8 // mfma_i32_32x32x8i8 {{32, 32, MfmaTypeId::I8TyId, 1}, - {32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, {{32, 32, MfmaTypeId::I8TyId, 2}, - {32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, + {32, 32, 8, 4, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}}, // mfma_i32_32x32x16i8 {{32, 32, MfmaTypeId::I8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_i32_32x32x16_i8::getOperationName()}}, + {32, 32, 16, 8, 8, ROCDL::mfma_i32_32x32x16_i8::getOperationName()}}, // mfma_i32_16x16x16i8 {{16, 16, MfmaTypeId::I8TyId, 1}, - {16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, {{16, 16, MfmaTypeId::I8TyId, 2}, - {16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, + {16, 16, 16, 4, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}}, // mfma_i32_16x16x32i8 {{16, 16, MfmaTypeId::I8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_i32_16x16x32_i8::getOperationName()}}, + {16, 16, 32, 8, 8, ROCDL::mfma_i32_16x16x32_i8::getOperationName()}}, // mfma_i32_4x4x4i8 {{4, 4, MfmaTypeId::I8TyId, 1}, - {4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 4, MfmaTypeId::I8TyId, 2}, - {4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 4, MfmaTypeId::I8TyId, 3}, - {4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 64, MfmaTypeId::I8TyId, 1}, - {4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 64, MfmaTypeId::I8TyId, 2}, - {4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{4, 64, MfmaTypeId::I8TyId, 3}, - {4, 64, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {4, 64, 64, 4, 64, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{64, 4, MfmaTypeId::I8TyId, 1}, - {64, 4, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{64, 4, MfmaTypeId::I8TyId, 2}, - {64, 4, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, {{64, 4, MfmaTypeId::I8TyId, 3}, - {64, 4, 4, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, + {64, 4, 64, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}}, // fp8 * pf8 // mfma_f32_32x32x16_FP8_FP8 {{32, 32, MfmaTypeId::Fp8Fp8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName()}}, // mfma_f32_16x16x32_FP8_FP8 {{16, 16, MfmaTypeId::Fp8Fp8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName()}}, + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName()}}, // mfma_f32_32x32x16_FP8_BF8 {{32, 32, MfmaTypeId::Fp8Bf8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName()}}, // mfma_f32_16x16x32_FP8_BF8 {{16, 16, MfmaTypeId::Fp8Bf8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName()}}, + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName()}}, // mfma_f32_32x32x16_BF8_FP8 {{32, 32, MfmaTypeId::Bf8Fp8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName()}}, // mfma_f32_16x16x32_BF8_FP8 {{16, 16, MfmaTypeId::Bf8Fp8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName()}}, + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName()}}, // mfma_f32_32x32x16_BF8_BF8 {{32, 32, MfmaTypeId::Bf8Bf8TyId, 3}, - {32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName()}}, + {32, 32, 16, 8, 8, + ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName()}}, // mfma_f32_16x16x32_BF8_BF8 {{16, 16, MfmaTypeId::Bf8Bf8TyId, 3}, - {16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName()}}}; + {16, 16, 32, 8, 8, + ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName()}}}; return MfmaInsnMap; }; @@ -859,6 +869,7 @@ unsigned MfmaInsn::getKDim() { return attr.k; } unsigned MfmaInsn::getMDim() { return attr.m; } unsigned MfmaInsn::getNDim() { return attr.n; } StringRef MfmaInsn::getInsnName() { return attr.insn; } -unsigned MfmaInsn::getKBase() { return attr.k_base;} +unsigned MfmaInsn::getKBaseA() { return attr.k_base_a; } +unsigned MfmaInsn::getKBaseB() { return attr.k_base_b; } } // namespace mlir diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 0a451d539453..d991b689c6c2 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -1706,8 +1706,9 @@ def kernel(X, stride_xm, stride_xn, [4, 32, 64, 4], [32, 4, 64, 2], [16, 4, 64, 8], - [64, 4, 16, 1], - [4, 64, 16, 1], + [64, 4, 64, 1], + [4, 64, 64, 1], + [4, 64, 64, 4], ] for allow_tf32 in [False, True] for col_a in [True, False]