diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b46d999b3..6aaddd65f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -81,7 +81,7 @@ def aten_acosh(self: TFloat) -> TFloat: return op.Acosh(self) -@torch_op("aten::add") +@torch_op(("aten::add", "aten::add.Tensor")) def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" # TODO(microsoft/onnxruntime#15977): Improve fp16 precision @@ -1235,7 +1235,7 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: return op.SplitToSequence(self, list_split, axis=dim) -@torch_op("aten::clamp", trace_only=True) +@torch_op(("aten::clamp", "aten::clamp.Tensor"), trace_only=True) def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None) -> TReal: """clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor""" clamped = self @@ -2184,7 +2184,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType raise NotImplementedError() -@torch_op("aten::div") +@torch_op(("aten::div", "aten::div.Tensor")) def aten_div(self: TFloat, other: TFloat) -> TFloat: """div.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -2299,7 +2299,7 @@ def aten_embedding_sparse_backward( raise NotImplementedError() -@torch_op("aten::empty") +@torch_op(("aten::empty", "aten::empty.memory_format")) def aten_empty(size: IntType, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var] # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor @@ -2353,7 +2353,7 @@ def aten_empty_strided( return op.Expand(zero, size) -@torch_op("aten::eq") +@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar")) def aten_eq(self: TTensor, other: TTensor) -> BOOL: """eq.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -2563,7 +2563,7 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType raise NotImplementedError() -@torch_op("aten::fill") +@torch_op(("aten::fill", "aten::fill.Tensor")) def aten_fill(self: TTensor, value: TTensor) -> TTensor: """fill.Tensor(Tensor self, Tensor value) -> Tensor""" @@ -2748,7 +2748,7 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::ge") +@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar")) def aten_ge(self: TReal, other: TReal) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -2905,7 +2905,7 @@ def aten_gru_cell( raise NotImplementedError() -@torch_op("aten::gt") +@torch_op(("aten::gt", "aten::gt.Scalar")) def aten_gt(self: TReal, other: TReal) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3595,7 +3595,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::le") +@torch_op(("aten::le", "aten::le.Tensor")) def aten_le(self: TReal, other: TReal) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3884,7 +3884,7 @@ def aten_lstm_mps_backward( raise NotImplementedError() -@torch_op("aten::lt") +@torch_op(("aten::lt", "aten::lt.Scalar")) def aten_lt(self: TReal, other: TReal) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -3957,7 +3957,7 @@ def aten_margin_ranking_loss( raise NotImplementedError() -@torch_op("aten::masked_fill") +@torch_op(("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor")) def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: """masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor""" # NOTE: Do not attempt to cast `mask` to BOOL because mask should not take any other types. @@ -4462,7 +4462,7 @@ def aten_msort(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::mul") +@torch_op(("aten::mul", "aten::mul.Tensor")) def aten_mul(self: TReal, other: TReal) -> TReal: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" # FIXME(titaiwang): get rid of this when we have type_promotion @@ -4470,7 +4470,7 @@ def aten_mul(self: TReal, other: TReal) -> TReal: return op.Mul(self, other) -@torch_op("aten::mul") +@torch_op(("aten::mul", "aten::mul.Tensor")) def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: """ONNX Mul doesn't support Boolean, so use And as an equivalent operator.""" @@ -4883,7 +4883,7 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType: raise NotImplementedError() -@torch_op("aten::ne") +@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor")) def aten_ne(self: TReal, other: TReal) -> BOOL: """ne.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -5223,7 +5223,7 @@ def aten_positive(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::pow") +@torch_op(("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar")) def aten_pow(self: TReal, exponent: TTensor) -> TReal: """pow(Tensor self, Tensor exponent) -> Tensor""" @@ -5756,7 +5756,7 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16: return op.Reciprocal(op.Sqrt(self)) -@torch_op("aten::rsub") +@torch_op(("aten::rsub", "aten::rsub.Scalar")) def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" # FIXME(titaiwang): get rid of this when we have type_promotion @@ -5785,7 +5785,7 @@ def aten_scatter_add( return op.ScatterElements(self, index, src, axis=dim, reduction="add") -@torch_op("aten::scatter_reduce", trace_only=True) +@torch_op(("aten::scatter_reduce", "aten::scatter_reduce.two"), trace_only=True) def aten_scatter_reduce( self: TReal, dim: int, # we have to use int here because ScatterElements() will use this attribute @@ -5855,7 +5855,7 @@ def aten_segment_reduce( raise NotImplementedError() -@torch_op("aten::select") +@torch_op(("aten::select", "aten::select.int")) def aten_select(self: TTensor, dim: int, index: int) -> TTensor: """select(Tensor self, int dim, int index) -> Tensor""" @@ -5935,7 +5935,7 @@ def aten_sinh(self: TFloat) -> TFloat: return op.Sinh(self) -@torch_op("aten::slice", trace_only=True) +@torch_op(("aten::slice", "aten::slice.Tensor"), trace_only=True) def aten_slice( self: TTensor, dim: int = 0, @@ -6081,7 +6081,7 @@ def aten_sparse_mask(self: TensorType, mask: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::split") +@torch_op(("aten::split", "aten::split.Tensor")) def aten_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor: """split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]""" @@ -6309,7 +6309,7 @@ def aten_stft( return result -@torch_op("aten::sub") +@torch_op(("aten::sub", "aten::sub.Tensor")) def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" alpha = op.CastLike(alpha, other) @@ -6324,7 +6324,7 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1.0) -> Te raise NotImplementedError() -@torch_op("aten::sum", trace_only=True) +@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True) def aten_sum_dim_IntList( self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1 ) -> TReal: @@ -6634,7 +6634,7 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType: raise NotImplementedError() -@torch_op("aten::transpose", trace_only=True) +@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True) def aten_transpose(self, dim0: int, dim1: int): """transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)""" @@ -6729,7 +6729,7 @@ def aten_type_as(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::unbind") +@torch_op(("aten::unbind", "aten::unbind.int")) def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" @@ -7082,7 +7082,7 @@ def aten_vstack(tensors: Sequence[TTensor]) -> TTensor: return op.ConcatFromSequence(tensors, axis=0) -@torch_op("aten::where") +@torch_op(("aten::where", "aten::where.self")) def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor: """where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""