Skip to content

Commit

Permalink
remove a bunch of old unused stuff (#832)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Dec 6, 2024
1 parent 091f1cd commit 19b5f93
Show file tree
Hide file tree
Showing 13 changed files with 0 additions and 1,388 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ dependencies = [
"dataclasses-json~=0.6.4",
"ray[default]>=2.34.0",
"pydantic<3",
"rich~=13.0",
"filelock~=3.13",
"async-lru~=2.0",
"tqdm-loggable>=0.2",
Expand Down
53 changes: 0 additions & 53 deletions src/levanter/data/_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generic, Iterable, Mapping, Sequence, TypeVar, Union

import numpy as np
import pyarrow as pa
import ray

from levanter.utils.actor_pool import AutoScalingActorPool, PoolWorkerBase
from levanter.utils.ray_utils import RefBox


T = TypeVar("T")
Expand Down Expand Up @@ -236,54 +231,6 @@ def to_hf_batched(x):
return {b.field(i).name: to_hf_batched(b.column(i).to_numpy(zero_copy_only=False)) for i in range(b.num_columns)}


@ray.remote(num_cpus=0.1) # keep this low b/c it doesn't do much
class BatchProcessorPool:
def __init__(self, processor: BatchProcessor, min_size: int = 1, max_size: int = 10):
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(message)s")
processor_ref = ray.put(processor)
self.actor_pool = AutoScalingActorPool(
lambda: _create_batch_processor_actor(processor, processor_ref), min_size, max_size
)

async def process_batch(self, batch_ref: RefBox):
return await self.actor_pool.submit(
lambda a, b: a.process_batch.remote(b), batch_ref.ref, obj_ref=batch_ref.ref
)

def num_pending_tasks(self):
return self.actor_pool.num_pending_tasks

def resize_pool(self, *, min_size: int | None = None, max_size: int | None = None):
self.actor_pool.resize_pool(min_size=min_size, max_size=max_size)

def ensure_max_at_least(self, size: int):
self.actor_pool.resize_pool(max_size=max(size, self.actor_pool.get_max_size()))


def _create_batch_processor_actor(processor: BatchProcessor, processor_ref):
cpus = processor.num_cpus
gpus = processor.num_gpus
resources = processor.resources
return _BatchProcessorActor.options( # type: ignore
num_cpus=cpus, num_gpus=gpus, resources=resources, scheduling_strategy="SPREAD"
).remote(processor_ref)


@ray.remote
class _BatchProcessorActor(PoolWorkerBase):
def __init__(self, processor: BatchProcessor):
from levanter.store.tree_store import TreeBatchPreparer

self.processor = processor
self.preparer = TreeBatchPreparer(processor.output_exemplar)

def process_batch(self, batch):
result = self.processor(batch)
result = _canonicalize_batch(result)
prepared = self.preparer(result)
return prepared


def _canonicalize_batch(batch: Union[dict, list[dict]]) -> list[dict]:
if isinstance(batch, pa.RecordBatch):
batch = dict_from_record_batch(batch)
Expand Down
87 changes: 0 additions & 87 deletions src/levanter/data/metrics_monitor.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
import dataclasses
import logging as pylogging
import threading
import time
from dataclasses import dataclass
from typing import Any, Dict, Optional, Protocol, Union

import jax
from dataclasses_json import dataclass_json
from rich.progress import (
BarColumn,
Progress,
TaskID,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)

import levanter.tracker

Expand All @@ -35,53 +25,6 @@ def __call__(self, metrics: InProgressCacheMetrics):
...


class RichMetricsMonitor(MetricsMonitor):

progress: Optional[Progress] # type: ignore
task: Optional[TaskID]

def __init__(self, num_shards, **kwargs):
"""kwargs are passed to rich.progress.Progress"""
self.kwargs = kwargs
self.progress: Optional[Progress] = None
self.task = None
self.num_shards = num_shards

def __call__(self, metrics: InProgressCacheMetrics):
if self.progress is None:
self._init_progress(metrics)

self.progress.update(self.task, completed=metrics.shards_finished, **dataclasses.asdict(metrics)) # type: ignore

self.progress.refresh() # type: ignore

if metrics.is_finished:
self.progress.stop() # type: ignore

def _init_progress(self, metrics):
columns = [
BarColumn(),
TaskProgressColumn(),
TextColumn("| {task.fields[rows_finished]} docs", justify="center"),
]

for field in metrics.field_counts:
columns.append(TextColumn(f"| {{task.fields[field_counts][{field}]}} {field}", justify="center"))

columns.append(TimeElapsedColumn())
columns.append(TimeRemainingColumn())

self.progress = Progress(
*columns,
**self.kwargs,
)

self.task = self.progress.add_task(
"Shards", total=self.num_shards, completed=metrics.shards_finished, **dataclasses.asdict(metrics)
)
self.progress.start()


class LoggingMetricsMonitor(MetricsMonitor):
last_metrics: Optional[InProgressCacheMetrics]
last_time: Optional[float]
Expand Down Expand Up @@ -109,16 +52,6 @@ def __call__(self, metrics: InProgressCacheMetrics):
if metrics.is_finished:
to_log[f"{self.prefix}/finished"] = 1

# estimate the rate of progress
# if self.last_metrics is not None:
# assert self.last_time is not None
# elapsed = time.time() - self.last_time
# to_log[f"{self.prefix}/shards_per_s"] = (metrics.shards_finished - self.last_metrics.shards_finished) / elapsed
# to_log[f"{self.prefix}/rows_per_s"] = (metrics.rows_finished - self.last_metrics.rows_finished) / elapsed
#
# for field, count in metrics.field_counts.items():
# to_log[f"{self.prefix}/{field}_per_s"] = (count - self.last_metrics.field_counts[field]) / elapsed

self.last_metrics = metrics
self.last_time = time.time()

Expand Down Expand Up @@ -153,23 +86,3 @@ def __call__(self, metrics: InProgressCacheMetrics):

if metrics.is_finished:
self.logger.info("Cache creation finished")


class WaitTimeReportingThread(threading.Thread):
def __init__(self, report, interval=60):
super().__init__()
self.report = report
self.interval = interval
self.shutdown_event = threading.Event()

def run(self):
total_waited = 0
while True:
if self.shutdown_event.wait(self.interval):
break
if total_waited > 0:
self.report(total_waited)
total_waited += self.interval

def shutdown(self):
self.shutdown_event.set()
Empty file removed src/levanter/data/shard_cache.py
Empty file.
58 changes: 0 additions & 58 deletions src/levanter/mesh.py

This file was deleted.

114 changes: 0 additions & 114 deletions src/levanter/models/longformer.py

This file was deleted.

Loading

0 comments on commit 19b5f93

Please sign in to comment.