Skip to content

Commit

Permalink
add options for a signal aliases
Browse files Browse the repository at this point in the history
adapt to SignalRelay
  • Loading branch information
getzze committed Mar 22, 2024
1 parent f6cebcb commit f5d4fa4
Show file tree
Hide file tree
Showing 5 changed files with 496 additions and 23 deletions.
2 changes: 2 additions & 0 deletions src/psygnal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"EventedModel",
"get_evented_namespace",
"is_evented",
"PSYGNAL_METADATA",
"Signal",
"SignalGroup",
"SignalGroupDescriptor",
Expand All @@ -48,6 +49,7 @@
stacklevel=2,
)

from ._dataclass_utils import PSYGNAL_METADATA
from ._evented_decorator import evented
from ._exceptions import EmitLoopError
from ._group import EmissionInfo, SignalGroup
Expand Down
271 changes: 263 additions & 8 deletions src/psygnal/_dataclass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,30 @@
import dataclasses
import sys
import types
from typing import TYPE_CHECKING, Any, Iterator, List, Protocol, cast, overload
from dataclasses import dataclass, fields
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterator,
List,
Mapping,
Protocol,
cast,
overload,
)

if TYPE_CHECKING:
from dataclasses import Field

import attrs
import msgspec
from pydantic import BaseModel
from typing_extensions import TypeGuard # py310
from typing_extensions import TypeAlias, TypeGuard # py310

EqOperator: TypeAlias = Callable[[Any, Any], bool]

PSYGNAL_METADATA = "__psygnal_metadata"


class _DataclassParams(Protocol):
Expand All @@ -29,12 +46,11 @@ class AttrsType:
__attrs_attrs__: tuple[attrs.Attribute, ...]


_DATACLASS_PARAMS = "__dataclass_params__"
KW_ONLY = object()
with contextlib.suppress(ImportError):
from dataclasses import _DATACLASS_PARAMS # type: ignore
from dataclasses import KW_ONLY # py310
_DATACLASS_PARAMS = "__dataclass_params__"
_DATACLASS_FIELDS = "__dataclass_fields__"
with contextlib.suppress(ImportError):
from dataclasses import _DATACLASS_FIELDS # type: ignore


class DataClassType:
Expand Down Expand Up @@ -171,8 +187,8 @@ def iter_fields(
yield field_name, p_field.annotation
else:
for p_field in cls.__fields__.values(): # type: ignore [attr-defined]
if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore
yield p_field.name, p_field.outer_type_ # type: ignore
if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore [attr-defined]
yield p_field.name, p_field.outer_type_ # type: ignore [attr-defined]
return

if (attrs_fields := getattr(cls, "__attrs_attrs__", None)) is not None:
Expand All @@ -185,3 +201,242 @@ def iter_fields(
type_ = cls.__annotations__.get(m_field, None)
yield m_field, type_
return


@dataclass
class FieldOptions:
name: str
type_: type | None = None
# set KW_ONLY value for compatibility with python < 3.10
_: KW_ONLY = KW_ONLY # type: ignore [valid-type]
alias: str | None = None
skip: bool | None = None
eq: EqOperator | None = None
disable_setattr: bool | None = None


def is_kw_only(f: Field) -> bool:
if hasattr(f, "kw_only"):
return cast(bool, f.kw_only)
# for python < 3.10
if f.name not in ["name", "type_"]:
return True
return False


def sanitize_field_options_dict(d: Mapping) -> dict[str, Any]:
field_options_kws = [f.name for f in fields(FieldOptions) if is_kw_only(f)]
return {k: v for k, v in d.items() if k in field_options_kws}


def get_msgspec_metadata(
cls: type[msgspec.Struct],
m_field: str,
) -> tuple[type | None, dict[str, Any]]:
# Look for type in cls and super classes
type_: type | None = None
for super_cls in cls.__mro__:
if not hasattr(super_cls, "__annotations__"):
continue

Check warning on line 240 in src/psygnal/_dataclass_utils.py

View check run for this annotation

Codecov / codecov/patch

src/psygnal/_dataclass_utils.py#L240

Added line #L240 was not covered by tests
type_ = super_cls.__annotations__.get(m_field, None)
if type_ is not None:
break

msgspec = sys.modules.get("msgspec", None)
if msgspec is None:
return type_, {}

Check warning on line 247 in src/psygnal/_dataclass_utils.py

View check run for this annotation

Codecov / codecov/patch

src/psygnal/_dataclass_utils.py#L247

Added line #L247 was not covered by tests

metadata_list = getattr(type_, "__metadata__", [])

metadata: dict[str, Any] = {}
for meta in metadata_list:
if not isinstance(meta, msgspec.Meta):
continue

Check warning on line 254 in src/psygnal/_dataclass_utils.py

View check run for this annotation

Codecov / codecov/patch

src/psygnal/_dataclass_utils.py#L254

Added line #L254 was not covered by tests
single_meta: dict[str, Any] = getattr(meta, "extra", {}).get(
PSYGNAL_METADATA, {}
)
metadata.update(single_meta)

return type_, metadata


def iter_fields_with_options(
cls: type, exclude_frozen: bool = True
) -> Iterator[FieldOptions]:
"""Iterate over all fields in the class, return a field description.
This function recognizes dataclasses, attrs classes, msgspec Structs, and pydantic
models.
Parameters
----------
cls : type
The class to iterate over.
exclude_frozen : bool, optional
If True, frozen fields will be excluded. By default True.
Yields
------
FieldOptions
A dataclass instance with the name, type and metadata of each field.
"""
# Add metadata for dataclasses.dataclass
dclass_fields = getattr(cls, "__dataclass_fields__", None)
if dclass_fields is not None:
"""
Example
-------
from dataclasses import dataclass, field
@dataclass
class Foo:
bar: int = field(metadata={"alias": "bar_alias"})
assert (
Foo.__dataclass_fields__["bar"].metadata ==
{"__psygnal_metadata": {"alias": "bar_alias"}}
)
"""
for d_field in dclass_fields.values():
if d_field._field_type is dataclasses._FIELD: # type: ignore [attr-defined]
metadata = getattr(d_field, "metadata", {}).get(PSYGNAL_METADATA, {})
metadata = sanitize_field_options_dict(metadata)
options = FieldOptions(d_field.name, d_field.type, **metadata)
yield options
return

# Add metadata for pydantic dataclass
if is_pydantic_model(cls):
"""
Example
-------
from typing import Annotated
from pydantic import BaseModel, Field
# Only works with Pydantic v2
class Foo(BaseModel):
bar: Annotated[
str,
{'__psygnal_metadata': {"alias": "bar_alias"}}
] = Field(...)
# Working with Pydantic v2 and partially with v1
# Alternative, using Field `json_schema_extra` keyword argument
class Bar(BaseModel):
bar: str = Field(
json_schema_extra={PSYGNAL_METADATA: {"alias": "bar_alias"}}
)
assert (
Foo.model_fields["bar"].metadata[0] ==
{"__psygnal_metadata": {"alias": "bar_alias"}}
)
assert (
Bar.model_fields["bar"].json_schema_extra ==
{"__psygnal_metadata": {"alias": "bar_alias"}}
)
"""
if hasattr(cls, "model_fields"):
# Pydantic v2
for field_name, p_field in cls.model_fields.items():
# skip frozen field
if exclude_frozen and p_field.frozen:
continue

Check warning on line 350 in src/psygnal/_dataclass_utils.py

View check run for this annotation

Codecov / codecov/patch

src/psygnal/_dataclass_utils.py#L350

Added line #L350 was not covered by tests
metadata_list = getattr(p_field, "metadata", [])
metadata = {}
for field in metadata_list:
metadata.update(field.get(PSYGNAL_METADATA, {}))
# Compat with using Field `json_schema_extra` keyword argument
if isinstance(getattr(p_field, "json_schema_extra", None), Mapping):
meta_dict = cast(Mapping, p_field.json_schema_extra)
metadata.update(meta_dict.get(PSYGNAL_METADATA, {}))
metadata = sanitize_field_options_dict(metadata)
options = FieldOptions(field_name, p_field.annotation, **metadata)
yield options
return

else:
# Pydantic v1, metadata is not always working
for pv1_field in cls.__fields__.values(): # type: ignore [attr-defined]
# skip frozen field
if exclude_frozen and not pv1_field.field_info.allow_mutation:
continue

Check warning on line 369 in src/psygnal/_dataclass_utils.py

View check run for this annotation

Codecov / codecov/patch

src/psygnal/_dataclass_utils.py#L369

Added line #L369 was not covered by tests
meta_dict = getattr(pv1_field.field_info, "extra", {}).get(
"json_schema_extra", {}
)
metadata = meta_dict.get(PSYGNAL_METADATA, {})

metadata = sanitize_field_options_dict(metadata)
options = FieldOptions(
pv1_field.name,
pv1_field.outer_type_,
**metadata,
)
yield options
return

# Add metadata for attrs dataclass
attrs_fields = getattr(cls, "__attrs_attrs__", None)
if attrs_fields is not None:
"""
Example
-------
from attrs import define, field
@define
class Foo:
bar: int = field(metadata={"alias": "bar_alias"})
assert (
Foo.__attrs_attrs__.bar.metadata ==
{"__psygnal_metadata": {"alias": "bar_alias"}}
)
"""
for a_field in attrs_fields:
metadata = getattr(a_field, "metadata", {}).get(PSYGNAL_METADATA, {})
metadata = sanitize_field_options_dict(metadata)
options = FieldOptions(a_field.name, a_field.type, **metadata)
yield options
return

# Add metadata for attrs dataclass
if is_msgspec_struct(cls):
"""
Example
-------
from typing import Annotated
from msgspec import Meta, Struct
class Foo(Struct):
bar: Annotated[
str,
Meta(extra={"__psygnal_metadata": {"alias": "bar_alias"}))
] = ""
print(Foo.__annotations__["bar"].__metadata__[0].extra)
# {"__psygnal_metadata": {"alias": "bar_alias"}}
"""
for m_field in cls.__struct_fields__:
try:
type_, metadata = get_msgspec_metadata(cls, m_field)
metadata = sanitize_field_options_dict(metadata)
except AttributeError:
msg = f"Cannot parse field metadata for {m_field}: {type_}"

Check warning on line 436 in src/psygnal/_dataclass_utils.py

View check run for this annotation

Codecov / codecov/patch

src/psygnal/_dataclass_utils.py#L435-L436

Added lines #L435 - L436 were not covered by tests
# logger.exception(msg)
print(msg)
type_, metadata = None, {}

Check warning on line 439 in src/psygnal/_dataclass_utils.py

View check run for this annotation

Codecov / codecov/patch

src/psygnal/_dataclass_utils.py#L438-L439

Added lines #L438 - L439 were not covered by tests
options = FieldOptions(m_field, type_, **metadata)
yield options
return
5 changes: 4 additions & 1 deletion src/psygnal/_evented_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from psygnal._group_descriptor import SignalGroupDescriptor

if TYPE_CHECKING:
from psygnal._group_descriptor import EqOperator, FieldAliasFunc
from psygnal._group_descriptor import ( # type: ignore[attr-defined]
EqOperator,
FieldAliasFunc,
)

__all__ = ["evented"]

Expand Down
Loading

0 comments on commit f5d4fa4

Please sign in to comment.