diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index cf387bfab2..284520d4d5 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -93,6 +93,7 @@ runtime.python_library( "source_transformation/sdpa.py", "source_transformation/spin_quant.py", "source_transformation/vulkan_rope.py", + "source_transformation/attention_sink.py", ], _is_external_target = True, base_module = "executorch.examples.models.llama", @@ -213,3 +214,16 @@ runtime.python_test( "//executorch/examples/models/llama:llama_transformer", ], ) + +runtime.python_test( + name = "attention_sink_test", + srcs = [ + "source_transformation/test_attention_sink.py", + ], + supports_static_listing = False, + deps = [ + "fbsource//third-party/pypi/parameterized:parameterized", + "//caffe2:torch", + ":export_library", + ], +) diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 0383c79898..1445787f5e 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -92,6 +92,22 @@ def apply_rotary_emb( return xq_out.type_as(xq), xk_out.type_as(xk) +def apply_rotary_emb_to_k( + xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) + + freqs_cos = reshape_for_broadcast(freqs_cos, xk_r) + freqs_sin = reshape_for_broadcast(freqs_sin, xk_r) + + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos + + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + + return xk_out.type_as(xk) + + class RotaryEmbedding(torch.nn.Module): def __init__(self): super().__init__() @@ -160,3 +176,28 @@ def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + + +def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the key tensors. + + Args: + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of k. Similarly, if k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `torch.Tensor` the key tensor rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + k_embed = (k * cos) + (rotate_half(k) * sin) + return k_embed diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py new file mode 100644 index 0000000000..94f5b47871 --- /dev/null +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Components for supporting Attention Sink. See +# https://arxiv.org/abs/2309.17453 for more details about Attention Sink. + +import torch + +from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope +from executorch.examples.models.llama.rope import ( + apply_rotary_emb_to_k, + hf_apply_rotary_emb_to_k, +) + + +class RopeWithAttentionSink(Rope): + """ + Rope that helps adjust position encoding when tokens are shifted in KVCache. + For AttentionSink, when tokens are shifted in KVCache, we need to use positions + in KVCache instead of positions in the actual text. + """ + + def __init__(self, params: ModelArgs): + super().__init__(params) + if self.params.use_hf_rope: + self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k + else: + self.apply_rotary_emb_to_k = apply_rotary_emb_to_k + + def rerotate_k( + self, + k: torch.Tensor, + original_position: int, + new_position: int, + ): + """ + Rerotate k from original_position to new_position. This is done by rerotating + k with (new_position * theta - original_position * theta) with the following matrix: + (cos(delta), -sin(delta) + sin(delta), cos(delta)) + where delta = new_position * theta - original_position * theta + + The shape of k is (batch_size, seq_len, n_local_heads, head_dim) + + Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961 + """ + seq_len = k.shape[1] + original_freqs_cos = self.freqs_cos.narrow(0, original_position, seq_len) + original_freqs_sin = self.freqs_sin.narrow(0, original_position, seq_len) + new_freqs_cos = self.freqs_cos.narrow(0, new_position, seq_len) + new_freqs_sin = self.freqs_sin.narrow(0, new_position, seq_len) + rerotation_cos = ( + new_freqs_cos * original_freqs_cos + new_freqs_sin * original_freqs_sin + ) + rerotation_sin = ( + new_freqs_sin * original_freqs_cos - new_freqs_cos * original_freqs_sin + ) + + return self.apply_rotary_emb_to_k(k, rerotation_cos, rerotation_sin) diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py new file mode 100644 index 0000000000..adb3bff3a5 --- /dev/null +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.examples.models.llama.llama_transformer import ModelArgs + +from executorch.examples.models.llama.source_transformation.attention_sink import ( + RopeWithAttentionSink, +) +from parameterized import parameterized + + +class RopeWithAttentionSinkTest(unittest.TestCase): + + def setUp(self): + torch.manual_seed(42) + self.params = ModelArgs(use_kv_cache=True, enable_dynamic_shape=True) + self.rope_with_attention_sink = RopeWithAttentionSink(params=self.params) + + @parameterized.expand( + [ + [128, 127], # Rotate left + [128, 128], # No rotation + [128, 129], # Rotate right + ] + ) + def test_rotate(self, original_position, new_position): + seq_len = 32 + + q = torch.rand( + 1, seq_len, self.params.n_heads, self.params.head_dim, dtype=torch.float32 + ) + k = torch.rand( + 1, + seq_len, + self.params.n_heads, + self.params.head_dim, + dtype=torch.float32, + ) + freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + input_pos=torch.tensor([original_position], dtype=torch.int32), + seq_len=seq_len, + ) + _, pre_rotated_k = self.rope_with_attention_sink.forward( + q=q, + k=k, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + ) + + rerotated_k = self.rope_with_attention_sink.rerotate_k( + k=pre_rotated_k, + original_position=original_position, + new_position=new_position, + ) + + freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + input_pos=torch.tensor([new_position], dtype=torch.int32), + seq_len=seq_len, + ) + _, expected_k = self.rope_with_attention_sink.forward( + q=q, + k=k, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + ) + + torch.testing.assert_close(rerotated_k, expected_k)