From c726a9bf545f7721f7861aacda373775c1caa4c5 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 27 Nov 2024 12:06:49 -0800 Subject: [PATCH] Implement get_freqs for RopeWithAttentionSink This PR implements the `get_freqs` function for `RopeWithAttentionSink`. It returns the `freqs_cos` and `freqs_sin` for given `input_pos` and `seq_len` after shifting tokens in the pre-computed `freqs_cos` and `freq_sin`. Differential Revision: [D66525306](https://our.internmc.facebook.com/intern/diff/D66525306/) ghstack-source-id: 255582545 Pull Request resolved: https://github.com/pytorch/executorch/pull/7100 Co-authored-by: Lunwen He --- .../source_transformation/attention_sink.py | 29 ++++++++++- .../test_attention_sink.py | 51 ++++++++++++++++++- 2 files changed, 77 insertions(+), 3 deletions(-) diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 94f5b47871..8f4fd1ebd2 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -7,6 +7,8 @@ # Components for supporting Attention Sink. See # https://arxiv.org/abs/2309.17453 for more details about Attention Sink. +from typing import Optional + import torch from executorch.examples.models.llama.llama_transformer import ModelArgs, Rope @@ -23,12 +25,37 @@ class RopeWithAttentionSink(Rope): in KVCache instead of positions in the actual text. """ - def __init__(self, params: ModelArgs): + def __init__( + self, + params: ModelArgs, + window_size: int, + sink_size: int, + eviction_batch_size: int, + ): super().__init__(params) if self.params.use_hf_rope: self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k + self.max_seq_length = window_size + sink_size + assert self.max_seq_length == self.params.max_seq_len + self.eviction_batch_size = eviction_batch_size + self.position_shift = 0 + + def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): + assert input_pos is not None + + input_pos_item = input_pos.item() + torch._check_is_size(input_pos_item) + if input_pos_item + self.position_shift + seq_len > self.max_seq_length: + # There are not enough spaces in the cache to store the new tokens. + # We need to evict some old tokens and shift some recent tokens. + num_to_evict = max( + input_pos_item + self.position_shift - self.max_seq_length + seq_len, + self.eviction_batch_size, + ) + self.position_shift -= num_to_evict # pyre-ignore [8] + return super().get_freqs(input_pos + self.position_shift, seq_len) def rerotate_k( self, diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py index adb3bff3a5..8eaa992dc3 100644 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -17,10 +17,57 @@ class RopeWithAttentionSinkTest(unittest.TestCase): + def _init_rope(self, params: ModelArgs, eviction_batch_size: int): + return RopeWithAttentionSink( + params=params, + window_size=252, + sink_size=4, + eviction_batch_size=eviction_batch_size, + ) + def setUp(self): torch.manual_seed(42) - self.params = ModelArgs(use_kv_cache=True, enable_dynamic_shape=True) - self.rope_with_attention_sink = RopeWithAttentionSink(params=self.params) + self.params = ModelArgs( + use_kv_cache=True, enable_dynamic_shape=True, max_seq_len=256 + ) + self.rope_with_attention_sink = self._init_rope( + params=self.params, eviction_batch_size=1 + ) + + @parameterized.expand( + [ + [0, 10, 1, 0], # No shift + [250, 10, 1, 246], # Some shift + [256, 10, 1, 246], # All shift + [0, 10, 30, 0], # No shift with batch eviction + [250, 10, 30, 220], # Some shift with batch eviction + [256, 10, 30, 226], # All shift with batch eviction + ] + ) + def test_get_freqs( + self, input_pos, seq_len, eviction_batch_size, expected_result_pos + ): + self.rope_with_attention_sink = self._init_rope( + params=self.params, eviction_batch_size=eviction_batch_size + ) + + freqs_cos, freqs_sin = self.rope_with_attention_sink.get_freqs( + input_pos=torch.tensor([input_pos], dtype=torch.int32), + seq_len=seq_len, + ) + + torch.testing.assert_close( + freqs_cos, + self.rope_with_attention_sink.freqs_cos.narrow( + 0, expected_result_pos, seq_len + ), + ) + torch.testing.assert_close( + freqs_sin, + self.rope_with_attention_sink.freqs_sin.narrow( + 0, expected_result_pos, seq_len + ), + ) @parameterized.expand( [