From a0d302dbdba278a2168695ca4b8072e4c8ee75b7 Mon Sep 17 00:00:00 2001 From: Robin Harms Oredsson Date: Thu, 7 Mar 2024 22:29:52 +0100 Subject: [PATCH] Allow override of layer and envelope for pubsub channels. Closes #2 --- CHANGELOG.md | 3 +- envelope/channels/messages.py | 4 ++- envelope/channels/models.py | 30 +++++++++++++++++-- .../deferred_jobs/tests/test_async_signals.py | 1 - 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a100a1d..537798b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,11 @@ # Changelog - ## dev * RecheckSubscriptionsSchema subscriptions changed from set to list to fix common serialization problems. +* PubSub and context channels accepts arguments `envelope_name` and `layer_name` in case + they need to be overridden. (#2) ## 1.0.0 (2024-03-07) diff --git a/envelope/channels/messages.py b/envelope/channels/messages.py index bc31d84..42fa3be 100644 --- a/envelope/channels/messages.py +++ b/envelope/channels/messages.py @@ -47,7 +47,7 @@ def get_channel( ) -> ContextChannel: ch = get_context_channel(channel_type) # This may cause errors right? - return ch(pk, consumer_name) + return ch(pk, consumer_channel=consumer_name) @add_message(WS_INCOMING) @@ -69,6 +69,8 @@ def get_app_state(self, channel: ContextChannel) -> list | None: return list(app_state) async def pre_queue(self, consumer: WebsocketConsumer, **kwargs) -> Subscribed: + if self.mm.consumer_name is None: + self.mm.consumer_name = consumer.channel_name channel = self.get_channel( self.data.channel_type, self.data.pk, self.mm.consumer_name ) diff --git a/envelope/channels/models.py b/envelope/channels/models.py index c3c888f..6b5fd7b 100644 --- a/envelope/channels/models.py +++ b/envelope/channels/models.py @@ -48,8 +48,15 @@ def channel_name(self) -> str: def __init__( self, consumer_channel: str | None = None, + *, + envelope_name: str | None = None, + layer_name: str | None = None, ): self.consumer_channel = consumer_channel + if envelope_name: + self.envelope_name = envelope_name + if layer_name: + self.layer_name = layer_name async def subscribe(self): if not self.consumer_channel: # pragma: no coverage @@ -98,9 +105,16 @@ def __init__( self, pk: int, consumer_channel: str | None = None, + *, + envelope_name: str | None = None, + layer_name: str | None = None, ): self.pk = pk - super().__init__(consumer_channel) + super().__init__( + consumer_channel=consumer_channel, + envelope_name=envelope_name, + layer_name=layer_name, + ) @property def channel_name(self) -> str: @@ -126,10 +140,20 @@ def permission(self) -> str | None: @classmethod def from_instance( - cls, instance: models.Model, consumer_channel: str | None = None + cls, + instance: models.Model, + consumer_channel: str | None = None, + *, + envelope_name: str | None = None, + layer_name: str | None = None, ) -> ContextChannel: assert isinstance(instance, cls.model), f"Instance must be a {cls.model}" - inst = cls(instance.pk, consumer_channel) + inst = cls( + instance.pk, + consumer_channel=consumer_channel, + envelope_name=envelope_name, + layer_name=layer_name, + ) # Set context straight away to avoid lookup inst.context = instance return inst diff --git a/envelope/deferred_jobs/tests/test_async_signals.py b/envelope/deferred_jobs/tests/test_async_signals.py index a208e58..f398713 100644 --- a/envelope/deferred_jobs/tests/test_async_signals.py +++ b/envelope/deferred_jobs/tests/test_async_signals.py @@ -115,7 +115,6 @@ async def test_dispatch_disconnection_job_no_user(self): async def test_maybe_update_connection(self): self.mock_consumer.user_pk = self.user.pk self.mock_consumer.connection_update_interval = 10 - self.user = self.user with patch( "django_rq.queues.get_redis_connection",