Skip to content

Commit

Permalink
IR optimizer (#1855)
Browse files Browse the repository at this point in the history
Initial version of IR-based optimizer (avoids conversion to Proto).

Still to be evaluated/debugged with real models. Adding here to enable
experimentation with benchmark models.
  • Loading branch information
gramalingam authored Sep 6, 2024
1 parent d7a6411 commit fb6d20c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 14 deletions.
38 changes: 28 additions & 10 deletions onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import logging
from typing import Any

import onnx
import onnx.shape_inference

from onnxscript import rewriter
from onnxscript import ir, rewriter
from onnxscript.optimizer import _constant_folding, _inliner
from onnxscript.optimizer.constant_folding import fold_constants
from onnxscript.optimizer.remove_unused import remove_unused_nodes
from onnxscript.optimizer.remove_unused_function import remove_unused_functions
Expand All @@ -23,6 +26,13 @@

logger = logging.getLogger(__name__)

_DEFAULT_REWRITE_RULES = [
*no_op.rules.rules, # TODO: merge this rule into constant folding?
*broadcast_to_matmul.rules.rules,
gemm_to_matmul_add.rule,
*cast_constant_of_shape.rules.rules,
]


def optimize(
model: onnx.ModelProto,
Expand Down Expand Up @@ -79,15 +89,7 @@ def optimize(
model = remove_unused_functions(model)
inline_functions_with_unused_outputs(model)
# NOTE: This is general rewrite rules
model = rewriter.rewrite(
model,
pattern_rewrite_rules=[
*no_op.rules.rules, # TODO: merge this rule into constant folding?
*broadcast_to_matmul.rules.rules,
gemm_to_matmul_add.rule,
*cast_constant_of_shape.rules.rules,
],
)
model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
if stop_if_no_change and not modified:
logger.debug("Stopping after %d iterations.", _)
break
Expand All @@ -109,8 +111,24 @@ def optimize(
return model


def optimize_ir(
model: ir.Model,
num_iterations: int = 2,
*,
onnx_shape_inference: bool = True,
stop_if_no_change: bool = True,
) -> None:
del stop_if_no_change # Looks like rewriter doesn't support this yet.
_inliner.inline(model)
for _ in range(num_iterations):
_constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference)
rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
remove_unused_nodes(model)


__all__ = [
"fold_constants",
"remove_unused_nodes",
"optimize",
"optimize_ir",
]
2 changes: 1 addition & 1 deletion onnxscript/optimizer/_inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def clone_node(self, node: ir.Node) -> ir.Node:
num_outputs=len(node.outputs),
graph=None,
name=new_name,
doc_string=node.doc_string,
doc_string=node.doc_string, # type: ignore
metadata_props=new_metadata,
)
new_outputs = new_node.outputs
Expand Down
20 changes: 17 additions & 3 deletions onnxscript/optimizer/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

import onnx

import onnxscript.ir as ir
import onnxscript.optimizer as optimizer


class OptimizerTest(unittest.TestCase):
def test_static_split_to_sequence_with_uneven_split(self):
model = onnx.parser.parse_model(
def _model_proto(self) -> onnx.ModelProto:
return onnx.parser.parse_model(
"""
<
ir_version: 8,
Expand Down Expand Up @@ -59,11 +60,24 @@ def test_static_split_to_sequence_with_uneven_split(self):
}
"""
)
optimized = optimizer.optimize(model, num_iterations=1, onnx_shape_inference=False)

def test_static_split_to_sequence_with_uneven_split_proto(self):
model_proto = self._model_proto()
optimized = optimizer.optimize(
model_proto, num_iterations=1, onnx_shape_inference=False
)
self.assertEqual(len(optimized.graph.node), 2)
self.assertEqual(len(optimized.graph.node[0].output), 2)
self.assertEqual(optimized.graph.node[0].op_type, "Split")

def test_static_split_to_sequence_with_uneven_split_ir(self):
model_proto = self._model_proto()
model_ir = ir.serde.deserialize_model(model_proto)
optimizer.optimize_ir(model_ir, num_iterations=1, onnx_shape_inference=False)
self.assertEqual(len(model_ir.graph), 2)
self.assertEqual(len(model_ir.graph.node(0).outputs), 2)
self.assertEqual(model_ir.graph.node(0).op_type, "Split")


if __name__ == "__main__":
unittest.main()

0 comments on commit fb6d20c

Please sign in to comment.