diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index ba684b8e5..1b8b7a632 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -157,7 +157,7 @@ def on_step(self, info, force: bool = False): if not force: return # don't save checkpoint at step 0 unless forced - if step == self._last_save_step: + if step == self._last_save_step and not force: # we've already saved a checkpoint at this step return diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index 462c1cf2c..82471e5df 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -1,11 +1,10 @@ # References: # * Orbax: https://github.com/google/orbax/blob/11d2934ecfff77e86b5e07d0fef02b67eff4511b/orbax/checkpoint/pytree_checkpoint_handler.py#L312 -import asyncio -import functools import logging import os +from dataclasses import dataclass from functools import partial -from typing import Callable, Optional +from typing import Any, Callable, Optional import equinox import jax @@ -13,12 +12,11 @@ import jax.numpy as jnp import jax.tree_util as jtu import numpy as np -import tensorstore -from jax.sharding import Mesh -from tensorstore import TensorStore +from jax.sharding import Mesh, Sharding +from jaxtyping import PyTree import haliax as hax -import haliax.tree_util as htu +from haliax.jax_utils import is_jax_array_like from haliax.partitioning import ResourceMapping from haliax.util import is_named_array @@ -45,15 +43,23 @@ def tree_serialize_leaves_tensorstore( else: manager_was_none = False - leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=_is_named_or_none) + leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array) + assert len(jax.tree.leaves(leaf_key_paths, is_leaf=is_named_array)) == len( + jax.tree.leaves(pytree, is_leaf=is_named_array) + ) - def path_from_key_path(key_path): - return os.path.join(checkpoint_dir, *key_path.split(".")) + paths = _fs_paths_from_key_paths(checkpoint_dir, leaf_key_paths) - paths = jtu.tree_map(path_from_key_path, leaf_key_paths, is_leaf=lambda x: x is None) - paths = jtu.tree_leaves(paths, is_leaf=lambda x: x is None) - leaves = jtu.tree_leaves(pytree, is_leaf=lambda x: x is None) - assert len(leaves) == len(paths) + # make a dataclass since tuples are pytrees + @dataclass + class Pair: + path: str + leaf: Any + + zipped = jax.tree.map(lambda x, y: Pair(x, y), paths, pytree, is_leaf=lambda x: x is None) + paired_leaves = jax.tree.leaves(zipped) + paths = [p.path for p in paired_leaves] + leaves = [p.leaf.array if is_named_array(p.leaf) else p.leaf for p in paired_leaves] # ok, not all of these are arrays, but we'll deal with that in the async function def _ensure_is_array(x): @@ -79,88 +85,40 @@ def _ensure_is_array(x): manager.wait_until_finished() -def _tensorstore_spec_for(checkpoint_dir, key_path: str): - checkpoint_path = os.path.join(checkpoint_dir, *key_path.split(".")) - ts_spec = array_ser.get_tensorstore_spec(checkpoint_path) - return ts_spec - +def _fs_paths_from_key_paths(checkpoint_dir, leaf_key_paths): + def path_from_key_path(key_path): + return os.path.join(checkpoint_dir, *key_path.split(".")) -async def _serialize_one_leaf(x, spec): - if isinstance(x, hax.NamedArray): - # we don't need to do anything special for named arrays to serialize, though we will for deserialization. - return await _serialize_one_leaf(x.array, spec) - elif isinstance(x, jax.Array): - if not x.is_fully_addressable: - return await array_ser.async_serialize(x, spec) - else: - return await save_array_to_tensorstore(x, spec) - elif isinstance(x, (bool, float, complex, int)): - return await save_array_to_tensorstore(np.array(x), spec) - elif x is None: - return - elif isinstance(x, jnp.ndarray): - return await save_array_to_tensorstore(x, spec) - elif isinstance(x, np.ndarray): - return await save_array_to_tensorstore(x, spec) + paths = jtu.tree_map(path_from_key_path, leaf_key_paths) + return paths + + +def _sharding_from_leaf(leaf, axis_mapping, mesh) -> Optional[jax.sharding.Sharding]: + if is_named_array(leaf): + if leaf.array is None: + return None + return hax.partitioning.sharding_for_axis(leaf.axes, axis_mapping, mesh) + elif hasattr(leaf, "sharding") and getattr(leaf, "sharding") is not None: + return leaf.sharding + elif is_jax_array_like(leaf): + return _fully_replicated_sharding(mesh) + elif isinstance(leaf, (bool, float, complex, int, np.ndarray)): + return _fully_replicated_sharding(mesh) else: - raise TypeError(f"Can't serialize {type(x)}") - - -async def save_array_to_tensorstore(x, spec): - if jax.process_index() == 0: - if x.dtype == jnp.bfloat16: - # Tensorstore uses 'bfloat16', not '