diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml
index 5baeb6acef..52f8c25386 100644
--- a/.github/workflows/test-next.yml
+++ b/.github/workflows/test-next.yml
@@ -57,13 +57,13 @@ jobs:
run: |
pyversion=${{ matrix.python-version }}
pyversion_no_dot=${pyversion//./}
- tox run -e next-py${pyversion_no_dot}-${{ matrix.tox-env-factor }}
- # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}.json
+ tox run -e next-py${pyversion_no_dot}-${{ matrix.tox-env-factor }}-cpu
+ # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json
# - name: Upload coverage.json artifact
# uses: actions/upload-artifact@v3
# with:
- # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}
- # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}.json
+ # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu
+ # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json
# - name: Gather info
# run: |
# echo ${{ github.ref_type }} >> info.txt
@@ -76,5 +76,5 @@ jobs:
# - name: Upload info artifact
# uses: actions/upload-artifact@v3
# with:
- # name: info-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}
+ # name: info-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu
# path: info.txt
diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml
index 3dc38bcd97..971a3cfc35 100644
--- a/ci/cscs-ci.yml
+++ b/ci/cscs-ci.yml
@@ -117,3 +117,4 @@ test py310:
- SUBPACKAGE: eve
- SUBPACKAGE: next
VARIANT: [-nomesh, -atlas]
+ SUBVARIANT: [-cuda11x, -cpu]
diff --git a/docs/development/ADRs/0009-Compiled-Backend-Integration.md b/docs/development/ADRs/0009-Compiled-Backend-Integration.md
index 273f954438..27c2f0c73c 100644
--- a/docs/development/ADRs/0009-Compiled-Backend-Integration.md
+++ b/docs/development/ADRs/0009-Compiled-Backend-Integration.md
@@ -159,7 +159,7 @@ Compiled backends may generate code which depends on libraries and tools written
1. can be installed with `pip` (from `PyPI` or another source) automatically.
2. can not be installed with `pip` and not commonly found on HPC machines.
-3. libraries and tools which are left to the user to install and make discoverable: `pybind11`, C++ compilers
+3. libraries and tools which are left to the user to install and make discoverable: `boost`, C++ compilers
Category 1 are made dependencies of `GT4Py`. Examples include `pybind11`, `cmake`, `ninja`.
diff --git a/docs/development/ADRs/0016-Multiple-Backends-and-Build-Systems.md b/docs/development/ADRs/0016-Multiple-Backends-and-Build-Systems.md
new file mode 100644
index 0000000000..ac84903514
--- /dev/null
+++ b/docs/development/ADRs/0016-Multiple-Backends-and-Build-Systems.md
@@ -0,0 +1,118 @@
+---
+tags: [backend, gridtools, bindings, libraries, otf]
+---
+
+# Support for Multiple Backends, Build Systems and Libraries
+
+- **Status**: valid
+- **Authors**: Rico Häuselmann (@DropD)
+- **Created**: 2023-10-11
+- **Updated**: 2023-10-11
+
+In the process of enabling CUDA for the GTFN backend, we encountered a potential support matrix of build systems x target language libraries. The current design requires build systems about all the libraries they can be used with. We decided that the matrix is too small for now and to not revisit the existing design yet.
+
+## Context
+
+ADRs [0009](0009-Compiled_Backend_Integration.md), [0011](0011-On_The_Fly_Compilation.md) and [0012](0012-GridTools_Cpp_OTF_Steps.md) detail the design decisions around what is loosely referred as "gt4py.next backends". In summary the goals are:
+
+- extensibility
+ - adding backends should not require changing existing code
+ - adding / modifying backend modules like build systems / compilers should not be blocked by assumptions in other modules.
+- modularity
+ - increase the chance that two different backends (for example GTFN and another C++ backend) can share code.
+
+Therefore the concerns of generating code in the target language, generating python bindings in the target language and of building (compiling) the generated code are separated it code generator, bindings generator and compile step / build system. The compile step is written to be build system agnostic.
+
+There is one category that connects all these concerns: libraries written in the target language and used in generated / bindings code.
+
+Current design:
+
+```mermaid
+graph LR
+
+gtgen("GTFN code generator (C++/Cuda)") --> |GridTools::fn_naive| Compiler
+gtgen("GTFN code generator (C++/Cuda)") --> |GridTools::fn_gpu| Compiler
+nb("nanobind bindings generator") --> |nanobind| Compiler
+Compiler --> CMakeProject --> CMakeListsGenerator
+Compiler --> CompiledbProject --> CMakeListsGenerator
+```
+
+The current design contains two mappings:
+
+- library name -> CMake `find_package()` call
+- library name -> CMake target name
+
+and the gridtools cpu/gpu link targets are differentiated by internally separating between two fictitious "gridtools_cpu" and "gridtools_gpu" libraries.
+
+## concerns
+
+### Usage
+
+The "gridtools_cpu" and "gridtools_gpu" fake library names add to the learning curve for this part of the code. Reuse of the existing components might require this knowledge.
+
+### Scalability
+
+Adding a new backend using the existing build systems but relying on different libraries has to modify existing build system components (at the very least CMakeListsGenerator).
+
+### Separation of concerns
+
+It makes more sense to separate the concerns of how to generate a valid build system configuration and how to use a particular library in a particular build system than to mix the two.
+
+## Decision
+
+Currently the code overhead is in the tens of lines, and there are no concrete plans to add more compiled backends or different build systems. Therefore we decide to keep the current design for now but to redesign as soon as the matrix grows.
+To this end ToDo comments are added in the relevant places
+
+## Consequences
+
+Initial GTFN gpu support will not be blocked by design work.
+
+## Alternatives Considered
+
+### Push build system support to the LibraryDependency instance
+
+```
+#src/gt4py/next/otf/binding/interface.py
+
+...
+class LibraryDependency:
+ name: str
+ version: str
+ link_targets: list[str]
+ include_headers: list[str]
+```
+
+- Simple, choice is made at code generator level, where the knowledge should be
+- Interface might not suit every build system
+- Up to the implementer to make the logic for choosing reusable (or not)
+
+### Create additional data structures to properly separate concerns
+
+```
+class BuildSystemConfig:
+ device_type: core_defs.DeviceType
+ ...
+
+
+class LibraryAdaptor:
+ library: LibraryDependency
+ build_system: CMakeProject
+
+ def config_phase(self, config: BuildSystemConfig) -> str:
+ import gridtools_cpp
+ cmake_dir = gridtools_cpp.get_cmake_dir()
+
+ return f"find_package(... {cmake_dir} ... )"
+
+def build_phase(self, config: BuildSystemConfig) -> str:
+ return "" # header only library
+
+def link_phase(self, main_target_name: str, config: BuildSystemConfig) -> str:
+ return f"target_link_libraries({main_target_name} ...)"
+```
+
+- More general and fully extensible, adaptors can be added for any required library / build system combination without touching existing code (depending on the registering mechanism).
+- More likely to be reusable as choices are explicit and can be overridden separately by sub classing.
+- More design work required. Open questions:
+ - Design the interface to work with any build system
+ - How to register adaptors? entry points? global dictionary?
diff --git a/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py
new file mode 100644
index 0000000000..c5857999ee
--- /dev/null
+++ b/src/gt4py/next/ffront/foast_passes/type_alias_replacement.py
@@ -0,0 +1,105 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from dataclasses import dataclass
+from typing import Any, cast
+
+import gt4py.next.ffront.field_operator_ast as foast
+from gt4py.eve import NodeTranslator, traits
+from gt4py.eve.concepts import SourceLocation, SymbolName, SymbolRef
+from gt4py.next.ffront import dialect_ast_enums
+from gt4py.next.ffront.fbuiltins import TYPE_BUILTIN_NAMES
+from gt4py.next.type_system import type_specifications as ts
+from gt4py.next.type_system.type_translation import from_type_hint
+
+
+@dataclass
+class TypeAliasReplacement(NodeTranslator, traits.VisitorWithSymbolTableTrait):
+ """
+ Replace Type Aliases with their actual type.
+
+ After this pass, the type aliases used for explicit construction of literal
+ values and for casting field values are replaced by their actual types.
+ """
+
+ closure_vars: dict[str, Any]
+
+ @classmethod
+ def apply(
+ cls, node: foast.FunctionDefinition | foast.FieldOperator, closure_vars: dict[str, Any]
+ ) -> tuple[foast.FunctionDefinition, dict[str, Any]]:
+ foast_node = cls(closure_vars=closure_vars).visit(node)
+ new_closure_vars = closure_vars.copy()
+ for key, value in closure_vars.items():
+ if isinstance(value, type) and key not in TYPE_BUILTIN_NAMES:
+ new_closure_vars[value.__name__] = closure_vars[key]
+ return foast_node, new_closure_vars
+
+ def is_type_alias(self, node_id: SymbolName | SymbolRef) -> bool:
+ return (
+ node_id in self.closure_vars
+ and isinstance(self.closure_vars[node_id], type)
+ and node_id not in TYPE_BUILTIN_NAMES
+ )
+
+ def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name:
+ if self.is_type_alias(node.id):
+ return foast.Name(
+ id=self.closure_vars[node.id].__name__, location=node.location, type=node.type
+ )
+ return node
+
+ def _update_closure_var_symbols(
+ self, closure_vars: list[foast.Symbol], location: SourceLocation
+ ) -> list[foast.Symbol]:
+ new_closure_vars: list[foast.Symbol] = []
+ existing_type_names: set[str] = set()
+
+ for var in closure_vars:
+ if self.is_type_alias(var.id):
+ actual_type_name = self.closure_vars[var.id].__name__
+ # Avoid multiple definitions of a type in closure_vars
+ if actual_type_name not in existing_type_names:
+ new_closure_vars.append(
+ foast.Symbol(
+ id=actual_type_name,
+ type=ts.FunctionType(
+ pos_or_kw_args={},
+ kw_only_args={},
+ pos_only_args=[ts.DeferredType(constraint=ts.ScalarType)],
+ returns=cast(
+ ts.DataType, from_type_hint(self.closure_vars[var.id])
+ ),
+ ),
+ namespace=dialect_ast_enums.Namespace.CLOSURE,
+ location=location,
+ )
+ )
+ existing_type_names.add(actual_type_name)
+ elif var.id not in existing_type_names:
+ new_closure_vars.append(var)
+ existing_type_names.add(var.id)
+
+ return new_closure_vars
+
+ def visit_FunctionDefinition(
+ self, node: foast.FunctionDefinition, **kwargs
+ ) -> foast.FunctionDefinition:
+ return foast.FunctionDefinition(
+ id=node.id,
+ params=node.params,
+ body=self.visit(node.body, **kwargs),
+ closure_vars=self._update_closure_var_symbols(node.closure_vars, node.location),
+ location=node.location,
+ )
diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py
index 082939c938..c7c4c3a23f 100644
--- a/src/gt4py/next/ffront/func_to_foast.py
+++ b/src/gt4py/next/ffront/func_to_foast.py
@@ -33,6 +33,7 @@
from gt4py.next.ffront.foast_passes.closure_var_type_deduction import ClosureVarTypeDeduction
from gt4py.next.ffront.foast_passes.dead_closure_var_elimination import DeadClosureVarElimination
from gt4py.next.ffront.foast_passes.iterable_unpack import UnpackedAssignPass
+from gt4py.next.ffront.foast_passes.type_alias_replacement import TypeAliasReplacement
from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
@@ -91,6 +92,7 @@ def _postprocess_dialect_ast(
closure_vars: dict[str, Any],
annotations: dict[str, Any],
) -> foast.FunctionDefinition:
+ foast_node, closure_vars = TypeAliasReplacement.apply(foast_node, closure_vars)
foast_node = ClosureVarFolding.apply(foast_node, closure_vars)
foast_node = DeadClosureVarElimination.apply(foast_node)
foast_node = ClosureVarTypeDeduction.apply(foast_node, closure_vars)
diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py
index eb7e715fb6..4ee93e043f 100644
--- a/src/gt4py/next/iterator/embedded.py
+++ b/src/gt4py/next/iterator/embedded.py
@@ -685,7 +685,7 @@ def _single_vertical_idx(
indices: NamedFieldIndices, column_axis: Tag, column_index: common.IntIndex
) -> NamedFieldIndices:
transformed = {
- axis: (index if axis != column_axis else index.start + column_index) # type: ignore[union-attr] # trust me, `index` is range in case of `column_axis`
+ axis: (index if axis != column_axis else index.start + column_index) # type: ignore[union-attr] # trust me, `index` is range in case of `column_axis` # fmt: off
for axis, index in indices.items()
}
return transformed
@@ -1050,7 +1050,7 @@ def __gt_origin__(self) -> tuple[int, ...]:
return (0,)
@classmethod
- def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override] # Signature incompatible with supertype
+ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override] # Signature incompatible with supertype # fmt: off
raise NotImplementedError()
@property
@@ -1070,7 +1070,7 @@ def remap(self, index_field: common.Field) -> common.Field:
raise NotImplementedError()
def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.int32:
- if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code
+ if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off
d, r = item[0]
assert d == self._dimension
assert isinstance(r, int)
@@ -1156,7 +1156,7 @@ def __gt_origin__(self) -> tuple[int, ...]:
return tuple()
@classmethod
- def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override] # Signature incompatible with supertype
+ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override] # Signature incompatible with supertype # fmt: off
raise NotImplementedError()
@property
diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py
index 9dccddc012..5d54512bd0 100644
--- a/src/gt4py/next/otf/binding/nanobind.py
+++ b/src/gt4py/next/otf/binding/nanobind.py
@@ -17,7 +17,7 @@
from __future__ import annotations
-from typing import Any, Sequence, Union
+from typing import Any, Sequence, TypeVar, Union
import gt4py.eve as eve
from gt4py.eve.codegen import JinjaTemplate as as_jinja, TemplatedGenerator
@@ -26,6 +26,9 @@
from gt4py.next.type_system import type_info as ti, type_specifications as ts
+SrcL = TypeVar("SrcL", bound=languages.NanobindSrcL, covariant=True)
+
+
class Expr(eve.Node):
pass
@@ -191,8 +194,8 @@ def make_argument(name: str, type_: ts.TypeSpec) -> str | BufferSID | CompositeS
def create_bindings(
- program_source: stages.ProgramSource[languages.Cpp, languages.LanguageWithHeaderFilesSettings],
-) -> stages.BindingSource[languages.Cpp, languages.Python]:
+ program_source: stages.ProgramSource[SrcL, languages.LanguageWithHeaderFilesSettings],
+) -> stages.BindingSource[SrcL, languages.Python]:
"""
Generate Python bindings through which a C++ function can be called.
@@ -201,7 +204,7 @@ def create_bindings(
program_source
The program source for which the bindings are created
"""
- if program_source.language is not languages.Cpp:
+ if program_source.language not in [languages.Cpp, languages.Cuda]:
raise ValueError(
f"Can only create bindings for C++ program sources, received {program_source.language}."
)
@@ -221,7 +224,6 @@ def create_bindings(
"gridtools/common/tuple_util.hpp",
"gridtools/fn/unstructured.hpp",
"gridtools/fn/cartesian.hpp",
- "gridtools/fn/backend/naive.hpp",
"gridtools/storage/adapter/nanobind_adapter.hpp",
],
wrapper=WrapperFunction(
@@ -266,8 +268,6 @@ def create_bindings(
@workflow.make_step
def bind_source(
- inp: stages.ProgramSource[languages.Cpp, languages.LanguageWithHeaderFilesSettings],
-) -> stages.CompilableSource[
- languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python
-]:
+ inp: stages.ProgramSource[SrcL, languages.LanguageWithHeaderFilesSettings],
+) -> stages.CompilableSource[SrcL, languages.LanguageWithHeaderFilesSettings, languages.Python]:
return stages.CompilableSource(program_source=inp, binding_source=create_bindings(inp))
diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py
index b281fde7b5..3d36f5d985 100644
--- a/src/gt4py/next/otf/compilation/build_systems/cmake.py
+++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py
@@ -38,7 +38,7 @@ def _generate_next_value_(name, start, count, last_values):
@dataclasses.dataclass
class CMakeFactory(
compiler.BuildSystemProjectGenerator[
- languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python
+ languages.Cpp | languages.Cuda, languages.LanguageWithHeaderFilesSettings, languages.Python
]
):
"""Create a CMakeProject from a ``CompilableSource`` stage object with given CMake settings."""
@@ -50,7 +50,7 @@ class CMakeFactory(
def __call__(
self,
source: stages.CompilableSource[
- languages.Cpp,
+ languages.Cpp | languages.Cuda,
languages.LanguageWithHeaderFilesSettings,
languages.Python,
],
@@ -63,16 +63,21 @@ def __call__(
name = source.program_source.entry_point.name
header_name = f"{name}.{source.program_source.language_settings.header_extension}"
bindings_name = f"{name}_bindings.{source.program_source.language_settings.file_extension}"
+ cmake_languages = [cmake_lists.Language(name="CXX")]
+ if source.program_source.language is languages.Cuda:
+ cmake_languages = [*cmake_languages, cmake_lists.Language(name="CUDA")]
+ cmake_lists_src = cmake_lists.generate_cmakelists_source(
+ name,
+ source.library_deps,
+ [header_name, bindings_name],
+ languages=cmake_languages,
+ )
return CMakeProject(
root_path=cache.get_cache_folder(source, cache_strategy),
source_files={
header_name: source.program_source.source_code,
bindings_name: source.binding_source.source_code,
- "CMakeLists.txt": cmake_lists.generate_cmakelists_source(
- name,
- source.library_deps,
- [header_name, bindings_name],
- ),
+ "CMakeLists.txt": cmake_lists_src,
},
program_name=name,
generator_name=self.cmake_generator_name,
diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py
index ef222341e3..5ea4ba0519 100644
--- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py
+++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py
@@ -30,22 +30,31 @@ class LinkDependency(eve.Node):
target: str
+class Language(eve.Node):
+ name: str
+
+
class CMakeListsFile(eve.Node):
project_name: str
find_deps: Sequence[FindDependency]
link_deps: Sequence[LinkDependency]
source_names: Sequence[str]
bin_output_suffix: str
+ languages: Sequence[Language]
class CMakeListsGenerator(eve.codegen.TemplatedGenerator):
CMakeListsFile = as_jinja(
"""
- project({{project_name}})
cmake_minimum_required(VERSION 3.20.0)
+ project({{project_name}})
+
# Languages
- enable_language(CXX)
+ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
+ set(CMAKE_CUDA_ARCHITECTURES 60)
+ endif()
+ {{"\\n".join(languages)}}
# Paths
list(APPEND CMAKE_MODULE_PATH ${CMAKE_BINARY_DIR})
@@ -77,18 +86,17 @@ class CMakeListsGenerator(eve.codegen.TemplatedGenerator):
)
def visit_FindDependency(self, dep: FindDependency):
+ # TODO(ricoh): do not add more libraries here
+ # and do not use this design in a new build system.
+ # Instead, design this to be extensible (refer to ADR-0016).
match dep.name:
- case "pybind11":
- import pybind11
-
- return f"find_package(pybind11 CONFIG REQUIRED PATHS {pybind11.get_cmake_dir()} NO_DEFAULT_PATH)"
case "nanobind":
import nanobind
py = "find_package(Python COMPONENTS Interpreter Development REQUIRED)"
nb = f"find_package(nanobind CONFIG REQUIRED PATHS {nanobind.cmake_dir()} NO_DEFAULT_PATHS)"
return py + "\n" + nb
- case "gridtools":
+ case "gridtools_cpu" | "gridtools_gpu":
import gridtools_cpp
return f"find_package(GridTools REQUIRED PATHS {gridtools_cpp.get_cmake_dir()} NO_DEFAULT_PATH)"
@@ -96,13 +104,16 @@ def visit_FindDependency(self, dep: FindDependency):
raise ValueError("Library {name} is not supported".format(name=dep.name))
def visit_LinkDependency(self, dep: LinkDependency):
+ # TODO(ricoh): do not add more libraries here
+ # and do not use this design in a new build system.
+ # Instead, design this to be extensible (refer to ADR-0016).
match dep.name:
- case "pybind11":
- lib_name = "pybind11::module"
case "nanobind":
lib_name = "nanobind-static"
- case "gridtools":
+ case "gridtools_cpu":
lib_name = "GridTools::fn_naive"
+ case "gridtools_gpu":
+ lib_name = "GridTools::fn_gpu"
case _:
raise ValueError("Library {name} is not supported".format(name=dep.name))
@@ -118,11 +129,14 @@ def visit_LinkDependency(self, dep: LinkDependency):
lnk = f"target_link_libraries({dep.target} PUBLIC {lib_name})"
return cfg + "\n" + lnk
+ Language = as_jinja("enable_language({{name}})")
+
def generate_cmakelists_source(
project_name: str,
dependencies: tuple[interface.LibraryDependency, ...],
source_names: Sequence[str],
+ languages: Sequence[Language] = (Language(name="CXX"),),
) -> str:
"""
Generate CMakeLists file contents.
@@ -135,5 +149,6 @@ def generate_cmakelists_source(
link_deps=[LinkDependency(name=d.name, target=project_name) for d in dependencies],
source_names=source_names,
bin_output_suffix=common.python_module_suffix(),
+ languages=languages,
)
return CMakeListsGenerator.apply(cmakelists_file)
diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py
index 34f2f85081..84a69859c0 100644
--- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py
+++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py
@@ -20,7 +20,7 @@
import re
import shutil
import subprocess
-from typing import Optional
+from typing import Optional, TypeVar
from gt4py.next.otf import languages, stages
from gt4py.next.otf.binding import interface
@@ -28,10 +28,13 @@
from gt4py.next.otf.compilation.build_systems import cmake, cmake_lists
+SrcL = TypeVar("SrcL", bound=languages.NanobindSrcL)
+
+
@dataclasses.dataclass
class CompiledbFactory(
compiler.BuildSystemProjectGenerator[
- languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python
+ SrcL, languages.LanguageWithHeaderFilesSettings, languages.Python
]
):
"""
@@ -48,7 +51,7 @@ class CompiledbFactory(
def __call__(
self,
source: stages.CompilableSource[
- languages.Cpp,
+ SrcL,
languages.LanguageWithHeaderFilesSettings,
languages.Python,
],
@@ -66,6 +69,8 @@ def __call__(
deps=source.library_deps,
build_type=self.cmake_build_type,
cmake_flags=self.cmake_extra_flags or [],
+ language=source.program_source.language,
+ language_settings=source.program_source.language_settings,
)
if self.renew_compiledb or not (
@@ -92,9 +97,7 @@ def __call__(
@dataclasses.dataclass()
class CompiledbProject(
- stages.BuildSystemProject[
- languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python
- ]
+ stages.BuildSystemProject[SrcL, languages.LanguageWithHeaderFilesSettings, languages.Python]
):
"""
Compiledb build system for gt4py programs.
@@ -113,18 +116,21 @@ class CompiledbProject(
compile_commands_cache: pathlib.Path
bindings_file_name: str
- def build(self):
+ def build(self) -> None:
self._write_files()
- if build_data.read_data(self.root_path).status < build_data.BuildStatus.CONFIGURED:
+ current_data = build_data.read_data(self.root_path)
+ if current_data is None or current_data.status < build_data.BuildStatus.CONFIGURED:
self._run_config()
+ current_data = build_data.read_data(self.root_path) # update after config
if (
- build_data.BuildStatus.CONFIGURED
- <= build_data.read_data(self.root_path).status
+ current_data is not None
+ and build_data.BuildStatus.CONFIGURED
+ <= current_data.status
< build_data.BuildStatus.COMPILED
):
self._run_build()
- def _write_files(self):
+ def _write_files(self) -> None:
def ignore_not_libraries(folder: str, children: list[str]) -> list[str]:
pattern = r"((lib.*\.a)|(.*\.lib))"
libraries = [child for child in children if re.match(pattern, child)]
@@ -151,7 +157,7 @@ def ignore_not_libraries(folder: str, children: list[str]) -> list[str]:
path=self.root_path,
)
- def _run_config(self):
+ def _run_config(self) -> None:
compile_db = json.loads(self.compile_commands_cache.read_text())
(self.root_path / "build").mkdir(exist_ok=True)
@@ -176,7 +182,7 @@ def _run_config(self):
self.root_path,
)
- def _run_build(self):
+ def _run_build(self) -> None:
logfile = self.root_path / "log_build.txt"
compile_db = json.loads((self.root_path / "compile_commands.json").read_text())
assert compile_db
@@ -212,19 +218,16 @@ def _cc_prototype_program_source(
deps: tuple[interface.LibraryDependency, ...],
build_type: cmake.BuildType,
cmake_flags: list[str],
+ language: type[SrcL],
+ language_settings: languages.LanguageWithHeaderFilesSettings,
) -> stages.ProgramSource:
name = _cc_prototype_program_name(deps, build_type.value, cmake_flags)
return stages.ProgramSource(
entry_point=interface.Function(name=name, parameters=()),
source_code="",
library_deps=deps,
- language=languages.Cpp,
- language_settings=languages.LanguageWithHeaderFilesSettings(
- formatter_key="",
- formatter_style=None,
- file_extension="",
- header_extension="",
- ),
+ language=language,
+ language_settings=language_settings,
)
@@ -251,16 +254,26 @@ def _cc_create_compiledb(
stages.CompilableSource(prototype_program_source, None), cache_strategy
)
+ header_ext = prototype_program_source.language_settings.header_extension
+ src_ext = prototype_program_source.language_settings.file_extension
+ prog_src_name = f"{name}.{header_ext}"
+ binding_src_name = f"{name}.{src_ext}"
+ cmake_languages = [cmake_lists.Language(name="CXX")]
+ if prototype_program_source.language is languages.Cuda:
+ cmake_languages = [*cmake_languages, cmake_lists.Language(name="CUDA")]
+
prototype_project = cmake.CMakeProject(
generator_name="Ninja",
build_type=build_type,
extra_cmake_flags=cmake_flags,
root_path=cache_path,
source_files={
- f"{name}.hpp": "",
- f"{name}.cpp": "",
+ **{name: "" for name in [binding_src_name, prog_src_name]},
"CMakeLists.txt": cmake_lists.generate_cmakelists_source(
- name, prototype_program_source.library_deps, [f"{name}.hpp", f"{name}.cpp"]
+ name,
+ prototype_program_source.library_deps,
+ [binding_src_name, prog_src_name],
+ cmake_languages,
),
},
program_name=name,
@@ -290,21 +303,21 @@ def _cc_create_compiledb(
entry["command"]
.replace(f"CMakeFiles/{name}.dir", ".")
.replace(str(cache_path), "$SRC_PATH")
- .replace(f"{name}.cpp", "$BINDINGS_FILE")
- .replace(f"{name}", "$NAME")
+ .replace(binding_src_name, "$BINDINGS_FILE")
+ .replace(name, "$NAME")
.replace("-I$SRC_PATH/build/_deps", f"-I{cache_path}/build/_deps")
)
entry["file"] = (
entry["file"]
.replace(f"CMakeFiles/{name}.dir", ".")
.replace(str(cache_path), "$SRC_PATH")
- .replace(f"{name}.cpp", "$BINDINGS_FILE")
+ .replace(binding_src_name, "$BINDINGS_FILE")
)
entry["output"] = (
entry["output"]
.replace(f"CMakeFiles/{name}.dir", ".")
- .replace(f"{name}.cpp", "$BINDINGS_FILE")
- .replace(f"{name}", "$NAME")
+ .replace(binding_src_name, "$BINDINGS_FILE")
+ .replace(name, "$NAME")
)
compile_db_path = cache_path / "compile_commands.json"
diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py
index 32c5469333..dacb444207 100644
--- a/src/gt4py/next/otf/compilation/compiler.py
+++ b/src/gt4py/next/otf/compilation/compiler.py
@@ -23,7 +23,7 @@
from gt4py.next.otf.step_types import LS, SrcL, TgtL
-SourceLanguageType = TypeVar("SourceLanguageType", bound=languages.LanguageTag)
+SourceLanguageType = TypeVar("SourceLanguageType", bound=languages.NanobindSrcL)
LanguageSettingsType = TypeVar("LanguageSettingsType", bound=languages.LanguageSettings)
T = TypeVar("T")
diff --git a/src/gt4py/next/otf/languages.py b/src/gt4py/next/otf/languages.py
index e2738615ac..b0d01d91ab 100644
--- a/src/gt4py/next/otf/languages.py
+++ b/src/gt4py/next/otf/languages.py
@@ -57,6 +57,14 @@ class Python(LanguageTag):
...
-class Cpp(LanguageTag):
+class NanobindSrcL(LanguageTag):
+ ...
+
+
+class Cpp(NanobindSrcL):
settings_class = LanguageWithHeaderFilesSettings
...
+
+
+class Cuda(NanobindSrcL):
+ settings_class = LanguageWithHeaderFilesSettings
diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py
index d144533798..4c6cdc273d 100644
--- a/src/gt4py/next/otf/recipes.py
+++ b/src/gt4py/next/otf/recipes.py
@@ -14,27 +14,21 @@
from __future__ import annotations
import dataclasses
-from typing import Generic, TypeVar
-from gt4py.next.otf import languages, stages, step_types, workflow
-
-
-SrcL = TypeVar("SrcL", bound=languages.LanguageTag)
-TgtL = TypeVar("TgtL", bound=languages.LanguageTag)
-LS = TypeVar("LS", bound=languages.LanguageSettings)
+from gt4py.next.otf import stages, step_types, workflow
@dataclasses.dataclass(frozen=True)
-class OTFCompileWorkflow(workflow.NamedStepSequence, Generic[SrcL, LS, TgtL]):
+class OTFCompileWorkflow(workflow.NamedStepSequence):
"""The typical compiled backend steps composed into a workflow."""
- translation: step_types.TranslationStep[SrcL, LS]
+ translation: step_types.TranslationStep
bindings: workflow.Workflow[
- stages.ProgramSource[SrcL, LS],
- stages.CompilableSource[SrcL, LS, TgtL],
+ stages.ProgramSource,
+ stages.CompilableSource,
]
compilation: workflow.Workflow[
- stages.CompilableSource[SrcL, LS, TgtL],
+ stages.CompilableSource,
stages.CompiledProgram,
]
decoration: workflow.Workflow[stages.CompiledProgram, stages.CompiledProgram]
diff --git a/src/gt4py/next/otf/step_types.py b/src/gt4py/next/otf/step_types.py
index 54fe2e5389..5eeb5c495b 100644
--- a/src/gt4py/next/otf/step_types.py
+++ b/src/gt4py/next/otf/step_types.py
@@ -50,7 +50,10 @@ def __call__(
...
-class CompilationStep(Protocol[SrcL, LS, TgtL]):
+class CompilationStep(
+ workflow.Workflow[stages.CompilableSource[SrcL, LS, TgtL], stages.CompiledProgram],
+ Protocol[SrcL, LS, TgtL],
+):
"""Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram)."""
def __call__(self, source: stages.CompilableSource[SrcL, LS, TgtL]) -> stages.CompiledProgram:
diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py
index 8cd910e40f..645d1f742f 100644
--- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py
+++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py
@@ -220,6 +220,7 @@ def visit_FencilDefinition(
return self.generic_visit(
node,
grid_type_str=self._grid_type_str[node.grid_type],
+ block_sizes=self._block_sizes(node.offset_definitions),
**kwargs,
)
@@ -261,6 +262,8 @@ def visit_TemporaryAllocation(self, node, **kwargs):
${'\\n'.join(offset_definitions)}
${'\\n'.join(function_definitions)}
+ ${block_sizes}
+
inline auto ${id} = [](auto... connectivities__){
return [connectivities__...](auto backend, ${','.join('auto&& ' + p for p in params)}){
auto tmp_alloc__ = gtfn::backend::tmp_allocator(backend);
@@ -273,6 +276,18 @@ def visit_TemporaryAllocation(self, node, **kwargs):
"""
)
+ def _block_sizes(self, offset_definitions: list[gtfn_ir.TagDefinition]) -> str:
+ block_dims = []
+ block_sizes = [32, 8] + [1] * (len(offset_definitions) - 2)
+ for i, tag in enumerate(offset_definitions):
+ if tag.alias is None:
+ block_dims.append(
+ f"gridtools::meta::list<{tag.name.id}_t, "
+ f"gridtools::integral_constant>"
+ )
+ sizes_str = ",\n".join(block_dims)
+ return f"using block_sizes_t = gridtools::meta::list<{sizes_str}>;"
+
@classmethod
def apply(cls, root: Any, **kwargs: Any) -> str:
generated_code = super().apply(root, **kwargs)
diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
index 5e24e855b5..7bf310f4e1 100644
--- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
+++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
@@ -16,10 +16,11 @@
import dataclasses
import warnings
-from typing import Any, Final, Optional, TypeVar
+from typing import Any, Final, Optional
import numpy as np
+from gt4py._core import definitions as core_defs
from gt4py.eve import trees, utils
from gt4py.next import common
from gt4py.next.common import Connectivity, Dimension
@@ -32,8 +33,6 @@
from gt4py.next.type_system import type_specifications as ts, type_translation
-T = TypeVar("T")
-
GENERATED_CONNECTIVITY_PARAM_PREFIX = "gt_conn_"
@@ -45,14 +44,30 @@ def get_param_description(name: str, obj: Any) -> interface.Parameter:
class GTFNTranslationStep(
workflow.ChainableWorkflowMixin[
stages.ProgramCall,
- stages.ProgramSource[languages.Cpp, languages.LanguageWithHeaderFilesSettings],
+ stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings],
],
- step_types.TranslationStep[languages.Cpp, languages.LanguageWithHeaderFilesSettings],
+ step_types.TranslationStep[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings],
):
- language_settings: languages.LanguageWithHeaderFilesSettings = cpp_interface.CPP_DEFAULT
- enable_itir_transforms: bool = True # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135
+ language_settings: Optional[languages.LanguageWithHeaderFilesSettings] = None
+ # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135
+ enable_itir_transforms: bool = True
use_imperative_backend: bool = False
lift_mode: Optional[LiftMode] = None
+ device_type: core_defs.DeviceType = core_defs.DeviceType.CPU
+
+ def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings:
+ match self.device_type:
+ case core_defs.DeviceType.CUDA:
+ return languages.LanguageWithHeaderFilesSettings(
+ formatter_key=cpp_interface.CPP_DEFAULT.formatter_key,
+ formatter_style=cpp_interface.CPP_DEFAULT.formatter_style,
+ file_extension="cu",
+ header_extension="cuh",
+ )
+ case core_defs.DeviceType.CPU:
+ return cpp_interface.CPP_DEFAULT
+ case _:
+ raise self._not_implemented_for_device_type()
def _process_regular_arguments(
self,
@@ -98,7 +113,7 @@ def _process_regular_arguments(
isinstance(
dim, fbuiltins.FieldOffset
) # TODO(havogt): remove support for FieldOffset as Dimension
- or dim.kind == common.DimensionKind.LOCAL
+ or dim.kind is common.DimensionKind.LOCAL
):
# translate sparse dimensions to tuple dtype
dim_name = dim.value
@@ -159,7 +174,7 @@ def _process_connectivity_args(
def __call__(
self,
inp: stages.ProgramCall,
- ) -> stages.ProgramSource[languages.Cpp, languages.LanguageWithHeaderFilesSettings]:
+ ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]:
"""Generate GTFN C++ code from the ITIR definition."""
program: itir.FencilDefinition = inp.program
@@ -189,7 +204,8 @@ def __call__(
# combine into a format that is aligned with what the backend expects
parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters
- args_expr: list[str] = ["gridtools::fn::backend::naive{}", *regular_args_expr]
+ backend_arg = self._backend_type()
+ args_expr: list[str] = [backend_arg, *regular_args_expr]
function = interface.Function(program.id, tuple(parameters))
decl_body = (
@@ -205,9 +221,9 @@ def __call__(
**inp.kwargs,
)
source_code = interface.format_source(
- self.language_settings,
+ self._language_settings(),
f"""
- #include
+ #include <{self._backend_header()}>
#include
#include
{stencil_src}
@@ -215,16 +231,69 @@ def __call__(
""".strip(),
)
- module = stages.ProgramSource(
+ module: stages.ProgramSource[
+ languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings
+ ] = stages.ProgramSource(
entry_point=function,
- library_deps=(interface.LibraryDependency("gridtools", "master"),),
+ library_deps=(interface.LibraryDependency(self._library_name(), "master"),),
source_code=source_code,
- language=languages.Cpp,
- language_settings=self.language_settings,
+ language=self._language(),
+ language_settings=self._language_settings(),
)
return module
+ def _backend_header(self) -> str:
+ match self.device_type:
+ case core_defs.DeviceType.CUDA:
+ return "gridtools/fn/backend/gpu.hpp"
+ case core_defs.DeviceType.CPU:
+ return "gridtools/fn/backend/naive.hpp"
+ case _:
+ raise self._not_implemented_for_device_type()
+
+ def _backend_type(self) -> str:
+ match self.device_type:
+ case core_defs.DeviceType.CUDA:
+ return "gridtools::fn::backend::gpu{}"
+ case core_defs.DeviceType.CPU:
+ return "gridtools::fn::backend::naive{}"
+ case _:
+ raise self._not_implemented_for_device_type()
+
+ def _language(self) -> type[languages.NanobindSrcL]:
+ match self.device_type:
+ case core_defs.DeviceType.CUDA:
+ return languages.Cuda
+ case core_defs.DeviceType.CPU:
+ return languages.Cpp
+ case _:
+ raise self._not_implemented_for_device_type()
+
+ def _language_settings(self) -> languages.LanguageWithHeaderFilesSettings:
+ return (
+ self.language_settings
+ if self.language_settings is not None
+ else self._default_language_settings()
+ )
+
+ def _library_name(self) -> str:
+ match self.device_type:
+ case core_defs.DeviceType.CUDA:
+ return "gridtools_gpu"
+ case core_defs.DeviceType.CPU:
+ return "gridtools_cpu"
+ case _:
+ raise self._not_implemented_for_device_type()
+
+ def _not_implemented_for_device_type(self) -> NotImplementedError:
+ return NotImplementedError(
+ f"{self.__class__.__name__} is not implemented for "
+ f"device type {self.device_type.name}"
+ )
+
+
+translate_program_cpu: Final[step_types.TranslationStep] = GTFNTranslationStep()
-translate_program: Final[
- step_types.TranslationStep[languages.Cpp, languages.LanguageWithHeaderFilesSettings]
-] = GTFNTranslationStep()
+translate_program_gpu: Final[step_types.TranslationStep] = GTFNTranslationStep(
+ device_type=core_defs.DeviceType.CUDA
+)
diff --git a/src/gt4py/next/program_processors/otf_compile_executor.py b/src/gt4py/next/program_processors/otf_compile_executor.py
index a22028414b..cd08c16933 100644
--- a/src/gt4py/next/program_processors/otf_compile_executor.py
+++ b/src/gt4py/next/program_processors/otf_compile_executor.py
@@ -20,15 +20,15 @@
from gt4py.next.program_processors import processor_interface as ppi
-SrcL = TypeVar("SrcL", bound=languages.LanguageTag)
+SrcL = TypeVar("SrcL", bound=languages.NanobindSrcL)
TgtL = TypeVar("TgtL", bound=languages.LanguageTag)
LS = TypeVar("LS", bound=languages.LanguageSettings)
HashT = TypeVar("HashT")
@dataclasses.dataclass(frozen=True)
-class OTFCompileExecutor(ppi.ProgramExecutor, Generic[SrcL, LS, TgtL, HashT]):
- otf_workflow: recipes.OTFCompileWorkflow[SrcL, LS, TgtL]
+class OTFCompileExecutor(ppi.ProgramExecutor):
+ otf_workflow: recipes.OTFCompileWorkflow
name: Optional[str] = None
def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None:
@@ -42,7 +42,7 @@ def __name__(self) -> str:
@dataclasses.dataclass(frozen=True)
-class CachedOTFCompileExecutor(ppi.ProgramExecutor, Generic[SrcL, LS, TgtL, HashT]):
+class CachedOTFCompileExecutor(ppi.ProgramExecutor, Generic[HashT]):
otf_workflow: workflow.CachedStep[stages.ProgramCall, stages.CompiledProgram, HashT]
name: Optional[str] = None
diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
index f78d90095c..1c1bed9c5e 100644
--- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
+++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
@@ -11,29 +11,44 @@
# distribution for a copy of the license or check .
#
# SPDX-License-Identifier: GPL-3.0-or-later
-
-from typing import Any, Mapping, Sequence
+import hashlib
+from typing import Any, Mapping, Optional, Sequence
import dace
import numpy as np
+from dace.codegen.compiled_sdfg import CompiledSDFG
+from dace.transformation.auto import auto_optimize as autoopt
import gt4py.next.iterator.ir as itir
-from gt4py.next import common
-from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
+from gt4py.next.common import Dimension, Domain, UnitRange, is_field
+from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider
from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
from gt4py.next.otf.compilation import cache
from gt4py.next.program_processors.processor_interface import program_executor
-from gt4py.next.type_system import type_translation
+from gt4py.next.type_system import type_specifications as ts, type_translation
from .itir_to_sdfg import ItirToSDFG
-from .utility import connectivity_identifier, filter_neighbor_tables
+from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims
+
+
+def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]:
+ sorted_dims = get_sorted_dims(domain.dims)
+ return [domain.ranges[dim_index] for dim_index, _ in sorted_dims]
+
+
+""" Default build configuration in DaCe backend """
+_build_type = "Release"
+# removing -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins
+_cpu_args = (
+ "-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -Wno-unused-parameter -Wno-unused-label"
+)
def convert_arg(arg: Any):
- if common.is_field(arg):
- sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value)
+ if is_field(arg):
+ sorted_dims = get_sorted_dims(arg.domain.dims)
ndim = len(sorted_dims)
- dim_indices = [dim[0] for dim in sorted_dims]
+ dim_indices = [dim_index for dim_index, _ in sorted_dims]
assert isinstance(arg.ndarray, np.ndarray)
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
return arg
@@ -69,6 +84,17 @@ def get_shape_args(
}
+def get_offset_args(
+ arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any]
+) -> Mapping[str, int]:
+ return {
+ str(sym): -drange.start
+ for param, arg in zip(params, args)
+ if is_field(arg)
+ for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain))
+ }
+
+
def get_stride_args(
arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any]
) -> Mapping[str, int]:
@@ -85,17 +111,89 @@ def get_stride_args(
return stride_args
+_build_cache_cpu: dict[str, CompiledSDFG] = {}
+_build_cache_gpu: dict[str, CompiledSDFG] = {}
+
+
+def get_cache_id(
+ program: itir.FencilDefinition,
+ arg_types: Sequence[ts.TypeSpec],
+ column_axis: Optional[Dimension],
+ offset_provider: Mapping[str, Any],
+) -> str:
+ max_neighbors = [
+ (k, v.max_neighbors)
+ for k, v in offset_provider.items()
+ if isinstance(v, (NeighborTableOffsetProvider, StridedNeighborOffsetProvider))
+ ]
+ cache_id_args = [
+ str(arg)
+ for arg in (
+ program,
+ *arg_types,
+ column_axis,
+ *max_neighbors,
+ )
+ ]
+ m = hashlib.sha256()
+ for s in cache_id_args:
+ m.update(s.encode())
+ return m.hexdigest()
+
+
@program_executor
def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
+ # build parameters
+ auto_optimize = kwargs.get("auto_optimize", False)
+ build_type = kwargs.get("build_type", "RelWithDebInfo")
+ run_on_gpu = kwargs.get("run_on_gpu", False)
+ build_cache = kwargs.get("build_cache", None)
+ # ITIR parameters
column_axis = kwargs.get("column_axis", None)
offset_provider = kwargs["offset_provider"]
- neighbor_tables = filter_neighbor_tables(offset_provider)
- program = preprocess_program(program, offset_provider)
arg_types = [type_translation.from_value(arg) for arg in args]
- sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
- sdfg: dace.SDFG = sdfg_genenerator.visit(program)
- sdfg.simplify()
+ neighbor_tables = filter_neighbor_tables(offset_provider)
+
+ cache_id = get_cache_id(program, arg_types, column_axis, offset_provider)
+ if build_cache is not None and cache_id in build_cache:
+ # retrieve SDFG program from build cache
+ sdfg_program = build_cache[cache_id]
+ sdfg = sdfg_program.sdfg
+ else:
+ # visit ITIR and generate SDFG
+ program = preprocess_program(program, offset_provider)
+ sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
+ sdfg = sdfg_genenerator.visit(program)
+ sdfg.simplify()
+
+ # set array storage for GPU execution
+ if run_on_gpu:
+ device = dace.DeviceType.GPU
+ sdfg._name = f"{sdfg.name}_gpu"
+ for _, _, array in sdfg.arrays_recursive():
+ if not array.transient:
+ array.storage = dace.dtypes.StorageType.GPU_Global
+ else:
+ device = dace.DeviceType.CPU
+
+ # run DaCe auto-optimization heuristics
+ if auto_optimize:
+ # TODO Investigate how symbol definitions improve autoopt transformations,
+ # in which case the cache table should take the symbols map into account.
+ symbols: dict[str, int] = {}
+ sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols)
+
+ # compile SDFG and retrieve SDFG program
+ sdfg.build_folder = cache._session_cache_dir_path / ".dacecache"
+ with dace.config.temporary_config():
+ dace.config.Config.set("compiler", "build_type", value=build_type)
+ dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args)
+ sdfg_program = sdfg.compile(validate=False)
+
+ # store SDFG program in build cache
+ if build_cache is not None:
+ build_cache[cache_id] = sdfg_program
dace_args = get_args(program.params, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
@@ -103,9 +201,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
- dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args)
-
- sdfg.build_folder = cache._session_cache_dir_path / ".dacecache"
+ dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args)
+ dace_offsets = get_offset_args(sdfg.arrays, program.params, args)
all_args = {
**dace_args,
@@ -113,16 +210,40 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
**dace_shapes,
**dace_conn_shapes,
**dace_strides,
- **dace_conn_stirdes,
+ **dace_conn_strides,
+ **dace_offsets,
}
expected_args = {
key: value
for key, value in all_args.items()
if key in sdfg.signature_arglist(with_types=False)
}
+
with dace.config.temporary_config():
dace.config.Config.set("compiler", "allow_view_arguments", value=True)
- dace.config.Config.set("compiler", "build_type", value="Debug")
- dace.config.Config.set("compiler", "cpu", "args", value="-O0")
dace.config.Config.set("frontend", "check_args", value=True)
- sdfg(**expected_args)
+ sdfg_program(**expected_args)
+
+
+@program_executor
+def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
+ run_dace_iterator(
+ program,
+ *args,
+ **kwargs,
+ build_cache=_build_cache_cpu,
+ build_type=_build_type,
+ run_on_gpu=False,
+ )
+
+
+@program_executor
+def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
+ run_dace_iterator(
+ program,
+ *args,
+ **kwargs,
+ build_cache=_build_cache_gpu,
+ build_type=_build_type,
+ run_on_gpu=True,
+ )
diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py
index 56031d8555..580486aa4a 100644
--- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py
+++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py
@@ -38,6 +38,8 @@
create_memlet_at,
create_memlet_full,
filter_neighbor_tables,
+ flatten_list,
+ get_sorted_dims,
map_nested_sdfg_symbols,
unique_var_name,
)
@@ -79,9 +81,10 @@ def get_scan_dim(
- scan_dim_dtype: data type along the scan dimension
"""
output_type = cast(ts.FieldType, storage_types[output.id])
+ sorted_dims = [dim for _, dim in get_sorted_dims(output_type.dims)]
return (
column_axis.value,
- output_type.dims.index(column_axis),
+ sorted_dims.index(column_axis),
output_type.dtype,
)
@@ -105,18 +108,30 @@ def __init__(
self.offset_provider = offset_provider
self.storage_types = {}
- def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec):
+ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True):
if isinstance(type_, ts.FieldType):
shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
+ offset = (
+ [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
+ if has_offset
+ else None
+ )
dtype = as_dace_type(type_.dtype)
- sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype)
+ sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype)
elif isinstance(type_, ts.ScalarType):
sdfg.add_symbol(name, as_dace_type(type_))
else:
raise NotImplementedError()
self.storage_types[name] = type_
+ def get_output_nodes(
+ self, closure: itir.StencilClosure, context: Context
+ ) -> dict[str, dace.nodes.AccessNode]:
+ translator = PythonTaskletCodegen(self.offset_provider, context, self.node_types)
+ output_nodes = flatten_list(translator.visit(closure.output))
+ return {node.value.data: node.value for node in output_nodes}
+
def visit_FencilDefinition(self, node: itir.FencilDefinition):
program_sdfg = dace.SDFG(name=node.id)
last_state = program_sdfg.add_state("program_entry")
@@ -134,54 +149,33 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
scalar_kind = type_translation.get_scalar_kind(table.table.dtype)
local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL)
type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind))
- self.add_storage(program_sdfg, connectivity_identifier(offset), type_)
+ self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False)
# Create a nested SDFG for all stencil closures.
for closure in node.closures:
- assert isinstance(closure.output, itir.SymRef)
-
- # filter out arguments with scalar type, because they are passed as symbols
- input_names = [
- str(inp.id)
- for inp in closure.inputs
- if isinstance(self.storage_types[inp.id], ts.FieldType)
- ]
- connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
- output_names = [str(closure.output.id)]
-
# Translate the closure and its stencil's body to an SDFG.
- closure_sdfg = self.visit(closure, array_table=program_sdfg.arrays)
+ closure_sdfg, input_names, output_names = self.visit(
+ closure, array_table=program_sdfg.arrays
+ )
# Create a new state for the closure.
last_state = program_sdfg.add_state_after(last_state)
# Create memlets to transfer the program parameters
- input_memlets = [
- create_memlet_full(name, program_sdfg.arrays[name]) for name in input_names
- ]
- connectivity_memlets = [
- create_memlet_full(name, program_sdfg.arrays[name]) for name in connectivity_names
- ]
- output_memlets = [
- create_memlet_full(name, program_sdfg.arrays[name]) for name in output_names
- ]
-
- input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)}
- connectivity_mapping = {
- param: arg for param, arg in zip(connectivity_names, connectivity_memlets)
+ input_mapping = {
+ name: create_memlet_full(name, program_sdfg.arrays[name]) for name in input_names
}
output_mapping = {
- param: arg_memlet for param, arg_memlet in zip(output_names, output_memlets)
+ name: create_memlet_full(name, program_sdfg.arrays[name]) for name in output_names
}
- array_mapping = {**input_mapping, **connectivity_mapping}
- symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, array_mapping)
+ symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, input_mapping)
# Insert the closure's SDFG as a nested SDFG of the program.
nsdfg_node = last_state.add_nested_sdfg(
sdfg=closure_sdfg,
parent=program_sdfg,
- inputs=set(input_names) | set(connectivity_names),
+ inputs=set(input_names),
outputs=set(output_names),
symbol_mapping=symbol_mapping,
)
@@ -191,49 +185,78 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
access_node = last_state.add_access(inner_name)
last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet)
- for inner_name, memlet in connectivity_mapping.items():
- access_node = last_state.add_access(inner_name)
- last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet)
-
for inner_name, memlet in output_mapping.items():
access_node = last_state.add_access(inner_name)
last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet)
+
program_sdfg.validate()
return program_sdfg
def visit_StencilClosure(
self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array]
- ) -> dace.SDFG:
+ ) -> tuple[dace.SDFG, list[str], list[str]]:
assert ItirToSDFG._check_no_lifts(node)
assert ItirToSDFG._check_shift_offsets_are_literals(node)
- assert isinstance(node.output, itir.SymRef)
-
- neighbor_tables = filter_neighbor_tables(self.offset_provider)
- input_names = [str(inp.id) for inp in node.inputs]
- conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
- output_name = str(node.output.id)
# Create the closure's nested SDFG and single state.
closure_sdfg = dace.SDFG(name="closure")
closure_state = closure_sdfg.add_state("closure_entry")
closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init")
- # Add DaCe arrays for inputs, output and connectivities to closure SDFG.
- for name in [*input_names, *conn_names, output_name]:
- assert name not in closure_sdfg.arrays or (name in input_names and name == output_name)
+ program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {}
+ closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms)
+ neighbor_tables = filter_neighbor_tables(self.offset_provider)
+
+ input_names = [str(inp.id) for inp in node.inputs]
+ conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
+
+ output_nodes = self.get_output_nodes(node, closure_ctx)
+ output_names = [k for k, _ in output_nodes.items()]
+
+ # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG.
+ input_transients_mapping = {}
+ for name in [*input_names, *conn_names, *output_names]:
if name in closure_sdfg.arrays:
- # in/out parameter, container already added for in parameter
- continue
- if isinstance(self.storage_types[name], ts.FieldType):
+ assert name in input_names and name in output_names
+ # In case of closures with in/out fields, there is risk of race condition
+ # between read/write access nodes in the (asynchronous) map tasklet.
+ transient_name = unique_var_name()
+ closure_sdfg.add_array(
+ transient_name,
+ shape=array_table[name].shape,
+ strides=array_table[name].strides,
+ dtype=array_table[name].dtype,
+ transient=True,
+ )
+ closure_init_state.add_nedge(
+ closure_init_state.add_access(name),
+ closure_init_state.add_access(transient_name),
+ create_memlet_full(name, closure_sdfg.arrays[name]),
+ )
+ input_transients_mapping[name] = transient_name
+ elif isinstance(self.storage_types[name], ts.FieldType):
closure_sdfg.add_array(
name,
shape=array_table[name].shape,
strides=array_table[name].strides,
dtype=array_table[name].dtype,
)
+ else:
+ assert isinstance(self.storage_types[name], ts.ScalarType)
- # Get output domain of the closure
- program_arg_syms: dict[str, ValueExpr | IteratorExpr | SymbolExpr] = {}
+ input_field_names = [
+ input_name
+ for input_name in input_names
+ if isinstance(self.storage_types[input_name], ts.FieldType)
+ ]
+
+ # Closure outputs should all be fields
+ assert all(
+ isinstance(self.storage_types[output_name], ts.FieldType)
+ for output_name in output_names
+ )
+
+ # Update symbol table and get output domain of the closure
for name, type_ in self.storage_types.items():
if isinstance(type_, ts.ScalarType):
if name in input_names:
@@ -246,78 +269,69 @@ def visit_StencilClosure(
)
access = closure_init_state.add_access(out_name)
value = ValueExpr(access, dtype)
- memlet = create_memlet_at(out_name, ("0",))
+ memlet = dace.Memlet.simple(out_name, "0")
closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet)
program_arg_syms[name] = value
else:
program_arg_syms[name] = SymbolExpr(name, as_dace_type(type_))
- domain_ctx = Context(closure_sdfg, closure_state, program_arg_syms)
- closure_domain = self._visit_domain(node.domain, domain_ctx)
+ closure_domain = self._visit_domain(node.domain, closure_ctx)
# Map SDFG tasklet arguments to parameters
input_access_names = [
- input_name
- if isinstance(self.storage_types[input_name], ts.FieldType)
+ input_transients_mapping[input_name]
+ if input_name in input_transients_mapping
+ else input_name
+ if input_name in input_field_names
else cast(ValueExpr, program_arg_syms[input_name]).value.data
for input_name in input_names
]
input_memlets = [
create_memlet_full(name, closure_sdfg.arrays[name]) for name in input_access_names
]
- conn_memlet = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names]
+ conn_memlets = [create_memlet_full(name, closure_sdfg.arrays[name]) for name in conn_names]
- transient_to_arg_name_mapping = {}
# create and write to transient that is then copied back to actual output array to avoid aliasing of
# same memory in nested SDFG with different names
- nsdfg_output_name = unique_var_name()
- output_descriptor = closure_sdfg.arrays[output_name]
- transient_to_arg_name_mapping[nsdfg_output_name] = output_name
+ output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names}
# scan operator should always be the first function call in a closure
if is_scan(node.stencil):
- nsdfg, map_domain, scan_dim_index = self._visit_scan_stencil_closure(
- node, closure_sdfg.arrays, closure_domain, nsdfg_output_name
+ assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs"
+ transient_name, output_name = next(iter(output_connectors_mapping.items()))
+
+ nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure(
+ node, closure_sdfg.arrays, closure_domain, transient_name
)
- results = [nsdfg_output_name]
+ results = [transient_name]
_, (scan_lb, scan_ub) = closure_domain[scan_dim_index]
output_subset = f"{scan_lb.value}:{scan_ub.value}"
- closure_sdfg.add_array(
- nsdfg_output_name,
- dtype=output_descriptor.dtype,
- shape=(array_table[output_name].shape[scan_dim_index],),
- strides=(array_table[output_name].strides[scan_dim_index],),
- transient=True,
- )
-
- output_memlet = create_memlet_at(
- output_name,
- tuple(
- f"i_{dim}"
- if f"i_{dim}" in map_domain
- else f"0:{output_descriptor.shape[scan_dim_index]}"
- for dim, _ in closure_domain
- ),
- )
+ output_memlets = [
+ create_memlet_at(
+ output_name,
+ tuple(
+ f"i_{dim}"
+ if f"i_{dim}" in map_ranges
+ else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}"
+ for dim, _ in closure_domain
+ ),
+ )
+ ]
else:
- nsdfg, map_domain, results = self._visit_parallel_stencil_closure(
+ nsdfg, map_ranges, results = self._visit_parallel_stencil_closure(
node, closure_sdfg.arrays, closure_domain
)
- assert len(results) == 1
output_subset = "0"
- closure_sdfg.add_scalar(
- nsdfg_output_name,
- dtype=output_descriptor.dtype,
- transient=True,
- )
-
- output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_domain.keys()))
+ output_memlets = [
+ create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys()))
+ for output_name in output_connectors_mapping.values()
+ ]
input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)}
- output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, [output_memlet])}
- conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlet)}
+ output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, output_memlets)}
+ conn_mapping = {param: arg for param, arg in zip(conn_names, conn_memlets)}
array_mapping = {**input_mapping, **conn_mapping}
symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, array_mapping)
@@ -325,15 +339,16 @@ def visit_StencilClosure(
nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg(
closure_state,
sdfg=nsdfg,
- map_ranges=map_domain or {"__dummy": "0"},
+ map_ranges=map_ranges or {"__dummy": "0"},
inputs=array_mapping,
outputs=output_mapping,
symbol_mapping=symbol_mapping,
+ output_nodes=output_nodes,
)
access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)}
for edge in closure_state.in_edges(map_exit):
memlet = edge.data
- if memlet.data not in transient_to_arg_name_mapping:
+ if memlet.data not in output_connectors_mapping:
continue
transient_access = closure_state.add_access(memlet.data)
closure_state.add_edge(
@@ -341,28 +356,16 @@ def visit_StencilClosure(
edge.src_conn,
transient_access,
None,
- dace.Memlet(data=memlet.data, subset=output_subset),
+ dace.Memlet.simple(memlet.data, output_subset),
)
- inner_memlet = dace.Memlet(
- data=memlet.data, subset=output_subset, other_subset=memlet.subset
+ inner_memlet = dace.Memlet.simple(
+ memlet.data, output_subset, other_subset_str=memlet.subset
)
closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet)
closure_state.remove_edge(edge)
- access_nodes[memlet.data].data = transient_to_arg_name_mapping[memlet.data]
-
- for _, (lb, ub) in closure_domain:
- for b in lb, ub:
- if isinstance(b, SymbolExpr):
- continue
- map_entry.add_in_connector(b.value.data)
- closure_state.add_edge(
- b.value,
- None,
- map_entry,
- b.value.data,
- create_memlet_at(b.value.data, ("0",)),
- )
- return closure_sdfg
+ access_nodes[memlet.data].data = output_connectors_mapping[memlet.data]
+
+ return closure_sdfg, input_field_names + conn_names, output_names
def _visit_scan_stencil_closure(
self,
@@ -390,12 +393,12 @@ def _visit_scan_stencil_closure(
connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
# find the scan dimension, same as output dimension, and exclude it from the map domain
- map_domain = {}
+ map_ranges = {}
for dim, (lb, ub) in closure_domain:
lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value
ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value
if not dim == scan_dim:
- map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}"
+ map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}"
else:
scan_lb_str = lb_str
scan_ub_str = ub_str
@@ -481,29 +484,28 @@ def _visit_scan_stencil_closure(
"__result",
carry_node1,
None,
- dace.Memlet(data=f"{scan_carry_name}", subset="0"),
+ dace.Memlet.simple(scan_carry_name, "0"),
)
carry_node2 = lambda_state.add_access(scan_carry_name)
lambda_state.add_memlet_path(
carry_node2,
scan_inner_node,
- memlet=dace.Memlet(data=f"{scan_carry_name}", subset="0"),
+ memlet=dace.Memlet.simple(scan_carry_name, "0"),
src_conn=None,
dst_conn=lambda_carry_name,
)
# connect access nodes to lambda inputs
for (inner_name, _), data_name in zip(lambda_inputs[1:], input_names):
- data_subset = (
- ", ".join([f"i_{dim}" for dim, _ in closure_domain])
- if isinstance(self.storage_types[data_name], ts.FieldType)
- else "0"
- )
+ if isinstance(self.storage_types[data_name], ts.FieldType):
+ memlet = create_memlet_at(data_name, tuple(f"i_{dim}" for dim, _ in closure_domain))
+ else:
+ memlet = dace.Memlet.simple(data_name, "0")
lambda_state.add_memlet_path(
lambda_state.add_access(data_name),
scan_inner_node,
- memlet=dace.Memlet(data=f"{data_name}", subset=data_subset),
+ memlet=memlet,
src_conn=None,
dst_conn=inner_name,
)
@@ -527,12 +529,13 @@ def _visit_scan_stencil_closure(
data_name,
shape=(array_table[node.output.id].shape[scan_dim_index],),
strides=(array_table[node.output.id].strides[scan_dim_index],),
+ offset=(array_table[node.output.id].offset[scan_dim_index],),
dtype=array_table[node.output.id].dtype,
)
lambda_state.add_memlet_path(
scan_inner_node,
lambda_state.add_access(data_name),
- memlet=dace.Memlet(data=data_name, subset=f"i_{scan_dim}"),
+ memlet=dace.Memlet.simple(data_name, f"i_{scan_dim}"),
src_conn=lambda_connector.value.label,
dst_conn=None,
)
@@ -544,10 +547,10 @@ def _visit_scan_stencil_closure(
lambda_update_state.add_memlet_path(
result_node,
carry_node3,
- memlet=dace.Memlet(data=f"{output_names[0]}", subset=f"i_{scan_dim}", other_subset="0"),
+ memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"),
)
- return scan_sdfg, map_domain, scan_dim_index
+ return scan_sdfg, map_ranges, scan_dim_index
def _visit_parallel_stencil_closure(
self,
@@ -562,11 +565,11 @@ def _visit_parallel_stencil_closure(
conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
# find the scan dimension, same as output dimension, and exclude it from the map domain
- map_domain = {}
+ map_ranges = {}
for dim, (lb, ub) in closure_domain:
lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value
ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value
- map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}"
+ map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}"
# Create an SDFG for the tasklet that computes a single item of the output domain.
index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain}
@@ -583,7 +586,7 @@ def _visit_parallel_stencil_closure(
self.node_types,
)
- return context.body, map_domain, [r.value.data for r in results]
+ return context.body, map_ranges, [r.value.data for r in results]
def _visit_domain(
self, node: itir.FunCall, context: Context
diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py
index 2e7a598d9a..b28703feef 100644
--- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py
+++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py
@@ -23,7 +23,7 @@
from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols
import gt4py.eve.codegen
-from gt4py.next import Dimension, type_inference as next_typing
+from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing
from gt4py.next.iterator import ir as itir, type_inference as itir_typing
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.iterator.ir import FunCall, Lambda
@@ -34,9 +34,9 @@
add_mapped_nested_sdfg,
as_dace_type,
connectivity_identifier,
- create_memlet_at,
create_memlet_full,
filter_neighbor_tables,
+ flatten_list,
map_nested_sdfg_symbols,
unique_name,
unique_var_name,
@@ -244,10 +244,8 @@ def builtin_neighbors(
)
# select full shape only in the neighbor-axis dimension
field_subset = [
- f"0:{sdfg.arrays[iterator.field.data].shape[idx]}"
- if dim == table.neighbor_axis.value
- else f"i_{dim}"
- for idx, dim in enumerate(iterator.dimensions)
+ f"0:{shape}" if dim == table.neighbor_axis.value else f"i_{dim}"
+ for dim, shape in zip(sorted(iterator.dimensions), sdfg.arrays[iterator.field.data].shape)
]
state.add_memlet_path(
iterator.field,
@@ -426,32 +424,36 @@ def visit_Lambda(
context.body.add_array(name, shape=shape, strides=strides, dtype=dtype)
# Translate the function's body
- result: ValueExpr | SymbolExpr = self.visit(node.expr)[0]
- # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors
- if isinstance(result, ValueExpr):
- result_name = unique_var_name()
- self.context.body.add_scalar(result_name, result.dtype, transient=True)
- result_access = self.context.state.add_access(result_name)
- self.context.state.add_edge(
- result.value,
- None,
- result_access,
- None,
- # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution
- dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr),
- )
- result = ValueExpr(value=result_access, dtype=result.dtype)
- else:
- result = self.add_expr_tasklet([], result.value, result.dtype, "forward")[0]
- self.context.body.arrays[result.value.data].transient = False
- self.context = prev_context
+ results: list[ValueExpr] = []
+ # We are flattening the returned list of value expressions because the multiple outputs of a lamda
+ # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this.
+ for expr in flatten_list(self.visit(node.expr)):
+ if isinstance(expr, ValueExpr):
+ result_name = unique_var_name()
+ self.context.body.add_scalar(result_name, expr.dtype, transient=True)
+ result_access = self.context.state.add_access(result_name)
+ self.context.state.add_edge(
+ expr.value,
+ None,
+ result_access,
+ None,
+ # in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution
+ dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr),
+ )
+ result = ValueExpr(value=result_access, dtype=expr.dtype)
+ else:
+ # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors
+ result = self.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0]
+ self.context.body.arrays[result.value.data].transient = False
+ results.append(result)
+ self.context = prev_context
for node in context.state.nodes():
if isinstance(node, dace.nodes.AccessNode):
if context.state.out_degree(node) == 0 and context.state.in_degree(node) == 0:
context.state.remove_node(node)
- return context, inputs, [result]
+ return context, inputs, results
def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr:
if node.id not in self.context.symbol_map:
@@ -576,6 +578,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
return iterator
args: list[ValueExpr]
+ sorted_dims = sorted(iterator.dimensions)
if self.context.reduce_limit:
# we are visiting a child node of reduction, so the neighbor index can be used for indirect addressing
result_name = unique_var_name()
@@ -595,9 +598,9 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
)
# if dim is not found in iterator indices, we take the neighbor index over the reduction domain
- array_index = [
+ flat_index = [
f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name
- for dim in sorted(iterator.dimensions)
+ for dim in sorted_dims
]
args = [ValueExpr(iterator.field, iterator.dtype)] + [
ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices
@@ -608,7 +611,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
name="deref",
inputs=set(internals),
outputs={"__result"},
- code=f"__result = {args[0].value.data}_v[{', '.join(array_index)}]",
+ code=f"__result = {args[0].value.data}_v[{', '.join(flat_index)}]",
)
for arg, internal in zip(args, internals):
@@ -630,12 +633,9 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
return [ValueExpr(value=result_access, dtype=iterator.dtype)]
else:
- sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0])
- flat_index = [
- ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions
+ args = [ValueExpr(iterator.field, iterator.dtype)] + [
+ ValueExpr(iterator.indices[dim], iterator.dtype) for dim in sorted_dims
]
-
- args = [ValueExpr(iterator.field, int), *flat_index]
internals = [f"{arg.value.data}_v" for arg in args]
expr = f"{internals[0]}[{', '.join(internals[1:])}]"
return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref")
@@ -702,18 +702,31 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr:
element = tail[1].value
assert isinstance(element, int)
- table: NeighborTableOffsetProvider = self.offset_provider[offset]
- shifted_dim = table.origin_axis.value
- target_dim = table.neighbor_axis.value
+ if isinstance(self.offset_provider[offset], NeighborTableOffsetProvider):
+ table = self.offset_provider[offset]
+ shifted_dim = table.origin_axis.value
+ target_dim = table.neighbor_axis.value
- conn = self.context.state.add_access(connectivity_identifier(offset))
+ conn = self.context.state.add_access(connectivity_identifier(offset))
+
+ args = [
+ ValueExpr(conn, table.table.dtype),
+ ValueExpr(iterator.indices[shifted_dim], dace.int64),
+ ]
+
+ internals = [f"{arg.value.data}_v" for arg in args]
+ expr = f"{internals[0]}[{internals[1]}, {element}]"
+ else:
+ offset_provider = self.offset_provider[offset]
+ assert isinstance(offset_provider, StridedNeighborOffsetProvider)
+
+ shifted_dim = offset_provider.origin_axis.value
+ target_dim = offset_provider.neighbor_axis.value
+ offset_value = iterator.indices[shifted_dim]
+ args = [ValueExpr(offset_value, dace.int64)]
+ internals = [f"{offset_value.data}_v"]
+ expr = f"{internals[0]} * {offset_provider.max_neighbors} + {element}"
- args = [
- ValueExpr(conn, table.table.dtype),
- ValueExpr(iterator.indices[shifted_dim], dace.int64),
- ]
- internals = [f"{arg.value.data}_v" for arg in args]
- expr = f"{internals[0]}[{internals[1]}, {element}]"
shifted_value = self.add_expr_tasklet(
list(zip(args, internals)), expr, dace.dtypes.int64, "ind_addr"
)[0].value
@@ -849,7 +862,7 @@ def _visit_reduce(self, node: itir.FunCall):
p.apply_pass(lambda_context.body, {})
input_memlets = [
- create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args)
+ dace.Memlet.simple(expr.value.data, "__idx") for arg, expr in zip(node.args, args)
]
output_memlet = dace.Memlet.simple(result_name, "0")
@@ -928,7 +941,7 @@ def add_expr_tasklet(
)
self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet)
- memlet = create_memlet_at(result_access.data, ("0",))
+ memlet = dace.Memlet.simple(result_access.data, "0")
self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet)
return [ValueExpr(result_access, result_type)]
diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py
index 889a1ab150..1fdd022a49 100644
--- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py
+++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py
@@ -11,11 +11,12 @@
# distribution for a copy of the license or check .
#
# SPDX-License-Identifier: GPL-3.0-or-later
-
-from typing import Any
+import itertools
+from typing import Any, Sequence
import dace
+from gt4py.next import Dimension
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.type_system import type_specifications as ts
@@ -49,7 +50,7 @@ def connectivity_identifier(name: str):
def create_memlet_full(source_identifier: str, source_array: dace.data.Array):
bounds = [(0, size) for size in source_array.shape]
subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds)
- return dace.Memlet(data=source_identifier, subset=subset)
+ return dace.Memlet.simple(source_identifier, subset)
def create_memlet_at(source_identifier: str, index: tuple[str, ...]):
@@ -57,6 +58,10 @@ def create_memlet_at(source_identifier: str, index: tuple[str, ...]):
return dace.Memlet(data=source_identifier, subset=subset)
+def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]:
+ return sorted(enumerate(dims), key=lambda v: v[1].value)
+
+
def map_nested_sdfg_symbols(
parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet]
) -> dict[str, str]:
@@ -161,3 +166,11 @@ def unique_name(prefix):
def unique_var_name():
return unique_name("__var")
+
+
+def flatten_list(node_list: list[Any]) -> list[Any]:
+ return list(
+ itertools.chain.from_iterable(
+ [flatten_list(e) if e.__class__ == list else [e] for e in node_list]
+ )
+ )
diff --git a/src/gt4py/next/program_processors/runners/gtfn_cpu.py b/src/gt4py/next/program_processors/runners/gtfn.py
similarity index 76%
rename from src/gt4py/next/program_processors/runners/gtfn_cpu.py
rename to src/gt4py/next/program_processors/runners/gtfn.py
index 31b8323474..35c10fe353 100644
--- a/src/gt4py/next/program_processors/runners/gtfn_cpu.py
+++ b/src/gt4py/next/program_processors/runners/gtfn.py
@@ -16,11 +16,12 @@
import numpy.typing as npt
+from gt4py._core import definitions as core_defs
from gt4py.eve.utils import content_hash
from gt4py.next import common
from gt4py.next.iterator.transforms import LiftMode
-from gt4py.next.otf import languages, recipes, stages, workflow
-from gt4py.next.otf.binding import cpp_interface, nanobind
+from gt4py.next.otf import languages, recipes, stages, step_types, workflow
+from gt4py.next.otf.binding import nanobind
from gt4py.next.otf.compilation import cache, compiler
from gt4py.next.otf.compilation.build_systems import compiledb
from gt4py.next.program_processors import otf_compile_executor
@@ -91,11 +92,23 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:
)
-GTFN_DEFAULT_TRANSLATION_STEP = gtfn_module.GTFNTranslationStep(
- cpp_interface.CPP_DEFAULT, enable_itir_transforms=True, use_imperative_backend=False
+GTFN_DEFAULT_TRANSLATION_STEP: step_types.TranslationStep[
+ languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings
+] = gtfn_module.GTFNTranslationStep(
+ enable_itir_transforms=True,
+ use_imperative_backend=False,
+ device_type=core_defs.DeviceType.CPU,
)
-GTFN_DEFAULT_COMPILE_STEP = compiler.Compiler(
+GTFN_GPU_TRANSLATION_STEP: step_types.TranslationStep[
+ languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings
+] = gtfn_module.GTFNTranslationStep(
+ enable_itir_transforms=True,
+ use_imperative_backend=False,
+ device_type=core_defs.DeviceType.CUDA,
+)
+
+GTFN_DEFAULT_COMPILE_STEP: step_types.CompilationStep = compiler.Compiler(
cache_strategy=cache.Strategy.SESSION, builder_factory=compiledb.CompiledbFactory()
)
@@ -108,30 +121,35 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:
)
-run_gtfn = otf_compile_executor.OTFCompileExecutor[
- languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python, Any
-](name="run_gtfn", otf_workflow=GTFN_DEFAULT_WORKFLOW)
+GTFN_GPU_WORKFLOW = recipes.OTFCompileWorkflow(
+ translation=GTFN_GPU_TRANSLATION_STEP,
+ bindings=nanobind.bind_source,
+ compilation=GTFN_DEFAULT_COMPILE_STEP,
+ decoration=convert_args,
+)
+
+
+run_gtfn = otf_compile_executor.OTFCompileExecutor(
+ name="run_gtfn", otf_workflow=GTFN_DEFAULT_WORKFLOW
+)
-run_gtfn_imperative = otf_compile_executor.OTFCompileExecutor[
- languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python, Any
-](
+run_gtfn_imperative = otf_compile_executor.OTFCompileExecutor(
name="run_gtfn_imperative",
otf_workflow=run_gtfn.otf_workflow.replace(
translation=run_gtfn.otf_workflow.translation.replace(use_imperative_backend=True),
),
)
-run_gtfn_cached = otf_compile_executor.CachedOTFCompileExecutor[
- languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python, Any
-](
+run_gtfn_cached = otf_compile_executor.CachedOTFCompileExecutor(
name="run_gtfn_cached",
otf_workflow=workflow.CachedStep(step=run_gtfn.otf_workflow, hash_function=compilation_hash),
) # todo(ricoh): add API for converting an executor to a cached version of itself and vice versa
+run_gtfn_gpu = otf_compile_executor.OTFCompileExecutor(
+ name="run_gtfn_gpu", otf_workflow=GTFN_GPU_WORKFLOW
+)
-run_gtfn_with_temporaries = otf_compile_executor.OTFCompileExecutor[
- languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python, Any
-](
+run_gtfn_with_temporaries = otf_compile_executor.OTFCompileExecutor(
name="run_gtfn_with_temporaries",
otf_workflow=run_gtfn.otf_workflow.replace(
translation=run_gtfn.otf_workflow.translation.replace(lift_mode=LiftMode.FORCE_TEMPORARIES),
diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py
index 27ccb29095..98ac9352c3 100644
--- a/tests/next_tests/exclusion_matrices.py
+++ b/tests/next_tests/exclusion_matrices.py
@@ -61,7 +61,6 @@
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
(USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE),
(USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE),
- (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
]
#: Skip matrix, contains for each backend processor a list of tuples with following fields:
@@ -81,11 +80,18 @@
(USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
],
- GTFN_CPU: GTFN_SKIP_TEST_LIST,
- GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST,
+ GTFN_CPU: GTFN_SKIP_TEST_LIST
+ + [
+ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
+ ],
+ GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST
+ + [
+ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
+ ],
GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST
+ [
(USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE),
+ (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
],
GTFN_FORMAT_SOURCECODE: [
(USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE),
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/__init__.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/__init__.py
new file mode 100644
index 0000000000..6c43e2f12a
--- /dev/null
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/__init__.py
@@ -0,0 +1,13 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
index 383716484e..93296ae85f 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
@@ -21,8 +21,8 @@
import gt4py.next as gtx
from gt4py.next.ffront import decorator
-from gt4py.next.iterator import embedded, ir as itir
-from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip
+from gt4py.next.iterator import ir as itir
+from gt4py.next.program_processors.runners import gtfn, roundtrip
try:
@@ -49,9 +49,9 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non
@pytest.fixture(
params=[
roundtrip.executor,
- gtfn_cpu.run_gtfn,
- gtfn_cpu.run_gtfn_imperative,
- gtfn_cpu.run_gtfn_with_temporaries,
+ gtfn.run_gtfn,
+ gtfn.run_gtfn_imperative,
+ gtfn.run_gtfn_with_temporaries,
]
+ OPTIONAL_PROCESSORS,
ids=lambda p: next_tests.get_processor_id(p),
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py
index 1402649127..deb1382dfb 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py
@@ -20,22 +20,11 @@
import pytest
from gt4py.next import errors
-from gt4py.next.common import Field
-from gt4py.next.errors.exceptions import TypeError_
from gt4py.next.ffront.decorator import field_operator, program, scan_operator
-from gt4py.next.ffront.fbuiltins import broadcast, int32, int64
-from gt4py.next.program_processors.runners import gtfn_cpu
+from gt4py.next.ffront.fbuiltins import broadcast, int32
from next_tests.integration_tests import cases
-from next_tests.integration_tests.cases import (
- IDim,
- IField,
- IJKField,
- IJKFloatField,
- JDim,
- KDim,
- cartesian_case,
-)
+from next_tests.integration_tests.cases import IDim, IField, IJKFloatField, KDim, cartesian_case
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
fieldview_backend,
)
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
index 865950eeab..f974e07ad8 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py
@@ -27,13 +27,12 @@
float64,
int32,
int64,
- maximum,
minimum,
neighbor_sum,
where,
)
from gt4py.next.ffront.experimental import as_offset
-from gt4py.next.program_processors.runners import gtfn_cpu
+from gt4py.next.program_processors.runners import gtfn
from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
@@ -159,7 +158,6 @@ def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField:
cases.verify(cartesian_case, testee, a, b, out=out, ref=a.ndarray[1:] + b.ndarray[2:])
-@pytest.mark.uses_tuple_returns
def test_tuples(cartesian_case): # noqa: F811 # fixtures
@gtx.field_operator
def testee(a: cases.IJKFloatField, b: cases.IJKFloatField) -> cases.IJKFloatField:
@@ -400,7 +398,6 @@ def testee(a: cases.IKField, offset_field: cases.IKField) -> gtx.Field[[IDim, KD
assert np.allclose(out, ref)
-@pytest.mark.uses_tuple_returns
def test_nested_tuple_return(cartesian_case):
@gtx.field_operator
def pack_tuple(
@@ -476,7 +473,7 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField
)
-@pytest.mark.uses_tuple_returns
+@pytest.mark.uses_constant_fields
def test_tuple_with_local_field_in_reduction_shifted(unstructured_case):
@gtx.field_operator
def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField:
@@ -528,12 +525,12 @@ def simple_scan_operator(carry: float) -> float:
@pytest.mark.uses_lift_expressions
def test_solve_triag(cartesian_case):
if cartesian_case.backend in [
- gtfn_cpu.run_gtfn,
- gtfn_cpu.run_gtfn_imperative,
- gtfn_cpu.run_gtfn_with_temporaries,
+ gtfn.run_gtfn,
+ gtfn.run_gtfn_imperative,
+ gtfn.run_gtfn_with_temporaries,
]:
pytest.xfail("Nested `scan`s requires creating temporaries.")
- if cartesian_case.backend == gtfn_cpu.run_gtfn_with_temporaries:
+ if cartesian_case.backend == gtfn.run_gtfn_with_temporaries:
pytest.xfail("Temporary extraction does not work correctly in combination with scans.")
@gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0))
@@ -632,7 +629,7 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField:
def test_ternary_scan(cartesian_case):
- if cartesian_case.backend in [gtfn_cpu.run_gtfn_with_temporaries]:
+ if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]:
pytest.xfail("Temporary extraction does not work correctly in combination with scans.")
@gtx.scan_operator(axis=KDim, forward=True, init=0.0)
@@ -655,7 +652,7 @@ def simple_scan_operator(carry: float, a: float) -> float:
@pytest.mark.parametrize("forward", [True, False])
@pytest.mark.uses_tuple_returns
def test_scan_nested_tuple_output(forward, cartesian_case):
- if cartesian_case.backend in [gtfn_cpu.run_gtfn_with_temporaries]:
+ if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]:
pytest.xfail("Temporary extraction does not work correctly in combination with scans.")
init = (1, (2, 3))
@@ -692,7 +689,9 @@ def test_scan_nested_tuple_input(cartesian_case):
inp2 = gtx.np_as_located_field(KDim)(np.arange(0.0, k_size, 1))
out = gtx.np_as_located_field(KDim)(np.zeros((k_size,)))
- prev_levels_iterator = lambda i: range(i + 1)
+ def prev_levels_iterator(i):
+ return range(i + 1)
+
expected = np.asarray(
[
reduce(lambda prev, i: prev + inp1[i] + inp2[i], prev_levels_iterator(i), init)
@@ -760,9 +759,9 @@ def program_domain(a: cases.IField, out: cases.IField):
def test_domain_input_bounds(cartesian_case):
if cartesian_case.backend in [
- gtfn_cpu.run_gtfn,
- gtfn_cpu.run_gtfn_imperative,
- gtfn_cpu.run_gtfn_with_temporaries,
+ gtfn.run_gtfn,
+ gtfn.run_gtfn_imperative,
+ gtfn.run_gtfn_with_temporaries,
]:
pytest.xfail("FloorDiv not fully supported in gtfn.")
@@ -840,7 +839,6 @@ def program_domain(
)
-@pytest.mark.uses_tuple_returns
def test_domain_tuple(cartesian_case):
@gtx.field_operator
def fieldop_domain_tuple(
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py
new file mode 100644
index 0000000000..290cece3fa
--- /dev/null
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gpu_backend.py
@@ -0,0 +1,43 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import pytest
+
+import gt4py.next as gtx
+from gt4py.next.iterator import embedded
+from gt4py.next.program_processors.runners import gtfn
+
+from next_tests.integration_tests import cases
+from next_tests.integration_tests.cases import cartesian_case # noqa: F401
+from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( # noqa: F401
+ fieldview_backend,
+)
+
+
+@pytest.mark.requires_gpu
+@pytest.mark.parametrize("fieldview_backend", [gtfn.run_gtfn_gpu])
+def test_copy(cartesian_case, fieldview_backend): # noqa: F811 # fixtures
+ import cupy as cp # TODO(ricoh): replace with storages solution when available
+
+ @gtx.field_operator(backend=fieldview_backend)
+ def testee(a: cases.IJKField) -> cases.IJKField:
+ return a
+
+ inp_arr = cp.full(shape=(3, 4, 5), fill_value=3, dtype=cp.int32)
+ outp_arr = cp.zeros_like(inp_arr)
+ inp = embedded.np_as_located_field(cases.IDim, cases.JDim, cases.KDim)(inp_arr)
+ outp = embedded.np_as_located_field(cases.IDim, cases.JDim, cases.KDim)(outp_arr)
+
+ testee(inp, out=outp, offset_provider={})
+ assert cp.allclose(inp_arr, outp_arr)
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py
index 0ae874f3a6..56d5e35b3a 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py
@@ -17,8 +17,8 @@
import pytest
import gt4py.next as gtx
-from gt4py.next import broadcast, float64, int32, int64, max_over, min_over, neighbor_sum, where
-from gt4py.next.program_processors.runners import gtfn_cpu
+from gt4py.next import broadcast, float64, int32, max_over, min_over, neighbor_sum, where
+from gt4py.next.program_processors.runners import gtfn
from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import (
@@ -30,7 +30,6 @@
Joff,
KDim,
V2EDim,
- Vertex,
cartesian_case,
unstructured_case,
)
@@ -47,9 +46,9 @@
)
def test_maxover_execution_(unstructured_case, strategy):
if unstructured_case.backend in [
- gtfn_cpu.run_gtfn,
- gtfn_cpu.run_gtfn_imperative,
- gtfn_cpu.run_gtfn_with_temporaries,
+ gtfn.run_gtfn,
+ gtfn.run_gtfn_imperative,
+ gtfn.run_gtfn_with_temporaries,
]:
pytest.xfail("`maxover` broken in gtfn, see #1289.")
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py
index 85826c1ac0..034ce56fee 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py
@@ -37,7 +37,7 @@
tanh,
trunc,
)
-from gt4py.next.program_processors.runners import gtfn_cpu
+from gt4py.next.program_processors.runners import gtfn
from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import IDim, cartesian_case, unstructured_case
@@ -69,9 +69,9 @@ def pow(inp1: cases.IField) -> cases.IField:
def test_floordiv(cartesian_case):
if cartesian_case.backend in [
- gtfn_cpu.run_gtfn,
- gtfn_cpu.run_gtfn_imperative,
- gtfn_cpu.run_gtfn_with_temporaries,
+ gtfn.run_gtfn,
+ gtfn.run_gtfn_imperative,
+ gtfn.run_gtfn_with_temporaries,
]:
pytest.xfail(
"FloorDiv not yet supported."
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py
index f489126fa7..d86bc21679 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py
@@ -128,7 +128,6 @@ def fo_from_fo_program(in_field: cases.IFloatField, out: cases.IFloatField):
)
-@pytest.mark.uses_tuple_returns
def test_tuple_program_return_constructed_inside(cartesian_case):
@gtx.field_operator
def pack_tuple(
@@ -155,7 +154,6 @@ def prog(
assert np.allclose((a, b), (out_a, out_b))
-@pytest.mark.uses_tuple_returns
def test_tuple_program_return_constructed_inside_with_slicing(cartesian_case):
@gtx.field_operator
def pack_tuple(
@@ -183,7 +181,6 @@ def prog(
assert out_a[0] == 0 and out_b[0] == 0
-@pytest.mark.uses_tuple_returns
def test_tuple_program_return_constructed_inside_nested(cartesian_case):
@gtx.field_operator
def pack_tuple(
diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py
index ca29c5b18b..e2bbbaa553 100644
--- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py
+++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py
@@ -52,7 +52,7 @@
xor_,
)
from gt4py.next.iterator.runtime import closure, fendef, fundef, offset
-from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data
from next_tests.unit_tests.conftest import program_processor, run_processor
diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py
index bd5a717bb2..67b439507c 100644
--- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py
+++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py
@@ -148,7 +148,6 @@ def stencil(inp1, inp2, inp3, inp4):
"stencil",
[tuple_output1, tuple_output2],
)
-@pytest.mark.uses_tuple_returns
def test_tuple_of_field_output_constructed_inside(program_processor, stencil):
program_processor, validate = program_processor
@@ -194,7 +193,6 @@ def fencil(size0, size1, size2, inp1, inp2, out1, out2):
assert np.allclose(inp2, out2)
-@pytest.mark.uses_tuple_returns
def test_asymetric_nested_tuple_of_field_output_constructed_inside(program_processor):
program_processor, validate = program_processor
@@ -288,7 +286,7 @@ def tuple_input(inp):
return tuple_get(0, inp_deref) + tuple_get(1, inp_deref)
-@pytest.mark.uses_tuple_returns
+@pytest.mark.uses_tuple_args
def test_tuple_field_input(program_processor):
program_processor, validate = program_processor
@@ -348,7 +346,7 @@ def tuple_tuple_input(inp):
)
-@pytest.mark.uses_tuple_returns
+@pytest.mark.uses_tuple_args
def test_tuple_of_tuple_of_field_input(program_processor):
program_processor, validate = program_processor
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py
index 8db9a4c36e..64fb238470 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py
@@ -18,7 +18,7 @@
import pytest
import gt4py.next as gtx
-from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip
+from gt4py.next.program_processors.runners import gtfn, roundtrip
from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
fieldview_backend,
@@ -214,9 +214,9 @@ class setup:
@pytest.mark.uses_tuple_returns
def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend):
if fieldview_backend in [
- gtfn_cpu.run_gtfn,
- gtfn_cpu.run_gtfn_imperative,
- gtfn_cpu.run_gtfn_with_temporaries,
+ gtfn.run_gtfn,
+ gtfn.run_gtfn_imperative,
+ gtfn.run_gtfn_with_temporaries,
]:
pytest.xfail("Needs implementation of scan projector.")
@@ -234,7 +234,7 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup, fieldview_backend):
@pytest.mark.uses_tuple_returns
def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend):
- if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]:
+ if fieldview_backend in [gtfn.run_gtfn_with_temporaries]:
pytest.xfail(
"Needs implementation of scan projector. Breaks in type inference as executed"
"again after CollapseTuple."
@@ -256,7 +256,7 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup, fieldview_backend):
@pytest.mark.uses_tuple_returns
def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend):
- if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]:
+ if fieldview_backend in [gtfn.run_gtfn_with_temporaries]:
pytest.xfail("Temporary extraction does not work correctly in combination with scans.")
solve_nonhydro_stencil_52_like.with_backend(fieldview_backend)(
test_setup.z_alpha,
@@ -273,7 +273,7 @@ def test_solve_nonhydro_stencil_52_like(test_setup, fieldview_backend):
@pytest.mark.uses_tuple_returns
def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup, fieldview_backend):
- if fieldview_backend in [gtfn_cpu.run_gtfn_with_temporaries]:
+ if fieldview_backend in [gtfn.run_gtfn_with_temporaries]:
pytest.xfail("Temporary extraction does not work correctly in combination with scans.")
if fieldview_backend == roundtrip.executor:
pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].")
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py
index 16d839a8ab..4e295e92af 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py
@@ -18,7 +18,7 @@
import gt4py.next as gtx
from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift
from gt4py.next.iterator.runtime import closure, fendef, fundef, offset
-from gt4py.next.program_processors.runners import gtfn_cpu
+from gt4py.next.program_processors.runners import gtfn
from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor
@@ -79,9 +79,9 @@ def test_anton_toy(program_processor, lift_mode):
program_processor, validate = program_processor
if program_processor in [
- gtfn_cpu.run_gtfn,
- gtfn_cpu.run_gtfn_imperative,
- gtfn_cpu.run_gtfn_with_temporaries,
+ gtfn.run_gtfn,
+ gtfn.run_gtfn_imperative,
+ gtfn.run_gtfn_with_temporaries,
]:
from gt4py.next.iterator import transforms
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py
index 41d6c8f0f9..04cf8c6f9c 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py
@@ -149,7 +149,7 @@ def k_level_condition_upper_tuple(k_idx, k_level):
),
],
)
-@pytest.mark.uses_tuple_returns
+@pytest.mark.uses_tuple_args
def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_function, ref_function):
program_processor, validate = program_processor
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py
index 42de13ef44..445b73548b 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py
@@ -16,7 +16,7 @@
import pytest
-pytest.importorskip("atlas4py")
+pytest.importorskip("atlas4py") # isort: skip
import gt4py.next as gtx
from gt4py.next.iterator import library
@@ -37,7 +37,6 @@
)
from gt4py.next.iterator.runtime import closure, fendef, fundef, offset
from gt4py.next.iterator.transforms.pass_manager import LiftMode
-from gt4py.next.program_processors.runners import gtfn_cpu
from next_tests.integration_tests.multi_feature_tests.iterator_tests.fvm_nabla_setup import (
assert_close,
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py
index 7bd028b7c3..af70dd590f 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py
@@ -18,7 +18,7 @@
import gt4py.next as gtx
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fendef, fundef, offset
-from gt4py.next.program_processors.runners import gtfn_cpu
+from gt4py.next.program_processors.runners import gtfn
from next_tests.integration_tests.cases import IDim, JDim
from next_tests.integration_tests.multi_feature_tests.iterator_tests.hdiff_reference import (
@@ -75,9 +75,9 @@ def hdiff(inp, coeff, out, x, y):
def test_hdiff(hdiff_reference, program_processor, lift_mode):
program_processor, validate = program_processor
if program_processor in [
- gtfn_cpu.run_gtfn,
- gtfn_cpu.run_gtfn_imperative,
- gtfn_cpu.run_gtfn_with_temporaries,
+ gtfn.run_gtfn,
+ gtfn.run_gtfn_imperative,
+ gtfn.run_gtfn_with_temporaries,
]:
# TODO(tehrengruber): check if still true
from gt4py.next.iterator import transforms
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py
index f11046cb5d..a0471e8baa 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py
@@ -19,10 +19,8 @@
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fendef, fundef
from gt4py.next.iterator.transforms import LiftMode
-from gt4py.next.program_processors.formatters.gtfn import (
- format_sourcecode as gtfn_format_sourcecode,
-)
-from gt4py.next.program_processors.runners import gtfn_cpu
+from gt4py.next.program_processors.formatters import gtfn as gtfn_formatters
+from gt4py.next.program_processors.runners import gtfn
from next_tests.integration_tests.cases import IDim, JDim, KDim
from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor
@@ -121,16 +119,16 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode):
if (
program_processor
in [
- gtfn_cpu.run_gtfn,
- gtfn_cpu.run_gtfn_imperative,
- gtfn_cpu.run_gtfn_with_temporaries,
- gtfn_format_sourcecode,
+ gtfn.run_gtfn,
+ gtfn.run_gtfn_imperative,
+ gtfn.run_gtfn_with_temporaries,
+ gtfn_formatters.format_sourcecode,
]
and lift_mode == LiftMode.FORCE_INLINE
):
pytest.skip("gtfn does only support lifted scans when using temporaries")
if (
- program_processor == gtfn_cpu.run_gtfn_with_temporaries
+ program_processor == gtfn.run_gtfn_with_temporaries
or lift_mode == LiftMode.FORCE_TEMPORARIES
):
pytest.xfail("tuple_get on columns not supported.")
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py
index 92b93ddb63..d475fab3a8 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py
@@ -30,15 +30,13 @@
shift,
)
from gt4py.next.iterator.runtime import fundef
-from gt4py.next.program_processors.formatters import gtfn
-from gt4py.next.program_processors.runners import gtfn_cpu
+from gt4py.next.program_processors.runners import gtfn
from next_tests.toy_connectivity import (
C2E,
E2V,
V2E,
V2V,
- C2EDim,
Cell,
E2VDim,
Edge,
@@ -409,9 +407,9 @@ def shift_sparse_stencil2(inp):
def test_shift_sparse_input_field2(program_processor, lift_mode):
program_processor, validate = program_processor
if program_processor in [
- gtfn_cpu.run_gtfn,
- gtfn_cpu.run_gtfn_imperative,
- gtfn_cpu.run_gtfn_with_temporaries,
+ gtfn.run_gtfn,
+ gtfn.run_gtfn_imperative,
+ gtfn.run_gtfn_with_temporaries,
]:
pytest.xfail(
"Bug in bindings/compilation/caching: only the first program seems to be compiled."
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py b/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py
index 4e456637cf..c60079eaf1 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/otf_tests/test_gtfn_workflow.py
@@ -14,7 +14,7 @@
import numpy as np
import gt4py.next as gtx
-from gt4py.next.program_processors.runners import gtfn_cpu
+from gt4py.next.program_processors.runners import gtfn
from next_tests.integration_tests.cases import IDim, JDim
@@ -37,7 +37,7 @@ def test_different_buffer_sizes():
)
out = gtx.np_as_located_field(IDim, JDim)(np.zeros((out_nx, out_ny), dtype=np.int32))
- @gtx.field_operator(backend=gtfn_cpu.run_gtfn)
+ @gtx.field_operator(backend=gtfn.run_gtfn)
def copy(inp: gtx.Field[[IDim, JDim], gtx.int32]) -> gtx.Field[[IDim, JDim], gtx.int32]:
return inp
diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py
index 7a62778be1..747431599a 100644
--- a/tests/next_tests/unit_tests/conftest.py
+++ b/tests/next_tests/unit_tests/conftest.py
@@ -22,8 +22,8 @@
from gt4py import eve
from gt4py.next.iterator import ir as itir, pretty_parser, pretty_printer, runtime, transforms
from gt4py.next.program_processors import processor_interface as ppi
-from gt4py.next.program_processors.formatters import gtfn, lisp, type_check
-from gt4py.next.program_processors.runners import double_roundtrip, gtfn_cpu, roundtrip
+from gt4py.next.program_processors.formatters import gtfn as gtfn_formatters, lisp, type_check
+from gt4py.next.program_processors.runners import double_roundtrip, gtfn, roundtrip
try:
@@ -78,10 +78,10 @@ def pretty_format_and_check(root: itir.FencilDefinition, *args, **kwargs) -> str
(roundtrip.executor, True),
(type_check.check, False),
(double_roundtrip.executor, True),
- (gtfn_cpu.run_gtfn, True),
- (gtfn_cpu.run_gtfn_imperative, True),
- (gtfn_cpu.run_gtfn_with_temporaries, True),
- (gtfn.format_sourcecode, False),
+ (gtfn.run_gtfn, True),
+ (gtfn.run_gtfn_imperative, True),
+ (gtfn.run_gtfn_with_temporaries, True),
+ (gtfn_formatters.format_sourcecode, False),
]
+ OPTIONAL_PROCESSORS,
ids=lambda p: next_tests.get_processor_id(p[0]),
diff --git a/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py
new file mode 100644
index 0000000000..e87f869352
--- /dev/null
+++ b/tests/next_tests/unit_tests/ffront_tests/foast_passes_tests/test_type_alias_replacement.py
@@ -0,0 +1,44 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import ast
+import typing
+from typing import TypeAlias
+
+import pytest
+
+import gt4py.next as gtx
+from gt4py.next import float32, float64
+from gt4py.next.ffront.fbuiltins import astype
+from gt4py.next.ffront.func_to_foast import FieldOperatorParser
+
+
+TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests.
+vpfloat: TypeAlias = float32
+wpfloat: TypeAlias = float64
+
+
+@pytest.mark.parametrize("test_input,expected", [(vpfloat, "float32"), (wpfloat, "float64")])
+def test_type_alias_replacement(test_input, expected):
+ def fieldop_with_typealias(
+ a: gtx.Field[[TDim], test_input], b: gtx.Field[[TDim], float32]
+ ) -> gtx.Field[[TDim], test_input]:
+ return test_input("3.1418") + astype(a, test_input)
+
+ foast_tree = FieldOperatorParser.apply_to_function(fieldop_with_typealias)
+
+ assert (
+ foast_tree.body.stmts[0].value.left.func.id == expected
+ and foast_tree.body.stmts[0].value.right.args[1].id == expected
+ )
diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py
index 1fab2643b5..45ef85e37c 100644
--- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py
+++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py
@@ -78,7 +78,7 @@ def make_program_source(name: str) -> stages.ProgramSource:
entry_point=entry_point,
source_code=src,
library_deps=[
- interface.LibraryDependency("gridtools", "master"),
+ interface.LibraryDependency("gridtools_cpu", "master"),
],
language=languages.Cpp,
language_settings=cpp_interface.CPP_DEFAULT,
diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py
index 93be884687..ae5f582e47 100644
--- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py
+++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py
@@ -65,9 +65,9 @@ def fencil_example():
def test_codegen(fencil_example):
fencil, parameters = fencil_example
- module = gtfn_module.translate_program(
+ module = gtfn_module.translate_program_cpu(
stages.ProgramCall(fencil, parameters, {"offset_provider": {}})
)
assert module.entry_point.name == fencil.id
- assert any(d.name == "gridtools" for d in module.library_deps)
+ assert any(d.name == "gridtools_cpu" for d in module.library_deps)
assert module.language is languages.Cpp
diff --git a/tox.ini b/tox.ini
index e16aaff27f..18a6ff8e84 100644
--- a/tox.ini
+++ b/tox.ini
@@ -71,7 +71,7 @@ commands =
python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} {posargs} tests{/}eve_tests
python -m pytest --doctest-modules src{/}gt4py{/}eve
-[testenv:next-py{310}-{nomesh,atlas}]
+[testenv:next-py{310}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}]
description = Run 'gt4py.next' tests
pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH
deps =
@@ -81,8 +81,10 @@ set_env =
{[testenv]set_env}
PIP_EXTRA_INDEX_URL = {env:PIP_EXTRA_INDEX_URL:https://test.pypi.org/simple/}
commands =
- nomesh: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas" {posargs} tests{/}next_tests
- atlas: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas" {posargs} tests{/}next_tests
+ nomesh-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and not requires_gpu" {posargs} tests{/}next_tests
+ nomesh-gpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_atlas and requires_gpu" {posargs} tests{/}next_tests
+ atlas-cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and not requires_gpu" {posargs} tests{/}next_tests
+ atlas-gpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests
pytest --doctest-modules src{/}gt4py{/}next
[testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}]