Skip to content

Commit

Permalink
Add an actor pool for batch processing, switch to a thread for writin…
Browse files Browse the repository at this point in the history
…g batches instead of a ray actor/task (#757)

About a 5x speedup. Memory usage isn't super well controlled in mixtures
and that needs some work
  • Loading branch information
dlwh authored Oct 10, 2024
1 parent 8bed0aa commit adf4b6d
Show file tree
Hide file tree
Showing 14 changed files with 874 additions and 286 deletions.
15 changes: 9 additions & 6 deletions config/data/pile_mixture.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
cache_dir: "gs://levanter-data/tokenized/pile-domains/"
tokenizer: "EleutherAI/gpt-neox-20b"
cache_options:
batch_size: 32
num_shard_groups: 16
configs:
arxiv:
train_urls:
Expand All @@ -11,11 +14,11 @@ configs:
- gs://levanter-data/pile-domains/books2/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/books2/val.jsonl.zst
books3:
train_urls:
- gs://levanter-data/pile-domains/books3/{00..29}.jsonl.zst
validation_urls:
- gs://levanter-data/pile-domains/books3/val.jsonl.zst
# books3:
# train_urls:
# - gs://levanter-data/pile-domains/books3/{00..29}.jsonl.zst
# validation_urls:
# - gs://levanter-data/pile-domains/books3/val.jsonl.zst
dm_math:
train_urls:
- gs://levanter-data/pile-domains/dm_math/{00..29}.jsonl.zst
Expand Down Expand Up @@ -115,7 +118,7 @@ train_weights:
# these weights come from the paper https://arxiv.org/pdf/2101.00027.pdf
pile_cc: 0.1811
pubmed_central: 0.1440
books3: 0.1207
# books3: 0.1207
owt2: 0.1001
arxiv: 0.0896
github: 0.0759
Expand Down
72 changes: 69 additions & 3 deletions src/levanter/data/_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generic, Iterable, Mapping, Sequence, TypeVar, Union

import numpy as np
import pyarrow as pa
import ray

from levanter.utils.actor_pool import AutoScalingActorPool, PoolWorkerBase
from levanter.utils.ray_utils import RefBox


T = TypeVar("T")
Expand Down Expand Up @@ -143,12 +148,12 @@ def rec(dataset):

source, transforms, batch_transform = rec(dataset)

batch_size = batch_transform.batch_size if batch_transform is not None else 1024
# batch_size = batch_transform.batch_size if batch_transform is not None else 1024
cpus = batch_transform.num_cpus if batch_transform is not None else 1
gpus = batch_transform.num_gpus if batch_transform is not None else 0
resources = batch_transform.resources if batch_transform is not None else {}

return source, _CompositeBatchProcessor(transforms, batch_size, cpus, gpus, resources)
return source, _CompositeBatchProcessor(transforms, cpus, gpus, resources)


class _CompositeBatchProcessor(BatchProcessor):
Expand All @@ -157,7 +162,6 @@ def __init__(self, transforms, num_cpus, num_gpus, resources):
self._num_cpus = num_cpus
self._num_gpus = num_gpus
self._resources = resources
self._batch_size = batch_size

@property
def batch_size(self):
Expand Down Expand Up @@ -230,3 +234,65 @@ def to_hf_batched(x):
return x

return {b.field(i).name: to_hf_batched(b.column(i).to_numpy(zero_copy_only=False)) for i in range(b.num_columns)}


@ray.remote(num_cpus=0)
class BatchProcessorPool:
def __init__(self, processor: BatchProcessor, min_size: int = 1, max_size: int = 10):
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(message)s")
processor_ref = ray.put(processor)
self.actor_pool = AutoScalingActorPool(
lambda: _create_batch_processor_actor(processor, processor_ref), min_size, max_size
)

async def process_batch(self, batch_ref: RefBox):
return await self.actor_pool.submit(
lambda a, b: a.process_batch.remote(b), batch_ref.ref, obj_ref=batch_ref.ref
)

def num_pending_tasks(self):
return self.actor_pool.num_pending_tasks


def _create_batch_processor_actor(processor: BatchProcessor, processor_ref):
cpus = processor.num_cpus
gpus = processor.num_gpus
resources = processor.resources
return _BatchProcessorActor.options( # type: ignore
num_cpus=cpus, num_gpus=gpus, resources=resources, scheduling_strategy="SPREAD"
).remote(processor_ref)


@ray.remote
class _BatchProcessorActor(PoolWorkerBase):
def __init__(self, processor: BatchProcessor):
from levanter.store.tree_store import TreeBatchPreparer

self.processor = processor
self.preparer = TreeBatchPreparer(processor.output_exemplar)

def process_batch(self, batch):
result = self.processor(batch)
result = _canonicalize_batch(result)
prepared = self.preparer(result)
return prepared


def _canonicalize_batch(batch: Union[dict, list[dict]]) -> list[dict]:
if isinstance(batch, pa.RecordBatch):
batch = dict_from_record_batch(batch)

if isinstance(batch, dict):
return _to_list_of_dicts(batch)
else:
return batch


def _to_list_of_dicts(batch: dict) -> list[dict]:
"""
Convert a batch of dictionaries to a list of dictionaries, suitable for writing to a cache.
"""
keys = list(batch.keys())
values = list(batch.values())
num_rows = len(values[0])
return [{key: values[i][j] for i, key in enumerate(keys)} for j in range(num_rows)]
8 changes: 2 additions & 6 deletions src/levanter/data/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def build_or_load(
override_resources=None,
max_length=448,
cache_options: CacheOptions = CacheOptions.default(),
split: str = "",
) -> "ProcessedAudioCache":
bp = BatchAudioProcessor(
processor,
Expand All @@ -316,12 +317,7 @@ def build_or_load(
)
monitors = monitors or []
cache = build_or_load_cache(
cache_dir,
source,
bp,
await_finished=await_finished,
monitors=monitors,
options=cache_options,
cache_dir, source, bp, await_finished=await_finished, monitors=monitors, options=cache_options, split=split
)
if cache.is_finished:
logger.info(f"Cache {cache_dir} is complete.")
Expand Down
1 change: 1 addition & 0 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ def build_or_load_cache(
monitors=monitors,
await_finished=False,
options=self.cache_options,
split=split,
)


Expand Down
1 change: 1 addition & 0 deletions src/levanter/main/cache_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def main(args: RayCachedLMDatasetConfig):
processor=batch_tokenizer,
await_finished=False,
monitors=monitors,
split=split,
)

cache.await_finished()
Expand Down
Loading

0 comments on commit adf4b6d

Please sign in to comment.