diff --git a/config/gpt2_small_fast_pile.yaml b/config/gpt2_small_fast_pile.yaml index 3a21732a7..291213d75 100644 --- a/config/gpt2_small_fast_pile.yaml +++ b/config/gpt2_small_fast_pile.yaml @@ -1,4 +1,4 @@ -data: !include data/pile_source_old.yaml +data: !include data/pile_mixture.yaml model: type: gpt2 hidden_dim: 768 diff --git a/pyproject.toml b/pyproject.toml index 19fb077bf..0831605cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "draccus>=0.8.0", "pyarrow>=11.0.0", "zstandard>=0.20.0", - "datasets>=2.18,<4.0", + "datasets>=3.1.0,<4.0", "gcsfs>=2024.2,<2024.10", "braceexpand>=0.1.7", "jmp>=0.0.3", diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 186a0d9dd..90803df3e 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -197,7 +197,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: dataset = self._load_dataset() if isinstance(dataset, datasets.IterableDataset) and shard_name != "data": # ex_iterable has a key that gets discarded typically - shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards)) + shard = map( + lambda t: t[1], + dataset._ex_iterable.shard_data_sources(index=int(shard_name), num_shards=dataset.n_shards), + ) else: shard = dataset diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 70c1fe4b3..de6980430 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -35,7 +35,7 @@ from levanter.store.cache import CacheOptions, TreeCache from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore -from levanter.utils.fsspec_utils import fsspec_expand_glob +from levanter.utils.fsspec_utils import expand_glob from levanter.utils.hf_utils import num_cpus_used_by_tokenizer @@ -508,7 +508,7 @@ def urls_for_split(self, split): else: raise ValueError(f"Unknown split {split}") - urls = [globbed for url in urls for globbed in fsspec_expand_glob(url)] + urls = [globbed for url in urls for globbed in expand_glob(url)] return urls @@ -625,7 +625,7 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): import levanter.data - validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] + validation_urls = [url for url_pat in config.validation_urls for url in expand_glob(url_pat)] dataset = levanter.data.datasource_from_jsonl(validation_urls) input_field = config.input_field diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 45265c994..d49f4553b 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -3,6 +3,7 @@ import copy import dataclasses import logging as pylogging +import operator import os import pprint import random @@ -12,26 +13,30 @@ from concurrent.futures import Future as threading_Future from contextlib import AbstractContextManager from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union import deepdiff import fsspec.core -import humanfriendly import jax +import numpy as np import pyarrow as pa import ray +import tensorstore as ts from dataclasses_json import dataclass_json from fsspec import AbstractFileSystem from jaxtyping import PyTree from ray.actor import ActorHandle +from ray.runtime_env import RuntimeEnv +from tqdm_loggable.auto import tqdm +from levanter.data import batched from levanter.data.dataset import AsyncDataset -from levanter.store._prefetch_actor import QueueEmpty, RayPrefetchQueue -from levanter.utils.py_utils import Stopwatch -from ..data._preprocessor import BatchProcessor, BatchProcessorPool, BatchResult, dict_from_record_batch +from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource +from ..utils.fsspec_utils import exists as fsspec_exists +from ..utils.fsspec_utils import remove as fsspec_remove from ..utils.ray_utils import ( ExceptionInfo, RefBox, @@ -40,8 +45,7 @@ log_failures_to, ser_exc_info, ) -from ..utils.thread_utils import ExceptionTrackingThread -from .jagged_array import PreparedBatch +from .jagged_array import JaggedArrayStore, PreparedBatch from .tree_store import TreeStore @@ -69,23 +73,20 @@ class CacheOptions: """ num_shard_groups: Optional[int] = 128 - """Number of groups to divide the shards into. This is used to parallelize the cache building process without - overloading Ray. If None, all shards will be in their own group.""" - shard_order_randomization_key: Optional[int] = 0 - """A key used to randomize the order of the shards before building and grouping.""" - batch_size: int = 128 - """The batch size to use when processing the data. This is used to control the memory usage of the cache building - process. Lower values will use less memory but take somewhat longer to build the cache.""" # the below options don't actually impact the cache's result, but do impact construction target_size_per_flush: int | str = "512MB" """The number of bytes to buffer before flushing to disk. This is used to control the memory usage of the cache building process. Lower values will use less memory but could take somewhat longer to build the cache.""" - prefetch_per_group: int = 4 - """The number of batches to prefetch per group. This is used to keep the processors busy and to reduce the time""" + + batch_size: int = 128 @property def target_bytes_per_flush(self): + if isinstance(self.target_size_per_flush, int): + return self.target_size_per_flush + import humanfriendly + return humanfriendly.parse_size(self.target_size_per_flush) @staticmethod @@ -99,14 +100,14 @@ def no_fanciness(batch_size: Optional[int] = None): """ if batch_size is None: batch_size = 128 - return CacheOptions(num_shard_groups=None, shard_order_randomization_key=None, batch_size=batch_size) + return CacheOptions(num_shard_groups=None, batch_size=batch_size) @staticmethod def one_group(): """ For testing, disables all the fancy features of the cache. This makes it easier to predict the behavior """ - return CacheOptions(num_shard_groups=1, shard_order_randomization_key=None, batch_size=128) + return CacheOptions(num_shard_groups=1, batch_size=128) def build_or_load_cache( @@ -116,7 +117,6 @@ def build_or_load_cache( await_finished: bool = True, monitors: Optional[Sequence["MetricsMonitor"]] = None, options: CacheOptions = CacheOptions.default(), - force_flush: bool = False, split: str = "test", ) -> "TreeCache[U]": """ @@ -144,8 +144,6 @@ def build_or_load_cache( options: Configuration for the cache. This is used to configure a few parts of the cache creation process - force_flush: for testing, forces the cache to flush after every batch. This is useful for testing. - Returns: (TreeCache) A TreeCache object that can be used to read the cache. @@ -156,7 +154,6 @@ def build_or_load_cache( shard_source=input_shards, processor=processor, options=options, - force_flush=force_flush, split=split, ) @@ -320,12 +317,11 @@ def build_or_load( shard_source: ShardedDataSource[T], processor: BatchProcessor[T, U], options: Optional["CacheOptions"] = None, - force_flush: bool = False, split: str = "test", ) -> "TreeCache[U]": if options is None: options = CacheOptions.default() - metadata = CacheMetadata(options=options, preprocessor_metadata=processor.metadata) + metadata = CacheMetadata(preprocessor_metadata=processor.metadata) try: return TreeCache.load(cache_dir, processor.output_exemplar, metadata) except FileNotFoundError: @@ -334,8 +330,6 @@ def build_or_load( shard_source=shard_source, processor=processor, options=options, - force_flush=force_flush, - split=split, ) return TreeCache(cache_dir=cache_dir, exemplar=processor.output_exemplar, ledger=None, _broker=broker) @@ -489,13 +483,11 @@ class CacheLedger: is_finished: bool = False finished_shards: List[str] = dataclasses.field(default_factory=list) field_counts: Dict[str, int] = dataclasses.field(default_factory=dict) - metadata: "CacheMetadata" = dataclasses.field(default_factory=lambda: CacheMetadata(CacheOptions(), {})) + metadata: "CacheMetadata" = dataclasses.field(default_factory=lambda: CacheMetadata({})) @staticmethod - def load_or_initialize( - cache_dir: str, source: ShardedDataSource, processor: BatchProcessor, config: "CacheOptions" - ): - metadata = CacheMetadata(options=config, preprocessor_metadata=processor.metadata) + def load_or_initialize(cache_dir: str, source: ShardedDataSource, processor: BatchProcessor): + metadata = CacheMetadata(preprocessor_metadata=processor.metadata) try: return CacheLedger.load(cache_dir, metadata) except FileNotFoundError: @@ -531,7 +523,6 @@ def _serialize_and_commit(self, cache_dir): @dataclass_json @dataclass(frozen=True) class CacheMetadata: - options: CacheOptions = CacheOptions.default() preprocessor_metadata: Optional[dict[str, Any]] = None def compare_to(self, other: "CacheMetadata") -> deepdiff.DeepDiff: @@ -552,13 +543,6 @@ def empty(): return CacheMetadata() -@dataclass -class _ShardStatus: - shard_name: str - num_rows_committed: int - is_finished: bool - - class SerialCacheWriter(AbstractContextManager): """ Writes TreeCache-compatible caches to disk. This is a serial version of TreeCacheWriter that doesn't use Ray. @@ -616,91 +600,6 @@ def write_batch(self, batch: BatchResult): self._tree_store.extend(cbatch) -class ShardedCacheWriter: - """ - Similar to SerialCacheWriter, but tracks shard metadata. - - Similar to _OrderedCacheWriter, it also supports resuming, and it - groups together batches before writing (at some interval) in order to improve performance. - """ - - def __init__( - self, - cache_dir: str, - initial_ledger: CacheLedger, - exemplar: T, - on_write: Optional[Callable[[CacheLedger], None]] = None, - ): - self.cache_dir = cache_dir - self._on_write = on_write - - self._ledger = copy.deepcopy(initial_ledger) - - self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a") # type: ignore - self._tree_store.trim_to_size(self._ledger.total_num_rows) - - @property - def ledger(self): - return self._ledger - - # we have both versions b/c we need this one for actors - def get_ledger(self): - return self._ledger - - @property - def is_finished(self): - return self._ledger.is_finished - - def finish_shard(self, shard_name: str, num_rows: int): - current_rows = self._ledger.shard_rows.get(shard_name, 0) - if current_rows != num_rows: - raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}") - - self._ledger.finished_shards.append(shard_name) - self._ledger._serialize_and_commit(self.cache_dir) - - def write_prepared_batch(self, shard_counts: Mapping[str, int], batch: PyTree[PreparedBatch]): - if self.is_finished: - raise RuntimeError("Cannot write to a finished cache") - self._tree_store.extend_with_batch(batch) - - for shard, num_rows in shard_counts.items(): - self._ledger.shard_rows[shard] = self._ledger.shard_rows.get(shard, 0) + num_rows - - total_rows = self._ledger.total_num_rows + sum(shard_counts.values()) - self._ledger.total_num_rows = total_rows - self._ledger._serialize_and_commit(self.cache_dir) - - if self._on_write: - self._on_write(self._ledger) - - def write_batch(self, shard_name: str, batch: BatchResult): - if self.is_finished: - raise RuntimeError("Cannot write to a finished cache") - - if isinstance(batch, pa.RecordBatch): - raise NotImplementedError("Only non-RecordBatch batches are supported for now") - - batch = _canonicalize_batch(batch) # type: ignore - prepared = self._tree_store.batch_preparer(batch) - - return self.write_prepared_batch({shard_name: len(batch)}, prepared) - - def finish(self): - # if successful, write the ledger - logger.info("Finished writing cache") - # check that all shards are finished - if set(self._ledger.shard_rows.keys()) != set(self._ledger.finished_shards): - raise ValueError("Not all shards are finished") - - self._ledger.is_finished = True - self._ledger._serialize_and_commit(self.cache_dir) - if self._on_write: - self._on_write(self._ledger) - - return self._tree_store - - def _serialize_json_and_commit(path, obj): # just to be paranoid, we write to a temp file and then rename it # TODO: probably we could do better here @@ -711,11 +610,10 @@ def _serialize_json_and_commit(path, obj): fs.copy(path, f"{path}.bak") for i in range(10): - with fsspec.open(f"{path}.tmp", "w") as file: - file.write(obj.to_json()) try: - fs.rename(f"{path}.tmp", path) + with fsspec.open(path, "w") as file: + file.write(obj.to_json()) break except FileNotFoundError: # this happens for some reason sometimes. It makes no sense. @@ -724,7 +622,9 @@ def _serialize_json_and_commit(path, obj): pass -@ray.remote(num_cpus=0.1) # keep this small b/c it doesn't do a lot +@ray.remote( + num_cpus=0.1, runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "cpu"}) +) # keep this small b/c it doesn't do a lot class _TreeStoreCacheBuilder(SnitchRecipient): """ Actor that coordinates the building of a cache. It spins up a bunch of workers to read from each shard @@ -736,11 +636,9 @@ def __init__( self, cache_dir: str, name: str, - split: str, # to workaround https://github.com/ray-project/ray/issues/44083 source: ShardedDataSource[T], processor: BatchProcessor[T, U], options: CacheOptions, - force_flush: bool, ): pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) self.logger = pylogging.getLogger(f"{__name__}.{name}") @@ -751,7 +649,7 @@ def __init__( self._options = options self._updated_ledger_condition = asyncio.Condition() # used to subscribe to metrics updates - self._ledger = CacheLedger.load_or_initialize(cache_dir, source, processor, options) + self._ledger = CacheLedger.load_or_initialize(cache_dir, source, processor) if self._ledger.is_finished: self._finished_promise.set_result(None) @@ -770,7 +668,16 @@ def __init__( # (we get twice from we need to concatenate prepared batches into the accumulator) # TODO: measure. memory=2 * self._options.target_bytes_per_flush, - ).remote(current_actor_handle(), cache_dir, self._ledger, source, processor, force_flush) + ).remote(current_actor_handle(), cache_dir, source, options, processor) + + self._tokenize_pbar = tqdm( + total=len(source.shard_names), desc=f"{path_for_name}: tokenizing", unit="shard" + ) + self._copy_pbar = tqdm(total=len(source.shard_names), desc=f"{path_for_name}: copying", unit="shard") + self._report_totals = _ProgressReport(0, 0, 0) + self._copy_report_totals = _ProgressReport(0, 0, 0) + self._last_update = time.time() + except Exception: # Ray behaves poorly if the constructor of an actor fails, so we catch and log here # this also propagates to the finished promise, so we can handle it there @@ -827,16 +734,26 @@ def _writer_exception(self, shard_name, exc_info: ExceptionInfo): pass self._do_notify() + def _child_failed(self, child: ray.actor.ActorHandle | str | None, exception: ExceptionInfo): + self._writer_exception(str(child), exception) + def _notify_updated_ledger(self, ledger: CacheLedger): """ Called by the cache writer when it has updated the ledger. """ was_finished = self._ledger.is_finished - self._ledger = ledger + # ensure the ledger is "monotonic" meaning that we only expect it to grow + if ledger.total_num_rows < self._ledger.total_num_rows: + raise RuntimeError(f"Ledger went backwards: {ledger.total_num_rows} < {self._ledger.total_num_rows}") + + for shard, rows in ledger.shard_rows.items(): + if rows < self._ledger.shard_rows.get(shard, 0): + raise RuntimeError(f"Shard {shard} went backwards: {rows} < {self._ledger.shard_rows.get(shard, 0)}") if was_finished: raise RuntimeError("Ledger was already finished") + self._ledger = ledger if self._ledger.is_finished: logger.info(f"Finalizing cache {self._cache_dir}...") # guard against invalid state errors @@ -854,40 +771,62 @@ async def _do_notify_async(): asyncio.create_task(_do_notify_async()) + def _report_progress(self, report: "_ProgressReport"): + import humanfriendly + + if report.new_shards > 0: + self._tokenize_pbar.update(report.new_shards) + self._report_totals.new_shards += report.new_shards + self._report_totals.new_rows += report.new_rows + self._report_totals.new_bytes += report.new_bytes + + if time.time() - self._last_update > 10.0: + self._last_update = time.time() + + mb_str = humanfriendly.format_size(self._report_totals.new_bytes) + self._tokenize_pbar.set_postfix( + { + "rows": self._report_totals.new_rows, + "shards": self._report_totals.new_shards, + "size": mb_str, + } + ) + + def _report_copy_progress(self, report: "_ProgressReport"): + self._copy_pbar.update(report.new_shards) + self._copy_report_totals.new_shards += report.new_shards + self._copy_report_totals.new_rows += report.new_rows + self._copy_report_totals.new_bytes += report.new_bytes + + if time.time() - self._last_update > 10.0: + self._last_update = time.time() + self._copy_pbar.set_postfix( + { + "shards": report.new_shards, + "rows": report.new_rows, + # "size": humanfriendly.format_size(report.new_bytes), + } + ) -def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default(), force_flush=False): - name = f"lev_cache_manager::{split}::{cache_dir}" + +def _get_builder_actor(cache_dir, shard_source, processor, options=CacheOptions.default()): + name = f"lev_cache_manager::{cache_dir}" path_for_name = os.path.join(*os.path.split(cache_dir)[-2:]) name_for_display = f"builder::{path_for_name}" return _TreeStoreCacheBuilder.options(name=name, get_if_exists=True).remote( # type: ignore name=name_for_display, - split=split, cache_dir=cache_dir, source=shard_source, processor=processor, options=options, - force_flush=force_flush, ) ##### # Core implementation starts below. ##### -# The main idea is to have a bunch of reader tasks that read batches, dispatch tokenization tasks, producing -# a stream of tokenized batches. We then interleave these tokenized batches and write them to the cache. -# The reader tasks are given a group of shards, which are implicitly concatenated together. - - -@dataclass -class _Batch: - """ - A batch of data that has either been read or tokenized. - """ - - shard_name: str - row_indices: List[int] - payload: ray.ObjectRef +# The main idea is to tokenize each shard group in parallel, and then write the results to the cache in order. @dataclass @@ -898,33 +837,23 @@ class _ShardFinished: shard_name: str total_rows: int - - -_Message = _Batch | _ShardFinished -""" -A message that can be sent from a reader task to the writer task. -""" - -_TIME_BETWEEN_WRITES = 20.0 # seconds + path_to_shard: str @ray.remote(num_cpus=1) def _core_writer_task( parent, cache_dir, - initial_ledger: CacheLedger, source: ShardedDataSource, + options: CacheOptions, processor, - force_flush: bool, ): """ This is the main task that processes the data and writes it to the cache. - It chains together: - * 1 generator per shard group - * interleaving of the generators - * processing of the batches - * writing of the batches to the cache + It receives "finished shards" messages from the reader tasks, and copies the data from temporary files + to the cache directory. + """ pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) logger.info("Starting writer task") @@ -933,400 +862,612 @@ def _core_writer_task( # append a small random number to the name to avoid collisions name += f"::{random.randint(0, 1000)}" + # we want to do the following: + # 1. write the 0th shard group to the output cache directly, updating metrics as we go + # 2. in the background, start processing other shard groups to temporary caches + # 3. once (1) is done, we start copying the temporary caches to the output cache (in order) + + # We notify the parent actor of progress and updates to the ledger. + # We special-case the 0'th ledger because we commit it to the output cache directly. + def report_fn(report: _ProgressReport, ledger: CacheLedger): + parent._report_progress.remote(report) + + def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): + parent._report_progress.remote(report) + ray.get(parent._notify_updated_ledger.remote(ledger)) + with log_failures_to(parent): + temporary_cache_path = os.path.join(cache_dir, "___temp") + + group_cache_paths: dict[str, str] = {} + group_ledgers: dict[str, CacheLedger | None] = {} + write_refs: dict[str, ray.ObjectRef] = {} - def on_write(ledger): - ray.get(parent._notify_updated_ledger.remote(ledger)) + shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) - sharded_cache_writer = ShardedCacheWriter( - cache_dir, initial_ledger, processor.output_exemplar, on_write=on_write + for name, group in shard_groups.items(): + assert len(group) > 0 + + logger.debug( + f"Tokenizing {len(source.shard_names)} shards in {len(shard_groups)} groups to {temporary_cache_path}." ) - options = initial_ledger.metadata.options - num_groups = min(options.num_shard_groups or 1000000, len(source.shard_names)) + processor_ref = ray.put(processor) + source_ref = ray.put(source) + + # We treat the first group specially: we tokenize it directly to the output cache (since it comes first) + # This enables us to expose data quickly + first_group = next(iter(shard_groups), None) - processor_pool = _mk_processor_pool(processor, 0, num_groups * 4) + for group_name, shards in shard_groups.items(): + if group_name == first_group: + group_out_path = cache_dir + else: + group_out_path = os.path.join(temporary_cache_path, group_name) + + group_cache_paths[group_name] = group_out_path + + ledger = _try_load(group_out_path) + group_ledgers[group_name] = ledger + + if ledger is not None: + if group_name == first_group: + ray.get(parent._notify_updated_ledger.remote(ledger)) + continue + + report_fn_to_use = report_fn_first_group if group_name == first_group else report_fn + + ref = ( + ray.remote(_tokenize_one_shard_group) + .options( # type: ignore + num_cpus=processor.num_cpus, + num_gpus=processor.num_gpus, + resources=processor.resources, + memory=3 * 1024 * 1024 * 1024, # made this up + name=f"tokenize::{temporary_cache_path}::{group_name}", + retry_exceptions=True, + max_retries=10, + ) + .remote(group_out_path, source_ref, shards, processor_ref, options, report_fn_to_use, parent) + ) - interleave: RayPrefetchQueue = RayPrefetchQueue( - lambda: _make_interleave(name, source, initial_ledger, processor_pool), - 64, - producer_options={"num_cpus": 1, "name": f"{name}::interleave"}, + write_refs[group_name] = ref + + ledger = _start_copies( + parent, + cache_dir, + shard_groups, + first_group, + write_refs, + group_ledgers, + group_cache_paths, + processor, + processor_ref, ) - total_time = Stopwatch() - loading_time = Stopwatch() - append_time = Stopwatch() - flush_time = Stopwatch() - flush_amortized_time = Stopwatch() - - current_prepared_batch: Optional[PyTree[PreparedBatch]] = None - current_shard_rows: dict[str, int] = {} - time_of_last_write = time.time() - batches_total = 0.0 - flush_thread = None - finished_shards_last_flush: list = [] - - while True: - with total_time: # 0.0051 - try: - cur_time = time.time() - time_since_last_write = cur_time - time_of_last_write - remaining_time = _TIME_BETWEEN_WRITES - time_since_last_write - - if current_prepared_batch is not None: - with flush_amortized_time: # 6e-4 - current_byte_size = sum( - b.byte_size for b in jax.tree_util.tree_flatten(current_prepared_batch)[0] - ) - should_flush = ( - force_flush - or remaining_time <= 0 - or (current_byte_size >= options.target_bytes_per_flush) - ) - if should_flush: - with flush_time: # 0.613s - if flush_thread is not None: - flush_thread.join() - - flush_thread = ExceptionTrackingThread( - target=_write_batches, - args=( - sharded_cache_writer, - current_shard_rows, - current_prepared_batch, - finished_shards_last_flush, - ), - ) - flush_thread.start() - - current_prepared_batch = None - current_shard_rows = {} - finished_shards_last_flush = [] - - time_of_last_write = time.time() - continue - else: - remaining_time = _TIME_BETWEEN_WRITES - - with loading_time: - try: - message = interleave.get_next(timeout=max(remaining_time, 0.1)) - except QueueEmpty: - logger.info("Writer running ahead of reader.") - continue - - with append_time: - match message: - case _Batch(shard, row_indices, payload): - batches_total += 1 - this_prepared_batch = ray.get(payload) - if current_prepared_batch is None: - # TODO: actually check row indices - current_shard_rows = {shard: len(row_indices)} - current_prepared_batch = this_prepared_batch - else: - current_shard_rows[shard] = current_shard_rows.get(shard, 0) + len(row_indices) - current_prepared_batch = _concat_prepared_batches( - current_prepared_batch, this_prepared_batch - ) - del this_prepared_batch - - if force_flush: - _write_batches( - sharded_cache_writer, - current_shard_rows, - current_prepared_batch, - finished_shards_last_flush, - ) - finished_shards_last_flush = [] - current_prepared_batch = None - current_shard_rows = {} - - case _ShardFinished(shard, total_rows): - finished_shards_last_flush.append((shard, total_rows)) - case _: - raise AssertionError(f"Unexpected message type {type(message)}") - - # if batches_total % 1000 == 0: - # print( - # f"Processed {batches_total} batches: {loading_time.average()}s load," - # f" {append_time.average()}s append, {flush_time.average()}s flush blocked, " - # f"{flush_amortized_time.average()}s amortized flush, " - # f"{total_time.average()}s total" - # ) - except StopIteration: - logger.info("Finished all shards") - break - except Exception as e: - logger.exception("Error while processing batch") - raise e + ledger.is_finished = True + ledger._serialize_and_commit(cache_dir) + ray.get(parent._notify_updated_ledger.remote(ledger)) + + temporary_cache_paths = set(group_cache_paths.values()) - {cache_dir} + _clean_up_temp_caches(temporary_cache_paths) + - # force a flush - if current_prepared_batch is not None or finished_shards_last_flush: - if flush_thread is not None: - flush_thread.join() - _write_batches( - sharded_cache_writer, current_shard_rows, current_prepared_batch, finished_shards_last_flush +def _start_copies( + parent, + cache_dir, + shard_groups, + first_group, + write_refs, + group_ledgers, + group_cache_paths, + processor, + processor_ref, +): + """ + Copy the temporary caches to the output cache, in order. (essentially concatenating them) + + Args: + parent: the parent actor handle (_TreeStoreCacheBuilder) + cache_dir: the output cache directory + shard_groups: a dict mapping group names to lists of shard names + first_group: the privileged group that is written directly to the output cache + write_refs: a dict mapping group names to ray.ObjectRefs of the cache building tasks + group_ledgers: a dict mapping group names to the ledgers for the groups. Mutated in place. + group_cache_paths: a dict mapping group names to the paths of the temporary caches + processor: the processor object + processor_ref: a ray.ObjectRef of the processor object + + Returns: + The final ledger + """ + # This logic is a bit hairy thanks to resumes. + # First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these + # separately. We also need to update the ledger as we go. + # Second, note that JaggedArrayStores have two notions of length: the number of rows, and the data size. + # We store the number of rows in offsets[0], and the data size in offsets[offsets[0]], which is just the final offset. + # So we can keep a cache "locked" to a particular read size until we're ready by controlling the offsets. + + # * When we load the permanent cache, we have already written some number of groups to it. In + # particular, we have written the 0'th group to the permanent cache. + # * We enforce that we only commit a whole group to the ledger at a time. + # * We need to copy the remaining groups to the permanent cache, and update the ledger as we go. + # * To copy a group, we need to know the total number of rows in that group, as well as the "data offsets" + # for the data in the cache. We can get the total number of rows from the ledger, and we also calculate + # the data offsets for where the group goes in the permanent cache. This is just a running sum of the + # data sizes of the previous groups. Because we have multiple JaggedArrayStores, this can be a pytree + # of integers, one for each array. + # * Once we have finished the i'th cache and all caches < 1, we can "unlock" the data for the i'th cache + # by updating the offset[0] of the permanent cache to the total number of rows through the i'th cache. + # * We also need to update the ledger with the total number of rows + + # reload the ledger for the first group, which will be the sink for the other groups + assert first_group in write_refs + + group_ledgers[first_group] = ray.get(write_refs[first_group]) + overall_ledger = group_ledgers[first_group] + + # initialize the data offset tree + permanent_cache = TreeStore.open(processor.output_exemplar, cache_dir, mode="a", cache_metadata=False) + data_offset_tree = jax.tree_map(lambda x: x.data_size, permanent_cache.tree) + total_rows_from_caches = overall_ledger.total_num_rows + copy_refs: dict[str, ray.ObjectRef] = {} + last_ref: ray.ObjectRef | None = None + + found_one_to_copy = False + + for group in shard_groups: + # first make sure it's either done this run or already done + if write_refs.get(group) is not None: + this_ledger = ray.get(write_refs[group]) + group_ledgers[group] = this_ledger + else: + this_ledger = group_ledgers[group] + + if group == first_group: + # this is the first group, so it's already in the cache and we don't need to + # increment the data offset tree etc. + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(overall_ledger.finished_shards), new_rows=overall_ledger.total_num_rows) ) + continue - sharded_cache_writer.finish() + assert this_ledger is not None + # see if we already copied this group, meaning all the shards are in the permanent cache + shards_copied = sum(1 if shard in overall_ledger.finished_shards else 0 for shard in shard_groups[group]) + + if found_one_to_copy and shards_copied > 0: + raise RuntimeError("A previous group was copied, but this group was not. This should never happen.") + elif shards_copied == len(shard_groups[group]): + assert ( + overall_ledger.total_num_rows >= total_rows_from_caches + ), f"{overall_ledger.total_num_rows} < {total_rows_from_caches}. {group}" + continue # nothing to do + elif shards_copied > 0: + # In theory we can handle this, but it's a bit tricky, so we're going to punt for now + raise RuntimeError("Some shards were copied but not all. This should never happen.") + + found_one_to_copy = True + # we need to copy this group + + # we can't "commit" the group to the ledger (or the number of rows) + # until we've updated the ledger for all previous groups, so we block on the last ref + ref_to_send = None if last_ref is None else RefBox(last_ref) + + last_ref = _copy_cache.remote( + cache_dir, + group_cache_paths[group], + processor_ref, + data_offset_tree, + ref_to_send, + total_rows_from_caches, + parent, + ) + copy_refs[group] = last_ref - out = sharded_cache_writer.get_ledger() - return out + # update the offset information: data offsets and total rows + this_cache = TreeStore.open(processor.output_exemplar, group_cache_paths[group], mode="r", cache_metadata=True) + data_offset_tree = jax.tree.map( + operator.add, data_offset_tree, jax.tree.map(lambda x: x.data_size, this_cache.tree) + ) + total_rows_from_caches += this_ledger.total_num_rows + + # refs form a linked list implicitly, so we can just wait on the last one + if last_ref is not None: + ledger = ray.get(last_ref) + else: + ledger = overall_ledger + return ledger -def _concat_prepared_batches( - current_prepared_batch: PyTree[PreparedBatch], this_prepared_batch: PyTree[PreparedBatch] -): - return jax.tree.map(lambda *bs: PreparedBatch.concat(bs), current_prepared_batch, this_prepared_batch) +def _clean_up_temp_caches(paths): + for path in paths: + if fsspec_exists(path): + for i in range(10): + # this is crashy for some reason + try: + fsspec_remove(path, recursive=True) + break + except Exception: + logger.exception(f"Failed to remove {path} on attempt {i}") + time.sleep(1) + +def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]: + if num_groups is None or num_groups >= len(source.shard_names): + return {shard_name: [shard_name] for shard_name in source.shard_names} -def _write_batches(writer: ShardedCacheWriter, shard_totals, batch: Optional[PyTree[PreparedBatch]], finished_shards): - # concatenate the payloads - if batch is not None: - writer.write_prepared_batch(shard_totals, batch) + shard_names = source.shard_names + num_shards_per_group = (len(shard_names)) // num_groups + num_groups_with_extra = len(shard_names) % num_groups - for shard, total_rows in finished_shards: - writer.finish_shard(shard, total_rows) + # if we have a remainder, we want to distribute the extra shards evenly + out_groups: dict[str, list[str]] = {} + start = 0 + for i in range(num_groups): + num_shards = num_shards_per_group + (1 if i < num_groups_with_extra else 0) + out_groups[f"group_{i}"] = list(shard_names[start : start + num_shards]) + start += num_shards + # make sure we got all the shards + assert sum(len(shards) for shards in out_groups.values()) == len(shard_names) -def _fetch_batches(batches) -> tuple[dict[str, int], list[PreparedBatch]]: - shards_for_batches, payloads_for_batches = zip(*batches) - payloads_for_batches = ray.get(list(payloads_for_batches)) + return out_groups # type: ignore - shard_row_totals: dict[str, int] = {} - for shard, payload in zip(shards_for_batches, payloads_for_batches): - shard_row_totals[shard] = shard_row_totals.get(shard, 0) + jax.tree.leaves(payload)[0].num_rows - return shard_row_totals, payloads_for_batches +def _merge_ledgers(dest, source): + dest.total_num_rows += source.total_num_rows + for shard, rows in source.shard_rows.items(): + current_value = dest.shard_rows.get(shard, 0) + assert current_value == 0, f"Shard {shard} already has {current_value} rows" + dest.shard_rows[shard] = rows + dest.finished_shards.extend(source.finished_shards) + for field, count in source.field_counts.items(): + dest.field_counts[field] = dest.field_counts.get(field, 0) + count -def _interleave_shards(readers: Sequence[RayPrefetchQueue], first_index: int) -> Iterator[T]: # _Message + return dest + + +@ray.remote(num_cpus=4, memory=4 * 1024 * 1024 * 1024) +def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: RefBox, rows_so_far, parent): """ - Interleaves the results of multiple iterators. To support resume, - we need to be able to start from not the "first" iterator. + Copies the data from one cache to another, appending it to the end of the destination cache. + Once the copy is done and the last_ref is set, the data is "unlocked" in the destination cache by updating the + offsets[0] of the destination cache to the total number of rows in the cache. Args: - readers: A list of iterators - first_index: The index of the first iterator to start from. We use this to support resuming. - """ + dest_path: The path to the destination cache. + source_path: The path to the source cache. + processor: The processor used to create the cache. + data_offset_tree: The data offset tree for the destination cache. + last_ref: The ref to wait on before updating the ledger. + rows_so_far: The total number of rows in the destination cache before this copy. - finished: set[int] = set() - total = 0 - while len(finished) < len(readers): - for i in range(first_index, len(readers)): - reader = readers[i] - if i not in finished: - try: - message = reader.get_next() - total += 1 - yield message - except StopIteration: - finished.add(i) - except Exception as e: - logger.exception(f"Error while processing group {i}") - raise e + Returns: - first_index = 0 + """ + with log_failures_to(parent): + asyncio.run(_extend_cache_with_other_cache(dest_path, source_path, processor, data_offset_tree, rows_so_far)) + if last_ref is not None: + ray.wait([last_ref.ref], fetch_local=False) + permanent_cache = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) + source_ledger = CacheLedger.load(source_path) + + new_num_rows = source_ledger.total_num_rows + rows_so_far + + futures = jax.tree.leaves(jax.tree.map(lambda x: x.offsets[0].write(new_num_rows), permanent_cache.tree)) + for future in futures: + future.result() + + dest_ledger = CacheLedger.load(dest_path) + _merge_ledgers(dest_ledger, source_ledger) + dest_ledger._serialize_and_commit(dest_path) + assert not dest_ledger.is_finished + + ray.get(parent._notify_updated_ledger.remote(dest_ledger)) + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(source_ledger.shard_rows), new_rows=source_ledger.total_num_rows) + ) - logger.info(f"Finished all shards, got {total} batches") + return dest_ledger -def _assign_shards_to_groups(shards: Sequence[_ShardStatus], num_groups: int) -> list["_ShardGroup"]: +async def _extend_cache_with_other_cache( + dest_path: str, source_path: str, processor: BatchProcessor, data_offset_tree: PyTree[int], row_offset +) -> int: """ - Assigns shards to groups in a round-robin fashion. + Copies the data from one cache to another, appending it to the end of the destination cache. + + Returns: + The number of rows in the source cache. """ - groups: list[list] = [[] for _ in range(num_groups)] - for i, shard in enumerate(shards): - groups[i % num_groups].append(shard) - return [_ShardGroup(group) for group in groups] + logger.info(f"Copying data from {source_path} to {dest_path}.") + dest = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) + source = TreeStore.open(processor.output_exemplar, source_path, mode="r", cache_metadata=True) + + source_num_rows = await source.async_len() + + async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore, data_offset: int): + """Copies **just the data array** from one shard to the permanent cache at a given offset.""" + # TODO: it'd be good if we just didn't expose the full data array (but only the used part) + data_size = source_array.data_size + data = source_array.data[0:data_size] + futures: list[ts.Future] = [] + + # write_future = dest_array.data[data_offset : data_offset + source_array.data_size].write(data) + async with ts.Transaction() as txn: + dest = dest_array.data + out_end = data_offset + data_size + write_future = dest.with_transaction(txn)[data_offset:out_end].write(data) + futures.append(write_future) + + if source_array.shapes is not None: + source_shapes = source_array.shapes[0:source_num_rows] + async with ts.Transaction() as txn: + dest = dest_array.shapes + out_end = row_offset + source_num_rows + shape_future = dest.with_transaction(txn)[row_offset:out_end].write(source_shapes) + futures.append(shape_future) + + source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]] + source_offsets = _virtual_offset(source_offsets, data_offset) + + async with ts.Transaction() as txn: + dest = dest_array.offsets + out_end = row_offset + 1 + source_num_rows + offset_future = dest.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets) + + futures.append(offset_future) + + out = await asyncio.gather(*futures) + return out + futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree) -def _randomize_shards(shards: Sequence[T], seed: int) -> list[T]: - prng = random.Random(seed) - shuffled = list(shards) - prng.shuffle(shuffled) - return shuffled + await asyncio.gather(*jax.tree.leaves(futures)) + logger.info(f"Finished copying data from {source_path} to {dest_path}.") + return source_num_rows -class _ShardGroup: - """ - Given a group of shards and a list of statuses, implicitly concatenates the shards and reads from them. - This class mostly exists for resuming: we want to be able to start from the last shard we were working on. +def _virtual_offset(base: ts.TensorStore, offset_amount): + """ + This function creates a new tensorstore that is a virtual offset of another tensorstore. + That is, it's y[i] = x[i] + offset_amount. """ - def __init__(self, group: list[_ShardStatus]): - self.shards = group - self.total_rows_committed, _all_finished = self._impute_total_rows_committed_and_check_invariants() - - def _impute_total_rows_committed_and_check_invariants(self): - # we also want to ensure that we haven't started any shards until we've finished the previous ones - total_committed = 0 - last_shard_name = None - last_was_finished = True - all_finished = True - - for status in self.shards: - shard_name = status.shard_name - if not last_was_finished and status.num_rows_committed > 0: - raise ValueError( - f"Shard {shard_name} has rows committed but previous shard in group {last_shard_name} " - "is not finished. Something about the cache configuration has changed: either the " - "number/order of shards, the shard shuffle random seed, or the number of groups." - ) - total_committed += status.num_rows_committed - if not status.is_finished: - all_finished = False - last_was_finished = status.is_finished - last_shard_name = shard_name + async def do_read(domain: ts.IndexDomain, array: np.ndarray, read_params: ts.VirtualChunkedReadParameters): + array[...] = (await base[domain].read()) + offset_amount - return total_committed, all_finished + return ts.virtual_chunked(do_read, dtype=base.dtype, domain=base.domain, shape=base.shape) -def _make_interleave(name: str, source: ShardedDataSource, initial_ledger: CacheLedger, processor_pool: ActorHandle): - """ - Given a list of ShardStatus objects and sources, creates an interleaving generator - that reads from shards and tokenizes them in parallel. +async def _copy_data_from_one_shard_to_permanent_memory( + dest_path: str, + source_path: str, + processor: BatchProcessor, + data_offset_tree: PyTree[int], +): + """Copies from one tree store to the permanent cache at a given offset (for each leaf)""" + logger.info(f"Copying data from {source_path} to {dest_path}.") + dest = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) + source = TreeStore.open(processor.output_exemplar, source_path, mode="r", cache_metadata=True) - We use ShardStatus objects to track the progress of each shard. If we're preempted, we can resume - from the last shard we were working on. This function starts each shard at the last committed row - and starts interleaving from the next shard (i.e. the one with the fewest rows that isn't finished). - """ - logger.setLevel(DEFAULT_LOG_LEVEL) - statuses = _get_shard_statuses(initial_ledger, source) + def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore, data_offset: int): + # TODO: it'd be good if we just didn't expose the full data array (but only the used part) + data = source_array.data[0 : source_array.data_size] + # write_future = dest_array.data[data_offset : data_offset + source_array.data_size].write(data) + with ts.Transaction() as txn: + dest = dest_array.data + out_end = data_offset + source_array.data_size + write_future = dest.with_transaction(txn)[data_offset:out_end].write(data) - options = initial_ledger.metadata.options + return write_future - unfinished_shards = _check_current_shard_progress(statuses) + futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree) - if not unfinished_shards: - logger.info("All shards finished. Nothing to do.") - return + await asyncio.gather(*jax.tree.leaves(futures)) + logger.info(f"Finished copying data from {source_path} to {dest_path}.") + return - group_names, groups = _randomize_and_group_shards(name, options, statuses) - logger.warning(f"Starting cache build with {len(statuses)} shards, in {len(groups)} groups") +@dataclass +class _ProgressReport: + new_rows: int = 0 + new_bytes: float = 0 + new_shards: int = 0 + # TODO: other counts - def _make_generator_fn(group: _ShardGroup): - def generator(): - pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) - for message in _shard_reader_generator(source, group, options.batch_size): - match message: - case _Batch(): - # processed = ray.put(process_task(ray.get(message.payload))) - # processed = process_task.remote(processor_ref, message.payload) - processed = processor_pool.process_batch.remote(RefBox(message.payload)) - yield dataclasses.replace(message, payload=processed) - case _ShardFinished(): - yield message - case _: - raise AssertionError(f"Unexpected message type {type(message)}") - return generator +def _tokenize_one_shard_group( + temporary_cache_path: str, + source: ShardedDataSource, + shards: list[str], + processor: BatchProcessor, + options: CacheOptions, + report_fn: Callable[[_ProgressReport, CacheLedger], None], + force_unfinalized: bool, +) -> CacheLedger: + # ray breaks if this is top level + import humanfriendly - generator_fns = [_make_generator_fn(group) for group in groups] + logger = pylogging.getLogger("tokenize") + pylogging.basicConfig(level=pylogging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") - readers = [ - RayPrefetchQueue( - fn, - options.prefetch_per_group, - producer_options=dict(num_cpus=0.1, name=name, scheduling_strategy="SPREAD"), - ) - for name, fn in zip(group_names, generator_fns) - ] + # restrict shards to the ones we're supposed to process + # this is a bit hacky but when there are a lot of shards (e.g. SlimPajama 122K), + # we encounter significant overhead just parsing the shard names from the json + source = _RestrictedShardedDataSource(source, shards) - # then figure out the first shard to start from. This is the first unfinished shard with the minimum number of rows - first_group_to_start = min( - range(len(groups)), - key=lambda i: groups[i].total_rows_committed, - ) + ledger = CacheLedger.load_or_initialize(temporary_cache_path, source, processor) - yield from _interleave_shards(readers, first_group_to_start) + if ledger.is_finished: + logger.info("Shard group already processed.") + return ledger + writer = ShardGroupCacheWriter(temporary_cache_path, ledger, shards, processor.output_exemplar) -def _mk_processor_pool(processor, min_size, max_size): - import hashlib + total_rows = ledger.total_num_rows + found_shard_with_rows = False - metadata_hash = hashlib.md5(str(processor.metadata).encode()).hexdigest() - processor_pool_name = f"processor_pool::{metadata_hash}" - processor_pool = BatchProcessorPool.options( # type: ignore - name=processor_pool_name, get_if_exists=True, lifetime="detached" - ).remote( # type: ignore - processor, min_size, max_size - ) + if total_rows > 0: + report_fn(_ProgressReport(new_rows=total_rows), ledger) - ray.get(processor_pool.ensure_max_at_least.remote(max_size)) + for shard_name in shards: + if shard_name in ledger.finished_shards: + logger.info(f"Shard {shard_name} already processed.") + report_fn(_ProgressReport(new_shards=1), ledger) + continue - return processor_pool + logger.debug(f"Processing {shard_name}.") + rows_this_shard = ledger.shard_rows.get(shard_name, 0) -def _check_current_shard_progress(statuses): - unfinished_shards: list[_ShardStatus] = [] - shards_with_progress: dict[str, int] = {} - for status in statuses: - if not status.is_finished: - unfinished_shards.append(status) - if status.num_rows_committed > 0: - shards_with_progress[status.shard_name] = status.num_rows_committed - if unfinished_shards and shards_with_progress: - formatted = ", ".join(f"{k}: {v}" for k, v in shards_with_progress.items()) - logger.info(f"Resuming from shards with progress: {formatted}") - return unfinished_shards + if found_shard_with_rows and rows_this_shard != 0: + raise ValueError("Found more than one shard with rows to process.") + if rows_this_shard != 0: + found_shard_with_rows = True -def _randomize_and_group_shards(name, options, statuses): - if options.shard_order_randomization_key is not None: - seed = options.shard_order_randomization_key - logger.info(f"Randomizing shard order with seed {seed}") - statuses = _randomize_shards(statuses, seed) + shard_iterator = source.open_shard_at_row(shard_name, rows_this_shard) - num_groups = min( - options.num_shard_groups if options.num_shard_groups is not None else len(statuses), len(statuses) - ) - if num_groups == 1: - group_names = [f"generator::{name}::all_shards"] - elif len(statuses) == num_groups: - group_names = [f"generator::{name}::{status.shard_name}" for status in statuses] - else: - group_names = [f"generator::{name}::group_{i}" for i in range(num_groups)] + prepared_batch: PyTree[PreparedBatch] | None = None + this_batch_size = 0 + + for batch in batched(shard_iterator, options.batch_size): + tokenized = processor(batch) + tokenized = _canonicalize_batch(tokenized) # type: ignore + this_prepared = writer._tree_store.batch_preparer(tokenized) + + this_batch_size += len(batch) + rows_this_shard += len(batch) + total_rows += len(batch) + + if prepared_batch is None: + prepared_batch = this_prepared + else: + prepared_batch = jax.tree.map( + lambda *trees: PreparedBatch.concat(trees), prepared_batch, this_prepared + ) - groups = _assign_shards_to_groups(statuses, num_groups) - return group_names, groups + batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) + + if batch_byte_size > options.target_bytes_per_flush: + writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) + report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger) + + nice_bytes = humanfriendly.format_size(batch_byte_size) + logger.debug( + f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" + ) + # print(f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})", flush=True) + this_batch_size = 0 + prepared_batch = None + + if prepared_batch is not None: + batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) + nice_bytes = humanfriendly.format_size(batch_byte_size) + + report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger) + + writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) + logger.debug( + f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" + ) + this_batch_size = 0 + prepared_batch = None + writer.finish_shard(shard_name, rows_this_shard) -def _shard_reader_generator( - shard_source: ShardedDataSource[T], group: _ShardGroup, batch_size: int -) -> Iterator[_Message]: + report_fn(_ProgressReport(new_shards=1), writer.ledger) + + if not force_unfinalized: + writer.finish() + + logger.debug(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.") + + return writer.ledger + + +class ShardGroupCacheWriter: """ - Given a group of shards, implicitly concatenates the shards and reads from them. + Similar to SerialCacheWriter, but tracks shard metadata for one shard. """ - for status in group.shards: - if status.is_finished: - logger.info(f"Skipping finished shard {status.shard_name}") - continue - start_row = status.num_rows_committed - logger.info(f"Opening shard {status.shard_name} at row {start_row}") - shard_iter = shard_source.open_shard_at_row(status.shard_name, start_row) - batch = [] - batch_idxes = [] - row_idx = start_row - for row in shard_iter: - batch.append(row) - batch_idxes.append(row_idx) - row_idx += 1 + def __init__(self, cache_dir: str, initial_ledger: CacheLedger, shards: list[str], exemplar: T): + self.cache_dir = cache_dir + + self._ledger = copy.deepcopy(initial_ledger) + self.shards = shards + + self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a") # type: ignore + self._tree_store.trim_to_size(self._ledger.total_num_rows) + + @property + def ledger(self): + return self._ledger + + # we have both versions b/c we need this one for actors + def get_ledger(self): + return self._ledger + + @property + def is_finished(self): + return self._ledger.is_finished + + def finish_shard(self, shard_name: str, num_rows: int): + if shard_name not in self.shards: + raise ValueError(f"Shard {shard_name} not in tracked shards") + + current_rows = self._ledger.shard_rows.get(shard_name, 0) + if current_rows != num_rows: + raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}") + + self._ledger.finished_shards.append(shard_name) + self._ledger._serialize_and_commit(self.cache_dir) - if len(batch) == batch_size: - yield _Batch(status.shard_name, batch_idxes, ray.put(batch)) - batch = [] - batch_idxes = [] + def write_prepared_batch(self, shard_name: str, row_count: int, batch: PyTree[PreparedBatch]): + if self.is_finished: + raise RuntimeError("Cannot write to a finished cache") + self._tree_store.extend_with_batch(batch) - if len(batch) > 0: - yield _Batch(status.shard_name, batch_idxes, ray.put(batch)) + if shard_name not in self.shards: + raise ValueError(f"Shard {shard_name} not in tracked shards") + self._ledger.shard_rows[shard_name] += row_count + self._ledger.total_num_rows += row_count - logger.info(f"Finished generating shard {status.shard_name} with {row_idx} rows") - yield _ShardFinished(status.shard_name, row_idx) + self._ledger._serialize_and_commit(self.cache_dir) + + def finish(self): + if len(self._ledger.finished_shards) != len(self.shards): + raise ValueError("Not all shards are finished") + + self._ledger.is_finished = True + self._ledger._serialize_and_commit(self.cache_dir) + # ensure all tracked shards are finished + + return self._tree_store + + +class _RestrictedShardedDataSource(ShardedDataSource): + def __init__(self, source: ShardedDataSource, shards: list[str]): + self._source = source + self._shards = shards + + @property + def shard_names(self): + return self._shards + + def open_shard_at_row(self, shard_name, row): + return self._source.open_shard_at_row(shard_name, row) + + +def _randomize_shards(shards: Sequence[T], seed: int) -> list[T]: + prng = random.Random(seed) + shuffled = list(shards) + prng.shuffle(shuffled) + return shuffled def _canonicalize_batch(batch: Union[dict, List[dict]]) -> List[dict]: @@ -1360,8 +1501,13 @@ def _ledger_to_metrics(ledger: CacheLedger) -> InProgressCacheMetrics: ) -def _get_shard_statuses(ledger: CacheLedger, source: ShardedDataSource): - return [ - _ShardStatus(name, ledger.shard_rows.get(name, 0), name in ledger.finished_shards) - for name in source.shard_names - ] +def _try_load(path): + try: + ledger = CacheLedger.load(path) + if ledger.is_finished: + return ledger + else: + logger.debug(f"Cache exists but is not finished at {path}.") + return None + except FileNotFoundError: + return None diff --git a/src/levanter/store/tree_store.py b/src/levanter/store/tree_store.py index 03355a8d2..83d6c88b0 100644 --- a/src/levanter/store/tree_store.py +++ b/src/levanter/store/tree_store.py @@ -172,6 +172,9 @@ def get_batch_sync(self, indices) -> List[T]: return out + async def async_len(self) -> int: + return await jax.tree.leaves(self.tree)[0].num_rows_async() + def _construct_builder_tree(exemplar, path, mode, cache_metadata): def open_builder(tree_path, item): diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py index 64870443d..c8d3931fe 100644 --- a/src/levanter/utils/fsspec_utils.py +++ b/src/levanter/utils/fsspec_utils.py @@ -1,5 +1,6 @@ import braceexpand import fsspec +from fsspec.asyn import AsyncFileSystem def exists(url, **kwargs) -> bool: @@ -14,7 +15,7 @@ def mkdirs(path): fs.makedirs(path, exist_ok=True) -def fsspec_expand_glob(url): +def expand_glob(url): expanded_urls = braceexpand.braceexpand(url) for expanded_url in expanded_urls: if "*" in expanded_url: @@ -28,3 +29,21 @@ def fsspec_expand_glob(url): yield from [f"{protocol}://{path}" for path in globbed] else: yield expanded_url + + +def remove(url, *, recursive=False, **kwargs): + """Remove a file from a remote filesystem.""" + # TODO: better to use a STS deletion policy or job for this one. + fs, path = fsspec.core.url_to_fs(url, **kwargs) + + fs.rm(path, recursive=recursive) + + +async def async_remove(url, *, recursive=False, **kwargs): + """Remove a file from a remote filesystem.""" + fs, path = fsspec.core.url_to_fs(url, **kwargs) + + if isinstance(fs, AsyncFileSystem): + return await fs._rm(path, recursive=recursive) + else: + fs.rm(path, recursive=recursive) diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index c1eb73670..086de48e1 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -1,6 +1,4 @@ import asyncio -import copy -import os import tempfile from typing import Any, Dict, Iterator, Sequence @@ -10,17 +8,7 @@ from levanter.data import BatchProcessor, ShardedDataSource, batched from levanter.data.sharded_datasource import TextUrlDataSource -from levanter.store.cache import ( - LEDGER_FILE_NAME, - CacheLedger, - CacheOptions, - SerialCacheWriter, - ShardedCacheWriter, - TreeStore, - _get_builder_actor, - _serialize_json_and_commit, - build_or_load_cache, -) +from levanter.store.cache import CacheOptions, SerialCacheWriter, TreeStore, _get_builder_actor, build_or_load_cache from levanter.utils.py_utils import logical_cpu_core_count @@ -140,13 +128,13 @@ def test_full_end_to_end_cache(): with td as tmpdir: ray_ds = build_or_load_cache( tmpdir, - SimpleShardSource(num_shards=2), + SimpleShardSource(num_shards=15), TestProcessor(), await_finished=True, - options=CacheOptions.no_fanciness(8), + options=CacheOptions(num_shard_groups=3, batch_size=8), ) - expected = process_interleave(TestProcessor(), SimpleShardSource(num_shards=2), 8) + expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=15)) all_data = ray_ds[:] @@ -162,15 +150,14 @@ def test_full_end_to_end_cache_with_groups(): SimpleShardSource(num_shards=5), TestProcessor(), await_finished=True, - options=CacheOptions(num_shard_groups=2, batch_size=8, shard_order_randomization_key=None), + options=CacheOptions(num_shard_groups=2, batch_size=8), ) - expected = process_interleave(TestProcessor(), SimpleShardSource(num_shards=5), 8) + expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=5)) all_data = ray_ds[:] - # check_datasets_equal(all_data, expected) - assert len(all_data) == len(list(expected)) + check_datasets_equal(all_data, expected) @pytest.mark.ray @@ -204,7 +191,6 @@ class _CustomException(Exception): @pytest.mark.ray -@pytest.mark.skip("This test segfaults in CI. I think a ray bug") def test_cache_recover_from_crash(): class CrashingShardSource(ShardedDataSource[list[int]]): def __init__(self, crash_point: int): @@ -218,7 +204,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # parse the shard name to get the shard number shard_num = int(shard_name.split("_")[1]) for i in range(10): - if shard_num * 10 + i == self.crash_point: + if i == self.crash_point: raise _CustomException(f"Crashing at {shard_num} {i} {self.crash_point}") if i >= row: yield [shard_num * 10 + i] * 10 @@ -226,7 +212,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as tmpdir2: source = CrashingShardSource(4) with pytest.raises(_CustomException): - build_or_load_cache(tmpdir, source, TestProcessor()) + build_or_load_cache(tmpdir, source, TestProcessor(), CacheOptions(target_size_per_flush=1)) # kill the broker actor so that we can test recovery ray.kill( @@ -244,11 +230,11 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: ) # testing this doesn't throw - source = CrashingShardSource(1000) + source = CrashingShardSource(100000) reader1 = build_or_load_cache(tmpdir, source, TestProcessor(), await_finished=True) # compare to the original with no crash - reader2 = build_or_load_cache(tmpdir2, SimpleShardSource(), TestProcessor(), await_finished=True) + reader2 = build_or_load_cache(tmpdir2, SimpleShardSource(num_shards=4), TestProcessor(), await_finished=True) check_datasets_equal(reader1, reader2) @@ -295,7 +281,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # now block until the cache is done cache.await_finished(timeout=30) - expected = process_interleave(processor, SlowShardSource(), 16) + expected = simple_process(processor, SlowShardSource()) check_datasets_equal(list(cache[:]), expected) @@ -334,13 +320,12 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: SlowShardSource(), TestProcessor(), await_finished=False, - force_flush=True, - options=CacheOptions.no_fanciness(5), - ) # we need force_flush to ensure the cache is written to disk + options=CacheOptions(target_size_per_flush=1, batch_size=1), + ) # read the first 10 elements # ensure the first 10 elements are [{"test": np.array([i] * 10)} for i in range(10)] - first_10 = list(await cache.get_batch(range(0, 10))) + first_10 = list(await asyncio.wait_for(cache.get_batch(range(0, 10)), timeout=30.0)) for i, x in enumerate(first_10): np.testing.assert_array_equal(x["test"], np.array([i] * 10)) @@ -353,7 +338,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # now ensure we can get the next 10 elements, which will be # [{"test": np.array([i] * 10)} for i in range(10, 20)] - batch = await asyncio.wait_for(cache.get_batch(range(10, 20)), timeout=10) + batch = await asyncio.wait_for(cache.get_batch(range(10, 20)), timeout=10.0) for i, x in enumerate(batch): np.testing.assert_array_equal(x["test"], np.array([i + 10] * 10)) @@ -364,7 +349,6 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: cache.await_finished(timeout=10) -@pytest.mark.skip("This test segfaults in CI. I think a ray bug") @pytest.mark.ray def test_shard_cache_crashes_if_processor_throws(): class ThrowingProcessor(SimpleProcessor): @@ -398,7 +382,6 @@ def test_shard_cache_fails_with_multiple_shards_with_the_same_name(): build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=True) -@pytest.mark.skip("This test segfaults in CI. I think a ray bug") @pytest.mark.ray @pytest.mark.asyncio async def test_shard_cache_fails_gracefully_with_unknown_file_type_async(): @@ -451,89 +434,3 @@ def test_shard_cache_fails_gracefully_with_unknown_file_type(): cache.await_finished(timeout=10) del cache - - -def test_sharded_cache_writer(): - with tempfile.TemporaryDirectory() as tmpdir: - source = SimpleShardSource(num_shards=4) - processor = SimpleProcessor() - ledger = CacheLedger.load_or_initialize(tmpdir, source, processor, CacheOptions.no_fanciness(8)) - - exemplar = {"data": np.array([0], dtype=np.int64)} - - writer = ShardedCacheWriter(tmpdir, ledger, exemplar) - for shard_name in source.shard_names: - for ex in batched(source.open_shard(shard_name), ledger.metadata.options.batch_size): - writer.write_batch(shard_name, processor(ex)) - - for shard_name in source.shard_names: - writer.finish_shard(shard_name, source._rows_per_shard) - - store = writer.finish() - - data_path = store.path - - del store - - builder = TreeStore.open(exemplar, data_path, mode="r") - - assert len(builder) == 40 - - for i, x in enumerate(builder): - np.testing.assert_array_equal(x["data"], np.asarray([i % 10 + i // 10 * 10] * 10)) - - # check totals for the ledger - ledger = writer.ledger - assert ledger.total_num_rows == 40 - assert ledger.is_finished - - for shard_name in source.shard_names: - assert ledger.shard_rows[shard_name] == 10 - - -def test_sharded_cache_writer_trims_on_resume(): - with tempfile.TemporaryDirectory() as tmpdir: - source = SimpleShardSource(num_shards=4) - processor = SimpleProcessor() - - exemplar = {"data": np.array([0], dtype=np.int64)} - - ledger = CacheLedger.load_or_initialize(tmpdir, source, processor, CacheOptions.no_fanciness(batch_size=8)) - - writer = ShardedCacheWriter(tmpdir, ledger, exemplar) - for shard_name in source.shard_names: - for ex in batched(source.open_shard(shard_name), 8): - writer.write_batch(shard_name, processor(ex)) - - for shard_name in source.shard_names: - writer.finish_shard(shard_name, 10) - - writer.finish() - - # now deliberately truncate the ledger a bit - ledger = copy.deepcopy(writer.ledger) - assert ledger.total_num_rows == 40 - assert ledger.is_finished - ledger.total_num_rows = 24 - ledger.shard_rows["shard_0"] = 8 - ledger.shard_rows["shard_1"] = 8 - ledger.shard_rows["shard_2"] = 8 - ledger.shard_rows["shard_3"] = 0 - ledger.is_finished = False - - _serialize_json_and_commit(os.path.join(tmpdir, LEDGER_FILE_NAME), ledger) - - writer = ShardedCacheWriter(tmpdir, ledger, exemplar) - - # ensure it got truncated - assert writer.ledger.total_num_rows == 24 - assert writer.ledger.is_finished is False - assert writer.ledger.shard_rows["shard_0"] == 8 - assert writer.ledger.shard_rows["shard_1"] == 8 - assert writer.ledger.shard_rows["shard_2"] == 8 - assert writer.ledger.shard_rows["shard_3"] == 0 - - new_store = writer._tree_store - new_data = new_store[:] - - assert len(new_data) == 24