Skip to content

Commit

Permalink
Make new tokenization ~67% faster (#744)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Sep 25, 2024
1 parent 91be677 commit cd82fb3
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 9 deletions.
2 changes: 1 addition & 1 deletion scripts/launch_gpt2_small_fast_tpu.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Launches the "gpt_small_fast" model on a TPU node

python infra/launch.py --foreground --tpu_name levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \
python infra/launch.py --foreground --tpu_name $(whoami)-levanter-itest-32 --zone us-central2-b --tpu_type v4-32 --preemptible -- \
python -m levanter.main.train_lm \
--config_path config/gpt2_small_fast.yaml \
--trainer.checkpointer.base_path gs://levanter-checkpoints/gpt-itest/ --trainer.checkpointer.save_interval 30m $*
45 changes: 37 additions & 8 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ def __init__(
# double check that we're not finished by committing the ledger
self._attempt_to_write_batches()

if not self._ledger.is_finished:
self._actual_writer_thread = threading.Thread(target=self._write_loop, daemon=True)
self._stop_loop = threading.Event()
self._actual_writer_thread.start()

def batch_finished(self, shard_name: str, shard_batch_idx: int, batch_result_box):
with log_failures_to(self._parent):
if self._failed:
Expand All @@ -286,7 +291,6 @@ def batch_finished(self, shard_name: str, shard_batch_idx: int, batch_result_box
# we need to keep track of the order of the batches so that we can write them out in order
self._total_queue_length += len(batch_result)
self._batch_queue.append_to_group(shard_name, shard_batch_idx, batch_result)
self._attempt_to_write_batches()
next_missing_item = self._batch_queue.next_missing_item_index()

overwhelmed = self.is_overwhelmed()
Expand All @@ -303,6 +307,7 @@ def batch_finished(self, shard_name: str, shard_batch_idx: int, batch_result_box
def shard_failed(self, shard_name: str, batch_id: int, exc_info: ExceptionInfo):
with log_failures_to(self._parent):
self._failed = True
self._stop_loop.set()
logger.error(f"Shard {shard_name} failed at batch {batch_id}", exc_info=exc_info.restore())
self._parent.shard_failed.remote(shard_name, exc_info)

Expand All @@ -314,7 +319,10 @@ def shard_finished_reading(self, shard_name: str, expected_num_rows: int):
logger.debug(
f"Attempting to write batches because {shard_name} finished reading with {expected_num_rows} batches."
)
self._attempt_to_write_batches()
self.flush()

def flush(self):
self._attempt_to_write_batches()

def get_shard_status(self, shard_name: str):
with log_failures_to(self._parent):
Expand All @@ -327,7 +335,7 @@ def get_ledger(self):

def _attempt_to_write_batches(self):
if self._ledger.is_finished:
raise RuntimeError("Trying to write batches after cache is finished")
return

if self._failed:
logger.warning("Not writing batches because of failure.")
Expand Down Expand Up @@ -361,6 +369,22 @@ def _attempt_to_write_batches(self):

ray.wait(futures_to_await + futures_to_await_shards)

def _finish(self):
self._stop_loop.set()
self._actual_writer_thread.join()

def _write_loop(self):
while True:
try:
self._stop_loop.wait(1)
if self._stop_loop.is_set():
break
except TimeoutError:
pass
self._attempt_to_write_batches()
if self._ledger.is_finished:
break

def _dequeue_ready_batches(self):
for shard, batch in self._batch_queue.drain():
logger.debug(f"Writing batch for {shard}")
Expand Down Expand Up @@ -422,6 +446,9 @@ def is_overwhelmed(self) -> bool:
max_queue_size = self._min_items_to_write * 3
return self._total_queue_length > max_queue_size

def __del__(self):
self._finish()


def _to_list_of_dicts(batch: dict) -> List[dict]:
"""
Expand Down Expand Up @@ -940,16 +967,16 @@ async def get_batch(self, indices: Sequence[int] | slice):

return await self.store.get_batch(indices)

async def _wait_for_len(self, needed_len):
async def _wait_for_len(self, needed_len: int):
if self._broker is not None:
while needed_len > await self.current_len():
new_ledger = await self._broker.updated_ledger.remote()
new_ledger: CacheLedger = await self._broker.updated_ledger.remote()

if needed_len <= new_ledger.total_num_rows:
break

if new_ledger.is_finished:
if needed_len >= new_ledger.rows_finished:
if needed_len >= new_ledger.total_num_rows:
raise IndexError(
f"Index {needed_len} out of bounds for cache of size {new_ledger.total_num_rows}"
)
Expand All @@ -967,15 +994,17 @@ def _wait_for_len_sync(self, needed_len, timeout: Optional[float] = None):
if cur_time > t_max:
raise TimeoutError(f"Timed out waiting for cache to reach {needed_len}")
try:
new_ledger = ray.get(self._broker.updated_ledger.remote(), timeout=max(t_max - cur_time, 10))
new_ledger: CacheLedger = ray.get(
self._broker.updated_ledger.remote(), timeout=max(t_max - cur_time, 10)
)
except TimeoutError:
continue

if needed_len <= new_ledger.total_num_rows:
break

if new_ledger.is_finished:
if needed_len >= new_ledger.rows_finished:
if needed_len >= new_ledger.total_num_rows:
raise IndexError(
f"Index {needed_len} out of bounds for cache of size {new_ledger.total_num_rows}"
)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_new_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ async def test_batch_finished():
batch_result = [np.array([1, 2, 3])]

await writer.batch_finished.remote(shard_idx, shard_batch_idx, batch_result)
await writer.flush.remote()
shard_status = await writer.get_shard_status.remote("shard1")
assert shard_status.num_rows_committed == 1
finally:
Expand Down Expand Up @@ -307,6 +308,8 @@ async def test_attempt_to_write_batches():
await writer.batch_finished.remote("shard1", 0, shard1_batch)
await writer.batch_finished.remote("shard2", 0, shard2_batch)

await writer.flush.remote()

ledger = await writer.get_ledger.remote()
assert ledger.is_finished is False
assert ledger.total_num_rows == 2 # Assuming each batch has 1 row for simplicity
Expand Down Expand Up @@ -336,6 +339,7 @@ async def test_finalize_cache():
await writer.shard_finished_reading.remote("shard1", 1)
await writer.shard_finished_reading.remote("shard2", 1)
await writer.batch_finished.remote("shard2", 0, shard2_batch)
await writer.flush.remote()

ledger = await writer.get_ledger.remote()
assert ledger.is_finished is False
Expand Down Expand Up @@ -390,6 +394,7 @@ async def test_out_of_order_batches_same_shard():

await writer.batch_finished.remote("shard1", 1, shard1_batch1)
await writer.batch_finished.remote("shard1", 0, shard1_batch0)
await writer.flush.remote()

store = TreeStore.open(exemplar, cache_dir, mode="r")
assert len(store) == 2
Expand Down Expand Up @@ -419,6 +424,7 @@ async def test_out_of_order_batches_different_shards():
await writer.batch_finished.remote("shard1", 1, shard1_batch1)
await writer.batch_finished.remote("shard2", 0, shard2_batch0)
await writer.batch_finished.remote("shard1", 0, shard1_batch0)
await writer.flush.remote()

store = TreeStore.open(exemplar, cache_dir, mode="r")
assert len(store) == 3
Expand Down Expand Up @@ -451,6 +457,7 @@ async def test_batches_different_orders_all_shards():
await writer.batch_finished.remote("shard3", 0, shard3_batch0)
await writer.batch_finished.remote("shard1", 1, shard1_batch1)
await writer.batch_finished.remote("shard1", 0, shard1_batch0)
await writer.flush.remote()

store = TreeStore.open(exemplar, cache_dir, mode="r")
assert len(store) == 4
Expand Down Expand Up @@ -486,6 +493,7 @@ async def test_intermixed_batches_same_and_different_shards():
await writer.batch_finished.remote("shard1", 1, shard1_batch1)
await writer.batch_finished.remote("shard2", 1, shard2_batch1)
await writer.batch_finished.remote("shard1", 0, shard1_batch0)
await writer.flush.remote()

store = TreeStore.open(exemplar, cache_dir, mode="r")
assert len(store) == 5
Expand All @@ -512,6 +520,7 @@ async def test_duplicate_batches_same_shard():
shard1_batch0 = [np.array([1, 2, 3])]

await writer.batch_finished.remote("shard1", 0, shard1_batch0)
await writer.flush.remote()
with pytest.raises(RayTaskError):
await writer.batch_finished.remote("shard1", 0, shard1_batch0) # Duplicate
finally:
Expand Down Expand Up @@ -544,6 +553,7 @@ async def test_mixed_order_batches_multiple_shards():
await writer.batch_finished.remote("shard2", 1, shard2_batch1)
await writer.batch_finished.remote("shard1", 0, shard1_batch0)
await writer.batch_finished.remote("shard3", 1, shard3_batch1)
await writer.flush.remote()

store = TreeStore.open(exemplar, cache_dir, mode="r")
assert len(store) == 6
Expand Down Expand Up @@ -892,10 +902,12 @@ async def test_backpressure_mechanism():
await writer.batch_finished.remote("shard1", 1, shard3_batch)
await writer.batch_finished.remote("shard1", 2, shard3_batch)
await writer.batch_finished.remote("shard1", 3, shard3_batch)
await writer.flush.remote()

# Check if backpressure is signaled
is_overwhelmed = await writer.is_overwhelmed.remote()
assert is_overwhelmed is True
await writer.flush.remote()

for i in range(4):
if (await parent.desired_next_item.remote()) == 0:
Expand All @@ -910,6 +922,12 @@ async def test_backpressure_mechanism():
# Reduce the queue size to relieve backpressure
# Check if backpressure is relieved
is_overwhelmed = await writer.is_overwhelmed.remote()
count = 0
while is_overwhelmed and count < 10:
await writer.flush.remote()
await asyncio.sleep(0.4)
is_overwhelmed = await writer.is_overwhelmed.remote()
count += 1
assert is_overwhelmed is False

for i in range(4):
Expand Down

0 comments on commit cd82fb3

Please sign in to comment.