Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Refactor ort specific fusions #2039

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

__all__ = [
# Modules
"function_rule",
"pattern",
# Functions
"rewrite",
Expand All @@ -16,18 +15,16 @@

from onnxscript import ir
from onnxscript.optimizer import _remove_unused, _remove_unused_function
from onnxscript.rewriter import function_rule, pattern
from onnxscript.rewriter import pattern

RewriteRuleSet = pattern.RewriteRuleSet
PatternRewriteRule = pattern.RewriteRule
FunctionRewriteRule = function_rule.FunctionRewriteRule

ModelProtoOrIr = TypeVar("ModelProtoOrIr", onnx.ModelProto, ir.Model)


def rewrite(
model: ModelProtoOrIr,
function_rewrite_rules: Sequence[type[FunctionRewriteRule]] = (),
pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], RewriteRuleSet] = (),
) -> ModelProtoOrIr:
if isinstance(model, onnx.ModelProto):
Expand All @@ -36,11 +33,6 @@ def rewrite(
else:
model_ir = model
proto = False
if function_rewrite_rules:
for rule_cls in function_rewrite_rules:
count, model_ir = rule_cls().apply_to_model(model_ir)
if count > 0:
print(f"Applied {count} of rewrite rules.")
if pattern_rewrite_rules:
if not isinstance(pattern_rewrite_rules, RewriteRuleSet):
# Create a pattern rule-set using provided rules
Expand Down
1 change: 1 addition & 0 deletions onnxscript/rewriter/onnxruntime/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This folder (and function_rule based rewrites) are deprecated. The folder will be removed soon.
48 changes: 0 additions & 48 deletions onnxscript/rewriter/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnx

from onnxscript.rewriter import function_rule, pattern
from onnxscript.rewriter import rewrite as _rewrite
from onnxscript.rewriter.onnxruntime import (
fused_matmul_rule_sets,
group_normalization_merge_silu,
instance_to_group_normalization,
softmax,
transformers,
)

ORT_FUNCTION_REWRITE_RULES = [*transformers.TRANSFORMERS_FUNCTION_REWRITE_RULES]

ORT_PATTERN_REWRITE_RULES = [
*softmax.rules.rules,
*instance_to_group_normalization.rules.rules,
# NOTE: group normalization merge silu should be applied after instance to group normalization
*group_normalization_merge_silu.rules.rules,
*fused_matmul_rule_sets.fused_matmul_rule_sets(),
]


def rewrite(
model_proto: onnx.ModelProto,
/,
function_rules: list[type[function_rule.FunctionRewriteRule]] | None = None,
pattern_rules: list[pattern.RewriteRule] | None = None,
) -> onnx.ModelProto:
"""Rewrite the model using the given rules.

Args:
model_proto: The model to rewrite.
function_rules: The function rewrite rules to apply. If None, the default rules
for onnxruntime are used.
pattern_rules: The pattern rewrite rules to apply. If None, the default rules
for onnxruntime are used.

Returns:
The rewritten model.
"""
function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES
pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES
return _rewrite(
model_proto, function_rewrite_rules=function_rules, pattern_rewrite_rules=pattern_rules
)
2 changes: 1 addition & 1 deletion onnxscript/rewriter/onnxruntime/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter import function_rule
from onnxscript.rewriter.onnxruntime import function_rule
from onnxscript.rewriter.onnxruntime.transformers import (
biassplitgelu,
fastgelu,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import onnxscript
from onnxscript import ir
from onnxscript.rewriter import function_rule
from onnxscript.rewriter.onnxruntime import function_rule

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/onnxruntime/transformers/fastgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import onnxscript
from onnxscript import ir
from onnxscript.rewriter import function_rule
from onnxscript.rewriter.onnxruntime import function_rule

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/onnxruntime/transformers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import onnxscript.ir.convenience
import onnxscript.rewriter._ir_utils as _ir_utils
from onnxscript import ir
from onnxscript.rewriter import function_rule
from onnxscript.rewriter.onnxruntime import function_rule

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
import onnxscript.ir.convenience
import onnxscript.rewriter._ir_utils as _ir_utils
from onnxscript import ir
from onnxscript.rewriter import function_rule
from onnxscript.rewriter.onnxruntime import function_rule

logger = logging.getLogger(__name__)

Expand Down
15 changes: 15 additions & 0 deletions onnxscript/rewriter/ort_fusions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@

__all__ = [
"optimize_for_ort",
"transformers",

Check warning

Code scanning / lintrunner

RUFF/F822 Warning

Undefined name transformers in \_\_all\_\_.
See https://docs.astral.sh/ruff/rules/undefined-export
]

from onnxscript.rewriter.ort_fusions import (
Fixed Show fixed Hide fixed
fused_matmul_rule_sets,
# group_normalization_merge_silu,
instance_to_group_normalization,
softmax,
)
from onnxscript.rewriter.ort_fusions._core import optimize_for_ort

ORT_PATTERN_REWRITE_RULES = [
*softmax.rules.rules,
*instance_to_group_normalization.rules.rules,
# NOTE: group normalization merge silu should be applied after instance to group normalization
# *group_normalization_merge_silu.rules.rules,
*fused_matmul_rule_sets.fused_matmul_rule_sets(),
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import onnx.reference
import onnx.reference.op_run

import onnxscript.rewriter.onnxruntime.fused_matmul_rule_sets as fused_matmul_rule_sets
import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets
from onnxscript import ir

FLOAT = onnx.TensorProto.FLOAT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import onnx.parser

from onnxscript import ir
from onnxscript.rewriter.onnxruntime import instance_to_group_normalization
from onnxscript.rewriter.ort_fusions import instance_to_group_normalization


class ReplaceInstanceNormWithGroupNormTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import parameterized

from onnxscript import ir
from onnxscript.rewriter.onnxruntime import softmax
from onnxscript.rewriter.ort_fusions import softmax


class SoftmaxUpcastRemovalTest(unittest.TestCase):
Expand Down
Loading