From c8d230cd62269d1635de679505d56577db93fa40 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 11 Oct 2024 13:57:34 -0700 Subject: [PATCH] cap the size of the core writer task rather than the number of batches --- src/levanter/store/cache.py | 118 ++++++++++++++++++----------- src/levanter/store/jagged_array.py | 4 + tests/test_new_cache.py | 8 +- 3 files changed, 84 insertions(+), 46 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 218bce95c..45265c994 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -16,6 +16,7 @@ import deepdiff import fsspec.core +import humanfriendly import jax import pyarrow as pa import ray @@ -77,12 +78,16 @@ class CacheOptions: process. Lower values will use less memory but take somewhat longer to build the cache.""" # the below options don't actually impact the cache's result, but do impact construction - num_batches_per_flush = 256 - """The number of batches to process before flushing the cache to disk. This is used to control the memory usage of - the cache building process. Lower values will use less memory but may take somewhat longer to build the cache.""" + target_size_per_flush: int | str = "512MB" + """The number of bytes to buffer before flushing to disk. This is used to control the memory usage of the cache + building process. Lower values will use less memory but could take somewhat longer to build the cache.""" prefetch_per_group: int = 4 """The number of batches to prefetch per group. This is used to keep the processors busy and to reduce the time""" + @property + def target_bytes_per_flush(self): + return humanfriendly.parse_size(self.target_size_per_flush) + @staticmethod def default(): return CacheOptions() @@ -684,6 +689,10 @@ def write_batch(self, shard_name: str, batch: BatchResult): def finish(self): # if successful, write the ledger logger.info("Finished writing cache") + # check that all shards are finished + if set(self._ledger.shard_rows.keys()) != set(self._ledger.finished_shards): + raise ValueError("Not all shards are finished") + self._ledger.is_finished = True self._ledger._serialize_and_commit(self.cache_dir) if self._on_write: @@ -755,8 +764,13 @@ def __init__( self.logger.info("Cache already finished. Nothing to do.") return self._cache_writer = _core_writer_task.options( - name=f"writer::{path_for_name}", scheduling_strategy="SPREAD" - ).remote(current_actor_handle(), cache_dir, split, self._ledger, source, processor, force_flush) + name=f"writer::{path_for_name}", + scheduling_strategy="SPREAD", + # memory needed for the writer is twice the options' target size per flush + # (we get twice from we need to concatenate prepared batches into the accumulator) + # TODO: measure. + memory=2 * self._options.target_bytes_per_flush, + ).remote(current_actor_handle(), cache_dir, self._ledger, source, processor, force_flush) except Exception: # Ray behaves poorly if the constructor of an actor fails, so we catch and log here # this also propagates to the finished promise, so we can handle it there @@ -767,14 +781,6 @@ def current_ledger(self): raise self._finished_promise.exception() return self._ledger - def other_failed(self, error: ExceptionInfo): - """Callback method for when a shard worker has failed.""" - self._writer_exception(None, error) - - def _child_failed(self, child: ray.actor.ActorHandle, exception: ExceptionInfo): - self.logger.error(f"Child {child} failed with exception", exc_info=exception.restore()) - self._writer_exception(None, exception) - def is_finished(self): if self.failed(): return False @@ -872,13 +878,6 @@ def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheO # a stream of tokenized batches. We then interleave these tokenized batches and write them to the cache. # The reader tasks are given a group of shards, which are implicitly concatenated together. -# This is still much slower than I would like but I haven't figured out why yet. -# TODO: -# - [ ] Profile the tokenization process more (see TIME comments) -# - [ ] Try Ray's autoscaling actorpool if the issue is tokenization isn't fast enough -# https://github.com/ray-project/ray/blob/1bab09bf842edee51c3778be4cfb16f8b900d764/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py -# - [ ] More observability into what's queued and how long work items take - @dataclass class _Batch: @@ -907,14 +906,12 @@ class _ShardFinished: """ _TIME_BETWEEN_WRITES = 20.0 # seconds -_MIN_WRITE_BATCHES = 100 @ray.remote(num_cpus=1) def _core_writer_task( parent, cache_dir, - split, initial_ledger: CacheLedger, source: ShardedDataSource, processor, @@ -948,11 +945,11 @@ def on_write(ledger): options = initial_ledger.metadata.options num_groups = min(options.num_shard_groups or 1000000, len(source.shard_names)) - processor_pool = _mk_processor_pool(split, processor, 0, num_groups * 4) + processor_pool = _mk_processor_pool(processor, 0, num_groups * 4) interleave: RayPrefetchQueue = RayPrefetchQueue( lambda: _make_interleave(name, source, initial_ledger, processor_pool), - 512, + 64, producer_options={"num_cpus": 1, "name": f"{name}::interleave"}, ) @@ -962,7 +959,8 @@ def on_write(ledger): flush_time = Stopwatch() flush_amortized_time = Stopwatch() - batches: list = [] + current_prepared_batch: Optional[PyTree[PreparedBatch]] = None + current_shard_rows: dict[str, int] = {} time_of_last_write = time.time() batches_total = 0.0 flush_thread = None @@ -975,22 +973,36 @@ def on_write(ledger): time_since_last_write = cur_time - time_of_last_write remaining_time = _TIME_BETWEEN_WRITES - time_since_last_write - if len(batches) > 0: + if current_prepared_batch is not None: with flush_amortized_time: # 6e-4 - if remaining_time <= 0 or len(batches) >= options.num_batches_per_flush or force_flush: + current_byte_size = sum( + b.byte_size for b in jax.tree_util.tree_flatten(current_prepared_batch)[0] + ) + should_flush = ( + force_flush + or remaining_time <= 0 + or (current_byte_size >= options.target_bytes_per_flush) + ) + if should_flush: with flush_time: # 0.613s - shard_rows, payloads = _fetch_batches(batches) if flush_thread is not None: flush_thread.join() - batches = [] flush_thread = ExceptionTrackingThread( target=_write_batches, - args=(sharded_cache_writer, shard_rows, payloads, finished_shards_last_flush), + args=( + sharded_cache_writer, + current_shard_rows, + current_prepared_batch, + finished_shards_last_flush, + ), ) flush_thread.start() + current_prepared_batch = None + current_shard_rows = {} finished_shards_last_flush = [] + time_of_last_write = time.time() continue else: @@ -1005,18 +1017,30 @@ def on_write(ledger): with append_time: match message: - case _Batch(shard, _, payload): + case _Batch(shard, row_indices, payload): batches_total += 1 - batches.append((shard, payload)) + this_prepared_batch = ray.get(payload) + if current_prepared_batch is None: + # TODO: actually check row indices + current_shard_rows = {shard: len(row_indices)} + current_prepared_batch = this_prepared_batch + else: + current_shard_rows[shard] = current_shard_rows.get(shard, 0) + len(row_indices) + current_prepared_batch = _concat_prepared_batches( + current_prepared_batch, this_prepared_batch + ) + del this_prepared_batch if force_flush: - shard_rows, payloads = _fetch_batches(batches) - del batches _write_batches( - sharded_cache_writer, shard_rows, payloads, finished_shards_last_flush + sharded_cache_writer, + current_shard_rows, + current_prepared_batch, + finished_shards_last_flush, ) - batches = [] finished_shards_last_flush = [] + current_prepared_batch = None + current_shard_rows = {} case _ShardFinished(shard, total_rows): finished_shards_last_flush.append((shard, total_rows)) @@ -1038,12 +1062,12 @@ def on_write(ledger): raise e # force a flush - if len(batches) > 0: - shard_row_totals, payloads_for_batches = _fetch_batches(batches) - del batches + if current_prepared_batch is not None or finished_shards_last_flush: if flush_thread is not None: flush_thread.join() - _write_batches(sharded_cache_writer, shard_row_totals, payloads_for_batches, finished_shards_last_flush) + _write_batches( + sharded_cache_writer, current_shard_rows, current_prepared_batch, finished_shards_last_flush + ) sharded_cache_writer.finish() @@ -1051,10 +1075,16 @@ def on_write(ledger): return out -def _write_batches(writer: ShardedCacheWriter, shard_totals, batches, finished_shards): +def _concat_prepared_batches( + current_prepared_batch: PyTree[PreparedBatch], this_prepared_batch: PyTree[PreparedBatch] +): + return jax.tree.map(lambda *bs: PreparedBatch.concat(bs), current_prepared_batch, this_prepared_batch) + + +def _write_batches(writer: ShardedCacheWriter, shard_totals, batch: Optional[PyTree[PreparedBatch]], finished_shards): # concatenate the payloads - final_payload = jax.tree.map(lambda *bs: PreparedBatch.concat(bs), *batches) - writer.write_prepared_batch(shard_totals, final_payload) + if batch is not None: + writer.write_prepared_batch(shard_totals, batch) for shard, total_rows in finished_shards: writer.finish_shard(shard, total_rows) @@ -1215,7 +1245,7 @@ def generator(): yield from _interleave_shards(readers, first_group_to_start) -def _mk_processor_pool(split, processor, min_size, max_size): +def _mk_processor_pool(processor, min_size, max_size): import hashlib metadata_hash = hashlib.md5(str(processor.metadata).encode()).hexdigest() diff --git a/src/levanter/store/jagged_array.py b/src/levanter/store/jagged_array.py index d6674774f..1013d0e34 100644 --- a/src/levanter/store/jagged_array.py +++ b/src/levanter/store/jagged_array.py @@ -29,6 +29,10 @@ class PreparedBatch: offsets: np.ndarray shapes: Optional[np.ndarray] + @property + def byte_size(self): + return self.data.nbytes + self.offsets.nbytes + (self.shapes.nbytes if self.shapes is not None else 0) + def astype(self, dtype): return PreparedBatch(self.data.astype(dtype), self.offsets, self.shapes) diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index 275c6a236..35ba8999e 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -97,8 +97,9 @@ def metadata(self) -> Dict[str, Any]: class SimpleShardSource(ShardedDataSource[list[int]]): - def __init__(self, num_shards: int = 4): + def __init__(self, num_shards: int = 4, rows_per_shard: int = 10): self._num_shards = num_shards + self._rows_per_shard = rows_per_shard @property def shard_names(self) -> Sequence[str]: @@ -107,7 +108,7 @@ def shard_names(self) -> Sequence[str]: def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # parse the shard name to get the shard number shard_num = int(shard_name.split("_")[1]) - return ([shard_num * 10 + i] * 10 for i in range(row, 10)) + return ([shard_num * 10 + i] * 10 for i in range(row, self._rows_per_shard)) def test_serial_cache_writer(): @@ -465,6 +466,9 @@ def test_sharded_cache_writer(): for ex in batched(source.open_shard(shard_name), ledger.metadata.options.batch_size): writer.write_batch(shard_name, processor(ex)) + for shard_name in source.shard_names: + writer.finish_shard(shard_name, source._rows_per_shard) + store = writer.finish() data_path = store.path