Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rotary embedding fusion rule (part 1) #1981

Merged
merged 23 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,28 @@
return default


@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
shape = _get_input(node, 1)
if input is None or shape is None:
return None

Check warning on line 314 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L314

Added line #L314 was not covered by tests
input_shape = input.shape
if input_shape is None:
return None

Check warning on line 317 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L317

Added line #L317 was not covered by tests
input_shape_dims = list(input_shape.dims)
if any(not isinstance(dim, int) for dim in input_shape_dims):
return None

Check warning on line 320 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L320

Added line #L320 was not covered by tests
shape_value = _get_numpy_value(shape)
if shape_value is None:
return None

Check warning on line 323 in onnxscript/optimizer/_constant_folding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/optimizer/_constant_folding.py#L323

Added line #L323 was not covered by tests
target_shape_dims = shape_value.tolist()
if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
return op.Identity(input)
return None


@register("Cast")
def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
Expand Down
27 changes: 27 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Licensed under the MIT License.
from __future__ import annotations

import math
from typing import Callable

import numpy as np

import onnxscript.ir as ir
Expand Down Expand Up @@ -77,3 +80,27 @@
if np_val is not None and np_val.size == 1:
return np_val.item()
return None


def is_singleton_value(
val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None
) -> bool:
"""Returns True if the value is a single element tensor with given value, and False otherwise."""
scalar = get_singleton_value(val)
if scalar is None:
return False

Check warning on line 91 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L91

Added line #L91 was not covered by tests
if isinstance(expected, Callable):
Fixed Show fixed Hide fixed
return expected(scalar)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
if isinstance(expected, int):
return expected == scalar
# rtol must be specified for float comparison
assert rtol is not None
return math.isclose(scalar, expected, rtol=rtol)

Check warning on line 98 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L97-L98

Added lines #L97 - L98 were not covered by tests
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed


def has_rank(value: ir.Value | None, rank: int) -> bool:
"""Returns True if the value is statically known to have the given rank, and False otherwise."""
if value is None:
return False

Check warning on line 104 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L104

Added line #L104 was not covered by tests
shape = value.shape
return (shape is not None) and (shape.rank() == rank)
12 changes: 12 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization

__all__ = [
"fuse_rms_normalization",
"fuse_normalization",
"fuse_rotary_embedding",
"fuse_cos_sin_cache",
]
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion onnxscript/rewriter/onnxruntime/xformers/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def ort_run(model_name: str, model, inputs):
providers = ["CPUExecutionProvider"]
with tempfile.TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
io.save(model, model_path)
_save(model, model_path)
# Run model
session = onnxruntime.InferenceSession(model_path, providers=providers)
ort_outputs = session.run(None, inputs)
Expand Down
90 changes: 90 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.optimizer import remove_unused_nodes
from onnxscript.rewriter import _ir_utils, pattern

# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops.

# Original code (from transformers) for computing cos/sin cache for RoPE:
# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135
# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
# position_ids_expanded = position_ids[:, None, :].float()
# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
# emb = torch.cat((freqs, freqs), dim=-1)
# cos = emb.cos()
# sin = emb.sin()


class CosSinCacheFusion(pattern.RewriteRuleClassBase):
def __init__(self, name: str, max_pos_id: int):
super().__init__(name)
self._max_pos_id = max_pos_id
self.remove_nodes = False
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads):
position_ids_expanded = op.Unsqueeze(position_ids, 1)
position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
freqs = op.MatMul(inv_freq, position_ids_expanded)
freqs = op.Transpose(freqs, perm=[0, 2, 1])
emb = op.Concat(freqs, freqs, axis=-1)
cos = op.Cos(emb)
sin = op.Sin(emb)
cos_4d = op.Unsqueeze(cos, 1) # convert
sin_4d = op.Unsqueeze(sin, 1)
return op.RotaryEmbedding(
x,
cos_4d,
sin_4d,
interleaved=interleaved,
num_heads=num_heads,
_domain="ai.onnxruntime.fusion",
)

def check(self, context, inv_freq, position_ids, **_):
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
if not _ir_utils.has_rank(position_ids, 2):
return False

Check warning on line 50 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L50

Added line #L50 was not covered by tests
if not _ir_utils.has_rank(inv_freq, 3):
return False

Check warning on line 52 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L52

Added line #L52 was not covered by tests
inv_freq_shape = inv_freq.shape
if inv_freq.const_value is None:
return False

Check warning on line 55 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L55

Added line #L55 was not covered by tests
return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1

def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_):
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
angles = np.matmul(pos_id_range, inv_freq_values)
cos_value = np.cos(angles)
cos_value = np.concatenate([cos_value, cos_value], axis=-1)
sin_value = np.sin(angles)
sin_value = np.concatenate([sin_value, sin_value], axis=-1)
cos_2d = op.Constant(value=ir.tensor(cos_value))
# cos = op.Gather(cos_2d, position_ids, axis=0)
sin_2d = op.Constant(value=ir.tensor(sin_value))
# sin = op.Gather(sin_2d, position_ids, axis=0)
return op.RotaryEmbedding(
x,
position_ids,
cos_2d,
sin_2d,
interleaved=interleaved,
num_heads=num_heads,
_domain="com.microsoft",
)


_rule = CosSinCacheFusion.rule("CosSinCache", 2048)

cos_sin_cache_rules = pattern.RewriteRuleSet([_rule])


def fuse_cos_sin_cache(model: ir.Model) -> int:
count = cos_sin_cache_rules.apply_to_model(model)
print(f"CosSinCache count: {count}")
remove_unused_nodes(model)
return count
29 changes: 29 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers import fuse_cos_sin_cache, fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run


class TestCosSinCacheTransform(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
inputs = smollm_test.get_ort_inputs()
original_outputs = ort_run("original", model, inputs)
count = fuse_rotary_embedding(model)
Fixed Show fixed Hide fixed
self.assertGreater(count, 0)
count = fuse_cos_sin_cache(model)
self.assertGreater(count, 0)
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)


if __name__ == "__main__":
unittest.main()

Check warning on line 29 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py#L29

Added line #L29 was not covered by tests
8 changes: 2 additions & 6 deletions onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,10 @@ def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool):
cast_input: Whether to cast input to do the normalization in a different precision.
cast_normalized: Whether to cast the normalized output to the target dtype (same as scale).
"""
self._name = name
super().__init__(name=name)
self._cast_input = cast_input
self._cast_normalized = cast_normalized

@property
def name(self):
return self._name

def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
if self._cast_input:
x = op.Cast(x, to=compute_dtype)
Expand Down Expand Up @@ -95,5 +91,5 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):


def fuse_rms_normalization(model: ir.Model) -> None:
count = rms_normalization_ruleset.apply_to_model(model, verbose=5)
count = rms_normalization_ruleset.apply_to_model(model)
print(f"RMS Normalization count: {count}")
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,12 @@

import unittest

import onnx

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization


def model_repr(self):
return f"Model({self.graph.name})"


onnx.ModelProto.__repr__ = model_repr


class TestRmsNormalization(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
Expand Down
64 changes: 64 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.rewriter import _ir_utils, pattern

# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern
# for full rotation without interleaving.
# TODO(rama): Add pattern variations to handle other cases.

# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet,
# so it can't be tested by running against ORT. Unfortunately, this is the new pattern out
# of current version of transformers (not yet supported by ORT).


def _rotate_half_pattern(op, x, start1, end1, start2, end2):
# Slice(input, starts, ends, axes, steps)
x1 = op.Slice(x, start1, end1, [3], [1])
x2 = op.Slice(x, start2, end2, [3], [1])
minus_x2 = op.Neg(x2)
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
rotated_x = op.Concat(minus_x2, x1, axis=-1)
return rotated_x


class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x, cos, sin, start1, end1, start2, end2):
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin

def check(self, op, x, start1, end1, start2, end2, **_):
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
if x is None or x.shape is None or len(x.shape) != 4:
return False

Check warning on line 33 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L33

Added line #L33 was not covered by tests
if not isinstance(x.shape[1], int):
return False

Check warning on line 35 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L35

Added line #L35 was not covered by tests
head_size = x.shape[3]
if not isinstance(head_size, int):
return False

Check warning on line 38 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L38

Added line #L38 was not covered by tests
half_head_size = head_size // 2

# Check that x is being split into two equal halves of size half_head_size
return (
_ir_utils.is_singleton_value(start1, 0)
and _ir_utils.is_singleton_value(end1, half_head_size)
and _ir_utils.is_singleton_value(start2, half_head_size)
and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size)
)

def rewrite(self, op, x, cos, sin, **_):
num_heads = x.shape[1]
return op.RotaryEmbedding(
x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion"
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
)


_rule = RotaryEmbeddingFusion.rule()

rotary_embedding_rules = pattern.RewriteRuleSet([_rule])


def fuse_rotary_embedding(model: ir.Model) -> int:
count = rotary_embedding_rules.apply_to_model(model)
print(f"Rotary Embedding count: {count}")
return count
23 changes: 23 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding


class TestRotaryEmbedding(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
fuse_rotary_embedding(model)
op_types = [n.op_type for n in model.graph]
self.assertIn("RotaryEmbedding", op_types)


if __name__ == "__main__":
unittest.main()

Check warning on line 23 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py#L23

Added line #L23 was not covered by tests
Loading
Loading