diff --git a/poetry.lock b/poetry.lock index 5b7f65c..ca1699b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "aio-pika" @@ -37,6 +37,18 @@ pamqp = "3.2.1" setuptools = {version = "*", markers = "python_version < \"3.8\""} yarl = "*" +[[package]] +name = "aiostream" +version = "0.4.5" +description = "Generator-based operators for asynchronous iteration" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "aiostream-0.4.5-py3-none-any.whl", hash = "sha256:25b7c2d9c83570d78c0ef5a20e949b7d0b8ea3b0b0a4f22c49d3f721105a6057"}, + {file = "aiostream-0.4.5.tar.gz", hash = "sha256:3ecbf87085230fbcd9605c32ca20c4fb41af02c71d076eab246ea22e35947d88"}, +] + [[package]] name = "anyio" version = "3.6.2" @@ -1765,4 +1777,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "08ce77e6e8381786d82bd61e4660375e3823e64a246dc49e6e44f2f170c223ca" +content-hash = "7576855cadb51e428d252b410c6957a60c58064554faf843b820837d052c0e71" diff --git a/pyproject.toml b/pyproject.toml index f991ec0..23887d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ keywords = ["taskiq", "tasks", "distributed", "async", "aio-pika"] python = "^3.7" taskiq = "^0" aio-pika = "^8.1.0" +aiostream = "^0.4.5" [tool.poetry.dev-dependencies] pytest = "^7.0" diff --git a/taskiq_aio_pika/broker.py b/taskiq_aio_pika/broker.py index e18719e..11e2c51 100644 --- a/taskiq_aio_pika/broker.py +++ b/taskiq_aio_pika/broker.py @@ -1,10 +1,11 @@ import asyncio from datetime import timedelta from logging import getLogger -from typing import Any, AsyncGenerator, Callable, Dict, Optional, TypeVar +from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, TypeVar from aio_pika import DeliveryMode, ExchangeType, Message, connect_robust from aio_pika.abc import AbstractChannel, AbstractQueue, AbstractRobustConnection +from aiostream import stream from taskiq.abc.broker import AsyncBroker from taskiq.abc.result_backend import AsyncResultBackend from taskiq.message import BrokerMessage @@ -50,7 +51,6 @@ def __init__( # noqa: WPS211 delay_queue_name: Optional[str] = None, declare_exchange: bool = True, declare_queues: bool = True, - routing_key: str = "#", exchange_type: ExchangeType = ExchangeType.TOPIC, max_priority: Optional[int] = None, **connection_kwargs: Any, @@ -75,7 +75,6 @@ def __init__( # noqa: WPS211 if it doesn't exist. :param declare_queues: whether you want to declare queues even on client side. May be useful for message persistance. - :param routing_key: that used to bind that queue to the exchange. :param exchange_type: type of the exchange. Used only if `declare_exchange` is True. :param max_priority: maximum priority value for messages. @@ -93,8 +92,8 @@ def __init__( # noqa: WPS211 self._declare_exchange = declare_exchange self._declare_queues = declare_queues self._queue_name = queue_name - self._routing_key = routing_key self._max_priority = max_priority + self._queue_name_list: List[str] = [] self._dead_letter_queue_name = f"{queue_name}.dead_letter" if dead_letter_queue_name: @@ -135,10 +134,20 @@ async def startup(self) -> None: # noqa: WPS217 if self._declare_queues: await self.declare_queues(self.write_channel) + def add_queue(self, queue_name: str) -> "AioPikaBroker": + """ + This function is add queue name. + + :param queue_name: queue_name. + :return: AioPikaBroker + """ + self._queue_name_list.append(queue_name) + return self + async def declare_queues( self, channel: AbstractChannel, - ) -> AbstractQueue: + ) -> List[AbstractQueue]: """ This function is used to declare queues. @@ -150,6 +159,16 @@ async def declare_queues( :param channel: channel to used for declaration. :return: main queue instance. """ + queue_list = [] + + async def bind_queue(queue_name: str) -> AbstractQueue: + queue = await channel.declare_queue( + queue_name, + arguments=args, + ) + await queue.bind(exchange=self._exchange_name, routing_key=queue_name) + return queue + await channel.declare_queue( self._dead_letter_queue_name, ) @@ -159,10 +178,7 @@ async def declare_queues( } if self._max_priority is not None: args["x-max-priority"] = self._max_priority - queue = await channel.declare_queue( - self._queue_name, - arguments=args, - ) + queue_list.append(await bind_queue(self._queue_name)) await channel.declare_queue( self._delay_queue_name, arguments={ @@ -170,8 +186,11 @@ async def declare_queues( "x-dead-letter-routing-key": self._queue_name, }, ) - await queue.bind(exchange=self._exchange_name, routing_key=self._routing_key) - return queue + if self._queue_name_list: + for _queue_name in self._queue_name_list: + queue_list.append(await bind_queue(_queue_name)) + + return queue_list async def kick(self, message: BrokerMessage) -> None: """ @@ -201,12 +220,13 @@ async def kick(self, message: BrokerMessage) -> None: priority=priority, ) delay = parse_val(int, message.labels.get("delay")) + routing_key_name = message.queue_name or self._queue_name # type: ignore # todo if delay is None: exchange = await self.write_channel.get_exchange( self._exchange_name, ensure=False, ) - await exchange.publish(rmq_msg, routing_key=message.task_name) + await exchange.publish(rmq_msg, routing_key=routing_key_name) else: rmq_msg.expiration = timedelta(seconds=delay) await self.write_channel.default_exchange.publish( @@ -227,11 +247,19 @@ async def listen(self) -> AsyncGenerator[bytes, None]: if self.read_channel is None: raise ValueError("Call startup before starting listening.") await self.read_channel.set_qos(prefetch_count=self._qos) - queue = await self.declare_queues(self.read_channel) - async with queue.iterator() as iterator: - async for message in iterator: - async with message.process(): - yield message.body + queue_list = await self.declare_queues(self.read_channel) + + async def body(queue: AbstractQueue) -> AsyncGenerator[bytes, None]: + async with queue.iterator() as iterator: + async for message in iterator: + async with message.process(): + yield message.body + + combine = stream.merge(*[body(queue) for queue in queue_list]) + + async with combine.stream() as streamer: + async for message_body in streamer: + yield message_body async def shutdown(self) -> None: """Close all connections on shutdown."""