Skip to content

Commit

Permalink
added set_bypass() call (#3116)
Browse files Browse the repository at this point in the history
  • Loading branch information
bpickrel authored May 28, 2024
1 parent e182a61 commit 78f55ed
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ struct fused_reduce
const auto* sm = mods.front();
if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported");
if(not sm->bypass())
MIGRAPHX_THROW("fused_reduce: bypass flag is not set");
auto names = sm->get_parameter_names();
check_shapes{inputs, *this}.has(names.size()).same_ndims();
std::sort(names.begin(), names.end());
Expand Down Expand Up @@ -426,6 +428,7 @@ struct reduce_reshape : rewrite_reshapes_base
auto dims = base_dims(inputs);
auto* oldm = ins->module_inputs().front();
auto* sm = mpm.create_module(oldm->name() + "_reshape");
sm->set_bypass();
insert_module_in_submodule(sm, inputs, oldm, transform_op([&](const operation& sop) {
if(contains(sop.name(), "reduce"))
return make_op(sop.name(), {{"axes", axes}});
Expand Down
12 changes: 12 additions & 0 deletions test/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ TEST_CASE(scalar_multibroadcast)
reduce_mod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), x0);
reduce_mod->add_return({sqrtbc});

EXPECT(test::throws([&] {
mm->add_instruction(
migraphx::make_op("fused_reduce", {{"axes", {1, 2}}}), {pow}, {reduce_mod});
}));
// reduce modules must be flagged for bypass when running subsequent passes
reduce_mod->set_bypass();
auto bip = mm->add_instruction(
migraphx::make_op("fused_reduce", {{"axes", {1, 2}}}), {pow}, {reduce_mod});
mm->add_return({bip});
Expand Down Expand Up @@ -217,6 +223,12 @@ TEST_CASE(scalar_multibroadcast_contiguous)
reduce_mod->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), x0);
reduce_mod->add_return({sqrtbc});

EXPECT(test::throws([&] {
mm->add_instruction(
migraphx::make_op("fused_reduce", {{"axes", {1, 2}}}), {pow}, {reduce_mod});
}));
// reduce modules must be flagged for bypass when running subsequent passes
reduce_mod->set_bypass();
auto bip = mm->add_instruction(
migraphx::make_op("fused_reduce", {{"axes", {1, 2}}}), {pow}, {reduce_mod});
mm->add_return({bip});
Expand Down

0 comments on commit 78f55ed

Please sign in to comment.