Skip to content

Commit

Permalink
Return actual subclasses instead of generic aliases when concretizing…
Browse files Browse the repository at this point in the history
… and other bug fixes
  • Loading branch information
egparedes committed Jan 19, 2024
1 parent bcc4f0e commit d34836f
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 95 deletions.
57 changes: 14 additions & 43 deletions src/gt4py/eve/datamodels/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,17 +883,6 @@ def _substitute_typevars(
return type_params_map[type_hint], True
elif getattr(type_hint, "__parameters__", []):
return type_hint[tuple(type_params_map[tp] for tp in type_hint.__parameters__)], True
# TODO(egparedes): WIP fix for partial specialization
# # Type hint is a generic model: replace all the concretized type vars
# noqa: e800 replaced = False
# noqa: e800 new_args = []
# noqa: e800 for tp in type_hint.__parameters__:
# noqa: e800 if tp in type_params_map:
# noqa: e800 new_args.append(type_params_map[tp])
# noqa: e800 replaced = True
# noqa: e800 else:
# noqa: e800 new_args.append(type_params_map[tp])
# noqa: e800 return type_hint[tuple(new_args)], replaced
else:
return type_hint, False

Expand Down Expand Up @@ -971,14 +960,6 @@ def __pretty__(
return __pretty__


def _is_concrete_data_model(
cls: Type, type_args: Tuple[Type]
) -> typing.TypeGuard[Type[DataModelT]]:
return hasattr(cls, "__bound_type_params__") and all(
a == b for a, b in zip(cls.__bound_type_params__.values(), type_args)
)


def _make_data_model_class_getitem() -> classmethod:
def __class_getitem__(
cls: Type[GenericDataModelT], args: Union[Type, Tuple[Type]]
Expand All @@ -988,24 +969,15 @@ def __class_getitem__(
See :class:`GenericDataModelAlias` for further information.
"""
type_args: Tuple[Type] = args if isinstance(args, tuple) else (args,)
concrete_cls: Type[DataModelT] = (
cls if _is_concrete_data_model(cls, type_args) else concretize(cls, *type_args)
)
res = xtyping.StdGenericAliasType(concrete_cls, type_args)
if sys.version_info < (3, 9):
# in Python 3.8, xtyping.StdGenericAliasType (aka typing._GenericAlias)
# does not copy all required `__dict__` entries, so do it manually
for k, v in concrete_cls.__dict__.items():
if k not in res.__dict__:
res.__dict__[k] = v
return res
concrete_cls: Type[DataModelT] = concretize(cls, *type_args)
return concrete_cls

return classmethod(__class_getitem__)


def _make_type_converter(type_annotation: TypeAnnotation, name: str) -> TypeConverter[_T]:
# TODO(egparedes): if a "typing tree" structure is implemented, refactor this code as a tree traversal.
#
# TODO(egparedes): if a "typing tree" structure is implemented, refactor this code
# as a tree traversal.
if xtyping.is_actual_type(type_annotation) and not isinstance(None, type_annotation):
assert not xtyping.get_args(type_annotation)
assert isinstance(type_annotation, type)
Expand Down Expand Up @@ -1326,11 +1298,7 @@ def _make_concrete_with_cache(
# Replace field definitions with the new actual types for generic fields
type_params_map = dict(zip(datamodel_cls.__parameters__, type_args))
model_fields = getattr(datamodel_cls, MODEL_FIELD_DEFINITIONS_ATTR)
new_annotations = {
# TODO(egparedes): ?
# noqa: e800 "__args__": "ClassVar[Tuple[Union[Type, TypeVar], ...]]",
# noqa: e800 "__parameters__": "ClassVar[Tuple[TypeVar, ...]]",
}
new_annotations = {}

new_field_c_attrs = {}
for field_name, field_type in xtyping.get_type_hints(datamodel_cls).items():
Expand Down Expand Up @@ -1358,18 +1326,21 @@ def _make_concrete_with_cache(

class_name = f"{datamodel_cls.__name__}__{'_'.join(arg_names)}"

bound_type_params = {
tp_var.__name__: type_params_map[tp_var] for tp_var in datamodel_cls.__parameters__
}

namespace = {
"__annotations__": new_annotations,
"__module__": module if module else datamodel_cls.__module__,
"__bound_type_params__": bound_type_params, # TODO(havogt) is this useful information?
**new_field_c_attrs,
}

concrete_cls = type(class_name, (datamodel_cls,), namespace)

# Update the tuple of generic parameters in the new class, in case
# this is a partial concretization
assert hasattr(concrete_cls, "__parameters__")
concrete_cls.__parameters__ = tuple(
type_params_map[tp_var]
for tp_var in datamodel_cls.__parameters__
if isinstance(type_params_map[tp_var], typing.TypeVar)
)
assert concrete_cls.__module__ == module or not module

if MODEL_FIELD_DEFINITIONS_ATTR not in concrete_cls.__dict__:
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/eve/extended_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,9 +643,11 @@ def get_partial_type_hints(
resolved_hints = get_type_hints( # type: ignore[call-arg] # Python 3.8 does not define `include-extras`
obj, globalns=globalns, localns=localns, include_extras=include_extras
)
hints.update(resolved_hints)
hints[name] = resolved_hints[name]
except NameError as error:
if isinstance(hint, str):
# This conversion could be probably skipped after the fix applied in bpo-41370.
# Check: https://github.com/python/cpython/commit/b465b606049f6f7dd0711cb031fdaa251818741a#diff-ddb987fca5f5df0c9a2f5521ed687919d70bb3d64eaeb8021f98833a2a716887R344
hints[name] = ForwardRef(hint)
elif isinstance(hint, (ForwardRef, _typing.ForwardRef)):
hints[name] = hint
Expand Down
117 changes: 66 additions & 51 deletions tests/eve_tests/unit_tests/test_datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import enum
import numbers
import types
import typing
from typing import Set # noqa: F401 # imported but unused (used in exec() context)
Expand Down Expand Up @@ -1150,66 +1151,80 @@ class PartialGenericModel(datamodels.GenericDataModel, Generic[T]):
with pytest.raises(TypeError, match="'PartialGenericModel__int.value'"):
PartialGenericModel__int(value=["1"])

def test_partial_specialization(self):
class PartialGenericModel(datamodels.GenericDataModel, Generic[T, U]):
def test_partial_concretization(self):
class BaseGenericModel(datamodels.GenericDataModel, Generic[T, U]):
value: List[Tuple[T, U]]

PartialGenericModel(value=[])
PartialGenericModel(value=[("value", 3)])
PartialGenericModel(value=[(1, "value")])
PartialGenericModel(value=[(-1.0, "value")])
with pytest.raises(TypeError, match="'PartialGenericModel.value'"):
PartialGenericModel(value=1)
with pytest.raises(TypeError, match="'PartialGenericModel.value'"):
PartialGenericModel(value=(1, 2))
with pytest.raises(TypeError, match="'PartialGenericModel.value'"):
PartialGenericModel(value=[()])
with pytest.raises(TypeError, match="'PartialGenericModel.value'"):
PartialGenericModel(value=[(1,)])
BaseGenericModel(value=[])
BaseGenericModel(value=[("value", 3)])
BaseGenericModel(value=[(1, "value")])
BaseGenericModel(value=[(-1.0, "value")])
with pytest.raises(TypeError, match="'BaseGenericModel.value'"):
BaseGenericModel(value=1)
with pytest.raises(TypeError, match="'BaseGenericModel.value'"):
BaseGenericModel(value=(1, 2))
with pytest.raises(TypeError, match="'BaseGenericModel.value'"):
BaseGenericModel(value=[()])
with pytest.raises(TypeError, match="'BaseGenericModel.value'"):
BaseGenericModel(value=[(1,)])

assert len(BaseGenericModel.__parameters__) == 2

PartiallyConcretizedGenericModel = BaseGenericModel[int, U]

assert len(PartiallyConcretizedGenericModel.__parameters__) == 1

PartiallyConcretizedGenericModel(value=[])
PartiallyConcretizedGenericModel(value=[(1, 2)])
PartiallyConcretizedGenericModel(value=[(1, "value")])
PartiallyConcretizedGenericModel(value=[(1, (11, 12))])
with pytest.raises(TypeError, match=".value'"):
PartiallyConcretizedGenericModel(value=1)
with pytest.raises(TypeError, match=".value'"):
PartiallyConcretizedGenericModel(value=(1, 2))
with pytest.raises(TypeError, match=".value'"):
PartiallyConcretizedGenericModel(value=[1.0])
with pytest.raises(TypeError, match=".value'"):
PartiallyConcretizedGenericModel(value=["1"])

print(f"{PartialGenericModel.__parameters__=}")
print(f"{hasattr(PartialGenericModel ,'__args__')=}")
FullySpecializedGenericModel = PartiallyConcretizedGenericModel[str]

PartiallySpecializedGenericModel = PartialGenericModel[int, U]
print(f"{PartiallySpecializedGenericModel.__datamodel_fields__=}")
print(f"{PartiallySpecializedGenericModel.__parameters__=}")
print(f"{PartiallySpecializedGenericModel.__args__=}")
assert len(FullySpecializedGenericModel.__parameters__) == 0

PartiallySpecializedGenericModel(value=[])
PartiallySpecializedGenericModel(value=[(1, 2)])
PartiallySpecializedGenericModel(value=[(1, "value")])
PartiallySpecializedGenericModel(value=[(1, (11, 12))])
FullySpecializedGenericModel(value=[])
FullySpecializedGenericModel(value=[(1, "value")])
with pytest.raises(TypeError, match=".value'"):
FullySpecializedGenericModel(value=1)
with pytest.raises(TypeError, match=".value'"):
FullySpecializedGenericModel(value=(1, 2))
with pytest.raises(TypeError, match=".value'"):
PartiallySpecializedGenericModel(value=1)
FullySpecializedGenericModel(value=[1.0])
with pytest.raises(TypeError, match=".value'"):
PartiallySpecializedGenericModel(value=(1, 2))
FullySpecializedGenericModel(value=["1"])
with pytest.raises(TypeError, match=".value'"):
PartiallySpecializedGenericModel(value=[1.0])
FullySpecializedGenericModel(value=1)
with pytest.raises(TypeError, match=".value'"):
PartiallySpecializedGenericModel(value=["1"])

# TODO(egparedes): after fixing partial nested datamodel specialization
# noqa: e800 FullySpecializedGenericModel = PartiallySpecializedGenericModel[str]
# noqa: e800 print(f"{FullySpecializedGenericModel.__datamodel_fields__=}")
# noqa: e800 print(f"{FullySpecializedGenericModel.__parameters__=}")
# noqa: e800 print(f"{FullySpecializedGenericModel.__args__=}")

# noqa: e800 FullySpecializedGenericModel(value=[])
# noqa: e800 FullySpecializedGenericModel(value=[(1, "value")])
# noqa: e800 with pytest.raises(TypeError, match=".value'"):
# noqa: e800 FullySpecializedGenericModel(value=1)
# noqa: e800 with pytest.raises(TypeError, match=".value'"):
# noqa: e800 FullySpecializedGenericModel(value=(1, 2))
# noqa: e800 with pytest.raises(TypeError, match=".value'"):
# noqa: e800 FullySpecializedGenericModel(value=[1.0])
# noqa: e800 with pytest.raises(TypeError, match=".value'"):
# noqa: e800 FullySpecializedGenericModel(value=["1"])
# noqa: e800 with pytest.raises(TypeError, match=".value'"):
# noqa: e800 FullySpecializedGenericModel(value=1)
# noqa: e800 with pytest.raises(TypeError, match=".value'"):
# noqa: e800 FullySpecializedGenericModel(value=[(1, 2)])
# noqa: e800 with pytest.raises(TypeError, match=".value'"):
# noqa: e800 FullySpecializedGenericModel(value=[(1, (11, 12))])
FullySpecializedGenericModel(value=[(1, 2)])
with pytest.raises(TypeError, match=".value'"):
FullySpecializedGenericModel(value=[(1, (11, 12))])

def test_partial_concretization_with_typevar(self):
class PartialGenericModel(datamodels.GenericDataModel, Generic[T]):
a: T
values: List[T]

B = TypeVar("B", bound=numbers.Number)
PartiallyConcretizedGenericModel = PartialGenericModel[B]

PartiallyConcretizedGenericModel(a=1, values=[2, 3])
PartiallyConcretizedGenericModel(a=-1.32, values=[2.2, 3j])

with pytest.raises(TypeError, match=".a'"):
PartiallyConcretizedGenericModel(a="1", values=[2, 3])
with pytest.raises(TypeError, match=".values'"):
PartiallyConcretizedGenericModel(a=1, values=[1, "2"])
with pytest.raises(TypeError, match=".values'"):
PartiallyConcretizedGenericModel(a=1, values=(1, 2))

# Reuse sample_type_data from test_field_type_hint
@pytest.mark.parametrize(["type_hint", "valid_values", "wrong_values"], SAMPLE_TYPE_DATA)
Expand Down

0 comments on commit d34836f

Please sign in to comment.