Skip to content

Commit

Permalink
feat: support coroutine functions
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Dec 20, 2024
1 parent 6db5219 commit 4fb4ef8
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 19 deletions.
79 changes: 69 additions & 10 deletions src/psygnal/_weak_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import sys
import weakref
from functools import partial
Expand All @@ -19,6 +20,8 @@
from ._mypyc import mypyc_attr

if TYPE_CHECKING:
from collections.abc import Coroutine

import toolz
from typing_extensions import TypeAlias, TypeGuard # py310

Expand All @@ -28,6 +31,9 @@
_T = TypeVar("_T")
_R = TypeVar("_R") # return type of cb

# reference to all background tasks created by Coroutine WeakCallbacks
_BACKGROUND_TASKS: set[asyncio.Task] = set()


def _is_toolz_curry(obj: Any) -> TypeGuard[toolz.curry]:
"""Return True if obj is a toolz.curry object."""
Expand Down Expand Up @@ -124,14 +130,17 @@ def _on_delete(weak_cb):
kwargs = cb.keywords
cb = cb.func

is_coro = asyncio.iscoroutinefunction(cb)

if isinstance(cb, FunctionType):
return (
StrongFunction(cb, max_args, args, kwargs, priority=priority)
if strong_func
else WeakFunction(
if strong_func:
cls = StrongCoroutineFunction if is_coro else StrongFunction
return cls(cb, max_args, args, kwargs, priority=priority)
else:
wcls = WeakCoroutineFunction if is_coro else WeakFunction
return wcls(
cb, max_args, args, kwargs, finalize, on_ref_error, priority=priority
)
)

if isinstance(cb, MethodType):
if getattr(cb, "__name__", None) == "__setitem__":
Expand All @@ -145,7 +154,8 @@ def _on_delete(weak_cb):
return WeakSetitem(
obj, key, max_args, finalize, on_ref_error, priority=priority
)
return WeakMethod(
mcls = WeakCoroutineMethod if is_coro else WeakMethod
return mcls(
cb, max_args, args, kwargs, finalize, on_ref_error, priority=priority
)

Expand Down Expand Up @@ -225,7 +235,7 @@ def __init__(

self.priority: int = priority

def cb(self, args: tuple[Any, ...] = ()) -> None:
def cb(self, args: tuple[Any, ...] = ()) -> Any:
"""Call the callback with `args`. Args will be spread when calling the func."""
raise NotImplementedError()

Expand Down Expand Up @@ -334,6 +344,8 @@ def _cb(_: weakref.ReferenceType) -> None:
class StrongFunction(WeakCallback):
"""Wrapper around a strong function reference."""

_f: Callable

def __init__(
self,
obj: Callable,
Expand All @@ -351,7 +363,7 @@ def __init__(
if args:
self._object_repr = f"{self._object_repr}{(*args,)!r}".replace(")", " ...)")

def cb(self, args: tuple[Any, ...] = ()) -> None:
def cb(self, args: tuple[Any, ...] = ()) -> Any:
if self._max_args is not None:
args = args[: self._max_args]
self._f(*self._args, *args, **self._kwargs)
Expand All @@ -370,6 +382,21 @@ def __setstate__(self, state: dict) -> None:
setattr(self, k, v)


class StrongCoroutineFunction(StrongFunction):
"""Wrapper around a strong coroutine function reference."""

_f: Callable[..., Coroutine]

def cb(self, args: tuple[Any, ...] = ()) -> Coroutine:
if self._max_args is not None:
args = args[: self._max_args]
coroutine = self._f(*self._args, *args, **self._kwargs)
task = asyncio.create_task(coroutine)
_BACKGROUND_TASKS.add(task)
task.add_done_callback(_BACKGROUND_TASKS.discard)
return coroutine


class WeakFunction(WeakCallback):
"""Wrapper around a weak function reference."""

Expand All @@ -391,7 +418,7 @@ def __init__(
if args:
self._object_repr = f"{self._object_repr}{(*args,)!r}".replace(")", " ...)")

def cb(self, args: tuple[Any, ...] = ()) -> None:
def cb(self, args: tuple[Any, ...] = ()) -> Any:
f = self._f()
if f is None:
raise ReferenceError("weakly-referenced object no longer exists")
Expand All @@ -408,6 +435,21 @@ def dereference(self) -> Callable | None:
return f


class WeakCoroutineFunction(WeakFunction):
def cb(self, args: tuple[Any, ...] = ()) -> Coroutine:
f = self._f()
if f is None:
raise ReferenceError("weakly-referenced object no longer exists")
if self._max_args is not None:
args = args[: self._max_args]
coroutine = f(*self._args, *args, **self._kwargs)

task = asyncio.create_task(coroutine)
_BACKGROUND_TASKS.add(task)
task.add_done_callback(_BACKGROUND_TASKS.discard)
return coroutine


class WeakMethod(WeakCallback):
"""Wrapper around a method bound to a weakly-referenced object.
Expand Down Expand Up @@ -442,7 +484,7 @@ def slot_repr(self) -> str:
func_name = getattr(self._func_ref(), "__name__", "<method>")
return f"{self._obj_module}.{obj.__class__.__qualname__}.{func_name}"

def cb(self, args: tuple[Any, ...] = ()) -> None:
def cb(self, args: tuple[Any, ...] = ()) -> Any:
obj = self._obj_ref()
func = self._func_ref()
if obj is None or func is None:
Expand All @@ -463,6 +505,23 @@ def dereference(self) -> MethodType | partial | None:
return method


class WeakCoroutineMethod(WeakMethod):
def cb(self, args: tuple[Any, ...] = ()) -> Coroutine:
obj = self._obj_ref()
func = self._func_ref()
if obj is None or func is None:
raise ReferenceError("weakly-referenced object no longer exists")

if self._max_args is not None:
args = args[: self._max_args]
coroutine = func(obj, *self._args, *args, **self._kwargs)

task = asyncio.create_task(coroutine)
_BACKGROUND_TASKS.add(task)
task.add_done_callback(_BACKGROUND_TASKS.discard)
return coroutine


class WeakBuiltin(WeakCallback):
"""Wrapper around a c-based method on a weakly-referenced object.
Expand Down
62 changes: 53 additions & 9 deletions tests/test_weak_callable.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import gc
import re
from functools import partial
Expand All @@ -15,10 +16,13 @@
"type_",
[
"function",
"coroutinefunc",
"toolz_function",
"weak_func",
"weak_coroutinefunc",
"lambda",
"method",
"coroutinemethod",
"partial_method",
"toolz_method",
"setattr",
Expand All @@ -33,37 +37,50 @@ def test_slot_types(type_: str, capsys: Any) -> None:
final_mock = Mock()

class MyObj:
def method(self, x: int) -> None:
def method(self, x: int) -> int:
mock(x)
return x

async def coroutine_method(self, x: int) -> int:
mock(x)
return x

def __setitem__(self, key, value):
mock(value)
return value

def __setattr__(self, __name: str, __value) -> None:
def __setattr__(self, __name: str, __value: Any) -> Any:
if __name == "x":
mock(__value)
return __value

obj = MyObj()
obj: Any = MyObj()

if type_ == "setattr":
cb = weak_callback(setattr, obj, "x", finalize=final_mock)
elif type_ == "setitem":
cb = weak_callback(obj.__setitem__, "x", finalize=final_mock)
elif type_ in {"function", "weak_func"}:

def obj(x: int) -> None:
def obj(x: int) -> int:
mock(x)
return x

cb = weak_callback(obj, strong_func=(type_ == "function"), finalize=final_mock)
elif type_ in {"coroutinefunc", "weak_coroutinefunc"}:

async def obj(x: int) -> int:
mock(x)
return x

cb = weak_callback(
obj, strong_func=(type_ == "coroutinefunc"), finalize=final_mock
)
elif type_ == "toolz_function":
toolz = pytest.importorskip("toolz")

@toolz.curry
def obj(z: int, x: int) -> None:
def obj(z: int, x: int) -> int:
mock(x)
return x

Expand All @@ -72,6 +89,8 @@ def obj(z: int, x: int) -> None:
cb = weak_callback(lambda x: mock(x) and x, finalize=final_mock)
elif type_ == "method":
cb = weak_callback(obj.method, finalize=final_mock)
elif type_ == "coroutinemethod":
cb = weak_callback(obj.coroutine_method, finalize=final_mock)
elif type_ == "partial_method":
cb = weak_callback(partial(obj.method, 2), max_args=0, finalize=final_mock)
elif type_ == "toolz_method":
Expand All @@ -87,30 +106,55 @@ def obj(z: int, x: int) -> None:

assert isinstance(cb, WeakCallback)
assert isinstance(cb.slot_repr(), str)
cb.cb((2,))

if "coroutine" in type_:

async def main() -> None:
await cb.cb((2,))

asyncio.run(main())
else:
cb.cb((2,))
assert cb.dereference() is not None
if type_ == "print":
assert capsys.readouterr().out == "2\n"
return

mock.assert_called_once_with(2)
mock.reset_mock()
result = cb(2)

if "coroutine" in type_:
result: Any = None

async def main() -> None:
nonlocal result
result = await cb(2)

asyncio.run(main())
else:
result = cb(2)
if type_ not in ("setattr", "mock"):
assert result == 2
mock.assert_called_once_with(2)

del obj

if type_ not in ("function", "toolz_function", "lambda", "mock"):
if type_ not in ("function", "coroutinefunc", "toolz_function", "lambda", "mock"):
final_mock.assert_called_once_with(cb)
assert cb.dereference() is None
with pytest.raises(ReferenceError):
cb.cb((2,))
with pytest.raises(ReferenceError):
cb(2)
else:
cb.cb((4,))
if "coroutine" in type_:

async def main() -> None:
await cb.cb((4,))

asyncio.run(main())
else:
cb.cb((4,))
mock.assert_called_with(4)


Expand Down

0 comments on commit 4fb4ef8

Please sign in to comment.