diff --git a/README.rst b/README.rst
index 2fa5626..3b34aa2 100644
--- a/README.rst
+++ b/README.rst
@@ -171,6 +171,7 @@ to 10, and all ``websocket.send!`` channels to 20:
If you want to enforce a matching order, use an ``OrderedDict`` as the
argument; channels will then be matched in the order the dict provides them.
+.. _encryption
``symmetric_encryption_keys``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -237,6 +238,51 @@ And then in your channels consumer, you can implement the handler:
async def redis_disconnect(self, *args):
# Handle disconnect
+
+
+``serializer_format``
+~~~~~~~~~~~~~~~~~~~~~~
+By default every message sent to redis is encoded using `msgpack `_ (_currently ``msgpack`` is a mandatory dependency of this package, it may become optional in a future release_).
+It is also possible to switch to `JSON `_:
+
+.. code-block:: python
+
+ CHANNEL_LAYERS = {
+ "default": {
+ "BACKEND": "channels_redis.core.RedisChannelLayer",
+ "CONFIG": {
+ "hosts": ["redis://:password@127.0.0.1:6379/0"],
+ "serializer_format": "json",
+ },
+ },
+ }
+
+
+Custom serializer can be defined by:
+
+- extending ``channels_redis.serializers.BaseMessageSerializer``, implementing ``as_bytes `` and ``from_bytes`` methods
+- using any class which accepts generic keyword arguments and provides ``serialize``/``deserialize`` methods
+
+Then it may be registered (or can be overriden) by using ``channels_redis.serializers.registry``:
+
+.. code-block:: python
+
+ from channels_redis.serializers import registry
+
+ class MyFormatSerializer:
+ def serialize(self, message):
+ ...
+ def deserialize(self, message):
+ ...
+
+ registry.register_serializer('myformat', MyFormatSerializer)
+
+**NOTE**: the registry allows to override the serializer class used for a specific format without any particular check nor constraint, thus it is recommended to pay attention with order-of-imports when using third-party serializers which may override a built-in format.
+
+
+Serializers are also responsible for encryption *symmetric_encryption_keys*. When extending ``channels_redis.serializers.BaseMessageSerializer`` encryption is already configured in the base class, unless you override ``serialize``/``deserialize`` methods: in this case you should call ``self.crypter.encrypt`` in serialization and ``self.crypter.decrypt`` in deserialization process. When using full custom serializer expect an optional sequence of keys to be passed via ``symmetric_encryption_keys``.
+
+
Dependencies
------------
diff --git a/channels_redis/core.py b/channels_redis/core.py
index a164059..a230081 100644
--- a/channels_redis/core.py
+++ b/channels_redis/core.py
@@ -1,20 +1,17 @@
import asyncio
-import base64
import collections
import functools
-import hashlib
import itertools
import logging
-import random
import time
import uuid
-import msgpack
from redis import asyncio as aioredis
from channels.exceptions import ChannelFull
from channels.layers import BaseChannelLayer
+from .serializers import registry
from .utils import (
_close_redis,
_consistent_hash,
@@ -115,6 +112,8 @@ def __init__(
capacity=100,
channel_capacity=None,
symmetric_encryption_keys=None,
+ random_prefix_length=12,
+ serializer_format="msgpack",
):
# Store basic information
self.expiry = expiry
@@ -126,6 +125,14 @@ def __init__(
# Configure the host objects
self.hosts = decode_hosts(hosts)
self.ring_size = len(self.hosts)
+ # serialization
+ self._serializer = registry.get_serializer(
+ serializer_format,
+ # As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
+ random_prefix_length=random_prefix_length,
+ expiry=self.expiry,
+ symmetric_encryption_keys=symmetric_encryption_keys,
+ )
# Cached redis connection pools and the event loop they are from
self._layers = {}
# Normal channels choose a host index by cycling through the available hosts
@@ -133,8 +140,6 @@ def __init__(
self._send_index_generator = itertools.cycle(range(len(self.hosts)))
# Decide on a unique client prefix to use in ! sections
self.client_prefix = uuid.uuid4().hex
- # Set up any encryption objects
- self._setup_encryption(symmetric_encryption_keys)
# Number of coroutines trying to receive right now
self.receive_count = 0
# The receive lock
@@ -154,24 +159,6 @@ def __init__(
def create_pool(self, index):
return create_pool(self.hosts[index])
- def _setup_encryption(self, symmetric_encryption_keys):
- # See if we can do encryption if they asked
- if symmetric_encryption_keys:
- if isinstance(symmetric_encryption_keys, (str, bytes)):
- raise ValueError(
- "symmetric_encryption_keys must be a list of possible keys"
- )
- try:
- from cryptography.fernet import MultiFernet
- except ImportError:
- raise ValueError(
- "Cannot run with encryption without 'cryptography' installed."
- )
- sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
- self.crypter = MultiFernet(sub_fernets)
- else:
- self.crypter = None
-
### Channel layer API ###
extensions = ["groups", "flush"]
@@ -656,41 +643,19 @@ def serialize(self, message):
"""
Serializes message to a byte string.
"""
- value = msgpack.packb(message, use_bin_type=True)
- if self.crypter:
- value = self.crypter.encrypt(value)
-
- # As we use an sorted set to expire messages we need to guarantee uniqueness, with 12 bytes.
- random_prefix = random.getrandbits(8 * 12).to_bytes(12, "big")
- return random_prefix + value
+ return self._serializer.serialize(message)
def deserialize(self, message):
"""
Deserializes from a byte string.
"""
- # Removes the random prefix
- message = message[12:]
-
- if self.crypter:
- message = self.crypter.decrypt(message, self.expiry + 10)
- return msgpack.unpackb(message, raw=False)
+ return self._serializer.deserialize(message)
### Internal functions ###
def consistent_hash(self, value):
return _consistent_hash(value, self.ring_size)
- def make_fernet(self, key):
- """
- Given a single encryption key, returns a Fernet instance using it.
- """
- from cryptography.fernet import Fernet
-
- if isinstance(key, str):
- key = key.encode("utf8")
- formatted_key = base64.urlsafe_b64encode(hashlib.sha256(key).digest())
- return Fernet(formatted_key)
-
def __str__(self):
return f"{self.__class__.__name__}(hosts={self.hosts})"
diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py
index 6957b0a..a80e12d 100644
--- a/channels_redis/pubsub.py
+++ b/channels_redis/pubsub.py
@@ -3,9 +3,9 @@
import logging
import uuid
-import msgpack
from redis import asyncio as aioredis
+from .serializers import registry
from .utils import (
_close_redis,
_consistent_hash,
@@ -25,10 +25,21 @@ async def _async_proxy(obj, name, *args, **kwargs):
class RedisPubSubChannelLayer:
- def __init__(self, *args, **kwargs) -> None:
+ def __init__(
+ self,
+ *args,
+ symmetric_encryption_keys=None,
+ serializer_format="msgpack",
+ **kwargs,
+ ) -> None:
self._args = args
self._kwargs = kwargs
self._layers = {}
+ # serialization
+ self._serializer = registry.get_serializer(
+ serializer_format,
+ symmetric_encryption_keys=symmetric_encryption_keys,
+ )
def __getattr__(self, name):
if name in (
@@ -48,13 +59,13 @@ def serialize(self, message):
"""
Serializes message to a byte string.
"""
- return msgpack.packb(message)
+ return self._serializer.serialize(message)
def deserialize(self, message):
"""
Deserializes from a byte string.
"""
- return msgpack.unpackb(message)
+ return self._serializer.deserialize(message)
def _get_layer(self):
loop = asyncio.get_running_loop()
diff --git a/channels_redis/serializers.py b/channels_redis/serializers.py
new file mode 100644
index 0000000..b981797
--- /dev/null
+++ b/channels_redis/serializers.py
@@ -0,0 +1,169 @@
+import abc
+import base64
+import hashlib
+import json
+import random
+
+try:
+ from cryptography.fernet import Fernet, MultiFernet
+except ImportError:
+ MultiFernet = Fernet = None
+
+
+class SerializerDoesNotExist(KeyError):
+ """The requested serializer was not found."""
+
+
+class BaseMessageSerializer(abc.ABC):
+ def __init__(
+ self,
+ symmetric_encryption_keys=None,
+ random_prefix_length=0,
+ expiry=None,
+ ):
+ self.random_prefix_length = random_prefix_length
+ self.expiry = expiry
+ # Set up any encryption objects
+ self._setup_encryption(symmetric_encryption_keys)
+
+ def _setup_encryption(self, symmetric_encryption_keys):
+ # See if we can do encryption if they asked
+ if symmetric_encryption_keys:
+ if isinstance(symmetric_encryption_keys, (str, bytes)):
+ raise ValueError(
+ "symmetric_encryption_keys must be a list of possible keys"
+ )
+ if MultiFernet is None:
+ raise ValueError(
+ "Cannot run with encryption without 'cryptography' installed."
+ )
+ sub_fernets = [self.make_fernet(key) for key in symmetric_encryption_keys]
+ self.crypter = MultiFernet(sub_fernets)
+ else:
+ self.crypter = None
+
+ def make_fernet(self, key):
+ """
+ Given a single encryption key, returns a Fernet instance using it.
+ """
+ if Fernet is None:
+ raise ValueError(
+ "Cannot run with encryption without 'cryptography' installed."
+ )
+
+ if isinstance(key, str):
+ key = key.encode("utf-8")
+ formatted_key = base64.urlsafe_b64encode(hashlib.sha256(key).digest())
+ return Fernet(formatted_key)
+
+ @abc.abstractmethod
+ def as_bytes(self, message, *args, **kwargs):
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def from_bytes(self, message, *args, **kwargs):
+ raise NotImplementedError
+
+ def serialize(self, message):
+ """
+ Serializes message to a byte string.
+ """
+ message = self.as_bytes(message)
+ if self.crypter:
+ message = self.crypter.encrypt(message)
+
+ if self.random_prefix_length > 0:
+ # provide random prefix
+ message = (
+ random.getrandbits(8 * self.random_prefix_length).to_bytes(
+ self.random_prefix_length, "big"
+ )
+ + message
+ )
+ return message
+
+ def deserialize(self, message):
+ """
+ Deserializes from a byte string.
+ """
+ if self.random_prefix_length > 0:
+ # Removes the random prefix
+ message = message[self.random_prefix_length :] # noqa: E203
+
+ if self.crypter:
+ ttl = self.expiry if self.expiry is None else self.expiry + 10
+ message = self.crypter.decrypt(message, ttl)
+ return self.from_bytes(message)
+
+
+class MissingSerializer(BaseMessageSerializer):
+ exception = None
+
+ def __init__(self, *args, **kwargs):
+ raise self.exception
+
+
+class JSONSerializer(BaseMessageSerializer):
+ # json module by default always produces str while loads accepts bytes
+ # thus we must force bytes conversion
+ # we use UTF-8 since it is the recommended encoding for interoperability
+ # see https://docs.python.org/3/library/json.html#character-encodings
+ def as_bytes(self, message, *args, **kwargs):
+ message = json.dumps(message, *args, **kwargs)
+ return message.encode("utf-8")
+
+ from_bytes = staticmethod(json.loads)
+
+
+# code ready for a future in which msgpack may become an optional dependency
+try:
+ import msgpack
+except ImportError as exc:
+
+ class MsgPackSerializer(MissingSerializer):
+ exception = exc
+
+else:
+
+ class MsgPackSerializer(BaseMessageSerializer):
+ as_bytes = staticmethod(msgpack.packb)
+ from_bytes = staticmethod(msgpack.unpackb)
+
+
+class SerializersRegistry:
+ """
+ Serializers registry inspired by that of ``django.core.serializers``.
+ """
+
+ def __init__(self):
+ self._registry = {}
+
+ def register_serializer(self, format, serializer_class):
+ """
+ Register a new serializer for given format
+ """
+ assert isinstance(serializer_class, type) and (
+ issubclass(serializer_class, BaseMessageSerializer)
+ or (
+ hasattr(serializer_class, "serialize")
+ and hasattr(serializer_class, "deserialize")
+ )
+ ), """
+ `serializer_class` should be a class which implements `serialize` and `deserialize` method
+ or a subclass of `channels_redis.serializers.BaseMessageSerializer`
+ """
+
+ self._registry[format] = serializer_class
+
+ def get_serializer(self, format, *args, **kwargs):
+ try:
+ serializer_class = self._registry[format]
+ except KeyError:
+ raise SerializerDoesNotExist(format)
+
+ return serializer_class(*args, **kwargs)
+
+
+registry = SerializersRegistry()
+registry.register_serializer("json", JSONSerializer)
+registry.register_serializer("msgpack", MsgPackSerializer)
diff --git a/tests/test_core.py b/tests/test_core.py
index 2752040..e5bda1c 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -636,26 +636,3 @@ def test_receive_buffer_respects_capacity():
assert buff.qsize() == capacity
messages = [buff.get_nowait() for _ in range(capacity)]
assert list(range(9900, 10000)) == messages
-
-
-def test_serialize():
- """
- Test default serialization method
- """
- message = {"a": True, "b": None, "c": {"d": []}}
- channel_layer = RedisChannelLayer()
- serialized = channel_layer.serialize(message)
- assert isinstance(serialized, bytes)
- assert serialized[12:] == b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"
-
-
-def test_deserialize():
- """
- Test default deserialization method
- """
- message = b"Q\x0c\xbb?Q\xbc\xe3|D\xfd9\x00\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"
- channel_layer = RedisChannelLayer()
- deserialized = channel_layer.deserialize(message)
-
- assert isinstance(deserialized, dict)
- assert deserialized == {"a": True, "b": None, "c": {"d": []}}
diff --git a/tests/test_serializers.py b/tests/test_serializers.py
new file mode 100644
index 0000000..76b8d29
--- /dev/null
+++ b/tests/test_serializers.py
@@ -0,0 +1,127 @@
+import pytest
+
+from channels_redis.serializers import (
+ JSONSerializer,
+ MsgPackSerializer,
+ SerializerDoesNotExist,
+ SerializersRegistry,
+)
+
+
+@pytest.fixture
+def registry():
+ return SerializersRegistry()
+
+
+class OnlySerialize:
+ def serialize(self, message):
+ return message
+
+
+class OnlyDeserialize:
+ def deserialize(self, message):
+ return message
+
+
+def bad_serializer():
+ pass
+
+
+class NoopSerializer:
+ def serialize(self, message):
+ return message
+
+ def deserialize(self, message):
+ return message
+
+
+@pytest.mark.parametrize(
+ "serializer_class", (OnlyDeserialize, OnlySerialize, bad_serializer)
+)
+def test_refuse_to_register_bad_serializers(registry, serializer_class):
+ with pytest.raises(AssertionError):
+ registry.register_serializer("custom", serializer_class)
+
+
+def test_raise_error_for_unregistered_serializer(registry):
+ with pytest.raises(SerializerDoesNotExist):
+ registry.get_serializer("unexistent")
+
+
+def test_register_custom_serializer(registry):
+ registry.register_serializer("custom", NoopSerializer)
+ serializer = registry.get_serializer("custom")
+ assert serializer.serialize("message") == "message"
+ assert serializer.deserialize("message") == "message"
+
+
+@pytest.mark.parametrize(
+ "serializer_cls,expected",
+ (
+ (MsgPackSerializer, b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"),
+ (JSONSerializer, b'{"a": true, "b": null, "c": {"d": []}}'),
+ ),
+)
+@pytest.mark.parametrize("prefix_length", (8, 12, 0, -1))
+def test_serialize(serializer_cls, expected, prefix_length):
+ """
+ Test default serialization method
+ """
+ message = {"a": True, "b": None, "c": {"d": []}}
+ serializer = serializer_cls(random_prefix_length=prefix_length)
+ serialized = serializer.serialize(message)
+ assert isinstance(serialized, bytes)
+ if prefix_length > 0:
+ assert serialized[prefix_length:] == expected
+ else:
+ assert serialized == expected
+
+
+@pytest.mark.parametrize(
+ "serializer_cls,value",
+ (
+ (MsgPackSerializer, b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"),
+ (JSONSerializer, b'{"a": true, "b": null, "c": {"d": []}}'),
+ ),
+)
+@pytest.mark.parametrize(
+ "prefix_length,prefix",
+ (
+ (8, b"Q\x0c\xbb?Q\xbc\xe3|"),
+ (12, b"Q\x0c\xbb?Q\xbc\xe3|D\xfd9\x00"),
+ (0, b""),
+ (-1, b""),
+ ),
+)
+def test_deserialize(serializer_cls, value, prefix_length, prefix):
+ """
+ Test default deserialization method
+ """
+ message = prefix + value
+ serializer = serializer_cls(random_prefix_length=prefix_length)
+ deserialized = serializer.deserialize(message)
+ assert isinstance(deserialized, dict)
+ assert deserialized == {"a": True, "b": None, "c": {"d": []}}
+
+
+@pytest.mark.parametrize(
+ "serializer_cls,clear_value",
+ (
+ (MsgPackSerializer, b"\x83\xa1a\xc3\xa1b\xc0\xa1c\x81\xa1d\x90"),
+ (JSONSerializer, b'{"a": true, "b": null, "c": {"d": []}}'),
+ ),
+)
+def test_serialization_encrypted(serializer_cls, clear_value):
+ """
+ Test serialization rount-trip with encryption
+ """
+ message = {"a": True, "b": None, "c": {"d": []}}
+ serializer = serializer_cls(
+ symmetric_encryption_keys=["a-test-key"], random_prefix_length=4
+ )
+ serialized = serializer.serialize(message)
+ assert isinstance(serialized, bytes)
+ assert serialized[4:] != clear_value
+ deserialized = serializer.deserialize(serialized)
+ assert isinstance(deserialized, dict)
+ assert deserialized == message