From 09d68b04a382ca84f2cb828f1441b0289d51cfce Mon Sep 17 00:00:00 2001 From: Dest1n1 Date: Tue, 26 Nov 2024 11:28:02 +0800 Subject: [PATCH] refactor: add hook point state --- src/xlens/components/attention.py | 18 ++++---- src/xlens/components/embed.py | 10 +++-- src/xlens/components/layer_norm.py | 53 ++++++++++++----------- src/xlens/components/mlp.py | 32 ++++++++------ src/xlens/components/transformer_block.py | 45 +++++++++++-------- src/xlens/hooked_transformer.py | 15 ++++--- src/xlens/hooks/hook_point.py | 25 +++++++---- src/xlens/hooks/utilities.py | 21 ++++----- tests/unit/test_with_hooks.py | 21 ++++++--- 9 files changed, 139 insertions(+), 101 deletions(-) diff --git a/src/xlens/components/attention.py b/src/xlens/components/attention.py index 7ef6532..ebe5946 100644 --- a/src/xlens/components/attention.py +++ b/src/xlens/components/attention.py @@ -185,10 +185,10 @@ def __call__( if self.cfg.positional_embedding_type == "rotary": assert self.hook_rot_k is not None and self.hook_rot_q is not None, "Rotary hooks must be defined" assert self.cfg.rotary_dim is not None, "Rotary dim must be defined" - q = self.hook_rot_q( + q, self.hook_rot_q = self.hook_rot_q( self.apply_rotary(q, kv_cache_pos_offset, attention_mask, rotary_dim=self.cfg.rotary_dim) ) - k = self.hook_rot_k( + k, self.hook_rot_k = self.hook_rot_k( self.apply_rotary(k, 0, attention_mask, rotary_dim=self.cfg.rotary_dim) ) # keys are cached so no offset @@ -206,13 +206,13 @@ def __call__( attn_scores, kv_cache_pos_offset, attention_mask ) # [batch, head_index, query_pos, key_pos] - attn_scores = self.hook_attn_scores(attn_scores) + attn_scores, self.hook_attn_scores = self.hook_attn_scores(attn_scores) pattern = jax.nn.softmax(attn_scores, axis=-1) pattern = jnp.where(jnp.isnan(pattern), jnp.zeros_like(pattern), pattern) - pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] + pattern, self.hook_pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] - result = self.hook_result( + result, self.hook_result = self.hook_result( einops.einsum( z, self.W_O.value, @@ -250,9 +250,9 @@ def attn_fn( + b ) - q = self.hook_q(attn_fn(query_input, self.W_Q.value, self.b_Q.value)) - k = self.hook_k(attn_fn(key_input, self.W_K.value, self.b_K.value)) - v = self.hook_v(attn_fn(value_input, self.W_V.value, self.b_V.value)) + q, self.hook_q = self.hook_q(attn_fn(query_input, self.W_Q.value, self.b_Q.value)) + k, self.hook_k = self.hook_k(attn_fn(key_input, self.W_K.value, self.b_K.value)) + v, self.hook_v = self.hook_v(attn_fn(value_input, self.W_V.value, self.b_V.value)) return q, k, v @@ -288,7 +288,7 @@ def calculate_z_scores( pattern, "batch head_index query_pos key_pos -> batch head_index query_pos key_pos", ) - z = self.hook_z( + z, self.hook_z = self.hook_z( einops.rearrange( pattern_ @ v_, "batch head_index query_pos d_head -> batch query_pos head_index d_head", diff --git a/src/xlens/components/embed.py b/src/xlens/components/embed.py index 57bc9b0..611d3d6 100644 --- a/src/xlens/components/embed.py +++ b/src/xlens/components/embed.py @@ -3,7 +3,7 @@ This module contains all the component :class:`Embed`. """ -from typing import Optional +from typing import Optional, Self import einops import flax.nnx as nnx @@ -30,13 +30,15 @@ def __init__(self, cfg: HookedTransformerConfig): # Some models (e.g. Bloom) need post embedding layer norm self.ln = LayerNorm(self.cfg) if self.cfg.post_embedding_ln else None - def __call__(self, tokens: Int[jax.Array, "batch pos"]) -> Float[jax.Array, "batch pos d_model"]: + def __call__(self, tokens: Int[jax.Array, "batch pos"]) -> tuple[Float[jax.Array, "batch pos d_model"], Self]: # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d] # B acts as a tensor of indices into the second dimension (so >=0 and Union[ - Float[jax.Array, "batch pos d_model"], - Float[jax.Array, "batch pos head_index d_model"], + ) -> tuple[ + Union[ + Float[jax.Array, "batch pos d_model"], + Float[jax.Array, "batch pos head_index d_model"], + ], + Self, ]: x = x - x.mean(-1, keepdims=True) # [batch, pos, length] - scale: Float[jax.Array, "batch pos 1"] = self.hook_scale( - jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps) - ) + scale, self.hook_scale = self.hook_scale(jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps)) x = x / scale # [batch, pos, length] - return self.hook_normalized(x * self.w + self.b) + x, self.hook_normalized = self.hook_normalized(x * self.w + self.b) + return x, self class LayerNormPre(nnx.Module): @@ -87,15 +89,17 @@ def __call__( Float[jax.Array, "batch pos d_model"], Float[jax.Array, "batch pos head_index d_model"], ], - ) -> Union[ - Float[jax.Array, "batch pos d_model"], - Float[jax.Array, "batch pos head_index d_model"], + ) -> tuple[ + Union[ + Float[jax.Array, "batch pos d_model"], + Float[jax.Array, "batch pos head_index d_model"], + ], + Self, ]: x = x - x.mean(-1, keepdims=True) # [batch, pos, length] - scale: Float[jax.Array, "batch pos 1"] = self.hook_scale( - jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps) - ) - return self.hook_normalized(x / scale) + scale, self.hook_scale = self.hook_scale(jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps)) + x, self.hook_normalized = self.hook_normalized(x / scale) + return x, self class RMSNorm(nnx.Module): @@ -126,12 +130,10 @@ def __init__(self, cfg: HookedTransformerConfig, length: Optional[int] = None): self.hook_scale = HookPoint() # [batch, pos, 1] self.hook_normalized = HookPoint() # [batch, pos, length] - def __call__(self, x: Float[jax.Array, "batch pos length"]) -> Float[jax.Array, "batch pos length"]: - scale: Float[jax.Array, "batch pos 1"] = self.hook_scale( - jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps) - ) - x = self.hook_normalized(x / scale) # [batch, pos, length] - return x * self.w + def __call__(self, x: Float[jax.Array, "batch pos length"]) -> tuple[Float[jax.Array, "batch pos length"], Self]: + scale, self.hook_scale = self.hook_scale(jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps)) + x, self.hook_normalized = self.hook_normalized(x / scale) # [batch, pos, length] + return x * self.w, self class RMSNormPre(nnx.Module): @@ -149,8 +151,7 @@ def __init__(self, cfg: HookedTransformerConfig): self.hook_scale = HookPoint() # [batch, pos] self.hook_normalized = HookPoint() # [batch, pos, length] - def __call__(self, x: Float[jax.Array, "batch pos length"]) -> Float[jax.Array, "batch pos length"]: - scale: Float[jax.Array, "batch pos 1"] = self.hook_scale( - jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps) - ) - return self.hook_normalized(x / scale) # [batch, pos, length] + def __call__(self, x: Float[jax.Array, "batch pos length"]) -> tuple[Float[jax.Array, "batch pos length"], Self]: + scale, self.hook_scale = self.hook_scale(jnp.sqrt((x**2).mean(-1, keepdims=True) + self.cfg.eps)) + x, self.hook_normalized = self.hook_normalized(x / scale) # [batch, pos, length] + return x, self diff --git a/src/xlens/components/mlp.py b/src/xlens/components/mlp.py index ea7e8ac..707530d 100644 --- a/src/xlens/components/mlp.py +++ b/src/xlens/components/mlp.py @@ -3,7 +3,7 @@ This module contains all the component :class:`MLP`. """ -from typing import Callable, Optional, Union +from typing import Callable, Optional, Self, Union import flax.nnx as nnx import jax @@ -61,19 +61,20 @@ def __init__(self, cfg: HookedTransformerConfig): self.hook_mid = None self.ln = None - def __call__(self, x: Float[jax.Array, "batch pos d_model"]) -> Float[jax.Array, "batch pos d_model"]: + def __call__(self, x: Float[jax.Array, "batch pos d_model"]) -> tuple[Float[jax.Array, "batch pos d_model"], Self]: # There's no fused `addmm` here. May cause performance issues. - pre_act = self.hook_pre(x @ self.W_in + self.b_in) + pre_act, self.hook_pre = self.hook_pre(x @ self.W_in + self.b_in) if self.cfg.is_layer_norm_activation(): assert ( self.ln is not None and self.hook_mid is not None ), "LayerNorm and HookPoint must be set for layer norm activation" - mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] - post_act = self.hook_post(self.ln(mid_act)) + mid_act, self.hook_mid = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] + mid_act, self.ln = self.ln(mid_act) + post_act, self.hook_post = self.hook_post(mid_act) else: - post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp] - return post_act @ self.W_out + self.b_out + post_act, self.hook_post = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp] + return post_act @ self.W_out + self.b_out, self class GatedMLP(nnx.Module): @@ -140,18 +141,21 @@ def __init__(self, cfg: HookedTransformerConfig): self.hook_mid = None self.ln = None - def __call__(self, x: Float[jax.Array, "batch pos d_model"]) -> Float[jax.Array, "batch pos d_model"]: + def __call__(self, x: Float[jax.Array, "batch pos d_model"]) -> tuple[Float[jax.Array, "batch pos d_model"], Self]: # Technically, all these einsums could be done with a single matmul, but this is more readable. - pre_act = self.hook_pre(x @ self.W_gate) + pre_act, self.hook_pre = self.hook_pre(x @ self.W_gate) if self.cfg.is_layer_norm_activation() and self.hook_mid is not None and self.ln is not None: assert ( self.ln is not None and self.hook_mid is not None ), "LayerNorm and HookPoint must be set for layer norm activation" - mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] - post_act = self.hook_post(self.ln(mid_act)) + mid_act, self.hook_mid = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] + mid_act, self.ln = self.ln(mid_act) + post_act, self.hook_post = self.hook_post(mid_act) else: - pre_linear = self.hook_pre_linear(x @ self.W_in) - post_act = self.hook_post((self.act_fn(pre_act) * pre_linear) + self.b_in) # [batch, pos, d_mlp] + pre_linear, self.hook_pre_linear = self.hook_pre_linear(x @ self.W_in) + post_act, self.hook_post = self.hook_post( + (self.act_fn(pre_act) * pre_linear) + self.b_in + ) # [batch, pos, d_mlp] - return post_act @ self.W_out + self.b_out + return post_act @ self.W_out + self.b_out, self diff --git a/src/xlens/components/transformer_block.py b/src/xlens/components/transformer_block.py index dadc6f3..81973ef 100644 --- a/src/xlens/components/transformer_block.py +++ b/src/xlens/components/transformer_block.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Self, Union +from typing import Any, Callable, Optional, Self, Union import flax.nnx as nnx import jax @@ -19,8 +19,8 @@ class TransformerBlock(nnx.Module): layer_id: Optional[int] - ln1: Callable[[Float[jax.Array, "batch pos d_model"]], Float[jax.Array, "batch pos d_model"]] - ln2: Optional[Callable[[Float[jax.Array, "batch pos d_model"]], Float[jax.Array, "batch pos d_model"]]] + ln1: Callable[[Float[jax.Array, "batch pos d_model"]], tuple[Float[jax.Array, "batch pos d_model"], Any]] + ln2: Optional[Callable[[Float[jax.Array, "batch pos d_model"]], tuple[Float[jax.Array, "batch pos d_model"], Any]]] attn: Attention mlp: Optional[MLP | GatedMLP] @@ -55,7 +55,7 @@ def __init__(self, cfg: HookedTransformerConfig, block_index: int): # We need to make this a lambda so we can call it on the config, just like the others def normalization_layer(cfg: HookedTransformerConfig): def identity(x: jax.Array): - return x + return x, identity return identity else: @@ -105,7 +105,7 @@ def __call__( Returns: Float[jax.Array, "batch pos d_model"]: Our resulting tensor """ - resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model] + resid_pre, self.hook_resid_pre = self.hook_resid_pre(resid_pre) # [batch, pos, d_model] attn_in = resid_pre @@ -117,35 +117,43 @@ def __call__( # queries, keys and values, independently. # Then take the layer norm of these inputs, and pass these to the attention module. + query_input, self.ln1 = self.ln1(query_input) + key_input, self.ln1 = self.ln1(key_input) + value_input, self.ln1 = self.ln1(value_input) + attn_out, self.attn = self.attn( - query_input=self.ln1(query_input), - key_input=self.ln1(key_input), - value_input=self.ln1(value_input), + query_input=query_input, + key_input=key_input, + value_input=value_input, attention_mask=attention_mask, ) - attn_out = self.hook_attn_out(attn_out) # [batch, pos, d_model] + attn_out, self.hook_attn_out = self.hook_attn_out(attn_out) # [batch, pos, d_model] if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp: assert ( self.mlp is not None and self.ln2 is not None and self.hook_resid_mid is not None ), "MLP, LayerNorm2 and hook_resid_mid must be defined if attn_only is False" - resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model] - mlp_in = self.hook_mlp_in(resid_mid) - normalized_resid_mid = self.ln2(mlp_in) + resid_mid, self.hook_resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model] + mlp_in, self.hook_mlp_in = self.hook_mlp_in(resid_mid) + normalized_resid_mid, self.ln2 = self.ln2(mlp_in) mlp_out = self.apply_mlp(normalized_resid_mid) - resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model] + resid_post, self.hook_resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model] + elif self.cfg.parallel_attn_mlp: # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used. # In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't. assert ( self.mlp is not None and self.ln2 is not None ), "MLP and LayerNorm2 must be defined if parallel_attn_mlp is True" - normalized_resid_pre_2 = self.ln2(self.hook_mlp_in(resid_pre)) + mlp_in, self.hook_mlp_in = self.hook_mlp_in(resid_pre) + normalized_resid_pre_2, self.ln2 = self.ln2(mlp_in) mlp_out = self.apply_mlp(normalized_resid_pre_2) - resid_post = self.hook_resid_post(resid_pre + attn_out + mlp_out) # [batch, pos, d_model] + resid_post, self.hook_resid_post = self.hook_resid_post( + resid_pre + attn_out + mlp_out + ) # [batch, pos, d_model] else: - resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model] + resid_post, self.hook_resid_post = self.hook_resid_post(resid_pre + attn_out) # [batch, pos, d_model] return resid_post, self @@ -158,5 +166,6 @@ def apply_mlp( Float[jax.Array, "batch pos d_model"]: Our resulting tensor """ assert self.mlp is not None, "MLP must be defined if apply_mlp is called" - mlp_out = self.mlp(normalized_resid) # [batch, pos, d_model] - return self.hook_mlp_out(mlp_out) + mlp_out, self.mlp = self.mlp(normalized_resid) # [batch, pos, d_model] + mlp_out, self.hook_mlp_out = self.hook_mlp_out(mlp_out) + return mlp_out diff --git a/src/xlens/hooked_transformer.py b/src/xlens/hooked_transformer.py index d3e4388..4b06938 100644 --- a/src/xlens/hooked_transformer.py +++ b/src/xlens/hooked_transformer.py @@ -18,6 +18,7 @@ Unembed, ) from xlens.hooks import with_cache, with_hooks +from xlens.hooks.utilities import retrieve_cache from xlens.pretrained.convert import get_pretrained_model_config, get_pretrained_weights from xlens.utilities.functional import functional from xlens.utils import load_pretrained_weights @@ -100,11 +101,12 @@ def __call__( is not computed automatically. Defaults to None. """ - tokens = self.hook_tokens(input_ids) # [batch, pos] - embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model] + tokens, self.hook_tokens = self.hook_tokens(input_ids) # [batch, pos] + embed, self.embed = self.embed(tokens) # [batch, pos, d_model] + embed, self.hook_embed = self.hook_embed(embed) # [batch, pos, d_model] self._check_kv_cache_consistency() # Check that the KV cache is consistent past_kv_pos_offset = self.blocks[0].attn.past_kv_cache.length - pos_embed = self.hook_pos_embed( + pos_embed, self.hook_pos_embed = self.hook_pos_embed( self.pos_embed(tokens, past_kv_pos_offset, attention_mask) ) # [batch, pos, d_model] residual = embed + pos_embed @@ -120,7 +122,7 @@ def __call__( if self.cfg.normalization_type is not None: assert self.ln_final is not None, "ln_final should be set if normalization_type is set" - residual = self.ln_final(residual) # [batch, pos, d_model] + residual, self.ln_final = self.ln_final(residual) # [batch, pos, d_model] logits = self.unembed(residual) # [batch, pos, d_vocab] return logits, self @@ -146,7 +148,7 @@ def run_with_hooks( self, input_ids: Int[jax.Array, "batch pos"], attention_mask: Optional[jax.Array] = None, # [batch pos] - hooks: list[tuple[str, Callable[[Any], Any]]] = [], + hooks: list[tuple[str, Callable[[Any, Any], tuple[Any, Any]]]] = [], ) -> tuple[Float[jax.Array, "batch pos d_vocab"], Self]: """Forward Pass with hooks. @@ -187,9 +189,10 @@ def run_with_cache( hook_names: list[str]: A list of strings, where each string is the name of a hook point """ - model, cache = with_cache(self, hook_names) + model = with_cache(self, hook_names) out, model = model(input_ids, attention_mask=attention_mask) + cache = retrieve_cache(model, hook_names) return out, cache, model diff --git a/src/xlens/hooks/hook_point.py b/src/xlens/hooks/hook_point.py index 78e3208..139727f 100644 --- a/src/xlens/hooks/hook_point.py +++ b/src/xlens/hooks/hook_point.py @@ -1,26 +1,33 @@ -from typing import Callable, Generic +from typing import Any, Callable, Generic, Self import flax.nnx as nnx import jax from typing_extensions import TypeVar +from xlens.utilities.functional import functional + T = TypeVar("T", default=jax.Array) class HookPoint(nnx.Module, Generic[T]): - def __init__(self, hooks: list[Callable[[T], T]] = []): + hooks: list[Callable[[T, Any], tuple[T, Any]]] + state: nnx.Variable[Any] + + def __init__(self, hooks: list[Callable[[T, Any], tuple[T, Any]]] = [], state: Any = None): self.hooks = hooks + self.state = nnx.Variable(state) - def __call__(self, x: T) -> T: + @functional + def __call__(self, x: T) -> tuple[T, Self]: for hook in self.hooks: - x = hook(x) - return x + x, self.state.value = hook(x, self.state.value) + return x, self - def append_hook(self, hook: Callable[[T], T]) -> "HookPoint[T]": - return HookPoint(self.hooks + [hook]) + def append_hook(self, hook: Callable[[T, Any], tuple[T, Any]]) -> "HookPoint[T]": + return HookPoint(self.hooks + [hook], self.state) - def prepend_hook(self, hook: Callable[[T], T]) -> "HookPoint[T]": - return HookPoint([hook] + self.hooks) + def prepend_hook(self, hook: Callable[[T, Any], tuple[T, Any]]) -> "HookPoint[T]": + return HookPoint([hook] + self.hooks, self.state) def clear_hooks(self) -> "HookPoint[T]": return HookPoint([]) diff --git a/src/xlens/hooks/utilities.py b/src/xlens/hooks/utilities.py index a709af8..f053023 100644 --- a/src/xlens/hooks/utilities.py +++ b/src/xlens/hooks/utilities.py @@ -12,7 +12,7 @@ @functional -def with_hooks(tree: U, hooks: list[tuple[str, Callable[[Any], Any]]] = []) -> U: +def with_hooks(tree: U, hooks: list[tuple[str, Callable[[Any, Any], tuple[Any, Any]]]] = []) -> U: """Set hooks on a tree of objects. Args: @@ -31,26 +31,27 @@ def with_hooks(tree: U, hooks: list[tuple[str, Callable[[Any], Any]]] = []) -> U return tree -def with_cache(tree: U, hook_names: list[str] = []) -> tuple[U, dict[str, Any]]: +def with_cache(tree: U, hook_names: list[str] = []) -> U: """Set hooks on a tree of objects. - Warning: This is not a pure function. Each time the tree is called, the cache will be updated. - Do JIT outside the full scope of the cache. + This function uses the state of the hook point to cache the values. Args: tree: U: The tree of objects to set hooks on. hook_names: list[str]: A list of strings, where each string is the name of a hook point """ - cache = {} - def hook_fn(name: str): - def _hook_fn(x: Any): - cache[name] = x - return x + def _hook_fn(x: Any, state: Any): + assert state is None, "State to cache should be None" + return x, x return _hook_fn tree = with_hooks(tree, [(name, hook_fn(name)) for name in hook_names]) - return tree, cache + return tree + + +def retrieve_cache(tree: Any, hook_names: list[str] = []) -> dict[str, Any]: + return {name: get_nested_attr(tree, name).state.value for name in hook_names} diff --git a/tests/unit/test_with_hooks.py b/tests/unit/test_with_hooks.py index c0bf4e9..1fb8b8a 100644 --- a/tests/unit/test_with_hooks.py +++ b/tests/unit/test_with_hooks.py @@ -1,8 +1,12 @@ +from typing import Self + import flax.nnx as nnx import jax import jax.numpy as jnp from xlens import HookPoint, with_cache, with_hooks +from xlens.hooks.utilities import retrieve_cache +from xlens.utilities.functional import functional class ModuleA(nnx.Module): @@ -11,13 +15,15 @@ class ModuleA(nnx.Module): def __init__(self, hook_mid: HookPoint): self.hook_mid = hook_mid - def __call__(self, x: jax.Array) -> jax.Array: - return self.hook_mid(x * 2) * 2 + @functional + def __call__(self, x: jax.Array) -> tuple[jax.Array, Self]: + x, self.hook_mid = self.hook_mid(x * 2) + return x * 2, self def test_with_hooks(): a = ModuleA(HookPoint()) - a_with_hooks = with_hooks(a, [("hook_mid", lambda x: x + 1)]) + a_with_hooks = with_hooks(a, [("hook_mid", lambda x, state: (x + 1, state))]) y_with_hooks = a_with_hooks(jnp.array(1.0)) assert jnp.allclose(y_with_hooks, 6.0) y = a(jnp.array(1.0)) @@ -26,9 +32,14 @@ def test_with_hooks(): def test_with_cache(): a = ModuleA(HookPoint()) - a, cache = with_cache(a, ["hook_mid"]) - y = a(jnp.array(1.0)) + a = with_cache(a, ["hook_mid"]) + y, a = a(jnp.array(1.0)) + cache = retrieve_cache(a, ["hook_mid"]) assert jnp.allclose(y, 4.0) assert "hook_mid" in cache assert jnp.allclose(cache["hook_mid"], 2.0) + + +if __name__ == "__main__": + test_with_cache()