Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix actor pool in python 3.11, add better scaling down logic #760

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix ray on python 3.11, add scaling down
dlwh committed Oct 10, 2024
commit db0ab691f12c3b3b3303224c20b62a332582ea7e
3 changes: 3 additions & 0 deletions config/data/openwebtext_source.yaml
Original file line number Diff line number Diff line change
@@ -4,3 +4,6 @@ validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized/openwebtext/"
tokenizer: "gpt2"
cache_options:
batch_size: 1024
num_shard_groups: 64
48 changes: 37 additions & 11 deletions src/levanter/utils/actor_pool.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,11 @@
# https://github.com/ray-project/ray/blob/1bab09bf842edee51c3778be4cfb16f8b900d764/python/ray/data/_internal/execution/operators/actor_pool_map_operator.py


def _wrap_ray_future(ray_future):
# work around https://github.com/ray-project/ray/issues/45895#issuecomment-2165164129
return asyncio.wrap_future(ray_future.future())


class AutoScalingActorPool:
"""Utility class to operate on a dynamically scaling pool of actors."""

@@ -37,6 +42,7 @@ def __init__(
self._actor_locations: Dict[ray.actor.ActorHandle, str] = {}
self._tasks_waiting_for_actor: list[asyncio.Future] = []
self._next_task_id = 0
self._scale_down_task: Optional[asyncio.Task] = None

self._scale_up(self._min_size)

@@ -45,14 +51,17 @@ def num_pending_tasks(self):
return len(self._tasks_waiting_for_actor)

def _scale_up(self, num_actors: int):
if self._scale_down_task and not self._scale_down_task.done():
self._scale_down_task.cancel()

for _ in range(num_actors):
try:
actor = self._create_actor_fn()
ready_ref = actor.get_location.remote()
self._pending_actors[ready_ref] = actor

async def wait_for_ready(actor, ready_ref):
loc = await ready_ref
loc = await _wrap_ray_future(ready_ref)
# pending -> floating
if ready_ref not in self._pending_actors:
logger.info("Actor was cancelled before it was ready.")
@@ -67,8 +76,8 @@ async def wait_for_ready(actor, ready_ref):
except Exception as e:
logger.error("Failed to create actor.", exc_info=e)

def _scale_down(self, num_actors: int):
for _ in range(num_actors):
def _scale_down(self, target_num_actors: int):
while len(self._idle_actors) + len(self._pending_actors) > target_num_actors:
if self._pending_actors:
actor = self._pending_actors.popitem()[1]
# let it die through gc
@@ -102,10 +111,20 @@ def _adjust_pool_size(self):
f" {self._max_size}"
)
self._scale_up(min(self._max_size - num_busy_actors, num_pending_tasks))

# Schedule scale down if idle
elif num_pending_tasks == 0 and num_nonworking_actors > self._min_size:
return # never scal edown. too many issues
logger.info(f"Scaling down due to no pending tasks. Current pool size: {total_actors}")
self._scale_down(num_nonworking_actors - self._min_size)
if self._scale_down_task is None or self._scale_down_task.done():
self._scale_down_task = asyncio.create_task(self._schedule_scale_down())

async def _schedule_scale_down(self):
try:
await asyncio.sleep(10)
if self.num_pending_tasks == 0:
logger.info("Scaling down due to no pending tasks.")
self._scale_down(self._min_size)
except asyncio.CancelledError:
logger.info("Scale down task was cancelled due to new activity.")

def _get_object_location(self, obj_ref: ray.ObjectRef) -> Optional[str]:
"""Get the location of the given object reference."""
@@ -153,10 +172,11 @@ def _assign_task_to_actor(self, actor, fn, value):
# floating -> busy
ray_future = fn(actor, value)
self._busy_actors[ray_future] = actor
if self._scale_down_task and not self._scale_down_task.done():
self._scale_down_task.cancel()
self._adjust_pool_size()

# return ray_future
return asyncio.ensure_future(self._wrap_ray_future(ray_future))
return asyncio.ensure_future(self._set_up_actor_return_on_finished(ray_future))

async def _enqueue_pending_task(self, fn, obj_ref, value, actor_future):
actor = await actor_future
@@ -181,10 +201,11 @@ def _maybe_start_pending_task(self, actor):
assigned = False
return assigned

async def _wrap_ray_future(self, ray_future):
await asyncio.wait([ray_future])
async def _set_up_actor_return_on_finished(self, ray_future):
future = _wrap_ray_future(ray_future)
await asyncio.wait([future])
self._on_task_done(ray_future)
return await ray_future
return await future

def _on_task_done(self, ray_future):
actor = self._busy_actors.pop(ray_future)
@@ -218,6 +239,11 @@ def push(self, actor: "ray.actor.ActorHandle"):
self._actor_locations[actor] = location
self._maybe_start_pending_task(actor)

def __del__(self):
if self._scale_down_task and not self._scale_down_task.done():
self._scale_down_task.cancel()
# just let ray kill the actors naturally


class PoolWorkerBase(ABC):
def get_location(self) -> str: