From d8412936fb5e51d00dced828e8d9c6c0859287c4 Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Thu, 15 Feb 2024 15:44:26 +0800 Subject: [PATCH] Split out fork executor (#319) * Move out task IDs set * Move execution code into separate file * Executor base class * Prepare execution helper --- tasktiger/executor.py | 364 ++++++++++++++++++++++++++++++++++++++++ tasktiger/worker.py | 374 ++++-------------------------------------- 2 files changed, 393 insertions(+), 345 deletions(-) create mode 100644 tasktiger/executor.py diff --git a/tasktiger/executor.py b/tasktiger/executor.py new file mode 100644 index 00000000..95e4ea4c --- /dev/null +++ b/tasktiger/executor.py @@ -0,0 +1,364 @@ +import errno +import fcntl +import os +import random +import select +import signal +import socket +import sys +import threading +import time +import traceback +from contextlib import ExitStack +from typing import ( + TYPE_CHECKING, + Any, + Collection, + ContextManager, + Dict, + List, + Optional, +) + +from redis.exceptions import LockError +from redis.lock import Lock +from structlog.stdlib import BoundLogger + +from ._internal import ( + g, + g_fork_lock, + serialize_func_name, + serialize_retry_method, +) +from .exceptions import RetryException +from .redis_semaphore import Semaphore +from .runner import get_runner_class +from .task import Task +from .timeouts import JobTimeoutException + +if TYPE_CHECKING: + from .worker import Worker + + +def sigchld_handler(*args: Any) -> None: + # Nothing to do here. This is just a dummy handler that we set up to catch + # the child process exiting. + pass + + +class WorkerContextManagerStack(ExitStack): + def __init__(self, context_managers: List[ContextManager]) -> None: + super(WorkerContextManagerStack, self).__init__() + + for mgr in context_managers: + self.enter_context(mgr) + + +class Executor: + def __init__(self, worker: "Worker"): + self.tiger = worker.tiger + self.worker = worker + self.connection = worker.connection + self.config = worker.config + + def execute( + self, + queue: str, + tasks: List[Task], + log: BoundLogger, + locks: Collection[Lock], + queue_lock: Optional[Semaphore], + ) -> bool: + """ + Executes the given tasks. Returns a boolean indicating whether + the tasks were executed successfully. + + Args: + queue: Name of the task queue. + tasks: List of tasks to execute, + log: Logger. + locks: List of task locks to renew periodically. + queue_lock: Optional queue lock to renew periodically for max + workers per queue. + """ + raise NotImplementedError + + +class ForkExecutor(Executor): + def execute( + self, + queue: str, + tasks: List[Task], + log: BoundLogger, + locks: Collection[Lock], + queue_lock: Optional[Semaphore], + ) -> bool: + task_func = tasks[0].func + serialized_task_func = tasks[0].serialized_func + + all_task_ids = {task.id for task in tasks} + with g_fork_lock: + child_pid = os.fork() + + if child_pid == 0: + # Child process + log = log.bind(child_pid=os.getpid()) + assert isinstance(log, BoundLogger) + + # Disconnect the Redis connection inherited from the main process. + # Note that this doesn't disconnect the socket in the main process. + self.connection.connection_pool.disconnect() + + random.seed() + + # Ignore Ctrl+C in the child so we don't abort the job -- the main + # process already takes care of a graceful shutdown. + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Run the tasks. + success = self._execute_forked(tasks, log) + + # Wait for any threads that might be running in the child, just + # like sys.exit() would. Note we don't call sys.exit() directly + # because it would perform additional cleanup (e.g. calling atexit + # handlers twice). See also: https://bugs.python.org/issue18966 + threading._shutdown() # type: ignore[attr-defined] + + os._exit(int(not success)) + else: + # Main process + log = log.bind(child_pid=child_pid) + assert isinstance(log, BoundLogger) + for task in tasks: + log.info( + "processing", + func=serialized_task_func, + task_id=task.id, + params={"args": task.args, "kwargs": task.kwargs}, + ) + + # Attach a signal handler to SIGCHLD (sent when the child process + # exits) so we can capture it. + signal.signal(signal.SIGCHLD, sigchld_handler) + + # Since newer Python versions retry interrupted system calls we can't + # rely on the fact that select() is interrupted with EINTR. Instead, + # we'll set up a wake-up file descriptor below. + + # Create a new pipe and apply the non-blocking flag (required for + # set_wakeup_fd). + pipe_r, pipe_w = os.pipe() + + opened_fd = os.fdopen(pipe_r) + flags = fcntl.fcntl(pipe_r, fcntl.F_GETFL, 0) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(pipe_r, fcntl.F_SETFL, flags) + + flags = fcntl.fcntl(pipe_w, fcntl.F_GETFL, 0) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(pipe_w, fcntl.F_SETFL, flags) + + # A byte will be written to pipe_w if a signal occurs (and can be + # read from pipe_r). + old_wakeup_fd = signal.set_wakeup_fd(pipe_w) + + def check_child_exit() -> Optional[int]: + """ + Do a non-blocking check to see if the child process exited. + Returns None if the process is still running, or the exit code + value of the child process. + """ + try: + pid, return_code = os.waitpid(child_pid, os.WNOHANG) + if pid != 0: # The child process is done. + return return_code + except OSError as e: + # Of course EINTR can happen if the child process exits + # while we're checking whether it exited. In this case it + # should be safe to retry. + if e.errno == errno.EINTR: + return check_child_exit() + else: + raise + return None + + hard_timeouts = self.worker.get_hard_timeouts(task_func, tasks) + time_started = time.time() + + # Upper bound for when we expect the child processes to finish. + # Since the hard timeout doesn't cover any processing overhead, + # we're adding an extra buffer of ACTIVE_TASK_UPDATE_TIMEOUT + # (which is the same time we use to determine if a task has + # expired). + timeout_at = ( + time_started + + sum(hard_timeouts) + + self.config["ACTIVE_TASK_UPDATE_TIMEOUT"] + ) + + # Wait for the child to exit and perform a periodic heartbeat. + # We check for the child twice in this loop so that we avoid + # unnecessary waiting if the child exited just before entering + # the while loop or while renewing heartbeat/locks. + while True: + return_code = check_child_exit() + if return_code is not None: + break + + # Wait until the timeout or a signal / child exit occurs. + try: + # If observed the following behavior will be seen + # in the pipe when the parent process receives a + # SIGTERM while a task is running in a child process: + # Linux: + # - 0 when parent receives SIGTERM + # - select() exits with EINTR when child exit + # triggers signal, so the signal in the + # pipe is never seen since check_child_exit() + # will see the child is gone + # + # macOS: + # - 15 (SIGTERM) when parent receives SIGTERM + # - 20 (SIGCHLD) when child exits + results = select.select( + [pipe_r], + [], + [], + self.config["ACTIVE_TASK_UPDATE_TIMER"], + ) + + if results[0]: + # Purge pipe so select will pause on next call + try: + # Behavior of a would be blocking read() + # Linux: + # Python 2.7 Raises IOError + # Python 3.x returns empty string + # + # macOS: + # Returns empty string + opened_fd.read(1) + except IOError: + pass + + except select.error as e: + if e.args[0] != errno.EINTR: + raise + + return_code = check_child_exit() + if return_code is not None: + break + + now = time.time() + if now > timeout_at: + log.error("hard timeout elapsed in parent process") + os.kill(child_pid, signal.SIGKILL) + pid, return_code = os.waitpid(child_pid, 0) + log.error("child killed", return_code=return_code) + execution = { + "time_started": time_started, + "time_failed": now, + "exception_name": serialize_func_name( + JobTimeoutException + ), + "success": False, + "host": socket.gethostname(), + } + self.worker.store_task_execution(tasks, execution) + break + + try: + self.worker.heartbeat(queue, all_task_ids) + for lock in locks: + try: + lock.reacquire() + except LockError: + log.warning( + "could not reacquire lock", lock=lock.name + ) + if queue_lock: + acquired, current_locks = queue_lock.renew() + if not acquired: + log.debug("queue lock renew failure") + except OSError as e: + # EINTR happens if the task completed. Since we're just + # renewing locks/heartbeat it's okay if we get interrupted. + if e.errno != errno.EINTR: + raise + + # Restore signals / clean up + signal.signal(signal.SIGCHLD, signal.SIG_DFL) + signal.set_wakeup_fd(old_wakeup_fd) + opened_fd.close() + os.close(pipe_w) + + success = return_code == 0 + return success + + def _execute_forked(self, tasks: List[Task], log: BoundLogger) -> bool: + """ + Executes the tasks in the forked process. Multiple tasks can be passed + for batch processing. However, they must all use the same function and + will share the execution entry. + """ + success = False + + execution: Dict[str, Any] = {} + + assert len(tasks) + task_func = tasks[0].serialized_func + assert all([task_func == task.serialized_func for task in tasks[1:]]) + + execution["time_started"] = time.time() + + try: + func = tasks[0].func + + runner_class = get_runner_class(log, tasks) + runner = runner_class(self.tiger) + + is_batch_func = getattr(func, "_task_batch", False) + g["tiger"] = self.tiger + g["current_task_is_batch"] = is_batch_func + + hard_timeouts = self.worker.get_hard_timeouts(func, tasks) + + with WorkerContextManagerStack( + self.config["CHILD_CONTEXT_MANAGERS"] + ): + if is_batch_func: + # Batch process if the task supports it. + g["current_tasks"] = tasks + runner.run_batch_tasks(tasks, hard_timeouts[0]) + else: + # Process sequentially. + for task, hard_timeout in zip(tasks, hard_timeouts): + g["current_tasks"] = [task] + runner.run_single_task(task, hard_timeout) + + except RetryException as exc: + execution["retry"] = True + if exc.method: + execution["retry_method"] = serialize_retry_method(exc.method) + execution["log_error"] = exc.log_error + execution["exception_name"] = serialize_func_name(exc.__class__) + exc_info = exc.exc_info or sys.exc_info() + except (JobTimeoutException, Exception) as exc: + execution["exception_name"] = serialize_func_name(exc.__class__) + exc_info = sys.exc_info() + else: + success = True + + if not success: + execution["time_failed"] = time.time() + if self.worker.store_tracebacks: + # Currently we only log failed task executions to Redis. + execution["traceback"] = "".join( + traceback.format_exception(*exc_info) + ) + execution["success"] = success + execution["host"] = socket.gethostname() + + self.worker.store_task_execution(tasks, execution) + + return success diff --git a/tasktiger/worker.py b/tasktiger/worker.py index 63fa8fd7..89b1314a 100644 --- a/tasktiger/worker.py +++ b/tasktiger/worker.py @@ -1,24 +1,15 @@ -import errno -import fcntl import hashlib import json import os import random -import select import signal -import socket -import sys -import threading import time -import traceback import uuid from collections import OrderedDict -from contextlib import ExitStack from typing import ( TYPE_CHECKING, Any, Collection, - ContextManager, Dict, List, Literal, @@ -30,7 +21,6 @@ from redis.client import PubSub from redis.exceptions import LockError -from redis.lock import Lock from structlog.stdlib import BoundLogger from ._internal import ( @@ -39,20 +29,13 @@ QUEUED, SCHEDULED, dotted_parts, - g, - g_fork_lock, gen_unique_id, import_attribute, queue_matches, - serialize_func_name, serialize_retry_method, ) -from .exceptions import ( - RetryException, - StopRetry, - TaskImportError, - TaskNotFound, -) +from .exceptions import StopRetry, TaskImportError, TaskNotFound +from .executor import ForkExecutor from .redis_semaphore import Semaphore from .runner import get_runner_class from .stats import StatsThread @@ -68,20 +51,6 @@ __all__ = ["Worker"] -def sigchld_handler(*args: Any) -> None: - # Nothing to do here. This is just a dummy handler that we set up to catch - # the child process exiting. - pass - - -class WorkerContextManagerStack(ExitStack): - def __init__(self, context_managers: List[ContextManager]) -> None: - super(WorkerContextManagerStack, self).__init__() - - for mgr in context_managers: - self.enter_context(mgr) - - class Worker: def __init__( self, @@ -109,6 +78,7 @@ def __init__( self._last_task_check = 0.0 self.stats_thread: Optional[StatsThread] = None self.id = str(uuid.uuid4()) + self.executor = ForkExecutor(self) if queues: self.only_queues = set(queues) @@ -397,7 +367,7 @@ def _worker_queue_expired_tasks(self) -> None: "failed to release lock queue_expired_tasks on full batch" ) - def _get_hard_timeouts(self, func: Any, tasks: List[Task]) -> List[float]: + def get_hard_timeouts(self, func: Any, tasks: List[Task]) -> List[float]: is_batch_func = getattr(func, "_task_batch", False) if is_batch_func: task_timeouts = [ @@ -419,74 +389,6 @@ def _get_hard_timeouts(self, func: Any, tasks: List[Task]) -> List[float]: for task in tasks ] - def _execute_forked(self, tasks: List[Task], log: BoundLogger) -> bool: - """ - Executes the tasks in the forked process. Multiple tasks can be passed - for batch processing. However, they must all use the same function and - will share the execution entry. - """ - success = False - - execution: Dict[str, Any] = {} - - assert len(tasks) - task_func = tasks[0].serialized_func - assert all([task_func == task.serialized_func for task in tasks[1:]]) - - execution["time_started"] = time.time() - - try: - func = tasks[0].func - - runner_class = get_runner_class(log, tasks) - runner = runner_class(self.tiger) - - is_batch_func = getattr(func, "_task_batch", False) - g["tiger"] = self.tiger - g["current_task_is_batch"] = is_batch_func - - hard_timeouts = self._get_hard_timeouts(func, tasks) - - with WorkerContextManagerStack( - self.config["CHILD_CONTEXT_MANAGERS"] - ): - if is_batch_func: - # Batch process if the task supports it. - g["current_tasks"] = tasks - runner.run_batch_tasks(tasks, hard_timeouts[0]) - else: - # Process sequentially. - for task, hard_timeout in zip(tasks, hard_timeouts): - g["current_tasks"] = [task] - runner.run_single_task(task, hard_timeout) - - except RetryException as exc: - execution["retry"] = True - if exc.method: - execution["retry_method"] = serialize_retry_method(exc.method) - execution["log_error"] = exc.log_error - execution["exception_name"] = serialize_func_name(exc.__class__) - exc_info = exc.exc_info or sys.exc_info() - except (JobTimeoutException, Exception) as exc: - execution["exception_name"] = serialize_func_name(exc.__class__) - exc_info = sys.exc_info() - else: - success = True - - if not success: - execution["time_failed"] = time.time() - if self.store_tracebacks: - # Currently we only log failed task executions to Redis. - execution["traceback"] = "".join( - traceback.format_exception(*exc_info) - ) - execution["success"] = success - execution["host"] = socket.gethostname() - - self._store_task_execution(tasks, execution) - - return success - def _get_queue_batch_size(self, queue: str) -> int: """Get queue batch size.""" @@ -538,7 +440,7 @@ def _get_queue_lock( return queue_lock, False - def _heartbeat(self, queue: str, task_ids: Collection[str]) -> None: + def heartbeat(self, queue: str, task_ids: Collection[str]) -> None: """ Updates the heartbeat for the given task IDs to prevent them from timing out and being requeued. @@ -547,233 +449,6 @@ def _heartbeat(self, queue: str, task_ids: Collection[str]) -> None: mapping = {task_id: now for task_id in task_ids} self.connection.zadd(self._key(ACTIVE, queue), mapping) # type: ignore[arg-type] - def _execute( - self, - queue: str, - tasks: List[Task], - log: BoundLogger, - locks: Collection[Lock], - queue_lock: Optional[Semaphore], - all_task_ids: Set[str], - ) -> bool: - """ - Executes the given tasks. Returns a boolean indicating whether - the tasks were executed successfully. - """ - - # The tasks must use the same function. - assert len(tasks) - serialized_task_func = tasks[0].serialized_func - task_func = tasks[0].func - assert all( - [ - serialized_task_func == task.serialized_func - for task in tasks[1:] - ] - ) - - # Before executing periodic tasks, queue them for the next period. - if serialized_task_func in self.tiger.periodic_task_funcs: - tasks[0]._queue_for_next_period() - - with g_fork_lock: - child_pid = os.fork() - - if child_pid == 0: - # Child process - log = log.bind(child_pid=os.getpid()) - assert isinstance(log, BoundLogger) - - # Disconnect the Redis connection inherited from the main process. - # Note that this doesn't disconnect the socket in the main process. - self.connection.connection_pool.disconnect() - - random.seed() - - # Ignore Ctrl+C in the child so we don't abort the job -- the main - # process already takes care of a graceful shutdown. - signal.signal(signal.SIGINT, signal.SIG_IGN) - - # Run the tasks. - success = self._execute_forked(tasks, log) - - # Wait for any threads that might be running in the child, just - # like sys.exit() would. Note we don't call sys.exit() directly - # because it would perform additional cleanup (e.g. calling atexit - # handlers twice). See also: https://bugs.python.org/issue18966 - threading._shutdown() # type: ignore[attr-defined] - - os._exit(int(not success)) - else: - # Main process - log = log.bind(child_pid=child_pid) - assert isinstance(log, BoundLogger) - for task in tasks: - log.info( - "processing", - func=serialized_task_func, - task_id=task.id, - params={"args": task.args, "kwargs": task.kwargs}, - ) - - # Attach a signal handler to SIGCHLD (sent when the child process - # exits) so we can capture it. - signal.signal(signal.SIGCHLD, sigchld_handler) - - # Since newer Python versions retry interrupted system calls we can't - # rely on the fact that select() is interrupted with EINTR. Instead, - # we'll set up a wake-up file descriptor below. - - # Create a new pipe and apply the non-blocking flag (required for - # set_wakeup_fd). - pipe_r, pipe_w = os.pipe() - - opened_fd = os.fdopen(pipe_r) - flags = fcntl.fcntl(pipe_r, fcntl.F_GETFL, 0) - flags = flags | os.O_NONBLOCK - fcntl.fcntl(pipe_r, fcntl.F_SETFL, flags) - - flags = fcntl.fcntl(pipe_w, fcntl.F_GETFL, 0) - flags = flags | os.O_NONBLOCK - fcntl.fcntl(pipe_w, fcntl.F_SETFL, flags) - - # A byte will be written to pipe_w if a signal occurs (and can be - # read from pipe_r). - old_wakeup_fd = signal.set_wakeup_fd(pipe_w) - - def check_child_exit() -> Optional[int]: - """ - Do a non-blocking check to see if the child process exited. - Returns None if the process is still running, or the exit code - value of the child process. - """ - try: - pid, return_code = os.waitpid(child_pid, os.WNOHANG) - if pid != 0: # The child process is done. - return return_code - except OSError as e: - # Of course EINTR can happen if the child process exits - # while we're checking whether it exited. In this case it - # should be safe to retry. - if e.errno == errno.EINTR: - return check_child_exit() - else: - raise - return None - - hard_timeouts = self._get_hard_timeouts(task_func, tasks) - time_started = time.time() - - # Upper bound for when we expect the child processes to finish. - # Since the hard timeout doesn't cover any processing overhead, - # we're adding an extra buffer of ACTIVE_TASK_UPDATE_TIMEOUT - # (which is the same time we use to determine if a task has - # expired). - timeout_at = ( - time_started - + sum(hard_timeouts) - + self.config["ACTIVE_TASK_UPDATE_TIMEOUT"] - ) - - # Wait for the child to exit and perform a periodic heartbeat. - # We check for the child twice in this loop so that we avoid - # unnecessary waiting if the child exited just before entering - # the while loop or while renewing heartbeat/locks. - while True: - return_code = check_child_exit() - if return_code is not None: - break - - # Wait until the timeout or a signal / child exit occurs. - try: - # If observed the following behavior will be seen - # in the pipe when the parent process receives a - # SIGTERM while a task is running in a child process: - # Linux: - # - 0 when parent receives SIGTERM - # - select() exits with EINTR when child exit - # triggers signal, so the signal in the - # pipe is never seen since check_child_exit() - # will see the child is gone - # - # macOS: - # - 15 (SIGTERM) when parent receives SIGTERM - # - 20 (SIGCHLD) when child exits - results = select.select( - [pipe_r], - [], - [], - self.config["ACTIVE_TASK_UPDATE_TIMER"], - ) - - if results[0]: - # Purge pipe so select will pause on next call - try: - # Behavior of a would be blocking read() - # Linux: - # Python 2.7 Raises IOError - # Python 3.x returns empty string - # - # macOS: - # Returns empty string - opened_fd.read(1) - except IOError: - pass - - except select.error as e: - if e.args[0] != errno.EINTR: - raise - - return_code = check_child_exit() - if return_code is not None: - break - - now = time.time() - if now > timeout_at: - log.error("hard timeout elapsed in parent process") - os.kill(child_pid, signal.SIGKILL) - pid, return_code = os.waitpid(child_pid, 0) - log.error("child killed", return_code=return_code) - execution = { - "time_started": time_started, - "time_failed": now, - "exception_name": serialize_func_name( - JobTimeoutException - ), - "success": False, - "host": socket.gethostname(), - } - self._store_task_execution(tasks, execution) - break - - try: - self._heartbeat(queue, all_task_ids) - for lock in locks: - try: - lock.reacquire() - except LockError: - log.warning( - "could not reacquire lock", lock=lock.name - ) - if queue_lock: - acquired, current_locks = queue_lock.renew() - if not acquired: - log.debug("queue lock renew failure") - except OSError as e: - # EINTR happens if the task completed. Since we're just - # renewing locks/heartbeat it's okay if we get interrupted. - if e.errno != errno.EINTR: - raise - - # Restore signals / clean up - signal.signal(signal.SIGCHLD, signal.SIG_DFL) - signal.set_wakeup_fd(old_wakeup_fd) - opened_fd.close() - os.close(pipe_w) - - success = return_code == 0 - return success - def _process_queue_message( self, message_queue: str, @@ -847,9 +522,6 @@ def _process_queue_tasks( else: tasks.append(task) - # List of task IDs that exist and we will update the heartbeat on. - valid_task_ids = {task.id for task in tasks} - # Group by task func tasks_by_func: Dict[str, List[Task]] = OrderedDict() for task in tasks: @@ -862,7 +534,7 @@ def _process_queue_tasks( # Execute tasks for each task func for tasks in tasks_by_func.values(): success, processed_tasks = self._execute_task_group( - queue, tasks, valid_task_ids, queue_lock + queue, tasks, queue_lock ) processed_count = processed_count + len(processed_tasks) log.debug( @@ -942,16 +614,30 @@ def _process_from_queue(self, queue: str) -> Tuple[List[str], int]: return task_ids, processed_count + def _prepare_execution(self, tasks: List[Task]) -> None: + # The tasks must use the same function. + assert len(tasks) + serialized_task_func = tasks[0].serialized_func + assert all( + [ + serialized_task_func == task.serialized_func + for task in tasks[1:] + ] + ) + + # Before executing periodic tasks, queue them for the next period. + if serialized_task_func in self.tiger.periodic_task_funcs: + tasks[0]._queue_for_next_period() + def _execute_task_group( self, queue: str, tasks: List[Task], - all_task_ids: Set[str], queue_lock: Optional[Semaphore], ) -> Tuple[bool, List[Task]]: """ - Executes the given tasks in the queue. Updates the heartbeat for task - IDs passed in all_task_ids. This internal method is only meant to be + Executes the given tasks in the queue as long as they are not locked, + and updates their heartbeats. This internal method is only meant to be called from within _process_from_queue. """ log: BoundLogger = self.log.bind(queue=queue) @@ -995,9 +681,6 @@ def _execute_task_group( when=when, mode="min", ) - # Make sure to remove it from this list so we don't - # re-add to the ACTIVE queue by updating the heartbeat. - all_task_ids.remove(task.id) continue lock_ids.add(lock_id) @@ -1010,8 +693,11 @@ def _execute_task_group( if self.stats_thread: self.stats_thread.report_task_start() - success = self._execute( - queue, ready_tasks, log, locks, queue_lock, all_task_ids + + self._prepare_execution(ready_tasks) + + success = self.executor.execute( + queue, ready_tasks, log, locks, queue_lock ) if self.stats_thread: self.stats_thread.report_task_end() @@ -1236,9 +922,7 @@ def _retrieve_queues(self, key: str) -> Set[str]: return set(self.connection.sscan_iter(key, match=match, count=100000)) - def _store_task_execution( - self, tasks: List[Task], execution: Dict - ) -> None: + def store_task_execution(self, tasks: List[Task], execution: Dict) -> None: serialized_execution = json.dumps(execution) for task in tasks: