diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 4489b2688fa6..d8e567767c6c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -45,14 +45,17 @@ def extend_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, def append_kv_cache(cache_k, cache_v, key_states, value_states): - new_size = (cache_k.size(0), - cache_k.size(1), - cache_k.size(2) + key_states.size(2), - cache_k.size(3)) + size_0, size_1, old_length, size_3 = cache_k.size() + k_size_2 = key_states.size(2) + new_length = old_length + k_size_2 + new_size = (size_0, + size_1, + new_length, + size_3) new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0) - new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states + new_cache_k[:, :, old_length:new_length, :] = key_states new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) - new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states + new_cache_v[:, :, old_length:new_length, :] = value_states return new_cache_k, new_cache_v