Skip to content

Commit

Permalink
Merge branch 'main' into index_put
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Jan 22, 2025
2 parents e28e753 + 0447822 commit a6c6538
Show file tree
Hide file tree
Showing 12 changed files with 242 additions and 59 deletions.
25 changes: 11 additions & 14 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
return op.Add(self, other)


@torch_op(
("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True, complex=True
)
@torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True)
def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""

Expand Down Expand Up @@ -2749,7 +2747,6 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
"aten::divide.Scalar",
"aten::true_divide.Tensor",
"aten::true_divide.Scalar",
"_operator::truediv",
)
)
def aten_div(self: TFloat, other: TFloat) -> TFloat:
Expand All @@ -2759,6 +2756,11 @@ def aten_div(self: TFloat, other: TFloat) -> TFloat:
return op.Div(self, other)


@torch_op("_operator::truediv", traceable=True)
def operator_truediv(self: TensorType, other: TensorType) -> FLOAT:
return op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype))


@torch_op(
(
"aten::div.Tensor",
Expand All @@ -2767,7 +2769,6 @@ def aten_div(self: TFloat, other: TFloat) -> TFloat:
"aten::divide.Scalar",
"aten::true_divide.Tensor",
"aten::true_divide.Scalar",
"_operator::truediv",
),
complex=True,
)
Expand Down Expand Up @@ -3597,17 +3598,15 @@ def python_math_floor(self: TFloat) -> TInt:
return op.Cast(floor, to=INT64.dtype)


@torch_op(("aten::floor_divide", "_operator::floordiv"), traceable=True)
@torch_op("aten::floor_divide", traceable=True)
def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""

return op.Floor(op.Div(self, other))


@torch_op(("aten::floor_divide", "_operator::floordiv"), traceable=True)
def aten_floor_divide_int(self: TInt, other: TInt) -> TInt:
"""floor_divide(Tensor self, Tensor other) -> Tensor"""

@torch_op("_operator::floordiv", traceable=True)
def operator_floordiv(self: INT64, other: INT64) -> INT64:
# We implement floor_divide only for positive inputs (using integer division)
# because that is the usual intended case and is the most efficient.
return op.Div(self, other)
Expand Down Expand Up @@ -4913,7 +4912,6 @@ def aten_logical_not(self: BOOL) -> BOOL:
"aten::bitwise_or.Scalar_Tensor",
"aten::add.Tensor",
"aten::add.Scalar",
"_operator::add",
),
traceable=True,
)
Expand Down Expand Up @@ -5631,7 +5629,7 @@ def aten_mul(self: TReal, other: TReal) -> TReal:


@torch_op(
("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"),
traceable=True,
)
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
Expand All @@ -5644,7 +5642,7 @@ def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:


@torch_op(
("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"),
traceable=True,
complex=True,
)
Expand Down Expand Up @@ -8017,7 +8015,6 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"aten::sub.Scalar",
"aten::subtract.Tensor",
"aten::subtract.Scalar",
"_operator::sub",
),
trace_only=True,
complex=True,
Expand Down
23 changes: 7 additions & 16 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,26 +825,17 @@ def aten_leaky_relu_backward(
raise NotImplementedError()


# NOTE: Do not register - We rely on PyTorch decomposition to aten_addmm (Gemm)
def aten_linear(input: TFloat, weight: TFloat) -> TFloat:
@torch_op("aten::linear", trace_only=True)
def aten_linear(input: TFloat, weight: TFloat, bias: Optional[TFloat] = None) -> TFloat:
"""linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"""

# NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases
# Optimizers may consider this path and replace it with Gemm
# We do not use Gemm here because input can have batch dimensions, which Gemm does not support
weight_transposed = op.Transpose(weight, perm=[1, 0])
return op.MatMul(input, weight_transposed)


# NOTE: Do not register - We rely on PyTorch decomposition to aten_addmm (Gemm)
def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat:
"""linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor"""

# NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases
# Optimizers may consider this path and replace it with Gemm
# We do not use Gemm here because input can have batch dimensions, which Gemm does not support
if len(input.shape) == 2:
# Use Gemm for the rank 2 input
return op.Gemm(input, weight, bias, transB=True)
weight_transposed = op.Transpose(weight, perm=[1, 0])
mul = op.MatMul(input, weight_transposed)
if bias is None:
return mul
return op.Add(mul, bias)


Expand Down
50 changes: 45 additions & 5 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Hashable,
Iterable,
Iterator,
NamedTuple,
OrderedDict,
Sequence,
SupportsInt,
Expand Down Expand Up @@ -1055,6 +1056,18 @@ def _quoted(string: str) -> str:
return f'"{string}"'


class Usage(NamedTuple):
"""A usage of a value in a node.
Attributes:
node: The node that uses the value.
idx: The input index of the value in the node.
"""

node: Node
idx: int


class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
"""IR Node.
Expand Down Expand Up @@ -1293,6 +1306,25 @@ def inputs(self, _: Any) -> None:
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
)

def predecessors(self) -> Sequence[Node]:
"""Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
# Use the ordered nature of a dictionary to deduplicate the nodes
predecessors: dict[Node, None] = {}
for value in self.inputs:
if value is not None and (producer := value.producer()) is not None:
predecessors[producer] = None
return tuple(predecessors)

def successors(self) -> Sequence[Node]:
"""Return the successor nodes of the node, deduplicated, in a deterministic order."""
# Use the ordered nature of a dictionary to deduplicate the nodes
successors: dict[Node, None] = {}
for value in self.outputs:
assert value is not None, "Bug: Output values are not expected to be None"
for usage in value.uses():
successors[usage.node] = None
return tuple(successors)

def replace_input_with(self, index: int, value: Value | None) -> None:
"""Replace an input with a new value."""
if index < 0 or index >= len(self.inputs):
Expand Down Expand Up @@ -1564,7 +1596,7 @@ def __init__(
# Use a collection of (Node, int) to store uses. This is needed
# because a single use can use the same value multiple times.
# Use a dictionary to preserve insertion order so that the visiting order is deterministic
self._uses: dict[tuple[Node, int], None] = {}
self._uses: dict[Usage, None] = {}
self.doc_string = doc_string

def __repr__(self) -> str:
Expand Down Expand Up @@ -1595,31 +1627,39 @@ def producer(self) -> Node | None:
"""
return self._producer

def consumers(self) -> Sequence[Node]:
"""Return the nodes (deduplicated) that consume this value."""
return tuple({usage.node: None for usage in self._uses})

def index(self) -> int | None:
"""The index of the output of the defining node."""
return self._index

def uses(self) -> Collection[tuple[Node, int]]:
def uses(self) -> Collection[Usage]:
"""Return a set of uses of the value.
The set contains tuples of ``(Node, index)`` where the index is the index of the input
of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
"""
return self._uses.keys()
# Create a tuple for the collection so that iteration on will will not
# be affected when the usage changes during graph mutation.
# This adds a small overhead but is better a user experience than
# having users call tuple().
return tuple(self._uses)

def _add_usage(self, use: Node, index: int) -> None:
"""Add a usage of this value.
This is an internal method. It should only be called by the Node class.
"""
self._uses[(use, index)] = None
self._uses[Usage(use, index)] = None

def _remove_usage(self, use: Node, index: int) -> None:
"""Remove a node from the uses of this value.
This is an internal method. It should only be called by the Node class.
"""
self._uses.pop((use, index))
self._uses.pop(Usage(use, index))

@property
def name(self) -> str | None:
Expand Down
48 changes: 44 additions & 4 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,13 @@ def test_is_dynamic_on_empty_shape(self):


class ValueTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = _core.Value(name="v0")
self.v1 = _core.Value(name="v1")
self.node = _core.Node(
"test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=2
)

def test_initialize(self):
_ = _core.Value()

Expand All @@ -732,14 +739,30 @@ def test_meta(self):
value.metadata_props["test"] = "any string"
self.assertEqual(value.metadata_props["test"], "any string")

def test_producer(self):
self.assertEqual(self.v0.producer(), None)
self.assertEqual(self.v1.producer(), None)
self.assertEqual(self.node.outputs[0].producer(), self.node)
self.assertEqual(self.node.outputs[1].producer(), self.node)

def test_consumers(self):
self.assertEqual(self.v0.consumers(), (self.node,))
self.assertEqual(self.v1.consumers(), (self.node,))
self.assertEqual(self.node.outputs[0].consumers(), ())
self.assertEqual(self.node.outputs[1].consumers(), ())

# TODO(justinchuby): Test all methods


class NodeTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = _core.Value()
self.v1 = _core.Value()
self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3)
self.v0 = _core.Value(name="v0")
self.v1 = _core.Value(name="v1")
self.node = _core.Node(
"test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=3
)
self.node_a = _core.Node("test", "TestOpA", inputs=[self.node.outputs[0]])
self.node_b = _core.Node("test", "TestOpB", inputs=self.node.outputs)

def test_it_is_hashable(self):
self.assertIsInstance(hash(self.node), int)
Expand All @@ -748,7 +771,7 @@ def test_it_is_hashable(self):
def test_init_with_values(self):
self.assertEqual(self.node.domain, "test")
self.assertEqual(self.node.op_type, "TestOp")
self.assertEqual(self.node.inputs, (self.v0, self.v1))
self.assertEqual(self.node.inputs, (self.v0, self.v1, self.v1))
self.assertEqual(len(self.node.outputs), 3)
self.assertEqual(self.node.attributes, {})

Expand Down Expand Up @@ -807,6 +830,23 @@ def test_it_is_added_to_a_graph_if_specified(self):
)
self.assertIn(self.node, graph)

def test_predecessors(self):
self.assertEqual(self.node.predecessors(), ())
self.assertEqual(self.node_a.predecessors(), (self.node,))
self.assertEqual(self.node_b.predecessors(), (self.node,))

def test_predecessors_are_unique(self):
# node_b has three inputs from node, but only one predecessor
self.assertEqual(self.node_b.predecessors(), (self.node,))

def test_successors(self):
self.assertEqual(self.node.successors(), (self.node_a, self.node_b))
self.assertEqual(self.node_a.successors(), ())
self.assertEqual(self.node_b.successors(), ())

def test_successors_are_unique(self):
self.assertEqual(self.node.successors(), (self.node_a, self.node_b))

# TODO(justinchuby): Test all methods


Expand Down
16 changes: 16 additions & 0 deletions onnxscript/ir/_tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Tape(Iterable[ir.Node]):

def __init__(self) -> None:
self._nodes: list[ir.Node] = []
self._initializers: list[ir.Value] = []

def __iter__(self) -> Iterator[ir.Node]:
return iter(self._nodes)
Expand All @@ -26,6 +27,10 @@ def __iter__(self) -> Iterator[ir.Node]:
def nodes(self) -> Sequence[ir.Node]:
return tuple(self._nodes)

@property
def initializers(self) -> Sequence[ir.Value]:
return tuple(self._initializers)

def op(
self,
op_type: str,
Expand Down Expand Up @@ -60,6 +65,17 @@ def op_multi_output(

return node.outputs

def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
name = name or tensor.name
if name is None:
raise ValueError("Name must be provided for initializer.")
shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims)
value = ir.Value(
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
)
self._initializers.append(value)
return value


# A type representing the domains/versions used in creating nodes in IR.
UsedOpsets = List[Tuple[str, Optional[int]]]
Expand Down
Loading

0 comments on commit a6c6538

Please sign in to comment.