Skip to content

Commit

Permalink
Fix index_put with boolean index (#2018)
Browse files Browse the repository at this point in the history
Related: #1749

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
xadupre and justinchuby authored Jan 24, 2025
1 parent e673351 commit 6d2b530
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 39 deletions.
51 changes: 16 additions & 35 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4261,7 +4261,7 @@ def aten_index_copy(
raise NotImplementedError()


@torch_op(("aten::index_put", "aten::_unsafe_index_put"))
@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True)
def aten_index_put(
self: TReal,
indices: Sequence[INT64],
Expand All @@ -4275,18 +4275,18 @@ def aten_index_put(
"""

# TODO(justinchuby): Handle when indicies has more than one element
index = op.SequenceAt(indices, 0)
index = indices[0]
new_index = op.Unsqueeze(index, [-1])

if op.Cast(accumulate, to=BOOL.dtype):
if accumulate:
result = op.ScatterND(self, new_index, values, reduction="add")
else:
result = op.ScatterND(self, new_index, values)

return result


@torch_op("aten::index_put")
@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
Expand All @@ -4295,37 +4295,18 @@ def aten_index_put_bool(
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""

index = op.SequenceAt(indices, 0) # assume indices only have 1 element
# FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it
index_int = op.Cast(index, to=INT32.dtype)
# if all False, return op.Identity(self)
if op.ReduceSum(index_int) == 0:
result = self
else:
# change array([F,F,T,F,F]) to array([2])
index = op.ArgMax(index_int) # assume index only have 1 True
# change array([2]) to array([2,2,2,2,2])
self_dim_1 = op.Shape(self, start=1, end=2)
index_dim_0 = op.Shape(index, start=0, end=1)
shape = op.Concat(self_dim_1, index_dim_0, axis=0)
new_ind = op.Expand(index, shape)
new_ind_t = op.Transpose(new_ind)

# values must have same rank with input(self)
if op.Size(op.Shape(values)) < op.Size(op.Shape(self)): # type: ignore[operator]
values = op.Unsqueeze(values, op.Constant(value_ints=[0]))

if op.Cast(accumulate, to=BOOL.dtype):
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
zeros = op.CastLike(zeros, values)
result = op.ScatterElements(zeros, new_ind_t, values)
# FIXME: type promotion
result = op.CastLike(result, self)
result = op.Add(result, self)
else:
result = op.ScatterElements(self, new_ind_t, values)

return result
# TODO: Support indices with more than 1 elements
index = indices[0]
# accumulate should be always False, True does not make sense but an assert would be great
# Reshape indices so it can be properly broadcasted
self_rank = len(self.shape)
index_rank = len(index.shape)
if self_rank > index_rank:
index_shape = op.Shape(index)
padding = op.Constant(value_ints=[1 for _ in range(self_rank - index_rank)])
padded_shape = op.Concat(index_shape, padding, axis=0)
index = op.Reshape(index, padded_shape)
return op.Where(index, values, self)


def aten_index_reduce(
Expand Down
6 changes: 2 additions & 4 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,12 +852,10 @@ def _where_input_wrangler(
TorchLibOpInfo(
"index_put_bool",
core_ops.aten_index_put_bool,
)
.skip(
).skip(
matcher=lambda sample: sample.args[0][0].dtype != torch.bool,
reason="this Aten overload only supports tensor(bool) as indices",
)
.skip(reason="FIXME: https://github.com/microsoft/onnxscript/issues/1749"),
),
TorchLibOpInfo(
"index_put",
core_ops.aten_index_put,
Expand Down

0 comments on commit 6d2b530

Please sign in to comment.