Skip to content

Commit

Permalink
conv: Add return conversion pattern.
Browse files Browse the repository at this point in the history
  • Loading branch information
xlauko committed Nov 22, 2024
1 parent d3dfbfb commit c6510d6
Showing 1 changed file with 56 additions and 15 deletions.
71 changes: 56 additions & 15 deletions lib/vast/Conversion/Parser/ToParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,20 @@ namespace vast::conv {
: base(mctx), models(models)
{}

static std::optional< function_model > get_model(
const function_models &models, string_ref name
) {
if (auto kv = models.find(name); kv != models.end()) {
return kv->second;
}

return std::nullopt;
}

std::optional< function_model > get_model(string_ref name) const {
return get_model(models, name);
}

const function_models &models;
};

Expand Down Expand Up @@ -408,34 +422,60 @@ namespace vast::conv {
}
};

struct FuncConversion
: parser_conversion_pattern_base< hl::FuncOp >
, tc::op_type_conversion< hl::FuncOp, function_type_converter >
struct ReturnConversion
: parser_conversion_pattern_base< hl::ReturnOp >
{
using op_t = hl::FuncOp;
using op_t = hl::ReturnOp;
using base = parser_conversion_pattern_base< op_t >;
using base::base;

using adaptor_t = typename op_t::Adaptor;

static std::optional< function_model > get_model(
const function_models &models, op_t op
) {
if (auto kv = models.find(op.getSymName()); kv != models.end()) {
return kv->second;
}
logical_result matchAndRewrite(
op_t op, adaptor_t adaptor, conversion_rewriter &rewriter
) const override {
auto func = op->getParentOfType< hl::FuncOp >();
auto model = get_model(func.getSymName());

return std::nullopt;
auto rty = model
? model->get_return_type(rewriter.getContext())
: pr::MaybeDataType::get(rewriter.getContext());

rewriter.replaceOpWithNewOp< op_t >(
op, convert_value_types(adaptor.getOperands(), rty, rewriter)
);

return mlir::success();
}

std::optional< function_model > get_model(op_t op) const {
return get_model(models, op);
static void legalize(parser_conversion_config &cfg) {
cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >();
cfg.target.addDynamicallyLegalOp< op_t >([](op_t op) {
for (auto ty : op.getResult().getType()) {
if (!is_parser_type(ty)) {
return false;
}
}
return true;
});
}
};

struct FuncConversion
: parser_conversion_pattern_base< hl::FuncOp >
, tc::op_type_conversion< hl::FuncOp, function_type_converter >
{
using op_t = hl::FuncOp;
using base = parser_conversion_pattern_base< op_t >;
using base::base;

using adaptor_t = typename op_t::Adaptor;


logical_result matchAndRewrite(
op_t op, adaptor_t adaptor, conversion_rewriter &rewriter
) const override {
auto tc = function_type_converter(*rewriter.getContext(), get_model(op));
auto tc = function_type_converter(*rewriter.getContext(), get_model(op.getSymName()));
if (auto func_op = mlir::dyn_cast< core::function_op_interface >(op.getOperation())) {
return this->replace(func_op, rewriter, tc);
}
Expand All @@ -447,7 +487,7 @@ namespace vast::conv {
cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >();
cfg.target.addDynamicallyLegalOp< op_t >([models = cfg.models](op_t op) {
return function_type_converter(
*op.getContext(), get_model(models, op)
*op.getContext(), get_model(models, op.getSymName())
).isLegal(op.getFunctionType());
});
}
Expand All @@ -459,6 +499,7 @@ namespace vast::conv {
FuncConversion,
ParamConversion,
DeclRefConversion,
ReturnConversion,
CallConversion
>;

Expand Down

0 comments on commit c6510d6

Please sign in to comment.