-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add rotary embedding fusion rule (part 1) (#1981)
Initial version of fusion for rotary embedding. Limitations: currently addresses only non-interleaved and full rotation. Other: * Add support for rewriting rules where the matched nodes are not removed. Useful in cases where matched nodes include some shared nodes. * Add optimization to eliminate redundant Reshape (helps simplify pattern).
- Loading branch information
1 parent
e92e02a
commit 343161a
Showing
12 changed files
with
317 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,15 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"fuse_rms_normalization", | ||
"fuse_normalization", | ||
"fuse_rotary_embedding", | ||
"fuse_cos_sin_cache", | ||
] | ||
|
||
from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache | ||
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization | ||
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding | ||
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
102 changes: 102 additions & 0 deletions
102
onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.optimizer import remove_unused_nodes | ||
from onnxscript.rewriter import _ir_utils, pattern | ||
|
||
# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops. | ||
|
||
# We match against the following code pattern: | ||
# Original code (from transformers) for computing cos/sin cache for RoPE: | ||
# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135 | ||
# position_ids_expanded = position_ids[:, None, :].float() | ||
# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | ||
# emb = torch.cat((freqs, freqs), dim=-1) | ||
# cos = emb.cos() | ||
# sin = emb.sin() | ||
# | ||
# We rewrite this pattern into the following form: | ||
# inv_freq_values = inv_freq_expanded.reshape(1, -1) | ||
# pos_id_range = np.arange(max_pos_id, dtype=np.float32).reshape(-1, 1) | ||
# angles = np.matmul(pos_id_range, inv_freq_values) | ||
# cos_value = np.cos(angles) | ||
# sin_value = np.sin(angles) | ||
# cos_2d = op.Constant(value=ir.tensor(cos_value)) | ||
# sin_2d = op.Constant(value=ir.tensor(sin_value)) | ||
# | ||
# This produces cos/sin values in a form that can be used by ORT's custom ops. | ||
|
||
# TODO: To apply the pattern-rewrite, we need to know the maximum position id. | ||
# Need to find a way to get this information from the model or its config. | ||
|
||
|
||
class CosSinCacheFusion(pattern.RewriteRuleClassBase): | ||
def __init__(self, name: str, max_pos_id: int): | ||
# This pattern makes use of shared Cos/Sin values. So, we can't remove the | ||
# matched nodes as part of the rewrite-step. We apply a separate final | ||
# pass to remove unused nodes. | ||
super().__init__(name, remove_nodes=False) | ||
self._max_pos_id = max_pos_id | ||
|
||
def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads): | ||
position_ids_expanded = op.Unsqueeze(position_ids, 1) | ||
position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) | ||
freqs = op.MatMul(inv_freq, position_ids_expanded) | ||
freqs = op.Transpose(freqs, perm=[0, 2, 1]) | ||
emb = op.Concat(freqs, freqs, axis=-1) | ||
cos = op.Cos(emb) | ||
sin = op.Sin(emb) | ||
cos_4d = op.Unsqueeze(cos, 1) # convert | ||
sin_4d = op.Unsqueeze(sin, 1) | ||
return op.RotaryEmbedding( | ||
x, | ||
cos_4d, | ||
sin_4d, | ||
interleaved=interleaved, | ||
num_heads=num_heads, | ||
_domain="ai.onnxruntime.fusion", | ||
) | ||
|
||
def check(self, context, inv_freq, position_ids, **_) -> bool: | ||
if not _ir_utils.has_rank(position_ids, 2): | ||
return False | ||
if not _ir_utils.has_rank(inv_freq, 3): | ||
return False | ||
inv_freq_shape = inv_freq.shape | ||
if inv_freq.const_value is None: | ||
return False | ||
return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1 | ||
|
||
def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_): | ||
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1) | ||
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) | ||
angles = np.matmul(pos_id_range, inv_freq_values) | ||
cos_value = np.cos(angles) | ||
sin_value = np.sin(angles) | ||
cos_2d = op.Constant(value=ir.tensor(cos_value)) | ||
sin_2d = op.Constant(value=ir.tensor(sin_value)) | ||
return op.RotaryEmbedding( | ||
x, | ||
position_ids, | ||
cos_2d, | ||
sin_2d, | ||
interleaved=interleaved, | ||
num_heads=num_heads, | ||
_domain="com.microsoft", | ||
) | ||
|
||
|
||
_rule = CosSinCacheFusion.rule("CosSinCache", 2048) | ||
|
||
cos_sin_cache_rules = pattern.RewriteRuleSet([_rule]) | ||
|
||
|
||
def fuse_cos_sin_cache(model: ir.Model) -> int: | ||
count = cos_sin_cache_rules.apply_to_model(model) | ||
print(f"CosSinCache count: {count}") | ||
remove_unused_nodes(model) | ||
return count |
29 changes: 29 additions & 0 deletions
29
onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import onnxscript.optimizer | ||
from onnxscript.rewriter.onnxruntime.xformers import fuse_cos_sin_cache, fuse_rotary_embedding | ||
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData | ||
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run | ||
|
||
|
||
class TestCosSinCacheTransform(unittest.TestCase): | ||
def test_smollm(self): | ||
smollm_test = _SmollmTestData() | ||
model = smollm_test.get_onnx_model() | ||
onnxscript.optimizer.optimize(model) | ||
inputs = smollm_test.get_ort_inputs() | ||
original_outputs = ort_run("original", model, inputs) | ||
count = fuse_rotary_embedding(model) | ||
self.assertGreater(count, 0) | ||
count = fuse_cos_sin_cache(model) | ||
self.assertGreater(count, 0) | ||
new_outputs = ort_run("optimized", model, inputs) | ||
assert_allclose(new_outputs, original_outputs) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.rewriter import _ir_utils, pattern | ||
|
||
# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern | ||
# for full rotation without interleaving. | ||
# TODO(rama): Add pattern variations to handle other cases (interleaved, as well as partial rotation). | ||
|
||
# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet. | ||
# so it can't be tested by running against ORT. See cos_sin_cache.py for a transformation that | ||
# rewrites the pattern into one that can be run against ORT. | ||
|
||
|
||
def _rotate_half_pattern(op, x, start1, end1, start2, end2): | ||
# Slice(input, starts, ends, axes, steps) | ||
x1 = op.Slice(x, start1, end1, [3], [1]) | ||
x2 = op.Slice(x, start2, end2, [3], [1]) | ||
minus_x2 = op.Neg(x2) | ||
rotated_x = op.Concat(minus_x2, x1, axis=-1) | ||
return rotated_x | ||
|
||
|
||
class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase): | ||
def pattern(self, op, x, cos, sin, start1, end1, start2, end2): | ||
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin | ||
|
||
def check(self, op, x, start1, end1, start2, end2, **_): | ||
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) | ||
if x is None or x.shape is None or len(x.shape) != 4: | ||
return False | ||
if not isinstance(x.shape[1], int): | ||
return False | ||
head_size = x.shape[3] | ||
if not isinstance(head_size, int): | ||
return False | ||
half_head_size = head_size // 2 | ||
|
||
# Check that x is being split into two equal halves of size half_head_size | ||
return ( | ||
_ir_utils.is_singleton_value(start1, 0) | ||
and _ir_utils.is_singleton_value(end1, half_head_size) | ||
and _ir_utils.is_singleton_value(start2, half_head_size) | ||
and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) | ||
) | ||
|
||
def rewrite(self, op, x, cos, sin, **_): | ||
num_heads = x.shape[1] | ||
return op.RotaryEmbedding( | ||
x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion" | ||
) | ||
|
||
|
||
_rule = RotaryEmbeddingFusion.rule() | ||
|
||
rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) | ||
|
||
|
||
def fuse_rotary_embedding(model: ir.Model) -> int: | ||
count = rotary_embedding_rules.apply_to_model(model) | ||
print(f"Rotary Embedding count: {count}") | ||
return count |
23 changes: 23 additions & 0 deletions
23
onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import unittest | ||
|
||
import onnxscript.optimizer | ||
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData | ||
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding | ||
|
||
|
||
class TestRotaryEmbedding(unittest.TestCase): | ||
def test_smollm(self): | ||
smollm_test = _SmollmTestData() | ||
model = smollm_test.get_onnx_model() | ||
onnxscript.optimizer.optimize(model) | ||
fuse_rotary_embedding(model) | ||
op_types = [n.op_type for n in model.graph] | ||
self.assertIn("RotaryEmbedding", op_types) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.