diff --git a/config/data/pile_mixture.yaml b/config/data/pile_mixture.yaml index ff75b8941..38b545f66 100644 --- a/config/data/pile_mixture.yaml +++ b/config/data/pile_mixture.yaml @@ -1,5 +1,8 @@ cache_dir: "gs://levanter-data/tokenized/pile-domains/" tokenizer: "EleutherAI/gpt-neox-20b" +cache_options: + batch_size: 32 + num_shard_groups: 16 configs: arxiv: train_urls: @@ -11,11 +14,11 @@ configs: - gs://levanter-data/pile-domains/books2/{00..29}.jsonl.zst validation_urls: - gs://levanter-data/pile-domains/books2/val.jsonl.zst - books3: - train_urls: - - gs://levanter-data/pile-domains/books3/{00..29}.jsonl.zst - validation_urls: - - gs://levanter-data/pile-domains/books3/val.jsonl.zst +# books3: +# train_urls: +# - gs://levanter-data/pile-domains/books3/{00..29}.jsonl.zst +# validation_urls: +# - gs://levanter-data/pile-domains/books3/val.jsonl.zst dm_math: train_urls: - gs://levanter-data/pile-domains/dm_math/{00..29}.jsonl.zst @@ -115,7 +118,7 @@ train_weights: # these weights come from the paper https://arxiv.org/pdf/2101.00027.pdf pile_cc: 0.1811 pubmed_central: 0.1440 - books3: 0.1207 +# books3: 0.1207 owt2: 0.1001 arxiv: 0.0896 github: 0.0759 diff --git a/src/levanter/data/_preprocessor.py b/src/levanter/data/_preprocessor.py index 3c1f77494..573015852 100644 --- a/src/levanter/data/_preprocessor.py +++ b/src/levanter/data/_preprocessor.py @@ -1,8 +1,13 @@ +import logging from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Generic, Iterable, Mapping, Sequence, TypeVar, Union import numpy as np import pyarrow as pa +import ray + +from levanter.utils.actor_pool import AutoScalingActorPool, PoolWorkerBase +from levanter.utils.ray_utils import RefBox T = TypeVar("T") @@ -143,12 +148,12 @@ def rec(dataset): source, transforms, batch_transform = rec(dataset) - batch_size = batch_transform.batch_size if batch_transform is not None else 1024 + # batch_size = batch_transform.batch_size if batch_transform is not None else 1024 cpus = batch_transform.num_cpus if batch_transform is not None else 1 gpus = batch_transform.num_gpus if batch_transform is not None else 0 resources = batch_transform.resources if batch_transform is not None else {} - return source, _CompositeBatchProcessor(transforms, batch_size, cpus, gpus, resources) + return source, _CompositeBatchProcessor(transforms, cpus, gpus, resources) class _CompositeBatchProcessor(BatchProcessor): @@ -157,7 +162,6 @@ def __init__(self, transforms, num_cpus, num_gpus, resources): self._num_cpus = num_cpus self._num_gpus = num_gpus self._resources = resources - self._batch_size = batch_size @property def batch_size(self): @@ -230,3 +234,65 @@ def to_hf_batched(x): return x return {b.field(i).name: to_hf_batched(b.column(i).to_numpy(zero_copy_only=False)) for i in range(b.num_columns)} + + +@ray.remote(num_cpus=0) +class BatchProcessorPool: + def __init__(self, processor: BatchProcessor, min_size: int = 1, max_size: int = 10): + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(message)s") + processor_ref = ray.put(processor) + self.actor_pool = AutoScalingActorPool( + lambda: _create_batch_processor_actor(processor, processor_ref), min_size, max_size + ) + + async def process_batch(self, batch_ref: RefBox): + return await self.actor_pool.submit( + lambda a, b: a.process_batch.remote(b), batch_ref.ref, obj_ref=batch_ref.ref + ) + + def num_pending_tasks(self): + return self.actor_pool.num_pending_tasks + + +def _create_batch_processor_actor(processor: BatchProcessor, processor_ref): + cpus = processor.num_cpus + gpus = processor.num_gpus + resources = processor.resources + return _BatchProcessorActor.options( # type: ignore + num_cpus=cpus, num_gpus=gpus, resources=resources, scheduling_strategy="SPREAD" + ).remote(processor_ref) + + +@ray.remote +class _BatchProcessorActor(PoolWorkerBase): + def __init__(self, processor: BatchProcessor): + from levanter.store.tree_store import TreeBatchPreparer + + self.processor = processor + self.preparer = TreeBatchPreparer(processor.output_exemplar) + + def process_batch(self, batch): + result = self.processor(batch) + result = _canonicalize_batch(result) + prepared = self.preparer(result) + return prepared + + +def _canonicalize_batch(batch: Union[dict, list[dict]]) -> list[dict]: + if isinstance(batch, pa.RecordBatch): + batch = dict_from_record_batch(batch) + + if isinstance(batch, dict): + return _to_list_of_dicts(batch) + else: + return batch + + +def _to_list_of_dicts(batch: dict) -> list[dict]: + """ + Convert a batch of dictionaries to a list of dictionaries, suitable for writing to a cache. + """ + keys = list(batch.keys()) + values = list(batch.values()) + num_rows = len(values[0]) + return [{key: values[i][j] for i, key in enumerate(keys)} for j in range(num_rows)] diff --git a/src/levanter/data/audio.py b/src/levanter/data/audio.py index 1fa8ec078..12695a20b 100644 --- a/src/levanter/data/audio.py +++ b/src/levanter/data/audio.py @@ -305,6 +305,7 @@ def build_or_load( override_resources=None, max_length=448, cache_options: CacheOptions = CacheOptions.default(), + split: str = "", ) -> "ProcessedAudioCache": bp = BatchAudioProcessor( processor, @@ -316,12 +317,7 @@ def build_or_load( ) monitors = monitors or [] cache = build_or_load_cache( - cache_dir, - source, - bp, - await_finished=await_finished, - monitors=monitors, - options=cache_options, + cache_dir, source, bp, await_finished=await_finished, monitors=monitors, options=cache_options, split=split ) if cache.is_finished: logger.info(f"Cache {cache_dir} is complete.") diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index dfd16f844..f2a3b8497 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -735,6 +735,7 @@ def build_or_load_cache( monitors=monitors, await_finished=False, options=self.cache_options, + split=split, ) diff --git a/src/levanter/main/cache_dataset.py b/src/levanter/main/cache_dataset.py index caccc567c..92471e997 100644 --- a/src/levanter/main/cache_dataset.py +++ b/src/levanter/main/cache_dataset.py @@ -48,6 +48,7 @@ def main(args: RayCachedLMDatasetConfig): processor=batch_tokenizer, await_finished=False, monitors=monitors, + split=split, ) cache.await_finished() diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index eae9f8402..c0bda78f9 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -12,25 +12,35 @@ from concurrent.futures import Future as threading_Future from contextlib import AbstractContextManager from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, TypeVar, Union +from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, TypeVar, Union import deepdiff import fsspec.core +import jax import pyarrow as pa import ray from dataclasses_json import dataclass_json from fsspec import AbstractFileSystem +from jaxtyping import PyTree from ray.actor import ActorHandle -from ray.remote_function import RemoteFunction from levanter.data.dataset import AsyncDataset from levanter.store._prefetch_actor import QueueEmpty, RayPrefetchQueue from levanter.utils.py_utils import Stopwatch -from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch +from ..data._preprocessor import BatchProcessor, BatchProcessorPool, BatchResult, dict_from_record_batch from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource -from ..utils.ray_utils import ExceptionInfo, SnitchRecipient, current_actor_handle, log_failures_to, ser_exc_info +from ..utils.ray_utils import ( + ExceptionInfo, + RefBox, + SnitchRecipient, + current_actor_handle, + log_failures_to, + ser_exc_info, +) +from ..utils.thread_utils import ExceptionTrackingThread +from .jagged_array import PreparedBatch from .tree_store import TreeStore @@ -43,7 +53,7 @@ LEDGER_FILE_NAME = "shard_ledger.json" DEFAULT_LOG_LEVEL = pylogging.INFO -LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" @dataclass_json @@ -66,6 +76,13 @@ class CacheOptions: """The batch size to use when processing the data. This is used to control the memory usage of the cache building process. Lower values will use less memory but take somewhat longer to build the cache.""" + # the below options don't actually impact the cache's result, but do impact construction + num_batches_per_flush = 256 + """The number of batches to process before flushing the cache to disk. This is used to control the memory usage of + the cache building process. Lower values will use less memory but may take somewhat longer to build the cache.""" + prefetch_per_group: int = 4 + """The number of batches to prefetch per group. This is used to keep the processors busy and to reduce the time""" + @staticmethod def default(): return CacheOptions() @@ -95,6 +112,7 @@ def build_or_load_cache( monitors: Optional[Sequence["MetricsMonitor"]] = None, options: CacheOptions = CacheOptions.default(), force_flush: bool = False, + split: str = "test", ) -> "TreeCache[U]": """ Produces a sharded cache of the dataset using Ray for distributed processing. The cache can be any path @@ -134,6 +152,7 @@ def build_or_load_cache( processor=processor, options=options, force_flush=force_flush, + split=split, ) if cache.is_finished: @@ -297,6 +316,7 @@ def build_or_load( processor: BatchProcessor[T, U], options: Optional["CacheOptions"] = None, force_flush: bool = False, + split: str = "test", ) -> "TreeCache[U]": if options is None: options = CacheOptions.default() @@ -310,6 +330,7 @@ def build_or_load( processor=processor, options=options, force_flush=force_flush, + split=split, ) return TreeCache(cache_dir=cache_dir, exemplar=processor.output_exemplar, ledger=None, _broker=broker) @@ -585,9 +606,9 @@ def write_batch(self, batch: BatchResult): if isinstance(batch, pa.RecordBatch): raise NotImplementedError("Only non-RecordBatch batches are supported for now") - batch = _canonicalize_batch(batch) # type: ignore + cbatch = _canonicalize_batch(batch) # type: ignore - self._tree_store.extend(batch) + self._tree_store.extend(cbatch) class ShardedCacheWriter: @@ -612,7 +633,6 @@ def __init__( self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a") # type: ignore self._tree_store.trim_to_size(self._ledger.total_num_rows) - self._items_ready_to_write: list = [] @property def ledger(self): @@ -627,7 +647,6 @@ def is_finished(self): return self._ledger.is_finished def finish_shard(self, shard_name: str, num_rows: int): - self.flush() current_rows = self._ledger.shard_rows.get(shard_name, 0) if current_rows != num_rows: raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}") @@ -635,6 +654,21 @@ def finish_shard(self, shard_name: str, num_rows: int): self._ledger.finished_shards.append(shard_name) self._ledger._serialize_and_commit(self.cache_dir) + def write_prepared_batch(self, shard_counts: Mapping[str, int], batch: PyTree[PreparedBatch]): + if self.is_finished: + raise RuntimeError("Cannot write to a finished cache") + self._tree_store.extend_with_batch(batch) + + for shard, num_rows in shard_counts.items(): + self._ledger.shard_rows[shard] = self._ledger.shard_rows.get(shard, 0) + num_rows + + total_rows = self._ledger.total_num_rows + sum(shard_counts.values()) + self._ledger.total_num_rows = total_rows + self._ledger._serialize_and_commit(self.cache_dir) + + if self._on_write: + self._on_write(self._ledger) + def write_batch(self, shard_name: str, batch: BatchResult): if self.is_finished: raise RuntimeError("Cannot write to a finished cache") @@ -643,15 +677,11 @@ def write_batch(self, shard_name: str, batch: BatchResult): raise NotImplementedError("Only non-RecordBatch batches are supported for now") batch = _canonicalize_batch(batch) # type: ignore + prepared = self._tree_store.batch_preparer(batch) - self._items_ready_to_write.append((shard_name, batch)) - - def flush(self): - self._attempt_to_write_batches() + return self.write_prepared_batch({shard_name: len(batch)}, prepared) def finish(self): - self.flush() - # if successful, write the ledger logger.info("Finished writing cache") self._ledger.is_finished = True @@ -661,59 +691,28 @@ def finish(self): return self._tree_store - def _attempt_to_write_batches(self): - if self._ledger.is_finished: - return - - if not self._items_ready_to_write: - return - - updated_shards = self._write_available_batches() - - logger.debug(f"Updated shards: {updated_shards}") - - did_write = len(updated_shards) > 0 - - if did_write: - - for shard, num_rows in updated_shards.items(): - self._ledger.shard_rows[shard] = self._ledger.shard_rows.get(shard, 0) + num_rows - - total_rows = self._ledger.total_num_rows + sum(updated_shards.values()) - self._ledger.total_num_rows = total_rows - self._ledger._serialize_and_commit(self.cache_dir) - - if self._on_write: - self._on_write(self._ledger) - - def _write_available_batches(self): - ready = self._items_ready_to_write - self._items_ready_to_write = [] - - if len(ready) == 0: - return {} - - to_write = [] - written_by_shard = {} - for shard, batch in ready: - to_write.extend(batch) - written_by_shard[shard] = written_by_shard.get(shard, 0) + len(batch) - - self._tree_store.extend(to_write) - return written_by_shard - def _serialize_json_and_commit(path, obj): # just to be paranoid, we write to a temp file and then rename it # TODO: probably we could do better here - with fsspec.open(f"{path}.tmp", "w") as file: - file.write(obj.to_json()) - # now copy the old file to a backup fs: AbstractFileSystem = fsspec.core.url_to_fs(path)[0] fs.mkdirs(os.path.dirname(path), exist_ok=True) if fs.exists(path): + # copy the old file to a backup fs.copy(path, f"{path}.bak") - fs.rename(f"{path}.tmp", path) + + for i in range(10): + with fsspec.open(f"{path}.tmp", "w") as file: + file.write(obj.to_json()) + + try: + fs.rename(f"{path}.tmp", path) + break + except FileNotFoundError: + # this happens for some reason sometimes. It makes no sense. + # FileNotFoundError: b/levanter-data/o/scratch%2Fdlwh%2Fpile-YYY%2Fpubmed_abs%2Ftrain%2Fshard_ledger.json.tmp/rewriteTo/b/levanter-data/o/scratch%2Fdlwh%2Fpile-YYY%2Fpubmed_abs%2Ftrain%2Fshard_ledger.json + logger.exception(f"Failed to rename {path}.tmp to {path}") + pass @ray.remote(num_cpus=0.1) # keep this small b/c it doesn't do a lot @@ -728,6 +727,7 @@ def __init__( self, cache_dir: str, name: str, + split: str, # to workaround https://github.com/ray-project/ray/issues/44083 source: ShardedDataSource[T], processor: BatchProcessor[T, U], options: CacheOptions, @@ -754,9 +754,9 @@ def __init__( if self._ledger.is_finished: self.logger.info("Cache already finished. Nothing to do.") return - self._cache_writer = _core_writer_task.remote( - current_actor_handle(), cache_dir, self._ledger, source, processor, force_flush - ) + self._cache_writer = _core_writer_task.options( + name=f"writer::{path_for_name}", scheduling_strategy="SPREAD" + ).remote(current_actor_handle(), cache_dir, split, self._ledger, source, processor, force_flush) except Exception: # Ray behaves poorly if the constructor of an actor fails, so we catch and log here # this also propagates to the finished promise, so we can handle it there @@ -849,13 +849,14 @@ async def _do_notify_async(): asyncio.create_task(_do_notify_async()) -def _get_builder_actor(cache_dir, shard_source, processor, options=CacheOptions.default(), force_flush=False): - name = f"lev_cache_manager::{cache_dir}" +def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default(), force_flush=False): + name = f"lev_cache_manager::{split}::{cache_dir}" path_for_name = os.path.join(*os.path.split(cache_dir)[-2:]) name_for_display = f"builder::{path_for_name}" return _TreeStoreCacheBuilder.options(name=name, get_if_exists=True).remote( # type: ignore name=name_for_display, + split=split, cache_dir=cache_dir, source=shard_source, processor=processor, @@ -906,7 +907,6 @@ class _ShardFinished: """ _TIME_BETWEEN_WRITES = 20.0 # seconds -_MAX_WRITE_BATCHES = 1000 _MIN_WRITE_BATCHES = 100 @@ -914,6 +914,7 @@ class _ShardFinished: def _core_writer_task( parent, cache_dir, + split, initial_ledger: CacheLedger, source: ShardedDataSource, processor, @@ -928,24 +929,30 @@ def _core_writer_task( * processing of the batches * writing of the batches to the cache """ - logger.setLevel(DEFAULT_LOG_LEVEL) + pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) logger.info("Starting writer task") name = str(os.path.join(*cache_dir.split("/")[-2:])) # append a small random number to the name to avoid collisions name += f"::{random.randint(0, 1000)}" - def on_write(ledger): - ray.get(parent._notify_updated_ledger.remote(ledger)) - with log_failures_to(parent): - sharded_cache_writer = ray.remote(ShardedCacheWriter).remote( + + def on_write(ledger): + ray.get(parent._notify_updated_ledger.remote(ledger)) + + sharded_cache_writer = ShardedCacheWriter( cache_dir, initial_ledger, processor.output_exemplar, on_write=on_write ) + options = initial_ledger.metadata.options + num_groups = min(options.num_shard_groups or 1000000, len(source.shard_names)) + + processor_pool = _mk_processor_pool(split, processor, 0, num_groups * 4) + interleave: RayPrefetchQueue = RayPrefetchQueue( - lambda: _make_interleave(name, source, initial_ledger, processor), - 4096, + lambda: _make_interleave(name, source, initial_ledger, processor_pool), + 512, producer_options={"num_cpus": 1, "name": f"{name}::interleave"}, ) @@ -955,34 +962,35 @@ def on_write(ledger): flush_time = Stopwatch() flush_amortized_time = Stopwatch() - i = 0 - batches_since_last_write = 0 + batches: list = [] time_of_last_write = time.time() - last_flush_future: Optional[ray.ObjectRef] = None - # start_of_last_flush = time_of_last_write + batches_total = 0.0 + flush_thread = None + finished_shards_last_flush: list = [] - # for i, batch_box in enumerate(interleave): while True: - with total_time: # 0.014 + with total_time: # 0.0051 try: cur_time = time.time() time_since_last_write = cur_time - time_of_last_write remaining_time = _TIME_BETWEEN_WRITES - time_since_last_write - if batches_since_last_write > 0: - with flush_amortized_time: - if remaining_time <= 0 or batches_since_last_write >= _MAX_WRITE_BATCHES or force_flush: - with flush_time: - # TODO: don't block? - if last_flush_future: - ray.get(last_flush_future) - # print( - # f"Flushed {batches_since_last_write} batches in" - # f" {time.time() - start_of_last_flush} seconds" - # ) - last_flush_future = sharded_cache_writer.flush.remote() - # start_of_last_flush = time.time() - batches_since_last_write = 0 + if len(batches) > 0: + with flush_amortized_time: # 6e-4 + if remaining_time <= 0 or len(batches) >= options.num_batches_per_flush or force_flush: + with flush_time: # 0.613s + shard_rows, payloads = _fetch_batches(batches) + if flush_thread is not None: + flush_thread.join() + + batches = [] + flush_thread = ExceptionTrackingThread( + target=_write_batches, + args=(sharded_cache_writer, shard_rows, payloads, finished_shards_last_flush), + ) + flush_thread.start() + + finished_shards_last_flush = [] time_of_last_write = time.time() continue else: @@ -998,18 +1006,26 @@ def on_write(ledger): with append_time: match message: case _Batch(shard, _, payload): - # TODO: ensure indices are what we expect - sharded_cache_writer.write_batch.remote(shard, payload) - batches_since_last_write += 1 - i += 1 + batches_total += 1 + batches.append((shard, payload)) + + if force_flush: + shard_rows, payloads = _fetch_batches(batches) + del batches + _write_batches( + sharded_cache_writer, shard_rows, payloads, finished_shards_last_flush + ) + batches = [] + finished_shards_last_flush = [] + case _ShardFinished(shard, total_rows): - ray.get(sharded_cache_writer.finish_shard.remote(shard, total_rows)) + finished_shards_last_flush.append((shard, total_rows)) case _: raise AssertionError(f"Unexpected message type {type(message)}") - # if i % 1000 == 0: + # if batches_total % 1000 == 0: # print( - # f"Processed {i} batches: {loading_time.average()}s load," + # f"Processed {batches_total} batches: {loading_time.average()}s load," # f" {append_time.average()}s append, {flush_time.average()}s flush blocked, " # f"{flush_amortized_time.average()}s amortized flush, " # f"{total_time.average()}s total" @@ -1021,12 +1037,43 @@ def on_write(ledger): logger.exception("Error while processing batch") raise e - sharded_cache_writer.finish.remote() + # force a flush + if len(batches) > 0: + shard_row_totals, payloads_for_batches = _fetch_batches(batches) + del batches + if flush_thread is not None: + flush_thread.join() + _write_batches(sharded_cache_writer, shard_row_totals, payloads_for_batches, finished_shards_last_flush) + + sharded_cache_writer.finish() - out = sharded_cache_writer.get_ledger.remote() + out = sharded_cache_writer.get_ledger() return out +def _write_batches(writer: ShardedCacheWriter, shard_totals, batches, finished_shards): + # concatenate the payloads + final_payload = jax.tree.map(lambda *bs: PreparedBatch.concat(bs), *batches) + writer.write_prepared_batch(shard_totals, final_payload) + + for shard, total_rows in finished_shards: + writer.finish_shard(shard, total_rows) + + +def _fetch_batches(batches) -> tuple[dict[str, int], list[PreparedBatch]]: + time_in = time.time() + shards_for_batches, payloads_for_batches = zip(*batches) + payloads_for_batches = ray.get(list(payloads_for_batches)) + time_out = time.time() + logger.info(f"Fetched {len(batches)} batches in {time_out - time_in} seconds") + + shard_row_totals: dict[str, int] = {} + for shard, payload in zip(shards_for_batches, payloads_for_batches): + shard_row_totals[shard] = shard_row_totals.get(shard, 0) + jax.tree.leaves(payload)[0].num_rows + + return shard_row_totals, payloads_for_batches + + def _interleave_shards(readers: Sequence[RayPrefetchQueue], first_index: int) -> Iterator[T]: # _Message """ Interleaves the results of multiple iterators. To support resume, @@ -1110,7 +1157,7 @@ def _impute_total_rows_committed_and_check_invariants(self): return total_committed, all_finished -def _make_interleave(name: str, source: ShardedDataSource, initial_ledger: CacheLedger, processor: BatchProcessor): +def _make_interleave(name: str, source: ShardedDataSource, initial_ledger: CacheLedger, processor_pool: ActorHandle): """ Given a list of ShardStatus objects and sources, creates an interleaving generator that reads from shards and tokenizes them in parallel. @@ -1134,9 +1181,6 @@ def _make_interleave(name: str, source: ShardedDataSource, initial_ledger: Cache logger.warning(f"Starting cache build with {len(statuses)} shards, in {len(groups)} groups") - process_task = _mk_process_task(processor) - processor_ref = ray.put(processor) - def _make_generator_fn(group: _ShardGroup): def generator(): pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) @@ -1144,7 +1188,8 @@ def generator(): match message: case _Batch(): # processed = ray.put(process_task(ray.get(message.payload))) - processed = process_task.remote(processor_ref, message.payload) + # processed = process_task.remote(processor_ref, message.payload) + processed = processor_pool.process_batch.remote(RefBox(message.payload)) yield dataclasses.replace(message, payload=processed) case _ShardFinished(): yield message @@ -1156,7 +1201,9 @@ def generator(): generator_fns = [_make_generator_fn(group) for group in groups] readers = [ - RayPrefetchQueue(fn, 128, producer_options=dict(name=name, scheduling_strategy="SPREAD")) + RayPrefetchQueue( + fn, options.prefetch_per_group, producer_options=dict(num_cpus=0, name=name, scheduling_strategy="SPREAD") + ) for name, fn in zip(group_names, generator_fns) ] @@ -1169,6 +1216,20 @@ def generator(): yield from _interleave_shards(readers, first_group_to_start) +def _mk_processor_pool(split, processor, min_size, max_size): + import hashlib + + metadata_hash = hashlib.md5(str(processor.metadata).encode()).hexdigest() + processor_pool_name = f"processor_pool::{metadata_hash}" + processor_pool = BatchProcessorPool.options( # type: ignore + name=processor_pool_name, get_if_exists=True, lifetime="detached" + ).remote( # type: ignore + processor, min_size, max_size + ) + + return processor_pool + + def _check_current_shard_progress(statuses): unfinished_shards: list[_ShardStatus] = [] shards_with_progress: dict[str, int] = {} @@ -1237,32 +1298,6 @@ def _shard_reader_generator( yield _ShardFinished(status.shard_name, row_idx) -def _mk_process_task(processor: BatchProcessor[T, U]) -> RemoteFunction: - """ - Returns a Ray remote function that processes a batch of data. Basically it takes the resources from - the processor and wraps its call - """ - # processor_ref = ray.put(processor) - # exemplar = { - # "input_ids": np.random.randint(0, 100, size=(4096,)) - # } - - @ray.remote(num_cpus=processor.num_cpus, num_gpus=processor.num_gpus, resources=processor.resources) - def process_task(processor, batch_payload): - try: - result = processor(batch_payload) # TIME: 0.03 seconds - result = _canonicalize_batch(result) # type: ignore - logger.debug("Finished processing batch") - return result - except Exception as e: - logger.exception("Error while processing batch") - raise e - finally: - pass - - return process_task - - def _canonicalize_batch(batch: Union[dict, List[dict]]) -> List[dict]: if isinstance(batch, pa.RecordBatch): batch = dict_from_record_batch(batch) diff --git a/src/levanter/store/jagged_array.py b/src/levanter/store/jagged_array.py index b236641c9..d6674774f 100644 --- a/src/levanter/store/jagged_array.py +++ b/src/levanter/store/jagged_array.py @@ -4,7 +4,6 @@ from typing import Optional, Sequence import fsspec.core -import jax import jax.experimental.array_serialization.serialization as ser import jax.numpy as jnp import numpy as np @@ -20,6 +19,60 @@ DEFAULT_WRITE_CHUNK_SIZE = DEFAULT_CHUNK_SIZE * 512 +@dataclass +class PreparedBatch: + """ + A batch of data that has been prepared for storage in a jagged array. + """ + + data: np.ndarray + offsets: np.ndarray + shapes: Optional[np.ndarray] + + def astype(self, dtype): + return PreparedBatch(self.data.astype(dtype), self.offsets, self.shapes) + + @property + def num_rows(self): + return len(self.offsets) + + @staticmethod + def from_batch(items: Sequence[np.ndarray], item_rank: Optional[int] = None) -> "PreparedBatch": + data, offsets, shapes = _prepare_batch(items, item_rank) + return PreparedBatch(data, offsets, shapes) + + @staticmethod + def concat(batches: Sequence["PreparedBatch"]) -> "PreparedBatch": + data = np.concatenate([batch.data for batch in batches]) + shapes = np.concatenate([batch.shapes for batch in batches]) if batches[0].shapes is not None else None + # offsets have to be adjusted by adding the previous offset + totals = np.cumsum([0] + [batch.data.size for batch in batches]) + offsets = np.concatenate([batch.offsets + total for batch, total in zip(batches, totals)]) + + return PreparedBatch(data, offsets, shapes) + + +def _prepare_batch(arrays, item_rank): + if item_rank is None: + item_rank = arrays[0].ndim + + if item_rank != 1: + shapes = np.array([data.shape[:-1] for data in arrays], dtype=np.int64) + else: + + shapes = None + + # check shapes + for data in arrays: + if data.ndim != item_rank: + raise ValueError(f"Expected data to have rank {item_rank}, but got {data.ndim}") + + offsets = np.array([data.size for data in arrays], dtype=np.int64) + offsets = np.cumsum(offsets) + data = np.concatenate([data.reshape(-1) for data in arrays]) + return data, offsets, shapes + + @dataclass class JaggedArrayStore: """ @@ -113,10 +166,10 @@ async def data_size_async(self): self._cached_data_size = result return result - async def append_async(self, data: jax.Array): + async def append_async(self, data: np.ndarray): await self.extend_async([data]) - def append(self, data: jax.Array): + def append(self, data: np.ndarray): self.extend([data]) async def trim_to_size_async(self, size: int): @@ -190,13 +243,21 @@ def trim_to_size(self, size: int): self._cached_num_rows = size self._cached_data_size = new_max - async def extend_async(self, arrays: Sequence[jax.Array]): - data, new_offsets, shapes = self._prepare_batch(arrays) + async def extend_async(self, arrays: Sequence[np.ndarray] | PreparedBatch): + if isinstance(arrays, PreparedBatch): + prepared = arrays + else: + prepared = PreparedBatch.from_batch(arrays, self.item_rank) + data = prepared.data + new_offsets = prepared.offsets + shapes = prepared.shapes num_rows = await self.num_rows_async() - num_added = len(arrays) + num_added = len(new_offsets) current_data_size = self.data_size + new_offsets = new_offsets + current_data_size + # Write to resized arrays concurrently, adjusting offsets explicitly write_tasks = [ self.data[current_data_size : current_data_size + len(data)].write(data), @@ -207,19 +268,33 @@ async def extend_async(self, arrays: Sequence[jax.Array]): await asyncio.gather(*write_tasks) # Update num_rows - await self.offsets[0].write(num_rows + len(arrays)) + await self.offsets[0].write(num_rows + num_added) if self._cache_metadata: - self._cached_num_rows = num_rows + len(arrays) + self._cached_num_rows = num_rows + num_added self._cached_data_size = current_data_size + len(data) - def extend(self, arrays: Sequence[jax.Array]): - data, new_offsets, shapes = self._prepare_batch(arrays) + def extend(self, arrays: Sequence[np.ndarray] | PreparedBatch): + if isinstance(arrays, PreparedBatch): + prepared = arrays + else: + prepared = PreparedBatch.from_batch(arrays, self.item_rank) + + data = prepared.data + new_offsets = prepared.offsets + shapes = prepared.shapes + + if shapes is None and self.item_rank != 1: + raise ValueError("Shapes must be provided for non-vector data") + elif shapes is not None and shapes.shape[1] != self.item_rank - 1: + raise ValueError(f"Shapes must have {self.item_rank-1} dimensions, but got {shapes.shape[1]}") num_rows = self.num_rows - num_added = len(arrays) + num_added = len(new_offsets) current_data_size = self.data_size + new_offsets = new_offsets + current_data_size + write_tasks = [ self.data[current_data_size : current_data_size + len(data)].write(data), self.offsets[num_rows + 1 : num_rows + num_added + 1].write(new_offsets), @@ -231,28 +306,12 @@ def extend(self, arrays: Sequence[jax.Array]): for task in write_tasks: task.result() - self.offsets[0].write(num_rows + len(arrays)).result() + self.offsets[0].write(num_rows + num_added).result() if self._cache_metadata: - self._cached_num_rows = num_rows + len(arrays) + self._cached_num_rows = num_rows + num_added self._cached_data_size = current_data_size + len(data) - def _prepare_batch(self, arrays): - if self.shapes is not None: - for data in arrays: - if data.ndim != self.item_rank: - raise ValueError(f"Expected data to have rank {self.item_rank}, got {data.ndim}") - shapes = np.array([data.shape[:-1] for data in arrays], dtype=np.int64) - else: - for data in arrays: - if data.ndim > 1: - raise ValueError(f"Expected data to have rank 1, got {data.ndim}") - shapes = None - new_offsets = np.array([data.size for data in arrays], dtype=np.int64) - new_offsets = np.cumsum(new_offsets) + self.data_size - data = np.concatenate([data.reshape(-1) for data in arrays]) - return data, new_offsets, shapes - async def reload_async(self) -> "JaggedArrayStore": """ Calls `resolve` on the underlying tensorstore objects, updating size information @@ -309,7 +368,7 @@ async def get_item_async(self, item): else: raise e - async def get_batch(self, indices: Sequence[int]) -> Sequence[jax.Array]: + async def get_batch(self, indices: Sequence[int]) -> Sequence[np.ndarray]: # get indices with ts.Batch(): all_indices_futs = [self._bounds_for_rows_async(indices[i], indices[i] + 1) for i in range(len(indices))] @@ -334,7 +393,7 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[jax.Array]: return data - def get_batch_sync(self, indices: Sequence[int]) -> Sequence[jax.Array]: + def get_batch_sync(self, indices: Sequence[int]) -> Sequence[np.ndarray]: all_indices = self._bounds_for_rows_batch(indices) with ts.Batch(): diff --git a/src/levanter/store/tree_store.py b/src/levanter/store/tree_store.py index cd29e5a4c..03355a8d2 100644 --- a/src/levanter/store/tree_store.py +++ b/src/levanter/store/tree_store.py @@ -10,7 +10,7 @@ from haliax.jax_utils import is_jax_array_like -from .jagged_array import JaggedArrayStore +from .jagged_array import JaggedArrayStore, PreparedBatch T = TypeVar("T", bound=PyTree) @@ -49,6 +49,10 @@ def __init__(self, tree, path: str, mode: str): self.mode = mode self.tree = tree + @property + def batch_preparer(self): + return TreeBatchPreparer(jtu.tree_map(lambda writer: 9, self.tree, is_leaf=heuristic_is_leaf)) + @staticmethod def open(exemplar: T, path: str, *, mode="a", cache_metadata: bool = False) -> "TreeStore": """ @@ -64,7 +68,6 @@ def extend(self, batch: Sequence[T]): """ Append a batch of data to the store. """ - # TODO: I do wish zarr supported async jtu.tree_map( lambda writer, *xs: writer.extend([np.asarray(x) for x in xs]), self.tree, @@ -80,7 +83,7 @@ def extend_with_batch(self, batch: T): For instance, HF's BatchEncoding is a dict of lists of numpy arrays. """ jtu.tree_map( - lambda writer, xs: writer.extend([np.asarray(x) for x in xs]), + lambda writer, xs: writer.extend(xs if isinstance(xs, PreparedBatch) else [np.asarray(x) for x in xs]), self.tree, batch, is_leaf=heuristic_is_leaf_batched, @@ -94,7 +97,9 @@ async def extend_with_batch_async(self, batch: T): For instance, HF's BatchEncoding is a dict of lists of numpy arrays. """ futures = jtu.tree_map( - lambda writer, xs: writer.extend_async([np.asarray(x) for x in xs]), + lambda writer, xs: writer.extend_async( + xs if isinstance(xs, PreparedBatch) else [np.asarray(x) for x in xs] + ), self.tree, batch, is_leaf=heuristic_is_leaf_batched, @@ -198,46 +203,14 @@ def _render_path_elem(x): return str(x) -# class TokenSeqDataset: -# """ -# A dataset of sequences of tokens of fixed length, materialized from a collection of JaggedArrayStores, -# which have typically much longer sequences. This class takes consecutive sequences of tokens from the builders -# and slices/concats them to form the dataset. -# """ -# -# def __init__( -# self, token_arrays: Sequence[JaggedArrayStore], token_counts: Sequence[int], seq_len: int, pad_token: int -# ): -# self.token_arrays = token_arrays -# -# def _round_to_nearest_multiple(x, y): -# return x + y - x % y -# -# token_counts_padded = np.array([_round_to_nearest_multiple(x, seq_len) for x in token_counts]) -# seq_counts = token_counts_padded // seq_len -# self.seq_counts_cumsum = np.concatenate([np.asarray([0]), np.cumsum(seq_counts)]) -# -# self.seq_len = seq_len -# self.pad_token = pad_token -# -# def __len__(self): -# return self.seq_counts_cumsum[-1] -# -# def __getitem__(self, seq_id): -# return asyncio.run(self.get_item_async(seq_id)) -# -# async def get_item_async(self, seq_id): -# # TODO: accept slices and such? -# shard_id = np.searchsorted(self.seq_counts_cumsum, seq_id, side="right") - 1 -# shard_start = self.seq_counts_cumsum[shard_id] -# shard_end = self.seq_counts_cumsum[shard_id + 1] -# shard_seq_id = seq_id - shard_start -# -# shard_seq_start = shard_seq_id * self.seq_len -# shard_seq_end = min((shard_seq_id + 1) * self.seq_len, self.token_arrays[shard_id].data_size) -# -# shard_seq = await self.token_arrays[shard_id].data[shard_seq_start:shard_seq_end].read() -# pad_len = self.seq_len - (shard_seq_end - shard_seq_start) -# padded_seq = np.concatenate([shard_seq, np.full(pad_len, self.pad_token, dtype=shard_seq.dtype)]) -# -# return padded_seq +class TreeBatchPreparer(Generic[T]): + def __init__(self, exemplar: T): + self.exemplar = exemplar + + def __call__(self, batch: List[T]) -> PyTree: + return jtu.tree_map( + lambda _, *xs: PreparedBatch.from_batch([np.asarray(x) for x in xs]), + self.exemplar, + *batch, + is_leaf=heuristic_is_leaf, + ) diff --git a/src/levanter/utils/actor_pool.py b/src/levanter/utils/actor_pool.py new file mode 100644 index 000000000..51ba2ccec --- /dev/null +++ b/src/levanter/utils/actor_pool.py @@ -0,0 +1,224 @@ +import asyncio +import logging +from abc import ABC +from typing import Any, Callable, Dict, List, Optional, TypeVar + +import ray + + +V = TypeVar("V") +R = TypeVar("R") + +logger = logging.getLogger(__name__) + +# Copilot-Adapted from: +# https://github.com/ray-project/ray/blob/1bab09bf842edee51c3778be4cfb16f8b900d764/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py + + +class AutoScalingActorPool: + """Utility class to operate on a dynamically scaling pool of actors.""" + + def __init__( + self, + create_actor_fn: Callable[[], "ray.actor.ActorHandle"], + min_size: int = 1, + max_size: int = 10, + ): + if max_size < min_size: + raise ValueError("max_size must be greater than or equal to min_size.") + self._create_actor_fn = create_actor_fn + self._min_size = min_size + self._max_size = max_size + + self._idle_actors: List[ray.actor.ActorHandle] = [] + self._busy_actors: Dict[ray.ObjectRef, ray.actor.ActorHandle] = {} + self._pending_actors: Dict[ray.ObjectRef, ray.actor.ActorHandle] = {} + + self._actor_locations: Dict[ray.actor.ActorHandle, str] = {} + self._tasks_waiting_for_actor: list[asyncio.Future] = [] + self._next_task_id = 0 + + self._scale_up(self._min_size) + + @property + def num_pending_tasks(self): + return len(self._tasks_waiting_for_actor) + + def _scale_up(self, num_actors: int): + for _ in range(num_actors): + try: + actor = self._create_actor_fn() + ready_ref = actor.get_location.remote() + self._pending_actors[ready_ref] = actor + + async def wait_for_ready(actor, ready_ref): + loc = await ready_ref + # pending -> floating + if ready_ref not in self._pending_actors: + logger.info("Actor was cancelled before it was ready.") + return + del self._pending_actors[ready_ref] + self._assert_is_floating(actor) + self._actor_locations[actor] = loc + self._maybe_start_pending_task(actor) # floating -> {idle, busy} + + asyncio.ensure_future(wait_for_ready(actor, ready_ref)) + + except Exception as e: + logger.error("Failed to create actor.", exc_info=e) + + def _scale_down(self, num_actors: int): + for _ in range(num_actors): + if self._pending_actors: + actor = self._pending_actors.popitem()[1] + # let it die through gc + # ray.kill(actor) + elif self._idle_actors: + actor = self._idle_actors.pop() + del self._actor_locations[actor] + # let it die through gc + # ray.kill(actor) + else: + break + + def _adjust_pool_size(self): + num_pending_tasks = self.num_pending_tasks + num_idle_actors = len(self._idle_actors) + num_busy_actors = len(self._busy_actors) + num_pending_actors = len(self._pending_actors) + + num_nonworking_actors = num_idle_actors + num_pending_actors + total_actors = num_nonworking_actors + num_busy_actors + + # TODO: better autoscale logic + if ( + num_pending_actors == 0 + and num_pending_tasks > 0 + and num_idle_actors == 0 + and total_actors < self._max_size + ): + logger.info( + f"Scaling up due to {num_pending_tasks} pending tasks. Current pool size: {total_actors}. Max size:" + f" {self._max_size}" + ) + self._scale_up(min(self._max_size - num_busy_actors, num_pending_tasks)) + elif num_pending_tasks == 0 and num_nonworking_actors > self._min_size: + return # never scal edown. too many issues + logger.info(f"Scaling down due to no pending tasks. Current pool size: {total_actors}") + self._scale_down(num_nonworking_actors - self._min_size) + + def _get_object_location(self, obj_ref: ray.ObjectRef) -> Optional[str]: + """Get the location of the given object reference.""" + try: + locs = ray.experimental.get_object_locations([obj_ref]) + nodes = locs[obj_ref]["node_ids"] + if nodes: + return nodes[0] + except Exception as e: + logger.error(f"Failed to get object location: {e}") + return None + + def _pick_actor(self, obj_ref: Optional[ray.ObjectRef] = None) -> Optional[ray.actor.ActorHandle]: + """Pick an actor based on locality and busyness.""" + # idle -> floating + if not self._idle_actors: + return None + + if obj_ref: + preferred_loc = self._get_object_location(obj_ref) + else: + preferred_loc = None + + def penalty_key(actor): + """Returns the key that should be minimized for the best actor.""" + requires_remote_fetch = self._actor_locations[actor] != preferred_loc + return requires_remote_fetch + + actor = min(self._idle_actors, key=penalty_key) + actor = self._idle_actors.pop(self._idle_actors.index(actor)) + return actor + + def submit(self, fn: Callable[["ray.actor.ActorHandle", V], R], value: V, obj_ref: Optional[ray.ObjectRef] = None): + actor = self._pick_actor(obj_ref) + if actor: + return self._assign_task_to_actor(actor, fn, value) + else: + actor_future: asyncio.Future = asyncio.Future() + self._tasks_waiting_for_actor.append(actor_future) + f = asyncio.ensure_future(self._enqueue_pending_task(fn, obj_ref, value, actor_future)) + self._adjust_pool_size() + return f + + def _assign_task_to_actor(self, actor, fn, value): + # floating -> busy + ray_future = fn(actor, value) + self._busy_actors[ray_future] = actor + self._adjust_pool_size() + + # return ray_future + return asyncio.ensure_future(self._wrap_ray_future(ray_future)) + + async def _enqueue_pending_task(self, fn, obj_ref, value, actor_future): + actor = await actor_future + return await self._assign_task_to_actor(actor, fn, value) + + def _assert_is_floating(self, actor): + assert actor not in self._idle_actors + assert actor not in self._busy_actors + assert actor not in self._pending_actors + + def _maybe_start_pending_task(self, actor): + self._assert_is_floating(actor) + if self._tasks_waiting_for_actor: + # floating -> busy (inside the _enqueue_pending_task coroutine) + actor_future = self._tasks_waiting_for_actor.pop(0) + actor_future.set_result(actor) + assigned = True + else: + # floating -> idle + self._idle_actors.append(actor) + self._adjust_pool_size() + assigned = False + return assigned + + async def _wrap_ray_future(self, ray_future): + await asyncio.wait([ray_future]) + self._on_task_done(ray_future) + return await ray_future + + def _on_task_done(self, ray_future): + actor = self._busy_actors.pop(ray_future) + self._maybe_start_pending_task(actor) + + async def map( + self, + fn: Callable[["ray.actor.ActorHandle", V], Any], + values: List[V], + obj_refs: Optional[List[Optional[ray.ObjectRef]]] = None, + ) -> List[Any]: + if obj_refs is None: + obj_refs = [None] * len(values) + + tasks = [self.submit(fn, v, obj_ref) for v, obj_ref in zip(values, obj_refs)] + return await asyncio.gather(*tasks) + + def has_free(self): + return bool(self._idle_actors) + + def has_free_or_pending_actors(self): + return bool(self._idle_actors) or bool(self._pending_actors) + + def pop_idle(self): + if self._idle_actors: + return self._idle_actors.pop() + return None + + def push(self, actor: "ray.actor.ActorHandle"): + location = ray.get(actor.get_location.remote()) + self._actor_locations[actor] = location + self._maybe_start_pending_task(actor) + + +class PoolWorkerBase(ABC): + def get_location(self) -> str: + return ray.get_runtime_context().get_node_id() diff --git a/src/levanter/utils/thread_utils.py b/src/levanter/utils/thread_utils.py index fad60ad31..401ac94c5 100644 --- a/src/levanter/utils/thread_utils.py +++ b/src/levanter/utils/thread_utils.py @@ -72,3 +72,27 @@ def close(self): self.loop.call_soon_threadsafe(self.loop.stop) self.thread.join() self.loop.close() + + +class ExceptionTrackingThread(threading.Thread): + """A thread that will store exceptions that occur in the target function and + re-raise them in the main thread.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._exception = None + + def run(self): + try: + super().run() + except Exception as e: + self._exception = e + + def join(self, *args, **kwargs): + super().join(*args, **kwargs) + if self._exception: + raise self._exception + + def check_raise(self): + if self._exception: + raise self._exception diff --git a/tests/test_actor_pool.py b/tests/test_actor_pool.py new file mode 100644 index 000000000..08686eb30 --- /dev/null +++ b/tests/test_actor_pool.py @@ -0,0 +1,167 @@ +import asyncio +import time + +import pytest +import ray + +from levanter.utils.actor_pool import AutoScalingActorPool, PoolWorkerBase +from levanter.utils.py_utils import logical_cpu_core_count + + +@ray.remote +class TestActor(PoolWorkerBase): + def __init__(self): + self.node_id = ray.get_runtime_context().get_node_id() + + def get_node_id(self): + return self.node_id + + def double(self, v): + return 2 * v + + +@ray.remote +class BlockerActor(PoolWorkerBase): + def __init__(self): + self.node_id = ray.get_runtime_context().get_node_id() + self.unblocked = False + self.unblock_event = asyncio.Event() + + def get_node_id(self): + return self.node_id + + async def block(self): + if not self.unblocked: + await self.unblock_event.wait() + + async def unblock(self): + self.unblocked = True + self.unblock_event.set() + + +@ray.remote +class BlockingTestActor(PoolWorkerBase): + def __init__(self, blocker): + self.node_id = ray.get_runtime_context().get_node_id() + self.blocker = blocker + + def get_node_id(self): + return self.node_id + + def double(self, v, bypass_blocker=False): + if not bypass_blocker: + ray.get(self.blocker.block.remote()) + return 2 * v + + +# Helper function to create a TestActor +def create_test_actor(): + return TestActor.remote() + + +def create_test_actor_blocker(blocker_handle): + return BlockingTestActor.remote(blocker_handle) + + +def setup_module(module): + ray.init( + "local", num_cpus=max(2 * logical_cpu_core_count(), 8), ignore_reinit_error=True + ) # 2x cpu count is faster on my m1 + + +def teardown_module(module): + ray.shutdown() + + +@pytest.mark.asyncio +async def test_basic_submit(): + pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4) + results = [pool.submit(lambda a, v: a.double.remote(v), i) for i in range(4)] + results = [await r for r in results] + + assert results == [0, 2, 4, 6] + + +@pytest.mark.asyncio +async def test_basic_submit_no_idle(): + pool = AutoScalingActorPool(create_test_actor, min_size=0, max_size=4) + results = [pool.submit(lambda a, v: a.double.remote(v), i) for i in range(4)] + results = [await r for r in results] + + assert results == [0, 2, 4, 6] + + +@pytest.mark.asyncio +async def test_basic_functionality(): + pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4) + results = list(await pool.map(lambda a, v: a.double.remote(v), [1, 2, 3, 4])) + assert results == [2, 4, 6, 8] + + +@pytest.mark.asyncio +async def test_scaling_up(): + blocker = BlockerActor.remote() + pool = AutoScalingActorPool(lambda: create_test_actor_blocker(blocker), min_size=1, max_size=4) + f1 = pool.submit(lambda a, v: a.double.remote(v), 1) + f2 = pool.submit(lambda a, v: a.double.remote(v), 2) + f3 = pool.submit(lambda a, v: a.double.remote(v, True), 3) + f4 = pool.submit(lambda a, v: a.double.remote(v, True), 4) + + shield_f2 = asyncio.shield(f2) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(shield_f2, timeout=0.1) + + assert (await asyncio.gather(f3, f4)) == [6, 8] + + await blocker.unblock.remote() + # assert (await asyncio.gather(f1, f2)) == [2, 4] + assert (await f1) == 2 + assert (await f2) == 4 + + +@pytest.mark.asyncio +async def test_scaling_down(): + pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4) + await pool.submit(lambda a, v: a.double.remote(v), 1) + await pool.submit(lambda a, v: a.double.remote(v), 2) + await pool.submit(lambda a, v: a.double.remote(v), 3) + await pool.submit(lambda a, v: a.double.remote(v), 4) + results = await asyncio.gather( + pool.submit(lambda a, v: a.double.remote(v), 1), + pool.submit(lambda a, v: a.double.remote(v), 2), + pool.submit(lambda a, v: a.double.remote(v), 3), + pool.submit(lambda a, v: a.double.remote(v), 4), + ) + assert results == [2, 4, 6, 8] + assert len(pool._idle_actors) == 1 + assert len(pool._busy_actors) == 0 + + +@pytest.mark.asyncio +async def test_push_pop_idle(): + pool = AutoScalingActorPool(create_test_actor, min_size=1, max_size=4) + await pool.submit(lambda a, v: a.double.remote(v), 1) + actor = pool.pop_idle() + assert actor is not None + pool.push(actor) + assert len(pool._idle_actors) == 1 + + +@pytest.mark.asyncio +async def test_submit_with_no_idle_actors(): + blocker = BlockerActor.remote() + pool = AutoScalingActorPool(lambda: create_test_actor_blocker(blocker), min_size=1, max_size=4) + futs = [pool.submit(lambda a, v: a.double.remote(v), i) for i in range(4)] + f5 = pool.submit(lambda a, v: a.double.remote(v), 5) + await _sleep_until(lambda: pool.num_pending_tasks == 1, timeout=10) + await blocker.unblock.remote() + await asyncio.gather(*futs) + assert (await f5) == 10 + + +async def _sleep_until(condition, timeout=5, message="Condition not met within timeout"): + start = time.time() + while not condition(): + if time.time() - start > timeout: + pytest.fail(message) + await asyncio.sleep(0.1) diff --git a/tests/test_jagged_array.py b/tests/test_jagged_array.py index c89a2c625..4a450bae7 100644 --- a/tests/test_jagged_array.py +++ b/tests/test_jagged_array.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from levanter.store.jagged_array import JaggedArrayStore +from levanter.store.jagged_array import JaggedArrayStore, PreparedBatch class TestJaggedArrayStore: @@ -50,6 +50,75 @@ def test_extend_with_multiple(self, cache_metadata): result2 = builder[1] assert jnp.all(result2 == data2) + @pytest.mark.parametrize("cache_metadata", [True, False]) + def test_extend_with_prepared_batch(self, cache_metadata): + with tempfile.TemporaryDirectory() as tmpdir: + builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32, cache_metadata=cache_metadata) + + data1 = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) + data2 = np.array([[5.0]], dtype=jnp.float32) + prepared = PreparedBatch.from_batch([data1, data2]) + + builder.extend(prepared) + + assert len(builder) == 2 + + result1 = builder[0] + assert jnp.all(result1 == data1) + + result2 = builder[1] + assert jnp.all(result2 == data2) + + # extendd with more data + data3 = jnp.array([[6.0, 7.0], [8.0, 9.0]]) + data4 = jnp.array([[10.0]]) + prepared2 = PreparedBatch.from_batch([data3, data4]) + + builder.extend(prepared2) + + assert len(builder) == 4 + + result3 = builder[2] + assert jnp.all(result3 == data3) + + result4 = builder[3] + assert jnp.all(result4 == data4) + + @pytest.mark.asyncio + @pytest.mark.parametrize("cache_metadata", [True, False]) + async def test_extend_with_prepared_batch_async(self, cache_metadata): + with tempfile.TemporaryDirectory() as tmpdir: + builder = JaggedArrayStore.open(tmpdir, item_rank=2, dtype=jnp.float32, cache_metadata=cache_metadata) + + data1 = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=jnp.float32) + data2 = np.array([[5.0]], dtype=jnp.float32) + prepared = PreparedBatch.from_batch([data1, data2]) + + await builder.extend_async(prepared) + + assert len(builder) == 2 + + result1 = builder[0] + assert jnp.all(result1 == data1) + + result2 = builder[1] + assert jnp.all(result2 == data2) + + # extendd with more data + data3 = jnp.array([[6.0, 7.0], [8.0, 9.0]]) + data4 = jnp.array([[10.0]]) + prepared2 = PreparedBatch.from_batch([data3, data4]) + + await builder.extend_async(prepared2) + + assert len(builder) == 4 + + result3 = builder[2] + assert jnp.all(result3 == data3) + + result4 = builder[3] + assert jnp.all(result4 == data4) + def test_append_error(self): with tempfile.TemporaryDirectory() as tmpdir: builder = JaggedArrayStore.open(tmpdir, item_rank=1, dtype=jnp.float32) diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index af6fa885f..275c6a236 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -1,6 +1,5 @@ import asyncio import copy -import logging import os import tempfile from typing import Any, Dict, Iterator, Sequence @@ -23,7 +22,6 @@ build_or_load_cache, ) from levanter.utils.py_utils import logical_cpu_core_count -from levanter.utils.ray_utils import ExceptionInfo, SnitchRecipient class TestProcessor(BatchProcessor[Sequence[int], dict[str, np.ndarray]]): @@ -135,62 +133,6 @@ def test_serial_cache_writer(): np.testing.assert_array_equal(x["data"], np.asarray([i % 10 + i // 10 * 10] * 10)) -def crappy_du(path): - import os - - total = 0 - for root, dirs, files in os.walk(path): - for f in files: - total += os.path.getsize(os.path.join(root, f)) - return total - - -@ray.remote -class PretendParent(SnitchRecipient): - def __init__(self): - self.logger = logging.getLogger("SnitchRecipient") - self.failure_received = asyncio.Event() - self.exception_info = None - self._finished_shards = set() - self._finished = False - self._ledger = None - self._desired_next_item = None - - def _child_failed(self, child: ray.actor.ActorHandle, exception: ExceptionInfo): - try: - self.logger.error(f"Child {child} failed with exception {exception}") - self.exception_info = exception - self.failure_received.set() - except Exception as e: - self.logger.error(f"Error in _child_failed: {e}") - - def shard_failed(self, shard_name, exc_info): - self.exception_info = exc_info - self.failure_received.set() - - async def wait_for_failure(self): - await self.failure_received.wait() - return self.exception_info - - def shard_finished(self, shard_name): - self._finished_shards.add(shard_name) - - def get_finished_shards(self): - return self._finished_shards - - def _notify_updated_ledger(self, ledger): - if ledger.is_finished: - self._finished = True - - self._ledger = ledger - - def _finalize(self): - self._finished = True - - def is_finished(self): - return self._finished - - @pytest.mark.ray def test_full_end_to_end_cache(): td = tempfile.TemporaryDirectory() diff --git a/tests/test_tree_store.py b/tests/test_tree_store.py index 66131ca48..a0d089576 100644 --- a/tests/test_tree_store.py +++ b/tests/test_tree_store.py @@ -254,6 +254,34 @@ def test_reading_from_written(): pytest.fail("Unexpected index") +def test_using_prepared_batches(): + with tempfile.TemporaryDirectory() as tmpdir: + exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)} + builder = TreeStore.open(exemplar, tmpdir, mode="w") + preparer = builder.batch_preparer + + batch = [ + {"a": np.array([1.0, 2.0]), "b": np.array([3.0, 4.0])}, + {"a": np.array([5.0, 6.0]), "b": np.array([7.0, 8.0])}, + ] + batch = preparer(batch) + builder.extend_with_batch(batch) + + del builder + + builder2 = TreeStore.open(exemplar, tmpdir, mode="r") + + for i, result in enumerate(builder2): + if i == 0: + assert np.all(result["a"] == np.array([1.0, 2.0])) + assert np.all(result["b"] == np.array([3.0, 4.0])) + elif i == 1: + assert np.all(result["a"] == np.array([5.0, 6.0])) + assert np.all(result["b"] == np.array([7.0, 8.0])) + else: + pytest.fail("Unexpected index") + + def test_resolve_changed_cache_size(): with tempfile.TemporaryDirectory() as tmpdir: exemplar = {"a": np.array([0], dtype=np.float64), "b": np.array([0], dtype=np.float64)}