From 28d1ec8c5eb5103c524ad4db580d657733dbbd30 Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Fri, 17 Jan 2025 13:58:38 -0800 Subject: [PATCH 1/8] [jobs] CANCELLING is not terminal This discrepancy caused issues, such as jobs getting stuck as CANCELLING when the job controller process crashes during cleanup. --- sky/jobs/state.py | 7 +++---- sky/jobs/utils.py | 9 +++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 420a3adbd82..9bb1db4528c 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -230,12 +230,12 @@ class ManagedJobStatus(enum.Enum): # RECOVERING: The cluster is preempted, and the controller process is # recovering the cluster (relaunching/failover). RECOVERING = 'RECOVERING' - # Terminal statuses - # SUCCEEDED: The job is finished successfully. - SUCCEEDED = 'SUCCEEDED' # CANCELLING: The job is requested to be cancelled by the user, and the # controller is cleaning up the cluster. CANCELLING = 'CANCELLING' + # Terminal statuses + # SUCCEEDED: The job is finished successfully. + SUCCEEDED = 'SUCCEEDED' # CANCELLED: The job is cancelled by the user. When the managed job is in # CANCELLED status, the cluster has been cleaned up. CANCELLED = 'CANCELLED' @@ -281,7 +281,6 @@ def terminal_statuses(cls) -> List['ManagedJobStatus']: cls.FAILED_PRECHECKS, cls.FAILED_NO_RESOURCE, cls.FAILED_CONTROLLER, - cls.CANCELLING, cls.CANCELLED, ] diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index dcedaa590a5..78e43ef9d24 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -694,10 +694,7 @@ def stream_logs(job_id: Optional[int], if job_status is None: with ux_utils.print_exception_no_traceback(): raise ValueError(f'Job {job_id} not found.') - # We shouldn't count CANCELLING as terminal here, the controller is - # still cleaning up. - if (job_status.is_terminal() and job_status != - managed_job_state.ManagedJobStatus.CANCELLING): + if job_status.is_terminal(): # Don't keep waiting. If the log file is not created by this # point, it never will be. This job may have been submitted # using an old version that did not create the log file, so this @@ -729,6 +726,10 @@ def stream_logs(job_id: Optional[int], print(end='', flush=True) # Check if the job if finished. + # TODO(cooperc): The controller can still be + # cleaning up if job is in a terminal status + # (e.g. SUCCEEDED). We want to follow those logs + # too. Use DONE instead? job_status = managed_job_state.get_status(job_id) assert job_status is not None, (job_id, job_name) if job_status.is_terminal(): From e9af1caab15ad114bfc31fdd3b93d5160d6794bb Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Tue, 21 Jan 2025 13:58:32 -0800 Subject: [PATCH 2/8] revamp nonterminal status checking --- sky/jobs/controller.py | 5 +- sky/jobs/state.py | 88 ++++++++++++++ sky/jobs/utils.py | 256 +++++++++++++++++++++++----------------- sky/skylet/constants.py | 2 +- sky/skylet/events.py | 2 +- 5 files changed, 243 insertions(+), 110 deletions(-) diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 553c3d946d3..3dc718e3c65 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -1,4 +1,7 @@ -"""Controller: handles the life cycle of a managed job.""" +"""Controller: handles the life cycle of a managed job. + +TODO(cooperc): Document lifecycle, and multiprocess layout. +""" import argparse import multiprocessing import os diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 9bb1db4528c..555bdf658b3 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -554,6 +554,53 @@ def set_failed( logger.info(failure_reason) +def set_failed_controller( + job_id: int, + failure_reason: str, + callback_func: Optional[CallbackType] = None, +): + """Set the status to FAILED_CONTROLLER for all tasks of the job. + + Unlike set_failed(), this will override any existing status, including + terminal ones. This should only be used when the controller process has died + abnormally. + + For jobs that already have an end_time set, we preserve that time instead of + overwriting it with the current time. + + Args: + job_id: The job id. + failure_reason: The failure reason. + """ + now = time.time() + fields_to_set: Dict[str, Any] = { + 'status': ManagedJobStatus.FAILED_CONTROLLER.value, + 'failure_reason': failure_reason, + } + with db_utils.safe_cursor(_DB_PATH) as cursor: + previous_status = cursor.execute( + 'SELECT status FROM spot WHERE spot_job_id=(?)', + (job_id,)).fetchone()[0] + previous_status = ManagedJobStatus(previous_status) + if previous_status == ManagedJobStatus.RECOVERING: + # If the job is recovering, we should set the last_recovered_at to + # the current time, so that the end_at - last_recovered_at will not + # affect the job duration calculation. + fields_to_set['last_recovered_at'] = now + + set_str = ', '.join(f'{k}=(?)' for k in fields_to_set) + + cursor.execute( + f"""\ + UPDATE spot SET + end_at = COALESCE(end_at, ?), + {set_str} + WHERE spot_job_id=(?)""", (now, *fields_to_set.values(), job_id)) + if callback_func: + callback_func('FAILED') + logger.info(failure_reason) + + def set_cancelling(job_id: int, callback_func: CallbackType): """Set tasks in the job as cancelling, if they are in non-terminal states. @@ -676,6 +723,47 @@ def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]: return jobs +def get_jobs_to_check(job_id: Optional[int] = None) -> List[int]: + """Get jobs that need controller process checking. + + Returns: + - For jobs with schedule state: jobs that have schedule state not DONE + - For legacy jobs (no schedule state): jobs that are in non-terminal status + + Args: + job_id: Optional job ID to check. If None, checks all jobs. + """ + job_filter = '' if job_id is None else 'AND spot.spot_job_id=(?)' + job_value = () if job_id is None else (job_id,) + + statuses = ', '.join(['?'] * len(ManagedJobStatus.terminal_statuses())) + field_values = [ + status.value for status in ManagedJobStatus.terminal_statuses() + ] + + # Get jobs that are either: + # 1. Have schedule state that is not DONE, or + # 2. Have no schedule state (legacy) AND are in non-terminal status + with db_utils.safe_cursor(_DB_PATH) as cursor: + rows = cursor.execute( + f"""\ + SELECT DISTINCT spot.spot_job_id + FROM spot + LEFT OUTER JOIN job_info + ON spot.spot_job_id=job_info.spot_job_id + WHERE ( + (job_info.schedule_state IS NOT NULL AND + job_info.schedule_state IS NOT ?) + OR + (job_info.schedule_state IS NULL AND status NOT IN ({statuses})) + ) + {job_filter} + ORDER BY spot.spot_job_id DESC""", + [ManagedJobScheduleState.DONE.value, *field_values, *job_value + ]).fetchall() + return [row[0] for row in rows if row[0] is not None] + + def get_all_job_ids_by_name(name: Optional[str]) -> List[int]: """Get all job ids by name.""" name_filter = '' diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 78e43ef9d24..2d4e8aba9dd 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -159,7 +159,7 @@ def _controller_process_alive(pid: int, job_id: int) -> bool: return False -def update_managed_job_status(job_id: Optional[int] = None): +def update_managed_jobs_statuses(job_id: Optional[int] = None): """Update managed job status if the controller process failed abnormally. Check the status of the controller process. If it is not running, it must @@ -168,125 +168,167 @@ def update_managed_job_status(job_id: Optional[int] = None): when above happens, which could be not accurate based on the frequency this function is called. - Note: we expect that job_id, if provided, refers to a nonterminal job. + Note: we expect that job_id, if provided, refers to a nonterminal job or a + job that has not completed its cleanup (schedule state not DONE). """ - if job_id is None: - # Warning: it's totally possible for the managed job to transition to - # a terminal status during the course of this function. The set_failed() - # called below will not update the state for jobs that already have a - # terminal status, so it should be fine. - job_ids = managed_job_state.get_nonterminal_job_ids_by_name(None) - else: - job_ids = [job_id] - for job_id_ in job_ids: + def _cleanup_job_clusters(job_id: int) -> Optional[str]: + """Clean up clusters for a job. Returns error message if any. - failure_reason = None - - tasks = managed_job_state.get_managed_jobs(job_id_) - schedule_state = tasks[0]['schedule_state'] - if schedule_state is None: - # Backwards compatibility: this job was submitted when ray was still - # used for managing the parallelism of job controllers. - # TODO(cooperc): Remove before 0.11.0. - controller_status = job_lib.get_status(job_id_) - if controller_status is None or controller_status.is_terminal(): - logger.error(f'Controller process for legacy job {job_id_} is ' - 'in an unexpected state.') - failure_reason = 'Legacy job is in an unexpected state' - - # Continue to mark the job as failed. - else: - # Still running. - continue - else: - pid = tasks[0]['controller_pid'] - if pid is None: - if schedule_state in ( - managed_job_state.ManagedJobScheduleState.INACTIVE, - managed_job_state.ManagedJobScheduleState.WAITING): - # Job has not been scheduled yet. - continue - elif (schedule_state == - managed_job_state.ManagedJobScheduleState.LAUNCHING): - # This should only be the case for a very short period of - # time between marking the job as submitted and writing the - # launched controller process pid back to the database (see - # scheduler.maybe_schedule_next_jobs). - # TODO(cooperc): Find a way to detect if we get stuck in - # this state. - logger.info(f'Job {job_id_} is in LAUNCHING state, ' - 'but controller process hasn\'t started yet.') - continue - # All other statuses are unexpected. Proceed to mark as failed. - logger.error(f'Expected to find a controller pid for state ' - f'{schedule_state.value} but found none.') - failure_reason = ('No controller pid set for ' - f'{schedule_state.value}') - else: - logger.debug(f'Checking controller pid {pid}') - if _controller_process_alive(pid, job_id_): - # The controller is still running. - continue - # Otherwise, proceed to mark the job as failed. - logger.error(f'Controller process for {job_id_} seems to be ' - 'dead.') - failure_reason = 'Controller process is dead' - - logger.error(f'Controller process for job {job_id_} has exited ' - 'abnormally. Setting the job status to FAILED_CONTROLLER.') + This function should not throw any exception. If it fails, it will + capture the error message, and log/return it. + """ + error_msg = None + tasks = managed_job_state.get_managed_jobs(job_id) for task in tasks: task_name = task['job_name'] - # Tear down the abnormal cluster to avoid resource leakage. - cluster_name = generate_managed_job_cluster_name(task_name, job_id_) + cluster_name = generate_managed_job_cluster_name(task_name, job_id) handle = global_user_state.get_handle_from_cluster_name( cluster_name) - # If the cluster exists, terminate it. if handle is not None: - terminate_cluster(cluster_name) + try: + terminate_cluster(cluster_name) + except Exception as e: # pylint: disable=broad-except + error_msg = (f'Failed to terminate cluster {cluster_name}: ' + f'{str(e)}') + logger.exception(error_msg, exc_info=e) + return error_msg + + # For backwards compatible jobs + # TODO(cooperc): Remove before 0.11.0. + def _handle_legacy_job(job_id: int): + controller_status = job_lib.get_status(job_id) + if controller_status is None or controller_status.is_terminal(): + logger.error(f'Controller process for legacy job {job_id} is ' + 'in an unexpected state.') + + cleanup_error = _cleanup_job_clusters(job_id) + if cleanup_error: + # Unconditionally set the job to failed_controller if the + # cleanup fails. + managed_job_state.set_failed_controller( + job_id, + failure_reason= + f'Legacy controller process for {job_id} exited ' + f'abnormally, and cleanup failed: {cleanup_error}. For ' + f'more details, run: sky jobs logs --controller {job_id}') + return + + # It's possible for the job to have transitioned to + # another terminal state while between when we checked its + # state and now. In that case, set_failed won't do + # anything, which is fine. + managed_job_state.set_failed( + job_id, + task_id=None, + failure_type=managed_job_state.ManagedJobStatus. + FAILED_CONTROLLER, + failure_reason= + 'Legacy controller process has exited abnormally. ' + f'For more details, run: sky jobs logs --controller {job_id}') + + # Get jobs that need checking (non-terminal or not DONE) + job_ids = managed_job_state.get_jobs_to_check(job_id) + if not job_ids: + # job_id is already terminal, or if job_id is None, there are no jobs + # that need to be checked. + return + + for job_id in job_ids: + tasks = managed_job_state.get_managed_jobs(job_id) + # Note: controller_pid and schedule_state are in the job_info table + # which is joined to the spot table, so all tasks with the same job_id + # will have the same value for these columns. This is what lets us just + # take tasks[0]['controller_pid'] and tasks[0]['schedule_state']. + schedule_state = tasks[0]['schedule_state'] + + # Backwards compatibility: this job was submitted when ray was still + # used for managing the parallelism of job controllers. + # TODO(cooperc): Remove before 0.11.0. + if (schedule_state is None or schedule_state is + managed_job_state.ManagedJobScheduleState.INVALID): + _handle_legacy_job(job_id) + continue + + # For jobs with schedule state: + pid = tasks[0]['controller_pid'] + if pid is None: + if schedule_state in ( + managed_job_state.ManagedJobScheduleState.INACTIVE, + managed_job_state.ManagedJobScheduleState.WAITING): + # For these states, the controller hasn't been started yet. + # This is expected. + continue + + if (schedule_state == + managed_job_state.ManagedJobScheduleState.LAUNCHING): + # This is unlikely but technically possible. There's a brief + # period between marking job as scheduled (LAUNCHING) and + # actually launching the controller process and writing the pid + # back to the table. + # TODO(cooperc): Find a way to detect if we get stuck in this + # state. + logger.info(f'Job {job_id} is in {schedule_state.value} state, ' + 'but controller process hasn\'t started yet.') + continue + + logger.error(f'Expected to find a controller pid for state ' + f'{schedule_state.value} but found none.') + failure_reason = f'No controller pid set for {schedule_state.value}' + else: + logger.debug(f'Checking controller pid {pid}') + if _controller_process_alive(pid, job_id): + # The controller is still running, so this job is fine. + continue + + # Double check job is not already DONE before marking as failed, to + # avoid the race where the controller marked itself as DONE and + # exited between the state check and the pid check. Since the job + # controller process will mark itself DONE _before_ exiting, if it + # has exited and it's still not DONE now, it is abnormal. + if (managed_job_state.get_job_schedule_state(job_id) == + managed_job_state.ManagedJobScheduleState.DONE): + # Never mind, the job is DONE now. This is fine. + continue + + logger.error(f'Controller process for {job_id} seems to be dead.') + failure_reason = 'Controller process is dead' + + # At this point, either pid is None or process is dead. # The controller process for this managed job is not running: it must # have exited abnormally, and we should set the job status to # FAILED_CONTROLLER. - # The `set_failed` will only update the task's status if the - # status is non-terminal. - managed_job_state.set_failed( - job_id_, - task_id=None, - failure_type=managed_job_state.ManagedJobStatus.FAILED_CONTROLLER, + logger.error(f'Controller process for job {job_id} has exited ' + 'abnormally. Setting the job status to FAILED_CONTROLLER.') + + # Cleanup clusters and capture any errors. + cleanup_error = _cleanup_job_clusters(job_id) + cleanup_error_msg = '' + if cleanup_error: + cleanup_error_msg = f'Also, cleanup failed: {cleanup_error}. ' + + # Set all tasks to FAILED_CONTROLLER, regardless of current status. + # This may change a job from SUCCEEDED or another terminal state to + # FAILED_CONTROLLER. This is what we want - we are sure that this + # controller process crashed, so we want to capture that even if the + # underlying job succeeded. + # Note: 2+ invocations of update_managed_jobs_statuses could be running + # at the same time, so this could override the FAILED_CONTROLLER status + # set by another invocation of update_managed_jobs_statuses. That should + # be okay. The only difference could be that one process failed to clean + # up the cluster while the other succeeds. No matter which + # failure_reason ends up in the database, the outcome is acceptable. + # We assume that no other code path outside the controller process will + # update the job status. + managed_job_state.set_failed_controller( + job_id, failure_reason= - f'Controller process has exited abnormally ({failure_reason}). For ' - f'more details, run: sky jobs logs --controller {job_id_}') - scheduler.job_done(job_id_, idempotent=True) - - # Some jobs may be in a terminal status, but are not yet DONE. For instance, - # they may be still cleaning up resources, etc. Such jobs won't be captured - # by the above check, which only looks at nonterminal jobs. So, check the - # controller liveness of all jobs that should have live controller - # processes. - for job_info in managed_job_state.get_schedule_live_jobs(job_id): - if not job_info['controller_pid']: - # Technically, a job with no controller process but in LAUNCHING - # schedule state can happen very briefly after the job is set to - # LAUNCHING but before the controller process is actually spawned. - # However, if we observe any state other than LAUNCHING, something - # is clearly wrong. - if (job_info['schedule_state'] != - managed_job_state.ManagedJobScheduleState.LAUNCHING): - logger.error( - f'Missing controller PID for {job_info["job_id"]}. ' - 'Setting to DONE.') - scheduler.job_done(job_info['job_id']) - else: - logger.info(f'LAUNCHING job {job_info["job_id"]} has no ' - 'controller process yet. Skipping.') + f'Controller process has exited abnormally ({failure_reason}). ' + f'{cleanup_error_msg}' + f'For more details, run: sky jobs logs --controller {job_id}') - elif not _controller_process_alive(job_info['controller_pid'], - job_info['job_id']): - logger.error( - f'Controller process for job {job_info["job_id"]} is not ' - 'alive. Marking the job as DONE.') - scheduler.job_done(job_info['job_id']) + scheduler.job_done(job_id, idempotent=True) def get_job_timestamp(backend: 'backends.CloudVmRayBackend', cluster_name: str, @@ -382,7 +424,7 @@ def cancel_jobs_by_id(job_ids: Optional[List[int]]) -> str: f'{job_status.value}. Skipped.') continue - update_managed_job_status(job_id) + update_managed_jobs_statuses(job_id) # Send the signal to the jobs controller. signal_file = pathlib.Path(SIGNAL_FILE_PREFIX.format(job_id)) diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index 91a1710ac02..881f86367b5 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -86,7 +86,7 @@ # cluster yaml is updated. # # TODO(zongheng,zhanghao): make the upgrading of skylet automatic? -SKYLET_VERSION = '10' +SKYLET_VERSION = '11' # The version of the lib files that skylet/jobs use. Whenever there is an API # change for the job_lib or log_lib, we need to bump this version, so that the # user can be notified to update their SkyPilot version on the remote cluster. diff --git a/sky/skylet/events.py b/sky/skylet/events.py index ea7892ad654..e909a5e8f23 100644 --- a/sky/skylet/events.py +++ b/sky/skylet/events.py @@ -74,7 +74,7 @@ class ManagedJobEvent(SkyletEvent): EVENT_INTERVAL_SECONDS = 300 def _run(self): - managed_job_utils.update_managed_job_status() + managed_job_utils.update_managed_jobs_statuses() managed_job_scheduler.maybe_schedule_next_jobs() From dfb405d7bd99a359f3b643dea62cb1b3322b3287 Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Tue, 21 Jan 2025 14:14:00 -0800 Subject: [PATCH 3/8] lint --- sky/jobs/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 2d4e8aba9dd..d6af513d909 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -223,9 +223,9 @@ def _handle_legacy_job(job_id: int): task_id=None, failure_type=managed_job_state.ManagedJobStatus. FAILED_CONTROLLER, - failure_reason= - 'Legacy controller process has exited abnormally. ' - f'For more details, run: sky jobs logs --controller {job_id}') + failure_reason=( + 'Legacy controller process has exited abnormally. For ' + f'more details, run: sky jobs logs --controller {job_id}')) # Get jobs that need checking (non-terminal or not DONE) job_ids = managed_job_state.get_jobs_to_check(job_id) From 6f5694161da3d015a0880bc42ce34d4188c17cce Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Wed, 22 Jan 2025 16:21:13 -0800 Subject: [PATCH 4/8] fix stream_logs_by_id --- sky/jobs/utils.py | 61 +++++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index d6af513d909..be19d3231ee 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -466,36 +466,24 @@ def cancel_job_by_name(job_name: str) -> str: def stream_logs_by_id(job_id: int, follow: bool = True) -> str: """Stream logs by job id.""" - controller_status = job_lib.get_status(job_id) - status_msg = ux_utils.spinner_message( - 'Waiting for controller process to be RUNNING') + '{status_str}' - status_display = rich_utils.safe_status(status_msg.format(status_str='')) + + def should_keep_logging(status: managed_job_state.ManagedJobStatus) -> bool: + # If we see CANCELLING, just exit - we could miss some job logs but the + # job will be terminated momentarily anyway so we don't really care. + return (not status.is_terminal() and + status is not managed_job_state.ManagedJobStatus.CANCELLING) + + msg = _JOB_WAITING_STATUS_MESSAGE.format(status_str='', job_id=job_id) + status_display = rich_utils.safe_status(msg) num_tasks = managed_job_state.get_num_tasks(job_id) with status_display: - prev_msg = None - while (controller_status != job_lib.JobStatus.RUNNING and - (controller_status is None or - not controller_status.is_terminal())): - status_str = 'None' - if controller_status is not None: - status_str = controller_status.value - msg = status_msg.format(status_str=f' (status: {status_str})') - if msg != prev_msg: - status_display.update(msg) - prev_msg = msg - time.sleep(_LOG_STREAM_CHECK_CONTROLLER_GAP_SECONDS) - controller_status = job_lib.get_status(job_id) - - msg = _JOB_WAITING_STATUS_MESSAGE.format(status_str='', job_id=job_id) - status_display.update(msg) prev_msg = msg - managed_job_status = managed_job_state.get_status(job_id) - while managed_job_status is None: + while (managed_job_status := + managed_job_state.get_status(job_id)) is None: time.sleep(1) - managed_job_status = managed_job_state.get_status(job_id) - if managed_job_status.is_terminal(): + if not should_keep_logging(managed_job_status): job_msg = '' if managed_job_status.is_failed(): job_msg = ('\nFailure reason: ' @@ -522,10 +510,12 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: task_id, managed_job_status = ( managed_job_state.get_latest_task_id_status(job_id)) - # task_id and managed_job_status can be None if the controller process - # just started and the managed job status has not set to PENDING yet. - while (managed_job_status is None or - not managed_job_status.is_terminal()): + # We wait for managed_job_status to be not None above. Once we see that + # it's not None, we don't expect it to every become None again. + assert managed_job_status is not None, (job_id, task_id, + managed_job_status) + + while should_keep_logging(managed_job_status): handle = None if task_id is not None: task_name = managed_job_state.get_task_name(job_id, task_id) @@ -555,8 +545,11 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) task_id, managed_job_status = ( managed_job_state.get_latest_task_id_status(job_id)) + assert managed_job_status is not None, (job_id, task_id, + managed_job_status) continue - assert managed_job_status is not None + assert (managed_job_status is + managed_job_state.ManagedJobStatus.RUNNING) assert isinstance(handle, backends.CloudVmRayResourceHandle), handle status_display.stop() returncode = backend.tail_logs(handle, @@ -610,6 +603,8 @@ def is_managed_job_status_updated( managed_job_status := managed_job_state.get_status(job_id)): time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) + assert managed_job_status is not None, ( + job_id, managed_job_status) continue if task_id == num_tasks - 1: @@ -635,6 +630,8 @@ def is_managed_job_status_updated( if original_task_id != task_id: break time.sleep(JOB_STATUS_CHECK_GAP_SECONDS) + assert managed_job_status is not None, (job_id, task_id, + managed_job_status) continue # The job can be cancelled by the user or the controller (when @@ -650,7 +647,7 @@ def is_managed_job_status_updated( # state. managed_job_status = managed_job_state.get_status(job_id) assert managed_job_status is not None, job_id - if managed_job_status.is_terminal(): + if not should_keep_logging(managed_job_status): break logger.info(f'{colorama.Fore.YELLOW}The job cluster is preempted ' f'or failed.{colorama.Style.RESET_ALL}') @@ -665,6 +662,8 @@ def is_managed_job_status_updated( # managed job state is updated. time.sleep(3 * JOB_STATUS_CHECK_GAP_SECONDS) managed_job_status = managed_job_state.get_status(job_id) + assert managed_job_status is not None, (job_id, managed_job_status) + should_keep_logging(managed_job_status) # The managed_job_status may not be in terminal status yet, since the # controller has not updated the managed job state yet. We wait for a while, @@ -672,7 +671,7 @@ def is_managed_job_status_updated( wait_seconds = 0 managed_job_status = managed_job_state.get_status(job_id) assert managed_job_status is not None, job_id - while (not managed_job_status.is_terminal() and follow and + while (should_keep_logging(managed_job_status) and follow and wait_seconds < _FINAL_JOB_STATUS_WAIT_TIMEOUT_SECONDS): time.sleep(1) wait_seconds += 1 From 06864e20339cc9a537e87d6914a774fadb39b6e5 Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Wed, 22 Jan 2025 18:09:21 -0800 Subject: [PATCH 5/8] remove set_failed_controller --- sky/jobs/state.py | 88 +++++++++++++++++------------------------------ sky/jobs/utils.py | 15 +++++--- 2 files changed, 42 insertions(+), 61 deletions(-) diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 555bdf658b3..cc363320480 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -511,8 +511,12 @@ def set_failed( failure_reason: str, callback_func: Optional[CallbackType] = None, end_time: Optional[float] = None, + override_terminal: bool = False, ): - """Set an entire job or task to failed, if they are in non-terminal states. + """Set an entire job or task to failed. + + By default, don't override tasks that are already terminal (that is, for + which end_at is already set). Args: job_id: The job id. @@ -521,12 +525,13 @@ def set_failed( failure_type: The failure type. One of ManagedJobStatus.FAILED_*. failure_reason: The failure reason. end_time: The end time. If None, the current time will be used. + override_terminal: If True, override the current status even if end_at + is already set. """ assert failure_type.is_failed(), failure_type end_time = time.time() if end_time is None else end_time - fields_to_set = { - 'end_at': end_time, + fields_to_set: Dict[str, Any] = { 'status': failure_type.value, 'failure_reason': failure_reason, } @@ -541,61 +546,30 @@ def set_failed( # affect the job duration calculation. fields_to_set['last_recovered_at'] = end_time set_str = ', '.join(f'{k}=(?)' for k in fields_to_set) - task_str = '' if task_id is None else f' AND task_id={task_id}' - - cursor.execute( - f"""\ - UPDATE spot SET - {set_str} - WHERE spot_job_id=(?){task_str} AND end_at IS null""", - (*list(fields_to_set.values()), job_id)) - if callback_func: - callback_func('FAILED') - logger.info(failure_reason) - - -def set_failed_controller( - job_id: int, - failure_reason: str, - callback_func: Optional[CallbackType] = None, -): - """Set the status to FAILED_CONTROLLER for all tasks of the job. - - Unlike set_failed(), this will override any existing status, including - terminal ones. This should only be used when the controller process has died - abnormally. - - For jobs that already have an end_time set, we preserve that time instead of - overwriting it with the current time. - - Args: - job_id: The job id. - failure_reason: The failure reason. - """ - now = time.time() - fields_to_set: Dict[str, Any] = { - 'status': ManagedJobStatus.FAILED_CONTROLLER.value, - 'failure_reason': failure_reason, - } - with db_utils.safe_cursor(_DB_PATH) as cursor: - previous_status = cursor.execute( - 'SELECT status FROM spot WHERE spot_job_id=(?)', - (job_id,)).fetchone()[0] - previous_status = ManagedJobStatus(previous_status) - if previous_status == ManagedJobStatus.RECOVERING: - # If the job is recovering, we should set the last_recovered_at to - # the current time, so that the end_at - last_recovered_at will not - # affect the job duration calculation. - fields_to_set['last_recovered_at'] = now - - set_str = ', '.join(f'{k}=(?)' for k in fields_to_set) + task_query_str = '' if task_id is None else 'AND task_id=(?)' + task_value = [] if task_id is None else [ + task_id, + ] - cursor.execute( - f"""\ - UPDATE spot SET - end_at = COALESCE(end_at, ?), - {set_str} - WHERE spot_job_id=(?)""", (now, *fields_to_set.values(), job_id)) + if override_terminal: + # Use COALESCE for end_at to avoid overriding the existing end_at if + # it's already set. + cursor.execute( + f"""\ + UPDATE spot SET + end_at = COALESCE(end_at, ?), + {set_str} + WHERE spot_job_id=(?) {task_query_str}""", + (end_time, *list(fields_to_set.values()), job_id, *task_value)) + else: + # Only set if end_at is null. + cursor.execute( + f"""\ + UPDATE spot SET + end_at = (?), + {set_str} + WHERE spot_job_id=(?) {task_query_str} AND end_at IS null""", + (end_time, *list(fields_to_set.values()), job_id, *task_value)) if callback_func: callback_func('FAILED') logger.info(failure_reason) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index be19d3231ee..04ab70534e3 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -206,12 +206,16 @@ def _handle_legacy_job(job_id: int): if cleanup_error: # Unconditionally set the job to failed_controller if the # cleanup fails. - managed_job_state.set_failed_controller( + managed_job_state.set_failed( job_id, + task_id=None, + failure_type=managed_job_state.ManagedJobStatus. + FAILED_CONTROLLER, failure_reason= f'Legacy controller process for {job_id} exited ' f'abnormally, and cleanup failed: {cleanup_error}. For ' - f'more details, run: sky jobs logs --controller {job_id}') + f'more details, run: sky jobs logs --controller {job_id}', + override_terminal=True) return # It's possible for the job to have transitioned to @@ -321,12 +325,15 @@ def _handle_legacy_job(job_id: int): # failure_reason ends up in the database, the outcome is acceptable. # We assume that no other code path outside the controller process will # update the job status. - managed_job_state.set_failed_controller( + managed_job_state.set_failed( job_id, + task_id=None, + failure_type=managed_job_state.ManagedJobStatus.FAILED_CONTROLLER, failure_reason= f'Controller process has exited abnormally ({failure_reason}). ' f'{cleanup_error_msg}' - f'For more details, run: sky jobs logs --controller {job_id}') + f'For more details, run: sky jobs logs --controller {job_id}', + override_terminal=True) scheduler.job_done(job_id, idempotent=True) From 838d6a7b29db5b4e964189ca4f88c8644fa41f4c Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Wed, 22 Jan 2025 18:15:29 -0800 Subject: [PATCH 6/8] address PR comments --- sky/jobs/state.py | 15 +++++++++------ sky/jobs/utils.py | 10 +++++----- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/sky/jobs/state.py b/sky/jobs/state.py index cc363320480..f38aa941325 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -697,7 +697,7 @@ def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]: return jobs -def get_jobs_to_check(job_id: Optional[int] = None) -> List[int]: +def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]: """Get jobs that need controller process checking. Returns: @@ -710,8 +710,9 @@ def get_jobs_to_check(job_id: Optional[int] = None) -> List[int]: job_filter = '' if job_id is None else 'AND spot.spot_job_id=(?)' job_value = () if job_id is None else (job_id,) - statuses = ', '.join(['?'] * len(ManagedJobStatus.terminal_statuses())) - field_values = [ + status_filter_str = ', '.join(['?'] * + len(ManagedJobStatus.terminal_statuses())) + terminal_status_values = [ status.value for status in ManagedJobStatus.terminal_statuses() ] @@ -729,11 +730,13 @@ def get_jobs_to_check(job_id: Optional[int] = None) -> List[int]: (job_info.schedule_state IS NOT NULL AND job_info.schedule_state IS NOT ?) OR - (job_info.schedule_state IS NULL AND status NOT IN ({statuses})) + (job_info.schedule_state IS NULL AND + status NOT IN ({status_filter_str})) ) {job_filter} - ORDER BY spot.spot_job_id DESC""", - [ManagedJobScheduleState.DONE.value, *field_values, *job_value + ORDER BY spot.spot_job_id DESC""", [ + ManagedJobScheduleState.DONE.value, *terminal_status_values, + *job_value ]).fetchall() return [row[0] for row in rows if row[0] is not None] diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 04ab70534e3..38cad31ac3b 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -212,9 +212,9 @@ def _handle_legacy_job(job_id: int): failure_type=managed_job_state.ManagedJobStatus. FAILED_CONTROLLER, failure_reason= - f'Legacy controller process for {job_id} exited ' - f'abnormally, and cleanup failed: {cleanup_error}. For ' - f'more details, run: sky jobs logs --controller {job_id}', + 'Legacy controller process has exited abnormally, and ' + f'cleanup failed: {cleanup_error}. For more details, run: ' + f'sky jobs logs --controller {job_id}', override_terminal=True) return @@ -232,7 +232,7 @@ def _handle_legacy_job(job_id: int): f'more details, run: sky jobs logs --controller {job_id}')) # Get jobs that need checking (non-terminal or not DONE) - job_ids = managed_job_state.get_jobs_to_check(job_id) + job_ids = managed_job_state.get_jobs_to_check_status(job_id) if not job_ids: # job_id is already terminal, or if job_id is None, there are no jobs # that need to be checked. @@ -249,7 +249,7 @@ def _handle_legacy_job(job_id: int): # Backwards compatibility: this job was submitted when ray was still # used for managing the parallelism of job controllers. # TODO(cooperc): Remove before 0.11.0. - if (schedule_state is None or schedule_state is + if (schedule_state is managed_job_state.ManagedJobScheduleState.INVALID): _handle_legacy_job(job_id) continue From 9a8e4c8466f9e9fa84bf8439cf34a17436ff29d1 Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Thu, 23 Jan 2025 13:08:24 -0800 Subject: [PATCH 7/8] Apply suggestions from code review Co-authored-by: Zhanghao Wu --- sky/jobs/state.py | 2 +- sky/jobs/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sky/jobs/state.py b/sky/jobs/state.py index f38aa941325..cdb34499aa2 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -562,7 +562,7 @@ def set_failed( WHERE spot_job_id=(?) {task_query_str}""", (end_time, *list(fields_to_set.values()), job_id, *task_value)) else: - # Only set if end_at is null. + # Only set if end_at is null, i.e. the previous state is not terminal. cursor.execute( f"""\ UPDATE spot SET diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 38cad31ac3b..87de56622ca 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -190,7 +190,7 @@ def _cleanup_job_clusters(job_id: int) -> Optional[str]: terminate_cluster(cluster_name) except Exception as e: # pylint: disable=broad-except error_msg = (f'Failed to terminate cluster {cluster_name}: ' - f'{str(e)}') + f'{common_utils.format_exception(e, use_bracket=True)}') logger.exception(error_msg, exc_info=e) return error_msg From 5b8a0a4b3e0ff2a9531ffb52b41a6f27aa269dc7 Mon Sep 17 00:00:00 2001 From: Christopher Cooper Date: Thu, 23 Jan 2025 13:14:49 -0800 Subject: [PATCH 8/8] address PR review --- sky/jobs/state.py | 11 ++++++----- sky/jobs/utils.py | 10 +++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/sky/jobs/state.py b/sky/jobs/state.py index cdb34499aa2..07f62c3c82d 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -562,7 +562,8 @@ def set_failed( WHERE spot_job_id=(?) {task_query_str}""", (end_time, *list(fields_to_set.values()), job_id, *task_value)) else: - # Only set if end_at is null, i.e. the previous state is not terminal. + # Only set if end_at is null, i.e. the previous status is not + # terminal. cursor.execute( f"""\ UPDATE spot SET @@ -700,12 +701,12 @@ def get_schedule_live_jobs(job_id: Optional[int]) -> List[Dict[str, Any]]: def get_jobs_to_check_status(job_id: Optional[int] = None) -> List[int]: """Get jobs that need controller process checking. - Returns: - - For jobs with schedule state: jobs that have schedule state not DONE - - For legacy jobs (no schedule state): jobs that are in non-terminal status - Args: job_id: Optional job ID to check. If None, checks all jobs. + + Returns a list of job_ids, including the following: + - For jobs with schedule state: jobs that have schedule state not DONE + - For legacy jobs (no schedule state): jobs that are in non-terminal status """ job_filter = '' if job_id is None else 'AND spot.spot_job_id=(?)' job_value = () if job_id is None else (job_id,) diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 87de56622ca..4cd4df12de3 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -189,8 +189,9 @@ def _cleanup_job_clusters(job_id: int) -> Optional[str]: try: terminate_cluster(cluster_name) except Exception as e: # pylint: disable=broad-except - error_msg = (f'Failed to terminate cluster {cluster_name}: ' - f'{common_utils.format_exception(e, use_bracket=True)}') + error_msg = ( + f'Failed to terminate cluster {cluster_name}: ' + f'{common_utils.format_exception(e, use_bracket=True)}') logger.exception(error_msg, exc_info=e) return error_msg @@ -478,7 +479,7 @@ def should_keep_logging(status: managed_job_state.ManagedJobStatus) -> bool: # If we see CANCELLING, just exit - we could miss some job logs but the # job will be terminated momentarily anyway so we don't really care. return (not status.is_terminal() and - status is not managed_job_state.ManagedJobStatus.CANCELLING) + status != managed_job_state.ManagedJobStatus.CANCELLING) msg = _JOB_WAITING_STATUS_MESSAGE.format(status_str='', job_id=job_id) status_display = rich_utils.safe_status(msg) @@ -555,7 +556,7 @@ def should_keep_logging(status: managed_job_state.ManagedJobStatus) -> bool: assert managed_job_status is not None, (job_id, task_id, managed_job_status) continue - assert (managed_job_status is + assert (managed_job_status == managed_job_state.ManagedJobStatus.RUNNING) assert isinstance(handle, backends.CloudVmRayResourceHandle), handle status_display.stop() @@ -670,7 +671,6 @@ def is_managed_job_status_updated( time.sleep(3 * JOB_STATUS_CHECK_GAP_SECONDS) managed_job_status = managed_job_state.get_status(job_id) assert managed_job_status is not None, (job_id, managed_job_status) - should_keep_logging(managed_job_status) # The managed_job_status may not be in terminal status yet, since the # controller has not updated the managed job state yet. We wait for a while,