Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support bubbling up of events from evented children on dataclasses #298

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions src/psygnal/_evented_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def evented(
equality_operators: dict[str, EqOperator] | None = None,
warn_on_no_fields: bool = ...,
cache_on_instance: bool = ...,
connect_child_events: bool = True,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = ...,
) -> T: ...

Expand All @@ -32,6 +33,7 @@ def evented(
equality_operators: dict[str, EqOperator] | None = None,
warn_on_no_fields: bool = ...,
cache_on_instance: bool = ...,
connect_child_events: bool = True,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = ...,
) -> Callable[[T], T]: ...

Expand All @@ -43,6 +45,7 @@ def evented(
equality_operators: dict[str, EqOperator] | None = None,
warn_on_no_fields: bool = True,
cache_on_instance: bool = True,
connect_child_events: bool = True,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = None,
) -> Callable[[T], T] | T:
"""A decorator to add events to a dataclass.
Expand Down Expand Up @@ -83,6 +86,14 @@ def evented(
access, but means that the owner instance will no longer be pickleable. If
`False`, the SignalGroup instance will *still* be cached, but not on the
instance itself.
connect_child_events : bool, optional
If `True`, will connect events from all fields on the dataclass whose type is
also "evented" (as determined by the `is_evented` function in this module,
which returns True if the class has been decorated with `@evented`, or if it
has a SignalGroupDescriptor) to the group on the parent object. By default
False.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... True?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Which btw I think should be the default... I think it takes a lot of experience with events to even notice that there's no bubbling up by default, let alone grok it. I think most users would expect bubbling to automagically work.)

This is useful for nested evented dataclasses, where you want to monitor events
emitted from arbitrarily deep children on the parent object.
signal_aliases: Mapping[str, str | None] | Callable[[str], str | None] | None
If defined, a mapping between field name and signal name. Field names that are
not `signal_aliases` keys are not aliased (the signal name is the field name).
Expand Down Expand Up @@ -128,6 +139,7 @@ def _decorate(cls: T) -> T:
equality_operators=equality_operators,
warn_on_no_fields=warn_on_no_fields,
cache_on_instance=cache_on_instance,
connect_child_events=connect_child_events,
signal_aliases=signal_aliases,
)
# as a decorator, this will have already been called
Expand Down
68 changes: 60 additions & 8 deletions src/psygnal/_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,30 @@ class EmissionInfo(NamedTuple):
Attributes
----------
signal : SignalInstance
The SignalInstance doing the emitting
args: tuple
The args that were emitted
loc: str | int | None | tuple[int | str, ...]
If the emitter was a `SignalGroup` attribute on another object, this
will be the location of that emitter on the parent (e.g. name of that attribute
or index if the parent was an evented sequence). Otherwise, it will be `None`.
If this is a flattened EmissionInfo, then this will be a tuple of locations.
"""

signal: SignalInstance
args: tuple[Any, ...]
loc: str | int | None | tuple[int | str, ...] = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the earlier example, which had attr_name here, I find that easier to think about than loc. This is especially true if you add dictionaries into the mix — so now we don't know from the type whether this is an index, a key, or an attribute name. What do you think about having three optionals:

  • attr
  • index
  • key

Of course, now I move on to read about "flatten"... 😅

You could make a lightweight "Loc" object that could contain each of the keys, then you could have a sequence of Loc objects? The nice thing about this option is that you could write an accessor that gets you the right object given the parent and the sequence of locs.

I'm also not super keen on the "loc" name. "from"? "emitted_by"? "child"?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about "origin"?

I think "loc" would be fine though...


def flatten(self) -> EmissionInfo:
"""Return the final signal and args that were emitted."""
info = self
path: list[int | str] = []
while info.args and isinstance(info.args[0], EmissionInfo):
if isinstance(info.loc, (str, int)):
path.append(info.loc)
info = info.args[0]
path.append(info.signal.name)
return EmissionInfo(info.signal, info.args, tuple(path))


class SignalRelay(SignalInstance):
Expand Down Expand Up @@ -96,18 +115,22 @@ def _disconnect_relay(self) -> None:
for sig in self._signals.values():
sig.disconnect(self._slot_relay)

def _slot_relay(self, *args: Any) -> None:
def _slot_relay(self, *args: Any, loc: str | int | None = None) -> None:
if emitter := Signal.current_emitter():
info = EmissionInfo(emitter, args)
info = EmissionInfo(emitter, args, loc or emitter.name)
self._run_emit_loop((info,))

def _relay_partial(self, loc: str | int | None) -> _relay_partial:
"""Return special partial that will call _slot_relay with 'loc'."""
return _relay_partial(self, loc)

def connect_direct(
self,
slot: Callable | None = None,
*,
check_nargs: bool | None = None,
check_types: bool | None = None,
unique: bool | str = False,
unique: bool | Literal["raise"] = False,
max_args: int | None = None,
) -> Callable[[Callable], Callable] | Callable:
"""Connect `slot` to be called whenever *any* Signal in this group is emitted.
Expand All @@ -128,7 +151,7 @@ def connect_direct(
If `True`, An additional check will be performed to make sure that types
declared in the slot signature are compatible with the signature
declared by this signal, by default `False`.
unique : bool | str
unique : bool | Literal["raise"]
If `True`, returns without connecting if the slot has already been
connected. If the literal string "raise" is passed to `unique`, then a
`ValueError` will be raised if the slot is already connected.
Expand Down Expand Up @@ -444,7 +467,7 @@ def connect(
thread: threading.Thread | Literal["main", "current"] | None = ...,
check_nargs: bool | None = ...,
check_types: bool | None = ...,
unique: bool | str = ...,
unique: bool | Literal["raise"] = ...,
max_args: int | None = None,
on_ref_error: RefErrorChoice = ...,
priority: int = ...,
Expand All @@ -458,7 +481,7 @@ def connect(
thread: threading.Thread | Literal["main", "current"] | None = ...,
check_nargs: bool | None = ...,
check_types: bool | None = ...,
unique: bool | str = ...,
unique: bool | Literal["raise"] = ...,
max_args: int | None = None,
on_ref_error: RefErrorChoice = ...,
priority: int = ...,
Expand All @@ -471,7 +494,7 @@ def connect(
thread: threading.Thread | Literal["main", "current"] | None = None,
check_nargs: bool | None = None,
check_types: bool | None = None,
unique: bool | str = False,
unique: bool | Literal["raise"] = False,
max_args: int | None = None,
on_ref_error: RefErrorChoice = "warn",
priority: int = 0,
Expand Down Expand Up @@ -504,7 +527,7 @@ def connect_direct(
*,
check_nargs: bool | None = None,
check_types: bool | None = None,
unique: bool | str = False,
unique: bool | Literal["raise"] = False,
max_args: int | None = None,
) -> Callable[[Callable], Callable] | Callable:
return self._psygnal_relay.connect_direct(
Expand Down Expand Up @@ -550,3 +573,32 @@ def _is_uniform(signals: Iterable[Signal]) -> bool:
return False
seen.add(v)
return True


class _relay_partial:
"""Small, single-purpose, mypyc-friendly variant of functools.partial.

Used to call SignalRelay._slot_relay with a specific loc.
__hash__ and __eq__ are implemented to make this object hashable and
comparable to other _relay_partial objects, so that adding it to a set
twice will not create two separate entries.
"""

__slots__ = ("relay", "loc")

def __init__(self, relay: SignalRelay, loc: str | int | None = None) -> None:
self.relay = relay
self.loc = loc

def __call__(self, *args: Any) -> Any:
return self.relay._slot_relay(*args, loc=self.loc)

def __hash__(self) -> int:
return hash((self.relay, self.loc))

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, _relay_partial)
and self.relay == other.relay
and self.loc == other.loc
)
65 changes: 63 additions & 2 deletions src/psygnal/_group_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,14 @@ def __setattr__(self, name: str, value: Any) -> None:
instance will be a subclass of `signal_group_class` (SignalGroup if it is None).
If False, a deepcopy of `signal_group_class` will be used.
Default to True
connect_child_events : bool, optional
If `True`, will connect events from all fields on the dataclass whose type is
also "evented" (as determined by the `is_evented` function in this module,
which returns True if the class has been decorated with `@evented`, or if it
has a SignalGroupDescriptor) to the group on the parent object. By default
False.
This is useful for nested evented dataclasses, where you want to monitor events
emitted from arbitrarily deep children on the parent object.
signal_aliases: Mapping[str, str | None] | Callable[[str], str | None] | None
If defined, a mapping between field name and signal name. Field names that are
not `signal_aliases` keys are not aliased (the signal name is the field name).
Expand Down Expand Up @@ -522,6 +530,7 @@ def __init__(
patch_setattr: bool = True,
signal_group_class: type[SignalGroup] | None = None,
collect_fields: bool = True,
connect_child_events: bool = False,
signal_aliases: Mapping[str, str | None] | FieldAliasFunc | None = None,
):
grp_cls = signal_group_class or SignalGroup
Expand All @@ -548,6 +557,8 @@ def __init__(
self._warn_on_no_fields = warn_on_no_fields
self._cache_on_instance = cache_on_instance
self._patch_setattr = patch_setattr
self._connect_child_events = connect_child_events

self._signal_group_class: type[SignalGroup] = grp_cls
self._collect_fields = collect_fields
self._signal_aliases = signal_aliases
Expand Down Expand Up @@ -609,17 +620,22 @@ def __get__(
obj_id = id(instance)
if obj_id not in self._instance_map:
# cache it
self._instance_map[obj_id] = signal_group(instance)
self._instance_map[obj_id] = grp = signal_group(instance)
# also *try* to set it on the instance as well, since it will skip all the
# __get__ logic in the future, but if it fails, no big deal.
if self._name and self._cache_on_instance:
with contextlib.suppress(Exception):
setattr(instance, self._name, self._instance_map[obj_id])
setattr(instance, self._name, grp)

# clean up the cache when the instance is deleted
with contextlib.suppress(TypeError): # if it's not weakref-able
weakref.finalize(instance, self._instance_map.pop, obj_id, None)

# setup nested event emission if requested
if self._connect_child_events:
# TODO: expose "recurse" somehow?
connect_child_events(instance, recurse=True, _group=grp)

return self._instance_map[obj_id]

def _get_signal_group(self, owner: type) -> type[SignalGroup]:
Expand Down Expand Up @@ -655,3 +671,48 @@ def _create_group(self, owner: type) -> type[SignalGroup]:

self._do_patch_setattr(owner, with_aliases=bool(Group._psygnal_aliases))
return Group


def _find_signal_group(obj: object, default_name: str = "events") -> SignalGroup | None:
# look for default "events" name as well
maybe_group = getattr(obj, get_evented_namespace(obj) or default_name, None)
return maybe_group if isinstance(maybe_group, SignalGroup) else None


def connect_child_events(
obj: object, recurse: bool = False, _group: SignalGroup | None = None
) -> None:
"""Connect events from evented children to a parent SignalGroup.

`obj` must be an evented dataclass-style object.
This is useful when you have a tree of objects, and you want to connect all
events from the children to the parent.

Parameters.
----------
obj : object
The object to connect events from. If it is not evented, this function will
do nothing.
recurse : bool, optional
If `True`, will also connect events from all evented children of `obj`, by
default `False`.
_group : SignalGroup, optional
(This is used internally during recursion.)
The SignalGroup to connect to. If not provided, will be found by calling
`get_evented_namespace(obj)`. By default None.
"""
if _group is None and (_group := _find_signal_group(obj)) is None:
return # pragma: no cover # not evented

for loc, attr_type in iter_fields(type(obj), exclude_frozen=True):
if is_evented(attr_type):
child = getattr(obj, loc, None)
if (child_group := _find_signal_group(child)) is not None:
child_group.connect(
_group._psygnal_relay._relay_partial(loc),
check_nargs=False,
check_types=False,
on_ref_error="ignore", # compiled objects are not weakref-able
)
if recurse:
connect_child_events(child, recurse=True, _group=child_group)
8 changes: 4 additions & 4 deletions src/psygnal/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def connect(
thread: threading.Thread | Literal["main", "current"] | None = ...,
check_nargs: bool | None = ...,
check_types: bool | None = ...,
unique: bool | str = ...,
unique: bool | Literal["raise"] = ...,
max_args: int | None = None,
on_ref_error: RefErrorChoice = ...,
priority: int = ...,
Expand All @@ -599,7 +599,7 @@ def connect(
thread: threading.Thread | Literal["main", "current"] | None = ...,
check_nargs: bool | None = ...,
check_types: bool | None = ...,
unique: bool | str = ...,
unique: bool | Literal["raise"] = ...,
max_args: int | None = None,
on_ref_error: RefErrorChoice = ...,
priority: int = ...,
Expand All @@ -612,7 +612,7 @@ def connect(
thread: threading.Thread | Literal["main", "current"] | None = None,
check_nargs: bool | None = None,
check_types: bool | None = None,
unique: bool | str = False,
unique: bool | Literal["raise"] = False,
max_args: int | None = None,
on_ref_error: RefErrorChoice = "warn",
priority: int = 0,
Expand Down Expand Up @@ -693,7 +693,7 @@ def my_function(): ...
If the provided slot fails validation, either due to mismatched positional
argument requirements, or failed type checking.
ValueError
If `unique` is `True` and `slot` has already been connected.
If `unique` is `'raise'` and `slot` has already been connected.
"""
if check_nargs is None:
check_nargs = self._check_nargs_on_connect
Expand Down
2 changes: 2 additions & 0 deletions src/psygnal/containers/_evented_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Iterable,
Iterator,
Mapping,
Expand Down Expand Up @@ -168,6 +169,7 @@ class EventedDict(TypedMutableMapping[_K, _V]):
"""

events: DictEvents # pragma: no cover
_psygnal_group_: ClassVar[str] = "events"

def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion src/psygnal/containers/_evented_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Iterable,
Mapping,
MutableSequence,
Expand Down Expand Up @@ -114,6 +115,7 @@ class EventedList(MutableSequence[_T]):
"""

events: ListEvents # pragma: no cover
_psygnal_group_: ClassVar[str] = "events"

def __init__(
self,
Expand Down Expand Up @@ -422,7 +424,7 @@ def _reemit_child_event(self, *args: Any) -> None:
and isinstance(emitter, SignalRelay)
and isinstance(args[0], EmissionInfo)
):
emitter, args = args[0]
emitter, args, *_ = args[0]

self.events.child_event.emit(idx, obj, emitter, args)

Expand Down
4 changes: 3 additions & 1 deletion src/psygnal/containers/_evented_proxy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Callable, Dict, Generic, List, TypeVar
from typing import Any, Callable, ClassVar, Dict, Generic, List, TypeVar
from weakref import finalize

try:
Expand Down Expand Up @@ -67,6 +67,8 @@ class EventedObjectProxy(ObjectProxy, Generic[T]):
An object to wrap
"""

_psygnal_group_: ClassVar[str] = "events"

def __init__(self, target: Any):
super().__init__(target)

Expand Down
2 changes: 2 additions & 0 deletions src/psygnal/containers/_evented_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Final,
Iterable,
Iterator,
Expand Down Expand Up @@ -254,6 +255,7 @@ class EventedSet(_BaseMutableSet[_T]):
"""

events: SetEvents # pragma: no cover
_psygnal_group_: ClassVar[str] = "events"

def __init__(self, iterable: Iterable[_T] = ()):
self.events = self._get_events_class()
Expand Down
Loading
Loading