Skip to content

Commit

Permalink
add coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
getzze committed Mar 12, 2024
1 parent 7e848ca commit b05eefd
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 11 deletions.
15 changes: 11 additions & 4 deletions src/psygnal/_group_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,15 @@ def _build_dataclass_signal_group(
# parse arguments
_equality_operators = dict(equality_operators) if equality_operators else {}
eq_map = _get_eq_operator_map(cls)
transform: FieldAliasFunc | None
if callable(signal_aliases):
transform = signal_aliases
_signal_aliases = {}
else:
transform = _identity
transform = None
_signal_aliases = dict(signal_aliases) if signal_aliases else {}
signal_group_sig_names = list(getattr(signal_group_class, "_psygnal_signals", {}))
signal_group_sig_aliases = dict(getattr(signal_group_class, "_psygnal_aliases", {}))

signals = {}
# create a Signal for each field in the dataclass
Expand All @@ -209,8 +211,12 @@ def _build_dataclass_signal_group(
sig_name: str | None
if name in _signal_aliases:
sig_name = _signal_aliases[name]
else:
elif callable(transform):
sig_name = transform(name)
elif name in signal_group_sig_aliases:
sig_name = signal_group_sig_aliases[name]
else:
sig_name = name

# Add the field and signal name to the table of signals, to emit with `setattr`
_signal_aliases[name] = sig_name
Expand All @@ -223,15 +229,16 @@ def _build_dataclass_signal_group(
if sig_name in signals:
key = next((k for k, v in _signal_aliases.items() if v == sig_name), None)
warnings.warn(
f"Signal {sig_name} was already created in {group_name}, "
f"Skip signal {sig_name!r}, was already created in {group_name}, "
f"from field {key}",
UserWarning,
stacklevel=2,
)
continue
if sig_name in signal_group_sig_names:
warnings.warn(
f"Skip signal {sig_name}, was already defined by {signal_group_class}",
f"Skip signal {sig_name!r}, was already defined by "
f"{signal_group_class}",
UserWarning,
stacklevel=2,
)
Expand Down
120 changes: 113 additions & 7 deletions tests/test_group_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,21 @@
],
)
def test_alias_parameters(type_: str) -> None:

class MyGroup(SignalGroup, signal_aliases={"b": None, "bb": None}):
b = Signal(str, str)
bb = Signal(str, str)


foo_options = {"signal_aliases": {"_b": None}}
bar_options = {
"signal_aliases": lambda x: None if x.startswith("_") else f"{x}_changed"
}
baz_options = {"signal_aliases": {"a": "a_changed", "_b": "b_changed"}}
baz2_options = {
"signal_group_class": MyGroup, "signal_aliases": {"aa": "a", "bb": "b"}
}


if type_ == "dataclass":
from dataclasses import dataclass, field
Expand Down Expand Up @@ -57,6 +67,15 @@ def b(self) -> str:
def b(self, value: str):
self._b = value

@dataclass
class Baz2:
events: ClassVar = SignalGroupDescriptor(**baz2_options)
a: int
aa: int
b: str
bb: str


elif type_ == "attrs":
from attrs import define, field

Expand Down Expand Up @@ -86,6 +105,15 @@ def b(self) -> str:
def b(self, value: str):
self._b = value

@define
class Baz2:
events: ClassVar = SignalGroupDescriptor(**baz2_options)
a: int
aa: int
b: str
bb: str


elif type_ == "pydantic":
pytest.importorskip("pydantic", minversion="2")
from pydantic import BaseModel
Expand Down Expand Up @@ -113,6 +141,14 @@ def b(self) -> str:
def b(self, value: str):
self._b = value

class Baz2(BaseModel):
events: ClassVar = SignalGroupDescriptor(**baz2_options)
a: int
aa: int
b: str
bb: str


elif type_ == "msgspec":
msgspec = pytest.importorskip("msgspec")

Expand All @@ -139,10 +175,19 @@ def b(self) -> str:
def b(self, value: str):
self._b = value

class Baz2(msgspec.Struct): # type: ignore
events: ClassVar = SignalGroupDescriptor(**baz2_options)
a: int
aa: int
b: str
bb: str


# Instantiate objects
foo = Foo(a=1, _b="b")
bar = Bar(a=1, _b="b")
baz = Baz(a=1)
baz2 = Baz2(a=1, aa=2, b="b", bb="bb")

# Check signals
assert set(foo.events) == {"a"}
Expand All @@ -166,12 +211,33 @@ def b(self, value: str):
"_b": "b_changed",
}

# with pytest.warns(UserWarning, match=r"Skip signal \'a\', was already created"):
with pytest.warns(UserWarning) as record:
assert set(baz2.events) == {"a", "b", "bb"}
assert len(record) == 2
assert record[0].message.args[0].startswith(
"Skip signal \'a\', was already created"
)
assert record[1].message.args[0].startswith(
"Skip signal \'b\', was already defined"
)
assert hasattr(baz.events, "_psygnal_aliases")
assert baz2.events._psygnal_aliases == {
"a": "a",
"aa": "a",
"b": None,
"bb": "b",
}

mock = Mock()
foo.events.a.connect(mock)
bar.events.a_changed.connect(mock)
baz.events.a_changed.connect(mock)
if not type_.startswith("pydantic"):
baz.events.b_changed.connect(mock)
baz2.events.a.connect(mock)
baz2.events.b.connect(mock)
baz2.events.bb.connect(mock)

# Foo
foo.a = 1
Expand Down Expand Up @@ -200,6 +266,26 @@ def b(self, value: str):
mock.assert_called_once_with(2, 1)
mock.reset_mock()

# Baz2
baz2.a = 1
baz2.aa = 2
mock.assert_not_called()
baz2.a = 2
mock.assert_called_once_with(2, 1)
mock.reset_mock()
baz2.aa = 3
mock.assert_called_once_with(3, 2)
mock.reset_mock()
baz2.b = "b"
mock.assert_not_called()
baz2.b = "c"
mock.assert_not_called()
baz2.bb = "bb"
mock.assert_not_called()
baz2.bb = "bbb"
mock.assert_called_once_with("bbb", "bb")
mock.reset_mock()

# pydantic v1 does not support properties
if type_ == "pydantic_v1":
return
Expand All @@ -211,8 +297,7 @@ def b(self, value: str):
mock.assert_called_once_with("c", "b")


@pytest.mark.parametrize("collect", [False, True])
def test_direct_signal_group(collect) -> None:
def test_direct_signal_group() -> None:

class FooSignalGroup(SignalGroup, signal_aliases={"e": None}):
a = Signal(int, int)
Expand All @@ -224,19 +309,33 @@ class FooSignalGroup(SignalGroup, signal_aliases={"e": None}):
class Foo:
events: ClassVar = SignalGroupDescriptor(
signal_group_class=FooSignalGroup,
collect_fields=collect,
signal_aliases={"b": "b_changed", "c": None, "_c": "c"},
collect_fields=False,
signal_aliases={
"b": "b_changed",
"c": None,
"_c": "c",
"_e": "e",
},
)
a: int
b: float
_c: str
_d: str

def __init__(self, a: int = 1, b: float = 2.0, c: str = "c", d: str = "d"):
_e: int

def __init__(
self,
a: int = 1,
b: float = 2.0,
c: str = "c",
d: str = "d",
_e: int = 5,
):
self.a = a
self.b = b
self.c = c
self.d = d
self._e = _e

@property
def c(self) -> str:
Expand All @@ -262,15 +361,16 @@ def d(self, value: str):
'_c': 'c',
'c': None,
'e': None,
'_e': 'e',
}

mock = Mock()
foo.events.a.connect(mock)
foo.events.b_changed.connect(mock)
foo.events.c.connect(mock)
foo.events.d.connect(mock)

foo.events.e.connect(mock)

foo.events.e.emit("f", "e")
mock.assert_called_once_with("f", "e")
mock.reset_mock()
Expand All @@ -297,3 +397,9 @@ def d(self, value: str):
foo.d = "DD"
mock.assert_called_once_with("dd", "d")
mock.reset_mock()

foo._e = 5
mock.assert_not_called()
foo._e = 6
mock.assert_called_once_with(6, 5)
mock.reset_mock()

0 comments on commit b05eefd

Please sign in to comment.