Skip to content

Commit

Permalink
feat: kv cache
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Nov 21, 2024
1 parent df17501 commit 7478c93
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 18 deletions.
32 changes: 20 additions & 12 deletions src/xlens/components/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class Attention(nnx.Module):
repeat_kv_heads: Optional[int]
rotary_sin: Optional[nnx.Variable[Float[jax.Array, "n_ctx rotary_dim"]]]
rotary_cos: Optional[nnx.Variable[Float[jax.Array, "n_ctx rotary_dim"]]]
past_kv_cache: nnx.Variable[
Optional[tuple[Float[jax.Array, "batch kv_pos d_model"], Float[jax.Array, "batch kv_pos d_model"]]]
]

W_Q: nnx.Param[Float[jax.Array, "n_heads d_model d_head"]]
W_K: nnx.Param[Float[jax.Array, "n_heads d_model d_head"] | Float[jax.Array, "n_key_value_heads d_model d_head"]]
Expand Down Expand Up @@ -143,6 +146,8 @@ def __init__(
self.rotary_sin = None
self.rotary_cos = None

self.past_kv_cache = nnx.Variable(None)

def __call__(
self,
query_input: Union[
Expand All @@ -160,17 +165,22 @@ def __call__(
Float[jax.Array, "batch kv_pos kv_head_index d_model"],
],
additive_attention_mask: Optional[Float[jax.Array, "batch 1 1 kv_pos"]] = None,
attention_mask: Optional[Int[jax.Array, "batch offset_pos"]] = None,
attention_mask: Optional[Int[jax.Array, "batch kv_pos"]] = None,
) -> Float[jax.Array, "batch pos d_model"]:
"""Forward pass for attention.
additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
attention_mask is the attention mask for padded tokens. Defaults to None.
"""

q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)

kv_cache_pos_offset = 0
kv_cache_pos_offset = 0 if self.past_kv_cache.value is None else self.past_kv_cache.value[0].shape[1]
if self.past_kv_cache.value is not None:
k_cache, v_cache = self.past_kv_cache.value
k = jnp.concatenate([k_cache, k], axis=1)
v = jnp.concatenate([v_cache, v], axis=1)
self.past_kv_cache.value = (k, v)
# print(q[0, -1, 0, :5])

if self.cfg.positional_embedding_type == "rotary":
assert self.hook_rot_k is not None and self.hook_rot_q is not None, "Rotary hooks must be defined"
Expand Down Expand Up @@ -204,19 +214,15 @@ def __call__(
pattern = jnp.where(jnp.isnan(pattern), jnp.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
w = einops.rearrange(
self.W_O.value,
"head_index d_head d_model -> d_model head_index d_head",
)
result = self.hook_result(
einops.einsum(
z,
w,
"... head_index d_head, d_model head_index d_head -> ... head_index d_model",
self.W_O.value,
"... head_index d_head, head_index d_head d_model -> ... head_index d_model",
)
) # [batch, pos, head_index, d_model]
out = (
einops.reduce(result, "batch position index model->batch position model", "sum") + self.b_O
einops.reduce(result, "batch pos index model -> batch pos model", "sum") + self.b_O
) # [batch, pos, d_model]
return out

Expand Down Expand Up @@ -295,7 +301,7 @@ def apply_causal_mask(
self,
attn_scores: Float[jax.Array, "batch head_index pos pos_plus_past_kv_pos_offset"],
past_kv_pos_offset: int = 0,
attention_mask: Optional[Int[jax.Array, "batch offset_pos"]] = None,
attention_mask: Optional[Int[jax.Array, "batch kv_pos"]] = None,
):
# The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different.
query_ctx_length = attn_scores.shape[-2]
Expand All @@ -313,7 +319,9 @@ def apply_causal_mask(
if attention_mask is not None:
# Apply a causal mask to the attention scores considering the padding
final_mask = einops.einsum(
final_mask, attention_mask, "batch head pos offset_pos, batch offset_pos -> batch head pos offset_pos"
final_mask,
attention_mask,
"batch head pos kv_pos, batch kv_pos -> batch head pos kv_pos",
).astype(bool)

return jnp.where(final_mask, attn_scores, -jnp.inf)
Expand Down
3 changes: 1 addition & 2 deletions src/xlens/components/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,12 @@ def identity(x: jax.Array):
def __call__(
self,
resid_pre: Float[jax.Array, "batch pos d_model"],
attention_mask: Optional[Int[jax.Array, "batch offset_pos"]] = None,
attention_mask: Optional[Int[jax.Array, "batch kv_pos"]] = None,
) -> Float[jax.Array, "batch pos d_model"]:
"""A single Transformer block.
Args:
resid_pre (jax.Array): The residual stream - shape [batch, pos, d_model]
past_kv_cache_entry (HookedTransformerKeyValueCache): A cache of previous keys and values, used only when generating text. Defaults to None.
attention_mask (jax.Array, optional): The attention mask for padded tokens. Defaults to None.
Returns:
Expand Down
36 changes: 32 additions & 4 deletions src/xlens/hooked_transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, cast

import flax.nnx as nnx
import jax
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(self, cfg: HookedTransformerConfig):
def __call__(
self,
input_ids: Int[jax.Array, "batch pos"],
attention_mask: Optional[jax.Array] = None, # [batch pos]
attention_mask: Optional[Int[jax.Array, "batch kv_pos"]] = None,
) -> Float[jax.Array, "batch pos d_vocab"]:
"""Forward Pass.
Expand All @@ -96,10 +96,18 @@ def __call__(

tokens = self.hook_tokens(input_ids) # [batch, pos]
embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model]
pos_embed = self.hook_pos_embed(self.pos_embed(tokens, 0, attention_mask)) # [batch, pos, d_model]
self._check_kv_cache_consistency() # Check that the KV cache is consistent
past_kv_pos_offset = (
0
if self.blocks[0].attn.past_kv_cache.value is None
else self.blocks[0].attn.past_kv_cache.value[0].shape[1]
)
pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, past_kv_pos_offset, attention_mask)
) # [batch, pos, d_model]
residual = embed + pos_embed

for _, block in list(zip(range(self.cfg.n_layers), self.blocks)):
for block in self.blocks:
# Note that each block includes skip connections, so we don't need
# residual + block(residual)
residual = block(
Expand All @@ -114,6 +122,26 @@ def __call__(
logits = self.unembed(residual) # [batch, pos, d_vocab]
return logits

def _check_kv_cache_consistency(self):
"""Check if the KV cache is consistent across blocks.
This is to ensure that the KV cache is either:
- None for all blocks
- Non-None and has the same shape for all blocks
"""
all_kv_cache_values = [block.attn.past_kv_cache.value for block in self.blocks]
if any(kv_cache_value is not None for kv_cache_value in all_kv_cache_values):
assert all(
kv_cache_value is not None for kv_cache_value in all_kv_cache_values
), "All cached values must be non-None if any are set"
first_cache = cast(tuple[jax.Array, jax.Array], all_kv_cache_values[0])
first_k_shape = first_cache[0].shape
for k_cache, v_cache in cast(list[tuple[jax.Array, jax.Array]], all_kv_cache_values):
assert k_cache.shape == first_k_shape and v_cache.shape == first_k_shape, (
"All cached values must have the same shape. Found shapes %s and %s while first shape is %s"
% (k_cache.shape, v_cache.shape, first_k_shape)
)

def run_with_hooks(
self,
input_ids: Int[jax.Array, "batch pos"],
Expand Down
61 changes: 61 additions & 0 deletions tests/integration/test_kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import jax
import jax.numpy as jnp
from transformers import GPT2Tokenizer

from xlens import HookedTransformer
from xlens.components.attention import Attention
from xlens.utilities.functional import functional


def test_kv_cache_attention():
model = HookedTransformer.from_pretrained("gpt2")
input = jax.random.normal(jax.random.PRNGKey(0), (1, 10, 768))
attention = model.blocks[0].attn

@functional
def no_cache_forward(attention: Attention, input: jax.Array) -> jax.Array:
return attention(input, input, input, attention_mask=jnp.ones((1, 10)))

no_cache_result = no_cache_forward(attention, input)
assert attention.past_kv_cache.value is None

@functional
def cache_forward(attention: Attention, input: jax.Array) -> jax.Array:
assert attention.past_kv_cache.value is None
logits_head = attention(input[:, :-1], input[:, :-1], input[:, :-1])
assert attention.past_kv_cache.value is not None
logits_tail = attention(input[:, -1:], input[:, -1:], input[:, -1:])
return jnp.concatenate([logits_head, logits_tail], axis=1)

cache_result = cache_forward(attention, input)

assert jnp.allclose(no_cache_result, cache_result, atol=1e-4)


def test_kv_cache():
model = HookedTransformer.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input = jnp.array(tokenizer("Hello, my dog is cute", return_tensors="np")["input_ids"])

@functional
def no_cache_forward(model: HookedTransformer, input: jax.Array) -> jax.Array:
return model(input)

no_cache_logits = no_cache_forward(model, input)

@functional
def cache_forward(model: HookedTransformer, input: jax.Array) -> jax.Array:
logits_head = model(input[:, :-1])
assert model.blocks[0].attn.past_kv_cache.value is not None
logits_tail = model(input[:, -1:])
return jnp.concatenate([logits_head, logits_tail], axis=1)

cache_logits = cache_forward(model, input)
print("No Cache Logits: ", no_cache_logits[0, -1, :5])
print("Cache Logits: ", cache_logits[0, -1, :5])

assert jnp.allclose(no_cache_logits, cache_logits, atol=1e-4)


if __name__ == "__main__":
test_kv_cache()

0 comments on commit 7478c93

Please sign in to comment.