Skip to content

Commit

Permalink
AddOp(linalg_vector_norm) | feat(torchlib) (#908)
Browse files Browse the repository at this point in the history
Also updated the should_skip_xfail logic in test to account for data
types.

## Notes

Have to skip one kind of test case: ord=6, dtype=float16.
In Pytorch:
```
>>> b
tensor([[2.3730, 0.9316, 0.6240, 6.1523, 0.1758],
        [5.3984, 7.9375, 4.9062, 1.8809, 6.1016],
        [4.0234, 8.2344, 3.2695, 0.8701, 1.3447]], dtype=torch.float16)
>>> la.vector_norm(a, dim=0, ord=6)
tensor([5.5483, 9.0838, 4.9754, 6.1532, 6.1017])
>>> la.vector_norm(b, dim=0, ord=6)
tensor([5.5469,    inf, 4.9727, 6.1523, 6.0977], dtype=torch.float16)
>>>
```

But in ORT, the result is:
```
tensor([5.5469,  9.08, 4.9727, 6.1523, 6.0977], dtype=loat16)
```

the second element ```9.08``` should be ```inf```. It works in eager
mode, failed in FullGraph mode only.

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
xiaowuhu and justinchuby authored Jul 25, 2023
1 parent f51abd2 commit 241260f
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 7 deletions.
79 changes: 74 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

from typing import Optional, Sequence

from onnxscript import BOOL, FLOAT, INT64
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
from onnxscript.onnx_opset import opset18 as op
from onnxscript.onnx_types import TensorType


Expand Down Expand Up @@ -305,13 +309,78 @@ def aten_linalg_vecdot(x: TensorType, y: TensorType, dim: int = -1) -> TensorTyp
raise NotImplementedError()


@torch_op("aten::linalg_vector_norm", trace_only=True)
def aten_linalg_vector_norm(
self: TensorType,
ord: float = 2,
self: TFloat,
ord: float = 2.0,
dim: Optional[int] = None,
keepdim: bool = False,
dtype: Optional[int] = None,
) -> TensorType:
dtype: int = -1,
) -> TFloat:
"""linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""

raise NotImplementedError()
if dtype != -1:
self = op.Cast(self, to=dtype)
if dim is None or (isinstance(dim, tuple) and len(dim) == 0):
self = op.Reshape(self, op.Constant(value_ints=[-1]))
keepdim = False
return _aten_linalg_vector_norm_no_dim_onnx(self, ord, keepdim)
else:
return _aten_linalg_vector_norm_onnx(self, ord, dim, keepdim)


@torch_op("aten::linalg_vector_norm", private=True)
def _aten_linalg_vector_norm_no_dim_onnx(self: TFloat, ord: float, keepdim: bool) -> TFloat:
self_rank = op.Size(op.Shape(self))
if self_rank == 0:
self = op.Unsqueeze(self, axes=[0])

self = op.Abs(self)
ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT
if op.IsInf(ord, detect_negative=0, detect_positive=1):
result = op.ReduceMax(self, keepdims=keepdim)
elif op.IsInf(ord, detect_negative=1, detect_positive=0):
result = op.ReduceMin(self, keepdims=keepdim)
elif ord == 0.0: # sum(x!=0) means count non-zero elements
self_bool = op.Cast(self, to=BOOL.dtype)
self_0_1 = op.CastLike(self_bool, self)
result = op.ReduceSum(self_0_1, keepdims=False)
else:
ord_float = op.CastLike(ord, self)
self_pow = op.Pow(self, ord_float)
result = op.Pow(op.ReduceSum(self_pow, keepdims=keepdim), op.Div(1.0, ord_float))

if self_rank == 0:
result = op.Squeeze(result)

return result


@torch_op("aten::linalg_vector_norm", private=True)
def _aten_linalg_vector_norm_onnx(
self: TFloat, ord: float, dim: INT64, keepdim: bool
) -> TFloat:
self_rank = op.Size(op.Shape(self))
if self_rank == 0:
self = op.Unsqueeze(self, axes=[0])

dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
self = op.Abs(self)
ord = op.Cast(ord, to=FLOAT.dtype) # Must be FLOAT, due to op.IsInf() needs FLOAT
if op.IsInf(ord, detect_negative=0, detect_positive=1):
result = op.ReduceMax(self, dim, keepdims=keepdim)
elif op.IsInf(ord, detect_negative=1, detect_positive=0):
result = op.ReduceMin(self, dim, keepdims=keepdim)
elif ord == 0.0: # sum(x!=0) means count non-zero elements
self_bool = op.Cast(self, to=BOOL.dtype)
self_0_1 = op.CastLike(self_bool, self)
result = op.ReduceSum(self_0_1, dim, keepdims=keepdim)
else:
ord_float = op.CastLike(ord, self)
self_pow = op.Pow(self, ord_float)
result = op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), op.Div(1.0, ord_float))

if self_rank == 0:
result = op.Squeeze(result)

return result
7 changes: 5 additions & 2 deletions onnxscript/tests/function_libs/torch_lib/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:


def _should_skip_xfail_test_sample(
op_name: str, sample
op_name: str, sample, dtype: torch.dtype
) -> Tuple[Optional[str], Optional[str]]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
Expand All @@ -67,6 +67,9 @@ def _should_skip_xfail_test_sample(
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
if decorator_meta.op_name == op_name:
assert decorator_meta.matcher is not None, "Matcher must be defined"
if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
# Not applicable for this dtype
continue
if decorator_meta.matcher(sample):
return decorator_meta.test_behavior, decorator_meta.reason
return None, None
Expand Down Expand Up @@ -184,7 +187,7 @@ def run_test_output_match(
),
kwargs=repr(cpu_sample.kwargs),
):
test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample)
test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample, dtype)

with ops_test_common.normal_xfail_skip_test_behaviors(test_behavior, reason):
input_onnx = [ops_test_common.convert_tensor_to_numpy(x) for x in inputs]
Expand Down
21 changes: 21 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from typing_extensions import Self

from onnxscript.function_libs.torch_lib.ops import core as core_ops
from onnxscript.function_libs.torch_lib.ops import linalg as linalg_ops
from onnxscript.function_libs.torch_lib.ops import nn as nn_ops
from onnxscript.function_libs.torch_lib.ops import special as special_ops
from onnxscript.tests.function_libs.torch_lib import extra_opinfo, ops_test_common
Expand Down Expand Up @@ -270,6 +271,15 @@ def _grid_sample_input_wrangler(
return args, kwargs


def _linalg_vector_norm_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Make the dims as tensor
if "dim" in kwargs:
kwargs["dim"] = np.array(kwargs["dim"], dtype=np.int64)
return args, kwargs


def _max_pool_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -630,6 +640,17 @@ def _where_input_wrangler(
TorchLibOpInfo("isnan", core_ops.aten_isnan),
TorchLibOpInfo("isneginf", core_ops.aten_isneginf),
TorchLibOpInfo("isposinf", core_ops.aten_isposinf),
TorchLibOpInfo(
"linalg.vector_norm",
linalg_ops.aten_linalg_vector_norm,
trace_only=True,
tolerance={torch.float16: (2e-3, 2e-3)},
input_wrangler=_linalg_vector_norm_input_wrangler,
).skip(
matcher=lambda sample: sample.kwargs.get("ord") == 6,
dtypes=[torch.float16],
reason="ORT returns a more accurate value for float16 with ord=6 (expected=Inf, actual=9.48).",
),
TorchLibOpInfo(
"linspace",
core_ops.aten_linspace,
Expand Down

0 comments on commit 241260f

Please sign in to comment.