From 2d93bfae9ed264f8bc509489dc6c6d1882a505a9 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 12 Dec 2024 16:48:09 -0800 Subject: [PATCH 1/3] fix lm_eval issue of llama Signed-off-by: Wang, Yi A --- .../models/llama/modeling_llama.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 72ce034ef..01e852d28 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -385,15 +385,18 @@ def forward( class LlamaKVCache(KVCache): @staticmethod def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur + if inp_seq_len != -1: + #reuse cache logic + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur if idx is not None: + # 2+ tokenizer logic if model is static shape optimized prev.index_copy_(dim, idx - 1, cur) return prev else: @@ -627,9 +630,11 @@ def pre_attn_forward( past_value = torch.zeros( key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device ) + past_key.copy_(key_states) + past_value.copy_(value_states) # Return list instead of tuple past_key_value = [past_key, past_value] - if ( + elif ( token_idx is not None and num_virtual_tokens is not None and num_virtual_tokens == past_key_value[0].shape[-2] From 00f269b7e3bfabdab71f6d46c10485061f643c0d Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 12 Dec 2024 17:24:52 -0800 Subject: [PATCH 2/3] fmt Signed-off-by: Wang, Yi A --- optimum/habana/transformers/models/llama/modeling_llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 01e852d28..f27868449 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -33,6 +33,7 @@ from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module from .configuration_llama import LlamaConfig + try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa @@ -57,6 +58,7 @@ import habana_frameworks.torch.core as htcore + def gaudi_llama_rmsnorm_forward(self, hidden_states): """ Copied from LlamaRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -382,11 +384,12 @@ def forward( padding_side, ) + class LlamaKVCache(KVCache): @staticmethod def update(prev, cur, dim, idx, inp_seq_len): if inp_seq_len != -1: - #reuse cache logic + # reuse cache logic orig_cur = cur if prev.shape == cur.shape: prev.copy_(cur) @@ -402,6 +405,7 @@ def update(prev, cur, dim, idx, inp_seq_len): else: return torch.cat((prev, cur), dim=dim) + def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed): if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: return fused_scaled_dot_product_attention_distributed From 65b8a7c4cde78908eff59419263ccbcbdfc55470 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Fri, 13 Dec 2024 02:49:55 -0800 Subject: [PATCH 3/3] fmt Signed-off-by: Wang, Yi A --- optimum/habana/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 295c4dc6f..03fe62491 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -385,7 +385,6 @@ def forward( ) - class KVCache(torch.nn.Module): def __init__(self): super(KVCache, self).__init__() @@ -429,6 +428,7 @@ def get_shape(self): def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed): if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1: return fused_scaled_dot_product_attention_distributed