Skip to content

Commit

Permalink
refactor: add hook point state
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Nov 26, 2024
1 parent 42c8d3e commit 09d68b0
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 101 deletions.
18 changes: 9 additions & 9 deletions src/xlens/components/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ def __call__(
if self.cfg.positional_embedding_type == "rotary":
assert self.hook_rot_k is not None and self.hook_rot_q is not None, "Rotary hooks must be defined"
assert self.cfg.rotary_dim is not None, "Rotary dim must be defined"
q = self.hook_rot_q(
q, self.hook_rot_q = self.hook_rot_q(
self.apply_rotary(q, kv_cache_pos_offset, attention_mask, rotary_dim=self.cfg.rotary_dim)
)
k = self.hook_rot_k(
k, self.hook_rot_k = self.hook_rot_k(
self.apply_rotary(k, 0, attention_mask, rotary_dim=self.cfg.rotary_dim)
) # keys are cached so no offset

Expand All @@ -206,13 +206,13 @@ def __call__(
attn_scores, kv_cache_pos_offset, attention_mask
) # [batch, head_index, query_pos, key_pos]

attn_scores = self.hook_attn_scores(attn_scores)
attn_scores, self.hook_attn_scores = self.hook_attn_scores(attn_scores)

pattern = jax.nn.softmax(attn_scores, axis=-1)
pattern = jnp.where(jnp.isnan(pattern), jnp.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
pattern, self.hook_pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
result = self.hook_result(
result, self.hook_result = self.hook_result(
einops.einsum(
z,
self.W_O.value,
Expand Down Expand Up @@ -250,9 +250,9 @@ def attn_fn(
+ b
)

q = self.hook_q(attn_fn(query_input, self.W_Q.value, self.b_Q.value))
k = self.hook_k(attn_fn(key_input, self.W_K.value, self.b_K.value))
v = self.hook_v(attn_fn(value_input, self.W_V.value, self.b_V.value))
q, self.hook_q = self.hook_q(attn_fn(query_input, self.W_Q.value, self.b_Q.value))
k, self.hook_k = self.hook_k(attn_fn(key_input, self.W_K.value, self.b_K.value))
v, self.hook_v = self.hook_v(attn_fn(value_input, self.W_V.value, self.b_V.value))

return q, k, v

Expand Down Expand Up @@ -288,7 +288,7 @@ def calculate_z_scores(
pattern,
"batch head_index query_pos key_pos -> batch head_index query_pos key_pos",
)
z = self.hook_z(
z, self.hook_z = self.hook_z(
einops.rearrange(
pattern_ @ v_,
"batch head_index query_pos d_head -> batch query_pos head_index d_head",
Expand Down
10 changes: 6 additions & 4 deletions src/xlens/components/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module contains all the component :class:`Embed`.
"""

from typing import Optional
from typing import Optional, Self

import einops
import flax.nnx as nnx
Expand All @@ -30,13 +30,15 @@ def __init__(self, cfg: HookedTransformerConfig):
# Some models (e.g. Bloom) need post embedding layer norm
self.ln = LayerNorm(self.cfg) if self.cfg.post_embedding_ln else None

def __call__(self, tokens: Int[jax.Array, "batch pos"]) -> Float[jax.Array, "batch pos d_model"]:
def __call__(self, tokens: Int[jax.Array, "batch pos"]) -> tuple[Float[jax.Array, "batch pos d_model"], Self]:
# If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d]
# B acts as a tensor of indices into the second dimension (so >=0 and <b)
if self.cfg.post_embedding_ln:
assert self.ln is not None
return self.ln(self.W_E[tokens, :])
return self.W_E[tokens, :]
embedding, self.ln = self.ln(self.W_E[tokens, :])
else:
embedding = self.W_E[tokens, :]
return embedding, self


def get_offset_position_ids(
Expand Down
53 changes: 27 additions & 26 deletions src/xlens/components/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module contains all the component :class:`LayerNorm`.
"""

from typing import Optional, Union
from typing import Optional, Self, Union

import flax.nnx as nnx
import jax
Expand Down Expand Up @@ -50,17 +50,19 @@ def __call__(
Float[jax.Array, "batch pos d_model"],
Float[jax.Array, "batch pos head_index d_model"],
],
) -> Union[
Float[jax.Array, "batch pos d_model"],
Float[jax.Array, "batch pos head_index d_model"],
) -> tuple[
Union[
Float[jax.Array, "batch pos d_model"],
Float[jax.Array, "batch pos head_index d_model"],
],
Self,
]:
x = x - x.mean(-1, keepdims=True) # [batch, pos, length]

scale: Float[jax.Array, "batch pos 1"] = self.hook_scale(
jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps)
)
scale, self.hook_scale = self.hook_scale(jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps))
x = x / scale # [batch, pos, length]
return self.hook_normalized(x * self.w + self.b)
x, self.hook_normalized = self.hook_normalized(x * self.w + self.b)
return x, self


class LayerNormPre(nnx.Module):
Expand All @@ -87,15 +89,17 @@ def __call__(
Float[jax.Array, "batch pos d_model"],
Float[jax.Array, "batch pos head_index d_model"],
],
) -> Union[
Float[jax.Array, "batch pos d_model"],
Float[jax.Array, "batch pos head_index d_model"],
) -> tuple[
Union[
Float[jax.Array, "batch pos d_model"],
Float[jax.Array, "batch pos head_index d_model"],
],
Self,
]:
x = x - x.mean(-1, keepdims=True) # [batch, pos, length]
scale: Float[jax.Array, "batch pos 1"] = self.hook_scale(
jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps)
)
return self.hook_normalized(x / scale)
scale, self.hook_scale = self.hook_scale(jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps))
x, self.hook_normalized = self.hook_normalized(x / scale)
return x, self


class RMSNorm(nnx.Module):
Expand Down Expand Up @@ -126,12 +130,10 @@ def __init__(self, cfg: HookedTransformerConfig, length: Optional[int] = None):
self.hook_scale = HookPoint() # [batch, pos, 1]
self.hook_normalized = HookPoint() # [batch, pos, length]

def __call__(self, x: Float[jax.Array, "batch pos length"]) -> Float[jax.Array, "batch pos length"]:
scale: Float[jax.Array, "batch pos 1"] = self.hook_scale(
jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps)
)
x = self.hook_normalized(x / scale) # [batch, pos, length]
return x * self.w
def __call__(self, x: Float[jax.Array, "batch pos length"]) -> tuple[Float[jax.Array, "batch pos length"], Self]:
scale, self.hook_scale = self.hook_scale(jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps))
x, self.hook_normalized = self.hook_normalized(x / scale) # [batch, pos, length]
return x * self.w, self


class RMSNormPre(nnx.Module):
Expand All @@ -149,8 +151,7 @@ def __init__(self, cfg: HookedTransformerConfig):
self.hook_scale = HookPoint() # [batch, pos]
self.hook_normalized = HookPoint() # [batch, pos, length]

def __call__(self, x: Float[jax.Array, "batch pos length"]) -> Float[jax.Array, "batch pos length"]:
scale: Float[jax.Array, "batch pos 1"] = self.hook_scale(
jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps)
)
return self.hook_normalized(x / scale) # [batch, pos, length]
def __call__(self, x: Float[jax.Array, "batch pos length"]) -> tuple[Float[jax.Array, "batch pos length"], Self]:
scale, self.hook_scale = self.hook_scale(jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps))
x, self.hook_normalized = self.hook_normalized(x / scale) # [batch, pos, length]
return x, self
32 changes: 18 additions & 14 deletions src/xlens/components/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module contains all the component :class:`MLP`.
"""

from typing import Callable, Optional, Union
from typing import Callable, Optional, Self, Union

import flax.nnx as nnx
import jax
Expand Down Expand Up @@ -61,19 +61,20 @@ def __init__(self, cfg: HookedTransformerConfig):
self.hook_mid = None
self.ln = None

def __call__(self, x: Float[jax.Array, "batch pos d_model"]) -> Float[jax.Array, "batch pos d_model"]:
def __call__(self, x: Float[jax.Array, "batch pos d_model"]) -> tuple[Float[jax.Array, "batch pos d_model"], Self]:
# There's no fused `addmm` here. May cause performance issues.
pre_act = self.hook_pre(x @ self.W_in + self.b_in)
pre_act, self.hook_pre = self.hook_pre(x @ self.W_in + self.b_in)

if self.cfg.is_layer_norm_activation():
assert (
self.ln is not None and self.hook_mid is not None
), "LayerNorm and HookPoint must be set for layer norm activation"
mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
post_act = self.hook_post(self.ln(mid_act))
mid_act, self.hook_mid = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
mid_act, self.ln = self.ln(mid_act)
post_act, self.hook_post = self.hook_post(mid_act)
else:
post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp]
return post_act @ self.W_out + self.b_out
post_act, self.hook_post = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp]
return post_act @ self.W_out + self.b_out, self


class GatedMLP(nnx.Module):
Expand Down Expand Up @@ -140,18 +141,21 @@ def __init__(self, cfg: HookedTransformerConfig):
self.hook_mid = None
self.ln = None

def __call__(self, x: Float[jax.Array, "batch pos d_model"]) -> Float[jax.Array, "batch pos d_model"]:
def __call__(self, x: Float[jax.Array, "batch pos d_model"]) -> tuple[Float[jax.Array, "batch pos d_model"], Self]:
# Technically, all these einsums could be done with a single matmul, but this is more readable.
pre_act = self.hook_pre(x @ self.W_gate)
pre_act, self.hook_pre = self.hook_pre(x @ self.W_gate)

if self.cfg.is_layer_norm_activation() and self.hook_mid is not None and self.ln is not None:
assert (
self.ln is not None and self.hook_mid is not None
), "LayerNorm and HookPoint must be set for layer norm activation"
mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
post_act = self.hook_post(self.ln(mid_act))
mid_act, self.hook_mid = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
mid_act, self.ln = self.ln(mid_act)
post_act, self.hook_post = self.hook_post(mid_act)
else:
pre_linear = self.hook_pre_linear(x @ self.W_in)
post_act = self.hook_post((self.act_fn(pre_act) * pre_linear) + self.b_in) # [batch, pos, d_mlp]
pre_linear, self.hook_pre_linear = self.hook_pre_linear(x @ self.W_in)
post_act, self.hook_post = self.hook_post(
(self.act_fn(pre_act) * pre_linear) + self.b_in
) # [batch, pos, d_mlp]

return post_act @ self.W_out + self.b_out
return post_act @ self.W_out + self.b_out, self
45 changes: 27 additions & 18 deletions src/xlens/components/transformer_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Self, Union
from typing import Any, Callable, Optional, Self, Union

import flax.nnx as nnx
import jax
Expand All @@ -19,8 +19,8 @@ class TransformerBlock(nnx.Module):

layer_id: Optional[int]

ln1: Callable[[Float[jax.Array, "batch pos d_model"]], Float[jax.Array, "batch pos d_model"]]
ln2: Optional[Callable[[Float[jax.Array, "batch pos d_model"]], Float[jax.Array, "batch pos d_model"]]]
ln1: Callable[[Float[jax.Array, "batch pos d_model"]], tuple[Float[jax.Array, "batch pos d_model"], Any]]
ln2: Optional[Callable[[Float[jax.Array, "batch pos d_model"]], tuple[Float[jax.Array, "batch pos d_model"], Any]]]
attn: Attention
mlp: Optional[MLP | GatedMLP]

Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, cfg: HookedTransformerConfig, block_index: int):
# We need to make this a lambda so we can call it on the config, just like the others
def normalization_layer(cfg: HookedTransformerConfig):
def identity(x: jax.Array):
return x
return x, identity

return identity
else:
Expand Down Expand Up @@ -105,7 +105,7 @@ def __call__(
Returns:
Float[jax.Array, "batch pos d_model"]: Our resulting tensor
"""
resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model]
resid_pre, self.hook_resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model]

attn_in = resid_pre

Expand All @@ -117,35 +117,43 @@ def __call__(
# queries, keys and values, independently.
# Then take the layer norm of these inputs, and pass these to the attention module.

query_input, self.ln1 = self.ln1(query_input)
key_input, self.ln1 = self.ln1(key_input)
value_input, self.ln1 = self.ln1(value_input)

attn_out, self.attn = self.attn(
query_input=self.ln1(query_input),
key_input=self.ln1(key_input),
value_input=self.ln1(value_input),
query_input=query_input,
key_input=key_input,
value_input=value_input,
attention_mask=attention_mask,
)

attn_out = self.hook_attn_out(attn_out) # [batch, pos, d_model]
attn_out, self.hook_attn_out = self.hook_attn_out(attn_out) # [batch, pos, d_model]

if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
assert (
self.mlp is not None and self.ln2 is not None and self.hook_resid_mid is not None
), "MLP, LayerNorm2 and hook_resid_mid must be defined if attn_only is False"
resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
mlp_in = self.hook_mlp_in(resid_mid)
normalized_resid_mid = self.ln2(mlp_in)
resid_mid, self.hook_resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
mlp_in, self.hook_mlp_in = self.hook_mlp_in(resid_mid)
normalized_resid_mid, self.ln2 = self.ln2(mlp_in)
mlp_out = self.apply_mlp(normalized_resid_mid)
resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model]
resid_post, self.hook_resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model]

elif self.cfg.parallel_attn_mlp:
# Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used.
# In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't.
assert (
self.mlp is not None and self.ln2 is not None
), "MLP and LayerNorm2 must be defined if parallel_attn_mlp is True"
normalized_resid_pre_2 = self.ln2(self.hook_mlp_in(resid_pre))
mlp_in, self.hook_mlp_in = self.hook_mlp_in(resid_pre)
normalized_resid_pre_2, self.ln2 = self.ln2(mlp_in)
mlp_out = self.apply_mlp(normalized_resid_pre_2)
resid_post = self.hook_resid_post(resid_pre + attn_out + mlp_out) # [batch, pos, d_model]
resid_post, self.hook_resid_post = self.hook_resid_post(
resid_pre + attn_out + mlp_out
) # [batch, pos, d_model]
else:
resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model]
resid_post, self.hook_resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model]

return resid_post, self

Expand All @@ -158,5 +166,6 @@ def apply_mlp(
Float[jax.Array, "batch pos d_model"]: Our resulting tensor
"""
assert self.mlp is not None, "MLP must be defined if apply_mlp is called"
mlp_out = self.mlp(normalized_resid) # [batch, pos, d_model]
return self.hook_mlp_out(mlp_out)
mlp_out, self.mlp = self.mlp(normalized_resid) # [batch, pos, d_model]
mlp_out, self.hook_mlp_out = self.hook_mlp_out(mlp_out)
return mlp_out
15 changes: 9 additions & 6 deletions src/xlens/hooked_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Unembed,
)
from xlens.hooks import with_cache, with_hooks
from xlens.hooks.utilities import retrieve_cache
from xlens.pretrained.convert import get_pretrained_model_config, get_pretrained_weights
from xlens.utilities.functional import functional
from xlens.utils import load_pretrained_weights
Expand Down Expand Up @@ -100,11 +101,12 @@ def __call__(
is not computed automatically. Defaults to None.
"""

tokens = self.hook_tokens(input_ids) # [batch, pos]
embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model]
tokens, self.hook_tokens = self.hook_tokens(input_ids) # [batch, pos]
embed, self.embed = self.embed(tokens) # [batch, pos, d_model]
embed, self.hook_embed = self.hook_embed(embed) # [batch, pos, d_model]
self._check_kv_cache_consistency() # Check that the KV cache is consistent
past_kv_pos_offset = self.blocks[0].attn.past_kv_cache.length
pos_embed = self.hook_pos_embed(
pos_embed, self.hook_pos_embed = self.hook_pos_embed(
self.pos_embed(tokens, past_kv_pos_offset, attention_mask)
) # [batch, pos, d_model]
residual = embed + pos_embed
Expand All @@ -120,7 +122,7 @@ def __call__(

if self.cfg.normalization_type is not None:
assert self.ln_final is not None, "ln_final should be set if normalization_type is set"
residual = self.ln_final(residual) # [batch, pos, d_model]
residual, self.ln_final = self.ln_final(residual) # [batch, pos, d_model]

logits = self.unembed(residual) # [batch, pos, d_vocab]
return logits, self
Expand All @@ -146,7 +148,7 @@ def run_with_hooks(
self,
input_ids: Int[jax.Array, "batch pos"],
attention_mask: Optional[jax.Array] = None, # [batch pos]
hooks: list[tuple[str, Callable[[Any], Any]]] = [],
hooks: list[tuple[str, Callable[[Any, Any], tuple[Any, Any]]]] = [],
) -> tuple[Float[jax.Array, "batch pos d_vocab"], Self]:
"""Forward Pass with hooks.
Expand Down Expand Up @@ -187,9 +189,10 @@ def run_with_cache(
hook_names: list[str]: A list of strings, where each string is the name of a hook point
"""

model, cache = with_cache(self, hook_names)
model = with_cache(self, hook_names)

out, model = model(input_ids, attention_mask=attention_mask)
cache = retrieve_cache(model, hook_names)

return out, cache, model

Expand Down
Loading

0 comments on commit 09d68b0

Please sign in to comment.