diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index c2a872c1c4..349089ebfa 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -50,8 +50,6 @@ stages: CUPY_PACKAGE: cupy-cuda12x CUPY_VERSION: 13.3.0 UBUNTU_VERSION: 22.04 - # TODO: enable CI job when Todi is back in operational state - when: manual build_py311_baseimage_x86_64: extends: .build_baseimage_x86_64 @@ -133,7 +131,7 @@ build_py310_image_aarch64: VARIANT: [-nomesh, -atlas] SUBVARIANT: [-cuda11x, -cpu] .test_helper_aarch64: - extends: [.container-runner-todi-gh200, .test_helper] + extends: [.container-runner-daint-gh200, .test_helper] parallel: matrix: - SUBPACKAGE: [cartesian, storage] diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index d187095019..d1631a461d 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -80,7 +80,9 @@ class Program: definition_stage: ffront_stages.ProgramDefinition backend: Optional[next_backend.Backend] - connectivities: Optional[common.OffsetProviderType] = None + connectivities: Optional[common.OffsetProvider] = ( + None # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information + ) @classmethod def from_function( @@ -304,7 +306,7 @@ def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: try: - from gt4py.next.program_processors.runners.dace_iterator import Program + from gt4py.next.program_processors.runners.dace_fieldview.program import Program except ImportError: pass diff --git a/src/gt4py/next/iterator/transforms/extractors.py b/src/gt4py/next/iterator/transforms/extractors.py new file mode 100644 index 0000000000..04c2b09139 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/extractors.py @@ -0,0 +1,72 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py import eve +from gt4py.next.iterator import ir as itir +from gt4py.next.type_system import type_specifications as ts + + +class SymbolNameSetExtractor(eve.NodeVisitor): + """Extract a set of symbol names""" + + def visit_Literal(self, node: itir.Literal) -> set[str]: + return set() + + def generic_visitor(self, node: itir.Node) -> set[str]: + input_fields: set[str] = set() + for child in eve.trees.iter_children_values(node): + input_fields |= self.visit(child) + return input_fields + + def visit_Node(self, node: itir.Node) -> set[str]: + return set() + + def visit_Program(self, node: itir.Program) -> set[str]: + names = set() + for stmt in node.body: + names |= self.visit(stmt) + return names + + def visit_IfStmt(self, node: itir.IfStmt) -> set[str]: + names = set() + for stmt in node.true_branch + node.false_branch: + names |= self.visit(stmt) + return names + + def visit_Temporary(self, node: itir.Temporary) -> set[str]: + return set() + + def visit_SymRef(self, node: itir.SymRef) -> set[str]: + return {str(node.id)} + + @classmethod + def only_fields(cls, program: itir.Program) -> set[str]: + field_param_names = [ + str(param.id) for param in program.params if isinstance(param.type, ts.FieldType) + ] + return {name for name in cls().visit(program) if name in field_param_names} + + +class InputNamesExtractor(SymbolNameSetExtractor): + """Extract the set of symbol names passed into field operators within a program.""" + + def visit_SetAt(self, node: itir.SetAt) -> set[str]: + return self.visit(node.expr) + + def visit_FunCall(self, node: itir.FunCall) -> set[str]: + input_fields = set() + for arg in node.args: + input_fields |= self.visit(arg) + return input_fields + + +class OutputNamesExtractor(SymbolNameSetExtractor): + """Extract the set of symbol names written to within a program""" + + def visit_SetAt(self, node: itir.SetAt) -> set[str]: + return self.visit(node.target) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/program.py b/src/gt4py/next/program_processors/runners/dace_fieldview/program.py new file mode 100644 index 0000000000..7f809152c5 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/program.py @@ -0,0 +1,248 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import collections +import dataclasses +import itertools +import typing +from typing import Any, ClassVar, Optional, Sequence + +import dace +import numpy as np + +from gt4py.next import backend as next_backend, common +from gt4py.next.ffront import decorator +from gt4py.next.iterator import ir as itir, transforms as itir_transforms +from gt4py.next.iterator.transforms import extractors as extractors +from gt4py.next.otf import arguments, recipes, toolchain +from gt4py.next.program_processors.runners.dace_common import utility as dace_utils +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass(frozen=True) +class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible): + """Extension of GT4Py Program implementing the SDFGConvertible interface via GTIR.""" + + sdfg_closure_cache: dict[str, Any] = dataclasses.field(default_factory=dict) + # Being a ClassVar ensures that in an SDFG with multiple nested GT4Py Programs, + # there is no name mangling of the connectivity tables used across the nested SDFGs + # since they share the same memory address. + connectivity_tables_data_descriptors: ClassVar[ + dict[str, dace.data.Array] + ] = {} # symbolically defined + + def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: + if (self.backend is None) or "dace" not in self.backend.name.lower(): + raise ValueError("The SDFG can be generated only for the DaCe backend.") + + offset_provider: common.OffsetProvider = { + **(self.connectivities or {}), + **self._implicit_offset_provider, + } + column_axis = kwargs.get("column_axis", None) + + # TODO(ricoh): connectivity tables required here for now. + gtir_stage = typing.cast(next_backend.Transforms, self.backend.transforms).past_to_itir( + toolchain.CompilableProgram( + data=self.past_stage, + args=arguments.CompileTimeArgs( + args=tuple(p.type for p in self.past_stage.past_node.params), + kwargs={}, + column_axis=column_axis, + offset_provider=offset_provider, + ), + ) + ) + program = gtir_stage.data + program = itir_transforms.apply_fieldview_transforms( # run the transforms separately because they require the runtime info + program, offset_provider=offset_provider + ) + object.__setattr__( + gtir_stage, + "data", + program, + ) + object.__setattr__( + gtir_stage.args, "offset_provider", gtir_stage.args.offset_provider_type + ) # TODO(ricoh): currently this is circumventing the frozenness of CompileTimeArgs + # in order to isolate DaCe from the runtime tables in connectivities.offset_provider. + # These are needed at the time of writing for mandatory GTIR passes. + # Remove this as soon as Program does not expect connectivity tables anymore. + + _crosscheck_dace_parsing( + dace_parsed_args=[*args, *kwargs.values()], + gt4py_program_args=[p.type for p in program.params], + ) + + compile_workflow = typing.cast( + recipes.OTFCompileWorkflow, + self.backend.executor + if not hasattr(self.backend.executor, "step") + else self.backend.executor.step, + ) # We know which backend we are using, but we don't know if the compile workflow is cached. + # TODO(ricoh): switch 'itir_transforms_off=True' because we ran them separately previously + # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with + # the other parts of the workaround when possible. + sdfg = dace.SDFG.from_json( + compile_workflow.translation.replace(itir_transforms_off=True)(gtir_stage).source_code + ) + + self.sdfg_closure_cache["arrays"] = sdfg.arrays + + # Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, + # offset_providers_per_input_field. Add them as dynamic attributes to the SDFG + field_params = { + str(param.id): param for param in program.params if isinstance(param.type, ts.FieldType) + } + + def single_horizontal_dim_per_field( + fields: typing.Iterable[itir.Sym], + ) -> typing.Iterator[tuple[str, common.Dimension]]: + for field in fields: + assert isinstance(field.type, ts.FieldType) + horizontal_dims = [ + dim for dim in field.type.dims if dim.kind is common.DimensionKind.HORIZONTAL + ] + # do nothing for fields with multiple horizontal dimensions + # or without horizontal dimensions + # this is only meant for use with unstructured grids + if len(horizontal_dims) == 1: + yield str(field.id), horizontal_dims[0] + + input_fields = ( + field_params[name] for name in extractors.InputNamesExtractor.only_fields(program) + ) + sdfg.gt4py_program_input_fields = dict(single_horizontal_dim_per_field(input_fields)) + + output_fields = ( + field_params[name] for name in extractors.OutputNamesExtractor.only_fields(program) + ) + sdfg.gt4py_program_output_fields = dict(single_horizontal_dim_per_field(output_fields)) + + # TODO (ricoh): bring back sdfg.offset_providers_per_input_field. + # A starting point would be to use the "trace_shifts" pass on GTIR + # and associate the extracted shifts with each input field. + # Analogous to the version in `runners.dace_iterator.__init__`, which + # was removed when merging #1742. + + return sdfg + + def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: + """ + Return the closure arrays of the SDFG represented by this object + as a mapping between array name and the corresponding value. + + The connectivity tables are defined symbolically, i.e. table sizes & strides are DaCe symbols. + The need to define the connectivity tables in the `__sdfg_closure__` arises from the fact that + the offset providers are not part of GT4Py Program's arguments. + Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. + """ + closure_dict: dict[str, Any] = {} + + if self.connectivities: + symbols = {} + with_table = [ + name for name, conn in self.connectivities.items() if common.is_neighbor_table(conn) + ] + in_arrays_with_id = [ + (name, conn_id) + for name in with_table + if (conn_id := dace_utils.connectivity_identifier(name)) + in self.sdfg_closure_cache["arrays"] + ] + in_arrays = (name for name, _ in in_arrays_with_id) + name_axis = list(itertools.product(in_arrays, [0, 1])) + + def size_symbol_name(name: str, axis: int) -> str: + return dace_utils.field_size_symbol_name( + dace_utils.connectivity_identifier(name), axis + ) + + connectivity_tables_size_symbols = { + (sname := size_symbol_name(name, axis)): dace.symbol(sname) + for name, axis in name_axis + } + + def stride_symbol_name(name: str, axis: int) -> str: + return dace_utils.field_stride_symbol_name( + dace_utils.connectivity_identifier(name), axis + ) + + connectivity_table_stride_symbols = { + (sname := stride_symbol_name(name, axis)): dace.symbol(sname) + for name, axis in name_axis + } + + symbols = connectivity_tables_size_symbols | connectivity_table_stride_symbols + + # Define the storage location (e.g. CPU, GPU) of the connectivity tables + if "storage" not in self.connectivity_tables_data_descriptors: + for _, conn_id in in_arrays_with_id: + self.connectivity_tables_data_descriptors["storage"] = self.sdfg_closure_cache[ + "arrays" + ][conn_id].storage + break + + # Build the closure dictionary + for name, conn_id in in_arrays_with_id: + if conn_id not in self.connectivity_tables_data_descriptors: + conn = self.connectivities[name] + assert common.is_neighbor_table(conn) + self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( + dtype=dace.dtypes.dtype_to_typeclass(conn.dtype.dtype.type), + shape=[ + symbols[dace_utils.field_size_symbol_name(conn_id, 0)], + symbols[dace_utils.field_size_symbol_name(conn_id, 1)], + ], + strides=[ + symbols[dace_utils.field_stride_symbol_name(conn_id, 0)], + symbols[dace_utils.field_stride_symbol_name(conn_id, 1)], + ], + storage=Program.connectivity_tables_data_descriptors["storage"], + ) + closure_dict[conn_id] = self.connectivity_tables_data_descriptors[conn_id] + + return closure_dict + + def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: + return [p.id for p in self.past_stage.past_node.params], [] + + +def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> None: + for dace_parsed_arg, gt4py_program_arg in zip( + dace_parsed_args, + gt4py_program_args, + strict=False, # dace does not see implicit size args + ): + match dace_parsed_arg: + case dace.data.Scalar(): + assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg) + case bool() | np.bool_(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind == ts.ScalarKind.BOOL + case int() | np.integer(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] + case float() | np.floating(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] + case str() | np.str_(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind == ts.ScalarKind.STRING + case dace.data.Array(): + assert isinstance(gt4py_program_arg, ts.FieldType) + assert isinstance(gt4py_program_arg.dtype, ts.ScalarType) + assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) + assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype) + case dace.data.Structure() | dict() | collections.OrderedDict(): + # offset provider + pass + case _: + raise ValueError( + f"Unresolved case for {dace_parsed_arg} (==, !=) {gt4py_program_arg}" + ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index a38a50d886..779dc8a1c9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -37,6 +37,7 @@ class DaCeTranslator( ): device_type: core_defs.DeviceType auto_optimize: bool + itir_transforms_off: bool = False def _language_settings(self) -> languages.LanguageSettings: return languages.LanguageSettings( @@ -51,7 +52,8 @@ def generate_sdfg( auto_opt: bool, on_gpu: bool, ) -> dace.SDFG: - ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) + if not self.itir_transforms_off: + ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) sdfg = gtir_sdfg.build_sdfg_from_gtir( ir, offset_provider_type=common.offset_provider_to_type(offset_provider) ) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 5a43144b4b..57c52eae12 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -583,6 +583,11 @@ def test_K_offset_write(backend): if backend == "cuda": pytest.skip("cuda K-offset write generates bad code") + if backend == "dace:gpu": + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1684" + ) + arraylib = get_array_library(backend) array_shape = (1, 1, 4) K_values = arraylib.arange(start=40, stop=44) diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 08904c06f3..cd71c306eb 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -12,15 +12,16 @@ import gt4py.next as gtx from gt4py.next import allocators as gtx_allocators, common as gtx_common +from gt4py._core import definitions as core_defs from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import cartesian_case, unstructured_case +from next_tests.integration_tests.cases import cartesian_case, unstructured_case # noqa: F401 from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( E2V, E2VDim, Edge, Vertex, - exec_alloc_descriptor, - mesh_descriptor, + exec_alloc_descriptor, # noqa: F401 + mesh_descriptor, # noqa: F401 ) from next_tests.integration_tests.multi_feature_tests.ffront_tests.test_laplacian import ( lap_program, @@ -37,23 +38,17 @@ pytestmark = pytest.mark.requires_dace -def test_sdfgConvertible_laplap(cartesian_case): +def test_sdfgConvertible_laplap(cartesian_case): # noqa: F811 if not cartesian_case.backend or "dace" not in cartesian_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - # TODO(ricoh): enable test after adding GTIR support - pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") - - allocator, backend = unstructured_case.allocator, unstructured_case.backend - - if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): - import cupy as xp - else: - import numpy as xp + backend = cartesian_case.backend in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() out_field = cases.allocate(cartesian_case, laplap_program, "out_field")() + xp = in_field.array_ns + # Test DaCe closure support @dace.program def sdfg(): @@ -88,16 +83,13 @@ def testee(a: gtx.Field[gtx.Dims[Vertex], gtx.float64], b: gtx.Field[gtx.Dims[Ed @pytest.mark.uses_unstructured_shift -def test_sdfgConvertible_connectivities(unstructured_case): +def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811 if not unstructured_case.backend or "dace" not in unstructured_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - # TODO(ricoh): enable test after adding GTIR support - pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") - allocator, backend = unstructured_case.allocator, unstructured_case.backend - if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): + if gtx_allocators.is_field_allocator_for(allocator, gtx_allocators.CUPY_DEVICE): import cupy as xp dace_storage_type = dace.StorageType.GPU_Global @@ -113,6 +105,15 @@ def test_sdfgConvertible_connectivities(unstructured_case): name="OffsetProvider", ) + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[0, 1], [1, 2], [2, 0]]), + allocator=allocator, + ) + + testee2 = testee.with_backend(backend).with_connectivities({"E2V": e2v}) + @dace.program def sdfg( a: dace.data.Array(dtype=dace.float64, shape=(rows,), storage=dace_storage_type), @@ -120,17 +121,10 @@ def sdfg( offset_provider: OffsetProvider_t, connectivities: dace.compiletime, ): - testee.with_backend(backend).with_connectivities(connectivities)( - a, out, offset_provider=offset_provider - ) + testee2.with_connectivities(connectivities)(a, out, offset_provider=offset_provider) + return out - e2v = gtx.as_connectivity( - [Edge, E2VDim], - codomain=Vertex, - data=xp.asarray([[0, 1], [1, 2], [2, 0]]), - allocator=allocator, - ) - connectivities = {"E2V": e2v.__gt_type__()} + connectivities = {"E2V": e2v} # replace 'e2v' with 'e2v.__gt_type__()' when GTIR is AOT offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) SDFG = sdfg.to_sdfg(connectivities=connectivities) @@ -138,23 +132,21 @@ def sdfg( a = gtx.as_field([Vertex], xp.asarray([0.0, 1.0, 2.0]), allocator=allocator) out = gtx.zeros({Edge: 3}, allocator=allocator) - e2v_ndarray_copy = ( - e2v.ndarray.copy() - ) # otherwise DaCe complains about the gt4py custom allocated view - # This is a low level interface to call the compiled SDFG. - # It is not supposed to be used in user code. - # The high level interface should be provided by a DaCe Orchestrator, - # i.e. decorator that hides the low level operations. - # This test checks only that the SDFGConvertible interface works correctly. + + def get_stride_from_numpy_to_dace(arg: core_defs.NDArrayObject, axis: int) -> int: + # NumPy strides: number of bytes to jump + # DaCe strides: number of elements to jump + return arg.strides[axis] // arg.itemsize + cSDFG( a, out, offset_provider, rows=3, cols=2, - connectivity_E2V=e2v_ndarray_copy, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), + connectivity_E2V=e2v, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v.ndarray, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v.ndarray, 1), ) e2v_np = e2v.asnumpy() @@ -166,18 +158,19 @@ def sdfg( data=xp.asarray([[1, 0], [2, 1], [0, 2]]), allocator=allocator, ) - e2v_ndarray_copy = e2v.ndarray.copy() offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) - cSDFG( - a, - out, - offset_provider, - rows=3, - cols=2, - connectivity_E2V=e2v_ndarray_copy, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), - ) + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "allow_view_arguments", value=True) + cSDFG( + a, + out, + offset_provider, + rows=3, + cols=2, + connectivity_E2V=e2v, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v.ndarray, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v.ndarray, 1), + ) e2v_np = e2v.asnumpy() assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py new file mode 100644 index 0000000000..db0f90b409 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py @@ -0,0 +1,134 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from gt4py import next as gtx +from gt4py.next import common + +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + Cell, + Edge, + IDim, + JDim, + KDim, + Vertex, + mesh_descriptor, # noqa: F401 +) + + +try: + import dace + + from gt4py.next.program_processors.runners import dace as dace_backends +except ImportError: + from types import ModuleType + from typing import Optional + + from gt4py.next import backend as next_backend + + dace: Optional[ModuleType] = None + dace_backends: Optional[ModuleType] = None + + +@pytest.fixture( + params=[ + pytest.param(dace_backends.run_dace_cpu, marks=pytest.mark.requires_dace), + pytest.param( + dace_backends.run_dace_gpu, marks=(pytest.mark.requires_gpu, pytest.mark.requires_dace) + ), + ] +) +def gtir_dace_backend(request): + yield request.param + + +@pytest.fixture +def cartesian(request, gtir_dace_backend): + if gtir_dace_backend is None: + yield None + + yield cases.Case( + backend=gtir_dace_backend, + offset_provider={ + "Ioff": IDim, + "Joff": JDim, + "Koff": KDim, + }, + default_sizes={IDim: 10, JDim: 10, KDim: 10}, + grid_type=common.GridType.CARTESIAN, + allocator=gtir_dace_backend.allocator, + ) + + +@pytest.fixture +def unstructured(request, gtir_dace_backend, mesh_descriptor): # noqa: F811 + if gtir_dace_backend is None: + yield None + + yield cases.Case( + backend=gtir_dace_backend, + offset_provider=mesh_descriptor.offset_provider, + default_sizes={ + Vertex: mesh_descriptor.num_vertices, + Edge: mesh_descriptor.num_edges, + Cell: mesh_descriptor.num_cells, + KDim: 10, + }, + grid_type=common.GridType.UNSTRUCTURED, + allocator=gtir_dace_backend.allocator, + ) + + +@pytest.mark.skipif(dace is None, reason="DaCe not found") +def test_halo_exchange_helper_attrs(unstructured): + local_int = gtx.int + + @gtx.field_operator(backend=unstructured.backend) + def testee_op( + a: gtx.Field[[Vertex, KDim], gtx.int], + ) -> gtx.Field[[Vertex, KDim], gtx.int]: + return a + local_int(10) + + @gtx.program(backend=unstructured.backend) + def testee_prog( + a: gtx.Field[[Vertex, KDim], gtx.int], + b: gtx.Field[[Vertex, KDim], gtx.int], + c: gtx.Field[[Vertex, KDim], gtx.int], + ): + testee_op(b, out=c) + testee_op(a, out=b) + + dace_storage_type = ( + dace.StorageType.GPU_Global + if unstructured.backend == dace_backends.run_dace_gpu + else dace.StorageType.Default + ) + + rows = dace.symbol("rows") + cols = dace.symbol("cols") + + @dace.program + def testee_dace( + a: dace.data.Array(dtype=dace.int64, shape=(rows, cols), storage=dace_storage_type), + b: dace.data.Array(dtype=dace.int64, shape=(rows, cols), storage=dace_storage_type), + c: dace.data.Array(dtype=dace.int64, shape=(rows, cols), storage=dace_storage_type), + ): + testee_prog(a, b, c) + + # if simplify=True, DaCe might inline the nested SDFG coming from Program.__sdfg__, + # effectively erasing the attributes we want to test for here + sdfg = testee_dace.to_sdfg(simplify=False) + + testee = next( + subgraph for subgraph in sdfg.all_sdfgs_recursive() if subgraph.name == "testee_prog" + ) + + assert testee.gt4py_program_input_fields == {"a": Vertex, "b": Vertex} + assert testee.gt4py_program_output_fields == {"b": Vertex, "c": Vertex} diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py new file mode 100644 index 0000000000..7358ab3d8f --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py @@ -0,0 +1,102 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import typing + +import pytest + +from gt4py import next as gtx +from gt4py.next import common +from gt4py.next.iterator.transforms import extractors + +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + IDim, + JDim, + KDim, +) + + +if typing.TYPE_CHECKING: + from types import ModuleType + from typing import Optional + +try: + import dace + + from gt4py.next.program_processors.runners.dace import run_dace_cpu +except ImportError: + from gt4py.next import backend as next_backend + + dace: Optional[ModuleType] = None + run_dace_cpu: Optional[next_backend.Backend] = None + + +@pytest.fixture(params=[pytest.param(run_dace_cpu, marks=pytest.mark.requires_dace), gtx.gtfn_cpu]) +def gtir_dace_backend(request): + yield request.param + + +@pytest.fixture +def cartesian(request, gtir_dace_backend): + if gtir_dace_backend is None: + yield None + + yield cases.Case( + backend=gtir_dace_backend, + offset_provider={ + "Ioff": IDim, + "Joff": JDim, + "Koff": KDim, + }, + default_sizes={IDim: 10, JDim: 10, KDim: 10}, + grid_type=common.GridType.CARTESIAN, + allocator=gtir_dace_backend.allocator, + ) + + +@pytest.mark.skipif(dace is None, reason="DaCe not found") +def test_input_names_extractor_cartesian(cartesian): + @gtx.field_operator(backend=cartesian.backend) + def testee_op( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: + return a + + @gtx.program(backend=cartesian.backend) + def testee( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + b: gtx.Field[[IDim, JDim, KDim], gtx.int], + c: gtx.Field[[IDim, JDim, KDim], gtx.int], + ): + testee_op(b, out=c) + testee_op(a, out=b) + + input_field_names = extractors.InputNamesExtractor.only_fields(testee.gtir) + assert input_field_names == {"a", "b"} + + +@pytest.mark.skipif(dace is None, reason="DaCe not found") +def test_output_names_extractor(cartesian): + @gtx.field_operator(backend=cartesian.backend) + def testee_op( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: + return a + + @gtx.program(backend=cartesian.backend) + def testee( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + b: gtx.Field[[IDim, JDim, KDim], gtx.int], + c: gtx.Field[[IDim, JDim, KDim], gtx.int], + ): + testee_op(a, out=b) + testee_op(a, out=c) + + output_field_names = extractors.OutputNamesExtractor.only_fields(testee.gtir) + assert output_field_names == {"b", "c"}