From db0ab691f12c3b3b3303224c20b62a332582ea7e Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 10 Oct 2024 16:17:38 -0700 Subject: [PATCH] fix ray on python 3.11, add scaling down --- config/data/openwebtext_source.yaml | 3 ++ src/levanter/utils/actor_pool.py | 48 ++++++++++++++++++++++------- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/config/data/openwebtext_source.yaml b/config/data/openwebtext_source.yaml index 764ee0b9e..6daa695c0 100644 --- a/config/data/openwebtext_source.yaml +++ b/config/data/openwebtext_source.yaml @@ -4,3 +4,6 @@ validation_urls: - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" cache_dir: "gs://levanter-data/tokenized/openwebtext/" tokenizer: "gpt2" +cache_options: + batch_size: 1024 + num_shard_groups: 64 diff --git a/src/levanter/utils/actor_pool.py b/src/levanter/utils/actor_pool.py index 51ba2ccec..76c3ca8fb 100644 --- a/src/levanter/utils/actor_pool.py +++ b/src/levanter/utils/actor_pool.py @@ -15,6 +15,11 @@ # https://github.com/ray-project/ray/blob/1bab09bf842edee51c3778be4cfb16f8b900d764/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py +def _wrap_ray_future(ray_future): + # work around https://github.com/ray-project/ray/issues/45895#issuecomment-2165164129 + return asyncio.wrap_future(ray_future.future()) + + class AutoScalingActorPool: """Utility class to operate on a dynamically scaling pool of actors.""" @@ -37,6 +42,7 @@ def __init__( self._actor_locations: Dict[ray.actor.ActorHandle, str] = {} self._tasks_waiting_for_actor: list[asyncio.Future] = [] self._next_task_id = 0 + self._scale_down_task: Optional[asyncio.Task] = None self._scale_up(self._min_size) @@ -45,6 +51,9 @@ def num_pending_tasks(self): return len(self._tasks_waiting_for_actor) def _scale_up(self, num_actors: int): + if self._scale_down_task and not self._scale_down_task.done(): + self._scale_down_task.cancel() + for _ in range(num_actors): try: actor = self._create_actor_fn() @@ -52,7 +61,7 @@ def _scale_up(self, num_actors: int): self._pending_actors[ready_ref] = actor async def wait_for_ready(actor, ready_ref): - loc = await ready_ref + loc = await _wrap_ray_future(ready_ref) # pending -> floating if ready_ref not in self._pending_actors: logger.info("Actor was cancelled before it was ready.") @@ -67,8 +76,8 @@ async def wait_for_ready(actor, ready_ref): except Exception as e: logger.error("Failed to create actor.", exc_info=e) - def _scale_down(self, num_actors: int): - for _ in range(num_actors): + def _scale_down(self, target_num_actors: int): + while len(self._idle_actors) + len(self._pending_actors) > target_num_actors: if self._pending_actors: actor = self._pending_actors.popitem()[1] # let it die through gc @@ -102,10 +111,20 @@ def _adjust_pool_size(self): f" {self._max_size}" ) self._scale_up(min(self._max_size - num_busy_actors, num_pending_tasks)) + + # Schedule scale down if idle elif num_pending_tasks == 0 and num_nonworking_actors > self._min_size: - return # never scal edown. too many issues - logger.info(f"Scaling down due to no pending tasks. Current pool size: {total_actors}") - self._scale_down(num_nonworking_actors - self._min_size) + if self._scale_down_task is None or self._scale_down_task.done(): + self._scale_down_task = asyncio.create_task(self._schedule_scale_down()) + + async def _schedule_scale_down(self): + try: + await asyncio.sleep(10) + if self.num_pending_tasks == 0: + logger.info("Scaling down due to no pending tasks.") + self._scale_down(self._min_size) + except asyncio.CancelledError: + logger.info("Scale down task was cancelled due to new activity.") def _get_object_location(self, obj_ref: ray.ObjectRef) -> Optional[str]: """Get the location of the given object reference.""" @@ -153,10 +172,11 @@ def _assign_task_to_actor(self, actor, fn, value): # floating -> busy ray_future = fn(actor, value) self._busy_actors[ray_future] = actor + if self._scale_down_task and not self._scale_down_task.done(): + self._scale_down_task.cancel() self._adjust_pool_size() - # return ray_future - return asyncio.ensure_future(self._wrap_ray_future(ray_future)) + return asyncio.ensure_future(self._set_up_actor_return_on_finished(ray_future)) async def _enqueue_pending_task(self, fn, obj_ref, value, actor_future): actor = await actor_future @@ -181,10 +201,11 @@ def _maybe_start_pending_task(self, actor): assigned = False return assigned - async def _wrap_ray_future(self, ray_future): - await asyncio.wait([ray_future]) + async def _set_up_actor_return_on_finished(self, ray_future): + future = _wrap_ray_future(ray_future) + await asyncio.wait([future]) self._on_task_done(ray_future) - return await ray_future + return await future def _on_task_done(self, ray_future): actor = self._busy_actors.pop(ray_future) @@ -218,6 +239,11 @@ def push(self, actor: "ray.actor.ActorHandle"): self._actor_locations[actor] = location self._maybe_start_pending_task(actor) + def __del__(self): + if self._scale_down_task and not self._scale_down_task.done(): + self._scale_down_task.cancel() + # just let ray kill the actors naturally + class PoolWorkerBase(ABC): def get_location(self) -> str: