Skip to content

Commit

Permalink
Add Event.isolate()
Browse files Browse the repository at this point in the history
  • Loading branch information
TeamSpen210 committed Sep 14, 2023
1 parent 9fe5037 commit 905bd4a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 11 deletions.
52 changes: 45 additions & 7 deletions src/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
whenever they are modified.
"""
from __future__ import annotations
from typing import Generator, TypeVar, Any, Type, Generic, Callable, Awaitable
from typing_extensions import ParamSpec

from contextlib import contextmanager
from functools import partial
from typing import TypeVar, Any, Type, Generic, Callable, Awaitable
from typing_extensions import ParamSpec
import math

import attrs
import trio
import srctools.logger

__all__ = ['Event', 'ValueChange', 'ObsValue']
LOGGER = srctools.logger.get_logger(__name__)
# TODO: Swap to TypeVarTuple, no kwargs allowed.
ArgT = ParamSpec('ArgT')
ValueT = TypeVar('ValueT')
ValueT_co = TypeVar('ValueT_co', covariant=True)
Expand All @@ -33,6 +36,7 @@ class Event(Generic[ArgT]):
"""Store functions to be called when an event occurs."""
callbacks: list[Callable[ArgT, Awaitable[Any]]]
last_result: tuple[ArgT.args, ArgT.kwargs] | None = attrs.field(init=False)
_override: trio.MemorySendChannel[ArgT.args] | None = attrs.field(repr=False)
_cur_calls: int
name: str
log: bool = attrs.field(repr=False)
Expand All @@ -41,6 +45,7 @@ def __init__(self, name: str='') -> None:
self.name = name or f'<Unnamed {id(self):x}>'
self.callbacks = []
self._cur_calls = 0
self._override = None
self.log = False
self.last_result = None

Expand All @@ -67,8 +72,15 @@ async def __call__(self, /, *args: ArgT.args, **kwargs: ArgT.kwargs) -> None:
This is re-entrant - if called whilst the same event is already being
run, the second will be ignored.
"""
if kwargs:
raise TypeError("No kwargs allowed.")
if self.log:
LOGGER.debug('{}(*{}, **{}) = {}', self.name, args, kwargs, self.callbacks)
LOGGER.debug(
'{}({}) = {}',
self.name,
','.join(map(repr, args)),
self.callbacks,
)

if self._cur_calls and self.last_result is not None:
last_pos, last_kw = self.last_result
Expand All @@ -78,13 +90,16 @@ async def __call__(self, /, *args: ArgT.args, **kwargs: ArgT.kwargs) -> None:
self.last_result = (args, kwargs)
self._cur_calls += 1
try:
async with trio.open_nursery() as nursery:
for func in self.callbacks:
nursery.start_soon(partial(func, *args, **kwargs))
if self._override is not None:
await self._override.send(args)
else:
async with trio.open_nursery() as nursery:
for func in self.callbacks:
nursery.start_soon(partial(func, *args, **kwargs))
finally:
self._cur_calls -= 1

def unregister(self, func: Callable[ArgT, Awaitable[Any]],) -> None:
def unregister(self, func: Callable[ArgT, Awaitable[Any]]) -> None:
"""Remove the given callback.
If it is not registered, raise LookupError.
Expand All @@ -94,6 +109,29 @@ def unregister(self, func: Callable[ArgT, Awaitable[Any]],) -> None:
except ValueError:
raise LookupError(func) from None

@contextmanager
def isolate(self) -> Generator[trio.MemoryReceiveChannel[ArgT.args], None, None]:
"""Temporarily disable all listening callbacks, and redirect to the supplied channel.
This is mainly intended for testing code, to prevent it from affecting other things.
This cannot currently be nested within itself, but isolating different events is fine.
"""
send: trio.MemorySendChannel[ArgT.args]
rec: trio.MemoryReceiveChannel[ArgT.args]

if self._override is not None:
raise ValueError('Event.isolate() does not support nesting with itself!')
# Use an infinite buffer. If the user doesn't read from the channel, or only reads after
# the with statement has exited we want events to just be stored.
send, rec = trio.open_memory_channel(math.inf)
self._override = send
try:
yield rec
finally:
send.close()
assert self._override is send, self._override
self._override = None


@attrs.frozen
class ValueChange(Generic[ValueT_co]):
Expand Down
51 changes: 47 additions & 4 deletions src/test/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest
from unittest.mock import AsyncMock, create_autospec, call

import trio

from event import Event, ValueChange, ObsValue


Expand Down Expand Up @@ -94,15 +96,56 @@ async def test_register_priming() -> None:
func2.assert_awaited_once_with(10)


@no_type_check
async def test_isolate() -> None:
"""Test the isolation context manager."""
event = Event[int]('isolate')
func1 = create_autospec(func_unary, name='func1')
func2 = create_autospec(func_unary, name='func2')
event.register(func1)
await event(4)
func1.assert_awaited_once_with(4)

func1.reset_mock()
rec: trio.MemoryReceiveChannel
with event.isolate() as rec:
with pytest.raises(ValueError): # No nesting.
with event.isolate():
pass

await event(5)
func1.assert_not_awaited()
assert await rec.receive() == (5, )

await event.register_and_prime(func2)
func1.assert_not_awaited()
func2.assert_awaited_once_with(5) # Still passed through.
func2.reset_mock()

await event(48)
await event(36)
for i in range(1024): # Unlimited buffer.
await event(i)

func1.assert_not_awaited()
func2.assert_not_awaited()

assert await rec.receive() == (48, )
assert await rec.receive() == (36, )
for i in range(1024):
assert await rec.receive() == (i, )
# Finished here.
with pytest.raises(trio.EndOfChannel):
await rec.receive()


def test_valuechange() -> None:
"""Check ValueChange() produces the right values."""
with pytest.raises(TypeError):
ValueChange()
ValueChange() # type: ignore
with pytest.raises(TypeError):
ValueChange(1)
ValueChange(1) # type: ignore
with pytest.raises(TypeError):
ValueChange(1, 2, 3)
ValueChange(1, 2, 3) # type: ignore
assert ValueChange(1, 2) == ValueChange(1, 2)
assert ValueChange(old=2, new=3) == ValueChange(2, 3)

Expand Down

0 comments on commit 905bd4a

Please sign in to comment.