Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support arbitrary field dependents #340

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ dependencies = [
"hatch-mypyc>=0.13.0",
"mypy>=0.991",
"mypy_extensions >=0.4.2",
"pydantic!=2.10.0",
"pydantic!=2.10.0", # typing error in v2.10 prevents mypyc from working
"types-attrs",
"msgspec",
]
Expand Down
62 changes: 44 additions & 18 deletions src/psygnal/_evented_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,14 @@ def _get_defaults(
def _get_config(cls: pydantic.BaseModel) -> "ConfigDict":
return cls.model_config

def _get_fields(cls: pydantic.BaseModel) -> dict[str, pydantic.fields.FieldInfo]:
return cls.model_fields
def _get_fields(
cls: pydantic.BaseModel,
) -> dict[str, pydantic.fields.FieldInfo]:
comp_fields = {
name: pydantic.fields.FieldInfo(annotation=f.return_type, frozen=False)
for name, f in cls.model_computed_fields.items()
Comment on lines +132 to +133
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see test with pydantic 1.x, Should be here: if hasattr(cls, "model_computed_fields")?

}
return {**cls.model_fields, **comp_fields}

def _model_dump(obj: pydantic.BaseModel) -> dict:
return obj.model_dump()
Expand Down Expand Up @@ -353,9 +359,11 @@ class Config:
f"{prop!r} is not."
)
for field in fields:
if field not in model_fields:
if field not in model_fields and not hasattr(cls, field):
warnings.warn(
f"Unrecognized field dependency: {field!r}", stacklevel=2
f"property {prop!r} cannot depend on unrecognized attribute "
f"name: {field!r}",
stacklevel=2,
)
deps.setdefault(field, set()).add(prop)
if model_config.get(GUESS_PROPERTY_DEPENDENCIES, False):
Expand Down Expand Up @@ -470,6 +478,7 @@ class Config:
_changes_queue: dict[str, Any] = PrivateAttr(default_factory=dict)
_primary_changes: set[str] = PrivateAttr(default_factory=set)
_delay_check_semaphore: int = PrivateAttr(0)
_names_that_need_emission: set[str] = PrivateAttr(default_factory=set)

if PYDANTIC_V1:

Expand All @@ -484,6 +493,9 @@ def __init__(_model_self_, **data: Any) -> None:
# but if we don't use `ClassVar`, then the `dataclass_transform` decorator
# will add _events: SignalGroup to the __init__ signature, for *all* user models
_model_self_._events = Group(_model_self_) # type: ignore [misc]
_model_self_._names_that_need_emission = set(_model_self_._events) | set(
_model_self_.__field_dependents__
)

# expose the private SignalGroup publicly
@property
Expand Down Expand Up @@ -565,15 +577,23 @@ def _check_if_values_changed_and_emit_if_needed(self) -> None:
# do not run whole machinery if there is no need
return
to_emit = []
must_continue = False
for name in self._primary_changes:
# primary changes should contains only fields
# that are changed directly by assignment
Comment on lines 582 to 583
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not fully understand. When we need to emit a signal. This comment still statement that self._primary_changes are only fields changed by direct assignment.
And documentation of https://docs.pydantic.dev/2.0/usage/computed_fields/ do not mention the assigment of this property.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i wasn't entirely clear on the _primary_changes, and _changes_queue attributes here. So maybe I did it wrong. we need to emit a signal here when a non-field attribute changes, provided that some other evented field lists that attribute as one of its "dependents". So, in the example provided by the original requester:

from psygnal import EventedModel
from pydantic import computed_field, PrivateAttr

class MyModel(EventedModel):
    _items_dict: dict[str, int] = PrivateAttr(default_factory=dict)

    @computed_field
    def item_names(self) -> list[int]:
        return list(self._items_dict.keys())

    def add_item(self, name: str, value: int) -> None:
        # this next assignment should trigger the emission of `self.events.item_names`
        # because `_items_dict` was listed as a field_dependency of `item_names`
        self._items_dict = {**self._items_dict, name: value}

    class Config:
        field_dependencies = {"item_names": ["_items_dict"]}

if I did that incorrectly (or stupidly) with the existing _primary_changes/_changes_queue pattern let me know

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this might be the first case where the "primary change" is to a non-evented attribute (_items_dict)... and so it didn't quite fit into the existing pattern. It's the first case where a _primary_change isn't actually able to itself be evented

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm.

Maybe we should add a dummy emiter (callback that do nothing, and raises an exception on connection) for private attributes that are listed in field_dependencies? Then we will move the complexity from setattr to constructor call?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's worth trying, lemme see how it goes

old_value = self._changes_queue[name]
new_value = getattr(self, name)

if not _check_field_equality(type(self), name, new_value, old_value):
to_emit.append((name, new_value))
if name in self._events:
to_emit.append((name, new_value))
else:
# An attribute is changing that is not in the SignalGroup
# if it has field dependents, we must still continue
# to check the _changes_queue
must_continue = name in self.__field_dependents__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
must_continue = name in self.__field_dependents__
must_continue = must_continue or name in self.__field_dependents__

As it may happen in the middle of for loop

self._changes_queue.pop(name)
if not to_emit:
if not to_emit and not must_continue:
# If no direct changes was made then we can skip whole machinery
self._changes_queue.clear()
self._primary_changes.clear()
Expand All @@ -595,7 +615,7 @@ def __setattr__(self, name: str, value: Any) -> None:
if (
name == "_events"
or not hasattr(self, "_events") # can happen on init
or name not in self._events
or name not in self._names_that_need_emission
):
# fallback to default behavior
return self._super_setattr_(name, value)
Expand Down Expand Up @@ -646,18 +666,24 @@ def _setattr_impl(self, name: str, value: Any) -> None:
# note that ALL signals will have sat least one listener simply by nature of
# being in the `self._events` SignalGroup.
group = self._events
signal_instance: SignalInstance = group[name]
deps_with_callbacks = {
dep_name
for dep_name in self.__field_dependents__.get(name, ())
if len(group[dep_name])
}
if (
len(signal_instance) < 1 # the signal itself has no listeners
and not deps_with_callbacks # no dependent properties with listeners
and not len(group._psygnal_relay) # no listeners on the SignalGroup
):
if name in group:
signal_instance: SignalInstance = group[name]
deps_with_callbacks = {
dep_name
for dep_name in self.__field_dependents__.get(name, ())
if len(group[dep_name])
}
if (
len(signal_instance) < 1 # the signal itself has no listeners
and not deps_with_callbacks # no dependent properties with listeners
and not len(group._psygnal_relay) # no listeners on the SignalGroup
):
return self._super_setattr_(name, value)
elif name in self.__field_dependents__:
deps_with_callbacks = self.__field_dependents__[name]
else:
return self._super_setattr_(name, value)

self._primary_changes.add(name)
if name not in self._changes_queue:
self._changes_queue[name] = getattr(self, name, object())
Expand Down
48 changes: 47 additions & 1 deletion tests/test_evented_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ class Config:


def test_unrecognized_property_dependencies():
with pytest.warns(UserWarning, match="Unrecognized field dependency: 'b'"):
with pytest.warns(UserWarning, match="cannot depend on unrecognized attribute"):

class M(EventedModel):
x: int
Expand Down Expand Up @@ -983,3 +983,49 @@ def c(self, val: Sequence[int]) -> None:
m.a = 5
mock_a.assert_called_with(5)
mock_c.assert_called_with([5, 20])


@pytest.mark.skipif(not PYDANTIC_V2, reason="computed_field added in v2")
def test_private_field_dependents():
from pydantic import PrivateAttr, computed_field

from psygnal import EventedModel

class MyModel(EventedModel):
_items_dict: dict[str, int] = PrivateAttr(default_factory=dict)

@computed_field # type: ignore [prop-decorator]
@property
def item_names(self) -> list[str]:
return list(self._items_dict)

@computed_field # type: ignore [prop-decorator]
@property
def item_sum(self) -> int:
return sum(self._items_dict.values())

def add_item(self, name: str, value: int) -> None:
if name in self._items_dict:
raise ValueError(f"Name {name} already exists!")
self._items_dict = {**self._items_dict, name: value}

# Ideally the following would work
model_config = { # type: ignore [typeddict-unknown-key]
"field_dependencies": {
"item_names": ["_items_dict"],
"item_sum": ["_items_dict"],
}
}

m = MyModel()
item_sum_mock = Mock()
item_names_mock = Mock()
m.events.item_sum.connect(item_sum_mock)
m.events.item_names.connect(item_names_mock)
m.add_item("a", 1)
item_sum_mock.assert_called_with(1)
item_names_mock.assert_called_with(["a"])
item_sum_mock.reset_mock()
m.add_item("b", 2)
item_sum_mock.assert_called_with(3)
item_names_mock.assert_called_with(["a", "b"])
Loading