Skip to content

Commit

Permalink
[compiler] fix inf/nan convert on x86_64 arch
Browse files Browse the repository at this point in the history
  • Loading branch information
jianwenyyy committed Jun 28, 2024
1 parent bed7740 commit 329efd7
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 23 deletions.
6 changes: 5 additions & 1 deletion compiler/include/byteir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def HloFusionToLinalg : Pass<"hlo-fusion-to-linalg", "func::FuncOp"> {
Option<"enablePrimitiveOps", "enable-primitive-ops", "bool",
/*default=*/"false",
"Lower to primitive Linalg ops (map, reduce and "
"transpose) when possible, instead of linalg.generic">
"transpose) when possible, instead of linalg.generic">,
Option<"target", "target", "std::string", /*default*/ "",
"Specificy the target">,
Option<"arch", "arch", "std::string", /*default*/ "",
"Specificy the target arch">
];
}

Expand Down
10 changes: 6 additions & 4 deletions compiler/include/byteir/Conversion/ToLinalg/ToLinalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ void populateTensorToLinalgConversionPatterns(RewritePatternSet &patterns);
void populateLinalgExtToLinalgConversionPatterns(RewritePatternSet &patterns);

void populateHloToLinalgExtConversionPattern(TypeConverter &typeConverter,
RewritePatternSet &patterns);
RewritePatternSet &patterns,
const std::string &target = "",
const std::string &arch = "");

std::unique_ptr<OperationPass<func::FuncOp>>
createHloFusionToLinalgPass(llvm::StringRef anchorTag = "",
bool enablePrimitiveOps = false);
std::unique_ptr<OperationPass<func::FuncOp>> createHloFusionToLinalgPass(
llvm::StringRef anchorTag = "", bool enablePrimitiveOps = false,
const std::string &target = "", const std::string &arch = "");

std::unique_ptr<OperationPass<func::FuncOp>> createUnrealizedCastToLinalgPass();

Expand Down
3 changes: 3 additions & 0 deletions compiler/include/byteir/Pipelines/LinalgTensorOpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ struct LinalgTensorOptPipelineOptions
*this, "target",
llvm::cl::desc("An optional attribute to speicify target."),
llvm::cl::init("")};
Option<std::string> arch{
*this, "arch", llvm::cl::desc("An optional attribute to speicify arch."),
llvm::cl::init("")};
};

void createLinalgTensorOptPipeline(
Expand Down
137 changes: 125 additions & 12 deletions compiler/lib/Conversion/ToLinalg/HloToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1267,13 +1267,116 @@ class ByteirRepeatCustomCallConverter
}
};

/// Code below is copied from legalize_to_linalg.cc
/// Remove this when upstream FPToSIOp solves inf/nan convert.
Value coerceTensorShape(OpBuilder &builder, Location loc,
TypedValue<ShapedType> value, ShapedType targetType) {
return builder.createOrFold<tensor::CastOp>(
loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()),
value);
}

inline Value mapFPToSIConvertOpToStdScalarOp(Location loc,
ArrayRef<Type> targetTypes,
ArrayRef<Type> resultTypes,
ValueRange args, OpBuilder *b) {
assert(targetTypes.size() == 1 && "ConvertOp should return a single result");
assert(resultTypes.size() == 1 && "ConvertOp should return a single result");
assert(args.size() == 1 && "ConvertOp should take a single argument");

Type targetType = getElementTypeOrSelf(targetTypes.front());
Type convertedSourceType = getElementTypeOrSelf(args.front());

if (mlir::arith::FPToSIOp::areCastCompatible(convertedSourceType,
targetType)) {
Value infValue = b->create<mlir::arith::ConstantOp>(
loc,
b->getFloatAttr(
convertedSourceType,
APFloat::getInf(
dyn_cast<FloatType>(convertedSourceType).getFloatSemantics())));
Value isInf = b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
args.front(), infValue);
Value isNan = b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
args.front(), args.front());
Value maxIntval = b->create<arith::ConstantOp>(
loc,
b->getIntegerAttr(targetType,
APInt::getSignedMaxValue(
dyn_cast<IntegerType>(targetType).getWidth())));
Value zeroIntval =
b->create<arith::ConstantOp>(loc, b->getZeroAttr(targetType));
return b->create<::mlir::arith::SelectOp>(
loc, isInf, maxIntval,
b->create<::mlir::arith::SelectOp>(
loc, isNan, zeroIntval,
b->create<mlir::arith::FPToSIOp>(loc, resultTypes, args,
std::nullopt)));
}
return nullptr;
}

class FPToSIConvertOpConverter : public OpConversionPattern<mhlo::ConvertOp> {
public:
using OpConversionPattern<mhlo::ConvertOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(mhlo::ConvertOp op, typename mhlo::ConvertOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op.getLoc();
RankedTensorType type = dyn_cast<RankedTensorType>(op.getType());
if (!type || !type.hasStaticShape()) {
return failure();
}
// Apply only if convert type is FPToInt32
if (!mlir::arith::FPToSIOp::areCastCompatible(op.getOperand().getType(),
op.getType())) {
return failure();
}
auto targetType = op.getType().getElementType();
if (isa<IntegerType>(targetType) &&
(cast<IntegerType>(targetType).getWidth() != 32 ||
cast<IntegerType>(targetType).isUnsigned())) {
return failure();
}
// Find input/output values and types.
std::optional<ShapedType> resultTy =
this->typeConverter->convertType(op->getResultTypes().front())
.template dyn_cast<ShapedType>();
Value emptyTensor =
getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands());
// Mapped inputs are cast to the same shape as the init tensor.
SmallVector<Value> mappedInputs;
for (Value input : adaptor.getOperands()) {
mappedInputs.push_back(
coerceTensorShape(rewriter, loc, cast<TypedValue<ShapedType>>(input),
cast<ShapedType>(emptyTensor.getType())));
}

auto mapOp = rewriter.create<linalg::MapOp>(
loc, mappedInputs, emptyTensor,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value innerResult = mapFPToSIConvertOpToStdScalarOp(
op.getLoc(), op.getType(), getElementTypeOrSelf(emptyTensor),
args, &b);
b.create<linalg::YieldOp>(loc, innerResult);
},
linalg::getPrunedAttributeList(op));
rewriter.replaceOp(op, mapOp->getResults());
return success();
}
};

struct HloFusionToLinalgPass
: public HloFusionToLinalgBase<HloFusionToLinalgPass> {

HloFusionToLinalgPass(StringRef tag, bool enablePrimitiveOps)
HloFusionToLinalgPass(StringRef tag, bool enablePrimitiveOps,
StringRef target, StringRef arch)
: HloFusionToLinalgBase() {
anchorTag = tag.str();
this->enablePrimitiveOps = enablePrimitiveOps;
this->target = target.str();
this->arch = arch.str();
}

void getDependentDialects(DialectRegistry &registry) const final {
Expand All @@ -1293,13 +1396,13 @@ struct HloFusionToLinalgPass

MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
ConversionTarget target(ctx);
target.addLegalDialect<
ConversionTarget conversionTarget(ctx);
conversionTarget.addLegalDialect<
arith::ArithDialect, cf::ControlFlowDialect, func::FuncDialect,
linalg::LinalgDialect, math::MathDialect, tensor::TensorDialect,
scf::SCFDialect, shape::ShapeDialect, linalg_ext::LinalgExtDialect>();

target.addLegalOp<UnrealizedConversionCastOp>();
conversionTarget.addLegalOp<UnrealizedConversionCastOp>();

auto typeConverter = createHloToLinalgTypeConverter();

Expand All @@ -1308,22 +1411,31 @@ struct HloFusionToLinalgPass
[](Operation *op) { return isInBodyOfLinalgOps(op); });
mhlo::populateHloToLinalgConversionPattern(&ctx, *typeConverter, &patterns,
enablePrimitiveOps);
populateHloToLinalgExtConversionPattern(*typeConverter, patterns);
populateHloToLinalgExtConversionPattern(*typeConverter, patterns,
this->target, this->arch);

FrozenRewritePatternSet frozenPatterns(std::move(patterns));
if (failed(applyPartialConversion(func, target, frozenPatterns))) {
if (failed(
applyPartialConversion(func, conversionTarget, frozenPatterns))) {
signalPassFailure();
}
}
};

} // namespace

void mlir::populateHloToLinalgExtConversionPattern(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
void mlir::populateHloToLinalgExtConversionPattern(TypeConverter &typeConverter,
RewritePatternSet &patterns,
const std::string &target,
const std::string &arch) {
auto ctx = patterns.getContext();
patterns.add<ReduceWindowOpConversion>(typeConverter, ctx, PatternBenefit(2));
patterns.add<DotGeneralLinalgExtBatchMatMulOpConversion>(typeConverter, ctx,
PatternBenefit(2));
if (target == "cpu" && arch == "x86_64") {
patterns.add<FPToSIConvertOpConverter>(typeConverter, ctx,
PatternBenefit(2));
}
patterns.add<SoftmaxCustomCallConverter>(ctx);
patterns.add<ScatterOpConversion>(ctx);
patterns.add<LayerNormCustomCallConverter>(ctx);
Expand All @@ -1333,8 +1445,9 @@ void mlir::populateHloToLinalgExtConversionPattern(
patterns.add<ByteirRepeatCustomCallConverter>(ctx);
}

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::createHloFusionToLinalgPass(llvm::StringRef anchorTag,
bool enablePrimitiveOps) {
return std::make_unique<HloFusionToLinalgPass>(anchorTag, enablePrimitiveOps);
std::unique_ptr<OperationPass<func::FuncOp>> mlir::createHloFusionToLinalgPass(
llvm::StringRef anchorTag, bool enablePrimitiveOps,
const std::string &target, const std::string &arch) {
return std::make_unique<HloFusionToLinalgPass>(anchorTag, enablePrimitiveOps,
target, arch);
}
12 changes: 7 additions & 5 deletions compiler/lib/Pipelines/LinalgTensorOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,10 @@ void addGenericLinalgPasses(OpPassManager &pm) {
}
}

void addCPULinalgOptPasses(OpPassManager &pm) {
void addCPULinalgOptPasses(OpPassManager &pm, const std::string &target,
const std::string &arch) {
pm.addNestedPass<func::FuncOp>(createHloFusionToLinalgPass(
getByteIRHloAggressiveFusionAttrName(), true));
getByteIRHloAggressiveFusionAttrName(), true, target, arch));
pm.addNestedPass<func::FuncOp>(createUnrealizedCastToLinalgPass());
{
TileAndVectorizeTransposeOptions options;
Expand All @@ -248,9 +249,10 @@ void addCPULinalgOptPasses(OpPassManager &pm) {
}

void createLinalgTensorOptPipelineImpl(OpPassManager &pm,
const std::string &target) {
const std::string &target,
const std::string &arch) {
if (target == "cpu") {
addCPULinalgOptPasses(pm);
addCPULinalgOptPasses(pm, target, arch);
} else {
addGenericLinalgPasses(pm);
}
Expand All @@ -260,5 +262,5 @@ void createLinalgTensorOptPipelineImpl(OpPassManager &pm,
void mlir::createLinalgTensorOptPipeline(
OpPassManager &pm, const LinalgTensorOptPipelineOptions &options) {
invokeOpPassPipelineBuilder(createLinalgTensorOptPipelineImpl, pm,
options.target);
options.target, options.arch);
}
3 changes: 2 additions & 1 deletion compiler/python/byteir/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,15 @@ def _compile_cpu(

entry_func_str = "entry-func={}".format(entry_func)
target_str = "target={}".format(target)
arch_str="arch={}".format(cpu_arch)
with context:
PassManager().parse("builtin.module(hlo-graph-opt{" + entry_func_str + " " + target_str + "})").run(module.operation)
_print_verbose(module, "// IR Dump After Hlo Graph Opt:") if verbose else ...
with context:
PassManager().parse("builtin.module(hlo-fusion-opt{" + entry_func_str + " " + target_str + " outline-single-elemwise-op})").run(module.operation)
_print_verbose(module, "// IR Dump After Hlo Fusion Opt:") if verbose else ...
with context:
PassManager.parse("builtin.module(linalg-tensor-opt{" + target_str + "})").run(module.operation)
PassManager.parse("builtin.module(linalg-tensor-opt{" + target_str + " " + arch_str + "})").run(module.operation)
_print_verbose(module, "// IR Dump After Linalg Tensor Opt:") if verbose else ...
with context:
PassManager.parse("builtin.module(byre-tensor-opt{{append-arg-types {}}})".format(entry_func_str)).run(module.operation)
Expand Down
3 changes: 3 additions & 0 deletions tests/numerical_test/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
MLIR_TEST_SPECIAL_INPUTS = {
"cpu@log_plus_one.mlir": [
np.random.uniform(low=0.5, high=1.0, size=(256, 64)).astype(np.float16)
],
"cpu@convert_f32_i32_special_val.mlir": [
np.array([[np.inf, -np.inf, np.nan], [1., 999.999, -np.inf]], dtype=np.float32),
]
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
func.func @convert_f32_i32_special_val(%arg0 : tensor<2x3xf32>) -> tensor<2x3xi32> {
%0 = stablehlo.convert %arg0 : (tensor<2x3xf32>) -> tensor<2x3xi32>
func.return %0 : tensor<2x3xi32>
}

0 comments on commit 329efd7

Please sign in to comment.