diff --git a/setup.py b/setup.py index 8ee11f8f..516c7ad2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.44.8', + version = '2.0.0', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/tests/test_x_transformers.py b/tests/test_x_transformers.py index d5b8c3f3..468ff19d 100644 --- a/tests/test_x_transformers.py +++ b/tests/test_x_transformers.py @@ -651,7 +651,7 @@ def test_hybrid(hybrid_axial_dim): mask = torch.randint(0, 2, (2, 1024)).bool() embed = enc(x, mask = mask) -def test_latent_q_and_kv(): +def test_multi_latent_attention(): model = TransformerWrapper( num_tokens = 20000, max_seq_len = 1024, @@ -659,10 +659,12 @@ def test_latent_q_and_kv(): dim = 128, depth = 6, heads = 8, - attn_dim_latent_q = 64, attn_use_latent_q = True, - attn_dim_latent_kv = 64, - attn_use_latent_kv = True + attn_dim_latent_q = 128, + attn_use_latent_kv = True, + attn_dim_latent_kv = 128, + attn_latent_rope_subheads = 4, + rotary_pos_emb = False ) ) diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 198a92fd..48f5d3ed 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -1524,7 +1524,7 @@ def forward( latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2) if exists(maybe_cached_k_sub_heads): - k_sub_heads = cat((k_sub_heads, maybe_cached_k_sub_heads), dim = -2) + k_sub_heads = cat((maybe_cached_k_sub_heads, k_sub_heads), dim = -2) if return_intermediates: cached_kv = (latent_kv_input, k_sub_heads) @@ -1568,23 +1568,24 @@ def forward( # take care of caching - if not is_multi_latent_attn and exists(cache): - ck, cv = cache.cached_kv + if not is_multi_latent_attn: + if exists(cache): + ck, cv = cache.cached_kv - if exists(mem): - mk, k = unpack(k, mem_packed_shape, 'b h * d') - mv, v = unpack(v, mem_packed_shape, 'b h * d') + if exists(mem): + mk, k = unpack(k, mem_packed_shape, 'b h * d') + mv, v = unpack(v, mem_packed_shape, 'b h * d') - k = cat((ck, k), dim = -2) - v = cat((cv, v), dim = -2) + k = cat((ck, k), dim = -2) + v = cat((cv, v), dim = -2) - if exists(mem): - k = cat((mk, k), dim = -2) - v = cat((mv, v), dim = -2) + if exists(mem): + k = cat((mk, k), dim = -2) + v = cat((mv, v), dim = -2) - if not is_multi_latent_attn and return_intermediates: - mem_len = mem.shape[-2] if exists(mem) else 0 - cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :]) + if return_intermediates: + mem_len = mem.shape[-2] if exists(mem) else 0 + cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :]) if exists(rotary_pos_emb): rotate_num_heads = self.rotate_num_heads @@ -1594,8 +1595,8 @@ def forward( q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.) if partial_rotate_heads: - q, q_rest = q[:, -rotate_num_heads:], q[:, :-rotate_num_heads] - k, k_rest = k[:, -rotate_num_heads:], k[:, :-rotate_num_heads] + q_rest, q = q[:, :-rotate_num_heads], q[:, -rotate_num_heads:] + k_rest, k = k[:, :-rotate_num_heads], k[:, -rotate_num_heads:] q = apply_rotary_pos_emb(q, freqs, q_xpos_scale) @@ -1608,8 +1609,8 @@ def forward( k = apply_rotary_pos_emb(k, freqs, k_xpos_scale) if partial_rotate_heads: - q = cat((q, q_rest), dim = 1) - k = cat((k, k_rest), dim = 1) + q = cat((q_rest, q), dim = 1) + k = cat((k_rest, k), dim = 1) input_mask = context_mask