Skip to content

Commit

Permalink
[MFMA] MFMA 4x64 64x4 version 2
Browse files Browse the repository at this point in the history
Extend K dimension of mfma4x64 and mfma64x4 dot operand layout from 4 to 64.
  • Loading branch information
binarman committed Mar 18, 2024
1 parent 009215d commit 4c1bbc9
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 131 deletions.
6 changes: 4 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ struct MfmaInsnAttr {
unsigned n;
unsigned k;
// k_base refers to the number of elements per thread
unsigned k_base;
unsigned k_base_a;
unsigned k_base_b;
llvm::StringRef insn;
};

Expand Down Expand Up @@ -223,7 +224,8 @@ class MfmaInsn {
unsigned getMDim();
unsigned getNDim();
StringRef getInsnName();
unsigned getKBase();
unsigned getKBaseA();
unsigned getKBaseB();
};
} // namespace mlir

Expand Down
3 changes: 2 additions & 1 deletion lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getKWidth() == 4 &&
dotOperandLayout.getParent() == mfmaLayout &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16 ||
(mfmaLayout.getMDim() == 4 && mfmaLayout.getNDim() == 64)) &&
mfmaLayout.getIsTransposed() &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,12 @@ llvm::SmallVector<llvm::SmallVector<Value>> computeTensorElemMappingInBlock(
if (iNonKDim == 32)
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
else {
// In this configuration wave contains 16 copies of same data
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) {
// shortcut for 64x64 tile size.
// In this case warp do not wrap, so no need to introduce this offset
if (iNonKDim == 64)
laneHOffset = i32_val(0);
} else {
assert(iKDim * iNonKDim / numOfElems == 64 &&
"seems no all threads in wave contain unique elements");
else
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
}
}

for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Expand Down Expand Up @@ -346,7 +344,7 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
// 32 33 34 35 ... 63
// 32 33 34 35 ... 63
Value halfOffset;
if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4)
if (iNonKDim == 64)
halfOffset = i32_val(0);
else
halfOffset =
Expand Down Expand Up @@ -456,6 +454,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
int numSubBlocks = 1;
if ((mfmaInstrK == 4 || mfmaInstrK == 1) && mfmaInstrNonK == 4)
numSubBlocks = 16;
assert(numSubBlocks == 1 &&
"after reworking layout, there should be no redundency");
int numOfElems = mfmaInstrNonK * mfmaInstrK * numSubBlocks / iWaveSize;
assert(numOfElems >= 1);

Expand Down
178 changes: 162 additions & 16 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,140 @@ struct DotOpMFMAConversionHelper {
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
}

/**
* @param mfmaInsnName
* @param valA
* @param valB
* @param valC
* @param cbsz Control Broadcast Size modifier
* @param abid A-matrix Broadcast Identifier
* @param blgp B-matrix Lane Group Pattern modifier
*/
Value generateMFMAOp(StringRef mfmaInsnName, Value valA, Value valB,
Value valC) const {
Value valC, int cbsz = 0, int abid = 0,
int blgp = 0) const {
assert(cbsz >= 0 && cbsz <= 4);
assert(abid >= 0 && abid <= 15);
assert(blgp >= 0 && blgp <= 7);
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
Value zeroVal = i32_val(0);
Value cbszFlag = cbsz != 0 ? i32_val(cbsz) : zeroVal;
Value abidFlag = abid != 0 ? i32_val(abid) : zeroVal;
Value blgpFlag = blgp != 0 ? i32_val(blgp) : zeroVal;
OperationState loweredOp(loc, mfmaInsnName);
loweredOp.addTypes(resType);
loweredOp.addOperands({valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
loweredOp.addOperands({valA, valB, valC, cbszFlag, abidFlag, blgpFlag});
return rewriter.create(loweredOp)->getResult(0);
}

Value getSubVector(Value vec, int numSubVectors, int subVecId) const {
auto groupVecType = vec.getType().cast<VectorType>();
auto elemType = groupVecType.getElementType();
auto totalElems = groupVecType.getNumElements();
auto elemsPerRep = totalElems / numSubVectors;
VectorType repVecType = vec_ty(elemType, elemsPerRep);
Value repVec = undef(repVecType);
for (int i = 0; i < elemsPerRep; i++) {
Value elem =
extract_element(elemType, vec, i32_val(subVecId * elemsPerRep + i));
repVec = insert_element(repVecType, repVec, elem, i32_val(i));
}
return repVec;
}

Value getRepetitionValue(Value vec, int repId) const {
auto groupVecType = vec.getType().cast<VectorType>();
auto elemType = groupVecType.getElementType();
if (elemType.getIntOrFloatBitWidth() == 16) {
Value elem = getSubVector(vec, 16, repId);
return elem;
}
auto totalElems = groupVecType.getNumElements();
assert(repId < totalElems);
Value elem = extract_element(elemType, vec, i32_val(repId));
return elem;
}

Value broadcastGroup(Value val, int groupId, int numGroups) const {
constexpr int waveSize = 64;
const int groupSize = waveSize / numGroups;

Value lane = getThreadId();
// Multiply by 4, because permute requires offset in bytes
Value laneOffset = mul(urem(lane, i32_val(groupSize)), i32_val(4));
Value permuteAddr = add(laneOffset, i32_val(groupId * groupSize * 4));
Type valType = val.getType();
Value broadcasted;
if (valType.isInteger(32))
broadcasted = rewriter.create<ROCDL::DsBpermuteOp>(loc, val.getType(),
permuteAddr, val);
if (valType.isF32()) {
val = bitcast(val, i32_ty);
broadcasted = rewriter.create<ROCDL::DsBpermuteOp>(loc, val.getType(),
permuteAddr, val);
broadcasted = bitcast(broadcasted, f32_ty);
}
if (valType.isa<VectorType>()) {
auto vecTy = valType.dyn_cast<VectorType>();
auto vecBitSize = vecTy.getElementType().getIntOrFloatBitWidth() *
vecTy.getNumElements();
const int int32VecSize = vecBitSize / 32;

Type int32VecTy = vec_ty(i32_ty, int32VecSize);
Value int32Val = bitcast(val, int32VecTy);
Value int32Broadcasted = undef(int32VecTy);
for (int i = 0; i < int32VecSize; ++i) {
Value int32Chunk = extract_element(i32_ty, int32Val, i32_val(i));
Value broadcastedChunk = rewriter.create<ROCDL::DsBpermuteOp>(
loc, i32_ty, permuteAddr, int32Chunk);
int32Broadcasted = insert_element(int32VecTy, int32Broadcasted,
broadcastedChunk, i32_val(i));
}
broadcasted = bitcast(int32Broadcasted, valType);
}
assert(broadcasted);
return broadcasted;
}

Value generateMFMATile(StringRef mfmaInsnName, Value valA, Value valB,
Value valC, int mDim, int nDim, bool transpose) const {

Value acc;
if (mDim == nDim) {
acc = transpose ? generateMFMAOp(mfmaInsnName, valB, valA, valC)
: generateMFMAOp(mfmaInsnName, valA, valB, valC);
}
if (mDim == 4 && nDim == 64 || mDim == 64 && nDim == 4) {
// broadcast selected kRep A operand matrix to all A matrices(2^4=16)
constexpr int broadcastCtrl = 4;
constexpr int numRepeats = 16;
acc = valC;
for (int kRep = 0; kRep < numRepeats; kRep++) {
if (mDim == 4 && !transpose) {
Value repVec = getRepetitionValue(valB, kRep);
acc = generateMFMAOp(mfmaInsnName, valA, repVec, acc, broadcastCtrl,
kRep);
}
if (mDim == 4 && transpose) {
Value repValB = getRepetitionValue(valB, kRep);
Value broadcastValA = broadcastGroup(valA, kRep, numRepeats);
acc = generateMFMAOp(mfmaInsnName, repValB, broadcastValA, acc);
}
if (nDim == 4 && !transpose) {
Value repValA = getRepetitionValue(valA, kRep);
Value broadcastValB = broadcastGroup(valB, kRep, numRepeats);
acc = generateMFMAOp(mfmaInsnName, repValA, broadcastValB, acc);
}
if (nDim == 4 && transpose) {
Value repVec = getRepetitionValue(valA, kRep);
acc = generateMFMAOp(mfmaInsnName, valB, repVec, acc, broadcastCtrl,
kRep);
}
}
}
return acc;
}

int getNumSubmatrices(Type elementType, int mDim, int nDim) const {
if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64)
return 1;
Expand Down Expand Up @@ -187,13 +311,14 @@ struct DotOpMFMAConversionHelper {
llvm::report_fatal_error("No match found in MFMA database\n");

mfmaInsnName = (*maybeMfmaInsn).getInsnName();
unsigned k_base = (*maybeMfmaInsn).getKBase();
unsigned kBaseA = (*maybeMfmaInsn).getKBaseA();
unsigned kBaseB = (*maybeMfmaInsn).getKBaseB();

auto aEncoding = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto bEncoding = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();

auto kWidth = aEncoding.getKWidth();
assert(kWidth == bEncoding.getKWidth());
auto kWidthA = aEncoding.getKWidth();
auto kWidthB = bEncoding.getKWidth();

auto repA = aEncoding.getMFMARep(aTensorTy.getShape());
auto repB = bEncoding.getMFMARep(bTensorTy.getShape());
Expand All @@ -209,9 +334,9 @@ struct DotOpMFMAConversionHelper {
auto numRepK = repA[1];

auto operandA = getValuesFromDotOperandLayoutStruct(
loadedA, numRepM, numRepK, kWidth, k_base, aTensorTy.getElementType());
loadedA, numRepM, numRepK, kWidthA, kBaseA, aTensorTy.getElementType());
auto operandB = getValuesFromDotOperandLayoutStruct(
loadedB, numRepN, numRepK, kWidth, k_base, aTensorTy.getElementType());
loadedB, numRepN, numRepK, kWidthB, kBaseB, aTensorTy.getElementType());

auto dstElemTy = dTensorTy.getElementType();
auto fc =
Expand All @@ -236,12 +361,10 @@ struct DotOpMFMAConversionHelper {

acc = zeroAuxiliarBlocks(subBlocks, acc);
for (size_t k = 0; k < numRepK; k++)
for (int kpack = 0; kpack < kWidth / k_base; ++kpack)
acc = mfmaLayout.getIsTransposed()
? generateMFMAOp(mfmaInsnName, operandB[kpack][{n, k}],
operandA[kpack][{m, k}], acc)
: generateMFMAOp(mfmaInsnName, operandA[kpack][{m, k}],
operandB[kpack][{n, k}], acc);
for (int kpack = 0; kpack < kWidthA / kBaseA; ++kpack)
acc = generateMFMATile(mfmaInsnName, operandA[kpack][{m, k}],
operandB[kpack][{n, k}], acc, mDim, nDim,
mfmaLayout.getIsTransposed());
acc = reduceSubBlocks(subBlocks, acc);
for (unsigned v = 0; v < elemsPerVec; ++v) {
fc[m * numRepN * elemsPerVec + n * elemsPerVec + v] =
Expand Down Expand Up @@ -276,12 +399,29 @@ struct DotOpMFMAConversionHelper {
extract_element(type, rawElems, i32_val(elemId + k * k_base));
vec = insert_element(vecTy, vec, val, i32_val(elemId));
}
// if (64 == k_base) {
// constexpr int numRepeats = 16;
// const int oneOpKWidth = k_base / numRepeats;
// assert(oneOpKWidth == 4);
// auto repVecTy = vec_ty(type, oneOpKWidth);
// auto operandVecTy = vec_ty(repVecTy, numRepeats);
// results.push_back(bitcast(vec, operandVecTy));
// }
if (type.getIntOrFloatBitWidth() == 8) {
if (4 == k_base)
// This is for int8 on pre- MI300 GPUs
results.push_back(bitcast(vec, i32_ty));
if (8 == k_base)
results.push_back(bitcast(vec, i64_ty));
// In this case one tile is processed by sevelar instructions
// repack flat vector into vector of vectors
if (64 == k_base) {
constexpr int numRepeats = 16;
assert(k_base / numRepeats == 4);
auto repVecTy = i32_ty;
auto operandVecTy = vec_ty(repVecTy, numRepeats);
results.push_back(bitcast(vec, operandVecTy));
}
} else
results.push_back(vec);
}
Expand All @@ -305,8 +445,14 @@ struct DotOpMFMAConversionHelper {
auto rawElems = elems[n1 * i + j];

if (type.isF32()) {
for (int k = 0; k < kpack; ++k) {
dotOpVals[k][{i, j}] = extract_element(type, rawElems, i32_val(k));
if (k_base == 16) {
for (int k = 0; k < kpack; ++k)
dotOpVals[k][{i, j}] = getSubVector(rawElems, kpack, k);
} else {
for (int k = 0; k < kpack; ++k) {
dotOpVals[k][{i, j}] =
extract_element(type, rawElems, i32_val(k));
}
}
} else {
SmallVector<Value> vals;
Expand Down
36 changes: 26 additions & 10 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,17 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
return {};
}
} else if (parentLayout.isa<MfmaEncodingAttr>()) {
} else if (auto mfmaLayout = parentLayout.dyn_cast<MfmaEncodingAttr>()) {
auto opIdx = dotLayout.getOpIdx();
auto kWidth = dotLayout.getKWidth();
if (opIdx == 0) {
return {4, 1};
int repeats =
(mfmaLayout.getMDim() == 64 && mfmaLayout.getNDim() == 4) ? 16 : 1;
return {1, kWidth * repeats};
} else if (opIdx == 1) {
return {1, 4};
int repeats =
(mfmaLayout.getMDim() == 4 && mfmaLayout.getNDim() == 64) ? 16 : 1;
return {kWidth * repeats, 1};
} else {
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
return {};
Expand Down Expand Up @@ -458,6 +463,8 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
auto parentShapePerCTA = getShapePerCTATile(parentLayout, tensorShape);
auto opIdx = dotLayout.getOpIdx();

assert(parentMfmaLayout.getMDim() == 32);

if (opIdx == 0) {
return {parentShapePerCTA[0], 32};
} else if (opIdx == 1) {
Expand Down Expand Up @@ -1102,16 +1109,13 @@ DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));
int64_t kWidth = getKWidth();
constexpr int waveSize = 64; // MFMA is used on wave64 architectures only
int kGroups = -1;
if (mDim == nDim)
kGroups = waveSize / mDim;
if (mDim == 64 && nDim == 4 || mDim == 4 && nDim == 64)
kGroups = 1;
auto nonKDim = getOpIdx() == 0 ? mDim : nDim;
int kGroups = waveSize / nonKDim;
int64_t kDim = kWidth * kGroups;
if (getOpIdx() == 0)
return {mDim, kDim};
return {nonKDim, kDim};
else
return {kDim, nDim};
return {kDim, nonKDim};
}

SmallVector<int64_t>
Expand Down Expand Up @@ -1902,6 +1906,18 @@ struct TritonGPUInferLayoutInterface
// Verify that the encodings are valid.
if (!aEncoding || !bEncoding)
return op->emitError("mismatching encoding between A and B operands");
#ifdef USE_ROCM
auto aParentEncoding =
aEncoding.getParent().dyn_cast_or_null<MfmaEncodingAttr>();
auto bParentEncoding =
bEncoding.getParent().dyn_cast_or_null<MfmaEncodingAttr>();
if (aParentEncoding != bParentEncoding)
return op->emitError(
"mismatching parent encoding between A and B operands");
if (aParentEncoding != nullptr &&
aParentEncoding.getMDim() != aParentEncoding.getNDim())
return success();
#endif // USE_ROCM
if (aEncoding.getKWidth() != bEncoding.getKWidth())
return op->emitError("mismatching kWidth between A and B operands");
return success();
Expand Down
Loading

0 comments on commit 4c1bbc9

Please sign in to comment.