Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add collect_fields option to SignalGroupDescriptor, and accept a SignalGroup subclass #291

Merged
merged 11 commits into from
Mar 11, 2024
26 changes: 11 additions & 15 deletions src/psygnal/_evented_decorator.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,52 @@
from __future__ import annotations

from typing import (
Any,
Callable,
Dict,
Literal,
Optional,
Type,
TypeVar,
Union,
overload,
)

from psygnal._group_descriptor import SignalGroupDescriptor

__all__ = ["evented"]

T = TypeVar("T", bound=Type)
T = TypeVar("T", bound=type)

EqOperator = Callable[[Any, Any], bool]
PSYGNAL_GROUP_NAME = "_psygnal_group_"
_NULL = object()


@overload
def evented(
cls: T,
*,
events_namespace: str = "events",
equality_operators: Optional[Dict[str, EqOperator]] = None,
equality_operators: dict[str, EqOperator] | None = None,
warn_on_no_fields: bool = ...,
cache_on_instance: bool = ...,
) -> T: ...


@overload
def evented(
cls: "Optional[Literal[None]]" = None,
cls: Literal[None] | None = None,
*,
events_namespace: str = "events",
equality_operators: Optional[Dict[str, EqOperator]] = None,
equality_operators: dict[str, EqOperator] | None = None,
warn_on_no_fields: bool = ...,
cache_on_instance: bool = ...,
) -> Callable[[T], T]: ...


def evented(
cls: Optional[T] = None,
cls: T | None = None,
*,
events_namespace: str = "events",
equality_operators: Optional[Dict[str, EqOperator]] = None,
equality_operators: dict[str, EqOperator] | None = None,
warn_on_no_fields: bool = True,
cache_on_instance: bool = True,
) -> Union[Callable[[T], T], T]:
) -> Callable[[T], T] | T:
"""A decorator to add events to a dataclass.

See also the documentation for
Expand All @@ -71,7 +67,7 @@ def evented(
The class to decorate.
events_namespace : str
The name of the namespace to add the events to, by default `"events"`
equality_operators : Optional[Dict[str, Callable]]
equality_operators : dict[str, Callable] | None
A dictionary mapping field names to equality operators (a function that takes
two values and returns `True` if they are equal). These will be used to
determine if a field has changed when setting a new value. By default, this
Expand Down Expand Up @@ -122,7 +118,7 @@ def _decorate(cls: T) -> T:
if any(k.startswith("_psygnal") for k in getattr(cls, "__annotations__", {})):
raise TypeError("Fields on an evented class cannot start with '_psygnal'")

descriptor = SignalGroupDescriptor(
descriptor: SignalGroupDescriptor = SignalGroupDescriptor(
equality_operators=equality_operators,
warn_on_no_fields=warn_on_no_fields,
cache_on_instance=cache_on_instance,
Expand Down
85 changes: 70 additions & 15 deletions src/psygnal/_group_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import contextlib
import copy
import operator
import sys
import warnings
import weakref
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -34,6 +34,7 @@
T = TypeVar("T", bound=Type)
S = TypeVar("S")


EqOperator = Callable[[Any, Any], bool]
_EQ_OPERATORS: dict[type, dict[str, EqOperator]] = {}
_EQ_OPERATOR_NAME = "__eq_operators__"
Expand Down Expand Up @@ -141,11 +142,25 @@ def connect_setattr(
)


@lru_cache(maxsize=None)
def _build_dataclass_signal_group(
cls: type, equality_operators: Iterable[tuple[str, EqOperator]] | None = None
cls: type,
signal_group_class: type[SignalGroup],
equality_operators: Iterable[tuple[str, EqOperator]] | None = None,
) -> type[SignalGroup]:
"""Build a SignalGroup with events for each field in a dataclass."""
"""Build a SignalGroup with events for each field in a dataclass.

Parameters
----------
cls : type
the dataclass to look for the fields to connect with signals.
signal_group_class: type[SignalGroup]
SignalGroup or a subclass of it, to use as a super class.
Default to SignalGroup
equality_operators: Iterable[tuple[str, EqOperator]] | None
If defined, a mapping of field name and equality operator to use to compare if
each field was modified after being set.
Default to None
"""
_equality_operators = dict(equality_operators) if equality_operators else {}
signals = {}
eq_map = _get_eq_operator_map(cls)
Expand All @@ -162,7 +177,9 @@ def _build_dataclass_signal_group(
# patch in our custom SignalInstance class with maxargs=1 on connect_setattr
sig._signal_instance_class = _DataclassFieldSignalInstance

return type(f"{cls.__name__}SignalGroup", (SignalGroup,), signals)
# Create `signal_group_class` subclass with the attached signals
group_name = f"{cls.__name__}{signal_group_class.__name__}"
return type(group_name, (signal_group_class,), signals)


def is_evented(obj: object) -> bool:
Expand Down Expand Up @@ -335,8 +352,6 @@ def __setattr__(self, name: str, value: Any) -> None:
field to determine whether to emit an event. If not provided, the default
equality operator is `operator.eq`, except for numpy arrays, where
`np.array_equal` is used.
signal_group_class : type[SignalGroup], optional
A custom SignalGroup class to use, by default None
warn_on_no_fields : bool, optional
If `True` (the default), a warning will be emitted if no mutable dataclass-like
fields are found on the object.
Expand All @@ -352,6 +367,13 @@ def __setattr__(self, name: str, value: Any) -> None:
events when fields change. If `False`, no `__setattr__` method will be
created. (This will prevent signal emission, and assumes you are using a
different mechanism to emit signals when fields change.)
signal_group_class : type[SignalGroup] | None, optional
A custom SignalGroup class to use, SignalGroup if None, by default None
collect_fields : bool, optional
Create a signal for each field in the dataclass. If True, the `SignalGroup`
instance will be a subclass of `signal_group_class` (SignalGroup if it is None).
If False, a deepcopy of `signal_group_class` will be used.
Default to True

Examples
--------
Expand All @@ -374,22 +396,42 @@ class Person:
```
"""

# map of id(obj) -> SignalGroup
# cached here in case the object isn't modifiable
_instance_map: ClassVar[dict[int, SignalGroup]] = {}

def __init__(
self,
*,
equality_operators: dict[str, EqOperator] | None = None,
signal_group_class: type[SignalGroup] | None = None,
warn_on_no_fields: bool = True,
cache_on_instance: bool = True,
patch_setattr: bool = True,
signal_group_class: type[SignalGroup] | None = None,
collect_fields: bool = True,
):
self._signal_group = signal_group_class
grp_cls = signal_group_class or SignalGroup
if not (isinstance(grp_cls, type) and issubclass(grp_cls, SignalGroup)):
raise TypeError( # pragma: no cover
f"'signal_group_class' must be a subclass of SignalGroup, "
f"not {grp_cls}"
)
if grp_cls is SignalGroup and collect_fields is False:
raise ValueError(
"Cannot use SignalGroup with collect_fields=False. "
"Use a custom SignalGroup subclass instead."
)

self._name: str | None = None
self._eqop = tuple(equality_operators.items()) if equality_operators else None
self._warn_on_no_fields = warn_on_no_fields
self._cache_on_instance = cache_on_instance
self._patch_setattr = patch_setattr

self._signal_group_class: type[SignalGroup] = grp_cls
self._collect_fields = collect_fields
self._signal_groups: dict[int, type[SignalGroup]] = {}

def __set_name__(self, owner: type, name: str) -> None:
"""Called when this descriptor is added to class `owner` as attribute `name`."""
self._name = name
Expand Down Expand Up @@ -417,10 +459,6 @@ def _do_patch_setattr(self, owner: type) -> None:
"emitted when fields change."
) from e

# map of id(obj) -> SignalGroup
# cached here in case the object isn't modifiable
_instance_map: ClassVar[dict[int, SignalGroup]] = {}

@overload
def __get__(self, instance: None, owner: type) -> SignalGroupDescriptor: ...

Expand All @@ -434,13 +472,15 @@ def __get__(
if instance is None:
return self

signal_group = self._get_signal_group(owner)

# if we haven't yet instantiated a SignalGroup for this instance,
# do it now and cache it. Note that we cache it here in addition to
# the instance (in case the instance is not modifiable).
obj_id = id(instance)
if obj_id not in self._instance_map:
# cache it
self._instance_map[obj_id] = self._create_group(owner)(instance)
self._instance_map[obj_id] = signal_group(instance)
# also *try* to set it on the instance as well, since it will skip all the
# __get__ logic in the future, but if it fails, no big deal.
if self._name and self._cache_on_instance:
Expand All @@ -453,13 +493,28 @@ def __get__(

return self._instance_map[obj_id]

def _get_signal_group(self, owner: type) -> type[SignalGroup]:
type_id = id(owner)
if type_id not in self._signal_groups:
self._signal_groups[type_id] = self._create_group(owner)
return self._signal_groups[type_id]

def _create_group(self, owner: type) -> type[SignalGroup]:
Group = self._signal_group or _build_dataclass_signal_group(owner, self._eqop)
# Do not collect fields from owner class, copy the SignalGroup
if not self._collect_fields:
Group = copy.deepcopy(self._signal_group_class)

# Collect fields and create SignalGroup subclass
else:
Group = _build_dataclass_signal_group(
owner, self._signal_group_class, equality_operators=self._eqop
)
if self._warn_on_no_fields and not Group._psygnal_signals:
warnings.warn(
f"No mutable fields found on class {owner}: no events will be "
"emitted. (Is this a dataclass, attrs, msgspec, or pydantic model?)",
stacklevel=2,
)

self._do_patch_setattr(owner)
return Group
62 changes: 60 additions & 2 deletions tests/test_group_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, ClassVar
from typing import Any, ClassVar, Optional, Type
from unittest.mock import Mock, patch

import pytest

from psygnal import SignalGroupDescriptor, _compiled, _group_descriptor
from psygnal import (
Signal,
SignalGroup,
SignalGroupDescriptor,
_compiled,
_group_descriptor,
)


class MyGroup(SignalGroup):
sig = Signal()


@pytest.mark.parametrize("type_", ["dataclass", "pydantic", "attrs", "msgspec"])
Expand Down Expand Up @@ -213,3 +224,50 @@ class Bar:
# when using connect_setattr with maxargs=None
# remove this test if/when we change maxargs to default to 1 on SignalInstance
assert bar.y == (2, 1) # type: ignore


@pytest.mark.parametrize("collect", [True, False])
@pytest.mark.parametrize("klass", [None, SignalGroup, MyGroup])
def test_collect_fields(collect: bool, klass: Optional[Type[SignalGroup]]) -> None:
signal_class = klass or SignalGroup
should_fail_def = signal_class is SignalGroup and collect is False
ctx = pytest.raises(ValueError) if should_fail_def else nullcontext()

with ctx:

@dataclass
class Foo:
events: ClassVar = SignalGroupDescriptor(
warn_on_no_fields=False,
signal_group_class=klass,
collect_fields=collect,
)
a: int = 1

if should_fail_def:
return

@dataclass
class Bar(Foo):
b: float = 2.0

foo = Foo()
bar = Bar()

assert issubclass(type(foo.events), signal_class)

if collect:
assert type(foo.events) is not signal_class
assert "a" in foo.events
assert "a" in bar.events
assert "b" in bar.events

else:
assert type(foo.events) == signal_class
assert "a" not in foo.events
assert "a" not in bar.events
assert "b" not in bar.events

if signal_class is MyGroup:
assert "sig" in foo.events
assert "sig" in bar.events
Loading