diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 55d4475a8..03fe62491 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -403,15 +403,18 @@ def allocate(self, inp_seq_len, dtype, device, shape): @staticmethod def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur + if inp_seq_len != -1: + # reuse cache logic + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur if idx is not None: + # 2+ tokenizer logic if model is static shape optimized prev.index_copy_(dim, idx - 1, cur) return prev else: @@ -654,9 +657,11 @@ def pre_attn_forward( past_value = torch.zeros( key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device ) + past_key.copy_(key_states) + past_value.copy_(value_states) # Return list instead of tuple past_key_value = [past_key, past_value] - if ( + elif ( token_idx is not None and num_virtual_tokens is not None and num_virtual_tokens == past_key_value[0].shape[-2]