From 13debf7a3af363987b08209c6e398f71a95e617c Mon Sep 17 00:00:00 2001 From: olzhasar Date: Sun, 29 Sep 2024 03:42:30 +0500 Subject: [PATCH] Handle multiple messages in send and group_send --- channels_redis/core.py | 179 +++++++++++++++++++++++++++++++++-------- tests/test_core.py | 54 +++++++++++++ 2 files changed, 199 insertions(+), 34 deletions(-) diff --git a/channels_redis/core.py b/channels_redis/core.py index a164059..6ed3865 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -180,22 +180,40 @@ async def send(self, channel, message): """ Send a message onto a (general or specific) channel. """ + await self.send_bulk(channel, (message,)) + + async def send_bulk(self, channel, messages): + """ + Send multiple messages in bulk onto a (general or specific) channel. + The `messages` argument should be an iterable of dicts. + """ + # Typecheck - assert isinstance(message, dict), "message is not a dict" assert self.valid_channel_name(channel), "Channel name not valid" - # Make sure the message does not contain reserved keys - assert "__asgi_channel__" not in message + # If it's a process-local channel, strip off local part and stick full name in message channel_non_local_name = channel - if "!" in channel: - message = dict(message.items()) - message["__asgi_channel__"] = channel + process_local = "!" in channel + if process_local: channel_non_local_name = self.non_local_name(channel) + + now = time.time() + mapping = {} + for message in messages: + assert isinstance(message, dict), "message is not a dict" + # Make sure the message does not contain reserved keys + assert "__asgi_channel__" not in message + if process_local: + message = dict(message.items()) + message["__asgi_channel__"] = channel + + mapping[self.serialize(message)] = now + # Write out message into expiring key (avoids big items in list) channel_key = self.prefix + channel_non_local_name # Pick a connection to the right server - consistent for specific # channels, random for general channels - if "!" in channel: + if process_local: index = self.consistent_hash(channel) else: index = next(self._send_index_generator) @@ -207,13 +225,13 @@ async def send(self, channel, message): # Check the length of the list before send # This can allow the list to leak slightly over capacity, but that's fine. - if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity( - channel - ): + current_length = await connection.zcount(channel_key, "-inf", "+inf") + + if current_length + len(messages) > self.get_capacity(channel): raise ChannelFull() # Push onto the list then set it to expire in case it's not consumed - await connection.zadd(channel_key, {self.serialize(message): time.time()}) + await connection.zadd(channel_key, mapping) await connection.expire(channel_key, int(self.expiry)) def _backup_channel_name(self, channel): @@ -517,10 +535,7 @@ async def group_discard(self, group, channel): connection = self.connection(self.consistent_hash(group)) await connection.zrem(key, channel) - async def group_send(self, group, message): - """ - Sends a message to the entire group. - """ + async def _get_group_connection_and_channels(self, group): assert self.valid_group_name(group), "Group name not valid" # Retrieve list of all channel names key = self._group_key(group) @@ -532,11 +547,36 @@ async def group_send(self, group, message): channel_names = [x.decode("utf8") for x in await connection.zrange(key, 0, -1)] + return connection, channel_names + + async def _exec_group_lua_script( + self, conn_idx, group, channel_redis_keys, channel_names, script, args + ): + # channel_keys does not contain a single redis key more than once + connection = self.connection(conn_idx) + channels_over_capacity = await connection.eval( + script, len(channel_redis_keys), *channel_redis_keys, *args + ) + if channels_over_capacity > 0: + logger.info( + "%s of %s channels over capacity in group %s", + channels_over_capacity, + len(channel_names), + group, + ) + + async def group_send(self, group, message): + """ + Sends a message to the entire group. + """ + + connection, channel_names = await self._get_group_connection_and_channels(group) + ( connection_to_channel_keys, channel_keys_to_message, channel_keys_to_capacity, - ) = self._map_channel_keys_to_connection(channel_names, message) + ) = self._map_channel_keys_to_connection(channel_names, (message,)) for connection_index, channel_redis_keys in connection_to_channel_keys.items(): # Discard old messages based on expiry @@ -569,7 +609,7 @@ async def group_send(self, group, message): # We need to filter the messages to keep those related to the connection args = [ - channel_keys_to_message[channel_key] + channel_keys_to_message[channel_key][0] for channel_key in channel_redis_keys ] @@ -581,20 +621,88 @@ async def group_send(self, group, message): args += [time.time(), self.expiry] - # channel_keys does not contain a single redis key more than once - connection = self.connection(connection_index) - channels_over_capacity = await connection.eval( - group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args + await self._exec_group_lua_script( + connection_index, + group, + channel_redis_keys, + channel_names, + group_send_lua, + args, ) - if channels_over_capacity > 0: - logger.info( - "%s of %s channels over capacity in group %s", - channels_over_capacity, - len(channel_names), - group, + + async def group_send_bulk(self, group, messages): + """ + Sends multiple messages in bulk to the entire group. + The `messages` argument should be an iterable of dicts. + """ + + connection, channel_names = await self._get_group_connection_and_channels(group) + + ( + connection_to_channel_keys, + channel_keys_to_message, + channel_keys_to_capacity, + ) = self._map_channel_keys_to_connection(channel_names, messages) + + for connection_index, channel_redis_keys in connection_to_channel_keys.items(): + # Discard old messages based on expiry + pipe = connection.pipeline() + for key in channel_redis_keys: + pipe.zremrangebyscore( + key, min=0, max=int(time.time()) - int(self.expiry) ) + await pipe.execute() + + # Create a LUA script specific for this connection. + # Make sure to use the message list specific to this channel, it is + # stored in channel_to_message dict and each message contains the + # __asgi_channel__ key. + + group_send_lua = """ + local over_capacity = 0 + local num_messages = tonumber(ARGV[#ARGV - 2]) + local current_time = ARGV[#ARGV - 1] + local expiry = ARGV[#ARGV] + for i=1,#KEYS do + if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then + local messages = {} + for j=num_messages * (i - 1) + 1, num_messages * i do + table.insert(messages, current_time) + table.insert(messages, ARGV[j]) + end + redis.call('ZADD', KEYS[i], unpack(messages)) + redis.call('EXPIRE', KEYS[i], expiry) + else + over_capacity = over_capacity + 1 + end + end + return over_capacity + """ + + # We need to filter the messages to keep those related to the connection + args = [] + + for channel_key in channel_redis_keys: + args += channel_keys_to_message[channel_key] + + # We need to send the capacity for each channel + args += [ + channel_keys_to_capacity[channel_key] + for channel_key in channel_redis_keys + ] - def _map_channel_keys_to_connection(self, channel_names, message): + args += [len(messages), time.time(), self.expiry] + + await self._exec_group_lua_script( + connection_index, + group, + channel_redis_keys, + channel_names, + group_send_lua, + args, + ) + + def _map_channel_keys_to_connection(self, channel_names, messages): """ For a list of channel names, GET @@ -609,7 +717,7 @@ def _map_channel_keys_to_connection(self, channel_names, message): # Connection dict keyed by index to list of redis keys mapped on that index connection_to_channel_keys = collections.defaultdict(list) # Message dict maps redis key to the message that needs to be send on that key - channel_key_to_message = dict() + channel_key_to_message = collections.defaultdict(list) # Channel key mapped to its capacity channel_key_to_capacity = dict() @@ -623,20 +731,23 @@ def _map_channel_keys_to_connection(self, channel_names, message): # Have we come across the same redis key? if channel_key not in channel_key_to_message: # If not, fill the corresponding dicts - message = dict(message.items()) - message["__asgi_channel__"] = [channel] - channel_key_to_message[channel_key] = message + for message in messages: + message = dict(message.items()) + message["__asgi_channel__"] = [channel] + channel_key_to_message[channel_key].append(message) channel_key_to_capacity[channel_key] = self.get_capacity(channel) idx = self.consistent_hash(channel_non_local_name) connection_to_channel_keys[idx].append(channel_key) else: # Yes, Append the channel in message dict - channel_key_to_message[channel_key]["__asgi_channel__"].append(channel) + for message in channel_key_to_message[channel_key]: + message["__asgi_channel__"].append(channel) # Now that we know what message needs to be send on a redis key we serialize it for key, value in channel_key_to_message.items(): # Serialize the message stored for each redis key - channel_key_to_message[key] = self.serialize(value) + for idx, message in enumerate(value): + channel_key_to_message[key][idx] = self.serialize(message) return ( connection_to_channel_keys, diff --git a/tests/test_core.py b/tests/test_core.py index 2752040..a8e5bef 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,4 +1,5 @@ import asyncio +import collections import random import async_timeout @@ -125,6 +126,25 @@ async def listen2(): async_to_sync(channel_layer.flush)() +@pytest.mark.asyncio +async def test_send_multiple(channel_layer): + messsages = [ + {"type": "test.message.1"}, + {"type": "test.message.2"}, + {"type": "test.message.3"}, + ] + + await channel_layer.send_bulk("test-channel-1", messsages) + + expected = {"test.message.1", "test.message.2", "test.message.3"} + received = set() + for _ in range(3): + msg = await channel_layer.receive("test-channel-1") + received.add(msg["type"]) + + assert received == expected + + @pytest.mark.asyncio async def test_send_capacity(channel_layer): """ @@ -225,6 +245,40 @@ async def test_groups_basic(channel_layer): await channel_layer.flush() +@pytest.mark.asyncio +async def test_groups_multiple(channel_layer): + """ + Tests basic group operation. + """ + channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1") + channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2") + channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3") + await channel_layer.group_add("test-group", channel_name1) + await channel_layer.group_add("test-group", channel_name2) + await channel_layer.group_add("test-group", channel_name3) + + messages = [ + {"type": "message.1"}, + {"type": "message.2"}, + {"type": "message.3"}, + ] + + expected = {msg["type"] for msg in messages} + + await channel_layer.group_send_bulk("test-group", messages) + + received = collections.defaultdict(set) + + for channel_name in (channel_name1, channel_name2, channel_name3): + async with async_timeout.timeout(1): + for _ in range(len(messages)): + received[channel_name].add( + (await channel_layer.receive(channel_name))["type"] + ) + + assert received[channel_name] == expected + + @pytest.mark.asyncio async def test_groups_channel_full(channel_layer): """