Skip to content

Commit

Permalink
Consistently clean shutdown (#348)
Browse files Browse the repository at this point in the history
* Proper shutdown behaviors

* Remove debug
  • Loading branch information
JackUrb authored Jan 4, 2021
1 parent 7ede736 commit 297e5d8
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 11 deletions.
44 changes: 37 additions & 7 deletions mephisto/abstractions/blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Iterable,
AsyncIterator,
Callable,
Tuple,
TYPE_CHECKING,
)

Expand All @@ -28,6 +29,7 @@
AgentReturnedError,
AgentDisconnectedError,
AgentTimeoutError,
AgentShutdownError,
)
from mephisto.data_model.constants.assignment_state import AssignmentState

Expand Down Expand Up @@ -124,8 +126,8 @@ def __init__(
self.args = args
self.shared_state = shared_state
self.task_run = task_run
self.running_assignments: Dict[str, "Assignment"] = {}
self.running_units: Dict[str, "Unit"] = {}
self.running_assignments: Dict[str, Tuple["Assignment", List["Agent"]]] = {}
self.running_units: Dict[str, Tuple["Unit", "Agent"]] = {}
self.running_onboardings: Dict[str, "OnboardingAgent"] = {}
self.is_concurrent = False
# TODO(102) populate some kind of local state for tasks that are being run
Expand Down Expand Up @@ -166,7 +168,12 @@ def launch_onboarding(self, onboarding_agent: "OnboardingAgent") -> None:
try:
self.run_onboarding(onboarding_agent)
onboarding_agent.mark_done()
except (AgentReturnedError, AgentTimeoutError, AgentDisconnectedError):
except (
AgentReturnedError,
AgentTimeoutError,
AgentDisconnectedError,
AgentShutdownError,
):
self.cleanup_onboarding(onboarding_agent)
except Exception as e:
print(f"Unhandled exception in onboarding {onboarding_agent}: {repr(e)}")
Expand All @@ -188,10 +195,15 @@ def launch_unit(self, unit: "Unit", agent: "Agent") -> None:
print(f"Unit {unit.db_id} is launching with {agent}")

# At this point we're sure we want to run the unit
self.running_units[unit.db_id] = unit
self.running_units[unit.db_id] = (unit, agent)
try:
self.run_unit(unit, agent)
except (AgentReturnedError, AgentTimeoutError, AgentDisconnectedError):
except (
AgentReturnedError,
AgentTimeoutError,
AgentDisconnectedError,
AgentShutdownError,
):
# A returned Unit can be worked on again by someone else.
if (
unit.get_status() != AssignmentState.EXPIRED
Expand Down Expand Up @@ -221,10 +233,15 @@ def launch_assignment(
print(f"Assignment {assignment.db_id} is launching with {agents}")

# At this point we're sure we want to run the assignment
self.running_assignments[assignment.db_id] = assignment
self.running_assignments[assignment.db_id] = (assignment, agents)
try:
self.run_assignment(assignment, agents)
except (AgentReturnedError, AgentTimeoutError, AgentDisconnectedError) as e:
except (
AgentReturnedError,
AgentTimeoutError,
AgentDisconnectedError,
AgentShutdownError,
) as e:
# TODO(#99) if some operator flag is set for counting complete tasks, launch a
# new assignment copied from the parameters of this one
disconnected_agent_id = e.agent_id
Expand Down Expand Up @@ -269,6 +286,19 @@ def filter_units_for_worker(self, units: List["Unit"], worker: "Worker"):
"""
return units

def shutdown(self):
"""
Updates the status of all agents tracked by this runner to throw a ShutdownException,
ensuring that all the threads exit correctly and we can cleanup properly.
"""
for _unit, agent in self.running_units.values():
agent.shutdown()
for _assignment, agents in self.running_assignments.values():
for agent in agents:
agent.shutdown()
for onboarding_agent in self.running_onboardings.values():
onboarding_agent.shutdown()

# TaskRunners must implement either the unit or assignment versions of the
# run and cleanup functions, depending on if the task is run at the assignment
# level rather than on the the unit level.
Expand Down
23 changes: 23 additions & 0 deletions mephisto/data_model/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AgentReturnedError,
AgentDisconnectedError,
AgentTimeoutError,
AgentShutdownError,
)

from typing import List, Optional, Tuple, Mapping, Dict, Any, TYPE_CHECKING
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
self.task_run_id = row["task_run_id"]
self.task_id = row["task_id"]
self.did_submit = threading.Event()
self.is_shutdown = False

# Deferred loading of related entities
self._worker: Optional["Worker"] = None
Expand Down Expand Up @@ -248,6 +250,8 @@ def act(self, timeout: Optional[int] = None) -> Optional["Packet"]:
self.has_action.wait(timeout)

if len(self.pending_actions) == 0:
if self.is_shutdown:
raise AgentShutdownError(self.db_id)
# various disconnect cases
status = self.get_status()
if status == AgentState.STATUS_DISCONNECT:
Expand Down Expand Up @@ -283,6 +287,14 @@ def get_status(self) -> str:
self.db_status = row["status"]
return self.db_status

def shutdown(self) -> None:
"""
Force the given agent to end any polling threads and throw an AgentShutdownError
from any acts called on it, ensuring tasks using this agent can be cleaned up.
"""
self.has_action.set()
self.is_shutdown = True

# Children classes should implement the following methods

def approve_work(self) -> None:
Expand Down Expand Up @@ -361,6 +373,7 @@ def __init__(
self.task_run_id = row["task_run_id"]
self.task_id = row["task_id"]
self.did_submit = threading.Event()
self.is_shutdown = False

# Deferred loading of related entities
self._worker: Optional["Worker"] = None
Expand Down Expand Up @@ -465,6 +478,8 @@ def act(self, timeout: Optional[int] = None) -> Optional["Packet"]:

if len(self.pending_actions) == 0:
# various disconnect cases
if self.is_shutdown:
raise AgentShutdownError(self.db_id)
status = self.get_status()
if status == AgentState.STATUS_DISCONNECT:
raise AgentDisconnectedError(self.db_id)
Expand Down Expand Up @@ -509,6 +524,14 @@ def mark_done(self) -> None:
]:
self.update_status(AgentState.STATUS_WAITING)

def shutdown(self) -> None:
"""
Force the given agent to end any polling threads and throw an AgentShutdownError
from any acts called on it, ensuring tasks using this agent can be cleaned up.
"""
self.has_action.set()
self.is_shutdown = True

@staticmethod
def new(db: "MephistoDB", worker: Worker, task_run: "TaskRun") -> "OnboardingAgent":
"""
Expand Down
7 changes: 7 additions & 0 deletions mephisto/data_model/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,10 @@ class AgentReturnedError(AbsentAgentError):

def __init__(self, agent_id):
super().__init__(f"Agent returned task", agent_id)


class AgentShutdownError(AbsentAgentError):
"""Exception for when a task is shutdown but agents are still in a task"""

def __init__(self, agent_id):
super().__init__(f"This agent has been forced to shut down", agent_id)
67 changes: 63 additions & 4 deletions mephisto/operations/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import threading
import shlex
import traceback
import signal

from argparse import ArgumentParser

Expand Down Expand Up @@ -300,13 +301,67 @@ def _track_and_kill_runs(self):
del self._task_runs_tracked[task_run.db_id]
time.sleep(2)

def force_shutdown(self, timeout=5):
"""
Force a best-effort shutdown of everything, letting no individual
shutdown step suspend for more than the timeout before moving on.
Skips waiting for in-flight assignments to rush the shutdown.
** Should only be used in sandbox or test environments. **
"""
self.is_shutdown = True

def end_launchers_and_expire_units():
for tracked_run in self._task_runs_tracked.values():
tracked_run.task_launcher.shutdown()
tracked_run.task_launcher.expire_units()

def end_architects():
for tracked_run in self._task_runs_tracked.values():
tracked_run.architect.shutdown()

def shutdown_supervisor():
if self.supervisor is not None:
self.supervisor.shutdown()

tasks = {
"expire-units": end_launchers_and_expire_units,
"kill-architects": end_architects,
"fire-supervisor": shutdown_supervisor,
}

for tname, t in tasks.items():
shutdown_thread = threading.Thread(target=t, name=f"force-shutdown-{tname}")
shutdown_thread.start()
start_time = time.time()
while time.time() - start_time < timeout and shutdown_thread.is_alive():
time.sleep(0.5)
if not shutdown_thread.is_alive():
# Only join if the shutdown fully completed
shutdown_thread.join()

def shutdown(self, skip_input=True):
logger.info("operator shutting down")
self.is_shutdown = True
for tracked_run in self._task_runs_tracked.values():
logger.info("expiring units")
tracked_run.task_launcher.shutdown()
for run_id, tracked_run in self._task_runs_tracked.items():
logger.info(f"Expiring units for task run {run_id}.")
try:
tracked_run.task_launcher.shutdown()
except (KeyboardInterrupt, SystemExit) as e:
logger.info(
f"Skipping waiting for launcher threads to join on task run {run_id}."
)

def cant_cancel_expirations(self, sig, frame):
logging.warn(
"Ignoring ^C during unit expirations. ^| if you NEED to exit and you will "
"clean up units that hadn't been expired afterwards."
)

old_handler = signal.signal(signal.SIGINT, cant_cancel_expirations)
tracked_run.task_launcher.expire_units()
signal.signal(signal.SIGINT, old_handler)
try:
remaining_runs = self._task_runs_tracked.values()
while len(remaining_runs) > 0:
Expand All @@ -318,7 +373,8 @@ def shutdown(self, skip_input=True):
next_runs.append(tracked_run)
if len(next_runs) > 0:
logger.info(
f"Waiting on {len(remaining_runs)} task runs, Ctrl-C ONCE to FORCE QUIT"
f"Waiting on {len(remaining_runs)} task runs with assignments in-flight "
f"Ctrl-C ONCE to kill running tasks and FORCE QUIT."
)
time.sleep(30)
remaining_runs = next_runs
Expand All @@ -334,6 +390,9 @@ def shutdown(self, skip_input=True):
"Skipping waiting for outstanding task completions, shutting down servers now!"
)
for tracked_run in remaining_runs:
logger.info(
f"Shutting down Architect for task run {tracked_run.task_run.db_id}"
)
tracked_run.architect.shutdown()
finally:
self.supervisor.shutdown()
Expand Down
6 changes: 6 additions & 0 deletions mephisto/operations/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,15 @@ def shutdown(self):
"""Close all of the channels, join threads"""
channels_to_close = list(self.channels.keys())
for channel_id in channels_to_close:
channel_info = self.channels[channel_id]
channel_info.job.task_runner.shutdown()
self.close_channel(channel_id)
if self.sending_thread is not None:
self.sending_thread.join()
for agent_info in self.agents.values():
assign_thread = agent_info.assignment_thread
if assign_thread is not None:
assign_thread.join()

def _send_alive(self, channel_info: ChannelInfo) -> bool:
logger.info("Sending alive")
Expand Down
1 change: 1 addition & 0 deletions mephisto/operations/task_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def shutdown(self) -> None:
"""Clean up running threads for generating assignments and units"""
self.assignment_thread_done = True
self.keep_launching_units = False
self.finished_generators = True
if self.assignments_thread is not None:
self.assignments_thread.join()
self.units_thread.join()

0 comments on commit 297e5d8

Please sign in to comment.