Skip to content

Commit

Permalink
fix tree leaf stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Dec 3, 2024
1 parent 074d0ec commit bf13f12
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def on_step(self, info, force: bool = False):
if not force:
return # don't save checkpoint at step 0 unless forced

if step == self._last_save_step:
if step == self._last_save_step and not force:
# we've already saved a checkpoint at this step
return

Expand Down
20 changes: 16 additions & 4 deletions src/levanter/tensorstore_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# * Orbax: https://github.com/google/orbax/blob/11d2934ecfff77e86b5e07d0fef02b67eff4511b/orbax/checkpoint/pytree_checkpoint_handler.py#L312
import logging
import os
from dataclasses import dataclass
from functools import partial
from typing import Callable, Optional
from typing import Any, Callable, Optional

import equinox
import jax
Expand Down Expand Up @@ -43,11 +44,22 @@ def tree_serialize_leaves_tensorstore(
manager_was_none = False

leaf_key_paths = jax_utils.leaf_key_paths(pytree, is_leaf=is_named_array)
assert len(jax.tree.leaves(leaf_key_paths, is_leaf=is_named_array)) == len(
jax.tree.leaves(pytree, is_leaf=is_named_array)
)

paths = _fs_paths_from_key_paths(checkpoint_dir, leaf_key_paths)
paths = jtu.tree_leaves(paths, is_leaf=lambda x: x is None)
leaves = jtu.tree_leaves(pytree, is_leaf=lambda x: x is None)
assert len(leaves) == len(paths)

# make a dataclass since tuples are pytrees
@dataclass
class Pair:
path: str
leaf: Any

zipped = jax.tree.map(lambda x, y: Pair(x, y), paths, pytree, is_leaf=lambda x: x is None)
paired_leaves = jax.tree.leaves(zipped)
paths = [p.path for p in paired_leaves]
leaves = [p.leaf.array if is_named_array(p.leaf) else p.leaf for p in paired_leaves]

# ok, not all of these are arrays, but we'll deal with that in the async function
def _ensure_is_array(x):
Expand Down
26 changes: 15 additions & 11 deletions src/levanter/utils/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,17 +152,21 @@ def leaf_key_paths(
x, prefix=join_key(prefix, p), is_leaf=is_leaf, use_state_dict_keys=use_state_dict_keys
)

out: PyTree[str]

if is_leaf is not None and is_leaf(pytree):
return prefix
out = prefix
elif pytree is None:
out = None
elif isinstance(pytree, dict):
return {k: rec(v, k) for k, v in pytree.items()}
out = {k: rec(v, k) for k, v in pytree.items()}
elif _isnamedtupleinstance(pytree):
d = {k: rec(v, k) for k, v in pytree._asdict().items()}
return pytree.__class__(**d)
out = pytree.__class__(**d)
elif isinstance(pytree, list):
return [rec(v, str(i)) for i, v in enumerate(pytree)]
out = [rec(v, str(i)) for i, v in enumerate(pytree)]
elif isinstance(pytree, tuple):
return tuple(rec(v, str(i)) for i, v in enumerate(pytree))
out = tuple(rec(v, str(i)) for i, v in enumerate(pytree))
elif isinstance(pytree, eqx.Module):
names = []
rec_values = []
Expand All @@ -181,17 +185,17 @@ def leaf_key_paths(

_, tree_def = eqx.tree_flatten_one_level(pytree)
out = jax.tree_util.tree_unflatten(tree_def, rec_values)
return out
# this doesn't work reliably because tree_at doesn't like none values
# return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values, is_leaf=lambda x: x is None)
else:
leaves, treedef = jax.tree_util.tree_flatten(pytree, is_leaf=is_leaf)
if len(leaves) == 0:
return None
out = None
elif len(leaves) == 1:
return jax.tree_util.tree_unflatten(treedef, [f"{prefix}"])
out = jax.tree_util.tree_unflatten(treedef, [f"{prefix}"])
else:
return jax.tree_util.tree_unflatten(treedef, [join_key(prefix, str(i)) for i in range(len(leaves))])
out = jax.tree_util.tree_unflatten(treedef, [join_key(prefix, str(i)) for i in range(len(leaves))])

# assert len(jax.tree.leaves(out, is_leaf=is_leaf)) == len(jax.tree.leaves(pytree, is_leaf=is_leaf)), (out, pytree)
return out


def join_key(prefix, k):
Expand Down

0 comments on commit bf13f12

Please sign in to comment.