Skip to content

Commit

Permalink
the rotateable subhead keys from MLA needs to be cached
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 4, 2025
1 parent 7ad71d3 commit c6edc65
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,20 +1504,30 @@ def forward(
# multi-latent attention logic
# https://arxiv.org/abs/2405.04434 - Deepseek-AI team

k_sub_heads = None # the rotateable subheads of keys derived from base sequence

if self.use_latent_q:
q_input = self.to_latent_q(q_input)

if is_multi_latent_attn:
assert not qkv_receive_diff_residuals
needs_k_sub_heads = exists(self.to_rotateable_k)

latent_kv_input = self.to_latent_kv(k_input)

if needs_k_sub_heads:
rotateable_k = self.to_rotateable_k(k_input)
k_sub_heads = self.split_rotateable_k_heads(rotateable_k)

if exists(cache):
cached_latent_kv = cache.cached_kv
cached_latent_kv, maybe_cached_k_sub_heads = cache.cached_kv
latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2)

if exists(maybe_cached_k_sub_heads):
k_sub_heads = cat((k_sub_heads, maybe_cached_k_sub_heads), dim = -2)

if return_intermediates:
cached_kv = latent_kv_input
cached_kv = (latent_kv_input, k_sub_heads)

k_input = v_input = latent_kv_input

Expand All @@ -1533,11 +1543,8 @@ def forward(

# take care of decoupled rope from multi-latent attention

if exists(self.to_rotateable_k):
rotate_k = self.to_rotateable_k(k_input)
rotate_k = self.split_rotateable_k_heads(rotate_k)

k = cat((k, rotate_k), dim = 1)
if exists(k_sub_heads):
k = cat((k, k_sub_heads), dim = 1)

# if previous values passed in for residual, either invoke resformer

Expand Down

0 comments on commit c6edc65

Please sign in to comment.