Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RabbitMQConnector #7

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions config/local.dist.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@ echo = true
access_key = decision_making_app_admin
secret_key = decision_making_app_admin
bucket_name = decision_making_app

[message_queue]
host = decision_making_app.rabbitmq
port = 5672
username = decision_making_app_admin
password = decision_making_app_admin
connection_pool_max_size = 3
channel_pool_max_size = 15
default_exchange_name = decision_making_app
1,842 changes: 936 additions & 906 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@ aioboto3 = "12.1.0"
types-aioboto3-lite = { extras = ["essential"], version = "^12.1.0" }
types-aiobotocore-lite = { extras = ["s3"], version = "^2.9.0" }


[tool.poetry.group.test]
optional = true

[tool.poetry.group.test.dependencies]
pytest = "^7.4.2"
pytest-asyncio = "^0.21.1"
pytest-order = "^1.2.0"
testcontainers = { extras = ["postgresql"], version = "^3.7.1" }
testcontainers-minio = "^0.0.1rc1" # for now its a separate package, in future add as an extra
testcontainers = {extras = ["postgresql", "rabbitmq"], version = "^3.7.1"}
testcontainers-minio = "^0.0.1rc1" # for now, it's a separate package


[tool.poetry.group.lint]
Expand Down
29 changes: 28 additions & 1 deletion src/app/application/common/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import io

from typing import Protocol
from typing import Callable, Protocol


class ObjectStorageGateway(Protocol):
Expand All @@ -28,3 +28,30 @@ async def get_file_presigned_url(self, name: str, expires_in: int) -> str:
@abc.abstractmethod
async def get_file_object(self, name: str) -> io.BytesIO:
raise NotImplementedError


class MessageQueueGateway(Protocol):
@abc.abstractmethod
async def publish(
self,
*message_bodies: dict,
routing_key: str,
exchange_name: str | None = None,
) -> None:
raise NotImplementedError

@abc.abstractmethod
async def consume(
self,
queue_name: str,
callback: Callable,
) -> None:
raise NotImplementedError

@abc.abstractmethod
async def create_queue(self, name: str) -> None:
raise NotImplementedError

@abc.abstractmethod
async def delete_queue(self, name: str) -> None:
raise NotImplementedError
2 changes: 2 additions & 0 deletions src/app/infrastructure/config/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from .database import DatabaseConfig
from .main import Config
from .object_storage import ObjectStorageConfig
from .message_queue import MessageQueueConfig


__all__ = [
"AppConfig",
"Config",
"DatabaseConfig",
"ObjectStorageConfig",
"MessageQueueConfig",
]
2 changes: 2 additions & 0 deletions src/app/infrastructure/config/models/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from .application import AppConfig
from .database import DatabaseConfig
from .object_storage import ObjectStorageConfig
from .message_queue import MessageQueueConfig


@dataclass
class Config:
app_config: AppConfig
db_config: DatabaseConfig
storage_config: ObjectStorageConfig
message_queue_config: MessageQueueConfig
24 changes: 24 additions & 0 deletions src/app/infrastructure/config/models/message_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from dataclasses import dataclass


@dataclass
class MessageQueueConfig:
host: str
port: str
username: str = "guest"
password: str = "guest"
connector_prefix: str = "amqp"

# options that are not usable in the full url
connection_pool_max_size: int = 3
channel_pool_max_size: int = 15
batch_size: int = 200
default_exchange_name: str = "decision_making_app"

@property
def full_url(self) -> str:
return "{}://{}:{}@{}:{}".format(
self.connector_prefix,
self.username, self.password,
self.host, self.port,
)
18 changes: 17 additions & 1 deletion src/app/infrastructure/config/parsers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Config,
DatabaseConfig,
ObjectStorageConfig,
MessageQueueConfig,
)


Expand All @@ -22,6 +23,7 @@ def load_config(path: str | None = None) -> Config:
application_data = parser["application"]
database_data = parser["database"]
object_storage_data = parser["object_storage"]
message_queue_data = parser["message_queue"]

application_config = AppConfig(
host=application_data.get("host"),
Expand All @@ -41,5 +43,19 @@ def load_config(path: str | None = None) -> Config:
secret_key=object_storage_data.get("secret_key"),
bucket_name=object_storage_data.get("bucket_name"),
)
message_queue_config = MessageQueueConfig(
host=message_queue_data.get("host"),
port=message_queue_data.get("port"),
username=message_queue_data.get("username"),
password=message_queue_data.get("password"),
connection_pool_max_size=message_queue_data.getint("connection_pool_max_size"),
channel_pool_max_size=message_queue_data.getint("channel_pool_max_size"),
default_exchange_name=message_queue_data.get("default_exchange_name"),
)

return Config(application_config, database_config, object_storage_config)
return Config(
application_config,
database_config,
object_storage_config,
message_queue_config,
)
Empty file.
133 changes: 133 additions & 0 deletions src/app/infrastructure/message_queue/gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import asyncio
from typing import Callable, Coroutine

import orjson
from aio_pika import (
Channel,
Connection,
ExchangeType,
Message,
connect_robust,
)
from aio_pika.abc import (
AbstractExchange,
AbstractIncomingMessage,
AbstractRobustConnection,
DeliveryMode,
)
from aio_pika.pool import Pool

from src.app.infrastructure.config.models import MessageQueueConfig
from src.app.application.common.interfaces import MessageQueueGateway

from .iterables import batched


class MessageQueueGatewayImpl(MessageQueueGateway):
def __init__(
self,
settings: MessageQueueConfig,
batch_size: int | None = None,
):
self._settings = settings
self._batch_size = batch_size if batch_size is not None else settings.batch_size

self.connection_pool: Pool[Connection] = Pool(
self._get_connection,
max_size=settings.connection_pool_max_size,
)
self.channel_pool: Pool[Channel] = Pool(
self._get_channel,
max_size=settings.channel_pool_max_size,
)
self.default_exchange_name = settings.default_exchange_name

async def publish(
self,
*message_bodies: dict,
routing_key: str,
exchange_name: str | None = None,
) -> None:
async with self.channel_pool.acquire() as channel:
exchange = await self._get_exchange(
channel,
exchange_name=exchange_name or self.default_exchange_name,
)

messages = (self._serialize_message(m) for m in message_bodies)
for messages_batch in batched(
messages,
self._batch_size,
):
tasks = [
exchange.publish(message=m, routing_key=routing_key)
for m in messages_batch
]
await asyncio.gather(*tasks)

async def consume(
self,
queue_name: str,
callback: Callable[[AbstractIncomingMessage], Coroutine],
) -> None:
async with self.channel_pool.acquire() as channel:
queue = await channel.get_queue(queue_name, ensure=True)
await queue.consume(callback, exclusive=True) # type: ignore

async def create_exchange(
self,
name: str,
type_: ExchangeType = ExchangeType.TOPIC,
) -> None:
async with self.channel_pool.acquire() as channel:
await channel.declare_exchange(name, type_)

async def create_queue(self, name: str) -> None:
async with self.channel_pool.acquire() as channel:
await channel.declare_queue(name)

async def bid_queue(
self,
queue_name: str,
routing_key: str,
exchange_name: str,
) -> bool:
async with self.channel_pool.acquire() as channel:
queue = await channel.get_queue(queue_name)
exchange = await self._get_exchange(channel, exchange_name=exchange_name)
result = await queue.bind(exchange=exchange, routing_key=routing_key)
return result.name == "Queue.BindOk"

async def delete_exchange(self, name: str) -> None:
async with self.channel_pool.acquire() as channel:
await channel.exchange_delete(name)

async def delete_queue(self, name: str) -> None:
async with self.channel_pool.acquire() as channel:
await channel.queue_delete(name)

async def shutdown(self) -> None:
await self.channel_pool.close()
await self.connection_pool.close()

@staticmethod
def _serialize_message(
body: dict,
delivery_mode: DeliveryMode = DeliveryMode.PERSISTENT,
) -> Message:
return Message(body=orjson.dumps(body), delivery_mode=delivery_mode)

async def _get_exchange(
self,
channel: Channel,
exchange_name: str | None,
) -> AbstractExchange:
exchange_name = exchange_name or self.default_exchange_name
return await channel.get_exchange(name=exchange_name, ensure=True)

async def _get_connection(self) -> AbstractRobustConnection:
return await connect_robust(self._settings.full_url)

async def _get_channel(self) -> Channel:
async with self.connection_pool.acquire() as connection:
return await connection.channel() # type: ignore
11 changes: 11 additions & 0 deletions src/app/infrastructure/message_queue/iterables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from itertools import islice
from typing import Iterable


def batched(iterable: Iterable, batch_size: int):
error_msg: str = "batch_size must be at least one"
if batch_size < 1:
raise ValueError(error_msg)
it = iter(iterable)
while batch := tuple(islice(it, batch_size)):
yield batch
Empty file.
57 changes: 57 additions & 0 deletions tests/integration/infrastructure/message_queue/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Generator, AsyncGenerator
from uuid import uuid4

from testcontainers.rabbitmq import RabbitMqContainer # type: ignore[import-untyped]

import pytest
import pytest_asyncio

from src.app.infrastructure.config.models import MessageQueueConfig
from src.app.infrastructure.message_queue.gateway import MessageQueueGatewayImpl


@pytest.fixture(scope="session")
def rmq_container() -> Generator[RabbitMqContainer, None, None]:
rabbitmq_container = RabbitMqContainer(
username="guest",
password="guest",
)
try:
rabbitmq_container.start()
yield rabbitmq_container
finally:
rabbitmq_container.stop()


@pytest_asyncio.fixture()
async def message_queue_gateway(
rmq_container: RabbitMqContainer,
) -> AsyncGenerator[MessageQueueGatewayImpl, None]:
config = MessageQueueConfig(
host=rmq_container.get_container_host_ip(),
port=rmq_container.get_exposed_port(5672),
)
gateway = MessageQueueGatewayImpl(config)
yield gateway
await gateway.shutdown()

@pytest.fixture()
async def queue(message_queue_gateway: MessageQueueGatewayImpl):
queue_name = str(uuid4())
await message_queue_gateway.create_queue(queue_name)

async with message_queue_gateway.channel_pool.acquire() as channel:
yield await channel.get_queue(name=queue_name, ensure=True)

await message_queue_gateway.delete_queue(queue_name)


@pytest.fixture()
async def exchange(message_queue_gateway: MessageQueueGatewayImpl):
exchange_name = str(uuid4())
await message_queue_gateway.create_exchange(exchange_name)

async with message_queue_gateway.channel_pool.acquire() as channel:
yield await channel.get_exchange(name=exchange_name, ensure=True)

await message_queue_gateway.delete_exchange(exchange_name)
Loading
Loading