Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

weaksets in robust channel #529

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions aio_pika/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ class DeclarationResult:
class AbstractTransaction:
state: TransactionState

@property
@abstractmethod
def channel(self) -> "AbstractChannel":
raise NotImplementedError

@abstractmethod
async def select(
self, timeout: TimeoutType = None,
Expand Down Expand Up @@ -514,9 +509,8 @@ def is_closed(self) -> bool:
def close(self, exc: Optional[ExceptionType] = None) -> Awaitable[None]:
raise NotImplementedError

@property
@abstractmethod
def channel(self) -> aiormq.abc.AbstractChannel:
async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:
raise NotImplementedError

@property
Expand Down
49 changes: 32 additions & 17 deletions aio_pika/channel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import warnings
from abc import ABC
from types import TracebackType
from typing import Any, AsyncContextManager, Generator, Optional, Type, Union
Expand Down Expand Up @@ -78,7 +79,7 @@ def __init__(
# That's means user closed channel instance explicitly
self._closed: bool = False

self._channel = None
self._channel: Optional[UnderlayChannel] = None
self._channel_number = channel_number

self.close_callbacks = CallbackCollection(self)
Expand Down Expand Up @@ -119,8 +120,7 @@ async def close(
self._closed = True
await self._channel.close()

@property
def channel(self) -> aiormq.abc.AbstractChannel:
async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:

if not self.is_initialized or not self._channel:
raise aiormq.exceptions.ChannelInvalidStateError(
Expand All @@ -134,13 +134,20 @@ def channel(self) -> aiormq.abc.AbstractChannel:

return self._channel.channel

@property
def channel(self) -> Optional[aiormq.abc.AbstractChannel]:
warnings.warn("This property is deprecated, do not use this anymore.")
if self._channel is None:
raise aiormq.exceptions.ChannelInvalidStateError
return self._channel.channel

@property
def number(self) -> Optional[int]:
return (
self.channel.number
if self.is_initialized
else self._channel_number
)
if self._channel is None:
return self._channel_number

underlay_channel: UnderlayChannel = self._channel
return underlay_channel.channel.number

def __str__(self) -> str:
return "{}".format(self.number or "Not initialized channel")
Expand Down Expand Up @@ -192,7 +199,8 @@ async def _on_close(self, closing: asyncio.Future) -> None:
self._channel.channel.on_return_callbacks.discard(self._on_return)

async def _on_initialized(self) -> None:
self.channel.on_return_callbacks.add(self._on_return)
channel = await self.get_underlay_channel()
channel.on_return_callbacks.add(self._on_return)

def _on_return(self, message: aiormq.abc.DeliveredMessage) -> None:
self.return_callbacks(IncomingMessage(message, no_ack=True))
Expand Down Expand Up @@ -240,8 +248,10 @@ async def declare_exchange(
if auto_delete and durable is None:
durable = False

channel = await self.get_underlay_channel()

exchange = self.EXCHANGE_CLASS(
channel=self.channel,
channel=channel,
name=name,
type=type,
durable=durable,
Expand Down Expand Up @@ -281,7 +291,7 @@ async def get_exchange(
return await self.declare_exchange(name=name, passive=True)
else:
return self.EXCHANGE_CLASS(
channel=self.channel,
channel=await self.get_underlay_channel(),
name=name,
durable=False,
auto_delete=False,
Expand Down Expand Up @@ -321,7 +331,7 @@ async def declare_queue(
"""

queue: AbstractQueue = self.QUEUE_CLASS(
channel=self.channel,
channel=await self.get_underlay_channel(),
name=name,
durable=durable,
exclusive=exclusive,
Expand Down Expand Up @@ -358,7 +368,7 @@ async def get_queue(
return await self.declare_queue(name=name, passive=True)
else:
return self.QUEUE_CLASS(
channel=self.channel,
channel=await self.get_underlay_channel(),
name=name,
durable=False,
exclusive=False,
Expand All @@ -379,7 +389,9 @@ async def set_qos(
warn('Use "global_" instead of "all_channels"', DeprecationWarning)
global_ = all_channels

return await self.channel.basic_qos(
channel = await self.get_underlay_channel()

return await channel.basic_qos(
prefetch_count=prefetch_count,
prefetch_size=prefetch_size,
global_=global_,
Expand All @@ -394,7 +406,8 @@ async def queue_delete(
if_empty: bool = False,
nowait: bool = False,
) -> aiormq.spec.Queue.DeleteOk:
return await self.channel.queue_delete(
channel = await self.get_underlay_channel()
return await channel.queue_delete(
queue=queue_name,
if_unused=if_unused,
if_empty=if_empty,
Expand All @@ -409,7 +422,8 @@ async def exchange_delete(
if_unused: bool = False,
nowait: bool = False,
) -> aiormq.spec.Exchange.DeleteOk:
return await self.channel.exchange_delete(
channel = await self.get_underlay_channel()
return await channel.exchange_delete(
exchange=exchange_name,
if_unused=if_unused,
nowait=nowait,
Expand All @@ -426,7 +440,8 @@ def transaction(self) -> Transaction:
return Transaction(self)

async def flow(self, active: bool = True) -> aiormq.spec.Channel.FlowOk:
return await self.channel.flow(active=active)
channel = await self.get_underlay_channel()
return await channel.flow(active=active)


__all__ = ("Channel",)
2 changes: 1 addition & 1 deletion aio_pika/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ async def ack(self, multiple: bool = False) -> None:

async def reject(self, requeue: bool = False) -> None:
""" When `requeue=True` the message will be returned to queue.
Otherwise message will be dropped.
Otherwise, message will be dropped.

.. note::
This method looks like a blocking-method, but actually it just
Expand Down
15 changes: 10 additions & 5 deletions aio_pika/robust_channel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
from collections import defaultdict
from itertools import chain
from typing import Any, DefaultDict, Dict, Optional, Set, Type, Union
from typing import Any, DefaultDict, Dict, MutableSet, Optional, Type, Union
from warnings import warn
from weakref import WeakSet

import aiormq

Expand All @@ -28,8 +29,8 @@ class RobustChannel(Channel, AbstractRobustChannel): # type: ignore
QUEUE_CLASS: Type[Queue] = RobustQueue
EXCHANGE_CLASS: Type[Exchange] = RobustExchange

_exchanges: DefaultDict[str, Set[AbstractRobustExchange]]
_queues: DefaultDict[str, Set[RobustQueue]]
_exchanges: DefaultDict[str, MutableSet[AbstractRobustExchange]]
_queues: DefaultDict[str, MutableSet[RobustQueue]]
default_exchange: RobustExchange

def __init__(
Expand Down Expand Up @@ -57,8 +58,8 @@ def __init__(
on_return_raises=on_return_raises,
)

self._exchanges = defaultdict(set)
self._queues = defaultdict(set)
self._exchanges = defaultdict(WeakSet)
self._queues = defaultdict(WeakSet)
self._prefetch_count: int = 0
self._prefetch_size: int = 0
self._global_qos: bool = False
Expand All @@ -73,6 +74,10 @@ async def __close_callback(self, *_: Any) -> None:

await self.reopen()

async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:
await self._connection.ready()
return await super().get_underlay_channel()

async def restore(self, connection: aiormq.abc.AbstractConnection) -> None:
async with self.__restore_lock:
self._connection = connection
Expand Down
9 changes: 6 additions & 3 deletions aio_pika/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,25 @@ def channel(self) -> AbstractChannel:
async def select(
self, timeout: TimeoutType = None,
) -> aiormq.spec.Tx.SelectOk:
result = await self.channel.channel.tx_select(timeout=timeout)
channel = await self.channel.get_underlay_channel()
result = await channel.tx_select(timeout=timeout)

self.state = TransactionState.STARTED
return result

async def rollback(
self, timeout: TimeoutType = None,
) -> commands.Tx.RollbackOk:
result = await self.channel.channel.tx_rollback(timeout=timeout)
channel = await self.channel.get_underlay_channel()
result = await channel.tx_rollback(timeout=timeout)
self.state = TransactionState.ROLLED_BACK
return result

async def commit(
self, timeout: TimeoutType = None,
) -> commands.Tx.CommitOk:
result = await self.channel.channel.tx_commit(timeout=timeout)
channel = await self.channel.get_underlay_channel()
result = await channel.tx_commit(timeout=timeout)
self.state = TransactionState.COMMITED
return result

Expand Down
31 changes: 24 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ typing_extensions = [{ version = '*', python = "< 3.8" }]
setuptools = [{ version = '*', python = "< 3.8" }]

[tool.poetry.group.dev.dependencies]
aiomisc = "^16.2"
collective-checkdocs = "^0.2"
coverage = "^6.5.0"
coveralls = "^3.3.1"
Expand All @@ -59,6 +58,7 @@ sphinx = "^5.3.0"
sphinx-autobuild = "^2021.3.14"
timeout-decorator = "^0.5.0"
types-setuptools = "^65.6.0.2"
aiomisc-pytest = "^1.1.1"

[tool.poetry.group.uvloop.dependencies]
uvloop = "^0.17.0"
Expand Down
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


@pytest.fixture
async def add_cleanup(loop):
async def add_cleanup(event_loop):
entities = []

def payload(func, *args, **kwargs):
Expand All @@ -34,12 +34,12 @@ def payload(func, *args, **kwargs):


@pytest.fixture
async def create_task(loop):
async def create_task(event_loop):
tasks = []

def payload(coroutine):
nonlocal tasks
task = loop.create_task(coroutine)
task = event_loop.create_task(coroutine)
tasks.append(task)
return task

Expand Down Expand Up @@ -91,8 +91,8 @@ def connection_fabric(request):


@pytest.fixture
def create_connection(connection_fabric, loop, amqp_url):
return partial(connection_fabric, amqp_url, loop=loop)
def create_connection(connection_fabric, event_loop, amqp_url):
return partial(connection_fabric, amqp_url, loop=event_loop)


@pytest.fixture
Expand Down
Loading