diff --git a/compiler/include/byteir/Conversion/Passes.td b/compiler/include/byteir/Conversion/Passes.td index f7d79311c..01dd52c7f 100644 --- a/compiler/include/byteir/Conversion/Passes.td +++ b/compiler/include/byteir/Conversion/Passes.td @@ -63,7 +63,11 @@ def HloFusionToLinalg : Pass<"hlo-fusion-to-linalg", "func::FuncOp"> { Option<"enablePrimitiveOps", "enable-primitive-ops", "bool", /*default=*/"false", "Lower to primitive Linalg ops (map, reduce and " - "transpose) when possible, instead of linalg.generic"> + "transpose) when possible, instead of linalg.generic">, + Option<"target", "target", "std::string", /*default*/ "", + "Specificy the target">, + Option<"arch", "arch", "std::string", /*default*/ "", + "Specificy the target arch"> ]; } diff --git a/compiler/include/byteir/Conversion/ToLinalg/ToLinalg.h b/compiler/include/byteir/Conversion/ToLinalg/ToLinalg.h index 8e64ce9d3..8a8c0f114 100644 --- a/compiler/include/byteir/Conversion/ToLinalg/ToLinalg.h +++ b/compiler/include/byteir/Conversion/ToLinalg/ToLinalg.h @@ -41,11 +41,13 @@ void populateTensorToLinalgConversionPatterns(RewritePatternSet &patterns); void populateLinalgExtToLinalgConversionPatterns(RewritePatternSet &patterns); void populateHloToLinalgExtConversionPattern(TypeConverter &typeConverter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + const std::string &target = "", + const std::string &arch = ""); -std::unique_ptr> -createHloFusionToLinalgPass(llvm::StringRef anchorTag = "", - bool enablePrimitiveOps = false); +std::unique_ptr> createHloFusionToLinalgPass( + llvm::StringRef anchorTag = "", bool enablePrimitiveOps = false, + const std::string &target = "", const std::string &arch = ""); std::unique_ptr> createUnrealizedCastToLinalgPass(); diff --git a/compiler/include/byteir/Pipelines/LinalgTensorOpt.h b/compiler/include/byteir/Pipelines/LinalgTensorOpt.h index e8428ba4b..a5d521b2d 100644 --- a/compiler/include/byteir/Pipelines/LinalgTensorOpt.h +++ b/compiler/include/byteir/Pipelines/LinalgTensorOpt.h @@ -30,6 +30,9 @@ struct LinalgTensorOptPipelineOptions *this, "target", llvm::cl::desc("An optional attribute to speicify target."), llvm::cl::init("")}; + Option arch{ + *this, "arch", llvm::cl::desc("An optional attribute to speicify arch."), + llvm::cl::init("")}; }; void createLinalgTensorOptPipeline( diff --git a/compiler/lib/Conversion/ToLinalg/HloToLinalg.cpp b/compiler/lib/Conversion/ToLinalg/HloToLinalg.cpp index 2cba82b6f..4f177ebaf 100644 --- a/compiler/lib/Conversion/ToLinalg/HloToLinalg.cpp +++ b/compiler/lib/Conversion/ToLinalg/HloToLinalg.cpp @@ -1270,10 +1270,13 @@ class ByteirRepeatCustomCallConverter struct HloFusionToLinalgPass : public HloFusionToLinalgBase { - HloFusionToLinalgPass(StringRef tag, bool enablePrimitiveOps) + HloFusionToLinalgPass(StringRef tag, bool enablePrimitiveOps, + StringRef target, StringRef arch) : HloFusionToLinalgBase() { anchorTag = tag.str(); this->enablePrimitiveOps = enablePrimitiveOps; + this->target = target.str(); + this->arch = arch.str(); } void getDependentDialects(DialectRegistry ®istry) const final { @@ -1293,13 +1296,13 @@ struct HloFusionToLinalgPass MLIRContext &ctx = getContext(); RewritePatternSet patterns(&ctx); - ConversionTarget target(ctx); - target.addLegalDialect< + ConversionTarget conversionTarget(ctx); + conversionTarget.addLegalDialect< arith::ArithDialect, cf::ControlFlowDialect, func::FuncDialect, linalg::LinalgDialect, math::MathDialect, tensor::TensorDialect, scf::SCFDialect, shape::ShapeDialect, linalg_ext::LinalgExtDialect>(); - target.addLegalOp(); + conversionTarget.addLegalOp(); auto typeConverter = createHloToLinalgTypeConverter(); @@ -1308,22 +1311,188 @@ struct HloFusionToLinalgPass [](Operation *op) { return isInBodyOfLinalgOps(op); }); mhlo::populateHloToLinalgConversionPattern(&ctx, *typeConverter, &patterns, enablePrimitiveOps); - populateHloToLinalgExtConversionPattern(*typeConverter, patterns); + populateHloToLinalgExtConversionPattern(*typeConverter, patterns, + this->target, this->arch); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); - if (failed(applyPartialConversion(func, target, frozenPatterns))) { + if (failed( + applyPartialConversion(func, conversionTarget, frozenPatterns))) { signalPassFailure(); } } }; + +/// Code below is copied from legalize_to_linalg.cc +/// Remove this when upstream FPToSIOp solves inf/nan convert. +Value coerceTensorShape(OpBuilder &builder, Location loc, + TypedValue value, ShapedType targetType) { + return builder.createOrFold( + loc, targetType.cloneWith(std::nullopt, value.getType().getElementType()), + value); +} + +inline Value mapFPToSIConvertOpToStdScalarOp(Location loc, + ArrayRef targetTypes, + ArrayRef resultTypes, + ArrayRef argTypes, + ValueRange args, OpBuilder *b) { + assert(targetTypes.size() == 1 && "ConvertOp should return a single result"); + assert(resultTypes.size() == 1 && "ConvertOp should return a single result"); + assert(argTypes.size() == 1 && "ConvertOp should take a single argument"); + assert(args.size() == 1 && "ConvertOp should take a single argument"); + + Type sourceType = getElementTypeOrSelf(argTypes.front()); + Type targetType = getElementTypeOrSelf(targetTypes.front()); + Type convertedSourceType = getElementTypeOrSelf(args.front()); + + if (mlir::arith::FPToSIOp::areCastCompatible(convertedSourceType, + targetType)) { + Value infValue = b->create( + loc, + b->getFloatAttr( + convertedSourceType, + APFloat::getInf( + dyn_cast(convertedSourceType).getFloatSemantics()))); + Value isInf = b->create(loc, arith::CmpFPredicate::OEQ, + args.front(), infValue); + Value isNan = b->create(loc, arith::CmpFPredicate::UNO, + args.front(), args.front()); + Value maxIntval = b->create( + loc, + b->getIntegerAttr(targetType, + APInt::getSignedMaxValue( + dyn_cast(targetType).getWidth()))); + Value zeroIntval = + b->create(loc, b->getZeroAttr(targetType)); + return b->create<::mlir::arith::SelectOp>( + loc, isInf, maxIntval, + b->create<::mlir::arith::SelectOp>( + loc, isNan, zeroIntval, + b->create(loc, resultTypes, args, + std::nullopt))); + } + return nullptr; +} + +class FPToSIConvertOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(mhlo::ConvertOp op, typename mhlo::ConvertOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + int64_t maxRank = getMaxRank(adaptor); + // Apply only if all operands are scalar or have the same rank. + if (!llvm::all_of(adaptor.getOperands(), [&](Value v) { + int64_t r = getRank(v); + return r == 0 || r == maxRank; + })) { + return rewriter.notifyMatchFailure( + op, "Operands must be of same rank or scalar."); + } + // Find result type, if on tensors. + std::optional resultTy; + resultTy = this->typeConverter->convertType(op->getResultTypes().front()) + .template dyn_cast(); + + // Check result type compatibility. + if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != maxRank || + !(resultTy->getElementType().isSignlessIntOrFloat() || + resultTy->getElementType().isa())) { + return rewriter.notifyMatchFailure( + op, "mismatched operand/result types or iterator count"); + } + // Apply only if convert type is FPToInt32 + if (!mlir::arith::FPToSIOp::areCastCompatible(op.getOperand().getType(), + op.getType())) { + return failure(); + } + auto targetType = op.getType().getElementType(); + if (isa(targetType) && + cast(targetType).getWidth() != 32) { + return failure(); + } + // All-scalar pointwise ops inside of linalg ops are processes by + // ScalarHloToArithmeticPattern. + if (maxRank == 0 && isInBodyOfLinalgOps(op)) + return failure(); + // Find input/output values and types. + Value emptyTensor = + getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands()); + // Mapped inputs are cast to the same shape as the init tensor. + // Values from scalar inputs are extracted and used directly in the block. + SmallVector mappedInputs; + SmallVector scalarInputs; + for (Value input : adaptor.getOperands()) { + if (getRank(input) == maxRank) { + mappedInputs.push_back(coerceTensorShape( + rewriter, loc, cast>(input), + cast(emptyTensor.getType()))); + scalarInputs.push_back(nullptr); + } else { + scalarInputs.push_back(rewriter.create(loc, input)); + } + } + + auto mapOp = rewriter.create( + loc, mappedInputs, emptyTensor, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value innerResult = mapFPToSIConvertOpToStdScalarOp( + op.getLoc(), op.getType(), getElementTypeOrSelf(emptyTensor), + llvm::to_vector(op->getOperandTypes()), + interleaveScalarAndBlockArgs(scalarInputs, args), &b); + b.create(loc, innerResult); + }, + linalg::getPrunedAttributeList(op)); + rewriter.replaceOp(op, mapOp->getResults()); + return success(); + } + +protected: + int64_t getRank(Value v) const { + return v.getType().cast().getRank(); + } + + int64_t getMaxRank(typename mhlo::ConvertOp::Adaptor adaptor) const { + int64_t maxRank = 0; + for (auto operand : adaptor.getOperands()) { + maxRank = std::max(maxRank, getRank(operand)); + } + return maxRank; + } + + // Inserts block arguments in places where scalar inputs have a nullptr. + SmallVector interleaveScalarAndBlockArgs(ValueRange scalarInputs, + ValueRange blockArgs) const { + SmallVector result; + auto argsIter = blockArgs.begin(); + for (Value scalarInput : scalarInputs) { + if (scalarInput) { + result.push_back(scalarInput); + } else { + result.push_back(*argsIter); + ++argsIter; + } + } + return result; + } +}; + } // namespace -void mlir::populateHloToLinalgExtConversionPattern( - TypeConverter &typeConverter, RewritePatternSet &patterns) { +void mlir::populateHloToLinalgExtConversionPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, + const std::string &target, + const std::string &arch) { auto ctx = patterns.getContext(); patterns.add(typeConverter, ctx, PatternBenefit(2)); patterns.add(typeConverter, ctx, PatternBenefit(2)); + if (target == "cpu" && arch == "x86_64") { + patterns.add(typeConverter, ctx, + PatternBenefit(2)); + } patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); @@ -1333,8 +1502,9 @@ void mlir::populateHloToLinalgExtConversionPattern( patterns.add(ctx); } -std::unique_ptr> -mlir::createHloFusionToLinalgPass(llvm::StringRef anchorTag, - bool enablePrimitiveOps) { - return std::make_unique(anchorTag, enablePrimitiveOps); +std::unique_ptr> mlir::createHloFusionToLinalgPass( + llvm::StringRef anchorTag, bool enablePrimitiveOps, + const std::string &target, const std::string &arch) { + return std::make_unique(anchorTag, enablePrimitiveOps, + target, arch); } diff --git a/compiler/lib/Pipelines/LinalgTensorOpt.cpp b/compiler/lib/Pipelines/LinalgTensorOpt.cpp index 4cc957771..b1d75341c 100644 --- a/compiler/lib/Pipelines/LinalgTensorOpt.cpp +++ b/compiler/lib/Pipelines/LinalgTensorOpt.cpp @@ -228,9 +228,10 @@ void addGenericLinalgPasses(OpPassManager &pm) { } } -void addCPULinalgOptPasses(OpPassManager &pm) { +void addCPULinalgOptPasses(OpPassManager &pm, const std::string &target, + const std::string &arch) { pm.addNestedPass(createHloFusionToLinalgPass( - getByteIRHloAggressiveFusionAttrName(), true)); + getByteIRHloAggressiveFusionAttrName(), true, target, arch)); pm.addNestedPass(createUnrealizedCastToLinalgPass()); { TileAndVectorizeTransposeOptions options; @@ -248,9 +249,10 @@ void addCPULinalgOptPasses(OpPassManager &pm) { } void createLinalgTensorOptPipelineImpl(OpPassManager &pm, - const std::string &target) { + const std::string &target, + const std::string &arch) { if (target == "cpu") { - addCPULinalgOptPasses(pm); + addCPULinalgOptPasses(pm, target, arch); } else { addGenericLinalgPasses(pm); } @@ -260,5 +262,5 @@ void createLinalgTensorOptPipelineImpl(OpPassManager &pm, void mlir::createLinalgTensorOptPipeline( OpPassManager &pm, const LinalgTensorOptPipelineOptions &options) { invokeOpPassPipelineBuilder(createLinalgTensorOptPipelineImpl, pm, - options.target); + options.target, options.arch); } diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index 1b7cbe902..7f98549e2 100644 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -297,6 +297,7 @@ def _compile_cpu( entry_func_str = "entry-func={}".format(entry_func) target_str = "target={}".format(target) + arch_str="arch={}".format(cpu_arch) with context: PassManager().parse("builtin.module(hlo-graph-opt{" + entry_func_str + " " + target_str + "})").run(module.operation) _print_verbose(module, "// IR Dump After Hlo Graph Opt:") if verbose else ... @@ -304,7 +305,7 @@ def _compile_cpu( PassManager().parse("builtin.module(hlo-fusion-opt{" + entry_func_str + " " + target_str + " outline-single-elemwise-op})").run(module.operation) _print_verbose(module, "// IR Dump After Hlo Fusion Opt:") if verbose else ... with context: - PassManager.parse("builtin.module(linalg-tensor-opt{" + target_str + "})").run(module.operation) + PassManager.parse("builtin.module(linalg-tensor-opt{" + target_str + " " + arch_str + "})").run(module.operation) _print_verbose(module, "// IR Dump After Linalg Tensor Opt:") if verbose else ... with context: PassManager.parse("builtin.module(byre-tensor-opt{{append-arg-types {}}})".format(entry_func_str)).run(module.operation) diff --git a/tests/numerical_test/execute.py b/tests/numerical_test/execute.py index 3095df438..6339605e5 100644 --- a/tests/numerical_test/execute.py +++ b/tests/numerical_test/execute.py @@ -33,6 +33,9 @@ MLIR_TEST_SPECIAL_INPUTS = { "cpu@log_plus_one.mlir": [ np.random.uniform(low=0.5, high=1.0, size=(256, 64)).astype(np.float16) + ], + "cpu@convert_f32_i32_special_val.mlir": [ + np.array([[np.inf, -np.inf, np.nan], [1., 999.999, -np.inf]], dtype=np.float32), ] }