From 1578108f8706e45b2b06527429a54b658106c6df Mon Sep 17 00:00:00 2001 From: Kaiyuan Eric Chen Date: Thu, 9 Jan 2025 11:12:30 -0800 Subject: [PATCH] [Jobs] --sync-down execution log support for managed jobs (#4527) * add support for managed jobs * support --name to download logs * format * fix job_id * add support that separate controller that not controller * remove the interface for getting complete job_ids, parse table instead * code formatting * documentation for sync down logs * revive job id to timestamp via new interface * make the log works * revert managed job yaml * code formatting and cleanup * Fix formatting in managed_job.yaml by removing unnecessary whitespace * Remove unnecessary whitespace in managed_job.yaml for improved formatting * add sync down smoke tests * Enhance smoke tests for managed jobs by adding command to wait for job status. This update improves the reliability of the test by ensuring the job is in a RUNNING state before proceeding with log synchronization. * refactor sync down logic remove the dependency of getting job status * formatting * fix rsync bug due to refactor * clean up printing * Update sky/jobs/core.py Co-authored-by: Christopher Cooper * Refactor job ID retrieval and update CLI option flag - Changed the method name from `get_job_ids_by_name` to `get_all_job_ids_by_name` for clarity in `ManagedJobCodeGen`. - Updated the CLI option flag for `sync-down` from `-d` to `-s` for better alignment with common conventions. * format & linting --------- Co-authored-by: Christopher Cooper --- sky/backends/cloud_vm_ray_backend.py | 146 ++++++++++++++++++++++++++ sky/cli.py | 26 +++-- sky/jobs/__init__.py | 2 + sky/jobs/core.py | 46 ++++++++ sky/jobs/state.py | 27 +++++ sky/jobs/utils.py | 9 ++ tests/smoke_tests/test_managed_job.py | 23 ++++ 7 files changed, 272 insertions(+), 7 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 2316888b44c..128c2acafe5 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3891,6 +3891,152 @@ def tail_managed_job_logs(self, stdin=subprocess.DEVNULL, ) + def sync_down_managed_job_logs( + self, + handle: CloudVmRayResourceHandle, + job_id: Optional[int] = None, + job_name: Optional[str] = None, + controller: bool = False, + local_dir: str = constants.SKY_LOGS_DIRECTORY) -> Dict[str, str]: + """Sync down logs for a managed job. + + Args: + handle: The handle to the cluster. + job_id: The job ID to sync down logs for. + job_name: The job name to sync down logs for. + controller: Whether to sync down logs for the controller. + local_dir: The local directory to sync down logs to. + + Returns: + A dictionary mapping job_id to log path. + """ + # if job_name is not None, job_id should be None + assert job_name is None or job_id is None, (job_name, job_id) + if job_id is None and job_name is not None: + # generate code to get the job_id + code = managed_jobs.ManagedJobCodeGen.get_all_job_ids_by_name( + job_name=job_name) + returncode, run_timestamps, stderr = self.run_on_head( + handle, + code, + stream_logs=False, + require_outputs=True, + separate_stderr=True) + subprocess_utils.handle_returncode(returncode, code, + 'Failed to sync down logs.', + stderr) + job_ids = common_utils.decode_payload(run_timestamps) + if not job_ids: + logger.info(f'{colorama.Fore.YELLOW}' + 'No matching job found' + f'{colorama.Style.RESET_ALL}') + return {} + elif len(job_ids) > 1: + logger.info( + f'{colorama.Fore.YELLOW}' + f'Multiple jobs IDs found under the name {job_name}. ' + 'Downloading the latest job logs.' + f'{colorama.Style.RESET_ALL}') + job_ids = [job_ids[0]] # descending order + else: + job_ids = [job_id] + + # get the run_timestamp + # the function takes in [job_id] + code = job_lib.JobLibCodeGen.get_run_timestamp_with_globbing(job_ids) + returncode, run_timestamps, stderr = self.run_on_head( + handle, + code, + stream_logs=False, + require_outputs=True, + separate_stderr=True) + subprocess_utils.handle_returncode(returncode, code, + 'Failed to sync logs.', stderr) + # returns with a dict of {job_id: run_timestamp} + run_timestamps = common_utils.decode_payload(run_timestamps) + if not run_timestamps: + logger.info(f'{colorama.Fore.YELLOW}' + 'No matching log directories found' + f'{colorama.Style.RESET_ALL}') + return {} + + run_timestamp = list(run_timestamps.values())[0] + job_id = list(run_timestamps.keys())[0] + local_log_dir = '' + if controller: # download controller logs + remote_log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, + run_timestamp) + local_log_dir = os.path.expanduser( + os.path.join(local_dir, run_timestamp)) + + logger.info(f'{colorama.Fore.CYAN}' + f'Job {job_ids} local logs: {local_log_dir}' + f'{colorama.Style.RESET_ALL}') + + runners = handle.get_command_runners() + + def _rsync_down(args) -> None: + """Rsync down logs from remote nodes. + + Args: + args: A tuple of (runner, local_log_dir, remote_log_dir) + """ + (runner, local_log_dir, remote_log_dir) = args + try: + os.makedirs(local_log_dir, exist_ok=True) + runner.rsync( + source=f'{remote_log_dir}/', + target=local_log_dir, + up=False, + stream_logs=False, + ) + except exceptions.CommandError as e: + if e.returncode == exceptions.RSYNC_FILE_NOT_FOUND_CODE: + # Raised by rsync_down. Remote log dir may not exist + # since the job can be run on some part of the nodes. + logger.debug( + f'{runner.node_id} does not have the tasks/*.') + else: + raise + + parallel_args = [[runner, *item] + for item in zip([local_log_dir], [remote_log_dir]) + for runner in runners] + subprocess_utils.run_in_parallel(_rsync_down, parallel_args) + else: # download job logs + local_log_dir = os.path.expanduser( + os.path.join(local_dir, 'managed_jobs', run_timestamp)) + os.makedirs(os.path.dirname(local_log_dir), exist_ok=True) + log_file = os.path.join(local_log_dir, 'run.log') + + code = managed_jobs.ManagedJobCodeGen.stream_logs(job_name=None, + job_id=job_id, + follow=False, + controller=False) + + # With the stdin=subprocess.DEVNULL, the ctrl-c will not + # kill the process, so we need to handle it manually here. + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGINT, backend_utils.interrupt_handler) + signal.signal(signal.SIGTSTP, backend_utils.stop_handler) + + # We redirect the output to the log file + # and disable the STDOUT and STDERR + self.run_on_head( + handle, + code, + log_path=log_file, + stream_logs=False, + process_stream=False, + ssh_mode=command_runner.SshMode.INTERACTIVE, + stdin=subprocess.DEVNULL, + ) + + logger.info(f'{colorama.Fore.CYAN}' + f'Job {job_id} logs: {local_log_dir}' + f'{colorama.Style.RESET_ALL}') + return {str(job_id): local_log_dir} + def tail_serve_logs(self, handle: CloudVmRayResourceHandle, service_name: str, target: serve_lib.ServiceComponent, replica_id: Optional[int], follow: bool) -> None: diff --git a/sky/cli.py b/sky/cli.py index d00aae9b646..27948f9ec85 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -3933,17 +3933,29 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool): required=False, help='Query the latest job logs, restarting the jobs controller if stopped.' ) +@click.option('--sync-down', + '-s', + default=False, + is_flag=True, + required=False, + help='Download logs for all jobs shown in the queue.') @click.argument('job_id', required=False, type=int) @usage_lib.entrypoint def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool, - controller: bool, refresh: bool): - """Tail the log of a managed job.""" + controller: bool, refresh: bool, sync_down: bool): + """Tail or sync down the log of a managed job.""" try: - managed_jobs.tail_logs(name=name, - job_id=job_id, - follow=follow, - controller=controller, - refresh=refresh) + if sync_down: + managed_jobs.sync_down_logs(name=name, + job_id=job_id, + controller=controller, + refresh=refresh) + else: + managed_jobs.tail_logs(name=name, + job_id=job_id, + follow=follow, + controller=controller, + refresh=refresh) except exceptions.ClusterNotUpError: with ux_utils.print_exception_no_traceback(): raise diff --git a/sky/jobs/__init__.py b/sky/jobs/__init__.py index 5688ca7c7a2..5f52a863e36 100644 --- a/sky/jobs/__init__.py +++ b/sky/jobs/__init__.py @@ -9,6 +9,7 @@ from sky.jobs.core import launch from sky.jobs.core import queue from sky.jobs.core import queue_from_kubernetes_pod +from sky.jobs.core import sync_down_logs from sky.jobs.core import tail_logs from sky.jobs.recovery_strategy import DEFAULT_RECOVERY_STRATEGY from sky.jobs.recovery_strategy import RECOVERY_STRATEGIES @@ -37,6 +38,7 @@ 'queue', 'queue_from_kubernetes_pod', 'tail_logs', + 'sync_down_logs', # utils 'ManagedJobCodeGen', 'format_job_table', diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 3718d0ac67c..3cb67daba94 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -427,6 +427,52 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool, controller=controller) +@usage_lib.entrypoint +def sync_down_logs( + name: Optional[str], + job_id: Optional[int], + refresh: bool, + controller: bool, + local_dir: str = skylet_constants.SKY_LOGS_DIRECTORY) -> None: + """Sync down logs of managed jobs. + + Please refer to sky.cli.job_logs for documentation. + + Raises: + ValueError: invalid arguments. + sky.exceptions.ClusterNotUpError: the jobs controller is not up. + """ + # TODO(zhwu): Automatically restart the jobs controller + if name is not None and job_id is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Cannot specify both name and job_id.') + + jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER + job_name_or_id_str = '' + if job_id is not None: + job_name_or_id_str = str(job_id) + elif name is not None: + job_name_or_id_str = f'-n {name}' + else: + job_name_or_id_str = '' + handle = _maybe_restart_controller( + refresh, + stopped_message=( + f'{jobs_controller_type.value.name.capitalize()} is stopped. To ' + f'get the logs, run: {colorama.Style.BRIGHT}sky jobs logs ' + f'-r --sync-down {job_name_or_id_str}{colorama.Style.RESET_ALL}'), + spinner_message='Retrieving job logs') + + backend = backend_utils.get_backend_from_handle(handle) + assert isinstance(backend, backends.CloudVmRayBackend), backend + + backend.sync_down_managed_job_logs(handle, + job_id=job_id, + job_name=name, + controller=controller, + local_dir=local_dir) + + spot_launch = common_utils.deprecated_function( launch, name='sky.jobs.launch', diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 31dcfcfd5eb..5da807b8bbb 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -564,6 +564,33 @@ def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]: return job_ids +def get_all_job_ids_by_name(name: Optional[str]) -> List[int]: + """Get all job ids by name.""" + name_filter = '' + field_values = [] + if name is not None: + # We match the job name from `job_info` for the jobs submitted after + # #1982, and from `spot` for the jobs submitted before #1982, whose + # job_info is not available. + name_filter = ('WHERE (job_info.name=(?) OR ' + '(job_info.name IS NULL AND spot.task_name=(?)))') + field_values = [name, name] + + # Left outer join is used here instead of join, because the job_info does + # not contain the managed jobs submitted before #1982. + with db_utils.safe_cursor(_DB_PATH) as cursor: + rows = cursor.execute( + f"""\ + SELECT DISTINCT spot.spot_job_id + FROM spot + LEFT OUTER JOIN job_info + ON spot.spot_job_id=job_info.spot_job_id + {name_filter} + ORDER BY spot.spot_job_id DESC""", field_values).fetchall() + job_ids = [row[0] for row in rows if row[0] is not None] + return job_ids + + def _get_all_task_ids_statuses( job_id: int) -> List[Tuple[int, ManagedJobStatus]]: with db_utils.safe_cursor(_DB_PATH) as cursor: diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index e5bbced997c..b044e31bda6 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -855,6 +855,15 @@ def cancel_job_by_name(cls, job_name: str) -> str: """) return cls._build(code) + @classmethod + def get_all_job_ids_by_name(cls, job_name: str) -> str: + code = textwrap.dedent(f"""\ + from sky.utils import common_utils + job_id = managed_job_state.get_all_job_ids_by_name({job_name!r}) + print(common_utils.encode_payload(job_id), end="", flush=True) + """) + return cls._build(code) + @classmethod def stream_logs(cls, job_name: Optional[str], diff --git a/tests/smoke_tests/test_managed_job.py b/tests/smoke_tests/test_managed_job.py index 5c930724523..4a16b469e5a 100644 --- a/tests/smoke_tests/test_managed_job.py +++ b/tests/smoke_tests/test_managed_job.py @@ -871,3 +871,26 @@ def test_managed_jobs_inline_env(generic_cloud: str): timeout=20 * 60, ) smoke_tests_utils.run_one_test(test) + + +@pytest.mark.managed_jobs +def test_managed_jobs_logs_sync_down(): + name = smoke_tests_utils.get_cluster_name() + test = smoke_tests_utils.Test( + 'test-managed-jobs-logs-sync-down', + [ + f'sky jobs launch -n {name} -y examples/managed_job.yaml -d', + smoke_tests_utils. + get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}', + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300 + smoke_tests_utils.BUMP_UP_SECONDS), + f'sky jobs logs --controller 1 --sync-down', + f'sky jobs logs 1 --sync-down', + f'sky jobs logs --controller --name minimal --sync-down', + f'sky jobs logs --name minimal --sync-down', + ], + f'sky jobs cancel -y -n {name}', + timeout=20 * 60, + ) + smoke_tests_utils.run_one_test(test)