From 343161afd6c3a93991e4da40be98582fce01cdec Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Tue, 7 Jan 2025 11:05:52 -0800 Subject: [PATCH] Add rotary embedding fusion rule (part 1) (#1981) Initial version of fusion for rotary embedding. Limitations: currently addresses only non-interleaved and full rotation. Other: * Add support for rewriting rules where the matched nodes are not removed. Useful in cases where matched nodes include some shared nodes. * Add optimization to eliminate redundant Reshape (helps simplify pattern). --- onnxscript/optimizer/_constant_folding.py | 23 ++++ onnxscript/rewriter/_ir_utils.py | 27 +++++ onnxscript/rewriter/generic_pattern.py | 5 + .../rewriter/onnxruntime/xformers/__init__.py | 12 +++ .../onnxruntime/xformers/_test_utils.py | 2 +- .../onnxruntime/xformers/cos_sin_cache.py | 102 ++++++++++++++++++ .../xformers/cos_sin_cache_test.py | 29 +++++ .../onnxruntime/xformers/rms_normalization.py | 8 +- .../xformers/rms_normalization_test.py | 9 -- .../onnxruntime/xformers/rotary_embedding.py | 64 +++++++++++ .../xformers/rotary_embedding_test.py | 23 ++++ onnxscript/rewriter/pattern.py | 42 +++++--- 12 files changed, 317 insertions(+), 29 deletions(-) create mode 100644 onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 661a5cd82..1ecfa0911 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -300,6 +300,29 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> return default +@register("Reshape") +def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Reshape node by Identity when applicable.""" + input = _get_input(node, 0) + shape = _get_input(node, 1) + if input is None or shape is None: + return None + input_shape = input.shape + if input_shape is None: + return None + input_shape_dims = list(input_shape.dims) + if any(not isinstance(dim, int) for dim in input_shape_dims): + return None + shape_value = _get_numpy_value(shape) + if shape_value is None: + return None + target_shape_dims = shape_value.tolist() + if input_shape_dims == target_shape_dims: + # No need to check for special values like -1, 0, etc. here + return op.Identity(input) + return None + + @register("Cast") def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = _get_input(node, 0) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 7c303556a..1d657a5ab 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -2,6 +2,9 @@ # Licensed under the MIT License. from __future__ import annotations +import math +from typing import Callable + import numpy as np import onnxscript.ir as ir @@ -77,3 +80,27 @@ def get_singleton_value(val: ir.Value | None): if np_val is not None and np_val.size == 1: return np_val.item() return None + + +def is_singleton_value( + val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None +) -> bool: + """Returns True if the value is a single element tensor with given value, and False otherwise.""" + scalar = get_singleton_value(val) + if scalar is None: + return False + if callable(expected): + return expected(scalar) + if isinstance(expected, int): + return expected == scalar + # rtol must be specified for float comparison + assert rtol is not None + return math.isclose(scalar, expected, rel_tol=rtol) + + +def has_rank(value: ir.Value | None, rank: int) -> bool: + """Returns True if the value is statically known to have the given rank, and False otherwise.""" + if value is None: + return False + shape = value.shape + return (shape is not None) and (shape.rank() == rank) diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index 2926f5964..de06d7a22 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -551,7 +551,12 @@ def match( graph_or_function: ir.Graph | ir.Function, node: ir.Node, verbose: int = 0, + remove_nodes: bool = True, ) -> orp.MatchResult | None: + if not remove_nodes: + raise NotImplementedError( + "remove_nodes=False is not implemented in GenericPatternMatcher" + ) del model del graph_or_function self.verbose = verbose diff --git a/onnxscript/rewriter/onnxruntime/xformers/__init__.py b/onnxscript/rewriter/onnxruntime/xformers/__init__.py index 44b5591d8..43cec1352 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/__init__.py +++ b/onnxscript/rewriter/onnxruntime/xformers/__init__.py @@ -1,3 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from __future__ import annotations + +__all__ = [ + "fuse_rms_normalization", + "fuse_normalization", + "fuse_rotary_embedding", + "fuse_cos_sin_cache", +] + +from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization +from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding +from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization diff --git a/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py index 0b4e2c55f..b9ed0aecf 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py @@ -25,7 +25,7 @@ def ort_run(model_name: str, model, inputs): providers = ["CPUExecutionProvider"] with tempfile.TemporaryDirectory() as temp_dir: model_path = os.path.join(temp_dir, f"{model_name}.onnx") - io.save(model, model_path) + _save(model, model_path) # Run model session = onnxruntime.InferenceSession(model_path, providers=providers) ort_outputs = session.run(None, inputs) diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py new file mode 100644 index 000000000..46272ccf9 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import numpy as np + +import onnxscript.ir as ir +from onnxscript.optimizer import remove_unused_nodes +from onnxscript.rewriter import _ir_utils, pattern + +# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops. + +# We match against the following code pattern: +# Original code (from transformers) for computing cos/sin cache for RoPE: +# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135 +# position_ids_expanded = position_ids[:, None, :].float() +# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) +# emb = torch.cat((freqs, freqs), dim=-1) +# cos = emb.cos() +# sin = emb.sin() +# +# We rewrite this pattern into the following form: +# inv_freq_values = inv_freq_expanded.reshape(1, -1) +# pos_id_range = np.arange(max_pos_id, dtype=np.float32).reshape(-1, 1) +# angles = np.matmul(pos_id_range, inv_freq_values) +# cos_value = np.cos(angles) +# sin_value = np.sin(angles) +# cos_2d = op.Constant(value=ir.tensor(cos_value)) +# sin_2d = op.Constant(value=ir.tensor(sin_value)) +# +# This produces cos/sin values in a form that can be used by ORT's custom ops. + +# TODO: To apply the pattern-rewrite, we need to know the maximum position id. +# Need to find a way to get this information from the model or its config. + + +class CosSinCacheFusion(pattern.RewriteRuleClassBase): + def __init__(self, name: str, max_pos_id: int): + # This pattern makes use of shared Cos/Sin values. So, we can't remove the + # matched nodes as part of the rewrite-step. We apply a separate final + # pass to remove unused nodes. + super().__init__(name, remove_nodes=False) + self._max_pos_id = max_pos_id + + def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads): + position_ids_expanded = op.Unsqueeze(position_ids, 1) + position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) + freqs = op.MatMul(inv_freq, position_ids_expanded) + freqs = op.Transpose(freqs, perm=[0, 2, 1]) + emb = op.Concat(freqs, freqs, axis=-1) + cos = op.Cos(emb) + sin = op.Sin(emb) + cos_4d = op.Unsqueeze(cos, 1) # convert + sin_4d = op.Unsqueeze(sin, 1) + return op.RotaryEmbedding( + x, + cos_4d, + sin_4d, + interleaved=interleaved, + num_heads=num_heads, + _domain="ai.onnxruntime.fusion", + ) + + def check(self, context, inv_freq, position_ids, **_) -> bool: + if not _ir_utils.has_rank(position_ids, 2): + return False + if not _ir_utils.has_rank(inv_freq, 3): + return False + inv_freq_shape = inv_freq.shape + if inv_freq.const_value is None: + return False + return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1 + + def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_): + inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1) + pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) + angles = np.matmul(pos_id_range, inv_freq_values) + cos_value = np.cos(angles) + sin_value = np.sin(angles) + cos_2d = op.Constant(value=ir.tensor(cos_value)) + sin_2d = op.Constant(value=ir.tensor(sin_value)) + return op.RotaryEmbedding( + x, + position_ids, + cos_2d, + sin_2d, + interleaved=interleaved, + num_heads=num_heads, + _domain="com.microsoft", + ) + + +_rule = CosSinCacheFusion.rule("CosSinCache", 2048) + +cos_sin_cache_rules = pattern.RewriteRuleSet([_rule]) + + +def fuse_cos_sin_cache(model: ir.Model) -> int: + count = cos_sin_cache_rules.apply_to_model(model) + print(f"CosSinCache count: {count}") + remove_unused_nodes(model) + return count diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py new file mode 100644 index 000000000..dfe6625a8 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime.xformers import fuse_cos_sin_cache, fuse_rotary_embedding +from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData +from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run + + +class TestCosSinCacheTransform(unittest.TestCase): + def test_smollm(self): + smollm_test = _SmollmTestData() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = smollm_test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + count = fuse_rotary_embedding(model) + self.assertGreater(count, 0) + count = fuse_cos_sin_cache(model) + self.assertGreater(count, 0) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py index 1f7a96df1..1e348acfb 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py @@ -35,14 +35,10 @@ def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool): cast_input: Whether to cast input to do the normalization in a different precision. cast_normalized: Whether to cast the normalized output to the target dtype (same as scale). """ - self._name = name + super().__init__(name=name) self._cast_input = cast_input self._cast_normalized = cast_normalized - @property - def name(self): - return self._name - def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): if self._cast_input: x = op.Cast(x, to=compute_dtype) @@ -95,5 +91,5 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype): def fuse_rms_normalization(model: ir.Model) -> None: - count = rms_normalization_ruleset.apply_to_model(model, verbose=5) + count = rms_normalization_ruleset.apply_to_model(model) print(f"RMS Normalization count: {count}") diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py index 79a966838..30080474c 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py @@ -4,21 +4,12 @@ import unittest -import onnx - import onnxscript.optimizer from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization -def model_repr(self): - return f"Model({self.graph.name})" - - -onnx.ModelProto.__repr__ = model_repr - - class TestRmsNormalization(unittest.TestCase): def test_smollm(self): smollm_test = _SmollmTestData() diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py new file mode 100644 index 000000000..b36cf2c9b --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnxscript.ir as ir +from onnxscript.rewriter import _ir_utils, pattern + +# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern +# for full rotation without interleaving. +# TODO(rama): Add pattern variations to handle other cases (interleaved, as well as partial rotation). + +# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet. +# so it can't be tested by running against ORT. See cos_sin_cache.py for a transformation that +# rewrites the pattern into one that can be run against ORT. + + +def _rotate_half_pattern(op, x, start1, end1, start2, end2): + # Slice(input, starts, ends, axes, steps) + x1 = op.Slice(x, start1, end1, [3], [1]) + x2 = op.Slice(x, start2, end2, [3], [1]) + minus_x2 = op.Neg(x2) + rotated_x = op.Concat(minus_x2, x1, axis=-1) + return rotated_x + + +class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x, cos, sin, start1, end1, start2, end2): + return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin + + def check(self, op, x, start1, end1, start2, end2, **_): + # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) + if x is None or x.shape is None or len(x.shape) != 4: + return False + if not isinstance(x.shape[1], int): + return False + head_size = x.shape[3] + if not isinstance(head_size, int): + return False + half_head_size = head_size // 2 + + # Check that x is being split into two equal halves of size half_head_size + return ( + _ir_utils.is_singleton_value(start1, 0) + and _ir_utils.is_singleton_value(end1, half_head_size) + and _ir_utils.is_singleton_value(start2, half_head_size) + and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) + ) + + def rewrite(self, op, x, cos, sin, **_): + num_heads = x.shape[1] + return op.RotaryEmbedding( + x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion" + ) + + +_rule = RotaryEmbeddingFusion.rule() + +rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) + + +def fuse_rotary_embedding(model: ir.Model) -> int: + count = rotary_embedding_rules.apply_to_model(model) + print(f"Rotary Embedding count: {count}") + return count diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py new file mode 100644 index 000000000..6f8d37dee --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData +from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding + + +class TestRotaryEmbedding(unittest.TestCase): + def test_smollm(self): + smollm_test = _SmollmTestData() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + fuse_rotary_embedding(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("RotaryEmbedding", op_types) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index f2faf77c3..a961ae872 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -946,6 +946,7 @@ def match( graph_or_function: ir.Graph | ir.Function, node: ir.Node, verbose: int = 0, + remove_nodes: bool = True, ) -> MatchResult: """Match the pattern against the subgraph ending at the given node.""" @@ -1144,6 +1145,7 @@ def _match_single_output_node( model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, + check_removable: bool, ) -> MatchResult: del model del graph_or_function @@ -1162,13 +1164,13 @@ def _match_single_output_node( output_values = self._get_output_values() if output_values is None: return match - if not _valid_to_replace(match.nodes, output_values): + if check_removable and not _valid_to_replace(match.nodes, output_values): return match.fail("Matched nodes have other uses preventing replacement.") match.outputs.extend(output_values) return match - def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: + def _multi_match(self, candidate: Iterable[ir.Node], check_removable: bool) -> MatchResult: """Find a match for a pattern with multiple output nodes. For a pattern with K output nodes, the input candidate should specify K nodes @@ -1176,6 +1178,8 @@ def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: Args: candidate: An iterable of nodes that will be matched against the pattern output nodes. + check_removable: If True, check that the matched nodes can be removed (that is, that + they are not used elsewhere in the graph). """ match = self._match for pattern_node, node in zip(self.pattern.output_nodes, candidate): @@ -1185,7 +1189,7 @@ def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: if output_values is None: return match - if not _valid_to_replace(match.nodes, output_values): + if check_removable and not _valid_to_replace(match.nodes, output_values): return match.fail("Matched nodes have other uses preventing replacement.") match.outputs.extend(output_values) @@ -1197,6 +1201,7 @@ def match( graph_or_function: ir.Graph | ir.Function, node: ir.Node, verbose: int = 0, + remove_nodes: bool = True, ) -> MatchResult: """Match the pattern against the subgraph ending at the given node. @@ -1216,7 +1221,9 @@ def match( if self.pattern.has_single_output_node: self._init_match(verbose) - return self._match_single_output_node(model, graph_or_function, node) + return self._match_single_output_node( + model, graph_or_function, node, check_removable=remove_nodes + ) else: # Note: This is a potentially expensive algorithm for matching patterns with # multiple output nodes. For patterns with N output nodes, we try all possible @@ -1243,7 +1250,7 @@ def get_nodes(pattern_node): match = None for combination in itertools.product(*candidates): self._init_match(verbose) - match = self._multi_match(combination) + match = self._multi_match(combination, check_removable=remove_nodes) if match: return match if match is None: @@ -1260,6 +1267,7 @@ def __init__( matcher: PatternMatcher | Callable[[GraphPattern], PatternMatcher] | None = None, verbose: int = 0, name: str | None = None, + remove_nodes: bool = True, ) -> None: """Create a rewrite rule. @@ -1275,6 +1283,7 @@ def __init__( If not provided, a default matcher will be used. verbose: The verbosity level of the rule. name: An optional name for the pattern that will show up in verbose logging. + remove_nodes: If True, the matched nodes will be removed from the graph. """ if not isinstance(target_pattern, GraphPattern): @@ -1298,6 +1307,7 @@ def __init__( self._matcher = matcher(self._target_pattern) self._verbose = verbose self.name = name + self.remove_nodes = remove_nodes def __str__(self) -> str: if self.name: @@ -1317,7 +1327,9 @@ def try_rewrite( if verbose and verbose > 2: print(f"[try_rewrite] {self}") verbose = verbose if verbose is not None else self._verbose - match = self._matcher.match(model, graph_or_function, node, verbose=verbose) + match = self._matcher.match( + model, graph_or_function, node, verbose=verbose, remove_nodes=self.remove_nodes + ) if match: context = None # TODO(rama) for var in self._target_pattern.inputs: @@ -1440,19 +1452,23 @@ class RewriteRuleClassBase: def rule(cls, *args, **kwargs): instance = cls(*args, **kwargs) return RewriteRule( - instance.pattern, instance.rewrite, instance.check, name=instance.name + instance.pattern, + instance.rewrite, + instance.check, + name=instance.name, + remove_nodes=instance.remove_nodes, ) - @property - def name(self): - """Default implementation of name property.""" - return self.__class__.__name__ + def __init__(self, name: str | None = None, remove_nodes: bool = True) -> None: + self.name = name or self.__class__.__name__ + self.remove_nodes = remove_nodes def pattern(self, op, *args, **kwargs): raise NotImplementedError("Method 'pattern' must be implemented by derived class.") def check(self, op, *args, **kwargs): - raise NotImplementedError("Method 'check' must be implemented by derived class.") + # Default check function that always returns True. + return True def rewrite(self, op, *args, **kwargs): raise NotImplementedError("Method 'rewrite' must be implemented by derived class.") @@ -1488,7 +1504,7 @@ def _apply_to_graph_or_function( _convenience.replace_nodes_and_values( graph_or_function, node, - delta.match.nodes, + delta.match.nodes if rule.remove_nodes else [], delta.new_nodes, delta.match.outputs, delta.new_outputs,