From c95d9c6b4ddee1d7874cbec8282bd36a83075e0b Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Tue, 3 Nov 2020 17:38:26 +0100 Subject: [PATCH 01/17] Change Node concept --- setup.cfg | 2 +- src/eve/concepts.py | 123 ++++++++++++++++++++++++++------------------ src/eve/visitors.py | 4 +- 3 files changed, 75 insertions(+), 54 deletions(-) diff --git a/setup.cfg b/setup.cfg index 8d13861..f00b3ec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = numpy>=1.17 packaging>=20.0 pybind11>=2.5 - pydantic>=1.5 + pydantic>=1.7.2 typing_extensions>=3.4 typing_inspect>=0.6.0 xxhash>=1.4.4 diff --git a/src/eve/concepts.py b/src/eve/concepts.py index 811c09c..74c384e 100644 --- a/src/eve/concepts.py +++ b/src/eve/concepts.py @@ -21,8 +21,10 @@ import collections.abc import functools +import types import pydantic +import pydantic.generics from . import type_definitions, utils from .type_definitions import NOTHING, IntEnum, Str, StrEnum @@ -37,7 +39,6 @@ Optional, Set, Tuple, - Type, TypedDict, TypeVar, Union, @@ -46,13 +47,6 @@ # -- Fields -- -class ImplFieldMetadataDict(TypedDict, total=False): - info: pydantic.fields.FieldInfo - - -NodeImplFieldMetadataDict = Dict[str, ImplFieldMetadataDict] - - class FieldKind(StrEnum): INPUT = "input" OUTPUT = "output" @@ -106,6 +100,8 @@ def field( # -- Models -- class BaseModelConfig: extra = "forbid" + underscore_attrs_are_private = True + # TODO(egparedes): class attributes with '_attrs_' substring seems to break sphinx-autodoc class FrozenModelConfig(BaseModelConfig): @@ -124,7 +120,7 @@ class Config(FrozenModelConfig): # -- Nodes -- _EVE_NODE_INTERNAL_SUFFIX = "__" -_EVE_NODE_IMPL_SUFFIX = "_" +_EVE_NODE_ANNOTATION_SUFFIX = "_" AnyNode = TypeVar("AnyNode", bound="BaseNode") ValueNode = Union[bool, bytes, int, float, str, IntEnum, StrEnum] @@ -146,18 +142,14 @@ def __new__(mcls, name, bases, namespace, **kwargs): # Postprocess created class: # Add metadata class members - impl_fields_metadata = {} children_metadata = {} for name, model_field in cls.__fields__.items(): - if name.endswith(_EVE_NODE_IMPL_SUFFIX): - impl_fields_metadata[name] = {"definition": model_field} - elif not name.endswith(_EVE_NODE_INTERNAL_SUFFIX): + if not name.endswith(_EVE_NODE_INTERNAL_SUFFIX): children_metadata[name] = { "definition": model_field, **model_field.field_info.extra.get(_EVE_METADATA_KEY, {}), } - cls.__node_impl_fields__ = impl_fields_metadata cls.__node_children__ = children_metadata return cls @@ -166,7 +158,7 @@ def __new__(mcls, name, bases, namespace, **kwargs): class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): """Base class representing an IR node. - It is currently implemented as a pydantic Model with some extra features. + A node is currently implemented as a pydantic Model with some extra features. Field values should be either: @@ -174,60 +166,89 @@ class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): * enum.Enum types * other :class:`Node` subclasses * other :class:`pydantic.BaseModel` subclasses - * supported collections (:class:`List`, :class:`Dict`, :class:`Set`) - of any of the previous items - - Field naming scheme: - - * Field names starting with "_" are ignored by pydantic and Eve. They - will not be considered as `fields` and thus none of the pydantic - features will work (type coercion, validators, etc.). - * Field names ending with "__" are reserved for internal Eve use and - should NOT be defined by regular users. All pydantic features will - work on these fields anyway but they will be invisible for Eve users. - * Field names ending with "_" are considered implementation fields - not children nodes. They are intended to be defined by users when needed, - typically to cache derived, non-essential information on the node. + * supported collections (:class:`Tuple`, :class:`List`, :class:`Dict`, :class:`Set`) + of any of the previous items + + Class members naming scheme: + + * Member names starting with ``_`` are transformed into `private attributes` + by pydantic and thus ignored by Eve. Since none of the pydantic features + will work on them (type coercion, validators, etc.), it is not recommended + to define new pydantic private attributes in the nodes. + * Member names ending with ``_`` are considered node data annotations, not children + nodes. They are intended to be used by the user, typically to cache derived, + non-essential information on the node, and they can be assigned directly + without a explicit definition in the class body (which will consequently + trigger an error). + * Member names ending with ``__`` are reserved for internal Eve use and + should NOT be defined by regular users. All pydantic features will + work on these fields anyway but they will be invisible in Eve nodes. """ - __node_impl_fields__: ClassVar[NodeImplFieldMetadataDict] __node_children__: ClassVar[NodeChildrenMetadataDict] - # Node fields - #: Unique node-id (implementation field) - id_: Optional[Str] = None + # Node private attributes + #: Unique node-id + __node_id__: Optional[str] = pydantic.PrivateAttr( # type: ignore # mypy can't find PrivateAttr + default_factory=utils.UIDGenerator.sequential_id + ) - @pydantic.validator("id_", pre=True, always=True) - def _id_validator(cls: Type[AnyNode], v: Optional[str]) -> str: # type: ignore # validators are classmethods - if v is None: - v = utils.UIDGenerator.sequential_id(prefix=cls.__qualname__) - if not isinstance(v, str): - raise TypeError(f"id_ is not an 'str' instance ({type(v)})") - return v - - def iter_impl_fields(self) -> Generator[Tuple[str, Any], None, None]: - for name, _ in self.__fields__.items(): - if name.endswith(_EVE_NODE_IMPL_SUFFIX) and not name.endswith( - _EVE_NODE_INTERNAL_SUFFIX - ): - yield name, getattr(self, name) + #: Node analysis annotations + __node_annotations__: Optional[Str] = pydantic.PrivateAttr( # type: ignore # mypy can't find PrivateAttr + default_factory=types.SimpleNamespace + ) def iter_children(self) -> Generator[Tuple[str, Any], None, None]: for name, _ in self.__fields__.items(): - if not ( - name.endswith(_EVE_NODE_IMPL_SUFFIX) or name.endswith(_EVE_NODE_INTERNAL_SUFFIX) - ): + if not name.endswith(_EVE_NODE_INTERNAL_SUFFIX): yield name, getattr(self, name) + def iter_children_names(self) -> Generator[str, None, None]: + for name, _ in self.iter_children(): + yield name + def iter_children_values(self) -> Generator[Any, None, None]: for _, node in self.iter_children(): yield node + @property + def data_annotations(self) -> Dict[str, Any]: + return self.__node_annotations__.__dict__ + + @property + def private_attrs_names(self) -> Tuple[str, ...]: + return self.__slots__ + + def __getattr__(self, name: str) -> Any: + return type(self)._get_attr_owner(self, name).__getattribute__(name) + + def __setattr__(self, name: str, value: Any) -> None: + type(self)._get_attr_owner(self, name).__setattr__(name, value) + + def __delattr__(self, name: str) -> None: + type(self)._get_attr_owner(self, name).__delattr__(name) + + @staticmethod + def _get_attr_owner(instance: BaseNode, name: str) -> Any: + attr_caller = super(BaseNode, instance) + if name.endswith(_EVE_NODE_ANNOTATION_SUFFIX) and not name.endswith( + _EVE_NODE_INTERNAL_SUFFIX + ): + attr_caller = attr_caller.__getattribute__("__node_annotations__") + + return attr_caller + class Config(BaseModelConfig): pass +class GenericNode(BaseNode, pydantic.generics.GenericModel): + """Base generic node class.""" + + pass + + class Node(BaseNode): """Default public name for a base node class.""" @@ -235,7 +256,7 @@ class Node(BaseNode): class FrozenNode(Node): - """Default public name for an inmutable base node class.""" + """Default public name for an immutable base node class.""" class Config(FrozenModelConfig): pass diff --git a/src/eve/visitors.py b/src/eve/visitors.py index dfd8ee7..066b81c 100644 --- a/src/eve/visitors.py +++ b/src/eve/visitors.py @@ -254,14 +254,14 @@ def del_op(container: MutableSet, idx: int) -> None: del_op = operator.delitem elif isinstance(node, (collections.abc.Sequence, collections.abc.Set)): - # Inmutable sequence or set: create a new container instance with the new values + # Immutable sequence or set: create a new container instance with the new values tmp_items = [self.visit(value, **kwargs) for value in node] result = node.__class__( # type: ignore [value for value in tmp_items if value is not concepts.NOTHING] ) elif isinstance(node, collections.abc.Mapping): - # Inmutable mapping: create a new mapping instance with the new values + # Immutable mapping: create a new mapping instance with the new values tmp_items = {key: self.visit(value, **kwargs) for key, value in node.items()} result = node.__class__( # type: ignore { From 6248dbfed9235e669dc468ef5d78b4951b0d0801 Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Wed, 4 Nov 2020 13:04:50 +0100 Subject: [PATCH 02/17] Fix test_concepts --- src/eve/concepts.py | 40 ++++--- tests/tests_eve/common.py | 67 ++++++----- tests/tests_eve/unit_tests/test_concepts.py | 116 +++++++++++--------- 3 files changed, 125 insertions(+), 98 deletions(-) diff --git a/src/eve/concepts.py b/src/eve/concepts.py index 74c384e..fd4e56b 100644 --- a/src/eve/concepts.py +++ b/src/eve/concepts.py @@ -101,7 +101,7 @@ def field( class BaseModelConfig: extra = "forbid" underscore_attrs_are_private = True - # TODO(egparedes): class attributes with '_attrs_' substring seems to break sphinx-autodoc + # TODO(egparedes): setting 'underscore_attrs_are_private' to True breaks sphinx-autodoc class FrozenModelConfig(BaseModelConfig): @@ -119,15 +119,28 @@ class Config(FrozenModelConfig): # -- Nodes -- -_EVE_NODE_INTERNAL_SUFFIX = "__" -_EVE_NODE_ANNOTATION_SUFFIX = "_" - AnyNode = TypeVar("AnyNode", bound="BaseNode") ValueNode = Union[bool, bytes, int, float, str, IntEnum, StrEnum] LeafNode = Union[AnyNode, ValueNode] TreeNode = Union[AnyNode, Union[List[LeafNode], Dict[Any, LeafNode], Set[LeafNode]]] +def _is_data_annotation_name(name: str) -> bool: + return name.endswith("_") and not name.startswith("_") + + +def _is_child_field_name(name: str) -> bool: + return not name.startswith("_") and not name.endswith("_") + + +def _is_internal_field_name(name: str) -> bool: + return name.endswith("__") and not name.startswith("_") + + +def _is_private_attr_name(name: str) -> bool: + return name.startswith("_") + + class NodeMetaclass(pydantic.main.ModelMetaclass): """Custom metaclass for Node classes. @@ -144,7 +157,7 @@ def __new__(mcls, name, bases, namespace, **kwargs): # Add metadata class members children_metadata = {} for name, model_field in cls.__fields__.items(): - if not name.endswith(_EVE_NODE_INTERNAL_SUFFIX): + if _is_child_field_name(name): children_metadata[name] = { "definition": model_field, **model_field.field_info.extra.get(_EVE_METADATA_KEY, {}), @@ -201,7 +214,7 @@ class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): def iter_children(self) -> Generator[Tuple[str, Any], None, None]: for name, _ in self.__fields__.items(): - if not name.endswith(_EVE_NODE_INTERNAL_SUFFIX): + if _is_child_field_name(name): yield name, getattr(self, name) def iter_children_names(self) -> Generator[str, None, None]: @@ -217,8 +230,8 @@ def data_annotations(self) -> Dict[str, Any]: return self.__node_annotations__.__dict__ @property - def private_attrs_names(self) -> Tuple[str, ...]: - return self.__slots__ + def private_attrs_names(self) -> Set[str]: + return set(self.__slots__) - {"__doc__"} def __getattr__(self, name: str) -> Any: return type(self)._get_attr_owner(self, name).__getattribute__(name) @@ -231,13 +244,10 @@ def __delattr__(self, name: str) -> None: @staticmethod def _get_attr_owner(instance: BaseNode, name: str) -> Any: - attr_caller = super(BaseNode, instance) - if name.endswith(_EVE_NODE_ANNOTATION_SUFFIX) and not name.endswith( - _EVE_NODE_INTERNAL_SUFFIX - ): - attr_caller = attr_caller.__getattribute__("__node_annotations__") - - return attr_caller + if _is_data_annotation_name(name): + return super(BaseNode, instance).__getattribute__("__node_annotations__") + else: + return super(BaseNode, instance) class Config(BaseModelConfig): pass diff --git a/tests/tests_eve/common.py b/tests/tests_eve/common.py index 9f8da4d..2d92b06 100644 --- a/tests/tests_eve/common.py +++ b/tests/tests_eve/common.py @@ -177,8 +177,8 @@ class LocationNode(Node): class SimpleNode(Node): - bool_value: Bool int_value: Int + bool_value: Bool float_value: Float str_value: Str bytes_value: Bytes @@ -187,17 +187,11 @@ class SimpleNode(Node): class SimpleNodeWithOptionals(Node): - int_value: Optional[Int] + int_value: Int float_value: Optional[Float] str_value: Optional[Str] -class SimpleNodeWithImplMembers(Node): - value_impl_: Int - int_value: Int - another_value_impl_: Int - - class SimpleNodeWithLoc(Node): int_value: Int float_value: Float @@ -206,6 +200,7 @@ class SimpleNodeWithLoc(Node): class SimpleNodeWithCollections(Node): + int_value: Int int_list: List[Int] str_set: Set[Str] str_to_int_dict: Dict[Str, Int] @@ -213,6 +208,7 @@ class SimpleNodeWithCollections(Node): class SimpleNodeWithAbstractCollections(Node): + int_value: Int int_sequence: Sequence[Int] str_set: Set[Str] str_to_int_mapping: Mapping[Str, Int] @@ -220,6 +216,7 @@ class SimpleNodeWithAbstractCollections(Node): class CompoundNode(Node): + int_value: Int location: LocationNode simple: SimpleNode simple_loc: SimpleNodeWithLoc @@ -228,8 +225,8 @@ class CompoundNode(Node): class FrozenSimpleNode(FrozenNode): - bool_value: Bool int_value: Int + bool_value: Bool float_value: Float str_value: Str bytes_value: Bytes @@ -258,8 +255,8 @@ def make_location_node(fixed: bool = False) -> LocationNode: def make_simple_node(fixed: bool = False) -> SimpleNode: factories = Factories if fixed else RandomFactories - bool_value = factories.make_bool() int_value = factories.make_int() + bool_value = factories.make_bool() float_value = factories.make_float() str_value = factories.make_str() bytes_value = factories.make_str().encode() @@ -267,8 +264,8 @@ def make_simple_node(fixed: bool = False) -> SimpleNode: str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) return SimpleNode( - bool_value=bool_value, int_value=int_value, + bool_value=bool_value, float_value=float_value, str_value=str_value, bytes_value=bytes_value, @@ -285,15 +282,6 @@ def make_simple_node_with_optionals(fixed: bool = False) -> SimpleNodeWithOption return SimpleNodeWithOptionals(int_value=int_value, float_value=float_value) -def make_simple_node_with_impl_members(fixed: bool = False) -> SimpleNodeWithImplMembers: - factories = Factories if fixed else RandomFactories - int_value = factories.make_int() - - return SimpleNodeWithImplMembers( - value_impl_=int_value, int_value=int_value, another_value_impl_=int_value - ) - - def make_simple_node_with_loc(fixed: bool = False) -> SimpleNodeWithLoc: factories = Factories if fixed else RandomFactories int_value = factories.make_int() @@ -308,13 +296,18 @@ def make_simple_node_with_loc(fixed: bool = False) -> SimpleNodeWithLoc: def make_simple_node_with_collections(fixed: bool = False) -> SimpleNodeWithCollections: factories = Factories if fixed else RandomFactories + int_value = factories.make_int() int_list = factories.make_collection(int, length=3) str_set = factories.make_collection(str, set, length=3) str_to_int_dict = factories.make_mapping(key_type=str, value_type=int, length=3) loc = make_source_location(fixed) return SimpleNodeWithCollections( - int_list=int_list, str_set=str_set, str_to_int_dict=str_to_int_dict, loc=loc + int_value=int_value, + int_list=int_list, + str_set=str_set, + str_to_int_dict=str_to_int_dict, + loc=loc, ) @@ -322,17 +315,23 @@ def make_simple_node_with_abstractcollections( fixed: bool = False, ) -> SimpleNodeWithAbstractCollections: factories = Factories if fixed else RandomFactories + int_value = factories.make_int() int_sequence = factories.make_collection(int, collection_type=tuple, length=3) str_set = factories.make_collection(str, set, length=3) str_to_int_mapping = factories.make_mapping(key_type=str, value_type=int, length=3) return SimpleNodeWithAbstractCollections( - int_sequence=int_sequence, str_set=str_set, str_to_int_mapping=str_to_int_mapping + int_value=int_value, + int_sequence=int_sequence, + str_set=str_set, + str_to_int_mapping=str_to_int_mapping, ) def make_compound_node(fixed: bool = False) -> CompoundNode: + factories = Factories if fixed else RandomFactories return CompoundNode( + int_value=factories.make_int(), location=make_location_node(), simple=make_simple_node(), simple_loc=make_simple_node_with_loc(), @@ -343,8 +342,8 @@ def make_compound_node(fixed: bool = False) -> CompoundNode: def make_frozen_simple_node(fixed: bool = False) -> FrozenSimpleNode: factories = Factories if fixed else RandomFactories - bool_value = factories.make_bool() int_value = factories.make_int() + bool_value = factories.make_bool() float_value = factories.make_float() str_value = factories.make_str() bytes_value = factories.make_str().encode() @@ -352,8 +351,8 @@ def make_frozen_simple_node(fixed: bool = False) -> FrozenSimpleNode: str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) return FrozenSimpleNode( - bool_value=bool_value, int_value=int_value, + bool_value=bool_value, float_value=float_value, str_value=str_value, bytes_value=bytes_value, @@ -369,8 +368,8 @@ def make_invalid_location_node(fixed: bool = False) -> LocationNode: def make_invalid_at_int_simple_node(fixed: bool = False) -> SimpleNode: factories = Factories if fixed else RandomFactories - bool_value = factories.make_bool() int_value = factories.make_float() + bool_value = factories.make_bool() float_value = factories.make_float() bytes_value = factories.make_str().encode() str_value = factories.make_str() @@ -378,8 +377,8 @@ def make_invalid_at_int_simple_node(fixed: bool = False) -> SimpleNode: str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) return SimpleNode( - bool_value=bool_value, int_value=int_value, + bool_value=bool_value, float_value=float_value, str_value=str_value, bytes_value=bytes_value, @@ -390,8 +389,8 @@ def make_invalid_at_int_simple_node(fixed: bool = False) -> SimpleNode: def make_invalid_at_float_simple_node(fixed: bool = False) -> SimpleNode: factories = Factories if fixed else RandomFactories - bool_value = factories.make_bool() int_value = factories.make_int() + bool_value = factories.make_bool() float_value = factories.make_int() str_value = factories.make_str() bytes_value = factories.make_str().encode() @@ -399,8 +398,8 @@ def make_invalid_at_float_simple_node(fixed: bool = False) -> SimpleNode: str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) return SimpleNode( - bool_value=bool_value, int_value=int_value, + bool_value=bool_value, float_value=float_value, str_value=str_value, bytes_value=bytes_value, @@ -411,8 +410,8 @@ def make_invalid_at_float_simple_node(fixed: bool = False) -> SimpleNode: def make_invalid_at_str_simple_node(fixed: bool = False) -> SimpleNode: factories = Factories if fixed else RandomFactories - bool_value = factories.make_bool() int_value = factories.make_int() + bool_value = factories.make_bool() float_value = factories.make_float() str_value = factories.make_float() bytes_value = factories.make_str().encode() @@ -420,8 +419,8 @@ def make_invalid_at_str_simple_node(fixed: bool = False) -> SimpleNode: str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) return SimpleNode( - bool_value=bool_value, int_value=int_value, + bool_value=bool_value, float_value=float_value, str_value=str_value, bytes_value=bytes_value, @@ -432,8 +431,8 @@ def make_invalid_at_str_simple_node(fixed: bool = False) -> SimpleNode: def make_invalid_at_bytes_simple_node(fixed: bool = False) -> SimpleNode: factories = Factories if fixed else RandomFactories - bool_value = factories.make_bool() int_value = factories.make_int() + bool_value = factories.make_bool() float_value = factories.make_float() str_value = factories.make_float() bytes_value = [1, "2", (3, 4)] @@ -441,8 +440,8 @@ def make_invalid_at_bytes_simple_node(fixed: bool = False) -> SimpleNode: str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) return SimpleNode( - bool_value=bool_value, int_value=int_value, + bool_value=bool_value, float_value=float_value, str_value=str_value, bytes_value=bytes_value, @@ -453,8 +452,8 @@ def make_invalid_at_bytes_simple_node(fixed: bool = False) -> SimpleNode: def make_invalid_at_enum_simple_node(fixed: bool = False) -> SimpleNode: factories = Factories if fixed else RandomFactories - bool_value = factories.make_bool() int_value = factories.make_int() + bool_value = factories.make_bool() float_value = factories.make_float() str_value = factories.make_float() bytes_value = factories.make_str().encode() @@ -462,8 +461,8 @@ def make_invalid_at_enum_simple_node(fixed: bool = False) -> SimpleNode: str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) return SimpleNode( - bool_value=bool_value, int_value=int_value, + bool_value=bool_value, float_value=float_value, str_value=str_value, bytes_value=bytes_value, diff --git a/tests/tests_eve/unit_tests/test_concepts.py b/tests/tests_eve/unit_tests/test_concepts.py index d77ab9e..4570497 100644 --- a/tests/tests_eve/unit_tests/test_concepts.py +++ b/tests/tests_eve/unit_tests/test_concepts.py @@ -18,84 +18,102 @@ import pydantic import pytest +import eve + from .. import common +@pytest.fixture( + params=[ + "valid_annotation_", + "v_", + "v3212_32_", + "VV_VV_00_", + "invalid_annotation__", + "_invalid_annotation__", + "__invalid_annotation__", + "_", + ] +) +def annotation_name(request): + yield request.param + + +@pytest.fixture(params=["data", 0, 1.1, None, [1, 2, 3], {"a": 1, 1: "a"}, lambda x: x]) +def annotation_value(request): + yield request.param + + class TestNode: def test_validation(self, invalid_sample_node_maker): with pytest.raises(pydantic.ValidationError): invalid_sample_node_maker() def test_mutability(self, sample_node): - sample_node.id_ = None + if "value" in sample_node.__fields__: + sample_node.int_value = 123456 def test_inmutability(self, frozen_sample_node): with pytest.raises(TypeError): - frozen_sample_node.id_ = None - - def test_unique_id(self, sample_node_maker): - node_a = sample_node_maker() - node_b = sample_node_maker() - node_c = sample_node_maker() - - assert node_a.id_ != node_b.id_ != node_c.id_ + frozen_sample_node.int_value = 123456 - def test_custom_id(self, source_location, sample_node_maker): - custom_id = "my_custom_id" - my_node = common.LocationNode(id_=custom_id, loc=source_location) - other_node = sample_node_maker() - - assert my_node.id_ == custom_id - assert my_node.id_ != other_node.id_ - - with pytest.raises(pydantic.ValidationError, match="id_"): - common.LocationNode(id_=32, loc=source_location) - - def test_impl_fields(self, sample_node): - impl_names = set(name for name, _ in sample_node.iter_impl_fields()) - - assert all(name.endswith("_") and not name.endswith("__") for name in impl_names) - assert ( - set( - name - for name in sample_node.__fields__.keys() - if name.endswith("_") and not name.endswith("__") - ) - == impl_names + def test_private_attrs(self, sample_node): + assert all( + eve.concepts._is_private_attr_name(name) for name in sample_node.private_attrs_names ) + assert sample_node.private_attrs_names >= {"__node_id__", "__node_annotations__"} def test_children(self, sample_node): - impl_field_names = set(name for name, _ in sample_node.iter_impl_fields()) - children_names = set(name for name, _ in sample_node.iter_children()) - public_names = impl_field_names | children_names + children_names = set(name for name in sample_node.iter_children_names()) field_names = set(sample_node.__fields__.keys()) - assert not any(name.endswith("__") for name in children_names) - assert not any(name.endswith("_") for name in children_names) - - assert public_names <= field_names - assert all(name.endswith("_") for name in field_names - public_names) - + assert all(eve.concepts._is_child_field_name(name) for name in children_names) + assert children_names <= field_names assert all( - node1 is node2 - for (name, node1), node2 in zip( - sample_node.iter_children(), sample_node.iter_children_values() - ) + eve.concepts._is_internal_field_name(name) for name in field_names - children_names ) - def test_node_metadata(self, sample_node): assert all( - name in sample_node.__node_impl_fields__ for name, _ in sample_node.iter_impl_fields() + name1 is name2 + for (name1, _), name2 in zip( + sample_node.iter_children(), sample_node.iter_children_names() + ) ) assert all( - isinstance(metadata, dict) - and isinstance(metadata["definition"], pydantic.fields.ModelField) - for metadata in sample_node.__node_impl_fields__.values() + node1 is node2 + for (_, node1), node2 in zip( + sample_node.iter_children(), sample_node.iter_children_values() + ) ) + def test_node_annotations(self, sample_node, annotation_name, annotation_value): + if eve.concepts._is_data_annotation_name(annotation_name): + setattr(sample_node, annotation_name, annotation_value) + assert getattr(sample_node, annotation_name) == annotation_value + else: + with pytest.raises(ValueError, match=f'has no field "{annotation_name}"'): + setattr(sample_node, annotation_name, annotation_value) + + def test_node_metadata(self, sample_node): assert all(name in sample_node.__node_children__ for name, _ in sample_node.iter_children()) assert all( isinstance(metadata, dict) and isinstance(metadata["definition"], pydantic.fields.ModelField) for metadata in sample_node.__node_children__.values() ) + + def test_unique_id(self, sample_node_maker): + node_a = sample_node_maker() + node_b = sample_node_maker() + node_c = sample_node_maker() + + assert node_a.__node_id__ != node_b.__node_id__ != node_c.__node_id__ + + def test_custom_id(self, source_location, sample_node_maker): + custom_id = "my_custom_id" + my_node = common.LocationNode(loc=source_location) + other_node = sample_node_maker() + my_node.__node_id__ = custom_id + + assert my_node.__node_id__ == custom_id + assert my_node.__node_id__ != other_node.__node_id__ From 39272a61c5e1e927ce577b7fc0209fc619fc02a7 Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Wed, 4 Nov 2020 14:47:52 +0100 Subject: [PATCH 03/17] Fix codegen and visitors --- src/eve/codegen.py | 23 ++++++----------------- src/eve/visitors.py | 1 - 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/src/eve/codegen.py b/src/eve/codegen.py index f2ad8bc..ff051f4 100644 --- a/src/eve/codegen.py +++ b/src/eve/codegen.py @@ -447,12 +447,11 @@ class TemplatedGenerator(NodeVisitor): The following keys are passed to template instances at rendering: - * ``**node_fields``: all the node children and implementation fields by name. - * ``_impl``: a ``dict`` instance with the results of visiting all - the node implementation fields. - * ``_children``: a ``dict`` instance with the results of visiting all - the node children. - * ``_this_node``: the actual node instance (before visiting children). + * ``**node_fields``: the results of visiting all the node children directly + available by name. + * ``_children``: the results of visiting all the node children collected + in a single ``dict`` instance. + * ``_this_node``: the actual node instance. * ``_this_generator``: the current generator instance. * ``_this_module``: the generator's module instance. * ``**kwargs``: the keyword arguments received by the visiting method. @@ -525,11 +524,7 @@ def generic_visit(self, node: TreeNode, **kwargs: Any) -> Union[str, Collection[ template, _ = self.get_template(node) if template: result = self.render_template( - template, - node, - self.transform_children(node, **kwargs), - self.transform_impl_fields(node, **kwargs), - **kwargs, + template, node, self.transform_children(node, **kwargs), **kwargs, ) elif isinstance(node, (collections.abc.Sequence, collections.abc.Set)) and not isinstance( node, type_definitions.ATOMIC_COLLECTION_TYPES @@ -560,16 +555,13 @@ def render_template( template: Template, node: Node, transformed_children: Mapping[str, Any], - transformed_impl_fields: Mapping[str, Any], **kwargs: Any, ) -> str: """Render a template using node instance data (see class documentation).""" return template.render( **transformed_children, - **transformed_impl_fields, _children=transformed_children, - _impl=transformed_impl_fields, _this_node=node, _this_generator=self, _this_module=sys.modules[type(self).__module__], @@ -578,6 +570,3 @@ def render_template( def transform_children(self, node: Node, **kwargs: Any) -> Dict[str, Any]: return {key: self.visit(value, **kwargs) for key, value in node.iter_children()} - - def transform_impl_fields(self, node: Node, **kwargs: Any) -> Dict[str, Any]: - return {key: self.visit(value, **kwargs) for key, value in node.iter_impl_fields()} diff --git a/src/eve/visitors.py b/src/eve/visitors.py index 066b81c..5b4e7f0 100644 --- a/src/eve/visitors.py +++ b/src/eve/visitors.py @@ -160,7 +160,6 @@ def generic_visit(self, node: concepts.TreeNode, **kwargs: Any) -> Any: key: self.visit(value, **kwargs) for key, value in node.iter_children() } result = node.__class__( # type: ignore - **{key: value for key, value in node.iter_impl_fields()}, **{key: value for key, value in tmp_items.items() if value is not NOTHING}, ) From 84625cc54dea99616a597a22abb6659250bda3ed Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Wed, 4 Nov 2020 16:05:55 +0100 Subject: [PATCH 04/17] Fixes --- src/eve/concepts.py | 38 ++++++++++++--------- tests/tests_eve/unit_tests/test_concepts.py | 17 +++++++++ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/src/eve/concepts.py b/src/eve/concepts.py index fd4e56b..82e057e 100644 --- a/src/eve/concepts.py +++ b/src/eve/concepts.py @@ -126,11 +126,11 @@ class Config(FrozenModelConfig): def _is_data_annotation_name(name: str) -> bool: - return name.endswith("_") and not name.startswith("_") + return name.endswith("_") and not name.endswith("__") and not name.startswith("_") def _is_child_field_name(name: str) -> bool: - return not name.startswith("_") and not name.endswith("_") + return not name.endswith("_") and not name.startswith("_") def _is_internal_field_name(name: str) -> bool: @@ -157,6 +157,11 @@ def __new__(mcls, name, bases, namespace, **kwargs): # Add metadata class members children_metadata = {} for name, model_field in cls.__fields__.items(): + assert not _is_private_attr_name(name) + if _is_data_annotation_name(name): + raise TypeError(f"Invalid field name ('{name}') looks like a data annotation.") + if _is_internal_field_name(name): + raise TypeError(f"Invalid field name ('{name}') looks like an Eve internal field.") if _is_child_field_name(name): children_metadata[name] = { "definition": model_field, @@ -182,20 +187,21 @@ class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): * supported collections (:class:`Tuple`, :class:`List`, :class:`Dict`, :class:`Set`) of any of the previous items - Class members naming scheme: - - * Member names starting with ``_`` are transformed into `private attributes` - by pydantic and thus ignored by Eve. Since none of the pydantic features - will work on them (type coercion, validators, etc.), it is not recommended - to define new pydantic private attributes in the nodes. - * Member names ending with ``_`` are considered node data annotations, not children - nodes. They are intended to be used by the user, typically to cache derived, - non-essential information on the node, and they can be assigned directly - without a explicit definition in the class body (which will consequently - trigger an error). - * Member names ending with ``__`` are reserved for internal Eve use and - should NOT be defined by regular users. All pydantic features will - work on these fields anyway but they will be invisible in Eve nodes. + Naming semantics for Node members: + + * Member names starting with ``_`` (e.g. ``_private_attr``) are transformed into + `private attributes` by pydantic and thus ignored by Eve. Since none of the + pydantic features will work on them (type coercion, validators, etc.), it is + not recommended to define new pydantic private attributes in the nodes. + * Member names ending with ``_`` (and not starting with ``_``, e.g. ``my_data_``) + are considered node data annotations, not children nodes. They are intended + to be used by the user, typically to cache derived, non-essential information + on the node, and they can be assigned directly without a explicit definition + in the class body (which will consequently trigger an error). + * Member names ending with ``__`` (and not starting with ``_``, e.g. ``internal__``) + are reserved for internal Eve use and should NOT be defined by regular users. + All pydantic features will work on these fields anyway but they will be + not visible visible in Eve nodes. """ diff --git a/tests/tests_eve/unit_tests/test_concepts.py b/tests/tests_eve/unit_tests/test_concepts.py index 4570497..1288540 100644 --- a/tests/tests_eve/unit_tests/test_concepts.py +++ b/tests/tests_eve/unit_tests/test_concepts.py @@ -57,6 +57,23 @@ def test_inmutability(self, frozen_sample_node): with pytest.raises(TypeError): frozen_sample_node.int_value = 123456 + def test_field_naming(self): + class NodeWithPrivateAttrs(eve.concepts.BaseNode): + _private_int: int = 0 + int_value: int = 1 + + assert len(NodeWithPrivateAttrs().__fields__) == 1 + + with pytest.raises(TypeError, match="data annotation"): + + class NodeWithDataAnnotationNamedFields(eve.concepts.BaseNode): + int_value_: int + + with pytest.raises(TypeError, match="internal field"): + + class NodeWithInternalNamedFields(eve.concepts.BaseNode): + int_value__: int + def test_private_attrs(self, sample_node): assert all( eve.concepts._is_private_attr_name(name) for name in sample_node.private_attrs_names From 66f409188642db85cafe71402d9dc3013a72ffa0 Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Wed, 4 Nov 2020 16:06:25 +0100 Subject: [PATCH 05/17] Adapt unstructured node definitions to new naming conventions --- src/gt_frontend/gtscript_ast.py | 2 +- src/gt_frontend/gtscript_to_gtir.py | 8 ++++---- src/gtc/unstructured/usid.py | 6 ++---- src/gtc/unstructured/usid_codegen.py | 4 ++-- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/gt_frontend/gtscript_ast.py b/src/gt_frontend/gtscript_ast.py index ca90502..31f74c7 100644 --- a/src/gt_frontend/gtscript_ast.py +++ b/src/gt_frontend/gtscript_ast.py @@ -161,7 +161,7 @@ class Pass(Statement): class Argument(GTScriptASTNode): name: str - type_: Union[Symbol, Union[SubscriptMultiple, SubscriptSingle]] + arg_type: Union[Symbol, Union[SubscriptMultiple, SubscriptSingle]] # is_keyword: bool diff --git a/src/gt_frontend/gtscript_to_gtir.py b/src/gt_frontend/gtscript_to_gtir.py index 6ab8c96..a37dc45 100644 --- a/src/gt_frontend/gtscript_to_gtir.py +++ b/src/gt_frontend/gtscript_to_gtir.py @@ -456,7 +456,7 @@ def visit_Stencil(self, node: Stencil, **kwargs) -> gtir.Stencil: @staticmethod def _transform_field_type(name, field_type): assert issubclass(field_type, Field) or issubclass(field_type, TemporaryField) - *location_types, vtype, = field_type.args + (*location_types, vtype,) = field_type.args assert isinstance(vtype, common.DataType) @@ -490,9 +490,9 @@ def visit_Computation(self, node: Computation) -> gtir.Computation: # parse temporary fields temporary_field_decls = [] - for name, type_ in self.symbol_table.types.items(): - if issubclass(type_, TemporaryField): - temporary_field_decls.append(self._transform_field_type(name, type_)) + for name, arg_type in self.symbol_table.types.items(): + if issubclass(arg_type, TemporaryField): + temporary_field_decls.append(self._transform_field_type(name, arg_type)) return gtir.Computation( name=node.name, diff --git a/src/gtc/unstructured/usid.py b/src/gtc/unstructured/usid.py index 4570096..174ebbe 100644 --- a/src/gtc/unstructured/usid.py +++ b/src/gtc/unstructured/usid.py @@ -158,13 +158,11 @@ def __eq__(self, other): class SidCompositeNeighborTableEntry(Node): connectivity: Str - connectivity_deref_: Optional[ - Connectivity - ] # TODO temporary workaround for symbol tbl reference + connectivity_deref: Optional[Connectivity] # TODO temporary workaround for symbol tbl reference @property def tag_name(self): - return self.connectivity_deref_.neighbor_tbl_tag + return self.connectivity_deref.neighbor_tbl_tag class Config(eve.concepts.FrozenModelConfig): pass diff --git a/src/gtc/unstructured/usid_codegen.py b/src/gtc/unstructured/usid_codegen.py index db3525a..b46d5de 100644 --- a/src/gtc/unstructured/usid_codegen.py +++ b/src/gtc/unstructured/usid_codegen.py @@ -41,7 +41,7 @@ class SymbolTblHelper(NodeTranslator): def visit_SidCompositeNeighborTableEntry(self, node: SidCompositeNeighborTableEntry, **kwargs): connectivity_deref = kwargs["symbol_tbl_conn"][node.connectivity] return SidCompositeNeighborTableEntry( - connectivity=node.connectivity, connectivity_deref_=connectivity_deref + connectivity=node.connectivity, connectivity_deref=connectivity_deref ) def visit_Kernel(self, node: Kernel, **kwargs): @@ -114,7 +114,7 @@ def location_type_from_dimensions(self, dimensions): ) SidCompositeNeighborTableEntry = as_fmt( - "gridtools::next::connectivity::neighbor_table({_this_node.connectivity_deref_.name})" + "gridtools::next::connectivity::neighbor_table({_this_node.connectivity_deref.name})" ) SidCompositeEntry = as_fmt("{name}") From 5481482a86316582b338f7057d5336e9f79e9322 Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Wed, 4 Nov 2020 16:32:30 +0100 Subject: [PATCH 06/17] Improve data annotation accessors --- src/eve/concepts.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/eve/concepts.py b/src/eve/concepts.py index 82e057e..3b052b5 100644 --- a/src/eve/concepts.py +++ b/src/eve/concepts.py @@ -240,20 +240,31 @@ def private_attrs_names(self) -> Set[str]: return set(self.__slots__) - {"__doc__"} def __getattr__(self, name: str) -> Any: - return type(self)._get_attr_owner(self, name).__getattribute__(name) + if _is_data_annotation_name(name): + try: + return super().__getattribute__("__node_annotations__").__getattribute__(name) + except AttributeError as e: + raise AttributeError(f"Invalid data annotation name: '{name}'") from e + else: + return super().__getattribute__(name) def __setattr__(self, name: str, value: Any) -> None: - type(self)._get_attr_owner(self, name).__setattr__(name, value) + if _is_data_annotation_name(name): + try: + super().__getattribute__("__node_annotations__").__setattr__(name, value) + except AttributeError as e: + raise AttributeError(f"Invalid data annotation name: '{name}'") from e + else: + super().__setattr__(name, value) def __delattr__(self, name: str) -> None: - type(self)._get_attr_owner(self, name).__delattr__(name) - - @staticmethod - def _get_attr_owner(instance: BaseNode, name: str) -> Any: if _is_data_annotation_name(name): - return super(BaseNode, instance).__getattribute__("__node_annotations__") + try: + super().__getattribute__("__node_annotations__").__delattr__(name) + except AttributeError as e: + raise AttributeError(f"Invalid data annotation name: '{name}'") from e else: - return super(BaseNode, instance) + super().__delattr__(name) class Config(BaseModelConfig): pass From 3490b2f4fe8a2cd9fd8f058da08c3372dd3b5d7f Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Wed, 4 Nov 2020 16:51:48 +0100 Subject: [PATCH 07/17] Partial fix for unstructured passes and tests --- src/gt_frontend/gtscript_to_gtir.py | 2 +- src/gtc/unstructured/gtir_to_nir.py | 2 +- .../unstructured/nir_passes/field_dependency_graph.py | 9 ++++----- src/gtc/unstructured/nir_to_usid.py | 2 +- tests/tests_gtc/test_nir_field_dependency_graph.py | 8 ++++---- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/gt_frontend/gtscript_to_gtir.py b/src/gt_frontend/gtscript_to_gtir.py index a37dc45..71eca84 100644 --- a/src/gt_frontend/gtscript_to_gtir.py +++ b/src/gt_frontend/gtscript_to_gtir.py @@ -456,7 +456,7 @@ def visit_Stencil(self, node: Stencil, **kwargs) -> gtir.Stencil: @staticmethod def _transform_field_type(name, field_type): assert issubclass(field_type, Field) or issubclass(field_type, TemporaryField) - (*location_types, vtype,) = field_type.args + *location_types, vtype = field_type.args assert isinstance(vtype, common.DataType) diff --git a/src/gtc/unstructured/gtir_to_nir.py b/src/gtc/unstructured/gtir_to_nir.py index 4e8e408..615ee94 100644 --- a/src/gtc/unstructured/gtir_to_nir.py +++ b/src/gtc/unstructured/gtir_to_nir.py @@ -123,7 +123,7 @@ def visit_NeighborReduce(self, node: gtir.NeighborReduce, *, last_block, **kwarg kwargs["location_comprehensions"] = loc_comprehension body_location = node.neighbors.chain.elements[-1] - reduce_var_name = "local" + str(node.id_) + reduce_var_name = "local" + str(node.__node_id__) last_block.declarations.append( nir.LocalVar( name=reduce_var_name, diff --git a/src/gtc/unstructured/nir_passes/field_dependency_graph.py b/src/gtc/unstructured/nir_passes/field_dependency_graph.py index 6d6da5a..654e21a 100644 --- a/src/gtc/unstructured/nir_passes/field_dependency_graph.py +++ b/src/gtc/unstructured/nir_passes/field_dependency_graph.py @@ -47,8 +47,7 @@ def __init__(self, **kwargs): @classmethod def generate(cls, loops, **kwargs): - """Runs the visitor, returns graph. - """ + """Runs the visitor, returns graph.""" instance = cls() for loop in loops: instance.visit(loop, **kwargs) @@ -62,9 +61,9 @@ def visit_FieldAccess(self, node: FieldAccess, **kwargs): self.graph.add_edge(source, kwargs["current_write"], extent=node.extent) def visit_AssignStmt(self, node: AssignStmt, **kwargs): - self.graph.add_node(node.left.id_) # make IR nodes hashable? - self.visit(node.right, current_write=node.left.id_) - self.last_write_access[node.left.name] = node.left.id_ + self.graph.add_node(node.left.__node_id__) # make IR nodes hashable? + self.visit(node.right, current_write=node.left.__node_id__) + self.last_write_access[node.left.name] = node.left.__node_id__ def generate_dependency_graph(loops: List[HorizontalLoop]) -> nx.DiGraph: diff --git a/src/gtc/unstructured/nir_to_usid.py b/src/gtc/unstructured/nir_to_usid.py index bad81d7..6cb3ad6 100644 --- a/src/gtc/unstructured/nir_to_usid.py +++ b/src/gtc/unstructured/nir_to_usid.py @@ -177,7 +177,7 @@ def visit_HorizontalLoop(self, node: nir.HorizontalLoop, **kwargs): usid.SidComposite(name=str(chain), entries=v, location=chain) ) # TODO _conn via property - kernel_name = "kernel_" + node.id_ + kernel_name = "kernel_" + node.__node_id__ kernel = usid.Kernel( ast=self.visit( node.stmt, diff --git a/tests/tests_gtc/test_nir_field_dependency_graph.py b/tests/tests_gtc/test_nir_field_dependency_graph.py index 4d3a0eb..5f657dd 100644 --- a/tests/tests_gtc/test_nir_field_dependency_graph.py +++ b/tests/tests_gtc/test_nir_field_dependency_graph.py @@ -45,8 +45,8 @@ def test_dependent_assignment(self): result = generate_dependency_graph(loops) assert len(result.nodes()) == 2 - assert result.has_edge(write0.id_, write1.id_) - assert result[write0.id_][write1.id_]["extent"] is False + assert result.has_edge(write0.__node_id__, write1.__node_id__) + assert result[write0.__node_id__][write1.__node_id__]["extent"] is False def test_dependent_assignment_with_extent(self): loop0, write0 = make_horizontal_loop_with_init("write0") @@ -56,5 +56,5 @@ def test_dependent_assignment_with_extent(self): result = generate_dependency_graph(loops) assert len(result.nodes()) == 2 - assert result.has_edge(write0.id_, write1.id_) - assert result[write0.id_][write1.id_]["extent"] is True + assert result.has_edge(write0.__node_id__, write1.__node_id__) + assert result[write0.__node_id__][write1.__node_id__]["extent"] is True From 4200017066df63559c4635a88a582d889b20c226 Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Wed, 4 Nov 2020 16:56:55 +0100 Subject: [PATCH 08/17] Improve BaseNode docstring --- src/eve/concepts.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/eve/concepts.py b/src/eve/concepts.py index 3b052b5..91b8b4f 100644 --- a/src/eve/concepts.py +++ b/src/eve/concepts.py @@ -189,10 +189,10 @@ class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): Naming semantics for Node members: - * Member names starting with ``_`` (e.g. ``_private_attr``) are transformed into - `private attributes` by pydantic and thus ignored by Eve. Since none of the - pydantic features will work on them (type coercion, validators, etc.), it is - not recommended to define new pydantic private attributes in the nodes. + * Member names starting with ``_`` (e.g. ``_private`` or ``__private__``) are + transformed into `private attributes` by pydantic and thus ignored by Eve. Since + none of the pydantic features will work on them (type coercion, validators, etc.), + it is not recommended to define new pydantic private attributes in the nodes. * Member names ending with ``_`` (and not starting with ``_``, e.g. ``my_data_``) are considered node data annotations, not children nodes. They are intended to be used by the user, typically to cache derived, non-essential information From 1b04e4e4980a915799262e39721330a7cc7298bf Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Wed, 4 Nov 2020 17:00:24 +0100 Subject: [PATCH 09/17] Fix bug in merge_horizontal_loops pass --- src/gtc/unstructured/nir_passes/merge_horizontal_loops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gtc/unstructured/nir_passes/merge_horizontal_loops.py b/src/gtc/unstructured/nir_passes/merge_horizontal_loops.py index 75107ee..d7121f6 100644 --- a/src/gtc/unstructured/nir_passes/merge_horizontal_loops.py +++ b/src/gtc/unstructured/nir_passes/merge_horizontal_loops.py @@ -129,7 +129,7 @@ def visit_VerticalLoop( def merge_horizontal_loops( root: nir.VerticalLoop, merge_candidates: List[List[nir.HorizontalLoop]] ): - return MergeHorizontalLoops().apply(root, merge_candidates) + return MergeHorizontalLoops.apply(root, merge_candidates) def find_and_merge_horizontal_loops(root: Node): From 13cd4b0a69060eee7c3675a791465249d2bbff0f Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Thu, 5 Nov 2020 12:34:49 +0100 Subject: [PATCH 10/17] Expand docstrings in concepts --- src/eve/concepts.py | 112 ++++++++++++++++++++++++++++++++------------ 1 file changed, 82 insertions(+), 30 deletions(-) diff --git a/src/eve/concepts.py b/src/eve/concepts.py index 91b8b4f..2bc0f5b 100644 --- a/src/eve/concepts.py +++ b/src/eve/concepts.py @@ -99,23 +99,31 @@ def field( # -- Models -- class BaseModelConfig: + """Base Eve configuration for mutable pydantic classes.""" + extra = "forbid" underscore_attrs_are_private = True # TODO(egparedes): setting 'underscore_attrs_are_private' to True breaks sphinx-autodoc class FrozenModelConfig(BaseModelConfig): + """Base Eve configuration for immutable pydantic classes.""" + allow_mutation = False class Model(pydantic.BaseModel): + """Base public class for models that are not IR Nodes.""" + class Config(BaseModelConfig): - pass + ... class FrozenModel(pydantic.BaseModel): + """Base public class for immutable models that are not IR Nodes.""" + class Config(FrozenModelConfig): - pass + ... # -- Nodes -- @@ -144,7 +152,7 @@ def _is_private_attr_name(name: str) -> bool: class NodeMetaclass(pydantic.main.ModelMetaclass): """Custom metaclass for Node classes. - Customize the creation of Node classes adding Eve specific attributes. + Customize the creation of new Node classes adding Eve specific attributes. """ @@ -178,30 +186,66 @@ class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): A node is currently implemented as a pydantic Model with some extra features. - Field values should be either: - - * builtin types: `bool`, `bytes`, `int`, `float`, `str` - * enum.Enum types - * other :class:`Node` subclasses - * other :class:`pydantic.BaseModel` subclasses - * supported collections (:class:`Tuple`, :class:`List`, :class:`Dict`, :class:`Set`) - of any of the previous items - - Naming semantics for Node members: - - * Member names starting with ``_`` (e.g. ``_private`` or ``__private__``) are - transformed into `private attributes` by pydantic and thus ignored by Eve. Since - none of the pydantic features will work on them (type coercion, validators, etc.), - it is not recommended to define new pydantic private attributes in the nodes. - * Member names ending with ``_`` (and not starting with ``_``, e.g. ``my_data_``) - are considered node data annotations, not children nodes. They are intended - to be used by the user, typically to cache derived, non-essential information - on the node, and they can be assigned directly without a explicit definition - in the class body (which will consequently trigger an error). - * Member names ending with ``__`` (and not starting with ``_``, e.g. ``internal__``) - are reserved for internal Eve use and should NOT be defined by regular users. - All pydantic features will work on these fields anyway but they will be - not visible visible in Eve nodes. + The public fields of a node encode the IR information and are considered + as `children` when iterating a tree. Besides children fields, a node class + can define private instance attributes, which are regular Python attributes, + without explicit validation or serialization, currently implemented by + pydantic using ``__slots__``. + + Accepted field values are: + + * builtin types: `bool`, `bytes`, `int`, `float`, `str`. + * enum.Enum types. + * other :class:`Node` subclasses. + * other :class:`pydantic.BaseModel` subclasses. + * supported collections (:class:`Tuple`, :class:`List`, :class:`Dict`, :class:`Set`). + of any of the previous items. + + Node members follow a specific naming scheme to distinguish between the + different kinds of members: + + * Member names ending with ``_`` and not starting with ``_``, (e.g. ``my_data_``) + are considered node data annotations, not children nodes. They are + intended to be used by the user, typically to cache derived, + non-essential information on the node, and they can be assigned directly + without a explicit definition in the class body (which will consequently + trigger an error). They are stored in the default ``__node_annotations__`` + private instance attribute. + * Member names starting with ``_`` (e.g. ``_private`` or ``__private__``) + are transformed into `private instance attributes` by pydantic and thus + ignored by Eve. Since none of the pydantic features will work on them + (type coercion, validators, etc.), it is not recommended for users to + define new pydantic private attributes in the nodes and use node data + annotations instead. + * Member names ending with ``__`` and not starting with ``_`` (e.g. ``internal__``) + are reserved for internal Eve use and should NOT be defined by + regular users. All pydantic features will work on these fields + anyway but they will be not visible visible in Eve nodes. + + + A default set of private attributes is defined in :class:`BaseNode` and + therefore available on all node subclasses: + + Attributes: + __node_id__: unique id of the node instance. + __node_annotations__: container for arbitrary data annotations. + + Additionally, node classes comes with the following utilities provided + by pydantic for simple serialization purposes: + + :meth:`dict()` + returns a dictionary of the model's fields and values; cf. exporting models + :meth:`json()` + returns a JSON string representation dict(); cf. exporting models + :meth:`copy()` + returns a copy (by default, shallow copy) of the model; cf. exporting models + :meth:`schema()` + returns a dictionary representing the model as JSON Schema; cf. Schema + :meth:`schema_json()` + returns a JSON string representation of schema(); cf. Schema + + Pydantic provides even more helper methods, but they are too `pydantic-specific` + and thus it is recommended to avoid using them in stable Eve code. """ @@ -213,33 +257,39 @@ class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): default_factory=utils.UIDGenerator.sequential_id ) - #: Node analysis annotations + #: Node data annotations __node_annotations__: Optional[Str] = pydantic.PrivateAttr( # type: ignore # mypy can't find PrivateAttr default_factory=types.SimpleNamespace ) def iter_children(self) -> Generator[Tuple[str, Any], None, None]: + """Iterate through all public field (name, value) pairs.""" for name, _ in self.__fields__.items(): if _is_child_field_name(name): yield name, getattr(self, name) def iter_children_names(self) -> Generator[str, None, None]: + """Iterate through all public field names.""" for name, _ in self.iter_children(): yield name def iter_children_values(self) -> Generator[Any, None, None]: + """Iterate through all public field values.""" for _, node in self.iter_children(): yield node @property def data_annotations(self) -> Dict[str, Any]: + """Node data annotations dict.""" return self.__node_annotations__.__dict__ @property def private_attrs_names(self) -> Set[str]: + """Names of all the private instance attributes.""" return set(self.__slots__) - {"__doc__"} def __getattr__(self, name: str) -> Any: + """Access node data annotations or regular instance data.""" if _is_data_annotation_name(name): try: return super().__getattribute__("__node_annotations__").__getattribute__(name) @@ -249,6 +299,7 @@ def __getattr__(self, name: str) -> Any: return super().__getattribute__(name) def __setattr__(self, name: str, value: Any) -> None: + """Set node data annotations or regular instance data.""" if _is_data_annotation_name(name): try: super().__getattribute__("__node_annotations__").__setattr__(name, value) @@ -258,6 +309,7 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) def __delattr__(self, name: str) -> None: + """Delete node data annotations or regular instance data.""" if _is_data_annotation_name(name): try: super().__getattribute__("__node_annotations__").__delattr__(name) @@ -267,7 +319,7 @@ def __delattr__(self, name: str) -> None: super().__delattr__(name) class Config(BaseModelConfig): - pass + ... class GenericNode(BaseNode, pydantic.generics.GenericModel): @@ -286,7 +338,7 @@ class FrozenNode(Node): """Default public name for an immutable base node class.""" class Config(FrozenModelConfig): - pass + ... KeyValue = Tuple[Union[int, str], Any] From 0590e5f489d2b6b498a23ca08fa11498c2d09de9 Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Fri, 6 Nov 2020 11:35:07 +0100 Subject: [PATCH 11/17] Fix traits and add copy tests --- src/eve/concepts.py | 16 +++++------ src/eve/traits.py | 28 +++++++++++------- tests/tests_eve/unit_tests/test_concepts.py | 32 +++++++++++++++++++++ 3 files changed, 57 insertions(+), 19 deletions(-) diff --git a/src/eve/concepts.py b/src/eve/concepts.py index 2bc0f5b..554bd7a 100644 --- a/src/eve/concepts.py +++ b/src/eve/concepts.py @@ -167,9 +167,9 @@ def __new__(mcls, name, bases, namespace, **kwargs): for name, model_field in cls.__fields__.items(): assert not _is_private_attr_name(name) if _is_data_annotation_name(name): - raise TypeError(f"Invalid field name ('{name}') looks like a data annotation.") + raise TypeError(f"Invalid field name '{name}' looks like a data annotation.") if _is_internal_field_name(name): - raise TypeError(f"Invalid field name ('{name}') looks like an Eve internal field.") + raise TypeError(f"Invalid field name '{name}' looks like an Eve internal field.") if _is_child_field_name(name): children_metadata[name] = { "definition": model_field, @@ -234,18 +234,18 @@ class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): by pydantic for simple serialization purposes: :meth:`dict()` - returns a dictionary of the model's fields and values; cf. exporting models + returns a dictionary of the model's fields and values. :meth:`json()` - returns a JSON string representation dict(); cf. exporting models + returns a JSON string representation dict(). :meth:`copy()` - returns a copy (by default, shallow copy) of the model; cf. exporting models + returns a copy (by default, shallow copy) of the model. :meth:`schema()` - returns a dictionary representing the model as JSON Schema; cf. Schema + returns a dictionary representing the model as JSON Schema. :meth:`schema_json()` - returns a JSON string representation of schema(); cf. Schema + returns a JSON string representation of schema(). Pydantic provides even more helper methods, but they are too `pydantic-specific` - and thus it is recommended to avoid using them in stable Eve code. + and therefore it is recommended to avoid them in stable Eve code. """ diff --git a/src/eve/traits.py b/src/eve/traits.py index e28cd1b..a907ee5 100644 --- a/src/eve/traits.py +++ b/src/eve/traits.py @@ -19,15 +19,28 @@ from __future__ import annotations -import pydantic - from . import concepts, iterators from .type_definitions import SymbolName -from .typingx import Any, Dict, Type +from .typingx import Any, Dict class SymbolTableTrait(concepts.Model): - symtable_: Dict[str, Any] = pydantic.Field(default_factory=dict) + """Trait implementing automatic symbol table creation for nodes. + + Nodes inheriting this trait will collect all the + :class:`eve.type_definitions.SymbolRef` instances defined in the + children nodes and store them in a ``symtable_`` node data annotation. + + Node data annotations: + + symtable_: Dict[str, eve.concepts.BaseNode]: + Mapping from symbol name to symbol node. + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.collect_symbols() @staticmethod def _collect_symbols(root_node: concepts.TreeNode) -> Dict[str, Any]: @@ -42,12 +55,5 @@ def _collect_symbols(root_node: concepts.TreeNode) -> Dict[str, Any]: return collected - @pydantic.root_validator(skip_on_failure=True) - def _collect_symbols_validator( # type: ignore # validators are classmethods - cls: Type[SymbolTableTrait], values: Dict[str, Any] - ) -> Dict[str, Any]: - values["symtable_"] = cls._collect_symbols(values) - return values - def collect_symbols(self) -> None: self.symtable_ = self._collect_symbols(self) diff --git a/tests/tests_eve/unit_tests/test_concepts.py b/tests/tests_eve/unit_tests/test_concepts.py index 1288540..1521de5 100644 --- a/tests/tests_eve/unit_tests/test_concepts.py +++ b/tests/tests_eve/unit_tests/test_concepts.py @@ -15,6 +15,8 @@ # SPDX-License-Identifier: GPL-3.0-or-later +import copy + import pydantic import pytest @@ -57,6 +59,36 @@ def test_inmutability(self, frozen_sample_node): with pytest.raises(TypeError): frozen_sample_node.int_value = 123456 + def test_copy(self, sample_node): + bla = (1, "bla", (1, 2)) + foo = {"a": 1, "b": 2} + sample_node.bla_ = bla + sample_node.foo_ = foo + + node_copy = sample_node.copy() + assert sample_node == node_copy + assert sample_node.__node_id__ == node_copy.__node_id__ + assert sample_node.__node_annotations__ == node_copy.__node_annotations__ + assert sample_node.__node_annotations__ is node_copy.__node_annotations__ + assert sample_node.bla_ is bla + assert sample_node.foo_ is foo + + node_copy = copy.copy(sample_node) + assert sample_node == node_copy + assert sample_node.__node_id__ == node_copy.__node_id__ + assert sample_node.__node_annotations__ == node_copy.__node_annotations__ + assert sample_node.__node_annotations__ is node_copy.__node_annotations__ + assert sample_node.bla_ is bla + assert sample_node.foo_ is foo + + node_copy = copy.deepcopy(sample_node) + assert sample_node == node_copy + assert sample_node.__node_id__ == node_copy.__node_id__ + assert sample_node.__node_annotations__ == node_copy.__node_annotations__ + assert sample_node.__node_annotations__ is not node_copy.__node_annotations__ + # assert sample_node.bla_ is not bla + # assert sample_node.foo_ is not foo + def test_field_naming(self): class NodeWithPrivateAttrs(eve.concepts.BaseNode): _private_int: int = 0 From 11381cf48800c970ab2a72582892703b26eca921 Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Fri, 6 Nov 2020 11:49:15 +0100 Subject: [PATCH 12/17] Fix copy tests --- tests/tests_eve/unit_tests/test_concepts.py | 22 +++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/tests_eve/unit_tests/test_concepts.py b/tests/tests_eve/unit_tests/test_concepts.py index 1521de5..4d1797b 100644 --- a/tests/tests_eve/unit_tests/test_concepts.py +++ b/tests/tests_eve/unit_tests/test_concepts.py @@ -60,34 +60,40 @@ def test_inmutability(self, frozen_sample_node): frozen_sample_node.int_value = 123456 def test_copy(self, sample_node): - bla = (1, "bla", (1, 2)) foo = {"a": 1, "b": 2} - sample_node.bla_ = bla + bla = [1, "bla", (1, 2)] + const_bla = tuple(bla) sample_node.foo_ = foo + sample_node.bla_ = bla + sample_node.const_bla_ = const_bla node_copy = sample_node.copy() assert sample_node == node_copy assert sample_node.__node_id__ == node_copy.__node_id__ assert sample_node.__node_annotations__ == node_copy.__node_annotations__ assert sample_node.__node_annotations__ is node_copy.__node_annotations__ - assert sample_node.bla_ is bla - assert sample_node.foo_ is foo + assert node_copy.foo_ is foo + assert node_copy.bla_ is bla + assert node_copy.const_bla_ is const_bla node_copy = copy.copy(sample_node) assert sample_node == node_copy assert sample_node.__node_id__ == node_copy.__node_id__ assert sample_node.__node_annotations__ == node_copy.__node_annotations__ assert sample_node.__node_annotations__ is node_copy.__node_annotations__ - assert sample_node.bla_ is bla - assert sample_node.foo_ is foo + assert node_copy.foo_ is foo + assert node_copy.bla_ is bla + assert node_copy.const_bla_ is const_bla node_copy = copy.deepcopy(sample_node) assert sample_node == node_copy assert sample_node.__node_id__ == node_copy.__node_id__ assert sample_node.__node_annotations__ == node_copy.__node_annotations__ assert sample_node.__node_annotations__ is not node_copy.__node_annotations__ - # assert sample_node.bla_ is not bla - # assert sample_node.foo_ is not foo + assert node_copy.foo_ is node_copy.__node_annotations__.foo_ + assert node_copy.foo_ is not foo + assert node_copy.bla_ is not bla + assert node_copy.const_bla_ is const_bla def test_field_naming(self): class NodeWithPrivateAttrs(eve.concepts.BaseNode): From 118563ee15d50dbe7b6d828dc3fb6439b117687c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 6 Nov 2020 15:50:26 +0100 Subject: [PATCH 13/17] New loop merging strategy (#89) horizontal loop merge candidates are stored as an index range as a node annotation on the vertical loop --- src/eve/concepts.py | 4 +- src/gt_frontend/py_to_gtscript.py | 2 +- .../nir_passes/merge_horizontal_loops.py | 128 ++++++++---------- .../test_nir_merge_horizontal_loops.py | 66 +++++---- 4 files changed, 98 insertions(+), 102 deletions(-) diff --git a/src/eve/concepts.py b/src/eve/concepts.py index 554bd7a..766d828 100644 --- a/src/eve/concepts.py +++ b/src/eve/concepts.py @@ -253,12 +253,12 @@ class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): # Node private attributes #: Unique node-id - __node_id__: Optional[str] = pydantic.PrivateAttr( # type: ignore # mypy can't find PrivateAttr + __node_id__: Optional[str] = pydantic.PrivateAttr( default_factory=utils.UIDGenerator.sequential_id ) #: Node data annotations - __node_annotations__: Optional[Str] = pydantic.PrivateAttr( # type: ignore # mypy can't find PrivateAttr + __node_annotations__: Optional[Str] = pydantic.PrivateAttr( default_factory=types.SimpleNamespace ) diff --git a/src/gt_frontend/py_to_gtscript.py b/src/gt_frontend/py_to_gtscript.py index ffd737c..ed0ea79 100644 --- a/src/gt_frontend/py_to_gtscript.py +++ b/src/gt_frontend/py_to_gtscript.py @@ -148,7 +148,7 @@ class Patterns: Pass = ast.Pass() - Argument = ast.arg(arg=Capture("name"), annotation=Capture("type_")) + Argument = ast.arg(arg=Capture("name"), annotation=Capture("arg_type")) Computation = ast.FunctionDef( args=ast.arguments(args=Capture("arguments")), diff --git a/src/gtc/unstructured/nir_passes/merge_horizontal_loops.py b/src/gtc/unstructured/nir_passes/merge_horizontal_loops.py index d7121f6..0e9c922 100644 --- a/src/gtc/unstructured/nir_passes/merge_horizontal_loops.py +++ b/src/gtc/unstructured/nir_passes/merge_horizontal_loops.py @@ -14,20 +14,29 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import copy from typing import List import networkx as nx import eve # noqa: F401 -from eve import Node, NodeTranslator, NodeVisitor +from eve import Node, NodeTranslator from gtc.unstructured import nir from gtc.unstructured.nir_passes.field_dependency_graph import generate_dependency_graph -class _FindMergeCandidatesAnalysis(NodeVisitor): +# This is an example of an analysis pass using data annotations. + + +def _has_read_with_offset_after_write(graph: nx.DiGraph, **kwargs): + return any(edge["extent"] for _, _, edge in graph.edges(data=True)) + + +def _find_merge_candidates(root: Node): """Find horizontal loop merge candidates. - Result is a List[List[HorizontalLoop]], where the inner list contains mergable loops. + Result is a List[List[int]], where the inner list contains a range of the horizontal loop indices. + The result is stored as a node data annotation `merge_candidates_` on the VerticalLoop. Currently the merge sets are ordered and disjunct, see question below. In the following examples A, B, C, ... are loops @@ -46,73 +55,57 @@ class _FindMergeCandidatesAnalysis(NodeVisitor): - if the read is without offset, we can fuse - if the read is with offset, we cannot fuse """ - - def __init__(self, **kwargs): - super().__init__() - self.candidates = [] - self.candidate = [] - - @classmethod - def find(cls, root, **kwargs) -> List[List[nir.HorizontalLoop]]: - """Runs the visitor, returns merge candidates. - """ - instance = cls() - instance.visit(root, **kwargs) - if len(instance.candidate) > 1: - instance.candidates.append(instance.candidate) - return instance.candidates - - def has_read_with_offset_after_write(self, graph: nx.DiGraph, **kwargs): - return any(edge["extent"] for _, _, edge in graph.edges(data=True)) - - def visit_HorizontalLoop(self, node: nir.HorizontalLoop, **kwargs): - if len(self.candidate) == 0: - self.candidate.append(node) - return - elif ( - self.candidate[-1].location_type == node.location_type - ): # same location type as previous - dependencies = generate_dependency_graph(self.candidate + [node]) - if not self.has_read_with_offset_after_write(dependencies): - self.candidate.append(node) - return - # cannot merge to previous loop: - if len(self.candidate) > 1: - self.candidates.append(self.candidate) # add a new merge set - self.candidate = [node] - - -def _find_merge_candidates(root: nir.VerticalLoop): - return _FindMergeCandidatesAnalysis().find(root) + vertical_loops = eve.FindNodes().by_type(nir.VerticalLoop, root) + for vloop in vertical_loops: + candidates = [] + candidate: List[nir.HorizontalLoop] = [] + candidate_range: List[int] = [0, 0] + + for index, hloop in enumerate(vloop.horizontal_loops): + if len(candidate) == 0: + candidate.append(hloop) + candidate_range[0] = index + continue + elif ( + candidate[-1].location_type == hloop.location_type + ): # same location type as previous + dependencies = generate_dependency_graph(candidate + [hloop]) + if not _has_read_with_offset_after_write(dependencies): + candidate.append(hloop) + candidate_range[1] = index + continue + # cannot merge to previous loop: + if len(candidate) > 1: + candidates.append(candidate_range) # add a new merge set + candidate = [hloop] + candidate_range = [index, 0] + + if len(candidate) > 1: + candidates.append(candidate_range) # add a new merge set + + vloop.merge_candidates_ = candidates class MergeHorizontalLoops(NodeTranslator): - """ - """ + """""" @classmethod - def apply(cls, root: nir.VerticalLoop, merge_candidates, **kwargs) -> nir.VerticalLoop: - """ - """ - # merge_candidates = _find_merge_candidates(root) - return cls().visit(root, merge_candidates=merge_candidates) - - def visit_VerticalLoop( - self, node: nir.VerticalLoop, *, merge_candidates: List[List[nir.HorizontalLoop]], **kwargs - ): - for candidate in merge_candidates: + def apply(cls, root: Node, **kwargs): + """""" + return cls().visit(root) + + def visit_VerticalLoop(self, node: nir.VerticalLoop, **kwargs): + new_horizontal_loops = copy.deepcopy(node.horizontal_loops) + for merge_group in node.merge_candidates_: declarations = [] statements = [] - location_type = candidate[0].location_type + location_type = node.horizontal_loops[merge_group[0]].location_type - first_index = node.horizontal_loops.index(candidate[0]) - last_index = node.horizontal_loops.index(candidate[-1]) - - for loop in candidate: + for loop in node.horizontal_loops[merge_group[0] : merge_group[1] + 1]: declarations += loop.stmt.declarations statements += loop.stmt.statements - node.horizontal_loops[first_index : last_index + 1] = [ # noqa: E203 + new_horizontal_loops = [ # noqa: E203 nir.HorizontalLoop( stmt=nir.BlockStmt( declarations=declarations, @@ -122,20 +115,13 @@ def visit_VerticalLoop( location_type=location_type, ) ] - - return node + return nir.VerticalLoop(loop_order=node.loop_order, horizontal_loops=new_horizontal_loops) -def merge_horizontal_loops( - root: nir.VerticalLoop, merge_candidates: List[List[nir.HorizontalLoop]] -): - return MergeHorizontalLoops.apply(root, merge_candidates) +def merge_horizontal_loops(root: Node): + return MergeHorizontalLoops.apply(root) def find_and_merge_horizontal_loops(root: Node): - copy = root.copy(deep=True) - vertical_loops = eve.FindNodes().by_type(nir.VerticalLoop, copy) - for loop in vertical_loops: - loop = merge_horizontal_loops(loop, _find_merge_candidates(loop)) - - return copy + _find_merge_candidates(root) + return merge_horizontal_loops(root) diff --git a/tests/tests_gtc/test_nir_merge_horizontal_loops.py b/tests/tests_gtc/test_nir_merge_horizontal_loops.py index d21c5f0..3c565bd 100644 --- a/tests/tests_gtc/test_nir_merge_horizontal_loops.py +++ b/tests/tests_gtc/test_nir_merge_horizontal_loops.py @@ -42,11 +42,12 @@ def test_same_location(self): second_loop = make_empty_horizontal_loop(common.LocationType.Vertex) stencil = make_vertical_loop([first_loop, second_loop]) - result = _find_merge_candidates(stencil) + _find_merge_candidates(stencil) + result = stencil.merge_candidates_ assert len(result) == 1 - assert result[0][0] == first_loop - assert result[0][1] == second_loop + assert result[0][0] == 0 + assert result[0][1] == 1 def test_2_on_same_location_1_other(self): first_loop = make_empty_horizontal_loop(common.LocationType.Vertex) @@ -54,12 +55,12 @@ def test_2_on_same_location_1_other(self): third_loop = make_empty_horizontal_loop(common.LocationType.Edge) stencil = make_vertical_loop([first_loop, second_loop, third_loop]) - result = _find_merge_candidates(stencil) + _find_merge_candidates(stencil) + result = stencil.merge_candidates_ assert len(result) == 1 - assert len(result[0]) == 2 - assert result[0][0] == first_loop - assert result[0][1] == second_loop + assert result[0][0] == 0 + assert result[0][1] == 1 def test_2_sets_of_location(self): first_loop = make_empty_horizontal_loop(common.LocationType.Vertex) @@ -68,17 +69,16 @@ def test_2_sets_of_location(self): fourth_loop = make_empty_horizontal_loop(common.LocationType.Edge) stencil = make_vertical_loop([first_loop, second_loop, third_loop, fourth_loop]) - result = _find_merge_candidates(stencil) + _find_merge_candidates(stencil) + result = stencil.merge_candidates_ assert len(result) == 2 - assert len(result[0]) == 2 - assert result[0][0] == first_loop - assert result[0][1] == second_loop + assert result[0][0] == 0 + assert result[0][1] == 1 - assert len(result[1]) == 2 - assert result[1][0] == third_loop - assert result[1][1] == fourth_loop + assert result[1][0] == 2 + assert result[1][1] == 3 def test_vertex_edge_vertex(self): first_loop = make_empty_horizontal_loop(common.LocationType.Vertex) @@ -86,7 +86,8 @@ def test_vertex_edge_vertex(self): third_loop = make_empty_horizontal_loop(common.LocationType.Vertex) stencil = make_vertical_loop([first_loop, second_loop, third_loop]) - result = _find_merge_candidates(stencil) + _find_merge_candidates(stencil) + result = stencil.merge_candidates_ assert len(result) == 0 @@ -99,7 +100,8 @@ def test_write_read_with_offset(self): second_loop, _, _ = make_horizontal_loop_with_copy("out", "field", True) stencil = make_vertical_loop([first_loop, second_loop]) - result = _find_merge_candidates(stencil) + _find_merge_candidates(stencil) + result = stencil.merge_candidates_ assert len(result) == 0 @@ -110,7 +112,8 @@ def test_read_a_with_offset_write_b_read_b_no_offset(self): second_loop, _, _ = make_horizontal_loop_with_copy("out", "field", False) stencil = make_vertical_loop([first_loop, second_loop]) - result = _find_merge_candidates(stencil) + _find_merge_candidates(stencil) + result = stencil.merge_candidates_ assert len(result) == 1 @@ -121,7 +124,8 @@ def test_write_read_no_offset(self): second_loop, _, _ = make_horizontal_loop_with_copy("out", "field", False) stencil = make_vertical_loop([first_loop, second_loop]) - result = _find_merge_candidates(stencil) + _find_merge_candidates(stencil) + result = stencil.merge_candidates_ assert len(result) == 1 @@ -134,7 +138,8 @@ def test_write_read_no_offset_write_read_no_offset(self): third_loop, _, _ = make_horizontal_loop_with_copy("out", "field2", False) stencil = make_vertical_loop([first_loop, second_loop, third_loop]) - result = _find_merge_candidates(stencil) + _find_merge_candidates(stencil) + result = stencil.merge_candidates_ assert len(result) == 1 @@ -147,12 +152,12 @@ def test_write_read_no_offset_write_read_with_offset(self): third_loop, _, _ = make_horizontal_loop_with_copy("out", "field2", True) stencil = make_vertical_loop([first_loop, second_loop, third_loop]) - result = _find_merge_candidates(stencil) + _find_merge_candidates(stencil) + result = stencil.merge_candidates_ assert len(result) == 1 - assert len(result[0]) == 2 - assert result[0][0] == first_loop - assert result[0][1] == second_loop + assert result[0][0] == 0 + assert result[0][1] == 1 class TestNIRMergeHorizontalLoops: @@ -161,9 +166,8 @@ def test_merge_empty_loops(self): second_loop = make_empty_horizontal_loop(default_location) stencil = make_vertical_loop([first_loop, second_loop]) - merge_candidates = [[first_loop, second_loop]] - - result = merge_horizontal_loops(stencil, merge_candidates) + stencil.merge_candidates_ = [[0, 1]] + result = merge_horizontal_loops(stencil) assert len(result.horizontal_loops) == 1 @@ -177,9 +181,9 @@ def test_merge_loops_with_stats_and_decls(self): second_loop = make_horizontal_loop(make_block_stmt([assignment2], [var2])) stencil = make_vertical_loop([first_loop, second_loop]) - merge_candidates = [[first_loop, second_loop]] + stencil.merge_candidates_ = [[0, 1]] - result = merge_horizontal_loops(stencil, merge_candidates) + result = merge_horizontal_loops(stencil) assert len(result.horizontal_loops) == 1 assert len(result.horizontal_loops[0].stmt.statements) == 2 @@ -221,8 +225,14 @@ def test_find_and_merge_with_2_vertical_loops(self): vloops = FindNodes().by_type(nir.VerticalLoop, result) assert len(vloops) == 2 + for vloop in vloops: # TODO more precise checks assert len(vloop.horizontal_loops) == 1 assert len(vloop.horizontal_loops[0].stmt.statements) == 2 assert len(vloop.horizontal_loops[0].stmt.declarations) == 2 + + # check we didn't touch the input tree + orig_vloops = FindNodes().by_type(nir.VerticalLoop, stencil) + for vloop in orig_vloops: + assert len(vloop.horizontal_loops) == 2 From c0b6d62396ffb4fd2d0a16c31d5182343fce1c2f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 6 Nov 2020 15:12:05 +0000 Subject: [PATCH 14/17] Update notebooks --- .../eve/workshop-2020-10/solution/day1.ipynb | 28 ---------- .../eve/workshop-2020-10/solution/day2.ipynb | 55 ++----------------- 2 files changed, 4 insertions(+), 79 deletions(-) diff --git a/examples/eve/workshop-2020-10/solution/day1.ipynb b/examples/eve/workshop-2020-10/solution/day1.ipynb index 17e606a..888529d 100644 --- a/examples/eve/workshop-2020-10/solution/day1.ipynb +++ b/examples/eve/workshop-2020-10/solution/day1.ipynb @@ -86,69 +86,52 @@ "text": [ ":6 \n", " binop: BinaryOp(\n", - " id_='BinaryOp_3',\n", " left=Literal(\n", - " id_='Literal_1',\n", " value='1',\n", " ),\n", " right=Literal(\n", - " id_='Literal_2',\n", " value='1',\n", " ),\n", " op='+',\n", " ) (BinaryOp)\n", ":16 \n", " lap: Fun(\n", - " id_='Fun_28',\n", " name='lap',\n", " params=[\n", " FieldDecl(\n", - " id_='FieldDecl_26',\n", " name='out',\n", " ),\n", " FieldDecl(\n", - " id_='FieldDecl_27',\n", " name='in',\n", " ),\n", " ],\n", " horizontal_loops=[\n", " HorizontalLoop(\n", - " id_='HorizontalLoop_25',\n", " i_indent=Indent(\n", - " id_='Indent_23',\n", " left=1,\n", " right=1,\n", " ),\n", " j_indent=Indent(\n", - " id_='Indent_24',\n", " left=1,\n", " right=1,\n", " ),\n", " body=[\n", " AssignStmt(\n", - " id_='AssignStmt_22',\n", " left=FieldAccess(\n", - " id_='FieldAccess_21',\n", " name='out',\n", " offset=Offset(\n", - " id_='Offset_20',\n", " i=0,\n", " j=0,\n", " ),\n", " ),\n", " right=BinaryOp(\n", - " id_='BinaryOp_19',\n", " left=BinaryOp(\n", - " id_='BinaryOp_7',\n", " left=Literal(\n", - " id_='Literal_4',\n", " value='-4',\n", " ),\n", " right=FieldAccess(\n", - " id_='FieldAccess_6',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_5',\n", " i=0,\n", " j=0,\n", " ),\n", @@ -156,23 +139,17 @@ " op='*',\n", " ),\n", " right=BinaryOp(\n", - " id_='BinaryOp_18',\n", " left=BinaryOp(\n", - " id_='BinaryOp_12',\n", " left=FieldAccess(\n", - " id_='FieldAccess_9',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_8',\n", " i=-1,\n", " j=0,\n", " ),\n", " ),\n", " right=FieldAccess(\n", - " id_='FieldAccess_11',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_10',\n", " i=1,\n", " j=0,\n", " ),\n", @@ -180,21 +157,16 @@ " op='+',\n", " ),\n", " right=BinaryOp(\n", - " id_='BinaryOp_17',\n", " left=FieldAccess(\n", - " id_='FieldAccess_14',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_13',\n", " i=0,\n", " j=-1,\n", " ),\n", " ),\n", " right=FieldAccess(\n", - " id_='FieldAccess_16',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_15',\n", " i=0,\n", " j=1,\n", " ),\n", diff --git a/examples/eve/workshop-2020-10/solution/day2.ipynb b/examples/eve/workshop-2020-10/solution/day2.ipynb index 72d243a..5382cd4 100644 --- a/examples/eve/workshop-2020-10/solution/day2.ipynb +++ b/examples/eve/workshop-2020-10/solution/day2.ipynb @@ -40,43 +40,32 @@ "text": [ ":11 \n", " stencil: Stencil(\n", - " id_='Stencil_22',\n", " name='lap',\n", " params=[\n", " FieldParam(\n", - " id_='FieldParam_20',\n", " name='out',\n", " ),\n", " FieldParam(\n", - " id_='FieldParam_21',\n", " name='in',\n", " ),\n", " ],\n", " body=[\n", " AssignStmt(\n", - " id_='AssignStmt_19',\n", " left=FieldAccess(\n", - " id_='FieldAccess_18',\n", " name='out',\n", " offset=Offset(\n", - " id_='Offset_17',\n", " i=0,\n", " j=0,\n", " ),\n", " ),\n", " right=BinaryOp(\n", - " id_='BinaryOp_16',\n", " left=BinaryOp(\n", - " id_='BinaryOp_4',\n", " left=Literal(\n", - " id_='Literal_1',\n", " value='-4',\n", " ),\n", " right=FieldAccess(\n", - " id_='FieldAccess_3',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_2',\n", " i=0,\n", " j=0,\n", " ),\n", @@ -84,23 +73,17 @@ " op='*',\n", " ),\n", " right=BinaryOp(\n", - " id_='BinaryOp_15',\n", " left=BinaryOp(\n", - " id_='BinaryOp_9',\n", " left=FieldAccess(\n", - " id_='FieldAccess_6',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_5',\n", " i=-1,\n", " j=0,\n", " ),\n", " ),\n", " right=FieldAccess(\n", - " id_='FieldAccess_8',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_7',\n", " i=1,\n", " j=0,\n", " ),\n", @@ -108,21 +91,16 @@ " op='+',\n", " ),\n", " right=BinaryOp(\n", - " id_='BinaryOp_14',\n", " left=FieldAccess(\n", - " id_='FieldAccess_11',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_10',\n", " i=0,\n", " j=-1,\n", " ),\n", " ),\n", " right=FieldAccess(\n", - " id_='FieldAccess_13',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_12',\n", " i=0,\n", " j=1,\n", " ),\n", @@ -199,8 +177,8 @@ "# self.extents[kwargs[\"cur_assign\"]] += Extent.from_offset(node.offset)\n", "\n", "# def visit_AssignStmt(self, node: AssignStmt, **kwargs):\n", - "# self.extents[node.id_] = Extent.zero()\n", - "# self.visit(node.right, cur_assign = node.id_)\n", + "# self.extents[node.__node_id__] = Extent.zero()\n", + "# self.visit(node.right, cur_assign = node.__node_id__)\n", "\n", "\n", "class ExtentAnalysis(eve.NodeVisitor):\n", @@ -221,7 +199,7 @@ " return self.visit(node.right, **kwargs)\n", " \n", " def visit_Stencil(self, node: Stencil, **kwargs):\n", - " return {s.id_: self.visit(s, **kwargs) for s in node.body}\n" + " return {s.__node_id__: self.visit(s, **kwargs) for s in node.body}\n" ] }, { @@ -270,7 +248,7 @@ " return lir.FieldDecl(name=node.name)\n", "\n", " def visit_AssignStmt(self, node: AssignStmt, **kwargs):\n", - " extent : Extent = self.extents[node.id_]\n", + " extent : Extent = self.extents[node.__node_id__]\n", " return lir.HorizontalLoop(body=[lir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right))], i_indent=lir.Indent(left=-extent.i_left, right=extent.i_right), j_indent=lir.Indent(left=-extent.j_left, right=extent.j_right))\n", "\n", " def visit_Stencil(self, node: Stencil, **kwargs):\n", @@ -288,56 +266,42 @@ "text": [ ":5 \n", " lir_stencil: Fun(\n", - " id_='Fun_47',\n", " name='lap',\n", " params=[\n", " FieldDecl(\n", - " id_='FieldDecl_23',\n", " name='out',\n", " ),\n", " FieldDecl(\n", - " id_='FieldDecl_24',\n", " name='in',\n", " ),\n", " ],\n", " horizontal_loops=[\n", " HorizontalLoop(\n", - " id_='HorizontalLoop_46',\n", " i_indent=Indent(\n", - " id_='Indent_44',\n", " left=1,\n", " right=1,\n", " ),\n", " j_indent=Indent(\n", - " id_='Indent_45',\n", " left=1,\n", " right=1,\n", " ),\n", " body=[\n", " AssignStmt(\n", - " id_='AssignStmt_43',\n", " left=FieldAccess(\n", - " id_='FieldAccess_26',\n", " name='out',\n", " offset=Offset(\n", - " id_='Offset_25',\n", " i=0,\n", " j=0,\n", " ),\n", " ),\n", " right=BinaryOp(\n", - " id_='BinaryOp_42',\n", " left=BinaryOp(\n", - " id_='BinaryOp_30',\n", " left=Literal(\n", - " id_='Literal_27',\n", " value=-4,\n", " ),\n", " right=FieldAccess(\n", - " id_='FieldAccess_29',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_28',\n", " i=0,\n", " j=0,\n", " ),\n", @@ -345,23 +309,17 @@ " op='*',\n", " ),\n", " right=BinaryOp(\n", - " id_='BinaryOp_41',\n", " left=BinaryOp(\n", - " id_='BinaryOp_35',\n", " left=FieldAccess(\n", - " id_='FieldAccess_32',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_31',\n", " i=-1,\n", " j=0,\n", " ),\n", " ),\n", " right=FieldAccess(\n", - " id_='FieldAccess_34',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_33',\n", " i=1,\n", " j=0,\n", " ),\n", @@ -369,21 +327,16 @@ " op='+',\n", " ),\n", " right=BinaryOp(\n", - " id_='BinaryOp_40',\n", " left=FieldAccess(\n", - " id_='FieldAccess_37',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_36',\n", " i=0,\n", " j=-1,\n", " ),\n", " ),\n", " right=FieldAccess(\n", - " id_='FieldAccess_39',\n", " name='in',\n", " offset=Offset(\n", - " id_='Offset_38',\n", " i=0,\n", " j=1,\n", " ),\n", From 0004878a4109549ae2615071cba74074c4200066 Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Fri, 6 Nov 2020 16:41:46 +0100 Subject: [PATCH 15/17] Tests changes: initial traits tests and refactoring of maker functions --- src/eve/concepts.py | 10 +- src/eve/traits.py | 16 +- src/eve/utils.py | 12 +- tests/tests_eve/common.py | 471 ----------------- tests/tests_eve/common_definitions.py | 548 ++++++++++++++++++++ tests/tests_eve/conftest.py | 15 +- tests/tests_eve/unit_tests/test_concepts.py | 4 +- tests/tests_eve/unit_tests/test_traits.py | 57 ++ tests/tests_eve/unit_tests/test_utils.py | 4 + 9 files changed, 645 insertions(+), 492 deletions(-) delete mode 100644 tests/tests_eve/common.py create mode 100644 tests/tests_eve/common_definitions.py create mode 100644 tests/tests_eve/unit_tests/test_traits.py diff --git a/src/eve/concepts.py b/src/eve/concepts.py index 766d828..a5d5a97 100644 --- a/src/eve/concepts.py +++ b/src/eve/concepts.py @@ -228,7 +228,8 @@ class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): Attributes: __node_id__: unique id of the node instance. - __node_annotations__: container for arbitrary data annotations. + __node_annotations__: container for arbitrary user data annotations. + __node_impl__: internal container for arbitrary data annotations. Additionally, node classes comes with the following utilities provided by pydantic for simple serialization purposes: @@ -258,7 +259,12 @@ class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): ) #: Node data annotations - __node_annotations__: Optional[Str] = pydantic.PrivateAttr( + __node_annotations__: Optional[types.SimpleNamespace] = pydantic.PrivateAttr( + default_factory=types.SimpleNamespace + ) + + #: Node data annotations + __node_impl__: Optional[types.SimpleNamespace] = pydantic.PrivateAttr( # type: ignore # mypy can't find PrivateAttr default_factory=types.SimpleNamespace ) diff --git a/src/eve/traits.py b/src/eve/traits.py index a907ee5..133b73e 100644 --- a/src/eve/traits.py +++ b/src/eve/traits.py @@ -28,13 +28,14 @@ class SymbolTableTrait(concepts.Model): """Trait implementing automatic symbol table creation for nodes. Nodes inheriting this trait will collect all the - :class:`eve.type_definitions.SymbolRef` instances defined in the - children nodes and store them in a ``symtable_`` node data annotation. + :class:`eve.type_definitions.SymbolName` instances defined in the + descendant nodes and store them in the ``__node_symtable__`` private + attribute. - Node data annotations: - - symtable_: Dict[str, eve.concepts.BaseNode]: - Mapping from symbol name to symbol node. + Attributes: + __node_symtable__: Dict[str, concepts.BaseNode] + mapping from symbol name to the symbol node + where it was defined. """ @@ -56,4 +57,5 @@ def _collect_symbols(root_node: concepts.TreeNode) -> Dict[str, Any]: return collected def collect_symbols(self) -> None: - self.symtable_ = self._collect_symbols(self) + assert isinstance(self, concepts.BaseNode) + self.__node_impl__.symtable = self._collect_symbols(self) diff --git a/src/eve/utils.py b/src/eve/utils.py index 6d3e91d..ce4a247 100644 --- a/src/eve/utils.py +++ b/src/eve/utils.py @@ -172,7 +172,9 @@ class CASE_STYLE(enum.Enum): KEBAB = "kebab" @classmethod - def split(cls, name: str, case_style: CASE_STYLE) -> List[str]: + def split(cls, name: str, case_style: Union[CASE_STYLE, str]) -> List[str]: + if isinstance(case_style, str): + case_style = cls.CASE_STYLE(case_style) assert isinstance(case_style, cls.CASE_STYLE) if case_style == cls.CASE_STYLE.CONCATENATED: raise ValueError("Impossible to split a simply concatenated string") @@ -181,7 +183,9 @@ def split(cls, name: str, case_style: CASE_STYLE) -> List[str]: return splitter(name) @classmethod - def join(cls, words: AnyWordsIterable, case_style: CASE_STYLE) -> str: + def join(cls, words: AnyWordsIterable, case_style: Union[CASE_STYLE, str]) -> str: + if isinstance(case_style, str): + case_style = cls.CASE_STYLE(case_style) assert isinstance(case_style, cls.CASE_STYLE) if isinstance(words, str): words = [words] @@ -192,7 +196,9 @@ def join(cls, words: AnyWordsIterable, case_style: CASE_STYLE) -> str: return joiner(words) @classmethod - def convert(cls, name: str, source_style: CASE_STYLE, target_style: CASE_STYLE) -> str: + def convert( + cls, name: str, source_style: Union[CASE_STYLE, str], target_style: Union[CASE_STYLE, str] + ) -> str: return cls.join(cls.split(name, source_style), target_style) # Following `join_...`` functions are based on: diff --git a/tests/tests_eve/common.py b/tests/tests_eve/common.py deleted file mode 100644 index 2d92b06..0000000 --- a/tests/tests_eve/common.py +++ /dev/null @@ -1,471 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Eve Toolchain - GT4Py Project - GridTools Framework -# -# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - - -import enum -import random -import string -from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set, Type, TypeVar - -from pydantic import Field, validator # noqa: F401 - -from eve.concepts import FrozenNode, Node, VType -from eve.type_definitions import Bool, Bytes, Float, Int, IntEnum, SourceLocation, Str, StrEnum - - -T = TypeVar("T") -S = TypeVar("S") - - -class Factories: - STR_LEN = 6 - - @classmethod - def make_bool(cls) -> bool: - return True - - @classmethod - def make_int(cls) -> int: - return 1 - - @classmethod - def make_neg_int(cls) -> int: - return -2 - - @classmethod - def make_pos_int(cls) -> int: - return 2 - - @classmethod - def make_float(cls) -> float: - return 1.1 - - @classmethod - def make_str(cls, length: Optional[int] = None) -> str: - length = length or cls.STR_LEN - return string.ascii_letters[:length] - - @classmethod - def make_member(cls, values: Sequence[T]) -> T: - return values[0] - - @classmethod - def make_collection( - cls, - item_type: Type[T], - collection_type: Type[Collection[T]] = list, - length: Optional[int] = None, - ) -> Collection[T]: - length = length or cls.STR_LEN - - maker_attr_name = f"make_{item_type.__name__}" - if hasattr(cls, maker_attr_name): - maker = getattr(cls, maker_attr_name) - else: - - def maker(): - return item_type() - - return collection_type([maker() for _ in range(length)]) # type: ignore - - @classmethod - def make_mapping( - cls, - key_type: Type[S], - value_type: Type[T], - mapping_type: Type[Mapping[S, T]] = dict, - length: Optional[int] = None, - ) -> Mapping[S, T]: - length = length or cls.STR_LEN - - key_maker_attr_name = f"make_{key_type.__name__}" - if hasattr(cls, key_maker_attr_name): - key_maker = getattr(cls, key_maker_attr_name) - else: - - def key_maker(): - return key_type() - - value_maker_attr_name = f"make_{value_type.__name__}" - if hasattr(cls, value_maker_attr_name): - value_maker = getattr(cls, value_maker_attr_name) - else: - - def value_maker(): - return value_type() - - return mapping_type({key_maker(): value_maker() for _ in range(length)}) # type: ignore - - -class RandomFactories(Factories): - MIN_INT = -9999 - MAX_INT = 9999 - MIN_FLOAT = -999.0 - MAX_FLOAT = 999.09 - - @classmethod - def make_bool(cls) -> bool: - return random.choice([True, False]) - - @classmethod - def make_int(cls) -> int: - return random.randint(cls.MIN_INT, cls.MAX_INT) - - @classmethod - def make_neg_int(cls) -> int: - return random.randint(cls.MIN_INT, 1) - - @classmethod - def make_pos_int(cls) -> int: - return random.randint(1, cls.MAX_INT) - - @classmethod - def make_float(cls) -> float: - return cls.MIN_FLOAT + random.random() * (cls.MAX_FLOAT - cls.MIN_FLOAT) - - @classmethod - def make_str(cls, length: Optional[int] = None) -> str: - length = length or cls.STR_LEN - return "".join(random.choice(string.ascii_letters) for _ in range(length)) - - @classmethod - def make_member(cls, values: Sequence[T]) -> T: - return random.choice(values) - - -@enum.unique -class IntKind(IntEnum): - """Sample int Enum.""" - - MINUS = -1 - ZERO = 0 - PLUS = 1 - - -@enum.unique -class StrKind(StrEnum): - """Sample string Enum.""" - - FOO = "foo" - BLA = "bla" - FIZ = "fiz" - FUZ = "fuz" - - -SimpleVType = VType("simple") - - -class EmptyNode(Node): - pass - - -class LocationNode(Node): - loc: SourceLocation - - -class SimpleNode(Node): - int_value: Int - bool_value: Bool - float_value: Float - str_value: Str - bytes_value: Bytes - int_kind: IntKind - str_kind: StrKind - - -class SimpleNodeWithOptionals(Node): - int_value: Int - float_value: Optional[Float] - str_value: Optional[Str] - - -class SimpleNodeWithLoc(Node): - int_value: Int - float_value: Float - str_value: Str - loc: Optional[SourceLocation] - - -class SimpleNodeWithCollections(Node): - int_value: Int - int_list: List[Int] - str_set: Set[Str] - str_to_int_dict: Dict[Str, Int] - loc: Optional[SourceLocation] - - -class SimpleNodeWithAbstractCollections(Node): - int_value: Int - int_sequence: Sequence[Int] - str_set: Set[Str] - str_to_int_mapping: Mapping[Str, Int] - loc: Optional[SourceLocation] - - -class CompoundNode(Node): - int_value: Int - location: LocationNode - simple: SimpleNode - simple_loc: SimpleNodeWithLoc - simple_opt: SimpleNodeWithOptionals - other_simple_opt: Optional[SimpleNodeWithOptionals] - - -class FrozenSimpleNode(FrozenNode): - int_value: Int - bool_value: Bool - float_value: Float - str_value: Str - bytes_value: Bytes - int_kind: IntKind - str_kind: StrKind - - -# -- Maker functions -- -def make_source_location(fixed: bool = False) -> SourceLocation: - factories = Factories if fixed else RandomFactories - line = factories.make_pos_int() - column = factories.make_pos_int() - str_value = factories.make_str() - source = f"file_{str_value}.py" - - return SourceLocation(line=line, column=column, source=source) - - -def make_empty_node(fixed: bool = False) -> LocationNode: - return EmptyNode() - - -def make_location_node(fixed: bool = False) -> LocationNode: - return LocationNode(loc=make_source_location(fixed)) - - -def make_simple_node(fixed: bool = False) -> SimpleNode: - factories = Factories if fixed else RandomFactories - int_value = factories.make_int() - bool_value = factories.make_bool() - float_value = factories.make_float() - str_value = factories.make_str() - bytes_value = factories.make_str().encode() - int_kind = IntKind.PLUS if fixed else factories.make_member([*IntKind]) - str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) - - return SimpleNode( - int_value=int_value, - bool_value=bool_value, - float_value=float_value, - str_value=str_value, - bytes_value=bytes_value, - int_kind=int_kind, - str_kind=str_kind, - ) - - -def make_simple_node_with_optionals(fixed: bool = False) -> SimpleNodeWithOptionals: - factories = Factories if fixed else RandomFactories - int_value = factories.make_int() - float_value = factories.make_float() - - return SimpleNodeWithOptionals(int_value=int_value, float_value=float_value) - - -def make_simple_node_with_loc(fixed: bool = False) -> SimpleNodeWithLoc: - factories = Factories if fixed else RandomFactories - int_value = factories.make_int() - float_value = factories.make_float() - str_value = factories.make_str() - loc = make_source_location(fixed) - - return SimpleNodeWithLoc( - int_value=int_value, float_value=float_value, str_value=str_value, loc=loc - ) - - -def make_simple_node_with_collections(fixed: bool = False) -> SimpleNodeWithCollections: - factories = Factories if fixed else RandomFactories - int_value = factories.make_int() - int_list = factories.make_collection(int, length=3) - str_set = factories.make_collection(str, set, length=3) - str_to_int_dict = factories.make_mapping(key_type=str, value_type=int, length=3) - loc = make_source_location(fixed) - - return SimpleNodeWithCollections( - int_value=int_value, - int_list=int_list, - str_set=str_set, - str_to_int_dict=str_to_int_dict, - loc=loc, - ) - - -def make_simple_node_with_abstractcollections( - fixed: bool = False, -) -> SimpleNodeWithAbstractCollections: - factories = Factories if fixed else RandomFactories - int_value = factories.make_int() - int_sequence = factories.make_collection(int, collection_type=tuple, length=3) - str_set = factories.make_collection(str, set, length=3) - str_to_int_mapping = factories.make_mapping(key_type=str, value_type=int, length=3) - - return SimpleNodeWithAbstractCollections( - int_value=int_value, - int_sequence=int_sequence, - str_set=str_set, - str_to_int_mapping=str_to_int_mapping, - ) - - -def make_compound_node(fixed: bool = False) -> CompoundNode: - factories = Factories if fixed else RandomFactories - return CompoundNode( - int_value=factories.make_int(), - location=make_location_node(), - simple=make_simple_node(), - simple_loc=make_simple_node_with_loc(), - simple_opt=make_simple_node_with_optionals(), - other_simple_opt=None, - ) - - -def make_frozen_simple_node(fixed: bool = False) -> FrozenSimpleNode: - factories = Factories if fixed else RandomFactories - int_value = factories.make_int() - bool_value = factories.make_bool() - float_value = factories.make_float() - str_value = factories.make_str() - bytes_value = factories.make_str().encode() - int_kind = IntKind.PLUS if fixed else factories.make_member([*IntKind]) - str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) - - return FrozenSimpleNode( - int_value=int_value, - bool_value=bool_value, - float_value=float_value, - str_value=str_value, - bytes_value=bytes_value, - int_kind=int_kind, - str_kind=str_kind, - ) - - -# -- Makers of invalid nodes -- -def make_invalid_location_node(fixed: bool = False) -> LocationNode: - return LocationNode(loc=SourceLocation(line=0, column=-1, source="")) - - -def make_invalid_at_int_simple_node(fixed: bool = False) -> SimpleNode: - factories = Factories if fixed else RandomFactories - int_value = factories.make_float() - bool_value = factories.make_bool() - float_value = factories.make_float() - bytes_value = factories.make_str().encode() - str_value = factories.make_str() - int_kind = IntKind.PLUS if fixed else factories.make_member([*IntKind]) - str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) - - return SimpleNode( - int_value=int_value, - bool_value=bool_value, - float_value=float_value, - str_value=str_value, - bytes_value=bytes_value, - int_kind=int_kind, - str_kind=str_kind, - ) - - -def make_invalid_at_float_simple_node(fixed: bool = False) -> SimpleNode: - factories = Factories if fixed else RandomFactories - int_value = factories.make_int() - bool_value = factories.make_bool() - float_value = factories.make_int() - str_value = factories.make_str() - bytes_value = factories.make_str().encode() - int_kind = IntKind.PLUS if fixed else factories.make_member([*IntKind]) - str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) - - return SimpleNode( - int_value=int_value, - bool_value=bool_value, - float_value=float_value, - str_value=str_value, - bytes_value=bytes_value, - int_kind=int_kind, - str_kind=str_kind, - ) - - -def make_invalid_at_str_simple_node(fixed: bool = False) -> SimpleNode: - factories = Factories if fixed else RandomFactories - int_value = factories.make_int() - bool_value = factories.make_bool() - float_value = factories.make_float() - str_value = factories.make_float() - bytes_value = factories.make_str().encode() - int_kind = IntKind.PLUS if fixed else factories.make_member([*IntKind]) - str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) - - return SimpleNode( - int_value=int_value, - bool_value=bool_value, - float_value=float_value, - str_value=str_value, - bytes_value=bytes_value, - int_kind=int_kind, - str_kind=str_kind, - ) - - -def make_invalid_at_bytes_simple_node(fixed: bool = False) -> SimpleNode: - factories = Factories if fixed else RandomFactories - int_value = factories.make_int() - bool_value = factories.make_bool() - float_value = factories.make_float() - str_value = factories.make_float() - bytes_value = [1, "2", (3, 4)] - int_kind = IntKind.PLUS if fixed else factories.make_member([*IntKind]) - str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) - - return SimpleNode( - int_value=int_value, - bool_value=bool_value, - float_value=float_value, - str_value=str_value, - bytes_value=bytes_value, - int_kind=int_kind, - str_kind=str_kind, - ) - - -def make_invalid_at_enum_simple_node(fixed: bool = False) -> SimpleNode: - factories = Factories if fixed else RandomFactories - int_value = factories.make_int() - bool_value = factories.make_bool() - float_value = factories.make_float() - str_value = factories.make_float() - bytes_value = factories.make_str().encode() - int_kind = IntKind.PLUS if fixed else factories.make_member([*IntKind]) - str_kind = StrKind.BLA if fixed else factories.make_member([*StrKind]) - - return SimpleNode( - int_value=int_value, - bool_value=bool_value, - float_value=float_value, - str_value=str_value, - bytes_value=bytes_value, - int_kind=int_kind, - str_kind=str_kind, - ) diff --git a/tests/tests_eve/common_definitions.py b/tests/tests_eve/common_definitions.py new file mode 100644 index 0000000..fd13784 --- /dev/null +++ b/tests/tests_eve/common_definitions.py @@ -0,0 +1,548 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import enum +import random +import string +from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set, Type, TypeVar + +from pydantic import Field, validator # noqa: F401 + +from eve.concepts import FrozenNode, Node, VType +from eve.traits import SymbolTableTrait +from eve.type_definitions import ( + Bool, + Bytes, + Float, + Int, + IntEnum, + SourceLocation, + Str, + StrEnum, + SymbolName, +) +from eve.utils import CaseStyleConverter + + +T = TypeVar("T") +S = TypeVar("S") + + +# -- Node definitions -- +@enum.unique +class IntKind(IntEnum): + """Sample int Enum.""" + + MINUS = -1 + ZERO = 0 + PLUS = 1 + + +@enum.unique +class StrKind(StrEnum): + """Sample string Enum.""" + + FOO = "foo" + BLA = "bla" + FIZ = "fiz" + FUZ = "fuz" + + +SimpleVType = VType("simple") + + +class EmptyNode(Node): + pass + + +class LocationNode(Node): + loc: SourceLocation + + +class SimpleNode(Node): + int_value: Int + bool_value: Bool + float_value: Float + str_value: Str + bytes_value: Bytes + int_kind: IntKind + str_kind: StrKind + + +class SimpleNodeWithOptionals(Node): + int_value: Int + float_value: Optional[Float] + str_value: Optional[Str] + + +class SimpleNodeWithLoc(Node): + int_value: Int + float_value: Float + str_value: Str + loc: Optional[SourceLocation] + + +class SimpleNodeWithCollections(Node): + int_value: Int + int_list: List[Int] + str_set: Set[Str] + str_to_int_dict: Dict[Str, Int] + loc: Optional[SourceLocation] + + +class SimpleNodeWithAbstractCollections(Node): + int_value: Int + int_sequence: Sequence[Int] + str_set: Set[Str] + str_to_int_mapping: Mapping[Str, Int] + loc: Optional[SourceLocation] + + +class SimpleNodeWithSymbolName(Node): + int_value: Int + name: SymbolName + + +class SimpleNodeWithDefaultSymbolName(Node): + int_value: Int + name: SymbolName = SymbolName("symbol_name") + + +class CompoundNode(Node): + int_value: Int + location: LocationNode + simple: SimpleNode + simple_loc: SimpleNodeWithLoc + simple_opt: SimpleNodeWithOptionals + other_simple_opt: Optional[SimpleNodeWithOptionals] + + +class CompoundNodeWithSymbols(Node): + int_value: Int + location: LocationNode + simple: SimpleNode + simple_loc: SimpleNodeWithLoc + simple_opt: SimpleNodeWithOptionals + other_simple_opt: Optional[SimpleNodeWithOptionals] + node_with_name: SimpleNodeWithSymbolName + + +class NodeWithSymbolTable(Node, SymbolTableTrait): + node_with_name: SimpleNodeWithSymbolName + list_with_name: List[SimpleNodeWithSymbolName] + node_with_default_name: SimpleNodeWithDefaultSymbolName + compound_with_name: CompoundNodeWithSymbols + + +class FrozenSimpleNode(FrozenNode): + int_value: Int + bool_value: Bool + float_value: Float + str_value: Str + bytes_value: Bytes + int_kind: IntKind + str_kind: StrKind + + +# -- General maker functions -- +_MIN_INT = -9999 +_MAX_INT = 9999 +_MIN_FLOAT = -999.0 +_MAX_FLOAT = 999.09 +_SEQUENCE_LEN = 6 + + +def make_random_bool_value() -> bool: + return random.choice([True, False]) + + +def make_bool_value(*, fixed: bool = False) -> bool: + return True if fixed else make_random_bool_value() + + +def make_random_int_value() -> int: + return random.randint(_MIN_INT, _MAX_INT) + + +def make_int_value(*, fixed: bool = False) -> int: + return 1 if fixed else make_random_int_value() + + +def make_random_neg_int_value() -> int: + return random.randint(_MIN_INT, 1) + + +def make_neg_int_value(*, fixed: bool = False) -> int: + return -2 if fixed else make_random_neg_int_value() + + +def make_random_pos_int_value() -> int: + return random.randint(1, _MAX_INT) + + +def make_pos_int_value(*, fixed: bool = False) -> int: + return 2 if fixed else make_random_pos_int_value() + + +def make_random_float_value() -> float: + return _MIN_FLOAT + random.random() * (_MAX_FLOAT - _MIN_FLOAT) + + +def make_float_value(*, fixed: bool = False) -> float: + return 1.1 if fixed else make_random_float_value() + + +def make_random_str_value(length: Optional[int] = None) -> str: + length = length or _SEQUENCE_LEN + return "".join(random.choice(string.ascii_letters) for _ in range(length)) + + +def make_str_value(length: int = _SEQUENCE_LEN, *, fixed: bool = False) -> str: + return string.ascii_letters[:length] if fixed else make_random_str_value(length) + + +def make_random_member_value(values: Sequence[T]) -> T: + return random.choice(values) + + +def make_member_value(values: Sequence[T], *, fixed: bool = False) -> T: + return values[0] if fixed else make_random_member_value(values) + + +def make_collection_value( + item_type: Type[T], + collection_type: Type[Collection[T]] = list, + length: Optional[int] = None, + *, + fixed: bool = False, +) -> Collection[T]: + length = length or _SEQUENCE_LEN + + maker_attr_name = f"make_{item_type.__name__}" + try: + maker = globals()[maker_attr_name] + except Exception: + + def maker(): + return item_type() + + return collection_type([maker() for _ in range(length)]) # type: ignore + + +def make_mapping_value( + key_type: Type[S], + value_type: Type[T], + mapping_type: Type[Mapping[S, T]] = dict, + length: Optional[int] = None, + *, + fixed: bool = False, +) -> Mapping[S, T]: + length = length or _SEQUENCE_LEN + + key_maker_attr_name = f"make_{key_type.__name__}" + try: + key_maker = globals()[key_maker_attr_name] + except Exception: + + def key_maker(): + return key_type() + + value_maker_attr_name = f"make_{value_type.__name__}" + try: + value_maker = globals()[value_maker_attr_name] + except Exception: + + def value_maker(): + return value_type() + + return mapping_type({key_maker(): value_maker() for _ in range(length)}) # type: ignore + + +def make_multinode_collection_value( + node_class: Type[Node], + collection_type: Type[Collection[T]] = list, + length: Optional[int] = None, + *, + fixed: bool = False, +) -> Collection[T]: + length = length or _SEQUENCE_LEN + + maker = globals()[f"make_{CaseStyleConverter.convert(node_class.__name__, 'pascal', 'snake')}"] + + return collection_type([maker(fixed=fixed) for _ in range(length)]) # type: ignore + + +# -- Node maker functions -- +def make_source_location(*, fixed: bool = False) -> SourceLocation: + line = make_pos_int_value(fixed=fixed) + column = make_pos_int_value(fixed=fixed) + str_value = make_str_value(fixed=fixed) + source = f"file_{str_value}.py" + + return SourceLocation(line=line, column=column, source=source) + + +def make_empty_node(*, fixed: bool = False) -> LocationNode: + return EmptyNode() + + +def make_location_node(*, fixed: bool = False) -> LocationNode: + return LocationNode(loc=make_source_location(fixed=fixed)) + + +def make_simple_node(*, fixed: bool = False) -> SimpleNode: + int_value = make_int_value(fixed=fixed) + bool_value = make_bool_value(fixed=fixed) + float_value = make_float_value(fixed=fixed) + str_value = make_str_value(fixed=fixed) + bytes_value = make_str_value(fixed=fixed).encode() + int_kind = IntKind.PLUS if fixed else make_member_value([*IntKind], fixed=fixed) + str_kind = StrKind.BLA if fixed else make_member_value([*StrKind], fixed=fixed) + + return SimpleNode( + int_value=int_value, + bool_value=bool_value, + float_value=float_value, + str_value=str_value, + bytes_value=bytes_value, + int_kind=int_kind, + str_kind=str_kind, + ) + + +def make_simple_node_with_optionals(*, fixed: bool = False) -> SimpleNodeWithOptionals: + int_value = make_int_value(fixed=fixed) + float_value = make_float_value(fixed=fixed) + + return SimpleNodeWithOptionals(int_value=int_value, float_value=float_value) + + +def make_simple_node_with_loc(*, fixed: bool = False) -> SimpleNodeWithLoc: + int_value = make_int_value(fixed=fixed) + float_value = make_float_value(fixed=fixed) + str_value = make_str_value(fixed=fixed) + loc = make_source_location(fixed=fixed) + + return SimpleNodeWithLoc( + int_value=int_value, float_value=float_value, str_value=str_value, loc=loc + ) + + +def make_simple_node_with_collections(*, fixed: bool = False) -> SimpleNodeWithCollections: + int_value = make_int_value(fixed=fixed) + int_list = make_collection_value(int, length=3) + str_set = make_collection_value(str, set, length=3) + str_to_int_dict = make_mapping_value(key_type=str, value_type=int, length=3) + loc = make_source_location(fixed=fixed) + + return SimpleNodeWithCollections( + int_value=int_value, + int_list=int_list, + str_set=str_set, + str_to_int_dict=str_to_int_dict, + loc=loc, + ) + + +def make_simple_node_with_abstract_collections( + *, fixed: bool = False, +) -> SimpleNodeWithAbstractCollections: + int_value = make_int_value(fixed=fixed) + int_sequence = make_collection_value(int, collection_type=tuple, length=3) + str_set = make_collection_value(str, set, length=3) + str_to_int_mapping = make_mapping_value(key_type=str, value_type=int, length=3) + + return SimpleNodeWithAbstractCollections( + int_value=int_value, + int_sequence=int_sequence, + str_set=str_set, + str_to_int_mapping=str_to_int_mapping, + ) + + +def make_simple_node_with_symbol_name(*, fixed: bool = False,) -> SimpleNodeWithSymbolName: + int_value = make_int_value(fixed=fixed) + name = make_str_value(fixed=fixed) + + return SimpleNodeWithSymbolName(int_value=int_value, name=name) + + +def make_simple_node_with_default_symbol_name( + *, fixed: bool = False, +) -> SimpleNodeWithDefaultSymbolName: + int_value = make_int_value(fixed=fixed) + + return SimpleNodeWithDefaultSymbolName(int_value=int_value) + + +def make_compound_node(*, fixed: bool = False) -> CompoundNode: + return CompoundNode( + int_value=make_int_value(fixed=fixed), + location=make_location_node(), + simple=make_simple_node(), + simple_loc=make_simple_node_with_loc(), + simple_opt=make_simple_node_with_optionals(), + other_simple_opt=None, + ) + + +def make_compound_node_with_symbols(*, fixed: bool = False) -> CompoundNodeWithSymbols: + return CompoundNodeWithSymbols( + int_value=make_int_value(fixed=fixed), + location=make_location_node(), + simple=make_simple_node(), + simple_loc=make_simple_node_with_loc(), + simple_opt=make_simple_node_with_optionals(), + other_simple_opt=None, + node_with_name=make_simple_node_with_symbol_name(fixed=fixed), + ) + + +def make_node_with_symbol_table(*, fixed: bool = False) -> NodeWithSymbolTable: + return NodeWithSymbolTable( + node_with_name=make_simple_node_with_symbol_name(fixed=fixed), + node_with_default_name=make_simple_node_with_default_symbol_name(fixed=fixed), + list_with_name=make_multinode_collection_value( + SimpleNodeWithSymbolName, length=4, fixed=fixed + ), + compound_with_name=make_compound_node_with_symbols(fixed=fixed), + ) + + +def make_frozen_simple_node(*, fixed: bool = False) -> FrozenSimpleNode: + int_value = make_int_value(fixed=fixed) + bool_value = make_bool_value(fixed=fixed) + float_value = make_float_value(fixed=fixed) + str_value = make_str_value(fixed=fixed) + bytes_value = make_str_value(fixed=fixed).encode() + int_kind = IntKind.PLUS if fixed else make_member_value([*IntKind], fixed=fixed) + str_kind = StrKind.BLA if fixed else make_member_value([*StrKind], fixed=fixed) + + return FrozenSimpleNode( + int_value=int_value, + bool_value=bool_value, + float_value=float_value, + str_value=str_value, + bytes_value=bytes_value, + int_kind=int_kind, + str_kind=str_kind, + ) + + +# -- Makers of invalid nodes -- +def make_invalid_location_node(*, fixed: bool = False) -> LocationNode: + return LocationNode(loc=SourceLocation(line=0, column=-1, source="")) + + +def make_invalid_at_int_simple_node(*, fixed: bool = False) -> SimpleNode: + int_value = make_float_value(fixed=fixed) + bool_value = make_bool_value(fixed=fixed) + float_value = make_float_value(fixed=fixed) + bytes_value = make_str_value(fixed=fixed).encode() + str_value = make_str_value(fixed=fixed) + int_kind = IntKind.PLUS if fixed else make_member_value([*IntKind], fixed=fixed) + str_kind = StrKind.BLA if fixed else make_member_value([*StrKind], fixed=fixed) + + return SimpleNode( + int_value=int_value, + bool_value=bool_value, + float_value=float_value, + str_value=str_value, + bytes_value=bytes_value, + int_kind=int_kind, + str_kind=str_kind, + ) + + +def make_invalid_at_float_simple_node(*, fixed: bool = False) -> SimpleNode: + int_value = make_int_value(fixed=fixed) + bool_value = make_bool_value(fixed=fixed) + float_value = make_int_value(fixed=fixed) + str_value = make_str_value(fixed=fixed) + bytes_value = make_str_value(fixed=fixed).encode() + int_kind = IntKind.PLUS if fixed else make_member_value([*IntKind], fixed=fixed) + str_kind = StrKind.BLA if fixed else make_member_value([*StrKind], fixed=fixed) + + return SimpleNode( + int_value=int_value, + bool_value=bool_value, + float_value=float_value, + str_value=str_value, + bytes_value=bytes_value, + int_kind=int_kind, + str_kind=str_kind, + ) + + +def make_invalid_at_str_simple_node(*, fixed: bool = False) -> SimpleNode: + int_value = make_int_value(fixed=fixed) + bool_value = make_bool_value(fixed=fixed) + float_value = make_float_value(fixed=fixed) + str_value = make_float_value(fixed=fixed) + bytes_value = make_str_value(fixed=fixed).encode() + int_kind = IntKind.PLUS if fixed else make_member_value([*IntKind]) + str_kind = StrKind.BLA if fixed else make_member_value([*StrKind]) + + return SimpleNode( + int_value=int_value, + bool_value=bool_value, + float_value=float_value, + str_value=str_value, + bytes_value=bytes_value, + int_kind=int_kind, + str_kind=str_kind, + ) + + +def make_invalid_at_bytes_simple_node(*, fixed: bool = False) -> SimpleNode: + int_value = make_int_value(fixed=fixed) + bool_value = make_bool_value(fixed=fixed) + float_value = make_float_value(fixed=fixed) + str_value = make_float_value(fixed=fixed) + bytes_value = [1, "2", (3, 4)] + int_kind = IntKind.PLUS if fixed else make_member_value([*IntKind]) + str_kind = StrKind.BLA if fixed else make_member_value([*StrKind]) + + return SimpleNode( + int_value=int_value, + bool_value=bool_value, + float_value=float_value, + str_value=str_value, + bytes_value=bytes_value, + int_kind=int_kind, + str_kind=str_kind, + ) + + +def make_invalid_at_enum_simple_node(*, fixed: bool = False) -> SimpleNode: + int_value = make_int_value(fixed=fixed) + bool_value = make_bool_value(fixed=fixed) + float_value = make_float_value(fixed=fixed) + str_value = make_float_value(fixed=fixed) + bytes_value = make_str_value(fixed=fixed).encode() + int_kind = IntKind.PLUS if fixed else make_member_value([*IntKind]) + str_kind = StrKind.BLA if fixed else make_member_value([*StrKind]) + + return SimpleNode( + int_value=int_value, + bool_value=bool_value, + float_value=float_value, + str_value=str_value, + bytes_value=bytes_value, + int_kind=int_kind, + str_kind=str_kind, + ) diff --git a/tests/tests_eve/conftest.py b/tests/tests_eve/conftest.py index 3135cdc..f049fa1 100644 --- a/tests/tests_eve/conftest.py +++ b/tests/tests_eve/conftest.py @@ -20,22 +20,23 @@ import pytest -from . import common +from . import common_definitions NODE_MAKERS = [] FROZEN_NODE_MAKERS = [] INVALID_NODE_MAKERS = [] -# Automatic creation of pytest fixtures from maker functions in .common -for key, value in common.__dict__.items(): + +# Automatic creation of pytest fixtures from maker functions in .common_definitions +for key, value in common_definitions.__dict__.items(): if key.startswith("make_"): name = key[5:] exec( f""" @pytest.fixture def {name}_maker(): - yield common.make_{name} + yield common_definitions.make_{name} @pytest.fixture def {name}({name}_maker): @@ -47,10 +48,10 @@ def fixed_{name}({name}_maker): """ ) - if "node" in key: - if "invalid" in key: + if "_node" in key: + if "_invalid" in key: INVALID_NODE_MAKERS.append(value) - elif "frozen" in key: + elif "_frozen" in key: FROZEN_NODE_MAKERS.append(value) else: NODE_MAKERS.append(value) diff --git a/tests/tests_eve/unit_tests/test_concepts.py b/tests/tests_eve/unit_tests/test_concepts.py index 4d1797b..bd9b76c 100644 --- a/tests/tests_eve/unit_tests/test_concepts.py +++ b/tests/tests_eve/unit_tests/test_concepts.py @@ -22,7 +22,7 @@ import eve -from .. import common +from .. import common_definitions @pytest.fixture( @@ -166,7 +166,7 @@ def test_unique_id(self, sample_node_maker): def test_custom_id(self, source_location, sample_node_maker): custom_id = "my_custom_id" - my_node = common.LocationNode(loc=source_location) + my_node = common_definitions.LocationNode(loc=source_location) other_node = sample_node_maker() my_node.__node_id__ = custom_id diff --git a/tests/tests_eve/unit_tests/test_traits.py b/tests/tests_eve/unit_tests/test_traits.py new file mode 100644 index 0000000..7fa158e --- /dev/null +++ b/tests/tests_eve/unit_tests/test_traits.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# +# Eve Toolchain - GT4Py Project - GridTools Framework +# +# Copyright (c) 2020, CSCS - Swiss National Supercomputing Center, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from __future__ import annotations + +import pytest + +from .. import common_definitions + + +@pytest.fixture +def symtable_node_with_expected_symbols(request): + node = common_definitions.make_node_with_symbol_table() + symbols = { + node.node_with_name.name: node.node_with_name, + node.node_with_default_name.name: node.node_with_default_name, + node.compound_with_name.node_with_name.name: node.compound_with_name.node_with_name, + } + symbols.update({n.name: n for n in node.list_with_name}) + + yield node, symbols + + +class TestSymbolTable: + def test_symbol_table_creation(self, symtable_node_with_expected_symbols): + node, expected_symbols = symtable_node_with_expected_symbols + collected_symtable = node.__node_impl__.symtable + assert isinstance(node.__node_impl__.symtable, dict) + assert all(isinstance(key, str) for key in collected_symtable) + + def test_symbol_table_collection(self, symtable_node_with_expected_symbols): + node, expected_symbols = symtable_node_with_expected_symbols + collected_symtable = node.__node_impl__.symtable + assert collected_symtable == expected_symbols + assert all( + collected_symtable[symbol_name] is symbol_node + for symbol_name, symbol_node in expected_symbols.items() + ) + + # import devtools + # devtools.debug(node) + # print(f"COLLECTED: {collected_symtable}") + # print(f"EXPECTED: {expected_symbols}") diff --git a/tests/tests_eve/unit_tests/test_utils.py b/tests/tests_eve/unit_tests/test_utils.py index 593f1cb..d5b7524 100644 --- a/tests/tests_eve/unit_tests/test_utils.py +++ b/tests/tests_eve/unit_tests/test_utils.py @@ -209,6 +209,10 @@ def test_case_style_converter(name_with_cases): words = name_with_cases.pop("words") for case, cased_string in name_with_cases.items(): + # Try also passing case as a string + if len(case.value) % 2: + case = case.value + assert CaseStyleConverter.join(words, case) == cased_string if case == CaseStyleConverter.CASE_STYLE.CONCATENATED: with pytest.raises(ValueError, match="Impossible to split"): From 8bc2aa37056077fe01d6f4699af24b654c61d9ed Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Fri, 6 Nov 2020 17:23:34 +0100 Subject: [PATCH 16/17] Update mypy version --- .pre-commit-config.yaml | 2 +- requirements_dev.txt | 2 +- src/eve/concepts.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5630e99..a364a10 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: [darglint, flake8-bugbear, flake8-builtins, flake8-docstrings] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.761 + rev: v0.790 hooks: - id: mypy additional_dependencies: [pydantic~=1.0] diff --git a/requirements_dev.txt b/requirements_dev.txt index 52c5e65..33200fb 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -14,7 +14,7 @@ ipython>=7 isort>=4.3.21 jupyter>=1.0 jupyterlab>=2.1 -mypy>=0.761 +mypy>=0.790 pip-tools>=5.1 pre-commit>=2.2 pytest>=5.4 diff --git a/src/eve/concepts.py b/src/eve/concepts.py index a5d5a97..cae584f 100644 --- a/src/eve/concepts.py +++ b/src/eve/concepts.py @@ -264,7 +264,7 @@ class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass): ) #: Node data annotations - __node_impl__: Optional[types.SimpleNamespace] = pydantic.PrivateAttr( # type: ignore # mypy can't find PrivateAttr + __node_impl__: Optional[types.SimpleNamespace] = pydantic.PrivateAttr( default_factory=types.SimpleNamespace ) From 1b872008399acb69e50134eca5cdf82dcd4c9b91 Mon Sep 17 00:00:00 2001 From: "Enrique G. Paredes" Date: Fri, 6 Nov 2020 17:26:47 +0100 Subject: [PATCH 17/17] Add symbol table tests --- src/eve/exceptions.py | 5 +++ src/eve/traits.py | 16 +++++++--- tests/tests_eve/unit_tests/test_traits.py | 38 +++++++++++++++++------ 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/eve/exceptions.py b/src/eve/exceptions.py index 564e362..de76a47 100644 --- a/src/eve/exceptions.py +++ b/src/eve/exceptions.py @@ -42,3 +42,8 @@ class EveValueError(EveError, ValueError): class EveRuntimeError(EveError, RuntimeError): message_template = "Runtime error (info: {info})" + + +# -- Runtime errors -- +class DuplicatedSymbolError(EveRuntimeError): + pass diff --git a/src/eve/traits.py b/src/eve/traits.py index 133b73e..879f8fe 100644 --- a/src/eve/traits.py +++ b/src/eve/traits.py @@ -19,7 +19,7 @@ from __future__ import annotations -from . import concepts, iterators +from . import concepts, exceptions, iterators from .type_definitions import SymbolName from .typingx import Any, Dict @@ -44,15 +44,23 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.collect_symbols() @staticmethod - def _collect_symbols(root_node: concepts.TreeNode) -> Dict[str, Any]: - collected = {} + def _collect_symbols(root_node: concepts.TreeNode) -> Dict[str, concepts.BaseNode]: + collected: Dict[str, concepts.BaseNode] = {} for node in iterators.traverse_tree(root_node): if isinstance(node, concepts.BaseNode): for name, metadata in node.__node_children__.items(): if isinstance(metadata["definition"].type_, type) and issubclass( metadata["definition"].type_, SymbolName ): - collected[getattr(node, name)] = node + symbol_name = getattr(node, name) + if symbol_name in collected: + raise exceptions.DuplicatedSymbolError( + f"""Redefinition of symbol name '{name}': + - previous: {collected[symbol_name]} + - new: {node} +""" + ) + collected[symbol_name] = node return collected diff --git a/tests/tests_eve/unit_tests/test_traits.py b/tests/tests_eve/unit_tests/test_traits.py index 7fa158e..308ddbb 100644 --- a/tests/tests_eve/unit_tests/test_traits.py +++ b/tests/tests_eve/unit_tests/test_traits.py @@ -19,11 +19,14 @@ import pytest +import eve +from eve.typingx import List + from .. import common_definitions @pytest.fixture -def symtable_node_with_expected_symbols(request): +def symtable_node_and_expected_symbols(): node = common_definitions.make_node_with_symbol_table() symbols = { node.node_with_name.name: node.node_with_name, @@ -35,15 +38,31 @@ def symtable_node_with_expected_symbols(request): yield node, symbols +class _NodeWithSymbolName(eve.Node): + name: eve.SymbolName = eve.SymbolName("symbol_name") + + +class _NodeWithSymbolTable(eve.Node, eve.SymbolTableTrait): + symbols: List[_NodeWithSymbolName] + + +@pytest.fixture +def node_with_duplicated_names_maker(): + def _maker(): + return _NodeWithSymbolTable(symbols=[_NodeWithSymbolName(), _NodeWithSymbolName()]) + + yield _maker + + class TestSymbolTable: - def test_symbol_table_creation(self, symtable_node_with_expected_symbols): - node, expected_symbols = symtable_node_with_expected_symbols + def test_symbol_table_creation(self, symtable_node_and_expected_symbols): + node, expected_symbols = symtable_node_and_expected_symbols collected_symtable = node.__node_impl__.symtable assert isinstance(node.__node_impl__.symtable, dict) assert all(isinstance(key, str) for key in collected_symtable) - def test_symbol_table_collection(self, symtable_node_with_expected_symbols): - node, expected_symbols = symtable_node_with_expected_symbols + def test_symbol_table_collection(self, symtable_node_and_expected_symbols): + node, expected_symbols = symtable_node_and_expected_symbols collected_symtable = node.__node_impl__.symtable assert collected_symtable == expected_symbols assert all( @@ -51,7 +70,8 @@ def test_symbol_table_collection(self, symtable_node_with_expected_symbols): for symbol_name, symbol_node in expected_symbols.items() ) - # import devtools - # devtools.debug(node) - # print(f"COLLECTED: {collected_symtable}") - # print(f"EXPECTED: {expected_symbols}") + def test_duplicated_symbols(self, node_with_duplicated_names_maker): + with pytest.raises( + eve.exceptions.DuplicatedSymbolError, match="Redefinition of symbol name" + ): + node_with_duplicated_names_maker()