Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Receiver.close() method #348

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

## New Features

<!-- Here goes the main new features and examples or instructions on how to use them -->
- Added a `Receiver.close()` method for closing just a receiver. Also implemented it for all the `Receiver` implementations in this library.

## Bug Fixes

Expand Down
22 changes: 22 additions & 0 deletions src/frequenz/channels/_anycast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from collections import deque
from typing import Generic, TypeVar

from typing_extensions import override

from ._exceptions import ChannelClosedError
from ._generic import ChannelMessageT
from ._receiver import Receiver, ReceiverStoppedError
Expand Down Expand Up @@ -320,6 +322,7 @@ def __init__(self, channel: Anycast[_T], /) -> None:
self._channel: Anycast[_T] = channel
"""The channel that this sender belongs to."""

@override
async def send(self, message: _T, /) -> None:
"""Send a message across the channel.

Expand Down Expand Up @@ -388,8 +391,12 @@ def __init__(self, channel: Anycast[_T], /) -> None:
self._channel: Anycast[_T] = channel
"""The channel that this receiver belongs to."""

self._closed: bool = False
"""Whether the receiver is closed."""

self._next: _T | type[_Empty] = _Empty

@override
async def ready(self) -> bool:
"""Wait until the receiver is ready with a message or an error.

Expand All @@ -405,6 +412,9 @@ async def ready(self) -> bool:
if self._next is not _Empty:
return True

if self._closed:
return False

# pylint: disable=protected-access
while len(self._channel._deque) == 0:
if self._channel._closed:
Expand All @@ -417,6 +427,7 @@ async def ready(self) -> bool:
# pylint: enable=protected-access
return True

@override
def consume(self) -> _T:
"""Return the latest message once `ready()` is complete.

Expand All @@ -431,6 +442,9 @@ def consume(self) -> _T:
):
raise ReceiverStoppedError(self) from ChannelClosedError(self._channel)

if self._next is _Empty and self._closed:
raise ReceiverStoppedError(self)

assert (
self._next is not _Empty
), "`consume()` must be preceded by a call to `ready()`"
Expand All @@ -441,6 +455,14 @@ def consume(self) -> _T:

return next_val

@override
def close(self) -> None:
"""Close this receiver.

After closing, the receiver will not be able to receive any more messages.
"""
self._closed = True
llucax marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self) -> str:
"""Return a string representation of this receiver."""
return f"{self._channel}:{type(self).__name__}"
Expand Down
26 changes: 25 additions & 1 deletion src/frequenz/channels/_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from collections import deque
from typing import Generic, TypeVar

from typing_extensions import override

from ._exceptions import ChannelClosedError
from ._generic import ChannelMessageT
from ._receiver import Receiver, ReceiverStoppedError
Expand Down Expand Up @@ -327,6 +329,7 @@ def __init__(self, channel: Broadcast[_T], /) -> None:
self._channel: Broadcast[_T] = channel
"""The broadcast channel this sender belongs to."""

@override
async def send(self, message: _T, /) -> None:
"""Send a message to all broadcast receivers.

Expand Down Expand Up @@ -414,6 +417,9 @@ def __init__(
self._q: deque[_T] = deque(maxlen=limit)
"""The receiver's internal message queue."""

self._closed: bool = False
"""Whether the receiver is closed."""

def enqueue(self, message: _T, /) -> None:
"""Put a message into this receiver's queue.

Expand Down Expand Up @@ -441,6 +447,7 @@ def __len__(self) -> int:
"""
return len(self._q)

@override
async def ready(self) -> bool:
"""Wait until the receiver is ready with a message or an error.

Expand All @@ -462,13 +469,14 @@ async def ready(self) -> bool:
# consumed, then we return immediately.
# pylint: disable=protected-access
while len(self._q) == 0:
if self._channel._closed:
if self._channel._closed or self._closed:
return False
async with self._channel._recv_cv:
await self._channel._recv_cv.wait()
return True
# pylint: enable=protected-access

@override
def consume(self) -> _T:
"""Return the latest message once `ready` is complete.

Expand All @@ -481,9 +489,25 @@ def consume(self) -> _T:
if not self._q and self._channel._closed: # pylint: disable=protected-access
raise ReceiverStoppedError(self) from ChannelClosedError(self._channel)

if self._closed:
raise ReceiverStoppedError(self)

assert self._q, "`consume()` must be preceded by a call to `ready()`"
return self._q.popleft()

@override
def close(self) -> None:
"""Close the receiver.

After calling this method, new messages will not be received. Once the
receiver's buffer is drained, trying to receive a message will raise a
[`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError].
"""
self._closed = True
self._channel._receivers.pop( # pylint: disable=protected-access
hash(self), None
)

def __str__(self) -> str:
"""Return a string representation of this receiver."""
return f"{self._channel}:{type(self).__name__}"
Expand Down
19 changes: 19 additions & 0 deletions src/frequenz/channels/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
from collections import deque
from typing import Any

from typing_extensions import override

from ._generic import ReceiverMessageT_co
from ._receiver import Receiver, ReceiverStoppedError

Expand Down Expand Up @@ -135,6 +137,7 @@ async def stop(self) -> None:
await asyncio.gather(*self._pending, return_exceptions=True)
self._pending = set()

@override
async def ready(self) -> bool:
"""Wait until the receiver is ready with a message or an error.

Expand Down Expand Up @@ -171,6 +174,7 @@ async def ready(self) -> bool:
asyncio.create_task(anext(self._receivers[name]), name=name)
)

@override
def consume(self) -> ReceiverMessageT_co:
"""Return the latest message once `ready` is complete.

Expand All @@ -187,6 +191,21 @@ def consume(self) -> ReceiverMessageT_co:

return self._results.popleft()

@override
def close(self) -> None:
"""Close the receiver.

After calling this method, new messages will not be received. Once the
receiver's buffer is drained, trying to receive a message will raise a
[`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError].
"""
for task in self._pending:
if not task.done() and task.get_loop().is_running():
llucax marked this conversation as resolved.
Show resolved Hide resolved
task.cancel()
self._pending = set()
for recv in self._receivers.values():
recv.close()

def __str__(self) -> str:
"""Return a string representation of this receiver."""
if len(self._receivers) > 3:
Expand Down
35 changes: 35 additions & 0 deletions src/frequenz/channels/_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard, TypeVar, overload

from typing_extensions import override

from ._exceptions import Error
from ._generic import MappedMessageT_co, ReceiverMessageT_co

Expand Down Expand Up @@ -215,6 +217,15 @@ def consume(self) -> ReceiverMessageT_co:
ReceiverError: If there is some problem with the receiver.
"""

def close(self) -> None:
"""Close the receiver.

After calling this method, new messages will not be available from the receiver.
Once the receiver's buffer is drained, trying to receive a message will raise a
[`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError].
"""
raise NotImplementedError("close() must be implemented by subclasses")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't forget:


def __aiter__(self) -> Self:
"""Get an async iterator over the received messages.

Expand Down Expand Up @@ -433,6 +444,7 @@ def __init__(
)
"""The function to apply on the input data."""

@override
async def ready(self) -> bool:
"""Wait until the receiver is ready with a message or an error.

Expand All @@ -448,6 +460,7 @@ async def ready(self) -> bool:

# We need a noqa here because the docs have a Raises section but the code doesn't
# explicitly raise anything.
@override
def consume(self) -> MappedMessageT_co: # noqa: DOC502
"""Return a transformed message once `ready()` is complete.

Expand All @@ -460,6 +473,16 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502
"""
return self._mapping_function(self._receiver.consume())

@override
def close(self) -> None:
"""Close the receiver.

After calling this method, new messages will not be received. Once the
receiver's buffer is drained, trying to receive a message will raise a
[`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError].
"""
self._receiver.close()

def __str__(self) -> str:
"""Return a string representation of the mapper."""
return f"{type(self).__name__}:{self._receiver}:{self._mapping_function}"
Expand Down Expand Up @@ -509,6 +532,7 @@ def __init__(

self._recv_closed = False

@override
async def ready(self) -> bool:
"""Wait until the receiver is ready with a message or an error.

Expand All @@ -528,6 +552,7 @@ async def ready(self) -> bool:
self._recv_closed = True
return False

@override
def consume(self) -> ReceiverMessageT_co:
"""Return a transformed message once `ready()` is complete.

Expand All @@ -547,6 +572,16 @@ def consume(self) -> ReceiverMessageT_co:
self._next_message = _SENTINEL
return message

@override
def close(self) -> None:
"""Close the receiver.

After calling this method, new messages will not be received. Once the
receiver's buffer is drained, trying to receive a message will raise a
[`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError].
"""
self._receiver.close()

def __str__(self) -> str:
"""Return a string representation of the filter."""
return f"{type(self).__name__}:{self._receiver}:{self._filter_function}"
Expand Down
9 changes: 9 additions & 0 deletions src/frequenz/channels/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import asyncio as _asyncio

from typing_extensions import override

from frequenz.channels._receiver import Receiver, ReceiverStoppedError


Expand Down Expand Up @@ -141,6 +143,7 @@ def set(self) -> None:
self._is_set = True
self._event.set()

@override
async def ready(self) -> bool:
"""Wait until this receiver is ready.

Expand All @@ -152,6 +155,7 @@ async def ready(self) -> bool:
await self._event.wait()
return not self._is_stopped

@override
def consume(self) -> None:
"""Consume the event.

Expand All @@ -168,6 +172,11 @@ def consume(self) -> None:
self._is_set = False
self._event.clear()

@override
def close(self) -> None:
"""Close this receiver."""
self.stop()

def __str__(self) -> str:
"""Return a string representation of this event."""
return f"{type(self).__name__}({self._name!r})"
Expand Down
3 changes: 3 additions & 0 deletions src/frequenz/channels/experimental/_relay_sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import typing

from typing_extensions import override

from .._generic import SenderMessageT_contra
from .._sender import Sender

Expand Down Expand Up @@ -46,6 +48,7 @@ def __init__(self, *senders: Sender[SenderMessageT_contra]) -> None:
"""
self._senders = senders

@override
async def send(self, message: SenderMessageT_contra, /) -> None:
"""Send a message.

Expand Down
8 changes: 8 additions & 0 deletions src/frequenz/channels/file_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from datetime import timedelta
from enum import Enum

from typing_extensions import override
from watchfiles import Change, awatch
from watchfiles.main import FileChange

Expand Down Expand Up @@ -185,6 +186,7 @@ def __del__(self) -> None:
# is stopped.
self._stop_event.set()

@override
async def ready(self) -> bool:
"""Wait until the receiver is ready with a message or an error.

Expand Down Expand Up @@ -212,6 +214,7 @@ async def ready(self) -> bool:

return True

@override
def consume(self) -> Event:
"""Return the latest event once `ready` is complete.

Expand All @@ -229,6 +232,11 @@ def consume(self) -> Event:
change, path_str = self._changes.pop()
return Event(type=EventType(change), path=pathlib.Path(path_str))

@override
def close(self) -> None:
"""Close this receiver."""
self._stop_event.set()
llucax marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self) -> str:
"""Return a string representation of this receiver."""
if len(self._paths) > 3:
Expand Down
Loading
Loading