Skip to content

Commit

Permalink
[torch] Fix _operator::{truediv/floordiv} (#2029)
Browse files Browse the repository at this point in the history
Create separate implementations for `_operator::{truediv/floordiv}` to
handle SymInts

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
xadupre and justinchuby authored Jan 21, 2025
1 parent 969c078 commit 23093b0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
25 changes: 11 additions & 14 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
return op.Add(self, other)


@torch_op(
("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True, complex=True
)
@torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True)
def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""

Expand Down Expand Up @@ -2749,7 +2747,6 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
"aten::divide.Scalar",
"aten::true_divide.Tensor",
"aten::true_divide.Scalar",
"_operator::truediv",
)
)
def aten_div(self: TFloat, other: TFloat) -> TFloat:
Expand All @@ -2759,6 +2756,11 @@ def aten_div(self: TFloat, other: TFloat) -> TFloat:
return op.Div(self, other)


@torch_op("_operator::truediv", traceable=True)
def operator_truediv(self: TensorType, other: TensorType) -> FLOAT:
return op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype))


@torch_op(
(
"aten::div.Tensor",
Expand All @@ -2767,7 +2769,6 @@ def aten_div(self: TFloat, other: TFloat) -> TFloat:
"aten::divide.Scalar",
"aten::true_divide.Tensor",
"aten::true_divide.Scalar",
"_operator::truediv",
),
complex=True,
)
Expand Down Expand Up @@ -3597,17 +3598,15 @@ def python_math_floor(self: TFloat) -> TInt:
return op.Cast(floor, to=INT64.dtype)


@torch_op(("aten::floor_divide", "_operator::floordiv"), traceable=True)
@torch_op("aten::floor_divide", traceable=True)
def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""

return op.Floor(op.Div(self, other))


@torch_op(("aten::floor_divide", "_operator::floordiv"), traceable=True)
def aten_floor_divide_int(self: TInt, other: TInt) -> TInt:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""

@torch_op("_operator::floordiv", traceable=True)
def operator_floordiv(self: INT64, other: INT64) -> INT64:
# We implement floor_divide only for positive inputs (using integer division)
# because that is the usual intended case and is the most efficient.
return op.Div(self, other)
Expand Down Expand Up @@ -4940,7 +4939,6 @@ def aten_logical_not(self: BOOL) -> BOOL:
"aten::bitwise_or.Scalar_Tensor",
"aten::add.Tensor",
"aten::add.Scalar",
"_operator::add",
),
traceable=True,
)
Expand Down Expand Up @@ -5658,7 +5656,7 @@ def aten_mul(self: TReal, other: TReal) -> TReal:


@torch_op(
("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"),
traceable=True,
)
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
Expand All @@ -5671,7 +5669,7 @@ def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:


@torch_op(
("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"),
traceable=True,
complex=True,
)
Expand Down Expand Up @@ -8044,7 +8042,6 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"aten::sub.Scalar",
"aten::subtract.Tensor",
"aten::subtract.Scalar",
"_operator::sub",
),
trace_only=True,
complex=True,
Expand Down
1 change: 0 additions & 1 deletion tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,6 @@ def _where_input_wrangler(
test_class_name="TestOutputConsistencyEager",
reason="fixme: off-by-one issue due to numerical precision. https://github.com/microsoft/onnxscript/issues/989",
),
TorchLibOpInfo("ops.aten.floor_divide.int", core_ops.aten_floor_divide_int),
TorchLibOpInfo("fmod", core_ops.aten_fmod),
TorchLibOpInfo("frac", core_ops.aten_frac),
TorchLibOpInfo("full", core_ops.aten_full),
Expand Down

0 comments on commit 23093b0

Please sign in to comment.