From 132644564778a018439bf2333846ce6fbf40d923 Mon Sep 17 00:00:00 2001 From: getzze Date: Wed, 6 Mar 2024 15:35:14 +0000 Subject: [PATCH 1/9] add collect_fields --- src/psygnal/_evented_decorator.py | 24 +++++------ src/psygnal/_group_descriptor.py | 67 ++++++++++++++++++++++++++----- tests/test_group_descriptor.py | 62 +++++++++++++++++++++++++++- 3 files changed, 127 insertions(+), 26 deletions(-) diff --git a/src/psygnal/_evented_decorator.py b/src/psygnal/_evented_decorator.py index 236015c8..ac2c9728 100644 --- a/src/psygnal/_evented_decorator.py +++ b/src/psygnal/_evented_decorator.py @@ -1,12 +1,10 @@ +from __future__ import annotations + from typing import ( Any, Callable, - Dict, Literal, - Optional, - Type, TypeVar, - Union, overload, ) @@ -14,11 +12,9 @@ __all__ = ["evented"] -T = TypeVar("T", bound=Type) +T = TypeVar("T", bound=type) EqOperator = Callable[[Any, Any], bool] -PSYGNAL_GROUP_NAME = "_psygnal_group_" -_NULL = object() @overload @@ -26,7 +22,7 @@ def evented( cls: T, *, events_namespace: str = "events", - equality_operators: Optional[Dict[str, EqOperator]] = None, + equality_operators: dict[str, EqOperator] | None = None, warn_on_no_fields: bool = ..., cache_on_instance: bool = ..., ) -> T: ... @@ -34,23 +30,23 @@ def evented( @overload def evented( - cls: "Optional[Literal[None]]" = None, + cls: Literal[None] | None = None, *, events_namespace: str = "events", - equality_operators: Optional[Dict[str, EqOperator]] = None, + equality_operators: dict[str, EqOperator] | None = None, warn_on_no_fields: bool = ..., cache_on_instance: bool = ..., ) -> Callable[[T], T]: ... def evented( - cls: Optional[T] = None, + cls: T | None = None, *, events_namespace: str = "events", - equality_operators: Optional[Dict[str, EqOperator]] = None, + equality_operators: dict[str, EqOperator] | None = None, warn_on_no_fields: bool = True, cache_on_instance: bool = True, -) -> Union[Callable[[T], T], T]: +) -> Callable[[T], T] | T: """A decorator to add events to a dataclass. See also the documentation for @@ -71,7 +67,7 @@ def evented( The class to decorate. events_namespace : str The name of the namespace to add the events to, by default `"events"` - equality_operators : Optional[Dict[str, Callable]] + equality_operators : dict[str, Callable] | None A dictionary mapping field names to equality operators (a function that takes two values and returns `True` if they are equal). These will be used to determine if a field has changed when setting a new value. By default, this diff --git a/src/psygnal/_group_descriptor.py b/src/psygnal/_group_descriptor.py index 5a836de5..35990892 100644 --- a/src/psygnal/_group_descriptor.py +++ b/src/psygnal/_group_descriptor.py @@ -1,11 +1,11 @@ from __future__ import annotations import contextlib +import copy import operator import sys import warnings import weakref -from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -141,11 +141,25 @@ def connect_setattr( ) -@lru_cache(maxsize=None) def _build_dataclass_signal_group( - cls: type, equality_operators: Iterable[tuple[str, EqOperator]] | None = None + cls: type, + signal_group_class: type[SignalGroup] = SignalGroup, + equality_operators: Iterable[tuple[str, EqOperator]] | None = None, ) -> type[SignalGroup]: - """Build a SignalGroup with events for each field in a dataclass.""" + """Build a SignalGroup with events for each field in a dataclass. + + Parameters + ---------- + cls : type + the dataclass to look for the fields to connect with signals. + signal_group_class: type[SignalGroup] + SignalGroup or a subclass of it, to use as a super class. + Default to SignalGroup + equality_operators: Iterable[tuple[str, EqOperator]] | None + If defined, a mapping of field name and equality operator to use to compare if + each field was modified after being set. + Default to None + """ _equality_operators = dict(equality_operators) if equality_operators else {} signals = {} eq_map = _get_eq_operator_map(cls) @@ -162,7 +176,9 @@ def _build_dataclass_signal_group( # patch in our custom SignalInstance class with maxargs=1 on connect_setattr sig._signal_instance_class = _DataclassFieldSignalInstance - return type(f"{cls.__name__}SignalGroup", (SignalGroup,), signals) + # Create `signal_group_class` subclass with the attached signals + group_name = f"{cls.__name__}{signal_group_class.__name__}" + return type(group_name, (signal_group_class,), signals) def is_evented(obj: object) -> bool: @@ -335,8 +351,6 @@ def __setattr__(self, name: str, value: Any) -> None: field to determine whether to emit an event. If not provided, the default equality operator is `operator.eq`, except for numpy arrays, where `np.array_equal` is used. - signal_group_class : type[SignalGroup], optional - A custom SignalGroup class to use, by default None warn_on_no_fields : bool, optional If `True` (the default), a warning will be emitted if no mutable dataclass-like fields are found on the object. @@ -352,6 +366,13 @@ def __setattr__(self, name: str, value: Any) -> None: events when fields change. If `False`, no `__setattr__` method will be created. (This will prevent signal emission, and assumes you are using a different mechanism to emit signals when fields change.) + signal_group_class : type[SignalGroup] | None, optional + A custom SignalGroup class to use, SignalGroup if None, by default None + collect_fields : bool, optional + Create a signal for each field in the dataclass. If True, the `SignalGroup` + instance will be a subclass of `signal_group_class` (SignalGroup if it is None). + If False, a deepcopy of `signal_group_class` will be used. + Default to True Examples -------- @@ -378,17 +399,21 @@ def __init__( self, *, equality_operators: dict[str, EqOperator] | None = None, - signal_group_class: type[SignalGroup] | None = None, warn_on_no_fields: bool = True, cache_on_instance: bool = True, patch_setattr: bool = True, + signal_group_class: type[SignalGroup] | None = None, + collect_fields: bool = True, ): - self._signal_group = signal_group_class self._name: str | None = None self._eqop = tuple(equality_operators.items()) if equality_operators else None self._warn_on_no_fields = warn_on_no_fields self._cache_on_instance = cache_on_instance self._patch_setattr = patch_setattr + self._signal_group_class = signal_group_class or SignalGroup + self._collect_fields = collect_fields + + self._signal_groups: dict[int, type[SignalGroup]] = {} def __set_name__(self, owner: type, name: str) -> None: """Called when this descriptor is added to class `owner` as attribute `name`.""" @@ -434,13 +459,15 @@ def __get__( if instance is None: return self + signal_group = self._get_signal_group(owner) + # if we haven't yet instantiated a SignalGroup for this instance, # do it now and cache it. Note that we cache it here in addition to # the instance (in case the instance is not modifiable). obj_id = id(instance) if obj_id not in self._instance_map: # cache it - self._instance_map[obj_id] = self._create_group(owner)(instance) + self._instance_map[obj_id] = signal_group(instance) # also *try* to set it on the instance as well, since it will skip all the # __get__ logic in the future, but if it fails, no big deal. if self._name and self._cache_on_instance: @@ -453,13 +480,31 @@ def __get__( return self._instance_map[obj_id] + def _get_signal_group(self, owner: type) -> type[SignalGroup]: + type_id = id(owner) + if type_id not in self._signal_groups: + self._signal_groups[type_id] = self._create_group(owner) + return self._signal_groups[type_id] + def _create_group(self, owner: type) -> type[SignalGroup]: - Group = self._signal_group or _build_dataclass_signal_group(owner, self._eqop) + # Do not collect fields from owner class, copy the SignalGroup + if not self._collect_fields: + Group = copy.deepcopy(self._signal_group_class) + + # Collect fields and create SignalGroup subclass + else: + Group = _build_dataclass_signal_group( + owner, + signal_group_class=self._signal_group_class, + equality_operators=self._eqop, + ) + if self._warn_on_no_fields and not Group._psygnal_signals: warnings.warn( f"No mutable fields found on class {owner}: no events will be " "emitted. (Is this a dataclass, attrs, msgspec, or pydantic model?)", stacklevel=2, ) + self._do_patch_setattr(owner) return Group diff --git a/tests/test_group_descriptor.py b/tests/test_group_descriptor.py index 1b36f032..bc0b5b96 100644 --- a/tests/test_group_descriptor.py +++ b/tests/test_group_descriptor.py @@ -4,7 +4,17 @@ import pytest -from psygnal import SignalGroupDescriptor, _compiled, _group_descriptor +from psygnal import ( + Signal, + SignalGroup, + SignalGroupDescriptor, + _compiled, + _group_descriptor, +) + + +class MyGroup(SignalGroup): + sig = Signal() @pytest.mark.parametrize("type_", ["dataclass", "pydantic", "attrs", "msgspec"]) @@ -213,3 +223,53 @@ class Bar: # when using connect_setattr with maxargs=None # remove this test if/when we change maxargs to default to 1 on SignalInstance assert bar.y == (2, 1) # type: ignore + + +@pytest.mark.parametrize("collect", [True, False]) +@pytest.mark.parametrize("klass", [None, SignalGroup, MyGroup]) +def test_collect_fields(collect, klass) -> None: + signal_class = klass or SignalGroup + + @dataclass + class Foo: + events: ClassVar = SignalGroupDescriptor( + warn_on_no_fields=False, + signal_group_class=klass, + collect_fields=collect, + ) + a: int = 1 + + @dataclass + class Bar(Foo): + b: float = 2.0 + + foo = Foo() + bar = Bar() + + signal_class = klass or SignalGroup + + # Cannot instantiate SignalGroup directly, use a subclass + if not collect and signal_class is SignalGroup: + with pytest.raises(TypeError): + _ = foo.events + with pytest.raises(TypeError): + _ = bar.events + return + + assert issubclass(type(foo.events), signal_class) + + if collect: + assert type(foo.events) is not signal_class + assert "a" in foo.events + assert "a" in bar.events + assert "b" in bar.events + + else: + assert type(foo.events) == signal_class + assert "a" not in foo.events + assert "a" not in bar.events + assert "b" not in bar.events + + if signal_class is MyGroup: + assert "sig" in foo.events + assert "sig" in bar.events From fcee5584343d3ad72c25a808d416ef1d3485cd08 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 11 Mar 2024 09:00:47 -0400 Subject: [PATCH 2/9] make generic, misc suggestions --- src/psygnal/_evented_decorator.py | 2 +- src/psygnal/_group_descriptor.py | 52 +++++++++++++++++++------------ tests/test_group_descriptor.py | 28 ++++++++++------- 3 files changed, 50 insertions(+), 32 deletions(-) diff --git a/src/psygnal/_evented_decorator.py b/src/psygnal/_evented_decorator.py index ac2c9728..c638eadd 100644 --- a/src/psygnal/_evented_decorator.py +++ b/src/psygnal/_evented_decorator.py @@ -118,7 +118,7 @@ def _decorate(cls: T) -> T: if any(k.startswith("_psygnal") for k in getattr(cls, "__annotations__", {})): raise TypeError("Fields on an evented class cannot start with '_psygnal'") - descriptor = SignalGroupDescriptor( + descriptor: SignalGroupDescriptor = SignalGroupDescriptor( equality_operators=equality_operators, warn_on_no_fields=warn_on_no_fields, cache_on_instance=cache_on_instance, diff --git a/src/psygnal/_group_descriptor.py b/src/psygnal/_group_descriptor.py index 35990892..bd3ea615 100644 --- a/src/psygnal/_group_descriptor.py +++ b/src/psygnal/_group_descriptor.py @@ -11,6 +11,7 @@ Any, Callable, ClassVar, + Generic, Iterable, Literal, Type, @@ -33,6 +34,8 @@ T = TypeVar("T", bound=Type) S = TypeVar("S") +GroupType = TypeVar("GroupType", bound=SignalGroup) + EqOperator = Callable[[Any, Any], bool] _EQ_OPERATORS: dict[type, dict[str, EqOperator]] = {} @@ -143,9 +146,9 @@ def connect_setattr( def _build_dataclass_signal_group( cls: type, - signal_group_class: type[SignalGroup] = SignalGroup, + signal_group_class: type[GroupType], equality_operators: Iterable[tuple[str, EqOperator]] | None = None, -) -> type[SignalGroup]: +) -> type[GroupType]: """Build a SignalGroup with events for each field in a dataclass. Parameters @@ -302,7 +305,7 @@ def _setattr_and_emit_(self: object, name: str, value: Any) -> None: return _inner(super_setattr) if super_setattr else _inner -class SignalGroupDescriptor: +class SignalGroupDescriptor(Generic[GroupType]): """Create a [`psygnal.SignalGroup`][] on first instance attribute access. This descriptor is designed to be used as a class attribute on a dataclass-like @@ -395,6 +398,10 @@ class Person: ``` """ + # map of id(obj) -> SignalGroup + # cached here in case the object isn't modifiable + _instance_map: ClassVar[dict[int, SignalGroup]] = {} + def __init__( self, *, @@ -402,18 +409,30 @@ def __init__( warn_on_no_fields: bool = True, cache_on_instance: bool = True, patch_setattr: bool = True, - signal_group_class: type[SignalGroup] | None = None, + signal_group_class: type[GroupType] | None = None, collect_fields: bool = True, ): + grp_cls = signal_group_class or cast(type[GroupType], SignalGroup) + if not (isinstance(grp_cls, type) and issubclass(grp_cls, SignalGroup)): + raise TypeError( + f"'signal_group_class' must be a subclass of SignalGroup, " + f"not {grp_cls}" + ) + if grp_cls is SignalGroup and collect_fields is False: + raise ValueError( + "Cannot use SignalGroup with collect_fields=False. " + "Use a custom SignalGroup subclass instead." + ) + self._name: str | None = None self._eqop = tuple(equality_operators.items()) if equality_operators else None self._warn_on_no_fields = warn_on_no_fields self._cache_on_instance = cache_on_instance self._patch_setattr = patch_setattr - self._signal_group_class = signal_group_class or SignalGroup - self._collect_fields = collect_fields - self._signal_groups: dict[int, type[SignalGroup]] = {} + self._signal_group_class: type[GroupType] = grp_cls + self._collect_fields = collect_fields + self._signal_groups: dict[int, type[GroupType]] = {} def __set_name__(self, owner: type, name: str) -> None: """Called when this descriptor is added to class `owner` as attribute `name`.""" @@ -442,19 +461,15 @@ def _do_patch_setattr(self, owner: type) -> None: "emitted when fields change." ) from e - # map of id(obj) -> SignalGroup - # cached here in case the object isn't modifiable - _instance_map: ClassVar[dict[int, SignalGroup]] = {} - @overload def __get__(self, instance: None, owner: type) -> SignalGroupDescriptor: ... @overload - def __get__(self, instance: object, owner: type) -> SignalGroup: ... + def __get__(self, instance: object, owner: type) -> GroupType: ... def __get__( self, instance: object, owner: type - ) -> SignalGroup | SignalGroupDescriptor: + ) -> GroupType | SignalGroupDescriptor: """Return a SignalGroup instance for `instance`.""" if instance is None: return self @@ -478,15 +493,15 @@ def __get__( with contextlib.suppress(TypeError): # if it's not weakref-able weakref.finalize(instance, self._instance_map.pop, obj_id, None) - return self._instance_map[obj_id] + return cast("GroupType", self._instance_map[obj_id]) - def _get_signal_group(self, owner: type) -> type[SignalGroup]: + def _get_signal_group(self, owner: type) -> type[GroupType]: type_id = id(owner) if type_id not in self._signal_groups: self._signal_groups[type_id] = self._create_group(owner) return self._signal_groups[type_id] - def _create_group(self, owner: type) -> type[SignalGroup]: + def _create_group(self, owner: type) -> type[GroupType]: # Do not collect fields from owner class, copy the SignalGroup if not self._collect_fields: Group = copy.deepcopy(self._signal_group_class) @@ -494,11 +509,8 @@ def _create_group(self, owner: type) -> type[SignalGroup]: # Collect fields and create SignalGroup subclass else: Group = _build_dataclass_signal_group( - owner, - signal_group_class=self._signal_group_class, - equality_operators=self._eqop, + owner, self._signal_group_class, equality_operators=self._eqop ) - if self._warn_on_no_fields and not Group._psygnal_signals: warnings.warn( f"No mutable fields found on class {owner}: no events will be " diff --git a/tests/test_group_descriptor.py b/tests/test_group_descriptor.py index bc0b5b96..e59ef3e2 100644 --- a/tests/test_group_descriptor.py +++ b/tests/test_group_descriptor.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from dataclasses import dataclass from typing import Any, ClassVar from unittest.mock import Mock, patch @@ -227,17 +228,24 @@ class Bar: @pytest.mark.parametrize("collect", [True, False]) @pytest.mark.parametrize("klass", [None, SignalGroup, MyGroup]) -def test_collect_fields(collect, klass) -> None: +def test_collect_fields(collect: bool, klass: type[SignalGroup] | None) -> None: signal_class = klass or SignalGroup + should_fail_def = signal_class is SignalGroup and collect is False + ctx = pytest.raises(ValueError) if should_fail_def else nullcontext() - @dataclass - class Foo: - events: ClassVar = SignalGroupDescriptor( - warn_on_no_fields=False, - signal_group_class=klass, - collect_fields=collect, - ) - a: int = 1 + with ctx: + + @dataclass + class Foo: + events: ClassVar = SignalGroupDescriptor( + warn_on_no_fields=False, + signal_group_class=klass, + collect_fields=collect, + ) + a: int = 1 + + if should_fail_def: + return @dataclass class Bar(Foo): @@ -246,8 +254,6 @@ class Bar(Foo): foo = Foo() bar = Bar() - signal_class = klass or SignalGroup - # Cannot instantiate SignalGroup directly, use a subclass if not collect and signal_class is SignalGroup: with pytest.raises(TypeError): From 9e6692e7bcb72964f391cf9ffaf7e69c83d3f792 Mon Sep 17 00:00:00 2001 From: getzze Date: Mon, 11 Mar 2024 13:12:12 +0000 Subject: [PATCH 3/9] compat py38 --- tests/test_group_descriptor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_group_descriptor.py b/tests/test_group_descriptor.py index e59ef3e2..611ff881 100644 --- a/tests/test_group_descriptor.py +++ b/tests/test_group_descriptor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import nullcontext from dataclasses import dataclass from typing import Any, ClassVar From 6e636e60bd882712ccd3ca61135b9d8d52efb95d Mon Sep 17 00:00:00 2001 From: getzze Date: Mon, 11 Mar 2024 13:16:26 +0000 Subject: [PATCH 4/9] cleanup --- tests/test_group_descriptor.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_group_descriptor.py b/tests/test_group_descriptor.py index 611ff881..48780656 100644 --- a/tests/test_group_descriptor.py +++ b/tests/test_group_descriptor.py @@ -256,14 +256,6 @@ class Bar(Foo): foo = Foo() bar = Bar() - # Cannot instantiate SignalGroup directly, use a subclass - if not collect and signal_class is SignalGroup: - with pytest.raises(TypeError): - _ = foo.events - with pytest.raises(TypeError): - _ = bar.events - return - assert issubclass(type(foo.events), signal_class) if collect: From c0a34b412dcedba8e4a213986ee5b5e7e7b37ec1 Mon Sep 17 00:00:00 2001 From: getzze Date: Mon, 11 Mar 2024 13:34:41 +0000 Subject: [PATCH 5/9] remove future --- tests/test_group_descriptor.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_group_descriptor.py b/tests/test_group_descriptor.py index 48780656..d54805c1 100644 --- a/tests/test_group_descriptor.py +++ b/tests/test_group_descriptor.py @@ -1,8 +1,6 @@ -from __future__ import annotations - from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, ClassVar +from typing import Any, ClassVar, Optional, Type from unittest.mock import Mock, patch import pytest @@ -230,7 +228,7 @@ class Bar: @pytest.mark.parametrize("collect", [True, False]) @pytest.mark.parametrize("klass", [None, SignalGroup, MyGroup]) -def test_collect_fields(collect: bool, klass: type[SignalGroup] | None) -> None: +def test_collect_fields(collect: bool, klass: Optional[Type[SignalGroup]]) -> None: signal_class = klass or SignalGroup should_fail_def = signal_class is SignalGroup and collect is False ctx = pytest.raises(ValueError) if should_fail_def else nullcontext() From c49d2464f9eda1944d078f98f09a882cf16cf9e1 Mon Sep 17 00:00:00 2001 From: getzze Date: Mon, 11 Mar 2024 13:38:19 +0000 Subject: [PATCH 6/9] remove future --- src/psygnal/_group_descriptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/psygnal/_group_descriptor.py b/src/psygnal/_group_descriptor.py index bd3ea615..a4f29dad 100644 --- a/src/psygnal/_group_descriptor.py +++ b/src/psygnal/_group_descriptor.py @@ -412,7 +412,7 @@ def __init__( signal_group_class: type[GroupType] | None = None, collect_fields: bool = True, ): - grp_cls = signal_group_class or cast(type[GroupType], SignalGroup) + grp_cls = signal_group_class or cast(Type[GroupType], SignalGroup) if not (isinstance(grp_cls, type) and issubclass(grp_cls, SignalGroup)): raise TypeError( f"'signal_group_class' must be a subclass of SignalGroup, " From 75d88b90c2d22b65618510fdd4276736e18f1870 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 11 Mar 2024 09:47:59 -0400 Subject: [PATCH 7/9] add pragma --- src/psygnal/_group_descriptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/psygnal/_group_descriptor.py b/src/psygnal/_group_descriptor.py index a4f29dad..783b2e60 100644 --- a/src/psygnal/_group_descriptor.py +++ b/src/psygnal/_group_descriptor.py @@ -414,7 +414,7 @@ def __init__( ): grp_cls = signal_group_class or cast(Type[GroupType], SignalGroup) if not (isinstance(grp_cls, type) and issubclass(grp_cls, SignalGroup)): - raise TypeError( + raise TypeError( # pragma: no cover f"'signal_group_class' must be a subclass of SignalGroup, " f"not {grp_cls}" ) From 2751807f355db86e38fe71a6174de8ee21906e27 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 11 Mar 2024 11:34:44 -0400 Subject: [PATCH 8/9] remove generic --- src/psygnal/_group_descriptor.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/psygnal/_group_descriptor.py b/src/psygnal/_group_descriptor.py index 783b2e60..a807a43b 100644 --- a/src/psygnal/_group_descriptor.py +++ b/src/psygnal/_group_descriptor.py @@ -11,7 +11,6 @@ Any, Callable, ClassVar, - Generic, Iterable, Literal, Type, @@ -34,7 +33,6 @@ T = TypeVar("T", bound=Type) S = TypeVar("S") -GroupType = TypeVar("GroupType", bound=SignalGroup) EqOperator = Callable[[Any, Any], bool] @@ -146,9 +144,9 @@ def connect_setattr( def _build_dataclass_signal_group( cls: type, - signal_group_class: type[GroupType], + signal_group_class: type[SignalGroup], equality_operators: Iterable[tuple[str, EqOperator]] | None = None, -) -> type[GroupType]: +) -> type[SignalGroup]: """Build a SignalGroup with events for each field in a dataclass. Parameters @@ -305,7 +303,7 @@ def _setattr_and_emit_(self: object, name: str, value: Any) -> None: return _inner(super_setattr) if super_setattr else _inner -class SignalGroupDescriptor(Generic[GroupType]): +class SignalGroupDescriptor: """Create a [`psygnal.SignalGroup`][] on first instance attribute access. This descriptor is designed to be used as a class attribute on a dataclass-like @@ -409,10 +407,10 @@ def __init__( warn_on_no_fields: bool = True, cache_on_instance: bool = True, patch_setattr: bool = True, - signal_group_class: type[GroupType] | None = None, + signal_group_class: type[SignalGroup] | None = None, collect_fields: bool = True, ): - grp_cls = signal_group_class or cast(Type[GroupType], SignalGroup) + grp_cls = signal_group_class or cast(Type[SignalGroup], SignalGroup) if not (isinstance(grp_cls, type) and issubclass(grp_cls, SignalGroup)): raise TypeError( # pragma: no cover f"'signal_group_class' must be a subclass of SignalGroup, " @@ -430,9 +428,9 @@ def __init__( self._cache_on_instance = cache_on_instance self._patch_setattr = patch_setattr - self._signal_group_class: type[GroupType] = grp_cls + self._signal_group_class: type[SignalGroup] = grp_cls self._collect_fields = collect_fields - self._signal_groups: dict[int, type[GroupType]] = {} + self._signal_groups: dict[int, type[SignalGroup]] = {} def __set_name__(self, owner: type, name: str) -> None: """Called when this descriptor is added to class `owner` as attribute `name`.""" @@ -465,11 +463,11 @@ def _do_patch_setattr(self, owner: type) -> None: def __get__(self, instance: None, owner: type) -> SignalGroupDescriptor: ... @overload - def __get__(self, instance: object, owner: type) -> GroupType: ... + def __get__(self, instance: object, owner: type) -> SignalGroup: ... def __get__( self, instance: object, owner: type - ) -> GroupType | SignalGroupDescriptor: + ) -> SignalGroup | SignalGroupDescriptor: """Return a SignalGroup instance for `instance`.""" if instance is None: return self @@ -493,15 +491,15 @@ def __get__( with contextlib.suppress(TypeError): # if it's not weakref-able weakref.finalize(instance, self._instance_map.pop, obj_id, None) - return cast("GroupType", self._instance_map[obj_id]) + return self._instance_map[obj_id] - def _get_signal_group(self, owner: type) -> type[GroupType]: + def _get_signal_group(self, owner: type) -> type[SignalGroup]: type_id = id(owner) if type_id not in self._signal_groups: self._signal_groups[type_id] = self._create_group(owner) return self._signal_groups[type_id] - def _create_group(self, owner: type) -> type[GroupType]: + def _create_group(self, owner: type) -> type[SignalGroup]: # Do not collect fields from owner class, copy the SignalGroup if not self._collect_fields: Group = copy.deepcopy(self._signal_group_class) From 7db13a79945e658ba17d76946009096dfd424a71 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 11 Mar 2024 11:35:36 -0400 Subject: [PATCH 9/9] remove unused cast --- src/psygnal/_group_descriptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/psygnal/_group_descriptor.py b/src/psygnal/_group_descriptor.py index a807a43b..9132b2dc 100644 --- a/src/psygnal/_group_descriptor.py +++ b/src/psygnal/_group_descriptor.py @@ -410,7 +410,7 @@ def __init__( signal_group_class: type[SignalGroup] | None = None, collect_fields: bool = True, ): - grp_cls = signal_group_class or cast(Type[SignalGroup], SignalGroup) + grp_cls = signal_group_class or SignalGroup if not (isinstance(grp_cls, type) and issubclass(grp_cls, SignalGroup)): raise TypeError( # pragma: no cover f"'signal_group_class' must be a subclass of SignalGroup, "