From 3d5677e5b68cb0e9b05e1c42b659c1d54e2d90c7 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 16 Nov 2024 20:09:00 -0800 Subject: [PATCH] Use haliax state dict (#805) --- docs/dev/Port-Models.md | 67 +-- pyproject.toml | 4 +- src/levanter/compat/hf_checkpoints.py | 19 +- src/levanter/compat/torch_serialization.py | 474 --------------------- src/levanter/lora.py | 14 +- src/levanter/models/backpack.py | 67 +-- src/levanter/models/gemma.py | 44 +- src/levanter/models/gpt2.py | 63 +-- src/levanter/models/llama.py | 125 +----- src/levanter/models/lm_model.py | 2 +- src/levanter/models/mistral.py | 32 +- src/levanter/models/mpt.py | 67 +-- src/levanter/models/whisper.py | 89 +--- tests/test_backpack.py | 2 +- tests/test_gemma.py | 7 +- tests/test_hf_checkpoints.py | 2 +- tests/test_hf_gpt2_serialize.py | 8 +- tests/test_llama.py | 7 +- tests/test_torch_serialization.py | 37 -- 19 files changed, 115 insertions(+), 1015 deletions(-) delete mode 100644 src/levanter/compat/torch_serialization.py diff --git a/docs/dev/Port-Models.md b/docs/dev/Port-Models.md index 282f51508..f76d0a6d8 100644 --- a/docs/dev/Port-Models.md +++ b/docs/dev/Port-Models.md @@ -96,15 +96,22 @@ We follow the same breakdown in the implementation of Llama in Levanter. #### Note on the Implementation Format - Each class will have its key layers and components defined as attributes and be initialized with a static method `init()`. -- Each class will be extended from Equinox's `Module` class. +- Each class will be extended from Equinox's `Module` class, except for classes with custom serialization logic, which instead inherit +from [haliax.state_dict.ModuleWithStateDictSerialization][]. +- [hax.nn.Linear][] modules can have "articulated" input or output axes, where PyTorch and other libraries typically require +a single input and output axis. For instance, attention modules in Levanter typically have a `Linear` from `Embed` to `(Heads, HeadSize)`. +When serializing these linear modules to state dicts (see the next section), Haliax will automatically flatten them. You should +ensure that `out_first=True` is set on Linear modules if they're going to be loaded as PyTorch Linear modules. ### Serialization to/from State Dicts PyTorch and Hugging Face Transformers use "state dicts" as their preferred serialization format, either as pickles or as the new [safetensors](https://github.com/huggingface/safetensors) format. A state dict is a Python `dict` with string keys and tensor values. The keys of the dict are json-ish "key paths" like `model.blocks.0.mlp.c_proj` and the values are the corresponding parameters for that key path. -You can read more about [PyTorch State Dicts here](https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html). +You can read [PyTorch's State Dict docs](https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html) +if you want to learn more. -Levanter has machinery for (de)serializing to and from state dicts. Simple cases are handled automatically, but often times custom logic is needed. +[Haliax has machinery for (de)serializing to and from state dicts](https://haliax.readthedocs.io/state-dict/). +Simple cases are handled automatically, but sometimes custom logic is needed. #### Easy Case: Identical Module Structure If your module has exactly the same fields with the same names and same shapes as the Hugging Face state dict (e.g. Gpt2Mlp), you don't need to do anything. @@ -113,41 +120,45 @@ If your module has exactly the same fields with the same names and same shapes a If for some reason you want to use different names from the HF implementation (e.g. because the names from HF aren't clear...), you can extend your class from `StateDictSerializationMixin` and use `_state_dict_key_map` to rename keys. For instance, the `Gpt2Transformer` class has this method: ```python -class Gpt2Transformer(StateDictSerializationMixin, eqx.Module): +from haliax.state_dict import ModuleWithStateDictSerialization + +class Gpt2Transformer(ModuleWithStateDictSerialization): ... def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"blocks": "h"} ``` -This says that the field called `blocks` in this class should be (de)serialized as `h`, because the Hugging Face GPT-2 implementation uses `h`, which is not very clear. You can also "flatten" the submodules of a field by using `None` or even include `.s` in the name if needed. +This says that the field called `blocks` in this class should be (de)serialized as `h`, because the Hugging Face GPT-2 implementation uses `h`, which is not very clear. You can also "flatten" the submodules of a field by using `None`. #### Hard Case: Custom Serialization -If your modules need special logic, you'll need to extend your class from `StateDictSerializationMixin` and overwrite the default function `to_state_dict()` and `from_state_dict()`. It takes in a Hugging Face state dict and returns a Levanter state_dict. - -For implementation, there are a few helper classes from torch_serialization that you can use: -- To add specific prefix to the keys of Hugging Face state_dict, you can use the helper function `apply_prefix()`. The prefix comes from the name of attributes defined at the beginning of your model class. -- To unflatten the linear layers of Hugging Face, you can use the helper function `unflatten_linear_params()`. -- To unstack the transformer blocks of Hugging Face, you can use the helper function `unstack_transformer_blocks()`. +If your modules need special logic, you'll need to extend your class from `ModuleWithStateDictSerialization` and +overwrite the default function `update_state_dict()` and `from_state_dict()`. It takes in and returns a modified +[haliax.state_dict.StateDict][]. As of May 2024, we almost never this in Levanter. -For example, below is the implementation of `from_state_dict()` in `LlamaAttention`. `LLamaAttention`, like most attention layers in Levanter, projects the embeddings (with shape `(Pos, Embed)`) to a tensor of shape `(Pos, Head, HeadSize)`, rather than the convention in Hugging Face Transformers that uses `(Pos, Head * HeadSize)` and then reshapes (because PyTorch Linear doesn't support multiple input or output dimensions). +For implementation, there are a few helper methods from `haliax.state_dict` that you can use: +- To join specific prefix to the keys of Hugging Face state_dict, you can use the helper function `apply_prefix()`. The prefix comes from the name of attributes defined at the beginning of your model class. -This difference means that we have to flatten linear layers when converting to a Hugging Face Transformers-compatible state dict and unflatten them when they are read in. +For example, below is the implementation of `update_state_dict()` in [levanter.models.backpack.BackpackLMHeadModel][]. +In this class, we want to preserve HF compatibility by saving untied output embeddings. (We chose not to implement +non-weight-tied embeddings.) ```python -def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> "LlamaAttention": - # unflatten the linear layers of HF state_dict to match the shape of LlamaAttention - d = {} - d.update(unflatten_linear_layers(apply_prefix(prefix, "q_proj"), state_dict, self.q_proj, True)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "k_proj"), state_dict, self.k_proj, True)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "v_proj"), state_dict, self.v_proj, True)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "o_proj"), state_dict, self.o_proj, True)) - - return super().from_state_dict(d, prefix) + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + state_dict = super().update_state_dict(state_dict, prefix=prefix) + # In levanter's implementation, we have a shared embedding matrix for both the word + # embeddings and the sense embeddings + state_dict[apply_prefix(prefix, "backpack.word_embeddings.weight")] = state_dict[ + apply_prefix(prefix, "backpack.gpt2_model.wte.weight") + ] + state_dict[apply_prefix(prefix, "backpack.position_embeddings.weight")] = state_dict[ + apply_prefix(prefix, "backpack.gpt2_model.wpe.weight") + ] + return state_dict ``` -Similarly, to save weights to Hugging Face, you will need to write a class function `to_state_dict()` in each of your model class. +Similarly, to load weights from the state dict, you'll need to implement `from_state_dict`. The correctness of your implementation can be validated through serialization tests, which will be discussed in the next section. @@ -302,13 +313,3 @@ Check out [Training on Your Own Data](../Training-On-Your-Data.md) for more deta If you are interested in profiling the training throughput of your model, good news is that it comes for free with automatic job monitoring in Levanter, powered through Weights & Biases. Once you run a training job, on the corresponding job page on Weights & Biases, you will be able to find a section named "Throughput". It reports metrics like `examples_per_second` and `tokens_per_second` across the training time. - -## Tips for Optimization -1. Avoid upcasting to float32. Levanter uses bfloat16 by default, which is more memory efficient and faster for training. You should avoid upcasting to float32 unless it is necessary for stability or accuracy. -2. For attention, rearrange the heads and position axes to make the computation more efficient. For example, in Llama, we did the following: - -```python -q = self.q_proj(x, key=key_q).rearrange((..., "heads", "position", "head_size")) -k = self.k_proj(x, key=key_k).rearrange((..., "heads", "position", "head_size")) -v = self.v_proj(x, key=key_v).rearrange((..., "heads", "position", "head_size")) -``` diff --git a/pyproject.toml b/pyproject.toml index abca1405d..29e54b9a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,11 +21,11 @@ classifiers = [ "Intended Audience :: Science/Research", ] dependencies = [ - "haliax>=1.4.dev307", + "haliax>=1.4.dev324", "equinox>=0.11.7", "jaxtyping>=0.2.34", "tokenizers>=0.15.2", - "transformers>=4.41.2", + "transformers>=4.41.2,<4.46.0", "optax>=0.1.9", "wandb>=0.17.8", "draccus>=0.9.3", diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index 7a116acae..5822c3fba 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -31,8 +31,8 @@ import haliax from haliax import Axis from haliax.partitioning import ResourceMapping +from haliax.state_dict import from_torch_compatible_state_dict, save_state_dict, to_torch_compatible_state_dict -from levanter.compat.torch_serialization import StateDictSerializationMixin, save_state_dict, to_numpy_state_dict from levanter.logging import silence_transformer_nag from levanter.models.asr_model import ASRMixin from levanter.models.lm_model import LmConfig, LmHeadModel @@ -128,7 +128,7 @@ def hf_checkpoint_converter(cls) -> "HFCheckpointConverter": MConfig = TypeVar("MConfig", bound=HFCompatConfig) -class ModelWithHfSerializationMixin(Generic[MConfig], StateDictSerializationMixin): +class ModelWithHfSerializationMixin(Generic[MConfig]): def get_hf_config(self): return self.config.to_hf_config(self.Vocab.size) @@ -545,9 +545,8 @@ def load_pretrained( ignore_prefix = self.ignore_prefix break - def load_from_state_dict(state_dict): - lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0)) - lev_model = lev_model.from_state_dict(state_dict, prefix=ignore_prefix) + def load_from_state_dict(template, state_dict): + lev_model = from_torch_compatible_state_dict(template, state_dict, prefix=ignore_prefix) # However, this might miss some buffers that don't get persisted in the state dict # (e.g. pytorch buffers with persistent=false), so we have to reinitialize them. We then init the model @@ -574,7 +573,10 @@ def load_from_state_dict(state_dict): if just_use_cpu: cpu_device = jax.local_devices(backend="cpu")[0] with local_cpu_mesh(): - lev_model = eqx.filter_jit(load_from_state_dict, donate="all", device=cpu_device)(state_dict) + lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0)) + lev_model = eqx.filter_jit(load_from_state_dict, donate="all", device=cpu_device)( + lev_model, state_dict + ) del state_dict # gotta move it to the accelerator now (assuming there is one!) @@ -583,7 +585,8 @@ def load_from_state_dict(state_dict): load_from_state_dict = haliax.named_jit( load_from_state_dict, axis_resources=axis_mapping, out_axis_resources=axis_mapping, donate_args=(True,) ) - lev_model = load_from_state_dict(state_dict) + lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0)) + lev_model = load_from_state_dict(lev_model, state_dict) # all_arrays: list[jax.Array] = get_backend().live_arrays() # total_size = sum(a.size * a.itemsize for a in all_arrays) @@ -695,7 +698,7 @@ def _save_pretrained_local( json.dump(dict_config, f) # Model - state_dict = to_numpy_state_dict(model) + state_dict = to_torch_compatible_state_dict(model) shards, index = _shard_hf_checkpoint(state_dict, max_shard_size, SAFE_TENSORS_MODEL) if index is None: save_state_dict(state_dict, os.path.join(path, SAFE_TENSORS_MODEL)) diff --git a/src/levanter/compat/torch_serialization.py b/src/levanter/compat/torch_serialization.py deleted file mode 100644 index 32ba84554..000000000 --- a/src/levanter/compat/torch_serialization.py +++ /dev/null @@ -1,474 +0,0 @@ -import re -from dataclasses import fields -from typing import Any, Dict, List, Optional, TypeVar, cast, overload - -import equinox as eqx -import jax -import numpy as np -import safetensors.numpy -from jax import numpy as jnp -from jax.experimental.multihost_utils import sync_global_devices -from jax.sharding import Mesh, NamedSharding, PartitionSpec -from jaxtyping import PyTree - -import haliax as hax -import haliax.nn as hnn -import haliax.partitioning -from haliax import NamedArray -from haliax._src.util import index_where -from haliax.jax_utils import is_jax_array_like -from haliax.util import ensure_tuple - -from levanter.utils.jax_utils import leaf_key_paths - - -StateDict = Dict[str, Any] -Tensor = Any - - -@overload -def apply_prefix(prefix: Optional[str], leaf: str) -> str: - ... - - -@overload -def apply_prefix(prefix: Optional[str], leaf: None) -> Optional[str]: - ... - - -@overload -def apply_prefix(prefix: None, leaf: Optional[str]) -> Optional[str]: - ... - - -def apply_prefix(prefix: Optional[str], leaf: Optional[str]) -> Optional[str]: - if prefix is None: - return leaf - elif leaf is None: - return prefix - else: - return f"{prefix}.{leaf}" - - -Mod = TypeVar("Mod", bound=eqx.Module) - - -class StateDictSerializationMixin: - """An eqx.Module that can be serialized to a torch-style state dict.""" - - def to_state_dict(self, prefix: Optional[str] = None) -> StateDict: - return jax_tree_to_state_dict(self, prefix) - - def from_state_dict(self: Mod, state_dict: StateDict, prefix: Optional[str] = None) -> Mod: - return default_eqx_module_from_state_dict(self, state_dict, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - return default_update_state_dict_with_eqx_module(state_dict, self, prefix) - - def _state_dict_key_map(self) -> Dict[str, Optional[str]]: - """Returns a dict mapping eqx.Module keys to torch keys that need to be renamed for serialization""" - return {} - - -def jax_tree_from_state_dict(tree: PyTree, state_dict: StateDict, prefix: Optional[str] = None) -> PyTree: - # TODO: assert compatibility of old and new values (type, shape, etc.) - if isinstance(tree, eqx.Module): - if hasattr(tree, "from_state_dict"): - return tree.from_state_dict(state_dict, prefix) - else: - return default_eqx_module_from_state_dict(tree, state_dict, prefix) - elif isinstance(tree, list): - return [ - jax_tree_from_state_dict(item, state_dict, apply_prefix(prefix, str(i))) for i, item in enumerate(tree) - ] - elif isinstance(tree, dict): - return {k: jax_tree_from_state_dict(v, state_dict, prefix=apply_prefix(prefix, k)) for k, v in tree.items()} - elif isinstance(tree, NamedArray): - # TODO: where's the best place to put this logic for NamedArrays - if prefix is None: - raise ValueError("Cannot extract a leaf value from a torch dict without a prefix") - - array = state_dict[prefix] - - if isinstance(array, np.ndarray): - mesh = haliax.partitioning._get_mesh() - if mesh.devices.size > 1: # this happens with the default mesh - pspec = haliax.partitioning.pspec_for_axis(tree.axes) - sharding = jax.sharding.NamedSharding(mesh, pspec) - array = jax.make_array_from_callback(tree.array.shape, sharding, lambda indices: array[indices]) - else: - array = jnp.array(array) - array = haliax.named(array, tree.axes) - else: - array = haliax.named(array, tree.axes) - array = haliax.auto_sharded(array) - - return array - elif is_jax_array_like(tree): - if prefix is None: - raise ValueError("Cannot extract a leaf value from a state dict without a prefix") - # TODO: add "strict" flag so we can return None in cases where it's just missing - return jnp.array(state_dict[prefix]) - else: - if prefix is None: - return tree - return state_dict.get(prefix, tree) - - -def update_state_dict_with_jax_tree(tree: PyTree, state_dict: StateDict, prefix: Optional[str] = None) -> None: - if isinstance(tree, eqx.Module): - if hasattr(tree, "update_state_dict"): - tree.update_state_dict(state_dict, prefix) - else: - default_update_state_dict_with_eqx_module(state_dict, tree, prefix) - elif isinstance(tree, list): - for i, item in enumerate(tree): - update_state_dict_with_jax_tree(item, state_dict, prefix=apply_prefix(prefix, str(i))) - elif isinstance(tree, dict): - for k, v in tree.items(): - update_state_dict_with_jax_tree(v, state_dict, prefix=apply_prefix(prefix, k)) - elif isinstance(tree, NamedArray): - # TODO: where's the best place to put this logic for NamedArrays - assert prefix is not None - state_dict[prefix] = tree.array - elif is_jax_array_like(tree): - if prefix is not None: - if tree is not None: - state_dict[prefix] = tree # type: ignore - else: - raise ValueError("Cannot update torch dict with a leaf value.") - else: - pass - - -def jax_tree_to_state_dict(tree: PyTree, prefix: Optional[str] = None) -> StateDict: - state_dict: StateDict = {} - update_state_dict_with_jax_tree(tree, state_dict, prefix) - return state_dict - - -def default_eqx_module_from_state_dict(mod: Mod, state_dict: StateDict, prefix: Optional[str] = None) -> Mod: - try: - from haliax.nn.scan import BlockSeq - - if isinstance(mod, BlockSeq): - return block_seq_from_state_dict(mod, state_dict, prefix) - except ImportError: - pass - - key_map: Dict[str, Optional[str]] = getattr(mod, "_state_dict_key_map", lambda: {})() # type: ignore - names = [] - values = [] - for field in fields(mod): - if field.metadata.get("static", False): - continue - key = key_map.get(field.name, field.name) - value = getattr(mod, field.name) - # TODO: might want to add a flag that allows missing keys? - new = jax_tree_from_state_dict(value, state_dict, apply_prefix(prefix, key)) - # Do not try to update parameters that are never defined - if value is None and new is None: - continue - names.append(field.name) - values.append(new) - return eqx.tree_at(lambda m: [getattr(m, name) for name in names], mod, values) - - -def default_eqx_module_to_state_dict(mod: eqx.Module, prefix: Optional[str] = None) -> StateDict: - state_dict: StateDict = {} - default_update_state_dict_with_eqx_module(state_dict, mod, prefix) - return state_dict - - -def default_update_state_dict_with_eqx_module( - state_dict: StateDict, mod: eqx.Module, prefix: Optional[str] = None -) -> StateDict: - try: - from haliax.nn.scan import BlockSeq - - if isinstance(mod, BlockSeq): - return update_block_seq_state_dict(state_dict, mod, prefix) - except ImportError: - pass - - key_map: Dict[str, Optional[str]] = getattr(mod, "_state_dict_key_map", lambda: {})() # type: ignore - for field in fields(mod): - if field.metadata.get("static", False): - continue - key = key_map.get(field.name, field.name) - value = getattr(mod, field.name) - update_state_dict_with_jax_tree(value, state_dict, apply_prefix(prefix, key)) - return state_dict - - -def flatten_linear_layers(prefix: Optional[str], tree: PyTree, out_dims_first_in_dict: Optional[bool]) -> StateDict: - """ - In PyTorch, linear layers are stored as a 2d weight matrix and a 1d bias vector. In Haliax, - linear layers can have arbitrary dimensions, grouped into input and output axes. This function - flattens the linear layers in a state dict into a 2d weight matrix and a 1d bias vector. - - **You should use out_dims_first_in_dict=True if you're using this to convert a PyTorch model to Haliax and the - PyTorch model uses Linear. If the PyTorch model uses Conv1d, use False.** None is probably not what you want, - except in very specific cases. - - :param prefix: prefix to apply to the keys in the state dict - :param tree: - :param out_dims_first_in_dict: if True, the output dimensions will be the first axis in the flattened weight matrix. - If False, the input dimensions will be the first axis. If None, the weight's axes will be left as-is. - This is the default in PyTorch, but not in Haliax. - """ - - ret_dict: StateDict = {} - - def _flatten_linear(layer, prefix): - if not isinstance(layer, hnn.Linear): - return layer - - weight = layer.weight - bias = layer.bias - - if weight.array is not None: - weight = weight.flatten_axes(layer.Out, "__OUT__").flatten_axes(layer.In, "__IN__") - if bias is not None: - bias = bias.flatten_axes(layer.Out, "__OUT__") - - if out_dims_first_in_dict is True: - weight = weight.rearrange((..., "__OUT__", "__IN__")) - elif out_dims_first_in_dict is False: - weight = weight.rearrange((..., "__IN__", "__OUT__")) - else: - pass - - ret_dict[apply_prefix(prefix, "weight")] = weight.array - - if bias is not None: - ret_dict[apply_prefix(prefix, "bias")] = bias.array - - return ret_dict - - tree_prefixes = leaf_key_paths(tree, prefix, is_leaf=lambda x: isinstance(x, hnn.Linear), use_state_dict_keys=True) - jax.tree_util.tree_map(_flatten_linear, tree, tree_prefixes, is_leaf=lambda x: isinstance(x, hnn.Linear)) - return ret_dict - - -def unflatten_linear_layers( - prefix, statedict: StateDict, layer: hnn.Linear, out_dims_first_in_dict: Optional[bool] -) -> StateDict: - """ - In PyTorch, linear layers are stored as a 2d weight matrix and a 1d bias vector. In Haliax, - linear layers can have arbitrary dimensions, grouped into input and output axes. This function - unflattens the linear layers in a state dict into a 2d weight matrix and a 1d bias vector. - - **You should use out_dims_first_in_dict=True if you're using this to convert a PyTorch model to Haliax and the - PyTorch model uses Linear. If the PyTorch model uses Conv1d, use False.** None is probably not what you want, - except in very specific cases. - - :param prefix: prefix to apply to the keys in the state dict - :param statedict: the state dict to source the flattened weights from - :param layer: the exemplar layer to use for unflattening - :param out_dims_first_in_dict: if True, the output dimensions will be the first axis in the flattened weight matrix. - If False, the input dimensions will be the first axis. If None, the weight's axes will be inferred from the linear - :return: - """ - ret_dict: StateDict = {} - - def _unflatten_linear(layer, prefix): - nonlocal out_dims_first_in_dict - - if not isinstance(layer, hnn.Linear): - return layer - - weight = statedict[apply_prefix(prefix, "weight")] - bias = statedict.get(apply_prefix(prefix, "bias"), None) - - Out = ensure_tuple(layer.Out) - In = ensure_tuple(layer.In) - InOut = In + Out - extra_dims = tuple(ax for ax in layer.weight.axes if ax not in InOut) - - if out_dims_first_in_dict is None: - out_dims_first_in_dict = layer.out_first - - if out_dims_first_in_dict: - weight = hax.named(weight, hax.concat_axis_specs(extra_dims, ("__OUT__", "__IN__"))) - else: - weight = hax.named(weight, hax.concat_axis_specs(extra_dims, ("__IN__", "__OUT__"))) - - if layer.out_first: - weight = weight.rearrange((..., "__OUT__", "__IN__")) - else: - weight = weight.rearrange((..., "__IN__", "__OUT__")) - - # now unflatten - weight = weight.unflatten_axis("__OUT__", layer.Out).unflatten_axis("__IN__", layer.In) - - if bias is not None: - bias = hax.named(bias, hax.concat_axis_specs(extra_dims, ("__OUT__",))) - bias = bias.unflatten_axis("__OUT__", layer.Out) - - # tree_structure = jax.tree_structure(layer) - # return jax.tree_unflatten(tree_structure, (weight, bias)) - - ret_dict[apply_prefix(prefix, "weight")] = weight.array - if bias is not None: - ret_dict[apply_prefix(prefix, "bias")] = bias.array - - return ret_dict - - tree_prefixes = leaf_key_paths( - layer, prefix, is_leaf=lambda x: isinstance(x, hnn.Linear), use_state_dict_keys=True - ) - jax.tree_util.tree_map(_unflatten_linear, layer, tree_prefixes, is_leaf=lambda x: isinstance(x, hnn.Linear)) - return ret_dict - - -def unstack_state_dict(state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - """ - Unstack all keys matching prefix in a new state dict, returning a state dict that has all keys matching - prefix unstacked, but otherwise the same. - - Unstacked in this case means roughly "compatible with a torch.nn.Sequential", which means that the - keys are of the form ".0.", ".1.", etc. - :param state_dict: - :param prefix: - :return: - """ - new_dict: StateDict = {} - prefix = apply_prefix(prefix, "") - assert prefix is not None - - for k, v in state_dict.items(): - if k.startswith(prefix) and v is not None: - for i, v_i in enumerate(v): - new_dict[f"{prefix}{i}.{k[len(prefix):]}"] = v_i - else: - new_dict[k] = v - - return new_dict - - -def stack_state_dict(state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - """ - Stack all keys matching prefix in a new state dict, returning a state dict that has all keys matching - prefix stacked, but otherwise the same. - - Stacked in this case means roughly "compatible with a torch.nn.Sequential", which means that the - keys are of the form ".0.", ".1.", etc. - :param state_dict: - :param prefix: - :return: - """ - vectorized_dict: StateDict = {} - - tensors_to_vectorize: Dict[str, List[Optional[Any]]] = {} - escaped = re.escape(prefix or "") - pattern = re.compile(rf"{escaped}\.(\d+)\.(.*)") - - for k, v in state_dict.items(): - match = pattern.match(k) - if match: - block_idx = int(match.group(1)) - block_key = match.group(2) - tensors = tensors_to_vectorize.setdefault(block_key, []) - if len(tensors) <= block_idx: - tensors.extend([None] * (block_idx - len(tensors) + 1)) - assert tensors[block_idx] is None, f"Duplicate key {k}" - tensors[block_idx] = v - else: - vectorized_dict[k] = v - - # now we have to vectorize the tensors - for k, tensors in tensors_to_vectorize.items(): - vectorized_dict[cast(str, apply_prefix(prefix, k))] = jnp.stack(tensors, axis=0) - - return vectorized_dict - - -def block_seq_from_state_dict(seq, state_dict: StateDict, prefix: Optional[str] = None): - out_blocks = [] - for i, block in enumerate(seq.blocks): - my_prefix = apply_prefix(prefix, str(i)) - block = block.from_state_dict(state_dict, my_prefix) - out_blocks.append(block) - - return eqx.tree_at(lambda m: m.blocks, seq, out_blocks) - - -def update_block_seq_state_dict(state_dict: StateDict, seq, prefix: Optional[str] = None): - for i, block in enumerate(seq.blocks): - my_prefix = apply_prefix(prefix, str(i)) - block.update_state_dict(state_dict, my_prefix) - - return state_dict - - -def to_numpy_state_dict(model, prefix: Optional[str] = None) -> StateDict: - """ - Convert a model to a state dict by first creating desharded copies of all parameters that reside in CPU - memory. - - This method is especially useful for saving models distributed across multiple hosts. - """ - - with jax.default_device(jax.local_devices(backend="cpu")[0]): - - def get_to_cpu(arr): - if not is_jax_array_like(arr): - return arr - elif isinstance(arr, np.ndarray): - return arr - elif arr.is_fully_addressable: - r = np.array(arr) - return r - else: - # unfortunately, jax's allgather seems to replicate to every device rather than every host - # which doesn't work for ~7B parameter models on TPU (assuming we also have optimizer state) - # this approach limits us to <64B parameters, but that's good enough for now - # we're going to do something a bit fancy, where we shard the model into a (process, device) mesh, - # then look for some axis along which we can shard the array, and then we'll do an allgather - # via pjit. If we can't find one, we'll just fully replicate since it probably isn't that big. - # TODO: ensure that this mesh arranges devices correctly - # (jax seems to do this internally itself, so we should be fine?) - process_mesh = Mesh(np.array(jax.devices()).reshape((jax.process_count(), -1)), ("process", "device")) - # now we need to find an axis along which we can shard the array. - # for this, we need to find an axis s.t. size(axis) % local_devices == 0 - - try: - axis_to_shard = index_where( - lambda axis_size: axis_size % process_mesh.devices.size == 0, arr.shape - ) - except ValueError: - return np.array(arr) - - shardings = [None if i != axis_to_shard else "device" for i in range(len(arr.shape))] - sharding = NamedSharding(process_mesh, PartitionSpec(*shardings)) - out = jax.device_put(arr, sharding) - return np.array(out) - - # need to make sure the model is on *this machine* and *this machine's CPU* before saving - model = jax.tree_util.tree_map(lambda arr: get_to_cpu(arr), model) - # TODO: it would be nice if safetensors supported an iterator or something so we could do the allgather one at a time - state_dict = model.to_state_dict(prefix=prefix) - return state_dict - - -_GLOBAL_SAVE_COUNT = 0 - - -def save_state_dict(state_dict: StateDict, path): - """ - Save a model's state dict to a file, bringing all tensors to the CPU first and then converting to numpy. - This will save using safetensors format - """ - state_dict = {k: v for k, v in state_dict.items() if v is not None} - # now that we've moved the model to the CPU, we don't need to do this on all processes - if jax.process_index() == 0: - # the "pt" is a lie but it doesn't seem to actually matter and HF demands it - safetensors.numpy.save_file(state_dict, path, metadata={"format": "pt"}) - global _GLOBAL_SAVE_COUNT - sync_global_devices(f"local {_GLOBAL_SAVE_COUNT}") - _GLOBAL_SAVE_COUNT += 1 - - -def _identity_fn(x): - return x diff --git a/src/levanter/lora.py b/src/levanter/lora.py index 83558f75d..1e0f37d67 100644 --- a/src/levanter/lora.py +++ b/src/levanter/lora.py @@ -56,14 +56,14 @@ import haliax.nn as hnn from haliax import Axis from haliax.jax_utils import shaped_rng_split - -from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef, upload_to_hub -from levanter.compat.torch_serialization import ( +from haliax.state_dict import ( + ModuleWithStateDictSerialization, StateDict, - StateDictSerializationMixin, save_state_dict, - to_numpy_state_dict, + to_torch_compatible_state_dict, ) + +from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef, upload_to_hub from levanter.logging import silence_transformer_nag from levanter.trainer import StepInfo from levanter.utils.cloud_utils import temp_dir_before_upload @@ -153,7 +153,7 @@ def merge(self) -> hax.NamedArray: return hax.dot(self.lora_A.weight, self.lora_B.weight, axis=LORA_R) * self.scale -class LoraLinear(eqx.Module, StateDictSerializationMixin): +class LoraLinear(ModuleWithStateDictSerialization): """ Linear layer with LoRA transform. """ @@ -518,5 +518,5 @@ def lora_state_dict(model: M, prefix: Optional[str] = DEFAULT_DICT_PREFIX) -> St Returns a state dict of the LoRA parameters of the given model without other parameters. This method attempts to return a state dict compatible with PEFT's import method. """ - state_dict = to_numpy_state_dict(filter_lora_params(model), prefix=prefix) + state_dict = to_torch_compatible_state_dict(filter_lora_params(model), prefix=prefix) return {k: v for k, v in state_dict.items() if v is not None} diff --git a/src/levanter/models/backpack.py b/src/levanter/models/backpack.py index 4de8accc7..715706f8e 100644 --- a/src/levanter/models/backpack.py +++ b/src/levanter/models/backpack.py @@ -12,15 +12,9 @@ import haliax.nn as hnn from haliax import Axis, AxisSpec, NamedArray from haliax.jax_utils import named_call +from haliax.state_dict import ModuleWithStateDictSerialization, StateDict, with_prefix from levanter.compat.hf_checkpoints import HFCheckpointConverter, LmWithHfSerializationMixin -from levanter.compat.torch_serialization import ( - StateDict, - StateDictSerializationMixin, - apply_prefix, - flatten_linear_layers, - unflatten_linear_layers, -) from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionMask, materialize_mask from levanter.models.gpt2 import ACT2FN, Gpt2Config, Gpt2Transformer @@ -100,7 +94,7 @@ def from_hf_config(cls, hf_config: PretrainedConfig): ) -class BackpackMlp(eqx.Module, StateDictSerializationMixin): +class BackpackMlp(eqx.Module): c_fc: hnn.Linear # projection from Embed to Intermediate (typically 4x Embed) c_proj: hnn.Linear # projection from Intermediate to Embed act: Callable = eqx.static_field() @@ -131,32 +125,8 @@ def __call__(self, x: NamedArray): x = self.c_proj(x) return x - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> "BackpackMlp": - d = {} - d.update( - unflatten_linear_layers( - apply_prefix(prefix, "c_proj"), state_dict, self.c_proj, out_dims_first_in_dict=False - ) - ) - d.update( - unflatten_linear_layers(apply_prefix(prefix, "c_fc"), state_dict, self.c_fc, out_dims_first_in_dict=False) - ) - return super().from_state_dict(d, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - my_dict: StateDict = {} - super().update_state_dict(my_dict, prefix) - - my_dict.update( - flatten_linear_layers(apply_prefix(prefix, "c_proj"), self.c_proj, out_dims_first_in_dict=False) - ) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "c_fc"), self.c_fc, out_dims_first_in_dict=False)) - - state_dict.update(my_dict) - return state_dict - -class WeightsOnlyAttention(StateDictSerializationMixin, eqx.Module): +class WeightsOnlyAttention(ModuleWithStateDictSerialization): """ Changes from Gpt2Attention: 1. No projection; it returns the attention weights @@ -208,25 +178,8 @@ def __call__(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], la attn_weights = self.dropout(attn_weights, key=key) return attn_weights - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> "WeightsOnlyAttention": - d = unflatten_linear_layers( - apply_prefix(prefix, "c_attn"), state_dict, self.c_attn, out_dims_first_in_dict=True - ) - return super().from_state_dict(d, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - # need to undo the reshape we did in from_state_dict - # reminder that everything is vectorized - my_dict: StateDict = {} - super().update_state_dict(my_dict, prefix) - - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "c_attn"), self.c_attn, out_dims_first_in_dict=True)) - - state_dict.update(my_dict) - return state_dict - -class NoMixBlock(StateDictSerializationMixin, eqx.Module): +class NoMixBlock(eqx.Module): ln_1: hnn.LayerNorm ln_2: hnn.LayerNorm mlp: BackpackMlp @@ -266,7 +219,7 @@ def __call__(self, hidden_states: NamedArray, residual: NamedArray, *, key): return hidden_states -class BackpackSenses(StateDictSerializationMixin, eqx.Module): +class BackpackSenses(eqx.Module): dropout: hnn.Dropout block: NoMixBlock ln: hnn.LayerNorm @@ -355,7 +308,7 @@ def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_weights) -class BackpackLMHeadModel(eqx.Module, LmWithHfSerializationMixin): +class BackpackLMHeadModel(LmWithHfSerializationMixin, ModuleWithStateDictSerialization): transformer: Gpt2Transformer embeddings: BackpackGpt2Embeddings sense_net: BackpackSenses @@ -449,10 +402,10 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) state_dict = super().update_state_dict(state_dict, prefix=prefix) # In levanter's implementation, we have a shared embedding matrix for both the word # embeddings and the sense embeddings - state_dict[apply_prefix(prefix, "backpack.word_embeddings.weight")] = state_dict[ - apply_prefix(prefix, "backpack.gpt2_model.wte.weight") + state_dict[with_prefix(prefix, "backpack.word_embeddings.weight")] = state_dict[ + with_prefix(prefix, "backpack.gpt2_model.wte.weight") ] - state_dict[apply_prefix(prefix, "backpack.position_embeddings.weight")] = state_dict[ - apply_prefix(prefix, "backpack.gpt2_model.wpe.weight") + state_dict[with_prefix(prefix, "backpack.position_embeddings.weight")] = state_dict[ + with_prefix(prefix, "backpack.gpt2_model.wpe.weight") ] return state_dict diff --git a/src/levanter/models/gemma.py b/src/levanter/models/gemma.py index c38acf5ef..23e2bf6dc 100644 --- a/src/levanter/models/gemma.py +++ b/src/levanter/models/gemma.py @@ -11,15 +11,9 @@ from haliax import Axis, AxisSpec, NamedArray from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split from haliax.nn.scan import Stacked +from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig -from levanter.compat.torch_serialization import ( - StateDict, - StateDictSerializationMixin, - apply_prefix, - stack_state_dict, - unstack_state_dict, -) from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask from levanter.models.llama import ( # Gemma attention and MLP is identical to LLama @@ -130,8 +124,8 @@ def hf_checkpoint_converter(self) -> HFCheckpointConverter["GemmaConfig"]: # ty # See https://github.com/huggingface/transformers/pull/29402 for more detail. @classmethod def from_hf_config(cls, hf_config: HfConfig): - if hf_config.hidden_activation: - activation_function = hf_config.hidden_activation + if hf_config.hidden_activation is None: + activation_function = "gelu_pytorch_tanh" else: activation_function = "gelu_pytorch_tanh" @@ -231,7 +225,7 @@ def __call__(self, x: NamedArray) -> NamedArray: return out.astype(dtype) -class GemmaDecoderLayer(StateDictSerializationMixin, eqx.Module): +class GemmaDecoderLayer(ModuleWithStateDictSerialization): config: GemmaConfig = eqx.static_field() self_attn: LlamaAttention mlp: LlamaMlp @@ -272,7 +266,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, return output -class GemmaTransformer(StateDictSerializationMixin, eqx.Module): +class GemmaTransformer(ModuleWithStateDictSerialization): config: GemmaConfig = eqx.static_field() layers: BlockFoldable[GemmaDecoderLayer] norm: GemmaRMSNorm @@ -301,25 +295,8 @@ def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray | AttentionMask return x - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - if isinstance(self.layers, Stacked): - state_dict = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "layers")) - - out = super().from_state_dict(state_dict, prefix=prefix) - return out - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - my_state_dict: StateDict = {} - super().update_state_dict(my_state_dict, prefix=prefix) - - if isinstance(self.layers, Stacked): - stacked_dict = unstack_state_dict(my_state_dict, prefix=apply_prefix(prefix, "layers")) - state_dict.update(stacked_dict) - return state_dict - - -class GemmaLMHeadModel(eqx.Module, LmHeadModel[GemmaConfig], StateDictSerializationMixin): +class GemmaLMHeadModel(LmHeadModel[GemmaConfig], ModuleWithStateDictSerialization): transformer: GemmaTransformer # Gemma ties the weights of the embedding matrix and LM head. Rather than @@ -376,12 +353,3 @@ def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[GemmaConfig]": def _state_dict_key_map(self) -> Dict[str, Optional[str]]: """Map from Levanter model names to HF.""" return {"transformer": "model", "embeddings": None} - - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - return super().from_state_dict(state_dict, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - my_dict: StateDict = {} - super().update_state_dict(my_dict, prefix=prefix) - state_dict.update(my_dict) - return state_dict diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index 28e878193..1d2fe5892 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -14,17 +14,9 @@ from haliax import Axis, NamedArray from haliax.jax_utils import named_call, shaped_rng_split from haliax.nn.scan import Stacked +from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, LmWithHfSerializationMixin -from levanter.compat.torch_serialization import ( - StateDict, - StateDictSerializationMixin, - apply_prefix, - flatten_linear_layers, - stack_state_dict, - unflatten_linear_layers, - unstack_state_dict, -) from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.lm_model import LmConfig @@ -160,7 +152,7 @@ def __call__(self, x: NamedArray, *, key=None): return x -class Gpt2Attention(StateDictSerializationMixin, eqx.Module): +class Gpt2Attention(eqx.Module): config: Gpt2Config = eqx.static_field() c_attn: hnn.Linear # input projection from [embed] -> [(q, k, v), heads, head_dim] @@ -218,31 +210,8 @@ def __call__(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], la return attn_output - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> "Gpt2Attention": - # our c_attn is [embed] -> [3, heads, head_dim] and hf's is the flattened [embed] -> [3 * heads * head_dim] - # and our c_proj is [heads, head_dim] -> [embed] and hf's is the flattened [heads * head_dim] -> [embed] - # so we need to reshape the one in the dict before forwarding to the linear - # keep in mind that everything is vectorized in our implementation, so there's a leading num_layers dim - d = {} - d.update(unflatten_linear_layers(apply_prefix(prefix, "c_attn"), state_dict, self.c_attn, None)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "c_proj"), state_dict, self.c_proj, None)) - return super().from_state_dict(d, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - # need to undo the reshape we did in from_state_dict - # reminder that everything is vectorized - my_dict: StateDict = {} - super().update_state_dict(my_dict, prefix) - - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "c_attn"), self.c_attn, None)) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "c_proj"), self.c_proj, None)) - - state_dict.update(my_dict) - return state_dict - - -class Gpt2Block(StateDictSerializationMixin, eqx.Module): +class Gpt2Block(eqx.Module): ln_1: hnn.LayerNorm attn: Gpt2Attention ln_2: hnn.LayerNorm @@ -276,7 +245,7 @@ def __call__(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], la return x -class Gpt2Transformer(StateDictSerializationMixin, eqx.Module): +class Gpt2Transformer(ModuleWithStateDictSerialization): config: Gpt2Config = eqx.static_field() blocks: Stacked[Gpt2Block] ln_f: hnn.LayerNorm @@ -303,28 +272,8 @@ def __call__(self, x: NamedArray, attn_mask: Optional[AttentionMask | NamedArray def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"blocks": "h"} - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - # We use a vectorized set of blocks, meaning that we have 1 GptBlock, - # whereas in hf we have numlayers GptBlocks. So we need to build one GptBlock from numlayers GptBlocks. - # the individual blocks are named h.0.FOO, h.1.FOO, etc. - # we want to vectorize them to h.FOO, h.FOO, etc. - stacked = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "h")) - out = super().from_state_dict(stacked, prefix=prefix) - return out - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - # this method needs to "devectorize" the blocks, so that we have a list of blocks h.0.FOO, h.1.FOO, etc. - # first just do the normal thing with our own dict, which we'll post-process - my_state_dict: StateDict = {} - super().update_state_dict(my_state_dict, prefix) - - stacked_dict = unstack_state_dict(my_state_dict, apply_prefix(prefix, "h")) - state_dict.update(stacked_dict) - - return state_dict - -class Gpt2Embeddings(StateDictSerializationMixin, eqx.Module): +class Gpt2Embeddings(ModuleWithStateDictSerialization, eqx.Module): Vocab: Axis = eqx.static_field() config: Gpt2Config = eqx.static_field() @@ -367,7 +316,7 @@ def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_token_embeddings) -class Gpt2LMHeadModel(eqx.Module, LmWithHfSerializationMixin[Gpt2Config]): +class Gpt2LMHeadModel(LmWithHfSerializationMixin[Gpt2Config]): transformer: Gpt2Transformer embeddings: Gpt2Embeddings diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 85861da6a..6b04ec540 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -12,17 +12,9 @@ from haliax import Axis, AxisSpec, NamedArray from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split from haliax.nn.scan import Stacked +from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig -from levanter.compat.torch_serialization import ( - StateDict, - StateDictSerializationMixin, - apply_prefix, - flatten_linear_layers, - stack_state_dict, - unflatten_linear_layers, - unstack_state_dict, -) from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention from levanter.models.gpt2 import ACT2FN @@ -180,7 +172,7 @@ def flops_per_token(self, vocab_size: int): ) -class LlamaMlp(eqx.Module, StateDictSerializationMixin): +class LlamaMlp(eqx.Module): """Multi-layer Perceptron In comparison with GPT2, LlamaMlp adds an up-proj that multiplies with activated gate_proj, before down-proj. @@ -213,46 +205,8 @@ def __call__(self, x: NamedArray, *, key=None) -> NamedArray: outputs = self.down_proj(hidden_states, key=k_down) return outputs - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - # unflatten the linear layers of HF state_dict to match the shape of LlamaMlp - d = {} - d.update( - unflatten_linear_layers( - apply_prefix(prefix, "gate_proj"), state_dict, self.gate_proj, out_dims_first_in_dict=True - ) - ) - d.update( - unflatten_linear_layers( - apply_prefix(prefix, "up_proj"), state_dict, self.up_proj, out_dims_first_in_dict=True - ) - ) - d.update( - unflatten_linear_layers( - apply_prefix(prefix, "down_proj"), state_dict, self.down_proj, out_dims_first_in_dict=True - ) - ) - - return super().from_state_dict(d, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - my_dict: StateDict = {} - super().update_state_dict(my_dict, prefix=prefix) - - my_dict.update( - flatten_linear_layers(apply_prefix(prefix, "gate_proj"), self.gate_proj, out_dims_first_in_dict=True) - ) - my_dict.update( - flatten_linear_layers(apply_prefix(prefix, "up_proj"), self.up_proj, out_dims_first_in_dict=True) - ) - my_dict.update( - flatten_linear_layers(apply_prefix(prefix, "down_proj"), self.down_proj, out_dims_first_in_dict=True) - ) - - state_dict.update(my_dict) - return state_dict - -class LlamaAttention(StateDictSerializationMixin, eqx.Module): +class LlamaAttention(eqx.Module): config: LlamaConfig = eqx.static_field() q_proj: hnn.Linear # projection from Embed to query k_proj: hnn.Linear # projection from Embed to key @@ -316,29 +270,6 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, attn_output = self.o_proj(attn_output, key=key_o) return attn_output - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - # unflatten the linear layers of HF state_dict to match the shape of LlamaAttention - d = {} - d.update(unflatten_linear_layers(apply_prefix(prefix, "q_proj"), state_dict, self.q_proj, True)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "k_proj"), state_dict, self.k_proj, True)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "v_proj"), state_dict, self.v_proj, True)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "o_proj"), state_dict, self.o_proj, True)) - - return super().from_state_dict(d, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - # flatten the linear layers of LlamaAttention to match the shape of HF state_dict - my_dict: StateDict = {} - super().update_state_dict(my_dict, prefix) - - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "q_proj"), self.q_proj, True)) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "k_proj"), self.k_proj, True)) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "v_proj"), self.v_proj, True)) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "o_proj"), self.o_proj, True)) - - state_dict.update(my_dict) - return state_dict - class LlamaRMSNorm(eqx.Module): """ @@ -384,7 +315,7 @@ def __call__(self, x: NamedArray) -> NamedArray: return out.astype(in_dtype) -class LlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): +class LlamaDecoderLayer(eqx.Module): config: LlamaConfig = eqx.static_field() self_attn: LlamaAttention mlp: LlamaMlp @@ -425,7 +356,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, return output -class LlamaTransformer(StateDictSerializationMixin, eqx.Module): +class LlamaTransformer(eqx.Module): config: LlamaConfig = eqx.static_field() layers: BlockFoldable[LlamaDecoderLayer] norm: LlamaRMSNorm @@ -454,27 +385,8 @@ def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray | AttentionMask return x - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - if isinstance(self.layers, Stacked): - state_dict = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "layers")) - - out = super().from_state_dict(state_dict, prefix=prefix) - return out - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - my_state_dict: StateDict = {} - super().update_state_dict(my_state_dict, prefix=prefix) - - if isinstance(self.layers, Stacked): - stacked_dict = unstack_state_dict(my_state_dict, prefix=apply_prefix(prefix, "layers")) - state_dict.update(stacked_dict) - else: - state_dict.update(my_state_dict) - - return state_dict - -class LlamaEmbedding(StateDictSerializationMixin, eqx.Module): +class LlamaEmbedding(ModuleWithStateDictSerialization, eqx.Module): """Similar to GPT2 Embedding, except that: - Llama doesn't have position embedding in the Embedding layer. - Llama doesn't use dropout. @@ -504,7 +416,7 @@ def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_weights) -class LlamaLMHeadModel(eqx.Module, LmHeadModel[LlamaConfig], StateDictSerializationMixin): +class LlamaLMHeadModel(ModuleWithStateDictSerialization, LmHeadModel[LlamaConfig]): transformer: LlamaTransformer embeddings: LlamaEmbedding lm_head: Optional[hnn.Linear] @@ -595,26 +507,3 @@ def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"transformer": "model", "embeddings": None} - - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - # unflatten the linear layers of HF state_dict to match the shape of LlamaMlp - d = state_dict.copy() - if self.lm_head is not None: - d.update( - unflatten_linear_layers( - apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True - ) - ) - return super().from_state_dict(d, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - my_dict: StateDict = {} - super().update_state_dict(my_dict, prefix=prefix) - - if self.lm_head is not None: - my_dict.update( - flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) - ) - - state_dict.update(my_dict) - return state_dict diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 1a82aa7be..63cb2d4e3 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -85,7 +85,7 @@ def build(self, Vocab: Axis, *, key: PRNGKey) -> "LmT": return self.model_type.init(Vocab, self, key=key) # type: ignore -class LmHeadModel(Generic[LmConfigT], abc.ABC): +class LmHeadModel(eqx.Module, Generic[LmConfigT]): """ Superclass for models with a language modeling head. """ diff --git a/src/levanter/models/mistral.py b/src/levanter/models/mistral.py index 764e18aea..b9f19ef41 100644 --- a/src/levanter/models/mistral.py +++ b/src/levanter/models/mistral.py @@ -2,22 +2,15 @@ from dataclasses import dataclass from typing import Dict, Optional, Type, Union -import equinox as eqx import jax.random as jrandom import haliax as hax import haliax.nn as hnn from haliax import Axis, NamedArray from haliax.jax_utils import maybe_rng_split +from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter -from levanter.compat.torch_serialization import ( - StateDict, - StateDictSerializationMixin, - apply_prefix, - flatten_linear_layers, - unflatten_linear_layers, -) from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionBackend, AttentionMask from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaTransformer @@ -150,7 +143,7 @@ def flops_per_token(self, vocab_size: int) -> Optional[float]: ) -class MistralLMHeadModel(eqx.Module, LmHeadModel[MistralConfig], StateDictSerializationMixin): +class MistralLMHeadModel(ModuleWithStateDictSerialization, LmHeadModel[MistralConfig]): transformer: LlamaTransformer embeddings: LlamaEmbedding lm_head: hnn.Linear @@ -210,24 +203,3 @@ def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[MistralConfig]": def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"transformer": "model", "embeddings": None} - - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - # unflatten the linear layers of HF state_dict to match the shape of MistralMlp - d = state_dict.copy() - d.update( - unflatten_linear_layers( - apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True - ) - ) - return super().from_state_dict(d, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - my_dict: StateDict = {} - super().update_state_dict(my_dict, prefix=prefix) - - my_dict.update( - flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) - ) - - state_dict.update(my_dict) - return state_dict diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py index 97b61f1dc..e77e967d7 100644 --- a/src/levanter/models/mpt.py +++ b/src/levanter/models/mpt.py @@ -15,18 +15,10 @@ from haliax import Axis, NamedArray from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split from haliax.nn.scan import Stacked +from haliax.state_dict import ModuleWithStateDictSerialization import levanter.models.attention from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, LmWithHfSerializationMixin -from levanter.compat.torch_serialization import ( - StateDict, - StateDictSerializationMixin, - apply_prefix, - flatten_linear_layers, - stack_state_dict, - unflatten_linear_layers, - unstack_state_dict, -) from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionMask from levanter.models.lm_model import LmConfig @@ -212,7 +204,7 @@ def flops_per_token(self, vocab_size: int) -> Optional[float]: ) -class MptMlp(eqx.Module, StateDictSerializationMixin): +class MptMlp(eqx.Module): up_proj: hnn.Linear # projection from Embed to Intermediate (typically 4x Embed) down_proj: hnn.Linear # projection from Intermediate to Embed @@ -233,7 +225,7 @@ def __call__(self, hidden_states: NamedArray, *, key): # Attention is the same as GPT-2 Attention, modulo alibi -class MptAttention(StateDictSerializationMixin, eqx.Module): +class MptAttention(eqx.Module): Wqkv: hnn.Linear # input projection from [embed] -> [(q, k, v), heads, head_dim] out_proj: hnn.Linear # output projection from [heads, head_dim] -> [embed] @@ -298,35 +290,6 @@ def __call__( return attn_output - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - # our c_attn is [embed] -> [3, heads, head_dim] and hf's is the flattened [embed] -> [3 * heads * head_dim] - # and our c_proj is [heads, head_dim] -> [embed] and hf's is the flattened [heads * head_dim] -> [embed] - # so we need to reshape the one in the dict before forwarding to the linear - # keep in mind that everything is vectorized in our implementation, so there's a leading num_layers dim - - d = unflatten_linear_layers(apply_prefix(prefix, "Wqkv"), state_dict, self.Wqkv, out_dims_first_in_dict=True) - d.update( - unflatten_linear_layers( - apply_prefix(prefix, "out_proj"), state_dict, self.out_proj, out_dims_first_in_dict=True - ) - ) - - return super().from_state_dict(d, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - # need to undo the reshape we did in from_state_dict - # reminder that everything is vectorized - state_dict.update(flatten_linear_layers(apply_prefix(prefix, "Wqkv"), self.Wqkv, out_dims_first_in_dict=True)) - state_dict.update( - flatten_linear_layers(apply_prefix(prefix, "out_proj"), self.out_proj, out_dims_first_in_dict=True) - ) - return state_dict - - -# Block is broadly similar to GPT-2 Block, with the following changes: -# * fancy layer norm type (we ignore this) -# pdrop seems to be off so we won't use it - class MptBlock(eqx.Module): norm_1: eqx.Module @@ -362,7 +325,7 @@ def __call__( return hidden_states -class MptTransformer(StateDictSerializationMixin, eqx.Module): +class MptTransformer(eqx.Module): config: MptConfig = eqx.static_field() blocks: Stacked[MptBlock] norm_f: hnn.LayerNorm @@ -396,28 +359,8 @@ def __call__( return hidden_states - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - # We use a vectorized set of blocks, meaning that we have 1 GptBlock, - # whereas in hf we have numlayers GptBlocks. So we need to build one GptBlock from numlayers GptBlocks. - # the individual blocks are named h.0.FOO, h.1.FOO, etc. - # we want to vectorize them to h.FOO, h.FOO, etc. - stacked = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "blocks")) - out = super().from_state_dict(stacked, prefix=prefix) - return out - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - # this method needs to "devectorize" the blocks, so that we have a list of blocks h.0.FOO, h.1.FOO, etc. - # first just do the normal thing with our own dict, which we'll post-process - my_state_dict: StateDict = {} - super().update_state_dict(my_state_dict, prefix) - - stacked_dict = unstack_state_dict(my_state_dict, apply_prefix(prefix, "blocks")) - state_dict.update(stacked_dict) - - return state_dict - -class MptLmHeadModel(eqx.Module, LmWithHfSerializationMixin): +class MptLmHeadModel(LmWithHfSerializationMixin, ModuleWithStateDictSerialization): wte: hnn.Embedding transformer: MptTransformer _config: MptConfig = eqx.static_field() diff --git a/src/levanter/models/whisper.py b/src/levanter/models/whisper.py index ad1db0ab6..7239626f7 100644 --- a/src/levanter/models/whisper.py +++ b/src/levanter/models/whisper.py @@ -14,17 +14,9 @@ from haliax import Axis, NamedArray from haliax.jax_utils import named_call, shaped_rng_split from haliax.nn.scan import Stacked +from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, ModelWithHfSerializationMixin -from levanter.compat.torch_serialization import ( - StateDict, - StateDictSerializationMixin, - apply_prefix, - flatten_linear_layers, - stack_state_dict, - unflatten_linear_layers, - unstack_state_dict, -) from levanter.logging import silence_transformer_nag from levanter.models.asr_model import ASRConfig, ASRMixin from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention @@ -128,7 +120,7 @@ def from_hf_config(cls, hf_config: HfConfig): ) -class WhisperMlp(eqx.Module, StateDictSerializationMixin): +class WhisperMlp(eqx.Module): fc1: hnn.Linear # projection from Embed to Intermediate (typically 4x Embed) fc2: hnn.Linear # projection from Intermediate to Embed act: Callable = eqx.static_field() @@ -152,29 +144,8 @@ def __call__(self, x: NamedArray, *, key=None): x = self.fc2(x, key=k2) return x - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - d = {} - d.update( - unflatten_linear_layers(apply_prefix(prefix, "fc1"), state_dict, self.fc1, out_dims_first_in_dict=True) - ) - d.update( - unflatten_linear_layers(apply_prefix(prefix, "fc2"), state_dict, self.fc2, out_dims_first_in_dict=True) - ) - - return super().from_state_dict(d, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - my_dict: StateDict = {} - super().update_state_dict(my_dict, prefix=prefix) - - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "fc1"), self.fc1, out_dims_first_in_dict=True)) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "fc2"), self.fc2, out_dims_first_in_dict=True)) - state_dict.update(my_dict) - return state_dict - - -class WhisperAttention(StateDictSerializationMixin, eqx.Module): +class WhisperAttention(eqx.Module): config: WhisperConfig = eqx.static_field() q_proj: hnn.Linear # input projection from [embed] -> [q, heads, head_dim] @@ -229,34 +200,8 @@ def __call__(self, x: NamedArray, xa: Optional[NamedArray] = None, mask: Optiona return attn_output - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - # unflatten the linear layers of HF state_dict to match the shape of LlamaAttention - d = {} - d.update(unflatten_linear_layers(apply_prefix(prefix, "q_proj"), state_dict, self.q_proj, True)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "k_proj"), state_dict, self.k_proj, True)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "v_proj"), state_dict, self.v_proj, True)) - d.update(unflatten_linear_layers(apply_prefix(prefix, "out_proj"), state_dict, self.out_proj, True)) - - return super().from_state_dict(d, prefix) - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - # flatten the linear layers of LlamaAttention to match the shape of HF state_dict - my_dict: StateDict = {} - super().update_state_dict(my_dict, prefix) - - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "q_proj"), self.q_proj, True)) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "k_proj"), self.k_proj, True)) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "v_proj"), self.v_proj, True)) - my_dict.update(flatten_linear_layers(apply_prefix(prefix, "out_proj"), self.out_proj, True)) - state_dict.update(my_dict) - return state_dict - - -class WhisperLayer( - eqx.Module, - StateDictSerializationMixin, -): +class WhisperLayer(ModuleWithStateDictSerialization, eqx.Module): self_attn: WhisperAttention attn_ln: hnn.LayerNorm @@ -314,7 +259,7 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]: } -class WhisperTransformer(eqx.Module, StateDictSerializationMixin): +class WhisperTransformer(ModuleWithStateDictSerialization): layers: Stacked[WhisperLayer] Layer: Axis layer_norm: hnn.LayerNorm @@ -349,22 +294,8 @@ def __call__( return x - def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): - stacked = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "layers")) - out = super().from_state_dict(stacked, prefix=prefix) - return out - - def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: - my_state_dict: StateDict = {} - super().update_state_dict(my_state_dict, prefix=prefix) - - stacked_dict = unstack_state_dict(my_state_dict, prefix=apply_prefix(prefix, "layers")) - state_dict.update(stacked_dict) - - return state_dict - -class WhisperEncoder(eqx.Module, StateDictSerializationMixin): +class WhisperEncoder(ModuleWithStateDictSerialization): config: WhisperConfig = eqx.static_field() conv1: hnn.Conv conv2: hnn.Conv @@ -381,8 +312,10 @@ def init(cls, config: WhisperConfig, *, key) -> "WhisperEncoder": conv1 = hnn.Conv.init(Len, config.Mels, Mid, kernel_size=3, padding=1, key=k_conv1) conv2 = hnn.Conv.init(Len, Mid, config.Embed, kernel_size=3, stride=2, padding=1, key=k_conv2) if isinstance(config.activation_function, str): - activation_fn = ACT2FN[config.activation_function] - act = activation_fn # type: ignore + act = ACT2FN[config.activation_function] # type: ignore + else: + act = config.activation_function + transformer = WhisperTransformer.init( config.EncoderLayer, config.EncoderHeads, @@ -456,7 +389,7 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"token_embeddings": "embed_tokens", "position_embeddings": "embed_positions"} -class WhisperDecoder(eqx.Module, StateDictSerializationMixin): +class WhisperDecoder(ModuleWithStateDictSerialization): transformer: WhisperTransformer embeddings: WhisperDecoderEmbeddings diff --git a/tests/test_backpack.py b/tests/test_backpack.py index 28865dba8..24717b864 100644 --- a/tests/test_backpack.py +++ b/tests/test_backpack.py @@ -93,7 +93,7 @@ def test_backpack_nano_compare(): Vocab = haliax.Axis("vocab", vocab_size) lev_model = BackpackLMHeadModel.init(Vocab, lev_config, key=PRNGKey(0)) - lev_model = lev_model.from_state_dict(loaded_checkpoint) + lev_model = haliax.state_dict.from_torch_compatible_state_dict(lev_model, loaded_checkpoint) lev_model = inference_mode(lev_model, True) hax_input = haliax.named(input, lev_config.Pos) diff --git a/tests/test_gemma.py b/tests/test_gemma.py index 64a3149fe..cf1f91258 100644 --- a/tests/test_gemma.py +++ b/tests/test_gemma.py @@ -44,7 +44,6 @@ def test_gemma_config(): # See https://github.com/huggingface/transformers/pull/29402 for more info. assert gemma_config.activation_function == "gelu_new" # gelu_new is a closer match to gelu_pytorch_tanh assert new_hf_config.hidden_activation == "gelu_pytorch_tanh" - assert new_hf_config.hidden_act == "gelu_pytorch_tanh" # assert the content in new_hf_config is the same as hf_config for k in new_hf_config.__dict__.keys(): @@ -95,7 +94,7 @@ def test_gemma_decoder_layer(num_kv_heads): key = random.PRNGKey(0) gemma_decoder_layer = GemmaDecoderLayer.init(config=gemma_config, key=key) - state = gemma_decoder_layer.to_state_dict() + state = hax.state_dict.to_torch_compatible_state_dict(gemma_decoder_layer) state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} hf_decoder_layer = HFGemmaDecoderLayer(gemma_config.to_hf_config(32000), layer_idx=0) hf_decoder_layer.load_state_dict(state, strict=True) @@ -264,7 +263,7 @@ def test_gemma_attention(use_flash, num_kv_heads): attention = LlamaAttention.init(config=config, key=random.PRNGKey(0)) # type: ignore - state = attention.to_state_dict() + state = hax.state_dict.to_torch_compatible_state_dict(attention) state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} hf_attention = HFGemmaAttention(config.to_hf_config(32000)) hf_attention.load_state_dict(state, strict=True) @@ -293,7 +292,7 @@ def test_gemma_mlp(): config = _get_gemma_config() mlp = LlamaMlp.init(config.Embed, config.Mlp, config.activation_function, key=random.PRNGKey(0)) - state = mlp.to_state_dict() + state = hax.state_dict.to_torch_compatible_state_dict(mlp) state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} hf_mlp = HFGemmaMLP(config.to_hf_config(32000)) hf_mlp.load_state_dict(state, strict=True) diff --git a/tests/test_hf_checkpoints.py b/tests/test_hf_checkpoints.py index 976b6bac4..088daae63 100644 --- a/tests/test_hf_checkpoints.py +++ b/tests/test_hf_checkpoints.py @@ -48,7 +48,7 @@ def test_save_backpack_model_with_code(): Vocab = converter.Vocab lev_model = BackpackLMHeadModel.init(Vocab, lev_config, key=PRNGKey(0)) - lev_model = lev_model.from_state_dict(loaded_checkpoint) + lev_model = haliax.state_dict.from_torch_compatible_state_dict(lev_model, loaded_checkpoint) lev_model = inference_mode(lev_model, True) with tempfile.TemporaryDirectory() as tmpdir: diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py index a0002b1c1..24d87ce0b 100644 --- a/tests/test_hf_gpt2_serialize.py +++ b/tests/test_hf_gpt2_serialize.py @@ -146,7 +146,7 @@ def compute_loss(model: LmHeadModel, input_ids): state_dict = torch_model.transformer.state_dict(keep_vars=True) state_dict = {k: v.grad for k, v in state_dict.items()} - jax_grad_dict = jax_grad.to_state_dict() + jax_grad_dict = hax.state_dict.to_torch_compatible_state_dict(jax_grad) for jax_key, jax_g in jax_grad_dict.items(): if jax_key not in state_dict: @@ -176,7 +176,7 @@ def compute_loss(model: LmHeadModel, input_ids): updates, state = jax_optimizer.update(updates=jax_grad, state=state, params=model) new_model = equinox.apply_updates(model, updates) - new_model_dict = new_model.to_state_dict() + new_model_dict = hax.state_dict.to_torch_compatible_state_dict(new_model) state_dict = torch_model.transformer.state_dict(keep_vars=True) # now compare new params @@ -205,8 +205,8 @@ def test_hf_save_to_fs_spec(): loaded_model = converter.load_pretrained(Gpt2LMHeadModel, ref=f"{tmpdir}/test") - simple_dict = simple_model.to_state_dict() - loaded_dict = loaded_model.to_state_dict() + simple_dict = hax.state_dict.to_torch_compatible_state_dict(simple_model) + loaded_dict = hax.state_dict.to_torch_compatible_state_dict(loaded_model) assert simple_dict.keys() == loaded_dict.keys() diff --git a/tests/test_llama.py b/tests/test_llama.py index 2d2b6506f..87576205d 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -146,7 +146,7 @@ def test_llama_attention(use_flash, num_kv_heads): attention = LlamaAttention.init(config=config, key=random.PRNGKey(0)) - state = attention.to_state_dict() + state = hax.state_dict.to_torch_compatible_state_dict(attention) state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} hf_attention = HFLlamaAttention(config.to_hf_config(32000)) hf_attention.load_state_dict(state, strict=True) @@ -206,7 +206,7 @@ def test_llama_decoder_layer(num_kv_heads): key = random.PRNGKey(0) llama_decoder_layer = LlamaDecoderLayer.init(config=llama_config, key=key) - state = llama_decoder_layer.to_state_dict() + state = hax.state_dict.to_torch_compatible_state_dict(llama_decoder_layer) state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config(32000), layer_idx=0) hf_decoder_layer.load_state_dict(state, strict=True) @@ -387,4 +387,5 @@ def test_state_dict_consistency(scan_layers, num_kv_heads): model = LlamaLMHeadModel.init(Vocab=Vocab, config=config, key=random.PRNGKey(0)) hf_config = config.to_hf_config(Vocab.size) hf_model = LlamaForCausalLM(hf_config) - assert set(hf_model.state_dict().keys()) == set(model.to_state_dict().keys()) + levanter_state_dict = hax.state_dict.to_torch_compatible_state_dict(model) + assert set(hf_model.state_dict().keys()) == set(levanter_state_dict.keys()) diff --git a/tests/test_torch_serialization.py b/tests/test_torch_serialization.py index dab1c5f26..e69de29bb 100644 --- a/tests/test_torch_serialization.py +++ b/tests/test_torch_serialization.py @@ -1,37 +0,0 @@ -import jax -import pytest - -import haliax as hax - -from levanter.compat.torch_serialization import ( - flatten_linear_layers, - jax_tree_from_state_dict, - unflatten_linear_layers, -) - - -@pytest.mark.parametrize("out_dims_first", [True, False]) -def test_unflatten_linear_layers(out_dims_first: bool): - H = hax.Axis("H", 10) - W = hax.Axis("W", 20) - D = hax.Axis("D", 30) - B = hax.Axis("B", 40) - linear = hax.nn.Linear.init((H, W), (D, B), key=jax.random.PRNGKey(0), use_bias=True, out_first=False) - - assert linear.weight.axes == (H, W, D, B) - - # first flatten the weight matrix - flat = flatten_linear_layers(None, linear, out_dims_first_in_dict=out_dims_first) - if out_dims_first: - assert flat["weight"].shape == (D.size * B.size, H.size * W.size) - else: - assert flat["weight"].shape == (H.size * W.size, D.size * B.size) - assert flat["bias"].shape == (D.size * B.size,) - assert flat["weight"].dtype == flat["bias"].dtype == linear.weight.dtype - - # now unflatten it - unflat_dict = unflatten_linear_layers(None, flat, linear, out_dims_first_in_dict=out_dims_first) - new_linear = jax_tree_from_state_dict(linear, unflat_dict) - - assert new_linear.weight.axes == (H, W, D, B) - assert new_linear.bias.axes == (D, B)