diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index 1b0e995156..31e63bdf9f 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -16,6 +16,7 @@ import dataclasses import functools import sys +import types import typing import warnings @@ -1254,8 +1255,11 @@ def _make_concrete_with_cache( if not is_generic_datamodel_class(datamodel_cls): raise TypeError(f"'{datamodel_cls.__name__}' is not a generic model class.") for t in type_args: + _accepted_types: tuple[type, ...] = (type, type(None), xtyping.StdGenericAliasType) + if sys.version_info >= (3, 10): + _accepted_types = (*_accepted_types, types.UnionType) if not ( - isinstance(t, (type, type(None), xtyping.StdGenericAliasType)) + isinstance(t, _accepted_types) or (getattr(type(t), "__module__", None) in ("typing", "typing_extensions")) ): raise TypeError( diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index e150832295..695ab69dc3 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -14,6 +14,8 @@ import collections.abc import dataclasses import functools +import sys +import types import typing from . import exceptions, extended_typing as xtyping, utils @@ -193,6 +195,12 @@ def __call__( if type_annotation is None: type_annotation = type(None) + if sys.version_info >= (3, 10): + if isinstance( + type_annotation, types.UnionType + ): # see https://github.com/python/cpython/issues/105499 + type_annotation = typing.Union[type_annotation.__args__] + # Non-generic types if xtyping.is_actual_type(type_annotation): assert not xtyping.get_args(type_annotation) @@ -277,6 +285,7 @@ def __call__( if issubclass(origin_type, (collections.abc.Sequence, collections.abc.Set)): assert len(type_args) == 1 + make_recursive(type_args[0]) if (member_validator := make_recursive(type_args[0])) is None: raise exceptions.EveValueError( f"{type_args[0]} type annotation is not supported." diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index d334487ae1..6b40cbb77f 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Optional, TypeVar, cast +from typing import Any, Optional, TypeAlias, TypeVar, cast import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits @@ -48,7 +48,7 @@ def with_altered_scalar_kind( if isinstance(type_spec, ts.FieldType): return ts.FieldType( dims=type_spec.dims, - dtype=ts.ScalarType(kind=new_scalar_kind, shape=type_spec.dtype.shape), + dtype=with_altered_scalar_kind(type_spec.dtype, new_scalar_kind), ) elif isinstance(type_spec, ts.ScalarType): return ts.ScalarType(kind=new_scalar_kind, shape=type_spec.shape) @@ -68,13 +68,18 @@ def construct_tuple_type( >>> mask_type = ts.FieldType( ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) ... ) - >>> true_branch_types = [ts.ScalarType(kind=ts.ScalarKind), ts.ScalarType(kind=ts.ScalarKind)] + >>> true_branch_types = [ + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), + ... ] >>> false_branch_types = [ - ... ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind)), - ... ts.ScalarType(kind=ts.ScalarKind), + ... ts.FieldType( + ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... ), + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), ... ] >>> print(construct_tuple_type(true_branch_types, false_branch_types, mask_type)) - [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] + [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] """ element_types_new = true_branch_types for i, element in enumerate(true_branch_types): @@ -105,8 +110,8 @@ def promote_to_mask_type( >>> I, J = (Dimension(value=dim) for dim in ["I", "J"]) >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) >>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), ts.ScalarType(kind=dtype)) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=ScalarType(kind=, shape=None), shape=None)) + >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), dtype) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) >>> promote_to_mask_type( ... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype) ... ) @@ -360,7 +365,7 @@ def visit_Assign(self, node: foast.Assign, **kwargs: Any) -> foast.Assign: def visit_TupleTargetAssign( self, node: foast.TupleTargetAssign, **kwargs: Any ) -> foast.TupleTargetAssign: - TargetType = list[foast.Starred | foast.Symbol] + TargetType: TypeAlias = list[foast.Starred | foast.Symbol] values = self.visit(node.value, **kwargs) if isinstance(values.type, ts.TupleType): @@ -374,7 +379,7 @@ def visit_TupleTargetAssign( ) new_targets: TargetType = [] - new_type: ts.TupleType | ts.DataType + new_type: ts.DataType for i, index in enumerate(indices): old_target = targets[i] @@ -391,7 +396,8 @@ def visit_TupleTargetAssign( location=old_target.location, ) else: - new_type = values.type.types[index] + new_type = values.type.types[index] # type: ignore[assignment] # see check in next line + assert isinstance(new_type, ts.DataType) new_target = self.visit( old_target, refine_type=new_type, location=old_target.location, **kwargs ) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3c65695aec..4519b4e571 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -236,6 +236,7 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") @@ -417,12 +418,14 @@ def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) min_value, _ = type_info.arithmetic_bounds(dtype) init_expr = self._make_literal(str(min_value), dtype) return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) _, max_value = type_info.arithmetic_bounds(dtype) init_expr = self._make_literal(str(max_value), dtype) return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 92f7327218..9355273588 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -104,6 +104,15 @@ def visit_Program(self, node: past.Program, **kwargs: Any) -> past.Program: location=node.location, ) + def visit_Slice(self, node: past.Slice, **kwargs: Any) -> past.Slice: + return past.Slice( + lower=self.visit(node.lower, **kwargs), + upper=self.visit(node.upper, **kwargs), + step=self.visit(node.step, **kwargs), + type=ts.DeferredType(constraint=None), + location=node.location, + ) + def visit_Subscript(self, node: past.Subscript, **kwargs: Any) -> past.Subscript: value = self.visit(node.value, **kwargs) return past.Subscript( diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 7958b7a8d3..1add668791 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -109,6 +109,7 @@ def _field_constituents_shape_and_dims( match arg_type: case ts.TupleType(): for el, el_type in zip(arg, arg_type.types): + assert isinstance(el_type, ts.DataType) yield from _field_constituents_shape_and_dims(el, el_type) case ts.FieldType(): dims = type_info.extract_dims(arg_type) diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 8160a2c42d..83ecf92839 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -169,7 +169,9 @@ def _scan_param_promotion(param: ts.TypeSpec, arg: ts.TypeSpec) -> ts.FieldType -------- >>> _scan_param_promotion( ... ts.ScalarType(kind=ts.ScalarKind.INT64), - ... ts.FieldType(dims=[common.Dimension("I")], dtype=ts.ScalarKind.FLOAT64), + ... ts.FieldType( + ... dims=[common.Dimension("I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... ), ... ) FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)) """ @@ -252,8 +254,8 @@ def function_signature_incompatibilities_scanop( # build a function type to leverage the already existing signature checking capabilities function_type = ts.FunctionType( pos_only_args=[], - pos_or_kw_args=promoted_params, # type: ignore[arg-type] # dict is invariant, but we don't care here. - kw_only_args=promoted_kwparams, # type: ignore[arg-type] # same as above + pos_or_kw_args=promoted_params, + kw_only_args=promoted_kwparams, returns=ts.DeferredType(constraint=None), ) diff --git a/src/gt4py/next/ffront/type_specifications.py b/src/gt4py/next/ffront/type_specifications.py index e4f6c826fe..b76a116297 100644 --- a/src/gt4py/next/ffront/type_specifications.py +++ b/src/gt4py/next/ffront/type_specifications.py @@ -6,23 +6,19 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass import gt4py.next.type_system.type_specifications as ts -from gt4py.next import common as func_common +from gt4py.next import common -@dataclass(frozen=True) class ProgramType(ts.TypeSpec, ts.CallableType): definition: ts.FunctionType -@dataclass(frozen=True) class FieldOperatorType(ts.TypeSpec, ts.CallableType): definition: ts.FunctionType -@dataclass(frozen=True) class ScanOperatorType(ts.TypeSpec, ts.CallableType): - axis: func_common.Dimension + axis: common.Dimension definition: ts.FunctionType diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 13c64e264e..5949d29432 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -54,7 +54,6 @@ ) from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime -from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.otf import arguments from gt4py.next.type_system import type_specifications as ts, type_translation @@ -1460,7 +1459,7 @@ class _List(Generic[DT]): def __getitem__(self, i: int): return self.values[i] - def __gt_type__(self) -> itir_ts.ListType: + def __gt_type__(self) -> ts.ListType: offset_tag = self.offset.value assert isinstance(offset_tag, str) element_type = type_translation.from_value(self.values[0]) @@ -1470,7 +1469,7 @@ def __gt_type__(self) -> itir_ts.ListType: connectivity = offset_provider[offset_tag] assert common.is_neighbor_connectivity(connectivity) local_dim = connectivity.__gt_type__().neighbor_dim - return itir_ts.ListType(element_type=element_type, offset_type=local_dim) + return ts.ListType(element_type=element_type, offset_type=local_dim) @dataclasses.dataclass(frozen=True) @@ -1480,10 +1479,10 @@ class _ConstList(Generic[DT]): def __getitem__(self, _): return self.value - def __gt_type__(self) -> itir_ts.ListType: + def __gt_type__(self) -> ts.ListType: element_type = type_translation.from_value(self.value) assert isinstance(element_type, ts.DataType) - return itir_ts.ListType( + return ts.ListType( element_type=element_type, offset_type=_CONST_DIM, ) @@ -1801,7 +1800,7 @@ def _fieldspec_list_to_value( domain: common.Domain, type_: ts.TypeSpec ) -> tuple[common.Domain, ts.TypeSpec]: """Translate the list element type into the domain.""" - if isinstance(type_, itir_ts.ListType): + if isinstance(type_, ts.ListType): if type_.offset_type == _CONST_DIM: return domain.insert( len(domain), common.named_range((_CONST_DIM, 1)) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index b7087472e0..cc42896f2b 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -20,10 +20,7 @@ inline_lifts, trace_shifts, ) -from gt4py.next.iterator.type_system import ( - inference as type_inference, - type_specifications as it_ts, -) +from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -140,7 +137,7 @@ def fuse_as_fieldop( if arg.type and not isinstance(arg.type, ts.DeferredType): assert isinstance(arg.type, ts.TypeSpec) dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) - assert not isinstance(dtype, it_ts.ListType) + assert not isinstance(dtype, ts.ListType) new_param: str if isinstance( arg, itir.SymRef @@ -246,7 +243,7 @@ def visit_FunCall(self, node: itir.FunCall): ) or cpm.is_call_to(arg, "if_") ) - and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) + and (isinstance(dtype, ts.ListType) or len(arg_shifts) <= 1) ) ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 334fb330d7..ac7fcb8f1c 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -98,12 +98,12 @@ def _transform_by_pattern( tmp_expr.type, tuple_constructor=lambda *elements: tuple(elements), ) - tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = ( - type_info.apply_to_primitive_constituents( - type_info.extract_dtype, - tmp_expr.type, - tuple_constructor=lambda *elements: tuple(elements), - ) + tmp_dtypes: ( + ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...] + ) = type_info.apply_to_primitive_constituents( + type_info.extract_dtype, + tmp_expr.type, + tuple_constructor=lambda *elements: tuple(elements), ) # allocate temporary for all tuple elements diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 1b980783fa..1da59546c0 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -275,8 +275,8 @@ def _get_dimensions(obj: Any): if isinstance(obj, common.Dimension): yield obj elif isinstance(obj, ts.TypeSpec): - for field in dataclasses.fields(obj.__class__): - yield from _get_dimensions(getattr(obj, field.name)) + for field in obj.__datamodel_fields__.keys(): + yield from _get_dimensions(getattr(obj, field)) elif isinstance(obj, collections.abc.Mapping): for el in obj.values(): yield from _get_dimensions(el) @@ -479,7 +479,7 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( - lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above + lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype, ) diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index eef8c75d0f..7825bf1c98 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -6,43 +6,29 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import dataclasses -from typing import Literal, Optional +from typing import Literal from gt4py.next import common from gt4py.next.type_system import type_specifications as ts -@dataclasses.dataclass(frozen=True) class NamedRangeType(ts.TypeSpec): dim: common.Dimension -@dataclasses.dataclass(frozen=True) class DomainType(ts.DataType): dims: list[common.Dimension] | Literal["unknown"] -@dataclasses.dataclass(frozen=True) class OffsetLiteralType(ts.TypeSpec): value: ts.ScalarType | common.Dimension -@dataclasses.dataclass(frozen=True) -class ListType(ts.DataType): - element_type: ts.DataType - # TODO(havogt): the `offset_type` is not yet used in type_inference, - # it is meant to describe the neighborhood (via the local dimension) - offset_type: Optional[common.Dimension] = None - - -@dataclasses.dataclass(frozen=True) class IteratorType(ts.DataType, ts.CallableType): position_dims: list[common.Dimension] | Literal["unknown"] defined_dims: list[common.Dimension] element_type: ts.DataType -@dataclasses.dataclass(frozen=True) class ProgramType(ts.TypeSpec): params: dict[str, ts.DataType] diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 5be9ed7438..22a04ec04a 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -155,18 +155,18 @@ def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType @_register_builtin_type_synthesizer -def make_const_list(scalar: ts.ScalarType) -> it_ts.ListType: +def make_const_list(scalar: ts.ScalarType) -> ts.ListType: assert isinstance(scalar, ts.ScalarType) - return it_ts.ListType(element_type=scalar) + return ts.ListType(element_type=scalar) @_register_builtin_type_synthesizer -def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: it_ts.ListType) -> ts.DataType: +def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: ts.ListType) -> ts.DataType: if isinstance(index, it_ts.OffsetLiteralType): assert isinstance(index.value, ts.ScalarType) index = index.value assert isinstance(index, ts.ScalarType) and type_info.is_integral(index) - assert isinstance(list_, it_ts.ListType) + assert isinstance(list_, ts.ListType) return list_.element_type @@ -198,14 +198,14 @@ def index(arg: ts.DimensionType) -> ts.FieldType: @_register_builtin_type_synthesizer -def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: +def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> ts.ListType: assert ( isinstance(offset_literal, it_ts.OffsetLiteralType) and isinstance(offset_literal.value, common.Dimension) and offset_literal.value.kind == common.DimensionKind.LOCAL ) assert isinstance(it, it_ts.IteratorType) - return it_ts.ListType(element_type=it.element_type) + return ts.ListType(element_type=it.element_type) @_register_builtin_type_synthesizer @@ -270,7 +270,7 @@ def _convert_as_fieldop_input_to_iterator( else: defined_dims.append(dim) if is_nb_field: - element_type = it_ts.ListType(element_type=element_type) + element_type = ts.ListType(element_type=element_type) return it_ts.IteratorType( position_dims=domain.dims, defined_dims=defined_dims, element_type=element_type @@ -342,14 +342,14 @@ def apply_scan( def map_(op: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def applied_map( - *args: it_ts.ListType, offset_provider_type: common.OffsetProviderType - ) -> it_ts.ListType: + *args: ts.ListType, offset_provider_type: common.OffsetProviderType + ) -> ts.ListType: assert len(args) > 0 - assert all(isinstance(arg, it_ts.ListType) for arg in args) + assert all(isinstance(arg, ts.ListType) for arg in args) arg_el_types = [arg.element_type for arg in args] el_type = op(*arg_el_types, offset_provider_type=offset_provider_type) assert isinstance(el_type, ts.DataType) - return it_ts.ListType(element_type=el_type) + return ts.ListType(element_type=el_type) return applied_map @@ -357,8 +357,8 @@ def applied_map( @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @TypeSynthesizer - def applied_reduce(*args: it_ts.ListType, offset_provider_type: common.OffsetProviderType): - assert all(isinstance(arg, it_ts.ListType) for arg in args) + def applied_reduce(*args: ts.ListType, offset_provider_type: common.OffsetProviderType): + assert all(isinstance(arg, ts.ListType) for arg in args) return op( init, *(arg.element_type for arg in args), offset_provider_type=offset_provider_type ) diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 24913a1365..edd56fad48 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -86,6 +86,7 @@ def _type_string(type_: ts.TypeSpec) -> str: return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>" elif isinstance(type_, ts.FieldType): ndims = len(type_.dims) + assert isinstance(type_.dtype, ts.ScalarType) dtype = cpp_interface.render_scalar_type(type_.dtype) shape = f"nanobind::shape<{', '.join(['gridtools::nanobind::dynamic_size'] * ndims)}>" buffer_t = f"nanobind::ndarray<{dtype}, {shape}>" diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index d5b34fd5b9..f7bb1805e0 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -701,7 +701,7 @@ def visit_Temporary( def dtype_to_cpp(x: ts.DataType) -> str: if isinstance(x, ts.TupleType): assert all(isinstance(i, ts.ScalarType) for i in x.types) - return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" + return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" # type: ignore[arg-type] # ensured by assert assert isinstance(x, ts.ScalarType) res = pytype_to_cpptype(x) assert isinstance(res, str) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index cffbd74c90..354a9692d8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -23,7 +23,6 @@ domain_utils, ir_makers as im, ) -from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_dataflow, @@ -119,7 +118,7 @@ def get_local_view( ) elif len(local_dims) == 1: - field_dtype = itir_ts.ListType( + field_dtype = ts.ListType( element_type=self.gt_type.dtype, offset_type=local_dims[0] ) field_domain = [ @@ -267,10 +266,11 @@ def _create_field_operator( if isinstance(output_edge.result.gt_dtype, ts.ScalarType): assert output_edge.result.gt_dtype == node_type.dtype assert isinstance(dataflow_output_desc, dace.data.Scalar) + assert isinstance(node_type.dtype, ts.ScalarType) assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) field_dtype = output_edge.result.gt_dtype else: - assert isinstance(node_type.dtype, itir_ts.ListType) + assert isinstance(node_type.dtype, ts.ListType) assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type assert isinstance(dataflow_output_desc, dace.data.Array) assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index a3653fb519..0376143883 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -31,7 +31,6 @@ from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_python_codegen, @@ -64,7 +63,7 @@ class ValueExpr: """ dc_node: dace.nodes.AccessNode - gt_dtype: itir_ts.ListType | ts.ScalarType + gt_dtype: ts.ListType | ts.ScalarType @dataclasses.dataclass(frozen=True) @@ -79,7 +78,7 @@ class MemletExpr: """ dc_node: dace.nodes.AccessNode - gt_dtype: itir_ts.ListType | ts.ScalarType + gt_dtype: ts.ListType | ts.ScalarType subset: dace_subsets.Range @@ -112,7 +111,7 @@ class IteratorExpr: """ field: dace.nodes.AccessNode - gt_dtype: itir_ts.ListType | ts.ScalarType + gt_dtype: ts.ListType | ts.ScalarType field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] @@ -121,7 +120,7 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: raise ValueError(f"Cannot deref iterator {self}.") field_desc = self.field.desc(sdfg) - if isinstance(self.gt_dtype, itir_ts.ListType): + if isinstance(self.gt_dtype, ts.ListType): assert len(field_desc.shape) == len(self.field_domain) + 1 assert self.gt_dtype.offset_type is not None field_domain = [*self.field_domain, (self.gt_dtype.offset_type, 0)] @@ -444,7 +443,7 @@ def _construct_tasklet_result( return ValueExpr( dc_node=temp_node, gt_dtype=( - itir_ts.ListType(element_type=data_type, offset_type=_CONST_DIM) + ts.ListType(element_type=data_type, offset_type=_CONST_DIM) if use_array else data_type ), @@ -547,7 +546,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: - assert isinstance(node.type, itir_ts.ListType) + assert isinstance(node.type, ts.ListType) assert len(node.args) == 2 assert isinstance(node.args[0], gtir.OffsetLiteral) @@ -650,7 +649,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: ) return ValueExpr( - dc_node=neighbors_node, gt_dtype=itir_ts.ListType(node.type.element_type, offset_type) + dc_node=neighbors_node, gt_dtype=ts.ListType(node.type.element_type, offset_type) ) def _visit_map(self, node: gtir.FunCall) -> ValueExpr: @@ -669,7 +668,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: In above example, the result would be an array with size V2E.max_neighbors, containing the V2E neighbor values incremented by 1.0. """ - assert isinstance(node.type, itir_ts.ListType) + assert isinstance(node.type, ts.ListType) assert isinstance(node.fun, gtir.FunCall) assert len(node.fun.args) == 1 # the operation to be mapped on the arguments @@ -689,7 +688,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gtx_common.Dimension, gtx_common.NeighborConnectivityType ] = {} for input_arg in input_args: - assert isinstance(input_arg.gt_dtype, itir_ts.ListType) + assert isinstance(input_arg.gt_dtype, ts.ListType) assert input_arg.gt_dtype.offset_type is not None offset_type = input_arg.gt_dtype.offset_type if offset_type == _CONST_DIM: @@ -759,7 +758,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: connectivity_slice = self._construct_local_view( MemletExpr( dc_node=self.state.add_access(connectivity), - gt_dtype=itir_ts.ListType( + gt_dtype=ts.ListType( element_type=node.type.element_type, offset_type=offset_type ), subset=dace_subsets.Range.from_string( @@ -798,7 +797,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: return ValueExpr( dc_node=result_node, - gt_dtype=itir_ts.ListType(node.type.element_type, offset_type), + gt_dtype=ts.ListType(node.type.element_type, offset_type), ) def _make_reduce_with_skip_values( @@ -825,7 +824,7 @@ def _make_reduce_with_skip_values( origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) assert ( - isinstance(input_expr.gt_dtype, itir_ts.ListType) + isinstance(input_expr.gt_dtype, ts.ListType) and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type @@ -938,7 +937,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: input_expr = self.visit(node.args[0]) assert isinstance(input_expr, (MemletExpr, ValueExpr)) assert ( - isinstance(input_expr.gt_dtype, itir_ts.ListType) + isinstance(input_expr.gt_dtype, ts.ListType) and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type @@ -1232,7 +1231,7 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: connector, ) - if isinstance(node.type, itir_ts.ListType): + if isinstance(node.type, ts.ListType): # The only builtin function (so far) handled here that returns a list # is 'make_const_list'. There are other builtin functions (map_, neighbors) # that return a list but they are handled in specialized visit methods. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 9bd40f75f8..10895ce66e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -290,6 +290,7 @@ def _add_storage( # represent zero-dimensional fields as scalar arguments return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient) # handle default case: field with one or more dimensions + assert isinstance(gt_type.dtype, ts.ScalarType) dc_dtype = dace_utils.as_dace_type(gt_type.dtype) if tuple_name is None: # Use symbolic shape, which allows to invoke the program with fields of different size; diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 118f0449c8..c46420c24b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -45,17 +45,18 @@ def get_tuple_fields( ... ("a_1_1", sty), ... ] """ + assert all(isinstance(t, ts.DataType) for t in tuple_type.types) fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] if flatten: - expanded_fields = [ + expanded_fields: list[list[tuple[str, ts.DataType]]] = [ get_tuple_fields(field_name, field_type) if isinstance(field_type, ts.TupleType) - else [(field_name, field_type)] + else [(field_name, field_type)] # type: ignore[list-item] # checked in assert for field_name, field_type in fields ] return list(itertools.chain(*expanded_fields)) else: - return fields + return fields # type: ignore[return-value] # checked in assert def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 66f8937dc5..983063a9cb 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -78,15 +78,15 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: >>> type_class(ts.TupleType(types=[])).__name__ 'TupleType' """ - match symbol_type: - case ts.DeferredType(constraint): - if constraint is None: - raise ValueError(f"No type information available for '{symbol_type}'.") - elif isinstance(constraint, tuple): - raise ValueError(f"Not sufficient type information available for '{symbol_type}'.") - return constraint - case ts.TypeSpec() as concrete_type: - return concrete_type.__class__ + if isinstance(symbol_type, ts.DeferredType): + constraint = symbol_type.constraint + if constraint is None: + raise ValueError(f"No type information available for '{symbol_type}'.") + elif isinstance(constraint, tuple): + raise ValueError(f"Not sufficient type information available for '{symbol_type}'.") + return constraint + if isinstance(symbol_type, ts.TypeSpec): + return symbol_type.__class__ raise ValueError( f"Invalid type for TypeInfo: requires '{ts.TypeSpec}', got '{type(symbol_type)}'." ) @@ -197,7 +197,7 @@ def apply_to_primitive_constituents( return fun(*symbol_types) -def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: +def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType: """ Extract the data type from ``symbol_type`` if it is either `FieldType` or `ScalarType`. @@ -234,7 +234,10 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: >>> is_floating_point(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) True """ - return extract_dtype(symbol_type).kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] + return isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) and dtype.kind in [ + ts.ScalarKind.FLOAT32, + ts.ScalarKind.FLOAT64, + ] def is_integer(symbol_type: ts.TypeSpec) -> bool: @@ -295,7 +298,10 @@ def is_number(symbol_type: ts.TypeSpec) -> bool: def is_logical(symbol_type: ts.TypeSpec) -> bool: - return extract_dtype(symbol_type).kind is ts.ScalarKind.BOOL + return ( + isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) + and dtype.kind is ts.ScalarKind.BOOL + ) def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: @@ -385,11 +391,10 @@ def extract_dims(symbol_type: ts.TypeSpec) -> list[common.Dimension]: >>> extract_dims(ts.FieldType(dims=[I, J], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64))) [Dimension(value='I', kind=), Dimension(value='J', kind=)] """ - match symbol_type: - case ts.ScalarType(): - return [] - case ts.FieldType(dims): - return dims + if isinstance(symbol_type, ts.ScalarType): + return [] + if isinstance(symbol_type, ts.FieldType): + return symbol_type.dims raise ValueError(f"Can not extract dimensions from '{symbol_type}'.") @@ -502,7 +507,9 @@ def promote( return types[0] elif all(isinstance(type_, (ts.ScalarType, ts.FieldType)) for type_ in types): dims = common.promote_dims(*(extract_dims(type_) for type_ in types)) - dtype = cast(ts.ScalarType, promote(*(extract_dtype(type_) for type_ in types))) + extracted_dtypes = [extract_dtype(type_) for type_ in types] + assert all(isinstance(dtype, ts.ScalarType) for dtype in extracted_dtypes) + dtype = cast(ts.ScalarType, promote(*extracted_dtypes)) # type: ignore[arg-type] # checked is `ScalarType` return ts.FieldType(dims=dims, dtype=dtype) raise TypeError("Expected a 'FieldType' or 'ScalarType'.") diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index fa8c9b9ab1..060d56aea2 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,21 +6,13 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass from typing import Iterator, Optional, Sequence, Union -from gt4py.eve.type_definitions import IntEnum -from gt4py.eve.utils import content_hash -from gt4py.next import common as func_common +from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types +from gt4py.next import common -@dataclass(frozen=True) -class TypeSpec: - def __hash__(self) -> int: - return hash(content_hash(self)) - - def __init_subclass__(cls) -> None: - cls.__hash__ = TypeSpec.__hash__ # type: ignore[method-assign] +class TypeSpec(eve_datamodels.DataModel, kw_only=False, frozen=True): ... # type: ignore[call-arg] class DataType(TypeSpec): @@ -40,14 +32,12 @@ class CallableType: """ -@dataclass(frozen=True) class DeferredType(TypeSpec): """Dummy used to represent a type not yet inferred.""" constraint: Optional[type[TypeSpec] | tuple[type[TypeSpec], ...]] -@dataclass(frozen=True) class VoidType(TypeSpec): """ Return type of a function without return values. @@ -56,22 +46,20 @@ class VoidType(TypeSpec): """ -@dataclass(frozen=True) class DimensionType(TypeSpec): - dim: func_common.Dimension + dim: common.Dimension -@dataclass(frozen=True) class OffsetType(TypeSpec): # TODO(havogt): replace by ConnectivityType - source: func_common.Dimension - target: tuple[func_common.Dimension] | tuple[func_common.Dimension, func_common.Dimension] + source: common.Dimension + target: tuple[common.Dimension] | tuple[common.Dimension, common.Dimension] def __str__(self) -> str: return f"Offset[{self.source}, {self.target}]" -class ScalarKind(IntEnum): +class ScalarKind(eve_types.IntEnum): BOOL = 1 INT32 = 32 INT64 = 64 @@ -80,7 +68,6 @@ class ScalarKind(IntEnum): STRING = 3001 -@dataclass(frozen=True) class ScalarType(DataType): kind: ScalarKind shape: Optional[list[int]] = None @@ -92,31 +79,43 @@ def __str__(self) -> str: return f"{kind_str}{self.shape}" -@dataclass(frozen=True) -class TupleType(DataType): - types: list[DataType] - - def __str__(self) -> str: - return f"tuple[{', '.join(map(str, self.types))}]" +class ListType(DataType): + """Represents a neighbor list in the ITIR representation. - def __iter__(self) -> Iterator[DataType]: - yield from self.types + Note: not used in the frontend. + """ - def __len__(self) -> int: - return len(self.types) + element_type: DataType + # TODO(havogt): the `offset_type` is not yet used in type_inference, + # it is meant to describe the neighborhood (via the local dimension) + offset_type: Optional[common.Dimension] = None -@dataclass(frozen=True) class FieldType(DataType, CallableType): - dims: list[func_common.Dimension] - dtype: ScalarType + dims: list[common.Dimension] + dtype: ScalarType | ListType def __str__(self) -> str: dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]" return f"Field[{dims}, {self.dtype}]" -@dataclass(frozen=True) +class TupleType(DataType): + # TODO(tehrengruber): Remove `DeferredType` again. This was erroneously + # introduced before we checked the annotations at runtime. All attributes of + # a type that are types themselves must be concrete. + types: list[DataType | DimensionType | DeferredType] + + def __str__(self) -> str: + return f"tuple[{', '.join(map(str, self.types))}]" + + def __iter__(self) -> Iterator[DataType | DimensionType | DeferredType]: + yield from self.types + + def __len__(self) -> int: + return len(self.types) + + class FunctionType(TypeSpec, CallableType): pos_only_args: Sequence[TypeSpec] pos_or_kw_args: dict[str, TypeSpec] diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 62a6781316..e601556e55 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -10,7 +10,6 @@ import builtins import collections.abc -import dataclasses import functools import types import typing @@ -105,7 +104,7 @@ def from_type_hint( raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") tuple_types = [recursive_make_symbol(arg) for arg in args] assert all(isinstance(elem, ts.DataType) for elem in tuple_types) - return ts.TupleType(types=tuple_types) # type: ignore[arg-type] # checked in assert + return ts.TupleType(types=tuple_types) case common.Field: if (n_args := len(args)) != 2: @@ -168,7 +167,6 @@ def from_type_hint( raise ValueError(f"'{type_hint}' type is not supported.") -@dataclasses.dataclass(frozen=True) class UnknownPythonObject(ts.TypeSpec): _object: Any @@ -217,9 +215,9 @@ def from_value(value: Any) -> ts.TypeSpec: # not needed anymore. elems = [from_value(el) for el in value] assert all(isinstance(elem, ts.DataType) for elem in elems) - return ts.TupleType(types=elems) # type: ignore[arg-type] # checked in assert + return ts.TupleType(types=elems) elif isinstance(value, types.ModuleType): - return UnknownPythonObject(_object=value) + return UnknownPythonObject(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) symbol_type = from_type_hint(type_) diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index 05be5f3db0..75b07fd8a0 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -10,9 +10,9 @@ import enum import numbers +import sys import types import typing -from typing import Set # noqa: F401 [unused-import] used in exec() context from typing import ( Any, Callable, @@ -26,6 +26,7 @@ MutableSequence, Optional, Sequence, + Set, # noqa: F401 [unused-import] used in exec() context Tuple, Type, TypeVar, @@ -555,6 +556,18 @@ class WrongModel: ("typing.MutableSequence[int]", ([1, 2, 3], []), ((1, 2, 3), tuple(), 1, [1.0], {1})), ("typing.Set[int]", ({1, 2, 3}, set()), (1, [1], (1,), {1: None})), ("typing.Union[int, float, str]", [1, 3.0, "one"], [[1], [], 1j]), + pytest.param( + "int | float | str", + [1, 3.0, "one"], + [[1], [], 1j], + marks=pytest.mark.skipif(sys.version_info < (3, 10), reason="| union syntax not supported"), + ), + pytest.param( + "typing.List[int|float]", + [[1, 2.0], []], + [1, 2.0, [1, "2.0"]], + marks=pytest.mark.skipif(sys.version_info < (3, 10), reason="| union syntax not supported"), + ), ("typing.Optional[int]", [1, None], [[1], [], 1j]), ( "typing.Dict[Union[int, float, str], Union[Tuple[int, Optional[float]], Set[int]]]", diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7eb4e86adb..b6b70af07c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -46,8 +46,8 @@ bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) float64_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) -float64_list_type = it_ts.ListType(element_type=float64_type) -int_list_type = it_ts.ListType(element_type=int_type) +float64_list_type = ts.ListType(element_type=float64_type) +int_list_type = ts.ListType(element_type=int_type) float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) float_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=float64_type) @@ -77,8 +77,8 @@ def expression_test_cases(): (im.deref(im.ref("it", it_on_e_of_e_type)), it_on_e_of_e_type.element_type), (im.call("can_deref")(im.ref("it", it_on_e_of_e_type)), bool_type), (im.if_(True, 1, 2), int_type), - (im.call("make_const_list")(True), it_ts.ListType(element_type=bool_type)), - (im.call("list_get")(0, im.ref("l", it_ts.ListType(element_type=bool_type))), bool_type), + (im.call("make_const_list")(True), ts.ListType(element_type=bool_type)), + (im.call("list_get")(0, im.ref("l", ts.ListType(element_type=bool_type))), bool_type), ( im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), it_ts.NamedRangeType(dim=Vertex), @@ -110,7 +110,7 @@ def expression_test_cases(): # neighbors ( im.neighbors("E2V", im.ref("a", it_on_e_of_e_type)), - it_ts.ListType(element_type=it_on_e_of_e_type.element_type), + ts.ListType(element_type=it_on_e_of_e_type.element_type), ), # cast (im.call("cast_")(1, "int32"), int_type),