From e634897dd1f533310e3587102e1aef550107e05e Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 24 Jan 2025 13:44:50 -0800 Subject: [PATCH 1/3] Refactor ort specific fusions --- onnxscript/rewriter/__init__.py | 10 +--------- onnxscript/rewriter/onnxruntime/README.md | 1 + onnxscript/rewriter/onnxruntime/__init__.py | 3 ++- onnxscript/rewriter/{ => onnxruntime}/function_rule.py | 0 .../rewriter/onnxruntime/transformers/__init__.py | 2 +- .../rewriter/onnxruntime/transformers/biassplitgelu.py | 2 +- .../rewriter/onnxruntime/transformers/fastgelu.py | 2 +- .../rewriter/onnxruntime/transformers/layernorm.py | 2 +- .../onnxruntime/transformers/multihead_attention.py | 2 +- .../fused_matmul_rule_sets.py | 0 .../fused_matmul_rule_sets_test.py | 2 +- .../instance_to_group_normalization.py | 0 .../instance_to_group_normalization_test.py | 2 +- .../rewriter/{onnxruntime => ort_fusions}/softmax.py | 0 .../{onnxruntime => ort_fusions}/softmax_test.py | 2 +- 15 files changed, 12 insertions(+), 18 deletions(-) create mode 100644 onnxscript/rewriter/onnxruntime/README.md rename onnxscript/rewriter/{ => onnxruntime}/function_rule.py (100%) rename onnxscript/rewriter/{onnxruntime => ort_fusions}/fused_matmul_rule_sets.py (100%) rename onnxscript/rewriter/{onnxruntime => ort_fusions}/fused_matmul_rule_sets_test.py (99%) rename onnxscript/rewriter/{onnxruntime => ort_fusions}/instance_to_group_normalization.py (100%) rename onnxscript/rewriter/{onnxruntime => ort_fusions}/instance_to_group_normalization_test.py (99%) rename onnxscript/rewriter/{onnxruntime => ort_fusions}/softmax.py (100%) rename onnxscript/rewriter/{onnxruntime => ort_fusions}/softmax_test.py (98%) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 421535553c..896a30b58f 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -6,7 +6,6 @@ __all__ = [ # Modules - "function_rule", "pattern", # Functions "rewrite", @@ -16,18 +15,16 @@ from onnxscript import ir from onnxscript.optimizer import _remove_unused, _remove_unused_function -from onnxscript.rewriter import function_rule, pattern +from onnxscript.rewriter import pattern RewriteRuleSet = pattern.RewriteRuleSet PatternRewriteRule = pattern.RewriteRule -FunctionRewriteRule = function_rule.FunctionRewriteRule ModelProtoOrIr = TypeVar("ModelProtoOrIr", onnx.ModelProto, ir.Model) def rewrite( model: ModelProtoOrIr, - function_rewrite_rules: Sequence[type[FunctionRewriteRule]] = (), pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], RewriteRuleSet] = (), ) -> ModelProtoOrIr: if isinstance(model, onnx.ModelProto): @@ -36,11 +33,6 @@ def rewrite( else: model_ir = model proto = False - if function_rewrite_rules: - for rule_cls in function_rewrite_rules: - count, model_ir = rule_cls().apply_to_model(model_ir) - if count > 0: - print(f"Applied {count} of rewrite rules.") if pattern_rewrite_rules: if not isinstance(pattern_rewrite_rules, RewriteRuleSet): # Create a pattern rule-set using provided rules diff --git a/onnxscript/rewriter/onnxruntime/README.md b/onnxscript/rewriter/onnxruntime/README.md new file mode 100644 index 0000000000..20290b22b1 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/README.md @@ -0,0 +1 @@ +This folder (and function_rule based rewrites) are deprecated. The folder will be removed soon. \ No newline at end of file diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index aa7b9a0ae9..cf49d07153 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -4,9 +4,10 @@ import onnx -from onnxscript.rewriter import function_rule, pattern +from onnxscript.rewriter import pattern from onnxscript.rewriter import rewrite as _rewrite from onnxscript.rewriter.onnxruntime import ( + function_rule, fused_matmul_rule_sets, group_normalization_merge_silu, instance_to_group_normalization, diff --git a/onnxscript/rewriter/function_rule.py b/onnxscript/rewriter/onnxruntime/function_rule.py similarity index 100% rename from onnxscript/rewriter/function_rule.py rename to onnxscript/rewriter/onnxruntime/function_rule.py diff --git a/onnxscript/rewriter/onnxruntime/transformers/__init__.py b/onnxscript/rewriter/onnxruntime/transformers/__init__.py index be0085ae07..fdf0b207fd 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/__init__.py +++ b/onnxscript/rewriter/onnxruntime/transformers/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations -from onnxscript.rewriter import function_rule +from onnxscript.rewriter.onnxruntime import function_rule from onnxscript.rewriter.onnxruntime.transformers import ( biassplitgelu, fastgelu, diff --git a/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py b/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py index b63eb0cce5..6ec5c44bbc 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py +++ b/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py @@ -6,7 +6,7 @@ import onnxscript from onnxscript import ir -from onnxscript.rewriter import function_rule +from onnxscript.rewriter.onnxruntime import function_rule logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py b/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py index b0967c7ed4..a79adcfc46 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py +++ b/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py @@ -6,7 +6,7 @@ import onnxscript from onnxscript import ir -from onnxscript.rewriter import function_rule +from onnxscript.rewriter.onnxruntime import function_rule logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/onnxruntime/transformers/layernorm.py b/onnxscript/rewriter/onnxruntime/transformers/layernorm.py index fb56c9f6c7..2b2738221b 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/layernorm.py +++ b/onnxscript/rewriter/onnxruntime/transformers/layernorm.py @@ -8,7 +8,7 @@ import onnxscript.ir.convenience import onnxscript.rewriter._ir_utils as _ir_utils from onnxscript import ir -from onnxscript.rewriter import function_rule +from onnxscript.rewriter.onnxruntime import function_rule logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py index b6c6f0a969..4b992360e2 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py +++ b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py @@ -59,7 +59,7 @@ import onnxscript.ir.convenience import onnxscript.rewriter._ir_utils as _ir_utils from onnxscript import ir -from onnxscript.rewriter import function_rule +from onnxscript.rewriter.onnxruntime import function_rule logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py similarity index 100% rename from onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py rename to onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py similarity index 99% rename from onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py rename to onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py index a7d170e69e..04210e8537 100644 --- a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py @@ -10,7 +10,7 @@ import onnx.reference import onnx.reference.op_run -import onnxscript.rewriter.onnxruntime.fused_matmul_rule_sets as fused_matmul_rule_sets +import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets from onnxscript import ir FLOAT = onnx.TensorProto.FLOAT diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/ort_fusions/instance_to_group_normalization.py similarity index 100% rename from onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py rename to onnxscript/rewriter/ort_fusions/instance_to_group_normalization.py diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py b/onnxscript/rewriter/ort_fusions/instance_to_group_normalization_test.py similarity index 99% rename from onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py rename to onnxscript/rewriter/ort_fusions/instance_to_group_normalization_test.py index 81a20a984d..e5754d78d6 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py +++ b/onnxscript/rewriter/ort_fusions/instance_to_group_normalization_test.py @@ -6,7 +6,7 @@ import onnx.parser from onnxscript import ir -from onnxscript.rewriter.onnxruntime import instance_to_group_normalization +from onnxscript.rewriter.ort_fusions import instance_to_group_normalization class ReplaceInstanceNormWithGroupNormTest(unittest.TestCase): diff --git a/onnxscript/rewriter/onnxruntime/softmax.py b/onnxscript/rewriter/ort_fusions/softmax.py similarity index 100% rename from onnxscript/rewriter/onnxruntime/softmax.py rename to onnxscript/rewriter/ort_fusions/softmax.py diff --git a/onnxscript/rewriter/onnxruntime/softmax_test.py b/onnxscript/rewriter/ort_fusions/softmax_test.py similarity index 98% rename from onnxscript/rewriter/onnxruntime/softmax_test.py rename to onnxscript/rewriter/ort_fusions/softmax_test.py index f2aa37c1ff..e94657d573 100644 --- a/onnxscript/rewriter/onnxruntime/softmax_test.py +++ b/onnxscript/rewriter/ort_fusions/softmax_test.py @@ -6,7 +6,7 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter.onnxruntime import softmax +from onnxscript.rewriter.ort_fusions import softmax class SoftmaxUpcastRemovalTest(unittest.TestCase): From ad2332fc643ca5efaf41ec713f621192d724ed5b Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 24 Jan 2025 14:06:35 -0800 Subject: [PATCH 2/3] Remove onnxruntime init --- onnxscript/rewriter/onnxruntime/__init__.py | 49 --------------------- onnxscript/rewriter/ort_fusions/__init__.py | 16 +++++++ 2 files changed, 16 insertions(+), 49 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index cf49d07153..59e481eb93 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -1,51 +1,2 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from __future__ import annotations - -import onnx - -from onnxscript.rewriter import pattern -from onnxscript.rewriter import rewrite as _rewrite -from onnxscript.rewriter.onnxruntime import ( - function_rule, - fused_matmul_rule_sets, - group_normalization_merge_silu, - instance_to_group_normalization, - softmax, - transformers, -) - -ORT_FUNCTION_REWRITE_RULES = [*transformers.TRANSFORMERS_FUNCTION_REWRITE_RULES] - -ORT_PATTERN_REWRITE_RULES = [ - *softmax.rules.rules, - *instance_to_group_normalization.rules.rules, - # NOTE: group normalization merge silu should be applied after instance to group normalization - *group_normalization_merge_silu.rules.rules, - *fused_matmul_rule_sets.fused_matmul_rule_sets(), -] - - -def rewrite( - model_proto: onnx.ModelProto, - /, - function_rules: list[type[function_rule.FunctionRewriteRule]] | None = None, - pattern_rules: list[pattern.RewriteRule] | None = None, -) -> onnx.ModelProto: - """Rewrite the model using the given rules. - - Args: - model_proto: The model to rewrite. - function_rules: The function rewrite rules to apply. If None, the default rules - for onnxruntime are used. - pattern_rules: The pattern rewrite rules to apply. If None, the default rules - for onnxruntime are used. - - Returns: - The rewritten model. - """ - function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES - pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES - return _rewrite( - model_proto, function_rewrite_rules=function_rules, pattern_rewrite_rules=pattern_rules - ) diff --git a/onnxscript/rewriter/ort_fusions/__init__.py b/onnxscript/rewriter/ort_fusions/__init__.py index ef72e4beae..55081bb072 100644 --- a/onnxscript/rewriter/ort_fusions/__init__.py +++ b/onnxscript/rewriter/ort_fusions/__init__.py @@ -4,6 +4,22 @@ __all__ = [ "optimize_for_ort", + "transformers", ] +from onnxscript.rewriter.ort_fusions import ( + fused_matmul_rule_sets, + # group_normalization_merge_silu, + instance_to_group_normalization, + softmax, + transformers, +) from onnxscript.rewriter.ort_fusions._core import optimize_for_ort + +ORT_PATTERN_REWRITE_RULES = [ + *softmax.rules.rules, + *instance_to_group_normalization.rules.rules, + # NOTE: group normalization merge silu should be applied after instance to group normalization + # *group_normalization_merge_silu.rules.rules, + *fused_matmul_rule_sets.fused_matmul_rule_sets(), +] From 0cd5c73ac393e08f6f562712a434f3d336df7a97 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 24 Jan 2025 15:09:22 -0800 Subject: [PATCH 3/3] Cleanup --- onnxscript/rewriter/onnxruntime/README.md | 2 +- onnxscript/rewriter/ort_fusions/__init__.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/README.md b/onnxscript/rewriter/onnxruntime/README.md index 20290b22b1..b1a5d205a0 100644 --- a/onnxscript/rewriter/onnxruntime/README.md +++ b/onnxscript/rewriter/onnxruntime/README.md @@ -1 +1 @@ -This folder (and function_rule based rewrites) are deprecated. The folder will be removed soon. \ No newline at end of file +This folder (and function_rule based rewrites) are deprecated. The folder will be removed soon. diff --git a/onnxscript/rewriter/ort_fusions/__init__.py b/onnxscript/rewriter/ort_fusions/__init__.py index 55081bb072..f4d8c0df42 100644 --- a/onnxscript/rewriter/ort_fusions/__init__.py +++ b/onnxscript/rewriter/ort_fusions/__init__.py @@ -12,7 +12,6 @@ # group_normalization_merge_silu, instance_to_group_normalization, softmax, - transformers, ) from onnxscript.rewriter.ort_fusions._core import optimize_for_ort