diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeEpilogue.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeEpilogue.cpp index abe7f45a4c9e..02aa61672bab 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeEpilogue.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeEpilogue.cpp @@ -25,25 +25,34 @@ class BypassEpilogueSMEM : public mlir::RewritePattern { public: explicit BypassEpilogueSMEM(mlir::MLIRContext *context) - : mlir::RewritePattern(triton::StoreOp::getOperationName(), 1, context) {} + : mlir::RewritePattern(MatchAnyOpTypeTag(), 1, context) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - - auto stOp = dyn_cast(op); - if (!stOp) + Value ptr, val, mask; + RankedTensorType ptrType, valType; + triton::gpu::ConvertLayoutOp cvtOp; + + if (auto stOp = dyn_cast(op)){ + ptr = stOp.getPtr(); + val = stOp.getValue(); + mask = stOp.getMask(); + } else if (auto atomicRMWOp = dyn_cast(op)) { + ptr = atomicRMWOp.getPtr(); + val = atomicRMWOp.getVal(); + mask = atomicRMWOp.getMask(); + } else { return mlir::failure(); - Value ptr = stOp.getPtr(); - Value val = stOp.getValue(); - Value mask = stOp.getMask(); - auto ptrType = ptr.getType().dyn_cast(); - auto valType = val.getType().dyn_cast(); + } + + ptrType = ptr.getType().dyn_cast(); + valType = val.getType().dyn_cast(); if (!ptrType || !valType || !ptrType.getEncoding().isa() || !valType.getEncoding().isa()) return mlir::failure(); - auto cvtOp = dyn_cast(val.getDefiningOp()); + cvtOp = dyn_cast(val.getDefiningOp()); if (!cvtOp) return mlir::failure(); @@ -80,8 +89,18 @@ class BypassEpilogueSMEM : public mlir::RewritePattern { mask.getLoc(), newMaskType, mask); } - rewriter.replaceOpWithNewOp( - stOp, newPtr, newVal, newMask, stOp.getCache(), stOp.getEvict()); + if (auto stOp = dyn_cast(op)) { + rewriter.replaceOpWithNewOp( + stOp, newPtr, newVal, newMask, stOp.getCache(), stOp.getEvict()); + } else if (auto atomicRMWOp = dyn_cast(op)) { + auto result = atomicRMWOp.getResult(); + auto resultType = result.getType().dyn_cast(); + auto newResultType = RankedTensorType::get( + resultType.getShape(), resultType.getElementType(), newEncoding); + rewriter.replaceOpWithNewOp( + atomicRMWOp, newResultType, atomicRMWOp.getAtomicRmwOpAttr(), newPtr, newVal, newMask, atomicRMWOp.getSemAttr(), atomicRMWOp.getScopeAttr()); + } + return mlir::success(); } };