Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rotary embedding fusion rule (part 1) #1981

Merged
merged 23 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,28 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
return default


@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
shape = _get_input(node, 1)
if input is None or shape is None:
return None
input_shape = input.shape
if input_shape is None:
return None
input_shape_dims = list(input_shape.dims)
if any(not isinstance(dim, int) for dim in input_shape_dims):
return None
shape_value = _get_numpy_value(shape)
if shape_value is None:
return None
target_shape_dims = shape_value.tolist()
if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
return op.Identity(input)
return None


@register("Cast")
def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
Expand Down
27 changes: 27 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Licensed under the MIT License.
from __future__ import annotations

import math
from typing import Callable

import numpy as np

import onnxscript.ir as ir
Expand Down Expand Up @@ -77,3 +80,27 @@
if np_val is not None and np_val.size == 1:
return np_val.item()
return None


def is_singleton_value(
val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None
) -> bool:
"""Returns True if the value is a single element tensor with given value, and False otherwise."""
scalar = get_singleton_value(val)
if scalar is None:
return False
if isinstance(expected, Callable):
Fixed Show fixed Hide fixed
return expected(scalar)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
if isinstance(expected, int):
return expected == scalar
# rtol must be specified for float comparison
assert rtol is not None
return math.isclose(scalar, expected, rtol=rtol)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed


def has_rank(value: ir.Value | None, rank: int) -> bool:
"""Returns True if the value is statically known to have the given rank, and False otherwise."""
if value is None:
return False
shape = value.shape
return (shape is not None) and (shape.rank() == rank)
3 changes: 2 additions & 1 deletion onnxscript/rewriter/onnxruntime/xformers/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

def ort_run(model_name: str, model, inputs):
providers = ["CPUExecutionProvider"]
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = r"C:\Users\grama\OneDrive - Microsoft\0L-Torch\model\smollm-1L-debug"
with tempfile.TemporaryDirectory() as temp_dir2:
Fixed Show fixed Hide fixed
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
io.save(model, model_path)
# Run model
Expand Down
70 changes: 70 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
from __future__ import annotations

import numpy as np

import onnxscript.ir as ir
from onnxscript.rewriter import _ir_utils, pattern

# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops.

# Original code (from transformers) for computing cos/sin cache for RoPE:
# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135
# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
# position_ids_expanded = position_ids[:, None, :].float()
# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
# emb = torch.cat((freqs, freqs), dim=-1)
# cos = emb.cos()
# sin = emb.sin()


class CosSinCacheFusion(pattern.RewriteRuleClassBase):
def __init__(self, name: str, max_pos_id: int):
super().__init__(name)
self._max_pos_id = max_pos_id

def pattern(self, op, inv_freq, position_ids):
position_ids_expanded = op.Unsqueeze(position_ids, [1])
position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT)
freqs = op.MatMul(inv_freq, position_ids_expanded)
freqs = op.Transpose(freqs, perm=[0, 2, 1])
emb = op.Concat(freqs, freqs, axis=-1)
cos = op.Cos(emb)
sin = op.Sin(emb)
return cos, sin

def check(self, context, inv_freq, position_ids, **_):
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
if not _ir_utils.has_rank(position_ids, 2):
return False
if not _ir_utils.has_rank(inv_freq, 3):
return False
inv_freq_shape = inv_freq.shape
if inv_freq.const_value is None:
return False
return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1

def rewrite(self, op, inv_freq, position_ids, **_):
inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1)
pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1)
angles = np.matmul(pos_id_range, inv_freq_values)
cos_value = np.cos(angles)
cos_value = np.concatenate([cos_value, cos_value], axis=-1)
sin_value = np.sin(angles)
sin_value = np.concatenate([sin_value, sin_value], axis=-1)
cos_2d = op.Constant(value=ir.tensor(cos_value))
cos = op.Gather(cos_2d, position_ids, axis=0)
sin_2d = op.Constant(value=ir.tensor(sin_value))
sin = op.Gather(sin_2d, position_ids, axis=0)
return cos, sin


_rule = CosSinCacheFusion.rule("CosSinCache", 2048)

cos_sin_cache_rules = pattern.RewriteRuleSet([_rule])


def fuse_cos_sin_cache(model: ir.Model) -> int:
count = cos_sin_cache_rules.apply_to_model(model)
print(f"CosSinCache count: {count}")
return count
29 changes: 29 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
Fixed Show fixed Hide fixed
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnx
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run
from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache


class TestCosSinCacheTransform(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
inputs = smollm_test.get_ort_inputs()
original_outputs = ort_run("original", model, inputs)
count = fuse_cos_sin_cache(model)
self.assertGreater(count, 0)
new_outputs = ort_run("optimized", model, inputs)
assert_allclose(new_outputs, original_outputs)


if __name__ == "__main__":
unittest.main()
8 changes: 2 additions & 6 deletions onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,10 @@ def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool):
cast_input: Whether to cast input to do the normalization in a different precision.
cast_normalized: Whether to cast the normalized output to the target dtype (same as scale).
"""
self._name = name
super().__init__(name=name)
self._cast_input = cast_input
self._cast_normalized = cast_normalized

@property
def name(self):
return self._name

def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
if self._cast_input:
x = op.Cast(x, to=compute_dtype)
Expand Down Expand Up @@ -95,5 +91,5 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):


def fuse_rms_normalization(model: ir.Model) -> None:
count = rms_normalization_ruleset.apply_to_model(model, verbose=5)
count = rms_normalization_ruleset.apply_to_model(model)
print(f"RMS Normalization count: {count}")
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,12 @@

import unittest

import onnx

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization


def model_repr(self):
return f"Model({self.graph.name})"


onnx.ModelProto.__repr__ = model_repr


class TestRmsNormalization(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
Expand Down
58 changes: 58 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.rewriter import _ir_utils, pattern

# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern
# for full rotation without interleaving.
# TODO(rama): Add pattern variations to handle other cases.

# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet,
# so it can't be tested by running against ORT. Unfortunately, this is the new pattern out
# of current version of transformers (not yet supported by ORT).


def _rotate_half_pattern(op, x, start1, end1, start2, end2):
# Slice(input, starts, ends, axes, steps)
x1 = op.Slice(x, start1, end1, [3], [1])
x2 = op.Slice(x, start2, end2, [3], [1])
minus_x2 = op.Neg(x2)
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
rotated_x = op.Concat(minus_x2, x1, axis=-1)
return rotated_x


class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x, cos, sin, start1, end1, start2, end2):
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin

def check(self, op, x, start1, end1, start2, end2, **_):
# x needs to be a 4D tensor with known last dimension size (== head_size)
if x is None or x.shape is None or len(x.shape) != 4:
return False
head_size = x.shape[3]
if not isinstance(head_size, int):
return False
half_head_size = head_size // 2

# Check that x is being split into two equal halves of size half_head_size
return (
_ir_utils.is_singleton_value(start1, 0)
and _ir_utils.is_singleton_value(end1, half_head_size)
and _ir_utils.is_singleton_value(start2, half_head_size)
and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size)
)

def rewrite(self, op, x, cos, sin, **_):
return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="ai.onnxruntime.fusion")


_rule = RotaryEmbeddingFusion.rule()

rotary_embedding_rules = pattern.RewriteRuleSet([_rule])


def fuse_rotary_embedding(model: ir.Model) -> None:
count = rotary_embedding_rules.apply_to_model(model)
print(f"Rotary Embedding count: {count}")
23 changes: 23 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding


class TestRotaryEmbedding(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
fuse_rotary_embedding(model)
op_types = [n.op_type for n in model.graph]
self.assertIn("RotaryEmbedding", op_types)


if __name__ == "__main__":
unittest.main()
6 changes: 2 additions & 4 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,10 +1443,8 @@ def rule(cls, *args, **kwargs):
instance.pattern, instance.rewrite, instance.check, name=instance.name
)

@property
def name(self):
"""Default implementation of name property."""
return self.__class__.__name__
def __init__(self, name: str | None = None) -> None:
self.name = name or self.__class__.__name__

def pattern(self, op, *args, **kwargs):
raise NotImplementedError("Method 'pattern' must be implemented by derived class.")
Expand Down
Loading