Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reorder_slice_add_mul matcher #3478

Open
wants to merge 84 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
709daf0
find_scalar_mul_conv added
aarushjain29 Sep 25, 2024
94c624b
Formatting
aarushjain29 Sep 25, 2024
8304922
Support for scalar in simplify_reshape
aarushjain29 Sep 26, 2024
2e75e28
check for scalar in find_mul_add
aarushjain29 Sep 26, 2024
cd6d0bc
Changes in simplify reshape
aarushjain29 Sep 26, 2024
bc51516
check if output of mul_add is fed in conv
aarushjain29 Sep 26, 2024
f8a4505
Fix bugs in conv
aarushjain29 Sep 26, 2024
b0edd34
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
kahmed10 Oct 2, 2024
69457bd
Logic improved
aarushijai Oct 3, 2024
ba4ed3b
Revert "Logic improved"
aarushjain29 Oct 23, 2024
6a99db9
slicing added
aarushjain29 Oct 30, 2024
e17bb63
removed conv changes
aarushjain29 Oct 30, 2024
3a67dab
removed old code
aarushjain29 Oct 30, 2024
dd94dbe
Merge remote-tracking branch 'origin/develop' into 3432-improve-simpl…
aarushjain29 Oct 30, 2024
4dd052d
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
aarushjain29 Oct 30, 2024
7d7d5d4
Merge branch '3432-improve-simplify_algebra-to-find-more-horizontal-f…
aarushjain29 Oct 30, 2024
4e74cf4
Blocked any input operation for a
aarushjain29 Oct 30, 2024
5f90a6f
merge slice add
aarushjain29 Nov 8, 2024
95b31f7
slice mul add producing mlir_slice_mul_reshape
aarushjain29 Nov 8, 2024
4cd618a
Formatting
aarushjain29 Nov 8, 2024
ab5282b
WIP for make_basic_pred_matcher
aarushjain29 Nov 13, 2024
f49a289
WIP for make_basic_pred_matcher
aarushjain29 Nov 13, 2024
48b1f67
Working fusable_split for slice_add_mul
aarushjain29 Nov 13, 2024
cf62a4e
fusable split for find_mul_add matcher
aarushjain29 Nov 18, 2024
1101f35
formatting
aarushjain29 Nov 18, 2024
be20e61
corrected test case
aarushjain29 Nov 19, 2024
63b9737
corrected test case
aarushjain29 Nov 19, 2024
35c9441
formatting
aarushjain29 Nov 19, 2024
095192f
formatting
aarushjain29 Nov 19, 2024
aff69c3
Formatting errors
aarushjain29 Nov 19, 2024
d1fe6a1
Formatting errors
aarushjain29 Nov 19, 2024
4a787b3
Format
aarushjain29 Nov 19, 2024
4aaa2b9
format
aarushjain29 Nov 19, 2024
a979c11
skip_mul_add_for_horizontal_add_fusion
aarushjain29 Nov 20, 2024
4e11892
change fusable split
aarushjain29 Nov 21, 2024
7143f53
fusable_slice_add_mul_split
aarushjain29 Nov 22, 2024
e9b8c98
working slice_add_mul
aarushjain29 Nov 22, 2024
e4ec31d
refactor
aarushjain29 Nov 22, 2024
f8b4c05
refactor
aarushjain29 Nov 22, 2024
ffec7d8
Refactor
aarushjain29 Nov 22, 2024
0d5a6d5
skip the slice
aarushjain29 Nov 22, 2024
8f4b2b3
formatting
aarushjain29 Nov 22, 2024
4270a27
Formatting
aarushjain29 Nov 22, 2024
8788ba2
description for slice_add_mul
aarushjain29 Nov 22, 2024
7e28351
Removed 1 of the test which is not required
aarushjain29 Nov 25, 2024
9a31825
Changed some code by mistake
aarushjain29 Nov 25, 2024
2a10591
2 more test case added
aarushjain29 Nov 26, 2024
696b842
Formatting
aarushjain29 Nov 26, 2024
cec662f
test cases for clip model
aarushjain29 Nov 26, 2024
fa12fcc
add_dot_add_mul
aarushjain29 Nov 26, 2024
b05e0ba
formatting
aarushjain29 Nov 26, 2024
38b51e1
changed the name to slice_add
aarushjain29 Nov 26, 2024
7c30359
Formatting
aarushjain29 Nov 26, 2024
271bd2b
Formatting
aarushjain29 Nov 26, 2024
293dd51
change in mul_add struct - none_of added
aarushjain29 Nov 26, 2024
cef6880
changes in fusable_split
aarushjain29 Nov 27, 2024
f7f0996
Changed slice_mul_add to merge slice_mul
aarushijai Nov 27, 2024
892a26d
Revert "Changed slice_mul_add to merge slice_mul"
aarushijai Nov 27, 2024
e8e46d1
Changed slice_add_mul test case to fuse mul
aarushjain29 Nov 27, 2024
f478e04
test case for add_dot_add for each branch added
aarushjain29 Nov 27, 2024
e85f850
loop for literals
aarushjain29 Nov 27, 2024
ec4cb3e
removed older test case dot_add_mul
aarushjain29 Nov 27, 2024
318e36a
Added test for dot_add_mul with loop
aarushjain29 Nov 27, 2024
6be46f4
loop changes in add_dot_add_mul_2
aarushjain29 Nov 27, 2024
ad1bb34
seed values changed
aarushjain29 Nov 27, 2024
9804425
loop added for add_dot_add_mul_1
aarushjain29 Nov 27, 2024
7ebf297
shape s1 and s2
aarushjain29 Nov 27, 2024
b7f3037
CI Formatting
aarushjain29 Dec 2, 2024
256456a
CI Formatting
aarushjain29 Dec 2, 2024
9eceb9f
removing tidy errors
aarushjain29 Dec 2, 2024
6a50c51
use of literals reserve to remove tidy errors
aarushjain29 Dec 2, 2024
9b00fd8
use of literals reserve to remove tidy errors
aarushjain29 Dec 2, 2024
35a20e0
semicolon missing added
aarushjain29 Dec 2, 2024
58a4a0f
Removing tidy errors CI
aarushjain29 Dec 2, 2024
ce65be9
semicolon missing added
aarushjain29 Dec 2, 2024
3d070cb
Formatting
aarushjain29 Dec 2, 2024
75da7c3
Formatting
aarushjain29 Dec 2, 2024
28eb64d
generating literals in loop for add_dot_add_mul_1
aarushjain29 Dec 3, 2024
c59d2fa
generating literals in loop for add_dot_add_mul_1
aarushjain29 Dec 3, 2024
0eeb663
WIP change in add_dot_add_mul_2
aarushjain29 Dec 3, 2024
820bb12
loops added
aarushjain29 Dec 3, 2024
3950cd3
Formatting
aarushjain29 Dec 3, 2024
0d99d23
Formatting
aarushjain29 Dec 4, 2024
8bd694e
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
aarushjain29 Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 73 additions & 9 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,21 @@
}
};

auto fusable_split(const std::string& name)
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved
{
return match::make_basic_pred_matcher([&](instruction_ref ins) {
return all_of(ins->outputs(), [&](instruction_ref slice) {
if(slice->name() != "slice")
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should return true as we want to skip the outputs that are not slices.

return any_of(slice->outputs(), [&](instruction_ref x) {
if (x->name() == name)
return true;
});
});
});
}


// ******************************
// a * (x + b) => a * x + a * b
// ******************************
Expand All @@ -421,15 +436,22 @@
struct find_mul_add
{
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
{
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved
auto slice_1 = match::skip(match::name("slice"))(match::none_of(fusable_split("add")));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be match::none_of(match::name("slice")(match::arg(0)(fusable_split("add")))).


Check warning on line 441 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / tidy

non-void lambda does not return a value in all control paths [clang-diagnostic-return-type,-warnings-as-errors]
return match::name("mul")(
match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
slice_1.bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")
),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()
),
match::is_constant().bind("a")
)
);
}

void apply(module& m, const match::matcher_result& r) const
Expand All @@ -443,6 +465,46 @@
auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);

}
};

// When a slice is followed by a*(x+b), a⋅(x+b), this matcher performs
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved
// the add first, followed by the mul.
// If multiple slices originate from the same instruction and are followed by add,
// all the adds can be const folded and performed before the slicing.
struct find_slice_add_mul
{
auto matcher() const
{
auto slice_1 = match::name("slice")(match::arg(0)(fusable_split("add")));

return match::name("mul")(
match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
slice_1.bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")
),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()
),
match::is_constant().bind("a")
)
);
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];

assert(x_ins != b_ins);

auto ax_ins = m.insert_instruction(ins, make_op("add"), x_ins, b_ins);
m.replace_instruction(ins, make_op("mul"), ax_ins, a_ins);

}
};

Expand Down Expand Up @@ -1569,6 +1631,7 @@

auto outputs = ins->outputs();
group_by(outputs.begin(), outputs.end(), each, pred);

}
};

Expand Down Expand Up @@ -1981,6 +2044,7 @@
find_dot_slice{},
find_dot_mul{},
find_mul_add{},
find_slice_add_mul{},
find_unit_ops{},
find_neg_unit_ops{},
eliminate_zero_point{},
Expand Down
116 changes: 115 additions & 1 deletion test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@
EXPECT(m1 == m2);
}


TEST_CASE(simplify_conv_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 32, 32}};
Expand Down Expand Up @@ -4133,6 +4134,63 @@
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(skip_mul_add_for_horizontal_add_fusion)
{
migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}};

migraphx::module m1;
{
auto a = m1.add_parameter("a", as);

auto slice_a = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), a);
auto slice_b = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}), a);
auto slice_c = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}), a);

auto one = m1.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 3.0f));
auto two = m1.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
auto three = m1.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 4.0f));

auto add1 = m1.add_instruction(migraphx::make_op("add"), slice_a, one);

auto add2 = m1.add_instruction(migraphx::make_op("add"), slice_b, two);
auto add3 = m1.add_instruction(migraphx::make_op("add"), slice_c, three);

m1.add_return({add1});

};

run_pass(m1);

migraphx::module m2;
{
auto a = m2.add_parameter("a", as);

auto one = m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 3.0f));
auto two = m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
auto three = m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {1, 77, 768}}, 4.0f));

auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), one, two, three);

auto add = m2.add_instruction(migraphx::make_op("add"), a, concat);

auto slice_a = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), add);

m2.add_return({slice_a});
};

EXPECT(m1.sort() == m2.sort());
}
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved

TEST_CASE(dot_slice_ab)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
Expand Down Expand Up @@ -4195,6 +4253,61 @@
EXPECT(m1.sort() == m2.sort());
}


TEST_CASE(complex_graph_operations)
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved
{
migraphx::module m;

auto x_0 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 0));
auto x_1 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 1));
auto x_2 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 2));
auto x_3 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 3));
auto x_4 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 4));
auto x_5 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 5));
auto x_6 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 6));
auto x_7 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 7));
auto x_8 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 8));
auto x_9 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 9));
auto x_10 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 10));
auto x_11 = m.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {64, 512}}, 11));
auto x_12 = m.add_literal(migraphx::literal({migraphx::shape::float_type, {1}}, {0.125}));


auto p_x = m.add_parameter("x", {migraphx::shape::float_type, {64, 64}});


auto x_14 = m.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {64, 512}}}), x_12);


auto x_15 = m.add_instruction(migraphx::make_op("dot"), p_x, x_11);
auto x_16 = m.add_instruction(migraphx::make_op("add"), x_15, x_10);

Check warning on line 4284 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

Value stored to 'add2' during its initialization is never read [clang-analyzer-deadcode.DeadStores,-warnings-as-errors]

Check warning on line 4284 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

unused variable 'add2' [clang-diagnostic-unused-variable,-warnings-as-errors]
auto x_17 = m.add_instruction(migraphx::make_op("add"), x_16, x_9);

Check warning on line 4285 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

Value stored to 'add3' during its initialization is never read [clang-analyzer-deadcode.DeadStores,-warnings-as-errors]

Check warning on line 4285 in test/simplify_algebra_test.cpp

View workflow job for this annotation

GitHub Actions / tidy

unused variable 'add3' [clang-diagnostic-unused-variable,-warnings-as-errors]
auto x_18 = m.add_instruction(migraphx::make_op("add"), x_17, x_8);

auto x_19 = m.add_instruction(migraphx::make_op("dot"), p_x, x_7);
auto x_20 = m.add_instruction(migraphx::make_op("add"), x_19, x_6);
auto x_21 = m.add_instruction(migraphx::make_op("add"), x_20, x_5);
auto x_22 = m.add_instruction(migraphx::make_op("add"), x_21, x_4);

auto x_23 = m.add_instruction(migraphx::make_op("dot"), p_x, x_3);
auto x_24 = m.add_instruction(migraphx::make_op("add"), x_23, x_2);
auto x_25 = m.add_instruction(migraphx::make_op("add"), x_24, x_1);
auto x_26 = m.add_instruction(migraphx::make_op("add"), x_25, x_0);

auto x_27 = m.add_instruction(migraphx::make_op("mul"), x_26, x_14);

m.add_return({x_18, x_22, x_27});

run_pass(m);

EXPECT(m.get_output_shapes().size() == 3);
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved


}



TEST_CASE(dot_slice_not_applicable_1)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
Expand All @@ -4209,7 +4322,8 @@
{{"axes", {1, 2}}, {"starts", {64, 64}}, {"ends", {128, 128}}}),
dot);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), dot);
migraphx::make_op("slice", {{"axes", {0}}, {"starts"
, {0}}, {"ends", {1}}}), dot);

m1.add_return({slice1, slice2});
};
Expand Down
Loading