diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index e7a5306d4..0e0eec7ae 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -1023,8 +1023,6 @@ def _start_copies( copy_refs: dict[str, ray.ObjectRef] = {} last_ref: ray.ObjectRef | None = None - found_one_to_copy = False - for group in shard_groups: # first make sure it's either done this run or already done if write_refs.get(group) is not None: @@ -1043,20 +1041,21 @@ def _start_copies( assert this_ledger is not None # see if we already copied this group, meaning all the shards are in the permanent cache - shards_copied = sum(1 if shard in overall_ledger.finished_shards else 0 for shard in shard_groups[group]) + shards_copied = [shard for shard in shard_groups[group] if shard in overall_ledger.finished_shards] - if found_one_to_copy and shards_copied > 0: - raise RuntimeError("A previous group was copied, but this group was not. This should never happen.") - elif shards_copied == len(shard_groups[group]): + if len(shards_copied) == len(shard_groups[group]): assert ( overall_ledger.total_num_rows >= total_rows_from_caches ), f"{overall_ledger.total_num_rows} < {total_rows_from_caches}. {group}" continue # nothing to do - elif shards_copied > 0: - # In theory we can handle this, but it's a bit tricky, so we're going to punt for now - raise RuntimeError("Some shards were copied but not all. This should never happen.") + elif len(shards_copied) > 0: + # In theory, we can handle this, but it's a bit tricky, so we're going to punt for now + raise RuntimeError( + "Some shards were copied but not all. This should never happen." + f"Specifically the following shards were copied: {shards_copied}" + f"And the following shards were not: {set(shard_groups[group]) - set(shards_copied)}" + ) - found_one_to_copy = True # we need to copy this group # we can't "commit" the group to the ledger (or the number of rows)