-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] Implement replace_all_uses_with #1414
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
# -------------------------------------------------------------------------- | ||
"""Convenience methods for constructing and manipulating the IR. | ||
|
||
This is an internal only module. We should choose to expose some of the methods | ||
after they are proven to be useful. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any, Mapping, Sequence | ||
|
||
from onnxscript.ir import _core, _protocols | ||
|
||
|
||
def convert_attributes(attrs: Mapping[str, Any]) -> list[_core.Attr]: | ||
attributes: list[_core.Attr] = [] | ||
for name, attr in attrs.items(): | ||
if isinstance(attr, int): | ||
attributes.append(_core.AttrInt64(name, attr)) | ||
elif isinstance(attr, float): | ||
attributes.append(_core.AttrFloat32(name, attr)) | ||
elif isinstance(attr, str): | ||
attributes.append(_core.AttrString(name, attr)) | ||
elif isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr): | ||
attributes.append(_core.AttrInt64s(name, attr)) | ||
elif isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr): | ||
attributes.append(_core.AttrFloat32s(name, attr)) | ||
elif isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr): | ||
attributes.append(_core.AttrStrings(name, attr)) | ||
elif isinstance(attr, _core.Attr): | ||
attributes.append(attr) | ||
else: | ||
raise TypeError(f"Unsupported attribute type: '{type(attr)}'") | ||
return attributes | ||
|
||
|
||
def replace_all_uses_with( | ||
values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol], | ||
replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol], | ||
) -> None: | ||
"""Replace all consumers of the given values with the replacements. | ||
|
||
This is useful when nodes in the graph are replaced with new nodes, where | ||
the old users need to be updated to use the outputs of the new nodes. | ||
|
||
For example, suppose we have the following graph:: | ||
|
||
A -> {B, C} | ||
|
||
We want to replace the node A with a new node D:: | ||
|
||
>>> from onnxscript import ir | ||
>>> input = ir.Input("input") | ||
>>> node_a = ir.Node("", "A", [input]) | ||
>>> node_b = ir.Node("", "B", node_a.outputs) | ||
>>> node_c = ir.Node("", "C", node_a.outputs) | ||
>>> node_d = ir.Node("", "D", [input]) | ||
>>> replace_all_uses_with(node_a.outputs, node_d.outputs) | ||
>>> len(node_b.inputs) | ||
1 | ||
>>> node_b.inputs[0].producer().op_type | ||
'D' | ||
>>> len(node_c.inputs) | ||
1 | ||
>>> node_c.inputs[0].producer().op_type | ||
'D' | ||
>>> len(node_a.outputs[0].consumers()) | ||
0 | ||
|
||
When values and replacements are sequences, they are zipped into pairs. All | ||
users of the first value is replaced with the first replacement, and so on. | ||
|
||
.. note:: | ||
You still need to update the graph outputs if any of the values being | ||
replaced are part of the graph outputs. Be sure to remove the old nodes | ||
from the graph using ``graph.remove()`` if they are no longer needed. | ||
|
||
Args: | ||
values: The value or values to be replaced. | ||
replacements: The new value or values to use as inputs. | ||
""" | ||
if not isinstance(values, Sequence): | ||
values = (values,) | ||
if not isinstance(replacements, Sequence): | ||
replacements = (replacements,) | ||
if len(values) != len(replacements): | ||
raise ValueError("The number of values and replacements must match.") | ||
for value, replacement in zip(values, replacements): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A question about the IR: how is a (missing) optional output represented? Does it have a corresponding Value object or is it represented as None in the node's outputs? Wondering if we should handle the possibility of None being here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't aware that an empty output was possible, so I didn't handle them anywhere. I will create a pr to implement empty output support. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correction: an empty output is represented as a Value whose name is |
||
for user_node, index in tuple(value.consumers()): | ||
user_node.replace_input_with(index, replacement) |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the same as the one in _tape.py ? We can presumably use this there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! I will create a follow up