diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 5ff9dfd8..21ed61f3 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -32,14 +32,22 @@ """ from collections.abc import Mapping, Sequence -from dataclasses import Field, fields, is_dataclass -from typing import Union, get_args, get_origin +from dataclasses import fields, is_dataclass +from typing import NamedTuple, Union, get_args, get_origin from arraycontext.container import is_array_container_type # {{{ dataclass containers +class _Field(NamedTuple): + """Small lookalike for :class:`dataclasses.Field`.""" + + init: bool + name: str + type: type + + def is_array_type(tp: type) -> bool: from arraycontext import Array return tp is Array or is_array_container_type(tp) @@ -73,7 +81,9 @@ def dataclass_array_container(cls: type) -> type: assert is_dataclass(cls) - def is_array_field(f: Field, field_type: type) -> bool: + def is_array_field(f: _Field) -> bool: + field_type = f.type + # NOTE: unions of array containers are treated separately to handle # unions of only array containers, e.g. `Union[np.ndarray, Array]`, as # they can work seamlessly with arithmetic and traversal. @@ -96,10 +106,8 @@ def is_array_field(f: Field, field_type: type) -> bool: f"Field '{f.name}' union contains non-array container " "arguments. All arguments must be array containers.") - if isinstance(field_type, str): - raise TypeError( - f"String annotation on field '{f.name}' not supported. " - "(this may be due to 'from __future__ import annotations')") + # NOTE: this should never happen due to using `inspect.get_annotations` + assert not isinstance(field_type, str) if __debug__: if not f.init: @@ -127,36 +135,52 @@ def is_array_field(f: Field, field_type: type) -> bool: return is_array_type(field_type) + from pytools import partition + + array_fields = _get_annotated_fields(cls) + array_fields, non_array_fields = partition(is_array_field, array_fields) + + if not array_fields: + raise ValueError(f"'{cls}' must have fields with array container type " + "in order to use the 'dataclass_array_container' decorator") + + return _inject_dataclass_serialization(cls, array_fields, non_array_fields) + + +def _get_annotated_fields(cls: type) -> Sequence[_Field]: + """Get a list of fields in the class *cls* with evaluated types. + + If any of the fields in *cls* have type annotations that are strings, e.g. + from using ``from __future__ import annotations``, this function evaluates + them using :func:`inspect.get_annotations`. Note that this requires the class + to live in a module that is importable. + + :return: a list of fields. + """ + from inspect import get_annotations - array_fields: list[Field] = [] - non_array_fields: list[Field] = [] + result = [] cls_ann: Mapping[str, type] | None = None for field in fields(cls): field_type_or_str = field.type if isinstance(field_type_or_str, str): if cls_ann is None: cls_ann = get_annotations(cls, eval_str=True) + field_type = cls_ann[field.name] else: field_type = field_type_or_str - if is_array_field(field, field_type): - array_fields.append(field) - else: - non_array_fields.append(field) - - if not array_fields: - raise ValueError(f"'{cls}' must have fields with array container type " - "in order to use the 'dataclass_array_container' decorator") + result.append(_Field(init=field.init, name=field.name, type=field_type)) - return _inject_dataclass_serialization(cls, array_fields, non_array_fields) + return result def _inject_dataclass_serialization( cls: type, - array_fields: Sequence[Field], - non_array_fields: Sequence[Field]) -> type: + array_fields: Sequence[_Field], + non_array_fields: Sequence[_Field]) -> type: """Implements :func:`~arraycontext.serialize_container` and :func:`~arraycontext.deserialize_container` for the given dataclass *cls*.