Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rotary embedding fusion rule (part 1) #1981

Merged
merged 23 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Licensed under the MIT License.
from __future__ import annotations

import math
from typing import Callable

import numpy as np

import onnxscript.ir as ir
Expand Down Expand Up @@ -77,3 +80,19 @@
if np_val is not None and np_val.size == 1:
return np_val.item()
return None


def is_singleton_value(
val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None
) -> bool:
"""Returns True if the value is a single element tensor with given value, and False otherwise."""
scalar = get_singleton_value(val)
if scalar is None:
return False

Check warning on line 91 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L91

Added line #L91 was not covered by tests
if isinstance(expected, Callable):
Fixed Show fixed Hide fixed
return expected(scalar)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
if isinstance(expected, int):
return expected == scalar
# rtol must be specified for float comparison
assert rtol is not None
return math.isclose(scalar, expected, rtol=rtol)

Check warning on line 98 in onnxscript/rewriter/_ir_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_ir_utils.py#L97-L98

Added lines #L97 - L98 were not covered by tests
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
8 changes: 2 additions & 6 deletions onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,10 @@ def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool):
cast_input: Whether to cast input to do the normalization in a different precision.
cast_normalized: Whether to cast the normalized output to the target dtype (same as scale).
"""
self._name = name
super().__init__(name=name)
self._cast_input = cast_input
self._cast_normalized = cast_normalized

@property
def name(self):
return self._name

def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
if self._cast_input:
x = op.Cast(x, to=compute_dtype)
Expand Down Expand Up @@ -95,5 +91,5 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):


def fuse_rms_normalization(model: ir.Model) -> None:
count = rms_normalization_ruleset.apply_to_model(model, verbose=5)
count = rms_normalization_ruleset.apply_to_model(model)
print(f"RMS Normalization count: {count}")
58 changes: 58 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.rewriter import _ir_utils, pattern

# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern
# for full rotation without interleaving.
# TODO(rama): Add pattern variations to handle other cases.

# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet,
# so it can't be tested by running against ORT. Unfortunately, this is the new pattern out
# of current version of transformers (not yet supported by ORT).


def _rotate_half_pattern(op, x, start1, end1, start2, end2):
# Slice(input, starts, ends, axes, steps)
x1 = op.Slice(x, start1, end1, [3], [1])
x2 = op.Slice(x, start2, end2, [3], [1])
minus_x2 = op.Neg(x2)
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
rotated_x = op.Concat(minus_x2, x1, axis=-1)
return rotated_x


class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x, cos, sin, start1, end1, start2, end2):
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin

def check(self, op, x, start1, end1, start2, end2, **_):
# x needs to be a 4D tensor with known last dimension size (== head_size)
if x is None or x.shape is None or len(x.shape) != 4:
return False

Check warning on line 33 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L33

Added line #L33 was not covered by tests
head_size = x.shape[3]
if not isinstance(head_size, int):
return False

Check warning on line 36 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L36

Added line #L36 was not covered by tests
half_head_size = head_size // 2

# Check that x is being split into two equal halves of size half_head_size
return (
_ir_utils.is_singleton_value(start1, 0)
and _ir_utils.is_singleton_value(end1, half_head_size)
and _ir_utils.is_singleton_value(start2, half_head_size)
and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size)
)

def rewrite(self, op, x, cos, sin, **_):
return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="ai.onnxruntime.fusion")


_rule = RotaryEmbeddingFusion.rule()

rotary_embedding_rules = pattern.RewriteRuleSet([_rule])


def fuse_rotary_embedding(model: ir.Model) -> None:
count = rotary_embedding_rules.apply_to_model(model)
print(f"Rotary Embedding count: {count}")
23 changes: 23 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnxscript.optimizer
from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding


class TestRotaryEmbedding(unittest.TestCase):
def test_smollm(self):
smollm_test = _SmollmTestData()
model = smollm_test.get_onnx_model()
onnxscript.optimizer.optimize(model)
fuse_rotary_embedding(model)
op_types = [n.op_type for n in model.graph]
self.assertIn("RotaryEmbedding", op_types)


if __name__ == "__main__":
unittest.main()

Check warning on line 23 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py#L23

Added line #L23 was not covered by tests
6 changes: 2 additions & 4 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,10 +1443,8 @@ def rule(cls, *args, **kwargs):
instance.pattern, instance.rewrite, instance.check, name=instance.name
)

@property
def name(self):
"""Default implementation of name property."""
return self.__class__.__name__
def __init__(self, name: str | None = None) -> None:
self.name = name or self.__class__.__name__

def pattern(self, op, *args, **kwargs):
raise NotImplementedError("Method 'pattern' must be implemented by derived class.")
Expand Down
Loading