From ab7dc65c267aba4ae39d2ea51fc64c05131b4221 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 8 Jan 2025 16:11:53 -0800 Subject: [PATCH] fix serialization of partitioned states in optax (#852) --- src/levanter/tensorstore_serialization.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index 82471e5df..fc9155cd1 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -95,7 +95,7 @@ def path_from_key_path(key_path): def _sharding_from_leaf(leaf, axis_mapping, mesh) -> Optional[jax.sharding.Sharding]: if is_named_array(leaf): - if leaf.array is None: + if not is_jax_array_like(leaf.array): return None return hax.partitioning.sharding_for_axis(leaf.axes, axis_mapping, mesh) elif hasattr(leaf, "sharding") and getattr(leaf, "sharding") is not None: @@ -140,11 +140,11 @@ def tree_deserialize_leaves_tensorstore( manager = array_ser.GlobalAsyncCheckpointManager() shardings: PyTree[Optional[Sharding]] = jtu.tree_map( - partial(_sharding_from_leaf, axis_mapping=axis_mapping, mesh=mesh), pytree, is_leaf=is_named_array + partial(_sharding_from_leaf, axis_mapping=axis_mapping, mesh=mesh), pytree, is_leaf=_is_named_or_none ) # TODO: support ShapeDtypeStructs that are not NamedArrays - leaf_key_paths = jax_utils.leaf_key_paths(shardings, is_leaf=is_named_array) + leaf_key_paths = jax_utils.leaf_key_paths(shardings, is_leaf=_is_named_or_none) paths = _fs_paths_from_key_paths(checkpoint_dir, leaf_key_paths) paths = jtu.tree_leaves(paths, is_leaf=lambda x: x is None) @@ -157,6 +157,8 @@ def tree_deserialize_leaves_tensorstore( real_leaves = [x for x in shardings_leaves if x is not None] real_paths = [paths[i] for i in real_indices] + assert len(real_leaves) == len(real_paths), f"{len(real_leaves)} != {len(real_paths)}" + deser_leaves = manager.deserialize_with_paths(shardings=real_leaves, paths=real_paths) # now we need to recreate the original structure