Skip to content

Commit

Permalink
A proper job scheduling lib to handle background tasks (#85)
Browse files Browse the repository at this point in the history
* Proper job scheduling lib to handle background tasks

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* min diff

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cmyui and pre-commit-ci[bot] authored Apr 20, 2024
1 parent 1585295 commit 6a8c376
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 37 deletions.
6 changes: 3 additions & 3 deletions app/api/direct.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import logging
from typing import Any
from typing import Optional
Expand All @@ -16,6 +15,7 @@
import app.state
import app.usecases
import config
from app import job_scheduling
from app.adapters import amplitude
from app.constants.ranked_status import RankedStatus
from app.models.user import User
Expand Down Expand Up @@ -114,7 +114,7 @@ async def osu_direct(
)

if config.AMPLITUDE_API_KEY:
asyncio.create_task(
job_scheduling.schedule_job(
amplitude.track(
event_name="osudirect_search",
user_id=str(user.id),
Expand Down Expand Up @@ -170,7 +170,7 @@ async def beatmap_card(
json_data = result["data"] if USING_CHIMU else result

if config.AMPLITUDE_API_KEY:
asyncio.create_task(
job_scheduling.schedule_job(
amplitude.track(
event_name="osudirect_card_view",
user_id=str(user.id),
Expand Down
1 change: 0 additions & 1 deletion app/api/lastfm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import logging
import time

Expand Down
4 changes: 2 additions & 2 deletions app/api/rate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import logging
import time
from typing import Optional
Expand All @@ -11,6 +10,7 @@
import app.state
import app.usecases
import config
from app import job_scheduling
from app.adapters import amplitude
from app.models.beatmap import Beatmap
from app.models.user import User
Expand Down Expand Up @@ -66,7 +66,7 @@ async def rate_map(
beatmap.rating = new_rating

if config.AMPLITUDE_API_KEY:
asyncio.create_task(
job_scheduling.schedule_job(
amplitude.track(
event_name="rated_beatmap",
user_id=str(user.id),
Expand Down
4 changes: 2 additions & 2 deletions app/api/replays.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import logging

from fastapi import Depends
Expand All @@ -12,6 +11,7 @@
import app.usecases
import app.utils
import config
from app import job_scheduling
from app.adapters import amplitude
from app.constants.mode import Mode
from app.models.score import Score
Expand Down Expand Up @@ -47,7 +47,7 @@ async def get_replay(
await app.usecases.user.increment_replays_watched(db_score["userid"], mode)

if config.AMPLITUDE_API_KEY:
asyncio.create_task(
job_scheduling.schedule_job(
amplitude.track(
event_name="watched_replay",
user_id=str(user.id),
Expand Down
6 changes: 3 additions & 3 deletions app/api/score_sub.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import dataclasses
import hashlib
import logging
Expand Down Expand Up @@ -29,6 +28,7 @@
import app.usecases
import app.utils
import config
from app import job_scheduling
from app.adapters import amplitude
from app.constants.mode import Mode
from app.constants.ranked_status import RankedStatus
Expand Down Expand Up @@ -398,7 +398,7 @@ async def submit_score(
device_id = hashlib.sha1(login_disk_id.encode()).hexdigest()

if config.AMPLITUDE_API_KEY:
asyncio.create_task(
job_scheduling.schedule_job(
amplitude.track(
event_name="score_submission",
user_id=str(user.id),
Expand Down Expand Up @@ -453,7 +453,7 @@ async def submit_score(
# fire amplitude events for each
for achievement in new_achievements:
if config.AMPLITUDE_API_KEY:
asyncio.create_task(
job_scheduling.schedule_job(
amplitude.track(
event_name="achievement_unlocked",
user_id=str(score.user_id),
Expand Down
4 changes: 2 additions & 2 deletions app/api/screenshots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import logging
import random
import string
Expand All @@ -16,6 +15,7 @@
import app.state
import app.utils
import config
from app import job_scheduling
from app.adapters import amplitude
from app.adapters import s3
from app.models.user import User
Expand Down Expand Up @@ -95,7 +95,7 @@ async def upload_screenshot(
await s3.upload(content, file_name, "screenshots")

if config.AMPLITUDE_API_KEY:
asyncio.create_task(
job_scheduling.schedule_job(
amplitude.track(
event_name="upload_screenshot",
user_id=str(user.id),
Expand Down
4 changes: 3 additions & 1 deletion app/init_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import contextlib
import logging
import pprint
Expand All @@ -21,6 +20,7 @@
import app.state
import app.usecases
import config
from app import job_scheduling

ctx_stack = contextlib.AsyncExitStack()

Expand Down Expand Up @@ -88,6 +88,8 @@ async def on_startup() -> None:

@asgi_app.on_event("shutdown")
async def on_shutdown() -> None:
await job_scheduling.await_running_jobs(timeout=7.5)

await app.state.services.database.disconnect()

await app.state.services.redis.close()
Expand Down
79 changes: 79 additions & 0 deletions app/job_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

import asyncio
import sys
from typing import Any
from typing import Coroutine
from typing import Generator
from typing import TypeVar
from typing import Union

T = TypeVar("T")

ACTIVE_TASKS: set[asyncio.Task[Any]] = set()


def schedule_job(
coro: Union[
Generator[Any, None, T],
Coroutine[Any, Any, T],
],
) -> asyncio.Task[T]:
"""\
Run a coroutine to run in the background.
This function is a wrapper around `asyncio.create_task` that adds the task
to a set of active tasks. This set is used to provide handling of any
exceptions that occur as well as to wait for all tasks to complete before
shutting down the application.
"""
task = asyncio.create_task(coro)
task.add_done_callback(_handle_task_exception)
_register_task(task)
return task


def _register_task(task: asyncio.Task[Any]) -> None:
ACTIVE_TASKS.add(task)


def _unregister_task(task: asyncio.Task[Any]) -> None:
ACTIVE_TASKS.remove(task)


def _handle_task_exception(task: asyncio.Task[Any]) -> None:
_unregister_task(task)

if task.cancelled():
return None

try:
exception = task.exception()
except asyncio.InvalidStateError:
pass
else:
if exception is not None:
sys.excepthook(
type(exception),
exception,
exception.__traceback__,
)


async def await_running_jobs(
timeout: float,
) -> tuple[set[asyncio.Task[Any]], set[asyncio.Task[Any]]]:
"""\
Await all tasks to complete, or until the timeout is reached.
Returns a tuple of done and pending tasks.
"""
if not ACTIVE_TASKS:
return set(), set()

done, pending = await asyncio.wait(
ACTIVE_TASKS,
timeout=timeout,
return_when=asyncio.ALL_COMPLETED,
)
return done, pending
6 changes: 2 additions & 4 deletions app/usecases/discord.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

import asyncio
import logging
import traceback
from typing import Optional

from tenacity import retry
Expand All @@ -11,6 +9,7 @@

import app.state
import config
from app import job_scheduling
from app.models.user import User
from app.reliability import retry_if_exception_network_related

Expand Down Expand Up @@ -198,8 +197,7 @@ def schedule_hook(hook: Optional[str], embed: Embed):
if not hook:
return

loop = asyncio.get_event_loop()
loop.create_task(wrap_hook(hook, embed))
job_scheduling.schedule_job(wrap_hook(hook, embed))

logging.debug("Scheduled the performing of a discord webhook!")

Expand Down
5 changes: 1 addition & 4 deletions app/usecases/password.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import asyncio

import bcrypt

CACHE: dict[str, str] = {}
Expand All @@ -11,8 +9,7 @@ async def verify_password(plain_password: str, hashed_password: str) -> bool:
if hashed_password in CACHE:
return CACHE[hashed_password] == plain_password

result = await asyncio.to_thread(
bcrypt.checkpw,
result = bcrypt.checkpw(
plain_password.encode(),
hashed_password.encode(),
)
Expand Down
6 changes: 4 additions & 2 deletions app/usecases/score.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import asyncio
import hashlib
import logging
from typing import Optional

import app.state
import app.usecases
import app.utils
from app import job_scheduling
from app.models.achievement import Achievement
from app.models.beatmap import Beatmap
from app.models.score import Score
Expand Down Expand Up @@ -123,7 +123,9 @@ async def handle_first_place(
)

msg = f"[{score.mode.relax_str}] User {user.embed} has submitted a #1 place on {beatmap.embed} +{score.mods!r} ({score.pp:.2f}pp)"
await app.utils.send_announcement_as_side_effect(msg)
await job_scheduling.schedule_job(
app.utils.send_message_to_channel("#announce", msg),
)


OSU_VERSION = 2021_11_03
Expand Down
14 changes: 1 addition & 13 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from __future__ import annotations

import asyncio
import logging

from tenacity import retry
from tenacity import wait_exponential
from tenacity.stop import stop_after_attempt
Expand All @@ -22,7 +19,7 @@ def make_safe(username: str) -> str:
stop=stop_after_attempt(10),
reraise=True,
)
async def channel_message(channel: str, message: str) -> None:
async def send_message_to_channel(channel: str, message: str) -> None:
response = await app.state.services.http_client.get(
f"{config.BANCHO_SERVICE_URL}/api/v1/fokabotMessage",
params={
Expand All @@ -35,15 +32,6 @@ async def channel_message(channel: str, message: str) -> None:
response.raise_for_status()


async def send_announcement_as_side_effect(message: str) -> None:
try:
asyncio.create_task(channel_message("#announce", message))
except asyncio.TimeoutError:
logging.warning(
"Failed to send message to #announce, bancho-service is likely down",
)


async def check_online(user_id: int) -> bool:
key = f"bancho:tokens:ids:{user_id}"
return await app.state.services.redis.exists(key)
Expand Down

0 comments on commit 6a8c376

Please sign in to comment.