From 69457bd20ecbc54c23fbcbfa4d3d50cf221ca844 Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Thu, 3 Oct 2024 12:11:11 -0500 Subject: [PATCH] Logic improved --- src/simplify_algebra.cpp | 93 ++++++++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 36 deletions(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index e6b6b46a73d..259dba110fc 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -422,14 +422,16 @@ struct find_mul_add { auto matcher() const { - return match::name("mul")(match::either_arg(0, 1)( - match::name("add")( - match::either_arg(0, 1)( - match::any().bind("x"), - match::any_of(conv_const_weights(), match::is_constant()).bind("b")), - match::none_of(match::args(match::is_constant(), match::is_constant())), - match::used_once()), - match::is_constant().bind("a"))); + return match::name("mul")( + match::none_of[match::outputs()](match::name("convolution")), + match::either_arg(0, 1)( + match::name("add")( + match::either_arg(0, 1)( + match::any().bind("x"), + match::any_of(conv_const_weights(), match::is_constant()).bind("b")), + match::none_of(match::args(match::is_constant(), match::is_constant())), + match::used_once()), + match::is_constant().bind("a"))); } void apply(module& m, const match::matcher_result& r) const @@ -440,15 +442,42 @@ struct find_mul_add auto x_ins = r.instructions["x"]; assert(x_ins != b_ins); - if(a_ins->get_shape().scalar()){ - return; - } + auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins); + auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins); + m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins); + } +}; +/* Delete this later +struct find_mul_add +{ + auto matcher() const + { + return match::name("mul")( + match::none_of[match::outputs()](match::name("convolution")), + match::either_arg(0, 1)( + match::name("add")( + match::either_arg(0, 1)( + match::any().bind("x"), + match::any_of(conv_const_weights(), match::is_constant()).bind("b")), + match::none_of(match::args(match::is_constant(), match::is_constant())), + match::used_once()), + match::is_constant().bind("a"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto a_ins = r.instructions["a"]; + auto b_ins = r.instructions["b"]; + auto x_ins = r.instructions["x"]; + assert(x_ins != b_ins); auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins); auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins); m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins); } }; +*/ struct find_scalar_mul_conv { @@ -456,39 +485,31 @@ struct find_scalar_mul_conv { return match::name("mul")( match::either_arg(0, 1)( - conv_const_weights().bind("conv"), - match::either_arg(0, 1)( - match::name("broadcast", "multibroadcast", "constant").bind("scalar"), - match::any().bind("scalar") - ) - ) - ); - } - void apply(module& m, const match::matcher_result& r) const - { - auto ins = r.result; - auto conv_ins = r.instructions["conv"]; - auto scalar_ins = r.instructions["scalar"]; - auto w_ins = r.instructions["w"]; - - if(scalar_ins->get_shape().elements() != 1) - return; - const auto& w_shape = w_ins->get_shape().lens(); + match::is_constant().bind("scalar"), + match::name("convolution").bind("conv") + )); + } - if(scalar_ins->get_shape().ndim() != w_shape.size()) - { - scalar_ins = m.insert_instruction(ins, make_op("broadcast", {{"axis", 0}, {"out_lens", w_shape}}), scalar_ins); - } + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto scalar = r.instructions["scalar"]; + auto conv_ins = r.instructions["conv"]; - auto new_weights = m.insert_instruction(ins, make_op("mul"), scalar_ins, w_ins); + // Get the convol's input and weights + auto conv_input = conv_ins->inputs().front(); // input to conv + auto conv_weights = conv_ins->inputs().back(); // weights of the conv - auto new_conv = m.insert_instruction( - ins, conv_ins->get_operator(), conv_ins->inputs().front(), new_weights); + auto scaled_weights = m.insert_instruction(ins, make_op("mul"), scalar, conv_weights); + + // new conv with modified weights + auto new_conv = m.insert_instruction(ins, conv_ins->get_operator(), conv_input, scaled_weights); m.replace_instruction(ins, new_conv); } }; + struct find_dot_add { auto matcher() const