From 8de723195b15e5d1074022e16416214c5b053310 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 18 Dec 2024 10:45:06 -0800 Subject: [PATCH 01/17] First version --- onnxscript/rewriter/_ir_utils.py | 6 + .../rewriter/onnxruntime/xformers/mha.py | 141 ++++++++++++++++++ .../rewriter/onnxruntime/xformers/sdpa.py | 63 ++++++++ 3 files changed, 210 insertions(+) create mode 100644 onnxscript/rewriter/onnxruntime/xformers/mha.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/sdpa.py diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 7c303556a..002d0efc5 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations +import math import numpy as np import onnxscript.ir as ir @@ -77,3 +78,8 @@ def get_singleton_value(val: ir.Value | None): if np_val is not None and np_val.size == 1: return np_val.item() return None + +def is_singleton_value(val: ir.Value | None, expected_value: float, *, rtol: float) -> bool: + """Returns True if the value is a single element tensor with given value, and False otherwise.""" + scalar = get_singleton_value(val) + return scalar is not None and math.isclose(scalar, expected_value, rtol=rtol) diff --git a/onnxscript/rewriter/onnxruntime/xformers/mha.py b/onnxscript/rewriter/onnxruntime/xformers/mha.py new file mode 100644 index 000000000..8870913f2 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/mha.py @@ -0,0 +1,141 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Iterable +import onnxscript.ir as ir +from onnxscript.rewriter import pattern + +""" +The MultiHeadAttention pattern: + +B: Batch size +S: Sequence length +D: input embedding dimension +H: number of heads +d_h: head size (usually, D = H * d_h) + +thus, weights are usually of shape (D, D) and (D, D) and (D, D) + +for each of Q, K, and V, we have the following pattern: + MatMul (Input, W), producing output of shape (B, S, D) + Reshape to produce a matrix of shape (B, S, H, d_h) + Transpose middle two axes to produce a matrix of shape (B, H, S, d_h) + +This is followed by a RotaryEmbedding pattern for Q and K + +The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence) + +The dot-product attention is then computed using SDPA + +Finally, the output is transposed and reshaped back to (B, S, D) shape +""" + + +def _project_transpose_head(op, input, weight, reshape_var: str): + """Applied to each of Q, K, and V.""" + # input_2d = op.Reshape(input, _allow_other_inputs=True, _allow_other_attributes=True) + projected = op.MatMul(input, weight) + # Reshape into 3D tensor (B, S, D) + # reshaped_3d = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) + # Reshape from (B, S, D) to (B, S, H, D/H) + reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True, _outputs=[reshape_var]) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) + return transposed + + +def _multi_head_attention_pattern(op, input, query_weight, key_weight, value_weight, cos, sin): + query = _project_transpose_head(op, input, query_weight, "query_mm_reshaped") + query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") + key = _project_transpose_head(op, input, key_weight, "key_mm_reshaped") + key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local") + # Transpose last two axes of key_rope to compute dot-product via matmul. + key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True, _outputs=["key_reshaped"]) + key_reshaped_transposed = op.Transpose(key_reshaped) + key_transposed = op.Reshape(key_reshaped_transposed, _allow_other_inputs=True, _outputs=["key_transposed"]) + value = _project_transpose_head(op, input, value_weight, "value_mm_reshaped") + attention = op.SDPA( + query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" + ) + # Transpose back to (B, S, H, D/H) + attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"]) + return attention_reshaped, key_rope, value + +def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Iterable[str]) -> bool: + if val.shape is None: + return False + if val.shape.rank() != len(shape): + return False + for actual, expected in zip(val.shape, shape): + if expected not in bindings: + bindings[expected] = actual + elif actual != bindings[expected]: + return False + return True + +def _mha_validation(op, query_mm_reshaped, key_mm_reshaped, value_mm_reshaped, key_reshaped, key_transposed, attention_reshaped, **_): + bindings : dict[str, int] = {} + check = ( + _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"]) and + _check_shape(bindings, key_mm_reshaped, ["B", "S", "H", "d_h"]) and + _check_shape(bindings, value_mm_reshaped, ["B", "S", "H", "d_h"]) and + _check_shape(bindings, key_reshaped, ["B*H", "S", "d_h"]) and + _check_shape(bindings, key_transposed, ["B", "H", "d_h", "S"]) and + _check_shape(bindings, attention_reshaped, ["B", "S", "H*d_h"]) + ) + if not check: + return False + if bindings["B"] * bindings["H"] != bindings["B*H"]: + return False + if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]: + return False + return True + +def _multi_head_attention_pattern2( + op, input, query_weight, key_weight, value_weight, cos, sin +): + """Variation of first pattern with Reshape omitted.""" + query = _project_transpose_head(op, input, query_weight) + query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") + key = _project_transpose_head(op, input, key_weight) + key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local") + # Transpose last two axes of key_rope to compute dot-product via matmul. + # Reshape omitted here. + key_transposed = op.Transpose(key_rope) + # Reshape omitted here + value = _project_transpose_head(op, input, value_weight) + attention = op.SDPA( + query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" + ) + # Transpose back to (B, S, H, D/H) + attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) + return attention_reshaped, key_rope, value + + +def _multi_head_attention( + op, + input, + query_weight, + key_weight, + value_weight, + cos, + sin, + **_ +): + # TODO: other checks and concatenation of weights + return op.MultiHeadAttention( + input, query_weight, key_weight, value_weight, cos, sin, _domain="local", _outputs=3 + ) + + +_rule1 = pattern.RewriteRule(_multi_head_attention_pattern, _multi_head_attention, _mha_validation) + +# TODO: _rule2 validation conditions +# _rule2 = pattern.RewriteRule(_multi_head_attention_pattern2, _multi_head_attention) + +mha_rules = pattern.RewriteRuleSet([_rule1]) diff --git a/onnxscript/rewriter/onnxruntime/xformers/sdpa.py b/onnxscript/rewriter/onnxruntime/xformers/sdpa.py new file mode 100644 index 000000000..76791086f --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/sdpa.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import math + +from onnxscript.rewriter import _ir_utils, pattern + +class SDPA(pattern.RewriteRuleClassBase): + def __init__(self, name: str, *, use_mask: bool, pre_scale: bool): + self._name = name + self._use_mask = use_mask + self._pre_scale = pre_scale + + def pattern(self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale): + if self.pre_scale: + # Some implementations scale the query and key before computing the dot product + query = op.Mul(query, query_scale) + key_transposed = op.Mul(key_transposed, key_scale) + attn_score = op.MatMul(query, key_transposed) + if not self.pre_scale: + # Some implementations scale the dot product. + attn_score = op.Div(attn_score, qk_scale) + if self.use_mask: + # Some implementations add a mask to the dot product. + attn_score = op.Add(attn_score, mask) + attn_weight = op.Softmax(attn_score, axis=-1) + attn_output = op.MatMul(attn_weight, value) + return attn_output + + def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale): + # Check that the scaling factors match what SDPA implements: + + # We need to know the hidden size to check the scaling factors. + if query is None or query.shape is None or len(query.shape) < 2: + return False + hidden_size = query.shape[-1] + if not isinstance(hidden_size, int) : + return False + expected_scaling_factor = math.sqrt(hidden_size) + + if self.pre_scale: + # Check if query_scale and key_scale are scalars == 1/sqrt(sqrt(hidden_size)) + sqrt_scaling_factor = 1.0 / math.sqrt(expected_scaling_factor) + if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3): + return False + if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3): + return False + else: + # Check if qk_scale is a scalar == sqrt(hidden_size) + if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3): + return False + + # check ranks/shapes + + return True + + def rewrite(self, op, query, key_transposed, value, mask, **_): + return op.SDPA(query, key_transposed, value, mask, _domain="local") + +masked_pre_mul_sdpa_rule = SDPA.rule("masked_pre_mul_sdpa", use_mask=True, pre_scale=True) + +sdpa_rules = pattern.RewriteRuleSet([rule, rule2, rule3]) From a20b903088cd9e3493a3cf1ec50c3ede6de48a0b Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 18 Dec 2024 15:07:47 -0800 Subject: [PATCH 02/17] Add rotary embedding --- onnxscript/rewriter/_ir_utils.py | 17 +++++- .../rewriter/onnxruntime/xformers/mha.py | 59 ++++++++++++------- .../onnxruntime/xformers/rms_normalization.py | 8 +-- .../onnxruntime/xformers/rotary_embedding.py | 49 +++++++++++++++ .../xformers/rotary_embedding_test.py | 24 ++++++++ .../rewriter/onnxruntime/xformers/sdpa.py | 23 +++++--- onnxscript/rewriter/pattern.py | 6 +- 7 files changed, 145 insertions(+), 41 deletions(-) create mode 100644 onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py create mode 100644 onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 002d0efc5..25f5bacc6 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -3,6 +3,8 @@ from __future__ import annotations import math +from typing import Callable + import numpy as np import onnxscript.ir as ir @@ -79,7 +81,18 @@ def get_singleton_value(val: ir.Value | None): return np_val.item() return None -def is_singleton_value(val: ir.Value | None, expected_value: float, *, rtol: float) -> bool: + +def is_singleton_value( + val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None +) -> bool: """Returns True if the value is a single element tensor with given value, and False otherwise.""" scalar = get_singleton_value(val) - return scalar is not None and math.isclose(scalar, expected_value, rtol=rtol) + if scalar is None: + return False + if isinstance(expected, Callable): + return expected(scalar) + if isinstance(expected, int): + return expected == scalar + # rtol must be specified for float comparison + assert rtol is not None + return math.isclose(scalar, expected, rtol=rtol) diff --git a/onnxscript/rewriter/onnxruntime/xformers/mha.py b/onnxscript/rewriter/onnxruntime/xformers/mha.py index 8870913f2..9c31022d1 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/mha.py +++ b/onnxscript/rewriter/onnxruntime/xformers/mha.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import Iterable + import onnxscript.ir as ir from onnxscript.rewriter import pattern @@ -39,7 +40,12 @@ def _project_transpose_head(op, input, weight, reshape_var: str): # Reshape into 3D tensor (B, S, D) # reshaped_3d = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) # Reshape from (B, S, D) to (B, S, H, D/H) - reshaped = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True, _outputs=[reshape_var]) + reshaped = op.Reshape( + projected, + _allow_other_inputs=True, + _allow_other_attributes=True, + _outputs=[reshape_var], + ) # Transpose from (B, S, H, D/H) to (B, H, S, D/H) transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) return transposed @@ -53,7 +59,9 @@ def _multi_head_attention_pattern(op, input, query_weight, key_weight, value_wei # Transpose last two axes of key_rope to compute dot-product via matmul. key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True, _outputs=["key_reshaped"]) key_reshaped_transposed = op.Transpose(key_reshaped) - key_transposed = op.Reshape(key_reshaped_transposed, _allow_other_inputs=True, _outputs=["key_transposed"]) + key_transposed = op.Reshape( + key_reshaped_transposed, _allow_other_inputs=True, _outputs=["key_transposed"] + ) value = _project_transpose_head(op, input, value_weight, "value_mm_reshaped") attention = op.SDPA( query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" @@ -61,9 +69,12 @@ def _multi_head_attention_pattern(op, input, query_weight, key_weight, value_wei # Transpose back to (B, S, H, D/H) attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) # Reshape back to (B, S, D) - attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"]) + attention_reshaped = op.Reshape( + attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"] + ) return attention_reshaped, key_rope, value + def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Iterable[str]) -> bool: if val.shape is None: return False @@ -76,15 +87,25 @@ def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Iterable[str]) return False return True -def _mha_validation(op, query_mm_reshaped, key_mm_reshaped, value_mm_reshaped, key_reshaped, key_transposed, attention_reshaped, **_): - bindings : dict[str, int] = {} + +def _mha_validation( + op, + query_mm_reshaped, + key_mm_reshaped, + value_mm_reshaped, + key_reshaped, + key_transposed, + attention_reshaped, + **_, +): + bindings: dict[str, int] = {} check = ( - _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"]) and - _check_shape(bindings, key_mm_reshaped, ["B", "S", "H", "d_h"]) and - _check_shape(bindings, value_mm_reshaped, ["B", "S", "H", "d_h"]) and - _check_shape(bindings, key_reshaped, ["B*H", "S", "d_h"]) and - _check_shape(bindings, key_transposed, ["B", "H", "d_h", "S"]) and - _check_shape(bindings, attention_reshaped, ["B", "S", "H*d_h"]) + _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"]) + and _check_shape(bindings, key_mm_reshaped, ["B", "S", "H", "d_h"]) + and _check_shape(bindings, value_mm_reshaped, ["B", "S", "H", "d_h"]) + and _check_shape(bindings, key_reshaped, ["B*H", "S", "d_h"]) + and _check_shape(bindings, key_transposed, ["B", "H", "d_h", "S"]) + and _check_shape(bindings, attention_reshaped, ["B", "S", "H*d_h"]) ) if not check: return False @@ -94,6 +115,7 @@ def _mha_validation(op, query_mm_reshaped, key_mm_reshaped, value_mm_reshaped, k return False return True + def _multi_head_attention_pattern2( op, input, query_weight, key_weight, value_weight, cos, sin ): @@ -117,23 +139,16 @@ def _multi_head_attention_pattern2( return attention_reshaped, key_rope, value -def _multi_head_attention( - op, - input, - query_weight, - key_weight, - value_weight, - cos, - sin, - **_ -): +def _multi_head_attention(op, input, query_weight, key_weight, value_weight, cos, sin, **_): # TODO: other checks and concatenation of weights return op.MultiHeadAttention( input, query_weight, key_weight, value_weight, cos, sin, _domain="local", _outputs=3 ) -_rule1 = pattern.RewriteRule(_multi_head_attention_pattern, _multi_head_attention, _mha_validation) +_rule1 = pattern.RewriteRule( + _multi_head_attention_pattern, _multi_head_attention, _mha_validation +) # TODO: _rule2 validation conditions # _rule2 = pattern.RewriteRule(_multi_head_attention_pattern2, _multi_head_attention) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py index 1f7a96df1..1e348acfb 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py @@ -35,14 +35,10 @@ def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool): cast_input: Whether to cast input to do the normalization in a different precision. cast_normalized: Whether to cast the normalized output to the target dtype (same as scale). """ - self._name = name + super().__init__(name=name) self._cast_input = cast_input self._cast_normalized = cast_normalized - @property - def name(self): - return self._name - def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): if self._cast_input: x = op.Cast(x, to=compute_dtype) @@ -95,5 +91,5 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype): def fuse_rms_normalization(model: ir.Model) -> None: - count = rms_normalization_ruleset.apply_to_model(model, verbose=5) + count = rms_normalization_ruleset.apply_to_model(model) print(f"RMS Normalization count: {count}") diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py new file mode 100644 index 000000000..3ed2a2450 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnxscript.ir as ir +from onnxscript.rewriter import _ir_utils, pattern + +def _rotate_half_pattern(op, x, start1, end1, start2, end2): + # Slice(input, starts, ends, axes, steps) + x1 = op.Slice(x, start1, end1, [3], [1]) + x2 = op.Slice(x, start2, end2, [3], [1]) + minus_x2 = op.Neg(x2) + rotated_x = op.Concat(minus_x2, x1, axis=-1) + return rotated_x + +class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase): + + def pattern(self, op, x, cos, sin, start1, end1, start2, end2): + return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin + + def check(self, op, x, start1, end1, start2, end2, **_): + # x needs to be a 4D tensor with known last dimension size (== head_size) + if x is None or x.shape is None or len(x.shape) != 4: + return False + head_size = x.shape[3] + if not isinstance(head_size, int): + return False + half_head_size = head_size // 2 + + # Check that x is being split into two equal halves of size half_head_size + return ( + _ir_utils.is_singleton_value(start1, 0) and + _ir_utils.is_singleton_value(end1, half_head_size) and + _ir_utils.is_singleton_value(start2, half_head_size) and + _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) + ) + + def rewrite(self, op, x, cos, sin, **_): + return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="ai.onnxruntime.fusion") + + + +_rule = RotaryEmbeddingFusion.rule() + +rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) + +def fuse_rotary_embedding(model: ir.Model) -> None: + count = rotary_embedding_rules.apply_to_model(model) + print(f"Rotary Embedding count: {count}") diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py new file mode 100644 index 000000000..f16f092be --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData +from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding + + +class TestRotaryEmbedding(unittest.TestCase): + def test_smollm(self): + smollm_test = _SmollmTestData() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + fuse_rotary_embedding(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("RotaryEmbedding", op_types) + + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/xformers/sdpa.py b/onnxscript/rewriter/onnxruntime/xformers/sdpa.py index 76791086f..4ef8d7a3d 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/sdpa.py +++ b/onnxscript/rewriter/onnxruntime/xformers/sdpa.py @@ -6,13 +6,16 @@ from onnxscript.rewriter import _ir_utils, pattern + class SDPA(pattern.RewriteRuleClassBase): def __init__(self, name: str, *, use_mask: bool, pre_scale: bool): - self._name = name + super().__init__(name=name) self._use_mask = use_mask self._pre_scale = pre_scale - - def pattern(self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale): + + def pattern( + self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale + ): if self.pre_scale: # Some implementations scale the query and key before computing the dot product query = op.Mul(query, query_scale) @@ -35,7 +38,7 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, if query is None or query.shape is None or len(query.shape) < 2: return False hidden_size = query.shape[-1] - if not isinstance(hidden_size, int) : + if not isinstance(hidden_size, int): return False expected_scaling_factor = math.sqrt(hidden_size) @@ -52,12 +55,18 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, return False # check ranks/shapes - + return True def rewrite(self, op, query, key_transposed, value, mask, **_): - return op.SDPA(query, key_transposed, value, mask, _domain="local") + return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion") + masked_pre_mul_sdpa_rule = SDPA.rule("masked_pre_mul_sdpa", use_mask=True, pre_scale=True) -sdpa_rules = pattern.RewriteRuleSet([rule, rule2, rule3]) +sdpa_rules = pattern.RewriteRuleSet([masked_pre_mul_sdpa_rule]) + + +def fuse_sdpa(model: ir.Model) -> None: + count = sdpa_rules.apply_to_model(model) + print(f"SDPA count: {count}") diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index b9d5d002a..617341707 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1443,10 +1443,8 @@ def rule(cls, *args, **kwargs): instance.pattern, instance.rewrite, instance.check, name=instance.name ) - @property - def name(self): - """Default implementation of name property.""" - return self.__class__.__name__ + def __init__(self, name: str | None = None) -> None: + self.name = name or self.__class__.__name__ def pattern(self, op, *args, **kwargs): raise NotImplementedError("Method 'pattern' must be implemented by derived class.") From b8f7a08dbfa884bbfaab64aeb9d5f48b5e9e405d Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 18 Dec 2024 15:08:57 -0800 Subject: [PATCH 03/17] Remove SDPA --- .../rewriter/onnxruntime/xformers/sdpa.py | 72 ------------------- 1 file changed, 72 deletions(-) delete mode 100644 onnxscript/rewriter/onnxruntime/xformers/sdpa.py diff --git a/onnxscript/rewriter/onnxruntime/xformers/sdpa.py b/onnxscript/rewriter/onnxruntime/xformers/sdpa.py deleted file mode 100644 index 4ef8d7a3d..000000000 --- a/onnxscript/rewriter/onnxruntime/xformers/sdpa.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import math - -from onnxscript.rewriter import _ir_utils, pattern - - -class SDPA(pattern.RewriteRuleClassBase): - def __init__(self, name: str, *, use_mask: bool, pre_scale: bool): - super().__init__(name=name) - self._use_mask = use_mask - self._pre_scale = pre_scale - - def pattern( - self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale - ): - if self.pre_scale: - # Some implementations scale the query and key before computing the dot product - query = op.Mul(query, query_scale) - key_transposed = op.Mul(key_transposed, key_scale) - attn_score = op.MatMul(query, key_transposed) - if not self.pre_scale: - # Some implementations scale the dot product. - attn_score = op.Div(attn_score, qk_scale) - if self.use_mask: - # Some implementations add a mask to the dot product. - attn_score = op.Add(attn_score, mask) - attn_weight = op.Softmax(attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) - return attn_output - - def check(self, op, query, key_transposed, value, mask, query_scale, key_scale, qk_scale): - # Check that the scaling factors match what SDPA implements: - - # We need to know the hidden size to check the scaling factors. - if query is None or query.shape is None or len(query.shape) < 2: - return False - hidden_size = query.shape[-1] - if not isinstance(hidden_size, int): - return False - expected_scaling_factor = math.sqrt(hidden_size) - - if self.pre_scale: - # Check if query_scale and key_scale are scalars == 1/sqrt(sqrt(hidden_size)) - sqrt_scaling_factor = 1.0 / math.sqrt(expected_scaling_factor) - if not _ir_utils.is_singleton_value(query_scale, sqrt_scaling_factor, rtol=1e-3): - return False - if not _ir_utils.is_singleton_value(key_scale, sqrt_scaling_factor, rtol=1e-3): - return False - else: - # Check if qk_scale is a scalar == sqrt(hidden_size) - if not _ir_utils.is_singleton_value(qk_scale, expected_scaling_factor, rtol=1e-3): - return False - - # check ranks/shapes - - return True - - def rewrite(self, op, query, key_transposed, value, mask, **_): - return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion") - - -masked_pre_mul_sdpa_rule = SDPA.rule("masked_pre_mul_sdpa", use_mask=True, pre_scale=True) - -sdpa_rules = pattern.RewriteRuleSet([masked_pre_mul_sdpa_rule]) - - -def fuse_sdpa(model: ir.Model) -> None: - count = sdpa_rules.apply_to_model(model) - print(f"SDPA count: {count}") From 315c94e7e4a74769e2811ea1511a625d21dfa6dc Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 18 Dec 2024 15:13:28 -0800 Subject: [PATCH 04/17] Add comment --- .../onnxruntime/xformers/rotary_embedding.py | 21 +++++++++++++------ .../xformers/rotary_embedding_test.py | 1 - 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py index 3ed2a2450..83749cb5d 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -5,6 +5,15 @@ import onnxscript.ir as ir from onnxscript.rewriter import _ir_utils, pattern +# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern +# for full rotation without interleaving. +# TODO(rama): Add pattern variations to handle other cases. + +# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet, +# so it can't be tested by running against ORT. Unfortunately, this is the new pattern out +# of current version of transformers (not yet supported by ORT). + + def _rotate_half_pattern(op, x, start1, end1, start2, end2): # Slice(input, starts, ends, axes, steps) x1 = op.Slice(x, start1, end1, [3], [1]) @@ -13,8 +22,8 @@ def _rotate_half_pattern(op, x, start1, end1, start2, end2): rotated_x = op.Concat(minus_x2, x1, axis=-1) return rotated_x -class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase): +class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase): def pattern(self, op, x, cos, sin, start1, end1, start2, end2): return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin @@ -29,21 +38,21 @@ def check(self, op, x, start1, end1, start2, end2, **_): # Check that x is being split into two equal halves of size half_head_size return ( - _ir_utils.is_singleton_value(start1, 0) and - _ir_utils.is_singleton_value(end1, half_head_size) and - _ir_utils.is_singleton_value(start2, half_head_size) and - _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) + _ir_utils.is_singleton_value(start1, 0) + and _ir_utils.is_singleton_value(end1, half_head_size) + and _ir_utils.is_singleton_value(start2, half_head_size) + and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) ) def rewrite(self, op, x, cos, sin, **_): return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="ai.onnxruntime.fusion") - _rule = RotaryEmbeddingFusion.rule() rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) + def fuse_rotary_embedding(model: ir.Model) -> None: count = rotary_embedding_rules.apply_to_model(model) print(f"Rotary Embedding count: {count}") diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py index f16f092be..6f8d37dee 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py @@ -19,6 +19,5 @@ def test_smollm(self): self.assertIn("RotaryEmbedding", op_types) - if __name__ == "__main__": unittest.main() From 2219fd35016be43f3bab005c6587a0d97105fc9e Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 18 Dec 2024 15:14:46 -0800 Subject: [PATCH 05/17] Remove MHA --- .../rewriter/onnxruntime/xformers/mha.py | 156 ------------------ 1 file changed, 156 deletions(-) delete mode 100644 onnxscript/rewriter/onnxruntime/xformers/mha.py diff --git a/onnxscript/rewriter/onnxruntime/xformers/mha.py b/onnxscript/rewriter/onnxruntime/xformers/mha.py deleted file mode 100644 index 9c31022d1..000000000 --- a/onnxscript/rewriter/onnxruntime/xformers/mha.py +++ /dev/null @@ -1,156 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -from typing import Iterable - -import onnxscript.ir as ir -from onnxscript.rewriter import pattern - -""" -The MultiHeadAttention pattern: - -B: Batch size -S: Sequence length -D: input embedding dimension -H: number of heads -d_h: head size (usually, D = H * d_h) - -thus, weights are usually of shape (D, D) and (D, D) and (D, D) - -for each of Q, K, and V, we have the following pattern: - MatMul (Input, W), producing output of shape (B, S, D) - Reshape to produce a matrix of shape (B, S, H, d_h) - Transpose middle two axes to produce a matrix of shape (B, H, S, d_h) - -This is followed by a RotaryEmbedding pattern for Q and K - -The last two axes of the key-embedding are then swapped (using a Reshape/Transpose/Reshape sequence) - -The dot-product attention is then computed using SDPA - -Finally, the output is transposed and reshaped back to (B, S, D) shape -""" - - -def _project_transpose_head(op, input, weight, reshape_var: str): - """Applied to each of Q, K, and V.""" - # input_2d = op.Reshape(input, _allow_other_inputs=True, _allow_other_attributes=True) - projected = op.MatMul(input, weight) - # Reshape into 3D tensor (B, S, D) - # reshaped_3d = op.Reshape(projected, _allow_other_inputs=True, _allow_other_attributes=True) - # Reshape from (B, S, D) to (B, S, H, D/H) - reshaped = op.Reshape( - projected, - _allow_other_inputs=True, - _allow_other_attributes=True, - _outputs=[reshape_var], - ) - # Transpose from (B, S, H, D/H) to (B, H, S, D/H) - transposed = op.Transpose(reshaped, perm=[0, 2, 1, 3]) - return transposed - - -def _multi_head_attention_pattern(op, input, query_weight, key_weight, value_weight, cos, sin): - query = _project_transpose_head(op, input, query_weight, "query_mm_reshaped") - query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") - key = _project_transpose_head(op, input, key_weight, "key_mm_reshaped") - key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local") - # Transpose last two axes of key_rope to compute dot-product via matmul. - key_reshaped = op.Reshape(key_rope, _allow_other_inputs=True, _outputs=["key_reshaped"]) - key_reshaped_transposed = op.Transpose(key_reshaped) - key_transposed = op.Reshape( - key_reshaped_transposed, _allow_other_inputs=True, _outputs=["key_transposed"] - ) - value = _project_transpose_head(op, input, value_weight, "value_mm_reshaped") - attention = op.SDPA( - query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" - ) - # Transpose back to (B, S, H, D/H) - attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) - # Reshape back to (B, S, D) - attention_reshaped = op.Reshape( - attention_transposed, _allow_other_inputs=True, _outputs=["attention_reshaped"] - ) - return attention_reshaped, key_rope, value - - -def _check_shape(bindings: dict[str, int], val: ir.Value, shape: Iterable[str]) -> bool: - if val.shape is None: - return False - if val.shape.rank() != len(shape): - return False - for actual, expected in zip(val.shape, shape): - if expected not in bindings: - bindings[expected] = actual - elif actual != bindings[expected]: - return False - return True - - -def _mha_validation( - op, - query_mm_reshaped, - key_mm_reshaped, - value_mm_reshaped, - key_reshaped, - key_transposed, - attention_reshaped, - **_, -): - bindings: dict[str, int] = {} - check = ( - _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"]) - and _check_shape(bindings, key_mm_reshaped, ["B", "S", "H", "d_h"]) - and _check_shape(bindings, value_mm_reshaped, ["B", "S", "H", "d_h"]) - and _check_shape(bindings, key_reshaped, ["B*H", "S", "d_h"]) - and _check_shape(bindings, key_transposed, ["B", "H", "d_h", "S"]) - and _check_shape(bindings, attention_reshaped, ["B", "S", "H*d_h"]) - ) - if not check: - return False - if bindings["B"] * bindings["H"] != bindings["B*H"]: - return False - if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]: - return False - return True - - -def _multi_head_attention_pattern2( - op, input, query_weight, key_weight, value_weight, cos, sin -): - """Variation of first pattern with Reshape omitted.""" - query = _project_transpose_head(op, input, query_weight) - query_rope = op.RotaryEmbedding(query, cos, sin, _domain="local") - key = _project_transpose_head(op, input, key_weight) - key_rope = op.RotaryEmbedding(key, cos, sin, _domain="local") - # Transpose last two axes of key_rope to compute dot-product via matmul. - # Reshape omitted here. - key_transposed = op.Transpose(key_rope) - # Reshape omitted here - value = _project_transpose_head(op, input, value_weight) - attention = op.SDPA( - query_rope, key_transposed, value, _allow_other_inputs=True, _domain="local" - ) - # Transpose back to (B, S, H, D/H) - attention_transposed = op.Transpose(attention, perm=[0, 2, 1, 3]) - # Reshape back to (B, S, D) - attention_reshaped = op.Reshape(attention_transposed, _allow_other_inputs=True) - return attention_reshaped, key_rope, value - - -def _multi_head_attention(op, input, query_weight, key_weight, value_weight, cos, sin, **_): - # TODO: other checks and concatenation of weights - return op.MultiHeadAttention( - input, query_weight, key_weight, value_weight, cos, sin, _domain="local", _outputs=3 - ) - - -_rule1 = pattern.RewriteRule( - _multi_head_attention_pattern, _multi_head_attention, _mha_validation -) - -# TODO: _rule2 validation conditions -# _rule2 = pattern.RewriteRule(_multi_head_attention_pattern2, _multi_head_attention) - -mha_rules = pattern.RewriteRuleSet([_rule1]) From 5ec9d1e0461a87207ff457fb2ab43705037b8a9c Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 19 Dec 2024 21:16:43 -0800 Subject: [PATCH 06/17] Add rewrite for cos-sin computation --- onnxscript/optimizer/_constant_folding.py | 20 ++++++ onnxscript/rewriter/_ir_utils.py | 7 +++ .../onnxruntime/xformers/cos_sin_cache.py | 63 +++++++++++++++++++ 3 files changed, 90 insertions(+) create mode 100644 onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 4053bb2a1..5f261ce1f 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -305,6 +305,26 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> return None return default +@register("Reshape") +def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + input = _get_input(node, 0) + shape = _get_input(node, 1) + if input is None or shape is None: + return None + input_shape = input.shape + if input_shape is None: + return None + input_shape_dims = list(input_shape.dims) + if any(not isinstance(dim, int) for dim in input_shape_dims): + return None + shape_value = _get_numpy_value(shape) + if shape_value is None: + return None + target_shape_dims = shape_value.tolist() + if input_shape_dims == target_shape_dims: + # No need to check for special values like -1, 0, etc. here + return op.Identity(input) + return None @register("Cast") def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 25f5bacc6..418ced84c 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -96,3 +96,10 @@ def is_singleton_value( # rtol must be specified for float comparison assert rtol is not None return math.isclose(scalar, expected, rtol=rtol) + +def has_rank(value: ir.Value | None, rank: int) -> bool: + """Returns True if the value is statically known to have the given rank, and False otherwise.""" + if value is None: + return False + shape = value.shape + return (shape is not None) and (shape.rank() == rank) \ No newline at end of file diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py new file mode 100644 index 000000000..c3ab3da32 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import numpy as np +import onnxscript.ir as ir +from onnxscript.rewriter import _ir_utils, pattern + +# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops. + +# Original code (from transformers) for computing cos/sin cache for RoPE: +# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135 +# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) +# position_ids_expanded = position_ids[:, None, :].float() +# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) +# emb = torch.cat((freqs, freqs), dim=-1) +# cos = emb.cos() +# sin = emb.sin() + +class CosSinCacheFusion(pattern.RewriteRuleClassBase): + def __init__(self, name: str, max_pos_id: int): + super().__init__(name) + self._max_pos_id = max_pos_id + + def pattern(self, op, inv_freq, position_ids): + position_ids_expanded = op.Unsqueeze(position_ids, [1]) + position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) + freqs = op.MatMul(inv_freq, position_ids_expanded) + freqs = op.Transpose(freqs, perm=[0, 2, 1]) + emb = op.Concat(freqs, freqs, axis=-1) + cos = op.Cos(emb) + sin = op.Sin(emb) + return cos, sin + + def check(self, context, inv_freq, position_ids, **_): + if not _ir_utils.has_rank(position_ids, 2): + return False + if not _ir_utils.has_rank(inv_freq, 3): + return False + inv_freq_shape = inv_freq.shape + if inv_freq.const_value is None: + return False + return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1 + + def rewrite(self, op, inv_freq, position_ids, **_): + inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1) + pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) + angles = np.matmul(pos_id_range, inv_freq_values) + cos_value = np.cos(angles) + sin_value = np.sin(angles) + cos_2d= op.Constant(value=ir.tensor(cos_value)) + cos = op.Gather(cos_2d, position_ids, axis=0) + sin_2d = op.Constant(value=ir.tensor(sin_value)) + sin = op.Gather(sin_2d, position_ids, axis=0) + return cos, sin + +_rule = CosSinCacheFusion.rule("CosSinCache", 2048) + +cos_sin_cache_rules = pattern.RewriteRuleSet([_rule]) + +def fuse_cos_sin_cache(model: ir.Model) -> None: + count = cos_sin_cache_rules.apply_to_model(model) + print(f"CosSinCache count: {count}") \ No newline at end of file From 1fdc19b2b075b8c37a82e8a06e9ee61f7d105788 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 19 Dec 2024 21:17:59 -0800 Subject: [PATCH 07/17] Run lint --- onnxscript/optimizer/_constant_folding.py | 2 ++ onnxscript/rewriter/_ir_utils.py | 3 ++- .../rewriter/onnxruntime/xformers/cos_sin_cache.py | 10 +++++++--- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 5f261ce1f..1225b08d1 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -305,6 +305,7 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> return None return default + @register("Reshape") def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = _get_input(node, 0) @@ -326,6 +327,7 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return op.Identity(input) return None + @register("Cast") def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = _get_input(node, 0) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 418ced84c..9bee21642 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -97,9 +97,10 @@ def is_singleton_value( assert rtol is not None return math.isclose(scalar, expected, rtol=rtol) + def has_rank(value: ir.Value | None, rank: int) -> bool: """Returns True if the value is statically known to have the given rank, and False otherwise.""" if value is None: return False shape = value.shape - return (shape is not None) and (shape.rank() == rank) \ No newline at end of file + return (shape is not None) and (shape.rank() == rank) diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py index c3ab3da32..a0e73730b 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -3,13 +3,14 @@ from __future__ import annotations import numpy as np + import onnxscript.ir as ir from onnxscript.rewriter import _ir_utils, pattern # Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops. # Original code (from transformers) for computing cos/sin cache for RoPE: -# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135 +# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135 # inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # position_ids_expanded = position_ids[:, None, :].float() # freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) @@ -17,6 +18,7 @@ # cos = emb.cos() # sin = emb.sin() + class CosSinCacheFusion(pattern.RewriteRuleClassBase): def __init__(self, name: str, max_pos_id: int): super().__init__(name) @@ -48,16 +50,18 @@ def rewrite(self, op, inv_freq, position_ids, **_): angles = np.matmul(pos_id_range, inv_freq_values) cos_value = np.cos(angles) sin_value = np.sin(angles) - cos_2d= op.Constant(value=ir.tensor(cos_value)) + cos_2d = op.Constant(value=ir.tensor(cos_value)) cos = op.Gather(cos_2d, position_ids, axis=0) sin_2d = op.Constant(value=ir.tensor(sin_value)) sin = op.Gather(sin_2d, position_ids, axis=0) return cos, sin + _rule = CosSinCacheFusion.rule("CosSinCache", 2048) cos_sin_cache_rules = pattern.RewriteRuleSet([_rule]) + def fuse_cos_sin_cache(model: ir.Model) -> None: count = cos_sin_cache_rules.apply_to_model(model) - print(f"CosSinCache count: {count}") \ No newline at end of file + print(f"CosSinCache count: {count}") From eb916b8809fac3ad85e783faf17cfbc3f87ac503 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 20 Dec 2024 11:16:28 -0800 Subject: [PATCH 08/17] Add cos sin test --- .../onnxruntime/xformers/_test_utils.py | 3 +- .../onnxruntime/xformers/cos_sin_cache.py | 5 +++- .../xformers/cos_sin_cache_test.py | 29 +++++++++++++++++++ .../xformers/rms_normalization_test.py | 9 ------ 4 files changed, 35 insertions(+), 11 deletions(-) create mode 100644 onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py diff --git a/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py index 0b4e2c55f..37618522a 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py @@ -23,7 +23,8 @@ def _save(model, modelpath): def ort_run(model_name: str, model, inputs): providers = ["CPUExecutionProvider"] - with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = r"C:\Users\grama\OneDrive - Microsoft\0L-Torch\model\smollm-1L-debug" + with tempfile.TemporaryDirectory() as temp_dir2: model_path = os.path.join(temp_dir, f"{model_name}.onnx") io.save(model, model_path) # Run model diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py index a0e73730b..7ddae004c 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -49,7 +49,9 @@ def rewrite(self, op, inv_freq, position_ids, **_): pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) angles = np.matmul(pos_id_range, inv_freq_values) cos_value = np.cos(angles) + cos_value = np.concatenate([cos_value, cos_value], axis=-1) sin_value = np.sin(angles) + sin_value = np.concatenate([sin_value, sin_value], axis=-1) cos_2d = op.Constant(value=ir.tensor(cos_value)) cos = op.Gather(cos_2d, position_ids, axis=0) sin_2d = op.Constant(value=ir.tensor(sin_value)) @@ -62,6 +64,7 @@ def rewrite(self, op, inv_freq, position_ids, **_): cos_sin_cache_rules = pattern.RewriteRuleSet([_rule]) -def fuse_cos_sin_cache(model: ir.Model) -> None: +def fuse_cos_sin_cache(model: ir.Model) -> int: count = cos_sin_cache_rules.apply_to_model(model) print(f"CosSinCache count: {count}") + return count diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py new file mode 100644 index 000000000..9a84f45f1 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx + +import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData +from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run +from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache + + +class TestCosSinCacheTransform(unittest.TestCase): + def test_smollm(self): + smollm_test = _SmollmTestData() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = smollm_test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + count = fuse_cos_sin_cache(model) + self.assertGreater(count, 0) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py index 79a966838..30080474c 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py @@ -4,21 +4,12 @@ import unittest -import onnx - import onnxscript.optimizer from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization -def model_repr(self): - return f"Model({self.graph.name})" - - -onnx.ModelProto.__repr__ = model_repr - - class TestRmsNormalization(unittest.TestCase): def test_smollm(self): smollm_test = _SmollmTestData() From d874dbce0b7375002d5e680bd7f977684071d91d Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 20 Dec 2024 13:18:43 -0800 Subject: [PATCH 09/17] Extend rewriter to support node reuse --- .../rewriter/onnxruntime/xformers/__init__.py | 12 +++++++ .../onnxruntime/xformers/_test_utils.py | 5 ++- .../onnxruntime/xformers/cos_sin_cache.py | 17 ++++++---- .../xformers/cos_sin_cache_test.py | 6 ++-- onnxscript/rewriter/pattern.py | 34 ++++++++++++++----- 5 files changed, 52 insertions(+), 22 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/__init__.py b/onnxscript/rewriter/onnxruntime/xformers/__init__.py index 44b5591d8..53c8f2701 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/__init__.py +++ b/onnxscript/rewriter/onnxruntime/xformers/__init__.py @@ -1,3 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from __future__ import annotations + +from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization +from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding +from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization + +__all__ = [ + "fuse_rms_normalization", + "fuse_normalization", + "fuse_rotary_embedding", + "fuse_cos_sin_cache", +] diff --git a/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py index 37618522a..b9ed0aecf 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py @@ -23,10 +23,9 @@ def _save(model, modelpath): def ort_run(model_name: str, model, inputs): providers = ["CPUExecutionProvider"] - temp_dir = r"C:\Users\grama\OneDrive - Microsoft\0L-Torch\model\smollm-1L-debug" - with tempfile.TemporaryDirectory() as temp_dir2: + with tempfile.TemporaryDirectory() as temp_dir: model_path = os.path.join(temp_dir, f"{model_name}.onnx") - io.save(model, model_path) + _save(model, model_path) # Run model session = onnxruntime.InferenceSession(model_path, providers=providers) ort_outputs = session.run(None, inputs) diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py index 7ddae004c..5125a359f 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -23,16 +23,19 @@ class CosSinCacheFusion(pattern.RewriteRuleClassBase): def __init__(self, name: str, max_pos_id: int): super().__init__(name) self._max_pos_id = max_pos_id + self.remove_nodes = False - def pattern(self, op, inv_freq, position_ids): - position_ids_expanded = op.Unsqueeze(position_ids, [1]) + def pattern(self, op, x, inv_freq, position_ids): + position_ids_expanded = op.Unsqueeze(position_ids, 1) position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) freqs = op.MatMul(inv_freq, position_ids_expanded) freqs = op.Transpose(freqs, perm=[0, 2, 1]) emb = op.Concat(freqs, freqs, axis=-1) cos = op.Cos(emb) sin = op.Sin(emb) - return cos, sin + cos_4d = op.Unsqueeze(cos, 1) # convert + sin_4d = op.Unsqueeze(sin, 1) + return op.RotaryEmbedding(x, cos_4d, sin_4d, interleaved=0, _domain="ai.onnxruntime.fusion") def check(self, context, inv_freq, position_ids, **_): if not _ir_utils.has_rank(position_ids, 2): @@ -44,7 +47,7 @@ def check(self, context, inv_freq, position_ids, **_): return False return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1 - def rewrite(self, op, inv_freq, position_ids, **_): + def rewrite(self, op, x, inv_freq, position_ids, **_): inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1) pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) angles = np.matmul(pos_id_range, inv_freq_values) @@ -53,10 +56,10 @@ def rewrite(self, op, inv_freq, position_ids, **_): sin_value = np.sin(angles) sin_value = np.concatenate([sin_value, sin_value], axis=-1) cos_2d = op.Constant(value=ir.tensor(cos_value)) - cos = op.Gather(cos_2d, position_ids, axis=0) + # cos = op.Gather(cos_2d, position_ids, axis=0) sin_2d = op.Constant(value=ir.tensor(sin_value)) - sin = op.Gather(sin_2d, position_ids, axis=0) - return cos, sin + # sin = op.Gather(sin_2d, position_ids, axis=0) + return op.RotaryEmbedding(x, cos_2d, sin_2d, position_ids, interleaved=0, _domain="ai.onnxruntime.fusion") _rule = CosSinCacheFusion.rule("CosSinCache", 2048) diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py index 9a84f45f1..dfe6625a8 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py @@ -4,12 +4,10 @@ import unittest -import onnx - import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime.xformers import fuse_cos_sin_cache, fuse_rotary_embedding from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run -from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache class TestCosSinCacheTransform(unittest.TestCase): @@ -19,6 +17,8 @@ def test_smollm(self): onnxscript.optimizer.optimize(model) inputs = smollm_test.get_ort_inputs() original_outputs = ort_run("original", model, inputs) + count = fuse_rotary_embedding(model) + self.assertGreater(count, 0) count = fuse_cos_sin_cache(model) self.assertGreater(count, 0) new_outputs = ort_run("optimized", model, inputs) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 617341707..282c7d714 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -946,6 +946,7 @@ def match( graph_or_function: ir.Graph | ir.Function, node: ir.Node, verbose: int = 0, + remove_nodes: bool = True, ) -> MatchResult: """Match the pattern against the subgraph ending at the given node.""" @@ -1144,6 +1145,7 @@ def _match_single_output_node( model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, + check_removable: bool, ) -> MatchResult: del model del graph_or_function @@ -1162,13 +1164,13 @@ def _match_single_output_node( output_values = self._get_output_values() if output_values is None: return match - if not _valid_to_replace(match.nodes, output_values): + if check_removable and not _valid_to_replace(match.nodes, output_values): return match.fail("Matched nodes have other uses preventing replacement.") match.outputs.extend(output_values) return match - def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: + def _multi_match(self, candidate: Iterable[ir.Node], check_removable: bool) -> MatchResult: """Find a match for a pattern with multiple output nodes. For a pattern with K output nodes, the input candidate should specify K nodes @@ -1185,7 +1187,7 @@ def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: if output_values is None: return match - if not _valid_to_replace(match.nodes, output_values): + if check_removable and not _valid_to_replace(match.nodes, output_values): return match.fail("Matched nodes have other uses preventing replacement.") match.outputs.extend(output_values) @@ -1197,6 +1199,7 @@ def match( graph_or_function: ir.Graph | ir.Function, node: ir.Node, verbose: int = 0, + remove_nodes: bool = True, ) -> MatchResult: """Match the pattern against the subgraph ending at the given node. @@ -1216,7 +1219,9 @@ def match( if self.pattern.has_single_output_node: self._init_match(verbose) - return self._match_single_output_node(model, graph_or_function, node) + return self._match_single_output_node( + model, graph_or_function, node, check_removable=remove_nodes + ) else: # Note: This is a potentially expensive algorithm for matching patterns with # multiple output nodes. For patterns with N output nodes, we try all possible @@ -1243,7 +1248,7 @@ def get_nodes(pattern_node): match = None for combination in itertools.product(*candidates): self._init_match(verbose) - match = self._multi_match(combination) + match = self._multi_match(combination, check_removable=remove_nodes) if match: return match if match is None: @@ -1260,6 +1265,7 @@ def __init__( matcher: PatternMatcher | Callable[[GraphPattern], PatternMatcher] | None = None, verbose: int = 0, name: str | None = None, + remove_nodes: bool = True, ) -> None: """Create a rewrite rule. @@ -1275,6 +1281,7 @@ def __init__( If not provided, a default matcher will be used. verbose: The verbosity level of the rule. name: An optional name for the pattern that will show up in verbose logging. + remove_nodes: If True, the matched nodes will be removed from the graph. """ if not isinstance(target_pattern, GraphPattern): @@ -1298,6 +1305,7 @@ def __init__( self._matcher = matcher(self._target_pattern) self._verbose = verbose self.name = name + self.remove_nodes = remove_nodes def __str__(self) -> str: if self.name: @@ -1317,7 +1325,9 @@ def try_rewrite( if verbose and verbose > 2: print(f"[try_rewrite] {self}") verbose = verbose if verbose is not None else self._verbose - match = self._matcher.match(model, graph_or_function, node, verbose=verbose) + match = self._matcher.match( + model, graph_or_function, node, verbose=verbose, remove_nodes=self.remove_nodes + ) if match: context = None # TODO(rama) for var in self._target_pattern.inputs: @@ -1440,17 +1450,23 @@ class RewriteRuleClassBase: def rule(cls, *args, **kwargs): instance = cls(*args, **kwargs) return RewriteRule( - instance.pattern, instance.rewrite, instance.check, name=instance.name + instance.pattern, + instance.rewrite, + instance.check, + name=instance.name, + remove_nodes=instance.remove_nodes, ) def __init__(self, name: str | None = None) -> None: self.name = name or self.__class__.__name__ + self.remove_nodes = True def pattern(self, op, *args, **kwargs): raise NotImplementedError("Method 'pattern' must be implemented by derived class.") def check(self, op, *args, **kwargs): - raise NotImplementedError("Method 'check' must be implemented by derived class.") + # Default check function that always returns True. + return True def rewrite(self, op, *args, **kwargs): raise NotImplementedError("Method 'rewrite' must be implemented by derived class.") @@ -1486,7 +1502,7 @@ def _apply_to_graph_or_function( _convenience.replace_nodes_and_values( graph_or_function, node, - delta.match.nodes, + delta.match.nodes if rule.remove_nodes else [], delta.new_nodes, delta.match.outputs, delta.new_outputs, From a745039f27a7f2e12ffdc6eaf92302673871d444 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 20 Dec 2024 18:36:59 -0800 Subject: [PATCH 10/17] Minor fixes --- .../onnxruntime/xformers/cos_sin_cache.py | 25 ++++++++++++++++--- .../onnxruntime/xformers/rotary_embedding.py | 12 ++++++--- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py index 5125a359f..440f4a111 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -5,6 +5,7 @@ import numpy as np import onnxscript.ir as ir +from onnxscript.optimizer import remove_unused_nodes from onnxscript.rewriter import _ir_utils, pattern # Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops. @@ -25,7 +26,7 @@ def __init__(self, name: str, max_pos_id: int): self._max_pos_id = max_pos_id self.remove_nodes = False - def pattern(self, op, x, inv_freq, position_ids): + def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads): position_ids_expanded = op.Unsqueeze(position_ids, 1) position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) freqs = op.MatMul(inv_freq, position_ids_expanded) @@ -35,7 +36,14 @@ def pattern(self, op, x, inv_freq, position_ids): sin = op.Sin(emb) cos_4d = op.Unsqueeze(cos, 1) # convert sin_4d = op.Unsqueeze(sin, 1) - return op.RotaryEmbedding(x, cos_4d, sin_4d, interleaved=0, _domain="ai.onnxruntime.fusion") + return op.RotaryEmbedding( + x, + cos_4d, + sin_4d, + interleaved=interleaved, + num_heads=num_heads, + _domain="ai.onnxruntime.fusion", + ) def check(self, context, inv_freq, position_ids, **_): if not _ir_utils.has_rank(position_ids, 2): @@ -47,7 +55,7 @@ def check(self, context, inv_freq, position_ids, **_): return False return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1 - def rewrite(self, op, x, inv_freq, position_ids, **_): + def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_): inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1) pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) angles = np.matmul(pos_id_range, inv_freq_values) @@ -59,7 +67,15 @@ def rewrite(self, op, x, inv_freq, position_ids, **_): # cos = op.Gather(cos_2d, position_ids, axis=0) sin_2d = op.Constant(value=ir.tensor(sin_value)) # sin = op.Gather(sin_2d, position_ids, axis=0) - return op.RotaryEmbedding(x, cos_2d, sin_2d, position_ids, interleaved=0, _domain="ai.onnxruntime.fusion") + return op.RotaryEmbedding( + x, + position_ids, + cos_2d, + sin_2d, + interleaved=interleaved, + num_heads=num_heads, + _domain="com.microsoft", + ) _rule = CosSinCacheFusion.rule("CosSinCache", 2048) @@ -70,4 +86,5 @@ def rewrite(self, op, x, inv_freq, position_ids, **_): def fuse_cos_sin_cache(model: ir.Model) -> int: count = cos_sin_cache_rules.apply_to_model(model) print(f"CosSinCache count: {count}") + remove_unused_nodes(model) return count diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py index 83749cb5d..22e6bfeee 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -28,9 +28,11 @@ def pattern(self, op, x, cos, sin, start1, end1, start2, end2): return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin def check(self, op, x, start1, end1, start2, end2, **_): - # x needs to be a 4D tensor with known last dimension size (== head_size) + # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) if x is None or x.shape is None or len(x.shape) != 4: return False + if not isinstance(x.shape[1], int): + return False head_size = x.shape[3] if not isinstance(head_size, int): return False @@ -45,7 +47,10 @@ def check(self, op, x, start1, end1, start2, end2, **_): ) def rewrite(self, op, x, cos, sin, **_): - return op.RotaryEmbedding(x, cos, sin, interleaved=0, _domain="ai.onnxruntime.fusion") + num_heads = x.shape[1] + return op.RotaryEmbedding( + x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion" + ) _rule = RotaryEmbeddingFusion.rule() @@ -53,6 +58,7 @@ def rewrite(self, op, x, cos, sin, **_): rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) -def fuse_rotary_embedding(model: ir.Model) -> None: +def fuse_rotary_embedding(model: ir.Model) -> int: count = rotary_embedding_rules.apply_to_model(model) print(f"Rotary Embedding count: {count}") + return count From 17c06c395d266bee562bf5a265b04b34bdd62330 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sat, 21 Dec 2024 22:17:57 -0800 Subject: [PATCH 11/17] Fix concat bug in rotary embedding --- onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py index 440f4a111..538070feb 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -60,9 +60,9 @@ def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_): pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) angles = np.matmul(pos_id_range, inv_freq_values) cos_value = np.cos(angles) - cos_value = np.concatenate([cos_value, cos_value], axis=-1) + # cos_value = np.concatenate([cos_value, cos_value], axis=-1) sin_value = np.sin(angles) - sin_value = np.concatenate([sin_value, sin_value], axis=-1) + # sin_value = np.concatenate([sin_value, sin_value], axis=-1) cos_2d = op.Constant(value=ir.tensor(cos_value)) # cos = op.Gather(cos_2d, position_ids, axis=0) sin_2d = op.Constant(value=ir.tensor(sin_value)) From c7c7c79f2d76689e9d608f28467dadc6d749f4ba Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 22 Dec 2024 21:52:13 -0800 Subject: [PATCH 12/17] Minor cleanup --- onnxscript/optimizer/_constant_folding.py | 1 + .../onnxruntime/xformers/cos_sin_cache.py | 26 ++++++++++++++----- .../onnxruntime/xformers/rotary_embedding.py | 8 +++--- onnxscript/rewriter/pattern.py | 4 +-- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 1225b08d1..d6051ddee 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -308,6 +308,7 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> @register("Reshape") def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Reshape node by Identity when applicable.""" input = _get_input(node, 0) shape = _get_input(node, 1) if input is None or shape is None: diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py index 538070feb..b8ebd7841 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -10,21 +10,37 @@ # Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops. +# We match against the following code pattern: # Original code (from transformers) for computing cos/sin cache for RoPE: # https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135 -# inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # position_ids_expanded = position_ids[:, None, :].float() # freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # emb = torch.cat((freqs, freqs), dim=-1) # cos = emb.cos() # sin = emb.sin() +# +# We rewrite this pattern into the following form: +# inv_freq_values = inv_freq_expanded.reshape(1, -1) +# pos_id_range = np.arange(max_pos_id, dtype=np.float32).reshape(-1, 1) +# angles = np.matmul(pos_id_range, inv_freq_values) +# cos_value = np.cos(angles) +# sin_value = np.sin(angles) +# cos_2d = op.Constant(value=ir.tensor(cos_value)) +# sin_2d = op.Constant(value=ir.tensor(sin_value)) +# +# This produces cos/sin values in a form that can be used by ORT's custom ops. + +# TODO: To apply the pattern-rewrite, we need to know the maximum position id. +# Need to find a way to get this information from the model or its config. class CosSinCacheFusion(pattern.RewriteRuleClassBase): def __init__(self, name: str, max_pos_id: int): - super().__init__(name) + # This pattern makes use of shared Cos/Sin values. So, we can't remove the + # matched nodes as part of the rewrite-step. We apply a separate final + # pass to remove unused nodes. + super().__init__(name, remove_nodes=False) self._max_pos_id = max_pos_id - self.remove_nodes = False def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads): position_ids_expanded = op.Unsqueeze(position_ids, 1) @@ -60,13 +76,9 @@ def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_): pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) angles = np.matmul(pos_id_range, inv_freq_values) cos_value = np.cos(angles) - # cos_value = np.concatenate([cos_value, cos_value], axis=-1) sin_value = np.sin(angles) - # sin_value = np.concatenate([sin_value, sin_value], axis=-1) cos_2d = op.Constant(value=ir.tensor(cos_value)) - # cos = op.Gather(cos_2d, position_ids, axis=0) sin_2d = op.Constant(value=ir.tensor(sin_value)) - # sin = op.Gather(sin_2d, position_ids, axis=0) return op.RotaryEmbedding( x, position_ids, diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py index 22e6bfeee..b36cf2c9b 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -7,11 +7,11 @@ # Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern # for full rotation without interleaving. -# TODO(rama): Add pattern variations to handle other cases. +# TODO(rama): Add pattern variations to handle other cases (interleaved, as well as partial rotation). -# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet, -# so it can't be tested by running against ORT. Unfortunately, this is the new pattern out -# of current version of transformers (not yet supported by ORT). +# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet. +# so it can't be tested by running against ORT. See cos_sin_cache.py for a transformation that +# rewrites the pattern into one that can be run against ORT. def _rotate_half_pattern(op, x, start1, end1, start2, end2): diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 282c7d714..eec45ab02 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1457,9 +1457,9 @@ def rule(cls, *args, **kwargs): remove_nodes=instance.remove_nodes, ) - def __init__(self, name: str | None = None) -> None: + def __init__(self, name: str | None = None, remove_nodes: bool = True) -> None: self.name = name or self.__class__.__name__ - self.remove_nodes = True + self.remove_nodes = remove_nodes def pattern(self, op, *args, **kwargs): raise NotImplementedError("Method 'pattern' must be implemented by derived class.") From 9a4a58e2a9d36a125ddfd8f7071cfc59bda8b583 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 23 Dec 2024 10:08:23 -0800 Subject: [PATCH 13/17] Use callable to test callable --- onnxscript/rewriter/_ir_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 9bee21642..cefbec823 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -89,7 +89,7 @@ def is_singleton_value( scalar = get_singleton_value(val) if scalar is None: return False - if isinstance(expected, Callable): + if callable(expected): return expected(scalar) if isinstance(expected, int): return expected == scalar From 766791dbf117d3538eab3ba1583473f8811a26f6 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 23 Dec 2024 11:10:11 -0800 Subject: [PATCH 14/17] Fix lint issues --- onnxscript/rewriter/_ir_utils.py | 2 +- onnxscript/rewriter/pattern.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index cefbec823..1d657a5ab 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -95,7 +95,7 @@ def is_singleton_value( return expected == scalar # rtol must be specified for float comparison assert rtol is not None - return math.isclose(scalar, expected, rtol=rtol) + return math.isclose(scalar, expected, rel_tol=rtol) def has_rank(value: ir.Value | None, rank: int) -> bool: diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index eec45ab02..fa43c19d5 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1178,6 +1178,8 @@ def _multi_match(self, candidate: Iterable[ir.Node], check_removable: bool) -> M Args: candidate: An iterable of nodes that will be matched against the pattern output nodes. + check_removable: If True, check that the matched nodes can be removed (that is, that + they are not used elsewhere in the graph). """ match = self._match for pattern_node, node in zip(self.pattern.output_nodes, candidate): From 2b5309a9e6b93f01f8a4720ea53b44c298197a6c Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 27 Dec 2024 16:57:50 -0800 Subject: [PATCH 15/17] Update generic matcher for new parameter --- onnxscript/rewriter/generic_pattern.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index 2926f5964..de06d7a22 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -551,7 +551,12 @@ def match( graph_or_function: ir.Graph | ir.Function, node: ir.Node, verbose: int = 0, + remove_nodes: bool = True, ) -> orp.MatchResult | None: + if not remove_nodes: + raise NotImplementedError( + "remove_nodes=False is not implemented in GenericPatternMatcher" + ) del model del graph_or_function self.verbose = verbose From ed781dfd479f97a35fd68379d2cfe89189704e84 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Tue, 7 Jan 2025 10:43:33 -0800 Subject: [PATCH 16/17] Update onnxscript/rewriter/onnxruntime/xformers/__init__.py Co-authored-by: Justin Chu --- onnxscript/rewriter/onnxruntime/xformers/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/__init__.py b/onnxscript/rewriter/onnxruntime/xformers/__init__.py index 53c8f2701..43cec1352 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/__init__.py +++ b/onnxscript/rewriter/onnxruntime/xformers/__init__.py @@ -2,14 +2,14 @@ # Licensed under the MIT License. from __future__ import annotations -from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache -from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization -from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding -from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization - __all__ = [ "fuse_rms_normalization", "fuse_normalization", "fuse_rotary_embedding", "fuse_cos_sin_cache", ] + +from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization +from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding +from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization From f346d0b2480049ac673b2fa26c861be887b8e7fd Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Tue, 7 Jan 2025 10:43:52 -0800 Subject: [PATCH 17/17] Update onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py Co-authored-by: Justin Chu --- onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py index b8ebd7841..46272ccf9 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -61,7 +61,7 @@ def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads): _domain="ai.onnxruntime.fusion", ) - def check(self, context, inv_freq, position_ids, **_): + def check(self, context, inv_freq, position_ids, **_) -> bool: if not _ir_utils.has_rank(position_ids, 2): return False if not _ir_utils.has_rank(inv_freq, 3):