Skip to content

Commit

Permalink
release multi-latent attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 4, 2025
1 parent c6edc65 commit 16a743c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.44.8',
version = '2.0.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
10 changes: 6 additions & 4 deletions tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,18 +651,20 @@ def test_hybrid(hybrid_axial_dim):
mask = torch.randint(0, 2, (2, 1024)).bool()
embed = enc(x, mask = mask)

def test_latent_q_and_kv():
def test_multi_latent_attention():
model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 128,
depth = 6,
heads = 8,
attn_dim_latent_q = 64,
attn_use_latent_q = True,
attn_dim_latent_kv = 64,
attn_use_latent_kv = True
attn_dim_latent_q = 128,
attn_use_latent_kv = True,
attn_dim_latent_kv = 128,
attn_latent_rope_subheads = 4,
rotary_pos_emb = False
)
)

Expand Down
37 changes: 19 additions & 18 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,7 +1524,7 @@ def forward(
latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2)

if exists(maybe_cached_k_sub_heads):
k_sub_heads = cat((k_sub_heads, maybe_cached_k_sub_heads), dim = -2)
k_sub_heads = cat((maybe_cached_k_sub_heads, k_sub_heads), dim = -2)

if return_intermediates:
cached_kv = (latent_kv_input, k_sub_heads)
Expand Down Expand Up @@ -1568,23 +1568,24 @@ def forward(

# take care of caching

if not is_multi_latent_attn and exists(cache):
ck, cv = cache.cached_kv
if not is_multi_latent_attn:
if exists(cache):
ck, cv = cache.cached_kv

if exists(mem):
mk, k = unpack(k, mem_packed_shape, 'b h * d')
mv, v = unpack(v, mem_packed_shape, 'b h * d')
if exists(mem):
mk, k = unpack(k, mem_packed_shape, 'b h * d')
mv, v = unpack(v, mem_packed_shape, 'b h * d')

k = cat((ck, k), dim = -2)
v = cat((cv, v), dim = -2)
k = cat((ck, k), dim = -2)
v = cat((cv, v), dim = -2)

if exists(mem):
k = cat((mk, k), dim = -2)
v = cat((mv, v), dim = -2)
if exists(mem):
k = cat((mk, k), dim = -2)
v = cat((mv, v), dim = -2)

if not is_multi_latent_attn and return_intermediates:
mem_len = mem.shape[-2] if exists(mem) else 0
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
if return_intermediates:
mem_len = mem.shape[-2] if exists(mem) else 0
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])

if exists(rotary_pos_emb):
rotate_num_heads = self.rotate_num_heads
Expand All @@ -1594,8 +1595,8 @@ def forward(
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)

if partial_rotate_heads:
q, q_rest = q[:, -rotate_num_heads:], q[:, :-rotate_num_heads]
k, k_rest = k[:, -rotate_num_heads:], k[:, :-rotate_num_heads]
q_rest, q = q[:, :-rotate_num_heads], q[:, -rotate_num_heads:]
k_rest, k = k[:, :-rotate_num_heads], k[:, -rotate_num_heads:]

q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)

Expand All @@ -1608,8 +1609,8 @@ def forward(
k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)

if partial_rotate_heads:
q = cat((q, q_rest), dim = 1)
k = cat((k, k_rest), dim = 1)
q = cat((q_rest, q), dim = 1)
k = cat((k_rest, k), dim = 1)

input_mask = context_mask

Expand Down

0 comments on commit 16a743c

Please sign in to comment.