Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] refactor WorkerGroup state #50181

Merged
merged 22 commits into from
Feb 7, 2025

Conversation

matthewdeng
Copy link
Contributor

Overview

This PR refactors Ray Train's WorkerGroup state management to make state transitions (start, shutdown) atomic and to make the code more structured.

image

Key Changes

  1. Split worker group state into separate components:

    • WorkerGroupContext: Stores configuration used to start a worker group
    • WorkerGroupState: Stores runtime state of an active worker group
    • WorkerGroupPollStatus: Stores polling results from workers
  2. Introduced a builder pattern for worker group state management:

    • Added WorkerGroupStateBuilder to handle incremental state construction during WorkerGroup.start()
    • Improved error handling during worker group startup
  3. Moved worker status and polling logic to dedicated modules:

    • Created new state.py module for state management classes
    • Created new poll.py module for polling-related classes
  4. Renamed WorkerGroupStatus to WorkerGroupPollStatus to make it more explicit

    • TODO: Do this for WorkerStatus as well → WorkerPollStatus

Benefits

  • Clearer separation of concerns between configuration, runtime state, and polling status
  • Cleaner state management during worker group lifecycle
  • Improved code organization and maintainability
  • Better error handling during worker group startup and shutdown

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
@hongpeng-guo hongpeng-guo self-assigned this Feb 3, 2025
self._latest_start_time = time_monotonic()
self._worker_group_state_builder.with_start_time(time_monotonic())
self._worker_group_state = self._worker_group_state_builder.build()
self._worker_group_state_builder = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we clear this self._worker_group_state_builder here? Do we still need it to clear the state when shutting down the worker group?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a great question. It is a bit duplicative right now and I want to clean it up more... thinking that WorkerGroup.shutdown should only ever touch the worker_group_state and worker_group_state_builder creation/teardown logic should all be contained within WorkerGroup.create.

@@ -425,69 +348,52 @@ def shutdown(self, patience_s: float = 5.0):
with invoke_context_managers(
[callback.on_worker_group_shutdown for callback in self._callbacks]
):
if self._workers:
if self._worker_group_state_builder:
Copy link
Contributor

@hongpeng-guo hongpeng-guo Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Echoing the previous comments. why do we need to check self._worker_group_state_builder here?

Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Comment on lines 203 to 207
worker_group_context = WorkerGroupContext(
num_workers=num_workers,
resources_per_worker=resources_per_worker,
)
self._worker_group_context = worker_group_context
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This isn't used right now.

node_id_to_workers = collections.defaultdict(list)
# Launch the training function on each worker.
# This task should start a worker thread and return immediately.
ray_get_safe([worker.actor.run_train_fn.remote(train_fn) for worker in workers])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to try/catch this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, The ray_get_safe seems not needed anymore. The original issue has been solved. see the issue and PR

I think we can delete the ray_get_safe from ray train codebase, and just using ray.get() for now.

For the original question, if any worker raises, we will get the error here, using try ... catch ... here mostly only catches the bug of the ray.get() or ray_get_safe function itself. I think it should be fine to not use try catch here.

def __len__(self) -> int:
return len(self._workers)
#####################################################################################
# Utility Methods
Copy link
Contributor Author

@matthewdeng matthewdeng Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if these methods should exist at the existing WorkerGroup level or WorkerGroupState.

Essentially wondering if we should split between an "inactive" and "active" WorkerGroup, with specific methods for each. Can make it more clear for the Caller if it is handling in the active vs. inactive state, and avoid branching logic in the Worker Group layer to check if it's active everywhere.

Something like:

  • WorkerGroup InactiveWorkerGroup/WorkerGroupFactory
    • create()ActiveWorkerGroup
  • WorkerGroupState ActiveWorkerGroup
    • poll()
    • shutdown()
    • __len__()
    • ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little confused why WorkerGroup maps to a InactiveWorkerGroup and WorkerGroupState maps to a ActiveWorkerGroup. If I understand correctly, WorkerGroupState contains static information of a worker group. Could you explain more?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that the Controller would start with the InactiveWorkerGroup (maybe WorkerGroupFactory is a better name). Calling start() would return an ActiveWorkerGroup which contains all the state that is now captured in WorkerGroupState, plus additional methods that can be called on an active WorkerGroup, e.g. poll(), shutdown(), execute(), ....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes sense. I originally tried doing something like this but couldn't figure out a good way.

Then, we don't need to worry about calling inappropriate methods depending on active status of worker group.

Comment on lines +639 to +640
@staticmethod
def _sort_workers_by_node_id_and_gpu_id(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made these static because they're just generic utilities that don't need/modify any WorkerGroup state


def get_workers(self) -> List[Worker]:
return self._workers
# TODO: Access workers through WorkerGroupState instead?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think either way works. slightly prefer to get it from WorkerGroupState

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya that's the same feeling I had which motivated #50181 (comment). Basically just have the caller directly work with the "active" WorkerGroup rather than needing to support the inactive case within this method, or to implement logic in the caller to check if it's active.

Copy link
Contributor

@hongpeng-guo hongpeng-guo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! overall LGTM, left a few comments.

Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's my initial high level pass.

Comment on lines +106 to +110
def with_sync_actor(
self, sync_actor: SynchronizationActor
) -> "WorkerGroupStateBuilder":
self.sync_actor = sync_actor
return self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the sync_actor could be moved to be initialized/shut down by a SyncActorCallback. The sync actor is only used by the checkpoint module at the moment, so we could package it with that module rather than have it be a generic worker group concept.

I can see how a sync actor would be helpful as a util for users and future Ray Train features though.

The question is, what state should be managed by the WorkerGroup vs auxiliary callbacks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, I had similar thoughts in the other direction for the Datasets sharding logic... right now I haven't seen any overwhelming evidence for one answer or another.

I do agree that it would be helpful as a generic utility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the data callback and checkpointing logic to be separated from the controller and worker group as much as possible. I found the backend_executor pretty hard to work with previously, since everything was being handled together (checkpoint logic, backend setup, ray data logic).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah agreed! We should definitely strive to make all of these as modular and pluggable as possible (which they aren't yet). I just wasn't sure if the callbacks are the right interface.

Copy link
Contributor

@hongpeng-guo hongpeng-guo Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the sync_actor is only used by ray.train.report function. If it becomes a callback, I think the report function will need to subscribe to this callback and use its broadcast function inside ray.train.report. Not sure if we should build an important API report on a pluggable component. I feel callback should be a self-contained module. We can discuss the best practices offline.

Comment on lines 38 to 39
def _shutdown_sync_actor(self, sync_actor: SynchronizationActor):
ray.kill(sync_actor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, I find this sync actor lifecycle management to be a bit out of place here. Worker group should focus on its workers, not some random utility actor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, I was thinking it was there to allow us to do some sort of WorkerGroup synchronization.

Comment on lines 182 to 185
except Exception as e:
if not self.has_started():
worker_group_state_builder.shutdown()
raise e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why conditional on if it has not started? Is it even possible for has_started to be True here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My train of thought for this was that once has_started is True, then there is a WorkerGroupState which should be acted on, and the Builder should basically be discarded and not operated on anymore.

Right now in the code it should not return True here. This could also be changed to assert not self.has_started().

def __len__(self) -> int:
return len(self._workers)
#####################################################################################
# Utility Methods
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes sense. I originally tried doing something like this but couldn't figure out a good way.

Then, we don't need to worry about calling inappropriate methods depending on active status of worker group.

Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
List[Union[WorkerGroupCallback, WorkerCallback, TrainContextCallback]]
] = None,
placement_strategy: str = "PACK",
checkpoint: Optional[Checkpoint] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: In a followup PR, remove checkpoint from worker group constructor and move it to CheckpointManager, and populate worker context similar to how the dataset shards are passed to the workers.

python/ray/train/v2/tests/test_controller.py Outdated Show resolved Hide resolved
python/ray/train/v2/tests/test_controller.py Outdated Show resolved Hide resolved
Comment on lines 21 to +22
def before_init_train_context(
self, worker_group: "WorkerGroup"
self, workers: List["Worker"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this to workers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because the WorkerGroup itself is still in the middle of its own creation stage.

Calling WorkerGroup.execute would fail here because the WorkerGroup is not yet "active" when calling this step. So I wanted to constrain this callback hook to what is "ready" which is the workers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, this callback is the only one that runs into that problem.

Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
Signed-off-by: Matthew Deng <[email protected]>
num_workers=num_workers,
resources_per_worker=resources_per_worker,
placement_strategy=placement_strategy,
checkpoint=latest_checkpoint,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit concerned about putting checkpoint here. It seems all other fields are static, i.e., won't change in a workergroup life time. However, lastest_checkpoint will be updated every time we submitted a new checkpoint. Maybe we can migrate this field to another component in a follow-up PR.

Copy link
Contributor

@hongpeng-guo hongpeng-guo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. added a few comments!

Signed-off-by: Matthew Deng <[email protected]>
@matthewdeng matthewdeng marked this pull request as ready for review February 7, 2025 18:48
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@@ -283,16 +284,22 @@ def _restart_worker_group(
)
placement_strategy = self._scaling_policy.scaling_config.placement_strategy

worker_group_context = WorkerGroupContext(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to add run attempt ID as part of this context?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted to not put it in yet since we don't do anything with it yet. I will add it when it's used (probably in the upcoming state management PR)

Comment on lines +88 to +90
@classmethod
def set_start_failure(cls, start_failure):
cls._start_failure = start_failure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why this was needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to do this because the instance of the WorkerGroup is instantiated internally within the controller and changes over time. In test_worker_group_start_failure I had to inject this failure logic somehow so I did it by setting this and monkeypatching it.

python/ray/train/v2/tests/test_worker_group.py Outdated Show resolved Hide resolved
python/ray/train/v2/tests/test_worker_group.py Outdated Show resolved Hide resolved
Signed-off-by: Matthew Deng <[email protected]>
@matthewdeng matthewdeng enabled auto-merge (squash) February 7, 2025 21:44
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Feb 7, 2025
@matthewdeng matthewdeng merged commit b4279f3 into ray-project:master Feb 7, 2025
7 checks passed
justinvyu added a commit that referenced this pull request Feb 10, 2025
#50181 updated the internal
`WorkerGroup`, which impacted the `ScalingPolicy` and `FailurePolicy`
input APIs. Note that these are all internal-facing developer APIs.

This PR restores parity to the information available to the
`ScalingPolicy` in the `make_decision_for_running_worker_group` method.
In particular, this PR exposes the `WorkerGroupState`, which contains
the latest worker group's `start_time`.

---------

Signed-off-by: Justin Yu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants