Skip to content

Commit

Permalink
cap the size of the core writer task rather than the number of batches
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Oct 11, 2024
1 parent 51f9bf1 commit c8d230c
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 46 deletions.
118 changes: 74 additions & 44 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import deepdiff
import fsspec.core
import humanfriendly
import jax
import pyarrow as pa
import ray
Expand Down Expand Up @@ -77,12 +78,16 @@ class CacheOptions:
process. Lower values will use less memory but take somewhat longer to build the cache."""

# the below options don't actually impact the cache's result, but do impact construction
num_batches_per_flush = 256
"""The number of batches to process before flushing the cache to disk. This is used to control the memory usage of
the cache building process. Lower values will use less memory but may take somewhat longer to build the cache."""
target_size_per_flush: int | str = "512MB"
"""The number of bytes to buffer before flushing to disk. This is used to control the memory usage of the cache
building process. Lower values will use less memory but could take somewhat longer to build the cache."""
prefetch_per_group: int = 4
"""The number of batches to prefetch per group. This is used to keep the processors busy and to reduce the time"""

@property
def target_bytes_per_flush(self):
return humanfriendly.parse_size(self.target_size_per_flush)

@staticmethod
def default():
return CacheOptions()
Expand Down Expand Up @@ -684,6 +689,10 @@ def write_batch(self, shard_name: str, batch: BatchResult):
def finish(self):
# if successful, write the ledger
logger.info("Finished writing cache")
# check that all shards are finished
if set(self._ledger.shard_rows.keys()) != set(self._ledger.finished_shards):
raise ValueError("Not all shards are finished")

self._ledger.is_finished = True
self._ledger._serialize_and_commit(self.cache_dir)
if self._on_write:
Expand Down Expand Up @@ -755,8 +764,13 @@ def __init__(
self.logger.info("Cache already finished. Nothing to do.")
return
self._cache_writer = _core_writer_task.options(
name=f"writer::{path_for_name}", scheduling_strategy="SPREAD"
).remote(current_actor_handle(), cache_dir, split, self._ledger, source, processor, force_flush)
name=f"writer::{path_for_name}",
scheduling_strategy="SPREAD",
# memory needed for the writer is twice the options' target size per flush
# (we get twice from we need to concatenate prepared batches into the accumulator)
# TODO: measure.
memory=2 * self._options.target_bytes_per_flush,
).remote(current_actor_handle(), cache_dir, self._ledger, source, processor, force_flush)
except Exception:
# Ray behaves poorly if the constructor of an actor fails, so we catch and log here
# this also propagates to the finished promise, so we can handle it there
Expand All @@ -767,14 +781,6 @@ def current_ledger(self):
raise self._finished_promise.exception()
return self._ledger

def other_failed(self, error: ExceptionInfo):
"""Callback method for when a shard worker has failed."""
self._writer_exception(None, error)

def _child_failed(self, child: ray.actor.ActorHandle, exception: ExceptionInfo):
self.logger.error(f"Child {child} failed with exception", exc_info=exception.restore())
self._writer_exception(None, exception)

def is_finished(self):
if self.failed():
return False
Expand Down Expand Up @@ -872,13 +878,6 @@ def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheO
# a stream of tokenized batches. We then interleave these tokenized batches and write them to the cache.
# The reader tasks are given a group of shards, which are implicitly concatenated together.

# This is still much slower than I would like but I haven't figured out why yet.
# TODO:
# - [ ] Profile the tokenization process more (see TIME comments)
# - [ ] Try Ray's autoscaling actorpool if the issue is tokenization isn't fast enough
# https://github.com/ray-project/ray/blob/1bab09bf842edee51c3778be4cfb16f8b900d764/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py
# - [ ] More observability into what's queued and how long work items take


@dataclass
class _Batch:
Expand Down Expand Up @@ -907,14 +906,12 @@ class _ShardFinished:
"""

_TIME_BETWEEN_WRITES = 20.0 # seconds
_MIN_WRITE_BATCHES = 100


@ray.remote(num_cpus=1)
def _core_writer_task(
parent,
cache_dir,
split,
initial_ledger: CacheLedger,
source: ShardedDataSource,
processor,
Expand Down Expand Up @@ -948,11 +945,11 @@ def on_write(ledger):
options = initial_ledger.metadata.options
num_groups = min(options.num_shard_groups or 1000000, len(source.shard_names))

processor_pool = _mk_processor_pool(split, processor, 0, num_groups * 4)
processor_pool = _mk_processor_pool(processor, 0, num_groups * 4)

interleave: RayPrefetchQueue = RayPrefetchQueue(
lambda: _make_interleave(name, source, initial_ledger, processor_pool),
512,
64,
producer_options={"num_cpus": 1, "name": f"{name}::interleave"},
)

Expand All @@ -962,7 +959,8 @@ def on_write(ledger):
flush_time = Stopwatch()
flush_amortized_time = Stopwatch()

batches: list = []
current_prepared_batch: Optional[PyTree[PreparedBatch]] = None
current_shard_rows: dict[str, int] = {}
time_of_last_write = time.time()
batches_total = 0.0
flush_thread = None
Expand All @@ -975,22 +973,36 @@ def on_write(ledger):
time_since_last_write = cur_time - time_of_last_write
remaining_time = _TIME_BETWEEN_WRITES - time_since_last_write

if len(batches) > 0:
if current_prepared_batch is not None:
with flush_amortized_time: # 6e-4
if remaining_time <= 0 or len(batches) >= options.num_batches_per_flush or force_flush:
current_byte_size = sum(
b.byte_size for b in jax.tree_util.tree_flatten(current_prepared_batch)[0]
)
should_flush = (
force_flush
or remaining_time <= 0
or (current_byte_size >= options.target_bytes_per_flush)
)
if should_flush:
with flush_time: # 0.613s
shard_rows, payloads = _fetch_batches(batches)
if flush_thread is not None:
flush_thread.join()

batches = []
flush_thread = ExceptionTrackingThread(
target=_write_batches,
args=(sharded_cache_writer, shard_rows, payloads, finished_shards_last_flush),
args=(
sharded_cache_writer,
current_shard_rows,
current_prepared_batch,
finished_shards_last_flush,
),
)
flush_thread.start()

current_prepared_batch = None
current_shard_rows = {}
finished_shards_last_flush = []

time_of_last_write = time.time()
continue
else:
Expand All @@ -1005,18 +1017,30 @@ def on_write(ledger):

with append_time:
match message:
case _Batch(shard, _, payload):
case _Batch(shard, row_indices, payload):
batches_total += 1
batches.append((shard, payload))
this_prepared_batch = ray.get(payload)
if current_prepared_batch is None:
# TODO: actually check row indices
current_shard_rows = {shard: len(row_indices)}
current_prepared_batch = this_prepared_batch
else:
current_shard_rows[shard] = current_shard_rows.get(shard, 0) + len(row_indices)
current_prepared_batch = _concat_prepared_batches(
current_prepared_batch, this_prepared_batch
)
del this_prepared_batch

if force_flush:
shard_rows, payloads = _fetch_batches(batches)
del batches
_write_batches(
sharded_cache_writer, shard_rows, payloads, finished_shards_last_flush
sharded_cache_writer,
current_shard_rows,
current_prepared_batch,
finished_shards_last_flush,
)
batches = []
finished_shards_last_flush = []
current_prepared_batch = None
current_shard_rows = {}

case _ShardFinished(shard, total_rows):
finished_shards_last_flush.append((shard, total_rows))
Expand All @@ -1038,23 +1062,29 @@ def on_write(ledger):
raise e

# force a flush
if len(batches) > 0:
shard_row_totals, payloads_for_batches = _fetch_batches(batches)
del batches
if current_prepared_batch is not None or finished_shards_last_flush:
if flush_thread is not None:
flush_thread.join()
_write_batches(sharded_cache_writer, shard_row_totals, payloads_for_batches, finished_shards_last_flush)
_write_batches(
sharded_cache_writer, current_shard_rows, current_prepared_batch, finished_shards_last_flush
)

sharded_cache_writer.finish()

out = sharded_cache_writer.get_ledger()
return out


def _write_batches(writer: ShardedCacheWriter, shard_totals, batches, finished_shards):
def _concat_prepared_batches(
current_prepared_batch: PyTree[PreparedBatch], this_prepared_batch: PyTree[PreparedBatch]
):
return jax.tree.map(lambda *bs: PreparedBatch.concat(bs), current_prepared_batch, this_prepared_batch)


def _write_batches(writer: ShardedCacheWriter, shard_totals, batch: Optional[PyTree[PreparedBatch]], finished_shards):
# concatenate the payloads
final_payload = jax.tree.map(lambda *bs: PreparedBatch.concat(bs), *batches)
writer.write_prepared_batch(shard_totals, final_payload)
if batch is not None:
writer.write_prepared_batch(shard_totals, batch)

for shard, total_rows in finished_shards:
writer.finish_shard(shard, total_rows)
Expand Down Expand Up @@ -1215,7 +1245,7 @@ def generator():
yield from _interleave_shards(readers, first_group_to_start)


def _mk_processor_pool(split, processor, min_size, max_size):
def _mk_processor_pool(processor, min_size, max_size):
import hashlib

metadata_hash = hashlib.md5(str(processor.metadata).encode()).hexdigest()
Expand Down
4 changes: 4 additions & 0 deletions src/levanter/store/jagged_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ class PreparedBatch:
offsets: np.ndarray
shapes: Optional[np.ndarray]

@property
def byte_size(self):
return self.data.nbytes + self.offsets.nbytes + (self.shapes.nbytes if self.shapes is not None else 0)

def astype(self, dtype):
return PreparedBatch(self.data.astype(dtype), self.offsets, self.shapes)

Expand Down
8 changes: 6 additions & 2 deletions tests/test_new_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ def metadata(self) -> Dict[str, Any]:


class SimpleShardSource(ShardedDataSource[list[int]]):
def __init__(self, num_shards: int = 4):
def __init__(self, num_shards: int = 4, rows_per_shard: int = 10):
self._num_shards = num_shards
self._rows_per_shard = rows_per_shard

@property
def shard_names(self) -> Sequence[str]:
Expand All @@ -107,7 +108,7 @@ def shard_names(self) -> Sequence[str]:
def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]:
# parse the shard name to get the shard number
shard_num = int(shard_name.split("_")[1])
return ([shard_num * 10 + i] * 10 for i in range(row, 10))
return ([shard_num * 10 + i] * 10 for i in range(row, self._rows_per_shard))


def test_serial_cache_writer():
Expand Down Expand Up @@ -465,6 +466,9 @@ def test_sharded_cache_writer():
for ex in batched(source.open_shard(shard_name), ledger.metadata.options.batch_size):
writer.write_batch(shard_name, processor(ex))

for shard_name in source.shard_names:
writer.finish_shard(shard_name, source._rows_per_shard)

store = writer.finish()

data_path = store.path
Expand Down

0 comments on commit c8d230c

Please sign in to comment.