diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index ba684b8e5..1b8b7a632 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -157,7 +157,7 @@ def on_step(self, info, force: bool = False): if not force: return # don't save checkpoint at step 0 unless forced - if step == self._last_save_step: + if step == self._last_save_step and not force: # we've already saved a checkpoint at this step return diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index 674a1cfa9..82471e5df 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -2,8 +2,9 @@ # * Orbax: https://github.com/google/orbax/blob/11d2934ecfff77e86b5e07d0fef02b67eff4511b/orbax/checkpoint/pytree_checkpoint_handler.py#L312 import logging import os +from dataclasses import dataclass from functools import partial -from typing import Callable, Optional +from typing import Any, Callable, Optional import equinox import jax @@ -43,11 +44,22 @@ def tree_serialize_leaves_tensorstore( manager_was_none = False leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array) + assert len(jax.tree.leaves(leaf_key_paths, is_leaf=is_named_array)) == len( + jax.tree.leaves(pytree, is_leaf=is_named_array) + ) paths = _fs_paths_from_key_paths(checkpoint_dir, leaf_key_paths) - paths = jtu.tree_leaves(paths, is_leaf=lambda x: x is None) - leaves = jtu.tree_leaves(pytree, is_leaf=lambda x: x is None) - assert len(leaves) == len(paths) + + # make a dataclass since tuples are pytrees + @dataclass + class Pair: + path: str + leaf: Any + + zipped = jax.tree.map(lambda x, y: Pair(x, y), paths, pytree, is_leaf=lambda x: x is None) + paired_leaves = jax.tree.leaves(zipped) + paths = [p.path for p in paired_leaves] + leaves = [p.leaf.array if is_named_array(p.leaf) else p.leaf for p in paired_leaves] # ok, not all of these are arrays, but we'll deal with that in the async function def _ensure_is_array(x): diff --git a/src/levanter/utils/jax_utils.py b/src/levanter/utils/jax_utils.py index 39fbf438c..734ec930c 100644 --- a/src/levanter/utils/jax_utils.py +++ b/src/levanter/utils/jax_utils.py @@ -152,17 +152,21 @@ def leaf_key_paths( x, prefix=join_key(prefix, p), is_leaf=is_leaf, use_state_dict_keys=use_state_dict_keys ) + out: PyTree[str] + if is_leaf is not None and is_leaf(pytree): - return prefix + out = prefix + elif pytree is None: + out = None elif isinstance(pytree, dict): - return {k: rec(v, k) for k, v in pytree.items()} + out = {k: rec(v, k) for k, v in pytree.items()} elif _isnamedtupleinstance(pytree): d = {k: rec(v, k) for k, v in pytree._asdict().items()} - return pytree.__class__(**d) + out = pytree.__class__(**d) elif isinstance(pytree, list): - return [rec(v, str(i)) for i, v in enumerate(pytree)] + out = [rec(v, str(i)) for i, v in enumerate(pytree)] elif isinstance(pytree, tuple): - return tuple(rec(v, str(i)) for i, v in enumerate(pytree)) + out = tuple(rec(v, str(i)) for i, v in enumerate(pytree)) elif isinstance(pytree, eqx.Module): names = [] rec_values = [] @@ -181,17 +185,17 @@ def leaf_key_paths( _, tree_def = eqx.tree_flatten_one_level(pytree) out = jax.tree_util.tree_unflatten(tree_def, rec_values) - return out - # this doesn't work reliably because tree_at doesn't like none values - # return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values, is_leaf=lambda x: x is None) else: leaves, treedef = jax.tree_util.tree_flatten(pytree, is_leaf=is_leaf) if len(leaves) == 0: - return None + out = None elif len(leaves) == 1: - return jax.tree_util.tree_unflatten(treedef, [f"{prefix}"]) + out = jax.tree_util.tree_unflatten(treedef, [f"{prefix}"]) else: - return jax.tree_util.tree_unflatten(treedef, [join_key(prefix, str(i)) for i in range(len(leaves))]) + out = jax.tree_util.tree_unflatten(treedef, [join_key(prefix, str(i)) for i in range(len(leaves))]) + + # assert len(jax.tree.leaves(out, is_leaf=is_leaf)) == len(jax.tree.leaves(pytree, is_leaf=is_leaf)), (out, pytree) + return out def join_key(prefix, k):