-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] refactor WorkerGroup state #50181
Conversation
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
self._latest_start_time = time_monotonic() | ||
self._worker_group_state_builder.with_start_time(time_monotonic()) | ||
self._worker_group_state = self._worker_group_state_builder.build() | ||
self._worker_group_state_builder = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we clear this self._worker_group_state_builder
here? Do we still need it to clear the state when shutting down the worker group?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a great question. It is a bit duplicative right now and I want to clean it up more... thinking that WorkerGroup.shutdown
should only ever touch the worker_group_state
and worker_group_state_builder
creation/teardown logic should all be contained within WorkerGroup.create
.
@@ -425,69 +348,52 @@ def shutdown(self, patience_s: float = 5.0): | |||
with invoke_context_managers( | |||
[callback.on_worker_group_shutdown for callback in self._callbacks] | |||
): | |||
if self._workers: | |||
if self._worker_group_state_builder: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Echoing the previous comments. why do we need to check self._worker_group_state_builder
here?
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
worker_group_context = WorkerGroupContext( | ||
num_workers=num_workers, | ||
resources_per_worker=resources_per_worker, | ||
) | ||
self._worker_group_context = worker_group_context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This isn't used right now.
node_id_to_workers = collections.defaultdict(list) | ||
# Launch the training function on each worker. | ||
# This task should start a worker thread and return immediately. | ||
ray_get_safe([worker.actor.run_train_fn.remote(train_fn) for worker in workers]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to try/catch this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, The ray_get_safe
seems not needed anymore. The original issue has been solved. see the issue and PR
I think we can delete the ray_get_safe
from ray train codebase, and just using ray.get()
for now.
For the original question, if any worker raises, we will get the error here, using try ... catch ...
here mostly only catches the bug of the ray.get()
or ray_get_safe
function itself. I think it should be fine to not use try catch here.
def __len__(self) -> int: | ||
return len(self._workers) | ||
##################################################################################### | ||
# Utility Methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if these methods should exist at the existing WorkerGroup
level or WorkerGroupState
.
Essentially wondering if we should split between an "inactive" and "active" WorkerGroup
, with specific methods for each. Can make it more clear for the Caller
if it is handling in the active vs. inactive state, and avoid branching logic in the Worker Group layer to check if it's active everywhere.
Something like:
WorkerGroup
InactiveWorkerGroup
/WorkerGroupFactory
create()
→ActiveWorkerGroup
WorkerGroupState
ActiveWorkerGroup
poll()
shutdown()
__len__
()- ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a little confused why WorkerGroup
maps to a InactiveWorkerGroup
and WorkerGroupState
maps to a ActiveWorkerGroup
. If I understand correctly, WorkerGroupState
contains static information of a worker group. Could you explain more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is that the Controller
would start with the InactiveWorkerGroup
(maybe WorkerGroupFactory
is a better name). Calling start()
would return an ActiveWorkerGroup
which contains all the state that is now captured in WorkerGroupState
, plus additional methods that can be called on an active WorkerGroup
, e.g. poll()
, shutdown()
, execute()
, ....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that makes sense. I originally tried doing something like this but couldn't figure out a good way.
Then, we don't need to worry about calling inappropriate methods depending on active status of worker group.
@staticmethod | ||
def _sort_workers_by_node_id_and_gpu_id( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made these static because they're just generic utilities that don't need/modify any WorkerGroup
state
|
||
def get_workers(self) -> List[Worker]: | ||
return self._workers | ||
# TODO: Access workers through WorkerGroupState instead? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think either way works. slightly prefer to get it from WorkerGroupState
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya that's the same feeling I had which motivated #50181 (comment). Basically just have the caller directly work with the "active" WorkerGroup
rather than needing to support the inactive case within this method, or to implement logic in the caller to check if it's active.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! overall LGTM, left a few comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's my initial high level pass.
def with_sync_actor( | ||
self, sync_actor: SynchronizationActor | ||
) -> "WorkerGroupStateBuilder": | ||
self.sync_actor = sync_actor | ||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking the sync_actor
could be moved to be initialized/shut down by a SyncActorCallback
. The sync actor is only used by the checkpoint module at the moment, so we could package it with that module rather than have it be a generic worker group concept.
I can see how a sync actor would be helpful as a util for users and future Ray Train features though.
The question is, what state should be managed by the WorkerGroup
vs auxiliary callbacks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, I had similar thoughts in the other direction for the Datasets sharding logic... right now I haven't seen any overwhelming evidence for one answer or another.
I do agree that it would be helpful as a generic utility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer the data callback and checkpointing logic to be separated from the controller and worker group as much as possible. I found the backend_executor
pretty hard to work with previously, since everything was being handled together (checkpoint logic, backend setup, ray data logic).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah agreed! We should definitely strive to make all of these as modular and pluggable as possible (which they aren't yet). I just wasn't sure if the callbacks are the right interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the sync_actor
is only used by ray.train.report
function. If it becomes a callback, I think the report
function will need to subscribe to this callback and use its broadcast
function inside ray.train.report
. Not sure if we should build an important API report
on a pluggable component. I feel callback should be a self-contained module. We can discuss the best practices offline.
def _shutdown_sync_actor(self, sync_actor: SynchronizationActor): | ||
ray.kill(sync_actor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, I find this sync actor lifecycle management to be a bit out of place here. Worker group should focus on its workers, not some random utility actor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point, I was thinking it was there to allow us to do some sort of WorkerGroup synchronization.
except Exception as e: | ||
if not self.has_started(): | ||
worker_group_state_builder.shutdown() | ||
raise e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why conditional on if it has not started? Is it even possible for has_started
to be True here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My train of thought for this was that once has_started
is True, then there is a WorkerGroupState
which should be acted on, and the Builder
should basically be discarded and not operated on anymore.
Right now in the code it should not return True here. This could also be changed to assert not self.has_started()
.
def __len__(self) -> int: | ||
return len(self._workers) | ||
##################################################################################### | ||
# Utility Methods |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that makes sense. I originally tried doing something like this but couldn't figure out a good way.
Then, we don't need to worry about calling inappropriate methods depending on active status of worker group.
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
List[Union[WorkerGroupCallback, WorkerCallback, TrainContextCallback]] | ||
] = None, | ||
placement_strategy: str = "PACK", | ||
checkpoint: Optional[Checkpoint] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: In a followup PR, remove checkpoint
from worker group constructor and move it to CheckpointManager
, and populate worker context similar to how the dataset shards are passed to the workers.
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/worker_group/worker_group.py
Outdated
Show resolved
Hide resolved
def before_init_train_context( | ||
self, worker_group: "WorkerGroup" | ||
self, workers: List["Worker"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change this to workers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because the WorkerGroup
itself is still in the middle of its own creation stage.
Calling WorkerGroup.execute
would fail here because the WorkerGroup
is not yet "active" when calling this step. So I wanted to constrain this callback hook to what is "ready" which is the workers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, this callback is the only one that runs into that problem.
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
num_workers=num_workers, | ||
resources_per_worker=resources_per_worker, | ||
placement_strategy=placement_strategy, | ||
checkpoint=latest_checkpoint, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit concerned about putting checkpoint
here. It seems all other fields are static, i.e., won't change in a workergroup life time. However, lastest_checkpoint
will be updated every time we submitted a new checkpoint. Maybe we can migrate this field to another component in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. added a few comments!
Signed-off-by: Matthew Deng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
python/ray/train/v2/_internal/execution/controller/controller.py
Outdated
Show resolved
Hide resolved
@@ -283,16 +284,22 @@ def _restart_worker_group( | |||
) | |||
placement_strategy = self._scaling_policy.scaling_config.placement_strategy | |||
|
|||
worker_group_context = WorkerGroupContext( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to add run attempt ID as part of this context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opted to not put it in yet since we don't do anything with it yet. I will add it when it's used (probably in the upcoming state management PR)
@classmethod | ||
def set_start_failure(cls, start_failure): | ||
cls._start_failure = start_failure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious why this was needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I needed to do this because the instance of the WorkerGroup
is instantiated internally within the controller and changes over time. In test_worker_group_start_failure
I had to inject this failure logic somehow so I did it by setting this and monkeypatching it.
Signed-off-by: Matthew Deng <[email protected]>
#50181 updated the internal `WorkerGroup`, which impacted the `ScalingPolicy` and `FailurePolicy` input APIs. Note that these are all internal-facing developer APIs. This PR restores parity to the information available to the `ScalingPolicy` in the `make_decision_for_running_worker_group` method. In particular, this PR exposes the `WorkerGroupState`, which contains the latest worker group's `start_time`. --------- Signed-off-by: Justin Yu <[email protected]>
Overview
This PR refactors Ray Train's
WorkerGroup
state management to make state transitions (start, shutdown) atomic and to make the code more structured.Key Changes
Split worker group state into separate components:
WorkerGroupContext
: Stores configuration used to start a worker groupWorkerGroupState
: Stores runtime state of an active worker groupWorkerGroupPollStatus
: Stores polling results from workersIntroduced a builder pattern for worker group state management:
WorkerGroupStateBuilder
to handle incremental state construction duringWorkerGroup.start()
Moved worker status and polling logic to dedicated modules:
state.py
module for state management classespoll.py
module for polling-related classesRenamed
WorkerGroupStatus
toWorkerGroupPollStatus
to make it more explicitWorkerStatus
as well →WorkerPollStatus
Benefits
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.