diff --git a/scripts/launch_gpt2_small_fast_tpu.sh b/scripts/launch_gpt2_small_fast_tpu.sh index 7b2634749..0c09cdcfa 100644 --- a/scripts/launch_gpt2_small_fast_tpu.sh +++ b/scripts/launch_gpt2_small_fast_tpu.sh @@ -1,6 +1,6 @@ # Launches the "gpt_small_fast" model on a TPU node -python infra/launch.py --foreground --tpu_name levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \ +python infra/launch.py --foreground --tpu_name $(whoami)-levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \ python -m levanter.main.train_lm \ --config_path config/gpt2_small_fast.yaml \ --trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $* diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 608019374..56aa54f99 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -272,6 +272,11 @@ def __init__( # double check that we're not finished by committing the ledger self._attempt_to_write_batches() + if not self._ledger.is_finished: + self._actual_writer_thread = threading.Thread(target=self._write_loop, daemon=True) + self._stop_loop = threading.Event() + self._actual_writer_thread.start() + def batch_finished(self, shard_name: str, shard_batch_idx: int, batch_result_box): with log_failures_to(self._parent): if self._failed: @@ -286,7 +291,6 @@ def batch_finished(self, shard_name: str, shard_batch_idx: int, batch_result_box # we need to keep track of the order of the batches so that we can write them out in order self._total_queue_length += len(batch_result) self._batch_queue.append_to_group(shard_name, shard_batch_idx, batch_result) - self._attempt_to_write_batches() next_missing_item = self._batch_queue.next_missing_item_index() overwhelmed = self.is_overwhelmed() @@ -303,6 +307,7 @@ def batch_finished(self, shard_name: str, shard_batch_idx: int, batch_result_box def shard_failed(self, shard_name: str, batch_id: int, exc_info: ExceptionInfo): with log_failures_to(self._parent): self._failed = True + self._stop_loop.set() logger.error(f"Shard {shard_name} failed at batch {batch_id}", exc_info=exc_info.restore()) self._parent.shard_failed.remote(shard_name, exc_info) @@ -314,7 +319,10 @@ def shard_finished_reading(self, shard_name: str, expected_num_rows: int): logger.debug( f"Attempting to write batches because {shard_name} finished reading with {expected_num_rows} batches." ) - self._attempt_to_write_batches() + self.flush() + + def flush(self): + self._attempt_to_write_batches() def get_shard_status(self, shard_name: str): with log_failures_to(self._parent): @@ -327,7 +335,7 @@ def get_ledger(self): def _attempt_to_write_batches(self): if self._ledger.is_finished: - raise RuntimeError("Trying to write batches after cache is finished") + return if self._failed: logger.warning("Not writing batches because of failure.") @@ -361,6 +369,22 @@ def _attempt_to_write_batches(self): ray.wait(futures_to_await + futures_to_await_shards) + def _finish(self): + self._stop_loop.set() + self._actual_writer_thread.join() + + def _write_loop(self): + while True: + try: + self._stop_loop.wait(1) + if self._stop_loop.is_set(): + break + except TimeoutError: + pass + self._attempt_to_write_batches() + if self._ledger.is_finished: + break + def _dequeue_ready_batches(self): for shard, batch in self._batch_queue.drain(): logger.debug(f"Writing batch for {shard}") @@ -422,6 +446,9 @@ def is_overwhelmed(self) -> bool: max_queue_size = self._min_items_to_write * 3 return self._total_queue_length > max_queue_size + def __del__(self): + self._finish() + def _to_list_of_dicts(batch: dict) -> List[dict]: """ @@ -940,16 +967,16 @@ async def get_batch(self, indices: Sequence[int] | slice): return await self.store.get_batch(indices) - async def _wait_for_len(self, needed_len): + async def _wait_for_len(self, needed_len: int): if self._broker is not None: while needed_len > await self.current_len(): - new_ledger = await self._broker.updated_ledger.remote() + new_ledger: CacheLedger = await self._broker.updated_ledger.remote() if needed_len <= new_ledger.total_num_rows: break if new_ledger.is_finished: - if needed_len >= new_ledger.rows_finished: + if needed_len >= new_ledger.total_num_rows: raise IndexError( f"Index {needed_len} out of bounds for cache of size {new_ledger.total_num_rows}" ) @@ -967,7 +994,9 @@ def _wait_for_len_sync(self, needed_len, timeout: Optional[float] = None): if cur_time > t_max: raise TimeoutError(f"Timed out waiting for cache to reach {needed_len}") try: - new_ledger = ray.get(self._broker.updated_ledger.remote(), timeout=max(t_max - cur_time, 10)) + new_ledger: CacheLedger = ray.get( + self._broker.updated_ledger.remote(), timeout=max(t_max - cur_time, 10) + ) except TimeoutError: continue @@ -975,7 +1004,7 @@ def _wait_for_len_sync(self, needed_len, timeout: Optional[float] = None): break if new_ledger.is_finished: - if needed_len >= new_ledger.rows_finished: + if needed_len >= new_ledger.total_num_rows: raise IndexError( f"Index {needed_len} out of bounds for cache of size {new_ledger.total_num_rows}" ) diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index 3302674de..b6132e548 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -216,6 +216,7 @@ async def test_batch_finished(): batch_result = [np.array([1, 2, 3])] await writer.batch_finished.remote(shard_idx, shard_batch_idx, batch_result) + await writer.flush.remote() shard_status = await writer.get_shard_status.remote("shard1") assert shard_status.num_rows_committed == 1 finally: @@ -307,6 +308,8 @@ async def test_attempt_to_write_batches(): await writer.batch_finished.remote("shard1", 0, shard1_batch) await writer.batch_finished.remote("shard2", 0, shard2_batch) + await writer.flush.remote() + ledger = await writer.get_ledger.remote() assert ledger.is_finished is False assert ledger.total_num_rows == 2 # Assuming each batch has 1 row for simplicity @@ -336,6 +339,7 @@ async def test_finalize_cache(): await writer.shard_finished_reading.remote("shard1", 1) await writer.shard_finished_reading.remote("shard2", 1) await writer.batch_finished.remote("shard2", 0, shard2_batch) + await writer.flush.remote() ledger = await writer.get_ledger.remote() assert ledger.is_finished is False @@ -390,6 +394,7 @@ async def test_out_of_order_batches_same_shard(): await writer.batch_finished.remote("shard1", 1, shard1_batch1) await writer.batch_finished.remote("shard1", 0, shard1_batch0) + await writer.flush.remote() store = TreeStore.open(exemplar, cache_dir, mode="r") assert len(store) == 2 @@ -419,6 +424,7 @@ async def test_out_of_order_batches_different_shards(): await writer.batch_finished.remote("shard1", 1, shard1_batch1) await writer.batch_finished.remote("shard2", 0, shard2_batch0) await writer.batch_finished.remote("shard1", 0, shard1_batch0) + await writer.flush.remote() store = TreeStore.open(exemplar, cache_dir, mode="r") assert len(store) == 3 @@ -451,6 +457,7 @@ async def test_batches_different_orders_all_shards(): await writer.batch_finished.remote("shard3", 0, shard3_batch0) await writer.batch_finished.remote("shard1", 1, shard1_batch1) await writer.batch_finished.remote("shard1", 0, shard1_batch0) + await writer.flush.remote() store = TreeStore.open(exemplar, cache_dir, mode="r") assert len(store) == 4 @@ -486,6 +493,7 @@ async def test_intermixed_batches_same_and_different_shards(): await writer.batch_finished.remote("shard1", 1, shard1_batch1) await writer.batch_finished.remote("shard2", 1, shard2_batch1) await writer.batch_finished.remote("shard1", 0, shard1_batch0) + await writer.flush.remote() store = TreeStore.open(exemplar, cache_dir, mode="r") assert len(store) == 5 @@ -512,6 +520,7 @@ async def test_duplicate_batches_same_shard(): shard1_batch0 = [np.array([1, 2, 3])] await writer.batch_finished.remote("shard1", 0, shard1_batch0) + await writer.flush.remote() with pytest.raises(RayTaskError): await writer.batch_finished.remote("shard1", 0, shard1_batch0) # Duplicate finally: @@ -544,6 +553,7 @@ async def test_mixed_order_batches_multiple_shards(): await writer.batch_finished.remote("shard2", 1, shard2_batch1) await writer.batch_finished.remote("shard1", 0, shard1_batch0) await writer.batch_finished.remote("shard3", 1, shard3_batch1) + await writer.flush.remote() store = TreeStore.open(exemplar, cache_dir, mode="r") assert len(store) == 6 @@ -892,10 +902,12 @@ async def test_backpressure_mechanism(): await writer.batch_finished.remote("shard1", 1, shard3_batch) await writer.batch_finished.remote("shard1", 2, shard3_batch) await writer.batch_finished.remote("shard1", 3, shard3_batch) + await writer.flush.remote() # Check if backpressure is signaled is_overwhelmed = await writer.is_overwhelmed.remote() assert is_overwhelmed is True + await writer.flush.remote() for i in range(4): if (await parent.desired_next_item.remote()) == 0: @@ -910,6 +922,12 @@ async def test_backpressure_mechanism(): # Reduce the queue size to relieve backpressure # Check if backpressure is relieved is_overwhelmed = await writer.is_overwhelmed.remote() + count = 0 + while is_overwhelmed and count < 10: + await writer.flush.remote() + await asyncio.sleep(0.4) + is_overwhelmed = await writer.is_overwhelmed.remote() + count += 1 assert is_overwhelmed is False for i in range(4):