Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse transposes in pointwise and reduce fusions #3705

Merged
merged 42 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
844bbfe
Add methods to shape transform descriptor to get the common shape
pfultz2 Oct 12, 2024
21c5fd0
Format
pfultz2 Oct 12, 2024
6cf67af
Update rewrite_reshapes
pfultz2 Oct 13, 2024
d556255
Format
pfultz2 Oct 13, 2024
983cd31
Merge branch 'develop' into rewrite_reshapes-shape-transform
pfultz2 Oct 21, 2024
083f22b
Handle empty dims
pfultz2 Oct 22, 2024
5059a6c
Fix dims from dst
pfultz2 Oct 22, 2024
b209a00
Format
pfultz2 Oct 22, 2024
1e7e481
Add contiguous
pfultz2 Oct 22, 2024
985f342
Format
pfultz2 Oct 22, 2024
93dec17
Fix fuse_pointwise test
pfultz2 Oct 23, 2024
bf65516
Foramt
pfultz2 Oct 23, 2024
e260aca
Add transpose test
pfultz2 Oct 26, 2024
1588f26
Format
pfultz2 Oct 26, 2024
78c542d
Handle broadcasts
pfultz2 Oct 29, 2024
4d35f19
Format
pfultz2 Oct 29, 2024
98f6866
Try to use broadcast in shape transform
pfultz2 Oct 29, 2024
d26ab72
Format
pfultz2 Oct 29, 2024
ea93faf
Handle rebase
pfultz2 Oct 29, 2024
1e2856a
Format
pfultz2 Oct 29, 2024
edbe407
Check for only broadcast
pfultz2 Oct 29, 2024
3ec0846
Format
pfultz2 Oct 29, 2024
7a8db2d
Merge branch 'develop' into rewrite_reshapes-shape-transform
pfultz2 Nov 1, 2024
43005c5
Fix sizes
pfultz2 Nov 1, 2024
7283cbf
Format
pfultz2 Nov 1, 2024
7693133
Update test
pfultz2 Nov 2, 2024
15644a9
Remove line
pfultz2 Nov 2, 2024
b24adc0
Merge
pfultz2 Dec 4, 2024
f171e4f
Use origin_axis
pfultz2 Dec 4, 2024
fd4299e
Reuse create function
pfultz2 Dec 4, 2024
3d5515e
Add default arguments
pfultz2 Dec 4, 2024
0a5776d
Add unit tests
pfultz2 Dec 5, 2024
a7f87a8
Add more tests
pfultz2 Dec 10, 2024
4560be9
Fix rebase
pfultz2 Dec 11, 2024
73c5697
Add unit tests for rebase
pfultz2 Dec 11, 2024
df9a9be
Reuse code
pfultz2 Dec 11, 2024
4767d16
Use expose/hide API
pfultz2 Dec 11, 2024
5e0ee64
Format
pfultz2 Dec 11, 2024
aadc02e
Merge branch 'develop' into rewrite_reshapes-shape-transform
pfultz2 Dec 11, 2024
935e3eb
Cleanup
pfultz2 Dec 11, 2024
a502b67
Fix tidy issues
pfultz2 Dec 12, 2024
e1a768f
Merge branch 'develop' into rewrite_reshapes-shape-transform
pfultz2 Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 60 additions & 41 deletions src/include/migraphx/rewrite_reshapes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/shape_transform_descriptor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -72,18 +73,19 @@ struct rewrite_reshapes

auto matcher() const
{
auto reshape =
match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once());
auto skip_contiguous_broadcast =
match::skip(match::name("contiguous", "multibroadcast")(match::used_once()));
auto skip_contiguous_broadcast_arg = [&](auto... ms) {
return match::arg(0)(skip_contiguous_broadcast(ms...));
};
auto reshapes = match::name("reshape",
"squeeze",
"unsqueeze",
"flatten",
"transpose",
"contiguous",
"multibroadcast",
"broadcast")(match::used_once());
auto pointwise = match::name(op1)(match::used_once());
auto reshape_pointwise =
reshape(skip_contiguous_broadcast_arg(pointwise.bind("x"))).bind("reshape");
return match::name(op2)(match::any_of[match::inputs()](
skip_contiguous_broadcast(reshape_pointwise).bind("input")));
auto reshapes_pointwise =
reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x"))));
return match::name(op2)(
match::any_of[match::inputs()](reshapes_pointwise.bind("input")));
}

template <class F>
Expand All @@ -100,6 +102,12 @@ struct rewrite_reshapes
return last;
}

template <class F>
static bool any_input_of(instruction_ref start, instruction_ref last, F f)
{
return find_input_if(start, last, f) != last;
}

static bool match_input(instruction_ref ins, instruction_ref x_ins)
{
if(ins->inputs().empty())
Expand All @@ -120,65 +128,76 @@ struct rewrite_reshapes
return result;
}

static bool is_broadcast(instruction_ref ins) { return ins->name() == "multibroadcast"; }

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"];
auto input_ins = r.instructions["input"];

const auto has_broadcast_before_reshape = is_broadcasted(reshape_ins, x_ins);
const auto has_broadcast_after_reshape = is_broadcasted(input_ins, reshape_ins);
if(not has_broadcast_before_reshape.has_value())
return;
if(not has_broadcast_after_reshape.has_value())
// If its just a broadcast then skip
if(not any_input_of(input_ins, x_ins, [](instruction_ref x) {
return not contains({"multibroadcast", "broadcast", "contiguous"}, x->name());
}))
return;
if(*has_broadcast_after_reshape and *has_broadcast_before_reshape)
return;
const bool has_broadcast =
*has_broadcast_after_reshape or *has_broadcast_before_reshape;

auto dims1 = T::base_dims(ins);
auto dims2 = T::base_dims(x_ins);

if(elements(dims1) != elements(dims2))
return;

auto cd = common_dims::compute(T::base_dims(ins), T::base_dims(x_ins));
if(cd.dims.empty())
return;
std::vector<operation> ops;
auto next_ins = input_ins;
while(next_ins != x_ins)
{
ops.push_back(next_ins->get_operator());
next_ins = next_ins->inputs().front();
}
assert(next_ins == x_ins);
std::reverse(ops.begin(), ops.end());

if(ins->name() != "pointwise" and not T::supports(ins, cd.dims, cd.axes_map1))
auto desc =
shape_transform_descriptor::create(x_ins->get_shape().lens(), ops).rebase(dims2);
if(desc.empty())
return;
if(x_ins->name() != "pointwise" and not T::supports(x_ins, cd.dims, cd.axes_map2))
return;

auto reshape_input = [&](const auto& ins_to_insert) {
return [&](auto input) {
auto dims = cd.get_dimensions_for(input->get_shape().lens());
return mpm.get_module().insert_instruction(
ins_to_insert, make_op("reshape", {{"dims", dims}}), input);
auto cdims = desc.common_dims();
auto reshape_input = [&](const auto& ins_to_insert, auto generate) {
return [&, generate](auto input) {
auto gops = std::invoke(generate, desc, input->get_shape().lens());
auto start = input;
for(const auto& op : gops)
{
start = mpm.get_module().insert_instruction(ins_to_insert, op, start);
}
return start;
};
};
auto x_inputs = x_ins->inputs();
std::transform(
x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins));
auto new_x_ins = insert(mpm, x_ins, x_inputs, cd.axes_map2);
if(has_broadcast)
x_inputs.begin(),
x_inputs.end(),
x_inputs.begin(),
reshape_input(x_ins, &shape_transform_descriptor::generate_common_from_src));
auto new_x_ins = insert(mpm, x_ins, x_inputs, desc.common_axes_map_from_src());
if(new_x_ins->get_shape().lens() != cdims)
{
new_x_ins = mpm.get_module().insert_instruction(
x_ins, make_op("multibroadcast", {{"out_lens", cd.dims}}), new_x_ins);
x_ins, make_op("multibroadcast", {{"out_lens", cdims}}), new_x_ins);
}

auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input == input_ins)
return new_x_ins;
return reshape_input(ins)(input);
return reshape_input(ins,
&shape_transform_descriptor::generate_common_from_dst)(input);
});
auto pw = insert(mpm, ins, inputs, cd.axes_map1);
mpm.get_module().replace_instruction(
ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), pw);
auto pw = insert(mpm, ins, inputs, desc.common_axes_map_from_dst());
auto rins =
reshape_input(ins, &shape_transform_descriptor::generate_dst_from_common)(pw);
mpm.get_module().replace_instruction(ins, rins);
}

static bool same_dims(instruction_ref ins)
Expand Down
24 changes: 24 additions & 0 deletions src/include/migraphx/shape_transform_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor
shape_transform_descriptor() = default;
explicit shape_transform_descriptor(const std::vector<std::size_t>& dims);

static shape_transform_descriptor create(const std::vector<std::size_t>& dims,
const std::vector<operation>& ops);

shape_transform_descriptor rebase(const std::vector<std::size_t>& dims) const;

bool apply(const std::vector<operation>& ops);
bool apply_reshape(const std::vector<std::size_t>& rdims);
bool apply_reshape_impl(const std::vector<std::size_t>& rdims);
Expand All @@ -84,6 +89,22 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor
std::size_t elements() const;
std::vector<operation> generate() const;

bool has_broadcast() const;
void flatten_broadcast();

std::vector<std::size_t> common_dims(const std::vector<std::size_t>& input_dims = {}) const;
std::vector<operation>
generate_common_from_src(const std::vector<std::size_t>& input_dims = {}) const;
std::vector<operation>
generate_common_from_dst(const std::vector<std::size_t>& input_dims = {}) const;
std::vector<operation>
generate_dst_from_common(const std::vector<std::size_t>& input_dims = {}) const;
std::vector<std::vector<std::size_t>> common_axes_map_from_src() const;
std::vector<std::vector<std::size_t>> common_axes_map_from_dst() const;

bool empty() const;
std::vector<std::size_t> lens() const;

struct MIGRAPHX_EXPORT dimension
{
void simplify();
Expand All @@ -105,6 +126,9 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor

void add_split_axis(std::size_t i);

void expose();
void hide();

MIGRAPHX_EXPORT friend bool operator==(const sub& x, const sub& y);
MIGRAPHX_EXPORT friend bool operator!=(const sub& x, const sub& y);
MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const sub& x);
Expand Down
Loading
Loading