From 844bbfe632de3d630c632e7970b3b02c94034c5c Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 12 Oct 2024 15:25:18 -0700 Subject: [PATCH 01/37] Add methods to shape transform descriptor to get the common shape --- .../migraphx/shape_transform_descriptor.hpp | 4 + src/shape_transform_descriptor.cpp | 137 +++++++++++++++++- 2 files changed, 136 insertions(+), 5 deletions(-) diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index a8851de753f..cb9b46d8d16 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -82,6 +82,10 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor void simplify(); std::size_t elements() const; std::vector generate() const; + std::vector generate_common_from_src(const std::vector& input_dims) const; + std::vector generate_common_from_dst(const std::vector& input_dims) const; + std::vector> common_axes_map_from_src() const; + std::vector> common_axes_map_from_dst() const; struct MIGRAPHX_EXPORT dimension { diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index e336674d660..016e40bbea9 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -535,6 +535,15 @@ void shape_transform_descriptor::simplify() collapse_1_dims(dimensions); } +static std::size_t get_len(const dimension::sub& s, const std::vector& input_dims) +{ + if(not s.axis.empty() and not input_dims.empty() and input_dims.at(s.axis.front()) == 1) + { + return 1; + } + return s.len; +} + static operation make_reshape_squeeze(const std::vector& new_dims) { // Can use squeeze @@ -605,7 +614,7 @@ static void flatten_broadcasted_dim(dimension::sub& s) } } -static operation make_reshape_unsqueeze(const std::vector& subs) +static operation make_reshape_unsqueeze(const std::vector& subs, const std::vector& input_dims = {}) { bool use_reshape = false; // Check if split dimensions are all additional 1s @@ -630,7 +639,7 @@ static operation make_reshape_unsqueeze(const std::vector& subs) return; // Number of elements that are 1 auto n1 = - std::count_if(start, last, [](const dimension::sub& s) { return s.len == 1; }); + std::count_if(start, last, [&](const dimension::sub& s) { return get_len(s, input_dims) == 1; }); use_reshape |= std::max(0, n - n1 - 1) > 0; }, by_axis); @@ -641,10 +650,10 @@ static operation make_reshape_unsqueeze(const std::vector& subs) std::transform(subs.begin(), subs.end(), std::back_inserter(dims), - [](const dimension::sub& s) -> std::size_t { + [&](const dimension::sub& s) -> std::size_t { if(s.axis.empty()) return 1; - return s.len; + return get_len(s, input_dims); }); return make_op("reshape", {{"dims", dims}}); } @@ -656,7 +665,7 @@ static operation make_reshape_unsqueeze(const std::vector& subs) const auto& sub = subs[i]; if(sub.axis.size() == 1) continue; - if(sub.len != 1 and not sub.axis.empty()) + if(get_len(sub, input_dims) != 1 and not sub.axis.empty()) continue; axes.push_back(i); } @@ -792,6 +801,124 @@ std::vector shape_transform_descriptor::generate() const return result; } +std::vector shape_transform_descriptor::generate_common_from_src(const std::vector& input_dims) const +{ + std::vector result; + auto subs = get_all_subdimensions(dimensions); + // Need multibroadcast + if(std::any_of(subs.begin(), subs.end(), [&](const dimension::sub& s) { + return s.axis.empty() and get_len(s, input_dims) != 1; + })) + { + std::vector out_lens; + std::transform(subs.begin(), + subs.end(), + std::back_inserter(out_lens), + [&](const dimension::sub& s) { return get_len(s, input_dims); }); + result.push_back(make_op("multibroadcast", {{"out_lens", out_lens}})); + } + + // Flatten broadcasted subdimensions + std::for_each(subs.begin(), subs.end(), &flatten_broadcasted_dim); + + auto tsubs = subs; + // Inject additonal axis to compute transpose permutation better + auto is_empty_axis = [](const auto& s) { return s.axis.empty(); }; + group_find(tsubs.begin(), tsubs.end(), is_empty_axis, [&](auto start, auto last) { + if(start == tsubs.begin()) + return; + auto base = std::prev(start); + auto axis = base->axis; + axis.push_back(0); + std::for_each(start, last, [&](auto& s) { + s.axis = axis; + axis.back()++; + }); + }); + + auto compare_sub = [](auto f) { + return by(f, [](const dimension::sub& s) -> const auto& { return s.axis; }); + }; + // Need transpose + if(not std::is_sorted(tsubs.begin(), tsubs.end(), compare_sub(std::less<>{}))) + { + auto permutation = sort_permutation(tsubs, compare_sub(std::less<>{})); + result.push_back(make_op("transpose", {{"permutation", invert_permutation(permutation)}})); + subs = reorder_dims(subs, permutation); + } + // Need reshape unsqueeze + if(std::any_of( + subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() != 1; })) + { + result.push_back(make_reshape_unsqueeze(subs, input_dims)); + } + std::reverse(result.begin(), result.end()); + return result; +} +std::vector shape_transform_descriptor::generate_common_from_dst(const std::vector& input_dims) const +{ + std::vector result; + auto subs = get_all_subdimensions(dimensions); + // Need reshape unsqueeze + if(std::any_of( + subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() != 1; })) + { + result.push_back(make_reshape_unsqueeze(subs, input_dims)); + } + std::reverse(result.begin(), result.end()); + return result; +} + +std::vector> shape_transform_descriptor::common_axes_map_from_src() const +{ + std::vector> result; + auto subs = get_all_subdimensions(dimensions); + std::map> axes_map; + for(const auto& s:subs) + { + std::size_t axis = -1; + if(s.axis.empty()) + { + if(s.hidden_axis.has_value()) + axis = s.hidden_axis.value(); + else + continue; + } + else + { + axis = s.axis.front(); + } + axes_map[axis].push_back(&s); + } + for(auto&& p:axes_map) + { + std::sort(p.second.begin(), p.second.end(), by(std::less<>{}, [](const dimension::sub* s) { + return s->axis; + })); + } + auto max_axis = std::prev(axes_map.end())->first; + result.resize(max_axis); + for(auto&& p:axes_map) + { + std::transform(p.second.begin(), p.second.end(), std::back_inserter(result[p.first]), [&](const dimension::sub* s) { + return s - subs.data(); + }); + } + return result; +} +std::vector> shape_transform_descriptor::common_axes_map_from_dst() const +{ + std::vector> result; + std::size_t start = 0; + for(const auto& d:dimensions) + { + auto& v = result.emplace_back(d.subdimensions.size()); + std::iota(v.begin(), v.end(), start); + start += d.subdimensions.size(); + } + return result; +} + std::size_t dimension::len() const { return transform_accumulate(subdimensions.begin(), From 21c5fd0ecd89c82ed2329a937ebcd001f327ef87 Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 12 Oct 2024 15:25:28 -0700 Subject: [PATCH 02/37] Format --- .../migraphx/shape_transform_descriptor.hpp | 6 ++-- src/shape_transform_descriptor.cpp | 33 +++++++++++-------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index cb9b46d8d16..2452e8953d8 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -82,8 +82,10 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor void simplify(); std::size_t elements() const; std::vector generate() const; - std::vector generate_common_from_src(const std::vector& input_dims) const; - std::vector generate_common_from_dst(const std::vector& input_dims) const; + std::vector + generate_common_from_src(const std::vector& input_dims) const; + std::vector + generate_common_from_dst(const std::vector& input_dims) const; std::vector> common_axes_map_from_src() const; std::vector> common_axes_map_from_dst() const; diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 016e40bbea9..43c44097c3c 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -614,7 +614,8 @@ static void flatten_broadcasted_dim(dimension::sub& s) } } -static operation make_reshape_unsqueeze(const std::vector& subs, const std::vector& input_dims = {}) +static operation make_reshape_unsqueeze(const std::vector& subs, + const std::vector& input_dims = {}) { bool use_reshape = false; // Check if split dimensions are all additional 1s @@ -638,8 +639,9 @@ static operation make_reshape_unsqueeze(const std::vector& subs, if(n < 2) return; // Number of elements that are 1 - auto n1 = - std::count_if(start, last, [&](const dimension::sub& s) { return get_len(s, input_dims) == 1; }); + auto n1 = std::count_if(start, last, [&](const dimension::sub& s) { + return get_len(s, input_dims) == 1; + }); use_reshape |= std::max(0, n - n1 - 1) > 0; }, by_axis); @@ -801,7 +803,8 @@ std::vector shape_transform_descriptor::generate() const return result; } -std::vector shape_transform_descriptor::generate_common_from_src(const std::vector& input_dims) const +std::vector shape_transform_descriptor::generate_common_from_src( + const std::vector& input_dims) const { std::vector result; auto subs = get_all_subdimensions(dimensions); @@ -855,7 +858,8 @@ std::vector shape_transform_descriptor::generate_common_from_src(cons std::reverse(result.begin(), result.end()); return result; } -std::vector shape_transform_descriptor::generate_common_from_dst(const std::vector& input_dims) const +std::vector shape_transform_descriptor::generate_common_from_dst( + const std::vector& input_dims) const { std::vector result; auto subs = get_all_subdimensions(dimensions); @@ -874,7 +878,7 @@ std::vector> shape_transform_descriptor::common_axes_ma std::vector> result; auto subs = get_all_subdimensions(dimensions); std::map> axes_map; - for(const auto& s:subs) + for(const auto& s : subs) { std::size_t axis = -1; if(s.axis.empty()) @@ -890,19 +894,20 @@ std::vector> shape_transform_descriptor::common_axes_ma } axes_map[axis].push_back(&s); } - for(auto&& p:axes_map) + for(auto&& p : axes_map) { std::sort(p.second.begin(), p.second.end(), by(std::less<>{}, [](const dimension::sub* s) { - return s->axis; - })); + return s->axis; + })); } auto max_axis = std::prev(axes_map.end())->first; result.resize(max_axis); - for(auto&& p:axes_map) + for(auto&& p : axes_map) { - std::transform(p.second.begin(), p.second.end(), std::back_inserter(result[p.first]), [&](const dimension::sub* s) { - return s - subs.data(); - }); + std::transform(p.second.begin(), + p.second.end(), + std::back_inserter(result[p.first]), + [&](const dimension::sub* s) { return s - subs.data(); }); } return result; } @@ -910,7 +915,7 @@ std::vector> shape_transform_descriptor::common_axes_ma { std::vector> result; std::size_t start = 0; - for(const auto& d:dimensions) + for(const auto& d : dimensions) { auto& v = result.emplace_back(d.subdimensions.size()); std::iota(v.begin(), v.end(), start); From 6cf67aff6bd6c17994da04d6f8ecbe46be869d89 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 13 Oct 2024 15:52:05 -0700 Subject: [PATCH 03/37] Update rewrite_reshapes --- src/include/migraphx/rewrite_reshapes.hpp | 83 +++++++------------ .../migraphx/shape_transform_descriptor.hpp | 6 ++ src/shape_transform_descriptor.cpp | 76 ++++++++++++++++- 3 files changed, 110 insertions(+), 55 deletions(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index e44413a09a1..35475b2e01a 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -33,6 +33,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -72,18 +73,11 @@ struct rewrite_reshapes auto matcher() const { - auto reshape = - match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once()); - auto skip_contiguous_broadcast = - match::skip(match::name("contiguous", "multibroadcast")(match::used_once())); - auto skip_contiguous_broadcast_arg = [&](auto... ms) { - return match::arg(0)(skip_contiguous_broadcast(ms...)); - }; + auto reshapes = + match::name("reshape", "squeeze", "unsqueeze", "flatten", "transpose")(match::used_once()); auto pointwise = match::name(op1)(match::used_once()); - auto reshape_pointwise = - reshape(skip_contiguous_broadcast_arg(pointwise.bind("x"))).bind("reshape"); - return match::name(op2)(match::any_of[match::inputs()]( - skip_contiguous_broadcast(reshape_pointwise).bind("input"))); + auto reshapes_pointwise = reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x")))); + return match::name(op2)(match::any_of[match::inputs()](reshapes_pointwise.bind("input"))); } template @@ -124,61 +118,46 @@ struct rewrite_reshapes { auto ins = r.result; auto x_ins = r.instructions["x"]; - auto reshape_ins = r.instructions["reshape"]; auto input_ins = r.instructions["input"]; - const auto has_broadcast_before_reshape = is_broadcasted(reshape_ins, x_ins); - const auto has_broadcast_after_reshape = is_broadcasted(input_ins, reshape_ins); - if(not has_broadcast_before_reshape.has_value()) - return; - if(not has_broadcast_after_reshape.has_value()) - return; - if(*has_broadcast_after_reshape and *has_broadcast_before_reshape) - return; - const bool has_broadcast = - *has_broadcast_after_reshape or *has_broadcast_before_reshape; - - auto dims1 = T::base_dims(ins); - auto dims2 = T::base_dims(x_ins); - - if(elements(dims1) != elements(dims2)) - return; - - auto cd = common_dims::compute(T::base_dims(ins), T::base_dims(x_ins)); - if(cd.dims.empty()) - return; + std::vector ops; + auto reshape_ins = input_ins; + while(reshape_ins != x_ins) + { + ops.push_back(reshape_ins->get_operator()); + reshape_ins = reshape_ins->inputs().front(); + } + assert(reshape_ins == x_ins); + std::reverse(ops.begin(), ops.end()); - if(ins->name() != "pointwise" and not T::supports(ins, cd.dims, cd.axes_map1)) - return; - if(x_ins->name() != "pointwise" and not T::supports(x_ins, cd.dims, cd.axes_map2)) + auto desc = shape_transform_descriptor::create(x_ins->get_shape().lens(), ops); + if(desc.empty()) return; - - auto reshape_input = [&](const auto& ins_to_insert) { - return [&](auto input) { - auto dims = cd.get_dimensions_for(input->get_shape().lens()); - return mpm.get_module().insert_instruction( - ins_to_insert, make_op("reshape", {{"dims", dims}}), input); + auto reshape_input = [&](const auto& ins_to_insert, auto generate) { + return [&, generate](auto input) { + auto gops = std::invoke(generate, desc, input->get_shape().lens()); + auto start = input; + for(auto op:gops) + { + start = mpm.get_module().insert_instruction(ins_to_insert, op, start); + } + return start; }; }; auto x_inputs = x_ins->inputs(); std::transform( - x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins)); - auto new_x_ins = insert(mpm, x_ins, x_inputs, cd.axes_map2); - if(has_broadcast) - { - new_x_ins = mpm.get_module().insert_instruction( - x_ins, make_op("multibroadcast", {{"out_lens", cd.dims}}), new_x_ins); - } + x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins, &shape_transform_descriptor::generate_common_from_src)); + auto new_x_ins = insert(mpm, x_ins, x_inputs, desc.common_axes_map_from_src()); auto inputs = ins->inputs(); std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { if(input == input_ins) return new_x_ins; - return reshape_input(ins)(input); + return reshape_input(ins, &shape_transform_descriptor::generate_common_from_dst)(input); }); - auto pw = insert(mpm, ins, inputs, cd.axes_map1); - mpm.get_module().replace_instruction( - ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), pw); + auto pw = insert(mpm, ins, inputs, desc.common_axes_map_from_dst()); + auto rins = reshape_input(ins, &shape_transform_descriptor::generate_dst_from_common)(pw); + mpm.get_module().replace_instruction(ins, rins); } static bool same_dims(instruction_ref ins) diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index 2452e8953d8..2f9e5a01d97 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -74,6 +74,8 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor shape_transform_descriptor() = default; explicit shape_transform_descriptor(const std::vector& dims); + static shape_transform_descriptor create(const std::vector& dims, const std::vector& ops); + bool apply(const std::vector& ops); bool apply_reshape(const std::vector& rdims); bool apply_transpose(const std::vector& permutation); @@ -86,9 +88,13 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor generate_common_from_src(const std::vector& input_dims) const; std::vector generate_common_from_dst(const std::vector& input_dims) const; + std::vector + generate_dst_from_common(const std::vector& input_dims) const; std::vector> common_axes_map_from_src() const; std::vector> common_axes_map_from_dst() const; + bool empty() const; + struct MIGRAPHX_EXPORT dimension { void simplify(); diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 43c44097c3c..1d41332996e 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -99,6 +99,16 @@ std::vector compute_dims(const std::vector& ops, return s.lens(); } +shape_transform_descriptor shape_transform_descriptor::create(const std::vector& dims, const std::vector& ops) +{ + shape_transform_descriptor result{dims}; + if(not result.apply(ops)) + return {}; + result.simplify(); + assert(compute_dims(ops, dims) == compute_dims(result.generate(), dims)); + return result; +} + bool shape_transform_descriptor::apply(const std::vector& ops) { std::vector dims; @@ -538,9 +548,9 @@ void shape_transform_descriptor::simplify() static std::size_t get_len(const dimension::sub& s, const std::vector& input_dims) { if(not s.axis.empty() and not input_dims.empty() and input_dims.at(s.axis.front()) == 1) - { return 1; - } + if(s.axis.size() == 1) + return input_dims.at(s.axis.front()); return s.len; } @@ -872,6 +882,59 @@ std::vector shape_transform_descriptor::generate_common_from_dst( std::reverse(result.begin(), result.end()); return result; } +std::vector shape_transform_descriptor::generate_dst_from_common( + const std::vector& input_dims) const +{ + std::vector result; + std::vector new_dims = dimensions; + // Need broadcast + if(std::any_of(new_dims.begin(), new_dims.end(), &is_broadcast_dim)) + { + std::vector out_lens; + std::transform(new_dims.begin(), + new_dims.end(), + std::back_inserter(out_lens), + [](const dimension& d) { return d.len(); }); + auto startb = std::find_if_not(new_dims.begin(), new_dims.end(), &has_no_axes); + auto trailb = std::find_if_not(startb, new_dims.end(), &has_axes); + auto axis = std::distance(new_dims.begin(), startb); + auto extra_dims = axis + std::distance(trailb, new_dims.end()); + // Use broadcast instead of multibroadcast + if(std::all_of(trailb, new_dims.end(), &has_no_axes) and extra_dims > 0 and + axis < new_dims.size()) + { + result.push_back(make_op("broadcast", {{"axis", axis}, {"out_lens", out_lens}})); + new_dims.erase(trailb, new_dims.end()); + new_dims.erase(new_dims.begin(), new_dims.begin() + axis); + } + else + { + result.push_back(make_op("multibroadcast", {{"out_lens", out_lens}})); + } + } + // If all the dimensions have no axes then there isnt anthing else to do + // so just clear the new_dims + if(std::all_of(new_dims.begin(), new_dims.end(), &has_no_axes)) + new_dims.clear(); + // Flatten broadcasted dimensions + for(auto& d : new_dims) + { + if(d.subdimensions.size() != 1) + continue; + flatten_broadcasted_dim(d.subdimensions.front()); + } + // Need squeeze reshape + if(std::any_of(new_dims.begin(), new_dims.end(), [](const dimension& d) { + if(d.subdimensions.size() != 1) + return true; + return is_broadcast_dim(d); + })) + { + result.push_back(make_reshape_squeeze(new_dims)); + } + std::reverse(result.begin(), result.end()); + return result; +} std::vector> shape_transform_descriptor::common_axes_map_from_src() const { @@ -900,10 +963,12 @@ std::vector> shape_transform_descriptor::common_axes_ma return s->axis; })); } + assert(not axes_map.empty()); auto max_axis = std::prev(axes_map.end())->first; - result.resize(max_axis); + result.resize(max_axis + 1); for(auto&& p : axes_map) { + assert(p.first < result.size()); std::transform(p.second.begin(), p.second.end(), std::back_inserter(result[p.first]), @@ -924,6 +989,11 @@ std::vector> shape_transform_descriptor::common_axes_ma return result; } +bool shape_transform_descriptor::empty() const +{ + return dimensions.empty(); +} + std::size_t dimension::len() const { return transform_accumulate(subdimensions.begin(), From d556255e57d884c61b864e40b90ef85ccd13679f Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 13 Oct 2024 15:52:08 -0700 Subject: [PATCH 04/37] Format --- src/include/migraphx/rewrite_reshapes.hpp | 25 ++++++++++++------- .../migraphx/shape_transform_descriptor.hpp | 3 ++- src/shape_transform_descriptor.cpp | 8 +++--- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index 35475b2e01a..c90a4570c21 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -73,11 +73,13 @@ struct rewrite_reshapes auto matcher() const { - auto reshapes = - match::name("reshape", "squeeze", "unsqueeze", "flatten", "transpose")(match::used_once()); + auto reshapes = match::name("reshape", "squeeze", "unsqueeze", "flatten", "transpose")( + match::used_once()); auto pointwise = match::name(op1)(match::used_once()); - auto reshapes_pointwise = reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x")))); - return match::name(op2)(match::any_of[match::inputs()](reshapes_pointwise.bind("input"))); + auto reshapes_pointwise = + reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x")))); + return match::name(op2)( + match::any_of[match::inputs()](reshapes_pointwise.bind("input"))); } template @@ -135,9 +137,9 @@ struct rewrite_reshapes return; auto reshape_input = [&](const auto& ins_to_insert, auto generate) { return [&, generate](auto input) { - auto gops = std::invoke(generate, desc, input->get_shape().lens()); + auto gops = std::invoke(generate, desc, input->get_shape().lens()); auto start = input; - for(auto op:gops) + for(auto op : gops) { start = mpm.get_module().insert_instruction(ins_to_insert, op, start); } @@ -146,17 +148,22 @@ struct rewrite_reshapes }; auto x_inputs = x_ins->inputs(); std::transform( - x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins, &shape_transform_descriptor::generate_common_from_src)); + x_inputs.begin(), + x_inputs.end(), + x_inputs.begin(), + reshape_input(x_ins, &shape_transform_descriptor::generate_common_from_src)); auto new_x_ins = insert(mpm, x_ins, x_inputs, desc.common_axes_map_from_src()); auto inputs = ins->inputs(); std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { if(input == input_ins) return new_x_ins; - return reshape_input(ins, &shape_transform_descriptor::generate_common_from_dst)(input); + return reshape_input(ins, + &shape_transform_descriptor::generate_common_from_dst)(input); }); auto pw = insert(mpm, ins, inputs, desc.common_axes_map_from_dst()); - auto rins = reshape_input(ins, &shape_transform_descriptor::generate_dst_from_common)(pw); + auto rins = + reshape_input(ins, &shape_transform_descriptor::generate_dst_from_common)(pw); mpm.get_module().replace_instruction(ins, rins); } diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index 2f9e5a01d97..a6b1b7b6d19 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -74,7 +74,8 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor shape_transform_descriptor() = default; explicit shape_transform_descriptor(const std::vector& dims); - static shape_transform_descriptor create(const std::vector& dims, const std::vector& ops); + static shape_transform_descriptor create(const std::vector& dims, + const std::vector& ops); bool apply(const std::vector& ops); bool apply_reshape(const std::vector& rdims); diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 1d41332996e..5c3b962e44d 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -99,7 +99,8 @@ std::vector compute_dims(const std::vector& ops, return s.lens(); } -shape_transform_descriptor shape_transform_descriptor::create(const std::vector& dims, const std::vector& ops) +shape_transform_descriptor shape_transform_descriptor::create(const std::vector& dims, + const std::vector& ops) { shape_transform_descriptor result{dims}; if(not result.apply(ops)) @@ -989,10 +990,7 @@ std::vector> shape_transform_descriptor::common_axes_ma return result; } -bool shape_transform_descriptor::empty() const -{ - return dimensions.empty(); -} +bool shape_transform_descriptor::empty() const { return dimensions.empty(); } std::size_t dimension::len() const { From 083f22b062aec2a93e7aae669369df560c2efa04 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 22 Oct 2024 13:44:42 -0700 Subject: [PATCH 05/37] Handle empty dims --- src/shape_transform_descriptor.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 910c73c486f..fb149384c26 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -562,7 +562,9 @@ void shape_transform_descriptor::simplify() static std::size_t get_len(const dimension::sub& s, const std::vector& input_dims) { - if(not s.axis.empty() and not input_dims.empty() and input_dims.at(s.axis.front()) == 1) + if(input_dims.empty()) + return s.len; + if(not s.axis.empty() and input_dims.at(s.axis.front()) == 1) return 1; if(s.axis.size() == 1) return input_dims.at(s.axis.front()); From 5059a6c27d9eec20c86f59b395db40bae030c0e6 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 22 Oct 2024 14:07:35 -0700 Subject: [PATCH 06/37] Fix dims from dst --- src/shape_transform_descriptor.cpp | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index fb149384c26..c8ffe3e5668 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -564,10 +564,15 @@ static std::size_t get_len(const dimension::sub& s, const std::vector shape_transform_descriptor::generate_common_from_src( std::vector shape_transform_descriptor::generate_common_from_dst( const std::vector& input_dims) const { - std::vector result; auto subs = get_all_subdimensions(dimensions); // Need reshape unsqueeze if(std::any_of( subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() != 1; })) { - result.push_back(make_reshape_unsqueeze(subs, input_dims)); + // Map the input dims back to the src input if possible + std::vector src_input_dims(this->rank); + for(auto i:range(dimensions.size())) + { + const auto& d = dimensions[i]; + if(d.subdimensions.size() != 1) + continue; + const auto& sub = d.subdimensions.front(); + if(sub.axis.size() != 1) + continue; + auto axis = sub.axis.front(); + src_input_dims[axis] = input_dims[i]; + } + return {make_reshape_unsqueeze(subs, src_input_dims)}; } - std::reverse(result.begin(), result.end()); - return result; + return {}; } std::vector shape_transform_descriptor::generate_dst_from_common( const std::vector& input_dims) const From b209a004608573fefef88b87e865d57308eceec9 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 22 Oct 2024 14:07:41 -0700 Subject: [PATCH 07/37] Format --- src/shape_transform_descriptor.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index c8ffe3e5668..9143a09bbb2 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -900,7 +900,7 @@ std::vector shape_transform_descriptor::generate_common_from_dst( { // Map the input dims back to the src input if possible std::vector src_input_dims(this->rank); - for(auto i:range(dimensions.size())) + for(auto i : range(dimensions.size())) { const auto& d = dimensions[i]; if(d.subdimensions.size() != 1) @@ -908,7 +908,7 @@ std::vector shape_transform_descriptor::generate_common_from_dst( const auto& sub = d.subdimensions.front(); if(sub.axis.size() != 1) continue; - auto axis = sub.axis.front(); + auto axis = sub.axis.front(); src_input_dims[axis] = input_dims[i]; } return {make_reshape_unsqueeze(subs, src_input_dims)}; From 1e7e48140b248eb898744ec92f69fcb4db5ec8f6 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 22 Oct 2024 14:13:14 -0700 Subject: [PATCH 08/37] Add contiguous --- src/include/migraphx/rewrite_reshapes.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index c90a4570c21..922c20175db 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -73,7 +73,7 @@ struct rewrite_reshapes auto matcher() const { - auto reshapes = match::name("reshape", "squeeze", "unsqueeze", "flatten", "transpose")( + auto reshapes = match::name("reshape", "squeeze", "unsqueeze", "flatten", "transpose", "contiguous")( match::used_once()); auto pointwise = match::name(op1)(match::used_once()); auto reshapes_pointwise = From 985f34219e89cf4068eac89dafdd843d83bb501c Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 22 Oct 2024 14:14:58 -0700 Subject: [PATCH 09/37] Format --- src/include/migraphx/rewrite_reshapes.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index 922c20175db..ebf9c628fe1 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -73,7 +73,8 @@ struct rewrite_reshapes auto matcher() const { - auto reshapes = match::name("reshape", "squeeze", "unsqueeze", "flatten", "transpose", "contiguous")( + auto reshapes = match::name( + "reshape", "squeeze", "unsqueeze", "flatten", "transpose", "contiguous")( match::used_once()); auto pointwise = match::name(op1)(match::used_once()); auto reshapes_pointwise = From 93dec17e0df9cb3977e30f4625c1a2e365c72151 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 23 Oct 2024 09:50:12 -0700 Subject: [PATCH 10/37] Fix fuse_pointwise test --- src/shape_transform_descriptor.cpp | 35 +++++++++++++++--------------- test/fuse_pointwise.cpp | 4 ++-- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 9143a09bbb2..7d2cc7a0c12 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -893,27 +893,26 @@ std::vector shape_transform_descriptor::generate_common_from_src( std::vector shape_transform_descriptor::generate_common_from_dst( const std::vector& input_dims) const { + // Need reshape + if(std::all_of(dimensions.begin(), dimensions.end(), [](const dimension& d) { + return d.subdimensions.size() == 1; + })) + return {}; auto subs = get_all_subdimensions(dimensions); - // Need reshape unsqueeze - if(std::any_of( - subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() != 1; })) + // Map the input dims back to the src input if possible + std::vector src_input_dims(this->rank); + for(auto i : range(dimensions.size())) { - // Map the input dims back to the src input if possible - std::vector src_input_dims(this->rank); - for(auto i : range(dimensions.size())) - { - const auto& d = dimensions[i]; - if(d.subdimensions.size() != 1) - continue; - const auto& sub = d.subdimensions.front(); - if(sub.axis.size() != 1) - continue; - auto axis = sub.axis.front(); - src_input_dims[axis] = input_dims[i]; - } - return {make_reshape_unsqueeze(subs, src_input_dims)}; + const auto& d = dimensions[i]; + if(d.subdimensions.size() != 1) + continue; + const auto& sub = d.subdimensions.front(); + if(sub.axis.size() != 1) + continue; + auto axis = sub.axis.front(); + src_input_dims[axis] = input_dims[i]; } - return {}; + return {make_reshape_unsqueeze(subs, src_input_dims)}; } std::vector shape_transform_descriptor::generate_dst_from_common( const std::vector& input_dims) const diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 93881036efa..3410bfd7235 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -531,8 +531,8 @@ TEST_CASE(add_unsqueeze_add_nonstandard) auto x = mm->add_parameter("x", s1); auto y = mm->add_parameter("y", s1); auto z = mm->add_parameter("z", s2); - auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), x); - auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), y); + auto x2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x); + auto y2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), y); auto fadd = add_pointwise(p2, "main:pointwise0", {x2, y2, z}, [=](auto* pm, const auto& inputs) { auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); From bf65516b4a398e7a585c1b989ff07b5f18ee0c4c Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 23 Oct 2024 09:50:20 -0700 Subject: [PATCH 11/37] Foramt --- src/shape_transform_descriptor.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 7d2cc7a0c12..aca17cbe51b 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -895,8 +895,8 @@ std::vector shape_transform_descriptor::generate_common_from_dst( { // Need reshape if(std::all_of(dimensions.begin(), dimensions.end(), [](const dimension& d) { - return d.subdimensions.size() == 1; - })) + return d.subdimensions.size() == 1; + })) return {}; auto subs = get_all_subdimensions(dimensions); // Map the input dims back to the src input if possible From e260aca92f15ab6da9e8d205b8a0f8dd5a2c0704 Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 26 Oct 2024 16:05:38 -0700 Subject: [PATCH 12/37] Add transpose test --- test/fuse_pointwise.cpp | 46 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 3410bfd7235..cddf2cd9770 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -427,6 +427,48 @@ TEST_CASE(add_reshape_add) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(add_transpose_reshape_add) +{ + migraphx::shape s1{migraphx::shape::float_type, {3, 16, 10}}; + migraphx::shape s2{migraphx::shape::float_type, {3, 40, 2, 2}}; + migraphx::shape s3{migraphx::shape::float_type, {3, 10, 4, 2, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); + auto transpose = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), add1); + auto reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), transpose); + auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z); + mm->add_return({add2}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto z = mm->add_parameter("z", s2); + auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2, 10}}}), x); + auto x3 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 4, 1, 2, 3}}}), x2); + auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2, 10}}}), y); + auto y3 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 4, 1, 2, 3}}}), y2); + auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z); + auto fadd = + add_pointwise(p2, "main:pointwise0", {x3, y3, z2}, [=](auto* pm, const auto& inputs) { + auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); + return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]); + }); + auto reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), fadd); + mm->add_return({reshape}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(add_contiguous_reshape_add) { auto s1 = @@ -531,8 +573,8 @@ TEST_CASE(add_unsqueeze_add_nonstandard) auto x = mm->add_parameter("x", s1); auto y = mm->add_parameter("y", s1); auto z = mm->add_parameter("z", s2); - auto x2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x); - auto y2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), y); + auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), x); + auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), y); auto fadd = add_pointwise(p2, "main:pointwise0", {x2, y2, z}, [=](auto* pm, const auto& inputs) { auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); From 1588f26cdc31efc236e8d38d578cfa112536ebb4 Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 26 Oct 2024 16:05:45 -0700 Subject: [PATCH 13/37] Format --- test/fuse_pointwise.cpp | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index cddf2cd9770..5a330c8b4f5 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -439,7 +439,8 @@ TEST_CASE(add_transpose_reshape_add) auto y = mm->add_parameter("y", s1); auto z = mm->add_parameter("z", s2); auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y); - auto transpose = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), add1); + auto transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), add1); auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), transpose); auto add2 = mm->add_instruction(migraphx::make_op("add"), reshape, z); @@ -452,11 +453,15 @@ TEST_CASE(add_transpose_reshape_add) auto x = mm->add_parameter("x", s1); auto y = mm->add_parameter("y", s1); auto z = mm->add_parameter("z", s2); - auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2, 10}}}), x); - auto x3 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 4, 1, 2, 3}}}), x2); - auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2, 10}}}), y); - auto y3 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 4, 1, 2, 3}}}), y2); - auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z); + auto x2 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2, 10}}}), x); + auto x3 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 4, 1, 2, 3}}}), x2); + auto y2 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, 4, 2, 2, 10}}}), y); + auto y3 = mm->add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 4, 1, 2, 3}}}), y2); + auto z2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s3.lens()}}), z); auto fadd = add_pointwise(p2, "main:pointwise0", {x3, y3, z2}, [=](auto* pm, const auto& inputs) { auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); From 78c542d1bd324f0f3fef1b1237209a0308f30fd6 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 29 Oct 2024 08:57:23 -0700 Subject: [PATCH 14/37] Handle broadcasts --- src/include/migraphx/rewrite_reshapes.hpp | 47 ++++++++++++++++--- .../migraphx/shape_transform_descriptor.hpp | 1 + src/shape_transform_descriptor.cpp | 11 +++++ test/fuse_pointwise.cpp | 4 +- 4 files changed, 54 insertions(+), 9 deletions(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index ebf9c628fe1..6ec164d41ee 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -76,11 +76,14 @@ struct rewrite_reshapes auto reshapes = match::name( "reshape", "squeeze", "unsqueeze", "flatten", "transpose", "contiguous")( match::used_once()); + auto skip_broadcasts = + match::skip(match::name("contiguous", "multibroadcast")(match::used_once())); auto pointwise = match::name(op1)(match::used_once()); auto reshapes_pointwise = - reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x")))); + reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x")))).bind("reshape"); + auto broadcast_reshapes_pointwise = skip_broadcasts(reshapes_pointwise); return match::name(op2)( - match::any_of[match::inputs()](reshapes_pointwise.bind("input"))); + match::any_of[match::inputs()](broadcast_reshapes_pointwise.bind("input"))); } template @@ -97,6 +100,13 @@ struct rewrite_reshapes return last; } + template + static bool any_input_of(instruction_ref start, instruction_ref last, F f) + { + return find_input_if(start, last, f) != last; + } + + static bool match_input(instruction_ref ins, instruction_ref x_ins) { if(ins->inputs().empty()) @@ -117,20 +127,38 @@ struct rewrite_reshapes return result; } + static bool is_broadcast(instruction_ref ins) + { + return ins->name() == "multibroadcast"; + } + void apply(module_pass_manager& mpm, const match::matcher_result& r) const { auto ins = r.result; auto x_ins = r.instructions["x"]; auto input_ins = r.instructions["input"]; + auto reshape_ins = r.instructions["reshape"]; + + auto broadcast_ins = + find_input_if(input_ins, reshape_ins, &is_broadcast); + if(broadcast_ins != reshape_ins and any_input_of(broadcast_ins, reshape_ins, &is_broadcast)) + return; + const auto has_broadcast = broadcast_ins != reshape_ins and broadcast_ins->name() == "multibroadcast"; + + auto dims1 = T::base_dims(ins); + auto dims2 = T::base_dims(x_ins); + + if(elements(dims1) != elements(dims2)) + return; std::vector ops; - auto reshape_ins = input_ins; - while(reshape_ins != x_ins) + auto next_ins = input_ins; + while(next_ins != x_ins) { - ops.push_back(reshape_ins->get_operator()); - reshape_ins = reshape_ins->inputs().front(); + ops.push_back(next_ins->get_operator()); + next_ins = next_ins->inputs().front(); } - assert(reshape_ins == x_ins); + assert(next_ins == x_ins); std::reverse(ops.begin(), ops.end()); auto desc = shape_transform_descriptor::create(x_ins->get_shape().lens(), ops); @@ -154,6 +182,11 @@ struct rewrite_reshapes x_inputs.begin(), reshape_input(x_ins, &shape_transform_descriptor::generate_common_from_src)); auto new_x_ins = insert(mpm, x_ins, x_inputs, desc.common_axes_map_from_src()); + if(has_broadcast) + { + new_x_ins = mpm.get_module().insert_instruction( + x_ins, make_op("multibroadcast", {{"out_lens", desc.common_dims(dims1)}}), new_x_ins); + } auto inputs = ins->inputs(); std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index c369a232ad1..4de2b07c666 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -85,6 +85,7 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor optional axis = nullopt); void simplify(); std::size_t elements() const; + std::vector common_dims(const std::vector& input_dims = {}) const; std::vector generate() const; std::vector generate_common_from_src(const std::vector& input_dims) const; diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index aca17cbe51b..ce0c6968884 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -1040,6 +1040,17 @@ std::size_t shape_transform_descriptor::elements() const std::multiplies<>{}, [](const auto& s) { return s.len(); }); } +std::vector shape_transform_descriptor::common_dims(const std::vector& input_dims) const +{ + std::vector result; + for(const auto& d : dimensions) + { + std::transform(d.subdimensions.begin(), d.subdimensions.end(), std::back_inserter(result), [&](const dimension::sub& s) { + return get_len(s, input_dims); + }); + } + return result; +} bool operator==(const dimension::sub& x, const dimension::sub& y) { diff --git a/test/fuse_pointwise.cpp b/test/fuse_pointwise.cpp index 5a330c8b4f5..69a698ddec0 100644 --- a/test/fuse_pointwise.cpp +++ b/test/fuse_pointwise.cpp @@ -578,8 +578,8 @@ TEST_CASE(add_unsqueeze_add_nonstandard) auto x = mm->add_parameter("x", s1); auto y = mm->add_parameter("y", s1); auto z = mm->add_parameter("z", s2); - auto x2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), x); - auto y2 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), y); + auto x2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x); + auto y2 = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), y); auto fadd = add_pointwise(p2, "main:pointwise0", {x2, y2, z}, [=](auto* pm, const auto& inputs) { auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]); From 4d35f193c2984607bf1284066b834717c2c0df93 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 29 Oct 2024 08:57:44 -0700 Subject: [PATCH 15/37] Format --- src/include/migraphx/rewrite_reshapes.hpp | 22 +++++++++++----------- src/shape_transform_descriptor.cpp | 10 ++++++---- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index 6ec164d41ee..675e424994a 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -80,7 +80,8 @@ struct rewrite_reshapes match::skip(match::name("contiguous", "multibroadcast")(match::used_once())); auto pointwise = match::name(op1)(match::used_once()); auto reshapes_pointwise = - reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x")))).bind("reshape"); + reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x")))) + .bind("reshape"); auto broadcast_reshapes_pointwise = skip_broadcasts(reshapes_pointwise); return match::name(op2)( match::any_of[match::inputs()](broadcast_reshapes_pointwise.bind("input"))); @@ -106,7 +107,6 @@ struct rewrite_reshapes return find_input_if(start, last, f) != last; } - static bool match_input(instruction_ref ins, instruction_ref x_ins) { if(ins->inputs().empty()) @@ -127,10 +127,7 @@ struct rewrite_reshapes return result; } - static bool is_broadcast(instruction_ref ins) - { - return ins->name() == "multibroadcast"; - } + static bool is_broadcast(instruction_ref ins) { return ins->name() == "multibroadcast"; } void apply(module_pass_manager& mpm, const match::matcher_result& r) const { @@ -139,11 +136,12 @@ struct rewrite_reshapes auto input_ins = r.instructions["input"]; auto reshape_ins = r.instructions["reshape"]; - auto broadcast_ins = - find_input_if(input_ins, reshape_ins, &is_broadcast); - if(broadcast_ins != reshape_ins and any_input_of(broadcast_ins, reshape_ins, &is_broadcast)) + auto broadcast_ins = find_input_if(input_ins, reshape_ins, &is_broadcast); + if(broadcast_ins != reshape_ins and + any_input_of(broadcast_ins, reshape_ins, &is_broadcast)) return; - const auto has_broadcast = broadcast_ins != reshape_ins and broadcast_ins->name() == "multibroadcast"; + const auto has_broadcast = + broadcast_ins != reshape_ins and broadcast_ins->name() == "multibroadcast"; auto dims1 = T::base_dims(ins); auto dims2 = T::base_dims(x_ins); @@ -185,7 +183,9 @@ struct rewrite_reshapes if(has_broadcast) { new_x_ins = mpm.get_module().insert_instruction( - x_ins, make_op("multibroadcast", {{"out_lens", desc.common_dims(dims1)}}), new_x_ins); + x_ins, + make_op("multibroadcast", {{"out_lens", desc.common_dims(dims1)}}), + new_x_ins); } auto inputs = ins->inputs(); diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index ce0c6968884..bd2f700d5c5 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -1040,14 +1040,16 @@ std::size_t shape_transform_descriptor::elements() const std::multiplies<>{}, [](const auto& s) { return s.len(); }); } -std::vector shape_transform_descriptor::common_dims(const std::vector& input_dims) const +std::vector +shape_transform_descriptor::common_dims(const std::vector& input_dims) const { std::vector result; for(const auto& d : dimensions) { - std::transform(d.subdimensions.begin(), d.subdimensions.end(), std::back_inserter(result), [&](const dimension::sub& s) { - return get_len(s, input_dims); - }); + std::transform(d.subdimensions.begin(), + d.subdimensions.end(), + std::back_inserter(result), + [&](const dimension::sub& s) { return get_len(s, input_dims); }); } return result; } From 98f68666019bfcda0ba4ad98ff23964b56712a15 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 29 Oct 2024 14:18:06 -0700 Subject: [PATCH 16/37] Try to use broadcast in shape transform --- src/include/migraphx/rewrite_reshapes.hpp | 24 +++++++------------ .../migraphx/shape_transform_descriptor.hpp | 6 ++++- src/shape_transform_descriptor.cpp | 16 ++++++++++++- 3 files changed, 28 insertions(+), 18 deletions(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index 675e424994a..ff392b2d874 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -74,17 +74,13 @@ struct rewrite_reshapes auto matcher() const { auto reshapes = match::name( - "reshape", "squeeze", "unsqueeze", "flatten", "transpose", "contiguous")( + "reshape", "squeeze", "unsqueeze", "flatten", "transpose", "contiguous", "multibroadcast", "broadcast")( match::used_once()); - auto skip_broadcasts = - match::skip(match::name("contiguous", "multibroadcast")(match::used_once())); auto pointwise = match::name(op1)(match::used_once()); auto reshapes_pointwise = - reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x")))) - .bind("reshape"); - auto broadcast_reshapes_pointwise = skip_broadcasts(reshapes_pointwise); + reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x")))); return match::name(op2)( - match::any_of[match::inputs()](broadcast_reshapes_pointwise.bind("input"))); + match::any_of[match::inputs()](reshapes_pointwise.bind("input"))); } template @@ -134,14 +130,6 @@ struct rewrite_reshapes auto ins = r.result; auto x_ins = r.instructions["x"]; auto input_ins = r.instructions["input"]; - auto reshape_ins = r.instructions["reshape"]; - - auto broadcast_ins = find_input_if(input_ins, reshape_ins, &is_broadcast); - if(broadcast_ins != reshape_ins and - any_input_of(broadcast_ins, reshape_ins, &is_broadcast)) - return; - const auto has_broadcast = - broadcast_ins != reshape_ins and broadcast_ins->name() == "multibroadcast"; auto dims1 = T::base_dims(ins); auto dims2 = T::base_dims(x_ins); @@ -162,6 +150,9 @@ struct rewrite_reshapes auto desc = shape_transform_descriptor::create(x_ins->get_shape().lens(), ops); if(desc.empty()) return; + const auto has_broadcast = desc.has_broadcast(); + if(has_broadcast) + desc.flatten_broadcast(); auto reshape_input = [&](const auto& ins_to_insert, auto generate) { return [&, generate](auto input) { auto gops = std::invoke(generate, desc, input->get_shape().lens()); @@ -184,7 +175,7 @@ struct rewrite_reshapes { new_x_ins = mpm.get_module().insert_instruction( x_ins, - make_op("multibroadcast", {{"out_lens", desc.common_dims(dims1)}}), + make_op("multibroadcast", {{"out_lens", desc.common_dims(dims2)}}), new_x_ins); } @@ -195,6 +186,7 @@ struct rewrite_reshapes return reshape_input(ins, &shape_transform_descriptor::generate_common_from_dst)(input); }); + mpm.get_module().debug_print(inputs); auto pw = insert(mpm, ins, inputs, desc.common_axes_map_from_dst()); auto rins = reshape_input(ins, &shape_transform_descriptor::generate_dst_from_common)(pw); diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index 4de2b07c666..0136652a732 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -85,8 +85,12 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor optional axis = nullopt); void simplify(); std::size_t elements() const; - std::vector common_dims(const std::vector& input_dims = {}) const; std::vector generate() const; + + bool has_broadcast() const; + void flatten_broadcast(); + + std::vector common_dims(const std::vector& input_dims = {}) const; std::vector generate_common_from_src(const std::vector& input_dims) const; std::vector diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index bd2f700d5c5..b8fe0058d90 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -47,7 +47,7 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim, Proj x *= proj(d); return x == dim; }); - if(it != last) + if(it != last) return it; return start; } @@ -835,6 +835,20 @@ std::vector shape_transform_descriptor::generate() const return result; } +bool shape_transform_descriptor::has_broadcast() const +{ + return std::any_of(dimensions.begin(), dimensions.end(), [&](const dimension& d) { + return std::any_of(d.subdimensions.begin(), d.subdimensions.end(), [&](const dimension::sub& s) { + return s.axis.empty() and s.len != 1; + }); + }); +} +void shape_transform_descriptor::flatten_broadcast() +{ + for(auto& d:dimensions) + std::for_each(d.subdimensions.begin(), d.subdimensions.end(), &flatten_broadcasted_dim); +} + std::vector shape_transform_descriptor::generate_common_from_src( const std::vector& input_dims) const { From d26ab72016846f44e1094fd967f23bdcd41b7c40 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 29 Oct 2024 14:18:12 -0700 Subject: [PATCH 17/37] Format --- src/include/migraphx/rewrite_reshapes.hpp | 11 ++++++++--- src/shape_transform_descriptor.cpp | 10 +++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index ff392b2d874..cb8b7f23bb3 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -73,9 +73,14 @@ struct rewrite_reshapes auto matcher() const { - auto reshapes = match::name( - "reshape", "squeeze", "unsqueeze", "flatten", "transpose", "contiguous", "multibroadcast", "broadcast")( - match::used_once()); + auto reshapes = match::name("reshape", + "squeeze", + "unsqueeze", + "flatten", + "transpose", + "contiguous", + "multibroadcast", + "broadcast")(match::used_once()); auto pointwise = match::name(op1)(match::used_once()); auto reshapes_pointwise = reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x")))); diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index b8fe0058d90..981ca017d8a 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -47,7 +47,7 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim, Proj x *= proj(d); return x == dim; }); - if(it != last) + if(it != last) return it; return start; } @@ -838,14 +838,14 @@ std::vector shape_transform_descriptor::generate() const bool shape_transform_descriptor::has_broadcast() const { return std::any_of(dimensions.begin(), dimensions.end(), [&](const dimension& d) { - return std::any_of(d.subdimensions.begin(), d.subdimensions.end(), [&](const dimension::sub& s) { - return s.axis.empty() and s.len != 1; - }); + return std::any_of(d.subdimensions.begin(), + d.subdimensions.end(), + [&](const dimension::sub& s) { return s.axis.empty() and s.len != 1; }); }); } void shape_transform_descriptor::flatten_broadcast() { - for(auto& d:dimensions) + for(auto& d : dimensions) std::for_each(d.subdimensions.begin(), d.subdimensions.end(), &flatten_broadcasted_dim); } From ea93fafec73eae5c5a447e81d78458d107979dac Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 29 Oct 2024 15:31:07 -0700 Subject: [PATCH 18/37] Handle rebase --- src/include/migraphx/rewrite_reshapes.hpp | 13 +++---- .../migraphx/shape_transform_descriptor.hpp | 3 ++ src/shape_transform_descriptor.cpp | 39 +++++++++++++++++++ 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index cb8b7f23bb3..f5edc1a3db3 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -141,7 +141,7 @@ struct rewrite_reshapes if(elements(dims1) != elements(dims2)) return; - + std::vector ops; auto next_ins = input_ins; while(next_ins != x_ins) @@ -152,12 +152,10 @@ struct rewrite_reshapes assert(next_ins == x_ins); std::reverse(ops.begin(), ops.end()); - auto desc = shape_transform_descriptor::create(x_ins->get_shape().lens(), ops); + auto desc = shape_transform_descriptor::create(x_ins->get_shape().lens(), ops).rebase(dims2); if(desc.empty()) return; - const auto has_broadcast = desc.has_broadcast(); - if(has_broadcast) - desc.flatten_broadcast(); + auto cdims = desc.common_dims(); auto reshape_input = [&](const auto& ins_to_insert, auto generate) { return [&, generate](auto input) { auto gops = std::invoke(generate, desc, input->get_shape().lens()); @@ -176,11 +174,11 @@ struct rewrite_reshapes x_inputs.begin(), reshape_input(x_ins, &shape_transform_descriptor::generate_common_from_src)); auto new_x_ins = insert(mpm, x_ins, x_inputs, desc.common_axes_map_from_src()); - if(has_broadcast) + if(new_x_ins->get_shape().lens() != cdims) { new_x_ins = mpm.get_module().insert_instruction( x_ins, - make_op("multibroadcast", {{"out_lens", desc.common_dims(dims2)}}), + make_op("multibroadcast", {{"out_lens", cdims}}), new_x_ins); } @@ -191,7 +189,6 @@ struct rewrite_reshapes return reshape_input(ins, &shape_transform_descriptor::generate_common_from_dst)(input); }); - mpm.get_module().debug_print(inputs); auto pw = insert(mpm, ins, inputs, desc.common_axes_map_from_dst()); auto rins = reshape_input(ins, &shape_transform_descriptor::generate_dst_from_common)(pw); diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index 0136652a732..41cfeb0aab8 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -77,6 +77,8 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor static shape_transform_descriptor create(const std::vector& dims, const std::vector& ops); + shape_transform_descriptor rebase(const std::vector& dims) const; + bool apply(const std::vector& ops); bool apply_reshape(const std::vector& rdims); bool apply_reshape_impl(const std::vector& rdims); @@ -101,6 +103,7 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor std::vector> common_axes_map_from_dst() const; bool empty() const; + std::vector lens() const; struct MIGRAPHX_EXPORT dimension { diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 981ca017d8a..4ced663470f 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -110,6 +110,36 @@ shape_transform_descriptor shape_transform_descriptor::create(const std::vector< return result; } +shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector& dims) const +{ + auto result = *this; + for(auto& d:result.dimensions) + { + for(auto& sub:d.subdimensions) + { + if(sub.axis.empty() and sub.hidden_axis.has_value()) + { + sub.len = dims.at(sub.hidden_axis.value()); + sub.axis = {sub.hidden_axis.value()}; + sub.hidden_axis = nullopt; + } + else if(sub.axis.size() == 1) + { + sub.len = dims.at(sub.axis.front()); + } + } + } + // TODO: Handle resizes + result.flatten_broadcast(); + if(not result.apply_reshape(this->common_dims())) + return {}; + if(not result.apply_reshape(this->lens())) + return {}; + result.simplify(); + + return result; +} + bool shape_transform_descriptor::apply(const std::vector& ops) { std::vector dims; @@ -1037,6 +1067,15 @@ std::vector> shape_transform_descriptor::common_axes_ma bool shape_transform_descriptor::empty() const { return dimensions.empty(); } +std::vector shape_transform_descriptor::lens() const +{ + std::vector result; + std::transform(dimensions.begin(), dimensions.end(), std::back_inserter(result), [](const dimension& d) { + return d.len(); + }); + return result; +} + std::size_t dimension::len() const { return transform_accumulate(subdimensions.begin(), From 1e2856a1a52e9242ce274afb1d4a8a4e8f2233d1 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 29 Oct 2024 15:31:13 -0700 Subject: [PATCH 19/37] Format --- src/include/migraphx/rewrite_reshapes.hpp | 11 +++++------ src/shape_transform_descriptor.cpp | 22 ++++++++++++---------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index f5edc1a3db3..241a22861ee 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -141,7 +141,7 @@ struct rewrite_reshapes if(elements(dims1) != elements(dims2)) return; - + std::vector ops; auto next_ins = input_ins; while(next_ins != x_ins) @@ -152,10 +152,11 @@ struct rewrite_reshapes assert(next_ins == x_ins); std::reverse(ops.begin(), ops.end()); - auto desc = shape_transform_descriptor::create(x_ins->get_shape().lens(), ops).rebase(dims2); + auto desc = + shape_transform_descriptor::create(x_ins->get_shape().lens(), ops).rebase(dims2); if(desc.empty()) return; - auto cdims = desc.common_dims(); + auto cdims = desc.common_dims(); auto reshape_input = [&](const auto& ins_to_insert, auto generate) { return [&, generate](auto input) { auto gops = std::invoke(generate, desc, input->get_shape().lens()); @@ -177,9 +178,7 @@ struct rewrite_reshapes if(new_x_ins->get_shape().lens() != cdims) { new_x_ins = mpm.get_module().insert_instruction( - x_ins, - make_op("multibroadcast", {{"out_lens", cdims}}), - new_x_ins); + x_ins, make_op("multibroadcast", {{"out_lens", cdims}}), new_x_ins); } auto inputs = ins->inputs(); diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 4ced663470f..015982bf2f7 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -110,17 +110,18 @@ shape_transform_descriptor shape_transform_descriptor::create(const std::vector< return result; } -shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector& dims) const +shape_transform_descriptor +shape_transform_descriptor::rebase(const std::vector& dims) const { auto result = *this; - for(auto& d:result.dimensions) + for(auto& d : result.dimensions) { - for(auto& sub:d.subdimensions) + for(auto& sub : d.subdimensions) { if(sub.axis.empty() and sub.hidden_axis.has_value()) { - sub.len = dims.at(sub.hidden_axis.value()); - sub.axis = {sub.hidden_axis.value()}; + sub.len = dims.at(sub.hidden_axis.value()); + sub.axis = {sub.hidden_axis.value()}; sub.hidden_axis = nullopt; } else if(sub.axis.size() == 1) @@ -1067,12 +1068,13 @@ std::vector> shape_transform_descriptor::common_axes_ma bool shape_transform_descriptor::empty() const { return dimensions.empty(); } -std::vector shape_transform_descriptor::lens() const -{ +std::vector shape_transform_descriptor::lens() const +{ std::vector result; - std::transform(dimensions.begin(), dimensions.end(), std::back_inserter(result), [](const dimension& d) { - return d.len(); - }); + std::transform(dimensions.begin(), + dimensions.end(), + std::back_inserter(result), + [](const dimension& d) { return d.len(); }); return result; } From edbe407b3ad267f7b9a13e81cfe79ddf2cc7662f Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 29 Oct 2024 15:40:40 -0700 Subject: [PATCH 20/37] Check for only broadcast --- src/include/migraphx/rewrite_reshapes.hpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index 241a22861ee..6f4eded55fd 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -136,12 +136,18 @@ struct rewrite_reshapes auto x_ins = r.instructions["x"]; auto input_ins = r.instructions["input"]; + // If its just a broadcast then skip + if(not any_input_of(input_ins, x_ins, [](instruction_ref x) { + return not contains({"multibroadcast", "broadcast", "contiguous"}, x->name()); + })) + return; + auto dims1 = T::base_dims(ins); auto dims2 = T::base_dims(x_ins); if(elements(dims1) != elements(dims2)) return; - + std::vector ops; auto next_ins = input_ins; while(next_ins != x_ins) @@ -152,11 +158,10 @@ struct rewrite_reshapes assert(next_ins == x_ins); std::reverse(ops.begin(), ops.end()); - auto desc = - shape_transform_descriptor::create(x_ins->get_shape().lens(), ops).rebase(dims2); + auto desc = shape_transform_descriptor::create(x_ins->get_shape().lens(), ops).rebase(dims2); if(desc.empty()) return; - auto cdims = desc.common_dims(); + auto cdims = desc.common_dims(); auto reshape_input = [&](const auto& ins_to_insert, auto generate) { return [&, generate](auto input) { auto gops = std::invoke(generate, desc, input->get_shape().lens()); @@ -178,7 +183,9 @@ struct rewrite_reshapes if(new_x_ins->get_shape().lens() != cdims) { new_x_ins = mpm.get_module().insert_instruction( - x_ins, make_op("multibroadcast", {{"out_lens", cdims}}), new_x_ins); + x_ins, + make_op("multibroadcast", {{"out_lens", cdims}}), + new_x_ins); } auto inputs = ins->inputs(); From 3ec0846a0f35f0bd12b080c79878de1adeeeb004 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 29 Oct 2024 15:40:44 -0700 Subject: [PATCH 21/37] Format --- src/include/migraphx/rewrite_reshapes.hpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index 6f4eded55fd..30496a7ddfa 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -138,8 +138,8 @@ struct rewrite_reshapes // If its just a broadcast then skip if(not any_input_of(input_ins, x_ins, [](instruction_ref x) { - return not contains({"multibroadcast", "broadcast", "contiguous"}, x->name()); - })) + return not contains({"multibroadcast", "broadcast", "contiguous"}, x->name()); + })) return; auto dims1 = T::base_dims(ins); @@ -147,7 +147,7 @@ struct rewrite_reshapes if(elements(dims1) != elements(dims2)) return; - + std::vector ops; auto next_ins = input_ins; while(next_ins != x_ins) @@ -158,10 +158,11 @@ struct rewrite_reshapes assert(next_ins == x_ins); std::reverse(ops.begin(), ops.end()); - auto desc = shape_transform_descriptor::create(x_ins->get_shape().lens(), ops).rebase(dims2); + auto desc = + shape_transform_descriptor::create(x_ins->get_shape().lens(), ops).rebase(dims2); if(desc.empty()) return; - auto cdims = desc.common_dims(); + auto cdims = desc.common_dims(); auto reshape_input = [&](const auto& ins_to_insert, auto generate) { return [&, generate](auto input) { auto gops = std::invoke(generate, desc, input->get_shape().lens()); @@ -183,9 +184,7 @@ struct rewrite_reshapes if(new_x_ins->get_shape().lens() != cdims) { new_x_ins = mpm.get_module().insert_instruction( - x_ins, - make_op("multibroadcast", {{"out_lens", cdims}}), - new_x_ins); + x_ins, make_op("multibroadcast", {{"out_lens", cdims}}), new_x_ins); } auto inputs = ins->inputs(); From 43005c55eeda7cad5a4a6163d6b6e40018486d7a Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 1 Nov 2024 15:36:46 -0700 Subject: [PATCH 22/37] Fix sizes --- src/shape_transform_descriptor.cpp | 43 +++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 015982bf2f7..030332d659b 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -84,6 +84,23 @@ static std::vector get_all_subdimensions(const std::vector +static void for_each_subdimension(Dimensions&& dimensions, Range&& r, F f) +{ + auto start = r.begin(); + auto last = r.end(); + for(auto& dim : dimensions) + { + for(auto& s:dim.subdimensions) + { + if(start == last) + return; + f(s, *start); + start++; + } + } +} + std::vector compute_dims(const operation& op, const std::vector& idims) { shape s{shape::float_type, idims}; @@ -943,27 +960,29 @@ std::vector shape_transform_descriptor::generate_common_from_dst( return d.subdimensions.size() == 1; })) return {}; - auto subs = get_all_subdimensions(dimensions); - // Map the input dims back to the src input if possible - std::vector src_input_dims(this->rank); - for(auto i : range(dimensions.size())) + std::vector subs; + // Update axes to point to the destination + for(std::size_t i : range(dimensions.size())) { const auto& d = dimensions[i]; - if(d.subdimensions.size() != 1) - continue; - const auto& sub = d.subdimensions.front(); - if(sub.axis.size() != 1) - continue; - auto axis = sub.axis.front(); - src_input_dims[axis] = input_dims[i]; + std::transform(d.subdimensions.begin(), d.subdimensions.end(), range(d.subdimensions.size()).begin(), std::back_inserter(subs), [&](dimension::sub s, auto j) { + s.axis = {i}; + if(d.subdimensions.size() > 1) + s.axis.push_back(j); + return s; + }); } - return {make_reshape_unsqueeze(subs, src_input_dims)}; + return {make_reshape_unsqueeze(subs, input_dims)}; } std::vector shape_transform_descriptor::generate_dst_from_common( const std::vector& input_dims) const { std::vector result; std::vector new_dims = dimensions; + for_each_subdimension(new_dims, input_dims, [&](auto& s, auto dim) { + s.len = dim; + }); + // Need broadcast if(std::any_of(new_dims.begin(), new_dims.end(), &is_broadcast_dim)) { From 7283cbf70657e075d3a0b9f281f1e07da43d1cd1 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 1 Nov 2024 15:36:54 -0700 Subject: [PATCH 23/37] Format --- src/shape_transform_descriptor.cpp | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 030332d659b..8d4afc85050 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -84,14 +84,14 @@ static std::vector get_all_subdimensions(const std::vector +template static void for_each_subdimension(Dimensions&& dimensions, Range&& r, F f) { auto start = r.begin(); - auto last = r.end(); + auto last = r.end(); for(auto& dim : dimensions) { - for(auto& s:dim.subdimensions) + for(auto& s : dim.subdimensions) { if(start == last) return; @@ -965,12 +965,16 @@ std::vector shape_transform_descriptor::generate_common_from_dst( for(std::size_t i : range(dimensions.size())) { const auto& d = dimensions[i]; - std::transform(d.subdimensions.begin(), d.subdimensions.end(), range(d.subdimensions.size()).begin(), std::back_inserter(subs), [&](dimension::sub s, auto j) { - s.axis = {i}; - if(d.subdimensions.size() > 1) - s.axis.push_back(j); - return s; - }); + std::transform(d.subdimensions.begin(), + d.subdimensions.end(), + range(d.subdimensions.size()).begin(), + std::back_inserter(subs), + [&](dimension::sub s, auto j) { + s.axis = {i}; + if(d.subdimensions.size() > 1) + s.axis.push_back(j); + return s; + }); } return {make_reshape_unsqueeze(subs, input_dims)}; } @@ -979,9 +983,7 @@ std::vector shape_transform_descriptor::generate_dst_from_common( { std::vector result; std::vector new_dims = dimensions; - for_each_subdimension(new_dims, input_dims, [&](auto& s, auto dim) { - s.len = dim; - }); + for_each_subdimension(new_dims, input_dims, [&](auto& s, auto dim) { s.len = dim; }); // Need broadcast if(std::any_of(new_dims.begin(), new_dims.end(), &is_broadcast_dim)) From 7693133d39707795958cc20f16ab55c0faff1b31 Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 2 Nov 2024 09:35:55 -0700 Subject: [PATCH 24/37] Update test --- src/include/migraphx/rewrite_reshapes.hpp | 1 + test/fuse_reduce.cpp | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index 30496a7ddfa..60cc9093822 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -136,6 +136,7 @@ struct rewrite_reshapes auto x_ins = r.instructions["x"]; auto input_ins = r.instructions["input"]; + // If its just a broadcast then skip if(not any_input_of(input_ins, x_ins, [](instruction_ref x) { return not contains({"multibroadcast", "broadcast", "contiguous"}, x->name()); diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 55fad66e5ac..160cd2a4b91 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -954,7 +954,7 @@ TEST_CASE(reshape_reduce_reduce_reduce_diff_axes) auto reduce0 = add_reduce( p2, "main:pointwise0:main:pointwise1:main:reduce_sum1:main:pointwise2:main:reduce_sum0:" - "main:pointwise3:main:pointwise4:main:pointwise5:main:pointwise6_reshape_reshape", + "main:pointwise3:main:pointwise4:main:pointwise5:main:pointwise6_reshape", {l2_mb, x1, x2, l1_mb}, {2}, [&](auto* rm, const auto& inputs, const auto& axes) { @@ -982,7 +982,7 @@ TEST_CASE(reshape_reduce_reduce_reduce_diff_axes) auto reduce1 = add_reduce(p2, - "main:reduce_sum2_reshape", + "main:reduce_sum2", {reduce0}, {1}, [&](auto* rm, const auto& inputs, const auto& axes) { From 15644a94d4eec273badd629446c5acb4f930c55b Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 2 Nov 2024 09:36:02 -0700 Subject: [PATCH 25/37] Remove line --- src/include/migraphx/rewrite_reshapes.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index 60cc9093822..30496a7ddfa 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -136,7 +136,6 @@ struct rewrite_reshapes auto x_ins = r.instructions["x"]; auto input_ins = r.instructions["input"]; - // If its just a broadcast then skip if(not any_input_of(input_ins, x_ins, [](instruction_ref x) { return not contains({"multibroadcast", "broadcast", "contiguous"}, x->name()); From f171e4f1d87e1ac8781e895e5369828e81a04659 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 4 Dec 2024 11:39:51 -0600 Subject: [PATCH 26/37] Use origin_axis --- src/shape_transform_descriptor.cpp | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index da7c074cb30..d40a60fcc0a 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -137,15 +137,15 @@ shape_transform_descriptor::rebase(const std::vector& dims) const { for(auto& sub : d.subdimensions) { - if(sub.axis.empty() and sub.hidden_axis.has_value()) + const auto& axis = sub.origin_axis(); + if(axis.size() == 1 or sub.has_hidden_axis()) { - sub.len = dims.at(sub.hidden_axis.value()); - sub.axis = {sub.hidden_axis.value()}; - sub.hidden_axis = nullopt; - } - else if(sub.axis.size() == 1) - { - sub.len = dims.at(sub.axis.front()); + sub.len = dims.at(axis.front()); + if(sub.has_hidden_axis()) + { + sub.axis = axis; + sub.hidden_axis.clear(); + } } } } @@ -1149,17 +1149,8 @@ std::vector> shape_transform_descriptor::common_axes_ma for(const auto& s : subs) { std::size_t axis = -1; - if(s.axis.empty()) - { - if(s.hidden_axis.has_value()) - axis = s.hidden_axis.value(); - else - continue; - } - else - { - axis = s.axis.front(); - } + if(not s.origin_axis().empty()) + axis = s.origin_axis().front(); axes_map[axis].push_back(&s); } for(auto&& p : axes_map) From fd4299e12db01ba9e576db7c7e919720557adb71 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 4 Dec 2024 12:03:33 -0600 Subject: [PATCH 27/37] Reuse create function --- src/shape_transform_descriptor.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index d40a60fcc0a..6b5d0690aac 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -1287,13 +1287,10 @@ std::ostream& operator<<(std::ostream& os, const shape_transform_descriptor& x) std::vector optimize_shape_transforms(const std::vector& dims, const std::vector& ops) { - shape_transform_descriptor sd{dims}; - if(not sd.apply(ops)) + auto sd = shape_transform_descriptor::create(dims, ops); + if(sd.empty()) return ops; - sd.simplify(); - auto result = sd.generate(); - assert(compute_dims(ops, dims) == compute_dims(result, dims)); - return result; + return sd.generate(); } } // namespace MIGRAPHX_INLINE_NS From 3d5515ec49d14de1372ce69c0d4ded7b3270479f Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 4 Dec 2024 12:06:03 -0600 Subject: [PATCH 28/37] Add default arguments --- src/include/migraphx/shape_transform_descriptor.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index ace2f4c1a28..7f2fc91e3dd 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -94,11 +94,11 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor std::vector common_dims(const std::vector& input_dims = {}) const; std::vector - generate_common_from_src(const std::vector& input_dims) const; + generate_common_from_src(const std::vector& input_dims = {}) const; std::vector - generate_common_from_dst(const std::vector& input_dims) const; + generate_common_from_dst(const std::vector& input_dims = {}) const; std::vector - generate_dst_from_common(const std::vector& input_dims) const; + generate_dst_from_common(const std::vector& input_dims = {}) const; std::vector> common_axes_map_from_src() const; std::vector> common_axes_map_from_dst() const; From 0a5776dea6007918cde1b0f452f98b7931d01a77 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 5 Dec 2024 10:35:40 -0600 Subject: [PATCH 29/37] Add unit tests --- test/shape_transform_descriptor.cpp | 106 ++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index d6168ecbe7c..1e69c5b139e 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -37,6 +37,7 @@ using d_axes = std::vector>; using ops = std::vector; using dimension = shape_transform_descriptor::dimension; using sub = dimension::sub; +using axes_map = std::vector>; all_lens get_all_lens(const shape_transform_descriptor& d) { @@ -562,4 +563,109 @@ TEST_CASE(optimize_squeeze_multibroadcast_transpose) }); } +TEST_CASE(common_dims_reshape_less) +{ + auto desc = make_descriptor({2, 32, 40, 8}, make_op("reshape", {{"dims", {2, 1280, 8}}})); + EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8}); + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1}, {2}, {3}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1, 2}, {3}}); + EXPECT(desc.generate_common_from_src() == ops{}); + EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 32, 40, 8}}})}); + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 1280, 8}}})}); +} + +TEST_CASE(common_dims_reshape1) +{ + auto desc = make_descriptor({2, 32, 2560}, make_op("reshape", {{"dims", {2, 1280, 8, 8}}})); + EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8, 8}); + EXPECT(desc.common_axes_map_from_src() == axes_map{{{0}, {1}, {2, 3, 4}}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1, 2}, {3}, {4}}); + EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 1280, 8, 8}}})}); +} + +TEST_CASE(common_dims_reshape2) +{ + auto desc = make_descriptor({2, 1280, 8, 8}, make_op("reshape", {{"dims", {2, 32, 2560}}})); + EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8, 8}); + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1, 2}, {3}, {4}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{{0}, {1}, {2, 3, 4}}}); + EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 32, 2560}}})}); +} + +TEST_CASE(common_dims_reshape3) +{ + auto desc = make_descriptor({2, 32, 4096}, make_op("reshape", {{"dims", {4, 16, 64, 64}}})); + + EXPECT(desc.common_dims() == final_lens{2, 2, 16, 64, 64}); + EXPECT(desc.common_dims({2, 1, 4096}) == final_lens{2, 1, 1, 64, 64}); + EXPECT(desc.common_dims({2, 32, 1}) == final_lens{2, 2, 16, 1, 1}); + + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1, 2}, {3, 4}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0, 1}, {2}, {3}, {4}}); + + EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_src({2, 32, 1}) == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_src({2, 1, 4096}) == ops{make_op("reshape", {{"dims", {2, 1, 1, 64, 64}}})}); + + EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); + + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {4, 16, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {4, 16, 1, 1}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == ops{make_op("squeeze", {{"axes", {1}}})}); +} + +TEST_CASE(common_dims_reshape4) +{ + auto desc = make_descriptor({4, 16, 64, 64}, make_op("reshape", {{"dims", {2, 32, 4096}}})); + + EXPECT(desc.common_dims() == final_lens{2, 2, 16, 64, 64}); + EXPECT(desc.common_dims({4, 16, 1, 1}) == final_lens{2, 2, 16, 1, 1}); + EXPECT(desc.common_dims({4, 1, 64, 64}) == final_lens{2, 2, 1, 64, 64}); + + EXPECT(desc.common_axes_map_from_src() == axes_map{{0, 1}, {2}, {3}, {4}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1, 2}, {3, 4}}); + + EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst({2, 32, 1}) == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_dst({2, 1, 4096}) == ops{make_op("reshape", {{"dims", {2, 1, 1, 64, 64}}})}); + + EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_src({4, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_src({4, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); + + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 32, 4096}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 2, 4096}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {2, 32, 1}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 16, 4096}}})}); +} + +TEST_CASE(common_dims_transpose_reshape) +{ + auto desc = make_descriptor({2, 16, 64, 64}, make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), make_op("reshape", {{"dims", {2, 32, 2048}}})); + EXPECT(desc.common_dims() == final_lens{2, 32, 2, 64, 16}); + + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {4}, {1, 2}, {3}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1}, {2, 3, 4}}); + + EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 32, 2, 64, 16}}})}); + EXPECT(desc.generate_common_from_dst({2, 32, 1}) == ops{make_op("unsqueeze", {{"axes", {3, 4}}})}); + EXPECT(desc.generate_common_from_dst({2, 1, 2048}) == ops{make_op("reshape", {{"dims", {2, 1, 2, 64, 16}}})}); + + EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 16, 32, 2, 64}}}), make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); + EXPECT(desc.generate_common_from_src({2, 16, 1, 1}) == ops{make_op("unsqueeze", {{"axes", {3}}}), make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); + EXPECT(desc.generate_common_from_src({2, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 1, 32, 2, 64}}}), make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); + + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 32, 2048}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 2, 64, 16}) == ops{make_op("reshape", {{"dims", {2, 1, 2048}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 1, 1, 16}) == ops{make_op("squeeze", {{"axes", {2, 3}}})}); + EXPECT(desc.generate_dst_from_common({2, 32, 2, 64, 1}) == ops{make_op("reshape", {{"dims", {2, 32, 128}}})}); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From a7f87a8b03ad92d21f1f3b021f64fb95b4141bdd Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 10 Dec 2024 17:13:36 -0600 Subject: [PATCH 30/37] Add more tests --- src/shape_transform_descriptor.cpp | 42 ++++-------------- test/shape_transform_descriptor.cpp | 69 ++++++++++++++++++++++++++--- 2 files changed, 71 insertions(+), 40 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 6b5d0690aac..6777716f3a0 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -1092,41 +1092,17 @@ std::vector shape_transform_descriptor::generate_dst_from_common( std::vector new_dims = dimensions; for_each_subdimension(new_dims, input_dims, [&](auto& s, auto dim) { s.len = dim; }); - // Need broadcast - if(std::any_of(new_dims.begin(), new_dims.end(), &is_broadcast_dim)) - { - std::vector out_lens; - std::transform(new_dims.begin(), - new_dims.end(), - std::back_inserter(out_lens), - [](const dimension& d) { return d.len(); }); - auto startb = std::find_if_not(new_dims.begin(), new_dims.end(), &has_no_axes); - auto trailb = std::find_if_not(startb, new_dims.end(), &has_axes); - auto axis = std::distance(new_dims.begin(), startb); - auto extra_dims = axis + std::distance(trailb, new_dims.end()); - // Use broadcast instead of multibroadcast - if(std::all_of(trailb, new_dims.end(), &has_no_axes) and extra_dims > 0 and - axis < new_dims.size()) - { - result.push_back(make_op("broadcast", {{"axis", axis}, {"out_lens", out_lens}})); - new_dims.erase(trailb, new_dims.end()); - new_dims.erase(new_dims.begin(), new_dims.begin() + axis); - } - else - { - result.push_back(make_op("multibroadcast", {{"out_lens", out_lens}})); - } - } - // If all the dimensions have no axes then there isnt anthing else to do - // so just clear the new_dims - if(std::all_of(new_dims.begin(), new_dims.end(), &has_no_axes)) - new_dims.clear(); - // Flatten broadcasted dimensions + // Remove broadcasted dimensions for(auto& d : new_dims) { if(d.subdimensions.size() != 1) continue; - flatten_broadcasted_dim(d.subdimensions.front()); + auto& s = d.subdimensions.front(); + if(s.axis.empty()) + { + s.axis = s.hidden_axis; + s.hidden_axis.clear(); + } } // Need squeeze reshape if(std::any_of(new_dims.begin(), new_dims.end(), [](const dimension& d) { @@ -1148,10 +1124,8 @@ std::vector> shape_transform_descriptor::common_axes_ma std::map> axes_map; for(const auto& s : subs) { - std::size_t axis = -1; if(not s.origin_axis().empty()) - axis = s.origin_axis().front(); - axes_map[axis].push_back(&s); + axes_map[s.origin_axis().front()].push_back(&s); } for(auto&& p : axes_map) { diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index 1e69c5b139e..28772c5fef8 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -115,6 +115,14 @@ shape_transform_descriptor make_descriptor(const std::vector& dims, return desc; } +template +shape_transform_descriptor make_simple_descriptor(const std::vector& dims, const Ts&... xs) +{ + auto desc = make_descriptor(dims, xs...); + desc.simplify(); + return desc; +} + TEST_CASE(dimension_len) { dimension dim; @@ -565,7 +573,7 @@ TEST_CASE(optimize_squeeze_multibroadcast_transpose) TEST_CASE(common_dims_reshape_less) { - auto desc = make_descriptor({2, 32, 40, 8}, make_op("reshape", {{"dims", {2, 1280, 8}}})); + auto desc = make_simple_descriptor({2, 32, 40, 8}, make_op("reshape", {{"dims", {2, 1280, 8}}})); EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8}); EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1}, {2}, {3}}); EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1, 2}, {3}}); @@ -576,7 +584,7 @@ TEST_CASE(common_dims_reshape_less) TEST_CASE(common_dims_reshape1) { - auto desc = make_descriptor({2, 32, 2560}, make_op("reshape", {{"dims", {2, 1280, 8, 8}}})); + auto desc = make_simple_descriptor({2, 32, 2560}, make_op("reshape", {{"dims", {2, 1280, 8, 8}}})); EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8, 8}); EXPECT(desc.common_axes_map_from_src() == axes_map{{{0}, {1}, {2, 3, 4}}}); EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1, 2}, {3}, {4}}); @@ -587,7 +595,7 @@ TEST_CASE(common_dims_reshape1) TEST_CASE(common_dims_reshape2) { - auto desc = make_descriptor({2, 1280, 8, 8}, make_op("reshape", {{"dims", {2, 32, 2560}}})); + auto desc = make_simple_descriptor({2, 1280, 8, 8}, make_op("reshape", {{"dims", {2, 32, 2560}}})); EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8, 8}); EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1, 2}, {3}, {4}}); EXPECT(desc.common_axes_map_from_dst() == axes_map{{{0}, {1}, {2, 3, 4}}}); @@ -598,7 +606,7 @@ TEST_CASE(common_dims_reshape2) TEST_CASE(common_dims_reshape3) { - auto desc = make_descriptor({2, 32, 4096}, make_op("reshape", {{"dims", {4, 16, 64, 64}}})); + auto desc = make_simple_descriptor({2, 32, 4096}, make_op("reshape", {{"dims", {4, 16, 64, 64}}})); EXPECT(desc.common_dims() == final_lens{2, 2, 16, 64, 64}); EXPECT(desc.common_dims({2, 1, 4096}) == final_lens{2, 1, 1, 64, 64}); @@ -623,7 +631,7 @@ TEST_CASE(common_dims_reshape3) TEST_CASE(common_dims_reshape4) { - auto desc = make_descriptor({4, 16, 64, 64}, make_op("reshape", {{"dims", {2, 32, 4096}}})); + auto desc = make_simple_descriptor({4, 16, 64, 64}, make_op("reshape", {{"dims", {2, 32, 4096}}})); EXPECT(desc.common_dims() == final_lens{2, 2, 16, 64, 64}); EXPECT(desc.common_dims({4, 16, 1, 1}) == final_lens{2, 2, 16, 1, 1}); @@ -648,7 +656,7 @@ TEST_CASE(common_dims_reshape4) TEST_CASE(common_dims_transpose_reshape) { - auto desc = make_descriptor({2, 16, 64, 64}, make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), make_op("reshape", {{"dims", {2, 32, 2048}}})); + auto desc = make_simple_descriptor({2, 16, 64, 64}, make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), make_op("reshape", {{"dims", {2, 32, 2048}}})); EXPECT(desc.common_dims() == final_lens{2, 32, 2, 64, 16}); EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {4}, {1, 2}, {3}}); @@ -668,4 +676,53 @@ TEST_CASE(common_dims_transpose_reshape) EXPECT(desc.generate_dst_from_common({2, 32, 2, 64, 1}) == ops{make_op("reshape", {{"dims", {2, 32, 128}}})}); } +TEST_CASE(common_dims_broadcast_reshape) +{ + auto desc = make_simple_descriptor({2, 32, 1}, make_op("multibroadcast", {{"out_lens", {2, 32, 4096}}}), make_op("reshape", {{"dims", {4, 16, 64, 64}}})); + + EXPECT(desc.common_dims() == final_lens{2, 2, 16, 64, 64}); + EXPECT(desc.common_dims({2, 1, 1}) == final_lens{2, 1, 1, 64, 64}); + EXPECT(desc.common_dims({2, 1, 4096}) == final_lens{2, 1, 1, 64, 64}); + EXPECT(desc.common_dims({2, 32, 4096}) == final_lens{2, 2, 16, 64, 64}); + + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1, 2}, {3, 4}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0, 1}, {2}, {3}, {4}}); + + EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}}), make_op("multibroadcast", {{"out_lens", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_src({2, 1, 1}) == ops{make_op("unsqueeze", {{"axes", {2, 4}}}), make_op("multibroadcast", {{"out_lens", {2, 1, 1, 64, 64}}})}); + + EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); + + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {4, 16, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {4, 16, 1, 1}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == ops{make_op("squeeze", {{"axes", {1}}})}); +} + +TEST_CASE(common_dims_resize) +{ + auto desc = make_simple_descriptor({4, 16, 32, 32}, make_op("reshape", {{"dims", {4, 16, 32, 1, 32, 1}}}), make_op("multibroadcast", {{"out_lens", {4, 16, 32, 2, 32, 2}}}), make_op("reshape", {{"dims", {4, 16, 64, 64}}})); + + EXPECT(desc.common_dims() == final_lens{4, 16, 32, 2, 32, 2}); + EXPECT(desc.common_dims({4, 16, 1, 1}) == final_lens{4, 16, 1, 2, 1, 2}); + EXPECT(desc.common_dims({4, 1, 32, 32}) == final_lens{4, 1, 32, 2, 32, 2}); + + EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1}, {2}, {4}}); + EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1}, {2, 3}, {4, 5}}); + + EXPECT(desc.generate_common_from_src() == ops{make_op("unsqueeze", {{"axes", {3, 5}}}), make_op("multibroadcast", {{"out_lens", {4, 16, 32, 2, 32, 2}}})}); + EXPECT(desc.generate_common_from_src({4, 16, 1, 1}) == ops{make_op("unsqueeze", {{"axes", {3, 5}}}), make_op("multibroadcast", {{"out_lens", {4, 16, 1, 2, 1, 2}}})}); + EXPECT(desc.generate_common_from_src({4, 1, 32, 32}) == ops{make_op("unsqueeze", {{"axes", {3, 5}}}), make_op("multibroadcast", {{"out_lens", {4, 1, 32, 2, 32, 2}}})}); + + EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {4, 16, 32, 2, 32, 2}}})}); + EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == ops{make_op("unsqueeze", {{"axes", {3, 5}}})}); + EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {4, 1, 32, 2, 32, 2}}})}); + + EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {4, 16, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({4, 16, 1, 2, 1, 2}) == ops{make_op("squeeze", {{"axes", {2, 4}}})}); + EXPECT(desc.generate_dst_from_common({4, 1, 32, 2, 32, 2}) == ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 4560be92dc2118df9b9baf9e000c50a6f91ae244 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 10 Dec 2024 19:03:28 -0600 Subject: [PATCH 31/37] Fix rebase --- .../migraphx/shape_transform_descriptor.hpp | 3 + src/shape_transform_descriptor.cpp | 86 +++++++++++-------- 2 files changed, 55 insertions(+), 34 deletions(-) diff --git a/src/include/migraphx/shape_transform_descriptor.hpp b/src/include/migraphx/shape_transform_descriptor.hpp index 7f2fc91e3dd..dd1ff256b9b 100644 --- a/src/include/migraphx/shape_transform_descriptor.hpp +++ b/src/include/migraphx/shape_transform_descriptor.hpp @@ -126,6 +126,9 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor void add_split_axis(std::size_t i); + void expose(); + void hide(); + MIGRAPHX_EXPORT friend bool operator==(const sub& x, const sub& y); MIGRAPHX_EXPORT friend bool operator!=(const sub& x, const sub& y); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const sub& x); diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 6777716f3a0..47fa1e071a1 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -103,6 +103,24 @@ static void for_each_subdimension(Dimensions&& dimensions, Range&& r, F f) } } +// Group all axes into a map with a key of the axis and the value is vector of +// all subdimensions that have that axis. +static std::map> +group_axes(std::vector& dimensions) +{ + std::map> axes_map; + for(auto& d : dimensions) + { + for(auto& s : d.subdimensions) + { + if(s.origin_axis().empty()) + continue; + axes_map[s.origin_axis().front()].push_back(&s); + } + } + return axes_map; +} + std::vector compute_dims(const operation& op, const std::vector& idims) { shape s{shape::float_type, idims}; @@ -133,28 +151,28 @@ shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector& dims) const { auto result = *this; - for(auto& d : result.dimensions) + auto axes_map = group_axes(result.dimensions); + for(auto&[axis, subs]:axes_map) { - for(auto& sub : d.subdimensions) + assert(axis < dims.size()); + auto dim = dims[axis]; + auto final_dim = transform_accumulate(subs.begin(), subs.end(), std::size_t{1}, std::multiplies<>{}, [](const dimension::sub* s) { + return s->len; + }); + if(dim == final_dim) + { + for(auto* sub:subs) + sub->expose(); + } + else if(dim == 1) { - const auto& axis = sub.origin_axis(); - if(axis.size() == 1 or sub.has_hidden_axis()) + for(auto* sub:subs) { - sub.len = dims.at(axis.front()); - if(sub.has_hidden_axis()) - { - sub.axis = axis; - sub.hidden_axis.clear(); - } + if(not sub->has_hidden_axis()) + sub->len = 1; } } } - // TODO: Handle resizes - result.flatten_broadcast(); - if(not result.apply_reshape(this->common_dims())) - return {}; - if(not result.apply_reshape(this->lens())) - return {}; result.simplify(); return result; @@ -456,24 +474,6 @@ static void set_broadcast_dim(dimension& d, std::size_t axis) } } -// Group all axes into a map with a key of the axis and the value is vector of -// all subdimensions that have that axis. -static std::map> -group_axes(std::vector& dimensions) -{ - std::map> axes_map; - for(auto& d : dimensions) - { - for(auto& s : d.subdimensions) - { - if(s.origin_axis().empty()) - continue; - axes_map[s.origin_axis().front()].push_back(&s); - } - } - return axes_map; -} - static void set_origin_axis(dimension::sub& s, const std::vector& axis) { if(s.has_hidden_axis()) @@ -1219,6 +1219,24 @@ void shape_transform_descriptor::dimension::sub::add_split_axis(std::size_t i) hidden_axis.push_back(i); } +void shape_transform_descriptor::dimension::sub::expose() +{ + if(has_hidden_axis()) + { + axis = hidden_axis; + hidden_axis.clear(); + } +} + +void shape_transform_descriptor::dimension::sub::hide() +{ + if(not has_hidden_axis()) + { + hidden_axis = axis; + axis.clear(); + } +} + bool operator==(const dimension::sub& x, const dimension::sub& y) { return by(std::equal_to<>{}, From 73c56970c7e879150de7dcaa803bb172ef1e4e7c Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 11 Dec 2024 12:22:18 -0600 Subject: [PATCH 32/37] Add unit tests for rebase --- src/shape_transform_descriptor.cpp | 7 +++++++ test/shape_transform_descriptor.cpp | 26 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 47fa1e071a1..d4420857352 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -172,6 +172,13 @@ shape_transform_descriptor::rebase(const std::vector& dims) const sub->len = 1; } } + else if(subs.size() == 1) + { + subs.front()->len = dim; + subs.front()->expose(); + } + else + MIGRAPHX_THROW("Invalid rebase"); } result.simplify(); diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index 28772c5fef8..a11af01f5de 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -725,4 +725,30 @@ TEST_CASE(common_dims_resize) EXPECT(desc.generate_dst_from_common({4, 1, 32, 2, 32, 2}) == ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); } +TEST_CASE(rebase_reshape_broadcast) +{ + auto base_desc = make_simple_descriptor({3, 4, 64, 1}, make_op("reshape", {{"dims", {12, 8, 8, 1, 1}}}), make_op("multibroadcast", {{"out_lens", {12, 8, 8, 2, 2}}})); + + { + auto desc = base_desc.rebase({3, 4, 64, 4}); + EXPECT(get_final_lens(desc) == final_lens{12, 8, 8, 2, 2}); + EXPECT(get_all_lens(desc) == all_lens{{3, 4}, {8}, {8}, {2}, {2}}); + EXPECT(desc.generate() == ops{make_op("reshape", {{"dims", {3, 4, 8, 8, 2, 2}}}), make_op("reshape", {{"dims", {12, 8, 8, 2, 2}}})}); + } + + { + auto desc = base_desc.rebase({3, 5, 64, 1}); + EXPECT(get_final_lens(desc) == final_lens{15, 8, 8, 2, 2}); + EXPECT(get_all_lens(desc) == all_lens{{3, 5}, {8}, {8}, {2}, {2}}); + EXPECT(desc.generate() == ops{make_op("reshape", {{"dims", {3, 5, 8, 8, 1, 1}}}), make_op("reshape", {{"dims", {15, 8, 8, 1, 1}}}), make_op("multibroadcast", {{"out_lens", {15, 8, 8, 2, 2}}})}); + } + + { + auto desc = base_desc.rebase({3, 4, 1, 1}); + EXPECT(get_final_lens(desc) == final_lens{12, 1, 1, 2, 2}); + EXPECT(get_all_lens(desc) == all_lens{{3, 4}, {1}, {1}, {2}, {2}}); + EXPECT(desc.generate() == ops{make_op("unsqueeze", {{"axes", {3, 5}}}), make_op("reshape", {{"dims", {12, 1, 1, 1, 1}}}), make_op("multibroadcast", {{"out_lens", {12, 1, 1, 2, 2}}})}); + } +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From df9a9beb2320656f45cf5f31906210391c885c86 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 11 Dec 2024 13:09:43 -0600 Subject: [PATCH 33/37] Reuse code --- src/shape_transform_descriptor.cpp | 182 +++++++++++++---------------- 1 file changed, 81 insertions(+), 101 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index d4420857352..be5eeed512f 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -869,6 +869,30 @@ static operation make_reshape_unsqueeze(const std::vector& subs, } } +namespace { +struct operation_list +{ + std::vector ops; + + void push_back(const operation& op) + { + ops.push_back(op); + } + + std::vector to_vector() const & + { + return {ops.crbegin(), ops.crend()}; + } + + std::vector to_vector() && + { + std::reverse(ops.begin(), ops.end()); + return std::move(ops);; + } +}; + +} // namespace + static bool has_no_axes(const dimension& d) { return std::all_of(d.subdimensions.begin(), d.subdimensions.end(), [](const dimension::sub& s) { @@ -882,6 +906,57 @@ static bool has_axes(const dimension& d) }); } +static void generate_from_subdimensions(operation_list& result, std::vector subs, const std::vector& input_dims = {}) +{ + // Need multibroadcast + if(std::any_of(subs.begin(), subs.end(), [&](const dimension::sub& s) { + return s.axis.empty() and get_len(s, input_dims) != 1; + })) + { + std::vector out_lens; + std::transform(subs.begin(), + subs.end(), + std::back_inserter(out_lens), + [&](const dimension::sub& s) { return get_len(s, input_dims); }); + result.push_back(make_op("multibroadcast", {{"out_lens", out_lens}})); + } + + // Flatten broadcasted subdimensions + std::for_each(subs.begin(), subs.end(), &flatten_broadcasted_dim); + + auto tsubs = subs; + // Inject additonal axis to compute transpose permutation better + auto is_empty_axis = [](const auto& s) { return s.axis.empty(); }; + group_find(tsubs.begin(), tsubs.end(), is_empty_axis, [&](auto start, auto last) { + if(start == tsubs.begin()) + return; + auto base = std::prev(start); + auto axis = base->axis; + axis.push_back(0); + std::for_each(start, last, [&](auto& s) { + s.axis = axis; + axis.back()++; + }); + }); + + auto compare_sub = [](auto f) { + return by(f, [](const dimension::sub& s) -> const auto& { return s.axis; }); + }; + // Need transpose + if(not std::is_sorted(tsubs.begin(), tsubs.end(), compare_sub(std::less<>{}))) + { + auto permutation = sort_permutation(tsubs, compare_sub(std::less<>{})); + result.push_back(make_op("transpose", {{"permutation", invert_permutation(permutation)}})); + subs = reorder_dims(subs, permutation); + } + // Need reshape unsqueeze + if(std::any_of( + subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() != 1; })) + { + result.push_back(make_reshape_unsqueeze(subs, input_dims)); + } +} + // This will generate the operators to apply the shape transformation that is // represented by this class. This is the order of operators that will be // generated if needed: @@ -897,7 +972,7 @@ static bool has_axes(const dimension& d) // dimensions. std::vector shape_transform_descriptor::generate() const { - std::vector result; + operation_list result; std::vector new_dims = dimensions; // Need broadcast if(std::any_of(new_dims.begin(), new_dims.end(), &is_broadcast_dim)) @@ -946,55 +1021,8 @@ std::vector shape_transform_descriptor::generate() const } auto subs = get_all_subdimensions(new_dims); - // Need multibroadcast - if(std::any_of(subs.begin(), subs.end(), [](const dimension::sub& s) { - return s.axis.empty() and s.len != 1; - })) - { - std::vector out_lens; - std::transform(subs.begin(), - subs.end(), - std::back_inserter(out_lens), - [](const dimension::sub& s) { return s.len; }); - result.push_back(make_op("multibroadcast", {{"out_lens", out_lens}})); - } - - // Flatten broadcasted subdimensions - std::for_each(subs.begin(), subs.end(), &flatten_broadcasted_dim); - - auto tsubs = subs; - // Inject additonal axis to compute transpose permutation better - auto is_empty_axis = [](const auto& s) { return s.axis.empty(); }; - group_find(tsubs.begin(), tsubs.end(), is_empty_axis, [&](auto start, auto last) { - if(start == tsubs.begin()) - return; - auto base = std::prev(start); - auto axis = base->axis; - axis.push_back(0); - std::for_each(start, last, [&](auto& s) { - s.axis = axis; - axis.back()++; - }); - }); - - auto compare_sub = [](auto f) { - return by(f, [](const dimension::sub& s) -> const auto& { return s.axis; }); - }; - // Need transpose - if(not std::is_sorted(tsubs.begin(), tsubs.end(), compare_sub(std::less<>{}))) - { - auto permutation = sort_permutation(tsubs, compare_sub(std::less<>{})); - result.push_back(make_op("transpose", {{"permutation", invert_permutation(permutation)}})); - subs = reorder_dims(subs, permutation); - } - // Need reshape unsqueeze - if(std::any_of( - subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() != 1; })) - { - result.push_back(make_reshape_unsqueeze(subs)); - } - std::reverse(result.begin(), result.end()); - return result; + generate_from_subdimensions(result, subs); + return std::move(result).to_vector(); } bool shape_transform_descriptor::has_broadcast() const @@ -1014,57 +1042,10 @@ void shape_transform_descriptor::flatten_broadcast() std::vector shape_transform_descriptor::generate_common_from_src( const std::vector& input_dims) const { - std::vector result; + operation_list result; auto subs = get_all_subdimensions(dimensions); - // Need multibroadcast - if(std::any_of(subs.begin(), subs.end(), [&](const dimension::sub& s) { - return s.axis.empty() and get_len(s, input_dims) != 1; - })) - { - std::vector out_lens; - std::transform(subs.begin(), - subs.end(), - std::back_inserter(out_lens), - [&](const dimension::sub& s) { return get_len(s, input_dims); }); - result.push_back(make_op("multibroadcast", {{"out_lens", out_lens}})); - } - - // Flatten broadcasted subdimensions - std::for_each(subs.begin(), subs.end(), &flatten_broadcasted_dim); - - auto tsubs = subs; - // Inject additonal axis to compute transpose permutation better - auto is_empty_axis = [](const auto& s) { return s.axis.empty(); }; - group_find(tsubs.begin(), tsubs.end(), is_empty_axis, [&](auto start, auto last) { - if(start == tsubs.begin()) - return; - auto base = std::prev(start); - auto axis = base->axis; - axis.push_back(0); - std::for_each(start, last, [&](auto& s) { - s.axis = axis; - axis.back()++; - }); - }); - - auto compare_sub = [](auto f) { - return by(f, [](const dimension::sub& s) -> const auto& { return s.axis; }); - }; - // Need transpose - if(not std::is_sorted(tsubs.begin(), tsubs.end(), compare_sub(std::less<>{}))) - { - auto permutation = sort_permutation(tsubs, compare_sub(std::less<>{})); - result.push_back(make_op("transpose", {{"permutation", invert_permutation(permutation)}})); - subs = reorder_dims(subs, permutation); - } - // Need reshape unsqueeze - if(std::any_of( - subs.begin(), subs.end(), [](const dimension::sub& s) { return s.axis.size() != 1; })) - { - result.push_back(make_reshape_unsqueeze(subs, input_dims)); - } - std::reverse(result.begin(), result.end()); - return result; + generate_from_subdimensions(result, subs, input_dims); + return std::move(result).to_vector(); } std::vector shape_transform_descriptor::generate_common_from_dst( const std::vector& input_dims) const @@ -1120,7 +1101,6 @@ std::vector shape_transform_descriptor::generate_dst_from_common( { result.push_back(make_reshape_squeeze(new_dims)); } - std::reverse(result.begin(), result.end()); return result; } From 4767d16c9e4122afac82d233aa948e0b8a1690ec Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 11 Dec 2024 13:18:08 -0600 Subject: [PATCH 34/37] Use expose/hide API --- src/shape_transform_descriptor.cpp | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index be5eeed512f..e0be62931e4 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -361,11 +361,7 @@ bool shape_transform_descriptor::apply_broadcast(const std::vector& } for(auto& s : new_subs) { - if(not s.axis.empty()) - { - s.hidden_axis = s.axis; - s.axis.clear(); - } + s.hide(); } return {new_subs}; }); @@ -798,8 +794,7 @@ static void flatten_broadcasted_dim(dimension::sub& s) if(s.axis.empty()) { s.len = 1; - s.axis = s.hidden_axis; - s.hidden_axis.clear(); + s.expose(); } } @@ -1086,11 +1081,7 @@ std::vector shape_transform_descriptor::generate_dst_from_common( if(d.subdimensions.size() != 1) continue; auto& s = d.subdimensions.front(); - if(s.axis.empty()) - { - s.axis = s.hidden_axis; - s.hidden_axis.clear(); - } + s.expose(); } // Need squeeze reshape if(std::any_of(new_dims.begin(), new_dims.end(), [](const dimension& d) { From 5e0ee644b78c6d43ca19c62fcaddd54299a2e40b Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 11 Dec 2024 13:23:33 -0600 Subject: [PATCH 35/37] Format --- src/shape_transform_descriptor.cpp | 37 +++-- test/shape_transform_descriptor.cpp | 205 +++++++++++++++++++--------- 2 files changed, 159 insertions(+), 83 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index e0be62931e4..89d88b50fec 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -150,23 +150,25 @@ shape_transform_descriptor shape_transform_descriptor::create(const std::vector< shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector& dims) const { - auto result = *this; + auto result = *this; auto axes_map = group_axes(result.dimensions); - for(auto&[axis, subs]:axes_map) + for(auto& [axis, subs] : axes_map) { assert(axis < dims.size()); - auto dim = dims[axis]; - auto final_dim = transform_accumulate(subs.begin(), subs.end(), std::size_t{1}, std::multiplies<>{}, [](const dimension::sub* s) { - return s->len; - }); + auto dim = dims[axis]; + auto final_dim = transform_accumulate(subs.begin(), + subs.end(), + std::size_t{1}, + std::multiplies<>{}, + [](const dimension::sub* s) { return s->len; }); if(dim == final_dim) { - for(auto* sub:subs) + for(auto* sub : subs) sub->expose(); } else if(dim == 1) { - for(auto* sub:subs) + for(auto* sub : subs) { if(not sub->has_hidden_axis()) sub->len = 1; @@ -361,7 +363,7 @@ bool shape_transform_descriptor::apply_broadcast(const std::vector& } for(auto& s : new_subs) { - s.hide(); + s.hide(); } return {new_subs}; }); @@ -869,20 +871,15 @@ struct operation_list { std::vector ops; - void push_back(const operation& op) - { - ops.push_back(op); - } + void push_back(const operation& op) { ops.push_back(op); } - std::vector to_vector() const & - { - return {ops.crbegin(), ops.crend()}; - } + std::vector to_vector() const& { return {ops.crbegin(), ops.crend()}; } std::vector to_vector() && { std::reverse(ops.begin(), ops.end()); - return std::move(ops);; + return std::move(ops); + ; } }; @@ -901,7 +898,9 @@ static bool has_axes(const dimension& d) }); } -static void generate_from_subdimensions(operation_list& result, std::vector subs, const std::vector& input_dims = {}) +static void generate_from_subdimensions(operation_list& result, + std::vector subs, + const std::vector& input_dims = {}) { // Need multibroadcast if(std::any_of(subs.begin(), subs.end(), [&](const dimension::sub& s) { diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index a11af01f5de..ca9be15c80d 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -37,7 +37,7 @@ using d_axes = std::vector>; using ops = std::vector; using dimension = shape_transform_descriptor::dimension; using sub = dimension::sub; -using axes_map = std::vector>; +using axes_map = std::vector>; all_lens get_all_lens(const shape_transform_descriptor& d) { @@ -116,7 +116,8 @@ shape_transform_descriptor make_descriptor(const std::vector& dims, } template -shape_transform_descriptor make_simple_descriptor(const std::vector& dims, const Ts&... xs) +shape_transform_descriptor make_simple_descriptor(const std::vector& dims, + const Ts&... xs) { auto desc = make_descriptor(dims, xs...); desc.simplify(); @@ -573,7 +574,8 @@ TEST_CASE(optimize_squeeze_multibroadcast_transpose) TEST_CASE(common_dims_reshape_less) { - auto desc = make_simple_descriptor({2, 32, 40, 8}, make_op("reshape", {{"dims", {2, 1280, 8}}})); + auto desc = + make_simple_descriptor({2, 32, 40, 8}, make_op("reshape", {{"dims", {2, 1280, 8}}})); EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8}); EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1}, {2}, {3}}); EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1, 2}, {3}}); @@ -584,29 +586,36 @@ TEST_CASE(common_dims_reshape_less) TEST_CASE(common_dims_reshape1) { - auto desc = make_simple_descriptor({2, 32, 2560}, make_op("reshape", {{"dims", {2, 1280, 8, 8}}})); + auto desc = + make_simple_descriptor({2, 32, 2560}, make_op("reshape", {{"dims", {2, 1280, 8, 8}}})); EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8, 8}); EXPECT(desc.common_axes_map_from_src() == axes_map{{{0}, {1}, {2, 3, 4}}}); EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1, 2}, {3}, {4}}); - EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); - EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 1280, 8, 8}}})}); } TEST_CASE(common_dims_reshape2) { - auto desc = make_simple_descriptor({2, 1280, 8, 8}, make_op("reshape", {{"dims", {2, 32, 2560}}})); + auto desc = + make_simple_descriptor({2, 1280, 8, 8}, make_op("reshape", {{"dims", {2, 32, 2560}}})); EXPECT(desc.common_dims() == final_lens{2, 32, 40, 8, 8}); EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1, 2}, {3}, {4}}); EXPECT(desc.common_axes_map_from_dst() == axes_map{{{0}, {1}, {2, 3, 4}}}); - EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); - EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 32, 40, 8, 8}}})}); EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 32, 2560}}})}); } TEST_CASE(common_dims_reshape3) { - auto desc = make_simple_descriptor({2, 32, 4096}, make_op("reshape", {{"dims", {4, 16, 64, 64}}})); + auto desc = + make_simple_descriptor({2, 32, 4096}, make_op("reshape", {{"dims", {4, 16, 64, 64}}})); EXPECT(desc.common_dims() == final_lens{2, 2, 16, 64, 64}); EXPECT(desc.common_dims({2, 1, 4096}) == final_lens{2, 1, 1, 64, 64}); @@ -615,23 +624,33 @@ TEST_CASE(common_dims_reshape3) EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1, 2}, {3, 4}}); EXPECT(desc.common_axes_map_from_dst() == axes_map{{0, 1}, {2}, {3}, {4}}); - EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); - EXPECT(desc.generate_common_from_src({2, 32, 1}) == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); - EXPECT(desc.generate_common_from_src({2, 1, 4096}) == ops{make_op("reshape", {{"dims", {2, 1, 1, 64, 64}}})}); + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_src({2, 32, 1}) == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_src({2, 1, 4096}) == + ops{make_op("reshape", {{"dims", {2, 1, 1, 64, 64}}})}); - EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); - EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); - EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {4, 16, 64, 64}}})}); - EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); - EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {4, 16, 1, 1}}})}); - EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == ops{make_op("squeeze", {{"axes", {1}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {4, 16, 1, 1}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == + ops{make_op("squeeze", {{"axes", {1}}})}); } TEST_CASE(common_dims_reshape4) { - auto desc = make_simple_descriptor({4, 16, 64, 64}, make_op("reshape", {{"dims", {2, 32, 4096}}})); + auto desc = + make_simple_descriptor({4, 16, 64, 64}, make_op("reshape", {{"dims", {2, 32, 4096}}})); EXPECT(desc.common_dims() == final_lens{2, 2, 16, 64, 64}); EXPECT(desc.common_dims({4, 16, 1, 1}) == final_lens{2, 2, 16, 1, 1}); @@ -640,45 +659,70 @@ TEST_CASE(common_dims_reshape4) EXPECT(desc.common_axes_map_from_src() == axes_map{{0, 1}, {2}, {3}, {4}}); EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1, 2}, {3, 4}}); - EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); - EXPECT(desc.generate_common_from_dst({2, 32, 1}) == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); - EXPECT(desc.generate_common_from_dst({2, 1, 4096}) == ops{make_op("reshape", {{"dims", {2, 1, 1, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst({2, 32, 1}) == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_dst({2, 1, 4096}) == + ops{make_op("reshape", {{"dims", {2, 1, 1, 64, 64}}})}); - EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); - EXPECT(desc.generate_common_from_src({4, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); - EXPECT(desc.generate_common_from_src({4, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_src({4, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_src({4, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 32, 4096}}})}); - EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 2, 4096}}})}); - EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {2, 32, 1}}})}); - EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 16, 4096}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 2, 4096}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {2, 32, 1}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 16, 4096}}})}); } TEST_CASE(common_dims_transpose_reshape) { - auto desc = make_simple_descriptor({2, 16, 64, 64}, make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), make_op("reshape", {{"dims", {2, 32, 2048}}})); + auto desc = make_simple_descriptor({2, 16, 64, 64}, + make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), + make_op("reshape", {{"dims", {2, 32, 2048}}})); EXPECT(desc.common_dims() == final_lens{2, 32, 2, 64, 16}); EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {4}, {1, 2}, {3}}); EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1}, {2, 3, 4}}); - EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 32, 2, 64, 16}}})}); - EXPECT(desc.generate_common_from_dst({2, 32, 1}) == ops{make_op("unsqueeze", {{"axes", {3, 4}}})}); - EXPECT(desc.generate_common_from_dst({2, 1, 2048}) == ops{make_op("reshape", {{"dims", {2, 1, 2, 64, 16}}})}); - - EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 16, 32, 2, 64}}}), make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); - EXPECT(desc.generate_common_from_src({2, 16, 1, 1}) == ops{make_op("unsqueeze", {{"axes", {3}}}), make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); - EXPECT(desc.generate_common_from_src({2, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 1, 32, 2, 64}}}), make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 32, 2, 64, 16}}})}); + EXPECT(desc.generate_common_from_dst({2, 32, 1}) == + ops{make_op("unsqueeze", {{"axes", {3, 4}}})}); + EXPECT(desc.generate_common_from_dst({2, 1, 2048}) == + ops{make_op("reshape", {{"dims", {2, 1, 2, 64, 16}}})}); + + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 16, 32, 2, 64}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); + EXPECT(desc.generate_common_from_src({2, 16, 1, 1}) == + ops{make_op("unsqueeze", {{"axes", {3}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); + EXPECT(desc.generate_common_from_src({2, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 1, 32, 2, 64}}}), + make_op("transpose", {{"permutation", {0, 2, 3, 4, 1}}})}); EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {2, 32, 2048}}})}); - EXPECT(desc.generate_dst_from_common({2, 1, 2, 64, 16}) == ops{make_op("reshape", {{"dims", {2, 1, 2048}}})}); - EXPECT(desc.generate_dst_from_common({2, 1, 1, 1, 16}) == ops{make_op("squeeze", {{"axes", {2, 3}}})}); - EXPECT(desc.generate_dst_from_common({2, 32, 2, 64, 1}) == ops{make_op("reshape", {{"dims", {2, 32, 128}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 2, 64, 16}) == + ops{make_op("reshape", {{"dims", {2, 1, 2048}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 1, 1, 16}) == + ops{make_op("squeeze", {{"axes", {2, 3}}})}); + EXPECT(desc.generate_dst_from_common({2, 32, 2, 64, 1}) == + ops{make_op("reshape", {{"dims", {2, 32, 128}}})}); } TEST_CASE(common_dims_broadcast_reshape) { - auto desc = make_simple_descriptor({2, 32, 1}, make_op("multibroadcast", {{"out_lens", {2, 32, 4096}}}), make_op("reshape", {{"dims", {4, 16, 64, 64}}})); + auto desc = make_simple_descriptor({2, 32, 1}, + make_op("multibroadcast", {{"out_lens", {2, 32, 4096}}}), + make_op("reshape", {{"dims", {4, 16, 64, 64}}})); EXPECT(desc.common_dims() == final_lens{2, 2, 16, 64, 64}); EXPECT(desc.common_dims({2, 1, 1}) == final_lens{2, 1, 1, 64, 64}); @@ -688,23 +732,37 @@ TEST_CASE(common_dims_broadcast_reshape) EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1, 2}, {3, 4}}); EXPECT(desc.common_axes_map_from_dst() == axes_map{{0, 1}, {2}, {3}, {4}}); - EXPECT(desc.generate_common_from_src() == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}}), make_op("multibroadcast", {{"out_lens", {2, 2, 16, 64, 64}}})}); - EXPECT(desc.generate_common_from_src({2, 1, 1}) == ops{make_op("unsqueeze", {{"axes", {2, 4}}}), make_op("multibroadcast", {{"out_lens", {2, 1, 1, 64, 64}}})}); + EXPECT(desc.generate_common_from_src() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}}), + make_op("multibroadcast", {{"out_lens", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_src({2, 1, 1}) == + ops{make_op("unsqueeze", {{"axes", {2, 4}}}), + make_op("multibroadcast", {{"out_lens", {2, 1, 1, 64, 64}}})}); - EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); - EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); - EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {2, 2, 16, 64, 64}}})}); + EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {2, 2, 16, 1, 1}}})}); + EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {2, 2, 1, 64, 64}}})}); EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {4, 16, 64, 64}}})}); - EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); - EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == ops{make_op("reshape", {{"dims", {4, 16, 1, 1}}})}); - EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == ops{make_op("squeeze", {{"axes", {1}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({2, 2, 16, 1, 1}) == + ops{make_op("reshape", {{"dims", {4, 16, 1, 1}}})}); + EXPECT(desc.generate_dst_from_common({2, 1, 16, 64, 64}) == + ops{make_op("squeeze", {{"axes", {1}}})}); } TEST_CASE(common_dims_resize) { - auto desc = make_simple_descriptor({4, 16, 32, 32}, make_op("reshape", {{"dims", {4, 16, 32, 1, 32, 1}}}), make_op("multibroadcast", {{"out_lens", {4, 16, 32, 2, 32, 2}}}), make_op("reshape", {{"dims", {4, 16, 64, 64}}})); - + auto desc = + make_simple_descriptor({4, 16, 32, 32}, + make_op("reshape", {{"dims", {4, 16, 32, 1, 32, 1}}}), + make_op("multibroadcast", {{"out_lens", {4, 16, 32, 2, 32, 2}}}), + make_op("reshape", {{"dims", {4, 16, 64, 64}}})); + EXPECT(desc.common_dims() == final_lens{4, 16, 32, 2, 32, 2}); EXPECT(desc.common_dims({4, 16, 1, 1}) == final_lens{4, 16, 1, 2, 1, 2}); EXPECT(desc.common_dims({4, 1, 32, 32}) == final_lens{4, 1, 32, 2, 32, 2}); @@ -712,42 +770,61 @@ TEST_CASE(common_dims_resize) EXPECT(desc.common_axes_map_from_src() == axes_map{{0}, {1}, {2}, {4}}); EXPECT(desc.common_axes_map_from_dst() == axes_map{{0}, {1}, {2, 3}, {4, 5}}); - EXPECT(desc.generate_common_from_src() == ops{make_op("unsqueeze", {{"axes", {3, 5}}}), make_op("multibroadcast", {{"out_lens", {4, 16, 32, 2, 32, 2}}})}); - EXPECT(desc.generate_common_from_src({4, 16, 1, 1}) == ops{make_op("unsqueeze", {{"axes", {3, 5}}}), make_op("multibroadcast", {{"out_lens", {4, 16, 1, 2, 1, 2}}})}); - EXPECT(desc.generate_common_from_src({4, 1, 32, 32}) == ops{make_op("unsqueeze", {{"axes", {3, 5}}}), make_op("multibroadcast", {{"out_lens", {4, 1, 32, 2, 32, 2}}})}); - - EXPECT(desc.generate_common_from_dst() == ops{make_op("reshape", {{"dims", {4, 16, 32, 2, 32, 2}}})}); - EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == ops{make_op("unsqueeze", {{"axes", {3, 5}}})}); - EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == ops{make_op("reshape", {{"dims", {4, 1, 32, 2, 32, 2}}})}); + EXPECT(desc.generate_common_from_src() == + ops{make_op("unsqueeze", {{"axes", {3, 5}}}), + make_op("multibroadcast", {{"out_lens", {4, 16, 32, 2, 32, 2}}})}); + EXPECT(desc.generate_common_from_src({4, 16, 1, 1}) == + ops{make_op("unsqueeze", {{"axes", {3, 5}}}), + make_op("multibroadcast", {{"out_lens", {4, 16, 1, 2, 1, 2}}})}); + EXPECT(desc.generate_common_from_src({4, 1, 32, 32}) == + ops{make_op("unsqueeze", {{"axes", {3, 5}}}), + make_op("multibroadcast", {{"out_lens", {4, 1, 32, 2, 32, 2}}})}); + + EXPECT(desc.generate_common_from_dst() == + ops{make_op("reshape", {{"dims", {4, 16, 32, 2, 32, 2}}})}); + EXPECT(desc.generate_common_from_dst({4, 16, 1, 1}) == + ops{make_op("unsqueeze", {{"axes", {3, 5}}})}); + EXPECT(desc.generate_common_from_dst({4, 1, 64, 64}) == + ops{make_op("reshape", {{"dims", {4, 1, 32, 2, 32, 2}}})}); EXPECT(desc.generate_dst_from_common() == ops{make_op("reshape", {{"dims", {4, 16, 64, 64}}})}); - EXPECT(desc.generate_dst_from_common({4, 16, 1, 2, 1, 2}) == ops{make_op("squeeze", {{"axes", {2, 4}}})}); - EXPECT(desc.generate_dst_from_common({4, 1, 32, 2, 32, 2}) == ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); + EXPECT(desc.generate_dst_from_common({4, 16, 1, 2, 1, 2}) == + ops{make_op("squeeze", {{"axes", {2, 4}}})}); + EXPECT(desc.generate_dst_from_common({4, 1, 32, 2, 32, 2}) == + ops{make_op("reshape", {{"dims", {4, 1, 64, 64}}})}); } TEST_CASE(rebase_reshape_broadcast) { - auto base_desc = make_simple_descriptor({3, 4, 64, 1}, make_op("reshape", {{"dims", {12, 8, 8, 1, 1}}}), make_op("multibroadcast", {{"out_lens", {12, 8, 8, 2, 2}}})); + auto base_desc = + make_simple_descriptor({3, 4, 64, 1}, + make_op("reshape", {{"dims", {12, 8, 8, 1, 1}}}), + make_op("multibroadcast", {{"out_lens", {12, 8, 8, 2, 2}}})); { auto desc = base_desc.rebase({3, 4, 64, 4}); EXPECT(get_final_lens(desc) == final_lens{12, 8, 8, 2, 2}); EXPECT(get_all_lens(desc) == all_lens{{3, 4}, {8}, {8}, {2}, {2}}); - EXPECT(desc.generate() == ops{make_op("reshape", {{"dims", {3, 4, 8, 8, 2, 2}}}), make_op("reshape", {{"dims", {12, 8, 8, 2, 2}}})}); + EXPECT(desc.generate() == ops{make_op("reshape", {{"dims", {3, 4, 8, 8, 2, 2}}}), + make_op("reshape", {{"dims", {12, 8, 8, 2, 2}}})}); } { auto desc = base_desc.rebase({3, 5, 64, 1}); EXPECT(get_final_lens(desc) == final_lens{15, 8, 8, 2, 2}); EXPECT(get_all_lens(desc) == all_lens{{3, 5}, {8}, {8}, {2}, {2}}); - EXPECT(desc.generate() == ops{make_op("reshape", {{"dims", {3, 5, 8, 8, 1, 1}}}), make_op("reshape", {{"dims", {15, 8, 8, 1, 1}}}), make_op("multibroadcast", {{"out_lens", {15, 8, 8, 2, 2}}})}); + EXPECT(desc.generate() == ops{make_op("reshape", {{"dims", {3, 5, 8, 8, 1, 1}}}), + make_op("reshape", {{"dims", {15, 8, 8, 1, 1}}}), + make_op("multibroadcast", {{"out_lens", {15, 8, 8, 2, 2}}})}); } { auto desc = base_desc.rebase({3, 4, 1, 1}); EXPECT(get_final_lens(desc) == final_lens{12, 1, 1, 2, 2}); EXPECT(get_all_lens(desc) == all_lens{{3, 4}, {1}, {1}, {2}, {2}}); - EXPECT(desc.generate() == ops{make_op("unsqueeze", {{"axes", {3, 5}}}), make_op("reshape", {{"dims", {12, 1, 1, 1, 1}}}), make_op("multibroadcast", {{"out_lens", {12, 1, 1, 2, 2}}})}); + EXPECT(desc.generate() == ops{make_op("unsqueeze", {{"axes", {3, 5}}}), + make_op("reshape", {{"dims", {12, 1, 1, 1, 1}}}), + make_op("multibroadcast", {{"out_lens", {12, 1, 1, 2, 2}}})}); } } From 935e3eb1ec5e9c2b86c85f2c1d256ce677795507 Mon Sep 17 00:00:00 2001 From: Paul Date: Wed, 11 Dec 2024 16:33:52 -0600 Subject: [PATCH 36/37] Cleanup --- src/shape_transform_descriptor.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 89d88b50fec..987cc891573 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -873,13 +873,10 @@ struct operation_list void push_back(const operation& op) { ops.push_back(op); } - std::vector to_vector() const& { return {ops.crbegin(), ops.crend()}; } - std::vector to_vector() && { std::reverse(ops.begin(), ops.end()); return std::move(ops); - ; } }; From a502b67a240176d0f17eed2f75f9cfc11da0af89 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 12 Dec 2024 14:57:32 -0600 Subject: [PATCH 37/37] Fix tidy issues --- src/include/migraphx/rewrite_reshapes.hpp | 2 +- tools/cppcheck/migraphx.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index 30496a7ddfa..c064091c279 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -167,7 +167,7 @@ struct rewrite_reshapes return [&, generate](auto input) { auto gops = std::invoke(generate, desc, input->get_shape().lens()); auto start = input; - for(auto op : gops) + for(const auto& op : gops) { start = mpm.get_module().insert_instruction(ins_to_insert, op, start); } diff --git a/tools/cppcheck/migraphx.py b/tools/cppcheck/migraphx.py index 3be73d5b384..787279044a1 100644 --- a/tools/cppcheck/migraphx.py +++ b/tools/cppcheck/migraphx.py @@ -436,6 +436,8 @@ def MatcherNestedParentheses(cfg, data): for tok2 in token.tokAt(4).forward(token.linkAt(4)): if not simpleMatch(tok2, ") ) ) )"): continue + if simpleMatch(tok2.link.previous, "bind"): + continue cppcheck.reportError( tok2, "style", "Too many nested parentheses can affect readability; consider using variables instead."