From d34836fccff78034f33a8b6bbea21f32a150aca4 Mon Sep 17 00:00:00 2001 From: Enrique Gonzalez Paredes Date: Fri, 19 Jan 2024 11:02:36 +0100 Subject: [PATCH] Return actual subclasses instead of generic aliases when concretizing and other bug fixes --- src/gt4py/eve/datamodels/core.py | 57 +++------ src/gt4py/eve/extended_typing.py | 4 +- tests/eve_tests/unit_tests/test_datamodels.py | 117 ++++++++++-------- 3 files changed, 83 insertions(+), 95 deletions(-) diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index 5532eb4c95..8f1b9e6554 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -883,17 +883,6 @@ def _substitute_typevars( return type_params_map[type_hint], True elif getattr(type_hint, "__parameters__", []): return type_hint[tuple(type_params_map[tp] for tp in type_hint.__parameters__)], True - # TODO(egparedes): WIP fix for partial specialization - # # Type hint is a generic model: replace all the concretized type vars - # noqa: e800 replaced = False - # noqa: e800 new_args = [] - # noqa: e800 for tp in type_hint.__parameters__: - # noqa: e800 if tp in type_params_map: - # noqa: e800 new_args.append(type_params_map[tp]) - # noqa: e800 replaced = True - # noqa: e800 else: - # noqa: e800 new_args.append(type_params_map[tp]) - # noqa: e800 return type_hint[tuple(new_args)], replaced else: return type_hint, False @@ -971,14 +960,6 @@ def __pretty__( return __pretty__ -def _is_concrete_data_model( - cls: Type, type_args: Tuple[Type] -) -> typing.TypeGuard[Type[DataModelT]]: - return hasattr(cls, "__bound_type_params__") and all( - a == b for a, b in zip(cls.__bound_type_params__.values(), type_args) - ) - - def _make_data_model_class_getitem() -> classmethod: def __class_getitem__( cls: Type[GenericDataModelT], args: Union[Type, Tuple[Type]] @@ -988,24 +969,15 @@ def __class_getitem__( See :class:`GenericDataModelAlias` for further information. """ type_args: Tuple[Type] = args if isinstance(args, tuple) else (args,) - concrete_cls: Type[DataModelT] = ( - cls if _is_concrete_data_model(cls, type_args) else concretize(cls, *type_args) - ) - res = xtyping.StdGenericAliasType(concrete_cls, type_args) - if sys.version_info < (3, 9): - # in Python 3.8, xtyping.StdGenericAliasType (aka typing._GenericAlias) - # does not copy all required `__dict__` entries, so do it manually - for k, v in concrete_cls.__dict__.items(): - if k not in res.__dict__: - res.__dict__[k] = v - return res + concrete_cls: Type[DataModelT] = concretize(cls, *type_args) + return concrete_cls return classmethod(__class_getitem__) def _make_type_converter(type_annotation: TypeAnnotation, name: str) -> TypeConverter[_T]: - # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code as a tree traversal. - # + # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code + # as a tree traversal. if xtyping.is_actual_type(type_annotation) and not isinstance(None, type_annotation): assert not xtyping.get_args(type_annotation) assert isinstance(type_annotation, type) @@ -1326,11 +1298,7 @@ def _make_concrete_with_cache( # Replace field definitions with the new actual types for generic fields type_params_map = dict(zip(datamodel_cls.__parameters__, type_args)) model_fields = getattr(datamodel_cls, MODEL_FIELD_DEFINITIONS_ATTR) - new_annotations = { - # TODO(egparedes): ? - # noqa: e800 "__args__": "ClassVar[Tuple[Union[Type, TypeVar], ...]]", - # noqa: e800 "__parameters__": "ClassVar[Tuple[TypeVar, ...]]", - } + new_annotations = {} new_field_c_attrs = {} for field_name, field_type in xtyping.get_type_hints(datamodel_cls).items(): @@ -1358,18 +1326,21 @@ def _make_concrete_with_cache( class_name = f"{datamodel_cls.__name__}__{'_'.join(arg_names)}" - bound_type_params = { - tp_var.__name__: type_params_map[tp_var] for tp_var in datamodel_cls.__parameters__ - } - namespace = { "__annotations__": new_annotations, "__module__": module if module else datamodel_cls.__module__, - "__bound_type_params__": bound_type_params, # TODO(havogt) is this useful information? **new_field_c_attrs, } - concrete_cls = type(class_name, (datamodel_cls,), namespace) + + # Update the tuple of generic parameters in the new class, in case + # this is a partial concretization + assert hasattr(concrete_cls, "__parameters__") + concrete_cls.__parameters__ = tuple( + type_params_map[tp_var] + for tp_var in datamodel_cls.__parameters__ + if isinstance(type_params_map[tp_var], typing.TypeVar) + ) assert concrete_cls.__module__ == module or not module if MODEL_FIELD_DEFINITIONS_ATTR not in concrete_cls.__dict__: diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 4697fb7e8f..d0be9660bd 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -643,9 +643,11 @@ def get_partial_type_hints( resolved_hints = get_type_hints( # type: ignore[call-arg] # Python 3.8 does not define `include-extras` obj, globalns=globalns, localns=localns, include_extras=include_extras ) - hints.update(resolved_hints) + hints[name] = resolved_hints[name] except NameError as error: if isinstance(hint, str): + # This conversion could be probably skipped after the fix applied in bpo-41370. + # Check: https://github.com/python/cpython/commit/b465b606049f6f7dd0711cb031fdaa251818741a#diff-ddb987fca5f5df0c9a2f5521ed687919d70bb3d64eaeb8021f98833a2a716887R344 hints[name] = ForwardRef(hint) elif isinstance(hint, (ForwardRef, _typing.ForwardRef)): hints[name] = hint diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index 8fa9e02cb6..b39e30bbde 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -15,6 +15,7 @@ from __future__ import annotations import enum +import numbers import types import typing from typing import Set # noqa: F401 # imported but unused (used in exec() context) @@ -1150,66 +1151,80 @@ class PartialGenericModel(datamodels.GenericDataModel, Generic[T]): with pytest.raises(TypeError, match="'PartialGenericModel__int.value'"): PartialGenericModel__int(value=["1"]) - def test_partial_specialization(self): - class PartialGenericModel(datamodels.GenericDataModel, Generic[T, U]): + def test_partial_concretization(self): + class BaseGenericModel(datamodels.GenericDataModel, Generic[T, U]): value: List[Tuple[T, U]] - PartialGenericModel(value=[]) - PartialGenericModel(value=[("value", 3)]) - PartialGenericModel(value=[(1, "value")]) - PartialGenericModel(value=[(-1.0, "value")]) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=1) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=(1, 2)) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=[()]) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=[(1,)]) + BaseGenericModel(value=[]) + BaseGenericModel(value=[("value", 3)]) + BaseGenericModel(value=[(1, "value")]) + BaseGenericModel(value=[(-1.0, "value")]) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=1) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=(1, 2)) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=[()]) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=[(1,)]) + + assert len(BaseGenericModel.__parameters__) == 2 + + PartiallyConcretizedGenericModel = BaseGenericModel[int, U] + + assert len(PartiallyConcretizedGenericModel.__parameters__) == 1 + + PartiallyConcretizedGenericModel(value=[]) + PartiallyConcretizedGenericModel(value=[(1, 2)]) + PartiallyConcretizedGenericModel(value=[(1, "value")]) + PartiallyConcretizedGenericModel(value=[(1, (11, 12))]) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=1) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=(1, 2)) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=[1.0]) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=["1"]) - print(f"{PartialGenericModel.__parameters__=}") - print(f"{hasattr(PartialGenericModel ,'__args__')=}") + FullySpecializedGenericModel = PartiallyConcretizedGenericModel[str] - PartiallySpecializedGenericModel = PartialGenericModel[int, U] - print(f"{PartiallySpecializedGenericModel.__datamodel_fields__=}") - print(f"{PartiallySpecializedGenericModel.__parameters__=}") - print(f"{PartiallySpecializedGenericModel.__args__=}") + assert len(FullySpecializedGenericModel.__parameters__) == 0 - PartiallySpecializedGenericModel(value=[]) - PartiallySpecializedGenericModel(value=[(1, 2)]) - PartiallySpecializedGenericModel(value=[(1, "value")]) - PartiallySpecializedGenericModel(value=[(1, (11, 12))]) + FullySpecializedGenericModel(value=[]) + FullySpecializedGenericModel(value=[(1, "value")]) + with pytest.raises(TypeError, match=".value'"): + FullySpecializedGenericModel(value=1) + with pytest.raises(TypeError, match=".value'"): + FullySpecializedGenericModel(value=(1, 2)) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=1) + FullySpecializedGenericModel(value=[1.0]) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=(1, 2)) + FullySpecializedGenericModel(value=["1"]) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=[1.0]) + FullySpecializedGenericModel(value=1) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=["1"]) - - # TODO(egparedes): after fixing partial nested datamodel specialization - # noqa: e800 FullySpecializedGenericModel = PartiallySpecializedGenericModel[str] - # noqa: e800 print(f"{FullySpecializedGenericModel.__datamodel_fields__=}") - # noqa: e800 print(f"{FullySpecializedGenericModel.__parameters__=}") - # noqa: e800 print(f"{FullySpecializedGenericModel.__args__=}") - - # noqa: e800 FullySpecializedGenericModel(value=[]) - # noqa: e800 FullySpecializedGenericModel(value=[(1, "value")]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=1) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=(1, 2)) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[1.0]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=["1"]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=1) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[(1, 2)]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[(1, (11, 12))]) + FullySpecializedGenericModel(value=[(1, 2)]) + with pytest.raises(TypeError, match=".value'"): + FullySpecializedGenericModel(value=[(1, (11, 12))]) + + def test_partial_concretization_with_typevar(self): + class PartialGenericModel(datamodels.GenericDataModel, Generic[T]): + a: T + values: List[T] + + B = TypeVar("B", bound=numbers.Number) + PartiallyConcretizedGenericModel = PartialGenericModel[B] + + PartiallyConcretizedGenericModel(a=1, values=[2, 3]) + PartiallyConcretizedGenericModel(a=-1.32, values=[2.2, 3j]) + + with pytest.raises(TypeError, match=".a'"): + PartiallyConcretizedGenericModel(a="1", values=[2, 3]) + with pytest.raises(TypeError, match=".values'"): + PartiallyConcretizedGenericModel(a=1, values=[1, "2"]) + with pytest.raises(TypeError, match=".values'"): + PartiallyConcretizedGenericModel(a=1, values=(1, 2)) # Reuse sample_type_data from test_field_type_hint @pytest.mark.parametrize(["type_hint", "valid_values", "wrong_values"], SAMPLE_TYPE_DATA)