From 5ebf8ce14eec52376fb3c785095397e59822f665 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Tue, 29 Oct 2024 21:50:55 -0700 Subject: [PATCH 01/19] add multislice support in ray (#771) --- infra/cluster/job-cluster.yaml | 45 ++++++- infra/launch_on_ray.py | 3 +- src/levanter/infra/cli_helpers.py | 5 + src/levanter/infra/ray_tpu.py | 201 +++++++++++++++++++++++++++++- 4 files changed, 244 insertions(+), 10 deletions(-) diff --git a/infra/cluster/job-cluster.yaml b/infra/cluster/job-cluster.yaml index cf8703d54..cff7d4884 100644 --- a/infra/cluster/job-cluster.yaml +++ b/infra/cluster/job-cluster.yaml @@ -14,8 +14,8 @@ cluster_name: levanter-cluster # Configure GCP provider: type: gcp - region: us-central2 - availability_zone: us-central2-b + region: us-west4 + availability_zone: us-west4-a project_id: hai-gcp-models # Maximum Workers (excluding Head Node) @@ -126,6 +126,45 @@ available_node_types: schedulingConfig: preemptible: true + tpu_slice_v5e_16: + min_workers: 0 + max_workers: 1024 + resources: { "CPU": 120, "TPU": 4 } + + node_config: + acceleratorType: v5litepod-16 + runtimeVersion: tpu-ubuntu2204-base + + # [IMPORTANT] Configure all TPU Workers to be Preemptible! + schedulingConfig: + preemptible: true + + tpu_slice_v5e_64: + min_workers: 0 + max_workers: 1024 + resources: { "CPU": 120, "TPU": 4 } + + node_config: + acceleratorType: v5litepod-64 + runtimeVersion: tpu-ubuntu2204-base + + # [IMPORTANT] Configure all TPU Workers to be Preemptible! + schedulingConfig: + preemptible: true + + tpu_slice_v5e_256: + min_workers: 0 + max_workers: 1024 + resources: { "CPU": 120, "TPU": 4 } + + node_config: + acceleratorType: v5litepod-256 + runtimeVersion: tpu-ubuntu2204-base + + # [IMPORTANT] Configure all TPU Workers to be Preemptible! + schedulingConfig: + preemptible: true + docker: image: "ghcr.io/stanford-crfm/levanter-cluster:latest" container_name: "ray_docker" @@ -140,7 +179,7 @@ docker: - -v "/var/run/docker.sock:/var/run/docker.sock" initialization_commands: - - yes | gcloud auth configure-docker us-central2-docker.pkg.dev + - yes | gcloud auth configure-docker us-west4-docker.pkg.dev - "export TPU_WORKER_ID=$(curl -H 'Metadata-Flavor: Google' http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number) || true" - which docker || (curl -fsSL https://get.docker.com -o get-docker.sh; sudo sh get-docker.sh; sudo usermod -aG docker $USER; sudo systemctl restart docker -f) # always run this because ray doesn't run with sudo diff --git a/infra/launch_on_ray.py b/infra/launch_on_ray.py index fa5e81f27..90f2c586a 100755 --- a/infra/launch_on_ray.py +++ b/infra/launch_on_ray.py @@ -27,7 +27,7 @@ def main(): cli.add_arg(parser, config, ["--project"], default=cli.gcloud_config()["project"]) cli.add_arg(parser, config, ["--tpu_type"], required=True) # TODO: bring node_count to Ray - # cli.add_arg(parser, config, ["--node_count"], default=1, type=int) + cli.add_arg(parser, config, ["--node_count"], default=1, type=int) cli.add_arg(parser, config, ["--foreground"], default=False, action="store_true") cli.add_arg(parser, config, ["--retries"], default=10, type=int) cli.add_arg(parser, config, ["--run_id"], default=cli.default_run_id(), type=str) @@ -122,6 +122,7 @@ def main(): env=env, name="levanter", retries=retries, + node_count=args.node_count, ) address = args.address or os.getenv("RAY_ADDRESS") diff --git a/src/levanter/infra/cli_helpers.py b/src/levanter/infra/cli_helpers.py index b92b6efb5..58413ef2b 100644 --- a/src/levanter/infra/cli_helpers.py +++ b/src/levanter/infra/cli_helpers.py @@ -76,6 +76,11 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante "/tmp:/tmp", ] + # optionally add multislice env vars (if set by ray runtime env vars) + for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]: + v = shlex.quote(str(v)) + docker_command.extend(["-e", v]) + for k, v in env.items(): v = shlex.quote(str(v)) k = shlex.quote(str(k)) diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 2dc554808..57f484770 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -3,6 +3,7 @@ import logging import multiprocessing import os +import socket import subprocess import tempfile import time @@ -104,7 +105,83 @@ def do_run(remote_fn) -> _TpuRunResult: return do_run.remote(remote_fn) -def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts): +def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> ray.ObjectRef: + """ + Run a remote function on multiple TPU slices. + + Args: + remote_fn: A remote function that takes no arguments + tpu_type: The type of TPU to run on, e.g. "v4-32" + num_slices: The number of slices to run + + Returns: + A Ray ObjectRef that represents the result of the function + """ + + @ray.remote(resources={f"TPU-{tpu_type}-head": 1}) + class MultisliceActor: + def __init__(self): + self.pod_name = ray.util.accelerators.tpu.get_current_pod_name() + self.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() + self.ip = socket.gethostbyname(socket.gethostname()) + + def get_slice_info(self): + return self.pod_name, self.num_hosts, self.ip + + def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResult: + port = 8081 + mxla_env = { + "MEGASCALE_COORDINATOR_ADDRESS": f"{coordinator_ip}:{port}", + "MEGASCALE_NUM_SLICES": str(num_slices), + "MEGASCALE_PORT": f"{port}", + "MEGASCALE_SLICE_ID": str(slice_id), + } + + remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, self.num_hosts, env_vars=mxla_env) + + info = _TpuInfo(tpu_name, "ACTIVE", "TPU") + futures = [remote_fn.remote() for _ in range(self.num_hosts)] + try: + out = ray.get(futures) + logger.info("TPU job finished") + return TpuSuccess(info, out) + except RayError as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + return _handle_ray_error(info, e) + except Exception as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + return TpuFailed(info, e) + + actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore + futures = [actor.get_slice_info.remote() for actor in actors] + try: + logger.info("Getting slice infos...") + # also act as a sync step + slice_infos = ray.get(futures) + logger.info(f"TPU slice infos {slice_infos}") + except RayError as e: + logger.exception(e) + for actor in actors: + try: + ray.cancel(actor) + except Exception: + logger.exception("Failed to kill actor after primary failure") + return futures + + coordinator_ip = slice_infos[0][2] + + return [actor.do_run.remote(remote_fn, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)] + + +def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): """ Redecorate a remote function to run on a TPU pod. @@ -120,7 +197,11 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts): tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8 - remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": num_tpus_per_host}) + + remote_fn = remote_fn.options( + runtime_env=runtime_env, + resources={tpu_name: 1, "TPU": num_tpus_per_host}, + ) logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host") return remote_fn, tpu_name @@ -193,11 +274,107 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re raise RuntimeError("Failed too many times") from problem +def run_on_pod_multislice_resumable( + remote_fn, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10 +): + """ + Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached. + + Args: + remote_fn: A remote function that takes no arguments + tpu_type: The type of TPU to run on, e.g. "v4-32" + num_slices: The number of slices to run + max_retries_preemption: The maximum number of times to retry if the job is preempted + max_retries_failure: The maximum number of times to retry if the job fails + + Returns: + The result of the function (not an ObjectRef) + + """ + num_failures = 0 + num_preemptions = 0 + attempt = 0 + problem: Exception | None = None + + while num_failures < max_retries_failure and num_preemptions < max_retries_preemption: + logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}") + attempt += 1 + problem = None + futures = run_on_pod_multislice(remote_fn, tpu_type, num_slices) + try: + outs = ray.get(futures) + except ray.exceptions.RayTaskError as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + problem = e + if "preempted" in str(e).lower(): + num_preemptions += 1 + logger.warning(f"Preempted {num_preemptions} times, {e}") + else: + num_failures += 1 + logger.warning(f"Failed {num_failures} times", exc_info=e) + continue + except Exception as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + problem = e + num_failures += 1 + if num_failures >= max_retries_failure: + logger.exception("Failed too many times", exc_info=e) + raise e + else: + logger.warning(f"Failed {num_failures} times", exc_info=e) + continue + + if all(isinstance(out, TpuSuccess) for out in outs): + results = [out.result for out in outs] + logger.info("Success") + return results + elif any(isinstance(out, TpuPreempted) for out in outs): + out = None + for o in outs: + if isinstance(o, TpuPreempted): + out = o + assert out is not None + problem = out.error + num_preemptions += 1 + logger.warning(f"Preempted {num_preemptions} times. {problem}", exc_info=problem) + elif any(isinstance(out, TpuFailed) for out in outs): + num_preemptions += 1 + logger.warning(f"TPU node failure. Treating as preempted: {num_preemptions} times") + elif any(isinstance(out, TpuRunError) for out in outs): + out = None + for o in outs: + if isinstance(o, TpuRunError): + out = o + assert out is not None + problem = out.error + num_preemptions += 1 + problem = out.error + num_failures += 1 + logger.warning(f"Failed {num_failures} times", exc_info=problem) + else: + raise RuntimeError(f"Unexpected result: {out}") + + if num_preemptions >= max_retries_preemption: + raise RuntimeError("Preempted too many times") from problem + elif num_failures >= max_retries_failure: + raise RuntimeError("Failed too many times") from problem + + def _run_command(*args, **kwargs): return subprocess.check_call(args, **kwargs) -def run_docker_on_pod(image_id: str, command: Sequence[str], *, tpu_type: str, env: dict, name="levanter", retries=10): +def run_docker_on_pod( + image_id: str, command: Sequence[str], *, tpu_type: str, num_slices: int, env: dict, name="levanter", retries=10 +): env = _massage_env(env) docker_cmd = make_docker_run_command(image_id, command, env=env, foreground=True, name=name) @@ -210,9 +387,18 @@ def run_docker(): logger.exception("Failed to run docker command") raise e - run_on_pod_resumable( - ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000 - ) + if num_slices == 1: + run_on_pod_resumable( + ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000 + ) + else: + run_on_pod_multislice_resumable( + ray.remote(run_docker), + tpu_type=tpu_type, + num_slices=num_slices, + max_retries_failure=retries, + max_retries_preemption=10000, + ) def _kill_old_container(name): @@ -351,6 +537,7 @@ class RunDockerOnPodConfig: env: dict = dataclasses.field(default_factory=dict) name: str = "levanter" retries: int = 10 + node_count: int = 1 def submit_tpu_job_on_ray(config: RunDockerOnPodConfig, ray_address: str, run_id: Optional[str] = None): @@ -419,6 +606,8 @@ def main(args: RunDockerOnPodConfig): tpu_type=args.tpu_type, env=args.env, name=args.name, + retries=args.retries, + num_slices=args.node_count, ) From 8e60ba9d66cb3f4c1e8ceeff84d2dae97ff9de48 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 1 Nov 2024 09:53:22 -0700 Subject: [PATCH 02/19] almost there --- src/levanter/store/cache.py | 846 +++++++++++++++++-------------- src/levanter/store/tree_store.py | 3 + tests/test_new_cache.py | 116 +---- 3 files changed, 462 insertions(+), 503 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 45265c994..62752fc55 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -3,6 +3,7 @@ import copy import dataclasses import logging as pylogging +import operator import os import pprint import random @@ -12,24 +13,26 @@ from concurrent.futures import Future as threading_Future from contextlib import AbstractContextManager from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, TypeVar, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TypeVar, Union import deepdiff import fsspec.core import humanfriendly import jax +import numpy as np import pyarrow as pa import ray +import tensorstore as ts from dataclasses_json import dataclass_json from fsspec import AbstractFileSystem from jaxtyping import PyTree from ray.actor import ActorHandle +from tqdm_loggable.auto import tqdm +from levanter.data import batched from levanter.data.dataset import AsyncDataset -from levanter.store._prefetch_actor import QueueEmpty, RayPrefetchQueue -from levanter.utils.py_utils import Stopwatch -from ..data._preprocessor import BatchProcessor, BatchProcessorPool, BatchResult, dict_from_record_batch +from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource from ..utils.ray_utils import ( @@ -40,8 +43,7 @@ log_failures_to, ser_exc_info, ) -from ..utils.thread_utils import ExceptionTrackingThread -from .jagged_array import PreparedBatch +from .jagged_array import JaggedArrayStore, PreparedBatch from .tree_store import TreeStore @@ -69,20 +71,13 @@ class CacheOptions: """ num_shard_groups: Optional[int] = 128 - """Number of groups to divide the shards into. This is used to parallelize the cache building process without - overloading Ray. If None, all shards will be in their own group.""" - shard_order_randomization_key: Optional[int] = 0 - """A key used to randomize the order of the shards before building and grouping.""" - batch_size: int = 128 - """The batch size to use when processing the data. This is used to control the memory usage of the cache building - process. Lower values will use less memory but take somewhat longer to build the cache.""" # the below options don't actually impact the cache's result, but do impact construction target_size_per_flush: int | str = "512MB" """The number of bytes to buffer before flushing to disk. This is used to control the memory usage of the cache building process. Lower values will use less memory but could take somewhat longer to build the cache.""" - prefetch_per_group: int = 4 - """The number of batches to prefetch per group. This is used to keep the processors busy and to reduce the time""" + + batch_size: int = 128 @property def target_bytes_per_flush(self): @@ -99,14 +94,14 @@ def no_fanciness(batch_size: Optional[int] = None): """ if batch_size is None: batch_size = 128 - return CacheOptions(num_shard_groups=None, shard_order_randomization_key=None, batch_size=batch_size) + return CacheOptions(num_shard_groups=None, batch_size=batch_size) @staticmethod def one_group(): """ For testing, disables all the fancy features of the cache. This makes it easier to predict the behavior """ - return CacheOptions(num_shard_groups=1, shard_order_randomization_key=None, batch_size=128) + return CacheOptions(num_shard_groups=1, batch_size=128) def build_or_load_cache( @@ -116,7 +111,6 @@ def build_or_load_cache( await_finished: bool = True, monitors: Optional[Sequence["MetricsMonitor"]] = None, options: CacheOptions = CacheOptions.default(), - force_flush: bool = False, split: str = "test", ) -> "TreeCache[U]": """ @@ -144,8 +138,6 @@ def build_or_load_cache( options: Configuration for the cache. This is used to configure a few parts of the cache creation process - force_flush: for testing, forces the cache to flush after every batch. This is useful for testing. - Returns: (TreeCache) A TreeCache object that can be used to read the cache. @@ -156,7 +148,6 @@ def build_or_load_cache( shard_source=input_shards, processor=processor, options=options, - force_flush=force_flush, split=split, ) @@ -320,12 +311,11 @@ def build_or_load( shard_source: ShardedDataSource[T], processor: BatchProcessor[T, U], options: Optional["CacheOptions"] = None, - force_flush: bool = False, split: str = "test", ) -> "TreeCache[U]": if options is None: options = CacheOptions.default() - metadata = CacheMetadata(options=options, preprocessor_metadata=processor.metadata) + metadata = CacheMetadata(preprocessor_metadata=processor.metadata) try: return TreeCache.load(cache_dir, processor.output_exemplar, metadata) except FileNotFoundError: @@ -334,7 +324,6 @@ def build_or_load( shard_source=shard_source, processor=processor, options=options, - force_flush=force_flush, split=split, ) return TreeCache(cache_dir=cache_dir, exemplar=processor.output_exemplar, ledger=None, _broker=broker) @@ -489,13 +478,11 @@ class CacheLedger: is_finished: bool = False finished_shards: List[str] = dataclasses.field(default_factory=list) field_counts: Dict[str, int] = dataclasses.field(default_factory=dict) - metadata: "CacheMetadata" = dataclasses.field(default_factory=lambda: CacheMetadata(CacheOptions(), {})) + metadata: "CacheMetadata" = dataclasses.field(default_factory=lambda: CacheMetadata({})) @staticmethod - def load_or_initialize( - cache_dir: str, source: ShardedDataSource, processor: BatchProcessor, config: "CacheOptions" - ): - metadata = CacheMetadata(options=config, preprocessor_metadata=processor.metadata) + def load_or_initialize(cache_dir: str, source: ShardedDataSource, processor: BatchProcessor): + metadata = CacheMetadata(preprocessor_metadata=processor.metadata) try: return CacheLedger.load(cache_dir, metadata) except FileNotFoundError: @@ -531,7 +518,6 @@ def _serialize_and_commit(self, cache_dir): @dataclass_json @dataclass(frozen=True) class CacheMetadata: - options: CacheOptions = CacheOptions.default() preprocessor_metadata: Optional[dict[str, Any]] = None def compare_to(self, other: "CacheMetadata") -> deepdiff.DeepDiff: @@ -711,11 +697,10 @@ def _serialize_json_and_commit(path, obj): fs.copy(path, f"{path}.bak") for i in range(10): - with fsspec.open(f"{path}.tmp", "w") as file: - file.write(obj.to_json()) try: - fs.rename(f"{path}.tmp", path) + with fsspec.open(path, "w") as file: + file.write(obj.to_json()) break except FileNotFoundError: # this happens for some reason sometimes. It makes no sense. @@ -740,7 +725,6 @@ def __init__( source: ShardedDataSource[T], processor: BatchProcessor[T, U], options: CacheOptions, - force_flush: bool, ): pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) self.logger = pylogging.getLogger(f"{__name__}.{name}") @@ -751,7 +735,7 @@ def __init__( self._options = options self._updated_ledger_condition = asyncio.Condition() # used to subscribe to metrics updates - self._ledger = CacheLedger.load_or_initialize(cache_dir, source, processor, options) + self._ledger = CacheLedger.load_or_initialize(cache_dir, source, processor) if self._ledger.is_finished: self._finished_promise.set_result(None) @@ -770,7 +754,7 @@ def __init__( # (we get twice from we need to concatenate prepared batches into the accumulator) # TODO: measure. memory=2 * self._options.target_bytes_per_flush, - ).remote(current_actor_handle(), cache_dir, self._ledger, source, processor, force_flush) + ).remote(current_actor_handle(), cache_dir, self._ledger, source, options, processor) except Exception: # Ray behaves poorly if the constructor of an actor fails, so we catch and log here # this also propagates to the finished promise, so we can handle it there @@ -827,6 +811,9 @@ def _writer_exception(self, shard_name, exc_info: ExceptionInfo): pass self._do_notify() + def _child_failed(self, child: ray.actor.ActorHandle | str | None, exception: ExceptionInfo): + self._writer_exception(str(child), exception) + def _notify_updated_ledger(self, ledger: CacheLedger): """ Called by the cache writer when it has updated the ledger. @@ -855,7 +842,7 @@ async def _do_notify_async(): asyncio.create_task(_do_notify_async()) -def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default(), force_flush=False): +def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default()): name = f"lev_cache_manager::{split}::{cache_dir}" path_for_name = os.path.join(*os.path.split(cache_dir)[-2:]) name_for_display = f"builder::{path_for_name}" @@ -867,27 +854,13 @@ def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheO source=shard_source, processor=processor, options=options, - force_flush=force_flush, ) ##### # Core implementation starts below. ##### -# The main idea is to have a bunch of reader tasks that read batches, dispatch tokenization tasks, producing -# a stream of tokenized batches. We then interleave these tokenized batches and write them to the cache. -# The reader tasks are given a group of shards, which are implicitly concatenated together. - - -@dataclass -class _Batch: - """ - A batch of data that has either been read or tokenized. - """ - - shard_name: str - row_indices: List[int] - payload: ray.ObjectRef +# The main idea is to tokenize each shard group in parallel, and then write the results to the cache in order. @dataclass @@ -898,14 +871,7 @@ class _ShardFinished: shard_name: str total_rows: int - - -_Message = _Batch | _ShardFinished -""" -A message that can be sent from a reader task to the writer task. -""" - -_TIME_BETWEEN_WRITES = 20.0 # seconds + path_to_shard: str @ray.remote(num_cpus=1) @@ -914,17 +880,15 @@ def _core_writer_task( cache_dir, initial_ledger: CacheLedger, source: ShardedDataSource, + options: CacheOptions, processor, - force_flush: bool, ): """ This is the main task that processes the data and writes it to the cache. - It chains together: - * 1 generator per shard group - * interleaving of the generators - * processing of the batches - * writing of the batches to the cache + It receives "finished shards" messages from the reader tasks, and copies the data from temporary files + to the cache directory. + """ pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) logger.info("Starting writer task") @@ -933,400 +897,489 @@ def _core_writer_task( # append a small random number to the name to avoid collisions name += f"::{random.randint(0, 1000)}" - with log_failures_to(parent): - - def on_write(ledger): - ray.get(parent._notify_updated_ledger.remote(ledger)) + # We want to make sure it's there + initial_ledger._serialize_and_commit(cache_dir) - sharded_cache_writer = ShardedCacheWriter( - cache_dir, initial_ledger, processor.output_exemplar, on_write=on_write - ) + # we want to do the following: + # 1. write the 0th shard group to the output cache directly, updating metrics as we go + # 2. in the background, start processing other shard groups to temporary caches + # 3. once (1) is done, we start copying the temporary caches to the output cache (in order) + # for now we're going to punt on (1) + with log_failures_to(parent): + temporary_cache_path = os.path.join(cache_dir, "___temp") - options = initial_ledger.metadata.options - num_groups = min(options.num_shard_groups or 1000000, len(source.shard_names)) + paths: dict[str, str] = {} + ledgers: dict[str, CacheLedger | None] = {} + already_finished_paths: list[str] = [] + refs: dict[str, ray.ObjectRef] = {} - processor_pool = _mk_processor_pool(processor, 0, num_groups * 4) + shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) - interleave: RayPrefetchQueue = RayPrefetchQueue( - lambda: _make_interleave(name, source, initial_ledger, processor_pool), - 64, - producer_options={"num_cpus": 1, "name": f"{name}::interleave"}, + logger.info( + f"Tokenizing {len(source.shard_names)} shards in {len(shard_groups)} groups to {temporary_cache_path}." ) - total_time = Stopwatch() - loading_time = Stopwatch() - append_time = Stopwatch() - flush_time = Stopwatch() - flush_amortized_time = Stopwatch() - - current_prepared_batch: Optional[PyTree[PreparedBatch]] = None - current_shard_rows: dict[str, int] = {} - time_of_last_write = time.time() - batches_total = 0.0 - flush_thread = None - finished_shards_last_flush: list = [] - - while True: - with total_time: # 0.0051 - try: - cur_time = time.time() - time_since_last_write = cur_time - time_of_last_write - remaining_time = _TIME_BETWEEN_WRITES - time_since_last_write - - if current_prepared_batch is not None: - with flush_amortized_time: # 6e-4 - current_byte_size = sum( - b.byte_size for b in jax.tree_util.tree_flatten(current_prepared_batch)[0] - ) - should_flush = ( - force_flush - or remaining_time <= 0 - or (current_byte_size >= options.target_bytes_per_flush) - ) - if should_flush: - with flush_time: # 0.613s - if flush_thread is not None: - flush_thread.join() - - flush_thread = ExceptionTrackingThread( - target=_write_batches, - args=( - sharded_cache_writer, - current_shard_rows, - current_prepared_batch, - finished_shards_last_flush, - ), - ) - flush_thread.start() - - current_prepared_batch = None - current_shard_rows = {} - finished_shards_last_flush = [] - - time_of_last_write = time.time() - continue - else: - remaining_time = _TIME_BETWEEN_WRITES - - with loading_time: - try: - message = interleave.get_next(timeout=max(remaining_time, 0.1)) - except QueueEmpty: - logger.info("Writer running ahead of reader.") - continue - - with append_time: - match message: - case _Batch(shard, row_indices, payload): - batches_total += 1 - this_prepared_batch = ray.get(payload) - if current_prepared_batch is None: - # TODO: actually check row indices - current_shard_rows = {shard: len(row_indices)} - current_prepared_batch = this_prepared_batch - else: - current_shard_rows[shard] = current_shard_rows.get(shard, 0) + len(row_indices) - current_prepared_batch = _concat_prepared_batches( - current_prepared_batch, this_prepared_batch - ) - del this_prepared_batch - - if force_flush: - _write_batches( - sharded_cache_writer, - current_shard_rows, - current_prepared_batch, - finished_shards_last_flush, - ) - finished_shards_last_flush = [] - current_prepared_batch = None - current_shard_rows = {} - - case _ShardFinished(shard, total_rows): - finished_shards_last_flush.append((shard, total_rows)) - case _: - raise AssertionError(f"Unexpected message type {type(message)}") - - # if batches_total % 1000 == 0: - # print( - # f"Processed {batches_total} batches: {loading_time.average()}s load," - # f" {append_time.average()}s append, {flush_time.average()}s flush blocked, " - # f"{flush_amortized_time.average()}s amortized flush, " - # f"{total_time.average()}s total" - # ) - except StopIteration: - logger.info("Finished all shards") - break - except Exception as e: - logger.exception("Error while processing batch") - raise e - - # force a flush - if current_prepared_batch is not None or finished_shards_last_flush: - if flush_thread is not None: - flush_thread.join() - _write_batches( - sharded_cache_writer, current_shard_rows, current_prepared_batch, finished_shards_last_flush + unit = "shard" if len(shard_groups) == len(source.shard_names) else "shard group" + pbar = tqdm(total=len(shard_groups), desc="Tokenizing", unit=unit) + + processor_ref = ray.put(processor) + source_ref = ray.put(source) + + for group_name, shards in shard_groups.items(): + path = os.path.join(temporary_cache_path, group_name) + paths[group_name] = path + + ledger = _try_load(path) + ledgers[group_name] = ledger + + if ledger is not None: + already_finished_paths.append(path) + pbar.update(1) + continue + + ref = ( + ray.remote(_tokenize_one_shard_group) + .options( # type: ignore + num_cpus=processor.num_cpus, + num_gpus=processor.num_gpus, + resources=processor.resources, + memory=3 * 1024 * 1024 * 1024, # made this up + name=f"tokenize::{temporary_cache_path}::{group_name}", + retry_exceptions=True, + max_retries=10, + ) + .remote(os.path.join(temporary_cache_path, group_name), source_ref, shards, processor_ref, options) ) - sharded_cache_writer.finish() + refs[group_name] = ref + + # now we start copying the temporary caches to the output cache, in order. (essentially concatenating them) + # This logic is a bit hairy thanks to resumes. + # First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these + # separately. We also need to update the ledger as we go. + # Second, note that JaggedArrayStores have two notions of length: the number of rows, and the data size. + # We store the number of rows in offsets[0], and the data size in offsets[offsets[0]], which is just the final offset. + # So we can keep a cache "locked" to a particular read size until we're ready by controlling the offsets. + + # * When we load the permanent cache, we have already written some number of groups to it. + # (We check this invariant with an assert) + # * We need to copy the remaining groups to the permanent cache, and update the ledger as we go. + # * To copy a group, we need to know the total number of rows in that group, as well as the "data offsets" + # for the data in the cache. We can get the total number of rows from the ledger, and we also calculate + # the data offsets for where the group goes in the permanent cache. This is just a running sum of the + # data sizes of the previous groups. Because we have multiple JaggedArrayStores, this can be a pytree + # of integers, one for each array. + # * Once we have finished the i'th cache and all caches < 1, we can "unlock" the data for the i'th cache + # by updating the offset[0] of the permanent cache to the total number of rows through the i'th cache. + # * We also need to update the ledger with the total number of rows + permanent_cache = TreeStore.open(processor.output_exemplar, cache_dir, mode="a", cache_metadata=False) + # initialize the data offset tree + data_offset_tree = jax.tree_map(lambda x: 0, permanent_cache.tree) + total_rows_from_caches = 0 + + copy_refs: dict[str, ray.ObjectRef] = {} + last_ref: ray.ObjectRef | None = None + + for group in shard_groups: + # first make sure it's either done this run or already done + if refs.get(group) is not None: + this_ledger = ray.get(refs[group]) + ledgers[group] = ledger + else: + this_ledger = ledgers[group] + + assert this_ledger is not None + # see if we already copied this group, meaning all the shards are in the permanent cache + shards_copied = sum(1 if shard in initial_ledger.finished_shards else 0 for shard in shard_groups[group]) + if shards_copied == len(shard_groups[group]): + assert initial_ledger.total_num_rows >= total_rows_from_caches + elif shards_copied > 0: + # In theory we can handle this, but it's a bit tricky, so we're going to punt for now + raise RuntimeError("Some shards were copied but not all. This should never happen.") + else: + # we need to copy this group + ref_to_send = None if last_ref is None else RefBox(last_ref) + last_ref = _copy_cache.remote( + cache_dir, + paths[group], + processor_ref, + data_offset_tree, + ref_to_send, + total_rows_from_caches, + parent, + ) + copy_refs[group] = last_ref - out = sharded_cache_writer.get_ledger() - return out + # update the data offset tree + this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True) + data_offset_tree = jax.tree.map( + operator.add, data_offset_tree, jax.tree_map(lambda x: x.data_size, this_cache.tree) + ) + total_rows_from_caches += this_ledger.total_num_rows + if last_ref is not None: + ledger = ray.get(last_ref) + else: + ledger = initial_ledger -def _concat_prepared_batches( - current_prepared_batch: PyTree[PreparedBatch], this_prepared_batch: PyTree[PreparedBatch] -): - return jax.tree.map(lambda *bs: PreparedBatch.concat(bs), current_prepared_batch, this_prepared_batch) + ledger.is_finished = True + parent._notify_updated_ledger.remote(ledger) -def _write_batches(writer: ShardedCacheWriter, shard_totals, batch: Optional[PyTree[PreparedBatch]], finished_shards): - # concatenate the payloads - if batch is not None: - writer.write_prepared_batch(shard_totals, batch) +def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]: + if num_groups is None or num_groups >= len(source.shard_names): + return {shard_name: [shard_name] for shard_name in source.shard_names} - for shard, total_rows in finished_shards: - writer.finish_shard(shard, total_rows) + shard_names = source.shard_names + num_shards_per_group = len(shard_names) // num_groups + # if we have a remainder, we'll just add it to the last group + out_groups = { + f"group_{i}": list(shard_names[i * num_shards_per_group : (i + 1) * num_shards_per_group]) + for i in range(num_groups) + } + if len(shard_names) % num_shards_per_group != 0: + out_groups[f"group_{num_groups - 1}"].extend(shard_names[num_groups * num_shards_per_group :]) + return out_groups # type: ignore -def _fetch_batches(batches) -> tuple[dict[str, int], list[PreparedBatch]]: - shards_for_batches, payloads_for_batches = zip(*batches) - payloads_for_batches = ray.get(list(payloads_for_batches)) - shard_row_totals: dict[str, int] = {} - for shard, payload in zip(shards_for_batches, payloads_for_batches): - shard_row_totals[shard] = shard_row_totals.get(shard, 0) + jax.tree.leaves(payload)[0].num_rows +def _merge_ledgers(dest, source): + dest.total_num_rows += source.total_num_rows + for shard, rows in source.shard_rows.items(): + current_value = dest.shard_rows.get(shard, 0) + assert current_value == 0, f"Shard {shard} already has {current_value} rows" + dest.shard_rows[shard] = rows - return shard_row_totals, payloads_for_batches + dest.finished_shards.extend(source.finished_shards) + dest.field_counts.update(source.field_counts) + return dest -def _interleave_shards(readers: Sequence[RayPrefetchQueue], first_index: int) -> Iterator[T]: # _Message +@ray.remote(num_cpus=4, memory=4 * 1024 * 1024 * 1024) +def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: RefBox, rows_so_far, parent): """ - Interleaves the results of multiple iterators. To support resume, - we need to be able to start from not the "first" iterator. + Copies the data from one cache to another, appending it to the end of the destination cache. + Once the copy is done and the last_ref is set, the data is "unlocked" in the destination cache by updating the + offsets[0] of the destination cache to the total number of rows in the cache. Args: - readers: A list of iterators - first_index: The index of the first iterator to start from. We use this to support resuming. + dest_path: The path to the destination cache. + source_path: The path to the source cache. + processor: The processor used to create the cache. + data_offset_tree: The data offset tree for the destination cache. + last_ref: The ref to wait on before updating the ledger. + rows_so_far: The total number of rows in the destination cache before this copy. + + Returns: + + """ + with log_failures_to(parent): + asyncio.run(_extend_cache_with_other_cache(dest_path, source_path, processor, data_offset_tree, rows_so_far)) + print("done copying", flush=True) + if last_ref is not None: + ray.wait([last_ref.ref], fetch_local=False) + print("done waiting", flush=True) + permanent_cache = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) + dest_ledger = CacheLedger.load(dest_path) + source_ledger = CacheLedger.load(source_path) + + new_num_rows = source_ledger.total_num_rows + rows_so_far + + futures = jax.tree.leaves(jax.tree.map(lambda x: x.offsets[0].write(new_num_rows), permanent_cache.tree)) + for future in futures: + future.result() + + print("wrote rows", flush=True) + + _merge_ledgers(dest_ledger, source_ledger) + dest_ledger._serialize_and_commit(dest_path) + assert not dest_ledger.is_finished + parent._notify_updated_ledger.remote(dest_ledger) + print("done", flush=True) + return dest_ledger + + +async def _extend_cache_with_other_cache( + dest_path: str, source_path: str, processor: BatchProcessor, data_offset_tree: PyTree[int], row_offset +) -> int: """ + Copies the data from one cache to another, appending it to the end of the destination cache. - finished: set[int] = set() - total = 0 - while len(finished) < len(readers): - for i in range(first_index, len(readers)): - reader = readers[i] - if i not in finished: - try: - message = reader.get_next() - total += 1 - yield message - except StopIteration: - finished.add(i) - except Exception as e: - logger.exception(f"Error while processing group {i}") - raise e + Returns: + The number of rows in the source cache. + """ + logger.info(f"Copying data from {source_path} to {dest_path}.") + dest = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) + source = TreeStore.open(processor.output_exemplar, source_path, mode="r", cache_metadata=True) + + source_num_rows = await source.async_len() + + async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore, data_offset: int): + """Copies **just the data array** from one shard to the permanent cache at a given offset.""" + # TODO: it'd be good if we just didn't expose the full data array (but only the used part) + data_size = source_array.data_size + data = source_array.data[0:data_size] + print(f"starting to write data. {data.read().result()=}", flush=True) + print(f"{row_offset=}", flush=True) + futures: list[ts.Future] = [] + + # write_future = dest_array.data[data_offset : data_offset + source_array.data_size].write(data) + async with ts.Transaction() as txn: + dest = dest_array.data + out_end = data_offset + data_size + write_future = dest.with_transaction(txn)[data_offset:out_end].write(data) + futures.append(write_future) + + if source_array.shapes is not None: + source_shapes = source_array.shapes[0:source_num_rows] + async with ts.Transaction() as txn: + dest = dest_array.shapes + out_end = row_offset + source_num_rows + shape_future = dest.with_transaction(txn)[row_offset:out_end].write(source_shapes) + futures.append(shape_future) + print("done writing shapes", flush=True) + + source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]] + source_offsets = _virtual_offset(source_offsets, data_offset) + + async with ts.Transaction() as txn: + dest = dest_array.offsets + out_end = row_offset + 1 + source_num_rows + offset_future = dest.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets) + + print("hi", flush=True) + print(f"done writing offsets {source_offsets.domain}", flush=True) + print(f"done writing offsets {dest[row_offset+1:out_end].read().result()}", flush=True) + + futures.append(offset_future) + + out = await asyncio.gather(*futures) + print("done writing", flush=True) + return out + + futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree) - first_index = 0 + await asyncio.gather(*jax.tree.leaves(futures)) + logger.info(f"Finished copying data from {source_path} to {dest_path}.") - logger.info(f"Finished all shards, got {total} batches") + return source_num_rows -def _assign_shards_to_groups(shards: Sequence[_ShardStatus], num_groups: int) -> list["_ShardGroup"]: +def _virtual_offset(base: ts.TensorStore, offset_amount): """ - Assigns shards to groups in a round-robin fashion. + This function creates a new tensorstore that is a virtual offset of another tensorstore. + That is, it's y[i] = x[i] + offset_amount. """ - groups: list[list] = [[] for _ in range(num_groups)] - for i, shard in enumerate(shards): - groups[i % num_groups].append(shard) - return [_ShardGroup(group) for group in groups] + async def do_read(domain: ts.IndexDomain, array: np.ndarray, read_params: ts.VirtualChunkedReadParameters): + array[...] = (await base[domain].read()) + offset_amount -def _randomize_shards(shards: Sequence[T], seed: int) -> list[T]: - prng = random.Random(seed) - shuffled = list(shards) - prng.shuffle(shuffled) - return shuffled + return ts.virtual_chunked(do_read, dtype=base.dtype, domain=base.domain, shape=base.shape) -class _ShardGroup: - """ - Given a group of shards and a list of statuses, implicitly concatenates the shards and reads from them. +async def _copy_data_from_one_shard_to_permanent_memory( + dest_path: str, + source_path: str, + processor: BatchProcessor, + data_offset_tree: PyTree[int], +): + """Copies from one tree store to the permanent cache at a given offset (for each leaf)""" + logger.info(f"Copying data from {source_path} to {dest_path}.") + dest = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) + source = TreeStore.open(processor.output_exemplar, source_path, mode="r", cache_metadata=True) - This class mostly exists for resuming: we want to be able to start from the last shard we were working on. - """ + def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore, data_offset: int): + # TODO: it'd be good if we just didn't expose the full data array (but only the used part) + data = source_array.data[0 : source_array.data_size] + # write_future = dest_array.data[data_offset : data_offset + source_array.data_size].write(data) + with ts.Transaction() as txn: + dest = dest_array.data + out_end = data_offset + source_array.data_size + write_future = dest.with_transaction(txn)[data_offset:out_end].write(data) - def __init__(self, group: list[_ShardStatus]): - self.shards = group - self.total_rows_committed, _all_finished = self._impute_total_rows_committed_and_check_invariants() - - def _impute_total_rows_committed_and_check_invariants(self): - # we also want to ensure that we haven't started any shards until we've finished the previous ones - total_committed = 0 - last_shard_name = None - last_was_finished = True - all_finished = True - - for status in self.shards: - shard_name = status.shard_name - if not last_was_finished and status.num_rows_committed > 0: - raise ValueError( - f"Shard {shard_name} has rows committed but previous shard in group {last_shard_name} " - "is not finished. Something about the cache configuration has changed: either the " - "number/order of shards, the shard shuffle random seed, or the number of groups." - ) - total_committed += status.num_rows_committed - if not status.is_finished: - all_finished = False - last_was_finished = status.is_finished - last_shard_name = shard_name + return write_future - return total_committed, all_finished + futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree) + await asyncio.gather(*jax.tree.leaves(futures)) + logger.info(f"Finished copying data from {source_path} to {dest_path}.") + return -def _make_interleave(name: str, source: ShardedDataSource, initial_ledger: CacheLedger, processor_pool: ActorHandle): - """ - Given a list of ShardStatus objects and sources, creates an interleaving generator - that reads from shards and tokenizes them in parallel. - We use ShardStatus objects to track the progress of each shard. If we're preempted, we can resume - from the last shard we were working on. This function starts each shard at the last committed row - and starts interleaving from the next shard (i.e. the one with the fewest rows that isn't finished). - """ - logger.setLevel(DEFAULT_LOG_LEVEL) - statuses = _get_shard_statuses(initial_ledger, source) +def _tokenize_one_shard_group( + temporary_cache_path: str, + source: ShardedDataSource, + shards: list[str], + processor: BatchProcessor, + options: CacheOptions, +) -> CacheLedger: + # ray breaks if this is top level + import humanfriendly - options = initial_ledger.metadata.options + logger = pylogging.getLogger("tokenize") + pylogging.basicConfig(level=pylogging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") - unfinished_shards = _check_current_shard_progress(statuses) + # restrict shards to the ones we're supposed to process + # this is a bit hacky but when there are a lot of shards (e.g. SlimPajama 122K), + # we encounter significant overhead just parsing the shard names from the json + source = _RestrictedShardedDataSource(source, shards) - if not unfinished_shards: - logger.info("All shards finished. Nothing to do.") - return + ledger = CacheLedger.load_or_initialize(temporary_cache_path, source, processor) - group_names, groups = _randomize_and_group_shards(name, options, statuses) + if ledger.is_finished: + logger.info("Shard group already processed.") + return ledger - logger.warning(f"Starting cache build with {len(statuses)} shards, in {len(groups)} groups") + writer = ShardGroupCacheWriter(temporary_cache_path, ledger, shards, processor.output_exemplar) - def _make_generator_fn(group: _ShardGroup): - def generator(): - pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) - for message in _shard_reader_generator(source, group, options.batch_size): - match message: - case _Batch(): - # processed = ray.put(process_task(ray.get(message.payload))) - # processed = process_task.remote(processor_ref, message.payload) - processed = processor_pool.process_batch.remote(RefBox(message.payload)) - yield dataclasses.replace(message, payload=processed) - case _ShardFinished(): - yield message - case _: - raise AssertionError(f"Unexpected message type {type(message)}") + total_rows = ledger.total_num_rows + found_shard_with_rows = False - return generator + for shard_name in shards: + if shard_name in ledger.finished_shards: + logger.info(f"Shard {shard_name} already processed.") + continue - generator_fns = [_make_generator_fn(group) for group in groups] + logger.debug(f"Processing {shard_name}.") - readers = [ - RayPrefetchQueue( - fn, - options.prefetch_per_group, - producer_options=dict(num_cpus=0.1, name=name, scheduling_strategy="SPREAD"), - ) - for name, fn in zip(group_names, generator_fns) - ] + rows_this_shard = ledger.shard_rows.get(shard_name, 0) - # then figure out the first shard to start from. This is the first unfinished shard with the minimum number of rows - first_group_to_start = min( - range(len(groups)), - key=lambda i: groups[i].total_rows_committed, - ) + if found_shard_with_rows and rows_this_shard != 0: + raise ValueError("Found more than one shard with rows to process.") - yield from _interleave_shards(readers, first_group_to_start) + if rows_this_shard != 0: + found_shard_with_rows = True + shard_iterator = source.open_shard_at_row(shard_name, rows_this_shard) -def _mk_processor_pool(processor, min_size, max_size): - import hashlib + prepared_batch: PyTree[PreparedBatch] | None = None + this_batch_size = 0 - metadata_hash = hashlib.md5(str(processor.metadata).encode()).hexdigest() - processor_pool_name = f"processor_pool::{metadata_hash}" - processor_pool = BatchProcessorPool.options( # type: ignore - name=processor_pool_name, get_if_exists=True, lifetime="detached" - ).remote( # type: ignore - processor, min_size, max_size - ) + for batch in batched(shard_iterator, options.batch_size): + tokenized = processor(batch) + tokenized = _canonicalize_batch(tokenized) # type: ignore + this_prepared = writer._tree_store.batch_preparer(tokenized) + + this_batch_size += len(batch) + rows_this_shard += len(batch) - ray.get(processor_pool.ensure_max_at_least.remote(max_size)) + if prepared_batch is None: + prepared_batch = this_prepared + else: + prepared_batch = jax.tree.map( + lambda *trees: PreparedBatch.concat(trees), prepared_batch, this_prepared + ) - return processor_pool + batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) + if batch_byte_size > options.target_bytes_per_flush: + writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) + nice_bytes = humanfriendly.format_size(batch_byte_size) + logger.debug( + f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" + ) + this_batch_size = 0 + prepared_batch = None + + if prepared_batch is not None: + batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) + nice_bytes = humanfriendly.format_size(batch_byte_size) + writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) + logger.debug( + f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" + ) + this_batch_size = 0 + prepared_batch = None -def _check_current_shard_progress(statuses): - unfinished_shards: list[_ShardStatus] = [] - shards_with_progress: dict[str, int] = {} - for status in statuses: - if not status.is_finished: - unfinished_shards.append(status) - if status.num_rows_committed > 0: - shards_with_progress[status.shard_name] = status.num_rows_committed - if unfinished_shards and shards_with_progress: - formatted = ", ".join(f"{k}: {v}" for k, v in shards_with_progress.items()) - logger.info(f"Resuming from shards with progress: {formatted}") - return unfinished_shards + total_rows += rows_this_shard + writer.finish_shard(shard_name, rows_this_shard) -def _randomize_and_group_shards(name, options, statuses): - if options.shard_order_randomization_key is not None: - seed = options.shard_order_randomization_key - logger.info(f"Randomizing shard order with seed {seed}") - statuses = _randomize_shards(statuses, seed) + writer.finish() - num_groups = min( - options.num_shard_groups if options.num_shard_groups is not None else len(statuses), len(statuses) - ) - if num_groups == 1: - group_names = [f"generator::{name}::all_shards"] - elif len(statuses) == num_groups: - group_names = [f"generator::{name}::{status.shard_name}" for status in statuses] - else: - group_names = [f"generator::{name}::group_{i}" for i in range(num_groups)] + logger.info(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.") - groups = _assign_shards_to_groups(statuses, num_groups) - return group_names, groups + return writer.ledger -def _shard_reader_generator( - shard_source: ShardedDataSource[T], group: _ShardGroup, batch_size: int -) -> Iterator[_Message]: +class ShardGroupCacheWriter: """ - Given a group of shards, implicitly concatenates the shards and reads from them. + Similar to SerialCacheWriter, but tracks shard metadata for one shard. """ - for status in group.shards: - if status.is_finished: - logger.info(f"Skipping finished shard {status.shard_name}") - continue - start_row = status.num_rows_committed - logger.info(f"Opening shard {status.shard_name} at row {start_row}") - shard_iter = shard_source.open_shard_at_row(status.shard_name, start_row) - batch = [] - batch_idxes = [] - row_idx = start_row - for row in shard_iter: - batch.append(row) - batch_idxes.append(row_idx) - row_idx += 1 + def __init__(self, cache_dir: str, initial_ledger: CacheLedger, shards: list[str], exemplar: T): + self.cache_dir = cache_dir - if len(batch) == batch_size: - yield _Batch(status.shard_name, batch_idxes, ray.put(batch)) - batch = [] - batch_idxes = [] + self._ledger = copy.deepcopy(initial_ledger) + self.shards = shards - if len(batch) > 0: - yield _Batch(status.shard_name, batch_idxes, ray.put(batch)) + self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a") # type: ignore + self._tree_store.trim_to_size(self._ledger.total_num_rows) - logger.info(f"Finished generating shard {status.shard_name} with {row_idx} rows") - yield _ShardFinished(status.shard_name, row_idx) + @property + def ledger(self): + return self._ledger + + # we have both versions b/c we need this one for actors + def get_ledger(self): + return self._ledger + + @property + def is_finished(self): + return self._ledger.is_finished + + def finish_shard(self, shard_name: str, num_rows: int): + if shard_name not in self.shards: + raise ValueError(f"Shard {shard_name} not in tracked shards") + + current_rows = self._ledger.shard_rows.get(shard_name, 0) + if current_rows != num_rows: + raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}") + + self._ledger.finished_shards.append(shard_name) + self._ledger._serialize_and_commit(self.cache_dir) + + def write_prepared_batch(self, shard_name: str, row_count: int, batch: PyTree[PreparedBatch]): + if self.is_finished: + raise RuntimeError("Cannot write to a finished cache") + self._tree_store.extend_with_batch(batch) + + if shard_name not in self.shards: + raise ValueError(f"Shard {shard_name} not in tracked shards") + self._ledger.shard_rows[shard_name] += row_count + self._ledger.total_num_rows += row_count + + self._ledger._serialize_and_commit(self.cache_dir) + + def finish(self): + if len(self._ledger.finished_shards) != len(self.shards): + raise ValueError("Not all shards are finished") + + self._ledger.is_finished = True + self._ledger._serialize_and_commit(self.cache_dir) + # ensure all tracked shards are finished + + return self._tree_store + + +class _RestrictedShardedDataSource(ShardedDataSource): + def __init__(self, source: ShardedDataSource, shards: list[str]): + self._source = source + self._shards = shards + + @property + def shard_names(self): + return self._shards + + def open_shard_at_row(self, shard_name, row): + return self._source.open_shard_at_row(shard_name, row) + + +def _randomize_shards(shards: Sequence[T], seed: int) -> list[T]: + prng = random.Random(seed) + shuffled = list(shards) + prng.shuffle(shuffled) + return shuffled def _canonicalize_batch(batch: Union[dict, List[dict]]) -> List[dict]: @@ -1360,8 +1413,13 @@ def _ledger_to_metrics(ledger: CacheLedger) -> InProgressCacheMetrics: ) -def _get_shard_statuses(ledger: CacheLedger, source: ShardedDataSource): - return [ - _ShardStatus(name, ledger.shard_rows.get(name, 0), name in ledger.finished_shards) - for name in source.shard_names - ] +def _try_load(path): + try: + ledger = CacheLedger.load(path) + if ledger.is_finished: + return ledger + else: + logger.debug(f"Cache exists but is not finished at {path}.") + return None + except FileNotFoundError: + return None diff --git a/src/levanter/store/tree_store.py b/src/levanter/store/tree_store.py index 03355a8d2..83d6c88b0 100644 --- a/src/levanter/store/tree_store.py +++ b/src/levanter/store/tree_store.py @@ -172,6 +172,9 @@ def get_batch_sync(self, indices) -> List[T]: return out + async def async_len(self) -> int: + return await jax.tree.leaves(self.tree)[0].num_rows_async() + def _construct_builder_tree(exemplar, path, mode, cache_metadata): def open_builder(tree_path, item): diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index c1eb73670..82bf045c7 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -1,6 +1,4 @@ import asyncio -import copy -import os import tempfile from typing import Any, Dict, Iterator, Sequence @@ -10,17 +8,7 @@ from levanter.data import BatchProcessor, ShardedDataSource, batched from levanter.data.sharded_datasource import TextUrlDataSource -from levanter.store.cache import ( - LEDGER_FILE_NAME, - CacheLedger, - CacheOptions, - SerialCacheWriter, - ShardedCacheWriter, - TreeStore, - _get_builder_actor, - _serialize_json_and_commit, - build_or_load_cache, -) +from levanter.store.cache import CacheOptions, SerialCacheWriter, TreeStore, _get_builder_actor, build_or_load_cache from levanter.utils.py_utils import logical_cpu_core_count @@ -146,7 +134,7 @@ def test_full_end_to_end_cache(): options=CacheOptions.no_fanciness(8), ) - expected = process_interleave(TestProcessor(), SimpleShardSource(num_shards=2), 8) + expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=2)) all_data = ray_ds[:] @@ -162,15 +150,14 @@ def test_full_end_to_end_cache_with_groups(): SimpleShardSource(num_shards=5), TestProcessor(), await_finished=True, - options=CacheOptions(num_shard_groups=2, batch_size=8, shard_order_randomization_key=None), + options=CacheOptions(num_shard_groups=2, batch_size=8), ) - expected = process_interleave(TestProcessor(), SimpleShardSource(num_shards=5), 8) + expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=5)) all_data = ray_ds[:] - # check_datasets_equal(all_data, expected) - assert len(all_data) == len(list(expected)) + check_datasets_equal(all_data, expected) @pytest.mark.ray @@ -295,7 +282,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # now block until the cache is done cache.await_finished(timeout=30) - expected = process_interleave(processor, SlowShardSource(), 16) + expected = simple_process(processor, SlowShardSource()) check_datasets_equal(list(cache[:]), expected) @@ -334,9 +321,8 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: SlowShardSource(), TestProcessor(), await_finished=False, - force_flush=True, options=CacheOptions.no_fanciness(5), - ) # we need force_flush to ensure the cache is written to disk + ) # read the first 10 elements # ensure the first 10 elements are [{"test": np.array([i] * 10)} for i in range(10)] @@ -364,7 +350,6 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: cache.await_finished(timeout=10) -@pytest.mark.skip("This test segfaults in CI. I think a ray bug") @pytest.mark.ray def test_shard_cache_crashes_if_processor_throws(): class ThrowingProcessor(SimpleProcessor): @@ -398,7 +383,6 @@ def test_shard_cache_fails_with_multiple_shards_with_the_same_name(): build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=True) -@pytest.mark.skip("This test segfaults in CI. I think a ray bug") @pytest.mark.ray @pytest.mark.asyncio async def test_shard_cache_fails_gracefully_with_unknown_file_type_async(): @@ -451,89 +435,3 @@ def test_shard_cache_fails_gracefully_with_unknown_file_type(): cache.await_finished(timeout=10) del cache - - -def test_sharded_cache_writer(): - with tempfile.TemporaryDirectory() as tmpdir: - source = SimpleShardSource(num_shards=4) - processor = SimpleProcessor() - ledger = CacheLedger.load_or_initialize(tmpdir, source, processor, CacheOptions.no_fanciness(8)) - - exemplar = {"data": np.array([0], dtype=np.int64)} - - writer = ShardedCacheWriter(tmpdir, ledger, exemplar) - for shard_name in source.shard_names: - for ex in batched(source.open_shard(shard_name), ledger.metadata.options.batch_size): - writer.write_batch(shard_name, processor(ex)) - - for shard_name in source.shard_names: - writer.finish_shard(shard_name, source._rows_per_shard) - - store = writer.finish() - - data_path = store.path - - del store - - builder = TreeStore.open(exemplar, data_path, mode="r") - - assert len(builder) == 40 - - for i, x in enumerate(builder): - np.testing.assert_array_equal(x["data"], np.asarray([i % 10 + i // 10 * 10] * 10)) - - # check totals for the ledger - ledger = writer.ledger - assert ledger.total_num_rows == 40 - assert ledger.is_finished - - for shard_name in source.shard_names: - assert ledger.shard_rows[shard_name] == 10 - - -def test_sharded_cache_writer_trims_on_resume(): - with tempfile.TemporaryDirectory() as tmpdir: - source = SimpleShardSource(num_shards=4) - processor = SimpleProcessor() - - exemplar = {"data": np.array([0], dtype=np.int64)} - - ledger = CacheLedger.load_or_initialize(tmpdir, source, processor, CacheOptions.no_fanciness(batch_size=8)) - - writer = ShardedCacheWriter(tmpdir, ledger, exemplar) - for shard_name in source.shard_names: - for ex in batched(source.open_shard(shard_name), 8): - writer.write_batch(shard_name, processor(ex)) - - for shard_name in source.shard_names: - writer.finish_shard(shard_name, 10) - - writer.finish() - - # now deliberately truncate the ledger a bit - ledger = copy.deepcopy(writer.ledger) - assert ledger.total_num_rows == 40 - assert ledger.is_finished - ledger.total_num_rows = 24 - ledger.shard_rows["shard_0"] = 8 - ledger.shard_rows["shard_1"] = 8 - ledger.shard_rows["shard_2"] = 8 - ledger.shard_rows["shard_3"] = 0 - ledger.is_finished = False - - _serialize_json_and_commit(os.path.join(tmpdir, LEDGER_FILE_NAME), ledger) - - writer = ShardedCacheWriter(tmpdir, ledger, exemplar) - - # ensure it got truncated - assert writer.ledger.total_num_rows == 24 - assert writer.ledger.is_finished is False - assert writer.ledger.shard_rows["shard_0"] == 8 - assert writer.ledger.shard_rows["shard_1"] == 8 - assert writer.ledger.shard_rows["shard_2"] == 8 - assert writer.ledger.shard_rows["shard_3"] == 0 - - new_store = writer._tree_store - new_data = new_store[:] - - assert len(new_data) == 24 From f742ba75d4216097d2931d2cdc44a169539220ff Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 1 Nov 2024 09:54:53 -0700 Subject: [PATCH 03/19] crash the test for now --- tests/test_new_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index 82bf045c7..c61c66105 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -326,7 +326,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # read the first 10 elements # ensure the first 10 elements are [{"test": np.array([i] * 10)} for i in range(10)] - first_10 = list(await cache.get_batch(range(0, 10))) + first_10 = list(await asyncio.wait_for(cache.get_batch(range(0, 10)), timeout=10.0)) for i, x in enumerate(first_10): np.testing.assert_array_equal(x["test"], np.array([i] * 10)) @@ -339,7 +339,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # now ensure we can get the next 10 elements, which will be # [{"test": np.array([i] * 10)} for i in range(10, 20)] - batch = await asyncio.wait_for(cache.get_batch(range(10, 20)), timeout=10) + batch = await asyncio.wait_for(cache.get_batch(range(10, 20)), timeout=10.0) for i, x in enumerate(batch): np.testing.assert_array_equal(x["test"], np.array([i + 10] * 10)) From e9c03a0e9f1516d7e707e86cbe367e497d7e8f23 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 1 Nov 2024 11:56:30 -0700 Subject: [PATCH 04/19] cleanup temp filesqq --- src/levanter/data/text.py | 6 +- src/levanter/store/cache.py | 119 +++++------------------------ src/levanter/utils/fsspec_utils.py | 28 ++++++- 3 files changed, 49 insertions(+), 104 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 70c1fe4b3..de6980430 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -35,7 +35,7 @@ from levanter.store.cache import CacheOptions, TreeCache from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore -from levanter.utils.fsspec_utils import fsspec_expand_glob +from levanter.utils.fsspec_utils import expand_glob from levanter.utils.hf_utils import num_cpus_used_by_tokenizer @@ -508,7 +508,7 @@ def urls_for_split(self, split): else: raise ValueError(f"Unknown split {split}") - urls = [globbed for url in urls for globbed in fsspec_expand_glob(url)] + urls = [globbed for url in urls for globbed in expand_glob(url)] return urls @@ -625,7 +625,7 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): import levanter.data - validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] + validation_urls = [url for url_pat in config.validation_urls for url in expand_glob(url_pat)] dataset = levanter.data.datasource_from_jsonl(validation_urls) input_field = config.input_field diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 62752fc55..4e8244882 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -13,7 +13,7 @@ from concurrent.futures import Future as threading_Future from contextlib import AbstractContextManager from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TypeVar, Union +from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union import deepdiff import fsspec.core @@ -27,6 +27,7 @@ from fsspec import AbstractFileSystem from jaxtyping import PyTree from ray.actor import ActorHandle +from ray.runtime_env import RuntimeEnv from tqdm_loggable.auto import tqdm from levanter.data import batched @@ -35,6 +36,7 @@ from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource +from ..utils.fsspec_utils import async_remove from ..utils.ray_utils import ( ExceptionInfo, RefBox, @@ -538,13 +540,6 @@ def empty(): return CacheMetadata() -@dataclass -class _ShardStatus: - shard_name: str - num_rows_committed: int - is_finished: bool - - class SerialCacheWriter(AbstractContextManager): """ Writes TreeCache-compatible caches to disk. This is a serial version of TreeCacheWriter that doesn't use Ray. @@ -602,91 +597,6 @@ def write_batch(self, batch: BatchResult): self._tree_store.extend(cbatch) -class ShardedCacheWriter: - """ - Similar to SerialCacheWriter, but tracks shard metadata. - - Similar to _OrderedCacheWriter, it also supports resuming, and it - groups together batches before writing (at some interval) in order to improve performance. - """ - - def __init__( - self, - cache_dir: str, - initial_ledger: CacheLedger, - exemplar: T, - on_write: Optional[Callable[[CacheLedger], None]] = None, - ): - self.cache_dir = cache_dir - self._on_write = on_write - - self._ledger = copy.deepcopy(initial_ledger) - - self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a") # type: ignore - self._tree_store.trim_to_size(self._ledger.total_num_rows) - - @property - def ledger(self): - return self._ledger - - # we have both versions b/c we need this one for actors - def get_ledger(self): - return self._ledger - - @property - def is_finished(self): - return self._ledger.is_finished - - def finish_shard(self, shard_name: str, num_rows: int): - current_rows = self._ledger.shard_rows.get(shard_name, 0) - if current_rows != num_rows: - raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}") - - self._ledger.finished_shards.append(shard_name) - self._ledger._serialize_and_commit(self.cache_dir) - - def write_prepared_batch(self, shard_counts: Mapping[str, int], batch: PyTree[PreparedBatch]): - if self.is_finished: - raise RuntimeError("Cannot write to a finished cache") - self._tree_store.extend_with_batch(batch) - - for shard, num_rows in shard_counts.items(): - self._ledger.shard_rows[shard] = self._ledger.shard_rows.get(shard, 0) + num_rows - - total_rows = self._ledger.total_num_rows + sum(shard_counts.values()) - self._ledger.total_num_rows = total_rows - self._ledger._serialize_and_commit(self.cache_dir) - - if self._on_write: - self._on_write(self._ledger) - - def write_batch(self, shard_name: str, batch: BatchResult): - if self.is_finished: - raise RuntimeError("Cannot write to a finished cache") - - if isinstance(batch, pa.RecordBatch): - raise NotImplementedError("Only non-RecordBatch batches are supported for now") - - batch = _canonicalize_batch(batch) # type: ignore - prepared = self._tree_store.batch_preparer(batch) - - return self.write_prepared_batch({shard_name: len(batch)}, prepared) - - def finish(self): - # if successful, write the ledger - logger.info("Finished writing cache") - # check that all shards are finished - if set(self._ledger.shard_rows.keys()) != set(self._ledger.finished_shards): - raise ValueError("Not all shards are finished") - - self._ledger.is_finished = True - self._ledger._serialize_and_commit(self.cache_dir) - if self._on_write: - self._on_write(self._ledger) - - return self._tree_store - - def _serialize_json_and_commit(path, obj): # just to be paranoid, we write to a temp file and then rename it # TODO: probably we could do better here @@ -709,7 +619,9 @@ def _serialize_json_and_commit(path, obj): pass -@ray.remote(num_cpus=0.1) # keep this small b/c it doesn't do a lot +@ray.remote( + num_cpus=0.1, runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "cpu"}) +) # keep this small b/c it doesn't do a lot class _TreeStoreCacheBuilder(SnitchRecipient): """ Actor that coordinates the building of a cache. It spins up a bunch of workers to read from each shard @@ -1013,7 +925,7 @@ def _core_writer_task( # update the data offset tree this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True) data_offset_tree = jax.tree.map( - operator.add, data_offset_tree, jax.tree_map(lambda x: x.data_size, this_cache.tree) + operator.add, data_offset_tree, jax.tree.map(lambda x: x.data_size, this_cache.tree) ) total_rows_from_caches += this_ledger.total_num_rows @@ -1025,6 +937,16 @@ def _core_writer_task( ledger.is_finished = True parent._notify_updated_ledger.remote(ledger) + # clean up the temporary caches + async def cleanup(): + futures = [] + for path in already_finished_paths: + futures.append(async_remove(path, recursive=True)) + + await asyncio.gather(*futures) + + asyncio.run(cleanup()) + def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]: if num_groups is None or num_groups >= len(source.shard_names): @@ -1075,10 +997,8 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R """ with log_failures_to(parent): asyncio.run(_extend_cache_with_other_cache(dest_path, source_path, processor, data_offset_tree, rows_so_far)) - print("done copying", flush=True) if last_ref is not None: ray.wait([last_ref.ref], fetch_local=False) - print("done waiting", flush=True) permanent_cache = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) dest_ledger = CacheLedger.load(dest_path) source_ledger = CacheLedger.load(source_path) @@ -1089,13 +1009,12 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R for future in futures: future.result() - print("wrote rows", flush=True) - _merge_ledgers(dest_ledger, source_ledger) dest_ledger._serialize_and_commit(dest_path) assert not dest_ledger.is_finished + parent._notify_updated_ledger.remote(dest_ledger) - print("done", flush=True) + return dest_ledger diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py index 64870443d..cc03c174b 100644 --- a/src/levanter/utils/fsspec_utils.py +++ b/src/levanter/utils/fsspec_utils.py @@ -1,5 +1,10 @@ +import asyncio + import braceexpand import fsspec +from fsspec.asyn import AsyncFileSystem + +from levanter.utils.thread_utils import _executor, blocking_wait def exists(url, **kwargs) -> bool: @@ -14,7 +19,7 @@ def mkdirs(path): fs.makedirs(path, exist_ok=True) -def fsspec_expand_glob(url): +def expand_glob(url): expanded_urls = braceexpand.braceexpand(url) for expanded_url in expanded_urls: if "*" in expanded_url: @@ -28,3 +33,24 @@ def fsspec_expand_glob(url): yield from [f"{protocol}://{path}" for path in globbed] else: yield expanded_url + + +def remove(url, *, recursive=False, **kwargs): + """Remove a file from a remote filesystem.""" + # TODO: better to use a STS deletion policy or job for this one. + fs, path = fsspec.core.url_to_fs(url, **kwargs) + + if isinstance(fs, AsyncFileSystem): + blocking_wait(fs._rm(path, recursive=recursive)) + else: + fs.rm(path, recursive=recursive) + + +async def async_remove(url, *, recursive=False, **kwargs): + """Remove a file from a remote filesystem.""" + fs, path = fsspec.core.url_to_fs(url, **kwargs) + + if isinstance(fs, AsyncFileSystem): + return await fs._rm(path, recursive=recursive) + else: + return await asyncio.wrap_future(_executor.submit(fs.rm, path, recursive=recursive)) From 20cae0c5d5ca5f7fb3d91e4a8c4a82b0539c3d99 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 1 Nov 2024 11:59:26 -0700 Subject: [PATCH 05/19] more cleanup --- src/levanter/store/cache.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 4e8244882..aee2bc3d0 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -1038,8 +1038,6 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra # TODO: it'd be good if we just didn't expose the full data array (but only the used part) data_size = source_array.data_size data = source_array.data[0:data_size] - print(f"starting to write data. {data.read().result()=}", flush=True) - print(f"{row_offset=}", flush=True) futures: list[ts.Future] = [] # write_future = dest_array.data[data_offset : data_offset + source_array.data_size].write(data) @@ -1056,7 +1054,6 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra out_end = row_offset + source_num_rows shape_future = dest.with_transaction(txn)[row_offset:out_end].write(source_shapes) futures.append(shape_future) - print("done writing shapes", flush=True) source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]] source_offsets = _virtual_offset(source_offsets, data_offset) @@ -1066,14 +1063,9 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra out_end = row_offset + 1 + source_num_rows offset_future = dest.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets) - print("hi", flush=True) - print(f"done writing offsets {source_offsets.domain}", flush=True) - print(f"done writing offsets {dest[row_offset+1:out_end].read().result()}", flush=True) - futures.append(offset_future) out = await asyncio.gather(*futures) - print("done writing", flush=True) return out futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree) From cb85654799076b3db1b603cbe0c89294613d7c72 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 1 Nov 2024 16:47:58 -0700 Subject: [PATCH 06/19] ok, we are incremental! --- src/levanter/store/cache.py | 137 +++++++++++++++++++++++++++++------- tests/test_new_cache.py | 4 +- 2 files changed, 114 insertions(+), 27 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index aee2bc3d0..f01e3b881 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -13,11 +13,10 @@ from concurrent.futures import Future as threading_Future from contextlib import AbstractContextManager from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union import deepdiff import fsspec.core -import humanfriendly import jax import numpy as np import pyarrow as pa @@ -37,6 +36,7 @@ from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource from ..utils.fsspec_utils import async_remove +from ..utils.fsspec_utils import exists as fsspec_exists from ..utils.ray_utils import ( ExceptionInfo, RefBox, @@ -83,6 +83,10 @@ class CacheOptions: @property def target_bytes_per_flush(self): + if isinstance(self.target_size_per_flush, int): + return self.target_size_per_flush + import humanfriendly + return humanfriendly.parse_size(self.target_size_per_flush) @staticmethod @@ -667,6 +671,9 @@ def __init__( # TODO: measure. memory=2 * self._options.target_bytes_per_flush, ).remote(current_actor_handle(), cache_dir, self._ledger, source, options, processor) + + self._tokenize_pbar = tqdm(total=len(source.shard_names), desc="Tokenizing", unit="shard") + except Exception: # Ray behaves poorly if the constructor of an actor fails, so we catch and log here # this also propagates to the finished promise, so we can handle it there @@ -753,6 +760,19 @@ async def _do_notify_async(): asyncio.create_task(_do_notify_async()) + def _report_progress(self, report: "_ProgressReport"): + import humanfriendly + + self._tokenize_pbar.update(report.total_shards_completed) + mb_str = humanfriendly.format_size(report.total_bytes) + self._tokenize_pbar.set_postfix( + { + "rows": report.total_rows, + "shards": report.total_shards_completed, + "mb": mb_str, + } + ) + def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default()): name = f"lev_cache_manager::{split}::{cache_dir}" @@ -816,14 +836,22 @@ def _core_writer_task( # 1. write the 0th shard group to the output cache directly, updating metrics as we go # 2. in the background, start processing other shard groups to temporary caches # 3. once (1) is done, we start copying the temporary caches to the output cache (in order) - # for now we're going to punt on (1) + + # We notify the parent actor of progress and updates to the ledger. + # We special-case the 0'th ledger because we commit it to the output cache directly. + def report_fn(report: _ProgressReport, ledger: CacheLedger): + parent._report_progress.remote(report) + + def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): + parent._report_progress.remote(report) + parent._notify_updated_ledger.remote(ledger) + with log_failures_to(parent): temporary_cache_path = os.path.join(cache_dir, "___temp") paths: dict[str, str] = {} ledgers: dict[str, CacheLedger | None] = {} - already_finished_paths: list[str] = [] - refs: dict[str, ray.ObjectRef] = {} + write_refs: dict[str, ray.ObjectRef] = {} shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) @@ -832,23 +860,32 @@ def _core_writer_task( ) unit = "shard" if len(shard_groups) == len(source.shard_names) else "shard group" - pbar = tqdm(total=len(shard_groups), desc="Tokenizing", unit=unit) processor_ref = ray.put(processor) source_ref = ray.put(source) + # We treat the first group specially: we tokenize it directly to the output cache (since it comes first) + # This enables us to expose data quickly + first_group = next(iter(shard_groups), None) + for group_name, shards in shard_groups.items(): - path = os.path.join(temporary_cache_path, group_name) - paths[group_name] = path + if group_name == first_group: + group_out_path = cache_dir + else: + group_out_path = os.path.join(temporary_cache_path, group_name) - ledger = _try_load(path) + paths[group_name] = group_out_path + + ledger = _try_load(group_out_path) ledgers[group_name] = ledger if ledger is not None: - already_finished_paths.append(path) - pbar.update(1) + if group_name == first_group: + parent._notify_updated_ledger.remote(ledger) continue + report_fn = report_fn_first_group if group_name == first_group else report_fn + ref = ( ray.remote(_tokenize_one_shard_group) .options( # type: ignore @@ -860,10 +897,10 @@ def _core_writer_task( retry_exceptions=True, max_retries=10, ) - .remote(os.path.join(temporary_cache_path, group_name), source_ref, shards, processor_ref, options) + .remote(group_out_path, source_ref, shards, processor_ref, options, report_fn, parent) ) - refs[group_name] = ref + write_refs[group_name] = ref # now we start copying the temporary caches to the output cache, in order. (essentially concatenating them) # This logic is a bit hairy thanks to resumes. @@ -891,11 +928,13 @@ def _core_writer_task( copy_refs: dict[str, ray.ObjectRef] = {} last_ref: ray.ObjectRef | None = None + copying_pbar = tqdm(total=len(shard_groups), desc="Copying", unit=unit, leave=False, position=1) for group in shard_groups: # first make sure it's either done this run or already done - if refs.get(group) is not None: - this_ledger = ray.get(refs[group]) + if write_refs.get(group) is not None: + this_ledger = ray.get(write_refs[group]) + ledgers[group] = ledger else: this_ledger = ledgers[group] @@ -903,8 +942,10 @@ def _core_writer_task( assert this_ledger is not None # see if we already copied this group, meaning all the shards are in the permanent cache shards_copied = sum(1 if shard in initial_ledger.finished_shards else 0 for shard in shard_groups[group]) - if shards_copied == len(shard_groups[group]): + if shards_copied == len(shard_groups[group]) or group == first_group: assert initial_ledger.total_num_rows >= total_rows_from_caches + copying_pbar.update(1) + elif shards_copied > 0: # In theory we can handle this, but it's a bit tricky, so we're going to punt for now raise RuntimeError("Some shards were copied but not all. This should never happen.") @@ -922,30 +963,49 @@ def _core_writer_task( ) copy_refs[group] = last_ref - # update the data offset tree + # update the offset information: data offsets and total rows this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True) data_offset_tree = jax.tree.map( operator.add, data_offset_tree, jax.tree.map(lambda x: x.data_size, this_cache.tree) ) total_rows_from_caches += this_ledger.total_num_rows + # this little bit is totally unnecessary but nice logging + for group in shard_groups: + if group == first_group: + continue + + if copy_refs.get(group) is not None: + ray.wait([copy_refs[group]], fetch_local=False) + copying_pbar.update(1) + + # refs form a linked list implicitly, so we can just wait on the last one if last_ref is not None: ledger = ray.get(last_ref) else: ledger = initial_ledger ledger.is_finished = True + ledger._serialize_and_commit(cache_dir) parent._notify_updated_ledger.remote(ledger) # clean up the temporary caches - async def cleanup(): - futures = [] - for path in already_finished_paths: + _clean_up_temp_caches(paths, first_group) + + +def _clean_up_temp_caches(paths, first_group): + async def cleanup(): + futures = [] + for group, path in paths.items(): + if group == first_group: + continue + + if fsspec_exists(path): futures.append(async_remove(path, recursive=True)) - await asyncio.gather(*futures) + await asyncio.gather(*futures) - asyncio.run(cleanup()) + asyncio.run(cleanup()) def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]: @@ -1117,12 +1177,22 @@ def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore return +@dataclass +class _ProgressReport: + total_rows: int + total_bytes: float + total_shards_completed: int + # TODO: other counts + + def _tokenize_one_shard_group( temporary_cache_path: str, source: ShardedDataSource, shards: list[str], processor: BatchProcessor, options: CacheOptions, + report_fn: Callable[[_ProgressReport, CacheLedger], None], + force_unfinalized: bool, ) -> CacheLedger: # ray breaks if this is top level import humanfriendly @@ -1146,9 +1216,13 @@ def _tokenize_one_shard_group( total_rows = ledger.total_num_rows found_shard_with_rows = False + report = _ProgressReport(total_rows, 0, 0) + for shard_name in shards: if shard_name in ledger.finished_shards: logger.info(f"Shard {shard_name} already processed.") + report.total_shards_completed += 1 + report_fn(report, ledger) continue logger.debug(f"Processing {shard_name}.") @@ -1185,16 +1259,27 @@ def _tokenize_one_shard_group( if batch_byte_size > options.target_bytes_per_flush: writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) + report.total_rows += this_batch_size + report.total_bytes += batch_byte_size + + report_fn(report, writer.ledger) + nice_bytes = humanfriendly.format_size(batch_byte_size) logger.debug( f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" ) + # print(f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})", flush=True) this_batch_size = 0 prepared_batch = None if prepared_batch is not None: batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) nice_bytes = humanfriendly.format_size(batch_byte_size) + + report.total_rows += this_batch_size + report.total_bytes += batch_byte_size + report_fn(report, writer.ledger) + writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) logger.debug( f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" @@ -1202,11 +1287,13 @@ def _tokenize_one_shard_group( this_batch_size = 0 prepared_batch = None - total_rows += rows_this_shard - + report.total_shards_completed += 1 writer.finish_shard(shard_name, rows_this_shard) - writer.finish() + report_fn(report, writer.ledger) + + if not force_unfinalized: + writer.finish() logger.info(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.") diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index c61c66105..0edd1fb15 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -321,12 +321,12 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: SlowShardSource(), TestProcessor(), await_finished=False, - options=CacheOptions.no_fanciness(5), + options=CacheOptions(target_size_per_flush=1, batch_size=1), ) # read the first 10 elements # ensure the first 10 elements are [{"test": np.array([i] * 10)} for i in range(10)] - first_10 = list(await asyncio.wait_for(cache.get_batch(range(0, 10)), timeout=10.0)) + first_10 = list(await asyncio.wait_for(cache.get_batch(range(0, 10)), timeout=30.0)) for i, x in enumerate(first_10): np.testing.assert_array_equal(x["test"], np.array([i] * 10)) From 47441c0fd12b91eada4e4133a515f007a074de16 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 2 Nov 2024 22:59:13 -0700 Subject: [PATCH 07/19] a bit worried the bookkeeping isn't quite right on resume, but we're almost there. --- config/gpt2_small_fast_pile.yaml | 2 +- src/levanter/store/cache.py | 122 ++++++++++++++++++----------- src/levanter/utils/fsspec_utils.py | 11 +-- 3 files changed, 79 insertions(+), 56 deletions(-) diff --git a/config/gpt2_small_fast_pile.yaml b/config/gpt2_small_fast_pile.yaml index 3a21732a7..291213d75 100644 --- a/config/gpt2_small_fast_pile.yaml +++ b/config/gpt2_small_fast_pile.yaml @@ -1,4 +1,4 @@ -data: !include data/pile_source_old.yaml +data: !include data/pile_mixture.yaml model: type: gpt2 hidden_dim: 768 diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index f01e3b881..cd59aef4c 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -35,8 +35,8 @@ from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource -from ..utils.fsspec_utils import async_remove from ..utils.fsspec_utils import exists as fsspec_exists +from ..utils.fsspec_utils import remove as fsspec_remove from ..utils.ray_utils import ( ExceptionInfo, RefBox, @@ -672,7 +672,13 @@ def __init__( memory=2 * self._options.target_bytes_per_flush, ).remote(current_actor_handle(), cache_dir, self._ledger, source, options, processor) - self._tokenize_pbar = tqdm(total=len(source.shard_names), desc="Tokenizing", unit="shard") + self._tokenize_pbar = tqdm( + total=len(source.shard_names), desc=f"{path_for_name}: tokenizing", unit="shard" + ) + self._copy_pbar = tqdm(total=len(source.shard_names), desc=f"{path_for_name}: copying", unit="shard") + self._report_totals = _ProgressReport(0, 0, 0) + self._copy_report_totals = _ProgressReport(0, 0, 0) + self._last_update = time.time() except Exception: # Ray behaves poorly if the constructor of an actor fails, so we catch and log here @@ -763,15 +769,39 @@ async def _do_notify_async(): def _report_progress(self, report: "_ProgressReport"): import humanfriendly - self._tokenize_pbar.update(report.total_shards_completed) - mb_str = humanfriendly.format_size(report.total_bytes) - self._tokenize_pbar.set_postfix( - { - "rows": report.total_rows, - "shards": report.total_shards_completed, - "mb": mb_str, - } - ) + self._tokenize_pbar.update(report.new_shards) + self._report_totals.new_shards += report.new_shards + self._report_totals.new_rows += report.new_rows + self._report_totals.new_bytes += report.new_bytes + + if time.time() - self._last_update > 10.0: + self._last_update = time.time() + + mb_str = humanfriendly.format_size(self._report_totals.new_bytes) + self._tokenize_pbar.set_postfix( + { + "rows": self._report_totals.new_rows, + "shards": self._report_totals.new_shards, + "size": mb_str, + } + ) + + def _report_copy_progress(self, report: "_ProgressReport"): + # TODO: log bytes copied + self._copy_pbar.update(report.new_shards) + self._copy_report_totals.new_shards += report.new_shards + self._copy_report_totals.new_rows += report.new_rows + self._copy_report_totals.new_bytes += report.new_bytes + + if time.time() - self._last_update > 10.0: + self._last_update = time.time() + self._copy_pbar.set_postfix( + { + "shards": report.new_shards, + "rows": report.new_rows, + # "size": humanfriendly.format_size(report.new_bytes), + } + ) def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default()): @@ -855,12 +885,10 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) - logger.info( + logger.debug( f"Tokenizing {len(source.shard_names)} shards in {len(shard_groups)} groups to {temporary_cache_path}." ) - unit = "shard" if len(shard_groups) == len(source.shard_names) else "shard group" - processor_ref = ray.put(processor) source_ref = ray.put(source) @@ -928,7 +956,6 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): copy_refs: dict[str, ray.ObjectRef] = {} last_ref: ray.ObjectRef | None = None - copying_pbar = tqdm(total=len(shard_groups), desc="Copying", unit=unit, leave=False, position=1) for group in shard_groups: # first make sure it's either done this run or already done @@ -944,7 +971,9 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): shards_copied = sum(1 if shard in initial_ledger.finished_shards else 0 for shard in shard_groups[group]) if shards_copied == len(shard_groups[group]) or group == first_group: assert initial_ledger.total_num_rows >= total_rows_from_caches - copying_pbar.update(1) + parent._report_copy_progress.remote( + _ProgressReport(new_shards=shards_copied, new_rows=initial_ledger.total_num_rows) + ) elif shards_copied > 0: # In theory we can handle this, but it's a bit tricky, so we're going to punt for now @@ -976,8 +1005,11 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): continue if copy_refs.get(group) is not None: - ray.wait([copy_refs[group]], fetch_local=False) - copying_pbar.update(1) + ledger = ray.get(copy_refs[group]) + ledgers[group] = ledger + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(ledger.finished_shards), new_rows=ledger.total_num_rows) + ) # refs form a linked list implicitly, so we can just wait on the last one if last_ref is not None: @@ -989,23 +1021,23 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): ledger._serialize_and_commit(cache_dir) parent._notify_updated_ledger.remote(ledger) - # clean up the temporary caches _clean_up_temp_caches(paths, first_group) def _clean_up_temp_caches(paths, first_group): - async def cleanup(): - futures = [] - for group, path in paths.items(): - if group == first_group: - continue - - if fsspec_exists(path): - futures.append(async_remove(path, recursive=True)) - - await asyncio.gather(*futures) + for group, path in paths.items(): + if group == first_group: + continue - asyncio.run(cleanup()) + if fsspec_exists(path): + for i in range(10): + # this is crashy for some reason + try: + fsspec_remove(path, recursive=True) + break + except Exception: + logger.exception(f"Failed to remove {path} on attempt {i}") + time.sleep(1) def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]: @@ -1074,6 +1106,9 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R assert not dest_ledger.is_finished parent._notify_updated_ledger.remote(dest_ledger) + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(source_ledger.shard_rows), new_rows=source_ledger.total_num_rows) + ) return dest_ledger @@ -1179,9 +1214,9 @@ def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore @dataclass class _ProgressReport: - total_rows: int - total_bytes: float - total_shards_completed: int + new_rows: int = 0 + new_bytes: float = 0 + new_shards: int = 0 # TODO: other counts @@ -1216,13 +1251,13 @@ def _tokenize_one_shard_group( total_rows = ledger.total_num_rows found_shard_with_rows = False - report = _ProgressReport(total_rows, 0, 0) + if total_rows > 0: + report_fn(_ProgressReport(new_rows=total_rows), ledger) for shard_name in shards: if shard_name in ledger.finished_shards: logger.info(f"Shard {shard_name} already processed.") - report.total_shards_completed += 1 - report_fn(report, ledger) + report_fn(_ProgressReport(new_shards=1), ledger) continue logger.debug(f"Processing {shard_name}.") @@ -1247,6 +1282,7 @@ def _tokenize_one_shard_group( this_batch_size += len(batch) rows_this_shard += len(batch) + total_rows += len(batch) if prepared_batch is None: prepared_batch = this_prepared @@ -1259,10 +1295,7 @@ def _tokenize_one_shard_group( if batch_byte_size > options.target_bytes_per_flush: writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) - report.total_rows += this_batch_size - report.total_bytes += batch_byte_size - - report_fn(report, writer.ledger) + report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger) nice_bytes = humanfriendly.format_size(batch_byte_size) logger.debug( @@ -1276,9 +1309,7 @@ def _tokenize_one_shard_group( batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) nice_bytes = humanfriendly.format_size(batch_byte_size) - report.total_rows += this_batch_size - report.total_bytes += batch_byte_size - report_fn(report, writer.ledger) + report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger) writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) logger.debug( @@ -1287,15 +1318,14 @@ def _tokenize_one_shard_group( this_batch_size = 0 prepared_batch = None - report.total_shards_completed += 1 writer.finish_shard(shard_name, rows_this_shard) - report_fn(report, writer.ledger) + report_fn(_ProgressReport(new_shards=1), writer.ledger) if not force_unfinalized: writer.finish() - logger.info(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.") + logger.debug(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.") return writer.ledger diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py index cc03c174b..c8d3931fe 100644 --- a/src/levanter/utils/fsspec_utils.py +++ b/src/levanter/utils/fsspec_utils.py @@ -1,11 +1,7 @@ -import asyncio - import braceexpand import fsspec from fsspec.asyn import AsyncFileSystem -from levanter.utils.thread_utils import _executor, blocking_wait - def exists(url, **kwargs) -> bool: """Check if a file exists on a remote filesystem.""" @@ -40,10 +36,7 @@ def remove(url, *, recursive=False, **kwargs): # TODO: better to use a STS deletion policy or job for this one. fs, path = fsspec.core.url_to_fs(url, **kwargs) - if isinstance(fs, AsyncFileSystem): - blocking_wait(fs._rm(path, recursive=recursive)) - else: - fs.rm(path, recursive=recursive) + fs.rm(path, recursive=recursive) async def async_remove(url, *, recursive=False, **kwargs): @@ -53,4 +46,4 @@ async def async_remove(url, *, recursive=False, **kwargs): if isinstance(fs, AsyncFileSystem): return await fs._rm(path, recursive=recursive) else: - return await asyncio.wrap_future(_executor.submit(fs.rm, path, recursive=recursive)) + fs.rm(path, recursive=recursive) From dbdc2e49cd9cc373baec123f4e1ee0cb58531346 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 07:49:41 -0800 Subject: [PATCH 08/19] fix resume bookkeeping logic --- src/levanter/store/cache.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index cd59aef4c..51679921d 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -992,6 +992,11 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): ) copy_refs[group] = last_ref + if group == first_group: + # this is the first group, so it's already in the cache and we don't need to + # increment the data offset tree etc. + continue + # update the offset information: data offsets and total rows this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True) data_offset_tree = jax.tree.map( @@ -1045,14 +1050,15 @@ def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) return {shard_name: [shard_name] for shard_name in source.shard_names} shard_names = source.shard_names - num_shards_per_group = len(shard_names) // num_groups + num_shards_per_group = (len(shard_names) + num_groups - 1) // num_groups # if we have a remainder, we'll just add it to the last group out_groups = { f"group_{i}": list(shard_names[i * num_shards_per_group : (i + 1) * num_shards_per_group]) for i in range(num_groups) } - if len(shard_names) % num_shards_per_group != 0: - out_groups[f"group_{num_groups - 1}"].extend(shard_names[num_groups * num_shards_per_group :]) + + # make sure we got all the shards + assert sum(len(shards) for shards in out_groups.values()) == len(shard_names) return out_groups # type: ignore From fb44f035df1abc18211a1979c183558f550a8cf0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 10:02:42 -0800 Subject: [PATCH 09/19] wip --- src/levanter/store/cache.py | 229 ++++++++++++++++++++---------------- 1 file changed, 130 insertions(+), 99 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 51679921d..b1b69e2c3 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -670,7 +670,7 @@ def __init__( # (we get twice from we need to concatenate prepared batches into the accumulator) # TODO: measure. memory=2 * self._options.target_bytes_per_flush, - ).remote(current_actor_handle(), cache_dir, self._ledger, source, options, processor) + ).remote(current_actor_handle(), cache_dir, source, options, processor) self._tokenize_pbar = tqdm( total=len(source.shard_names), desc=f"{path_for_name}: tokenizing", unit="shard" @@ -787,7 +787,6 @@ def _report_progress(self, report: "_ProgressReport"): ) def _report_copy_progress(self, report: "_ProgressReport"): - # TODO: log bytes copied self._copy_pbar.update(report.new_shards) self._copy_report_totals.new_shards += report.new_shards self._copy_report_totals.new_rows += report.new_rows @@ -840,7 +839,6 @@ class _ShardFinished: def _core_writer_task( parent, cache_dir, - initial_ledger: CacheLedger, source: ShardedDataSource, options: CacheOptions, processor, @@ -859,9 +857,6 @@ def _core_writer_task( # append a small random number to the name to avoid collisions name += f"::{random.randint(0, 1000)}" - # We want to make sure it's there - initial_ledger._serialize_and_commit(cache_dir) - # we want to do the following: # 1. write the 0th shard group to the output cache directly, updating metrics as we go # 2. in the background, start processing other shard groups to temporary caches @@ -879,8 +874,8 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): with log_failures_to(parent): temporary_cache_path = os.path.join(cache_dir, "___temp") - paths: dict[str, str] = {} - ledgers: dict[str, CacheLedger | None] = {} + group_cache_paths: dict[str, str] = {} + group_ledgers: dict[str, CacheLedger | None] = {} write_refs: dict[str, ray.ObjectRef] = {} shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) @@ -902,10 +897,10 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): else: group_out_path = os.path.join(temporary_cache_path, group_name) - paths[group_name] = group_out_path + group_cache_paths[group_name] = group_out_path ledger = _try_load(group_out_path) - ledgers[group_name] = ledger + group_ledgers[group_name] = ledger if ledger is not None: if group_name == first_group: @@ -931,109 +926,145 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): write_refs[group_name] = ref # now we start copying the temporary caches to the output cache, in order. (essentially concatenating them) - # This logic is a bit hairy thanks to resumes. - # First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these - # separately. We also need to update the ledger as we go. - # Second, note that JaggedArrayStores have two notions of length: the number of rows, and the data size. - # We store the number of rows in offsets[0], and the data size in offsets[offsets[0]], which is just the final offset. - # So we can keep a cache "locked" to a particular read size until we're ready by controlling the offsets. - - # * When we load the permanent cache, we have already written some number of groups to it. - # (We check this invariant with an assert) - # * We need to copy the remaining groups to the permanent cache, and update the ledger as we go. - # * To copy a group, we need to know the total number of rows in that group, as well as the "data offsets" - # for the data in the cache. We can get the total number of rows from the ledger, and we also calculate - # the data offsets for where the group goes in the permanent cache. This is just a running sum of the - # data sizes of the previous groups. Because we have multiple JaggedArrayStores, this can be a pytree - # of integers, one for each array. - # * Once we have finished the i'th cache and all caches < 1, we can "unlock" the data for the i'th cache - # by updating the offset[0] of the permanent cache to the total number of rows through the i'th cache. - # * We also need to update the ledger with the total number of rows - permanent_cache = TreeStore.open(processor.output_exemplar, cache_dir, mode="a", cache_metadata=False) - # initialize the data offset tree - data_offset_tree = jax.tree_map(lambda x: 0, permanent_cache.tree) - total_rows_from_caches = 0 - - copy_refs: dict[str, ray.ObjectRef] = {} - last_ref: ray.ObjectRef | None = None - - for group in shard_groups: - # first make sure it's either done this run or already done - if write_refs.get(group) is not None: - this_ledger = ray.get(write_refs[group]) - - ledgers[group] = ledger - else: - this_ledger = ledgers[group] - - assert this_ledger is not None - # see if we already copied this group, meaning all the shards are in the permanent cache - shards_copied = sum(1 if shard in initial_ledger.finished_shards else 0 for shard in shard_groups[group]) - if shards_copied == len(shard_groups[group]) or group == first_group: - assert initial_ledger.total_num_rows >= total_rows_from_caches - parent._report_copy_progress.remote( - _ProgressReport(new_shards=shards_copied, new_rows=initial_ledger.total_num_rows) - ) - elif shards_copied > 0: - # In theory we can handle this, but it's a bit tricky, so we're going to punt for now - raise RuntimeError("Some shards were copied but not all. This should never happen.") - else: - # we need to copy this group - ref_to_send = None if last_ref is None else RefBox(last_ref) - last_ref = _copy_cache.remote( - cache_dir, - paths[group], - processor_ref, - data_offset_tree, - ref_to_send, - total_rows_from_caches, - parent, - ) - copy_refs[group] = last_ref + ledger = _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, group_ledgers, + group_cache_paths, processor, processor_ref) - if group == first_group: - # this is the first group, so it's already in the cache and we don't need to - # increment the data offset tree etc. - continue + ledger.is_finished = True + ledger._serialize_and_commit(cache_dir) + parent._notify_updated_ledger.remote(ledger) - # update the offset information: data offsets and total rows - this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True) - data_offset_tree = jax.tree.map( - operator.add, data_offset_tree, jax.tree.map(lambda x: x.data_size, this_cache.tree) - ) - total_rows_from_caches += this_ledger.total_num_rows + temporary_cache_paths = set(group_cache_paths.values()) - {cache_dir} + _clean_up_temp_caches(temporary_cache_paths) - # this little bit is totally unnecessary but nice logging - for group in shard_groups: - if group == first_group: - continue - if copy_refs.get(group) is not None: - ledger = ray.get(copy_refs[group]) - ledgers[group] = ledger - parent._report_copy_progress.remote( - _ProgressReport(new_shards=len(ledger.finished_shards), new_rows=ledger.total_num_rows) - ) +def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, group_ledgers, group_cache_paths, processor, + processor_ref): + """ + Copy the temporary caches to the output cache, in order. (essentially concatenating them) - # refs form a linked list implicitly, so we can just wait on the last one - if last_ref is not None: - ledger = ray.get(last_ref) + Args: + parent: the parent actor handle (_TreeStoreCacheBuilder) + cache_dir: the output cache directory + shard_groups: a dict mapping group names to lists of shard names + first_group: the privileged group that is written directly to the output cache + write_refs: a dict mapping group names to ray.ObjectRefs of the cache building tasks + group_ledgers: a dict mapping group names to the ledgers for the groups. Mutated in place. + group_cache_paths: a dict mapping group names to the paths of the temporary caches + processor: the processor object + processor_ref: a ray.ObjectRef of the processor object + """ + # This logic is a bit hairy thanks to resumes. + # First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these + # separately. We also need to update the ledger as we go. + # Second, note that JaggedArrayStores have two notions of length: the number of rows, and the data size. + # We store the number of rows in offsets[0], and the data size in offsets[offsets[0]], which is just the final offset. + # So we can keep a cache "locked" to a particular read size until we're ready by controlling the offsets. + + # * When we load the permanent cache, we have already written some number of groups to it. In + # particular, we have written the 0'th group to the permanent cache. + # * We enforce that we only commit a whole group to the ledger at a time. + # * We need to copy the remaining groups to the permanent cache, and update the ledger as we go. + # * To copy a group, we need to know the total number of rows in that group, as well as the "data offsets" + # for the data in the cache. We can get the total number of rows from the ledger, and we also calculate + # the data offsets for where the group goes in the permanent cache. This is just a running sum of the + # data sizes of the previous groups. Because we have multiple JaggedArrayStores, this can be a pytree + # of integers, one for each array. + # * Once we have finished the i'th cache and all caches < 1, we can "unlock" the data for the i'th cache + # by updating the offset[0] of the permanent cache to the total number of rows through the i'th cache. + # * We also need to update the ledger with the total number of rows + + # reload the ledger for the first group, which will be the sink for the other groups + assert first_group in write_refs + + group_ledgers[first_group] = ray.get(write_refs[first_group]) + overall_ledger = group_ledgers[first_group] + + # initialize the data offset tree + permanent_cache = TreeStore.open(processor.output_exemplar, cache_dir, mode="a", cache_metadata=False) + data_offset_tree = jax.tree_map(lambda x: x.data_size, permanent_cache.tree) + total_rows_from_caches = overall_ledger.total_num_rows + copy_refs: dict[str, ray.ObjectRef] = {} + last_ref: ray.ObjectRef | None = None + + found_one_to_copy = False + + for group in shard_groups: + # first make sure it's either done this run or already done + if write_refs.get(group) is not None: + this_ledger = ray.get(write_refs[group]) + group_ledgers[group] = this_ledger else: - ledger = initial_ledger + this_ledger = group_ledgers[group] - ledger.is_finished = True - ledger._serialize_and_commit(cache_dir) - parent._notify_updated_ledger.remote(ledger) + if group == first_group: + # this is the first group, so it's already in the cache and we don't need to + # increment the data offset tree etc. + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(overall_ledger.finished_shards), new_rows=overall_ledger.total_num_rows) + ) + continue - _clean_up_temp_caches(paths, first_group) + assert this_ledger is not None + # see if we already copied this group, meaning all the shards are in the permanent cache + shards_copied = sum(1 if shard in overall_ledger.finished_shards else 0 for shard in shard_groups[group]) + + if found_one_to_copy and shards_copied > 0: + raise RuntimeError("A previous group was copied, but this group was not. This should never happen.") + elif shards_copied == len(shard_groups[group]): + assert overall_ledger.total_num_rows >= total_rows_from_caches, f"{overall_ledger.total_num_rows} < {total_rows_from_caches}. {group}" + continue # nothing to do + elif shards_copied > 0: + # In theory we can handle this, but it's a bit tricky, so we're going to punt for now + raise RuntimeError("Some shards were copied but not all. This should never happen.") + + found_one_to_copy = True + # we need to copy this group + + # we can't "commit" the group to the ledger (or the number of rows) + # until we've updated the ledger for all previous groups, so we block on the last ref + ref_to_send = None if last_ref is None else RefBox(last_ref) + + last_ref = _copy_cache.remote( + cache_dir, + group_cache_paths[group], + processor_ref, + data_offset_tree, + ref_to_send, + total_rows_from_caches, + parent, + ) + copy_refs[group] = last_ref + # update the offset information: data offsets and total rows + this_cache = TreeStore.open(processor.output_exemplar, group_cache_paths[group], mode="r", cache_metadata=True) + data_offset_tree = jax.tree.map( + operator.add, data_offset_tree, jax.tree.map(lambda x: x.data_size, this_cache.tree) + ) + total_rows_from_caches += this_ledger.total_num_rows -def _clean_up_temp_caches(paths, first_group): - for group, path in paths.items(): + # this little bit is totally unnecessary but nice logging + for group in shard_groups: if group == first_group: continue + if copy_refs.get(group) is not None: + ledger = ray.get(copy_refs[group]) + group_ledgers[group] = ledger + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(ledger.finished_shards), new_rows=ledger.total_num_rows) + ) + + # refs form a linked list implicitly, so we can just wait on the last one + if last_ref is not None: + ledger = ray.get(last_ref) + else: + ledger = overall_ledger + return ledger + + +def _clean_up_temp_caches(paths): + for path in paths: if fsspec_exists(path): for i in range(10): # this is crashy for some reason From 98e017093b054bf9ec4019f434f1460b1894d7a1 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 16:15:35 -0800 Subject: [PATCH 10/19] reorg r --- src/levanter/store/cache.py | 7 ++----- tests/test_new_cache.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index b1b69e2c3..affe274eb 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -330,7 +330,6 @@ def build_or_load( shard_source=shard_source, processor=processor, options=options, - split=split, ) return TreeCache(cache_dir=cache_dir, exemplar=processor.output_exemplar, ledger=None, _broker=broker) @@ -637,7 +636,6 @@ def __init__( self, cache_dir: str, name: str, - split: str, # to workaround https://github.com/ray-project/ray/issues/44083 source: ShardedDataSource[T], processor: BatchProcessor[T, U], options: CacheOptions, @@ -803,14 +801,13 @@ def _report_copy_progress(self, report: "_ProgressReport"): ) -def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default()): - name = f"lev_cache_manager::{split}::{cache_dir}" +def _get_builder_actor(cache_dir, shard_source, processor, options=CacheOptions.default()): + name = f"lev_cache_manager::{cache_dir}" path_for_name = os.path.join(*os.path.split(cache_dir)[-2:]) name_for_display = f"builder::{path_for_name}" return _TreeStoreCacheBuilder.options(name=name, get_if_exists=True).remote( # type: ignore name=name_for_display, - split=split, cache_dir=cache_dir, source=shard_source, processor=processor, diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index 0edd1fb15..086de48e1 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -128,13 +128,13 @@ def test_full_end_to_end_cache(): with td as tmpdir: ray_ds = build_or_load_cache( tmpdir, - SimpleShardSource(num_shards=2), + SimpleShardSource(num_shards=15), TestProcessor(), await_finished=True, - options=CacheOptions.no_fanciness(8), + options=CacheOptions(num_shard_groups=3, batch_size=8), ) - expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=2)) + expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=15)) all_data = ray_ds[:] @@ -191,7 +191,6 @@ class _CustomException(Exception): @pytest.mark.ray -@pytest.mark.skip("This test segfaults in CI. I think a ray bug") def test_cache_recover_from_crash(): class CrashingShardSource(ShardedDataSource[list[int]]): def __init__(self, crash_point: int): @@ -205,7 +204,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # parse the shard name to get the shard number shard_num = int(shard_name.split("_")[1]) for i in range(10): - if shard_num * 10 + i == self.crash_point: + if i == self.crash_point: raise _CustomException(f"Crashing at {shard_num} {i} {self.crash_point}") if i >= row: yield [shard_num * 10 + i] * 10 @@ -213,7 +212,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as tmpdir2: source = CrashingShardSource(4) with pytest.raises(_CustomException): - build_or_load_cache(tmpdir, source, TestProcessor()) + build_or_load_cache(tmpdir, source, TestProcessor(), CacheOptions(target_size_per_flush=1)) # kill the broker actor so that we can test recovery ray.kill( @@ -231,11 +230,11 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: ) # testing this doesn't throw - source = CrashingShardSource(1000) + source = CrashingShardSource(100000) reader1 = build_or_load_cache(tmpdir, source, TestProcessor(), await_finished=True) # compare to the original with no crash - reader2 = build_or_load_cache(tmpdir2, SimpleShardSource(), TestProcessor(), await_finished=True) + reader2 = build_or_load_cache(tmpdir2, SimpleShardSource(num_shards=4), TestProcessor(), await_finished=True) check_datasets_equal(reader1, reader2) From bebf4038c503443fdace611b23eb9418fab3e4c0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 16:17:58 -0800 Subject: [PATCH 11/19] fix hf data loading for datasets>=3.1.0 --- src/levanter/data/sharded_datasource.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 186a0d9dd..90803df3e 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -197,7 +197,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: dataset = self._load_dataset() if isinstance(dataset, datasets.IterableDataset) and shard_name != "data": # ex_iterable has a key that gets discarded typically - shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards)) + shard = map( + lambda t: t[1], + dataset._ex_iterable.shard_data_sources(index=int(shard_name), num_shards=dataset.n_shards), + ) else: shard = dataset From ad0c3573a558a7afa236425f260bf53d02ba09ad Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 16:18:27 -0800 Subject: [PATCH 12/19] go ahead and bump datasets --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 19fb077bf..0831605cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "draccus>=0.8.0", "pyarrow>=11.0.0", "zstandard>=0.20.0", - "datasets>=2.18,<4.0", + "datasets>=3.1.0,<4.0", "gcsfs>=2024.2,<2024.10", "braceexpand>=0.1.7", "jmp>=0.0.3", From c823b7d5d6d8085a60a137cbf2409ceeed5b34fb Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 16:40:55 -0800 Subject: [PATCH 13/19] Fix hf datasets for new version (#784) --- pyproject.toml | 2 +- src/levanter/data/sharded_datasource.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 19fb077bf..0831605cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "draccus>=0.8.0", "pyarrow>=11.0.0", "zstandard>=0.20.0", - "datasets>=2.18,<4.0", + "datasets>=3.1.0,<4.0", "gcsfs>=2024.2,<2024.10", "braceexpand>=0.1.7", "jmp>=0.0.3", diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 186a0d9dd..90803df3e 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -197,7 +197,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: dataset = self._load_dataset() if isinstance(dataset, datasets.IterableDataset) and shard_name != "data": # ex_iterable has a key that gets discarded typically - shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards)) + shard = map( + lambda t: t[1], + dataset._ex_iterable.shard_data_sources(index=int(shard_name), num_shards=dataset.n_shards), + ) else: shard = dataset From fa00824eb129d279139d7814f6403b38b46a605a Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 4 Nov 2024 20:57:59 -0800 Subject: [PATCH 14/19] wip --- src/levanter/store/cache.py | 49 +++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index affe274eb..794af8e23 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -742,11 +742,18 @@ def _notify_updated_ledger(self, ledger: CacheLedger): Called by the cache writer when it has updated the ledger. """ was_finished = self._ledger.is_finished - self._ledger = ledger + # ensure the ledger is "monotonic" meaning that we only expect it to grow + if ledger.total_num_rows < self._ledger.total_num_rows: + raise RuntimeError(f"Ledger went backwards: {ledger.total_num_rows} < {self._ledger.total_num_rows}") + + for shard, rows in ledger.shard_rows.items(): + if rows < self._ledger.shard_rows.get(shard, 0): + raise RuntimeError(f"Shard {shard} went backwards: {rows} < {self._ledger.shard_rows.get(shard, 0)}") if was_finished: raise RuntimeError("Ledger was already finished") + self._ledger = ledger if self._ledger.is_finished: logger.info(f"Finalizing cache {self._cache_dir}...") # guard against invalid state errors @@ -767,7 +774,8 @@ async def _do_notify_async(): def _report_progress(self, report: "_ProgressReport"): import humanfriendly - self._tokenize_pbar.update(report.new_shards) + if report.new_shards > 0: + self._tokenize_pbar.update(report.new_shards) self._report_totals.new_shards += report.new_shards self._report_totals.new_rows += report.new_rows self._report_totals.new_bytes += report.new_bytes @@ -866,7 +874,7 @@ def report_fn(report: _ProgressReport, ledger: CacheLedger): def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): parent._report_progress.remote(report) - parent._notify_updated_ledger.remote(ledger) + ray.get(parent._notify_updated_ledger.remote(ledger)) with log_failures_to(parent): temporary_cache_path = os.path.join(cache_dir, "___temp") @@ -877,6 +885,9 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) + for name, group in shard_groups.items(): + assert len(group) > 0 + logger.debug( f"Tokenizing {len(source.shard_names)} shards in {len(shard_groups)} groups to {temporary_cache_path}." ) @@ -901,10 +912,10 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): if ledger is not None: if group_name == first_group: - parent._notify_updated_ledger.remote(ledger) + ray.get(parent._notify_updated_ledger.remote(ledger)) continue - report_fn = report_fn_first_group if group_name == first_group else report_fn + report_fn_to_use = report_fn_first_group if group_name == first_group else report_fn ref = ( ray.remote(_tokenize_one_shard_group) @@ -917,7 +928,7 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): retry_exceptions=True, max_retries=10, ) - .remote(group_out_path, source_ref, shards, processor_ref, options, report_fn, parent) + .remote(group_out_path, source_ref, shards, processor_ref, options, report_fn_to_use, parent) ) write_refs[group_name] = ref @@ -929,7 +940,7 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): ledger.is_finished = True ledger._serialize_and_commit(cache_dir) - parent._notify_updated_ledger.remote(ledger) + ray.get(parent._notify_updated_ledger.remote(ledger)) temporary_cache_paths = set(group_cache_paths.values()) - {cache_dir} _clean_up_temp_caches(temporary_cache_paths) @@ -1078,12 +1089,16 @@ def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) return {shard_name: [shard_name] for shard_name in source.shard_names} shard_names = source.shard_names - num_shards_per_group = (len(shard_names) + num_groups - 1) // num_groups - # if we have a remainder, we'll just add it to the last group - out_groups = { - f"group_{i}": list(shard_names[i * num_shards_per_group : (i + 1) * num_shards_per_group]) - for i in range(num_groups) - } + num_shards_per_group = (len(shard_names)) // num_groups + num_groups_with_extra = len(shard_names) % num_groups + + # if we have a remainder, we want to distribute the extra shards evenly + out_groups: dict[str, list[str]] = {} + start = 0 + for i in range(num_groups): + num_shards = num_shards_per_group + (1 if i < num_groups_with_extra else 0) + out_groups[f"group_{i}"] = list(shard_names[start : start + num_shards]) + start += num_shards # make sure we got all the shards assert sum(len(shards) for shards in out_groups.values()) == len(shard_names) @@ -1099,7 +1114,9 @@ def _merge_ledgers(dest, source): dest.shard_rows[shard] = rows dest.finished_shards.extend(source.finished_shards) - dest.field_counts.update(source.field_counts) + for field, count in source.field_counts.items(): + dest.field_counts[field] = dest.field_counts.get(field, 0) + count + return dest @@ -1126,7 +1143,6 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R if last_ref is not None: ray.wait([last_ref.ref], fetch_local=False) permanent_cache = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) - dest_ledger = CacheLedger.load(dest_path) source_ledger = CacheLedger.load(source_path) new_num_rows = source_ledger.total_num_rows + rows_so_far @@ -1135,11 +1151,12 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R for future in futures: future.result() + dest_ledger = CacheLedger.load(dest_path) _merge_ledgers(dest_ledger, source_ledger) dest_ledger._serialize_and_commit(dest_path) assert not dest_ledger.is_finished - parent._notify_updated_ledger.remote(dest_ledger) + ray.get(parent._notify_updated_ledger.remote(dest_ledger)) parent._report_copy_progress.remote( _ProgressReport(new_shards=len(source_ledger.shard_rows), new_rows=source_ledger.total_num_rows) ) From 91383f31b1ac17c77823b1e8df1646641edc812e Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 4 Nov 2024 23:05:43 -0800 Subject: [PATCH 15/19] fix wandb key (#785) turns out ray doesn't merge things when you use .options, which... what. --- src/levanter/infra/ray_tpu.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 57f484770..b04648079 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -11,6 +11,7 @@ from typing import Callable, Optional, Sequence import draccus +import mergedeep import ray from ray._private.accelerators import TPUAcceleratorManager from ray.dashboard.modules.job.sdk import JobSubmissionClient @@ -198,10 +199,15 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8 + # ray doesn't merge the runtime envs properly, so we have to do it ourselves + # we need to do a deep merge + runtime_env = mergedeep.merge({}, runtime_env, remote_fn._runtime_env, strategy=mergedeep.Strategy.ADDITIVE) + remote_fn = remote_fn.options( runtime_env=runtime_env, resources={tpu_name: 1, "TPU": num_tpus_per_host}, ) + logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host") return remote_fn, tpu_name From 45d3e70f3a61fadf693da75fd322072947caf3f1 Mon Sep 17 00:00:00 2001 From: William Held Date: Tue, 5 Nov 2024 12:21:15 -0500 Subject: [PATCH 16/19] Fix Llama 3 Tests (#782) HuggingFace seems to have changed a few things around in what info they expect to be stored in the config leading the Llama 3 roundtrip tests to hit errors. AFAICT, the Torch tests aren't running in CI so this just fixes the regression! ![image](https://github.com/user-attachments/assets/eec1f2a3-ceb9-443a-911e-ea6476fa91bf) --- src/levanter/models/rotary.py | 1 + tests/test_llama3.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/levanter/models/rotary.py b/src/levanter/models/rotary.py index 07657e5ff..55bbf3fcb 100644 --- a/src/levanter/models/rotary.py +++ b/src/levanter/models/rotary.py @@ -157,6 +157,7 @@ def to_hf_config(self) -> tuple[float, dict]: "low_freq_factor": self.low_freq_factor, "high_freq_factor": self.high_freq_factor, "original_max_position_embeddings": self.original_max_position_embeddings, + "rope_type": "llama3", } diff --git a/tests/test_llama3.py b/tests/test_llama3.py index 2fae326d1..653ba723c 100644 --- a/tests/test_llama3.py +++ b/tests/test_llama3.py @@ -26,9 +26,10 @@ def get_config(vocab_size=1000): "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, + "head_dim": 64, "initializer_range": 0.02, "intermediate_size": 14336, - "max_position_embeddings": 8192, + "max_position_embeddings": 131072, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, @@ -55,6 +56,7 @@ def get_config(vocab_size=1000): llama3_8b_config.hidden_size = 16 llama3_8b_config.intermediate_size = 64 llama3_8b_config.num_attention_heads = 4 + llama3_8b_config.head_dim = 4 llama3_8b_config.num_hidden_layers = 4 llama3_8b_config.num_key_value_heads = 2 llama3_8b_config.max_position_embeddings = 128 From b51a3802dc89a8eb9965e0a8bde4bf6bf2b4fec5 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 5 Nov 2024 11:03:55 -0800 Subject: [PATCH 17/19] was updating too many times --- src/levanter/store/cache.py | 47 ++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 794af8e23..d49f4553b 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -933,10 +933,17 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): write_refs[group_name] = ref - # now we start copying the temporary caches to the output cache, in order. (essentially concatenating them) - - ledger = _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, group_ledgers, - group_cache_paths, processor, processor_ref) + ledger = _start_copies( + parent, + cache_dir, + shard_groups, + first_group, + write_refs, + group_ledgers, + group_cache_paths, + processor, + processor_ref, + ) ledger.is_finished = True ledger._serialize_and_commit(cache_dir) @@ -946,8 +953,17 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): _clean_up_temp_caches(temporary_cache_paths) -def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, group_ledgers, group_cache_paths, processor, - processor_ref): +def _start_copies( + parent, + cache_dir, + shard_groups, + first_group, + write_refs, + group_ledgers, + group_cache_paths, + processor, + processor_ref, +): """ Copy the temporary caches to the output cache, in order. (essentially concatenating them) @@ -961,6 +977,9 @@ def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, grou group_cache_paths: a dict mapping group names to the paths of the temporary caches processor: the processor object processor_ref: a ray.ObjectRef of the processor object + + Returns: + The final ledger """ # This logic is a bit hairy thanks to resumes. # First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these @@ -1020,7 +1039,9 @@ def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, grou if found_one_to_copy and shards_copied > 0: raise RuntimeError("A previous group was copied, but this group was not. This should never happen.") elif shards_copied == len(shard_groups[group]): - assert overall_ledger.total_num_rows >= total_rows_from_caches, f"{overall_ledger.total_num_rows} < {total_rows_from_caches}. {group}" + assert ( + overall_ledger.total_num_rows >= total_rows_from_caches + ), f"{overall_ledger.total_num_rows} < {total_rows_from_caches}. {group}" continue # nothing to do elif shards_copied > 0: # In theory we can handle this, but it's a bit tricky, so we're going to punt for now @@ -1051,18 +1072,6 @@ def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, grou ) total_rows_from_caches += this_ledger.total_num_rows - # this little bit is totally unnecessary but nice logging - for group in shard_groups: - if group == first_group: - continue - - if copy_refs.get(group) is not None: - ledger = ray.get(copy_refs[group]) - group_ledgers[group] = ledger - parent._report_copy_progress.remote( - _ProgressReport(new_shards=len(ledger.finished_shards), new_rows=ledger.total_num_rows) - ) - # refs form a linked list implicitly, so we can just wait on the last one if last_ref is not None: ledger = ray.get(last_ref) From f53c99180915050988090a57b90e1d4a8b8a4d77 Mon Sep 17 00:00:00 2001 From: Kamyar Salahi Date: Tue, 5 Nov 2024 19:26:51 -0800 Subject: [PATCH 18/19] Internal eval fixes (#788) Allowing internal supervised eval to work without separate eval set --------- Co-authored-by: David Hall --- config/gpt2_small_fast_supervised.yaml | 1 + src/levanter/data/text.py | 2 +- src/levanter/main/train_lm.py | 8 ++++---- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/config/gpt2_small_fast_supervised.yaml b/config/gpt2_small_fast_supervised.yaml index d71e1267e..93675366d 100644 --- a/config/gpt2_small_fast_supervised.yaml +++ b/config/gpt2_small_fast_supervised.yaml @@ -15,6 +15,7 @@ data: supervised_data: validation_urls: - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz" + - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-validation-evaluation.jsonl.gz" cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/" input_field: "input" output_field: "output" diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 70c1fe4b3..f2bea44b2 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -631,7 +631,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain input_field = config.input_field output_field = config.output_field - output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)} + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index fe5e5dd35..79095d601 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -160,13 +160,13 @@ def main(config: TrainLmConfig): levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) + max_eval_examples_per_ds = config.trainer.max_eval_batches + if max_eval_examples_per_ds is not None: + max_eval_examples_per_ds *= config.trainer.eval_batch_size + if len(tagged_eval_datasets) == 0: logger.warning("No evaluation datasets provided.") else: - max_eval_examples_per_ds = config.trainer.max_eval_batches - if max_eval_examples_per_ds is not None: - max_eval_examples_per_ds *= config.trainer.eval_batch_size - causal_datasets = [ (CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags) for ds, tags in tagged_eval_datasets From 20ff94c80479df25805a5c616c526077efc3620f Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 5 Nov 2024 20:54:40 -0800 Subject: [PATCH 19/19] fix empty shards --- src/levanter/store/cache.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index d49f4553b..5e8657fc0 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -883,6 +883,14 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): group_ledgers: dict[str, CacheLedger | None] = {} write_refs: dict[str, ray.ObjectRef] = {} + if len(source.shard_names) == 0: + logger.info("No shards to process. Writing empty ledger.") + ledger = CacheLedger.load_or_initialize(cache_dir, source, processor) + ledger.is_finished = True + ledger._serialize_and_commit(cache_dir) + ray.get(parent._notify_updated_ledger.remote(ledger)) + return + shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) for name, group in shard_groups.items():