Skip to content

Commit

Permalink
Use haliax state dict (#805)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Nov 17, 2024
1 parent f8ab21a commit 3d5677e
Show file tree
Hide file tree
Showing 19 changed files with 115 additions and 1,015 deletions.
67 changes: 34 additions & 33 deletions docs/dev/Port-Models.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,22 @@ We follow the same breakdown in the implementation of Llama in Levanter.

#### Note on the Implementation Format
- Each class will have its key layers and components defined as attributes and be initialized with a static method `init()`.
- Each class will be extended from Equinox's `Module` class.
- Each class will be extended from Equinox's `Module` class, except for classes with custom serialization logic, which instead inherit
from [haliax.state_dict.ModuleWithStateDictSerialization][].
- [hax.nn.Linear][] modules can have "articulated" input or output axes, where PyTorch and other libraries typically require
a single input and output axis. For instance, attention modules in Levanter typically have a `Linear` from `Embed` to `(Heads, HeadSize)`.
When serializing these linear modules to state dicts (see the next section), Haliax will automatically flatten them. You should
ensure that `out_first=True` is set on Linear modules if they're going to be loaded as PyTorch Linear modules.

### Serialization to/from State Dicts

PyTorch and Hugging Face Transformers use "state dicts" as their preferred serialization format, either as pickles or as the new [safetensors](https://github.com/huggingface/safetensors) format.
A state dict is a Python `dict` with string keys and tensor values. The keys of the dict are json-ish "key paths" like `model.blocks.0.mlp.c_proj` and the values are the corresponding parameters for that key path.
You can read more about [PyTorch State Dicts here](https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html).
You can read [PyTorch's State Dict docs](https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html)
if you want to learn more.

Levanter has machinery for (de)serializing to and from state dicts. Simple cases are handled automatically, but often times custom logic is needed.
[Haliax has machinery for (de)serializing to and from state dicts](https://haliax.readthedocs.io/state-dict/).
Simple cases are handled automatically, but sometimes custom logic is needed.

#### Easy Case: Identical Module Structure
If your module has exactly the same fields with the same names and same shapes as the Hugging Face state dict (e.g. Gpt2Mlp), you don't need to do anything.
Expand All @@ -113,41 +120,45 @@ If your module has exactly the same fields with the same names and same shapes a
If for some reason you want to use different names from the HF implementation (e.g. because the names from HF aren't clear...), you can extend your class from `StateDictSerializationMixin` and use `_state_dict_key_map` to rename keys. For instance, the `Gpt2Transformer` class has this method:

```python
class Gpt2Transformer(StateDictSerializationMixin, eqx.Module):
from haliax.state_dict import ModuleWithStateDictSerialization

class Gpt2Transformer(ModuleWithStateDictSerialization):
...

def _state_dict_key_map(self) -> Dict[str, Optional[str]]:
return {"blocks": "h"}
```

This says that the field called `blocks` in this class should be (de)serialized as `h`, because the Hugging Face GPT-2 implementation uses `h`, which is not very clear. You can also "flatten" the submodules of a field by using `None` or even include `.s` in the name if needed.
This says that the field called `blocks` in this class should be (de)serialized as `h`, because the Hugging Face GPT-2 implementation uses `h`, which is not very clear. You can also "flatten" the submodules of a field by using `None`.

#### Hard Case: Custom Serialization

If your modules need special logic, you'll need to extend your class from `StateDictSerializationMixin` and overwrite the default function `to_state_dict()` and `from_state_dict()`. It takes in a Hugging Face state dict and returns a Levanter state_dict.

For implementation, there are a few helper classes from torch_serialization that you can use:
- To add specific prefix to the keys of Hugging Face state_dict, you can use the helper function `apply_prefix()`. The prefix comes from the name of attributes defined at the beginning of your model class.
- To unflatten the linear layers of Hugging Face, you can use the helper function `unflatten_linear_params()`.
- To unstack the transformer blocks of Hugging Face, you can use the helper function `unstack_transformer_blocks()`.
If your modules need special logic, you'll need to extend your class from `ModuleWithStateDictSerialization` and
overwrite the default function `update_state_dict()` and `from_state_dict()`. It takes in and returns a modified
[haliax.state_dict.StateDict][]. As of May 2024, we almost never this in Levanter.

For example, below is the implementation of `from_state_dict()` in `LlamaAttention`. `LLamaAttention`, like most attention layers in Levanter, projects the embeddings (with shape `(Pos, Embed)`) to a tensor of shape `(Pos, Head, HeadSize)`, rather than the convention in Hugging Face Transformers that uses `(Pos, Head * HeadSize)` and then reshapes (because PyTorch Linear doesn't support multiple input or output dimensions).
For implementation, there are a few helper methods from `haliax.state_dict` that you can use:
- To join specific prefix to the keys of Hugging Face state_dict, you can use the helper function `apply_prefix()`. The prefix comes from the name of attributes defined at the beginning of your model class.

This difference means that we have to flatten linear layers when converting to a Hugging Face Transformers-compatible state dict and unflatten them when they are read in.
For example, below is the implementation of `update_state_dict()` in [levanter.models.backpack.BackpackLMHeadModel][].
In this class, we want to preserve HF compatibility by saving untied output embeddings. (We chose not to implement
non-weight-tied embeddings.)

```python
def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> "LlamaAttention":
# unflatten the linear layers of HF state_dict to match the shape of LlamaAttention
d = {}
d.update(unflatten_linear_layers(apply_prefix(prefix, "q_proj"), state_dict, self.q_proj, True))
d.update(unflatten_linear_layers(apply_prefix(prefix, "k_proj"), state_dict, self.k_proj, True))
d.update(unflatten_linear_layers(apply_prefix(prefix, "v_proj"), state_dict, self.v_proj, True))
d.update(unflatten_linear_layers(apply_prefix(prefix, "o_proj"), state_dict, self.o_proj, True))

return super().from_state_dict(d, prefix)
def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict:
state_dict = super().update_state_dict(state_dict, prefix=prefix)
# In levanter's implementation, we have a shared embedding matrix for both the word
# embeddings and the sense embeddings
state_dict[apply_prefix(prefix, "backpack.word_embeddings.weight")] = state_dict[
apply_prefix(prefix, "backpack.gpt2_model.wte.weight")
]
state_dict[apply_prefix(prefix, "backpack.position_embeddings.weight")] = state_dict[
apply_prefix(prefix, "backpack.gpt2_model.wpe.weight")
]
return state_dict
```

Similarly, to save weights to Hugging Face, you will need to write a class function `to_state_dict()` in each of your model class.
Similarly, to load weights from the state dict, you'll need to implement `from_state_dict`.

The correctness of your implementation can be validated through serialization tests, which will be discussed in the next section.

Expand Down Expand Up @@ -302,13 +313,3 @@ Check out [Training on Your Own Data](../Training-On-Your-Data.md) for more deta
If you are interested in profiling the training throughput of your model, good news is that it comes for free with automatic job monitoring in Levanter, powered through Weights & Biases.

Once you run a training job, on the corresponding job page on Weights & Biases, you will be able to find a section named "Throughput". It reports metrics like `examples_per_second` and `tokens_per_second` across the training time.

## Tips for Optimization
1. Avoid upcasting to float32. Levanter uses bfloat16 by default, which is more memory efficient and faster for training. You should avoid upcasting to float32 unless it is necessary for stability or accuracy.
2. For attention, rearrange the heads and position axes to make the computation more efficient. For example, in Llama, we did the following:

```python
q = self.q_proj(x, key=key_q).rearrange((..., "heads", "position", "head_size"))
k = self.k_proj(x, key=key_k).rearrange((..., "heads", "position", "head_size"))
v = self.v_proj(x, key=key_v).rearrange((..., "heads", "position", "head_size"))
```
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ classifiers = [
"Intended Audience :: Science/Research",
]
dependencies = [
"haliax>=1.4.dev307",
"haliax>=1.4.dev324",
"equinox>=0.11.7",
"jaxtyping>=0.2.34",
"tokenizers>=0.15.2",
"transformers>=4.41.2",
"transformers>=4.41.2,<4.46.0",
"optax>=0.1.9",
"wandb>=0.17.8",
"draccus>=0.9.3",
Expand Down
19 changes: 11 additions & 8 deletions src/levanter/compat/hf_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
import haliax
from haliax import Axis
from haliax.partitioning import ResourceMapping
from haliax.state_dict import from_torch_compatible_state_dict, save_state_dict, to_torch_compatible_state_dict

from levanter.compat.torch_serialization import StateDictSerializationMixin, save_state_dict, to_numpy_state_dict
from levanter.logging import silence_transformer_nag
from levanter.models.asr_model import ASRMixin
from levanter.models.lm_model import LmConfig, LmHeadModel
Expand Down Expand Up @@ -128,7 +128,7 @@ def hf_checkpoint_converter(cls) -> "HFCheckpointConverter":
MConfig = TypeVar("MConfig", bound=HFCompatConfig)


class ModelWithHfSerializationMixin(Generic[MConfig], StateDictSerializationMixin):
class ModelWithHfSerializationMixin(Generic[MConfig]):
def get_hf_config(self):
return self.config.to_hf_config(self.Vocab.size)

Expand Down Expand Up @@ -545,9 +545,8 @@ def load_pretrained(
ignore_prefix = self.ignore_prefix
break

def load_from_state_dict(state_dict):
lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0))
lev_model = lev_model.from_state_dict(state_dict, prefix=ignore_prefix)
def load_from_state_dict(template, state_dict):
lev_model = from_torch_compatible_state_dict(template, state_dict, prefix=ignore_prefix)

# However, this might miss some buffers that don't get persisted in the state dict
# (e.g. pytorch buffers with persistent=false), so we have to reinitialize them. We then init the model
Expand All @@ -574,7 +573,10 @@ def load_from_state_dict(state_dict):
if just_use_cpu:
cpu_device = jax.local_devices(backend="cpu")[0]
with local_cpu_mesh():
lev_model = eqx.filter_jit(load_from_state_dict, donate="all", device=cpu_device)(state_dict)
lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0))
lev_model = eqx.filter_jit(load_from_state_dict, donate="all", device=cpu_device)(
lev_model, state_dict
)

del state_dict
# gotta move it to the accelerator now (assuming there is one!)
Expand All @@ -583,7 +585,8 @@ def load_from_state_dict(state_dict):
load_from_state_dict = haliax.named_jit(
load_from_state_dict, axis_resources=axis_mapping, out_axis_resources=axis_mapping, donate_args=(True,)
)
lev_model = load_from_state_dict(state_dict)
lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0))
lev_model = load_from_state_dict(lev_model, state_dict)

# all_arrays: list[jax.Array] = get_backend().live_arrays()
# total_size = sum(a.size * a.itemsize for a in all_arrays)
Expand Down Expand Up @@ -695,7 +698,7 @@ def _save_pretrained_local(
json.dump(dict_config, f)

# Model
state_dict = to_numpy_state_dict(model)
state_dict = to_torch_compatible_state_dict(model)
shards, index = _shard_hf_checkpoint(state_dict, max_shard_size, SAFE_TENSORS_MODEL)
if index is None:
save_state_dict(state_dict, os.path.join(path, SAFE_TENSORS_MODEL))
Expand Down
Loading

0 comments on commit 3d5677e

Please sign in to comment.