From 297e5d82ce2da33d73fa91b6bc81fd5fda09cd34 Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Mon, 4 Jan 2021 17:21:14 -0500 Subject: [PATCH] Consistently clean shutdown (#348) * Proper shutdown behaviors * Remove debug --- mephisto/abstractions/blueprint.py | 44 +++++++++++++++--- mephisto/data_model/agent.py | 23 ++++++++++ mephisto/data_model/exceptions.py | 7 +++ mephisto/operations/operator.py | 67 ++++++++++++++++++++++++++-- mephisto/operations/supervisor.py | 6 +++ mephisto/operations/task_launcher.py | 1 + 6 files changed, 137 insertions(+), 11 deletions(-) diff --git a/mephisto/abstractions/blueprint.py b/mephisto/abstractions/blueprint.py index 63fcbf12f..973b8bc21 100644 --- a/mephisto/abstractions/blueprint.py +++ b/mephisto/abstractions/blueprint.py @@ -18,6 +18,7 @@ Iterable, AsyncIterator, Callable, + Tuple, TYPE_CHECKING, ) @@ -28,6 +29,7 @@ AgentReturnedError, AgentDisconnectedError, AgentTimeoutError, + AgentShutdownError, ) from mephisto.data_model.constants.assignment_state import AssignmentState @@ -124,8 +126,8 @@ def __init__( self.args = args self.shared_state = shared_state self.task_run = task_run - self.running_assignments: Dict[str, "Assignment"] = {} - self.running_units: Dict[str, "Unit"] = {} + self.running_assignments: Dict[str, Tuple["Assignment", List["Agent"]]] = {} + self.running_units: Dict[str, Tuple["Unit", "Agent"]] = {} self.running_onboardings: Dict[str, "OnboardingAgent"] = {} self.is_concurrent = False # TODO(102) populate some kind of local state for tasks that are being run @@ -166,7 +168,12 @@ def launch_onboarding(self, onboarding_agent: "OnboardingAgent") -> None: try: self.run_onboarding(onboarding_agent) onboarding_agent.mark_done() - except (AgentReturnedError, AgentTimeoutError, AgentDisconnectedError): + except ( + AgentReturnedError, + AgentTimeoutError, + AgentDisconnectedError, + AgentShutdownError, + ): self.cleanup_onboarding(onboarding_agent) except Exception as e: print(f"Unhandled exception in onboarding {onboarding_agent}: {repr(e)}") @@ -188,10 +195,15 @@ def launch_unit(self, unit: "Unit", agent: "Agent") -> None: print(f"Unit {unit.db_id} is launching with {agent}") # At this point we're sure we want to run the unit - self.running_units[unit.db_id] = unit + self.running_units[unit.db_id] = (unit, agent) try: self.run_unit(unit, agent) - except (AgentReturnedError, AgentTimeoutError, AgentDisconnectedError): + except ( + AgentReturnedError, + AgentTimeoutError, + AgentDisconnectedError, + AgentShutdownError, + ): # A returned Unit can be worked on again by someone else. if ( unit.get_status() != AssignmentState.EXPIRED @@ -221,10 +233,15 @@ def launch_assignment( print(f"Assignment {assignment.db_id} is launching with {agents}") # At this point we're sure we want to run the assignment - self.running_assignments[assignment.db_id] = assignment + self.running_assignments[assignment.db_id] = (assignment, agents) try: self.run_assignment(assignment, agents) - except (AgentReturnedError, AgentTimeoutError, AgentDisconnectedError) as e: + except ( + AgentReturnedError, + AgentTimeoutError, + AgentDisconnectedError, + AgentShutdownError, + ) as e: # TODO(#99) if some operator flag is set for counting complete tasks, launch a # new assignment copied from the parameters of this one disconnected_agent_id = e.agent_id @@ -269,6 +286,19 @@ def filter_units_for_worker(self, units: List["Unit"], worker: "Worker"): """ return units + def shutdown(self): + """ + Updates the status of all agents tracked by this runner to throw a ShutdownException, + ensuring that all the threads exit correctly and we can cleanup properly. + """ + for _unit, agent in self.running_units.values(): + agent.shutdown() + for _assignment, agents in self.running_assignments.values(): + for agent in agents: + agent.shutdown() + for onboarding_agent in self.running_onboardings.values(): + onboarding_agent.shutdown() + # TaskRunners must implement either the unit or assignment versions of the # run and cleanup functions, depending on if the task is run at the assignment # level rather than on the the unit level. diff --git a/mephisto/data_model/agent.py b/mephisto/data_model/agent.py index 2ba5f95ce..4a6b1c61e 100644 --- a/mephisto/data_model/agent.py +++ b/mephisto/data_model/agent.py @@ -15,6 +15,7 @@ AgentReturnedError, AgentDisconnectedError, AgentTimeoutError, + AgentShutdownError, ) from typing import List, Optional, Tuple, Mapping, Dict, Any, TYPE_CHECKING @@ -59,6 +60,7 @@ def __init__( self.task_run_id = row["task_run_id"] self.task_id = row["task_id"] self.did_submit = threading.Event() + self.is_shutdown = False # Deferred loading of related entities self._worker: Optional["Worker"] = None @@ -248,6 +250,8 @@ def act(self, timeout: Optional[int] = None) -> Optional["Packet"]: self.has_action.wait(timeout) if len(self.pending_actions) == 0: + if self.is_shutdown: + raise AgentShutdownError(self.db_id) # various disconnect cases status = self.get_status() if status == AgentState.STATUS_DISCONNECT: @@ -283,6 +287,14 @@ def get_status(self) -> str: self.db_status = row["status"] return self.db_status + def shutdown(self) -> None: + """ + Force the given agent to end any polling threads and throw an AgentShutdownError + from any acts called on it, ensuring tasks using this agent can be cleaned up. + """ + self.has_action.set() + self.is_shutdown = True + # Children classes should implement the following methods def approve_work(self) -> None: @@ -361,6 +373,7 @@ def __init__( self.task_run_id = row["task_run_id"] self.task_id = row["task_id"] self.did_submit = threading.Event() + self.is_shutdown = False # Deferred loading of related entities self._worker: Optional["Worker"] = None @@ -465,6 +478,8 @@ def act(self, timeout: Optional[int] = None) -> Optional["Packet"]: if len(self.pending_actions) == 0: # various disconnect cases + if self.is_shutdown: + raise AgentShutdownError(self.db_id) status = self.get_status() if status == AgentState.STATUS_DISCONNECT: raise AgentDisconnectedError(self.db_id) @@ -509,6 +524,14 @@ def mark_done(self) -> None: ]: self.update_status(AgentState.STATUS_WAITING) + def shutdown(self) -> None: + """ + Force the given agent to end any polling threads and throw an AgentShutdownError + from any acts called on it, ensuring tasks using this agent can be cleaned up. + """ + self.has_action.set() + self.is_shutdown = True + @staticmethod def new(db: "MephistoDB", worker: Worker, task_run: "TaskRun") -> "OnboardingAgent": """ diff --git a/mephisto/data_model/exceptions.py b/mephisto/data_model/exceptions.py index 26c499cac..190c8166d 100644 --- a/mephisto/data_model/exceptions.py +++ b/mephisto/data_model/exceptions.py @@ -34,3 +34,10 @@ class AgentReturnedError(AbsentAgentError): def __init__(self, agent_id): super().__init__(f"Agent returned task", agent_id) + + +class AgentShutdownError(AbsentAgentError): + """Exception for when a task is shutdown but agents are still in a task""" + + def __init__(self, agent_id): + super().__init__(f"This agent has been forced to shut down", agent_id) diff --git a/mephisto/operations/operator.py b/mephisto/operations/operator.py index c98c923e0..eaec88974 100644 --- a/mephisto/operations/operator.py +++ b/mephisto/operations/operator.py @@ -14,6 +14,7 @@ import threading import shlex import traceback +import signal from argparse import ArgumentParser @@ -300,13 +301,67 @@ def _track_and_kill_runs(self): del self._task_runs_tracked[task_run.db_id] time.sleep(2) + def force_shutdown(self, timeout=5): + """ + Force a best-effort shutdown of everything, letting no individual + shutdown step suspend for more than the timeout before moving on. + + Skips waiting for in-flight assignments to rush the shutdown. + + ** Should only be used in sandbox or test environments. ** + """ + self.is_shutdown = True + + def end_launchers_and_expire_units(): + for tracked_run in self._task_runs_tracked.values(): + tracked_run.task_launcher.shutdown() + tracked_run.task_launcher.expire_units() + + def end_architects(): + for tracked_run in self._task_runs_tracked.values(): + tracked_run.architect.shutdown() + + def shutdown_supervisor(): + if self.supervisor is not None: + self.supervisor.shutdown() + + tasks = { + "expire-units": end_launchers_and_expire_units, + "kill-architects": end_architects, + "fire-supervisor": shutdown_supervisor, + } + + for tname, t in tasks.items(): + shutdown_thread = threading.Thread(target=t, name=f"force-shutdown-{tname}") + shutdown_thread.start() + start_time = time.time() + while time.time() - start_time < timeout and shutdown_thread.is_alive(): + time.sleep(0.5) + if not shutdown_thread.is_alive(): + # Only join if the shutdown fully completed + shutdown_thread.join() + def shutdown(self, skip_input=True): logger.info("operator shutting down") self.is_shutdown = True - for tracked_run in self._task_runs_tracked.values(): - logger.info("expiring units") - tracked_run.task_launcher.shutdown() + for run_id, tracked_run in self._task_runs_tracked.items(): + logger.info(f"Expiring units for task run {run_id}.") + try: + tracked_run.task_launcher.shutdown() + except (KeyboardInterrupt, SystemExit) as e: + logger.info( + f"Skipping waiting for launcher threads to join on task run {run_id}." + ) + + def cant_cancel_expirations(self, sig, frame): + logging.warn( + "Ignoring ^C during unit expirations. ^| if you NEED to exit and you will " + "clean up units that hadn't been expired afterwards." + ) + + old_handler = signal.signal(signal.SIGINT, cant_cancel_expirations) tracked_run.task_launcher.expire_units() + signal.signal(signal.SIGINT, old_handler) try: remaining_runs = self._task_runs_tracked.values() while len(remaining_runs) > 0: @@ -318,7 +373,8 @@ def shutdown(self, skip_input=True): next_runs.append(tracked_run) if len(next_runs) > 0: logger.info( - f"Waiting on {len(remaining_runs)} task runs, Ctrl-C ONCE to FORCE QUIT" + f"Waiting on {len(remaining_runs)} task runs with assignments in-flight " + f"Ctrl-C ONCE to kill running tasks and FORCE QUIT." ) time.sleep(30) remaining_runs = next_runs @@ -334,6 +390,9 @@ def shutdown(self, skip_input=True): "Skipping waiting for outstanding task completions, shutting down servers now!" ) for tracked_run in remaining_runs: + logger.info( + f"Shutting down Architect for task run {tracked_run.task_run.db_id}" + ) tracked_run.architect.shutdown() finally: self.supervisor.shutdown() diff --git a/mephisto/operations/supervisor.py b/mephisto/operations/supervisor.py index 8c78e0f52..29570dd5d 100644 --- a/mephisto/operations/supervisor.py +++ b/mephisto/operations/supervisor.py @@ -202,9 +202,15 @@ def shutdown(self): """Close all of the channels, join threads""" channels_to_close = list(self.channels.keys()) for channel_id in channels_to_close: + channel_info = self.channels[channel_id] + channel_info.job.task_runner.shutdown() self.close_channel(channel_id) if self.sending_thread is not None: self.sending_thread.join() + for agent_info in self.agents.values(): + assign_thread = agent_info.assignment_thread + if assign_thread is not None: + assign_thread.join() def _send_alive(self, channel_info: ChannelInfo) -> bool: logger.info("Sending alive") diff --git a/mephisto/operations/task_launcher.py b/mephisto/operations/task_launcher.py index 8a393e3d7..525c0275d 100644 --- a/mephisto/operations/task_launcher.py +++ b/mephisto/operations/task_launcher.py @@ -215,6 +215,7 @@ def shutdown(self) -> None: """Clean up running threads for generating assignments and units""" self.assignment_thread_done = True self.keep_launching_units = False + self.finished_generators = True if self.assignments_thread is not None: self.assignments_thread.join() self.units_thread.join()