Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Oct 14, 2024
1 parent ab543d6 commit 6395305
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/levanter/tensorstore_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def tree_deserialize_leaves_tensorstore(
# TODO: support ShapeDtypeStructs that are not NamedArrays
leaf_key_paths = jax_utils.leaf_key_paths(shardings, is_leaf=is_named_array)
paths = _fs_paths_from_key_paths(checkpoint_dir, leaf_key_paths)
paths = jtu.tree_leaves(paths)
paths = jtu.tree_leaves(paths, is_leaf=lambda x: x is None)

shardings_leaves, shardings_structure = jtu.tree_flatten(shardings)

Expand Down

0 comments on commit 6395305

Please sign in to comment.