diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 421535553..896a30b58 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -6,7 +6,6 @@ __all__ = [ # Modules - "function_rule", "pattern", # Functions "rewrite", @@ -16,18 +15,16 @@ from onnxscript import ir from onnxscript.optimizer import _remove_unused, _remove_unused_function -from onnxscript.rewriter import function_rule, pattern +from onnxscript.rewriter import pattern RewriteRuleSet = pattern.RewriteRuleSet PatternRewriteRule = pattern.RewriteRule -FunctionRewriteRule = function_rule.FunctionRewriteRule ModelProtoOrIr = TypeVar("ModelProtoOrIr", onnx.ModelProto, ir.Model) def rewrite( model: ModelProtoOrIr, - function_rewrite_rules: Sequence[type[FunctionRewriteRule]] = (), pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], RewriteRuleSet] = (), ) -> ModelProtoOrIr: if isinstance(model, onnx.ModelProto): @@ -36,11 +33,6 @@ def rewrite( else: model_ir = model proto = False - if function_rewrite_rules: - for rule_cls in function_rewrite_rules: - count, model_ir = rule_cls().apply_to_model(model_ir) - if count > 0: - print(f"Applied {count} of rewrite rules.") if pattern_rewrite_rules: if not isinstance(pattern_rewrite_rules, RewriteRuleSet): # Create a pattern rule-set using provided rules diff --git a/onnxscript/rewriter/onnxruntime/README.md b/onnxscript/rewriter/onnxruntime/README.md new file mode 100644 index 000000000..b1a5d205a --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/README.md @@ -0,0 +1 @@ +This folder (and function_rule based rewrites) are deprecated. The folder will be removed soon. diff --git a/onnxscript/rewriter/onnxruntime/__init__.py b/onnxscript/rewriter/onnxruntime/__init__.py index aa7b9a0ae..59e481eb9 100644 --- a/onnxscript/rewriter/onnxruntime/__init__.py +++ b/onnxscript/rewriter/onnxruntime/__init__.py @@ -1,50 +1,2 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from __future__ import annotations - -import onnx - -from onnxscript.rewriter import function_rule, pattern -from onnxscript.rewriter import rewrite as _rewrite -from onnxscript.rewriter.onnxruntime import ( - fused_matmul_rule_sets, - group_normalization_merge_silu, - instance_to_group_normalization, - softmax, - transformers, -) - -ORT_FUNCTION_REWRITE_RULES = [*transformers.TRANSFORMERS_FUNCTION_REWRITE_RULES] - -ORT_PATTERN_REWRITE_RULES = [ - *softmax.rules.rules, - *instance_to_group_normalization.rules.rules, - # NOTE: group normalization merge silu should be applied after instance to group normalization - *group_normalization_merge_silu.rules.rules, - *fused_matmul_rule_sets.fused_matmul_rule_sets(), -] - - -def rewrite( - model_proto: onnx.ModelProto, - /, - function_rules: list[type[function_rule.FunctionRewriteRule]] | None = None, - pattern_rules: list[pattern.RewriteRule] | None = None, -) -> onnx.ModelProto: - """Rewrite the model using the given rules. - - Args: - model_proto: The model to rewrite. - function_rules: The function rewrite rules to apply. If None, the default rules - for onnxruntime are used. - pattern_rules: The pattern rewrite rules to apply. If None, the default rules - for onnxruntime are used. - - Returns: - The rewritten model. - """ - function_rules = function_rules or ORT_FUNCTION_REWRITE_RULES - pattern_rules = pattern_rules or ORT_PATTERN_REWRITE_RULES - return _rewrite( - model_proto, function_rewrite_rules=function_rules, pattern_rewrite_rules=pattern_rules - ) diff --git a/onnxscript/rewriter/function_rule.py b/onnxscript/rewriter/onnxruntime/function_rule.py similarity index 100% rename from onnxscript/rewriter/function_rule.py rename to onnxscript/rewriter/onnxruntime/function_rule.py diff --git a/onnxscript/rewriter/onnxruntime/transformers/__init__.py b/onnxscript/rewriter/onnxruntime/transformers/__init__.py index be0085ae0..fdf0b207f 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/__init__.py +++ b/onnxscript/rewriter/onnxruntime/transformers/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations -from onnxscript.rewriter import function_rule +from onnxscript.rewriter.onnxruntime import function_rule from onnxscript.rewriter.onnxruntime.transformers import ( biassplitgelu, fastgelu, diff --git a/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py b/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py index b63eb0cce..6ec5c44bb 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py +++ b/onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py @@ -6,7 +6,7 @@ import onnxscript from onnxscript import ir -from onnxscript.rewriter import function_rule +from onnxscript.rewriter.onnxruntime import function_rule logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py b/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py index b0967c7ed..a79adcfc4 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py +++ b/onnxscript/rewriter/onnxruntime/transformers/fastgelu.py @@ -6,7 +6,7 @@ import onnxscript from onnxscript import ir -from onnxscript.rewriter import function_rule +from onnxscript.rewriter.onnxruntime import function_rule logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/onnxruntime/transformers/layernorm.py b/onnxscript/rewriter/onnxruntime/transformers/layernorm.py index fb56c9f6c..2b2738221 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/layernorm.py +++ b/onnxscript/rewriter/onnxruntime/transformers/layernorm.py @@ -8,7 +8,7 @@ import onnxscript.ir.convenience import onnxscript.rewriter._ir_utils as _ir_utils from onnxscript import ir -from onnxscript.rewriter import function_rule +from onnxscript.rewriter.onnxruntime import function_rule logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py index b6c6f0a96..4b992360e 100644 --- a/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py +++ b/onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py @@ -59,7 +59,7 @@ import onnxscript.ir.convenience import onnxscript.rewriter._ir_utils as _ir_utils from onnxscript import ir -from onnxscript.rewriter import function_rule +from onnxscript.rewriter.onnxruntime import function_rule logger = logging.getLogger(__name__) diff --git a/onnxscript/rewriter/ort_fusions/__init__.py b/onnxscript/rewriter/ort_fusions/__init__.py index ef72e4bea..f4d8c0df4 100644 --- a/onnxscript/rewriter/ort_fusions/__init__.py +++ b/onnxscript/rewriter/ort_fusions/__init__.py @@ -4,6 +4,21 @@ __all__ = [ "optimize_for_ort", + "transformers", ] +from onnxscript.rewriter.ort_fusions import ( + fused_matmul_rule_sets, + # group_normalization_merge_silu, + instance_to_group_normalization, + softmax, +) from onnxscript.rewriter.ort_fusions._core import optimize_for_ort + +ORT_PATTERN_REWRITE_RULES = [ + *softmax.rules.rules, + *instance_to_group_normalization.rules.rules, + # NOTE: group normalization merge silu should be applied after instance to group normalization + # *group_normalization_merge_silu.rules.rules, + *fused_matmul_rule_sets.fused_matmul_rule_sets(), +] diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py similarity index 100% rename from onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets.py rename to onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py diff --git a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py similarity index 99% rename from onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py rename to onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py index a7d170e69..04210e853 100644 --- a/onnxscript/rewriter/onnxruntime/fused_matmul_rule_sets_test.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py @@ -10,7 +10,7 @@ import onnx.reference import onnx.reference.op_run -import onnxscript.rewriter.onnxruntime.fused_matmul_rule_sets as fused_matmul_rule_sets +import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets from onnxscript import ir FLOAT = onnx.TensorProto.FLOAT diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py b/onnxscript/rewriter/ort_fusions/instance_to_group_normalization.py similarity index 100% rename from onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py rename to onnxscript/rewriter/ort_fusions/instance_to_group_normalization.py diff --git a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py b/onnxscript/rewriter/ort_fusions/instance_to_group_normalization_test.py similarity index 99% rename from onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py rename to onnxscript/rewriter/ort_fusions/instance_to_group_normalization_test.py index 81a20a984..e5754d78d 100644 --- a/onnxscript/rewriter/onnxruntime/instance_to_group_normalization_test.py +++ b/onnxscript/rewriter/ort_fusions/instance_to_group_normalization_test.py @@ -6,7 +6,7 @@ import onnx.parser from onnxscript import ir -from onnxscript.rewriter.onnxruntime import instance_to_group_normalization +from onnxscript.rewriter.ort_fusions import instance_to_group_normalization class ReplaceInstanceNormWithGroupNormTest(unittest.TestCase): diff --git a/onnxscript/rewriter/onnxruntime/softmax.py b/onnxscript/rewriter/ort_fusions/softmax.py similarity index 100% rename from onnxscript/rewriter/onnxruntime/softmax.py rename to onnxscript/rewriter/ort_fusions/softmax.py diff --git a/onnxscript/rewriter/onnxruntime/softmax_test.py b/onnxscript/rewriter/ort_fusions/softmax_test.py similarity index 98% rename from onnxscript/rewriter/onnxruntime/softmax_test.py rename to onnxscript/rewriter/ort_fusions/softmax_test.py index f2aa37c1f..e94657d57 100644 --- a/onnxscript/rewriter/onnxruntime/softmax_test.py +++ b/onnxscript/rewriter/ort_fusions/softmax_test.py @@ -6,7 +6,7 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter.onnxruntime import softmax +from onnxscript.rewriter.ort_fusions import softmax class SoftmaxUpcastRemovalTest(unittest.TestCase):