Skip to content

Commit

Permalink
chore: add strict type annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Nov 5, 2024
1 parent e722955 commit 441c0f4
Show file tree
Hide file tree
Showing 20 changed files with 53 additions and 39 deletions.
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,10 @@ docstring-code-line-length = "dynamic"


[tool.pyright]
typeCheckingMode = "standard"
reportUnknownMemberType = false
typeCheckingMode = "strict"
reportUnknownMemberType = false
reportUntypedFunctionDecorator = false
reportUnknownArgumentType = false
reportUnknownVariableType = false
reportMissingTypeStubs = false
reportConstantRedefinition = false
4 changes: 2 additions & 2 deletions src/xlens/components/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def calculate_sin_cos_rotary(
rotary_dim: int,
n_ctx: int,
base: int = 10000,
dtype=jnp.float32,
dtype: jnp.dtype = jnp.float32, # type: ignore
use_NTK_by_parts_rope: bool = False,
NTK_by_parts_factor: float = 8.0,
NTK_by_parts_low_freq_factor: float = 1.0,
Expand Down Expand Up @@ -368,7 +368,7 @@ def calculate_sin_cos_rotary(
def apply_rotary(
self,
x: Float[jax.Array, "batch pos head_index d_head"],
past_kv_pos_offset=0,
past_kv_pos_offset: int = 0,
attention_mask: Optional[jnp.ndarray] = None,
rotary_dim: int = 64,
) -> jnp.ndarray:
Expand Down
2 changes: 1 addition & 1 deletion src/xlens/components/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, cfg: HookedTransformerConfig, block_index: int):
elif cfg.normalization_type is None:
# This should just be the identity.
# We need to make this a lambda so we can call it on the config, just like the others
def normalization_layer(cfg):
def normalization_layer(cfg: HookedTransformerConfig):
def identity(x: jax.Array):
return x

Expand Down
4 changes: 2 additions & 2 deletions src/xlens/hooked_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __call__(
pos_embed = self.hook_pos_embed(self.pos_embed(tokens, 0, attention_mask)) # [batch, pos, d_model]
residual = embed + pos_embed

for i, block in list(zip(range(self.cfg.n_layers), self.blocks)):
for _, block in list(zip(range(self.cfg.n_layers), self.blocks)):
# Note that each block includes skip connections, so we don't need
# residual + block(residual)
residual = block(
Expand Down Expand Up @@ -157,7 +157,7 @@ def run_with_cache(
return out, cache

@classmethod
def from_pretrained(cls, model_name: str, hf_model=None) -> "HookedTransformer":
def from_pretrained(cls, model_name: str, hf_model: Any = None) -> "HookedTransformer":
"""Load a pretrained model.
Args:
Expand Down
2 changes: 1 addition & 1 deletion src/xlens/hooks/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def with_cache(tree: U, hook_names: list[str] = []) -> tuple[U, dict[str, Any]]:
cache = {}

def hook_fn(name: str):
def _hook_fn(x):
def _hook_fn(x: Any):
cache[name] = x
return x

Expand Down
4 changes: 3 additions & 1 deletion src/xlens/pretrained/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
This module contains functions for loading pretrained models from the Hugging Face Hub.
"""

from typing import Any

import jax

from xlens.config import HookedTransformerConfig
Expand Down Expand Up @@ -30,5 +32,5 @@ def get_pretrained_model_config(model_name: str) -> HookedTransformerConfig:
return converter.get_pretrained_model_config(model_name)


def get_pretrained_weights(cfg: HookedTransformerConfig, model_name: str, hf_model=None) -> dict[str, jax.Array]:
def get_pretrained_weights(cfg: HookedTransformerConfig, model_name: str, hf_model: Any = None) -> dict[str, jax.Array]:
return converter.get_pretrained_weights(cfg, model_name, hf_model=hf_model)
2 changes: 1 addition & 1 deletion src/xlens/pretrained/converters/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def convert_hf_weights(
hf_weights = {f"transformer.{k}": v for k, v in hf_weights.items()} | {
"lm_head.weight": hf_weights["wte.weight"]
}
state_dict = {}
state_dict: dict[str, jax.Array] = {}

state_dict["embed.W_E"] = hf_weights["transformer.wte.weight"]
assert state_dict["embed.W_E"].shape == (cfg.d_vocab, cfg.d_model)
Expand Down
4 changes: 2 additions & 2 deletions src/xlens/pretrained/converters/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self):

def convert_hf_model_config(self, hf_cfg: Any) -> HookedTransformerConfig:
if hasattr(hf_cfg, "rope_scaling") and hf_cfg.rope_scaling is not None:
ntk_cfg = {
ntk_cfg: dict[str, Any] = {
"use_NTK_by_parts_rope": True,
"NTK_by_parts_low_freq_factor": hf_cfg.rope_scaling["low_freq_factor"],
"NTK_by_parts_high_freq_factor": hf_cfg.rope_scaling["high_freq_factor"],
Expand Down Expand Up @@ -102,7 +102,7 @@ def convert_hf_weights(
hf_weights = {f"model.{k}": v for k, v in hf_weights.items()}
if "lm_head.weight" not in hf_weights:
hf_weights = {**hf_weights, "lm_head.weight": hf_weights["model.embed_tokens.weight"]}
state_dict = {}
state_dict: dict[str, jax.Array] = {}

state_dict["embed.W_E"] = hf_weights["model.embed_tokens.weight"]

Expand Down
2 changes: 1 addition & 1 deletion src/xlens/pretrained/converters/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def convert_hf_weights(
hf_weights = {f"model.{k}": v for k, v in hf_weights.items()}
if "lm_head.weight" not in hf_weights:
hf_weights = {**hf_weights, "lm_head.weight": hf_weights["model.embed_tokens.weight"]}
state_dict = {}
state_dict: dict[str, jax.Array] = {}

state_dict["embed.W_E"] = hf_weights["model.embed_tokens.weight"]

Expand Down
2 changes: 1 addition & 1 deletion src/xlens/pretrained/converters/neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def convert_hf_weights(
if "embed_out.weight" not in hf_weights:
hf_weights = {**hf_weights, "embed_out.weight": hf_weights["gpt_neox.embed_in.weight"]}

state_dict = {}
state_dict: dict[str, jax.Array] = {}

state_dict["embed.W_E"] = hf_weights["gpt_neox.embed_in.weight"]

Expand Down
2 changes: 1 addition & 1 deletion src/xlens/pretrained/converters/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def convert_hf_weights(
hf_weights = {f"model.{k}": v for k, v in hf_weights.items()}
if "lm_head.weight" not in hf_weights:
hf_weights = {**hf_weights, "lm_head.weight": hf_weights["model.embed_tokens.weight"]}
state_dict = {}
state_dict: dict[str, jax.Array] = {}

state_dict["embed.W_E"] = hf_weights["model.embed_tokens.weight"]

Expand Down
20 changes: 10 additions & 10 deletions src/xlens/pretrained/model_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from abc import ABC, abstractmethod
from typing import Optional
from typing import Any, Optional

import jax
import jax.numpy as jnp
Expand All @@ -21,7 +21,7 @@ def can_convert(self, model_name_or_path: str) -> bool:
pass

@abstractmethod
def get_pretrained_model_config(self, model_name_or_path: str, **kwargs) -> HookedTransformerConfig:
def get_pretrained_model_config(self, model_name_or_path: str, **kwargs: Any) -> HookedTransformerConfig:
"""Get the model configuration for the given model name.
Args:
Expand All @@ -35,7 +35,7 @@ def get_pretrained_model_config(self, model_name_or_path: str, **kwargs) -> Hook

@abstractmethod
def get_pretrained_weights(
self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs
self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs: Any
) -> dict[str, jax.Array]:
"""Get the pretrained weights for the given model.
Expand Down Expand Up @@ -71,7 +71,7 @@ def can_convert(self, model_name_or_path: str) -> bool:
if os.path.isdir(model_name_or_path):
if os.path.exists(os.path.join(model_name_or_path, "config.json")):
hf_cfg = AutoConfig.from_pretrained(model_name_or_path, token=True)
architecture = hf_cfg.architectures[0]
architecture: str = hf_cfg.architectures[0]
return architecture == self.model_architecture
else:
return False
Expand All @@ -90,7 +90,7 @@ def convert_hf_model_config(self, hf_cfg: AutoConfig) -> HookedTransformerConfig
"""
pass

def get_pretrained_model_config(self, model_name_or_path: str, **kwargs) -> HookedTransformerConfig:
def get_pretrained_model_config(self, model_name_or_path: str, **kwargs: Any) -> HookedTransformerConfig:
model_name_or_path = (
model_name_or_path if os.path.isdir(model_name_or_path) else self.rev_alias_map[model_name_or_path]
)
Expand All @@ -110,7 +110,7 @@ def convert_hf_weights(
pass

def get_pretrained_weights(
self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs
self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs: Any
) -> dict[str, jax.Array]:
if os.path.isdir(model_name_or_path):
if os.path.isfile(os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)):
Expand All @@ -128,7 +128,7 @@ def get_pretrained_weights(
from transformers import AutoModelForCausalLM

hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, token=True, **kwargs)
params = {k: jnp.array(v) for k, v in flatten_dict(hf_model.state_dict()).items()}
params: dict[str, jax.Array] = {k: jnp.array(v) for k, v in flatten_dict(hf_model.state_dict()).items()}
else:
params = safe_load_file(resolved_archive_file)
return self.convert_hf_weights(params, cfg)
Expand Down Expand Up @@ -175,14 +175,14 @@ def model_architectures(self) -> list[str]:
def can_convert(self, model_name_or_path: str) -> bool:
if os.path.isdir(model_name_or_path):
if os.path.exists(os.path.join(model_name_or_path, "config.json")):
architecture = AutoConfig.from_pretrained(model_name_or_path, token=True).architectures[0]
architecture: Any = AutoConfig.from_pretrained(model_name_or_path, token=True).architectures[0]
return architecture in self.model_architectures
else:
return False
else:
return model_name_or_path in self.rev_alias_map

def get_pretrained_model_config(self, model_name_or_path: str, **kwargs) -> HookedTransformerConfig:
def get_pretrained_model_config(self, model_name_or_path: str, **kwargs: Any) -> HookedTransformerConfig:
if os.path.isdir(model_name_or_path):
hf_cfg = AutoConfig.from_pretrained(model_name_or_path, token=True)
architecture = hf_cfg.architectures[0]
Expand All @@ -194,7 +194,7 @@ def get_pretrained_model_config(self, model_name_or_path: str, **kwargs) -> Hook
return self.name_converter_map[model_name_or_path].get_pretrained_model_config(model_name_or_path, **kwargs)

def get_pretrained_weights(
self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs
self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs: Any
) -> dict[str, jax.Array]:
if cfg.original_architecture in self.architecture_converter_map:
return self.architecture_converter_map[cfg.original_architecture].get_pretrained_weights(
Expand Down
15 changes: 9 additions & 6 deletions src/xlens/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, Hashable, TypeVar
from typing import Any, Hashable, TypeVar, cast

import jax

T = TypeVar("T")
U = TypeVar("U")


def transformer_lens_compatible_path_str(key_path: jax.tree_util.KeyPath) -> str:
def transformer_lens_compatible_path_str(key_path: tuple[Hashable, ...]) -> str:
def _transform_key_entry(entry: Hashable) -> str:
if isinstance(entry, jax.tree_util.SequenceKey):
return str(entry.idx)
Expand Down Expand Up @@ -40,8 +40,9 @@ def get_nested_component(
tree,
is_leaf=None if component_type is None else lambda x: isinstance(x, component_type),
)
flattened = cast(list[tuple[tuple[Hashable, ...], Any]], flattened)

def filter_path(key_path: jax.tree_util.KeyPath):
def filter_path(key_path: tuple[Hashable, ...]):
return path in [jax.tree_util.keystr(key_path)] + (
[transformer_lens_compatible_path_str(key_path)] if transformer_lens_compatible else []
)
Expand Down Expand Up @@ -72,8 +73,9 @@ def set_nested_component(
tree,
is_leaf=None if component_type is None else lambda x: isinstance(x, component_type),
)
flattened = cast(list[tuple[tuple[Hashable, ...], Any]], flattened)

def filter_path(key_path: jax.tree_util.KeyPath):
def filter_path(key_path: tuple[Hashable, ...]):
return path in [jax.tree_util.keystr(key_path)] + (
[transformer_lens_compatible_path_str(key_path)] if transformer_lens_compatible else []
)
Expand All @@ -97,6 +99,7 @@ def load_pretrained_weights(
"""

flattened, tree_def = jax.tree_util.tree_flatten_with_path(model)
flattened = cast(list[tuple[tuple[Hashable, ...], Any]], flattened)

res = [
pretrained_weights.get(transformer_lens_compatible_path_str(key_path), x)
Expand All @@ -108,8 +111,8 @@ def load_pretrained_weights(
return jax.tree_util.tree_unflatten(tree_def, res)


def flatten_dict(d, parent_key="", sep="."):
items = []
def flatten_dict(d: dict[str, Any], parent_key: str = "", sep: str = ".") -> dict[str, Any]:
items: list[tuple[str, Any]] = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, dict):
Expand Down
2 changes: 1 addition & 1 deletion tests/acceptance/computation/test_gpt2_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_gpt2_computation():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
hf_model.eval()

hf_input = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
hf_input: torch.Tensor = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
hf_logits = hf_model(hf_input).logits

del hf_model
Expand Down
2 changes: 1 addition & 1 deletion tests/acceptance/computation/test_llama_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_llama_computation():
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
hf_model.eval()

hf_input = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"]
hf_input: torch.Tensor = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"]
hf_output = hf_model(hf_input)
hf_logits = hf_output.logits

Expand Down
2 changes: 1 addition & 1 deletion tests/acceptance/computation/test_mistral_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_mistral_computation():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
hf_model.eval()

hf_input = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"]
hf_input: torch.Tensor = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"]
hf_output = hf_model(hf_input, output_hidden_states=True)
hf_logits = hf_output.logits
hf_hidden_states = hf_output.hidden_states
Expand Down
2 changes: 1 addition & 1 deletion tests/acceptance/computation/test_pythia_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_pythia_computation():
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
hf_model.eval()

hf_input = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"]
hf_input: torch.Tensor = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"]
hf_output = hf_model(hf_input)
hf_logits = hf_output.logits

Expand Down
4 changes: 3 additions & 1 deletion tests/acceptance/computation/test_qwen2_computation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import jax
import jax.numpy as jnp
import pytest
Expand All @@ -20,7 +22,7 @@ def test_qwen2_computation():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
hf_model.eval()

hf_input = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"]
hf_input: Any = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"]
hf_output = hf_model(hf_input)
hf_logits = hf_output.logits

Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import equinox as eqx
import jax

from xlens import HookPoint, get_nested_component, set_nested_component


class ModuleA(eqx.Module):
hook_point: HookPoint

def __call__(self, x):
def __call__(self, x: jax.Array) -> jax.Array:
return self.hook_point(x)


class ModuleB(eqx.Module):
module_as: list[ModuleA]

def __call__(self, x):
def __call__(self, x: jax.Array) -> jax.Array:
for module_a in self.module_as:
x = module_a(x)
return x
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_with_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import equinox as eqx
import jax
import jax.numpy as jnp

from xlens import HookPoint, with_cache
Expand All @@ -7,7 +8,7 @@
class ModuleA(eqx.Module):
hook_mid: HookPoint

def __call__(self, x):
def __call__(self, x: jax.Array) -> jax.Array:
return self.hook_mid(x * 2) * 2


Expand Down

0 comments on commit 441c0f4

Please sign in to comment.