Skip to content

Commit

Permalink
Resize onnx operator: Optimization for Compute and Space performance …
Browse files Browse the repository at this point in the history
…of its linear option. (#3773)

Optimize the space overhead required for Linear Resize operation: it is now 4x smaller for its 2D images. There were very large data-structures, getting to be over 16 times the total input_pixels for a 4D tensor. And now it becomes 4x smaller in size, followed with fewer reduction steps.
  • Loading branch information
lakhinderwalia authored Jan 27, 2025
1 parent 2503040 commit 165bd1d
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 273 deletions.
179 changes: 85 additions & 94 deletions src/onnx/parse_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,102 +21,82 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/op/resize.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <bitset>
#include <vector>
#include <map>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

/*
* Algorithm of calc_neighbor_points():
* Input: vvv_ind, 3-layer vector to compose vector of indices.
* in_s, shape to get space index from, using the composed vector of indices.
* Output: vector contains the result of space index.
*
* From vvv_ind:
* layer-1: size of 1st dimension, caller will pass as n_bits
* layer-2: hardcode to 2 by caller
* layer-3: a vector of out_elements (caller pass) integers.
* vvv_ind = {
* {{...}, {...}},
* {{...}, {...}},
* {{...}, {...}},
* ...
* {{...}, {...}}
* };
*
* To Compose a series of vector of indices, which will further be used to get space index from
* the input shape.
* indices{} has (2^n_bits) * out_elements members, each member is a vector of n_bits indices.
* indices = {
* {...},
* {...},
* {...},
* ...
* {...}
* };
*
* Notate vvv_ind as:
* 0-1
* A B
* C D
* E F
* G H
* Notate A' as A's transpose.
* i.e. A = {0,1,1,0,1};
* A' = {{0},
* {1},
* {1},
* {0},
* {1}
* };
* Input: vvv_ind, a collection of neighbors per resized dimension as:
* layer-1: (# resized dimensions, vector)
* layer-2: (A vector of 2 of: hi/low)
* layer-3: Neighor index of every pixel in that output dimension (vector)
* in_s, the original input tensor shape (vector)
* out_s, the output tensor shape (vector)
* resized_m, lens indices that have to resized (map)
*
* Outer loop:
* Iterate all values within range [0, (2^n_bits)) and maps to bitset for inner loop (MSB to LSB).
* Middle loop:
* Transform all elements in layer-3: take indices from inner loop to get index from input shape,
append to vec_ind.
* Inner loop:
* Compose a vector of indices by iterating all layer-1 using current bitset from current element.
*
* i.e. val = 6 -> bitset 0110b -> indices: pick each value from A'D'F'G' -> in_s.index(indices)
* Output: per resized pixel, its neighboring hi/lo indexes (vector): all permutations.
* This api stitches all the neighbors (for every dimension) for a resized pixel,
* to yield its neighbor index w.r.t to the input shape, in_s.
*/

static std::vector<int>
calc_neighbor_points(const std::vector<std::vector<std::vector<std::size_t>>>& vvv_ind,
const shape& in_s)
const shape& in_s,
const shape& out_s,
const std::map<size_t, size_t>& resized_m)
{
std::size_t n_bits = vvv_ind.size();
std::size_t m_elements = vvv_ind[0][0].size();
std::vector<int> vec_ind;
std::size_t ndims = out_s.ndim();
const auto& strides = out_s.strides();
std::size_t elements_ct = vvv_ind[0][0].size();

if(n_bits >= std::numeric_limits<std::size_t>::digits)
{
MIGRAPHX_THROW("PARSE_RESIZE: Shape dimension " + std::to_string(n_bits) + " exceeds " +
std::to_string(std::numeric_limits<std::size_t>::digits));
}
// This function computes for each element, all permutations of its neighbor indices into an
// Perm block in one go. (Instead of computing each permutation in isolation per element)
size_t permutations = 1u << resized_m.size();
std::vector<std::vector<std::size_t>> perm_blk(permutations, std::vector<size_t>(strides));

for(std::size_t val = 0; val < (std::size_t{1} << n_bits); val++)
// final outputted vector: permutations of neighbors.
std::vector<int> out_idx_vec(permutations * elements_ct);

for(size_t e_idx = 0; e_idx < elements_ct; ++e_idx)
{
std::bitset<std::numeric_limits<std::size_t>::digits> bits_val = val;
std::vector<std::size_t> indices(n_bits);
transform(range(m_elements), std::back_inserter(vec_ind), [&](std::size_t i_element) {
transform(
vvv_ind, range(n_bits), indices.begin(), [&](const auto& vv_ind, std::size_t bit) {
return vv_ind[bits_val[bit]][i_element];
});
return in_s.index(indices);
});
size_t t_idx = e_idx;
for(size_t l_idx = 0; l_idx != ndims; ++l_idx)
{
auto entry = resized_m.find(l_idx);
if(entry != resized_m.end())
{
size_t hi_cmp_bit = 1u << entry->second;
auto lo = vvv_ind[entry->second][0][e_idx];
auto hi = vvv_ind[entry->second][1][e_idx];
for(size_t i = 0; i < permutations; i++)
perm_blk[i][l_idx] = ((i & hi_cmp_bit) != 0) ? hi : lo;
}
else
{
size_t idx = t_idx / strides[l_idx];
// no permutations in an unmodified lens index, so idx is copied over:
for(size_t i = 0; i < permutations; i++)
perm_blk[i][l_idx] = idx;
}
t_idx %= strides[l_idx];
}
// write out the permuted indices, calculated off the perm_blk:
for(size_t i = 0; i < permutations; i++)
out_idx_vec[e_idx + elements_ct * i] = in_s.index(perm_blk[i]);
}

return vec_ind;
return out_idx_vec;
}

static std::string get_coord_trans_mode(const onnx_parser::attribute_map& attr)
Expand Down Expand Up @@ -391,7 +371,6 @@ struct parse_resize : op_parser<parse_resize>
": linear mode not supported for non-constant inputs");

shape out_s{in_s.type(), out_lens};
std::size_t out_elements = out_s.elements();

// reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
Expand All @@ -400,41 +379,55 @@ struct parse_resize : op_parser<parse_resize>
auto nearest_floor = op::resize::get_nearest_op("floor");
auto nearest_ceil = op::resize::get_nearest_op("ceil");

// get the number of dimensions
std::size_t n_dim = out_lens.size();
std::vector<size_t> resized_axes; // vector of dimensions to be resized
std::size_t out_elements = 1; // total number of elements to be resized
size_t resized_ct = 0;
std::map<size_t, size_t> resized_m; // modified indices --> vvv_ind index below
for(std::size_t axis = 0; axis != out_lens.size(); ++axis)
{
out_elements *= out_lens[axis];
if(in_lens[axis] == out_lens[axis])
continue;
resized_axes.push_back(axis);
resized_m[axis] = resized_ct++;
}

// Neighbor indices. For an axis. Two sets of max/min per element:
std::vector<std::vector<std::size_t>> vv_ind(2, std::vector<std::size_t>(out_elements));
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind);
std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
// Neighbor indices. For all resized axes:
std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(resized_ct, vv_ind);
// Delta list. For each resized axes - per element.
std::vector<std::vector<float>> delta(resized_ct, std::vector<float>(out_elements));

shape_for_each(out_s, [&](const auto& out_idx_v, size_t out_idx) {
for(auto ii = 0; ii < in_lens.size(); ++ii)
shape_for_each(out_s, [&](const auto& out_idx_v, std::size_t out_idx) {
for(size_t ii = 0; ii != resized_ct; ++ii)
{
auto idx_val = idx_op(in_lens[ii], out_lens[ii], out_idx_v[ii], vec_scale[ii]);
vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[ii], idx_val);
vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[ii], idx_val);
auto idx = resized_axes[ii];
auto idx_val =
idx_op(in_lens[idx], out_lens[idx], out_idx_v[idx], vec_scale[idx]);
vvv_ind[ii][0][out_idx] = nearest_floor(in_lens[idx], idx_val);
vvv_ind[ii][1][out_idx] = nearest_ceil(in_lens[idx], idx_val);
delta[ii][out_idx] = idx_val - vvv_ind[ii][0][out_idx];
}
});

auto ind = calc_neighbor_points(vvv_ind, in_s);
auto ind = calc_neighbor_points(vvv_ind, in_s, out_s, resized_m);

auto ind_lens = out_lens;
ind_lens[0] *= (std::size_t{1} << n_dim);
shape ind_s{shape::int32_type, ind_lens};
auto dim_lens = out_lens;
// indices matrix size grows 2x per resized-axis:
dim_lens[0] *= (1u << resized_ct);
shape ind_s{shape::int32_type, dim_lens};
auto ins_ind = info.add_literal(literal(ind_s, ind));
auto data = info.add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);

auto dim_lens = out_lens;
dim_lens[0] *= (std::size_t{1} << (n_dim - 1));
for(std::size_t i = 0; i < n_dim; ++i)
for(auto idx = resized_ct; idx != 0u; --idx)
{
dim_lens[0] /= 2; // halved for 2 slices of data (hi & low below)
shape dim_s{shape::float_type, dim_lens};
const auto& dim_delta = delta[n_dim - i - 1];
const auto& dim_delta = delta[idx - 1];
std::vector<float> delta_data;
for(std::size_t j = 0; j < dim_lens[0] / out_lens[0]; ++j)
{
delta_data.insert(delta_data.begin(), dim_delta.begin(), dim_delta.end());
}
auto ins_delta = info.add_literal(dim_s, delta_data);

// slice the data
Expand All @@ -449,9 +442,7 @@ struct parse_resize : op_parser<parse_resize>
auto diff = info.add_instruction(make_op("sub"), hi, low);
auto ddf = info.add_instruction(make_op("mul"), diff, ins_delta);
data = info.add_instruction(make_op("add"), ddf, low);
dim_lens[0] /= 2;
}

return data;
}
}
Expand Down
14 changes: 14 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11457,6 +11457,20 @@ def resize_upsample_linear_test():
return ([node], [X], [Y], [scales_tensor])


@onnx_test()
def resize_upsample_linear_large_test():
x = helper.make_tensor_value_info('X', TensorProto.FLOAT,
[1, 1, 1024, 1024])
s = helper.make_tensor('scales', TensorProto.FLOAT, [4], [1, 1, 2, 2])
y = helper.make_tensor_value_info('Y', TensorProto.FLOAT,
[1, 1, 2048, 2048])
node = onnx.helper.make_node('Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
mode='linear')
return ([node], [x], [y], [s])


@onnx_test()
def resize_upsample_pf_test():
scales = np.array([1.0, 1.0, 2.0, 3.0], dtype=np.float32)
Expand Down
Loading

0 comments on commit 165bd1d

Please sign in to comment.