Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Implement replace_all_uses_with #1414

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions onnxscript/ir/_convenience.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Convenience methods for constructing and manipulating the IR.

This is an internal only module. We should choose to expose some of the methods
after they are proven to be useful.
"""

from __future__ import annotations

from typing import Any, Mapping, Sequence

from onnxscript.ir import _core, _protocols


def convert_attributes(attrs: Mapping[str, Any]) -> list[_core.Attr]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same as the one in _tape.py ? We can presumably use this there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I will create a follow up

attributes: list[_core.Attr] = []

Check warning on line 19 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L19

Added line #L19 was not covered by tests
for name, attr in attrs.items():
if isinstance(attr, int):
attributes.append(_core.AttrInt64(name, attr))

Check warning on line 22 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L22

Added line #L22 was not covered by tests
elif isinstance(attr, float):
attributes.append(_core.AttrFloat32(name, attr))

Check warning on line 24 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L24

Added line #L24 was not covered by tests
elif isinstance(attr, str):
attributes.append(_core.AttrString(name, attr))

Check warning on line 26 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L26

Added line #L26 was not covered by tests
elif isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
attributes.append(_core.AttrInt64s(name, attr))

Check warning on line 28 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L28

Added line #L28 was not covered by tests
elif isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
attributes.append(_core.AttrFloat32s(name, attr))

Check warning on line 30 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L30

Added line #L30 was not covered by tests
elif isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
attributes.append(_core.AttrStrings(name, attr))

Check warning on line 32 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L32

Added line #L32 was not covered by tests
elif isinstance(attr, _core.Attr):
attributes.append(attr)

Check warning on line 34 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L34

Added line #L34 was not covered by tests
else:
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
return attributes

Check warning on line 37 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L36-L37

Added lines #L36 - L37 were not covered by tests


def replace_all_uses_with(
values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
) -> None:
"""Replace all consumers of the given values with the replacements.

This is useful when nodes in the graph are replaced with new nodes, where
the old users need to be updated to use the outputs of the new nodes.

For example, suppose we have the following graph::

A -> {B, C}

We want to replace the node A with a new node D::

>>> from onnxscript import ir
>>> input = ir.Input("input")
>>> node_a = ir.Node("", "A", [input])
>>> node_b = ir.Node("", "B", node_a.outputs)
>>> node_c = ir.Node("", "C", node_a.outputs)
>>> node_d = ir.Node("", "D", [input])
>>> replace_all_uses_with(node_a.outputs, node_d.outputs)
>>> len(node_b.inputs)
1
>>> node_b.inputs[0].producer().op_type
'D'
>>> len(node_c.inputs)
1
>>> node_c.inputs[0].producer().op_type
'D'
>>> len(node_a.outputs[0].consumers())
0

When values and replacements are sequences, they are zipped into pairs. All
users of the first value is replaced with the first replacement, and so on.

.. note::
You still need to update the graph outputs if any of the values being
replaced are part of the graph outputs. Be sure to remove the old nodes
from the graph using ``graph.remove()`` if they are no longer needed.

Args:
values: The value or values to be replaced.
replacements: The new value or values to use as inputs.
"""
if not isinstance(values, Sequence):
values = (values,)

Check warning on line 86 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L86

Added line #L86 was not covered by tests
if not isinstance(replacements, Sequence):
replacements = (replacements,)

Check warning on line 88 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L88

Added line #L88 was not covered by tests
if len(values) != len(replacements):
raise ValueError("The number of values and replacements must match.")

Check warning on line 90 in onnxscript/ir/_convenience.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_convenience.py#L90

Added line #L90 was not covered by tests
for value, replacement in zip(values, replacements):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A question about the IR: how is a (missing) optional output represented? Does it have a corresponding Value object or is it represented as None in the node's outputs? Wondering if we should handle the possibility of None being here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't aware that an empty output was possible, so I didn't handle them anywhere. I will create a pr to implement empty output support.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correction: an empty output is represented as a Value whose name is "". It will not have any users so we don't need to do anything about it.

for user_node, index in tuple(value.consumers()):
user_node.replace_input_with(index, replacement)
33 changes: 0 additions & 33 deletions onnxscript/ir/convenience.py

This file was deleted.

22 changes: 9 additions & 13 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import onnx.printer

from onnxscript import ir
from onnxscript.ir import _convenience
from onnxscript.rewriter import _ir_utils

# Overview of the pattern module: The classes below are used to define both
Expand Down Expand Up @@ -1023,20 +1024,15 @@ def _apply_deltas(
# TODO: simplify this
last_deleted = deleted_nodes[-1]
last_inserted = inserted_nodes[-1]
assert len(last_deleted.outputs) == len(last_inserted.outputs)
# Reconnect the users of the deleted node to use the new outputs
for last_deleted_output, last_inserted_output in zip(
last_deleted.outputs, last_inserted.outputs
):
for node, index in tuple(last_deleted_output.consumers()):
# Fix consumers because we are mutating consumers in the loop
node.replace_input_with(index, last_inserted_output)

# Update graph/function outputs if the node generates output
for old_output, new_output in zip(last_deleted.outputs, last_inserted.outputs):
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
if graph_or_function_output is old_output:
graph_or_function.outputs[idx] = new_output
_convenience.replace_all_uses_with(last_deleted.outputs, last_inserted.outputs)
# Update graph/function outputs if the node generates output
replacement_mapping = dict(zip(last_deleted.outputs, last_inserted.outputs))
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
if graph_or_function_output in replacement_mapping:
graph_or_function.outputs[idx] = replacement_mapping[
graph_or_function_output
]

# insert new nodes after the index node
# TODO(justinchuby): Do not access by index [i]
Expand Down
Loading