Skip to content

Commit

Permalink
was updating too many times
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 5, 2024
1 parent fa00824 commit b51a380
Showing 1 changed file with 28 additions and 19 deletions.
47 changes: 28 additions & 19 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,10 +933,17 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):

write_refs[group_name] = ref

# now we start copying the temporary caches to the output cache, in order. (essentially concatenating them)

ledger = _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, group_ledgers,
group_cache_paths, processor, processor_ref)
ledger = _start_copies(
parent,
cache_dir,
shard_groups,
first_group,
write_refs,
group_ledgers,
group_cache_paths,
processor,
processor_ref,
)

ledger.is_finished = True
ledger._serialize_and_commit(cache_dir)
Expand All @@ -946,8 +953,17 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger):
_clean_up_temp_caches(temporary_cache_paths)


def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, group_ledgers, group_cache_paths, processor,
processor_ref):
def _start_copies(
parent,
cache_dir,
shard_groups,
first_group,
write_refs,
group_ledgers,
group_cache_paths,
processor,
processor_ref,
):
"""
Copy the temporary caches to the output cache, in order. (essentially concatenating them)
Expand All @@ -961,6 +977,9 @@ def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, grou
group_cache_paths: a dict mapping group names to the paths of the temporary caches
processor: the processor object
processor_ref: a ray.ObjectRef of the processor object
Returns:
The final ledger
"""
# This logic is a bit hairy thanks to resumes.
# First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these
Expand Down Expand Up @@ -1020,7 +1039,9 @@ def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, grou
if found_one_to_copy and shards_copied > 0:
raise RuntimeError("A previous group was copied, but this group was not. This should never happen.")
elif shards_copied == len(shard_groups[group]):
assert overall_ledger.total_num_rows >= total_rows_from_caches, f"{overall_ledger.total_num_rows} < {total_rows_from_caches}. {group}"
assert (
overall_ledger.total_num_rows >= total_rows_from_caches
), f"{overall_ledger.total_num_rows} < {total_rows_from_caches}. {group}"
continue # nothing to do
elif shards_copied > 0:
# In theory we can handle this, but it's a bit tricky, so we're going to punt for now
Expand Down Expand Up @@ -1051,18 +1072,6 @@ def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, grou
)
total_rows_from_caches += this_ledger.total_num_rows

# this little bit is totally unnecessary but nice logging
for group in shard_groups:
if group == first_group:
continue

if copy_refs.get(group) is not None:
ledger = ray.get(copy_refs[group])
group_ledgers[group] = ledger
parent._report_copy_progress.remote(
_ProgressReport(new_shards=len(ledger.finished_shards), new_rows=ledger.total_num_rows)
)

# refs form a linked list implicitly, so we can just wait on the last one
if last_ref is not None:
ledger = ray.get(last_ref)
Expand Down

0 comments on commit b51a380

Please sign in to comment.