Skip to content

Commit

Permalink
Add RabbitMQConnectorGateway and related components
Browse files Browse the repository at this point in the history
Introduced the RabbitMQConnectorGateway for enhanced RabbitMQ integration. This includes new tests, and necessary Config model changes. The `publish`, `consume`, and queue management methods were added for the RabbitMQ interaction. A new `batched` helper was also added for batching messages during publishing. Configuration parsing logic and necessary dependencies were also updated.

4
  • Loading branch information
lazorikv committed Jan 17, 2024
1 parent 5de9dc6 commit b840093
Show file tree
Hide file tree
Showing 15 changed files with 513 additions and 46 deletions.
9 changes: 9 additions & 0 deletions config/local.dist.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@ echo = true
access_key = decision_making_app_admin
secret_key = decision_making_app_admin
bucket_name = decision_making_app

[rmq_connector]
host = decision_making_app.rabbitmq
port = 5672
username = decision_making_app_admin
password = decision_making_app_admin
connection_pool_max_size = 3
channel_pool_max_size = 15
default_exchange_name = decision_making_app
122 changes: 78 additions & 44 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ aio-pika = "^9.3.0"
aioboto3 = "12.1.0"
types-aioboto3-lite = { extras = ["essential"], version = "^12.1.0" }
types-aiobotocore-lite = { extras = ["s3"], version = "^2.9.0" }
pika = "^1.3.2"


[tool.poetry.group.test]
Expand All @@ -32,7 +33,7 @@ optional = true
pytest = "^7.4.2"
pytest-asyncio = "^0.21.1"
pytest-order = "^1.2.0"
testcontainers = { extras = ["postgresql"], version = "^3.7.1" }
testcontainers = {extras = ["postgresql", "rabbitmq"], version = "^3.7.1"}
testcontainers-minio = "^0.0.1rc1" # for now its a separate package, in future add as an extra


Expand Down
Empty file.
80 changes: 80 additions & 0 deletions src/app/application/rmq_connector/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import abc

from typing import Protocol, Callable, Coroutine

from aio_pika import Channel, Message
from aio_pika.abc import (
AbstractIncomingMessage,
ExchangeType,
AbstractRobustConnection,
DeliveryMode,
AbstractExchange,
)
from aio_pika.pool import Pool


class RabbitMQConnectorGateway(Protocol):
@abc.abstractmethod
def channel_pool(self) -> Pool[Channel]:
raise NotImplementedError

@abc.abstractmethod
async def publish(
self, *message_bodies: dict, routing_key: str, exchange_name: str | None = None,
) -> None:
raise NotImplementedError

@abc.abstractmethod
async def consume(
self, queue_name: str, callback: Callable[[AbstractIncomingMessage], Coroutine],
) -> None:
raise NotImplementedError

@abc.abstractmethod
async def create_exchange(
self, name: str, type_: ExchangeType = ExchangeType.TOPIC,
) -> None:
raise NotImplementedError

@abc.abstractmethod
async def create_queue(self, name: str) -> None:
raise NotImplementedError

@abc.abstractmethod
async def bid_queue(
self, queue_name: str, routing_key: str, exchange_name: str,
) -> bool:
raise NotImplementedError

@abc.abstractmethod
async def delete_exchange(self, name: str) -> None:
raise NotImplementedError

@abc.abstractmethod
async def delete_queue(self, name: str) -> None:
raise NotImplementedError

@abc.abstractmethod
async def shutdown(self) -> None:
raise NotImplementedError

@abc.abstractmethod
async def _get_exchange(
self, channel: Channel, exchange_name: str | None,
) -> AbstractExchange:
raise NotImplementedError

@abc.abstractmethod
async def _get_connection(self) -> AbstractRobustConnection:
raise NotImplementedError

@abc.abstractmethod
async def _get_channel(self) -> Channel:
raise NotImplementedError

@staticmethod
@abc.abstractmethod
def _serialize_message(
body: dict, delivery_mode: DeliveryMode = DeliveryMode.PERSISTENT,
) -> Message:
raise NotImplementedError
2 changes: 2 additions & 0 deletions src/app/infrastructure/config/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from .database import DatabaseConfig
from .main import Config
from .object_storage import ObjectStorageConfig
from .rmq_connector import RabbitMQConnectorConfig


__all__ = [
"AppConfig",
"Config",
"DatabaseConfig",
"ObjectStorageConfig",
"RabbitMQConnectorConfig",
]
2 changes: 2 additions & 0 deletions src/app/infrastructure/config/models/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from .application import AppConfig
from .database import DatabaseConfig
from .object_storage import ObjectStorageConfig
from .rmq_connector import RabbitMQConnectorConfig


@dataclass
class Config:
app_config: AppConfig
db_config: DatabaseConfig
storage_config: ObjectStorageConfig
rmq_connector_config: RabbitMQConnectorConfig
18 changes: 18 additions & 0 deletions src/app/infrastructure/config/models/rmq_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass


@dataclass
class RabbitMQConnectorConfig:
host: str
port: str
username: str = "guest"
password: str = "guest"
connection_pool_max_size: int = 3
channel_pool_max_size: int = 15
default_exchange_name: str = "decision_making_app"

def full_url(self) -> str:
return "amqp://{}:{}@{}:{}".format(
self.username, self.password,
self.host, self.port,
)
18 changes: 17 additions & 1 deletion src/app/infrastructure/config/parsers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Config,
DatabaseConfig,
ObjectStorageConfig,
RabbitMQConnectorConfig,
)


Expand All @@ -22,6 +23,7 @@ def load_config(path: str | None = None) -> Config:
application_data = parser["application"]
database_data = parser["database"]
object_storage_data = parser["object_storage"]
rmq_connector_data = parser["rmq_connector"]

application_config = AppConfig(
host=application_data.get("host"),
Expand All @@ -41,5 +43,19 @@ def load_config(path: str | None = None) -> Config:
secret_key=object_storage_data.get("secret_key"),
bucket_name=object_storage_data.get("bucket_name"),
)
rmq_connector_config = RabbitMQConnectorConfig(
host=rmq_connector_data.get("host"),
port=rmq_connector_data.get("port"),
username=rmq_connector_data.get("username"),
password=rmq_connector_data.get("password"),
connection_pool_max_size=rmq_connector_data.getint("connection_pool_max_size"),
channel_pool_max_size=rmq_connector_data.getint("channel_pool_max_size"),
default_exchange_name=rmq_connector_data.get("default_exchange_name"),
)

return Config(application_config, database_config, object_storage_config)
return Config(
application_config,
database_config,
object_storage_config,
rmq_connector_config,
)
Empty file.
126 changes: 126 additions & 0 deletions src/app/infrastructure/rmq_connector/gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import asyncio
from typing import Callable, Coroutine

from aio_pika import Channel, connect_robust, Connection, ExchangeType, Message
from aio_pika.abc import (
AbstractExchange,
AbstractIncomingMessage,
AbstractRobustConnection,
DeliveryMode,
)
from aio_pika.pool import Pool
import orjson

from .iterables import batched
from src.app.infrastructure.config.models import RabbitMQConnectorConfig
from src.app.application.rmq_connector.interfaces import RabbitMQConnectorGateway


class RabbitMQConnectorGatewayImpl(RabbitMQConnectorGateway):
BATCH_SIZE: int = 200

def __init__(self, settings: RabbitMQConnectorConfig):
self._settings = settings
self._connection_pool: Pool[Connection] = Pool(
self._get_connection,
max_size=settings.connection_pool_max_size,
)
self._channel_pool: Pool[Channel] = Pool(
self._get_channel,
max_size=settings.channel_pool_max_size,
)
self._default_exchange_name = settings.default_exchange_name

@property
def channel_pool(self) -> Pool[Channel]: # type: ignore
return self._channel_pool

async def publish(
self,
*message_bodies: dict,
routing_key: str,
exchange_name: str | None = None,
) -> None:
async with self._channel_pool.acquire() as channel:
exchange = await self._get_exchange(
channel,
exchange_name=exchange_name or self._default_exchange_name,
)

messages = (self._serialize_message(m) for m in message_bodies)
for messages_batch in batched(
messages,
RabbitMQConnectorGatewayImpl.BATCH_SIZE,
):
tasks = [
exchange.publish(message=m, routing_key=routing_key)
for m in messages_batch
]
await asyncio.gather(*tasks)

async def consume(
self,
queue_name: str,
callback: Callable[[AbstractIncomingMessage], Coroutine],
) -> None:
async with self._channel_pool.acquire() as channel:
queue = await channel.get_queue(queue_name, ensure=True)
await queue.consume(callback, exclusive=True) # type: ignore

async def create_exchange(
self,
name: str,
type_: ExchangeType = ExchangeType.TOPIC,
) -> None:
async with self._channel_pool.acquire() as channel:
await channel.declare_exchange(name, type_)

async def create_queue(self, name: str) -> None:
async with self._channel_pool.acquire() as channel:
await channel.declare_queue(name)

async def bid_queue(
self,
queue_name: str,
routing_key: str,
exchange_name: str,
) -> bool:
async with self._channel_pool.acquire() as channel:
queue = await channel.get_queue(queue_name)
exchange = await self._get_exchange(channel, exchange_name=exchange_name)
result = await queue.bind(exchange=exchange, routing_key=routing_key)
return result.name == "Queue.BindOk"

async def delete_exchange(self, name: str) -> None:
async with self._channel_pool.acquire() as channel:
await channel.exchange_delete(name)

async def delete_queue(self, name: str) -> None:
async with self._channel_pool.acquire() as channel:
await channel.queue_delete(name)

async def shutdown(self) -> None:
await self._channel_pool.close()
await self._connection_pool.close()

@staticmethod
def _serialize_message(
body: dict,
delivery_mode: DeliveryMode = DeliveryMode.PERSISTENT,
) -> Message:
return Message(body=orjson.dumps(body), delivery_mode=delivery_mode)

async def _get_exchange(
self,
channel: Channel,
exchange_name: str | None,
) -> AbstractExchange:
exchange_name = exchange_name or self._default_exchange_name
return await channel.get_exchange(name=exchange_name, ensure=True)

async def _get_connection(self) -> AbstractRobustConnection:
return await connect_robust(self._settings.full_url())

async def _get_channel(self) -> Channel:
async with self._connection_pool.acquire() as connection:
return await connection.channel() # type: ignore
11 changes: 11 additions & 0 deletions src/app/infrastructure/rmq_connector/iterables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from itertools import islice
from typing import Iterable


def batched(iterable: Iterable, batch_size: int):
error_msg = "batch_size must be at least one"
if batch_size < 1:
raise ValueError(error_msg)
it = iter(iterable)
while batch := tuple(islice(it, batch_size)):
yield batch
Empty file.
58 changes: 58 additions & 0 deletions tests/integration/infrastructure/rmq_connector/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Generator, AsyncGenerator
from uuid import uuid4

from testcontainers.rabbitmq import RabbitMqContainer # type: ignore[import-untyped]

import pytest
import pytest_asyncio

from src.app.infrastructure.config.models import RabbitMQConnectorConfig
from src.app.infrastructure.rmq_connector.gateway import RabbitMQConnectorGatewayImpl


@pytest.fixture(scope="session")
def rmq_container() -> Generator[RabbitMqContainer, None, None]:
rabbitmq_container = RabbitMqContainer(
username="guest",
password="guest",
)
try:
rabbitmq_container.start()
yield rabbitmq_container
finally:
rabbitmq_container.stop()


@pytest_asyncio.fixture()
async def rmq_connector_gateway(
rmq_container: RabbitMqContainer,
) -> AsyncGenerator[RabbitMQConnectorGatewayImpl, None]:
config = RabbitMQConnectorConfig(
host=rmq_container.get_container_host_ip(),
port=rmq_container.get_exposed_port(5672),
)
gateway = RabbitMQConnectorGatewayImpl(config)
yield gateway
await gateway.shutdown()


@pytest.fixture()
async def queue(rmq_connector_gateway: RabbitMQConnectorGatewayImpl):
queue_name = str(uuid4())
await rmq_connector_gateway.create_queue(queue_name)

async with rmq_connector_gateway.channel_pool.acquire() as channel:
yield await channel.get_queue(name=queue_name, ensure=True)

await rmq_connector_gateway.delete_queue(queue_name)


@pytest.fixture()
async def exchange(rmq_connector_gateway: RabbitMQConnectorGatewayImpl):
exchange_name = str(uuid4())
await rmq_connector_gateway.create_exchange(exchange_name)

async with rmq_connector_gateway.channel_pool.acquire() as channel:
yield await channel.get_exchange(name=exchange_name, ensure=True)

await rmq_connector_gateway.delete_exchange(exchange_name)
Loading

0 comments on commit b840093

Please sign in to comment.