Skip to content

Commit

Permalink
fix resume bookkeeping logic
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 3, 2024
1 parent 47441c0 commit dbdc2e4
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,11 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):
)
copy_refs[group] = last_ref

if group == first_group:
# this is the first group, so it's already in the cache and we don't need to
# increment the data offset tree etc.
continue

# update the offset information: data offsets and total rows
this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True)
data_offset_tree = jax.tree.map(
Expand Down Expand Up @@ -1045,14 +1050,15 @@ def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None)
return {shard_name: [shard_name] for shard_name in source.shard_names}

shard_names = source.shard_names
num_shards_per_group = len(shard_names) // num_groups
num_shards_per_group = (len(shard_names) + num_groups - 1) // num_groups
# if we have a remainder, we'll just add it to the last group
out_groups = {
f"group_{i}": list(shard_names[i * num_shards_per_group : (i + 1) * num_shards_per_group])
for i in range(num_groups)
}
if len(shard_names) % num_shards_per_group != 0:
out_groups[f"group_{num_groups - 1}"].extend(shard_names[num_groups * num_shards_per_group :])

# make sure we got all the shards
assert sum(len(shards) for shards in out_groups.values()) == len(shard_names)

return out_groups # type: ignore

Expand Down

0 comments on commit dbdc2e4

Please sign in to comment.