Skip to content

Commit

Permalink
Fix race condition on restart by restoring default signal handler
Browse files Browse the repository at this point in the history
  • Loading branch information
user committed Dec 21, 2024
1 parent 5152622 commit 0dfee63
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 36 deletions.
56 changes: 24 additions & 32 deletions avtdl/avtdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import argparse
import asyncio
import logging
import signal
from asyncio import AbstractEventLoop
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple

from avtdl.core import webui
from avtdl.core.chain import Chain
Expand Down Expand Up @@ -88,42 +87,35 @@ async def install_exception_handler() -> None:
loop.slow_callback_duration = 100


def get_sigterm_handler(ctx: RuntimeContext) -> Callable:
def handler(sig, frame):
logging.debug(f'signal {sig} received, initiating termination')
ctx.controller.terminate_after(0, TerminatedAction.EXIT)
return handler


async def run(config_path: Path, host: Optional[str], port: Optional[int]) -> None:
await install_exception_handler()
while True:
config = load_config(config_path)
ctx = RuntimeContext.create()
signal.signal(signal.SIGINT, get_sigterm_handler(ctx))
signal.signal(signal.SIGTERM, get_sigterm_handler(ctx))
settings, actors, chains = parse_config(config, ctx)
config_sancheck(actors, chains)

if host is not None:
settings.host = host
if port is not None:
settings.port = port

controller = ctx.controller
for runnable in actors.values():
_ = controller.create_task(runnable.run(), name=f'{runnable!r}.{hash(runnable)}')
_ = controller.create_task(webui.run(config_path, config, ctx, settings, actors, chains), name='webui')

action = await controller.run_until_termination()
if action == TerminatedAction.EXIT:
logging.info('terminating...')
break
elif action == TerminatedAction.RESTART:
logging.info('restarting...')
continue
else:
assert False, f'Unknown action: {action}'
with ctx:
settings, actors, chains = parse_config(config, ctx)
config_sancheck(actors, chains)

if host is not None:
settings.host = host
if port is not None:
settings.port = port

controller = ctx.controller
for runnable in actors.values():
_ = controller.create_task(runnable.run(), name=f'{runnable!r}.{hash(runnable)}')
_ = controller.create_task(webui.run(config_path, config, ctx, settings, actors, chains), name='webui')

action = await controller.run_until_termination()
if action == TerminatedAction.EXIT:
logging.info('terminating...')
break
elif action == TerminatedAction.RESTART:
logging.info('restarting...')
continue
else:
assert False, f'Unknown action: {action}'


def make_docs(output: Path) -> None:
Expand Down
25 changes: 21 additions & 4 deletions avtdl/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import datetime
import json
import logging
import signal
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from dataclasses import dataclass
from enum import Enum
from hashlib import sha1
from textwrap import shorten
Expand Down Expand Up @@ -301,10 +301,27 @@ async def run_until_termination(self) -> TerminatedAction:
return self.terminated_action


@dataclass
class RuntimeContext:
bus: MessageBus
controller: TasksController
def __init__(self, bus: MessageBus, controller: TasksController):
self.bus: MessageBus = bus
self.controller: TasksController = controller
self._sigint_handler = signal.getsignal(signal.SIGINT)
self._sigterm_handler = signal.getsignal(signal.SIGINT)

def _get_handler(self) -> Callable:
controller = self.controller
def handler(sig, frame):
controller.terminate_after(0, TerminatedAction.EXIT)
return handler

def __enter__(self):
signal.signal(signal.SIGINT, self._get_handler())
signal.signal(signal.SIGTERM, self._get_handler())

def __exit__(self, exc_type, exc_val, exc_tb):
signal.signal(signal.SIGINT, self._sigint_handler)
signal.signal(signal.SIGTERM, self._sigterm_handler)


@classmethod
def create(cls) -> 'RuntimeContext':
Expand Down

0 comments on commit 0dfee63

Please sign in to comment.