diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 96a0240b..ffa1c322 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -10,7 +10,7 @@ ## New Features - +- Added a `Receiver.close()` method for closing just a receiver. Also implemented it for all the `Receiver` implementations in this library. ## Bug Fixes diff --git a/src/frequenz/channels/_anycast.py b/src/frequenz/channels/_anycast.py index 1fadbef8..a3d2a846 100644 --- a/src/frequenz/channels/_anycast.py +++ b/src/frequenz/channels/_anycast.py @@ -10,6 +10,8 @@ from collections import deque from typing import Generic, TypeVar +from typing_extensions import override + from ._exceptions import ChannelClosedError from ._generic import ChannelMessageT from ._receiver import Receiver, ReceiverStoppedError @@ -320,6 +322,7 @@ def __init__(self, channel: Anycast[_T], /) -> None: self._channel: Anycast[_T] = channel """The channel that this sender belongs to.""" + @override async def send(self, message: _T, /) -> None: """Send a message across the channel. @@ -388,8 +391,12 @@ def __init__(self, channel: Anycast[_T], /) -> None: self._channel: Anycast[_T] = channel """The channel that this receiver belongs to.""" + self._closed: bool = False + """Whether the receiver is closed.""" + self._next: _T | type[_Empty] = _Empty + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -405,6 +412,9 @@ async def ready(self) -> bool: if self._next is not _Empty: return True + if self._closed: + return False + # pylint: disable=protected-access while len(self._channel._deque) == 0: if self._channel._closed: @@ -417,6 +427,7 @@ async def ready(self) -> bool: # pylint: enable=protected-access return True + @override def consume(self) -> _T: """Return the latest message once `ready()` is complete. @@ -431,6 +442,9 @@ def consume(self) -> _T: ): raise ReceiverStoppedError(self) from ChannelClosedError(self._channel) + if self._next is _Empty and self._closed: + raise ReceiverStoppedError(self) + assert ( self._next is not _Empty ), "`consume()` must be preceded by a call to `ready()`" @@ -441,6 +455,14 @@ def consume(self) -> _T: return next_val + @override + def close(self) -> None: + """Close this receiver. + + After closing, the receiver will not be able to receive any more messages. + """ + self._closed = True + def __str__(self) -> str: """Return a string representation of this receiver.""" return f"{self._channel}:{type(self).__name__}" diff --git a/src/frequenz/channels/_broadcast.py b/src/frequenz/channels/_broadcast.py index c1017cd7..cd31a9f4 100644 --- a/src/frequenz/channels/_broadcast.py +++ b/src/frequenz/channels/_broadcast.py @@ -11,6 +11,8 @@ from collections import deque from typing import Generic, TypeVar +from typing_extensions import override + from ._exceptions import ChannelClosedError from ._generic import ChannelMessageT from ._receiver import Receiver, ReceiverStoppedError @@ -327,6 +329,7 @@ def __init__(self, channel: Broadcast[_T], /) -> None: self._channel: Broadcast[_T] = channel """The broadcast channel this sender belongs to.""" + @override async def send(self, message: _T, /) -> None: """Send a message to all broadcast receivers. @@ -414,6 +417,9 @@ def __init__( self._q: deque[_T] = deque(maxlen=limit) """The receiver's internal message queue.""" + self._closed: bool = False + """Whether the receiver is closed.""" + def enqueue(self, message: _T, /) -> None: """Put a message into this receiver's queue. @@ -441,6 +447,7 @@ def __len__(self) -> int: """ return len(self._q) + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -462,13 +469,14 @@ async def ready(self) -> bool: # consumed, then we return immediately. # pylint: disable=protected-access while len(self._q) == 0: - if self._channel._closed: + if self._channel._closed or self._closed: return False async with self._channel._recv_cv: await self._channel._recv_cv.wait() return True # pylint: enable=protected-access + @override def consume(self) -> _T: """Return the latest message once `ready` is complete. @@ -481,9 +489,25 @@ def consume(self) -> _T: if not self._q and self._channel._closed: # pylint: disable=protected-access raise ReceiverStoppedError(self) from ChannelClosedError(self._channel) + if self._closed: + raise ReceiverStoppedError(self) + assert self._q, "`consume()` must be preceded by a call to `ready()`" return self._q.popleft() + @override + def close(self) -> None: + """Close the receiver. + + After calling this method, new messages will not be received. Once the + receiver's buffer is drained, trying to receive a message will raise a + [`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError]. + """ + self._closed = True + self._channel._receivers.pop( # pylint: disable=protected-access + hash(self), None + ) + def __str__(self) -> str: """Return a string representation of this receiver.""" return f"{self._channel}:{type(self).__name__}" diff --git a/src/frequenz/channels/_merge.py b/src/frequenz/channels/_merge.py index 3a306dbb..b1f38857 100644 --- a/src/frequenz/channels/_merge.py +++ b/src/frequenz/channels/_merge.py @@ -54,6 +54,8 @@ from collections import deque from typing import Any +from typing_extensions import override + from ._generic import ReceiverMessageT_co from ._receiver import Receiver, ReceiverStoppedError @@ -135,6 +137,7 @@ async def stop(self) -> None: await asyncio.gather(*self._pending, return_exceptions=True) self._pending = set() + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -171,6 +174,7 @@ async def ready(self) -> bool: asyncio.create_task(anext(self._receivers[name]), name=name) ) + @override def consume(self) -> ReceiverMessageT_co: """Return the latest message once `ready` is complete. @@ -187,6 +191,21 @@ def consume(self) -> ReceiverMessageT_co: return self._results.popleft() + @override + def close(self) -> None: + """Close the receiver. + + After calling this method, new messages will not be received. Once the + receiver's buffer is drained, trying to receive a message will raise a + [`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError]. + """ + for task in self._pending: + if not task.done() and task.get_loop().is_running(): + task.cancel() + self._pending = set() + for recv in self._receivers.values(): + recv.close() + def __str__(self) -> str: """Return a string representation of this receiver.""" if len(self._receivers) > 3: diff --git a/src/frequenz/channels/_receiver.py b/src/frequenz/channels/_receiver.py index 53862a45..b7d5e306 100644 --- a/src/frequenz/channels/_receiver.py +++ b/src/frequenz/channels/_receiver.py @@ -157,6 +157,8 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard, TypeVar, overload +from typing_extensions import override + from ._exceptions import Error from ._generic import MappedMessageT_co, ReceiverMessageT_co @@ -215,6 +217,15 @@ def consume(self) -> ReceiverMessageT_co: ReceiverError: If there is some problem with the receiver. """ + def close(self) -> None: + """Close the receiver. + + After calling this method, new messages will not be available from the receiver. + Once the receiver's buffer is drained, trying to receive a message will raise a + [`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError]. + """ + raise NotImplementedError("close() must be implemented by subclasses") + def __aiter__(self) -> Self: """Get an async iterator over the received messages. @@ -433,6 +444,7 @@ def __init__( ) """The function to apply on the input data.""" + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -448,6 +460,7 @@ async def ready(self) -> bool: # We need a noqa here because the docs have a Raises section but the code doesn't # explicitly raise anything. + @override def consume(self) -> MappedMessageT_co: # noqa: DOC502 """Return a transformed message once `ready()` is complete. @@ -460,6 +473,16 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502 """ return self._mapping_function(self._receiver.consume()) + @override + def close(self) -> None: + """Close the receiver. + + After calling this method, new messages will not be received. Once the + receiver's buffer is drained, trying to receive a message will raise a + [`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError]. + """ + self._receiver.close() + def __str__(self) -> str: """Return a string representation of the mapper.""" return f"{type(self).__name__}:{self._receiver}:{self._mapping_function}" @@ -509,6 +532,7 @@ def __init__( self._recv_closed = False + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -528,6 +552,7 @@ async def ready(self) -> bool: self._recv_closed = True return False + @override def consume(self) -> ReceiverMessageT_co: """Return a transformed message once `ready()` is complete. @@ -547,6 +572,16 @@ def consume(self) -> ReceiverMessageT_co: self._next_message = _SENTINEL return message + @override + def close(self) -> None: + """Close the receiver. + + After calling this method, new messages will not be received. Once the + receiver's buffer is drained, trying to receive a message will raise a + [`ReceiverStoppedError`][frequenz.channels.ReceiverStoppedError]. + """ + self._receiver.close() + def __str__(self) -> str: """Return a string representation of the filter.""" return f"{type(self).__name__}:{self._receiver}:{self._filter_function}" diff --git a/src/frequenz/channels/event.py b/src/frequenz/channels/event.py index 0d599e89..5d1bd425 100644 --- a/src/frequenz/channels/event.py +++ b/src/frequenz/channels/event.py @@ -16,6 +16,8 @@ import asyncio as _asyncio +from typing_extensions import override + from frequenz.channels._receiver import Receiver, ReceiverStoppedError @@ -141,6 +143,7 @@ def set(self) -> None: self._is_set = True self._event.set() + @override async def ready(self) -> bool: """Wait until this receiver is ready. @@ -152,6 +155,7 @@ async def ready(self) -> bool: await self._event.wait() return not self._is_stopped + @override def consume(self) -> None: """Consume the event. @@ -168,6 +172,11 @@ def consume(self) -> None: self._is_set = False self._event.clear() + @override + def close(self) -> None: + """Close this receiver.""" + self.stop() + def __str__(self) -> str: """Return a string representation of this event.""" return f"{type(self).__name__}({self._name!r})" diff --git a/src/frequenz/channels/experimental/_relay_sender.py b/src/frequenz/channels/experimental/_relay_sender.py index 3d78b474..398ba8d5 100644 --- a/src/frequenz/channels/experimental/_relay_sender.py +++ b/src/frequenz/channels/experimental/_relay_sender.py @@ -9,6 +9,8 @@ import typing +from typing_extensions import override + from .._generic import SenderMessageT_contra from .._sender import Sender @@ -46,6 +48,7 @@ def __init__(self, *senders: Sender[SenderMessageT_contra]) -> None: """ self._senders = senders + @override async def send(self, message: SenderMessageT_contra, /) -> None: """Send a message. diff --git a/src/frequenz/channels/file_watcher.py b/src/frequenz/channels/file_watcher.py index a4dc3074..e9ff4ca4 100644 --- a/src/frequenz/channels/file_watcher.py +++ b/src/frequenz/channels/file_watcher.py @@ -25,6 +25,7 @@ from datetime import timedelta from enum import Enum +from typing_extensions import override from watchfiles import Change, awatch from watchfiles.main import FileChange @@ -185,6 +186,7 @@ def __del__(self) -> None: # is stopped. self._stop_event.set() + @override async def ready(self) -> bool: """Wait until the receiver is ready with a message or an error. @@ -212,6 +214,7 @@ async def ready(self) -> bool: return True + @override def consume(self) -> Event: """Return the latest event once `ready` is complete. @@ -229,6 +232,11 @@ def consume(self) -> Event: change, path_str = self._changes.pop() return Event(type=EventType(change), path=pathlib.Path(path_str)) + @override + def close(self) -> None: + """Close this receiver.""" + self._stop_event.set() + def __str__(self) -> str: """Return a string representation of this receiver.""" if len(self._paths) > 3: diff --git a/src/frequenz/channels/timer.py b/src/frequenz/channels/timer.py index 2785feea..998430e9 100644 --- a/src/frequenz/channels/timer.py +++ b/src/frequenz/channels/timer.py @@ -102,6 +102,8 @@ async def main() -> None: import asyncio from datetime import timedelta +from typing_extensions import override + from ._receiver import Receiver, ReceiverStoppedError @@ -644,6 +646,7 @@ def stop(self) -> None: # We need a noqa here because the docs have a Raises section but the documented # exceptions are raised indirectly. + @override async def ready(self) -> bool: # noqa: DOC502 """Wait until the timer `interval` passed. @@ -715,6 +718,7 @@ async def ready(self) -> bool: # noqa: DOC502 return True + @override def consume(self) -> timedelta: """Return the latest drift once `ready()` is complete. @@ -741,6 +745,11 @@ def consume(self) -> timedelta: self._current_drift = None return drift + @override + def close(self) -> None: + """Close the timer.""" + self.stop() + def _now(self) -> int: """Return the current monotonic clock time in microseconds. diff --git a/tests/test_anycast.py b/tests/test_anycast.py index c6db0d9a..918c548c 100644 --- a/tests/test_anycast.py +++ b/tests/test_anycast.py @@ -217,3 +217,30 @@ async def test_anycast_filter() -> None: assert (await receiver.receive()) == 12 assert (await receiver.receive()) == 15 + + +async def test_anycast_close_receiver() -> None: + """Ensure closing a receiver stops the receiver.""" + chan = Anycast[int](name="input-chan") + sender = chan.new_sender() + + receiver_1 = chan.new_receiver() + receiver_2 = chan.new_receiver() + + await sender.send(1) + + assert (await receiver_1.receive()) == 1 + + receiver_1.close() + + await sender.send(2) + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_1.receive() + + assert (await receiver_2.receive()) == 2 + + receiver_2.close() + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_2.receive() diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index c8a2e9cf..f995a922 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -232,6 +232,42 @@ async def test_broadcast_map() -> None: assert (await receiver.receive()) is True +async def test_broadcast_map_close_receiver() -> None: + """Ensure closing a map stops the receiver.""" + chan = Broadcast[int](name="input-chan") + sender = chan.new_sender() + + receiver_1 = chan.new_receiver() + receiver_2 = chan.new_receiver() + plus_100_rx = receiver_1.map(lambda num: num + 100) + + await sender.send(1) + + assert (await plus_100_rx.receive()) == 101 + assert (await receiver_2.receive()) == 1 + + plus_100_rx.close() + + await sender.send(2) + + with pytest.raises(ReceiverStoppedError): + _ = await plus_100_rx.receive() + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_1.receive() + + assert (await receiver_2.receive()) == 2 + + await sender.send(3) + + assert (await receiver_2.receive()) == 3 + + receiver_2.close() + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_2.receive() + + async def test_broadcast_filter() -> None: """Ensure filter keeps only the messages that pass the filter.""" chan = Broadcast[int](name="input-chan") @@ -249,6 +285,43 @@ async def test_broadcast_filter() -> None: assert (await receiver.receive()) == 15 +async def test_broadcast_filter_close_receiver() -> None: + """Ensure closing a filter stops the receiver.""" + chan = Broadcast[int](name="input-chan") + sender = chan.new_sender() + + receiver_1 = chan.new_receiver() + receiver_2 = chan.new_receiver() + + gt_10_rx = receiver_1.filter(lambda num: num > 10) + + await sender.send(1) + assert (await receiver_2.receive()) == 1 + + await sender.send(100) + assert (await gt_10_rx.receive()) == 100 + assert (await receiver_2.receive()) == 100 + + gt_10_rx.close() + + await sender.send(2) + + with pytest.raises(ReceiverStoppedError): + _ = await gt_10_rx.receive() + with pytest.raises(ReceiverStoppedError): + _ = await receiver_1.receive() + + assert (await receiver_2.receive()) == 2 + + await sender.send(3) + assert (await receiver_2.receive()) == 3 + + receiver_2.close() + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_2.receive() + + async def test_broadcast_filter_type_guard() -> None: """Ensure filter type guard works.""" chan = Broadcast[int | str](name="input-chan") @@ -320,3 +393,35 @@ class Narrower(Actual): await sender.send(Narrower(10)) assert (await receiver.receive()).value == 10 + + +async def test_broadcast_close_receiver() -> None: + """Ensure closing a receiver stops the receiver.""" + chan = Broadcast[int](name="input-chan") + sender = chan.new_sender() + + receiver_1 = chan.new_receiver() + receiver_2 = chan.new_receiver() + + await sender.send(1) + + assert (await receiver_1.receive()) == 1 + assert (await receiver_2.receive()) == 1 + + receiver_1.close() + + await sender.send(2) + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_1.receive() + + assert (await receiver_2.receive()) == 2 + + await sender.send(3) + + assert (await receiver_2.receive()) == 3 + + receiver_2.close() + + with pytest.raises(ReceiverStoppedError): + _ = await receiver_2.receive() diff --git a/tests/test_event.py b/tests/test_event.py index 950720d0..c9e061c5 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -57,3 +57,51 @@ async def wait_for_event() -> None: assert not event.is_set await event_task + + +async def test_event_close_receiver() -> None: + """Ensure that closing an event stops the receiver.""" + event = Event() + assert not event.is_set + assert not event.is_stopped + + is_ready = False + + async def wait_for_event() -> None: + nonlocal is_ready + await event.ready() + is_ready = True + + event_task = _asyncio.create_task(wait_for_event()) + + await _asyncio.sleep(0) # Yield so the wait_for_event task can run. + + assert not is_ready + assert not event.is_set + assert not event.is_stopped + + event.set() + + await _asyncio.sleep(0) # Yield so the wait_for_event task can run. + assert is_ready + assert event.is_set + assert not event.is_stopped + + event.consume() + assert not event.is_set + assert not event.is_stopped + assert event_task.done() + assert event_task.result() is None + assert not event_task.cancelled() + + event.close() + assert not event.is_set + assert event.is_stopped + + await event.ready() + with _pytest.raises(ReceiverStoppedError): + event.consume() + assert event.is_stopped + assert not event.is_set + + await event_task diff --git a/tests/test_file_watcher_integration.py b/tests/test_file_watcher_integration.py index ef1846e6..7d727aa2 100644 --- a/tests/test_file_watcher_integration.py +++ b/tests/test_file_watcher_integration.py @@ -150,3 +150,40 @@ async def test_file_watcher_exit_iterator(tmp_path: pathlib.Path) -> None: file_watcher.consume() assert number_of_writes == expected_number_of_writes + + +@pytest.mark.integration +async def test_file_watcher_close_receiver(tmp_path: pathlib.Path) -> None: + """Ensure closing the file watcher stops the receiver. + + Args: + tmp_path: A tmp directory to run the file watcher on. Created by pytest. + """ + filename = tmp_path / "test-file" + + number_of_writes = 0 + expected_number_of_writes = 3 + + file_watcher = FileWatcher( + paths=[str(tmp_path)], + force_polling=True, + polling_interval=timedelta(seconds=0.05), + ) + timer = Timer(timedelta(seconds=0.1), SkipMissedAndDrift()) + + async for selected in select(file_watcher, timer): + if selected_from(selected, timer): + filename.write_text(f"{selected.message}") + elif selected_from(selected, file_watcher): + number_of_writes += 1 + if number_of_writes == expected_number_of_writes: + file_watcher.close() + break + + ready = await file_watcher.ready() + assert ready is False + + with pytest.raises(ReceiverStoppedError): + file_watcher.consume() + + assert number_of_writes == expected_number_of_writes diff --git a/tests/test_merge_integration.py b/tests/test_merge_integration.py index 2e3eefcb..f5a61970 100644 --- a/tests/test_merge_integration.py +++ b/tests/test_merge_integration.py @@ -8,6 +8,8 @@ import pytest from frequenz.channels import Anycast, Sender, merge +from frequenz.channels._broadcast import Broadcast +from frequenz.channels._receiver import ReceiverStoppedError @pytest.mark.integration @@ -39,3 +41,44 @@ async def send(ch1: Sender[int], ch2: Sender[int]) -> None: # succession. assert set(results[idx : idx + 2]) == {ctr + 1, ctr + 101} assert results[-1] == 1000 + + +async def test_merge_close_receiver() -> None: + """Ensure merge() closes when a receiver is closed.""" + chan1 = Broadcast[int](name="chan1") + chan2 = Broadcast[int](name="chan2") + + async def send(ch1: Sender[int], ch2: Sender[int]) -> None: + for ctr in range(5): + await ch1.send(ctr + 1) + await ch2.send(ctr + 101) + await chan1.close() + await chan2.close() + + rx1 = chan1.new_receiver() + rx2 = chan2.new_receiver() + closing_merge = merge(rx1, rx2) + prx1 = chan1.new_receiver() + prx2 = chan2.new_receiver() + completing_merge = merge(prx1, prx2) + + senders = asyncio.create_task(send(chan1.new_sender(), chan2.new_sender())) + + results: list[int] = [] + async for item in closing_merge: + results.append(item) + if item == 3: + closing_merge.close() + await senders + assert set(results) == {1, 101, 2, 102, 3, 103} + + with pytest.raises(ReceiverStoppedError): + _ = await rx1.receive() + + with pytest.raises(ReceiverStoppedError): + _ = await rx2.receive() + + comp_results: set[int] = set() + async for item in completing_merge: + comp_results.add(item) + assert comp_results == {1, 101, 2, 102, 3, 103, 4, 104, 5, 105} diff --git a/tests/test_timer.py b/tests/test_timer.py index 1e1dc7ed..55295b29 100644 --- a/tests/test_timer.py +++ b/tests/test_timer.py @@ -13,6 +13,7 @@ import pytest from hypothesis import strategies as st +from frequenz.channels import ReceiverStoppedError from frequenz.channels.timer import ( SkipMissedAndDrift, SkipMissedAndResync, @@ -331,6 +332,21 @@ async def test_timer_construction_wrong_args() -> None: ) +async def test_timer_close_receiver() -> None: + """Test the autostart of a periodic timer.""" + event_loop = asyncio.get_running_loop() + + timer = Timer(timedelta(seconds=1.0), TriggerAllMissed()) + + drift = await timer.receive() + assert drift == pytest.approx(timedelta(seconds=0.0)) + assert event_loop.time() == pytest.approx(1.0) + + timer.close() + with pytest.raises(ReceiverStoppedError): + await timer.receive() + + async def test_timer_autostart() -> None: """Test the autostart of a periodic timer.""" event_loop = asyncio.get_running_loop()