From 441c0f4128fc8cdb67a9a86a8c494ebaccb58ce4 Mon Sep 17 00:00:00 2001 From: Dest1n1 Date: Tue, 5 Nov 2024 16:29:39 +0800 Subject: [PATCH] chore: add strict type annotation --- pyproject.toml | 9 +++++++-- src/xlens/components/attention.py | 4 ++-- src/xlens/components/transformer_block.py | 2 +- src/xlens/hooked_transformer.py | 4 ++-- src/xlens/hooks/utilities.py | 2 +- src/xlens/pretrained/convert.py | 4 +++- src/xlens/pretrained/converters/gpt2.py | 2 +- src/xlens/pretrained/converters/llama.py | 4 ++-- src/xlens/pretrained/converters/mistral.py | 2 +- src/xlens/pretrained/converters/neox.py | 2 +- src/xlens/pretrained/converters/qwen2.py | 2 +- src/xlens/pretrained/model_converter.py | 20 +++++++++---------- src/xlens/utils.py | 15 ++++++++------ .../computation/test_gpt2_computation.py | 2 +- .../computation/test_llama_computation.py | 2 +- .../computation/test_mistral_computation.py | 2 +- .../computation/test_pythia_computation.py | 2 +- .../computation/test_qwen2_computation.py | 4 +++- tests/unit/test_utils.py | 5 +++-- tests/unit/test_with_cache.py | 3 ++- 20 files changed, 53 insertions(+), 39 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 86c4654..a8e91e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,5 +132,10 @@ docstring-code-line-length = "dynamic" [tool.pyright] -typeCheckingMode = "standard" -reportUnknownMemberType = false \ No newline at end of file +typeCheckingMode = "strict" +reportUnknownMemberType = false +reportUntypedFunctionDecorator = false +reportUnknownArgumentType = false +reportUnknownVariableType = false +reportMissingTypeStubs = false +reportConstantRedefinition = false \ No newline at end of file diff --git a/src/xlens/components/attention.py b/src/xlens/components/attention.py index 2ca2c2f..1c0edad 100644 --- a/src/xlens/components/attention.py +++ b/src/xlens/components/attention.py @@ -323,7 +323,7 @@ def calculate_sin_cos_rotary( rotary_dim: int, n_ctx: int, base: int = 10000, - dtype=jnp.float32, + dtype: jnp.dtype = jnp.float32, # type: ignore use_NTK_by_parts_rope: bool = False, NTK_by_parts_factor: float = 8.0, NTK_by_parts_low_freq_factor: float = 1.0, @@ -368,7 +368,7 @@ def calculate_sin_cos_rotary( def apply_rotary( self, x: Float[jax.Array, "batch pos head_index d_head"], - past_kv_pos_offset=0, + past_kv_pos_offset: int = 0, attention_mask: Optional[jnp.ndarray] = None, rotary_dim: int = 64, ) -> jnp.ndarray: diff --git a/src/xlens/components/transformer_block.py b/src/xlens/components/transformer_block.py index 5f85984..00c0b9e 100644 --- a/src/xlens/components/transformer_block.py +++ b/src/xlens/components/transformer_block.py @@ -52,7 +52,7 @@ def __init__(self, cfg: HookedTransformerConfig, block_index: int): elif cfg.normalization_type is None: # This should just be the identity. # We need to make this a lambda so we can call it on the config, just like the others - def normalization_layer(cfg): + def normalization_layer(cfg: HookedTransformerConfig): def identity(x: jax.Array): return x diff --git a/src/xlens/hooked_transformer.py b/src/xlens/hooked_transformer.py index 4aebf17..62543d0 100644 --- a/src/xlens/hooked_transformer.py +++ b/src/xlens/hooked_transformer.py @@ -90,7 +90,7 @@ def __call__( pos_embed = self.hook_pos_embed(self.pos_embed(tokens, 0, attention_mask)) # [batch, pos, d_model] residual = embed + pos_embed - for i, block in list(zip(range(self.cfg.n_layers), self.blocks)): + for _, block in list(zip(range(self.cfg.n_layers), self.blocks)): # Note that each block includes skip connections, so we don't need # residual + block(residual) residual = block( @@ -157,7 +157,7 @@ def run_with_cache( return out, cache @classmethod - def from_pretrained(cls, model_name: str, hf_model=None) -> "HookedTransformer": + def from_pretrained(cls, model_name: str, hf_model: Any = None) -> "HookedTransformer": """Load a pretrained model. Args: diff --git a/src/xlens/hooks/utilities.py b/src/xlens/hooks/utilities.py index ec710f1..b59acad 100644 --- a/src/xlens/hooks/utilities.py +++ b/src/xlens/hooks/utilities.py @@ -42,7 +42,7 @@ def with_cache(tree: U, hook_names: list[str] = []) -> tuple[U, dict[str, Any]]: cache = {} def hook_fn(name: str): - def _hook_fn(x): + def _hook_fn(x: Any): cache[name] = x return x diff --git a/src/xlens/pretrained/convert.py b/src/xlens/pretrained/convert.py index dfac70f..b2b7185 100644 --- a/src/xlens/pretrained/convert.py +++ b/src/xlens/pretrained/convert.py @@ -3,6 +3,8 @@ This module contains functions for loading pretrained models from the Hugging Face Hub. """ +from typing import Any + import jax from xlens.config import HookedTransformerConfig @@ -30,5 +32,5 @@ def get_pretrained_model_config(model_name: str) -> HookedTransformerConfig: return converter.get_pretrained_model_config(model_name) -def get_pretrained_weights(cfg: HookedTransformerConfig, model_name: str, hf_model=None) -> dict[str, jax.Array]: +def get_pretrained_weights(cfg: HookedTransformerConfig, model_name: str, hf_model: Any = None) -> dict[str, jax.Array]: return converter.get_pretrained_weights(cfg, model_name, hf_model=hf_model) diff --git a/src/xlens/pretrained/converters/gpt2.py b/src/xlens/pretrained/converters/gpt2.py index 8b7ed7a..97d95b8 100644 --- a/src/xlens/pretrained/converters/gpt2.py +++ b/src/xlens/pretrained/converters/gpt2.py @@ -41,7 +41,7 @@ def convert_hf_weights( hf_weights = {f"transformer.{k}": v for k, v in hf_weights.items()} | { "lm_head.weight": hf_weights["wte.weight"] } - state_dict = {} + state_dict: dict[str, jax.Array] = {} state_dict["embed.W_E"] = hf_weights["transformer.wte.weight"] assert state_dict["embed.W_E"].shape == (cfg.d_vocab, cfg.d_model) diff --git a/src/xlens/pretrained/converters/llama.py b/src/xlens/pretrained/converters/llama.py index e46e4a8..2a07fa3 100644 --- a/src/xlens/pretrained/converters/llama.py +++ b/src/xlens/pretrained/converters/llama.py @@ -63,7 +63,7 @@ def __init__(self): def convert_hf_model_config(self, hf_cfg: Any) -> HookedTransformerConfig: if hasattr(hf_cfg, "rope_scaling") and hf_cfg.rope_scaling is not None: - ntk_cfg = { + ntk_cfg: dict[str, Any] = { "use_NTK_by_parts_rope": True, "NTK_by_parts_low_freq_factor": hf_cfg.rope_scaling["low_freq_factor"], "NTK_by_parts_high_freq_factor": hf_cfg.rope_scaling["high_freq_factor"], @@ -102,7 +102,7 @@ def convert_hf_weights( hf_weights = {f"model.{k}": v for k, v in hf_weights.items()} if "lm_head.weight" not in hf_weights: hf_weights = {**hf_weights, "lm_head.weight": hf_weights["model.embed_tokens.weight"]} - state_dict = {} + state_dict: dict[str, jax.Array] = {} state_dict["embed.W_E"] = hf_weights["model.embed_tokens.weight"] diff --git a/src/xlens/pretrained/converters/mistral.py b/src/xlens/pretrained/converters/mistral.py index d303bcf..59c07eb 100644 --- a/src/xlens/pretrained/converters/mistral.py +++ b/src/xlens/pretrained/converters/mistral.py @@ -56,7 +56,7 @@ def convert_hf_weights( hf_weights = {f"model.{k}": v for k, v in hf_weights.items()} if "lm_head.weight" not in hf_weights: hf_weights = {**hf_weights, "lm_head.weight": hf_weights["model.embed_tokens.weight"]} - state_dict = {} + state_dict: dict[str, jax.Array] = {} state_dict["embed.W_E"] = hf_weights["model.embed_tokens.weight"] diff --git a/src/xlens/pretrained/converters/neox.py b/src/xlens/pretrained/converters/neox.py index 237af37..aedc288 100644 --- a/src/xlens/pretrained/converters/neox.py +++ b/src/xlens/pretrained/converters/neox.py @@ -93,7 +93,7 @@ def convert_hf_weights( if "embed_out.weight" not in hf_weights: hf_weights = {**hf_weights, "embed_out.weight": hf_weights["gpt_neox.embed_in.weight"]} - state_dict = {} + state_dict: dict[str, jax.Array] = {} state_dict["embed.W_E"] = hf_weights["gpt_neox.embed_in.weight"] diff --git a/src/xlens/pretrained/converters/qwen2.py b/src/xlens/pretrained/converters/qwen2.py index 26aaa06..ee0c83c 100644 --- a/src/xlens/pretrained/converters/qwen2.py +++ b/src/xlens/pretrained/converters/qwen2.py @@ -73,7 +73,7 @@ def convert_hf_weights( hf_weights = {f"model.{k}": v for k, v in hf_weights.items()} if "lm_head.weight" not in hf_weights: hf_weights = {**hf_weights, "lm_head.weight": hf_weights["model.embed_tokens.weight"]} - state_dict = {} + state_dict: dict[str, jax.Array] = {} state_dict["embed.W_E"] = hf_weights["model.embed_tokens.weight"] diff --git a/src/xlens/pretrained/model_converter.py b/src/xlens/pretrained/model_converter.py index 0594de7..2e06b01 100644 --- a/src/xlens/pretrained/model_converter.py +++ b/src/xlens/pretrained/model_converter.py @@ -2,7 +2,7 @@ import logging import os from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Optional import jax import jax.numpy as jnp @@ -21,7 +21,7 @@ def can_convert(self, model_name_or_path: str) -> bool: pass @abstractmethod - def get_pretrained_model_config(self, model_name_or_path: str, **kwargs) -> HookedTransformerConfig: + def get_pretrained_model_config(self, model_name_or_path: str, **kwargs: Any) -> HookedTransformerConfig: """Get the model configuration for the given model name. Args: @@ -35,7 +35,7 @@ def get_pretrained_model_config(self, model_name_or_path: str, **kwargs) -> Hook @abstractmethod def get_pretrained_weights( - self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs + self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs: Any ) -> dict[str, jax.Array]: """Get the pretrained weights for the given model. @@ -71,7 +71,7 @@ def can_convert(self, model_name_or_path: str) -> bool: if os.path.isdir(model_name_or_path): if os.path.exists(os.path.join(model_name_or_path, "config.json")): hf_cfg = AutoConfig.from_pretrained(model_name_or_path, token=True) - architecture = hf_cfg.architectures[0] + architecture: str = hf_cfg.architectures[0] return architecture == self.model_architecture else: return False @@ -90,7 +90,7 @@ def convert_hf_model_config(self, hf_cfg: AutoConfig) -> HookedTransformerConfig """ pass - def get_pretrained_model_config(self, model_name_or_path: str, **kwargs) -> HookedTransformerConfig: + def get_pretrained_model_config(self, model_name_or_path: str, **kwargs: Any) -> HookedTransformerConfig: model_name_or_path = ( model_name_or_path if os.path.isdir(model_name_or_path) else self.rev_alias_map[model_name_or_path] ) @@ -110,7 +110,7 @@ def convert_hf_weights( pass def get_pretrained_weights( - self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs + self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs: Any ) -> dict[str, jax.Array]: if os.path.isdir(model_name_or_path): if os.path.isfile(os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)): @@ -128,7 +128,7 @@ def get_pretrained_weights( from transformers import AutoModelForCausalLM hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path, token=True, **kwargs) - params = {k: jnp.array(v) for k, v in flatten_dict(hf_model.state_dict()).items()} + params: dict[str, jax.Array] = {k: jnp.array(v) for k, v in flatten_dict(hf_model.state_dict()).items()} else: params = safe_load_file(resolved_archive_file) return self.convert_hf_weights(params, cfg) @@ -175,14 +175,14 @@ def model_architectures(self) -> list[str]: def can_convert(self, model_name_or_path: str) -> bool: if os.path.isdir(model_name_or_path): if os.path.exists(os.path.join(model_name_or_path, "config.json")): - architecture = AutoConfig.from_pretrained(model_name_or_path, token=True).architectures[0] + architecture: Any = AutoConfig.from_pretrained(model_name_or_path, token=True).architectures[0] return architecture in self.model_architectures else: return False else: return model_name_or_path in self.rev_alias_map - def get_pretrained_model_config(self, model_name_or_path: str, **kwargs) -> HookedTransformerConfig: + def get_pretrained_model_config(self, model_name_or_path: str, **kwargs: Any) -> HookedTransformerConfig: if os.path.isdir(model_name_or_path): hf_cfg = AutoConfig.from_pretrained(model_name_or_path, token=True) architecture = hf_cfg.architectures[0] @@ -194,7 +194,7 @@ def get_pretrained_model_config(self, model_name_or_path: str, **kwargs) -> Hook return self.name_converter_map[model_name_or_path].get_pretrained_model_config(model_name_or_path, **kwargs) def get_pretrained_weights( - self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs + self, cfg: HookedTransformerConfig, model_name_or_path: str, **kwargs: Any ) -> dict[str, jax.Array]: if cfg.original_architecture in self.architecture_converter_map: return self.architecture_converter_map[cfg.original_architecture].get_pretrained_weights( diff --git a/src/xlens/utils.py b/src/xlens/utils.py index 5e0ef2c..c101bae 100644 --- a/src/xlens/utils.py +++ b/src/xlens/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Hashable, TypeVar +from typing import Any, Hashable, TypeVar, cast import jax @@ -6,7 +6,7 @@ U = TypeVar("U") -def transformer_lens_compatible_path_str(key_path: jax.tree_util.KeyPath) -> str: +def transformer_lens_compatible_path_str(key_path: tuple[Hashable, ...]) -> str: def _transform_key_entry(entry: Hashable) -> str: if isinstance(entry, jax.tree_util.SequenceKey): return str(entry.idx) @@ -40,8 +40,9 @@ def get_nested_component( tree, is_leaf=None if component_type is None else lambda x: isinstance(x, component_type), ) + flattened = cast(list[tuple[tuple[Hashable, ...], Any]], flattened) - def filter_path(key_path: jax.tree_util.KeyPath): + def filter_path(key_path: tuple[Hashable, ...]): return path in [jax.tree_util.keystr(key_path)] + ( [transformer_lens_compatible_path_str(key_path)] if transformer_lens_compatible else [] ) @@ -72,8 +73,9 @@ def set_nested_component( tree, is_leaf=None if component_type is None else lambda x: isinstance(x, component_type), ) + flattened = cast(list[tuple[tuple[Hashable, ...], Any]], flattened) - def filter_path(key_path: jax.tree_util.KeyPath): + def filter_path(key_path: tuple[Hashable, ...]): return path in [jax.tree_util.keystr(key_path)] + ( [transformer_lens_compatible_path_str(key_path)] if transformer_lens_compatible else [] ) @@ -97,6 +99,7 @@ def load_pretrained_weights( """ flattened, tree_def = jax.tree_util.tree_flatten_with_path(model) + flattened = cast(list[tuple[tuple[Hashable, ...], Any]], flattened) res = [ pretrained_weights.get(transformer_lens_compatible_path_str(key_path), x) @@ -108,8 +111,8 @@ def load_pretrained_weights( return jax.tree_util.tree_unflatten(tree_def, res) -def flatten_dict(d, parent_key="", sep="."): - items = [] +def flatten_dict(d: dict[str, Any], parent_key: str = "", sep: str = ".") -> dict[str, Any]: + items: list[tuple[str, Any]] = [] for k, v in d.items(): new_key = parent_key + sep + k if parent_key else k if isinstance(v, dict): diff --git a/tests/acceptance/computation/test_gpt2_computation.py b/tests/acceptance/computation/test_gpt2_computation.py index 05ad0c1..99117eb 100644 --- a/tests/acceptance/computation/test_gpt2_computation.py +++ b/tests/acceptance/computation/test_gpt2_computation.py @@ -18,7 +18,7 @@ def test_gpt2_computation(): tokenizer = GPT2Tokenizer.from_pretrained("gpt2") hf_model.eval() - hf_input = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"] + hf_input: torch.Tensor = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"] hf_logits = hf_model(hf_input).logits del hf_model diff --git a/tests/acceptance/computation/test_llama_computation.py b/tests/acceptance/computation/test_llama_computation.py index 2694591..520adab 100644 --- a/tests/acceptance/computation/test_llama_computation.py +++ b/tests/acceptance/computation/test_llama_computation.py @@ -18,7 +18,7 @@ def test_llama_computation(): tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") hf_model.eval() - hf_input = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"] + hf_input: torch.Tensor = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"] hf_output = hf_model(hf_input) hf_logits = hf_output.logits diff --git a/tests/acceptance/computation/test_mistral_computation.py b/tests/acceptance/computation/test_mistral_computation.py index 2217814..0293882 100644 --- a/tests/acceptance/computation/test_mistral_computation.py +++ b/tests/acceptance/computation/test_mistral_computation.py @@ -20,7 +20,7 @@ def test_mistral_computation(): tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") hf_model.eval() - hf_input = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"] + hf_input: torch.Tensor = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"] hf_output = hf_model(hf_input, output_hidden_states=True) hf_logits = hf_output.logits hf_hidden_states = hf_output.hidden_states diff --git a/tests/acceptance/computation/test_pythia_computation.py b/tests/acceptance/computation/test_pythia_computation.py index 13bc9bd..219ebd8 100644 --- a/tests/acceptance/computation/test_pythia_computation.py +++ b/tests/acceptance/computation/test_pythia_computation.py @@ -20,7 +20,7 @@ def test_pythia_computation(): tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m") hf_model.eval() - hf_input = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"] + hf_input: torch.Tensor = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"] hf_output = hf_model(hf_input) hf_logits = hf_output.logits diff --git a/tests/acceptance/computation/test_qwen2_computation.py b/tests/acceptance/computation/test_qwen2_computation.py index 63fc29e..06f0cc8 100644 --- a/tests/acceptance/computation/test_qwen2_computation.py +++ b/tests/acceptance/computation/test_qwen2_computation.py @@ -1,3 +1,5 @@ +from typing import Any + import jax import jax.numpy as jnp import pytest @@ -20,7 +22,7 @@ def test_qwen2_computation(): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") hf_model.eval() - hf_input = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"] + hf_input: Any = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"] hf_output = hf_model(hf_input) hf_logits = hf_output.logits diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 1cceda9..55d202b 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,5 @@ import equinox as eqx +import jax from xlens import HookPoint, get_nested_component, set_nested_component @@ -6,14 +7,14 @@ class ModuleA(eqx.Module): hook_point: HookPoint - def __call__(self, x): + def __call__(self, x: jax.Array) -> jax.Array: return self.hook_point(x) class ModuleB(eqx.Module): module_as: list[ModuleA] - def __call__(self, x): + def __call__(self, x: jax.Array) -> jax.Array: for module_a in self.module_as: x = module_a(x) return x diff --git a/tests/unit/test_with_cache.py b/tests/unit/test_with_cache.py index 85f3f5a..bb08890 100644 --- a/tests/unit/test_with_cache.py +++ b/tests/unit/test_with_cache.py @@ -1,4 +1,5 @@ import equinox as eqx +import jax import jax.numpy as jnp from xlens import HookPoint, with_cache @@ -7,7 +8,7 @@ class ModuleA(eqx.Module): hook_mid: HookPoint - def __call__(self, x): + def __call__(self, x: jax.Array) -> jax.Array: return self.hook_mid(x * 2) * 2