Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reorder_slice_add_mul matcher #3478

Open
wants to merge 84 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
709daf0
find_scalar_mul_conv added
aarushjain29 Sep 25, 2024
94c624b
Formatting
aarushjain29 Sep 25, 2024
8304922
Support for scalar in simplify_reshape
aarushjain29 Sep 26, 2024
2e75e28
check for scalar in find_mul_add
aarushjain29 Sep 26, 2024
cd6d0bc
Changes in simplify reshape
aarushjain29 Sep 26, 2024
bc51516
check if output of mul_add is fed in conv
aarushjain29 Sep 26, 2024
f8a4505
Fix bugs in conv
aarushjain29 Sep 26, 2024
b0edd34
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
kahmed10 Oct 2, 2024
69457bd
Logic improved
aarushijai Oct 3, 2024
ba4ed3b
Revert "Logic improved"
aarushjain29 Oct 23, 2024
6a99db9
slicing added
aarushjain29 Oct 30, 2024
e17bb63
removed conv changes
aarushjain29 Oct 30, 2024
3a67dab
removed old code
aarushjain29 Oct 30, 2024
dd94dbe
Merge remote-tracking branch 'origin/develop' into 3432-improve-simpl…
aarushjain29 Oct 30, 2024
4dd052d
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
aarushjain29 Oct 30, 2024
7d7d5d4
Merge branch '3432-improve-simplify_algebra-to-find-more-horizontal-f…
aarushjain29 Oct 30, 2024
4e74cf4
Blocked any input operation for a
aarushjain29 Oct 30, 2024
5f90a6f
merge slice add
aarushjain29 Nov 8, 2024
95b31f7
slice mul add producing mlir_slice_mul_reshape
aarushjain29 Nov 8, 2024
4cd618a
Formatting
aarushjain29 Nov 8, 2024
ab5282b
WIP for make_basic_pred_matcher
aarushjain29 Nov 13, 2024
f49a289
WIP for make_basic_pred_matcher
aarushjain29 Nov 13, 2024
48b1f67
Working fusable_split for slice_add_mul
aarushjain29 Nov 13, 2024
cf62a4e
fusable split for find_mul_add matcher
aarushjain29 Nov 18, 2024
1101f35
formatting
aarushjain29 Nov 18, 2024
be20e61
corrected test case
aarushjain29 Nov 19, 2024
63b9737
corrected test case
aarushjain29 Nov 19, 2024
35c9441
formatting
aarushjain29 Nov 19, 2024
095192f
formatting
aarushjain29 Nov 19, 2024
aff69c3
Formatting errors
aarushjain29 Nov 19, 2024
d1fe6a1
Formatting errors
aarushjain29 Nov 19, 2024
4a787b3
Format
aarushjain29 Nov 19, 2024
4aaa2b9
format
aarushjain29 Nov 19, 2024
a979c11
skip_mul_add_for_horizontal_add_fusion
aarushjain29 Nov 20, 2024
4e11892
change fusable split
aarushjain29 Nov 21, 2024
7143f53
fusable_slice_add_mul_split
aarushjain29 Nov 22, 2024
e9b8c98
working slice_add_mul
aarushjain29 Nov 22, 2024
e4ec31d
refactor
aarushjain29 Nov 22, 2024
f8b4c05
refactor
aarushjain29 Nov 22, 2024
ffec7d8
Refactor
aarushjain29 Nov 22, 2024
0d5a6d5
skip the slice
aarushjain29 Nov 22, 2024
8f4b2b3
formatting
aarushjain29 Nov 22, 2024
4270a27
Formatting
aarushjain29 Nov 22, 2024
8788ba2
description for slice_add_mul
aarushjain29 Nov 22, 2024
7e28351
Removed 1 of the test which is not required
aarushjain29 Nov 25, 2024
9a31825
Changed some code by mistake
aarushjain29 Nov 25, 2024
2a10591
2 more test case added
aarushjain29 Nov 26, 2024
696b842
Formatting
aarushjain29 Nov 26, 2024
cec662f
test cases for clip model
aarushjain29 Nov 26, 2024
fa12fcc
add_dot_add_mul
aarushjain29 Nov 26, 2024
b05e0ba
formatting
aarushjain29 Nov 26, 2024
38b51e1
changed the name to slice_add
aarushjain29 Nov 26, 2024
7c30359
Formatting
aarushjain29 Nov 26, 2024
271bd2b
Formatting
aarushjain29 Nov 26, 2024
293dd51
change in mul_add struct - none_of added
aarushjain29 Nov 26, 2024
cef6880
changes in fusable_split
aarushjain29 Nov 27, 2024
f7f0996
Changed slice_mul_add to merge slice_mul
aarushijai Nov 27, 2024
892a26d
Revert "Changed slice_mul_add to merge slice_mul"
aarushijai Nov 27, 2024
e8e46d1
Changed slice_add_mul test case to fuse mul
aarushjain29 Nov 27, 2024
f478e04
test case for add_dot_add for each branch added
aarushjain29 Nov 27, 2024
e85f850
loop for literals
aarushjain29 Nov 27, 2024
ec4cb3e
removed older test case dot_add_mul
aarushjain29 Nov 27, 2024
318e36a
Added test for dot_add_mul with loop
aarushjain29 Nov 27, 2024
6be46f4
loop changes in add_dot_add_mul_2
aarushjain29 Nov 27, 2024
ad1bb34
seed values changed
aarushjain29 Nov 27, 2024
9804425
loop added for add_dot_add_mul_1
aarushjain29 Nov 27, 2024
7ebf297
shape s1 and s2
aarushjain29 Nov 27, 2024
b7f3037
CI Formatting
aarushjain29 Dec 2, 2024
256456a
CI Formatting
aarushjain29 Dec 2, 2024
9eceb9f
removing tidy errors
aarushjain29 Dec 2, 2024
6a50c51
use of literals reserve to remove tidy errors
aarushjain29 Dec 2, 2024
9b00fd8
use of literals reserve to remove tidy errors
aarushjain29 Dec 2, 2024
35a20e0
semicolon missing added
aarushjain29 Dec 2, 2024
58a4a0f
Removing tidy errors CI
aarushjain29 Dec 2, 2024
ce65be9
semicolon missing added
aarushjain29 Dec 2, 2024
3d070cb
Formatting
aarushjain29 Dec 2, 2024
75da7c3
Formatting
aarushjain29 Dec 2, 2024
28eb64d
generating literals in loop for add_dot_add_mul_1
aarushjain29 Dec 3, 2024
c59d2fa
generating literals in loop for add_dot_add_mul_1
aarushjain29 Dec 3, 2024
0eeb663
WIP change in add_dot_add_mul_2
aarushjain29 Dec 3, 2024
820bb12
loops added
aarushjain29 Dec 3, 2024
3950cd3
Formatting
aarushjain29 Dec 3, 2024
0d99d23
Formatting
aarushjain29 Dec 4, 2024
8bd694e
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
aarushjain29 Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 104 additions & 8 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,24 +412,83 @@
}
};



auto fusable_split()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be called fusable_slice_add_mul_split?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be called fusable_slice_add_mul_split?

Yes with the current implementation that is true, but we should add a parameter to take the operator name, and I dont see a reason to check for mul.

{
return match::make_basic_pred_matcher([](instruction_ref ins) {

std::vector<instruction_ref> slices;

for (const auto& output : ins->outputs())
{
if (output->name() == "slice")
{
slices.push_back(output);

Check warning on line 427 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Consider using std::copy_if algorithm instead of a raw loop. [useStlAlgorithm]
}
}

if (slices.empty())
return false;

std::vector<instruction_ref> add_instructions;
for (const auto& slice : slices)
{
bool used_by_add = false;
for (auto& user : slice->outputs())

Check warning on line 438 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / tidy

'auto &user' can be declared as 'const auto &user' [readability-qualified-auto,-warnings-as-errors]
{
if (user->name() == "add")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont hardcode add here. Take a parameter to the function that uses the name.

{
used_by_add = true;
add_instructions.push_back(user);
}
}
if (!used_by_add)

Check warning on line 446 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Use 'not' instead of ! [UseNamedLogicOperator]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag isn't necessary. Please just check for add_instructions.empty() later.

return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should return true as we want to skip the outputs that are not slices.

}

bool any_add_followed_by_mul = false;
for (const auto& add_ins : add_instructions)
{
for (auto& user : add_ins->outputs())

Check warning on line 453 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / tidy

'auto &user' can be declared as 'const auto &user' [readability-qualified-auto,-warnings-as-errors]
{
if (user->name() == "mul")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont check for mul thats outside of the scope of this predicate.

{
any_add_followed_by_mul = true;
ins->outputs();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this line, without any assignment do?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this line, without any assignment do?

It does nothing. We should probably add [[nodiscard]] to those functions(in another PR) so we get a warning for this.

}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this is really hard to follow and the intermediate vectors are really unnecessary. This can be done with nested all_of+any_of algorithms:

auto fusable_split(const std::string& name)
{
    return match::make_basic_pred_matcher([&](instruction_ref ins) {
        return all_of(ins->outputs(), [&](instruction_ref slice) {
            if(output->name() != "slice")
                return true;
            return any_of(slice->outputs(), [&](instruction_ref x) {
                return x->name() == name;
            });
        });
    });
}

You still might need an extra any_of check for the case when there is no slices.


if (!any_add_followed_by_mul)

Check warning on line 463 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Use 'not' instead of ! [UseNamedLogicOperator]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of the if/else clause just return any_add_followed_by_mul might be okay.

return false;

Check warning on line 464 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / tidy

redundant boolean literal in conditional return statement [readability-simplify-boolean-expr,-warnings-as-errors]

return true;
});
}
Copy link
Collaborator

@pfultz2 pfultz2 Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These matchers should not be moved out into seperate functions. They should be contained in the find class that uses them.



// ******************************
// a * (x + b) => a * x + a * b
// ******************************
// When a * (x + b) is followed by another add of constant, then the
// additional add can be const folded. Also, better fusions can be applied
// when the add comes after.

struct find_mul_add
{
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
match::name("add")(
match::either_arg(0, 1)(
match::none_of(match::name("slice")(match::arg(0)(
fusable_split().bind("slc")

Check warning on line 486 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Too many nested parentheses can affect readability; consider using variables instead. [migraphx-MatcherNestedParentheses]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do bind this? Its not used. It should be removed.

))).bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved
}

void apply(module& m, const match::matcher_result& r) const
Expand All @@ -439,13 +498,48 @@
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);

auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);

}
};


struct find_slice_add_mul
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved
{

auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::name("slice")(match::arg(0)(
fusable_split().bind("slc")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

)).bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please break up this matcher using variables like above in the find_dot_mul matcher.

}
void apply(module& m, const match::matcher_result& r) const
{

auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];

assert(x_ins != b_ins);

auto ax_ins = m.insert_instruction(ins, make_op("add"), x_ins, b_ins);
m.replace_instruction(ins, make_op("mul"), ax_ins, a_ins);

}
};


struct find_dot_add
{
auto matcher() const
Expand Down Expand Up @@ -1569,6 +1663,7 @@

auto outputs = ins->outputs();
group_by(outputs.begin(), outputs.end(), each, pred);

}
};

Expand Down Expand Up @@ -1981,6 +2076,7 @@
find_dot_slice{},
find_dot_mul{},
find_mul_add{},
find_slice_add_mul{},
find_unit_ops{},
find_neg_unit_ops{},
eliminate_zero_point{},
Expand Down
96 changes: 96 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@
EXPECT(m1 == m2);
}


TEST_CASE(simplify_conv_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 32, 32}};
Expand Down Expand Up @@ -4133,6 +4134,46 @@
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(my_optim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case will test that the find_mul_add will not swap the mul and add. A good name for this test would be something like skip_mul_add_for_horizontal_add_fusion. Add another similar test case that tests find_slice_add_mul will swap if it finds the mul before add.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use better names for the test case names.

{
migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}};

migraphx::module m2;
{
auto a = m2.add_parameter("b", as);

auto slice_a = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {768}}}), a);
auto slice_b = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {768}}, {"ends", {1536}}}), a);
auto slice_c = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {1536}}, {"ends", {2304}}}), a);

auto one = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 3.0f));
auto two = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
auto three = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 4.0f));
auto four = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f));


auto add1 = m2.add_instruction(migraphx::make_op("add"), one, slice_a);
auto mul1 = m2.add_instruction(migraphx::make_op("mul"), add1, four);

auto add2 = m2.add_instruction(migraphx::make_op("add"), two, slice_b);

Check warning on line 4165 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

Value stored to 'add2' during its initialization is never read [clang-analyzer-deadcode.DeadStores,-warnings-as-errors]

Check warning on line 4165 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

unused variable 'add2' [clang-diagnostic-unused-variable,-warnings-as-errors]
auto add3 = m2.add_instruction(migraphx::make_op("add"), three, slice_c);

Check warning on line 4166 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

Value stored to 'add3' during its initialization is never read [clang-analyzer-deadcode.DeadStores,-warnings-as-errors]

Check warning on line 4166 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

unused variable 'add3' [clang-diagnostic-unused-variable,-warnings-as-errors]

m2.add_return({mul1});
};
migraphx::module m1 = m2;
run_pass(m2);

EXPECT(m1.sort() == m2.sort());

}
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved

TEST_CASE(dot_slice_ab)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
Expand Down Expand Up @@ -4195,6 +4236,61 @@
EXPECT(m1.sort() == m2.sort());
}


TEST_CASE(complex_graph_operations)
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved
{
migraphx::module m;

auto x_0 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 0));
auto x_1 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 1));
auto x_2 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 2));
auto x_3 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 3));
auto x_4 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 4));
auto x_5 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 5));
auto x_6 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 6));
auto x_7 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 7));
auto x_8 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 8));
auto x_9 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 9));
auto x_10 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 10));
auto x_11 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 11));
auto x_12 = m.add_literal(migraphx::literal({migraphx::shape::float_type, {1}}, {0.125}));


auto p_x = m.add_parameter("x", {migraphx::shape::float_type, {64, 64}});


auto x_14 = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {64, 512}}}), x_12);


auto x_15 = m.add_instruction(migraphx::make_op("dot"), p_x, x_11);
auto x_16 = m.add_instruction(migraphx::make_op("add"), x_15, x_10);
auto x_17 = m.add_instruction(migraphx::make_op("add"), x_16, x_9);
auto x_18 = m.add_instruction(migraphx::make_op("add"), x_17, x_8);

auto x_19 = m.add_instruction(migraphx::make_op("dot"), p_x, x_7);
auto x_20 = m.add_instruction(migraphx::make_op("add"), x_19, x_6);
auto x_21 = m.add_instruction(migraphx::make_op("add"), x_20, x_5);
auto x_22 = m.add_instruction(migraphx::make_op("add"), x_21, x_4);

auto x_23 = m.add_instruction(migraphx::make_op("dot"), p_x, x_3);
auto x_24 = m.add_instruction(migraphx::make_op("add"), x_23, x_2);
auto x_25 = m.add_instruction(migraphx::make_op("add"), x_24, x_1);
auto x_26 = m.add_instruction(migraphx::make_op("add"), x_25, x_0);

auto x_27 = m.add_instruction(migraphx::make_op("mul"), x_26, x_14);

m.add_return({x_18, x_22, x_27});

run_pass(m);

EXPECT(m.get_output_shapes().size() == 3);
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved


}



TEST_CASE(dot_slice_not_applicable_1)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
Expand Down
Loading