diff --git a/pyproject.toml b/pyproject.toml index dd7fb69b..897e5bfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ dependencies = [ # https://github.com/pyapp-kit/psygnal/issues/350 "mypy==1.13.0", "mypy_extensions >=0.4.2", - "pydantic!=2.10.0", + "pydantic!=2.10.0", # typing error in v2.10 prevents mypyc from working "types-attrs", "msgspec", ] diff --git a/src/psygnal/_evented_model.py b/src/psygnal/_evented_model.py index 31f90f98..38ebacf2 100644 --- a/src/psygnal/_evented_model.py +++ b/src/psygnal/_evented_model.py @@ -125,8 +125,14 @@ def _get_defaults( def _get_config(cls: pydantic.BaseModel) -> "ConfigDict": return cls.model_config - def _get_fields(cls: pydantic.BaseModel) -> dict[str, pydantic.fields.FieldInfo]: - return cls.model_fields + def _get_fields( + cls: pydantic.BaseModel, + ) -> dict[str, pydantic.fields.FieldInfo]: + comp_fields = { + name: pydantic.fields.FieldInfo(annotation=f.return_type, frozen=False) + for name, f in cls.model_computed_fields.items() + } + return {**cls.model_fields, **comp_fields} def _model_dump(obj: pydantic.BaseModel) -> dict: return obj.model_dump() @@ -353,9 +359,11 @@ class Config: f"{prop!r} is not." ) for field in fields: - if field not in model_fields: + if field not in model_fields and not hasattr(cls, field): warnings.warn( - f"Unrecognized field dependency: {field!r}", stacklevel=2 + f"property {prop!r} cannot depend on unrecognized attribute " + f"name: {field!r}", + stacklevel=2, ) deps.setdefault(field, set()).add(prop) if model_config.get(GUESS_PROPERTY_DEPENDENCIES, False): @@ -470,6 +478,7 @@ class Config: _changes_queue: dict[str, Any] = PrivateAttr(default_factory=dict) _primary_changes: set[str] = PrivateAttr(default_factory=set) _delay_check_semaphore: int = PrivateAttr(0) + _names_that_need_emission: set[str] = PrivateAttr(default_factory=set) if PYDANTIC_V1: @@ -484,6 +493,9 @@ def __init__(_model_self_, **data: Any) -> None: # but if we don't use `ClassVar`, then the `dataclass_transform` decorator # will add _events: SignalGroup to the __init__ signature, for *all* user models _model_self_._events = Group(_model_self_) # type: ignore [misc] + _model_self_._names_that_need_emission = set(_model_self_._events) | set( + _model_self_.__field_dependents__ + ) # expose the private SignalGroup publicly @property @@ -565,15 +577,23 @@ def _check_if_values_changed_and_emit_if_needed(self) -> None: # do not run whole machinery if there is no need return to_emit = [] + must_continue = False for name in self._primary_changes: # primary changes should contains only fields # that are changed directly by assignment old_value = self._changes_queue[name] new_value = getattr(self, name) + if not _check_field_equality(type(self), name, new_value, old_value): - to_emit.append((name, new_value)) + if name in self._events: + to_emit.append((name, new_value)) + else: + # An attribute is changing that is not in the SignalGroup + # if it has field dependents, we must still continue + # to check the _changes_queue + must_continue = name in self.__field_dependents__ self._changes_queue.pop(name) - if not to_emit: + if not to_emit and not must_continue: # If no direct changes was made then we can skip whole machinery self._changes_queue.clear() self._primary_changes.clear() @@ -595,7 +615,7 @@ def __setattr__(self, name: str, value: Any) -> None: if ( name == "_events" or not hasattr(self, "_events") # can happen on init - or name not in self._events + or name not in self._names_that_need_emission ): # fallback to default behavior return self._super_setattr_(name, value) @@ -646,18 +666,24 @@ def _setattr_impl(self, name: str, value: Any) -> None: # note that ALL signals will have sat least one listener simply by nature of # being in the `self._events` SignalGroup. group = self._events - signal_instance: SignalInstance = group[name] - deps_with_callbacks = { - dep_name - for dep_name in self.__field_dependents__.get(name, ()) - if len(group[dep_name]) - } - if ( - len(signal_instance) < 1 # the signal itself has no listeners - and not deps_with_callbacks # no dependent properties with listeners - and not len(group._psygnal_relay) # no listeners on the SignalGroup - ): + if name in group: + signal_instance: SignalInstance = group[name] + deps_with_callbacks = { + dep_name + for dep_name in self.__field_dependents__.get(name, ()) + if len(group[dep_name]) + } + if ( + len(signal_instance) < 1 # the signal itself has no listeners + and not deps_with_callbacks # no dependent properties with listeners + and not len(group._psygnal_relay) # no listeners on the SignalGroup + ): + return self._super_setattr_(name, value) + elif name in self.__field_dependents__: + deps_with_callbacks = self.__field_dependents__[name] + else: return self._super_setattr_(name, value) + self._primary_changes.add(name) if name not in self._changes_queue: self._changes_queue[name] = getattr(self, name, object()) diff --git a/tests/test_evented_model.py b/tests/test_evented_model.py index 693a8b40..0a60e890 100644 --- a/tests/test_evented_model.py +++ b/tests/test_evented_model.py @@ -587,7 +587,7 @@ class Config: def test_unrecognized_property_dependencies(): - with pytest.warns(UserWarning, match="Unrecognized field dependency: 'b'"): + with pytest.warns(UserWarning, match="cannot depend on unrecognized attribute"): class M(EventedModel): x: int @@ -983,3 +983,49 @@ def c(self, val: Sequence[int]) -> None: m.a = 5 mock_a.assert_called_with(5) mock_c.assert_called_with([5, 20]) + + +@pytest.mark.skipif(not PYDANTIC_V2, reason="computed_field added in v2") +def test_private_field_dependents(): + from pydantic import PrivateAttr, computed_field + + from psygnal import EventedModel + + class MyModel(EventedModel): + _items_dict: dict[str, int] = PrivateAttr(default_factory=dict) + + @computed_field # type: ignore [prop-decorator] + @property + def item_names(self) -> list[str]: + return list(self._items_dict) + + @computed_field # type: ignore [prop-decorator] + @property + def item_sum(self) -> int: + return sum(self._items_dict.values()) + + def add_item(self, name: str, value: int) -> None: + if name in self._items_dict: + raise ValueError(f"Name {name} already exists!") + self._items_dict = {**self._items_dict, name: value} + + # Ideally the following would work + model_config = { # type: ignore [typeddict-unknown-key] + "field_dependencies": { + "item_names": ["_items_dict"], + "item_sum": ["_items_dict"], + } + } + + m = MyModel() + item_sum_mock = Mock() + item_names_mock = Mock() + m.events.item_sum.connect(item_sum_mock) + m.events.item_names.connect(item_names_mock) + m.add_item("a", 1) + item_sum_mock.assert_called_with(1) + item_names_mock.assert_called_with(["a"]) + item_sum_mock.reset_mock() + m.add_item("b", 2) + item_sum_mock.assert_called_with(3) + item_names_mock.assert_called_with(["a", "b"])