-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
reorder_slice_add_mul matcher #3478
base: develop
Are you sure you want to change the base?
Changes from 33 commits
709daf0
94c624b
8304922
2e75e28
cd6d0bc
bc51516
f8a4505
b0edd34
69457bd
ba4ed3b
6a99db9
e17bb63
3a67dab
dd94dbe
4dd052d
7d7d5d4
4e74cf4
5f90a6f
95b31f7
4cd618a
ab5282b
f49a289
48b1f67
cf62a4e
1101f35
be20e61
63b9737
35c9441
095192f
aff69c3
d1fe6a1
4a787b3
4aaa2b9
a979c11
4e11892
7143f53
e9b8c98
e4ec31d
f8b4c05
ffec7d8
0d5a6d5
8f4b2b3
4270a27
8788ba2
7e28351
9a31825
2a10591
696b842
cec662f
fa12fcc
b05e0ba
38b51e1
7c30359
271bd2b
293dd51
cef6880
f7f0996
892a26d
e8e46d1
f478e04
e85f850
ec4cb3e
318e36a
6be46f4
ad1bb34
9804425
7ebf297
b7f3037
256456a
9eceb9f
6a50c51
9b00fd8
35a20e0
58a4a0f
ce65be9
3d070cb
75da7c3
28eb64d
c59d2fa
0eeb663
820bb12
3950cd3
0d99d23
8bd694e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -412,24 +412,83 @@ | |
} | ||
}; | ||
|
||
|
||
|
||
auto fusable_split() | ||
{ | ||
return match::make_basic_pred_matcher([](instruction_ref ins) { | ||
|
||
std::vector<instruction_ref> slices; | ||
|
||
for (const auto& output : ins->outputs()) | ||
{ | ||
if (output->name() == "slice") | ||
{ | ||
slices.push_back(output); | ||
} | ||
} | ||
|
||
if (slices.empty()) | ||
return false; | ||
|
||
std::vector<instruction_ref> add_instructions; | ||
for (const auto& slice : slices) | ||
{ | ||
bool used_by_add = false; | ||
for (auto& user : slice->outputs()) | ||
{ | ||
if (user->name() == "add") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dont hardcode |
||
{ | ||
used_by_add = true; | ||
add_instructions.push_back(user); | ||
} | ||
} | ||
if (!used_by_add) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This flag isn't necessary. Please just check for add_instructions.empty() later. |
||
return false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should return true as we want to skip the outputs that are not slices. |
||
} | ||
|
||
bool any_add_followed_by_mul = false; | ||
for (const auto& add_ins : add_instructions) | ||
{ | ||
for (auto& user : add_ins->outputs()) | ||
{ | ||
if (user->name() == "mul") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dont check for |
||
{ | ||
any_add_followed_by_mul = true; | ||
ins->outputs(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this line, without any assignment do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It does nothing. We should probably add |
||
} | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All this is really hard to follow and the intermediate vectors are really unnecessary. This can be done with nested all_of+any_of algorithms: auto fusable_split(const std::string& name)
{
return match::make_basic_pred_matcher([&](instruction_ref ins) {
return all_of(ins->outputs(), [&](instruction_ref slice) {
if(output->name() != "slice")
return true;
return any_of(slice->outputs(), [&](instruction_ref x) {
return x->name() == name;
});
});
});
} You still might need an extra |
||
|
||
if (!any_add_followed_by_mul) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of the if/else clause just |
||
return false; | ||
|
||
return true; | ||
}); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These matchers should not be moved out into seperate functions. They should be contained in the find class that uses them. |
||
|
||
|
||
// ****************************** | ||
// a * (x + b) => a * x + a * b | ||
// ****************************** | ||
// When a * (x + b) is followed by another add of constant, then the | ||
// additional add can be const folded. Also, better fusions can be applied | ||
// when the add comes after. | ||
|
||
struct find_mul_add | ||
{ | ||
auto matcher() const | ||
{ | ||
return match::name("mul")(match::either_arg(0, 1)( | ||
match::name("add")( | ||
match::either_arg(0, 1)( | ||
match::any().bind("x"), | ||
match::any_of(conv_const_weights(), match::is_constant()).bind("b")), | ||
match::none_of(match::args(match::is_constant(), match::is_constant())), | ||
match::used_once()), | ||
match::is_constant().bind("a"))); | ||
match::name("add")( | ||
match::either_arg(0, 1)( | ||
match::none_of(match::name("slice")(match::arg(0)( | ||
fusable_split().bind("slc") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do bind this? Its not used. It should be removed. |
||
))).bind("x"), | ||
match::any_of(conv_const_weights(), match::is_constant()).bind("b")), | ||
match::none_of(match::args(match::is_constant(), match::is_constant())), | ||
match::used_once()), | ||
match::is_constant().bind("a"))); | ||
aarushjain29 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
void apply(module& m, const match::matcher_result& r) const | ||
|
@@ -439,13 +498,48 @@ | |
auto b_ins = r.instructions["b"]; | ||
auto x_ins = r.instructions["x"]; | ||
assert(x_ins != b_ins); | ||
|
||
auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins); | ||
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins); | ||
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins); | ||
|
||
} | ||
}; | ||
|
||
|
||
struct find_slice_add_mul | ||
aarushjain29 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
|
||
auto matcher() const | ||
{ | ||
return match::name("mul")(match::either_arg(0, 1)( | ||
match::name("add")( | ||
match::either_arg(0, 1)( | ||
match::name("slice")(match::arg(0)( | ||
fusable_split().bind("slc") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. |
||
)).bind("x"), | ||
match::any_of(conv_const_weights(), match::is_constant()).bind("b")), | ||
match::none_of(match::args(match::is_constant(), match::is_constant())), | ||
match::used_once()), | ||
match::is_constant().bind("a"))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please break up this matcher using variables like above in the |
||
} | ||
void apply(module& m, const match::matcher_result& r) const | ||
{ | ||
|
||
auto ins = r.result; | ||
auto a_ins = r.instructions["a"]; | ||
auto b_ins = r.instructions["b"]; | ||
auto x_ins = r.instructions["x"]; | ||
|
||
assert(x_ins != b_ins); | ||
|
||
auto ax_ins = m.insert_instruction(ins, make_op("add"), x_ins, b_ins); | ||
m.replace_instruction(ins, make_op("mul"), ax_ins, a_ins); | ||
|
||
} | ||
}; | ||
|
||
|
||
struct find_dot_add | ||
{ | ||
auto matcher() const | ||
|
@@ -1569,6 +1663,7 @@ | |
|
||
auto outputs = ins->outputs(); | ||
group_by(outputs.begin(), outputs.end(), each, pred); | ||
|
||
} | ||
}; | ||
|
||
|
@@ -1981,6 +2076,7 @@ | |
find_dot_slice{}, | ||
find_dot_mul{}, | ||
find_mul_add{}, | ||
find_slice_add_mul{}, | ||
find_unit_ops{}, | ||
find_neg_unit_ops{}, | ||
eliminate_zero_point{}, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -509,6 +509,7 @@ | |
EXPECT(m1 == m2); | ||
} | ||
|
||
|
||
TEST_CASE(simplify_conv_add) | ||
{ | ||
migraphx::shape s{migraphx::shape::float_type, {1, 3, 32, 32}}; | ||
|
@@ -4133,6 +4134,46 @@ | |
EXPECT(m1.sort() == m2.sort()); | ||
} | ||
|
||
TEST_CASE(my_optim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test case will test that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use better names for the test case names. |
||
{ | ||
migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}}; | ||
|
||
migraphx::module m2; | ||
{ | ||
auto a = m2.add_parameter("b", as); | ||
|
||
auto slice_a = m2.add_instruction( | ||
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {768}}}), a); | ||
auto slice_b = m2.add_instruction( | ||
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {768}}, {"ends", {1536}}}), a); | ||
auto slice_c = m2.add_instruction( | ||
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1536}}, {"ends", {2304}}}), a); | ||
|
||
auto one = m2.add_literal( | ||
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 3.0f)); | ||
auto two = m2.add_literal( | ||
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f)); | ||
auto three = m2.add_literal( | ||
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 4.0f)); | ||
auto four = m2.add_literal( | ||
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f)); | ||
|
||
|
||
auto add1 = m2.add_instruction(migraphx::make_op("add"), one, slice_a); | ||
auto mul1 = m2.add_instruction(migraphx::make_op("mul"), add1, four); | ||
|
||
auto add2 = m2.add_instruction(migraphx::make_op("add"), two, slice_b); | ||
Check warning on line 4165 in test/simplify_algebra_test.cpp GitHub Actions / tidy
|
||
auto add3 = m2.add_instruction(migraphx::make_op("add"), three, slice_c); | ||
Check warning on line 4166 in test/simplify_algebra_test.cpp GitHub Actions / tidy
|
||
|
||
m2.add_return({mul1}); | ||
}; | ||
migraphx::module m1 = m2; | ||
run_pass(m2); | ||
|
||
EXPECT(m1.sort() == m2.sort()); | ||
|
||
} | ||
aarushjain29 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
TEST_CASE(dot_slice_ab) | ||
{ | ||
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; | ||
|
@@ -4195,6 +4236,61 @@ | |
EXPECT(m1.sort() == m2.sort()); | ||
} | ||
|
||
|
||
TEST_CASE(complex_graph_operations) | ||
aarushjain29 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
migraphx::module m; | ||
|
||
auto x_0 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 0)); | ||
auto x_1 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 1)); | ||
auto x_2 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 2)); | ||
auto x_3 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 3)); | ||
auto x_4 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 4)); | ||
auto x_5 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 5)); | ||
auto x_6 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 6)); | ||
auto x_7 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 7)); | ||
auto x_8 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 8)); | ||
auto x_9 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 9)); | ||
auto x_10 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 10)); | ||
auto x_11 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 11)); | ||
auto x_12 = m.add_literal(migraphx::literal({migraphx::shape::float_type, {1}}, {0.125})); | ||
|
||
|
||
auto p_x = m.add_parameter("x", {migraphx::shape::float_type, {64, 64}}); | ||
|
||
|
||
auto x_14 = m.add_instruction( | ||
migraphx::make_op("multibroadcast", {{"out_lens", {64, 512}}}), x_12); | ||
|
||
|
||
auto x_15 = m.add_instruction(migraphx::make_op("dot"), p_x, x_11); | ||
auto x_16 = m.add_instruction(migraphx::make_op("add"), x_15, x_10); | ||
auto x_17 = m.add_instruction(migraphx::make_op("add"), x_16, x_9); | ||
auto x_18 = m.add_instruction(migraphx::make_op("add"), x_17, x_8); | ||
|
||
auto x_19 = m.add_instruction(migraphx::make_op("dot"), p_x, x_7); | ||
auto x_20 = m.add_instruction(migraphx::make_op("add"), x_19, x_6); | ||
auto x_21 = m.add_instruction(migraphx::make_op("add"), x_20, x_5); | ||
auto x_22 = m.add_instruction(migraphx::make_op("add"), x_21, x_4); | ||
|
||
auto x_23 = m.add_instruction(migraphx::make_op("dot"), p_x, x_3); | ||
auto x_24 = m.add_instruction(migraphx::make_op("add"), x_23, x_2); | ||
auto x_25 = m.add_instruction(migraphx::make_op("add"), x_24, x_1); | ||
auto x_26 = m.add_instruction(migraphx::make_op("add"), x_25, x_0); | ||
|
||
auto x_27 = m.add_instruction(migraphx::make_op("mul"), x_26, x_14); | ||
|
||
m.add_return({x_18, x_22, x_27}); | ||
|
||
run_pass(m); | ||
|
||
EXPECT(m.get_output_shapes().size() == 3); | ||
aarushjain29 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
} | ||
|
||
|
||
|
||
TEST_CASE(dot_slice_not_applicable_1) | ||
{ | ||
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be called
fusable_slice_add_mul_split
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes with the current implementation that is true, but we should add a parameter to take the operator name, and I dont see a reason to check for
mul
.