-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RabbitMQConnectorGateway and related components
Introduced the RabbitMQConnectorGateway for enhanced RabbitMQ integration. This includes new tests, and necessary Config model changes. The `publish`, `consume`, and queue management methods were added for the RabbitMQ interaction. A new `batched` helper was also added for batching messages during publishing. Configuration parsing logic and necessary dependencies were also updated. 4
- Loading branch information
Showing
15 changed files
with
513 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import abc | ||
|
||
from typing import Protocol, Callable, Coroutine | ||
|
||
from aio_pika import Channel, Message | ||
from aio_pika.abc import ( | ||
AbstractIncomingMessage, | ||
ExchangeType, | ||
AbstractRobustConnection, | ||
DeliveryMode, | ||
AbstractExchange, | ||
) | ||
from aio_pika.pool import Pool | ||
|
||
|
||
class RabbitMQConnectorGateway(Protocol): | ||
@abc.abstractmethod | ||
def channel_pool(self) -> Pool[Channel]: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
async def publish( | ||
self, *message_bodies: dict, routing_key: str, exchange_name: str | None = None, | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
async def consume( | ||
self, queue_name: str, callback: Callable[[AbstractIncomingMessage], Coroutine], | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
async def create_exchange( | ||
self, name: str, type_: ExchangeType = ExchangeType.TOPIC, | ||
) -> None: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
async def create_queue(self, name: str) -> None: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
async def bid_queue( | ||
self, queue_name: str, routing_key: str, exchange_name: str, | ||
) -> bool: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
async def delete_exchange(self, name: str) -> None: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
async def delete_queue(self, name: str) -> None: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
async def shutdown(self) -> None: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
async def _get_exchange( | ||
self, channel: Channel, exchange_name: str | None, | ||
) -> AbstractExchange: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
async def _get_connection(self) -> AbstractRobustConnection: | ||
raise NotImplementedError | ||
|
||
@abc.abstractmethod | ||
async def _get_channel(self) -> Channel: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abc.abstractmethod | ||
def _serialize_message( | ||
body: dict, delivery_mode: DeliveryMode = DeliveryMode.PERSISTENT, | ||
) -> Message: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class RabbitMQConnectorConfig: | ||
host: str | ||
port: str | ||
username: str = "guest" | ||
password: str = "guest" | ||
connection_pool_max_size: int = 3 | ||
channel_pool_max_size: int = 15 | ||
default_exchange_name: str = "decision_making_app" | ||
|
||
def full_url(self) -> str: | ||
return "amqp://{}:{}@{}:{}".format( | ||
self.username, self.password, | ||
self.host, self.port, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import asyncio | ||
from typing import Callable, Coroutine | ||
|
||
from aio_pika import Channel, connect_robust, Connection, ExchangeType, Message | ||
from aio_pika.abc import ( | ||
AbstractExchange, | ||
AbstractIncomingMessage, | ||
AbstractRobustConnection, | ||
DeliveryMode, | ||
) | ||
from aio_pika.pool import Pool | ||
import orjson | ||
|
||
from .iterables import batched | ||
from src.app.infrastructure.config.models import RabbitMQConnectorConfig | ||
from src.app.application.rmq_connector.interfaces import RabbitMQConnectorGateway | ||
|
||
|
||
class RabbitMQConnectorGatewayImpl(RabbitMQConnectorGateway): | ||
BATCH_SIZE: int = 200 | ||
|
||
def __init__(self, settings: RabbitMQConnectorConfig): | ||
self._settings = settings | ||
self._connection_pool: Pool[Connection] = Pool( | ||
self._get_connection, | ||
max_size=settings.connection_pool_max_size, | ||
) | ||
self._channel_pool: Pool[Channel] = Pool( | ||
self._get_channel, | ||
max_size=settings.channel_pool_max_size, | ||
) | ||
self._default_exchange_name = settings.default_exchange_name | ||
|
||
@property | ||
def channel_pool(self) -> Pool[Channel]: # type: ignore | ||
return self._channel_pool | ||
|
||
async def publish( | ||
self, | ||
*message_bodies: dict, | ||
routing_key: str, | ||
exchange_name: str | None = None, | ||
) -> None: | ||
async with self._channel_pool.acquire() as channel: | ||
exchange = await self._get_exchange( | ||
channel, | ||
exchange_name=exchange_name or self._default_exchange_name, | ||
) | ||
|
||
messages = (self._serialize_message(m) for m in message_bodies) | ||
for messages_batch in batched( | ||
messages, | ||
RabbitMQConnectorGatewayImpl.BATCH_SIZE, | ||
): | ||
tasks = [ | ||
exchange.publish(message=m, routing_key=routing_key) | ||
for m in messages_batch | ||
] | ||
await asyncio.gather(*tasks) | ||
|
||
async def consume( | ||
self, | ||
queue_name: str, | ||
callback: Callable[[AbstractIncomingMessage], Coroutine], | ||
) -> None: | ||
async with self._channel_pool.acquire() as channel: | ||
queue = await channel.get_queue(queue_name, ensure=True) | ||
await queue.consume(callback, exclusive=True) # type: ignore | ||
|
||
async def create_exchange( | ||
self, | ||
name: str, | ||
type_: ExchangeType = ExchangeType.TOPIC, | ||
) -> None: | ||
async with self._channel_pool.acquire() as channel: | ||
await channel.declare_exchange(name, type_) | ||
|
||
async def create_queue(self, name: str) -> None: | ||
async with self._channel_pool.acquire() as channel: | ||
await channel.declare_queue(name) | ||
|
||
async def bid_queue( | ||
self, | ||
queue_name: str, | ||
routing_key: str, | ||
exchange_name: str, | ||
) -> bool: | ||
async with self._channel_pool.acquire() as channel: | ||
queue = await channel.get_queue(queue_name) | ||
exchange = await self._get_exchange(channel, exchange_name=exchange_name) | ||
result = await queue.bind(exchange=exchange, routing_key=routing_key) | ||
return result.name == "Queue.BindOk" | ||
|
||
async def delete_exchange(self, name: str) -> None: | ||
async with self._channel_pool.acquire() as channel: | ||
await channel.exchange_delete(name) | ||
|
||
async def delete_queue(self, name: str) -> None: | ||
async with self._channel_pool.acquire() as channel: | ||
await channel.queue_delete(name) | ||
|
||
async def shutdown(self) -> None: | ||
await self._channel_pool.close() | ||
await self._connection_pool.close() | ||
|
||
@staticmethod | ||
def _serialize_message( | ||
body: dict, | ||
delivery_mode: DeliveryMode = DeliveryMode.PERSISTENT, | ||
) -> Message: | ||
return Message(body=orjson.dumps(body), delivery_mode=delivery_mode) | ||
|
||
async def _get_exchange( | ||
self, | ||
channel: Channel, | ||
exchange_name: str | None, | ||
) -> AbstractExchange: | ||
exchange_name = exchange_name or self._default_exchange_name | ||
return await channel.get_exchange(name=exchange_name, ensure=True) | ||
|
||
async def _get_connection(self) -> AbstractRobustConnection: | ||
return await connect_robust(self._settings.full_url()) | ||
|
||
async def _get_channel(self) -> Channel: | ||
async with self._connection_pool.acquire() as connection: | ||
return await connection.channel() # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from itertools import islice | ||
from typing import Iterable | ||
|
||
|
||
def batched(iterable: Iterable, batch_size: int): | ||
error_msg = "batch_size must be at least one" | ||
if batch_size < 1: | ||
raise ValueError(error_msg) | ||
it = iter(iterable) | ||
while batch := tuple(islice(it, batch_size)): | ||
yield batch |
Empty file.
58 changes: 58 additions & 0 deletions
58
tests/integration/infrastructure/rmq_connector/conftest.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import Generator, AsyncGenerator | ||
from uuid import uuid4 | ||
|
||
from testcontainers.rabbitmq import RabbitMqContainer # type: ignore[import-untyped] | ||
|
||
import pytest | ||
import pytest_asyncio | ||
|
||
from src.app.infrastructure.config.models import RabbitMQConnectorConfig | ||
from src.app.infrastructure.rmq_connector.gateway import RabbitMQConnectorGatewayImpl | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def rmq_container() -> Generator[RabbitMqContainer, None, None]: | ||
rabbitmq_container = RabbitMqContainer( | ||
username="guest", | ||
password="guest", | ||
) | ||
try: | ||
rabbitmq_container.start() | ||
yield rabbitmq_container | ||
finally: | ||
rabbitmq_container.stop() | ||
|
||
|
||
@pytest_asyncio.fixture() | ||
async def rmq_connector_gateway( | ||
rmq_container: RabbitMqContainer, | ||
) -> AsyncGenerator[RabbitMQConnectorGatewayImpl, None]: | ||
config = RabbitMQConnectorConfig( | ||
host=rmq_container.get_container_host_ip(), | ||
port=rmq_container.get_exposed_port(5672), | ||
) | ||
gateway = RabbitMQConnectorGatewayImpl(config) | ||
yield gateway | ||
await gateway.shutdown() | ||
|
||
|
||
@pytest.fixture() | ||
async def queue(rmq_connector_gateway: RabbitMQConnectorGatewayImpl): | ||
queue_name = str(uuid4()) | ||
await rmq_connector_gateway.create_queue(queue_name) | ||
|
||
async with rmq_connector_gateway.channel_pool.acquire() as channel: | ||
yield await channel.get_queue(name=queue_name, ensure=True) | ||
|
||
await rmq_connector_gateway.delete_queue(queue_name) | ||
|
||
|
||
@pytest.fixture() | ||
async def exchange(rmq_connector_gateway: RabbitMQConnectorGatewayImpl): | ||
exchange_name = str(uuid4()) | ||
await rmq_connector_gateway.create_exchange(exchange_name) | ||
|
||
async with rmq_connector_gateway.channel_pool.acquire() as channel: | ||
yield await channel.get_exchange(name=exchange_name, ensure=True) | ||
|
||
await rmq_connector_gateway.delete_exchange(exchange_name) |
Oops, something went wrong.