From dbdc2e49cd9cc373baec123f4e1ee0cb58531346 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 07:49:41 -0800 Subject: [PATCH] fix resume bookkeeping logic --- src/levanter/store/cache.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index cd59aef4c..51679921d 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -992,6 +992,11 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): ) copy_refs[group] = last_ref + if group == first_group: + # this is the first group, so it's already in the cache and we don't need to + # increment the data offset tree etc. + continue + # update the offset information: data offsets and total rows this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True) data_offset_tree = jax.tree.map( @@ -1045,14 +1050,15 @@ def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) return {shard_name: [shard_name] for shard_name in source.shard_names} shard_names = source.shard_names - num_shards_per_group = len(shard_names) // num_groups + num_shards_per_group = (len(shard_names) + num_groups - 1) // num_groups # if we have a remainder, we'll just add it to the last group out_groups = { f"group_{i}": list(shard_names[i * num_shards_per_group : (i + 1) * num_shards_per_group]) for i in range(num_groups) } - if len(shard_names) % num_shards_per_group != 0: - out_groups[f"group_{num_groups - 1}"].extend(shard_names[num_groups * num_shards_per_group :]) + + # make sure we got all the shards + assert sum(len(shards) for shards in out_groups.values()) == len(shard_names) return out_groups # type: ignore