From 165bd1d0de80b10929e6c0a4c3ee671470d61add Mon Sep 17 00:00:00 2001 From: Lakhinder Walia <139581206+lakhinderwalia@users.noreply.github.com> Date: Mon, 27 Jan 2025 06:17:40 -0800 Subject: [PATCH] Resize onnx operator: Optimization for Compute and Space performance of its linear option. (#3773) Optimize the space overhead required for Linear Resize operation: it is now 4x smaller for its 2D images. There were very large data-structures, getting to be over 16 times the total input_pixels for a 4D tensor. And now it becomes 4x smaller in size, followed with fewer reduction steps. --- src/onnx/parse_resize.cpp | 179 +++++++++--------- test/onnx/gen_onnx.py | 14 ++ test/onnx/include/onnx_test_utils.hpp | 113 ++++++----- .../parse/resize_downsample_linear_test.cpp | 63 +++--- .../parse/resize_upsample_linear_test.cpp | 93 +-------- test/onnx/parse/upsample_linear_test.cpp | 4 +- .../resize_upsample_linear_large_test.onnx | Bin 0 -> 218 bytes 7 files changed, 193 insertions(+), 273 deletions(-) create mode 100644 test/onnx/resize_upsample_linear_large_test.onnx diff --git a/src/onnx/parse_resize.cpp b/src/onnx/parse_resize.cpp index e4c86ec40b0..0c6216ae955 100644 --- a/src/onnx/parse_resize.cpp +++ b/src/onnx/parse_resize.cpp @@ -21,6 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ + #include #include #include @@ -28,7 +29,8 @@ #include #include #include -#include +#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -36,87 +38,65 @@ namespace onnx { /* * Algorithm of calc_neighbor_points(): - * Input: vvv_ind, 3-layer vector to compose vector of indices. - * in_s, shape to get space index from, using the composed vector of indices. - * Output: vector contains the result of space index. - * - * From vvv_ind: - * layer-1: size of 1st dimension, caller will pass as n_bits - * layer-2: hardcode to 2 by caller - * layer-3: a vector of out_elements (caller pass) integers. - * vvv_ind = { - * {{...}, {...}}, - * {{...}, {...}}, - * {{...}, {...}}, - * ... - * {{...}, {...}} - * }; - * - * To Compose a series of vector of indices, which will further be used to get space index from - * the input shape. - * indices{} has (2^n_bits) * out_elements members, each member is a vector of n_bits indices. - * indices = { - * {...}, - * {...}, - * {...}, - * ... - * {...} - * }; - * - * Notate vvv_ind as: - * 0-1 - * A B - * C D - * E F - * G H - * Notate A' as A's transpose. - * i.e. A = {0,1,1,0,1}; - * A' = {{0}, - * {1}, - * {1}, - * {0}, - * {1} - * }; + * Input: vvv_ind, a collection of neighbors per resized dimension as: + * layer-1: (# resized dimensions, vector) + * layer-2: (A vector of 2 of: hi/low) + * layer-3: Neighor index of every pixel in that output dimension (vector) + * in_s, the original input tensor shape (vector) + * out_s, the output tensor shape (vector) + * resized_m, lens indices that have to resized (map) * - * Outer loop: - * Iterate all values within range [0, (2^n_bits)) and maps to bitset for inner loop (MSB to LSB). - * Middle loop: - * Transform all elements in layer-3: take indices from inner loop to get index from input shape, - append to vec_ind. - * Inner loop: - * Compose a vector of indices by iterating all layer-1 using current bitset from current element. - * - * i.e. val = 6 -> bitset 0110b -> indices: pick each value from A'D'F'G' -> in_s.index(indices) + * Output: per resized pixel, its neighboring hi/lo indexes (vector): all permutations. + * This api stitches all the neighbors (for every dimension) for a resized pixel, + * to yield its neighbor index w.r.t to the input shape, in_s. */ static std::vector calc_neighbor_points(const std::vector>>& vvv_ind, - const shape& in_s) + const shape& in_s, + const shape& out_s, + const std::map& resized_m) { - std::size_t n_bits = vvv_ind.size(); - std::size_t m_elements = vvv_ind[0][0].size(); - std::vector vec_ind; + std::size_t ndims = out_s.ndim(); + const auto& strides = out_s.strides(); + std::size_t elements_ct = vvv_ind[0][0].size(); - if(n_bits >= std::numeric_limits::digits) - { - MIGRAPHX_THROW("PARSE_RESIZE: Shape dimension " + std::to_string(n_bits) + " exceeds " + - std::to_string(std::numeric_limits::digits)); - } + // This function computes for each element, all permutations of its neighbor indices into an + // Perm block in one go. (Instead of computing each permutation in isolation per element) + size_t permutations = 1u << resized_m.size(); + std::vector> perm_blk(permutations, std::vector(strides)); - for(std::size_t val = 0; val < (std::size_t{1} << n_bits); val++) + // final outputted vector: permutations of neighbors. + std::vector out_idx_vec(permutations * elements_ct); + + for(size_t e_idx = 0; e_idx < elements_ct; ++e_idx) { - std::bitset::digits> bits_val = val; - std::vector indices(n_bits); - transform(range(m_elements), std::back_inserter(vec_ind), [&](std::size_t i_element) { - transform( - vvv_ind, range(n_bits), indices.begin(), [&](const auto& vv_ind, std::size_t bit) { - return vv_ind[bits_val[bit]][i_element]; - }); - return in_s.index(indices); - }); + size_t t_idx = e_idx; + for(size_t l_idx = 0; l_idx != ndims; ++l_idx) + { + auto entry = resized_m.find(l_idx); + if(entry != resized_m.end()) + { + size_t hi_cmp_bit = 1u << entry->second; + auto lo = vvv_ind[entry->second][0][e_idx]; + auto hi = vvv_ind[entry->second][1][e_idx]; + for(size_t i = 0; i < permutations; i++) + perm_blk[i][l_idx] = ((i & hi_cmp_bit) != 0) ? hi : lo; + } + else + { + size_t idx = t_idx / strides[l_idx]; + // no permutations in an unmodified lens index, so idx is copied over: + for(size_t i = 0; i < permutations; i++) + perm_blk[i][l_idx] = idx; + } + t_idx %= strides[l_idx]; + } + // write out the permuted indices, calculated off the perm_blk: + for(size_t i = 0; i < permutations; i++) + out_idx_vec[e_idx + elements_ct * i] = in_s.index(perm_blk[i]); } - - return vec_ind; + return out_idx_vec; } static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr) @@ -391,7 +371,6 @@ struct parse_resize : op_parser ": linear mode not supported for non-constant inputs"); shape out_s{in_s.type(), out_lens}; - std::size_t out_elements = out_s.elements(); // reshape input to one-dimension std::vector rsp_lens = {static_cast(in_s.elements())}; @@ -400,41 +379,55 @@ struct parse_resize : op_parser auto nearest_floor = op::resize::get_nearest_op("floor"); auto nearest_ceil = op::resize::get_nearest_op("ceil"); - // get the number of dimensions - std::size_t n_dim = out_lens.size(); + std::vector resized_axes; // vector of dimensions to be resized + std::size_t out_elements = 1; // total number of elements to be resized + size_t resized_ct = 0; + std::map resized_m; // modified indices --> vvv_ind index below + for(std::size_t axis = 0; axis != out_lens.size(); ++axis) + { + out_elements *= out_lens[axis]; + if(in_lens[axis] == out_lens[axis]) + continue; + resized_axes.push_back(axis); + resized_m[axis] = resized_ct++; + } + + // Neighbor indices. For an axis. Two sets of max/min per element: std::vector> vv_ind(2, std::vector(out_elements)); - std::vector>> vvv_ind(n_dim, vv_ind); - std::vector> delta(n_dim, std::vector(out_elements)); + // Neighbor indices. For all resized axes: + std::vector>> vvv_ind(resized_ct, vv_ind); + // Delta list. For each resized axes - per element. + std::vector> delta(resized_ct, std::vector(out_elements)); - shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) { - for(auto ii = 0; ii < in_lens.size(); ++ii) + shape_for_each(out_s, [&](const auto& out_idx_v, std::size_t out_idx) { + for(size_t ii = 0; ii != resized_ct; ++ii) { - auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]); - vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val); - vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val); + auto idx = resized_axes[ii]; + auto idx_val = + idx_op(in_lens[idx], out_lens[idx], out_idx_v[idx], vec_scale[idx]); + vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[idx], idx_val); + vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[idx], idx_val); delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx]; } }); - auto ind = calc_neighbor_points(vvv_ind, in_s); + auto ind = calc_neighbor_points(vvv_ind, in_s, out_s, resized_m); - auto ind_lens = out_lens; - ind_lens[0] *= (std::size_t{1} << n_dim); - shape ind_s{shape::int32_type, ind_lens}; + auto dim_lens = out_lens; + // indices matrix size grows 2x per resized-axis: + dim_lens[0] *= (1u << resized_ct); + shape ind_s{shape::int32_type, dim_lens}; auto ins_ind = info.add_literal(literal(ind_s, ind)); auto data = info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind); - auto dim_lens = out_lens; - dim_lens[0] *= (std::size_t{1} << (n_dim - 1)); - for(std::size_t i = 0; i < n_dim; ++i) + for(auto idx = resized_ct; idx != 0u; --idx) { + dim_lens[0] /= 2; // halved for 2 slices of data (hi & low below) shape dim_s{shape::float_type, dim_lens}; - const auto& dim_delta = delta[n_dim - i - 1]; + const auto& dim_delta = delta[idx - 1]; std::vector delta_data; for(std::size_t j = 0; j < dim_lens[0] / out_lens[0]; ++j) - { delta_data.insert(delta_data.begin(), dim_delta.begin(), dim_delta.end()); - } auto ins_delta = info.add_literal(dim_s, delta_data); // slice the data @@ -449,9 +442,7 @@ struct parse_resize : op_parser auto diff = info.add_instruction(make_op("sub"), hi, low); auto ddf = info.add_instruction(make_op("mul"), diff, ins_delta); data = info.add_instruction(make_op("add"), ddf, low); - dim_lens[0] /= 2; } - return data; } } diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 880aac4fed0..cfdd7af7711 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -11457,6 +11457,20 @@ def resize_upsample_linear_test(): return ([node], [X], [Y], [scales_tensor]) +@onnx_test() +def resize_upsample_linear_large_test(): + x = helper.make_tensor_value_info('X', TensorProto.FLOAT, + [1, 1, 1024, 1024]) + s = helper.make_tensor('scales', TensorProto.FLOAT, [4], [1, 1, 2, 2]) + y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, + [1, 1, 2048, 2048]) + node = onnx.helper.make_node('Resize', + inputs=['X', '', 'scales'], + outputs=['Y'], + mode='linear') + return ([node], [x], [y], [s]) + + @onnx_test() def resize_upsample_pf_test(): scales = np.array([1.0, 1.0, 2.0, 3.0], dtype=np.float32) diff --git a/test/onnx/include/onnx_test_utils.hpp b/test/onnx/include/onnx_test_utils.hpp index cffd6a44519..e997e615cc5 100644 --- a/test/onnx/include/onnx_test_utils.hpp +++ b/test/onnx/include/onnx_test_utils.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -325,6 +325,29 @@ inline migraphx::program make_quantizelinear_axis_prog() return p; } +/* Parsed IR equivalent of create_upsample_linear_prog() +module: "main" +@0 = @literal{ ... } -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1} +@1 = @literal{ ... } -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1} +@2 = @literal{ ... } -> int32_type, {4, 1, 4, 4}, {16, 16, 4, 1} +X = @param:X -> float_type, {1, 1, 2, 2}, {4, 4, 2, 1} +@4 = @literal{1, 1, 2, 2} -> float_type, {4}, {1} +@5 = undefined -> float_type, {}, {} +@6 = reshape[dims={4}](X) -> float_type, {4}, {1} +@7 = gather[axis=0](@6,@2) -> float_type, {4, 1, 4, 4}, {16, 16, 4, 1} +@8 = slice[axes={0},starts={0},ends={2}](@7) -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1} +@9 = slice[axes={0},starts={2},ends={4}](@7) -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1} +@10 = sub(@9,@8) -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1} +@11 = mul(@10,@1) -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1} +@12 = add(@11,@8) -> float_type, {2, 1, 4, 4}, {16, 16, 4, 1} +@13 = slice[axes={0},starts={0},ends={1}](@12) -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1} +@14 = slice[axes={0},starts={1},ends={2}](@12) -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1} +@15 = sub(@14,@13) -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1} +@16 = mul(@15,@0) -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1} +@17 = add(@16,@13) -> float_type, {1, 1, 4, 4}, {16, 16, 4, 1} +@18 = @return(@17) +*/ + inline auto create_upsample_linear_prog() { migraphx::program p; @@ -335,75 +358,51 @@ inline auto create_upsample_linear_prog() migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; auto x = mm->add_parameter("X", sx); - migraphx::shape s_ind{migraphx::shape::int32_type, {16, 1, 4, 4}}; - std::vector d_ind = { - 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, - 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, - 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, - 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, - 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, - 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, - 3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, - 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, - 2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3}; - auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind)); + migraphx::shape s_ind{migraphx::shape::int32_type, {4, 1, 4, 4}}; - migraphx::shape s8{migraphx::shape::float_type, {8, 1, 4, 4}}; - std::vector d8 = { - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0}; - auto l8 = mm->add_literal(migraphx::literal(s8, d8)); - - migraphx::shape s4{migraphx::shape::float_type, {4, 1, 4, 4}}; - std::vector d4 = { - 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, - 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, - 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, - 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, - 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, - 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, - 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, - 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0}; - auto l4 = mm->add_literal(migraphx::literal(s4, d4)); + std::vector d_ind = {0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, + 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, + 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3}; + + auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind)); migraphx::shape s2{migraphx::shape::float_type, {2, 1, 4, 4}}; - std::vector d2(32, 0); + + std::vector d2 = {-0.25, 0.25, 0.75, 0.25, -0.25, 0.25, 0.75, 0.25, + -0.25, 0.25, 0.75, 0.25, -0.25, 0.25, 0.75, 0.25, + -0.25, 0.25, 0.75, 0.25, -0.25, 0.25, 0.75, 0.25, + -0.25, 0.25, 0.75, 0.25, -0.25, 0.25, 0.75, 0.25}; + auto l2 = mm->add_literal(migraphx::literal(s2, d2)); migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 4}}; - std::vector d1(16, 0.0f); + + std::vector d1 = {-0.25, + -0.25, + -0.25, + -0.25, + 0.25, + 0.25, + 0.25, + 0.25, + 0.75, + 0.75, + 0.75, + 0.75, + 0.25, + 0.25, + 0.25, + 0.25}; + auto l1 = mm->add_literal(migraphx::literal(s1, d1)); mm->add_instruction(migraphx::make_op("undefined")); auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x); auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind); - auto slc80 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data); - auto slc81 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), data); - auto diff8 = mm->add_instruction(migraphx::make_op("sub"), slc81, slc80); - auto mul8 = mm->add_instruction(migraphx::make_op("mul"), diff8, l8); - auto add8 = mm->add_instruction(migraphx::make_op("add"), mul8, slc80); - auto slc40 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), add8); - auto slc41 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), add8); - auto diff4 = mm->add_instruction(migraphx::make_op("sub"), slc41, slc40); - auto mul4 = mm->add_instruction(migraphx::make_op("mul"), diff4, l4); - auto add4 = mm->add_instruction(migraphx::make_op("add"), mul4, slc40); auto slc20 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), add4); + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), data); auto slc21 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), add4); + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), data); auto diff2 = mm->add_instruction(migraphx::make_op("sub"), slc21, slc20); auto mul2 = mm->add_instruction(migraphx::make_op("mul"), diff2, l2); auto add2 = mm->add_instruction(migraphx::make_op("add"), mul2, slc20); diff --git a/test/onnx/parse/resize_downsample_linear_test.cpp b/test/onnx/parse/resize_downsample_linear_test.cpp index 8275059d679..210e20310aa 100644 --- a/test/onnx/parse/resize_downsample_linear_test.cpp +++ b/test/onnx/parse/resize_downsample_linear_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,6 +24,29 @@ #include +/* IR for the test case below: +module: "main" +@0 = @literal{0.333333, 0.333333} -> float_type, {1, 1, 1, 2}, {2, 2, 2, 1} +@1 = @literal{0.5, 0.5, 0.5, 0.5} -> float_type, {2, 1, 1, 2}, {2, 2, 2, 1} +@2 = @literal{0, 2, 4, 6, 1, 3, 5, 7} -> int32_type, {4, 1, 1, 2}, {2, 2, 2, 1} +X = @param:X -> float_type, {1, 1, 2, 4}, {8, 8, 4, 1} +@4 = @literal{1, 1, 0.6, 0.5} -> float_type, {4}, {1} +@5 = undefined -> float_type, {}, {} +@6 = reshape[dims={8}](X) -> float_type, {8}, {1} +@7 = gather[axis=0](@6,@2) -> float_type, {4, 1, 1, 2}, {2, 2, 2, 1} +@8 = slice[axes={0},starts={0},ends={2}](@7) -> float_type, {2, 1, 1, 2}, {2, 2, 2, 1} +@9 = slice[axes={0},starts={2},ends={4}](@7) -> float_type, {2, 1, 1, 2}, {2, 2, 2, 1} +@10 = sub(@9,@8) -> float_type, {2, 1, 1, 2}, {2, 2, 2, 1} +@11 = mul(@10,@1) -> float_type, {2, 1, 1, 2}, {2, 2, 2, 1} +@12 = add(@11,@8) -> float_type, {2, 1, 1, 2}, {2, 2, 2, 1} +@13 = slice[axes={0},starts={0},ends={1}](@12) -> float_type, {1, 1, 1, 2}, {2, 2, 2, 1} +@14 = slice[axes={0},starts={1},ends={2}](@12) -> float_type, {1, 1, 1, 2}, {2, 2, 2, 1} +@15 = sub(@14,@13) -> float_type, {1, 1, 1, 2}, {2, 2, 2, 1} +@16 = mul(@15,@0) -> float_type, {1, 1, 1, 2}, {2, 2, 2, 1} +@17 = add(@16,@13) -> float_type, {1, 1, 1, 2}, {2, 2, 2, 1} +@18 = @return(@17) +*/ + TEST_CASE(resize_downsample_linear_test) { migraphx::program p; @@ -34,51 +57,31 @@ TEST_CASE(resize_downsample_linear_test) migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 4}}; auto x = mm->add_parameter("X", sx); - migraphx::shape s_ind{migraphx::shape::int32_type, {16, 1, 1, 2}}; - std::vector d_ind = {0, 2, 0, 2, 0, 2, 0, 2, 4, 6, 4, 6, 4, 6, 4, 6, - 1, 3, 1, 3, 1, 3, 1, 3, 5, 7, 5, 7, 5, 7, 5, 7}; - auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind)); - - migraphx::shape s8{migraphx::shape::float_type, {8, 1, 1, 2}}; - std::vector d8(16, 0.5f); - auto l8 = mm->add_literal(migraphx::literal(s8, d8)); - migraphx::shape s4{migraphx::shape::float_type, {4, 1, 1, 2}}; - std::vector d4(8, 1.0f / 3.0f); - auto l4 = mm->add_literal(migraphx::literal(s4, d4)); + migraphx::shape s_ind{migraphx::shape::int32_type, {4, 1, 1, 2}}; + std::vector d_ind = {0, 2, 4, 6, 1, 3, 5, 7}; + auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind)); migraphx::shape s2{migraphx::shape::float_type, {2, 1, 1, 2}}; - std::vector d2(4, 0); + std::vector d2(4, 0.5f); auto l2 = mm->add_literal(migraphx::literal(s2, d2)); migraphx::shape s1{migraphx::shape::float_type, {1, 1, 1, 2}}; - std::vector d1(2, 0.0f); + std::vector d1(2, 1.0f / 3.0f); auto l1 = mm->add_literal(migraphx::literal(s1, d1)); mm->add_instruction(migraphx::make_op("undefined")); + auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), x); auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind); - auto slc80 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data); - auto slc81 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), data); - auto diff8 = mm->add_instruction(migraphx::make_op("sub"), slc81, slc80); - auto mul8 = mm->add_instruction(migraphx::make_op("mul"), diff8, l8); - auto add8 = mm->add_instruction(migraphx::make_op("add"), mul8, slc80); - auto slc40 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), add8); - auto slc41 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), add8); - auto diff4 = mm->add_instruction(migraphx::make_op("sub"), slc41, slc40); - auto mul4 = mm->add_instruction(migraphx::make_op("mul"), diff4, l4); - auto add4 = mm->add_instruction(migraphx::make_op("add"), mul4, slc40); auto slc20 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), add4); + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), data); auto slc21 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), add4); + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), data); auto diff2 = mm->add_instruction(migraphx::make_op("sub"), slc21, slc20); auto mul2 = mm->add_instruction(migraphx::make_op("mul"), diff2, l2); auto add2 = mm->add_instruction(migraphx::make_op("add"), mul2, slc20); + auto slc10 = mm->add_instruction( migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), add2); auto slc11 = mm->add_instruction( diff --git a/test/onnx/parse/resize_upsample_linear_test.cpp b/test/onnx/parse/resize_upsample_linear_test.cpp index b8711bb2136..18a612c9ba4 100644 --- a/test/onnx/parse/resize_upsample_linear_test.cpp +++ b/test/onnx/parse/resize_upsample_linear_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -23,98 +23,11 @@ */ #include +#include TEST_CASE(resize_upsample_linear_test) { - migraphx::program p; - auto* mm = p.get_main_module(); - migraphx::shape ss{migraphx::shape::float_type, {4}}; - std::vector ds = {1, 1, 2, 2}; - mm->add_literal(migraphx::literal(ss, ds)); - - migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}}; - auto x = mm->add_parameter("X", sx); - migraphx::shape s_ind{migraphx::shape::int32_type, {16, 1, 4, 4}}; - std::vector d_ind = { - 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, - 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, - 0, 1, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, - 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 0, 0, 1, - 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, - 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, - 3, 3, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, - 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, - 2, 3, 3, 3, 2, 3, 3, 3, 0, 1, 1, 1, 2, 3, 3, 3, 2, 3, 3, 3, 2, 3, 3, 3}; - auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind)); - - migraphx::shape s8{migraphx::shape::float_type, {8, 1, 4, 4}}; - std::vector d8 = { - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0, - 0, 1.0f / 3, 2.0f / 3, 0, 0, 1.0f / 3, 2.0f / 3, 0}; - auto l8 = mm->add_literal(migraphx::literal(s8, d8)); - - migraphx::shape s4{migraphx::shape::float_type, {4, 1, 4, 4}}; - std::vector d4 = { - 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, - 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, - 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, - 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, - 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, - 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0, - 0, 0, 0, 0, 1.0f / 3, 1.0f / 3, 1.0f / 3, 1.0f / 3, - 2.0f / 3, 2.0f / 3, 2.0f / 3, 2.0f / 3, 0, 0, 0, 0}; - auto l4 = mm->add_literal(migraphx::literal(s4, d4)); - - migraphx::shape s2{migraphx::shape::float_type, {2, 1, 4, 4}}; - std::vector d2(32, 0); - auto l2 = mm->add_literal(migraphx::literal(s2, d2)); - - migraphx::shape s1{migraphx::shape::float_type, {1, 1, 4, 4}}; - std::vector d1(16, 0.0f); - auto l1 = mm->add_literal(migraphx::literal(s1, d1)); - - mm->add_instruction(migraphx::make_op("undefined")); - auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), x); - auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind); - auto slc80 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {8}}}), data); - auto slc81 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {8}}, {"ends", {16}}}), data); - auto diff8 = mm->add_instruction(migraphx::make_op("sub"), slc81, slc80); - auto mul8 = mm->add_instruction(migraphx::make_op("mul"), diff8, l8); - auto add8 = mm->add_instruction(migraphx::make_op("add"), mul8, slc80); - auto slc40 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {4}}}), add8); - auto slc41 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {4}}, {"ends", {8}}}), add8); - auto diff4 = mm->add_instruction(migraphx::make_op("sub"), slc41, slc40); - auto mul4 = mm->add_instruction(migraphx::make_op("mul"), diff4, l4); - auto add4 = mm->add_instruction(migraphx::make_op("add"), mul4, slc40); - auto slc20 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), add4); - auto slc21 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), add4); - auto diff2 = mm->add_instruction(migraphx::make_op("sub"), slc21, slc20); - auto mul2 = mm->add_instruction(migraphx::make_op("mul"), diff2, l2); - auto add2 = mm->add_instruction(migraphx::make_op("add"), mul2, slc20); - auto slc10 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), add2); - auto slc11 = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), add2); - auto diff1 = mm->add_instruction(migraphx::make_op("sub"), slc11, slc10); - auto mul1 = mm->add_instruction(migraphx::make_op("mul"), diff1, l1); - auto add1 = mm->add_instruction(migraphx::make_op("add"), mul1, slc10); - mm->add_return({add1}); - + auto p = create_upsample_linear_prog(); // same net IR as upsample version auto prog = read_onnx("resize_upsample_linear_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/parse/upsample_linear_test.cpp b/test/onnx/parse/upsample_linear_test.cpp index cae67a0395e..365009cfb30 100644 --- a/test/onnx/parse/upsample_linear_test.cpp +++ b/test/onnx/parse/upsample_linear_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,7 +27,7 @@ TEST_CASE(upsample_linear_test) { - auto p = create_upsample_linear_prog(); + auto p = create_upsample_linear_prog(); // same net IR for resize & upsample auto prog = read_onnx("upsample_linear_test.onnx"); EXPECT(p == prog); } diff --git a/test/onnx/resize_upsample_linear_large_test.onnx b/test/onnx/resize_upsample_linear_large_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d89044e31a59929912000cd00b00d9bb01bda835 GIT binary patch literal 218 zcmdhNmEzYb;jV~=IPRuRHNsZ6R%u7uyiqA7b1Ni8n1TF1zx&BYkO#lXc@ zoSc}GS}epEsl*lp)~6-N#gdz!lB&c8(YJt+nFu4b