Skip to content

Commit

Permalink
Loading jobs differently to make them easier to track
Browse files Browse the repository at this point in the history
  • Loading branch information
robinharms committed Jul 1, 2024
1 parent 6576128 commit d456a6d
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 157 deletions.
69 changes: 1 addition & 68 deletions envelope/deferred_jobs/jobs.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
from __future__ import annotations

from datetime import datetime
from logging import getLogger
from typing import TYPE_CHECKING

from django.conf import settings
from django.contrib.auth import get_user_model
from django.db import transaction
from django.utils.translation import activate

from envelope import Error
from envelope import WS_INCOMING
from envelope.core.message import ErrorMessage
from envelope.signals import connection_closed
from envelope.signals import connection_created
from envelope.utils import get_envelope
from envelope.utils import get_error_type
from envelope.utils import update_connection_status
from envelope.utils import websocket_send_error

if TYPE_CHECKING:
from envelope.deferred_jobs.message import DeferredJob

User = get_user_model()

Expand All @@ -39,25 +29,6 @@ def _set_lang(lang=None):
activate(lang)


def handle_failure(job, connection, exc_type, exc_value, traceback):
"""
Failure callbacks are functions that accept job, connection, type, value and traceback arguments.
type, value and traceback values returned by sys.exc_info(),
which is the exception raised when executing your job.
See RQs docs
"""
mm = job.kwargs.get("mm", {})
if mm:
message_id = mm.get("id", None)
consumer_name = mm.get("consumer_name", None)
if message_id and consumer_name:
# FIXME: exc value might not be safe here
err = get_error_type(Error.JOB)(mm=mm, msg=str(exc_value))
websocket_send_error(err, channel_name=consumer_name)
return err # For testing, has no effect


def create_connection_status_on_websocket_connect(
*,
user_pk: int,
Expand Down Expand Up @@ -118,41 +89,3 @@ def mark_connection_action(
action_at: datetime,
):
update_connection_status(user_pk, channel_name=consumer_name, last_action=action_at)


def default_job(
data: dict, mm: dict, t: str, *, enqueued_at: datetime = None, **kwargs
):
env_name = mm.get("env", WS_INCOMING)
envelope = get_envelope(env_name)
# We won't handle key error here. Message name should've been checked when it was received.
message = envelope.registry[t](mm=mm, data=data)
run_job(message, enqueued_at=enqueued_at)


def run_job(message: DeferredJob, *, enqueued_at: datetime = None, update_conn=True):
message.on_worker = True
if message.mm.language:
# Otherwise skip lang?
activate(message.mm.language)
try:
if message.atomic:
with transaction.atomic(durable=True):
message.run_job()
else:
message.run_job()
except ErrorMessage as err: # Catchable, nice errors
if err.mm.id is None:
err.mm.id = message.mm.id
if err.mm.consumer_name is None:
err.mm.consumer_name = message.mm.consumer_name
if err.mm.consumer_name:
websocket_send_error(err)
else:
# Everything went fine
if update_conn and message.mm.user_pk and message.mm.consumer_name:
update_connection_status(
user_pk=message.mm.user_pk,
channel_name=message.mm.consumer_name,
last_action=enqueued_at,
)
69 changes: 61 additions & 8 deletions envelope/deferred_jobs/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@

from abc import ABC
from abc import abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING

from django.db import transaction
from django.utils.functional import cached_property
from django.utils.timezone import now
from django.utils.translation import activate
from django_rq import get_queue
from pydantic import BaseModel
from rq import Queue

from envelope import DEFAULT_QUEUE_NAME
from envelope import Error
from envelope.core.message import ErrorMessage
from envelope.core.message import Message
from envelope.utils import get_error_type
from envelope.utils import update_connection_status
from envelope.utils import websocket_send_error

if TYPE_CHECKING:
from django.db.models import Model
Expand All @@ -30,10 +36,8 @@ class DeferredJob(Message, ABC):
None # Job exec timeout in seconds if you want to override default
)
queue: str = DEFAULT_QUEUE_NAME # Queue name
# connection: None | Redis = None
atomic: bool = True
on_worker: bool = False
job: callable | str = "envelope.deferred_jobs.jobs.default_job"
should_run: bool = True # Mark as false to abort run

async def pre_queue(self, **kwargs):
Expand All @@ -42,11 +46,60 @@ async def pre_queue(self, **kwargs):
It's a good idea to avoid using this if it's not needed.
"""

@property
def on_failure(self) -> callable | None:
from envelope.deferred_jobs.jobs import handle_failure
@staticmethod
def handle_failure(job, connection, exc_type, exc_value, traceback):
"""
Failure callbacks are functions that accept job, connection, type, value and traceback arguments.
type, value and traceback values returned by sys.exc_info(),
which is the exception raised when executing your job.
return handle_failure
See RQs docs
"""
mm = job.kwargs.get("mm", {})
if mm:
if consumer_name := mm.get("consumer_name", None):
# FIXME: exc value might not be safe here
err = get_error_type(Error.JOB)(mm=mm, msg=str(exc_value))
websocket_send_error(err, channel_name=consumer_name)
return err # For testing, has no effect

@classmethod
def init_job(
cls,
data: dict,
mm: dict,
t: str,
*,
enqueued_at: datetime = None,
update_conn: bool = True,
**kwargs,
):
message = cls(mm=mm, data=data)
message.on_worker = True
if message.mm.language:
# Otherwise skip lang?
activate(message.mm.language)
try:
if message.atomic:
with transaction.atomic(durable=True):
message.run_job()
else:
message.run_job()
except ErrorMessage as err: # Catchable, nice errors
if err.mm.id is None:
err.mm.id = message.mm.id
if err.mm.consumer_name is None:
err.mm.consumer_name = message.mm.consumer_name
if err.mm.consumer_name:
websocket_send_error(err)
else:
# Everything went fine
if update_conn and message.mm.user_pk and message.mm.consumer_name:
update_connection_status(
user_pk=message.mm.user_pk,
channel_name=message.mm.consumer_name,
last_action=enqueued_at,
)

def enqueue(self, queue: Queue | None = None, **kwargs):
if queue is None:
Expand All @@ -55,7 +108,7 @@ def enqueue(self, queue: Queue | None = None, **kwargs):
data = {}
if self.data:
data = self.data.dict()
kwargs.setdefault("on_failure", self.on_failure)
kwargs.setdefault("on_failure", self.handle_failure)
if self.job_timeout:
kwargs["job_timeout"] = self.job_timeout
if self.ttl:
Expand All @@ -65,7 +118,7 @@ def enqueue(self, queue: Queue | None = None, **kwargs):
"To call enqueue on DeferredJob messages, env must be present in message meta."
)
return queue.enqueue(
self.job,
self.init_job,
t=self.name,
mm=self.mm.dict(),
data=data,
Expand Down
84 changes: 70 additions & 14 deletions envelope/deferred_jobs/tests/test_messages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from unittest.mock import patch

from channels.layers import get_channel_layer
from django.contrib.auth import get_user_model
from django.test import TestCase
from django_rq import get_queue
from fakeredis import FakeStrictRedis
from pydantic import BaseModel
from rq import SimpleWorker

from envelope import WS_INCOMING
Expand All @@ -22,6 +26,20 @@ def run_job(self):
Connection.objects.create(user=self.user, channel_name="abc")


class BadJob(DeferredJob):
name = "bad_job"

def run_job(self):
return 1 / 0


class NeverFoundJob(DeferredJob):
name = "bad_at_looking"

def run_job(self):
raise NotFoundError.from_message(self, model="something", value="1")


class DummyContextAction(ContextAction):
name = "dummy_context_action"
permission = None
Expand All @@ -37,20 +55,9 @@ class DeferredJobTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.user = User.objects.create(username="runner")
cls.msg_reg = get_message_registry(WS_INCOMING)
cls.msg_reg[DummyJob.name] = DummyJob

@classmethod
def tearDownClass(cls):
super().tearDownClass()
cls.msg_reg.pop(DummyJob.name)

def _mk_msg(self, **kwargs):
msg = DummyJob(**kwargs)
return msg

def test_enqueue_via_queue(self):
msg = self._mk_msg(mm={"user_pk": self.user.pk})
msg = DummyJob(mm={"user_pk": self.user.pk})
connection = FakeStrictRedis()
queue = get_queue(connection=connection)
queue.enqueue(msg.run_job)
Expand All @@ -59,7 +66,7 @@ def test_enqueue_via_queue(self):
self.assertTrue(Connection.objects.filter(channel_name="abc").exists())

def test_enqueue_via_msg(self):
msg = self._mk_msg(mm={"user_pk": self.user.pk, "env": WS_INCOMING})
msg = DummyJob(mm={"user_pk": self.user.pk, "env": WS_INCOMING})
connection = FakeStrictRedis()
queue = get_queue(connection=connection)
msg.enqueue(
Expand All @@ -69,6 +76,55 @@ def test_enqueue_via_msg(self):
self.assertTrue(worker.work(burst=True))
self.assertTrue(Connection.objects.filter(channel_name="abc").exists())

def test_error_handling(self):
msg = BadJob(
mm={"user_pk": self.user.pk, "env": WS_INCOMING, "consumer_name": "abc"}
)
connection = FakeStrictRedis()
queue = get_queue(connection=connection)
msg.enqueue(queue)
worker = SimpleWorker([queue], connection=connection)
channel_layer = get_channel_layer()
with patch.object(channel_layer, "send") as mock_send:
with self.captureOnCommitCallbacks(execute=True):
self.assertTrue(worker.work(burst=True))
self.assertTrue(mock_send.called)
self.assertEqual(
{
"t": "error.job",
"p": {"msg": "division by zero"},
"i": None,
"s": "f",
"type": "ws.error.send",
},
mock_send.call_args[0][1],
)

def test_job_raises_catchable_error(self):
msg = NeverFoundJob(
mm={"user_pk": self.user.pk, "env": WS_INCOMING, "consumer_name": "abc"},
)
connection = FakeStrictRedis()
queue = get_queue(connection=connection)
msg.enqueue(queue)
worker = SimpleWorker([queue], connection=connection)
channel_layer = get_channel_layer()
with patch.object(channel_layer, "send") as mock_send:
with self.captureOnCommitCallbacks(execute=True):
self.assertTrue(worker.work(burst=True))
self.assertTrue(mock_send.called)
self.assertEqual(
{
"t": "error.job",
"p": {"key": "pk", "model": "something", "value": "1"},
"i": None,
"s": "f",
"t": "error.not_found",
"type": "ws.error.send",
},
mock_send.call_args[0][1],
)


class DummyContextActionTests(TestCase):
@classmethod
Expand Down Expand Up @@ -146,7 +202,7 @@ def test_perm_raises(self):
{
"key": "pk",
"model": "envelope.connection",
"value": "1",
"value": str(self.conn.pk),
"permission": "hard to come by",
},
cm.exception.data.dict(),
Expand Down
Loading

0 comments on commit d456a6d

Please sign in to comment.