From d7f55522beacfc77c12964f6bbb1962899d8821d Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 25 Nov 2024 14:22:24 +0100 Subject: [PATCH 01/13] feat[next]: remove NeighborTableOffsetProvider, use gtx.as_connectivity (#1729) User-facing change: use `gtx.as_connectivity` to create a connectivity/neighbor table instead of `NeighborTableOffsetProvider` which is deprecated (and the backward-compatible mechanism broken for some use-cases). The internal concepts of `Connectivity` and `NeighborTable` are updated. `ConnectivityType` is introduced which contains the compile-time info of a `Connectivity`. See ADR 19. Additionally, the compile-time info is used (instead of the run-time connectivities) in many places of the toolchain when possible. --- .gitpod/.vscode/launch.json | 13 +- .../0008-Mapping_Domain_to_Cpp-Backend.md | 2 +- docs/development/ADRs/0019-Connectivities.md | 55 +++++ docs/user/next/QuickstartGuide.md | 6 +- .../exercises/2_divergence_exercise.ipynb | 4 +- .../2_divergence_exercise_solution.ipynb | 4 +- .../exercises/3_gradient_exercise.ipynb | 4 +- .../3_gradient_exercise_solution.ipynb | 4 +- .../workshop/exercises/4_curl_exercise.ipynb | 4 +- .../exercises/4_curl_exercise_solution.ipynb | 4 +- .../exercises/5_vector_laplace_exercise.ipynb | 10 +- .../5_vector_laplace_exercise_solution.ipynb | 10 +- .../8_diffusion_exercise_solution.ipynb | 8 +- docs/user/next/workshop/slides/slides_2.ipynb | 10 +- src/gt4py/_core/definitions.py | 10 +- src/gt4py/next/__init__.py | 6 +- src/gt4py/next/common.py | 170 ++++++++++---- src/gt4py/next/constructors.py | 24 +- src/gt4py/next/embedded/nd_array_field.py | 35 ++- src/gt4py/next/ffront/decorator.py | 47 ++-- src/gt4py/next/ffront/experimental.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 30 +-- src/gt4py/next/iterator/embedded.py | 215 +++++++++++------- .../next/iterator/ir_utils/domain_utils.py | 26 +-- src/gt4py/next/iterator/runtime.py | 10 +- .../iterator/transforms/collapse_tuple.py | 6 +- src/gt4py/next/iterator/transforms/cse.py | 6 +- .../iterator/transforms/fuse_as_fieldop.py | 9 +- .../next/iterator/transforms/global_tmps.py | 4 +- .../next/iterator/transforms/inline_scalar.py | 4 +- .../next/iterator/transforms/pass_manager.py | 29 ++- .../transforms/pass_manager_legacy.py | 14 +- .../next/iterator/transforms/unroll_reduce.py | 28 +-- .../next/iterator/type_system/inference.py | 34 +-- .../iterator/type_system/type_synthesizer.py | 48 ++-- src/gt4py/next/otf/arguments.py | 54 +---- .../codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py | 76 +------ .../codegens/gtfn/gtfn_module.py | 47 ++-- .../codegens/gtfn/itir_to_gtfn_ir.py | 31 +-- .../runners/dace_common/dace_backend.py | 21 +- .../runners/dace_common/utility.py | 15 +- .../runners/dace_fieldview/gtir_dataflow.py | 75 +++--- .../runners/dace_fieldview/gtir_sdfg.py | 33 ++- .../runners/dace_fieldview/workflow.py | 6 +- .../runners/dace_iterator/__init__.py | 53 +++-- .../runners/dace_iterator/itir_to_sdfg.py | 45 ++-- .../runners/dace_iterator/itir_to_tasklet.py | 97 ++++---- .../runners/dace_iterator/utility.py | 10 +- .../runners/dace_iterator/workflow.py | 6 +- .../next/program_processors/runners/gtfn.py | 16 +- .../program_processors/runners/roundtrip.py | 16 +- .../next/type_system/type_specifications.py | 1 + .../feature_tests/dace/test_orchestration.py | 86 ++++--- .../ffront_tests/ffront_test_utils.py | 91 +++++--- .../ffront_tests/test_execution.py | 36 +-- .../ffront_tests/test_external_local_field.py | 12 +- .../ffront_tests/test_gt4py_builtins.py | 18 +- .../test_temporaries_with_sizes.py | 2 +- .../iterator_tests/test_builtins.py | 40 +--- .../test_strided_offset_provider.py | 9 +- .../ffront_tests/test_ffront_fvm_nabla.py | 64 +++--- .../multi_feature_tests/fvm_nabla_setup.py | 56 +++-- .../iterator_tests/test_fvm_nabla.py | 114 ++++------ .../test_with_toy_connectivity.py | 38 ++-- tests/next_tests/toy_connectivity.py | 7 + tests/next_tests/unit_tests/conftest.py | 25 +- .../embedded_tests/test_nd_array_field.py | 15 +- .../test_embedded_field_with_list.py | 10 +- .../iterator_tests/test_runtime_domain.py | 10 +- .../iterator_tests/test_type_inference.py | 34 +-- .../transforms_tests/test_cse.py | 14 +- .../transforms_tests/test_domain_inference.py | 13 +- .../transforms_tests/test_fuse_as_fieldop.py | 13 +- .../transforms_tests/test_global_tmps.py | 8 +- .../transforms_tests/test_prune_casts.py | 6 +- .../transforms_tests/test_unroll_reduce.py | 69 ++++-- .../gtfn_tests/test_itir_to_gtfn_ir.py | 4 +- .../runners_tests/dace_tests/test_dace.py | 24 +- .../dace_tests/test_gtir_to_sdfg.py | 134 ++++++----- .../unit_tests/test_constructors.py | 14 +- 80 files changed, 1293 insertions(+), 1170 deletions(-) create mode 100644 docs/development/ADRs/0019-Connectivities.md diff --git a/.gitpod/.vscode/launch.json b/.gitpod/.vscode/launch.json index f682b56388..b25a182648 100644 --- a/.gitpod/.vscode/launch.json +++ b/.gitpod/.vscode/launch.json @@ -6,7 +6,7 @@ "configurations": [ { "name": "Python: Current File (just my code)", - "type": "python", + "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", @@ -14,11 +14,20 @@ }, { "name": "Python: Current File (all)", - "type": "python", + "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", "justMyCode": false + }, + { + "name": "Python: Debug Tests", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "purpose": ["debug-test"], + "console": "integratedTerminal", + "justMyCode": true } ] } diff --git a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md index a1ee8575d2..1ce83431ee 100644 --- a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md +++ b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md @@ -20,7 +20,7 @@ The Python embedded execution for Iterator IR keeps track of the current locatio ### Python side -On the Python side, we label dimensions of fields with the location type, e.g. `Edge` or `Vertex`. The domain uses `named_ranges` that uses the same location types to express _where_ to iterate, e.g. `named_range(Vertex, range(0, 100))` is an iteration over the `Vertex` dimension, no order in the domain is required. Additionally, the `Connectivity` (aka `NeighborTableOffsetProvider` in the current implementation) describes the mapping between location types. +On the Python side, we label dimensions of fields with the location type, e.g. `Edge` or `Vertex`. The domain uses `named_ranges` that uses the same location types to express _where_ to iterate, e.g. `named_range(Vertex, range(0, 100))` is an iteration over the `Vertex` dimension, no order in the domain is required. Additionally, the `Connectivity` describes the mapping between location types. ### C++ side diff --git a/docs/development/ADRs/0019-Connectivities.md b/docs/development/ADRs/0019-Connectivities.md new file mode 100644 index 0000000000..76e85e49a6 --- /dev/null +++ b/docs/development/ADRs/0019-Connectivities.md @@ -0,0 +1,55 @@ +--- +tags: [] +--- + +# [Connectivities] + +- **Status**: valid +- **Authors**: Hannes Vogt (@havogt) +- **Created**: 2024-11-08 +- **Updated**: 2024-11-08 + +The representation of Connectivities (neighbor tables, `NeighborTableOffsetProvider`) and their identifier (offset tag, `FieldOffset`, etc.) was extended and modified based on the needs of different parts of the toolchain. Here we outline the ideas for consolidating the different closely-related concepts. + +## History + +In the early days of Iterator IR (ITIR), an `offset` was a literal in the IR. Its meaning was only provided at execution time by a mapping from `offset` tag to an entity that we labelled `OffsetProvider`. We had mainly 2 kinds of `OffsetProvider`: a `Dimension` representing a Cartesian shift and a `NeighborTableOffsetProvider` for unstructured shifts. Since the type of `offset` needs to be known for compilation (strided for Cartesian, lookup-table for unstructured), this prevents a clean interface for ahead-of-time compilation. +For the frontend type-checking we later introduce a `FieldOffset` which contained type information of the mapped dimensions. +For (field-view) embedded we introduced a `ConnectivityField` (now `Connectivity`) which could be generated from the OffsetProvider information. + +These different concepts had overlap but were not 1-to-1 replacements. + +## Decision + +We update and introduce the following concepts + +### Conceptual definitions + +**Connectivity** is a mapping from index (or product of indices) to index. It covers 1-to-1 mappings, e.g. Cartesian shifts, NeighborTables (2D mappings) and dynamic Cartesian shifts. + +**NeighborConnectivity** is a 2D mapping of the N neighbors of a Location A to a Location B. + +**NeighborTable** is a _NeighborConnectivity_ backed by a buffer. + +**ConnectivityType**, **NeighborConnectivityType** contains all information that is needed for compilation. + +### Full definitions + +See `next.common` module + +Note: Currently, the compiled backends supports only `NeighborConnectivity`s that are `NeighborTable`s. We do not yet encode this in the type and postpone discussion to the point where we support alternative implementations (e.g. `StridedNeighborConnectivity`). + +## Which parts of the toolchain use which concept? + +### Embedded + +Embedded execution of field-view supports any kind of `Connectivity`. +Embedded execution of iterator (local) view supports only `NeighborConnectivity`s. + +### IR transformations and compiled backends + +All transformations and code-generation should use `ConnectivityType`, not the `Connectivity` which contains the runtime mapping. + +Note, currently the `global_tmps` pass uses runtime information, therefore this is not strictly enforced. + +The only supported `Connectivity`s in compiled backends (currently) are `NeighborTable`s. diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index 81604c7770..2cb6647519 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -155,8 +155,6 @@ This section approaches the pseudo-laplacian by introducing the required APIs pr - [Using reductions on connected mesh elements](#Using-reductions-on-connected-mesh-elements) - [Implementing the actual pseudo-laplacian](#Implementing-the-pseudo-laplacian) -+++ - #### Defining the mesh and its connectivities The examples related to unstructured meshes use the mesh below. The edges (in blue) and the cells (in red) are numbered with zero-based indices. @@ -237,7 +235,7 @@ E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim,E2CDim)) Note that the field offset does not contain the actual connectivity table, that's provided through an _offset provider_: ```{code-cell} ipython3 -E2C_offset_provider = gtx.NeighborTableOffsetProvider(edge_to_cell_table, EdgeDim, CellDim, 2) +E2C_offset_provider = gtx.as_connectivity([EdgeDim, E2CDim], codomain=CellDim, data=edge_to_cell_table, skip_value=-1) ``` The field operator `nearest_cell_to_edge` below shows an example of applying this transform. There is a little twist though: the subscript in `E2C[0]` means that only the value of the first connected cell is taken, the second (if exists) is ignored. @@ -385,7 +383,7 @@ As explained in the section outline, the pseudo-laplacian needs the cell-to-edge C2EDim = gtx.Dimension("C2E", kind=gtx.DimensionKind.LOCAL) C2E = gtx.FieldOffset("C2E", source=EdgeDim, target=(CellDim, C2EDim)) -C2E_offset_provider = gtx.NeighborTableOffsetProvider(cell_to_edge_table, CellDim, EdgeDim, 3) +C2E_offset_provider = gtx.as_connectivity([CellDim, C2EDim], codomain=EdgeDim, data=cell_to_edge_table, skip_value=-1) ``` **Weights of edge differences:** diff --git a/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb b/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb index 50349e52b0..b0a1980d0f 100644 --- a/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb +++ b/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "5dbd2f62", "metadata": {}, "outputs": [], @@ -113,7 +113,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " divergence_gt4py = gtx.zeros(cell_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb b/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb index 6baac2b8c0..573ee6a44e 100644 --- a/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "5dbd2f62", "metadata": {}, "outputs": [], @@ -118,7 +118,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " divergence_gt4py = gtx.zeros(cell_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb index c8914120d3..2b422b1823 100644 --- a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb +++ b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "84b02762", "metadata": {}, "outputs": [], @@ -110,7 +110,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " gradient_gt4py_x = gtx.zeros(cell_domain, allocator=backend)\n", " gradient_gt4py_y = gtx.zeros(cell_domain, allocator=backend)\n", diff --git a/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb b/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb index 5e940a4b71..85044b989f 100644 --- a/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "84b02762", "metadata": {}, "outputs": [], @@ -123,7 +123,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " gradient_gt4py_x = gtx.zeros(cell_domain, allocator=backend)\n", " gradient_gt4py_y = gtx.zeros(cell_domain, allocator=backend)\n", diff --git a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb index 4a6b37baf7..dc321f1bdd 100644 --- a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb +++ b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5b6ffc9e", "metadata": {}, "outputs": [], @@ -134,7 +134,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " curl_gt4py = gtx.zeros(vertex_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb b/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb index 065cf02de7..251fe8239a 100644 --- a/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb @@ -107,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5b6ffc9e", "metadata": {}, "outputs": [], @@ -139,7 +139,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " curl_gt4py = gtx.zeros(vertex_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb b/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb index 832375a86b..30f568de6f 100644 --- a/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb +++ b/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -272,10 +272,10 @@ " edge_orientation_cell.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", - " e2v_connectivity = gtx.NeighborTableOffsetProvider(e2v_table, E, V, 2, has_skip_values=False)\n", - " e2c_connectivity = gtx.NeighborTableOffsetProvider(e2c_table, E, C, 2, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", + " e2v_connectivity = gtx.as_connectivity([E, E2VDim], codomain=V, data=e2v_table)\n", + " e2c_connectivity = gtx.as_connectivity([E, E2CDim], codomain=C, data=e2c_table)\n", "\n", " laplacian_gt4py = gtx.zeros(edge_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb b/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb index be846d199d..eaeb8c7b02 100644 --- a/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb @@ -249,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -293,10 +293,10 @@ " edge_orientation_cell.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", - " e2v_connectivity = gtx.NeighborTableOffsetProvider(e2v_table, E, V, 2, has_skip_values=False)\n", - " e2c_connectivity = gtx.NeighborTableOffsetProvider(e2c_table, E, C, 2, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", + " e2v_connectivity = gtx.as_connectivity([E, E2VDim], codomain=V, data=e2v_table)\n", + " e2c_connectivity = gtx.as_connectivity([E, E2CDim], codomain=C, data=e2c_table)\n", "\n", " laplacian_gt4py = gtx.zeros(edge_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb b/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb index d4bcdb33d5..b278cee26d 100644 --- a/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb @@ -118,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -156,10 +156,8 @@ " dt,\n", " )\n", "\n", - " e2c2v_connectivity = gtx.NeighborTableOffsetProvider(\n", - " e2c2v_table, E, V, 4, has_skip_values=False\n", - " )\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " e2c2v_connectivity = gtx.as_connectivity([E, E2C2VDim], codomain=V, data=e2c2v_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " diffusion_step(\n", " u,\n", diff --git a/docs/user/next/workshop/slides/slides_2.ipynb b/docs/user/next/workshop/slides/slides_2.ipynb index 1e8925087f..c6967df4b2 100644 --- a/docs/user/next/workshop/slides/slides_2.ipynb +++ b/docs/user/next/workshop/slides/slides_2.ipynb @@ -281,17 +281,19 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "6d30a5e1", "metadata": {}, "outputs": [], "source": [ - "E2C_offset_provider = gtx.NeighborTableOffsetProvider(e2c_table, Edge, Cell, 2)" + "E2C_offset_provider = gtx.as_connectivity(\n", + " [Edge, E2CDim], codomain=Cell, data=e2c_table, skip_value=-1\n", + ")" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "d62f6c98", "metadata": {}, "outputs": [ @@ -311,7 +313,7 @@ " return cell_field(E2C[0]) # 0th index to isolate edge dimension\n", "\n", "\n", - "@gtx.program # uses skip_values, therefore we cannot use embedded\n", + "@gtx.program\n", "def run_nearest_cell_to_edge(\n", " cell_field: gtx.Field[Dims[Cell], float64], edge_field: gtx.Field[Dims[Edge], float64]\n", "):\n", diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 9d07b2eb79..8f62788b8f 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -439,13 +439,21 @@ def ndim(self) -> int: ... @property def shape(self) -> tuple[int, ...]: ... + @property + def strides(self) -> tuple[int, ...]: ... + @property def dtype(self) -> Any: ... + @property + def itemsize(self) -> int: ... + def item(self) -> Any: ... def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ... + def any(self) -> bool: ... + def __getitem__(self, item: Any) -> NDArrayObject: ... def __abs__(self) -> NDArrayObject: ... @@ -496,4 +504,4 @@ def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __xor__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 80bb276c70..4fa9215706 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -20,6 +20,7 @@ from . import common, ffront, iterator, program_processors from .common import ( + Connectivity, Dimension, DimensionKind, Dims, @@ -39,8 +40,7 @@ from .ffront.fbuiltins import * # noqa: F403 [undefined-local-with-import-star] explicitly reexport all from fbuiltins.__all__ from .ffront.fbuiltins import FieldOffset from .iterator.embedded import ( - NeighborTableOffsetProvider, - StridedNeighborOffsetProvider, + NeighborTableOffsetProvider, # TODO(havogt): deprecated index_field, np_as_located_field, ) @@ -61,6 +61,7 @@ "Dimension", "DimensionKind", "Field", + "Connectivity", "GridType", "domain", "Domain", @@ -75,7 +76,6 @@ "as_connectivity", # from iterator "NeighborTableOffsetProvider", - "StridedNeighborOffsetProvider", "index_field", "np_as_located_field", # from ffront diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 4aa0dd03aa..9b2870e1c0 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -18,7 +18,6 @@ from collections.abc import Mapping, Sequence import numpy as np -import numpy.typing as npt from gt4py._core import definitions as core_defs from gt4py.eve import utils @@ -95,7 +94,7 @@ def __str__(self) -> str: def __call__(self, val: int) -> NamedIndex: return NamedIndex(self, val) - def __add__(self, offset: int) -> ConnectivityField: + def __add__(self, offset: int) -> Connectivity: # TODO(sf-n): just to avoid circular import. Move or refactor the FieldOffset to avoid this. from gt4py.next.ffront import fbuiltins @@ -104,7 +103,7 @@ def __add__(self, offset: int) -> ConnectivityField: dimension_to_implicit_offset(self.value), source=self, target=(self,) )[offset] - def __sub__(self, offset: int) -> ConnectivityField: + def __sub__(self, offset: int) -> Connectivity: return self + (-offset) @@ -678,6 +677,9 @@ def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... + # TODO(havogt) + # This property is wrong, because for a function field we would not know to which NDArrayObject we want to convert + # at the very least, we need to take an allocator and rename this to `as_ndarray`. @property def ndarray(self) -> core_defs.NDArrayObject: ... @@ -688,7 +690,7 @@ def __str__(self) -> str: def asnumpy(self) -> np.ndarray: ... @abc.abstractmethod - def premap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... + def premap(self, index_field: Connectivity | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod def restrict(self, item: AnyIndexSpec) -> Self: ... @@ -700,8 +702,8 @@ def as_scalar(self) -> core_defs.ScalarT: ... @abc.abstractmethod def __call__( self, - index_field: ConnectivityField | fbuiltins.FieldOffset, - *args: ConnectivityField | fbuiltins.FieldOffset, + index_field: Connectivity | fbuiltins.FieldOffset, + *args: Connectivity | fbuiltins.FieldOffset, ) -> Field: ... @abc.abstractmethod @@ -811,12 +813,64 @@ def remapping(cls) -> ConnectivityKind: return cls.ALTER_DIMS | cls.ALTER_STRUCT +@dataclasses.dataclass(frozen=True) +class ConnectivityType: # TODO(havogt): would better live in type_specifications but would have to solve a circular import + domain: tuple[Dimension, ...] + codomain: Dimension + skip_value: Optional[core_defs.IntegralScalar] + dtype: core_defs.DType + + @property + def has_skip_values(self) -> bool: + return self.skip_value is not None + + +@dataclasses.dataclass(frozen=True) +class NeighborConnectivityType(ConnectivityType): + # TODO(havogt): refactor towards encoding this information in the local dimensions of the ConnectivityType.domain + max_neighbors: int + + @property + def source_dim(self) -> Dimension: + return self.domain[0] + + @property + def neighbor_dim(self) -> Dimension: + return self.domain[1] + + @runtime_checkable # type: ignore[misc] # DimT should be covariant, but then it breaks in other places -class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): +class Connectivity(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod - def codomain(self) -> DimT: ... + def codomain(self) -> DimT: + """ + The `codomain` is the set of all indices in a certain `Dimension`. + + We use the `Dimension` itself to describe the (infinite) set of all indices. + + Note: + We could restrict the infinite codomain to only the indices that are actually contained in the mapping. + Currently, this would just complicate implementation as we do not use this information. + """ + + def __gt_type__(self) -> ConnectivityType: + if is_neighbor_connectivity(self): + return NeighborConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + max_neighbors=self.ndarray.shape[1], + ) + else: + return ConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + ) @property def kind(self) -> ConnectivityKind: @@ -831,61 +885,61 @@ def skip_value(self) -> Optional[core_defs.IntegralScalar]: ... # Operators def __abs__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __neg__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __invert__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __eq__(self, other: Any) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __ne__(self, other: Any) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __add__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __radd__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __sub__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rsub__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __mul__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rmul__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __truediv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rtruediv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __floordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rfloordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __pow__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __and__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __or__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") # Utility function to construct a `Field` from different buffer representations. @@ -911,38 +965,58 @@ def _connectivity( domain: Optional[DomainLike] = None, dtype: Optional[core_defs.DType] = None, skip_value: Optional[core_defs.IntegralScalar] = None, -) -> ConnectivityField: +) -> Connectivity: raise NotImplementedError -@runtime_checkable -class Connectivity(Protocol): - max_neighbors: int - has_skip_values: bool - origin_axis: Dimension - neighbor_axis: Dimension - index_type: type[int] | type[np.int32] | type[np.int64] +class NeighborConnectivity(Connectivity, Protocol): + # TODO(havogt): work towards encoding this properly in the type + def __gt_type__(self) -> NeighborConnectivityType: ... + - def mapped_index( - self, cur_index: int | np.integer, neigh_index: int | np.integer - ) -> Optional[int | np.integer]: - """Return neighbor index.""" +def is_neighbor_connectivity(obj: Any) -> TypeGuard[NeighborConnectivity]: + if not isinstance(obj, Connectivity): + return False + domain_dims = obj.domain.dims + return ( + len(domain_dims) == 2 + and domain_dims[0].kind is DimensionKind.HORIZONTAL + and domain_dims[1].kind is DimensionKind.LOCAL + ) -@runtime_checkable -class NeighborTable(Connectivity, Protocol): - table: npt.NDArray +class NeighborTable( + NeighborConnectivity, Protocol +): # TODO(havogt): try to express by inheriting from NdArrayConnectivityField (but this would require a protocol to move it out of `embedded.nd_array_field`) + @property + def ndarray(self) -> core_defs.NDArrayObject: + # Note that this property is currently already there from inheriting from `Field`, + # however this seems wrong, therefore we explicitly introduce it here (or it should come + # implicitly from the `NdArrayConnectivityField` protocol). + ... -OffsetProviderElem: TypeAlias = Dimension | Connectivity +def is_neighbor_table(obj: Any) -> TypeGuard[NeighborTable]: + return is_neighbor_connectivity(obj) and hasattr(obj, "ndarray") + + +OffsetProviderElem: TypeAlias = Dimension | NeighborConnectivity +OffsetProviderTypeElem: TypeAlias = Dimension | NeighborConnectivityType OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem] +OffsetProviderType: TypeAlias = Mapping[Tag, OffsetProviderTypeElem] + + +def offset_provider_to_type(offset_provider: OffsetProvider) -> OffsetProviderType: + return { + k: v.__gt_type__() if isinstance(v, Connectivity) else v for k, v in offset_provider.items() + } DomainDimT = TypeVar("DomainDimT", bound="Dimension") @dataclasses.dataclass(frozen=True, eq=False) -class CartesianConnectivity(ConnectivityField[Dims[DomainDimT], DimT]): +class CartesianConnectivity(Connectivity[Dims[DomainDimT], DimT]): domain_dim: DomainDimT codomain: DimT offset: int = 0 @@ -981,7 +1055,7 @@ def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: return core_defs.Int32DType() # type: ignore[return-value] # This is a workaround to make this class concrete, since `codomain` is an - # abstract property of the `ConnectivityField` Protocol. + # abstract property of the `Connectivity` Protocol. if not TYPE_CHECKING: @functools.cached_property @@ -1024,9 +1098,9 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa def premap( self, - index_field: ConnectivityField | fbuiltins.FieldOffset, - *args: ConnectivityField | fbuiltins.FieldOffset, - ) -> ConnectivityField: + index_field: Connectivity | fbuiltins.FieldOffset, + *args: Connectivity | fbuiltins.FieldOffset, + ) -> Connectivity: raise NotImplementedError() __call__ = premap diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index dd52559e85..7b39511674 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -290,22 +290,24 @@ def as_connectivity( *, allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, device: Optional[core_defs.Device] = None, - skip_value: Optional[core_defs.IntegralScalar] = None, + skip_value: core_defs.IntegralScalar | eve.NothingType | None = eve.NOTHING, # TODO: copy=False -) -> common.ConnectivityField: +) -> common.Connectivity: """ - Construct a connectivity field from the given domain, codomain, and data. + Construct a `Connectivity` from the given domain, codomain, and data. Arguments: - domain: The domain of the connectivity field. It can be either a `common.DomainLike` object or a + domain: The domain of the connectivity. It can be either a `common.DomainLike` object or a sequence of `common.Dimension` objects. - codomain: The codomain dimension of the connectivity field. + codomain: The codomain dimension of the connectivity. data: The data used to construct the connectivity field. - dtype: The data type of the connectivity field. If not provided, it will be inferred from the data. - allocator: The allocator used to allocate the buffer for the connectivity field. If not provided, + dtype: The data type of the connectivity. If not provided, it will be inferred from the data. + allocator: The allocator used to allocate the buffer for the connectivity. If not provided, a default allocator will be used. - device: The device on which the connectivity field will be allocated. If not provided, the default + device: The device on which the connectivity will be allocated. If not provided, the default device will be used. + skip_value: The value that signals missing entries in the neighbor table. Defaults to the default + skip value if it is found in data, otherwise to `None` (= no skip value). Returns: The constructed connectivity field. @@ -313,9 +315,15 @@ def as_connectivity( Raises: ValueError: If the domain or codomain is invalid, or if the shape of the data does not match the domain shape. """ + if skip_value is eve.NOTHING: + skip_value = ( + common._DEFAULT_SKIP_VALUE if (data == common._DEFAULT_SKIP_VALUE).any() else None + ) + assert ( skip_value is None or skip_value == common._DEFAULT_SKIP_VALUE ) # TODO(havogt): not yet configurable + skip_value = cast(Optional[core_defs.IntegralScalar], skip_value) if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9ff5feaaee..e15fb4266a 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -36,7 +36,6 @@ exceptions as embedded_exceptions, ) from gt4py.next.ffront import experimental, fbuiltins -from gt4py.next.iterator import embedded as itir_embedded try: @@ -189,10 +188,10 @@ def from_array( def premap( self: NdArrayField, - *connectivities: common.ConnectivityField | fbuiltins.FieldOffset, + *connectivities: common.Connectivity | fbuiltins.FieldOffset, ) -> NdArrayField: """ - Rearrange the field content using the provided connectivity fields as index mappings. + Rearrange the field content using the provided connectivities (index mappings). This operation is conceptually equivalent to a regular composition of mappings `f∘c`, being `c` the `connectivity` argument and `f` the `self` data field. @@ -206,7 +205,7 @@ def premap( argument used in the right hand side of the operator should therefore have the same product of dimensions `c: S × T → A × B`. Such a mapping can also be expressed as a pair of mappings `c1: S × T → A` and `c2: S × T → B`, and this - is actually the only supported form in GT4Py because `ConnectivityField` instances + is actually the only supported form in GT4Py because `Connectivity` instances can only deal with a single dimension in its codomain. This approach makes connectivities reusable for any combination of dimensions in a field domain and matches the NumPy advanced indexing API, which basically is a @@ -261,15 +260,15 @@ def premap( """ # noqa: RUF002 # TODO(egparedes): move docstring to the `premap` builtin function when it exists - conn_fields: list[common.ConnectivityField] = [] + conn_fields: list[common.Connectivity] = [] codomains_counter: collections.Counter[common.Dimension] = collections.Counter() for connectivity in connectivities: - # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField - if not isinstance(connectivity, common.ConnectivityField): + # For neighbor reductions, a FieldOffset is passed instead of an actual Connectivity + if not isinstance(connectivity, common.Connectivity): assert isinstance(connectivity, fbuiltins.FieldOffset) connectivity = connectivity.as_connectivity_field() - assert isinstance(connectivity, common.ConnectivityField) + assert isinstance(connectivity, common.Connectivity) # Current implementation relies on skip_value == -1: # if we assume the indexed array has at least one element, @@ -318,8 +317,8 @@ def premap( def __call__( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: return functools.reduce( lambda field, current_index_field: field.premap(current_index_field), @@ -460,7 +459,7 @@ def _dace_descriptor(self) -> Any: @dataclasses.dataclass(frozen=True) class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__ - common.ConnectivityField[common.DimsT, common.DimT], + common.Connectivity[common.DimsT, common.DimT], NdArrayField[common.DimsT, core_defs.IntegralScalar], ): _codomain: common.DimT @@ -579,7 +578,7 @@ def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField: __getitem__ = restrict -def _domain_premap(data: NdArrayField, *connectivities: common.ConnectivityField) -> NdArrayField: +def _domain_premap(data: NdArrayField, *connectivities: common.Connectivity) -> NdArrayField: """`premap` implementation transforming only the field domain not the data (i.e. translation and relocation).""" new_domain = data.domain for connectivity in connectivities: @@ -668,7 +667,7 @@ def _reshuffling_premap( ) -def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField) -> NdArrayField: +def _remapping_premap(data: NdArrayField, connectivity: common.Connectivity) -> NdArrayField: new_dims = {*connectivity.domain.dims} - {connectivity.codomain} if repeated_dims := (new_dims & {*data.domain.dims}): raise ValueError(f"Remapped field will contain repeated dimensions '{repeated_dims}'.") @@ -693,7 +692,7 @@ def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField if restricted_connectivity_domain != connectivity.domain else connectivity ) - assert isinstance(restricted_connectivity, common.ConnectivityField) + assert isinstance(restricted_connectivity, common.Connectivity) # 2- then compute the index array new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start @@ -971,7 +970,7 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[has-type] +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] def _make_reduction( @@ -996,15 +995,15 @@ def _builtin_op( offset_definition = current_offset_provider[ axis.value ] # assumes offset and local dimension have same name - assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) + assert common.is_neighbor_table(offset_definition) new_domain = common.Domain(*[nr for nr in field.domain if nr.dim != axis]) broadcast_slice = tuple( - slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis + slice(None) if d in [axis, offset_definition.domain.dims[0]] else xp.newaxis for d in field.domain.dims ) masked_array = xp.where( - xp.asarray(offset_definition.table[broadcast_slice]) != common._DEFAULT_SKIP_VALUE, + xp.asarray(offset_definition.ndarray[broadcast_slice]) != common._DEFAULT_SKIP_VALUE, field.ndarray, initial_value_op(field), ) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index dc2421e1d2..9ce07d01bb 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -30,7 +30,6 @@ embedded as next_embedded, errors, ) -from gt4py.next.common import Connectivity, Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( field_operator_ast as foast, @@ -82,15 +81,15 @@ class Program: definition_stage: ffront_stages.ProgramDefinition backend: Optional[next_backend.Backend] - connectivities: Optional[dict[str, Connectivity]] + connectivities: Optional[common.OffsetProviderType] = None @classmethod def from_function( cls, definition: types.FunctionType, backend: Optional[next_backend], - grid_type: Optional[GridType] = None, - connectivities: Optional[dict[str, Connectivity]] = None, + grid_type: Optional[common.GridType] = None, + connectivities: Optional[common.OffsetProviderType] = None, ) -> Program: program_def = ffront_stages.ProgramDefinition(definition=definition, grid_type=grid_type) return cls(definition_stage=program_def, backend=backend, connectivities=connectivities) @@ -140,10 +139,10 @@ def _frontend_transforms(self) -> next_backend.Transforms: def with_backend(self, backend: next_backend.Backend) -> Program: return dataclasses.replace(self, backend=backend) - def with_connectivities(self, connectivities: dict[str, Connectivity]) -> Program: + def with_connectivities(self, connectivities: common.OffsetProviderType) -> Program: return dataclasses.replace(self, connectivities=connectivities) - def with_grid_type(self, grid_type: GridType) -> Program: + def with_grid_type(self, grid_type: common.GridType) -> Program: return dataclasses.replace( self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) @@ -199,7 +198,7 @@ def itir(self) -> itir.FencilDefinition: return self._frontend_transforms.past_to_itir(no_args_past).data @functools.cached_property - def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderElem]: + def _implicit_offset_provider(self) -> dict[str, common.Dimension]: """ Add all implicit offset providers. @@ -226,9 +225,7 @@ def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderEle ) return implicit_offset_provider - def __call__( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any - ) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: offset_provider = offset_provider | self._implicit_offset_provider if self.backend is None: warnings.warn( @@ -287,19 +284,17 @@ def definition(self) -> str: def with_backend(self, backend: next_backend.Backend) -> FrozenProgram: return self.__class__(program=self.program, backend=backend) - def with_grid_type(self, grid_type: GridType) -> FrozenProgram: + def with_grid_type(self, grid_type: common.GridType) -> FrozenProgram: return self.__class__( program=dataclasses.replace(self.program, grid_type=grid_type), backend=self.backend ) def jit( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any + self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any ) -> stages.CompiledProgram: return self.backend.jit(self.program, *args, offset_provider=offset_provider, **kwargs) - def __call__( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any - ) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: args, kwargs = signature.convert_to_positional(self.program, *args, **kwargs) if not self._compiled_program: @@ -328,7 +323,7 @@ class ProgramFromPast(Program): past_stage: ffront_stages.PastProgramDefinition - def __call__(self, *args: Any, offset_provider: dict[str, Dimension], **kwargs: Any) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: if self.backend is None: raise NotImplementedError( "Programs created from a PAST node (without a function definition) can not be executed in embedded mode" @@ -350,7 +345,7 @@ def __post_init__(self): class ProgramWithBoundArgs(Program): bound_args: dict[str, typing.Union[float, int, bool]] = None - def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): + def __call__(self, *args, offset_provider: common.OffsetProvider, **kwargs): type_ = self.past_stage.past_node.type new_type = ts_ffront.ProgramType( definition=ts.FunctionType( @@ -436,7 +431,7 @@ def program( *, # `NOTHING` -> default backend, `None` -> no backend (embedded execution) backend: next_backend.Backend | eve.NOTHING = eve.NOTHING, - grid_type: Optional[GridType] = None, + grid_type: Optional[common.GridType] = None, frozen: bool = False, ) -> Program | FrozenProgram | Callable[[types.FunctionType], Program | FrozenProgram]: """ @@ -506,7 +501,7 @@ def from_function( cls, definition: types.FunctionType, backend: Optional[next_backend.Backend], - grid_type: Optional[GridType] = None, + grid_type: Optional[common.GridType] = None, *, operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, operator_attributes: Optional[dict[str, Any]] = None, @@ -557,7 +552,7 @@ def __gt_type__(self) -> ts.CallableType: def with_backend(self, backend: next_backend.Backend) -> FieldOperator: return dataclasses.replace(self, backend=backend) - def with_grid_type(self, grid_type: GridType) -> FieldOperator: + def with_grid_type(self, grid_type: common.GridType) -> FieldOperator: return dataclasses.replace( self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) @@ -688,33 +683,33 @@ def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast. def scan_operator( definition: types.FunctionType, *, - axis: Dimension, + axis: common.Dimension, forward: bool, init: core_defs.Scalar, backend: Optional[str], - grid_type: GridType, + grid_type: common.GridType, ) -> FieldOperator[foast.ScanOperator]: ... @typing.overload def scan_operator( *, - axis: Dimension, + axis: common.Dimension, forward: bool, init: core_defs.Scalar, backend: Optional[str], - grid_type: GridType, + grid_type: common.GridType, ) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... def scan_operator( definition: Optional[types.FunctionType] = None, *, - axis: Dimension, + axis: common.Dimension, forward: bool = True, init: core_defs.Scalar = 0.0, backend=eve.NOTHING, - grid_type: GridType = None, + grid_type: common.GridType = None, ) -> ( FieldOperator[foast.ScanOperator] | Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]] diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index 8a94c20832..bd22aebe57 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -14,7 +14,7 @@ @BuiltInFunction -def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.ConnectivityField: +def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivity: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index d932431b51..b60fa63f95 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -16,7 +16,6 @@ import numpy as np from numpy import float32, float64, int32, int64 -import gt4py.next as gtx from gt4py._core import definitions as core_defs from gt4py.next import common from gt4py.next.common import Dimension, Field # noqa: F401 [unused-import] for TYPE_BUILTINS @@ -55,7 +54,7 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.DimensionType elif t is FieldOffset: return ts.OffsetType - elif t is common.ConnectivityField: + elif t is common.Connectivity: return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType @@ -321,7 +320,7 @@ def __post_init__(self) -> None: def __gt_type__(self) -> ts.OffsetType: return ts.OffsetType(source=self.source, target=self.target) - def __getitem__(self, offset: int) -> common.ConnectivityField: + def __getitem__(self, offset: int) -> common.Connectivity: """Serve as a connectivity factory.""" from gt4py.next import embedded # avoid circular import @@ -330,22 +329,19 @@ def __getitem__(self, offset: int) -> common.ConnectivityField: assert current_offset_provider is not None offset_definition = current_offset_provider[self.value] - connectivity: common.ConnectivityField + connectivity: common.Connectivity if isinstance(offset_definition, common.Dimension): connectivity = common.CartesianConnectivity(offset_definition, offset) - elif isinstance( - offset_definition, (gtx.NeighborTableOffsetProvider, common.ConnectivityField) - ): - unrestricted_connectivity = self.as_connectivity_field() - assert unrestricted_connectivity.domain.ndim > 1 + elif isinstance(offset_definition, common.Connectivity): + assert common.is_neighbor_connectivity(offset_definition) named_index = common.NamedIndex(self.target[-1], offset) - connectivity = unrestricted_connectivity[named_index] + connectivity = offset_definition[named_index] else: raise NotImplementedError() return connectivity - def as_connectivity_field(self) -> common.ConnectivityField: + def as_connectivity_field(self) -> common.Connectivity: """Convert to connectivity field using the offset providers in current embedded execution context.""" from gt4py.next import embedded # avoid circular import @@ -356,18 +352,8 @@ def as_connectivity_field(self) -> common.ConnectivityField: cache_key = id(offset_definition) if (connectivity := self._cache.get(cache_key, None)) is None: - if isinstance(offset_definition, common.ConnectivityField): + if isinstance(offset_definition, common.Connectivity): connectivity = offset_definition - elif isinstance(offset_definition, gtx.NeighborTableOffsetProvider): - connectivity = gtx.as_connectivity( - domain=self.target, - codomain=self.source, - data=offset_definition.table, - dtype=offset_definition.index_type, - skip_value=( - common._DEFAULT_SKIP_VALUE if offset_definition.has_skip_values else None - ), - ) else: raise NotImplementedError() diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 6221c95522..3c63ffef30 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -93,77 +93,113 @@ class SparseTag(Tag): ... -class NeighborTableOffsetProvider: +@xtyping.deprecated("Use a 'Connectivity' instead.") +def NeighborTableOffsetProvider( + table: core_defs.NDArrayObject, + origin_axis: common.Dimension, + neighbor_axis: common.Dimension, + max_neighbors: int, + has_skip_values=True, +) -> common.Connectivity: + return common._connectivity( + table, + codomain=neighbor_axis, + domain={ + origin_axis: table.shape[0], + common.Dimension( + value="_DummyLocalDim", kind=common.DimensionKind.LOCAL + ): max_neighbors, + }, + skip_value=common._DEFAULT_SKIP_VALUE if has_skip_values else None, + ) + + +# TODO(havogt): complete implementation and make available for fieldview embedded +@dataclasses.dataclass(frozen=True) +class StridedConnectivityField(common.Connectivity): + domain_dims: tuple[common.Dimension, common.Dimension] + codomain_dim: common.Dimension + _max_neighbors: int + def __init__( self, - table: core_defs.NDArrayObject, - origin_axis: common.Dimension, - neighbor_axis: common.Dimension, + domain_dims: Sequence[common.Dimension], + codomain_dim: common.Dimension, max_neighbors: int, - has_skip_values=True, - ) -> None: - self.table = table - self.origin_axis = origin_axis - self.neighbor_axis = neighbor_axis - assert not hasattr(table, "shape") or table.shape[1] == max_neighbors - self.max_neighbors = max_neighbors - self.has_skip_values = has_skip_values - self.index_type = table.dtype - - def mapped_index( - self, primary: common.IntIndex, neighbor_idx: common.IntIndex - ) -> common.IntIndex: - res = self.table[(primary, neighbor_idx)] - assert common.is_int_index(res) - return res + ): + object.__setattr__(self, "domain_dims", tuple(domain_dims)) + object.__setattr__(self, "codomain_dim", codomain_dim) + object.__setattr__(self, "_max_neighbors", max_neighbors) - if dace: - # Extension of NeighborTableOffsetProvider adding SDFGConvertible support in GT4Py Programs - def _dace_data_ptr(self) -> int: - obj = self.table - if dace.dtypes.is_array(obj): - if hasattr(obj, "__array_interface__"): - return obj.__array_interface__["data"][0] - if hasattr(obj, "__cuda_array_interface__"): - return obj.__cuda_array_interface__["data"][0] - raise ValueError("Unsupported data container.") - - def _dace_descriptor(self) -> dace.data.Data: - return dace.data.create_datadescriptor(self.table) - else: + @property + def __gt_origin__(self) -> xtyping.Never: + raise NotImplementedError + + def __gt_type__(self) -> common.NeighborConnectivityType: + return common.NeighborConnectivityType( + domain=self.domain_dims, + codomain=self.codomain_dim, + max_neighbors=self._max_neighbors, + skip_value=self.skip_value, + dtype=self.dtype, + ) - def _dace_data_ptr(self) -> NoReturn: # type: ignore[misc] - raise NotImplementedError( - "data_ptr is only supported when the 'dace' module is available." - ) + @property + def domain(self) -> common.Domain: + return common.Domain( + dims=self.domain_dims, + ranges=(common.UnitRange.infinite(), common.unit_range(self._max_neighbors)), + ) - def _dace_descriptor(self) -> NoReturn: # type: ignore[misc] - raise NotImplementedError( - "__descriptor__ is only supported when the 'dace' module is available." - ) + @property + def codomain(self) -> common.Dimension: + return self.codomain_dim - data_ptr = _dace_data_ptr - __descriptor__ = _dace_descriptor + @property + def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: + return core_defs.Int32DType() # type: ignore[return-value] + @property + def ndarray(self) -> core_defs.NDArrayObject: + raise NotImplementedError -class StridedNeighborOffsetProvider: - def __init__( + def asnumpy(self) -> np.ndarray: + raise NotImplementedError + + def premap(self, index_field: common.Connectivity | fbuiltins.FieldOffset) -> common.Field: + raise NotImplementedError + + def restrict( # type: ignore[override] self, - origin_axis: common.Dimension, - neighbor_axis: common.Dimension, - max_neighbors: int, - has_skip_values=True, - ) -> None: - self.origin_axis = origin_axis - self.neighbor_axis = neighbor_axis - self.max_neighbors = max_neighbors - self.has_skip_values = has_skip_values - self.index_type = int + item: common.AnyIndexSpec, + ) -> common.Field: + if not isinstance(item, tuple) or (isinstance(item, tuple) and not len(item) == 2): + raise NotImplementedError() # TODO(havogt): add proper slicing + index = item[0] * self._max_neighbors + item[1] # type: ignore[operator, call-overload] + return ConstantField(index) + + def as_scalar(self) -> xtyping.Never: + raise NotImplementedError() + + def __call__( + self, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, + ) -> common.Field: + raise NotImplementedError() - def mapped_index( - self, primary: common.IntIndex, neighbor_idx: common.IntIndex - ) -> common.IntIndex: - return primary * self.max_neighbors + neighbor_idx + __getitem__ = restrict # type: ignore[assignment] + + def inverse_image( + self, image_range: common.UnitRange | common.NamedRange + ) -> Sequence[common.NamedRange]: + raise NotImplementedError + + @property + def skip_value( + self, + ) -> None: + return None # Offsets @@ -597,10 +633,11 @@ def execute_shift( new_entry[i] = 0 else: offset_implementation = offset_provider[tag] - assert isinstance(offset_implementation, common.Connectivity) - cur_index = pos[offset_implementation.origin_axis.value] + assert common.is_neighbor_connectivity(offset_implementation) + source_dim = offset_implementation.__gt_type__().source_dim + cur_index = pos[source_dim.value] assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ + if offset_implementation[cur_index, index].as_scalar() in [ None, common._DEFAULT_SKIP_VALUE, ]: @@ -620,22 +657,22 @@ def execute_shift( else: raise AssertionError() return new_pos - else: - assert isinstance(offset_implementation, common.Connectivity) - assert offset_implementation.origin_axis.value in pos + elif common.is_neighbor_connectivity(offset_implementation): + source_dim = offset_implementation.__gt_type__().source_dim + assert source_dim.value in pos new_pos = pos.copy() - new_pos.pop(offset_implementation.origin_axis.value) - cur_index = pos[offset_implementation.origin_axis.value] + new_pos.pop(source_dim.value) + cur_index = pos[source_dim.value] assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ + if offset_implementation[cur_index, index].as_scalar() in [ None, common._DEFAULT_SKIP_VALUE, ]: return None else: - new_index = offset_implementation.mapped_index(cur_index, index) + new_index = offset_implementation[cur_index, index].as_scalar() assert new_index is not None - new_pos[offset_implementation.neighbor_axis.value] = int(new_index) + new_pos[offset_implementation.codomain.value] = int(new_index) return new_pos @@ -1196,8 +1233,8 @@ def as_scalar(self) -> core_defs.IntegralScalar: def premap( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1322,8 +1359,8 @@ def asnumpy(self) -> np.ndarray: def premap( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1428,10 +1465,12 @@ def __gt_type__(self) -> itir_ts.ListType: assert isinstance(offset_tag, str) element_type = type_translation.from_value(self.values[0]) assert isinstance(element_type, ts.DataType) - return itir_ts.ListType( - element_type=element_type, - offset_type=common.Dimension(value=offset_tag, kind=common.DimensionKind.LOCAL), - ) + offset_provider = embedded_context.offset_provider.get() + assert offset_provider is not None + connectivity = offset_provider[offset_tag] + assert common.is_neighbor_connectivity(connectivity) + local_dim = connectivity.__gt_type__().neighbor_dim + return itir_ts.ListType(element_type=element_type, offset_type=local_dim) @dataclasses.dataclass(frozen=True) @@ -1457,11 +1496,11 @@ def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[offset_str] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return _List( values=tuple( shifted.deref() - for i in range(connectivity.max_neighbors) + for i in range(connectivity.__gt_type__().max_neighbors) if (shifted := it.shift(offset_str, i)).can_deref() ), offset=offset, @@ -1533,11 +1572,11 @@ def deref(self) -> Any: offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[self.list_offset] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return _List( values=tuple( shifted.deref() - for i in range(connectivity.max_neighbors) + for i in range(connectivity.__gt_type__().max_neighbors) if ( shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i) ).can_deref() @@ -1671,9 +1710,9 @@ def _dimension_to_tag(domain: Domain) -> dict[Tag, range]: return {k.value if isinstance(k, common.Dimension) else k: v for k, v in domain.items()} -def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: +def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProviderType) -> None: if isinstance(domain, runtime.CartesianDomain): - if any(isinstance(o, common.Connectivity) for o in offset_provider.values()): + if any(isinstance(o, common.ConnectivityType) for o in offset_provider_type.values()): raise RuntimeError( "Got a 'CartesianDomain', but found a 'Connectivity' in 'offset_provider', expected 'UnstructuredDomain'." ) @@ -1770,10 +1809,10 @@ def _fieldspec_list_to_value( offset_type = type_.offset_type assert isinstance(offset_type, common.Dimension) connectivity = offset_provider[offset_type.value] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return domain.insert( len(domain), - common.named_range((offset_type, connectivity.max_neighbors)), + common.named_range((offset_type, connectivity.__gt_type__().max_neighbors)), ), type_.element_type return domain, type_ @@ -1809,7 +1848,7 @@ def closure( ) -> None: assert embedded_context.within_valid_context() offset_provider = embedded_context.offset_provider.get() - _validate_domain(domain_, offset_provider) + _validate_domain(domain_, common.offset_provider_to_type(offset_provider)) domain: dict[Tag, range] = _dimension_to_tag(domain_) if not (isinstance(out, common.Field) or is_tuple_of_field(out)): raise TypeError("'Out' needs to be a located field.") diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 8f842e1c13..f5625b509c 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -12,7 +12,6 @@ import functools from typing import Any, Literal, Mapping, Optional -import gt4py.next as gtx from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -23,20 +22,19 @@ def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> di """ Extract horizontal domain sizes from an `offset_provider`. - Considers the shape of the neighbor table to get the size of each `origin_axis` and the maximum - value inside the neighbor table to get the size of each `neighbor_axis`. + Considers the shape of the neighbor table to get the size of each `source_dim` and the maximum + value inside the neighbor table to get the size of each `codomain`. """ sizes = dict[str, int]() for provider in offset_provider.values(): - if isinstance(provider, gtx.NeighborTableOffsetProvider): - assert provider.origin_axis.kind == gtx.DimensionKind.HORIZONTAL - assert provider.neighbor_axis.kind == gtx.DimensionKind.HORIZONTAL - sizes[provider.origin_axis.value] = max( - sizes.get(provider.origin_axis.value, 0), provider.table.shape[0] + if common.is_neighbor_connectivity(provider): + conn_type = provider.__gt_type__() + sizes[conn_type.source_dim.value] = max( + sizes.get(conn_type.source_dim.value, 0), provider.ndarray.shape[0] ) - sizes[provider.neighbor_axis.value] = max( - sizes.get(provider.neighbor_axis.value, 0), - provider.table.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject + sizes[conn_type.codomain.value] = max( + sizes.get(conn_type.codomain.value, 0), + provider.ndarray.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject ) return sizes @@ -114,7 +112,7 @@ def translate( new_ranges[current_dim] = SymbolicRange.translate( self.ranges[current_dim], val.value ) - elif isinstance(nbt_provider, common.Connectivity): + elif common.is_neighbor_connectivity(nbt_provider): # unstructured shift assert ( isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) @@ -132,8 +130,8 @@ def translate( for k, v in _max_domain_sizes_by_location_type(offset_provider).items() } - old_dim = nbt_provider.origin_axis - new_dim = nbt_provider.neighbor_axis + old_dim = nbt_provider.__gt_type__().source_dim + new_dim = nbt_provider.__gt_type__().codomain assert new_dim not in new_ranges or old_dim == new_dim diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index ad85d154cb..d42f961202 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -12,7 +12,7 @@ import functools import types from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import devtools @@ -127,7 +127,9 @@ def fendef( ) -def _deduce_domain(domain: dict[common.Dimension, range], offset_provider: dict[str, Any]): +def _deduce_domain( + domain: dict[common.Dimension, range], offset_provider_type: common.OffsetProviderType +): if isinstance(domain, UnstructuredDomain): domain_builtin = builtins.unstructured_domain elif isinstance(domain, CartesianDomain): @@ -135,7 +137,7 @@ def _deduce_domain(domain: dict[common.Dimension, range], offset_provider: dict[ else: domain_builtin = ( builtins.unstructured_domain - if any(isinstance(o, common.Connectivity) for o in offset_provider.values()) + if any(isinstance(o, common.ConnectivityType) for o in offset_provider_type.values()) else builtins.cartesian_domain ) @@ -160,7 +162,7 @@ def impl(out, *inps): elif isinstance(dom, dict): # if passed as a dict, we need to convert back to builtins for interpretation by the backends assert offset_provider is not None - dom = _deduce_domain(dom, offset_provider) + dom = _deduce_domain(dom, common.offset_provider_to_type(offset_provider)) closure(dom, self.fundef_dispatcher, out, [*inps]) return impl diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index f84714e779..e71a24127f 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -105,7 +105,7 @@ def apply( *, ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, - offset_provider: Optional[common.OffsetProvider] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, within_stencil: Optional[bool] = None, # manually passing flags is mostly for allowing separate testing of the modes flags: Optional[Flag] = None, @@ -126,7 +126,7 @@ def apply( `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` """ flags = flags or cls.flags - offset_provider = offset_provider or {} + offset_provider_type = offset_provider_type or {} if isinstance(node, (ir.Program, ir.FencilDefinition)): within_stencil = False @@ -138,7 +138,7 @@ def apply( if not ignore_tuple_size: node = itir_type_inference.infer( node, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, allow_undeclared_symbols=allow_undeclared_symbols, ) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 38ea1fd53d..824adfdd8d 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -411,7 +411,7 @@ def apply( cls, node: ProgramOrExpr, within_stencil: bool | None = None, - offset_provider: common.OffsetProvider | None = None, + offset_provider_type: common.OffsetProviderType | None = None, ) -> ProgramOrExpr: is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) if is_program: @@ -422,9 +422,9 @@ def apply( within_stencil is not None ), "The expression's context must be specified using `within_stencil`." - offset_provider = offset_provider or {} + offset_provider_type = offset_provider_type or {} node = itir_type_inference.infer( - node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program + node, offset_provider_type=offset_provider_type, allow_undeclared_symbols=not is_program ) return cls().visit(node, within_stencil=within_stencil) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index da238733da..9076bf2d3f 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -11,6 +11,7 @@ from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import ( @@ -89,7 +90,7 @@ class FuseAsFieldOp(eve.NodeTranslator): ) >>> print( ... FuseAsFieldOp.apply( - ... nested_as_fieldop, offset_provider={}, allow_undeclared_symbols=True + ... nested_as_fieldop, offset_provider_type={}, allow_undeclared_symbols=True ... ) ... ) as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) @@ -134,12 +135,14 @@ def apply( cls, node: itir.Program, *, - offset_provider, + offset_provider_type: common.OffsetProviderType, uids: Optional[eve_utils.UIDGenerator] = None, allow_undeclared_symbols=False, ): node = type_inference.infer( - node, offset_provider=offset_provider, allow_undeclared_symbols=allow_undeclared_symbols + node, + offset_provider_type=offset_provider_type, + allow_undeclared_symbols=allow_undeclared_symbols, ) if not uids: diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 90f8a6cded..a6d39883e3 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -187,7 +187,9 @@ def create_global_tmps( arguments into temporaries. """ program = infer_domain.infer_program(program, offset_provider) - program = type_inference.infer(program, offset_provider=offset_provider) + program = type_inference.infer( + program, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) if not uids: uids = eve_utils.UIDGenerator(prefix="__tmp") diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py index c6e2c38b90..87b576d14d 100644 --- a/src/gt4py/next/iterator/transforms/inline_scalar.py +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -17,8 +17,8 @@ class InlineScalar(eve.NodeTranslator): @classmethod - def apply(cls, program: itir.Program, offset_provider: common.OffsetProvider): - program = itir_inference.infer(program, offset_provider=offset_provider) + def apply(cls, program: itir.Program, offset_provider_type: common.OffsetProviderType): + program = itir_inference.infer(program, offset_provider_type=offset_provider_type) return cls().visit(program) def visit_Expr(self, node: itir.Expr): diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 52a452155a..ec6f89685a 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -43,8 +43,8 @@ def __call__( def apply_common_transforms( ir: itir.Program | itir.FencilDefinition, *, + offset_provider=None, # TODO(havogt): should be replaced by offset_provider_type, but global_tmps currently relies on runtime info extract_temporaries=False, - offset_provider=None, unroll_reduce=False, common_subexpression_elimination=True, force_inline_lambda_args=False, @@ -56,7 +56,12 @@ def apply_common_transforms( #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, ) -> itir.Program: + # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps + if offset_provider_type is None: + offset_provider_type = common.offset_provider_to_type(offset_provider) + # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this if isinstance(ir, itir.FencilDefinition): ir = fencil_to_program.FencilToProgram.apply(ir) @@ -75,7 +80,7 @@ def apply_common_transforms( # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) - ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( ir, # type: ignore[arg-type] # always an itir.Program offset_provider=offset_provider, @@ -89,15 +94,15 @@ def apply_common_transforms( inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply(inlined, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program - inlined = InlineScalar.apply(inlined, offset_provider=offset_provider) + inlined = CollapseTuple.apply(inlined, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type) # This pass is required to run after CollapseTuple as otherwise we can not inline # expressions like `tuple_get(make_tuple(as_fieldop(stencil)(...)))` where stencil returns # a list. Such expressions must be inlined however because no backend supports such # field operators right now. inlined = fuse_as_fieldop.FuseAsFieldOp.apply( - inlined, uids=mergeasfop_uids, offset_provider=offset_provider + inlined, uids=mergeasfop_uids, offset_provider_type=offset_provider_type ) if inlined == ir: @@ -108,19 +113,21 @@ def apply_common_transforms( # breaks in test_zero_dim_tuple_arg as trivial tuple_get is not inlined if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) + ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) ir = MergeLet().visit(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True) if extract_temporaries: - ir = infer(ir, inplace=True, offset_provider=offset_provider) + ir = infer(ir, inplace=True, offset_provider_type=offset_provider_type) ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: - ir = CollapseTuple.apply(ir, ignore_tuple_size=True, offset_provider=offset_provider) # type: ignore[assignment] # always an itir.Program + ir = CollapseTuple.apply( + ir, ignore_tuple_size=True, offset_provider_type=offset_provider_type + ) # type: ignore[assignment] # always an itir.Program ir = NormalizeShifts().visit(ir) @@ -129,7 +136,7 @@ def apply_common_transforms( if unroll_reduce: for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) if unrolled == ir: break ir = unrolled # type: ignore[assignment] # still a `itir.Program` @@ -156,6 +163,8 @@ def apply_fieldview_transforms( ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) - ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # type is still `itir.Program` + ir = CollapseTuple.apply( + ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) # type: ignore[assignment] # type is still `itir.Program` ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py index 792bb421f1..94c962e92d 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py +++ b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py @@ -10,6 +10,7 @@ from typing import Callable, Optional from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet @@ -75,8 +76,13 @@ def apply_common_transforms( Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, ) -> itir.Program: assert isinstance(ir, itir.FencilDefinition) + # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps + if offset_provider_type is None: + offset_provider_type = common.offset_provider_to_type(offset_provider) + ir = fencil_to_program.FencilToProgram().apply(ir) icdlv_uids = eve_utils.UIDGenerator() @@ -109,7 +115,7 @@ def apply_common_transforms( # is constant-folded the surrounding tuple_get calls can be removed. inlined = CollapseTuple.apply( inlined, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, # TODO(tehrengruber): disabled since it increases compile-time too much right now flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, ) @@ -134,7 +140,7 @@ def apply_common_transforms( ir = CollapseTuple.apply( ir, ignore_tuple_size=True, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, # TODO(tehrengruber): disabled since it increases compile-time too much right now flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, ) @@ -149,7 +155,7 @@ def apply_common_transforms( if unroll_reduce: for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) if unrolled == ir: break ir = unrolled @@ -164,7 +170,7 @@ def apply_common_transforms( ir = ScanEtaReduction().visit(ir) if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program + ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[type-var] # always an itir.Program ir = MergeLet().visit(ir) ir = InlineLambdas.apply( diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index ec9c3efb2b..042a86cd8e 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -64,16 +64,16 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]: def _get_connectivity( applied_reduce_node: itir.FunCall, - offset_provider: dict[str, common.Dimension | common.Connectivity], -) -> common.Connectivity: + offset_provider_type: common.OffsetProviderType, +) -> common.NeighborConnectivityType: """Return single connectivity that is compatible with the arguments of the reduce.""" if not cpm.is_applied_reduce(applied_reduce_node): raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") - connectivities: list[common.Connectivity] = [] + connectivities: list[common.NeighborConnectivityType] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): - conn = offset_provider[o] - assert isinstance(conn, common.Connectivity) + conn = offset_provider_type[o] + assert isinstance(conn, common.NeighborConnectivityType) connectivities.append(conn) if not connectivities: @@ -120,15 +120,15 @@ class UnrollReduce(PreserveLocationVisitor, NodeTranslator): uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) @classmethod - def apply(cls, node: itir.Node, **kwargs) -> itir.Node: - return cls().visit(node, **kwargs) - - def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr: - offset_provider = kwargs["offset_provider"] - assert offset_provider is not None - connectivity = _get_connectivity(node, offset_provider) - max_neighbors = connectivity.max_neighbors - has_skip_values = connectivity.has_skip_values + def apply(cls, node: itir.Node, offset_provider_type: common.OffsetProviderType) -> itir.Node: + return cls().visit(node, offset_provider_type=offset_provider_type) + + def _visit_reduce( + self, node: itir.FunCall, offset_provider_type: common.OffsetProviderType + ) -> itir.Expr: + connectivity_type = _get_connectivity(node, offset_provider_type) + max_neighbors = connectivity_type.max_neighbors + has_skip_values = connectivity_type.has_skip_values acc = itir.SymRef(id=self.uids.sequential_id(prefix="_acc")) offset = itir.SymRef(id=self.uids.sequential_id(prefix="_i")) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 66d8345b94..987eb0f308 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -155,7 +155,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): >>> square_func_type_synthesizer = type_synthesizer.TypeSynthesizer( ... type_synthesizer=lambda base: power(base, int_type) ... ) - >>> square_func_type_synthesizer(float_type, offset_provider={}) + >>> square_func_type_synthesizer(float_type, offset_provider_type={}) ScalarType(kind=, shape=None) Note that without a corresponding call the function itself can not be fully typed and as such @@ -169,7 +169,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): ... node=square_func, ... store_inferred_type_in_node=True, ... ) - >>> o_type_synthesizer(float_type, offset_provider={}) + >>> o_type_synthesizer(float_type, offset_provider_type={}) ScalarType(kind=, shape=None) >>> square_func.type == ts.FunctionType( ... pos_only_args=[float_type], pos_or_kw_args={}, kw_only_args={}, returns=float_type @@ -225,13 +225,15 @@ def on_type_ready(self, cb: Callable[[ts.TypeSpec], None]) -> None: def __call__( self, *args: type_synthesizer.TypeOrTypeSynthesizer, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> Union[ts.TypeSpec, ObservableTypeSynthesizer]: assert all( isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args ), "ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer" - return_type_or_synthesizer = self.type_synthesizer(*args, offset_provider=offset_provider) + return_type_or_synthesizer = self.type_synthesizer( + *args, offset_provider_type=offset_provider_type + ) # return type is a typing rule by itself if isinstance(return_type_or_synthesizer, type_synthesizer.TypeSynthesizer): @@ -250,18 +252,18 @@ def __call__( def _get_dimensions_from_offset_provider( - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> dict[str, common.Dimension]: dimensions: dict[str, common.Dimension] = {} - for offset_name, provider in offset_provider.items(): + for offset_name, provider in offset_provider_type.items(): dimensions[offset_name] = common.Dimension( value=offset_name, kind=common.DimensionKind.LOCAL ) if isinstance(provider, common.Dimension): dimensions[provider.value] = provider - elif isinstance(provider, common.Connectivity): - dimensions[provider.origin_axis.value] = provider.origin_axis - dimensions[provider.neighbor_axis.value] = provider.neighbor_axis + elif isinstance(provider, common.NeighborConnectivityType): + dimensions[provider.source_dim.value] = provider.source_dim + dimensions[provider.codomain.value] = provider.codomain return dimensions @@ -318,7 +320,7 @@ class ITIRTypeInference(eve.NodeTranslator): PRESERVED_ANNEX_ATTRS = ("domain",) - offset_provider: common.OffsetProvider + offset_provider_type: common.OffsetProviderType #: Mapping from a dimension name to the actual dimension instance. dimensions: dict[str, common.Dimension] #: Allow sym refs to symbols that have not been declared. Mostly used in testing. @@ -329,7 +331,7 @@ def apply( cls, node: T, *, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, inplace: bool = False, allow_undeclared_symbols: bool = False, ) -> T: @@ -340,7 +342,7 @@ def apply( node: The :class:`itir.Node` to infer the types of. Keyword Arguments: - offset_provider: Offset provider dictionary. + offset_provider_type: Offset provider dictionary. inplace: Write types directly to the given ``node`` instead of returning a copy. allow_undeclared_symbols: Allow references to symbols that don't have a corresponding declaration. This is useful for testing or inference on partially inferred sub-nodes. @@ -403,9 +405,9 @@ def apply( ) instance = cls( - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, dimensions=( - _get_dimensions_from_offset_provider(offset_provider) + _get_dimensions_from_offset_provider(offset_provider_type) | _get_dimensions_from_types( node.pre_walk_values() .if_isinstance(itir.Node) @@ -540,7 +542,7 @@ def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.Stenc for input_ in inputs ] stencil_returns = stencil_type_synthesizer( - *stencil_args, offset_provider=self.offset_provider + *stencil_args, offset_provider_type=self.offset_provider_type ) return it_ts.StencilClosureType( @@ -632,7 +634,7 @@ def visit_FunCall( fun = self.visit(node.fun, ctx=ctx) args = self.visit(node.args, ctx=ctx) - result = fun(*args, offset_provider=self.offset_provider) + result = fun(*args, offset_provider_type=self.offset_provider_type) if isinstance(result, ObservableTypeSynthesizer): assert not result.node diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 43c4465576..5be9ed7438 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -35,20 +35,20 @@ class TypeSynthesizer: - isinstance checks to determine if an object is actually (meant to be) a type synthesizer and not just any callable. - writing simple type synthesizers without cluttering the signature with the additional - offset_provider argument that is only needed by some. + offset_provider_type argument that is only needed by some. """ type_synthesizer: Callable[..., TypeOrTypeSynthesizer] def __post_init__(self): - if "offset_provider" not in inspect.signature(self.type_synthesizer).parameters: + if "offset_provider_type" not in inspect.signature(self.type_synthesizer).parameters: synthesizer = self.type_synthesizer - self.type_synthesizer = lambda *args, offset_provider: synthesizer(*args) + self.type_synthesizer = lambda *args, offset_provider_type: synthesizer(*args) def __call__( - self, *args: TypeOrTypeSynthesizer, offset_provider: common.OffsetProvider + self, *args: TypeOrTypeSynthesizer, offset_provider_type: common.OffsetProviderType ) -> TypeOrTypeSynthesizer: - return self.type_synthesizer(*args, offset_provider=offset_provider) + return self.type_synthesizer(*args, offset_provider_type=offset_provider_type) TypeOrTypeSynthesizer = Union[ts.TypeSpec, TypeSynthesizer] @@ -212,7 +212,7 @@ def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) - def lift(stencil: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def apply_lift( - *its: it_ts.IteratorType, offset_provider: common.OffsetProvider + *its: it_ts.IteratorType, offset_provider_type: common.OffsetProviderType ) -> it_ts.IteratorType: assert all(isinstance(it, it_ts.IteratorType) for it in its) stencil_args = [ @@ -224,7 +224,7 @@ def apply_lift( ) for it in its ] - stencil_return_type = stencil(*stencil_args, offset_provider=offset_provider) + stencil_return_type = stencil(*stencil_args, offset_provider_type=offset_provider_type) assert isinstance(stencil_return_type, ts.DataType) position_dims = its[0].position_dims if its else [] @@ -282,7 +282,7 @@ def as_fieldop( stencil: TypeSynthesizer, domain: Optional[it_ts.DomainType] = None, *, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> TypeSynthesizer: # In case we don't have a domain argument to `as_fieldop` we can not infer the exact result # type. In order to still allow some passes which don't need this information to run before the @@ -308,7 +308,7 @@ def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, ) assert isinstance(stencil_return, ts.DataType) return type_info.apply_to_primitive_constituents( @@ -328,8 +328,10 @@ def scan( assert isinstance(direction, ts.ScalarType) and direction.kind == ts.ScalarKind.BOOL @TypeSynthesizer - def apply_scan(*its: it_ts.IteratorType, offset_provider: common.OffsetProvider) -> ts.DataType: - result = scan_pass(init, *its, offset_provider=offset_provider) + def apply_scan( + *its: it_ts.IteratorType, offset_provider_type: common.OffsetProviderType + ) -> ts.DataType: + result = scan_pass(init, *its, offset_provider_type=offset_provider_type) assert isinstance(result, ts.DataType) return result @@ -340,12 +342,12 @@ def apply_scan(*its: it_ts.IteratorType, offset_provider: common.OffsetProvider) def map_(op: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def applied_map( - *args: it_ts.ListType, offset_provider: common.OffsetProvider + *args: it_ts.ListType, offset_provider_type: common.OffsetProviderType ) -> it_ts.ListType: assert len(args) > 0 assert all(isinstance(arg, it_ts.ListType) for arg in args) arg_el_types = [arg.element_type for arg in args] - el_type = op(*arg_el_types, offset_provider=offset_provider) + el_type = op(*arg_el_types, offset_provider_type=offset_provider_type) assert isinstance(el_type, ts.DataType) return it_ts.ListType(element_type=el_type) @@ -355,15 +357,17 @@ def applied_map( @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @TypeSynthesizer - def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider): + def applied_reduce(*args: it_ts.ListType, offset_provider_type: common.OffsetProviderType): assert all(isinstance(arg, it_ts.ListType) for arg in args) - return op(init, *(arg.element_type for arg in args), offset_provider=offset_provider) + return op( + init, *(arg.element_type for arg in args), offset_provider_type=offset_provider_type + ) return applied_reduce @_register_builtin_type_synthesizer -def shift(*offset_literals, offset_provider: common.OffsetProvider) -> TypeSynthesizer: +def shift(*offset_literals, offset_provider_type: common.OffsetProviderType) -> TypeSynthesizer: @TypeSynthesizer def apply_shift( it: it_ts.IteratorType | ts.DeferredType, @@ -379,19 +383,19 @@ def apply_shift( assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( offset_axis.value, common.Dimension ) - provider = offset_provider[offset_axis.value.value] # TODO: naming - if isinstance(provider, common.Dimension): + type_ = offset_provider_type[offset_axis.value.value] + if isinstance(type_, common.Dimension): pass - elif isinstance(provider, common.Connectivity): + elif isinstance(type_, common.NeighborConnectivityType): found = False for i, dim in enumerate(new_position_dims): - if dim.value == provider.origin_axis.value: + if dim.value == type_.source_dim.value: assert not found - new_position_dims[i] = provider.neighbor_axis + new_position_dims[i] = type_.codomain found = True assert found else: - raise NotImplementedError() + raise NotImplementedError(f"{type_} is not a supported Connectivity type.") return it_ts.IteratorType( position_dims=new_position_dims, defined_dims=it.defined_dims, diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 802ad2155f..69d8985beb 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -26,7 +26,6 @@ import typing from typing import Any, Iterable, Iterator, Optional -import numpy as np from typing_extensions import Self from gt4py.next import common @@ -49,47 +48,19 @@ def from_signature(cls, *args: Any, **kwargs: Any) -> Self: return cls(args=args, kwargs=kwargs) -@dataclasses.dataclass(frozen=True) -class CompileTimeConnectivity(common.Connectivity): - """Compile-time standin for a GTX connectivity, retaining everything except the connectivity tables.""" - - max_neighbors: int - has_skip_values: bool - origin_axis: common.Dimension - neighbor_axis: common.Dimension - index_type: type[int] | type[np.int32] | type[np.int64] - - def mapped_index( - self, cur_index: int | np.integer, neigh_index: int | np.integer - ) -> Optional[int | np.integer]: - raise NotImplementedError( - "A CompileTimeConnectivity instance should not call `mapped_index`." - ) - - @classmethod - def from_connectivity(cls, connectivity: common.Connectivity) -> Self: - return cls( - max_neighbors=connectivity.max_neighbors, - has_skip_values=connectivity.has_skip_values, - origin_axis=connectivity.origin_axis, - neighbor_axis=connectivity.neighbor_axis, - index_type=connectivity.index_type, - ) - - @property - def table(self) -> None: - return None - - @dataclasses.dataclass(frozen=True) class CompileTimeArgs: """Compile-time standins for arguments to a GTX program to be used in ahead-of-time compilation.""" args: tuple[ts.TypeSpec, ...] kwargs: dict[str, ts.TypeSpec] - offset_provider: dict[str, common.Connectivity | common.Dimension] + offset_provider: common.OffsetProvider # TODO(havogt): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information column_axis: Optional[common.Dimension] + @property + def offset_provider_type(self) -> common.OffsetProviderType: + return common.offset_provider_to_type(self.offset_provider) + @classmethod def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: """Convert concrete GTX program arguments into their compile-time counterparts.""" @@ -98,8 +69,7 @@ def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: offset_provider = kwargs_copy.pop("offset_provider", {}) return cls( args=compile_args, - offset_provider=offset_provider, # TODO(ricoh): replace with the line below once the temporaries pass is AOT-ready. If unsure, just try it and run the tests. - # offset_provider={k: connectivity_or_dimension(v) for k, v in offset_provider.items()}, # noqa: ERA001 [commented-out-code] + offset_provider=offset_provider, column_axis=kwargs_copy.pop("column_axis", None), kwargs={ k: type_translation.from_value(v) for k, v in kwargs_copy.items() if v is not None @@ -138,18 +108,6 @@ def adapted_jit_to_aot_args_factory() -> ( return toolchain.ArgsOnlyAdapter(jit_to_aot_args) -def connectivity_or_dimension( - some_offset_provider: common.Connectivity | common.Dimension, -) -> CompileTimeConnectivity | common.Dimension: - match some_offset_provider: - case common.Dimension(): - return some_offset_provider - case common.Connectivity(): - return CompileTimeConnectivity.from_connectivity(some_offset_provider) - case _: - raise ValueError - - def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: for element in tuple_arg: match element: diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index cc57c137bf..b2aea05641 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -12,7 +12,6 @@ import gt4py.eve as eve from gt4py.eve import NodeTranslator, concepts from gt4py.eve.utils import UIDGenerator -from gt4py.next import common from gt4py.next.program_processors.codegens.gtfn import gtfn_ir, gtfn_ir_common from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ( AssignStmt, @@ -84,54 +83,9 @@ def _is_reduce(node: gtfn_ir.FunCall) -> TypeGuard[gtfn_ir.FunCall]: ) -def _get_connectivity( - applied_reduce_node: gtfn_ir.FunCall, - offset_provider: dict[str, common.Dimension | common.Connectivity], -) -> common.Connectivity: - """Return single connectivity that is compatible with the arguments of the reduce.""" - if not _is_reduce(applied_reduce_node): - raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") - - connectivities: list[common.Connectivity] = [] - for o in _get_partial_offset_tags(applied_reduce_node.args): - conn = offset_provider[o] - assert isinstance(conn, common.Connectivity) - connectivities.append(conn) - - if not connectivities: - raise RuntimeError("Couldn't detect partial shift in any arguments of 'reduce'.") - - if len({(c.max_neighbors, c.has_skip_values) for c in connectivities}) != 1: - # The condition for this check is required but not sufficient: the actual neighbor tables could still be incompatible. - raise RuntimeError("Arguments to 'reduce' have incompatible partial shifts.") - return connectivities[0] - - # TODO: end of code clone -def _make_dense_acess( - shift_call: gtfn_ir.FunCall, nbh_iter: gtfn_ir_common.SymRef -) -> gtfn_ir.FunCall: - return gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="deref"), - args=[ - gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="shift"), args=[*shift_call.args, nbh_iter] - ) - ], - ) - - -def _make_sparse_acess( - field_ref: gtfn_ir_common.SymRef, nbh_iter: gtfn_ir_common.SymRef -) -> gtfn_ir.FunCall: - return gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="tuple_get"), - args=[nbh_iter, gtfn_ir.FunCall(fun=gtfn_ir_common.SymRef(id="deref"), args=[field_ref])], - ) - - class PlugInCurrentIdx(NodeTranslator): def visit_SymRef( self, node: gtfn_ir_common.SymRef @@ -225,32 +179,6 @@ def _expand_symref( ) self.imp_list_ir.append(AssignStmt(lhs=gtfn_ir_common.SymRef(id=red_idx), rhs=rhs)) - def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.SymRef: - offset_provider = kwargs["offset_provider"] - assert offset_provider is not None - - connectivity = _get_connectivity(node, offset_provider) - - args = node.args - # do the following transformations to the node arguments - # dense fields: shift(dense_f, X2Y) -> deref(shift(dense_f, X2Y, nbh_iterator) - # sparse_fields: sparse_f -> tuple_get(nbh_iterator, deref(sparse_f))) - new_args = [] - nbh_iter = gtfn_ir_common.SymRef(id="nbh_iter") - for arg in args: - if isinstance(arg, gtfn_ir.FunCall) and arg.fun.id == "shift": # type: ignore - new_args.append(_make_dense_acess(arg, nbh_iter)) - if isinstance(arg, gtfn_ir_common.SymRef): - new_args.append(_make_sparse_acess(arg, nbh_iter)) - - red_idx = self.uids.sequential_id(prefix="red") - if isinstance(node.fun.args[0], gtfn_ir.Lambda): # type: ignore - self._expand_lambda(node, new_args, red_idx, connectivity.max_neighbors, **kwargs) - elif isinstance(node.fun.args[0], gtfn_ir_common.SymRef): # type: ignore - self._expand_symref(node, new_args, red_idx, connectivity.max_neighbors, **kwargs) - - return gtfn_ir_common.SymRef(id=red_idx) - def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.Expr: if any(isinstance(arg, gtfn_ir.Lambda) for arg in node.args): # do not try to lower constructs that take lambdas as argument to something more readable @@ -278,7 +206,9 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common. self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{lam_idx}"), rhs=expr)) return gtfn_ir_common.SymRef(id=f"{lam_idx}") if _is_reduce(node): - return self.handle_Reduction(node, **kwargs) + raise AssertionError( + "Not implemented. The code-path was removed as it was not actively used and tested." + ) if isinstance(node.fun, gtfn_ir_common.SymRef) and node.fun.id == "make_tuple": tupl_id = self.uids.sequential_id(prefix="tupl") tuple_fun = self.commit_args(node, tupl_id, "make_tuple", **kwargs) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index ce459f7970..f1649112a7 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -82,7 +82,7 @@ def _process_regular_arguments( self, program: itir.FencilDefinition | itir.Program, arg_types: tuple[ts.TypeSpec, ...], - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] @@ -104,22 +104,22 @@ def _process_regular_arguments( ): # translate sparse dimensions to tuple dtype dim_name = dim.value - connectivity = offset_provider[dim_name] - assert isinstance(connectivity, common.Connectivity) + connectivity = offset_provider_type[dim_name] + assert isinstance(connectivity, common.NeighborConnectivityType) size = connectivity.max_neighbors arg = f"gridtools::sid::dimension_to_tuple_like({arg})" arg_exprs.append(arg) return parameters, arg_exprs def _process_connectivity_args( - self, offset_provider: dict[str, common.Connectivity | common.Dimension] + self, offset_provider_type: common.OffsetProviderType ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] - for name, connectivity in offset_provider.items(): - if isinstance(connectivity, common.Connectivity): - if connectivity.index_type not in [np.int32, np.int64]: + for name, connectivity_type in offset_provider_type.items(): + if isinstance(connectivity_type, common.NeighborConnectivityType): + if connectivity_type.dtype.scalar_type not in [np.int32, np.int64]: raise ValueError( "Neighbor table indices must be of type 'np.int32' or 'np.int64'." ) @@ -129,15 +129,8 @@ def _process_connectivity_args( interface.Parameter( name=GENERATED_CONNECTIVITY_PARAM_PREFIX + name.lower(), type_=ts.FieldType( - dims=[ - connectivity.origin_axis, - common.Dimension( - name, kind=common.DimensionKind.LOCAL - ), # TODO(havogt): we should not use the name of the offset as the name of the local dimension - ], - dtype=ts.ScalarType( - type_translation.get_scalar_kind(connectivity.index_type) - ), + dims=list(connectivity_type.domain), + dtype=type_translation.from_dtype(connectivity_type.dtype), ), ) ) @@ -145,19 +138,19 @@ def _process_connectivity_args( # connectivity argument expression nbtbl = ( f"gridtools::fn::sid_neighbor_table::as_neighbor_table<" - f"generated::{connectivity.origin_axis.value}_t, " - f"generated::{name}_t, {connectivity.max_neighbors}" + f"generated::{connectivity_type.source_dim.value}_t, " + f"generated::{name}_t, {connectivity_type.max_neighbors}" f">(std::forward({GENERATED_CONNECTIVITY_PARAM_PREFIX}{name.lower()}))" ) arg_exprs.append( f"gridtools::hymap::keys::make_values({nbtbl})" ) - elif isinstance(connectivity, common.Dimension): + elif isinstance(connectivity_type, common.Dimension): pass else: raise AssertionError( - f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " - f"got '{type(connectivity).__name__}'." + f"Expected offset provider type '{name}' to be a 'NeighborConnectivityType' or 'Dimension', " + f"got '{type(connectivity_type).__name__}'." ) return parameters, arg_exprs @@ -165,7 +158,7 @@ def _process_connectivity_args( def _preprocess_program( self, program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, ) -> itir.Program: apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, @@ -194,7 +187,7 @@ def _preprocess_program( def generate_stencil_source( self, program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], ) -> str: if self.enable_itir_transforms: @@ -204,7 +197,9 @@ def generate_stencil_source( new_program = program gtfn_ir = GTFN_lowering.apply( - new_program, offset_provider=offset_provider, column_axis=column_axis + new_program, + offset_provider_type=common.offset_provider_to_type(offset_provider), + column_axis=column_axis, ) if self.use_imperative_backend: @@ -224,13 +219,13 @@ def __call__( # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) regular_parameters, regular_args_expr = self._process_regular_arguments( - program, inp.args.args, inp.args.offset_provider + program, inp.args.args, inp.args.offset_provider_type ) # handle connectivity parameters and arguments (i.e. what the user provided in the offset # provider) connectivity_parameters, connectivity_args_expr = self._process_connectivity_args( - inp.args.offset_provider + inp.args.offset_provider_type ) # combine into a format that is aligned with what the backend expects diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index bc2bd645e8..129d81d6f9 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -159,7 +159,7 @@ def _collect_dimensions_from_domain( def _collect_offset_definitions( node: itir.Node, grid_type: common.GridType, - offset_provider: dict[str, common.Dimension | common.Connectivity], + offset_provider_type: common.OffsetProviderType, ) -> dict[str, TagDefinition]: used_offset_tags: set[itir.OffsetLiteral] = ( node.walk_values() @@ -167,13 +167,13 @@ def _collect_offset_definitions( .filter(lambda offset_literal: isinstance(offset_literal.value, str)) .getattr("value") ).to_set() - if not used_offset_tags.issubset(set(offset_provider.keys())): + if not used_offset_tags.issubset(set(offset_provider_type.keys())): raise AssertionError("ITIR contains an offset tag without a corresponding offset provider.") offset_definitions = {} - for offset_name, dim_or_connectivity in offset_provider.items(): - if isinstance(dim_or_connectivity, common.Dimension): - dim: common.Dimension = dim_or_connectivity + for offset_name, dim_or_connectivity_type in offset_provider_type.items(): + if isinstance(dim_or_connectivity_type, common.Dimension): + dim: common.Dimension = dim_or_connectivity_type if grid_type == common.GridType.CARTESIAN: # create alias from offset to dimension offset_definitions[dim.value] = TagDefinition(name=Sym(id=dim.value)) @@ -201,12 +201,13 @@ def _collect_offset_definitions( offset_definitions[offset_name] = TagDefinition( name=Sym(id=offset_name), alias=SymRef(id=dim.value) ) - elif isinstance(dim_or_connectivity, common.Connectivity): + elif isinstance( + connectivity_type := dim_or_connectivity_type, common.NeighborConnectivityType + ): assert grid_type == common.GridType.UNSTRUCTURED offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) - connectivity: common.Connectivity = dim_or_connectivity - for dim in [connectivity.origin_axis, connectivity.neighbor_axis]: + for dim in [connectivity_type.source_dim, connectivity_type.codomain]: if dim.kind != common.DimensionKind.HORIZONTAL: raise NotImplementedError() offset_definitions[dim.value] = TagDefinition( @@ -323,7 +324,7 @@ class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): } _unary_op_map: ClassVar[dict[str, str]] = {"not_": "!"} - offset_provider: dict + offset_provider_type: common.OffsetProviderType column_axis: Optional[common.Dimension] grid_type: common.GridType @@ -338,18 +339,18 @@ def apply( cls, node: itir.Program, *, - offset_provider: dict, + offset_provider_type: common.OffsetProviderType, column_axis: Optional[common.Dimension], ) -> Program: if not isinstance(node, itir.Program): raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") - node = itir_type_inference.infer(node, offset_provider=offset_provider) + node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) grid_type = _get_gridtype(node.body) if grid_type == common.GridType.UNSTRUCTURED: node = _CannonicalizeUnstructuredDomain.apply(node) return cls( - offset_provider=offset_provider, column_axis=column_axis, grid_type=grid_type + offset_provider_type=offset_provider_type, column_axis=column_axis, grid_type=grid_type ).visit(node) def visit_Sym(self, node: itir.Sym, **kwargs: Any) -> Sym: @@ -484,8 +485,8 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: if "stencil" in kwargs: shift_offsets = self._collect_offset_or_axis_node(itir.OffsetLiteral, kwargs["stencil"]) for o in shift_offsets: - if o in self.offset_provider and isinstance( - self.offset_provider[o], common.Connectivity + if o in self.offset_provider_type and isinstance( + self.offset_provider_type[o], common.NeighborConnectivityType ): connectivities.append(SymRef(id=o)) return UnstructuredDomain( @@ -679,7 +680,7 @@ def visit_Program(self, node: itir.Program, **kwargs: Any) -> Program: function_definitions = self.visit(node.function_definitions) + extracted_functions offset_definitions = { **_collect_dimensions_from_domain(node.body), - **_collect_offset_definitions(node, self.grid_type, self.offset_provider), + **_collect_offset_definitions(node, self.grid_type, self.offset_provider_type), } return Program( id=SymbolName(node.id), diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index db0df7d121..56ba08015b 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -12,6 +12,7 @@ import dace import numpy as np +from gt4py._core import definitions as core_defs from gt4py.next import common as gtx_common, utils as gtx_utils from . import utility as dace_utils @@ -65,8 +66,8 @@ def _get_args( def _ensure_is_on_device( - connectivity_arg: np.typing.NDArray, device: dace.dtypes.DeviceType -) -> np.typing.NDArray: + connectivity_arg: core_defs.NDArrayObject, device: dace.dtypes.DeviceType +) -> core_defs.NDArrayObject: if device == dace.dtypes.DeviceType.GPU: if not isinstance(connectivity_arg, cp.ndarray): warnings.warn( @@ -78,7 +79,7 @@ def _ensure_is_on_device( def _get_shape_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] + arrays: Mapping[str, dace.data.Array], args: Mapping[str, core_defs.NDArrayObject] ) -> dict[str, int]: shape_args: dict[str, int] = {} for name, value in args.items(): @@ -103,7 +104,7 @@ def _get_shape_args( def _get_stride_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] + arrays: Mapping[str, dace.data.Array], args: Mapping[str, core_defs.NDArrayObject] ) -> dict[str, int]: stride_args = {} for name, value in args.items(): @@ -134,7 +135,7 @@ def get_sdfg_conn_args( sdfg: dace.SDFG, offset_provider: gtx_common.OffsetProvider, on_gpu: bool, -) -> dict[str, np.typing.NDArray]: +) -> dict[str, core_defs.NDArrayObject]: """ Extracts the connectivity tables that are used in the sdfg and ensures that the memory buffers are allocated for the target device. @@ -142,11 +143,11 @@ def get_sdfg_conn_args( device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU connectivity_args = {} - for offset, connectivity in dace_utils.filter_connectivities(offset_provider).items(): - assert isinstance(connectivity, gtx_common.NeighborTable) - param = dace_utils.connectivity_identifier(offset) - if param in sdfg.arrays: - connectivity_args[param] = _ensure_is_on_device(connectivity.table, device) + for offset, connectivity in offset_provider.items(): + if gtx_common.is_neighbor_table(connectivity): + param = dace_utils.connectivity_identifier(offset) + if param in sdfg.arrays: + connectivity_args[param] = _ensure_is_on_device(connectivity.ndarray, device) return connectivity_args diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index bc01e2abda..29395a30c1 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -79,19 +79,18 @@ def debug_info( return default -def filter_connectivities( - offset_provider: gtx_common.OffsetProvider, -) -> dict[str, gtx_common.Connectivity]: +def filter_connectivity_types( + offset_provider_type: gtx_common.OffsetProviderType, +) -> dict[str, gtx_common.NeighborConnectivityType]: """ - Filter offset providers of type `Connectivity`. + Filter offset provider types of type `NeighborConnectivityType`. In other words, filter out the cartesian offset providers. - Returns a new dictionary containing only `Connectivity` values. """ return { - offset: table - for offset, table in offset_provider.items() - if isinstance(table, gtx_common.Connectivity) + offset: conn + for offset, conn in offset_provider_type.items() + if isinstance(conn, gtx_common.NeighborConnectivityType) } diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 73b6e2ed4c..74142dec66 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -527,14 +527,14 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.args[0], gtir.OffsetLiteral) offset = node.args[0].value assert isinstance(offset, str) - offset_provider = self.subgraph_builder.get_offset_provider(offset) - assert isinstance(offset_provider, gtx_common.Connectivity) + offset_provider = self.subgraph_builder.get_offset_provider_type(offset) + assert isinstance(offset_provider, gtx_common.NeighborConnectivityType) it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) - assert offset_provider.neighbor_axis in it.dimensions - assert offset_provider.origin_axis in it.indices - origin_index = it.indices[offset_provider.origin_axis] + assert offset_provider.codomain in it.dimensions + assert offset_provider.source_dim in it.indices + origin_index = it.indices[offset_provider.source_dim] assert isinstance(origin_index, SymbolExpr) assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) @@ -561,7 +561,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: subset=sbs.Range.from_string( ",".join( it.indices[dim].value # type: ignore[union-attr] - if dim != offset_provider.neighbor_axis + if dim != offset_provider.codomain else f"0:{size}" for dim, size in zip(it.dimensions, field_desc.shape, strict=True) ) @@ -657,7 +657,9 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: tasklet_expression = f"{output_connector} = {fun_python_code}" input_args = [self.visit(arg) for arg in node.args] - input_connectivities: dict[gtx_common.Dimension, gtx_common.Connectivity] = {} + input_connectivity_types: dict[ + gtx_common.Dimension, gtx_common.NeighborConnectivityType + ] = {} for input_arg in input_args: assert isinstance(input_arg.gt_dtype, itir_ts.ListType) assert input_arg.gt_dtype.offset_type is not None @@ -665,11 +667,11 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: if offset_type == _CONST_DIM: # this input argument is the result of `make_const_list` continue - offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) - assert isinstance(offset_provider, gtx_common.Connectivity) - input_connectivities[offset_type] = offset_provider + offset_provider_t = self.subgraph_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_t, gtx_common.NeighborConnectivityType) + input_connectivity_types[offset_type] = offset_provider_t - if len(input_connectivities) == 0: + if len(input_connectivity_types) == 0: raise ValueError(f"Missing information on local dimension for map node {node}.") # GT4Py guarantees that all connectivities used to generate lists of neighbors @@ -678,14 +680,14 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: len( set( (conn.has_skip_values, conn.max_neighbors) - for conn in input_connectivities.values() + for conn in input_connectivity_types.values() ) ) != 1 ): raise ValueError("Unexpected arguments to map expression with different neighborhood.") - offset_type, offset_provider = next(iter(input_connectivities.items())) - local_size = offset_provider.max_neighbors + offset_type, offset_provider_type = next(iter(input_connectivity_types.items())) + local_size = offset_provider_type.max_neighbors map_index = dace_gtir_utils.get_map_variable(offset_type) # The dataflow we build in this class has some loose connections on input edges. @@ -717,14 +719,14 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: result, _ = self.sdfg.add_temp_transient((local_size,), dc_dtype) result_node = self.state.add_access(result) - if offset_provider.has_skip_values: + if offset_provider_type.has_skip_values: # In case the `map_` input expressions contain skip values, we use # the connectivity-based offset provider as mask for map computation. connectivity = dace_utils.connectivity_identifier(offset_type.value) connectivity_desc = self.sdfg.arrays[connectivity] connectivity_desc.transient = False - origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) + origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) connectivity_slice = self._construct_local_view( MemletExpr( @@ -733,7 +735,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: element_type=node.type.element_type, offset_type=offset_type ), subset=sbs.Range.from_string( - f"{origin_map_index}, 0:{offset_provider.max_neighbors}" + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" ), ) ) @@ -774,7 +776,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: def _make_reduce_with_skip_values( self, input_expr: ValueExpr | MemletExpr, - offset_provider: gtx_common.Connectivity, + offset_provider_type: gtx_common.NeighborConnectivityType, reduce_init: SymbolExpr, reduce_identity: SymbolExpr, reduce_wcr: str, @@ -792,7 +794,7 @@ def _make_reduce_with_skip_values( corresponding neighbor index in the connectivity table is valid, or the identity value if the neighbor index is missing. """ - origin_map_index = dace_gtir_utils.get_map_variable(offset_provider.origin_axis) + origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) assert ( isinstance(input_expr.gt_dtype, itir_ts.ListType) @@ -815,7 +817,7 @@ def _make_reduce_with_skip_values( f"Found {len(local_dim_indices)} local dimensions in reduce expression, expected one." ) local_dim_index = local_dim_indices[0] - assert desc.shape[local_dim_index] == offset_provider.max_neighbors + assert desc.shape[local_dim_index] == offset_provider_type.max_neighbors # we lower the reduction map with WCR out memlet in a nested SDFG nsdfg = dace.SDFG(name=self.unique_nsdfg_name("reduce_with_skip_values")) @@ -853,7 +855,7 @@ def _make_reduce_with_skip_values( # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. st_reduce.add_mapped_tasklet( name="reduce_with_skip_values", - map_ranges={"i": f"0:{offset_provider.max_neighbors}"}, + map_ranges={"i": f"0:{offset_provider_type.max_neighbors}"}, inputs={ "__val": dace.Memlet(data="values", subset="i"), "__neighbor_idx": dace.Memlet(data="neighbor_indices", subset="i"), @@ -882,7 +884,7 @@ def _make_reduce_with_skip_values( ) self._add_input_data_edge( connectivity_node, - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider.max_neighbors}"), + sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"), nsdfg_node, "neighbor_indices", ) @@ -910,12 +912,17 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type - offset_provider = self.subgraph_builder.get_offset_provider(offset_type.value) - assert isinstance(offset_provider, gtx_common.Connectivity) + offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) - if offset_provider.has_skip_values: + if offset_provider_type.has_skip_values: self._make_reduce_with_skip_values( - input_expr, offset_provider, reduce_init, reduce_identity, reduce_wcr, result_node + input_expr, + offset_provider_type, + reduce_init, + reduce_identity, + reduce_wcr, + result_node, ) else: @@ -1082,16 +1089,16 @@ def _make_dynamic_neighbor_offset( def _make_unstructured_shift( self, it: IteratorExpr, - connectivity: gtx_common.Connectivity, + connectivity: gtx_common.NeighborConnectivityType, offset_table_node: dace.nodes.AccessNode, offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - assert connectivity.neighbor_axis in it.dimensions - neighbor_dim = connectivity.neighbor_axis + assert connectivity.codomain in it.dimensions + neighbor_dim = connectivity.codomain assert neighbor_dim not in it.indices - origin_dim = connectivity.origin_axis + origin_dim = connectivity.source_dim assert origin_dim in it.indices origin_index = it.indices[origin_dim] assert isinstance(origin_index, SymbolExpr) @@ -1132,7 +1139,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: assert isinstance(offset_provider_arg, gtir.OffsetLiteral) offset = offset_provider_arg.value assert isinstance(offset, str) - offset_provider = self.subgraph_builder.get_offset_provider(offset) + offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset) # second argument should be the offset value, which could be a symbolic expression or a dynamic offset offset_expr = ( SymbolExpr(offset_value_arg.value, IndexDType) @@ -1140,8 +1147,8 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: else self.visit(offset_value_arg) ) - if isinstance(offset_provider, gtx_common.Dimension): - return self._make_cartesian_shift(it, offset_provider, offset_expr) + if isinstance(offset_provider_type, gtx_common.Dimension): + return self._make_cartesian_shift(it, offset_provider_type, offset_expr) else: # initially, the storage for the connectivity tables is created as transient; # when the tables are used, the storage is changed to non-transient, @@ -1151,7 +1158,7 @@ def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: offset_table_node = self.state.add_access(offset_table) return self._make_unstructured_shift( - it, offset_provider, offset_table_node, offset_expr + it, offset_provider_type, offset_table_node, offset_expr ) def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index ad8f490f12..52284edfac 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -41,7 +41,7 @@ class DataflowBuilder(Protocol): """Visitor interface to build a dataflow subgraph.""" @abc.abstractmethod - def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: ... + def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: ... @abc.abstractmethod def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: ... @@ -155,7 +155,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): from where to continue building the SDFG. """ - offset_provider: gtx_common.OffsetProvider + offset_provider_type: gtx_common.OffsetProviderType global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") @@ -164,8 +164,8 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="tlet") ) - def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: - return self.offset_provider[offset] + def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: + return self.offset_provider_type[offset] def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] @@ -195,10 +195,10 @@ def _make_array_shape_and_strides( Two lists of symbols, one for the shape and the other for the strides of the array. """ dc_dtype = gtir_builtin_translators.INDEX_DTYPE - neighbor_tables = dace_utils.filter_connectivities(self.offset_provider) + neighbor_table_types = dace_utils.filter_connectivity_types(self.offset_provider_type) shape = [ ( - neighbor_tables[dim.value].max_neighbors + neighbor_table_types[dim.value].max_neighbors if dim.kind == gtx_common.DimensionKind.LOCAL else dace.symbol(dace_utils.field_size_symbol_name(name, i), dc_dtype) ) @@ -374,13 +374,12 @@ def _add_sdfg_params( self.global_symbols[pname] = param.type # add SDFG storage for connectivity tables - for offset, offset_provider in dace_utils.filter_connectivities( - self.offset_provider + for offset, connectivity_type in dace_utils.filter_connectivity_types( + self.offset_provider_type ).items(): - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - local_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) + scalar_type = tt.from_dtype(connectivity_type.dtype) gt_type = ts.FieldType( - [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type ) # We store all connectivity tables as transient arrays here; later, while building # the field operator expressions, we change to non-transient (i.e. allocated externally) @@ -585,7 +584,7 @@ def visit_Lambda( } # lower let-statement lambda node as a nested SDFG - lambda_translator = GTIRToSDFG(self.offset_provider, lambda_symbols) + lambda_translator = GTIRToSDFG(self.offset_provider_type, lambda_symbols) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nstate = nsdfg.add_state("lambda") @@ -630,7 +629,7 @@ def _flatten_tuples( ) connectivity_arrays = { dace_utils.connectivity_identifier(offset) - for offset in dace_utils.filter_connectivities(self.offset_provider) + for offset in dace_utils.filter_connectivity_types(self.offset_provider_type) } input_memlets = {} @@ -778,7 +777,7 @@ def visit_SymRef( def build_sdfg_from_gtir( ir: gtir.Program, - offset_provider: gtx_common.OffsetProvider, + offset_provider_type: gtx_common.OffsetProviderType, ) -> dace.SDFG: """ Receives a GTIR program and lowers it to a DaCe SDFG. @@ -788,15 +787,15 @@ def build_sdfg_from_gtir( Args: ir: The GTIR program node to be lowered to SDFG - offset_provider: The definitions of offset providers used by the program node + offset_provider_type: The definitions of offset providers used by the program node Returns: An SDFG in the DaCe canonical form (simplified) """ - ir = gtir_type_inference.infer(ir, offset_provider=offset_provider) + ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) - sdfg_genenerator = GTIRToSDFG(offset_provider) + sdfg_genenerator = GTIRToSDFG(offset_provider_type) sdfg = sdfg_genenerator.visit(ir) assert isinstance(sdfg, dace.SDFG) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index aa4fd0cd3e..40d44f5ab0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -52,7 +52,9 @@ def generate_sdfg( on_gpu: bool, ) -> dace.SDFG: ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) - sdfg = gtir_sdfg.build_sdfg_from_gtir(ir, offset_provider=offset_provider) + sdfg = gtir_sdfg.build_sdfg_from_gtir( + ir, offset_provider_type=common.offset_provider_to_type(offset_provider) + ) if auto_opt: gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) @@ -75,7 +77,7 @@ def __call__( sdfg = self.generate_sdfg( program, - inp.args.offset_provider, + inp.args.offset_provider, # TODO(havogt): should be offset_provider_type once the transformation don't require run-time info inp.args.column_axis, auto_opt=self.auto_optimize, on_gpu=(self.device_type == gtx_allocators.CUPY_DEVICE), diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index fc2772027e..ef09cf51cd 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -9,7 +9,7 @@ import dataclasses import warnings from collections import OrderedDict -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from dataclasses import field from inspect import currentframe, getframeinfo from pathlib import Path @@ -38,7 +38,7 @@ def preprocess_program( program: itir.FencilDefinition, - offset_provider: Mapping[str, Any], + offset_provider_type: common.OffsetProviderType, lift_mode: legacy_itir_transforms.LiftMode, symbolic_domain_sizes: Optional[dict[str, str]] = None, temporary_extraction_heuristics: Optional[ @@ -51,13 +51,13 @@ def preprocess_program( common_subexpression_elimination=False, force_inline_lambda_args=True, lift_mode=lift_mode, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, symbolic_domain_sizes=symbolic_domain_sizes, temporary_extraction_heuristics=temporary_extraction_heuristics, unroll_reduce=unroll_reduce, ) - node = itir_type_inference.infer(node, offset_provider=offset_provider) + node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) if isinstance(node, itir.Program): fencil_definition = program_to_fencil.program_to_fencil(node) @@ -72,7 +72,7 @@ def preprocess_program( def build_sdfg_from_itir( program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - offset_provider: dict[str, Any], + offset_provider_type: common.OffsetProviderType, auto_optimize: bool = False, on_gpu: bool = False, column_axis: Optional[common.Dimension] = None, @@ -109,10 +109,18 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program, tmps = preprocess_program( - program, offset_provider, lift_mode, symbolic_domain_sizes, temporary_extraction_heuristics + program, + offset_provider_type, + lift_mode, + symbolic_domain_sizes, + temporary_extraction_heuristics, ) sdfg_genenerator = ItirToSDFG( - list(arg_types), offset_provider, tmps, use_field_canonical_representation, column_axis + list(arg_types), + offset_provider_type, + tmps, + use_field_canonical_representation, + column_axis, ) sdfg = sdfg_genenerator.visit(program) if sdfg is None: @@ -186,14 +194,12 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: raise ValueError( "[DaCe Orchestration] Connectivities -at compile time- are required to generate the SDFG. Use `with_connectivities` method." ) - offset_provider = ( - self.connectivities | self._implicit_offset_provider - ) # tables are None at this point + offset_provider_type = {**self.connectivities, **self._implicit_offset_provider} sdfg = self.backend.executor.step.translation.generate_sdfg( # type: ignore[union-attr] self.itir, arg_types, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, column_axis=kwargs.get("column_axis", None), ) self.sdfg_closure_vars["sdfg.arrays"] = sdfg.arrays # use it in __sdfg_closure__ @@ -238,7 +244,7 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: sdfg.offset_providers_per_input_field = {} itir_tmp = legacy_itir_transforms.apply_common_transforms( - self.itir, offset_provider=offset_provider + self.itir, offset_provider_type=offset_provider_type ) itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) for closure in itir_tmp_fencil.closures: @@ -267,7 +273,7 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ the offset providers are not part of GT4Py Program's arguments. Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. """ - offset_provider = self.connectivities + offset_provider_type = self.connectivities # Define DaCe symbols connectivity_table_size_symbols = { @@ -276,9 +282,9 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ ): dace.symbol( dace_utils.field_size_symbol_name(dace_utils.connectivity_identifier(k), axis) ) - for k, v in offset_provider.items() # type: ignore[union-attr] + for k, v in offset_provider_type.items() # type: ignore[union-attr] for axis in [0, 1] - if hasattr(v, "table") + if isinstance(v, common.NeighborConnectivityType) and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] } @@ -288,9 +294,9 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ ): dace.symbol( dace_utils.field_stride_symbol_name(dace_utils.connectivity_identifier(k), axis) ) - for k, v in offset_provider.items() # type: ignore[union-attr] + for k, v in offset_provider_type.items() # type: ignore[union-attr] for axis in [0, 1] - if hasattr(v, "table") + if isinstance(v, common.NeighborConnectivityType) and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] } @@ -298,8 +304,8 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ # Define the storage location (e.g. CPU, GPU) of the connectivity tables if "storage" not in Program.connectivity_tables_data_descriptors: - for k, v in offset_provider.items(): # type: ignore[union-attr] - if not hasattr(v, "table"): + for k, v in offset_provider_type.items(): # type: ignore[union-attr] + if not isinstance(v, common.NeighborConnectivityType): continue if dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"]: Program.connectivity_tables_data_descriptors["storage"] = ( @@ -311,12 +317,15 @@ def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[ # Build the closure dictionary closure_dict = {} - for k, v in offset_provider.items(): # type: ignore[union-attr] + for k, v in offset_provider_type.items(): # type: ignore[union-attr] conn_id = dace_utils.connectivity_identifier(k) - if hasattr(v, "table") and conn_id in self.sdfg_closure_vars["sdfg.arrays"]: + if ( + isinstance(v, common.NeighborConnectivityType) + and conn_id in self.sdfg_closure_vars["sdfg.arrays"] + ): if conn_id not in Program.connectivity_tables_data_descriptors: Program.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( - dtype=dace.int64 if v.index_type == np.int64 else dace.int32, + dtype=dace.int64 if v.dtype.scalar_type == np.int64 else dace.int32, shape=[ symbols[dace_utils.field_size_symbol_name(conn_id, 0)], symbols[dace_utils.field_size_symbol_name(conn_id, 1)], diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index a0f4b83d35..823943cfd5 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -7,14 +7,13 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings -from typing import Any, Mapping, Optional, Sequence, cast +from typing import Optional, Sequence, cast import dace from dace.sdfg.state import LoopRegion import gt4py.eve as eve -from gt4py.next import Dimension, DimensionKind -from gt4py.next.common import Connectivity +from gt4py.next import Dimension, DimensionKind, common from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef @@ -91,7 +90,10 @@ def _get_scan_dim( def _make_array_shape_and_strides( - name: str, dims: Sequence[Dimension], offset_provider: Mapping[str, Any], sort_dims: bool + name: str, + dims: Sequence[Dimension], + offset_provider_type: common.OffsetProviderType, + sort_dims: bool, ) -> tuple[list[dace.symbol], list[dace.symbol]]: """ Parse field dimensions and allocate symbols for array shape and strides. @@ -106,10 +108,10 @@ def _make_array_shape_and_strides( """ dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) sorted_dims = dace_utils.get_sorted_dims(dims) if sort_dims else list(enumerate(dims)) - neighbor_tables = dace_utils.filter_connectivities(offset_provider) + connectivity_types = dace_utils.filter_connectivity_types(offset_provider_type) shape = [ ( - neighbor_tables[dim.value].max_neighbors + connectivity_types[dim.value].max_neighbors if dim.kind == DimensionKind.LOCAL # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain else dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) @@ -144,21 +146,21 @@ class ItirToSDFG(eve.NodeVisitor): param_types: list[ts.TypeSpec] storage_types: dict[str, ts.TypeSpec] column_axis: Optional[Dimension] - offset_provider: dict[str, Any] + offset_provider_type: common.OffsetProviderType unique_id: int use_field_canonical_representation: bool def __init__( self, param_types: list[ts.TypeSpec], - offset_provider: dict[str, Connectivity | Dimension], + offset_provider_type: common.OffsetProviderType, tmps: list[itir.Temporary], use_field_canonical_representation: bool, column_axis: Optional[Dimension] = None, ): self.param_types = param_types self.column_axis = column_axis - self.offset_provider = offset_provider + self.offset_provider_type = offset_provider_type self.storage_types = {} self.tmps = tmps self.use_field_canonical_representation = use_field_canonical_representation @@ -166,7 +168,7 @@ def __init__( def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, sort_dimensions: bool): if isinstance(type_, ts.FieldType): shape, strides = _make_array_shape_and_strides( - name, type_.dims, self.offset_provider, sort_dimensions + name, type_.dims, self.offset_provider_type, sort_dimensions ) dtype = dace_utils.as_dace_type(type_.dtype) sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) @@ -255,7 +257,7 @@ def get_output_nodes( # Visit output node again to generate the corresponding tasklet context = Context(sdfg, state, output_symbols_pass.symbol_refs) translator = PythonTaskletCodegen( - self.offset_provider, context, self.use_field_canonical_representation + self.offset_provider_type, context, self.use_field_canonical_representation ) output_nodes = flatten_list(translator.visit(closure.output)) return {node.value.data: node.value for node in output_nodes} @@ -266,7 +268,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): entry_state = program_sdfg.add_state("program_entry", is_start_block=True) # Filter neighbor tables from offset providers. - neighbor_tables = get_used_connectivities(node, self.offset_provider) + connectivity_types = get_used_connectivities(node, self.offset_provider_type) # Add program parameters as SDFG storages. for param, type_ in zip(node.params, self.param_types): @@ -285,11 +287,10 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): last_state = entry_state # Add connectivities as SDFG storages. - for offset, offset_provider in neighbor_tables.items(): - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - local_dim = Dimension(offset, kind=DimensionKind.LOCAL) + for offset, connectivity_type in connectivity_types.items(): + scalar_type = tt.from_dtype(connectivity_type.dtype) type_ = ts.FieldType( - [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type ) self.add_storage( program_sdfg, @@ -362,7 +363,7 @@ def visit_StencilClosure( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] @@ -568,7 +569,7 @@ def _visit_scan_stencil_closure( ) assert isinstance(node.output, SymRef) - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) assert all( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls @@ -673,7 +674,7 @@ def _visit_scan_stencil_closure( connectivity_arrays = [(scan_sdfg.arrays[name], name) for name in connectivity_names] lambda_context, lambda_outputs = closure_to_tasklet_sdfg( node, - self.offset_provider, + self.offset_provider_type, lambda_domain, input_arrays, connectivity_arrays, @@ -738,7 +739,7 @@ def _visit_parallel_stencil_closure( tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... ], ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: - neighbor_tables = get_used_connectivities(node, self.offset_provider) + neighbor_tables = get_used_connectivities(node, self.offset_provider_type) assert all( isinstance(inp, SymRef) for inp in node.inputs ) # backend only supports SymRef inputs, not `index` calls @@ -762,7 +763,7 @@ def _visit_parallel_stencil_closure( context, results = closure_to_tasklet_sdfg( node, - self.offset_provider, + self.offset_provider_type, index_domain, input_arrays, connectivity_arrays, @@ -788,7 +789,7 @@ def _visit_domain( lower_bound = named_range.args[1] upper_bound = named_range.args[2] translator = PythonTaskletCodegen( - self.offset_provider, + self.offset_provider_type, context, self.use_field_canonical_representation, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 991053b4a5..2b2669187a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -19,8 +19,8 @@ import gt4py.eve.codegen from gt4py import eve -from gt4py.next import Dimension -from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value, Connectivity +from gt4py.next import common +from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir import FunCall, Lambda from gt4py.next.iterator.type_system import type_specifications as it_ts @@ -187,15 +187,15 @@ def _visit_lift_in_neighbors_reduction( transformer: PythonTaskletCodegen, node: itir.FunCall, node_args: Sequence[IteratorExpr | list[ValueExpr]], - offset_provider: Connectivity, + connectivity_type: common.NeighborConnectivityType, map_entry: dace.nodes.MapEntry, map_exit: dace.nodes.MapExit, neighbor_index_node: dace.nodes.AccessNode, neighbor_value_node: dace.nodes.AccessNode, ) -> list[ValueExpr]: assert transformer.context.reduce_identity is not None - neighbor_dim = offset_provider.neighbor_axis.value - origin_dim = offset_provider.origin_axis.value + neighbor_dim = connectivity_type.codomain.value + origin_dim = connectivity_type.source_dim.value lifted_args: list[IteratorExpr | ValueExpr] = [] for arg in node_args: @@ -232,7 +232,7 @@ def _visit_lift_in_neighbors_reduction( assert isinstance(y, ValueExpr) input_nodes[x] = y.value - neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider) + neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider_type) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] @@ -294,7 +294,7 @@ def _visit_lift_in_neighbors_reduction( memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), ) - if offset_provider.has_skip_values: + if connectivity_type.has_skip_values: # check neighbor validity on if/else inter-state edge # use one branch for connectivity case start_state = lift_context.body.add_state_before( @@ -333,8 +333,8 @@ def builtin_neighbors( assert isinstance(offset_literal, itir.OffsetLiteral) offset_dim = offset_literal.value assert isinstance(offset_dim, str) - offset_provider = transformer.offset_provider[offset_dim] - if not isinstance(offset_provider, Connectivity): + connectivity_type = transformer.offset_provider_type[offset_dim] + if not isinstance(connectivity_type, common.NeighborConnectivityType): raise NotImplementedError( "Neighbor reduction only implemented for connectivity based on neighbor tables." ) @@ -351,7 +351,7 @@ def builtin_neighbors( iterator = transformer.visit(data) assert isinstance(iterator, IteratorExpr) field_desc = iterator.field.desc(transformer.context.body) - origin_index_node = iterator.indices[offset_provider.origin_axis.value] + origin_index_node = iterator.indices[connectivity_type.source_dim.value] assert transformer.context.reduce_identity is not None assert transformer.context.reduce_identity.dtype == iterator.dtype @@ -361,7 +361,7 @@ def builtin_neighbors( sdfg.add_array( neighbor_value_var, dtype=iterator.dtype, - shape=(offset_provider.max_neighbors,), + shape=(connectivity_type.max_neighbors,), transient=True, ) neighbor_value_node = state.add_access(neighbor_value_var, debuginfo=di) @@ -375,7 +375,7 @@ def builtin_neighbors( neighbor_map_index = unique_name(f"{offset_dim}_neighbor_map_idx") me, mx = state.add_map( f"{offset_dim}_neighbor_map", - ndrange={neighbor_map_index: f"0:{offset_provider.max_neighbors}"}, + ndrange={neighbor_map_index: f"0:{connectivity_type.max_neighbors}"}, debuginfo=di, ) @@ -414,7 +414,7 @@ def builtin_neighbors( transformer, lift_node, lift_args, - offset_provider, + connectivity_type, me, mx, neighbor_index_node, @@ -423,13 +423,13 @@ def builtin_neighbors( else: sorted_dims = transformer.get_sorted_field_dimensions(iterator.dimensions) data_access_index = ",".join(f"{dim}_v" for dim in sorted_dims) - connector_neighbor_dim = f"{offset_provider.neighbor_axis.value}_v" + connector_neighbor_dim = f"{connectivity_type.codomain.value}_v" data_access_tasklet = state.add_tasklet( "data_access", code=f"__data = __field[{data_access_index}] " + ( f"if {connector_neighbor_dim} != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if offset_provider.has_skip_values + if connectivity_type.has_skip_values else "" ), inputs={"__field"} | {f"{dim}_v" for dim in iterator.dimensions}, @@ -445,7 +445,7 @@ def builtin_neighbors( ) for dim in iterator.dimensions: connector = f"{dim}_v" - if dim == offset_provider.neighbor_axis.value: + if dim == connectivity_type.codomain.value: state.add_edge( neighbor_index_node, None, @@ -470,7 +470,7 @@ def builtin_neighbors( src_conn="__data", ) - if not offset_provider.has_skip_values: + if not connectivity_type.has_skip_values: return [ValueExpr(neighbor_value_node, iterator.dtype)] else: """ @@ -483,7 +483,7 @@ def builtin_neighbors( sdfg.add_array( neighbor_valid_var, dtype=dace.dtypes.bool, - shape=(offset_provider.max_neighbors,), + shape=(connectivity_type.max_neighbors,), transient=True, ) neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) @@ -572,7 +572,7 @@ def build_if_state(arg, state): symbol_map = copy.deepcopy(transformer.context.symbol_map) node_context = Context(sdfg, state, symbol_map) node_taskgen = PythonTaskletCodegen( - transformer.offset_provider, + transformer.offset_provider_type, node_context, transformer.use_field_canonical_representation, ) @@ -884,21 +884,12 @@ def visit_SymRef(self, node: itir.SymRef): ) +@dataclasses.dataclass class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): - offset_provider: dict[str, Any] + offset_provider_type: common.OffsetProviderType context: Context use_field_canonical_representation: bool - def __init__( - self, - offset_provider: dict[str, Any], - context: Context, - use_field_canonical_representation: bool, - ): - self.offset_provider = offset_provider - self.context = context - self.use_field_canonical_representation = use_field_canonical_representation - def get_sorted_field_dimensions(self, dims: Sequence[str]): return sorted(dims) if self.use_field_canonical_representation else dims @@ -914,7 +905,7 @@ def visit_Lambda( ]: func_name = f"lambda_{abs(hash(node)):x}" neighbor_tables = ( - get_used_connectivities(node, self.offset_provider) if use_neighbor_tables else {} + get_used_connectivities(node, self.offset_provider_type) if use_neighbor_tables else {} ) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() @@ -974,7 +965,7 @@ def visit_Lambda( reduce_identity=self.context.reduce_identity, ) lambda_taskgen = PythonTaskletCodegen( - self.offset_provider, + self.offset_provider_type, lambda_context, self.use_field_canonical_representation, ) @@ -1066,7 +1057,7 @@ def _visit_call(self, node: itir.FunCall): store, self.context.body.arrays[store] ) - neighbor_tables = get_used_connectivities(node.fun, self.offset_provider) + neighbor_tables = get_used_connectivities(node.fun, self.offset_provider_type) for offset in neighbor_tables.keys(): var = dace_utils.connectivity_identifier(offset) nsdfg_inputs[var] = dace.Memlet.from_array(var, self.context.body.arrays[var]) @@ -1136,12 +1127,13 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices] assert len(dims_not_indexed) == 1 offset = dims_not_indexed[0] - offset_provider = self.offset_provider[offset] - neighbor_dim = offset_provider.neighbor_axis.value + offset_provider_type = self.offset_provider_type[offset] + assert isinstance(offset_provider_type, common.NeighborConnectivityType) + neighbor_dim = offset_provider_type.codomain.value result_name = unique_var_name() self.context.body.add_array( - result_name, (offset_provider.max_neighbors,), iterator.dtype, transient=True + result_name, (offset_provider_type.max_neighbors,), iterator.dtype, transient=True ) result_array = self.context.body.arrays[result_name] result_node = self.context.state.add_access(result_name, debuginfo=di) @@ -1158,7 +1150,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: # we create a mapped tasklet for array slicing index_name = unique_name(f"_i_{neighbor_dim}") - map_ranges = {index_name: f"0:{offset_provider.max_neighbors}"} + map_ranges = {index_name: f"0:{offset_provider_type.max_neighbors}"} src_subset = ",".join( [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] ) @@ -1212,27 +1204,30 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: offset_node = self.visit(tail[1])[0] assert offset_node.dtype in dace.dtypes.INTEGER_TYPES - if isinstance(self.offset_provider[offset_dim], Connectivity): - offset_provider = self.offset_provider[offset_dim] + if isinstance(self.offset_provider_type[offset_dim], common.NeighborConnectivityType): + offset_provider_type = cast( + common.NeighborConnectivityType, self.offset_provider_type[offset_dim] + ) # ensured by condition connectivity = self.context.state.add_access( dace_utils.connectivity_identifier(offset_dim), debuginfo=di ) - shifted_dim = offset_provider.origin_axis.value - target_dim = offset_provider.neighbor_axis.value + shifted_dim_tag = offset_provider_type.source_dim.value + target_dim_tag = offset_provider_type.codomain.value args = [ ValueExpr(connectivity, _INDEX_DTYPE), - ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), + ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node, ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" else: - assert isinstance(self.offset_provider[offset_dim], Dimension) + shifted_dim = self.offset_provider_type[offset_dim] + assert isinstance(shifted_dim, common.Dimension) - shifted_dim = self.offset_provider[offset_dim].value - target_dim = shifted_dim - args = [ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node] + shifted_dim_tag = shifted_dim.value + target_dim_tag = shifted_dim_tag + args = [ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]} + {internals[1]}" @@ -1241,8 +1236,8 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} - del shifted_index[shifted_dim] - shifted_index[target_dim] = shifted_value + del shifted_index[shifted_dim_tag] + shifted_index[target_dim_tag] = shifted_value return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) @@ -1506,7 +1501,7 @@ def is_scan(node: itir.Node) -> bool: def closure_to_tasklet_sdfg( node: itir.StencilClosure, - offset_provider: dict[str, Any], + offset_provider_type: common.OffsetProviderType, domain: dict[str, str], inputs: Sequence[tuple[str, ts.TypeSpec]], connectivities: Sequence[tuple[dace.ndarray, str]], @@ -1547,7 +1542,9 @@ def closure_to_tasklet_sdfg( body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) context = Context(body, state, symbol_map) - translator = PythonTaskletCodegen(offset_provider, context, use_field_canonical_representation) + translator = PythonTaskletCodegen( + offset_provider_type, context, use_field_canonical_representation + ) args = [itir.SymRef(id=name) for name, _ in inputs] if is_scan(node.stencil): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index d367eb0883..72bb32f003 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -7,21 +7,21 @@ # SPDX-License-Identifier: BSD-3-Clause import itertools -from typing import Any, Mapping +from typing import Any import dace import gt4py.next.iterator.ir as itir from gt4py import eve -from gt4py.next.common import Connectivity +from gt4py.next import common from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.program_processors.runners.dace_common import utility as dace_utils def get_used_connectivities( - node: itir.Node, offset_provider: Mapping[str, Any] -) -> dict[str, Connectivity]: - connectivities = dace_utils.filter_connectivities(offset_provider) + node: itir.Node, offset_provider_type: common.OffsetProviderType +) -> dict[str, common.NeighborConnectivityType]: + connectivities = dace_utils.filter_connectivity_types(offset_provider_type) offset_dims = set(eve.walk_values(node).if_isinstance(itir.OffsetLiteral).getattr("value")) return {offset: connectivities[offset] for offset in offset_dims if offset in connectivities} diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 740f1979cd..653ed4719d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -52,7 +52,7 @@ def generate_sdfg( self, program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - offset_provider: dict[str, common.Dimension | common.Connectivity], + offset_provider_type: common.OffsetProviderType, column_axis: Optional[common.Dimension], ) -> dace.SDFG: on_gpu = ( @@ -64,7 +64,7 @@ def generate_sdfg( return build_sdfg_from_itir( program, arg_types, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, auto_optimize=self.auto_optimize, on_gpu=on_gpu, column_axis=column_axis, @@ -87,7 +87,7 @@ def __call__( sdfg = self.generate_sdfg( program, inp.args.args, - inp.args.offset_provider, + common.offset_provider_to_type(inp.args.offset_provider), inp.args.column_axis, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 965c6417b2..1f3778f227 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -12,14 +12,12 @@ import diskcache import factory -import numpy.typing as npt import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py.eve import utils from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config -from gt4py.next.common import Connectivity, Dimension from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind @@ -63,8 +61,8 @@ def decorated_program( def _ensure_is_on_device( - connectivity_arg: npt.NDArray, device: core_defs.DeviceType -) -> npt.NDArray: + connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType +) -> core_defs.NDArrayObject: if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]: import cupy as cp @@ -79,17 +77,17 @@ def _ensure_is_on_device( def extract_connectivity_args( offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType -) -> list[tuple[npt.NDArray, tuple[int, ...]]]: +) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: # note: the order here needs to agree with the order of the generated bindings - args: list[tuple[npt.NDArray, tuple[int, ...]]] = [] + args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [] for name, conn in offset_provider.items(): if isinstance(conn, common.Connectivity): - if not isinstance(conn, common.NeighborTable): + if not common.is_neighbor_table(conn): raise NotImplementedError( "Only 'NeighborTable' connectivities implemented at this point." ) # copying to device here is a fallback for easy testing and might be removed later - conn_arg = _ensure_is_on_device(conn.table, device) + conn_arg = _ensure_is_on_device(conn.ndarray, device) args.append((conn_arg, tuple([0] * 2))) elif isinstance(conn, common.Dimension): pass @@ -125,7 +123,7 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: the program, sorted offset_provider, and column_axis. """ program: itir.FencilDefinition | itir.Program = inp.data - offset_provider: dict[str, Connectivity | Dimension] = inp.args.offset_provider + offset_provider: common.OffsetProvider = inp.args.offset_provider column_axis: Optional[common.Dimension] = inp.args.column_axis program_hash = utils.content_hash( diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 4d518d7fcc..1dd568b95a 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -94,7 +94,7 @@ def fencil_generator( ir: itir.Program | itir.FencilDefinition, debug: bool, use_embedded: bool, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, transforms: itir_transforms.ITIRTransform, ) -> stages.CompiledProgram: """ @@ -111,7 +111,15 @@ def fencil_generator( """ # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism - cache_key = hash((ir, transforms, debug, use_embedded, tuple(offset_provider.items()))) + cache_key = hash( + ( + ir, + transforms, + debug, + use_embedded, + tuple(common.offset_provider_to_type(offset_provider).items()), + ) + ) if cache_key in _FENCIL_CACHE: if debug: print(f"Using cached fencil for key {cache_key}") @@ -151,7 +159,9 @@ def fencil_generator( """ ) - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as source_file: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", encoding="utf-8", delete=False + ) as source_file: source_file_name = source_file.name if debug: print(source_file_name) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 0827d99cdc..fa8c9b9ab1 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -63,6 +63,7 @@ class DimensionType(TypeSpec): @dataclass(frozen=True) class OffsetType(TypeSpec): + # TODO(havogt): replace by ConnectivityType source: func_common.Dimension target: tuple[func_common.Dimension] | tuple[func_common.Dimension, func_common.Dimension] diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 1da34db3c0..f5646c71e4 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -6,30 +6,32 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import numpy as np -from typing import Optional from types import ModuleType +from typing import Optional + +import numpy as np import pytest import gt4py.next as gtx -from gt4py.next import backend as next_backend -from gt4py.next.otf import arguments +from gt4py.next import backend as next_backend, common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + E2V, + E2VDim, + Edge, + Vertex, exec_alloc_descriptor, mesh_descriptor, - Vertex, - Edge, - E2V, ) from next_tests.integration_tests.multi_feature_tests.ffront_tests.test_laplacian import ( lap_program, - laplap_program, lap_ref, + laplap_program, ) + try: import dace from gt4py.next.program_processors.runners.dace import ( @@ -57,25 +59,20 @@ def test_sdfgConvertible_laplap(cartesian_case): in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() out_field = cases.allocate(cartesian_case, laplap_program, "out_field")() - connectivities = {} # Dict of NeighborOffsetProviders, where self.table = None - for k, v in cartesian_case.offset_provider.items(): - if hasattr(v, "table"): - connectivities[k] = arguments.CompileTimeConnectivity( - v.max_neighbors, v.has_skip_values, v.origin_axis, v.neighbor_axis, v.table.dtype - ) - else: - connectivities[k] = v - # Test DaCe closure support @dace.program def sdfg(): tmp_field = xp.empty_like(out_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( cartesian_case.backend - ).with_connectivities(connectivities)(in_field, tmp_field) + ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + in_field, tmp_field + ) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( cartesian_case.backend - ).with_connectivities(connectivities)(tmp_field, out_field) + ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + tmp_field, out_field + ) sdfg() @@ -130,13 +127,13 @@ def sdfg( a, out, offset_provider=offset_provider ) - e2v = gtx.NeighborTableOffsetProvider( - xp.asarray([[0, 1], [1, 2], [2, 0]]), Edge, Vertex, 2, False - ) - connectivities = {} - connectivities["E2V"] = arguments.CompileTimeConnectivity( - e2v.max_neighbors, e2v.has_skip_values, e2v.origin_axis, e2v.neighbor_axis, e2v.table.dtype + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[0, 1], [1, 2], [2, 0]]), + allocator=allocator, ) + connectivities = {"E2V": e2v.__gt_type__()} offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) SDFG = sdfg.to_sdfg(connectivities=connectivities) @@ -144,6 +141,9 @@ def sdfg( a = gtx.as_field([Vertex], xp.asarray([0.0, 1.0, 2.0]), allocator=allocator) out = gtx.zeros({Edge: 3}, allocator=allocator) + e2v_ndarray_copy = ( + e2v.ndarray.copy() + ) # otherwise DaCe complains about the gt4py custom allocated view # This is a low level interface to call the compiled SDFG. # It is not supposed to be used in user code. # The high level interface should be provided by a DaCe Orchestrator, @@ -155,21 +155,21 @@ def sdfg( offset_provider, rows=3, cols=2, - connectivity_E2V=e2v.table, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 0 - ), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 1 - ), + connectivity_E2V=e2v_ndarray_copy, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), ) - e2v_xp = xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table - assert np.allclose(gtx.field_utils.asnumpy(out), gtx.field_utils.asnumpy(a)[e2v_xp[:, 0]]) + e2v_np = e2v.asnumpy() + assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) - e2v = gtx.NeighborTableOffsetProvider( - xp.asarray([[1, 0], [2, 1], [0, 2]]), Edge, Vertex, 2, False + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[1, 0], [2, 1], [0, 2]]), + allocator=allocator, ) + e2v_ndarray_copy = e2v.ndarray.copy() offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) cSDFG( a, @@ -177,17 +177,13 @@ def sdfg( offset_provider, rows=3, cols=2, - connectivity_E2V=e2v.table, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 0 - ), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 1 - ), + connectivity_E2V=e2v_ndarray_copy, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v_ndarray_copy, 1), ) - e2v_xp = xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table - assert np.allclose(gtx.field_utils.asnumpy(out), gtx.field_utils.asnumpy(a)[e2v_xp[:, 0]]) + e2v_np = e2v.asnumpy() + assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) def get_stride_from_numpy_to_dace(numpy_array: np.ndarray, axis: int) -> int: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index c64efb27d2..794dd06709 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -152,7 +152,10 @@ def num_edges(self) -> int: ... def num_levels(self) -> int: ... @property - def offset_provider(self) -> dict[str, common.Connectivity]: ... + def offset_provider(self) -> common.OffsetProvider: ... + + @property + def offset_provider_type(self) -> common.OffsetProviderType: ... def simple_mesh() -> MeshDescriptor: @@ -211,25 +214,40 @@ def simple_mesh() -> MeshDescriptor: assert all(len(row) == 2 for row in e2v_arr) e2v_arr = np.asarray(e2v_arr, dtype=gtx.IndexType) + offset_provider = { + V2E.value: common._connectivity( + v2e_arr, + codomain=Edge, + domain={Vertex: v2e_arr.shape[0], V2EDim: 4}, + skip_value=None, + ), + E2V.value: common._connectivity( + e2v_arr, + codomain=Vertex, + domain={Edge: e2v_arr.shape[0], E2VDim: 2}, + skip_value=None, + ), + C2V.value: common._connectivity( + c2v_arr, + codomain=Vertex, + domain={Cell: c2v_arr.shape[0], C2VDim: 4}, + skip_value=None, + ), + C2E.value: common._connectivity( + c2e_arr, + codomain=Edge, + domain={Cell: c2e_arr.shape[0], C2EDim: 4}, + skip_value=None, + ), + } + return types.SimpleNamespace( name="simple_mesh", num_vertices=num_vertices, num_edges=np.int32(num_edges), num_cells=num_cells, - offset_provider={ - V2E.value: gtx.NeighborTableOffsetProvider( - v2e_arr, Vertex, Edge, 4, has_skip_values=False - ), - E2V.value: gtx.NeighborTableOffsetProvider( - e2v_arr, Edge, Vertex, 2, has_skip_values=False - ), - C2V.value: gtx.NeighborTableOffsetProvider( - c2v_arr, Cell, Vertex, 4, has_skip_values=False - ), - C2E.value: gtx.NeighborTableOffsetProvider( - c2e_arr, Cell, Edge, 4, has_skip_values=False - ), - }, + offset_provider=offset_provider, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) @@ -287,25 +305,40 @@ def skip_value_mesh() -> MeshDescriptor: dtype=gtx.IndexType, ) + offset_provider = { + V2E.value: common._connectivity( + v2e_arr, + codomain=Edge, + domain={Vertex: v2e_arr.shape[0], V2EDim: 5}, + skip_value=common._DEFAULT_SKIP_VALUE, + ), + E2V.value: common._connectivity( + e2v_arr, + codomain=Vertex, + domain={Edge: e2v_arr.shape[0], E2VDim: 2}, + skip_value=None, + ), + C2V.value: common._connectivity( + c2v_arr, + codomain=Vertex, + domain={Cell: c2v_arr.shape[0], C2VDim: 3}, + skip_value=None, + ), + C2E.value: common._connectivity( + c2e_arr, + codomain=Edge, + domain={Cell: c2e_arr.shape[0], C2EDim: 3}, + skip_value=None, + ), + } + return types.SimpleNamespace( name="skip_value_mesh", num_vertices=num_vertices, num_edges=num_edges, num_cells=num_cells, - offset_provider={ - V2E.value: gtx.NeighborTableOffsetProvider( - v2e_arr, Vertex, Edge, 5, has_skip_values=True - ), - E2V.value: gtx.NeighborTableOffsetProvider( - e2v_arr, Edge, Vertex, 2, has_skip_values=False - ), - C2V.value: gtx.NeighborTableOffsetProvider( - c2v_arr, Cell, Vertex, 3, has_skip_values=False - ), - C2E.value: gtx.NeighborTableOffsetProvider( - c2e_arr, Cell, Edge, 3, has_skip_values=False - ), - }, + offset_provider=offset_provider, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index a5453151e6..1a51e3667d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -89,7 +89,7 @@ def testee(a: cases.VField) -> cases.EField: cases.verify_with_default_data( unstructured_case, testee, - ref=lambda a: a[unstructured_case.offset_provider["E2V"].table[:, 0]], + ref=lambda a: a[unstructured_case.offset_provider["E2V"].ndarray[:, 0]], ) @@ -115,16 +115,16 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: cases.verify_with_default_data( unstructured_case, composed_shift_unstructured_flat, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], ) cases.verify_with_default_data( unstructured_case, composed_shift_unstructured_intermediate_result, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], comparison=lambda inp, tmp: np.all(inp == tmp), ) @@ -132,8 +132,8 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: cases.verify_with_default_data( unstructured_case, composed_shift_unstructured, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], ) @@ -583,11 +583,11 @@ def testee(a: cases.VField) -> cases.VField: unstructured_case, testee, ref=lambda a: np.sum( - np.sum(a[unstructured_case.offset_provider["E2V"].table], axis=1, initial=0)[ - unstructured_case.offset_provider["V2E"].table + np.sum(a[unstructured_case.offset_provider["E2V"].ndarray], axis=1, initial=0)[ + unstructured_case.offset_provider["V2E"].ndarray ], axis=1, - where=unstructured_case.offset_provider["V2E"].table != common._DEFAULT_SKIP_VALUE, + where=unstructured_case.offset_provider["V2E"].ndarray != common._DEFAULT_SKIP_VALUE, ), comparison=lambda a, tmp_2: np.all(a == tmp_2), ) @@ -606,8 +606,8 @@ def testee(inp: cases.EField) -> cases.EField: unstructured_case, testee, ref=lambda inp: np.sum( - np.sum(inp[unstructured_case.offset_provider["V2E"].table], axis=1)[ - unstructured_case.offset_provider["E2V"].table + np.sum(inp[unstructured_case.offset_provider["V2E"].ndarray], axis=1)[ + unstructured_case.offset_provider["E2V"].ndarray ], axis=1, ), @@ -627,8 +627,8 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField unstructured_case, testee, ref=lambda a, b: [ - np.sum(a[unstructured_case.offset_provider["V2E"].table], axis=1), - np.sum(b[unstructured_case.offset_provider["V2E"].table], axis=1), + np.sum(a[unstructured_case.offset_provider["V2E"].ndarray], axis=1), + np.sum(b[unstructured_case.offset_provider["V2E"].ndarray], axis=1), ], comparison=lambda a, tmp: (np.all(a[0] == tmp[0]), np.all(a[1] == tmp[1])), ) @@ -649,11 +649,11 @@ def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: unstructured_case, reduce_tuple_element, ref=lambda e, v: np.sum( - e[v2e.table] + np.tile(v, (v2e.max_neighbors, 1)).T, + e[v2e.ndarray] + np.tile(v, (v2e.shape[1], 1)).T, axis=1, initial=0, - where=v2e.table != common._DEFAULT_SKIP_VALUE, - )[unstructured_case.offset_provider["E2V"].table[:, 0]], + where=v2e.ndarray != common._DEFAULT_SKIP_VALUE, + )[unstructured_case.offset_provider["E2V"].ndarray[:, 0]], ) @@ -780,7 +780,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: tmp = neighbor_sum(b(V2E) if 2 < 3 else a(V2E), axis=V2EDim) return tmp - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, testee, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 37f4ee2cd1..33832fb5f0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -33,11 +33,11 @@ def testee( ) # multiplication with shifted `ones` because reduction of only non-shifted field with local dimension is not supported inp = unstructured_case.as_field( - [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].ndarray ) ones = cases.allocate(unstructured_case, testee, "ones").strategy(cases.ConstInitializer(1))() - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify( unstructured_case, testee, @@ -57,7 +57,7 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 return neighbor_sum(inp, axis=V2EDim) inp = unstructured_case.as_field( - [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].ndarray ) cases.verify( @@ -65,7 +65,7 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 testee, inp, out=cases.allocate(unstructured_case, testee, cases.RETURN)(), - ref=np.sum(unstructured_case.offset_provider["V2E"].table, axis=1), + ref=np.sum(unstructured_case.offset_provider["V2E"].ndarray, axis=1), ) @@ -76,7 +76,7 @@ def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: return inp(V2E) out = unstructured_case.as_field( - [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].table) + [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].ndarray) ) inp = cases.allocate(unstructured_case, testee, "inp")() cases.verify( @@ -84,5 +84,5 @@ def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: testee, inp, out=out, - ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].table], + ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].ndarray], ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 29966c30ad..7648d34db7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -52,7 +52,7 @@ def testee(edge_f: cases.EField) -> cases.VField: inp = cases.allocate(unstructured_case, testee, "edge_f", strategy=strategy)() out = cases.allocate(unstructured_case, testee, cases.RETURN)() - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray ref = np.max( inp.asnumpy()[v2e_table], axis=1, @@ -69,7 +69,7 @@ def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) return out - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, minover, @@ -106,7 +106,7 @@ def reduction_ke_field( "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ ) def test_neighbor_sum(unstructured_case, fop): - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray edge_f = cases.allocate(unstructured_case, fop, "edge_f")() @@ -157,7 +157,7 @@ def fencil_op(edge_f: EKField) -> VKField: def fencil(edge_f: EKField, out: VKField): fencil_op(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray field = cases.allocate(unstructured_case, fencil, "edge_f", sizes={KDim: 2})() out = cases.allocate(unstructured_case, fencil_op, cases.RETURN, sizes={KDim: 1})() @@ -190,7 +190,7 @@ def reduce_expr(edge_f: cases.EField) -> cases.VField: def fencil(edge_f: cases.EField, out: cases.VField): reduce_expr(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, fencil, @@ -210,7 +210,7 @@ def test_reduction_with_common_expression(unstructured_case): def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, testee, @@ -226,7 +226,7 @@ def test_reduction_expression_with_where(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(where(mask, inp(V2E), inp(V2E)), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) @@ -255,7 +255,7 @@ def test_reduction_expression_with_where_and_tuples(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(where(mask, (inp(V2E), inp(V2E)), (inp(V2E), inp(V2E)))[1], axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) @@ -284,7 +284,7 @@ def test_reduction_expression_with_where_and_scalar(unstructured_case): def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: return neighbor_sum(inp(V2E) + where(mask, inp(V2E), 1), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray mask = unstructured_case.as_field( [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 11e28de9e1..66c56c4827 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -90,7 +90,7 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh a = cases.allocate(unstructured_case, testee, "a")() out = cases.allocate(unstructured_case, testee, "out")() - first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].table[:, i] for i in [0, 1]) + first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].ndarray[:, i] for i in [0, 1]) ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] cases.verify( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 3fc4ed9945..5e3a2fcd14 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -248,11 +248,14 @@ def test_can_deref(program_processor, stencil): program_processor, validate = program_processor Node = gtx.Dimension("Node") + NeighDim = gtx.Dimension("Neighbor", kind=gtx.DimensionKind.LOCAL) inp = gtx.as_field([Node], np.ones((1,), dtype=np.int32)) out = gtx.as_field([Node], np.asarray([0], dtype=inp.dtype)) - no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[-1]]), Node, Node, 1) + no_neighbor_tbl = gtx.as_connectivity( + domain={Node: 1, NeighDim: 1}, codomain=Node, data=np.array([[-1]]), skip_value=-1 + ) run_processor( stencil[{Node: range(1)}], program_processor, @@ -264,7 +267,9 @@ def test_can_deref(program_processor, stencil): if validate: assert np.allclose(out.asnumpy(), -1.0) - a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) + a_neighbor_tbl = gtx.as_connectivity( + domain={Node: 1, NeighDim: 1}, codomain=Node, data=np.array([[0]]), skip_value=-1 + ) run_processor( stencil[{Node: range(1)}], program_processor, @@ -277,37 +282,6 @@ def test_can_deref(program_processor, stencil): assert np.allclose(out.asnumpy(), 1.0) -# def test_can_deref_lifted(program_processor): -# program_processor, validate = program_processor - -# Neighbor = offset("Neighbor") -# Node = gtx.Dimension("Node") - -# @fundef -# def _can_deref(inp): -# shifted = shift(Neighbor, 0)(inp) -# return if_(can_deref(shifted), 1, -1) - -# inp = gtx.as_field([Node], np.zeros((1,))) -# out = gtx.as_field([Node], np.asarray([0])) - -# no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[None]]), Node, Node, 1) -# _can_deref[{Node: range(1)}]( -# inp, out=out, offset_provider={"Neighbor": no_neighbor_tbl}, program_processor=program_processor -# ) - -# if validate: -# assert np.allclose(np.asarray(out), -1.0) - -# a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) -# _can_deref[{Node: range(1)}]( -# inp, out=out, offset_provider={"Neighbor": a_neighbor_tbl}, program_processor=program_processor -# ) - -# if validate: -# assert np.allclose(np.asarray(out), 1.0) - - @pytest.mark.parametrize( "input_value, dtype, np_dtype", [ diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 69786b323b..7bde55bfd2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -14,6 +14,7 @@ from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from next_tests.unit_tests.conftest import program_processor, run_processor +from gt4py.next.iterator.embedded import StridedConnectivityField LocA = gtx.Dimension("LocA") @@ -21,8 +22,10 @@ LocB = gtx.Dimension("LocB") # unused LocA2LocAB = offset("O") -LocA2LocAB_offset_provider = gtx.StridedNeighborOffsetProvider( - origin_axis=LocA, neighbor_axis=LocAB, max_neighbors=2, has_skip_values=False +LocA2LocAB_offset_provider = StridedConnectivityField( + domain_dims=(LocA, gtx.Dimension("Dummy", kind=gtx.DimensionKind.LOCAL)), + codomain_dim=LocAB, + max_neighbors=2, ) @@ -41,7 +44,7 @@ def test_strided_offset_provider(program_processor): program_processor, validate = program_processor LocA_size = 2 - max_neighbors = LocA2LocAB_offset_provider.max_neighbors + max_neighbors = LocA2LocAB_offset_provider.__gt_type__().max_neighbors LocAB_size = LocA_size * max_neighbors rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index eb59c77201..6c6ca7e4bc 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -11,7 +11,6 @@ import numpy as np import pytest - pytest.importorskip("atlas4py") from gt4py import next as gtx @@ -22,20 +21,17 @@ exec_alloc_descriptor, ) from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( + E2V, + V2E, + E2VDim, + Edge, + V2EDim, + Vertex, assert_close, nabla_setup, ) -Vertex = gtx.Dimension("Vertex") -Edge = gtx.Dimension("Edge") -V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) -E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) - -V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) - - @gtx.field_operator def compute_zavgS( pp: gtx.Field[[Vertex], float], S_M: gtx.Field[[Edge], float] @@ -67,21 +63,19 @@ def pnabla( def test_ffront_compute_zavgS(exec_alloc_descriptor): - executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - - setup = nabla_setup() + _, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) - S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) + setup = nabla_setup(allocator=allocator) zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=allocator) - e2v = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False - ) - - compute_zavgS.with_backend(exec_alloc_descriptor)( - pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v} + compute_zavgS.with_backend( + None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor + )( + setup.input_field, + setup.S_fields[0], + out=zavgS, + offset_provider={"E2V": setup.edges2node_connectivity}, ) assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) @@ -89,27 +83,23 @@ def test_ffront_compute_zavgS(exec_alloc_descriptor): def test_ffront_nabla(exec_alloc_descriptor): - executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - - setup = nabla_setup() + _, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field, allocator=allocator) - pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) - S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field, allocator=allocator) + setup = nabla_setup(allocator=allocator) pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) - e2v = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False - ) - v2e = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7 - ) - - pnabla.with_backend(exec_alloc_descriptor)( - pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e} + pnabla.with_backend(None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor)( + setup.input_field, + setup.S_fields, + setup.sign_field, + setup.vol_field, + out=(pnabla_MXX, pnabla_MYY), + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) # TODO this check is not sensitive enough, need to implement a proper numpy reference! diff --git a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py index 8d7324f438..6a5865134d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py @@ -20,6 +20,18 @@ functionspace, ) +from gt4py import next as gtx +from gt4py.next.iterator import atlas_utils + + +Vertex = gtx.Dimension("Vertex") +Edge = gtx.Dimension("Edge") +V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) +E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) + +V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) +E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) + def assert_close(expected, actual): assert math.isclose(expected, actual), "expected={}, actual={}".format(expected, actual) @@ -33,9 +45,10 @@ def _default_config(): config["angle"] = 20.0 return config - def __init__(self, *, grid=StructuredGrid("O32"), config=None): + def __init__(self, *, allocator, grid=StructuredGrid("O32"), config=None): if config is None: config = self._default_config() + self.allocator = allocator mesh = StructuredMeshGenerator(config).generate(grid) fs_edges = functionspace.EdgeColumns(mesh, halo=1) @@ -55,12 +68,22 @@ def __init__(self, *, grid=StructuredGrid("O32"), config=None): self.edges_per_node = edges_per_node @property - def edges2node_connectivity(self): - return self.mesh.edges.node_connectivity + def edges2node_connectivity(self) -> gtx.Connectivity: + return gtx.as_connectivity( + domain={Edge: self.edges_size, E2VDim: 2}, + codomain=Vertex, + data=atlas_utils.AtlasTable(self.mesh.edges.node_connectivity).asnumpy(), + allocator=self.allocator, + ) @property - def nodes2edge_connectivity(self): - return self.mesh.nodes.edge_connectivity + def nodes2edge_connectivity(self) -> gtx.Connectivity: + return gtx.as_connectivity( + domain={Vertex: self.nodes_size, V2EDim: self.edges_per_node}, + codomain=Edge, + data=atlas_utils.AtlasTable(self.mesh.nodes.edge_connectivity).asnumpy(), + allocator=self.allocator, + ) @property def nodes_size(self): @@ -75,16 +98,16 @@ def _is_pole_edge(e, edge_flags): return Topology.check(edge_flags[e], Topology.POLE) @property - def is_pole_edge_field(self): + def is_pole_edge_field(self) -> gtx.Field: edge_flags = np.array(self.mesh.edges.flags()) pole_edge_field = np.zeros((self.edges_size,), dtype=bool) for e in range(self.edges_size): pole_edge_field[e] = self._is_pole_edge(e, edge_flags) - return pole_edge_field + return gtx.as_field([Edge], pole_edge_field, allocator=self.allocator) @property - def sign_field(self): + def sign_field(self) -> gtx.Field: node2edge_sign = np.zeros((self.nodes_size, self.edges_per_node)) edge_flags = np.array(self.mesh.edges.flags()) @@ -100,10 +123,10 @@ def sign_field(self): node2edge_sign[jnode, jedge] = -1.0 if self._is_pole_edge(iedge, edge_flags): node2edge_sign[jnode, jedge] = 1.0 - return node2edge_sign + return gtx.as_field([Vertex, V2EDim], node2edge_sign, allocator=self.allocator) @property - def S_fields(self): + def S_fields(self) -> tuple[gtx.Field, gtx.Field]: S = np.array(self.mesh.edges.field("dual_normals"), copy=False) S_MXX = np.zeros((self.edges_size)) S_MYY = np.zeros((self.edges_size)) @@ -124,10 +147,12 @@ def S_fields(self): assert math.isclose(min(S_MYY), -2001577.7946404363) assert math.isclose(max(S_MYY), 2001577.7946404363) - return S_MXX, S_MYY + return gtx.as_field([Edge], S_MXX, allocator=self.allocator), gtx.as_field( + [Edge], S_MYY, allocator=self.allocator + ) @property - def vol_field(self): + def vol_field(self) -> gtx.Field: rpi = 2.0 * math.asin(1.0) radius = 6371.22e03 deg2rad = 2.0 * rpi / 360.0 @@ -142,10 +167,10 @@ def vol_field(self): # VOL(min/max): 57510668192.214096 851856184496.32886 assert_close(57510668192.214096, min(vol)) assert_close(851856184496.32886, max(vol)) - return vol + return gtx.as_field([Vertex], vol, allocator=self.allocator) @property - def input_field(self): + def input_field(self) -> gtx.Field: klevel = 0 MXX = 0 MYY = 1 @@ -200,4 +225,5 @@ def input_field(self): assert_close(0.0000000000000000, min(rzs)) assert_close(1965.4980340735883, max(rzs)) - return rzs[:, klevel] + + return gtx.as_field([Vertex], rzs[:, klevel], allocator=self.allocator) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 3db4497910..4487681abf 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -111,25 +111,18 @@ def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): @pytest.mark.requires_atlas def test_compute_zavgS(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) + setup = nabla_setup(allocator=None) zavgS = gtx.as_field([Edge], np.zeros((setup.edges_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - run_processor( compute_zavgS_fencil, program_processor, setup.edges_size, zavgS, - pp, - S_MXX, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields[0], + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: @@ -141,9 +134,9 @@ def test_compute_zavgS(program_processor): program_processor, setup.edges_size, zavgS, - pp, - S_MYY, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields[1], + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: assert_close(-1000788897.3202186, np.min(zavgS.asnumpy())) @@ -158,29 +151,21 @@ def compute_zavgS2_fencil(n_edges, out, pp, S_M): @pytest.mark.requires_atlas def test_compute_zavgS2(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - pp = gtx.as_field([Vertex], setup.input_field) - - S = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) + setup = nabla_setup(allocator=None) zavgS = ( gtx.as_field([Edge], np.zeros((setup.edges_size))), gtx.as_field([Edge], np.zeros((setup.edges_size))), ) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - run_processor( compute_zavgS2_fencil, program_processor, setup.edges_size, zavgS, - pp, - S, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields, + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: @@ -195,34 +180,27 @@ def test_compute_zavgS2(program_processor): def test_nabla(program_processor): program_processor, validate = program_processor - setup = nabla_setup() + setup = nabla_setup(allocator=None) - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + S_MXX, S_MYY = setup.S_fields pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla, program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), - pp, + setup.input_field, S_MXX, S_MYY, - sign, - vol, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.sign_field, + setup.vol_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: @@ -245,33 +223,24 @@ def nabla2(n_nodes, out, pp, S, sign, vol): @pytest.mark.requires_atlas def test_nabla2(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_M = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) - vol = gtx.as_field([Vertex], setup.vol_field) + setup = nabla_setup(allocator=None) pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla2, program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), - pp, - S_M, - sign, - vol, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.input_field, + setup.S_fields, + setup.sign_field, + setup.vol_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: @@ -325,36 +294,29 @@ def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_ def test_nabla_sign(program_processor): program_processor, validate = program_processor - setup = nabla_setup() + setup = nabla_setup(allocator=None) - is_pole_edge = gtx.as_field([Edge], setup.is_pole_edge_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + S_MXX, S_MYY = setup.S_fields pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla_sign, program_processor, setup.nodes_size, pnabla_MXX, pnabla_MYY, - pp, + setup.input_field, S_MXX, S_MYY, - vol, + setup.vol_field, gtx.index_field(Vertex), - is_pole_edge, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.is_pole_edge_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 6fdc6a77a1..ac7ce9e544 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -38,9 +38,13 @@ V2VDim, Vertex, c2e_arr, + c2e_conn, e2v_arr, + e2v_conn, v2e_arr, + v2e_conn, v2v_arr, + v2v_conn, ) from next_tests.unit_tests.conftest import program_processor, run_processor @@ -89,7 +93,7 @@ def test_sum_edges_to_vertices(program_processor, stencil): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -111,7 +115,7 @@ def test_map_neighbors(program_processor): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -134,7 +138,7 @@ def test_map_make_const_list(program_processor): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -157,8 +161,8 @@ def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processo inp, out=out, offset_provider={ - "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), - "C2E": gtx.NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4), + "E2V": e2v_conn, + "C2E": c2e_conn, }, ) if validate: @@ -185,7 +189,7 @@ def test_sparse_input_field(program_processor): non_sparse, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: @@ -208,8 +212,8 @@ def test_sparse_input_field_v2v(program_processor): inp, out=out, offset_provider={ - "V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4), - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "V2V": v2v_conn, + "V2E": v2e_conn, }, ) @@ -235,7 +239,7 @@ def test_slice_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -259,7 +263,7 @@ def test_slice_twice_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -284,7 +288,7 @@ def test_shift_sliced_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -309,7 +313,7 @@ def test_slice_shifted_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -337,7 +341,7 @@ def test_lift(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -360,7 +364,7 @@ def test_shift_sparse_input_field(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -393,8 +397,8 @@ def test_shift_sparse_input_field2(program_processor): out2 = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) offset_provider = { - "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "E2V": e2v_conn, + "V2E": v2e_conn, } domain = {Vertex: range(0, 9)} @@ -448,7 +452,7 @@ def test_sparse_shifted_stencil_reduce(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: diff --git a/tests/next_tests/toy_connectivity.py b/tests/next_tests/toy_connectivity.py index 82c91a5e74..50db24b880 100644 --- a/tests/next_tests/toy_connectivity.py +++ b/tests/next_tests/toy_connectivity.py @@ -49,6 +49,8 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +c2e_conn = gtx.as_connectivity(domain={Cell: 9, C2EDim: 4}, codomain=Edge, data=c2e_arr) + v2v_arr = np.array( [ [1, 3, 2, 6], @@ -64,6 +66,8 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +v2v_conn = gtx.as_connectivity(domain={Vertex: 9, V2VDim: 4}, codomain=Vertex, data=v2v_arr) + e2v_arr = np.array( [ [0, 1], @@ -88,6 +92,7 @@ dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) +e2v_conn = gtx.as_connectivity(domain={Edge: 18, E2VDim: 2}, codomain=Vertex, data=e2v_arr) # order east, north, west, south (counter-clock wise) v2e_arr = np.array( @@ -104,3 +109,5 @@ ], dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), ) + +v2e_conn = gtx.as_connectivity(domain={Vertex: 9, V2EDim: 4}, codomain=Edge, data=v2e_arr) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index ca66b45d6d..f1269f1ed8 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -14,11 +14,11 @@ import pytest import gt4py.next as gtx -from gt4py.next import backend +from gt4py.next import backend, common +from gt4py.next.embedded import nd_array_field from gt4py.next.iterator import runtime from gt4py.next.program_processors import program_formatter - import next_tests @@ -97,12 +97,21 @@ def run_processor( @dataclasses.dataclass -class DummyConnectivity: +class DummyConnectivity(common.Connectivity): max_neighbors: int has_skip_values: int - origin_axis: gtx.Dimension = gtx.Dimension("dummy_origin") - neighbor_axis: gtx.Dimension = gtx.Dimension("dummy_neighbor") - index_type: type[int] = int + source_dim: gtx.Dimension = gtx.Dimension("dummy_origin") + codomain: gtx.Dimension = gtx.Dimension("dummy_neighbor") + + +def nd_array_implementation_params(): + for xp in nd_array_field._nd_array_implementations: + if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp: + yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu) + else: + yield pytest.param(xp, id=xp.__name__) + - def mapped_index(_, __) -> int: - return 0 +@pytest.fixture(params=nd_array_implementation_params()) +def nd_array_implementation(request): + yield request.param diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 063e79d92e..9dde5bb40a 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -15,7 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.common import Dimension, Domain, UnitRange, NamedRange, NamedIndex +from gt4py.next.common import Dimension, Domain, NamedIndex, NamedRange, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins @@ -28,19 +28,6 @@ D2 = Dimension("D2") -def nd_array_implementation_params(): - for xp in nd_array_field._nd_array_implementations: - if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp: - yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu) - else: - yield pytest.param(xp, id=xp.__name__) - - -@pytest.fixture(params=nd_array_implementation_params()) -def nd_array_implementation(request): - yield request.param - - @pytest.fixture( params=[ operator.add, diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py index dcc3a306f2..a91dbeb608 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py @@ -31,12 +31,10 @@ # 0 --0-- 1 --1-- 2 e2v_arr = np.array([[0, 1], [1, 2]]) -e2v_conn = gtx.NeighborTableOffsetProvider( - table=e2v_arr, - origin_axis=E, - neighbor_axis=V, - max_neighbors=2, - has_skip_values=False, +e2v_conn = gtx.as_connectivity( + domain={E: 2, E2VDim: 2}, + codomain=V, + data=e2v_arr, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 1f08362f4f..13e8637d1a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -10,18 +10,22 @@ import pytest import gt4py.next as gtx +from gt4py.next import common from gt4py.next.iterator.builtins import deref from gt4py.next.iterator.runtime import CartesianDomain, UnstructuredDomain, _deduce_domain, fundef -from next_tests.unit_tests.conftest import DummyConnectivity - @fundef def foo(inp): return deref(inp) -connectivity = DummyConnectivity(max_neighbors=0, has_skip_values=True) +connectivity = common.ConnectivityType( + domain=[gtx.Dimension("dummy_origin"), gtx.Dimension("dummy_neighbor")], + codomain=gtx.Dimension("dummy_codomain"), + skip_value=common._DEFAULT_SKIP_VALUE, + dtype=None, +) def test_deduce_domain(): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7b6214fb1b..65a5b5888d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -218,11 +218,11 @@ def expression_test_cases(): @pytest.mark.parametrize("test_case", expression_test_cases()) def test_expression_type(test_case): mesh = simple_mesh() - offset_provider = {**mesh.offset_provider, "Ioff": IDim, "Joff": JDim, "Koff": KDim} + offset_provider_type = {**mesh.offset_provider_type, "Ioff": IDim, "Joff": JDim, "Koff": KDim} testee, expected_type = test_case result = itir_type_inference.infer( - testee, offset_provider=offset_provider, allow_undeclared_symbols=True + testee, offset_provider_type=offset_provider_type, allow_undeclared_symbols=True ) assert result.type == expected_type @@ -231,14 +231,16 @@ def test_adhoc_polymorphism(): func = im.lambda_("a")(im.lambda_("b")(im.make_tuple("a", "b"))) testee = im.call(im.call(func)(im.ref("a_", bool_type)))(im.ref("b_", int_type)) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.type == ts.TupleType(types=[bool_type, int_type]) def test_aliased_function(): testee = im.let("f", im.lambda_("x")("x"))(im.call("f")(1)) - result = itir_type_inference.infer(testee, offset_provider={}) + result = itir_type_inference.infer(testee, offset_provider_type={}) assert result.args[0].type == ts.FunctionType( pos_only_args=[int_type], pos_or_kw_args={}, kw_only_args={}, returns=int_type @@ -253,7 +255,7 @@ def test_late_offset_axis(): testee = im.call(func)(im.ensure_offset("V2E")) result = itir_type_inference.infer( - testee, offset_provider=mesh.offset_provider, allow_undeclared_symbols=True + testee, offset_provider_type=mesh.offset_provider_type, allow_undeclared_symbols=True ) assert result.type == it_on_e_of_e_type @@ -265,7 +267,9 @@ def test_cast_first_arg_inference(): testee = im.call("cast_")( im.plus(im.literal_from_value(1), im.literal_from_value(2)), "float64" ) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.args[0].type == int_type assert result.type == float64_type @@ -291,7 +295,7 @@ def test_cartesian_fencil_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[IDim]), @@ -336,7 +340,7 @@ def test_unstructured_fencil_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[Vertex, KDim]), @@ -384,7 +388,7 @@ def test_function_definition(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) closure_type = it_ts.StencilClosureType( domain=it_ts.DomainType(dims=[IDim]), @@ -429,7 +433,7 @@ def test_fencil_with_nb_field_input(): ], ) - result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) assert result.closures[0].stencil.expr.args[0].type == float64_list_type assert result.closures[0].stencil.type.returns == float64_type @@ -456,7 +460,7 @@ def test_program_tuple_setat_short_target(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) assert ( isinstance(result.body[0].expr.type, ts.TupleType) @@ -487,7 +491,7 @@ def test_program_setat_without_domain(): ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) assert ( isinstance(result.body[0].expr.type, ts.DeferredType) @@ -512,7 +516,9 @@ def test_if_stmt(): false_branch=[], ) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.cond.type == bool_type assert result.true_branch[0].expr.type == float_i_field @@ -522,7 +528,7 @@ def test_as_fieldop_without_domain(): im.ref("inp", float_i_field) ) result = itir_type_inference.infer( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert result.type == ts.DeferredType(constraint=ts.FieldType) assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index e04856b75f..f4ea2d7fe1 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -21,7 +21,7 @@ @pytest.fixture -def offset_provider(request): +def offset_provider_type(request): return {"I": common.Dimension("I", kind=common.DimensionKind.HORIZONTAL)} @@ -137,7 +137,7 @@ def common_expr(): assert actual == expected -def test_if_can_deref_no_extraction(offset_provider): +def test_if_can_deref_no_extraction(offset_provider_type): # Test that a subexpression only occurring in one branch of an `if_` is not moved outside the # if statement. A case using `can_deref` is used here as it is common. @@ -157,11 +157,11 @@ def test_if_can_deref_no_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected -def test_if_can_deref_eligible_extraction(offset_provider): +def test_if_can_deref_eligible_extraction(offset_provider_type): # Test that a subexpression only occurring in both branches of an `if_` is moved outside the # if statement. A case using `can_deref` is used here as it is common. @@ -178,11 +178,11 @@ def test_if_can_deref_eligible_extraction(offset_provider): ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected -def test_if_eligible_extraction(offset_provider): +def test_if_eligible_extraction(offset_provider_type): # Test that a subexpression only occurring in the condition of an `if_` is moved outside the # if statement. @@ -191,7 +191,7 @@ def test_if_eligible_extraction(offset_provider): # (λ(_cs_1) → if _cs_1 ∧ _cs_1 then c else d)(a ∧ b) expected = im.let("_cs_1", im.and_("a", "b"))(im.if_(im.and_("_cs_1", "_cs_1"), "c", "d")) - actual = CSE.apply(testee, offset_provider=offset_provider, within_stencil=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 141091b450..817c06e8f0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -14,11 +14,12 @@ from gt4py import eve from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next import constructors from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import infer_domain from gt4py.next.iterator.ir_utils import domain_utils from gt4py.next.common import Dimension -from gt4py.next import common, NeighborTableOffsetProvider +from gt4py.next import common from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next import utils @@ -29,6 +30,7 @@ KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) Edge = common.Dimension(value="Edge", kind=common.DimensionKind.HORIZONTAL) +E2VDim = common.Dimension(value="E2V", kind=common.DimensionKind.LOCAL) @pytest.fixture @@ -39,11 +41,10 @@ def offset_provider(): @pytest.fixture def unstructured_offset_provider(): return { - "E2V": NeighborTableOffsetProvider( - np.array([[0, 1]], dtype=np.int32), - Edge, - Vertex, - 2, + "E2V": constructors.as_connectivity( + domain={Edge: 1, E2VDim: 2}, + codomain=Vertex, + data=np.array([[0, 1]], dtype=np.int32), ) } diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index b5b9a62009..168e9490e0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -13,6 +13,7 @@ from gt4py.next.iterator.transforms import fuse_as_fieldop from gt4py.next.type_system import type_specifications as ts + IDim = gtx.Dimension("IDim") field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) @@ -30,7 +31,7 @@ def test_trivial(): d, )(im.ref("inp1", field_type), im.ref("inp2", field_type), im.ref("inp3", field_type)) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -40,7 +41,7 @@ def test_trivial_literal(): testee = im.op_as_fieldop("plus", d)(im.op_as_fieldop("multiplies", d)(1, 2), 3) expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)() actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -65,7 +66,7 @@ def test_tuple_arg(): d, )() actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -85,7 +86,7 @@ def test_symref_used_twice(): d, )("inp1", "inp2") actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={}, allow_undeclared_symbols=True + testee, offset_provider_type={}, allow_undeclared_symbols=True ) assert actual == expected @@ -100,7 +101,7 @@ def test_no_inline(): d1, )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type))) actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == testee @@ -132,6 +133,6 @@ def test_partial_inline(): d1, )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), "inp1") actual = fuse_as_fieldop.FuseAsFieldOp.apply( - testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 23f62842c4..9d51dc4f33 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -52,7 +52,7 @@ def test_trivial(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -87,7 +87,7 @@ def test_trivial_let(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -128,7 +128,7 @@ def test_top_level_if(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( @@ -186,7 +186,7 @@ def test_nested_if(): ) ], ) - testee = type_inference.infer(testee, offset_provider=offset_provider) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) testee = infer_domain.infer_program(testee, offset_provider=offset_provider) expected = program_factory( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py index 7c991fb9a8..77d3323fb4 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py @@ -8,16 +8,16 @@ from gt4py import next as gtx from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.prune_casts import PruneCasts from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts def test_prune_casts_simple(): x_ref = im.ref("x", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) y_ref = im.ref("y", ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) testee = im.call("plus")(im.call("cast_")(x_ref, "float64"), im.call("cast_")(y_ref, "float64")) - testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) expected = im.call("plus")(im.call("cast_")(x_ref, "float64"), y_ref) actual = PruneCasts.apply(testee) @@ -32,7 +32,7 @@ def test_prune_casts_fieldop(): im.cast_as_fieldop("float64")(x_ref), im.cast_as_fieldop("float64")(y_ref), ) - testee = type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) expected = im.op_as_fieldop("plus")( im.cast_as_fieldop("float64")(x_ref), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 28bd88b853..0760247996 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -11,11 +11,20 @@ import pytest from gt4py.eve.utils import UIDs +from gt4py.next import common from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags -from next_tests.unit_tests.conftest import DummyConnectivity + +def dummy_connectivity_type(max_neighbors: int, has_skip_values: bool): + return common.NeighborConnectivityType( + domain=[common.Dimension("dummy_origin"), common.Dimension("dummy_neighbor")], + codomain=common.Dimension("dummy_codomain"), + skip_value=common._DEFAULT_SKIP_VALUE if has_skip_values else None, + dtype=None, + max_neighbors=max_neighbors, + ) @pytest.fixture(params=[True, False]) @@ -67,7 +76,7 @@ def reduction_if(): ], ) def test_get_partial_offsets(reduction, request): - offset_provider = {"Dim": SimpleNamespace(max_neighbors=3, has_skip_values=False)} + offset_provider_type = {"Dim": SimpleNamespace(max_neighbors=3, has_skip_values=False)} partial_offsets = _get_partial_offset_tags(request.getfixturevalue(reduction).args) assert set(partial_offsets) == {"Dim"} @@ -108,63 +117,73 @@ def _expected(red, dim, max_neighbors, has_skip_values, shifted_arg=0): def test_basic(basic_reduction, has_skip_values): expected = _expected(basic_reduction, "Dim", 3, has_skip_values) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=3, has_skip_values=has_skip_values)} - actual = UnrollReduce.apply(basic_reduction, offset_provider=offset_provider) + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=has_skip_values) + } + actual = UnrollReduce.apply(basic_reduction, offset_provider_type=offset_provider_type) assert actual == expected def test_reduction_with_shift_on_second_arg(reduction_with_shift_on_second_arg, has_skip_values): expected = _expected(reduction_with_shift_on_second_arg, "Dim", 1, has_skip_values, 1) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=1, has_skip_values=has_skip_values)} - actual = UnrollReduce.apply(reduction_with_shift_on_second_arg, offset_provider=offset_provider) + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=1, has_skip_values=has_skip_values) + } + actual = UnrollReduce.apply( + reduction_with_shift_on_second_arg, offset_provider_type=offset_provider_type + ) assert actual == expected def test_reduction_with_if(reduction_if): expected = _expected(reduction_if, "Dim", 2, False) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=2, has_skip_values=False)} - actual = UnrollReduce.apply(reduction_if, offset_provider=offset_provider) + offset_provider_type = {"Dim": dummy_connectivity_type(max_neighbors=2, has_skip_values=False)} + actual = UnrollReduce.apply(reduction_if, offset_provider_type=offset_provider_type) assert actual == expected def test_reduction_with_irrelevant_full_shift(reduction_with_irrelevant_full_shift): expected = _expected(reduction_with_irrelevant_full_shift, "Dim", 3, False) - offset_provider = { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "IrrelevantDim": DummyConnectivity( + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "IrrelevantDim": dummy_connectivity_type( max_neighbors=1, has_skip_values=True ), # different max_neighbors and skip value to trigger error } actual = UnrollReduce.apply( - reduction_with_irrelevant_full_shift, offset_provider=offset_provider + reduction_with_irrelevant_full_shift, offset_provider_type=offset_provider_type ) assert actual == expected @pytest.mark.parametrize( - "offset_provider", + "offset_provider_type", [ { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=False), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=False), }, { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=3, has_skip_values=True), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=3, has_skip_values=True), }, { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=True), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=True), }, ], ) -def test_reduction_with_incompatible_shifts(reduction_with_incompatible_shifts, offset_provider): - offset_provider = { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=False), +def test_reduction_with_incompatible_shifts( + reduction_with_incompatible_shifts, offset_provider_type +): + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=False), } with pytest.raises(RuntimeError, match="incompatible"): - UnrollReduce.apply(reduction_with_incompatible_shifts, offset_provider=offset_provider) + UnrollReduce.apply( + reduction_with_incompatible_shifts, offset_provider_type=offset_provider_type + ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 1a86f7b0f8..97591122e5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -21,7 +21,7 @@ def test_funcall_to_op(): ) actual = it2gtfn.GTFN_lowering( - grid_type=gtx.GridType.CARTESIAN, offset_provider={}, column_axis=None + grid_type=gtx.GridType.CARTESIAN, offset_provider_type={}, column_axis=None ).visit(testee) assert expected == actual @@ -32,7 +32,7 @@ def test_unapplied_funcall_to_function_object(): expected = gtfn_ir.SymRef(id="plus") actual = it2gtfn.GTFN_lowering( - grid_type=gtx.GridType.CARTESIAN, offset_provider={}, column_axis=None + grid_type=gtx.GridType.CARTESIAN, offset_provider_type={}, column_axis=None ).visit(testee) assert expected == actual diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index 329b2814d2..62d88d9f0a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -11,6 +11,7 @@ import ctypes import unittest import unittest.mock +from unittest.mock import patch import numpy as np import pytest @@ -20,19 +21,15 @@ from gt4py.next.ffront.fbuiltins import where from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import ( - E2V, - cartesian_case, - unstructured_case, -) +from next_tests.integration_tests.cases import E2V, cartesian_case, unstructured_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, mesh_descriptor, ) -from unittest.mock import patch from . import pytestmark + dace = pytest.importorskip("dace") @@ -151,14 +148,14 @@ def test_dace_fastcall_with_connectivity(unstructured_case, monkeypatch): # check that test connectivities are allocated on host memory # this is an assumption to test that fast_call cannot be used for gpu tests - assert isinstance(connectivity_E2V.table, np.ndarray) + assert isinstance(connectivity_E2V.ndarray, np.ndarray) @gtx.field_operator def testee(a: cases.VField) -> cases.EField: return a(E2V[0]) (a,), kwfields = cases.get_default_data(unstructured_case, testee) - numpy_ref = lambda a: a[connectivity_E2V.table[:, 0]] + numpy_ref = lambda a: a[connectivity_E2V.ndarray[:, 0]] mock_fast_call, mock_construct_args = make_mocks(monkeypatch) @@ -194,12 +191,11 @@ def verify_testee(offset_provider): # Here we copy the connectivity to gpu memory, and resuse the same cupy array # on multiple program calls, in order to ensure that fast_call is used. offset_provider = { - "E2V": gtx.NeighborTableOffsetProvider( - table=cp.asarray(connectivity_E2V.table), - origin_axis=connectivity_E2V.origin_axis, - neighbor_axis=connectivity_E2V.neighbor_axis, - max_neighbors=connectivity_E2V.max_neighbors, - has_skip_values=connectivity_E2V.has_skip_values, + "E2V": gtx.as_connectivity( + domain=connectivity_E2V.domain, + codomain=connectivity_E2V.codomain, + data=cp.asarray(connectivity_E2V.ndarray), + skip_value=connectivity_E2V.skip_value, ) } diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index e0c0c3fa4e..9c52ea81c3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from gt4py.next import common as gtx_common +from gt4py.next import common as gtx_common, constructors from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -50,13 +50,7 @@ "IDim": IDim, } SIMPLE_MESH: MeshDescriptor = simple_mesh() -SIMPLE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( - SIMPLE_MESH.offset_provider | CARTESIAN_OFFSETS -) SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() -SKIP_VALUE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( - SKIP_VALUE_MESH.offset_provider | CARTESIAN_OFFSETS -) SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) FSYMBOLS = dict( __w_size_0=N, @@ -83,20 +77,20 @@ def make_mesh_symbols(mesh: MeshDescriptor): __vertices_size_0=mesh.num_vertices, __vertices_stride_0=1, __connectivity_C2E_size_0=mesh.num_cells, - __connectivity_C2E_size_1=mesh.offset_provider["C2E"].max_neighbors, - __connectivity_C2E_stride_0=mesh.offset_provider["C2E"].max_neighbors, + __connectivity_C2E_size_1=mesh.offset_provider_type["C2E"].max_neighbors, + __connectivity_C2E_stride_0=mesh.offset_provider_type["C2E"].max_neighbors, __connectivity_C2E_stride_1=1, __connectivity_C2V_size_0=mesh.num_cells, - __connectivity_C2V_size_1=mesh.offset_provider["C2V"].max_neighbors, - __connectivity_C2V_stride_0=mesh.offset_provider["C2V"].max_neighbors, + __connectivity_C2V_size_1=mesh.offset_provider_type["C2V"].max_neighbors, + __connectivity_C2V_stride_0=mesh.offset_provider_type["C2V"].max_neighbors, __connectivity_C2V_stride_1=1, __connectivity_E2V_size_0=mesh.num_edges, - __connectivity_E2V_size_1=mesh.offset_provider["E2V"].max_neighbors, - __connectivity_E2V_stride_0=mesh.offset_provider["E2V"].max_neighbors, + __connectivity_E2V_size_1=mesh.offset_provider_type["E2V"].max_neighbors, + __connectivity_E2V_stride_0=mesh.offset_provider_type["E2V"].max_neighbors, __connectivity_E2V_stride_1=1, __connectivity_V2E_size_0=mesh.num_vertices, - __connectivity_V2E_size_1=mesh.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_0=mesh.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_size_1=mesh.offset_provider_type["V2E"].max_neighbors, + __connectivity_V2E_stride_0=mesh.offset_provider_type["V2E"].max_neighbors, __connectivity_V2E_stride_1=1, ) @@ -1018,14 +1012,14 @@ def test_gtir_connectivity_shift(): CELL_OFFSET_FTYPE = ts.FieldType(dims=[Cell], dtype=SIZE_TYPE) EDGE_OFFSET_FTYPE = ts.FieldType(dims=[Edge], dtype=SIZE_TYPE) - connectivity_C2E = SIMPLE_MESH_OFFSET_PROVIDER["C2E"] + connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] assert isinstance(connectivity_C2E, gtx_common.NeighborTable) - connectivity_E2V = SIMPLE_MESH_OFFSET_PROVIDER["E2V"] + connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) ev = np.random.rand(SIMPLE_MESH.num_edges, SIMPLE_MESH.num_vertices) - ref = ev[connectivity_C2E.table[:, C2E_neighbor_idx], :][ - :, connectivity_E2V.table[:, E2V_neighbor_idx] + ref = ev[connectivity_C2E.ndarray[:, C2E_neighbor_idx], :][ + :, connectivity_E2V.ndarray[:, E2V_neighbor_idx] ] for i, stencil in enumerate( @@ -1053,7 +1047,7 @@ def test_gtir_connectivity_shift(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) ce = np.empty([SIMPLE_MESH.num_cells, SIMPLE_MESH.num_edges]) @@ -1062,8 +1056,8 @@ def test_gtir_connectivity_shift(): ev, c2e_offset=np.full(SIMPLE_MESH.num_cells, C2E_neighbor_idx, dtype=np.int32), e2v_offset=np.full(SIMPLE_MESH.num_edges, E2V_neighbor_idx, dtype=np.int32), - connectivity_C2E=connectivity_C2E.table, - connectivity_E2V=connectivity_E2V.table, + connectivity_C2E=connectivity_C2E.ndarray, + connectivity_E2V=connectivity_E2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __ce_field_size_0=SIMPLE_MESH.num_cells, @@ -1114,15 +1108,17 @@ def test_gtir_connectivity_shift_chain(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_E2V = SIMPLE_MESH_OFFSET_PROVIDER["E2V"] + connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) - ref = e[connectivity_V2E.table[connectivity_E2V.table[:, E2V_neighbor_idx], V2E_neighbor_idx]] + ref = e[ + connectivity_V2E.ndarray[connectivity_E2V.ndarray[:, E2V_neighbor_idx], V2E_neighbor_idx] + ] # new empty output field e_out = np.empty_like(e) @@ -1130,8 +1126,8 @@ def test_gtir_connectivity_shift_chain(): sdfg( e, e_out, - connectivity_E2V=connectivity_E2V.table, - connectivity_V2E=connectivity_V2E.table, + connectivity_E2V=connectivity_E2V.ndarray, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __edges_out_size_0=SIMPLE_MESH.num_edges, @@ -1174,30 +1170,30 @@ def test_gtir_neighbors_as_input(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.shape[1]) e = np.random.rand(SIMPLE_MESH.num_edges) v = np.empty(SIMPLE_MESH.num_vertices, dtype=v2e_field.dtype) v_ref = [ functools.reduce(lambda x, y: x + y, v2e_values + e[v2e_neighbors], init_value) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field, strict=True) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field, strict=True) ] sdfg( v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __v2e_field_size_0=SIMPLE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1210,7 +1206,7 @@ def test_gtir_neighbors_as_output(): gtx_common.GridType.UNSTRUCTURED, ranges={ Vertex: (0, "nvertices"), - V2EDim: (0, SIMPLE_MESH_OFFSET_PROVIDER["V2E"].max_neighbors), + V2EDim: (0, SIMPLE_MESH.offset_provider_type["V2E"].max_neighbors), }, ) vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) @@ -1232,9 +1228,9 @@ def test_gtir_neighbors_as_output(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) @@ -1243,7 +1239,7 @@ def test_gtir_neighbors_as_output(): sdfg( e, v2e_field, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), __v2e_field_size_0=SIMPLE_MESH.num_vertices, @@ -1251,7 +1247,7 @@ def test_gtir_neighbors_as_output(): __v2e_field_stride_0=connectivity_V2E.max_neighbors, __v2e_field_stride_1=1, ) - assert np.allclose(v2e_field, e[connectivity_V2E.table]) + assert np.allclose(v2e_field, e[connectivity_V2E.ndarray]) def test_gtir_reduce(): @@ -1278,13 +1274,13 @@ def test_gtir_reduce(): ) )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) v_ref = [ functools.reduce(lambda x, y: x + y, e[v2e_neighbors], init_value) - for v2e_neighbors in connectivity_V2E.table + for v2e_neighbors in connectivity_V2E.ndarray ] for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): @@ -1305,7 +1301,7 @@ def test_gtir_reduce(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) # new empty output field v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) @@ -1313,7 +1309,7 @@ def test_gtir_reduce(): sdfg( e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), ) @@ -1344,7 +1340,7 @@ def test_gtir_reduce_with_skip_values(): ) )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SKIP_VALUE_MESH.num_edges) @@ -1354,7 +1350,7 @@ def test_gtir_reduce_with_skip_values(): [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], init_value, ) - for v2e_neighbors in connectivity_V2E.table + for v2e_neighbors in connectivity_V2E.ndarray ] for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): @@ -1375,7 +1371,7 @@ def test_gtir_reduce_with_skip_values(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) # new empty output field v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) @@ -1383,7 +1379,7 @@ def test_gtir_reduce_with_skip_values(): sdfg( e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SKIP_VALUE_MESH), ) @@ -1394,10 +1390,10 @@ def test_gtir_reduce_dot_product(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(*connectivity_V2E.shape) e = np.random.rand(SKIP_VALUE_MESH.num_edges) v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ @@ -1409,7 +1405,7 @@ def test_gtir_reduce_dot_product(): ), init_value, ) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field) ] testee = gtir.Program( @@ -1448,17 +1444,17 @@ def test_gtir_reduce_dot_product(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) sdfg( v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **make_mesh_symbols(SKIP_VALUE_MESH), __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1500,14 +1496,14 @@ def test_gtir_reduce_with_cond_neighbors(): ], ) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SKIP_VALUE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(*connectivity_V2E.shape) e = np.random.rand(SKIP_VALUE_MESH.num_edges) for use_sparse in [False, True]: - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH.offset_provider_type) v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ @@ -1525,19 +1521,19 @@ def test_gtir_reduce_with_cond_neighbors(): [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], init_value, ) - for v2e_neighbors, v2e_values in zip(connectivity_V2E.table, v2e_field, strict=True) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field, strict=True) ] sdfg( np.bool_(use_sparse), v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SKIP_VALUE_MESH), __v2e_field_size_0=SKIP_VALUE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1631,9 +1627,9 @@ def test_gtir_let_lambda_with_connectivity(): C2V_neighbor_idx = 2 cell_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Cell: (0, "ncells")}) - connectivity_C2E = SIMPLE_MESH_OFFSET_PROVIDER["C2E"] + connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] assert isinstance(connectivity_C2E, gtx_common.NeighborTable) - connectivity_C2V = SIMPLE_MESH_OFFSET_PROVIDER["C2V"] + connectivity_C2V = SIMPLE_MESH.offset_provider["C2V"] assert isinstance(connectivity_C2V, gtx_common.NeighborTable) testee = gtir.Program( @@ -1669,22 +1665,22 @@ def test_gtir_let_lambda_with_connectivity(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH.offset_provider_type) e = np.random.rand(SIMPLE_MESH.num_edges) v = np.random.rand(SIMPLE_MESH.num_vertices) c = np.empty(SIMPLE_MESH.num_cells) ref = ( - e[connectivity_C2E.table[:, C2E_neighbor_idx]] - + v[connectivity_C2V.table[:, C2V_neighbor_idx]] + e[connectivity_C2E.ndarray[:, C2E_neighbor_idx]] + + v[connectivity_C2V.ndarray[:, C2V_neighbor_idx]] ) sdfg( cells=c, edges=e, vertices=v, - connectivity_C2E=connectivity_C2E.table, - connectivity_C2V=connectivity_C2V.table, + connectivity_C2E=connectivity_C2E.ndarray, + connectivity_C2V=connectivity_C2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), ) diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index 6e9dfa3d64..0998ab8eab 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -11,10 +11,7 @@ from gt4py import next as gtx from gt4py._core import definitions as core_defs -from gt4py.next import allocators as next_allocators, common, float32 -from gt4py.next.program_processors.runners import roundtrip - -from next_tests.integration_tests import cases +from gt4py.next import allocators as next_allocators, common I = gtx.Dimension("I") @@ -154,3 +151,12 @@ def test_field_wrong_origin(): @pytest.mark.xfail(reason="aligned_index not supported yet") def test_aligned_index(): gtx.as_field([I], np.random.rand(sizes[I]).astype(gtx.float32), aligned_index=[I, 0]) + + +@pytest.mark.parametrize( + "data, skip_value", + [([0, 1, 2], None), ([0, 1, common._DEFAULT_SKIP_VALUE], common._DEFAULT_SKIP_VALUE)], +) +def test_as_connectivity(nd_array_implementation, data, skip_value): + testee = gtx.as_connectivity([I], J, nd_array_implementation.array(data)) + assert testee.skip_value is skip_value From 3fb206e46ceecf07b7ef6c668239d62d79028503 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 26 Nov 2024 10:53:19 +0100 Subject: [PATCH 02/13] feat[next][dace]: Symbolic domain without dace array offsets (#1735) Add support for field operator domain with symbolic shape, with dimension extent in non zero-based range. --- .../runners/dace_common/utility.py | 10 +- .../gtir_builtin_translators.py | 127 ++++++++++----- .../runners/dace_fieldview/gtir_dataflow.py | 100 +++++++----- .../runners/dace_fieldview/gtir_sdfg.py | 148 +++++++++++++----- .../runners/dace_fieldview/utility.py | 11 +- .../dace_tests/test_gtir_to_sdfg.py | 123 +++++++++++++-- 6 files changed, 367 insertions(+), 152 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index 29395a30c1..3e96ef3cec 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Final, Optional, Sequence +from typing import Final, Literal, Optional, Sequence import dace @@ -51,12 +51,16 @@ def connectivity_identifier(name: str) -> str: return f"connectivity_{name}" +def field_symbol_name(field_name: str, axis: int, sym: Literal["size", "stride"]) -> str: + return f"__{field_name}_{sym}_{axis}" + + def field_size_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_size_{axis}" + return field_symbol_name(field_name, axis, "size") def field_stride_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_stride_{axis}" + return field_symbol_name(field_name, axis, "stride") def is_field_symbol(name: str) -> bool: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 69aedf44d6..60dcd8ddc9 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, TypeAlias +from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace import dace.subsets as sbs @@ -33,6 +33,34 @@ from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg +def _get_domain_indices( + dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None +) -> sbs.Indices: + """ + Helper function to construct the list of indices for a field domain, applying + an optional offset in each dimension as start index. + + Args: + dims: The field dimensions. + offsets: The range start index in each dimension. + + Returns: + A list of indices for field access in dace arrays. As this list is returned + as `dace.subsets.Indices`, it should be converted to `dace.subsets.Range` before + being used in memlet subset because ranges are better supported throughout DaCe. + """ + index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] + if offsets is None: + return sbs.Indices(index_variables) + else: + return sbs.Indices( + [ + index - offset if offset != 0 else index + for index, offset in zip(index_variables, offsets, strict=True) + ] + ) + + @dataclasses.dataclass(frozen=True) class FieldopData: """ @@ -45,42 +73,59 @@ class FieldopData: Args: dc_node: DaCe access node to the data storage. gt_type: GT4Py type definition, which includes the field domain information. + offset: List of index offsets, in each dimension, when the dimension range + does not start from zero; assume zero offset, if not set. """ dc_node: dace.nodes.AccessNode gt_type: ts.FieldType | ts.ScalarType + offset: Optional[list[dace.symbolic.SymExpr]] + + def make_copy(self, data_node: dace.nodes.AccessNode) -> FieldopData: + """Create a copy of this data descriptor with a different access node.""" + assert data_node != self.dc_node + return FieldopData(data_node, self.gt_type, self.offset) def get_local_view( self, domain: FieldopDomain ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: - """Helper method to access a field in local view, given a field operator domain.""" + """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) ) if isinstance(self.gt_type, ts.FieldType): - indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { - dim: gtir_dataflow.SymbolExpr(dace_gtir_utils.get_map_variable(dim), INDEX_DTYPE) - for dim, _, _ in domain + domain_dims = [dim for dim, _, _ in domain] + domain_indices = _get_domain_indices(domain_dims) + it_indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { + dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE) + for dim, index in zip(domain_dims, domain_indices) } + field_domain = [ + (dim, dace.symbolic.SymExpr(0) if self.offset is None else self.offset[i]) + for i, dim in enumerate(self.gt_type.dims) + ] local_dims = [ dim for dim in self.gt_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL ] - if len(local_dims) == 0: return gtir_dataflow.IteratorExpr( - self.dc_node, self.gt_type.dtype, self.gt_type.dims, indices + self.dc_node, self.gt_type.dtype, field_domain, it_indices ) elif len(local_dims) == 1: field_dtype = itir_ts.ListType( element_type=self.gt_type.dtype, offset_type=local_dims[0] ) - field_dims = [ - dim for dim in self.gt_type.dims if dim.kind != gtx_common.DimensionKind.LOCAL + field_domain = [ + (dim, offset) + for dim, offset in field_domain + if dim.kind != gtx_common.DimensionKind.LOCAL ] - return gtir_dataflow.IteratorExpr(self.dc_node, field_dtype, field_dims, indices) + return gtir_dataflow.IteratorExpr( + self.dc_node, field_dtype, field_domain, it_indices + ) else: raise ValueError( @@ -155,9 +200,9 @@ def _parse_fieldop_arg( return arg.get_local_view(domain) -def _get_field_shape( +def _get_field_layout( domain: FieldopDomain, -) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr]]: +) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]: """ Parse the field operator domain and generates the shape of the result field. @@ -174,11 +219,14 @@ def _get_field_shape( domain: The field operator domain. Returns: - A tuple of two lists: the list of field dimensions and the list of dace - array sizes in each dimension. + A tuple of three lists containing: + - the domain dimensions + - the domain offset in each dimension + - the domain size in each dimension """ - domain_dims, _, domain_ubs = zip(*domain) - return list(domain_dims), list(domain_ubs) + domain_dims, domain_lbs, domain_ubs = zip(*domain) + domain_sizes = [(ub - lb) for lb, ub in zip(domain_lbs, domain_ubs)] + return list(domain_dims), list(domain_lbs), domain_sizes def _create_temporary_field( @@ -189,7 +237,7 @@ def _create_temporary_field( dataflow_output: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: """Helper method to allocate a temporary field where to write the output of a field operator.""" - field_dims, field_shape = _get_field_shape(domain) + field_dims, field_offset, field_shape = _get_field_layout(domain) output_desc = dataflow_output.result.dc_node.desc(sdfg) if isinstance(output_desc, dace.data.Array): @@ -197,6 +245,7 @@ def _create_temporary_field( assert isinstance(node_type.dtype.element_type, ts.ScalarType) assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) + field_offset.extend(output_desc.offset) field_shape.extend(output_desc.shape) elif isinstance(output_desc, dace.data.Scalar): assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) @@ -215,7 +264,11 @@ def _create_temporary_field( assert dataflow_output.result.gt_dtype.offset_type is not None field_dims.append(dataflow_output.result.gt_dtype.offset_type) - return FieldopData(field_node, ts.FieldType(field_dims, field_dtype)) + return FieldopData( + field_node, + ts.FieldType(field_dims, field_dtype), + offset=(field_offset if set(field_offset) != {0} else None), + ) def extract_domain(node: gtir.Node) -> FieldopDomain: @@ -285,7 +338,8 @@ def translate_as_fieldop( # parse the domain of the field operator domain = extract_domain(domain_expr) - domain_indices = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) + domain_dims, domain_offsets, _ = zip(*domain) + domain_indices = _get_domain_indices(domain_dims, domain_offsets) # visit the list of arguments to be passed to the lambda expression stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] @@ -350,10 +404,8 @@ def translate_broadcast_scalar( assert cpm.is_ref_to(stencil_expr, "deref") domain = extract_domain(domain_expr) - field_dims, field_shape = _get_field_shape(domain) - field_subset = sbs.Range.from_string( - ",".join(dace_gtir_utils.get_map_variable(dim) for dim in field_dims) - ) + output_dims, output_offset, output_shape = _get_field_layout(domain) + output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset)) assert len(node.args) == 1 scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) @@ -369,26 +421,15 @@ def translate_broadcast_scalar( assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) if len(node.args[0].type.dims) == 0: # zero-dimensional field input_subset = "0" - elif all( - isinstance(scalar_expr.indices[dim], gtir_dataflow.SymbolExpr) - for dim in scalar_expr.dimensions - if dim not in field_dims - ): - input_subset = ",".join( - dace_gtir_utils.get_map_variable(dim) - if dim in field_dims - else scalar_expr.indices[dim].value # type: ignore[union-attr] # catched by exception above - for dim in scalar_expr.dimensions - ) else: - raise ValueError(f"Cannot deref field {scalar_expr.field} in broadcast expression.") + input_subset = scalar_expr.get_memlet_subset(sdfg) input_node = scalar_expr.field gt_dtype = node.args[0].type.dtype else: raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") - output, _ = sdfg.add_temp_transient(field_shape, input_node.desc(sdfg).dtype) + output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype) output_node = state.add_access(output) sdfg_builder.add_mapped_tasklet( @@ -400,13 +441,13 @@ def translate_broadcast_scalar( }, inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, code="__val = __inp", - outputs={"__val": dace.Memlet(data=output_node.data, subset=field_subset)}, + outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)}, input_nodes={input_node.data: input_node}, output_nodes={output_node.data: output_node}, external_edges=True, ) - return FieldopData(output_node, ts.FieldType(field_dims, gt_dtype)) + return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset) def translate_if( @@ -467,7 +508,7 @@ def construct_output(inner_data: FieldopData) -> FieldopData: outer, _ = sdfg.add_temp_transient_like(inner_desc) outer_node = state.add_access(outer) - return FieldopData(outer_node, inner_data.gt_type) + return inner_data.make_copy(outer_node) result_temps = gtx_utils.tree_map(construct_output)(true_br_args) @@ -513,7 +554,7 @@ def _get_data_nodes( ) -> FieldopResult: if isinstance(data_type, ts.FieldType): data_node = state.add_access(data_name) - return FieldopData(data_node, data_type) + return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.ScalarType): if data_name in sdfg.symbols: @@ -522,7 +563,7 @@ def _get_data_nodes( ) else: data_node = state.add_access(data_name) - return FieldopData(data_node, data_type) + return sdfg_builder.make_field(data_node, data_type) elif isinstance(data_type, ts.TupleType): tuple_fields = dace_gtir_utils.get_tuple_fields(data_name, data_type) @@ -579,7 +620,7 @@ def translate_literal( data_type = node.type data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) - return FieldopData(data_node, data_type) + return FieldopData(data_node, data_type, offset=None) def translate_make_tuple( @@ -708,7 +749,7 @@ def translate_scalar_expr( dace.Memlet(data=temp_name, subset="0"), ) - return FieldopData(temp_node, node.type) + return FieldopData(temp_node, node.type, offset=None) def translate_symbol_ref( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 74142dec66..cfba4d61e5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -90,17 +90,42 @@ class IteratorExpr: Args: field: Access node to the field this iterator operates on. gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. - dimensions: Field domain represented as a sorted list of dimensions, needed - to order the map index variables and dereference an element in the field. + field_domain: Field domain represented as a sorted list of dimensions and offset values, + used to find the position of a map index variable in the memlet subset. The offset + value is either the start index of dimension range or the compile-time value of + a shift expression, or a composition of both, and it must be subtracted to the index + variable when constructing the memlet subset range. indices: Maps each dimension to an index value, which could be either a symbolic value or the result of a tasklet computation like neighbors connectivity or dynamic offset. """ field: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - dimensions: list[gtx_common.Dimension] + field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] + def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: + if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): + raise ValueError(f"Cannot deref iterator {self}.") + + field_desc = self.field.desc(sdfg) + if isinstance(self.gt_dtype, itir_ts.ListType): + assert len(field_desc.shape) == len(self.field_domain) + 1 + assert self.gt_dtype.offset_type is not None + field_domain = [*self.field_domain, (self.gt_dtype.offset_type, 0)] + else: + assert len(field_desc.shape) == len(self.field_domain) + field_domain = self.field_domain + + return sbs.Range.from_string( + ",".join( + str(self.indices[dim].value - offset) # type: ignore[union-attr] + if dim in self.indices + else f"0:{size}" + for (dim, offset), size in zip(field_domain, field_desc.shape, strict=True) + ) + ) + class DataflowInputEdge(Protocol): """ @@ -271,8 +296,17 @@ def _add_input_data_edge( src_subset: sbs.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, + src_offset: Optional[list[dace.symbolic.SymExpr]] = None, ) -> None: - edge = MemletInputEdge(self.state, src, src_subset, dst_node, dst_conn) + input_subset = ( + src_subset + if src_offset is None + else sbs.Range( + (start - off, stop - off, step) + for (start, stop, step), off in zip(src_subset, src_offset, strict=True) + ) + ) + edge = MemletInputEdge(self.state, src, input_subset, dst_node, dst_conn) self.input_edges.append(edge) def _add_edge( @@ -440,34 +474,21 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: field_desc = arg_expr.field.desc(self.sdfg) if isinstance(field_desc, dace.data.Scalar): # deref a zero-dimensional field - assert len(arg_expr.dimensions) == 0 + assert len(arg_expr.field_domain) == 0 assert isinstance(node.type, ts.ScalarType) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0") # default case: deref a field with one or more dimensions if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): - # when all indices are symblic expressions, we can perform direct field access through a memlet - if isinstance(arg_expr.gt_dtype, itir_ts.ListType): - assert len(field_desc.shape) == len(arg_expr.dimensions) + 1 - assert arg_expr.gt_dtype.offset_type is not None - field_dims = [*arg_expr.dimensions, arg_expr.gt_dtype.offset_type] - else: - assert len(field_desc.shape) == len(arg_expr.dimensions) - field_dims = arg_expr.dimensions - - field_subset = sbs.Range( - (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] - if dim in arg_expr.indices - else (0, size - 1, 1) - for dim, size in zip(field_dims, field_desc.shape) - ) + # when all indices are symbolic expressions, we can perform direct field access through a memlet + field_subset = arg_expr.get_memlet_subset(self.sdfg) return MemletExpr(arg_expr.field, arg_expr.gt_dtype, field_subset) # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, # either indirection through connectivity table or dynamic cartesian offset. - assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) - assert len(field_desc.shape) == len(arg_expr.dimensions) - field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] + assert all(dim in arg_expr.indices for dim, _ in arg_expr.field_domain) + assert len(field_desc.shape) == len(arg_expr.field_domain) + field_indices = [(dim, arg_expr.indices[dim]) for dim, _ in arg_expr.field_domain] index_connectors = [ IndexConnectorFmt.format(dim=dim.value) for dim, index in field_indices @@ -494,6 +515,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: sbs.Range.from_array(field_desc), deref_node, "field", + src_offset=[offset for (_, offset) in arg_expr.field_domain], ) for dim, index_expr in field_indices: @@ -532,7 +554,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: it = self.visit(node.args[1]) assert isinstance(it, IteratorExpr) - assert offset_provider.codomain in it.dimensions + assert any(dim == offset_provider.codomain for dim, _ in it.field_domain) assert offset_provider.source_dim in it.indices origin_index = it.indices[offset_provider.source_dim] assert isinstance(origin_index, SymbolExpr) @@ -560,10 +582,12 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=node.type, subset=sbs.Range.from_string( ",".join( - it.indices[dim].value # type: ignore[union-attr] + str(it.indices[dim].value - offset) # type: ignore[union-attr] if dim != offset_provider.codomain else f"0:{size}" - for dim, size in zip(it.dimensions, field_desc.shape, strict=True) + for (dim, offset), size in zip( + it.field_domain, field_desc.shape, strict=True + ) ) ), ) @@ -971,14 +995,13 @@ def _make_cartesian_shift( self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: DataExpr ) -> IteratorExpr: """Implements cartesian shift along one dimension.""" - assert offset_dim in it.dimensions + assert any(dim == offset_dim for dim, _ in it.field_domain) new_index: SymbolExpr | ValueExpr - assert offset_dim in it.indices index_expr = it.indices[offset_dim] if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): # purely symbolic expression which can be interpreted at compile time new_index = SymbolExpr( - dace.symbolic.pystr_to_symbolic(index_expr.value) + offset_expr.value, + index_expr.value + offset_expr.value, index_expr.dc_dtype, ) else: @@ -1032,15 +1055,10 @@ def _make_cartesian_shift( ) # a new iterator with a shifted index along one dimension - return IteratorExpr( - field=it.field, - gt_dtype=it.gt_dtype, - dimensions=it.dimensions, - indices={ - dim: (new_index if dim == offset_dim else index) - for dim, index in it.indices.items() - }, - ) + shifted_indices = { + dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items() + } + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _make_dynamic_neighbor_offset( self, @@ -1094,7 +1112,7 @@ def _make_unstructured_shift( offset_expr: DataExpr, ) -> IteratorExpr: """Implements shift in unstructured domain by means of a neighbor table.""" - assert connectivity.codomain in it.dimensions + assert any(dim == connectivity.codomain for dim, _ in it.field_domain) neighbor_dim = connectivity.codomain assert neighbor_dim not in it.indices @@ -1117,9 +1135,7 @@ def _make_unstructured_shift( offset_expr, offset_table_node, origin_index ) - return IteratorExpr( - field=it.field, gt_dtype=it.gt_dtype, dimensions=it.dimensions, indices=shifted_indices - ) + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: # convert builtin-index type to dace type diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 52284edfac..f15287e64c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -16,6 +16,7 @@ import abc import dataclasses +import functools import itertools import operator from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union @@ -98,9 +99,16 @@ def add_mapped_tasklet( class SDFGBuilder(DataflowBuilder, Protocol): """Visitor interface available to GTIR-primitive translators.""" + @abc.abstractmethod + def make_field( + self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + ) -> gtir_builtin_translators.FieldopData: + """Retrieve the field data descriptor including the domain offset information.""" + ... + @abc.abstractmethod def get_symbol_type(self, symbol_name: str) -> ts.DataType: - """Retrieve the GT4Py type of a symbol used in the program.""" + """Retrieve the GT4Py type of a symbol used in the SDFG.""" ... @abc.abstractmethod @@ -141,6 +149,15 @@ def _collect_symbols_in_domain_expressions( ) +def _get_tuple_type(data: tuple[gtir_builtin_translators.FieldopResult, ...]) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. + """ + return ts.TupleType( + types=[_get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] + ) + + @dataclasses.dataclass(frozen=True) class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): """Provides translation capability from a GTIR program to a DaCe SDFG. @@ -157,6 +174,9 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): offset_provider_type: gtx_common.OffsetProviderType global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) + field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( + default_factory=lambda: {} + ) map_uids: eve.utils.UIDGenerator = dataclasses.field( init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") ) @@ -167,6 +187,15 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: return self.offset_provider_type[offset] + def make_field( + self, data_node: dace.nodes.AccessNode, data_type: ts.FieldType | ts.ScalarType + ) -> gtir_builtin_translators.FieldopData: + if isinstance(data_type, ts.FieldType): + domain_offset = self.field_offsets.get(data_node.data, None) + else: + domain_offset = None + return gtir_builtin_translators.FieldopData(data_node, data_type, domain_offset) + def get_symbol_type(self, symbol_name: str) -> ts.DataType: return self.global_symbols[symbol_name] @@ -248,12 +277,10 @@ def _add_storage( """ if isinstance(gt_type, ts.TupleType): tuple_fields = [] - for tname, tsymbol_type in dace_gtir_utils.get_tuple_fields( - name, gt_type, flatten=True - ): + for tname, ttype in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): tuple_fields.extend( self._add_storage( - sdfg, symbolic_arguments, tname, tsymbol_type, transient, tuple_name=name + sdfg, symbolic_arguments, tname, ttype, transient, tuple_name=name ) ) return tuple_fields @@ -275,7 +302,6 @@ def _add_storage( tuple_name, gt_type.dims ) sdfg.add_array(name, sym_shape, dc_dtype, strides=sym_strides, transient=transient) - return [(name, gt_type)] elif isinstance(gt_type, ts.ScalarType): @@ -344,7 +370,7 @@ def make_temps( head_state.add_nedge( field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) ) - return gtir_builtin_translators.FieldopData(temp_node, field.gt_type) + return field.make_copy(temp_node) temp_result = gtx_utils.tree_map(make_temps)(result) return list(gtx_utils.flatten_nested_tuple((temp_result,))) @@ -405,6 +431,10 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: if node.function_definitions: raise NotImplementedError("Functions expected to be inlined as lambda calls.") + # Since program field arguments are passed to the SDFG as full-shape arrays, + # there is no offset that needs to be compensated. + assert len(self.field_offsets) == 0 + sdfg = dace.SDFG(node.id) sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) @@ -459,7 +489,7 @@ def visit_SetAt( The SDFG head state, eventually updated if the target write requires a new state. """ - temp_fields = self._visit_expression(stmt.expr, sdfg, state) + source_fields = self._visit_expression(stmt.expr, sdfg, state) # the target expression could be a `SymRef` to an output node or a `make_tuple` expression # in case the statement returns more than one field @@ -482,17 +512,26 @@ def visit_SetAt( } target_state: Optional[dace.SDFGState] = None - for temp, target in zip(temp_fields, target_fields, strict=True): + for source, target in zip(source_fields, target_fields, strict=True): target_desc = sdfg.arrays[target.dc_node.data] assert not target_desc.transient if isinstance(target.gt_type, ts.FieldType): - subset = ",".join( + target_subset = ",".join( f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.gt_type.dims ) + source_subset = ( + target_subset + if source.offset is None + else ",".join( + f"{domain[dim][0] - offset}:{domain[dim][1] - offset}" + for dim, offset in zip(target.gt_type.dims, source.offset, strict=True) + ) + ) else: assert len(domain) == 0 - subset = "0" + target_subset = "0" + source_subset = "0" if target.dc_node.data in state_input_data: # if inout argument, write the result in separate next state @@ -501,17 +540,21 @@ def visit_SetAt( target_state = sdfg.add_state_after(state, f"post_{state.label}") # create new access nodes in the target state target_state.add_nedge( - target_state.add_access(temp.dc_node.data), + target_state.add_access(source.dc_node.data), target_state.add_access(target.dc_node.data), - dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), ) # remove isolated access node state.remove_node(target.dc_node) else: state.add_nedge( - temp.dc_node, + source.dc_node, target.dc_node, - dace.Memlet(data=target.dc_node.data, subset=subset, other_subset=subset), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), ) return target_state or state @@ -574,17 +617,65 @@ def visit_Lambda( (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) ] + def flatten_tuples( + name: str, + arg: gtir_builtin_translators.FieldopResult, + ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: + if isinstance(arg, tuple): + tuple_type = _get_tuple_type(arg) + tuple_field_names = [ + arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) + ] + tuple_args = zip(tuple_field_names, arg, strict=True) + return list( + itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args]) + ) + else: + return [(name, arg)] + + lambda_arg_nodes = dict( + itertools.chain(*[flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) + ) + # inherit symbols from parent scope but eventually override with local symbols lambda_symbols = { sym: self.global_symbols[sym] for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) } | { - pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type + pname: _get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type for pname, arg in lambda_args_mapping } + def get_field_domain_offset( + p_name: str, p_type: ts.DataType + ) -> dict[str, Optional[list[dace.symbolic.SymExpr]]]: + if isinstance(p_type, ts.FieldType): + if p_name in lambda_arg_nodes: + arg = lambda_arg_nodes[p_name] + assert isinstance(arg, gtir_builtin_translators.FieldopData) + return {p_name: arg.offset} + elif field_domain_offset := self.field_offsets.get(p_name, None): + return {p_name: field_domain_offset} + elif isinstance(p_type, ts.TupleType): + p_fields = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) + return functools.reduce( + lambda field_offsets, field: ( + field_offsets | get_field_domain_offset(field[0], field[1]) + ), + p_fields, + {}, + ) + return {} + + # populate mapping from field name to domain offset + lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} + for p_name, p_type in lambda_symbols.items(): + lambda_field_offsets |= get_field_domain_offset(p_name, p_type) + # lower let-statement lambda node as a nested SDFG - lambda_translator = GTIRToSDFG(self.offset_provider_type, lambda_symbols) + lambda_translator = GTIRToSDFG( + self.offset_provider_type, lambda_symbols, lambda_field_offsets + ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) nstate = nsdfg.add_state("lambda") @@ -603,30 +694,11 @@ def visit_Lambda( head_state=nstate, ) - def _flatten_tuples( - name: str, - arg: gtir_builtin_translators.FieldopResult, - ) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: - if isinstance(arg, tuple): - tuple_type = dace_gtir_utils.get_tuple_type(arg) - tuple_field_names = [ - arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) - ] - tuple_args = zip(tuple_field_names, arg, strict=True) - return list( - itertools.chain(*[_flatten_tuples(fname, farg) for fname, farg in tuple_args]) - ) - else: - return [(name, arg)] - # Process lambda inputs # # All input arguments are passed as parameters to the nested SDFG, therefore # we they are stored as non-transient array and scalar objects. # - lambda_arg_nodes = dict( - itertools.chain(*[_flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) - ) connectivity_arrays = { dace_utils.connectivity_identifier(offset) for offset in dace_utils.filter_connectivity_types(self.offset_provider_type) @@ -739,7 +811,7 @@ def construct_output_for_nested_sdfg( head_state.add_edge( nsdfg_node, connector, outer_node, None, sdfg.make_array_memlet(outer) ) - outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) + outer_data = inner_data.make_copy(outer_node) elif inner_data.dc_node.data in lambda_arg_nodes: # This if branch and the next one handle the non-transient result nodes. # Non-transient nodes are just input nodes that are immediately returned @@ -748,7 +820,7 @@ def construct_output_for_nested_sdfg( outer_data = lambda_arg_nodes[inner_data.dc_node.data] else: outer_node = head_state.add_access(inner_data.dc_node.data) - outer_data = gtir_builtin_translators.FieldopData(outer_node, inner_data.gt_type) + outer_data = inner_data.make_copy(outer_node) # Isolated access node will make validation fail. # Isolated access nodes can be found in the join-state of an if-expression # or in lambda expressions that just construct tuples from input arguments. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index caec6cd87e..118f0449c8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import itertools -from typing import Any, Dict, TypeVar +from typing import Dict, TypeVar import dace @@ -58,15 +58,6 @@ def get_tuple_fields( return fields -def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: - """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. - """ - return ts.TupleType( - types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] - ) - - def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: """ Ensure that all symbols used in the program IR are valid strings (e.g. no unicode-strings). diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 9c52ea81c3..f5191fbaaa 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -47,7 +47,7 @@ VFTYPE = ts.FieldType(dims=[Vertex], dtype=FLOAT_TYPE) V2E_FTYPE = ts.FieldType(dims=[Vertex, V2EDim], dtype=EFTYPE.dtype) CARTESIAN_OFFSETS = { - "IDim": IDim, + IDim.value: IDim, } SIMPLE_MESH: MeshDescriptor = simple_mesh() SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() @@ -735,13 +735,13 @@ def test_gtir_cartesian_shift_left(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -749,13 +749,15 @@ def test_gtir_cartesian_shift_left(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -764,14 +766,14 @@ def test_gtir_cartesian_shift_left(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -828,13 +830,13 @@ def test_gtir_cartesian_shift_right(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", -OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, -OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", -OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -842,13 +844,15 @@ def test_gtir_cartesian_shift_right(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -857,14 +861,14 @@ def test_gtir_cartesian_shift_right(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -1539,6 +1543,91 @@ def test_gtir_reduce_with_cond_neighbors(): assert np.allclose(v, v_ref) +def test_gtir_symbolic_domain(): + MARGIN = 2 + assert MARGIN < N + OFFSET = 1000 * 1000 * 1000 + domain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + left_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.minus(MARGIN, OFFSET), im.minus(im.minus("size", MARGIN), OFFSET))}, + ) + right_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.plus(MARGIN, OFFSET), im.plus(im.plus("size", MARGIN), OFFSET))}, + ) + shift_left_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))) + shift_right_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))) + testee = gtir.Program( + id="symbolic_domain", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let( + "xᐞ1", + im.op_as_fieldop("multiplies", left_domain)( + 4.0, + im.as_fieldop( + shift_left_stencil, + left_domain, + )("x"), + ), + )( + im.let( + "xᐞ2", + im.op_as_fieldop("multiplies", right_domain)( + 3.0, + im.as_fieldop( + shift_right_stencil, + right_domain, + )("x"), + ), + )( + im.let( + "xᐞ3", + im.as_fieldop( + shift_right_stencil, + domain, + )("xᐞ1"), + )( + im.let( + "xᐞ4", + im.as_fieldop( + shift_left_stencil, + domain, + )("xᐞ2"), + )( + im.let("xᐞ5", im.op_as_fieldop("plus", domain)("xᐞ3", "xᐞ4"))( + im.op_as_fieldop("plus", domain)("xᐞ5", "x") + ) + ) + ) + ) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + ref = np.concatenate((b[0:MARGIN], a[MARGIN : N - MARGIN] * 8, b[N - MARGIN : N])) + + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + sdfg(a, b, **FSYMBOLS) + assert np.allclose(b, ref) + + def test_gtir_let_lambda(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) subdomain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) @@ -1722,7 +1811,7 @@ def test_gtir_let_lambda_with_cond(): def test_gtir_let_lambda_with_tuple1(): - domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) testee = gtir.Program( id="let_lambda_with_tuple1", function_definitions=[], @@ -1753,10 +1842,12 @@ def test_gtir_let_lambda_with_tuple1(): sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a)) + a_ref = np.concatenate((z_fields[0][:1], a[1 : N - 1], z_fields[0][N - 1 :])) + b_ref = np.concatenate((z_fields[1][:1], b[1 : N - 1], z_fields[1][N - 1 :])) sdfg(a, b, *z_fields, **FSYMBOLS) - assert np.allclose(z_fields[0], a) - assert np.allclose(z_fields[1], b) + assert np.allclose(z_fields[0], a_ref) + assert np.allclose(z_fields[1], b_ref) def test_gtir_let_lambda_with_tuple2(): From f6c219bd989e3c5325da1173bade4bff2ac9e650 Mon Sep 17 00:00:00 2001 From: SF-N Date: Tue, 26 Nov 2024 15:59:58 +0100 Subject: [PATCH 03/13] bug[next]: Fix SetAt type inference for ts.DeferredType (#1747) Fix to correctly handle tuples of ts.DeferredType. --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/type_system/inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 987eb0f308..249019769b 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -509,7 +509,10 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: # the target can have fewer elements than the expr in which case the output from the # expression is simply discarded. expr_type = functools.reduce( - lambda tuple_type, i: tuple_type.types[i], # type: ignore[attr-defined] # format ensured by primitive_constituents + lambda tuple_type, i: tuple_type.types[i] # type: ignore[attr-defined] # format ensured by primitive_constituents + # `ts.DeferredType` only occurs for scans returning a tuple + if not isinstance(tuple_type, ts.DeferredType) + else ts.DeferredType(constraint=None), path, node.expr.type, ) From f6c0498dbffd85a80a32281e5a53bfb35e00e745 Mon Sep 17 00:00:00 2001 From: edopao Date: Wed, 27 Nov 2024 09:55:46 +0100 Subject: [PATCH 04/13] feat[next][dace]: Lowering to SDFG of index builtin (#1751) Implements the lowering to SDFG of the GTIR index builtin. --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 14 ++++ .../gtir_builtin_translators.py | 83 ++++++++++++++++--- .../runners/dace_fieldview/gtir_sdfg.py | 2 + tests/next_tests/definitions.py | 1 - .../dace_tests/test_gtir_to_sdfg.py | 50 ++++++++++- 5 files changed, 134 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 2864c7f727..a4e111e785 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -519,6 +519,20 @@ def _impl(it: itir.Expr) -> itir.FunCall: return _impl +def index(dim: common.Dimension) -> itir.FunCall: + """ + Create a call to the `index` builtin, shorthand for `call("index")(axis)`, + after converting the given dimension to `itir.AxisLiteral`. + + Args: + dim: the dimension corresponding to the index axis. + + Returns: + A function that constructs a Field of indices in the given dimension. + """ + return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind)) + + def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 60dcd8ddc9..94ab3a6f76 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,7 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -277,20 +277,31 @@ def extract_domain(node: gtir.Node) -> FieldopDomain: the corresponding lower and upper bounds. The returned lower bound is inclusive, the upper bound is exclusive: [lower_bound, upper_bound[ """ - assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) domain = [] - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - axis = named_range.args[0] - assert isinstance(axis, gtir.AxisLiteral) - lower_bound, upper_bound = ( - dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg)) - for arg in named_range.args[1:3] - ) - dim = gtx_common.Dimension(axis.value, axis.kind) - domain.append((dim, lower_bound, upper_bound)) + + def parse_range_boundary(expr: gtir.Expr) -> str: + return dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(expr)) + + if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, gtir.AxisLiteral) + lower_bound, upper_bound = (parse_range_boundary(arg) for arg in named_range.args[1:3]) + dim = gtx_common.Dimension(axis.value, axis.kind) + domain.append((dim, lower_bound, upper_bound)) + + elif isinstance(node, domain_utils.SymbolicDomain): + assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"} + for dim, drange in node.ranges.items(): + domain.append( + (dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop)) + ) + + else: + raise ValueError(f"Invalid domain {node}.") return domain @@ -545,6 +556,51 @@ def construct_output(inner_data: FieldopData) -> FieldopData: return result_temps +def translate_index( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """ + Lowers the `index` builtin function to a mapped tasklet that writes the dimension + index values to a transient array. The extent of the index range is taken from + the domain information that should be present in the node annex. + """ + assert "domain" in node.annex + domain = extract_domain(node.annex.domain) + assert len(domain) == 1 + dim, lower_bound, upper_bound = domain[0] + dim_index = dace_gtir_utils.get_map_variable(dim) + + field_dims, field_offset, field_shape = _get_field_layout(domain) + field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) + + output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) + output_node = state.add_access(output) + + sdfg_builder.add_mapped_tasklet( + "index", + state, + map_ranges={ + dim_index: f"{lower_bound}:{upper_bound}", + }, + inputs={}, + code=f"__val = {dim_index}", + outputs={ + "__val": dace.Memlet( + data=output_node.data, + subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), + ) + }, + input_nodes={}, + output_nodes={output_node.data: output_node}, + external_edges=True, + ) + + return FieldopData(output_node, field_type, field_offset) + + def _get_data_nodes( sdfg: dace.SDFG, state: dace.SDFGState, @@ -777,6 +833,7 @@ def translate_symbol_ref( translate_as_fieldop, translate_broadcast_scalar, translate_if, + translate_index, translate_literal, translate_make_tuple, translate_tuple_get, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index f15287e64c..6b5e164458 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -568,6 +568,8 @@ def visit_FunCall( # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node, "if_"): return gtir_builtin_translators.translate_if(node, sdfg, head_state, self) + elif cpm.is_call_to(node, "index"): + return gtir_builtin_translators.translate_index(node, sdfg, head_state, self) elif cpm.is_call_to(node, "make_tuple"): return gtir_builtin_translators.translate_make_tuple(node, sdfg, head_state, self) elif cpm.is_call_to(node, "tuple_get"): diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 01fd18897d..349d3e9f70 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -154,7 +154,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ - (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index f5191fbaaa..c7466b853f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -12,15 +12,15 @@ Note: this test module covers the fieldview flavour of ITIR. """ -import copy import functools import numpy as np import pytest -from gt4py.next import common as gtx_common, constructors +from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import infer_domain from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -1973,3 +1973,49 @@ def test_gtir_if_values(): sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, np.where(a < b, a, b)) + + +def test_gtir_index(): + MARGIN = 2 + assert MARGIN < N + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + subdomain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + + testee = gtir.Program( + id="gtir_cast", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("i", im.index(IDim))( + im.op_as_fieldop("plus", domain)( + "i", + im.as_fieldop( + im.lambda_("a")(im.deref(im.shift(IDim.value, 1)("a"))), subdomain + )("i"), + ) + ), + domain=subdomain, + target=gtir.SymRef(id="x"), + ) + ], + ) + + v = np.empty(N, dtype=np.int32) + + # we need to run domain inference in order to add the domain annex information to the index node. + testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + ref = np.concatenate( + (v[:MARGIN], np.arange(MARGIN, N - MARGIN, dtype=np.int32), v[N - MARGIN :]) + ) + + sdfg(v, **FSYMBOLS) + np.allclose(v, ref) From 3ece412f0d78f32893d8f01ed0e74c8b38388854 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 28 Nov 2024 13:13:55 -0500 Subject: [PATCH 05/13] fix[cartesian]: Deactivate K offset write in `gt:gpu` (#1755) Following the issue logged as https://github.com/GridTools/gt4py/issues/1754 we are deactivating the K-offset write feature until we can figure out why it's failing. I will monitor any activity on the ticket if users are hit by this. --------- Co-authored-by: Hannes Vogt --- src/gt4py/cartesian/frontend/gtscript_frontend.py | 7 +++++++ .../multi_feature_tests/test_code_generation.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index ade05921ef..f155ea6209 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1460,6 +1460,13 @@ def visit_Assign(self, node: ast.Assign) -> list: loc=nodes.Location.from_ast_node(t), ) + if self.backend_name in ["gt:gpu"]: + raise GTScriptSyntaxError( + message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} as an unsolved bug remains." + "Please refer to https://github.com/GridTools/gt4py/issues/1754.", + loc=nodes.Location.from_ast_node(t), + ) + if not self._is_known(name): if name in self.temp_decls: field_decl = self.temp_decls[name] diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index c4d07d7337..7c4956b3ef 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -667,6 +667,10 @@ def test_K_offset_write_conditional(backend): pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) From 886058496c1ebcb90ba530a796213d1fec7c7095 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 29 Nov 2024 08:46:06 +0100 Subject: [PATCH 06/13] refact[next][dace]: Helper function for field operator constructor (#1743) Includes refactoring of the code for construction of field operators, in order to make it usable by the three lowering functions that construct fields: `translate_as_fieldop()`, `translate_broadcast_scalar()`, and `translate_index()`. --- .../gtir_builtin_translators.py | 242 +++++++----------- 1 file changed, 94 insertions(+), 148 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 94ab3a6f76..ff011c4193 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,11 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -229,40 +233,75 @@ def _get_field_layout( return list(domain_dims), list(domain_lbs), domain_sizes -def _create_temporary_field( +def _create_field_operator( sdfg: dace.SDFG, state: dace.SDFGState, domain: FieldopDomain, node_type: ts.FieldType, - dataflow_output: gtir_dataflow.DataflowOutputEdge, + sdfg_builder: gtir_sdfg.SDFGBuilder, + input_edges: Sequence[gtir_dataflow.DataflowInputEdge], + output_edge: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: - """Helper method to allocate a temporary field where to write the output of a field operator.""" + """ + Helper method to allocate a temporary field to store the output of a field operator. + + Args: + sdfg: The SDFG that represents the scope of the field data. + state: The SDFG state where to create an access node to the field data. + domain: The domain of the field operator that computes the field. + node_type: The GT4Py type of the IR node that produces this field. + sdfg_builder: The object used to build the map scope in the provided SDFG. + input_edges: List of edges to pass input data into the dataflow. + output_edge: Edge representing the dataflow output data. + + Returns: + The field data descriptor, which includes the field access node in the given `state` + and the field domain offset. + """ field_dims, field_offset, field_shape = _get_field_layout(domain) + field_indices = _get_domain_indices(field_dims, field_offset) + + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - output_desc = dataflow_output.result.dc_node.desc(sdfg) - if isinstance(output_desc, dace.data.Array): + field_subset = sbs.Range.from_indices(field_indices) + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + assert output_edge.result.gt_dtype == node_type.dtype + assert isinstance(dataflow_output_desc, dace.data.Scalar) + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) + field_dtype = output_edge.result.gt_dtype + else: assert isinstance(node_type.dtype, itir_ts.ListType) - assert isinstance(node_type.dtype.element_type, ts.ScalarType) - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) + assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = output_edge.result.gt_dtype.element_type # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) - field_offset.extend(output_desc.offset) - field_shape.extend(output_desc.shape) - elif isinstance(output_desc, dace.data.Scalar): - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) - else: - raise ValueError(f"Cannot create field for dace type {output_desc}.") + assert output_edge.result.gt_dtype.offset_type is not None + field_dims.append(output_edge.result.gt_dtype.offset_type) + field_shape.extend(dataflow_output_desc.shape) + field_offset.extend(dataflow_output_desc.offset) + field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) # allocate local temporary storage - temp_name, _ = sdfg.add_temp_transient(field_shape, output_desc.dtype) - field_node = state.add_access(temp_name) + field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + field_node = state.add_access(field_name) - if isinstance(dataflow_output.result.gt_dtype, ts.ScalarType): - field_dtype = dataflow_output.result.gt_dtype - else: - assert isinstance(dataflow_output.result.gt_dtype.element_type, ts.ScalarType) - field_dtype = dataflow_output.result.gt_dtype.element_type - assert dataflow_output.result.gt_dtype.offset_type is not None - field_dims.append(dataflow_output.result.gt_dtype.offset_type) + # create map range corresponding to the field operator domain + me, mx = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + }, + ) + + # here we setup the edges passing through the map entry node + for edge in input_edges: + edge.connect(me) + + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(mx, field_node, field_subset) return FieldopData( field_node, @@ -341,7 +380,8 @@ def translate_as_fieldop( # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. - return translate_broadcast_scalar(node, sdfg, state, sdfg_builder) + stencil_expr = im.lambda_("a")(im.deref("a")) + stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] else: raise NotImplementedError( f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." @@ -349,117 +389,18 @@ def translate_as_fieldop( # parse the domain of the field operator domain = extract_domain(domain_expr) - domain_dims, domain_offsets, _ = zip(*domain) - domain_indices = _get_domain_indices(domain_dims, domain_offsets) # visit the list of arguments to be passed to the lambda expression stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) - output_desc = output.result.dc_node.desc(sdfg) - - if isinstance(node.type.dtype, itir_ts.ListType): - assert isinstance(output_desc, dace.data.Array) - # additional local dimension for neighbors - # TODO(phimuell): Investigate if we should swap the two. - output_subset = sbs.Range.from_indices(domain_indices) + sbs.Range.from_array(output_desc) - else: - assert isinstance(output_desc, dace.data.Scalar) - output_subset = sbs.Range.from_indices(domain_indices) - - # create map range corresponding to the field operator domain - me, mx = sdfg_builder.add_map( - "fieldop", - state, - ndrange={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - ) - - # allocate local temporary storage for the result field - result_field = _create_temporary_field(sdfg, state, domain, node.type, output) - - # here we setup the edges from the map entry node - for edge in input_edges: - edge.connect(me) - - # and here the edge writing the result data through the map exit node - output.connect(mx, result_field.dc_node, output_subset) - - return result_field - + input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) -def translate_broadcast_scalar( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, -) -> FieldopResult: - """ - Generates the dataflow subgraph for the 'as_fieldop' builtin function for the - special case where the argument to 'as_fieldop' is a 'deref' scalar expression, - rather than a lambda function. This case corresponds to broadcasting the scalar - value over the field domain. Therefore, it is lowered to a mapped tasklet that - just writes the scalar value out to all elements of the result field. - """ - assert isinstance(node, gtir.FunCall) - assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) - - fun_node = node.fun - assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args - assert cpm.is_ref_to(stencil_expr, "deref") - - domain = extract_domain(domain_expr) - output_dims, output_offset, output_shape = _get_field_layout(domain) - output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset)) - - assert len(node.args) == 1 - scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) - - if isinstance(node.args[0].type, ts.ScalarType): - assert isinstance(scalar_expr, (gtir_dataflow.MemletExpr, gtir_dataflow.ValueExpr)) - input_subset = ( - str(scalar_expr.subset) if isinstance(scalar_expr, gtir_dataflow.MemletExpr) else "0" - ) - input_node = scalar_expr.dc_node - gt_dtype = node.args[0].type - elif isinstance(node.args[0].type, ts.FieldType): - assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) - if len(node.args[0].type.dims) == 0: # zero-dimensional field - input_subset = "0" - else: - input_subset = scalar_expr.get_memlet_subset(sdfg) - - input_node = scalar_expr.field - gt_dtype = node.args[0].type.dtype - else: - raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") - - output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( - "broadcast", - state, - map_ranges={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, - code="__val = __inp", - outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)}, - input_nodes={input_node.data: input_node}, - output_nodes={output_node.data: output_node}, - external_edges=True, + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge ) - return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset) - def translate_if( node: gtir.Node, @@ -567,38 +508,44 @@ def translate_index( index values to a transient array. The extent of the index range is taken from the domain information that should be present in the node annex. """ + assert cpm.is_call_to(node, "index") + assert isinstance(node.type, ts.FieldType) + assert "domain" in node.annex domain = extract_domain(node.annex.domain) assert len(domain) == 1 - dim, lower_bound, upper_bound = domain[0] + dim, _, _ = domain[0] dim_index = dace_gtir_utils.get_map_variable(dim) - field_dims, field_offset, field_shape = _get_field_layout(domain) - field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) - - output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( + index_data = sdfg.temp_data_name() + sdfg.add_scalar(index_data, INDEX_DTYPE, transient=True) + index_node = state.add_access(index_data) + index_value = gtir_dataflow.ValueExpr( + dc_node=index_node, + gt_dtype=dace_utils.as_itir_type(INDEX_DTYPE), + ) + index_write_tasklet = sdfg_builder.add_tasklet( "index", state, - map_ranges={ - dim_index: f"{lower_bound}:{upper_bound}", - }, inputs={}, + outputs={"__val"}, code=f"__val = {dim_index}", - outputs={ - "__val": dace.Memlet( - data=output_node.data, - subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), - ) - }, - input_nodes={}, - output_nodes={output_node.data: output_node}, - external_edges=True, + ) + state.add_edge( + index_write_tasklet, + "__val", + index_node, + None, + dace.Memlet(data=index_data, subset="0"), ) - return FieldopData(output_node, field_type, field_offset) + input_edges = [ + gtir_dataflow.EmptyInputEdge(state, index_write_tasklet), + ] + output_edge = gtir_dataflow.DataflowOutputEdge(state, index_value) + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge + ) def _get_data_nodes( @@ -831,7 +778,6 @@ def translate_symbol_ref( # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ translate_as_fieldop, - translate_broadcast_scalar, translate_if, translate_index, translate_literal, From d9b38f476ee5df1995d27b7497037f3f19c9b6e6 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 29 Nov 2024 02:50:43 -0500 Subject: [PATCH 07/13] hotfix[cartesian]: Fixing k offset write utest deactivate (#1757) Missed a utest in #1755 --- .../multi_feature_tests/test_code_generation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 7c4956b3ef..e51b3ef09d 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -582,13 +582,17 @@ def test_K_offset_write(backend): # Cuda generates bad code for the K offset if backend == "cuda": pytest.skip("cuda K-offset write generates bad code") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) @@ -660,7 +664,7 @@ def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): def test_K_offset_write_conditional(backend): if backend == "cuda": pytest.skip("Cuda backend is not capable of K offset write") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: From 791f67d031127872fc6375819267f59faeaf85ba Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 29 Nov 2024 10:02:34 +0100 Subject: [PATCH 08/13] test[next]: Fix flaky failure in GTIR to SDFG tests (#1759) The SDFG name has to be unique to avoid issues with parallel build in CI tests. --- .../runners_tests/dace_tests/test_gtir_to_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index c7466b853f..b1ba4ccf22 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1984,7 +1984,7 @@ def test_gtir_index(): ) testee = gtir.Program( - id="gtir_cast", + id="gtir_index", function_definitions=[], params=[ gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), From 04513ba859d5ed55ea99999f6fd826a2a542a627 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Fri, 29 Nov 2024 13:57:10 +0100 Subject: [PATCH 09/13] fix[next]: use current working directory as default cache folder root (#1744) Change the root folder of the gt4py cache directory from the system temp folder to the current working directory, which is more visible and also avoids polluting shared filesystems in hpc clusters. --------- Co-authored-by: Hannes Vogt --- src/gt4py/next/config.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index ed244c2932..7a19f3eb9d 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -11,7 +11,6 @@ import enum import os import pathlib -import tempfile from typing import Final @@ -51,25 +50,22 @@ def env_flag_to_bool(name: str, default: bool) -> bool: ) -_PREFIX: Final[str] = "GT4PY" - #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) +DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False) #: Verbose flag for DSL compilation errors VERBOSE_EXCEPTIONS: bool = env_flag_to_bool( - f"{_PREFIX}_VERBOSE_EXCEPTIONS", default=True if DEBUG else False + "GT4PY_VERBOSE_EXCEPTIONS", default=True if DEBUG else False ) #: Where generated code projects should be persisted. #: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT BUILD_CACHE_DIR: pathlib.Path = ( - pathlib.Path(os.environ.get(f"{_PREFIX}_BUILD_CACHE_DIR", tempfile.gettempdir())) - / "gt4py_cache" + pathlib.Path(os.environ.get("GT4PY_BUILD_CACHE_DIR", pathlib.Path.cwd())) / ".gt4py_cache" ) @@ -77,11 +73,11 @@ def env_flag_to_bool(name: str, default: bool) -> bool: #: - SESSION: generated code projects get destroyed when the interpreter shuts down #: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs BUILD_CACHE_LIFETIME: BuildCacheLifetime = BuildCacheLifetime[ - os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() + os.environ.get("GT4PY_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() ] #: Build type to be used when CMake is used to compile generated code. #: Might have no effect when CMake is not used as part of the toolchain. CMAKE_BUILD_TYPE: CMakeBuildType = CMakeBuildType[ - os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() + os.environ.get("GT4PY_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() ] From d581060e5c6e8b6f64b72cce041d539956ca4727 Mon Sep 17 00:00:00 2001 From: SF-N Date: Sat, 30 Nov 2024 09:39:26 +0100 Subject: [PATCH 10/13] bug[next]: ConstantFolding after create_global_tmps (#1756) Do `ConstantFolding` within `domain_union` to avoid nested minima and maxima by `create_global_tmps` --------- Co-authored-by: Till Ehrengruber --- src/gt4py/next/iterator/ir_utils/domain_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index f5625b509c..4a023f7535 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -16,6 +16,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: @@ -168,6 +169,8 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), [domain.ranges[dim].stop for domain in domains], ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) From a26d91f409ea5d67f168bbbc4a2157df2ed1080b Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 21:31:13 +0100 Subject: [PATCH 11/13] fix[next]: Fix annex & type preservation in inline_lambdas (#1760) Co-authored-by: SF-N --- src/gt4py/next/iterator/transforms/inline_lambdas.py | 11 +++++------ src/gt4py/next/iterator/transforms/remap_symbols.py | 5 ++++- src/gt4py/next/iterator/type_system/inference.py | 7 +++++-- .../transforms_tests/test_inline_lambdas.py | 7 +++++++ 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 5ec9ec5d0b..9053214b39 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -97,7 +97,6 @@ def new_name(name): if all(eligible_params): new_expr.location = node.location - return new_expr else: new_expr = ir.FunCall( fun=ir.Lambda( @@ -111,11 +110,11 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) - for attr in ("type", "recorded_shifts", "domain"): - if hasattr(node.annex, attr): - setattr(new_expr.annex, attr, getattr(node.annex, attr)) - itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) - return new_expr + for attr in ("type", "recorded_shifts", "domain"): + if hasattr(node.annex, attr): + setattr(new_expr.annex, attr, getattr(node.annex, attr)) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 08d896121d..fb909dc5d0 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -10,6 +10,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir +from gt4py.next.iterator.type_system import inference as type_inference class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): @@ -46,7 +47,9 @@ def visit_SymRef( self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.SymRef(id=name_map.get(node.id, node.id)) + new_ref = ir.SymRef(id=name_map.get(node.id, node.id)) + type_inference.copy_type(from_=node, to=new_ref, allow_untyped=True) + return new_ref return node def generic_visit( # type: ignore[override] diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 249019769b..ffca6cc7a7 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -95,14 +95,17 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None: +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped: bool = False) -> None: """ Copy type from one node to another. This function mainly exists for readability reasons. """ assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) - _set_node_type(to, from_.type) # type: ignore[arg-type] + if from_.type is None: + assert allow_untyped + return + _set_node_type(to, from_.type) def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 2e0a83d33b..c10d48ad06 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -84,3 +84,10 @@ def test_inline_lambda_args(): ) inlined = InlineLambdas.apply(testee, opcount_preserving=True, force_inline_lambda_args=True) assert inlined == expected + + +def test_type_preservation(): + testee = im.let("a", "b")("a") + testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) + inlined = InlineLambdas.apply(testee) + assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32) From 99c53004663b0b58c7ce8335bcc30e347d3686b5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 1 Dec 2024 22:08:39 +0100 Subject: [PATCH 12/13] refactor[next]: Use `set_at` & `as_fieldop` instead of `closure` in iterator tests (#1691) --- .../test_cartesian_offset_provider.py | 12 +++--- .../iterator_tests/test_conditional.py | 2 +- .../test_strided_offset_provider.py | 7 ++-- .../iterator_tests/test_trivial.py | 10 ++--- .../iterator_tests/test_tuple.py | 28 +++++-------- .../iterator_tests/test_anton_toy.py | 21 +++++----- .../iterator_tests/test_fvm_nabla.py | 40 ++++++++----------- .../iterator_tests/test_hdiff.py | 10 ++--- 8 files changed, 55 insertions(+), 75 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py index 2ebcd0c033..fedfd83fd2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py @@ -10,7 +10,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import double_roundtrip, roundtrip @@ -27,16 +27,14 @@ def foo(inp): @fendef(offset_provider={"I": I_loc, "J": J_loc}) def fencil(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) @fendef(offset_provider={"I": J_loc, "J": I_loc}) def fencil_swapped(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) def test_cartesian_offset_provider(): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index 551c567e61..eae66d425b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 7bde55bfd2..68e5f9d532 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -10,8 +10,8 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain, as_fieldop +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.unit_tests.conftest import program_processor, run_processor from gt4py.next.iterator.embedded import StridedConnectivityField @@ -36,7 +36,8 @@ def foo(inp): @fendef(offset_provider={"O": LocA2LocAB_offset_provider}) def fencil(size, out, inp): - closure(unstructured_domain(named_range(LocA, 0, size)), foo, out, [inp]) + domain = unstructured_domain(named_range(LocA, 0, size)) + set_at(as_fieldop(foo, domain)(inp), domain, out) @pytest.mark.uses_strided_neighbor_offset diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 5f1c70a6b3..fe89fe7c9d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -12,7 +12,7 @@ import gt4py.next as gtx from gt4py.next.iterator import transforms from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -94,12 +94,8 @@ def test_shifted_arg_to_lift(program_processor): @fendef def fen_direct_deref(i_size, j_size, out, inp): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)), - deref, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)) + set_at(as_fieldop(deref, domain)(inp), domain, out) def test_direct_deref(program_processor): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 2d84439c93..39d0bd69c3 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor @@ -114,16 +114,10 @@ def test_tuple_of_field_output_constructed_inside(program_processor, stencil): @fendef def fencil(size0, size1, size2, inp1, inp2, out1, out2): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, - make_tuple(out1, out2), - [inp1, inp2], + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) ) + set_at(as_fieldop(stencil, domain)(inp1, inp2), domain, make_tuple(out1, out2)) shape = [5, 7, 9] rng = np.random.default_rng() @@ -159,15 +153,13 @@ def stencil(inp1, inp2, inp3): @fendef def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) + ) + set_at( + as_fieldop(stencil, domain)(inp1, inp2, inp3), + domain, make_tuple(make_tuple(out1, out2), out3), - [inp1, inp2, inp3], ) shape = [5, 7, 9] diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 3ce9d6b470..d0a1601816 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -10,8 +10,15 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import ( + cartesian_domain, + deref, + lift, + named_range, + shift, + as_fieldop, +) +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.unit_tests.conftest import program_processor, run_processor @@ -85,14 +92,10 @@ def test_anton_toy(stencil, program_processor): @fendef(offset_provider={"i": IDim, "j": JDim}) def fencil(x, y, z, out, inp): - closure( - cartesian_domain( - named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) - ), - stencil, - out, - [inp], + domain = cartesian_domain( + named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) ) + set_at(as_fieldop(stencil, domain)(inp), domain, out) shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 4487681abf..22b4d8b3c5 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -28,8 +28,9 @@ reduce, tuple_get, unstructured_domain, + as_fieldop, ) -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, @@ -55,7 +56,8 @@ def compute_zavgS(pp, S_M): @fendef def compute_zavgS_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS, domain)(pp, S_M), domain, out) @fundef @@ -100,12 +102,8 @@ def compute_pnabla2(pp, S_M, sign, vol): @fendef def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - pnabla, - out, - [pp, S_MXX, S_MYY, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(pnabla, domain)(pp, S_MXX, S_MYY, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -145,7 +143,8 @@ def test_compute_zavgS(program_processor): @fendef def compute_zavgS2_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS2, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS2, domain)(pp, S_M), domain, out) @pytest.mark.requires_atlas @@ -212,12 +211,8 @@ def test_nabla(program_processor): @fendef def nabla2(n_nodes, out, pp, S, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla2, - out, - [pp, S, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(compute_pnabla2, domain)(pp, S, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -276,17 +271,16 @@ def compute_pnabla_sign(pp, S_M, vol, node_index, is_pole_edge): @fendef def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_pole_edge): # TODO replace by single stencil which returns tuple - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MXX, vol, node_index, is_pole_edge), + domain, out_MXX, - [pp, S_MXX, vol, node_index, is_pole_edge], ) - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MYY, vol, node_index, is_pole_edge), + domain, out_MYY, - [pp, S_MYY, vol, node_index, is_pole_edge], ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 45793b1d3e..e44e92013f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim @@ -57,12 +57,8 @@ def hdiff_sten(inp, coeff): @fendef(offset_provider={"I": IDim, "J": JDim}) def hdiff(inp, coeff, out, x, y): - closure( - cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), - hdiff_sten, - out, - [inp, coeff], - ) + domain = cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)) + set_at(as_fieldop(hdiff_sten, domain)(inp, coeff), domain, out) @pytest.mark.uses_origin From 6f49699f00ceb9e466fa4448bab779bc061df047 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 2 Dec 2024 13:09:47 +0100 Subject: [PATCH 13/13] style[eve]: remove unused imports and fix typos (#1748) Small cleanup PR in the eve framework: - Removes a stale `.gitignore` file. As far as I understood from the git history, earlier versions of this codebase had many `.gitignore` files in many places. Looks like this one is a leftover from a previous time. - Remove a couple of stale includes. The language server marked them as unused and since tests still pass, I guess we really don't need them anymore. - Fixed a couple of typos in comments - Fixed two typos in the github PR template --- .github/pull_request_template.md | 4 ++-- src/gt4py/eve/.gitignore | 1 - src/gt4py/eve/__init__.py | 14 ++------------ src/gt4py/eve/codegen.py | 6 +++--- src/gt4py/eve/datamodels/__init__.py | 4 ++-- src/gt4py/eve/datamodels/core.py | 16 ++++++++-------- src/gt4py/eve/extended_typing.py | 4 ---- src/gt4py/eve/trees.py | 8 -------- src/gt4py/eve/type_validation.py | 2 +- src/gt4py/eve/utils.py | 2 +- src/gt4py/next/ffront/decorator.py | 2 +- 11 files changed, 20 insertions(+), 43 deletions(-) delete mode 100644 src/gt4py/eve/.gitignore diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7284a7df04..83304a9c62 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -15,7 +15,7 @@ Delete this comment and add a proper description of the changes contained in thi - test: Adding missing tests or correcting existing tests : cartesian | eve | next | storage - # ONLY if changes are limited to a specific subsytem + # ONLY if changes are limited to a specific subsystem - PR Description: @@ -27,7 +27,7 @@ Delete this comment and add a proper description of the changes contained in thi ## Requirements - [ ] All fixes and/or new features come with corresponding tests. -- [ ] Important design decisions have been documented in the approriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. +- [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. If this PR contains code authored by new contributors please make sure: diff --git a/src/gt4py/eve/.gitignore b/src/gt4py/eve/.gitignore deleted file mode 100644 index 050cda3ca5..0000000000 --- a/src/gt4py/eve/.gitignore +++ /dev/null @@ -1 +0,0 @@ -_version.py diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 0b8cfa7d62..5adac47da3 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -24,8 +24,7 @@ """ -from __future__ import annotations # isort:skip - +from __future__ import annotations from .concepts import ( AnnexManager, @@ -89,15 +88,6 @@ "SymbolRef", "VType", "register_annex_user", - "# datamodels" "Coerced", - "DataModel", - "FrozenModel", - "GenericDataModel", - "Unchecked", - "concretize", - "datamodel", - "field", - "frozenmodel", # datamodels "Coerced", "DataModel", @@ -122,7 +112,7 @@ "pre_walk_values", "walk_items", "walk_values", - "# type_definition", + # type_definitions "NOTHING", "ConstrainedStr", "Enum", diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 15fda4f3b4..3869ff313b 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -347,7 +347,7 @@ def __str__(self) -> str: class Template(Protocol): """Protocol (abstract base class) defining the Template interface. - Direct subclassess of this base class only need to implement the + Direct subclasses of this base class only need to implement the abstract methods to adapt different template engines to this interface. @@ -654,8 +654,8 @@ def apply( # redefinition of symbol Args: root: An IR node. - node_templates (optiona): see :class:`NodeDumper`. - dump_function (optiona): see :class:`NodeDumper`. + node_templates (optional): see :class:`NodeDumper`. + dump_function (optional): see :class:`NodeDumper`. ``**kwargs`` (optional): custom extra parameters forwarded to `visit_NODE_TYPE_NAME()`. Returns: diff --git a/src/gt4py/eve/datamodels/__init__.py b/src/gt4py/eve/datamodels/__init__.py index 68ddea2510..6fd9c7bb21 100644 --- a/src/gt4py/eve/datamodels/__init__.py +++ b/src/gt4py/eve/datamodels/__init__.py @@ -11,7 +11,7 @@ Data Models can be considered as enhanced `attrs `_ / `dataclasses `_ providing additional features like automatic run-time type validation. Values assigned to fields -at initialization can be validated with automatic type checkings using the +at initialization can be validated with automatic type checking using the field type definition. Custom field validation methods can also be added with the :func:`validator` decorator, and global instance validation methods with :func:`root_validator`. @@ -33,7 +33,7 @@ 1. ``__init__()``. a. If a custom ``__init__`` already exists in the class, it will not be overwritten. - It is your responsability to call ``__auto_init__`` from there to obtain + It is your responsibility to call ``__auto_init__`` from there to obtain the described behavior. b. If there is not custom ``__init__``, the one generated by datamodels will be called first. diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index d596f59cfb..1b0e995156 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -24,7 +24,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz @@ -270,7 +270,7 @@ def datamodel( @overload -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Type[_T], /, *, @@ -289,7 +289,7 @@ def datamodel( # redefinion of unused symbol # TODO(egparedes): Use @dataclass_transform(eq_default=True, field_specifiers=("field",)) -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Optional[Type[_T]] = None, /, *, @@ -867,7 +867,7 @@ def _substitute_typevars( def _make_counting_attr_from_attribute( field_attrib: Attribute, *, include_type: bool = False, **kwargs: Any -) -> Any: # attr.s lies a bit in some typing definitons +) -> Any: # attr.s lies a bit in some typing definitions args = [ "default", "validator", @@ -965,7 +965,7 @@ def _type_converter(value: Any) -> _T: return value if isinstance(value, type_annotation) else type_annotation(value) except Exception as error: raise TypeError( - f"Error during coertion of given value '{value}' for field '{name}'." + f"Error during coercion of given value '{value}' for field '{name}'." ) from error return _type_converter @@ -996,7 +996,7 @@ def _type_converter(value: Any) -> _T: return _make_type_converter(origin_type, name) raise exceptions.EveTypeError( - f"Automatic type coertion for {type_annotation} types is not supported." + f"Automatic type coercion for {type_annotation} types is not supported." ) @@ -1085,7 +1085,7 @@ def _make_datamodel( ) else: - # Create field converter if automatic coertion is enabled + # Create field converter if automatic coercion is enabled converter: TypeConverter = cast( TypeConverter, _make_type_converter(type_hint, qualified_field_name) if coerce_field else None, @@ -1099,7 +1099,7 @@ def _make_datamodel( if isinstance(attr_value_in_cls, _KNOWN_MUTABLE_TYPES): warnings.warn( f"'{attr_value_in_cls.__class__.__name__}' value used as default in '{cls.__name__}.{key}'.\n" - "Mutable types should not defbe normally used as field defaults (use 'default_factory' instead).", + "Mutable types should not be used as field defaults (use 'default_factory' instead).", stacklevel=_stacklevel_offset + 2, ) setattr( diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index e276f3bccf..bf44824b49 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -14,12 +14,8 @@ from __future__ import annotations -import abc as _abc import array as _array -import collections.abc as _collections_abc -import ctypes as _ctypes import dataclasses as _dataclasses -import enum as _enum import functools as _functools import inspect as _inspect import mmap as _mmap diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index c8e8658413..8a3cc30f4b 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -31,14 +31,6 @@ from .type_definitions import Enum -try: - # For performance reasons, try to use cytoolz when possible (using cython) - import cytoolz as toolz -except ModuleNotFoundError: - # Fall back to pure Python toolz - import toolz # noqa: F401 [unused-import] - - TreeKey = Union[int, str] diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 613eca40b2..e150832295 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -311,7 +311,7 @@ def __call__( # ... # # Since this can be an arbitrary type (not something regular like a collection) there is - # no way to check if the type parameter is verifed in the actual instance. + # no way to check if the type parameter is verified in the actual instance. # The only check can be done at run-time is to verify that the value is an instance of # the original type, completely ignoring the annotation. Ideally, the static type checker # can do a better job to try figure out if the type parameter is ok ... diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 8cb68845d7..2c66d39290 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -69,7 +69,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 9ce07d01bb..61756f30c9 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -230,7 +230,7 @@ def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: if self.backend is None: warnings.warn( UserWarning( - f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a perfomance backend." + f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." ), stacklevel=2, )