diff --git a/README.rst b/README.rst index 2fa5626..3b34aa2 100644 --- a/README.rst +++ b/README.rst @@ -171,6 +171,7 @@ to 10, and all ``websocket.send!`` channels to 20: If you want to enforce a matching order, use an ``OrderedDict`` as the argument; channels will then be matched in the order the dict provides them. +.. _encryption ``symmetric_encryption_keys`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -237,6 +238,51 @@ And then in your channels consumer, you can implement the handler: async def redis_disconnect(self, *args): # Handle disconnect + + +``serializer_format`` +~~~~~~~~~~~~~~~~~~~~~~ +By default every message sent to redis is encoded using `msgpack `_ (_currently ``msgpack`` is a mandatory dependency of this package, it may become optional in a future release_). +It is also possible to switch to `JSON `_: + +.. code-block:: python + + CHANNEL_LAYERS = { + "default": { + "BACKEND": "channels_redis.core.RedisChannelLayer", + "CONFIG": { + "hosts": ["redis://:password@127.0.0.1:6379/0"], + "serializer_format": "json", + }, + }, + } + + +Custom serializer can be defined by: + +- extending ``channels_redis.serializers.BaseMessageSerializer``, implementing ``as_bytes `` and ``from_bytes`` methods +- using any class which accepts generic keyword arguments and provides ``serialize``/``deserialize`` methods + +Then it may be registered (or can be overriden) by using ``channels_redis.serializers.registry``: + +.. code-block:: python + + from channels_redis.serializers import registry + + class MyFormatSerializer: + def serialize(self, message): + ... + def deserialize(self, message): + ... + + registry.register_serializer('myformat', MyFormatSerializer) + +**NOTE**: the registry allows to override the serializer class used for a specific format without any particular check nor constraint, thus it is recommended to pay attention with order-of-imports when using third-party serializers which may override a built-in format. + + +Serializers are also responsible for encryption *symmetric_encryption_keys*. When extending ``channels_redis.serializers.BaseMessageSerializer`` encryption is already configured in the base class, unless you override ``serialize``/``deserialize`` methods: in this case you should call ``self.crypter.encrypt`` in serialization and ``self.crypter.decrypt`` in deserialization process. When using full custom serializer expect an optional sequence of keys to be passed via ``symmetric_encryption_keys``. + + Dependencies ------------ diff --git a/channels_redis/core.py b/channels_redis/core.py index a164059..a230081 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -1,20 +1,17 @@ import asyncio -import base64 import collections import functools -import hashlib import itertools import logging -import random import time import uuid -import msgpack from redis import asyncio as aioredis from channels.exceptions import ChannelFull from channels.layers import BaseChannelLayer +from .serializers import registry from .utils import ( _close_redis, _consistent_hash, @@ -115,6 +112,8 @@ def __init__( capacity=100, channel_capacity=None, symmetric_encryption_keys=None, + random_prefix_length=12, + serializer_format="msgpack", ): # Store basic information self.expiry = expiry @@ -126,6 +125,14 @@ def __init__( # Configure the host objects self.hosts = decode_hosts(hosts) self.ring_size = len(self.hosts) + # serialization + self._serializer = registry.get_serializer( + serializer_format, + # As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes. + random_prefix_length=random_prefix_length, + expiry=self.expiry, + symmetric_encryption_keys=symmetric_encryption_keys, + ) # Cached redis connection pools and the event loop they are from self._layers = {} # Normal channels choose a host index by cycling through the available hosts @@ -133,8 +140,6 @@ def __init__( self._send_index_generator = itertools.cycle(range(len(self.hosts))) # Decide on a unique client prefix to use in ! sections self.client_prefix = uuid.uuid4().hex - # Set up any encryption objects - self._setup_encryption(symmetric_encryption_keys) # Number of coroutines trying to receive right now self.receive_count = 0 # The receive lock @@ -154,24 +159,6 @@ def __init__( def create_pool(self, index): return create_pool(self.hosts[index]) - def _setup_encryption(self, symmetric_encryption_keys): - # See if we can do encryption if they asked - if symmetric_encryption_keys: - if isinstance(symmetric_encryption_keys, (str, bytes)): - raise ValueError( - "symmetric_encryption_keys must be a list of possible keys" - ) - try: - from cryptography.fernet import MultiFernet - except ImportError: - raise ValueError( - "Cannot run with encryption without 'cryptography' installed." - ) - sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys] - self.crypter = MultiFernet(sub_fernets) - else: - self.crypter = None - ### Channel layer API ### extensions = ["groups", "flush"] @@ -656,41 +643,19 @@ def serialize(self, message): """ Serializes message to a byte string. """ - value = msgpack.packb(message, use_bin_type=True) - if self.crypter: - value = self.crypter.encrypt(value) - - # As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes. - random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big") - return random_prefix + value + return self._serializer.serialize(message) def deserialize(self, message): """ Deserializes from a byte string. """ - # Removes the random prefix - message = message[12:] - - if self.crypter: - message = self.crypter.decrypt(message, self.expiry + 10) - return msgpack.unpackb(message, raw=False) + return self._serializer.deserialize(message) ### Internal functions ### def consistent_hash(self, value): return _consistent_hash(value, self.ring_size) - def make_fernet(self, key): - """ - Given a single encryption key, returns a Fernet instance using it. - """ - from cryptography.fernet import Fernet - - if isinstance(key, str): - key = key.encode("utf8") - formatted_key = base64.urlsafe_b64encode(hashlib.sha256(key).digest()) - return Fernet(formatted_key) - def __str__(self): return f"{self.__class__.__name__}(hosts={self.hosts})" diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index 6957b0a..a80e12d 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -3,9 +3,9 @@ import logging import uuid -import msgpack from redis import asyncio as aioredis +from .serializers import registry from .utils import ( _close_redis, _consistent_hash, @@ -25,10 +25,21 @@ async def _async_proxy(obj, name, *args, **kwargs): class RedisPubSubChannelLayer: - def __init__(self, *args, **kwargs) -> None: + def __init__( + self, + *args, + symmetric_encryption_keys=None, + serializer_format="msgpack", + **kwargs, + ) -> None: self._args = args self._kwargs = kwargs self._layers = {} + # serialization + self._serializer = registry.get_serializer( + serializer_format, + symmetric_encryption_keys=symmetric_encryption_keys, + ) def __getattr__(self, name): if name in ( @@ -48,13 +59,13 @@ def serialize(self, message): """ Serializes message to a byte string. """ - return msgpack.packb(message) + return self._serializer.serialize(message) def deserialize(self, message): """ Deserializes from a byte string. """ - return msgpack.unpackb(message) + return self._serializer.deserialize(message) def _get_layer(self): loop = asyncio.get_running_loop() diff --git a/channels_redis/serializers.py b/channels_redis/serializers.py new file mode 100644 index 0000000..b981797 --- /dev/null +++ b/channels_redis/serializers.py @@ -0,0 +1,169 @@ +import abc +import base64 +import hashlib +import json +import random + +try: + from cryptography.fernet import Fernet, MultiFernet +except ImportError: + MultiFernet = Fernet = None + + +class SerializerDoesNotExist(KeyError): + """The requested serializer was not found.""" + + +class BaseMessageSerializer(abc.ABC): + def __init__( + self, + symmetric_encryption_keys=None, + random_prefix_length=0, + expiry=None, + ): + self.random_prefix_length = random_prefix_length + self.expiry = expiry + # Set up any encryption objects + self._setup_encryption(symmetric_encryption_keys) + + def _setup_encryption(self, symmetric_encryption_keys): + # See if we can do encryption if they asked + if symmetric_encryption_keys: + if isinstance(symmetric_encryption_keys, (str, bytes)): + raise ValueError( + "symmetric_encryption_keys must be a list of possible keys" + ) + if MultiFernet is None: + raise ValueError( + "Cannot run with encryption without 'cryptography' installed." + ) + sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys] + self.crypter = MultiFernet(sub_fernets) + else: + self.crypter = None + + def make_fernet(self, key): + """ + Given a single encryption key, returns a Fernet instance using it. + """ + if Fernet is None: + raise ValueError( + "Cannot run with encryption without 'cryptography' installed." + ) + + if isinstance(key, str): + key = key.encode("utf-8") + formatted_key = base64.urlsafe_b64encode(hashlib.sha256(key).digest()) + return Fernet(formatted_key) + + @abc.abstractmethod + def as_bytes(self, message, *args, **kwargs): + raise NotImplementedError + + @abc.abstractmethod + def from_bytes(self, message, *args, **kwargs): + raise NotImplementedError + + def serialize(self, message): + """ + Serializes message to a byte string. + """ + message = self.as_bytes(message) + if self.crypter: + message = self.crypter.encrypt(message) + + if self.random_prefix_length > 0: + # provide random prefix + message = ( + random.getrandbits(8 * self.random_prefix_length).to_bytes( + self.random_prefix_length, "big" + ) + + message + ) + return message + + def deserialize(self, message): + """ + Deserializes from a byte string. + """ + if self.random_prefix_length > 0: + # Removes the random prefix + message = message[self.random_prefix_length :] # noqa: E203 + + if self.crypter: + ttl = self.expiry if self.expiry is None else self.expiry + 10 + message = self.crypter.decrypt(message, ttl) + return self.from_bytes(message) + + +class MissingSerializer(BaseMessageSerializer): + exception = None + + def __init__(self, *args, **kwargs): + raise self.exception + + +class JSONSerializer(BaseMessageSerializer): + # json module by default always produces str while loads accepts bytes + # thus we must force bytes conversion + # we use UTF-8 since it is the recommended encoding for interoperability + # see https://docs.python.org/3/library/json.html#character-encodings + def as_bytes(self, message, *args, **kwargs): + message = json.dumps(message, *args, **kwargs) + return message.encode("utf-8") + + from_bytes = staticmethod(json.loads) + + +# code ready for a future in which msgpack may become an optional dependency +try: + import msgpack +except ImportError as exc: + + class MsgPackSerializer(MissingSerializer): + exception = exc + +else: + + class MsgPackSerializer(BaseMessageSerializer): + as_bytes = staticmethod(msgpack.packb) + from_bytes = staticmethod(msgpack.unpackb) + + +class SerializersRegistry: + """ + Serializers registry inspired by that of ``django.core.serializers``. + """ + + def __init__(self): + self._registry = {} + + def register_serializer(self, format, serializer_class): + """ + Register a new serializer for given format + """ + assert isinstance(serializer_class, type) and ( + issubclass(serializer_class, BaseMessageSerializer) + or ( + hasattr(serializer_class, "serialize") + and hasattr(serializer_class, "deserialize") + ) + ), """ + `serializer_class` should be a class which implements `serialize` and `deserialize` method + or a subclass of `channels_redis.serializers.BaseMessageSerializer` + """ + + self._registry[format] = serializer_class + + def get_serializer(self, format, *args, **kwargs): + try: + serializer_class = self._registry[format] + except KeyError: + raise SerializerDoesNotExist(format) + + return serializer_class(*args, **kwargs) + + +registry = SerializersRegistry() +registry.register_serializer("json", JSONSerializer) +registry.register_serializer("msgpack", MsgPackSerializer) diff --git a/tests/test_core.py b/tests/test_core.py index 2752040..e5bda1c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -636,26 +636,3 @@ def test_receive_buffer_respects_capacity(): assert buff.qsize() == capacity messages = [buff.get_nowait() for _ in range(capacity)] assert list(range(9900, 10000)) == messages - - -def test_serialize(): - """ - Test default serialization method - """ - message = {"a": True, "b": None, "c": {"d": []}} - channel_layer = RedisChannelLayer() - serialized = channel_layer.serialize(message) - assert isinstance(serialized, bytes) - assert serialized[12:] == b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90" - - -def test_deserialize(): - """ - Test default deserialization method - """ - message = b"Q\x0c\xbb?Q\xbc\xe3|D\xfd9\x00\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90" - channel_layer = RedisChannelLayer() - deserialized = channel_layer.deserialize(message) - - assert isinstance(deserialized, dict) - assert deserialized == {"a": True, "b": None, "c": {"d": []}} diff --git a/tests/test_serializers.py b/tests/test_serializers.py new file mode 100644 index 0000000..76b8d29 --- /dev/null +++ b/tests/test_serializers.py @@ -0,0 +1,127 @@ +import pytest + +from channels_redis.serializers import ( + JSONSerializer, + MsgPackSerializer, + SerializerDoesNotExist, + SerializersRegistry, +) + + +@pytest.fixture +def registry(): + return SerializersRegistry() + + +class OnlySerialize: + def serialize(self, message): + return message + + +class OnlyDeserialize: + def deserialize(self, message): + return message + + +def bad_serializer(): + pass + + +class NoopSerializer: + def serialize(self, message): + return message + + def deserialize(self, message): + return message + + +@pytest.mark.parametrize( + "serializer_class", (OnlyDeserialize, OnlySerialize, bad_serializer) +) +def test_refuse_to_register_bad_serializers(registry, serializer_class): + with pytest.raises(AssertionError): + registry.register_serializer("custom", serializer_class) + + +def test_raise_error_for_unregistered_serializer(registry): + with pytest.raises(SerializerDoesNotExist): + registry.get_serializer("unexistent") + + +def test_register_custom_serializer(registry): + registry.register_serializer("custom", NoopSerializer) + serializer = registry.get_serializer("custom") + assert serializer.serialize("message") == "message" + assert serializer.deserialize("message") == "message" + + +@pytest.mark.parametrize( + "serializer_cls,expected", + ( + (MsgPackSerializer, b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"), + (JSONSerializer, b'{"a": true, "b": null, "c": {"d": []}}'), + ), +) +@pytest.mark.parametrize("prefix_length", (8, 12, 0, -1)) +def test_serialize(serializer_cls, expected, prefix_length): + """ + Test default serialization method + """ + message = {"a": True, "b": None, "c": {"d": []}} + serializer = serializer_cls(random_prefix_length=prefix_length) + serialized = serializer.serialize(message) + assert isinstance(serialized, bytes) + if prefix_length > 0: + assert serialized[prefix_length:] == expected + else: + assert serialized == expected + + +@pytest.mark.parametrize( + "serializer_cls,value", + ( + (MsgPackSerializer, b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"), + (JSONSerializer, b'{"a": true, "b": null, "c": {"d": []}}'), + ), +) +@pytest.mark.parametrize( + "prefix_length,prefix", + ( + (8, b"Q\x0c\xbb?Q\xbc\xe3|"), + (12, b"Q\x0c\xbb?Q\xbc\xe3|D\xfd9\x00"), + (0, b""), + (-1, b""), + ), +) +def test_deserialize(serializer_cls, value, prefix_length, prefix): + """ + Test default deserialization method + """ + message = prefix + value + serializer = serializer_cls(random_prefix_length=prefix_length) + deserialized = serializer.deserialize(message) + assert isinstance(deserialized, dict) + assert deserialized == {"a": True, "b": None, "c": {"d": []}} + + +@pytest.mark.parametrize( + "serializer_cls,clear_value", + ( + (MsgPackSerializer, b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"), + (JSONSerializer, b'{"a": true, "b": null, "c": {"d": []}}'), + ), +) +def test_serialization_encrypted(serializer_cls, clear_value): + """ + Test serialization rount-trip with encryption + """ + message = {"a": True, "b": None, "c": {"d": []}} + serializer = serializer_cls( + symmetric_encryption_keys=["a-test-key"], random_prefix_length=4 + ) + serialized = serializer.serialize(message) + assert isinstance(serialized, bytes) + assert serialized[4:] != clear_value + deserialized = serializer.deserialize(serialized) + assert isinstance(deserialized, dict) + assert deserialized == message