From e29016dbb4f0899f717b7d3f431c46d9f7334276 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 23 Sep 2024 08:25:56 -0400 Subject: [PATCH 01/11] feat[cartesian]: K offset write (#1452) ## Description Looking at enabling offset write in K to help with physics. Write is allowed on `FORWARD` and `BACKWARD`, disallowed for `PARALLEL`. TODO: - [x] Make tests for `conditional` - [ ] Explore auto `extend` calculation - [x] Fix `dace:X` backends Link to #131 Discussion happened on [GridTools concept ](https://github.com/GridTools/concepts/pull/34) ## Requirements - [x] All fixes and/or new features come with corresponding tests. - [x] Important design decisions have been documented in the approriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. --------- Co-authored-by: Hannes Vogt Co-authored-by: Florian Deconinck --- src/gt4py/cartesian/frontend/base.py | 1 + .../cartesian/frontend/gtscript_frontend.py | 28 ++- .../gtc/dace/expansion/daceir_builder.py | 31 ++- .../gtc/dace/expansion/tasklet_codegen.py | 7 +- src/gt4py/cartesian/gtc/dace/utils.py | 10 +- src/gt4py/cartesian/stencil_builder.py | 4 +- tests/cartesian_tests/definitions.py | 12 ++ .../feature_tests/test_field_layouts.py | 17 +- .../test_code_generation.py | 136 ++++++++++++- .../frontend_tests/test_gtscript_frontend.py | 179 ++++++++++++++---- .../frontend_tests/test_ir_maker.py | 2 +- 11 files changed, 354 insertions(+), 73 deletions(-) diff --git a/src/gt4py/cartesian/frontend/base.py b/src/gt4py/cartesian/frontend/base.py index 3ba54f3356..5e542cd36d 100644 --- a/src/gt4py/cartesian/frontend/base.py +++ b/src/gt4py/cartesian/frontend/base.py @@ -74,6 +74,7 @@ def generate( externals: Dict[str, Any], dtypes: Dict[Type, Type], options: BuildOptions, + backend_name: str, ) -> gtir.Stencil: """ Generate a StencilDefinition from a stencil Python function. diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index d21aba674c..962d175eb1 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -708,6 +708,7 @@ def __init__( fields: dict, parameters: dict, local_symbols: dict, + backend_name: str, *, domain: nodes.Domain, temp_decls: Optional[Dict[str, nodes.FieldDecl]] = None, @@ -721,6 +722,7 @@ def __init__( isinstance(value, (type, np.dtype)) for value in local_symbols.values() ) + self.backend_name = backend_name self.fields = fields self.parameters = parameters self.local_symbols = local_symbols @@ -1432,11 +1434,26 @@ def visit_Assign(self, node: ast.Assign) -> list: for t in node.targets[0].elts if isinstance(node.targets[0], ast.Tuple) else node.targets: name, spatial_offset, data_index = self._parse_assign_target(t) if spatial_offset: - if any(offset != 0 for offset in spatial_offset): + if spatial_offset[0] != 0 or spatial_offset[1] != 0: raise GTScriptSyntaxError( - message="Assignment to non-zero offsets is not supported.", + message="Assignment to non-zero offsets is not supported in IJ.", loc=nodes.Location.from_ast_node(t), ) + # Case of K-offset + if len(spatial_offset) == 3 and spatial_offset[2] != 0: + if self.iteration_order == nodes.IterationOrder.PARALLEL: + raise GTScriptSyntaxError( + message="Assignment to non-zero offsets in K is not available in PARALLEL. Choose FORWARD or BACKWARD.", + loc=nodes.Location.from_ast_node(t), + ) + if self.backend_name in ["gt:gpu", "dace:gpu"]: + import cupy as cp + + if cp.cuda.runtime.runtimeGetVersion() < 12000: + raise GTScriptSyntaxError( + message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} for CUDA<12. Please update CUDA.", + loc=nodes.Location.from_ast_node(t), + ) if not self._is_known(name): if name in self.temp_decls: @@ -1997,7 +2014,7 @@ def extract_arg_descriptors(self): return api_signature, fields_decls, parameter_decls - def run(self): + def run(self, backend_name: str): assert ( isinstance(self.ast_root, ast.Module) and "body" in self.ast_root._fields @@ -2055,6 +2072,7 @@ def run(self): fields=fields_decls, parameters=parameter_decls, local_symbols={}, # Not used + backend_name=backend_name, domain=domain, temp_decls=temp_decls, dtypes=self.dtypes, @@ -2110,14 +2128,14 @@ def prepare_stencil_definition(cls, definition, externals): return GTScriptParser.annotate_definition(definition, externals) @classmethod - def generate(cls, definition, externals, dtypes, options): + def generate(cls, definition, externals, dtypes, options, backend_name): if options.build_info is not None: start_time = time.perf_counter() if not hasattr(definition, "_gtscript_"): cls.prepare_stencil_definition(definition, externals) translator = GTScriptParser(definition, externals=externals, dtypes=dtypes, options=options) - definition_ir = translator.run() + definition_ir = translator.run(backend_name) # GTIR only supports LatLonGrids if definition_ir.domain != nodes.Domain.LatLonGrid(): diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index d5b1c91466..a8a3a3cb54 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -321,17 +321,29 @@ def visit_FieldAccess( is_target: bool, targets: Set[eve.SymbolRef], var_offset_fields: Set[eve.SymbolRef], + K_write_with_offset: Set[eve.SymbolRef], **kwargs: Any, ) -> Union[dcir.IndexAccess, dcir.ScalarAccess]: + """Generate the relevant accessor to match the memlet that was previously setup. + + When a Field is written in K, we force the usage of the OUT memlet throughout the stencil + to make sure all side effects are being properly resolved. Frontend checks ensure that no + parallel code issues sips here. + """ + res: Union[dcir.IndexAccess, dcir.ScalarAccess] - if node.name in var_offset_fields: + if node.name in var_offset_fields.union(K_write_with_offset): + # If write in K, we consider the variable to always be a target + is_target = is_target or node.name in targets or node.name in K_write_with_offset + name = get_tasklet_symbol(node.name, node.offset, is_target=is_target) res = dcir.IndexAccess( - name=node.name + "__", + name=name, offset=self.visit( node.offset, - is_target=False, + is_target=is_target, targets=targets, var_offset_fields=var_offset_fields, + K_write_with_offset=K_write_with_offset, **kwargs, ), data_index=node.data_index, @@ -799,11 +811,23 @@ def visit_VerticalLoop( ) ) + # Variable offsets var_offset_fields = { acc.name for acc in node.walk_values().if_isinstance(oir.FieldAccess) if isinstance(acc.offset, oir.VariableKOffset) } + + # We book keep - all write offset to K + K_write_with_offset = set() + for assign_node in node.walk_values().if_isinstance(oir.AssignStmt): + if isinstance(assign_node.left, oir.FieldAccess): + if ( + isinstance(assign_node.left.offset, common.CartesianOffset) + and assign_node.left.offset.k != 0 + ): + K_write_with_offset.add(assign_node.left.name) + sections_idx = next( idx for idx, item in enumerate(global_ctx.library_node.expansion_specification) @@ -821,6 +845,7 @@ def visit_VerticalLoop( iteration_ctx=iteration_ctx, symbol_collector=symbol_collector, var_offset_fields=var_offset_fields, + K_write_with_offset=K_write_with_offset, **kwargs, ) ) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index c219667a4a..696dc27387 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -81,7 +81,12 @@ def visit_IndexAccess( # if this node is not a target, it will still use the symbol of the write memlet if the # field was previously written in the same memlet. memlets = kwargs["read_memlets"] + kwargs["write_memlets"] - memlet = next(mem for mem in memlets if mem.connector == node.name) + try: + memlet = next(mem for mem in memlets if mem.connector == node.name) + except StopIteration: + raise ValueError( + "Memlet connector and tasklet variable mismatch, DaCe IR error." + ) from None index_strs = [] if node.offset is not None: diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index 9be2e9a07d..b5c23d2735 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -61,9 +61,9 @@ def get_tasklet_symbol( name: eve.SymbolRef, offset: Union[CartesianOffset, VariableKOffset], is_target: bool ): if is_target: - return f"__{name}" + return f"gtOUT__{name}" - acc_name = name + "__" + acc_name = f"gtIN__{name}" if offset is not None: offset_strs = [] for axis in dcir.Axis.dims_3d(): @@ -230,9 +230,12 @@ def _make_access_info( region, he_grid, grid_subset, + is_write, ) -> dcir.FieldAccessInfo: + # Check we have expression offsets in K + # OR write offsets in K offset = [offset_node.to_dict()[k] for k in "ijk"] - if isinstance(offset_node, oir.VariableKOffset): + if isinstance(offset_node, oir.VariableKOffset) or (offset[2] != 0 and is_write): variable_offset_axes = [dcir.Axis.K] else: variable_offset_axes = [] @@ -291,6 +294,7 @@ def visit_FieldAccess( region=region, he_grid=he_grid, grid_subset=grid_subset, + is_write=is_write, ) ctx.access_infos[node.name] = access_info.union( ctx.access_infos.get(node.name, access_info) diff --git a/src/gt4py/cartesian/stencil_builder.py b/src/gt4py/cartesian/stencil_builder.py index 07d58f25f5..c0f58c0bc9 100644 --- a/src/gt4py/cartesian/stencil_builder.py +++ b/src/gt4py/cartesian/stencil_builder.py @@ -277,7 +277,9 @@ def gtir_pipeline(self) -> GtirPipeline: return self._build_data.get("gtir_pipeline") or self._build_data.setdefault( "gtir_pipeline", GtirPipeline( - self.frontend.generate(self.definition, self.externals, self.dtypes, self.options), + self.frontend.generate( + self.definition, self.externals, self.dtypes, self.options, self.backend.name + ), self.stencil_id, ), ) diff --git a/tests/cartesian_tests/definitions.py b/tests/cartesian_tests/definitions.py index 9ed4e3dfb3..7499ad4a95 100644 --- a/tests/cartesian_tests/definitions.py +++ b/tests/cartesian_tests/definitions.py @@ -15,6 +15,7 @@ import datetime +import numpy as np import pytest from gt4py import cartesian as gt4pyc @@ -54,3 +55,14 @@ def _get_backends_with_storage_info(storage_info_kind: str): @pytest.fixture() def id_version(): return gt_utils.shashed_id(str(datetime.datetime.now())) + + +def get_array_library(backend: str): + """Return device ready array maker library""" + backend_cls = gt4pyc.backend.from_name(backend) + assert backend_cls is not None + if backend_cls.storage_info["device"] == "gpu": + assert cp is not None + return cp + else: + return np diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py b/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py index d032e16419..c1b4e58f97 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py @@ -6,13 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import numpy as np import pytest from gt4py import cartesian as gt4pyc, storage as gt_storage from gt4py.cartesian import gtscript -from cartesian_tests.definitions import ALL_BACKENDS, PERFORMANCE_BACKENDS +from cartesian_tests.definitions import ALL_BACKENDS, PERFORMANCE_BACKENDS, get_array_library from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import copy_stencil @@ -22,20 +21,10 @@ cp = None -def _get_array_library(backend: str): - backend_cls = gt4pyc.backend.from_name(backend) - assert backend_cls is not None - if backend_cls.storage_info["device"] == "gpu": - assert cp is not None - return cp - else: - return np - - @pytest.mark.parametrize("backend", ALL_BACKENDS) @pytest.mark.parametrize("order", ["C", "F"]) def test_numpy_allocators(backend, order): - xp = _get_array_library(backend) + xp = get_array_library(backend) shape = (20, 10, 5) inp = xp.array(xp.random.randn(*shape), order=order, dtype=xp.float_) outp = xp.zeros(shape=shape, order=order, dtype=xp.float_) @@ -48,7 +37,7 @@ def test_numpy_allocators(backend, order): @pytest.mark.parametrize("backend", PERFORMANCE_BACKENDS) def test_bad_layout_warns(backend): - xp = _get_array_library(backend) + xp = get_array_library(backend) backend_cls = gt4pyc.backend.from_name(backend) assert backend_cls is not None diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index cb8bb8c5d9..976f9a89af 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -27,7 +27,7 @@ ) from gt4py.storage.cartesian import utils as storage_utils -from cartesian_tests.definitions import ALL_BACKENDS, CPU_BACKENDS +from cartesian_tests.definitions import ALL_BACKENDS, CPU_BACKENDS, get_array_library from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import ( EXTERNALS_REGISTRY as externals_registry, REGISTRY as stencil_definitions, @@ -190,12 +190,20 @@ def stencil( assert field_3d.shape == full_shape[:] field_2d = gt_storage.zeros( - full_shape[:-1], dtype, backend=backend, aligned_index=aligned_index[:-1], dimensions="IJ" + full_shape[:-1], + dtype, + backend=backend, + aligned_index=aligned_index[:-1], + dimensions="IJ", ) assert field_2d.shape == full_shape[:-1] field_1d = gt_storage.ones( - full_shape[-1:], dtype, backend=backend, aligned_index=(aligned_index[-1],), dimensions="K" + full_shape[-1:], + dtype, + backend=backend, + aligned_index=(aligned_index[-1],), + dimensions="K", ) assert list(field_1d.shape) == [full_shape[-1]] @@ -273,7 +281,8 @@ def copy_2to3( def test_lower_dimensional_inputs_2d_to_3d_forward(backend): @gtscript.stencil(backend=backend) def copy_2to3( - inp: gtscript.Field[gtscript.IJ, np.float_], outp: gtscript.Field[gtscript.IJK, np.float_] + inp: gtscript.Field[gtscript.IJ, np.float_], + outp: gtscript.Field[gtscript.IJK, np.float_], ): with computation(FORWARD), interval(...): outp[0, 0, 0] = inp @@ -574,6 +583,125 @@ def test(out: Field[np.float64], inp: Field[np.float64]): test(out, inp) +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_K_offset_write(backend): + # Cuda generates bad code for the K offset + if backend == "cuda": + pytest.skip("cuda K-offset write generates bad code") + if backend in ["gt:gpu", "dace:gpu"]: + import cupy as cp + + if cp.cuda.runtime.runtimeGetVersion() < 12000: + pytest.skip( + f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" + ) + + arraylib = get_array_library(backend) + array_shape = (1, 1, 4) + K_values = arraylib.arange(start=40, stop=44) + + # Simple case of writing ot an offset. + # A is untouched + # B is written in K+1 and should have K_values, except for the first element (FORWARD) + @gtscript.stencil(backend=backend) + def simple(A: Field[np.float64], B: Field[np.float64]): + with computation(FORWARD), interval(...): + B[0, 0, 1] = A + + A = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + A[:, :, :] = K_values + B = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + simple(A, B) + assert (B[:, :, 0] == 0).all() + assert (B[:, :, 1:3] == K_values[0:2]).all() + + # Order of operations: FORWARD with negative offset + # means while A is update B will have non-updated values of A + # Because of the interval, value of B[0] is 0 + @gtscript.stencil(backend=backend) + def forward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): + with computation(FORWARD), interval(1, None): + A[0, 0, -1] = scalar + B[0, 0, 0] = A + + A = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + A[:, :, :] = K_values + B = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + forward(A, B, 2.0) + assert (A[:, :, :3] == 2.0).all() + assert (A[:, :, 3] == K_values[3]).all() + assert (B[:, :, 0] == 0).all() + assert (B[:, :, 1:] == K_values[1:]).all() + + # Order of operations: BACKWARD with negative offset + # means A is update B will get the updated values of A + @gtscript.stencil(backend=backend) + def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): + with computation(BACKWARD), interval(1, None): + A[0, 0, -1] = scalar + B[0, 0, 0] = A + + A = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + A[:, :, :] = K_values + B = gt_storage.empty( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + backward(A, B, 2.0) + assert (A[:, :, :3] == 2.0).all() + assert (A[:, :, 3] == K_values[3]).all() + assert (B[:, :, :] == A[:, :, :]).all() + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_K_offset_write_conditional(backend): + if backend == "cuda": + pytest.skip("Cuda backend is not capable of K offset write") + if backend in ["gt:gpu", "dace:gpu"]: + import cupy as cp + + if cp.cuda.runtime.runtimeGetVersion() < 12000: + pytest.skip( + f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" + ) + + arraylib = get_array_library(backend) + array_shape = (1, 1, 4) + K_values = arraylib.arange(start=40, stop=44) + + @gtscript.stencil(backend=backend) + def column_physics_conditional(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): + with computation(BACKWARD), interval(1, None): + if A > 0 and B > 0: + A[0, 0, -1] = scalar + B[0, 0, 1] = A + lev = 1 + while A >= 0 and B >= 0: + A[0, 0, lev] = -1 + B = -1 + lev = lev + 1 + + A = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + A[:, :, :] = K_values + B = gt_storage.ones( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + column_physics_conditional(A, B, 2.0) + assert (A[0, 0, :] == arraylib.array([2, 2, -1, -1])).all() + assert (B[0, 0, :] == arraylib.array([1, -1, 2, 42])).all() + + @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_direct_datadims_index(backend): F64_VEC4 = (np.float64, (2, 2, 2, 2)) diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index 1034176789..e62f878746 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -19,6 +19,7 @@ from gt4py.cartesian.frontend import gtscript_frontend as gt_frontend, nodes from gt4py.cartesian.gtscript import ( __INLINED, + FORWARD, IJ, IJK, PARALLEL, @@ -62,7 +63,7 @@ def parse_definition( ) definition_ir = gt_frontend.GTScriptParser( definition_func, externals=externals or {}, options=build_options, dtypes=dtypes - ).run() + ).run("numpy") setattr(definition_func, "__annotations__", original_annotations) @@ -108,7 +109,9 @@ def definition_func(inout_field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptSymbolError, match=r".*MISSING_CONSTANT.*"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def definition_func(inout_field: gtscript.Field[float]): @@ -116,10 +119,13 @@ def definition_func(inout_field: gtscript.Field[float]): inout_field = inout_field[0, 0, 0] + GLOBAL_NESTED_CONSTANTS.missing with pytest.raises( - gt_frontend.GTScriptDefinitionError, match=r".*GLOBAL_NESTED_CONSTANTS.missing.*" + gt_frontend.GTScriptDefinitionError, + match=r".*GLOBAL_NESTED_CONSTANTS.missing.*", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_recursive_function_imports(self): @@ -200,7 +206,9 @@ def definition_func(inout_field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptDefinitionError, match=r".*WRONG_VALUE_CONSTANT.*"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) @@ -215,7 +223,9 @@ def definition_func(inout_field: gtscript.Field[float]): with pytest.raises(TypeError, match=r"func is not a gtscript function"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_use_in_expr(self): @@ -290,7 +300,9 @@ def definition_func(inout_field: gtscript.Field[float]): "Please assign the function results to symbols first.", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_use_in_call_arg_multiple_return(self): @@ -316,7 +328,9 @@ def definition_func(inout_field: gtscript.Field[float]): "Please assign the function results to symbols first.", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_recursive_function_call_two_externals(self): @@ -412,7 +426,9 @@ def definition_func(in_field: gtscript.Field[float], out_field: gtscript.Field[f with pytest.raises(gt_frontend.GTScriptSyntaxError): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_bad_dup_add(self): @@ -422,7 +438,9 @@ def definition_func(in_field: gtscript.Field[float], out_field: gtscript.Field[f with pytest.raises(gt_frontend.GTScriptSyntaxError): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_bad_dup_axis(self): @@ -432,7 +450,9 @@ def definition_func(in_field: gtscript.Field[float], out_field: gtscript.Field[f with pytest.raises(gt_frontend.GTScriptSyntaxError): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_bad_out_of_order(self): @@ -442,7 +462,9 @@ def definition_func(in_field: gtscript.Field[float], out_field: gtscript.Field[f with pytest.raises(gt_frontend.GTScriptSyntaxError): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) @@ -495,7 +517,9 @@ def definition_func(inout_field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptDefinitionError, match=r".*MISSING_CONSTANT.*"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def definition_func(inout_field: gtscript.Field[float]): @@ -613,10 +637,13 @@ def definition_func(field: gtscript.Field[float]): field = 0 with pytest.raises( - gt_frontend.GTScriptSyntaxError, match="Invalid interval range specification" + gt_frontend.GTScriptSyntaxError, + match="Invalid interval range specification", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_error_do_not_mix(self): @@ -626,7 +653,9 @@ def definition_func(field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptSyntaxError, match="Two-argument syntax"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_reversed_interval(self): @@ -635,10 +664,13 @@ def definition_func(field: gtscript.Field[float]): field = 0 with pytest.raises( - gt_frontend.GTScriptSyntaxError, match="Invalid interval range specification" + gt_frontend.GTScriptSyntaxError, + match="Invalid interval range specification", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_overlapping_intervals_none(self): @@ -651,7 +683,9 @@ def definition_func(field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptSyntaxError, match="Overlapping intervals"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_overlapping_intervals(self): @@ -664,7 +698,9 @@ def definition_func(field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptSyntaxError, match="Overlapping intervals"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_nonoverlapping_intervals(self): @@ -806,7 +842,10 @@ def _stage_laplacian_y(dy, phi): @gtscript.function def _stage_laplacian(dx, dy, phi): - from gt4py.cartesian.__externals__ import stage_laplacian_x, stage_laplacian_y + from gt4py.cartesian.__externals__ import ( + stage_laplacian_x, + stage_laplacian_y, + ) lap_x = stage_laplacian_x(dx=dx, phi=phi) lap_y = stage_laplacian_y(dy=dy, phi=phi) @@ -876,10 +915,13 @@ def definition_func(phi: gtscript.Field[np.float64]): phi = test_no_return(phi) with pytest.raises( - gt_frontend.GTScriptSyntaxError, match="should have a single return statement" + gt_frontend.GTScriptSyntaxError, + match="should have a single return statement", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_number_return_args(self): @@ -896,7 +938,9 @@ def definition_func(phi: gtscript.Field[np.float64]): match="Number of returns values does not match arguments on left side", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_multiple_return(self): @@ -910,10 +954,13 @@ def definition_func(phi: gtscript.Field[np.float64]): phi = test_multiple_return(phi) with pytest.raises( - gt_frontend.GTScriptSyntaxError, match="should have a single return statement" + gt_frontend.GTScriptSyntaxError, + match="should have a single return statement", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_conditional_return(self): @@ -1217,7 +1264,9 @@ def definition_func(inout_field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptSyntaxError): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) @@ -1271,7 +1320,8 @@ def definition_func( ) @pytest.mark.parametrize( - "id_case,test_dtype", list(enumerate([str, np.uint32, np.uint64, dict, map, bytes])) + "id_case,test_dtype", + list(enumerate([str, np.uint32, np.uint64, dict, map, bytes])), ) def test_invalid_inlined_dtypes(self, id_case, test_dtype): with pytest.raises(ValueError, match=r".*data type descriptor.*"): @@ -1285,11 +1335,14 @@ def definition_func( out_field = in_field + param @pytest.mark.parametrize( - "id_case,test_dtype", list(enumerate([str, np.uint32, np.uint64, dict, map, bytes])) + "id_case,test_dtype", + list(enumerate([str, np.uint32, np.uint64, dict, map, bytes])), ) def test_invalid_external_dtypes(self, id_case, test_dtype): def definition_func( - in_field: gtscript.Field["dtype"], out_field: gtscript.Field["dtype"], param: "dtype" + in_field: gtscript.Field["dtype"], + out_field: gtscript.Field["dtype"], + param: "dtype", ): with computation(PARALLEL), interval(...): out_field = in_field + param @@ -1343,7 +1396,10 @@ def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float with pytest.raises(gt_frontend.GTScriptSyntaxError): - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func( + in_field: gtscript.Field[np.float_], + out_field: gtscript.Field[np.float_], + ): with computation(PARALLEL), interval(...): out_field[0, 0, 1] = in_field @@ -1364,7 +1420,7 @@ def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float with pytest.raises( gt_frontend.GTScriptSyntaxError, - match="Assignment to non-zero offsets is not supported.", + match="Assignment to non-zero offsets in K is not available in PARALLEL. Choose FORWARD or BACKWARD.", ): parse_definition( func, @@ -1403,16 +1459,21 @@ def definition_func( with pytest.raises( gt_frontend.GTScriptSyntaxError, - match="Assignment to non-zero offsets is not supported.", + match="Assignment to non-zero offsets in K is not available in PARALLEL. Choose FORWARD or BACKWARD.", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_slice(self): with pytest.raises(gt_frontend.GTScriptSyntaxError): - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func( + in_field: gtscript.Field[np.float_], + out_field: gtscript.Field[np.float_], + ): with computation(PARALLEL), interval(...): out_field[:, :, :] = in_field @@ -1421,7 +1482,10 @@ def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float def test_string(self): with pytest.raises(gt_frontend.GTScriptSyntaxError): - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func( + in_field: gtscript.Field[np.float_], + out_field: gtscript.Field[np.float_], + ): with computation(PARALLEL), interval(...): out_field["a_key"] = in_field @@ -1437,6 +1501,24 @@ def func(in_field: gtscript.Field[np.float_]): parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) + def test_K_offset_write(self): + def func(out: gtscript.Field[np.float64], inp: gtscript.Field[np.float64]): + with computation(FORWARD), interval(...): + out[0, 0, 1] = inp + + parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) + + with pytest.raises( + gt_frontend.GTScriptSyntaxError, + match=r"(.*?)Assignment to non-zero offsets in K is not available in PARALLEL. Choose FORWARD or BACKWARD.(.*)", + ): + + def func(out: gtscript.Field[np.float64], inp: gtscript.Field[np.float64]): + with computation(PARALLEL), interval(...): + out[0, 0, 1] = inp + + parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) + def test_datadims_direct_access(self): # Check classic data dimensions are working def data_dims( @@ -1499,7 +1581,9 @@ def data_dims_with_at( out_field = global_field.A[1, 0, 2] parse_definition( - data_dims_with_at, name=inspect.stack()[0][3], module=self.__class__.__name__ + data_dims_with_at, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) @@ -1541,7 +1625,9 @@ def definition_bw( match=r"(.*?)Intervals must be specified in order of execution(.*)", ): parse_definition( - definition, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) @@ -1653,7 +1739,11 @@ def sumdiff_defs( @pytest.mark.parametrize("dtype_scalar", [int, np.float32, np.float64]) def test_set_arg_dtypes(self, dtype_in, dtype_out, dtype_scalar): definition = self.sumdiff_defs - dtypes = {"dtype_in": dtype_in, "dtype_out": dtype_out, "dtype_scalar": dtype_scalar} + dtypes = { + "dtype_in": dtype_in, + "dtype_out": dtype_out, + "dtype_scalar": dtype_scalar, + } original_annotations = gtscript._set_arg_dtypes(definition, dtypes) @@ -1701,10 +1791,17 @@ def test_set_arg_dtypes(self, dtype_in, dtype_out, dtype_scalar): @pytest.mark.parametrize("dtype_scalar", [int, np.float32, np.float64]) def test_parsing(self, dtype_in, dtype_out, dtype_scalar): definition = self.sumdiff_defs - dtypes = {"dtype_in": dtype_in, "dtype_out": dtype_out, "dtype_scalar": dtype_scalar} + dtypes = { + "dtype_in": dtype_in, + "dtype_out": dtype_out, + "dtype_scalar": dtype_scalar, + } parse_definition( - definition, dtypes=dtypes, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition, + dtypes=dtypes, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) annotations = getattr(definition, "__annotations__", {}) diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_ir_maker.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_ir_maker.py index e054b7a715..7857ae3e00 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_ir_maker.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_ir_maker.py @@ -13,7 +13,7 @@ def test_AugAssign(): - ir_maker = IRMaker(None, None, None, domain=None) + ir_maker = IRMaker(None, None, None, None, domain=None) aug_assign = ast.parse("a += 1", feature_version=PYTHON_AST_VERSION).body[0] _, result = ir_maker.visit_AugAssign(aug_assign) From 7b2ae0781e4b4cae5449a6a3f3c30ea53ab355ad Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 23 Sep 2024 22:31:13 +0200 Subject: [PATCH 02/11] feat[next]: Apply common transforms from program (#1593) --- .../inline_center_deref_lift_vars.py | 19 +++++++------ .../iterator/transforms/inline_fundefs.py | 4 +-- .../next/iterator/transforms/pass_manager.py | 20 +++++++------- tests/next_tests/definitions.py | 3 ++- .../test_temporaries_with_sizes.py | 6 +++++ .../test_inline_center_deref_lift_vars.py | 27 ++++++++++--------- 6 files changed, 46 insertions(+), 33 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 5def306bbc..95c761d7ba 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -9,7 +9,7 @@ import dataclasses from typing import ClassVar, Optional -import gt4py.next.iterator.ir_utils.common_pattern_matcher as common_pattern_matcher +import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm from gt4py import eve from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir @@ -50,27 +50,30 @@ class InlineCenterDerefLiftVars(eve.NodeTranslator): uids: eve_utils.UIDGenerator @classmethod - def apply(cls, node: itir.FencilDefinition, uids: Optional[eve_utils.UIDGenerator] = None): + def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): if not uids: uids = eve_utils.UIDGenerator() return cls(uids=uids).visit(node) - def visit_StencilClosure(self, node: itir.StencilClosure, **kwargs): + def visit_FunCall(self, node: itir.FunCall, **kwargs): # TODO(tehrengruber): move the analysis out of this pass and just make it a requirement # such that we don't need to run in multiple times if multiple passes use it. - trace_shifts.trace_stencil(node.stencil, num_args=len(node.inputs), save_to_annex=True) - return self.generic_visit(node, **kwargs) + if cpm.is_call_to(node.fun, "as_fieldop"): + assert len(node.fun.args) in [1, 2] + trace_shifts.trace_stencil( + node.fun.args[0], num_args=len(node.args), save_to_annex=True + ) - def visit_FunCall(self, node: itir.FunCall, **kwargs): node = self.generic_visit(node) - if common_pattern_matcher.is_let(node): + + if cpm.is_let(node): assert isinstance(node.fun, itir.Lambda) # to make mypy happy eligible_params = [False] * len(node.fun.params) new_args = [] bound_scalars: dict[str, itir.Expr] = {} for i, (param, arg) in enumerate(zip(node.fun.params, node.args)): - if common_pattern_matcher.is_applied_lift(arg) and is_center_derefed_only(param): + if cpm.is_applied_lift(arg) and is_center_derefed_only(param): eligible_params[i] = True bound_arg_name = self.uids.sequential_id(prefix="_icdlv") capture_lift = im.promote_to_const_iterator(bound_arg_name) diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index ad2c1b36c0..0541719f0e 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -21,7 +21,7 @@ def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]): ) return self.generic_visit(node) - def visit_FencilDefinition(self, node: ir.FencilDefinition): + def visit_Program(self, node: ir.Program): return self.generic_visit(node, symtable=node.annex.symtable) @@ -37,7 +37,7 @@ def visit_SymRef(self, node: ir.SymRef, *, referenced: Set[str], second_pass: bo referenced.add(node.id) return node - def visit_FencilDefinition(self, node: ir.FencilDefinition): + def visit_Program(self, node: ir.Program): referenced: Set[str] = set() self.generic_visit(node, referenced=referenced, second_pass=False) return self.generic_visit(node, referenced=referenced, second_pass=True) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 5989a401ca..f431e8e501 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -80,10 +80,15 @@ def apply_common_transforms( ] = None, symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> itir.Program: - if isinstance(ir, itir.Program): - # TODO(havogt): during refactoring to GTIR, we bypass transformations in case we already translated to itir.Program - # (currently the case when using the roundtrip backend) - return ir + if isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries)): + ir = fencil_to_program.FencilToProgram().apply( + ir + ) # FIXME[#1582](havogt): should be removed after refactoring to combined IR + else: + assert isinstance(ir, itir.Program) + # FIXME[#1582](havogt): note: currently the case when using the roundtrip backend + pass + icdlv_uids = eve_utils.UIDGenerator() if lift_mode is None: @@ -131,6 +136,8 @@ def apply_common_transforms( raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") if lift_mode != LiftMode.FORCE_INLINE: + # FIXME[#1582](tehrengruber): implement new temporary pass here + raise NotImplementedError() assert offset_provider is not None ir = CreateGlobalTmps().visit( ir, @@ -175,11 +182,6 @@ def apply_common_transforms( ir = FuseMaps().visit(ir) ir = CollapseListGet().visit(ir) - assert isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries)) - ir = fencil_to_program.FencilToProgram().apply( - ir - ) # FIXME[#1582](havogt): should be removed after refactoring to combined IR - if unroll_reduce: for _ in range(10): unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 41ec0bd147..c0066872f3 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -191,12 +191,13 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST - + [(USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE)], + + [(ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE)], ProgramFormatterId.GTFN_CPP_FORMATTER: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) ], ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: [ + (ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 64887e0ec9..07e40675d6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -66,6 +66,9 @@ def prog( def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh_descriptor): + # FIXME[#1582](tehrengruber): enable when temporary pass has been implemented + pytest.xfail("Temporary pass not implemented.") + unstructured_case = Case( run_gtfn_with_temporaries_and_symbolic_sizes, offset_provider=mesh_descriptor.offset_provider, @@ -99,6 +102,9 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh def test_temporary_symbols(testee, mesh_descriptor): + # FIXME[#1582](tehrengruber): enable when temporary pass has been implemented + pytest.xfail("Temporary pass not implemented.") + itir_with_tmp = apply_common_transforms( testee.itir, lift_mode=LiftMode.USE_TEMPORARIES, diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py index 5d75e1db3e..6cc2f7cd28 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py @@ -11,31 +11,32 @@ from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars -def wrap_in_fencil(expr: itir.Expr) -> itir.FencilDefinition: - return itir.FencilDefinition( +def wrap_in_program(expr: itir.Expr) -> itir.Program: + return itir.Program( id="f", function_definitions=[], params=[im.sym("d"), im.sym("inp"), im.sym("out")], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("it")(expr))(im.ref("inp")), domain=im.call("cartesian_domain")(), - stencil=im.lambda_("it")(expr), - output=im.ref("out"), - inputs=[im.ref("inp")], + target=im.ref("out"), ) ], ) -def unwrap_from_fencil(fencil: itir.FencilDefinition) -> itir.Expr: - return fencil.closures[0].stencil.expr +def unwrap_from_program(program: itir.Program) -> itir.Expr: + stencil = program.body[0].expr.fun.args[0] + return stencil.expr def test_simple(): testee = im.let("var", im.lift("deref")("it"))(im.deref("var")) expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))())(·it)" - actual = unwrap_from_fencil(InlineCenterDerefLiftVars.apply(wrap_in_fencil(testee))) + actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert str(actual) == expected @@ -43,14 +44,14 @@ def test_double_deref(): testee = im.let("var", im.lift("deref")("it"))(im.plus(im.deref("var"), im.deref("var"))) expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))() + ·(↑(λ() → _icdlv_1))())(·it)" - actual = unwrap_from_fencil(InlineCenterDerefLiftVars.apply(wrap_in_fencil(testee))) + actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert str(actual) == expected def test_deref_at_non_center_different_pos(): testee = im.let("var", im.lift("deref")("it"))(im.deref(im.shift("I", 1)("var"))) - actual = unwrap_from_fencil(InlineCenterDerefLiftVars.apply(wrap_in_fencil(testee))) + actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert testee == actual @@ -59,5 +60,5 @@ def test_deref_at_multiple_pos(): im.plus(im.deref("var"), im.deref(im.shift("I", 1)("var"))) ) - actual = unwrap_from_fencil(InlineCenterDerefLiftVars.apply(wrap_in_fencil(testee))) + actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert testee == actual From c53c6f32988f4ce6ff448f0c7ffc724c44c1bec2 Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Tue, 24 Sep 2024 08:59:22 +0200 Subject: [PATCH 03/11] feat[next]: AOT toolchain (#1545) ## Description Partial implementation of https://hackmd.io/3EMvJ7zJQPepDd18qDU3sg Missing: - Integrate transforms and compilation steps into one workflow. This PR does prepare the ground for such a design change however. ### Background: The `gt4py.next` machinery to transform DSL functions into executable code has so far had unrestricted access to the arguments the user wished to pass to the executable function, including full fields and their contents, as well as connectivity tables. In other words, only JIT compilation was possible, even though in principle AOT compilation should also work. In principle this should not be necessary, and for most of the reusable, pluggable steps in this machinery this access was restricted in a previous PR. However, the restrictions were introduced on a per-step level to make the changes easier to integrate. In order to get pre-compiled programs, `icon4py` has used a workaround to store the result of the first JIT compilation. ### Changes This PR makes introduces the restriction that the machinery should only have access to the types of field and scalar arguments, and everything but the connectivity tables for connectivityes on a toolchain level. With the exception that one single ITIR pass still requires access to connectivity tables, however it has been tested that nothing else does. This also enabled the following design changes: - The field operator / program stages are now decoupled from the args and kwargs, which means less of them are required - The transformation steps could be simplified, because they only need to deal with the part of the data they actually require and an adapter can ensure interoperability - "optional" workflow steps have been replaced with step orders that depend on input type (see `next.otf.workflow.MultiWorkflow` and `next.backend.Transforms` for details. - `next.backend.Backend` now requires a `ModularExecutor` and will call it's `.executor.otf_workflow` directly and not `.executor.__call__`. This is what prepares for unifying transform and compile workflows. This also means that any custom code in `.executor.__call__` is not going to be called by backends. For new `ProgramExecutors` it is highly recommended to put all such code into their `.otf_workflow.decorate` step. The alternative is to only be usable with a custom backend. - `decorator.FrozenProgram` JIT compiles on first call and subsequently skips the whole toolchain. It is accessible via `@gtx.program(frozen=True, backend=...) or `decorator.Program.freeze()`. It is in principle also compatible with field operators. ### Required steps before merging - [x] test with `icon4py` (passed: https://github.com/C2SM/icon4py/pull/513) and add required changes ### Future steps It would be convenient to create a nice API to automatically create compile time arguments from type annotations to DSL functions. But this should probably be done in a follow-up PR. --------- Co-authored-by: Edoardo Paone Co-authored-by: Till Ehrengruber --- docs/user/next/advanced/HackTheToolchain.md | 49 +- .../next/advanced/ToolchainWalkthrough.md | 523 ++++-------------- docs/user/next/advanced/WorkflowPatterns.md | 14 +- src/gt4py/next/__init__.py | 2 - src/gt4py/next/backend.py | 214 ++++--- src/gt4py/next/ffront/decorator.py | 194 +++++-- src/gt4py/next/ffront/foast_to_gtir.py | 20 + src/gt4py/next/ffront/foast_to_itir.py | 22 +- src/gt4py/next/ffront/foast_to_past.py | 229 +++++--- src/gt4py/next/ffront/func_to_foast.py | 76 ++- src/gt4py/next/ffront/func_to_past.py | 80 +-- src/gt4py/next/ffront/past_passes/linters.py | 18 +- src/gt4py/next/ffront/past_process_args.py | 46 +- src/gt4py/next/ffront/past_to_itir.py | 107 ++-- src/gt4py/next/ffront/signature.py | 153 +++++ src/gt4py/next/ffront/stages.py | 58 +- src/gt4py/next/iterator/embedded.py | 36 +- src/gt4py/next/iterator/ir.py | 2 + .../iterator/transforms/fencil_to_program.py | 2 + .../next/iterator/transforms/global_tmps.py | 2 + src/gt4py/next/otf/arguments.py | 244 ++++++++ .../compilation/build_systems/compiledb.py | 3 + src/gt4py/next/otf/compilation/cache.py | 3 +- src/gt4py/next/otf/compilation/compiler.py | 11 +- src/gt4py/next/otf/stages.py | 23 +- src/gt4py/next/otf/step_types.py | 2 +- src/gt4py/next/otf/toolchain.py | 66 +++ src/gt4py/next/otf/workflow.py | 43 +- .../codegens/gtfn/gtfn_module.py | 24 +- .../program_processors/modular_executor.py | 13 +- .../next/program_processors/runners/dace.py | 20 +- .../runners/dace_common/workflow.py | 10 +- .../runners/dace_fieldview/workflow.py | 13 +- .../runners/dace_iterator/itir_to_sdfg.py | 18 +- .../runners/dace_iterator/workflow.py | 15 +- .../runners/double_roundtrip.py | 1 + .../next/program_processors/runners/gtfn.py | 31 +- .../program_processors/runners/roundtrip.py | 108 ++-- .../next/type_system/type_translation.py | 10 + .../feature_tests/dace/test_orchestration.py | 5 +- .../ffront_tests/ffront_test_utils.py | 2 +- .../ffront_tests/test_decorator.py | 36 ++ .../test_temporaries_with_sizes.py | 1 + .../iterator_tests/test_builtins.py | 1 + .../unit_tests/ffront_tests/test_stages.py | 97 +--- .../build_systems_tests/conftest.py | 1 + .../unit_tests/otf_tests/test_languages.py | 2 + .../gtfn_tests/test_gtfn_module.py | 9 +- tests/next_tests/unit_tests/test_config.py | 4 +- 49 files changed, 1547 insertions(+), 1116 deletions(-) create mode 100644 src/gt4py/next/ffront/signature.py create mode 100644 src/gt4py/next/otf/arguments.py create mode 100644 src/gt4py/next/otf/toolchain.py diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md index 70681796ee..546be95784 100644 --- a/docs/user/next/advanced/HackTheToolchain.md +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -3,7 +3,8 @@ import dataclasses import typing from gt4py import next as gtx -from gt4py.next.otf import workflow +from gt4py.next.otf import toolchain, workflow +from gt4py.next.ffront import stages as ff_stages from gt4py import eve ``` @@ -13,43 +14,35 @@ from gt4py import eve ## Replace Steps ```python -cached_lowering_toolchain = gtx.backend.DEFAULT_PROG_TRANSFORMS.replace( - past_to_itir=workflow.CachedStep( - step=gtx.ffront.past_to_itir.PastToItirFactory(), - hash_function=eve.utils.content_hash - ) +cached_lowering_toolchain = gtx.backend.DEFAULT_TRANSFORMS.replace( + past_to_itir=gtx.ffront.past_to_itir.past_to_itir_factory(cached=False) ) ``` ## Skip Steps / Change Order ```python -gtx.backend.DEFAULT_PROG_TRANSFORMS.step_order +DUMMY_FOP = toolchain.CompilableProgram(data=ff_stages.FieldOperatorDefinition(definition=None), args=None) ``` - ['func_to_past', - 'past_lint', - 'past_inject_args', - 'past_transform_args', - 'past_to_itir'] +```python +gtx.backend.DEFAULT_TRANSFORMS.step_order(DUMMY_FOP) +``` ```python @dataclasses.dataclass(frozen=True) -class SkipLinting(gtx.backend.ProgramTransformWorkflow): - @property - def step_order(self): - return [ - "func_to_past", - # not running "past_lint" - "past_inject_args", - "past_transform_args", - "past_to_itir", - ] - -same_steps = dataclasses.asdict(gtx.backend.DEFAULT_PROG_TRANSFORMS) +class SkipLinting(gtx.backend.Transforms): + def step_order(self, inp): + order = super().step_order(inp) + if "past_lint" in order: + order.remove("past_lint") # not running "past_lint" + return order + +same_steps = dataclasses.asdict(gtx.backend.DEFAULT_TRANSFORMS) skip_linting_transforms = SkipLinting( **same_steps ) +skip_linting_transforms.step_order(DUMMY_FOP) ``` ## Alternative Factory @@ -63,7 +56,7 @@ class Cpp2BindingsGen: class PureCpp2WorkflowFactory(gtx.program_processors.runners.gtfn.GTFNCompileWorkflowFactory): translation: workflow.Workflow[ - gtx.otf.stages.ProgramCall, gtx.otf.stages.ProgramSource] = MyCodeGen() + gtx.otf.stages.AOTProgram, gtx.otf.stages.ProgramSource] = MyCodeGen() bindings: workflow.Workflow[ gtx.otf.stages.ProgramSource, gtx.otf.stages.CompilableSource] = Cpp2BindingsGen() @@ -72,12 +65,12 @@ PureCpp2WorkflowFactory(cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG) ## Invent new Workflow Types -````mermaid +```mermaid graph LR IN_T --> i{{split}} --> A_T --> a{{track_a}} --> B_T --> o{{combine}} --> OUT_T i --> X_T --> x{{track_x}} --> Y_T --> o - +``` ```python IN_T = typing.TypeVar("IN_T") @@ -126,4 +119,4 @@ class PartiallyModularDiamond( b=self.track_a(a), y=self.track_x(x) ) -```` +``` diff --git a/docs/user/next/advanced/ToolchainWalkthrough.md b/docs/user/next/advanced/ToolchainWalkthrough.md index d44663a72c..83975d4106 100644 --- a/docs/user/next/advanced/ToolchainWalkthrough.md +++ b/docs/user/next/advanced/ToolchainWalkthrough.md @@ -1,6 +1,7 @@ ```python import dataclasses import inspect +import pprint import gt4py.next as gtx from gt4py.next import backend @@ -22,28 +23,14 @@ OFFSET_PROVIDER = {"Ioff": I} ```mermaid graph LR -fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]") foast -->|foast_to_itir| itir_expr(itir.Expr) -foasta -->|foast_to_foast_closure| fclos(FoastClosure) -fclos -->|foast_to_past_closure| pclos(PastClosure) -pclos -->|past_process_args| pclos -pclos -->|past_to_itir| pcall(ProgramCall) - -pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]") past -->|past_lint| past -pasta -->|past_to_past_closure| pclos(ProgramClosure) - -fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef -fuwr --> fargs(args, kwargs) - -foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) -fargs --> foasta +past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]") +tapast -->|past_to_itir| pcall(AOTProgram) -pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef -puwr --> pargs(args, kwargs) - -past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) -pargs --> pasta +pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]") ``` # Walkthrough from Field Operator @@ -64,45 +51,19 @@ start = example_fo.definition_stage gtx.ffront.stages.FieldOperatorDefinition? ``` - Init signature: - gtx.ffront.stages.FieldOperatorDefinition( -  definition: 'types.FunctionType', -  grid_type: 'Optional[common.GridType]' = None, -  node_class: 'type[OperatorNodeT]' = <class 'gt4py.next.ffront.field_operator_ast.FieldOperator'>, -  attributes: 'dict[str, Any]' = <factory>, - ) -> None - Docstring: FieldOperatorDefinition(definition: 'types.FunctionType', grid_type: 'Optional[common.GridType]' = None, node_class: 'type[OperatorNodeT]' = , attributes: 'dict[str, Any]' = ) - File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py - Type: type - Subclasses: - ## DSL -> FOAST ```mermaid graph LR -fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]") foast -->|foast_to_itir| itir_expr(itir.Expr) -foasta -->|foast_to_foast_closure| fclos(FoastClosure) -fclos -->|foast_to_past_closure| pclos(PastClosure) -pclos -->|past_process_args| pclos -pclos -->|past_to_itir| pcall(ProgramCall) - -pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]") past -->|past_lint| past -pasta -->|past_to_past_closure| pclos(ProgramClosure) +past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]") +tapast -->|past_to_itir| pcall(AOTProgram) -fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef -fuwr --> fargs(args, kwargs) - -foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) -fargs --> foasta - -pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef -puwr --> pargs(args, kwargs) - -past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) -pargs --> pasta +pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]") style fdef fill:red style foast fill:red @@ -110,25 +71,16 @@ linkStyle 0 stroke:red,stroke-width:4px,color:pink ``` ```python -foast = backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(start) + +foast = backend.DEFAULT_TRANSFORMS.func_to_foast( + gtx.otf.toolchain.CompilableProgram(start, gtx.otf.arguments.CompileTimeArgs.empty()) +) ``` ```python -gtx.ffront.stages.FoastOperatorDefinition? +foast.data.__class__? ``` - Init signature: - gtx.ffront.stages.FoastOperatorDefinition( -  foast_node: 'OperatorNodeT', -  closure_vars: 'dict[str, Any]', -  grid_type: 'Optional[common.GridType]' = None, -  attributes: 'dict[str, Any]' = <factory>, - ) -> None - Docstring: FoastOperatorDefinition(foast_node: 'OperatorNodeT', closure_vars: 'dict[str, Any]', grid_type: 'Optional[common.GridType]' = None, attributes: 'dict[str, Any]' = ) - File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py - Type: type - Subclasses: - ## FOAST -> ITIR This also happens inside the `decorator.FieldOperator.__gt_itir__` method during the lowering from calling Programs to ITIR @@ -136,28 +88,14 @@ This also happens inside the `decorator.FieldOperator.__gt_itir__` method during ```mermaid graph LR -fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]") foast -->|foast_to_itir| itir_expr(itir.Expr) -foasta -->|foast_to_foast_closure| fclos(FoastClosure) -fclos -->|foast_to_past_closure| pclos(PastClosure) -pclos -->|past_process_args| pclos -pclos -->|past_to_itir| pcall(ProgramCall) - -pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]") past -->|past_lint| past -pasta -->|past_to_past_closure| pclos(ProgramClosure) - -fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef -fuwr --> fargs(args, kwargs) - -foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) -fargs --> foasta +past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]") +tapast -->|past_to_itir| pcall(AOTProgram) -pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef -puwr --> pargs(args, kwargs) - -past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) -pargs --> pasta +pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]") style foast fill:red style itir_expr fill:red @@ -165,192 +103,113 @@ linkStyle 1 stroke:red,stroke-width:4px,color:pink ``` ```python -fitir = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_itir(foast) +fitir = backend.DEFAULT_TRANSFORMS.foast_to_itir(foast) ``` ```python fitir.__class__ ``` - gt4py.next.iterator.ir.FunctionDefinition - -## FOAST -> FOAST closure +## FOAST with args -> PAST with args -This is preparation for "directly calling" a field operator. +This auto-generates a program for us, directly in PAST representation and forwards the call arguments to it ```mermaid graph LR -fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]") foast -->|foast_to_itir| itir_expr(itir.Expr) -foasta -->|foast_to_foast_closure| fclos(FoastClosure) -fclos -->|foast_to_past_closure| pclos(PastClosure) -pclos -->|past_process_args| pclos -pclos -->|past_to_itir| pcall(ProgramCall) - -pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]") past -->|past_lint| past -pasta -->|past_to_past_closure| pclos(ProgramClosure) - -fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef -fuwr --> fargs(args, kwargs) - -foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) -fargs --> foasta +past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]") +tapast -->|past_to_itir| pcall(AOTProgram) -pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef -puwr --> pargs(args, kwargs) +pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]") -past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) -pargs --> pasta - -style foasta fill:red -style fclos fill:red +style foast fill:red +style past fill:red linkStyle 2 stroke:red,stroke-width:4px,color:pink ``` -Here we have to manually combine the previous result with the call arguments. When we call the toolchain as a whole later we will only have to do this once at the beginning. +So far we have gotten away with empty compile time arguments, now we need to supply actual types. The easiest way to do that is from concrete arguments. ```python -fclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_foast_closure( - gtx.otf.workflow.InputWithArgs( - data=foast, - args=(gtx.ones(domain={I: 10}, dtype=gtx.float64),), - kwargs={ - "out": gtx.zeros(domain={I: 10}, dtype=gtx.float64), - "from_fieldop": example_fo - }, - ) +jit_args = gtx.otf.arguments.JITArgs.from_signature( + gtx.ones(domain={I: 10}, dtype=gtx.float64), + out=gtx.zeros(domain={I: 10}, dtype=gtx.float64), + offset_provider=OFFSET_PROVIDER +) + +aot_args = gtx.otf.arguments.CompileTimeArgs.from_concrete_no_size( + *jit_args.args, **jit_args.kwargs ) ``` ```python -fclos.closure_vars["example_fo"].backend +pclos = backend.DEFAULT_TRANSFORMS.field_view_op_to_prog(gtx.otf.toolchain.CompilableProgram(data=foast.data, args=aot_args)) ``` ```python -gtx.ffront.stages.FoastClosure?? +pclos.data.__class__? ``` - Init signature: - gtx.ffront.stages.FoastClosure( -  foast_op_def: 'FoastOperatorDefinition[OperatorNodeT]', -  args: 'tuple[Any, ...]', -  kwargs: 'dict[str, Any]', -  closure_vars: 'dict[str, Any]', - ) -> None - Docstring: FoastClosure(foast_op_def: 'FoastOperatorDefinition[OperatorNodeT]', args: 'tuple[Any, ...]', kwargs: 'dict[str, Any]', closure_vars: 'dict[str, Any]') - Source: - @dataclasses.dataclass(frozen=True) - class FoastClosure(Generic[OperatorNodeT]): -  foast_op_def: FoastOperatorDefinition[OperatorNodeT] -  args: tuple[Any, ...] -  kwargs: dict[str, Any] -  closure_vars: dict[str, Any] - File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py - Type: type - Subclasses: - -## FOAST with args -> PAST closure +## Lint ProgramAST -This auto-generates a program for us, directly in PAST representation and forwards the call arguments to it +This checks the generated (or manually passed) PAST node. ```mermaid graph LR -fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]") foast -->|foast_to_itir| itir_expr(itir.Expr) -foasta -->|foast_to_foast_closure| fclos(FoastClosure) -fclos -->|foast_to_past_closure| pclos(PastClosure) -pclos -->|past_process_args| pclos -pclos -->|past_to_itir| pcall(ProgramCall) - -pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]") past -->|past_lint| past -pasta -->|past_to_past_closure| pclos(ProgramClosure) - -fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef -fuwr --> fargs(args, kwargs) - -foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) -fargs --> foasta +past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]") +tapast -->|past_to_itir| pcall(AOTProgram) -pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef -puwr --> pargs(args, kwargs) +pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]") -past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) -pargs --> pasta - -style fclos fill:red -style pclos fill:red +style past fill:red +%%style tapast fill:red linkStyle 3 stroke:red,stroke-width:4px,color:pink ``` ```python -pclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_past_closure(fclos) +linted = backend.DEFAULT_TRANSFORMS.past_lint(pclos) ``` -```python -gtx.ffront.stages.PastClosure? -``` - - Init signature: - gtx.ffront.stages.PastClosure( -  closure_vars: 'dict[str, Any]', -  past_node: 'past.Program', -  grid_type: 'Optional[common.GridType]', -  args: 'tuple[Any, ...]', -  kwargs: 'dict[str, Any]', - ) -> None - Docstring: PastClosure(closure_vars: 'dict[str, Any]', past_node: 'past.Program', grid_type: 'Optional[common.GridType]', args: 'tuple[Any, ...]', kwargs: 'dict[str, Any]') - File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py - Type: type - Subclasses: - ## Transform PAST closure arguments -Don't ask me, seems to be necessary though +This turns data arguments (or rather, their compile-time standins) passed as keyword args (allowed in DSL programs) into positional args (the only way supported by all compiled programs). Included in this is the 'out' argument which is automatically added when generating a fieldview program from a fieldview operator. ```mermaid graph LR -fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]") foast -->|foast_to_itir| itir_expr(itir.Expr) -foasta -->|foast_to_foast_closure| fclos(FoastClosure) -fclos -->|foast_to_past_closure| pclos(PastClosure) -pclos -->|past_process_args| pclos -pclos -->|past_to_itir| pcall(ProgramCall) - -pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]") past -->|past_lint| past -pasta -->|past_to_past_closure| pclos(ProgramClosure) - -fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef -fuwr --> fargs(args, kwargs) - -foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) -fargs --> foasta +past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]") +tapast -->|past_to_itir| pcall(AOTProgram) -pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef -puwr --> pargs(args, kwargs) +pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]") -past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) -pargs --> pasta - -style pclos fill:red -%%style pclos fill:red +style past fill:red +style tapast fill:red linkStyle 4 stroke:red,stroke-width:4px,color:pink ``` ```python -pclost = backend.DEFAULT_PROG_TRANSFORMS.past_transform_args(pclos) +pclost = backend.DEFAULT_TRANSFORMS.field_view_prog_args_transform(pclos) ``` ```python -pclost.kwargs +pprint.pprint(pclos.args) ``` - {} +```python +pprint.pprint(pclost.args) +``` ## Lower PAST -> ITIR @@ -359,67 +218,41 @@ still forwarding the call arguments ```mermaid graph LR -fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]") foast -->|foast_to_itir| itir_expr(itir.Expr) -foasta -->|foast_to_foast_closure| fclos(FoastClosure) -fclos -->|foast_to_past_closure| pclos(PastClosure) -pclos -->|past_process_args| pclos -pclos -->|past_to_itir| pcall(ProgramCall) - -pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]") past -->|past_lint| past -pasta -->|past_to_past_closure| pclos(ProgramClosure) +past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]") +tapast -->|past_to_itir| pcall(AOTProgram) -fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef -fuwr --> fargs(args, kwargs) +pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]") -foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) -fargs --> foasta - -pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef -puwr --> pargs(args, kwargs) - -past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) -pargs --> pasta - -style pclos fill:red +style tapast fill:red style pcall fill:red linkStyle 5 stroke:red,stroke-width:4px,color:pink ``` ```python -pitir = backend.DEFAULT_PROG_TRANSFORMS.past_to_itir(pclost) +pitir = backend.DEFAULT_TRANSFORMS.past_to_itir(pclost) ``` ```python -gtx.otf.stages.ProgramCall? +pitir.__class__? ``` - Init signature: - gtx.otf.stages.ProgramCall( -  program: 'itir.FencilDefinition', -  args: 'tuple[Any, ...]', -  kwargs: 'dict[str, Any]', - ) -> None - Docstring: Iterator IR representaion of a program together with arguments to be passed to it. - File: ~/Code/gt4py/src/gt4py/next/otf/stages.py - Type: type - Subclasses: - ## Executing The Result ```python -gtx.program_processors.runners.roundtrip.executor(pitir.program, *pitir.args, offset_provider=OFFSET_PROVIDER, **pitir.kwargs) +pprint.pprint(jit_args) ``` ```python -pitir.args +gtx.program_processors.runners.roundtrip.executor.otf_workflow(pitir)(*jit_args.args, **jit_args.kwargs) ``` - (NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), - NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])), - 10, - 10) +```python +pprint.pprint(jit_args) +``` ## Full Field Operator Toolchain @@ -428,47 +261,28 @@ using the default step order ```mermaid graph LR -fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]") foast -->|foast_to_itir| itir_expr(itir.Expr) -foasta -->|foast_to_foast_closure| fclos(FoastClosure) -fclos -->|foast_to_past_closure| pclos(PastClosure) -pclos -->|past_process_args| pclos -pclos -->|past_to_itir| pcall(ProgramCall) - -pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]") past -->|past_lint| past -pasta -->|past_to_past_closure| pclos(ProgramClosure) - -fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef -fuwr --> fargs(args, kwargs) - -foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) -fargs --> foasta +past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]") +tapast -->|past_to_itir| pcall(AOTProgram) -pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef -puwr --> pargs(args, kwargs) +pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]") -past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) -pargs --> pasta - -style fdefa fill:red -style fuwr fill:red style fdef fill:red -style fargs fill:red style foast fill:red -style fiwr fill:red -style foasta fill:red -style fclos fill:red -style pclos fill:red +style past fill:red +style tapast fill:red style pcall fill:red -linkStyle 0,2,3,4,5,9,10,11,12,13,14 stroke:red,stroke-width:4px,color:pink +linkStyle 0,2,3,4,5 stroke:red,stroke-width:4px,color:pink ``` ### Starting from DSL ```python -pitir2 = backend.DEFAULT_FIELDOP_TRANSFORMS( - gtx.otf.workflow.InputWithArgs(data=start, args=fclos.args, kwargs=fclos.kwargs | {"from_fieldop": example_fo}) +pitir2 = backend.DEFAULT_TRANSFORMS( + gtx.otf.toolchain.CompilableProgram(data=start, args=aot_args) ) assert pitir2 == pitir ``` @@ -476,46 +290,32 @@ assert pitir2 == pitir #### Pass The result to the compile workflow and execute ```python -example_compiled = gtx.program_processors.runners.roundtrip.executor.otf_workflow( - dataclasses.replace(pitir2, kwargs=pitir2.kwargs | {"offset_provider": OFFSET_PROVIDER}) -) +example_compiled = gtx.program_processors.runners.roundtrip.executor.otf_workflow(pitir2) ``` ```python -example_compiled(*pitir2.args, offset_provider=OFFSET_PROVIDER) +example_compiled(*jit_args.args, **jit_args.kwargs) ``` We can re-run with the output from the previous run as in- and output. ```python -example_compiled(pitir2.args[1], *pitir2.args[1:], offset_provider=OFFSET_PROVIDER) +example_compiled(jit_args.kwargs["out"], *jit_args.args[1:], **jit_args.kwargs) ``` ```python -pitir2.args[2] +pprint.pprint(jit_args) ``` - 10 - -```python -pitir.args -``` - - (NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), - NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.])), - 10, - 10) - ### Starting from FOAST Note that it is the exact same call but with a different input stage ```python -pitir3 = backend.DEFAULT_FIELDOP_TRANSFORMS( - gtx.otf.workflow.InputWithArgs( - data=foast, - args=fclos.args, - kwargs=fclos.kwargs | {"from_fieldop": example_fo} +pitir3 = backend.DEFAULT_TRANSFORMS( + gtx.otf.toolchain.CompilableProgram( + data=foast.data, + args=aot_args ) ) assert pitir3 == pitir @@ -536,46 +336,22 @@ p_start = example_prog.definition_stage ``` ```python -gtx.ffront.stages.ProgramDefinition? +p_start.__class__? ``` - Init signature: - gtx.ffront.stages.ProgramDefinition( -  definition: 'types.FunctionType', -  grid_type: 'Optional[common.GridType]' = None, - ) -> None - Docstring: ProgramDefinition(definition: 'types.FunctionType', grid_type: 'Optional[common.GridType]' = None) - File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py - Type: type - Subclasses: - ## DSL -> PAST ```mermaid graph LR -fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]") foast -->|foast_to_itir| itir_expr(itir.Expr) -foasta -->|foast_to_foast_closure| fclos(FoastClosure) -fclos -->|foast_to_past_closure| pclos(PastClosure) -pclos -->|past_process_args| pclos -pclos -->|past_to_itir| pcall(ProgramCall) - -pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]") past -->|past_lint| past -pasta -->|past_to_past_closure| pclos(ProgramClosure) - -fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef -fuwr --> fargs(args, kwargs) - -foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) -fargs --> foasta +past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]") +tapast -->|past_to_itir| pcall(AOTProgram) -pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef -puwr --> pargs(args, kwargs) - -past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) -pargs --> pasta +pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]") style pdef fill:red style past fill:red @@ -583,50 +359,8 @@ linkStyle 6 stroke:red,stroke-width:4px,color:pink ``` ```python -p_past = backend.DEFAULT_PROG_TRANSFORMS.func_to_past(p_start) -``` - -## PAST -> Closure - -```mermaid -graph LR - -fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) -foast -->|foast_to_itir| itir_expr(itir.Expr) -foasta -->|foast_to_foast_closure| fclos(FoastClosure) -fclos -->|foast_to_past_closure| pclos(PastClosure) -pclos -->|past_process_args| pclos -pclos -->|past_to_itir| pcall(ProgramCall) - -pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) -past -->|past_lint| past -pasta -->|past_to_past_closure| pclos(ProgramClosure) - -fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef -fuwr --> fargs(args, kwargs) - -foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) -fargs --> foasta - -pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef -puwr --> pargs(args, kwargs) - -past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) -pargs --> pasta - -style pasta fill:red -style pclos fill:red -linkStyle 8 stroke:red,stroke-width:4px,color:pink -``` - -```python -pclos = backend.DEFAULT_PROG_TRANSFORMS( - gtx.otf.workflow.InputWithArgs( - data=p_past, - args=fclos.args, - kwargs=fclos.kwargs - ) -) +p_past = backend.DEFAULT_TRANSFORMS.func_to_past( + gtx.otf.toolchain.CompilableProgram(data=p_start, args=gtx.otf.arguments.CompileTimeArgs.empty())) ``` ## Full Program Toolchain @@ -634,59 +368,38 @@ pclos = backend.DEFAULT_PROG_TRANSFORMS( ```mermaid graph LR -fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +fdef("CompilableProgram[FieldOperatorDefinition, AOT]") -->|func_to_foast| foast("CompilableProgram[FoastOperatorDefinition, AOT]") foast -->|foast_to_itir| itir_expr(itir.Expr) -foasta -->|foast_to_foast_closure| fclos(FoastClosure) -fclos -->|foast_to_past_closure| pclos(PastClosure) -pclos -->|past_process_args| pclos -pclos -->|past_to_itir| pcall(ProgramCall) - -pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +foast -->|field_view_op_to_prog| past("CompilableProgram[PastProgramDefinition, AOT]") past -->|past_lint| past -pasta -->|past_to_past_closure| pclos(ProgramClosure) - -fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef -fuwr --> fargs(args, kwargs) - -foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) -fargs --> foasta - -pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef -puwr --> pargs(args, kwargs) +past -->|field_view_prog_args_transform| tapast("CompilableProgram[PastProgramDefinition, AOT]") +tapast -->|past_to_itir| pcall(AOTProgram) -past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) -pargs --> pasta +pdef("CompilableProgram[ProgramDefinition, AOT]") -->|func_to_past| past("CompilableProgram[PastProgramDefinition, AOT]") -style pdefa fill:red -style puwr fill:red style pdef fill:red -style pargs fill:red style past fill:red -style piwr fill:red -style pasta fill:red -style pclos fill:red +style tapast fill:red style pcall fill:red -linkStyle 4,5,6,7,8,15,16,17,18,19,20 stroke:red,stroke-width:4px,color:pink +linkStyle 3,4,5,6 stroke:red,stroke-width:4px,color:pink ``` ### Starting from DSL ```python -p_itir1 = backend.DEFAULT_PROG_TRANSFORMS( - gtx.otf.workflow.InputWithArgs( +p_itir1 = backend.DEFAULT_TRANSFORMS( + gtx.otf.toolchain.CompilableProgram( data=p_start, - args=fclos.args, - kwargs=fclos.kwargs + args=jit_args ) ) ``` ```python -p_itir2 = backend.DEFAULT_PROG_TRANSFORMS( - gtx.otf.workflow.InputWithArgs( - data=p_past, - args=fclos.args, - kwargs=fclos.kwargs +p_itir2 = backend.DEFAULT_TRANSFORMS( + gtx.otf.toolchain.CompilableProgram( + data=p_past.data, + args=aot_args ) ) ``` diff --git a/docs/user/next/advanced/WorkflowPatterns.md b/docs/user/next/advanced/WorkflowPatterns.md index 76880d86f0..b5ae74479e 100644 --- a/docs/user/next/advanced/WorkflowPatterns.md +++ b/docs/user/next/advanced/WorkflowPatterns.md @@ -6,7 +6,7 @@ jupyter: extension: .md format_name: markdown format_version: "1.3" - jupytext_version: 1.16.1 + jupytext_version: 1.16.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -98,7 +98,7 @@ add_3(1) ### Example in the Wild -```python jupyter={"outputs_hidden": true} +```python gtx.ffront.func_to_past.func_to_past.steps.inner[0]?? ``` @@ -134,7 +134,7 @@ add_3_times_2(1) ### Example in the Wild -```python jupyter={"outputs_hidden": true} +```python gtx.program_processors.runners.roundtrip.Roundtrip?? ``` @@ -180,7 +180,7 @@ cached_calc(1) ### Example in the Wild -```python jupyter={"outputs_hidden": true} +```python gtx.backend.DEFAULT_PROG_TRANSFORMS.past_lint?? ``` @@ -268,7 +268,7 @@ strint_calc(1) == strint_calc("1") -```python jupyter={"outputs_hidden": true} editable=true slideshow={"slide_type": ""} +```python editable=true slideshow={"slide_type": ""} gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past?? ``` @@ -486,7 +486,3 @@ gtx.program_processors.runners.gtfn.run_gtfn_gpu.executor.otf_workflow?? ```python gtx.program_processors.runners.gtfn.GTFNBackendFactory?? ``` - -```python - -``` diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 406ac9b7ed..80bb276c70 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -39,7 +39,6 @@ from .ffront.fbuiltins import * # noqa: F403 [undefined-local-with-import-star] explicitly reexport all from fbuiltins.__all__ from .ffront.fbuiltins import FieldOffset from .iterator.embedded import ( - CompileTimeConnectivity, NeighborTableOffsetProvider, StridedNeighborOffsetProvider, index_field, @@ -76,7 +75,6 @@ "as_connectivity", # from iterator "NeighborTableOffsetProvider", - "CompileTimeConnectivity", "StridedNeighborOffsetProvider", "index_field", "np_as_located_field", diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index cdf5b402b5..d4e9965fb2 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -9,6 +9,7 @@ from __future__ import annotations import dataclasses +import typing from typing import Any, Generic from gt4py._core import definitions as core_defs @@ -20,143 +21,138 @@ func_to_past, past_process_args, past_to_itir, - stages as ffront_stages, + signature, ) from gt4py.next.ffront.past_passes import linters as past_linters +from gt4py.next.ffront.stages import ( + AOT_DSL_FOP, + AOT_DSL_PRG, + AOT_FOP, + AOT_PRG, + DSL_FOP, + DSL_PRG, + FOP, + PRG, +) from gt4py.next.iterator import ir as itir -from gt4py.next.otf import stages, workflow -from gt4py.next.program_processors import processor_interface as ppi - - -@workflow.make_step -def foast_to_foast_closure( - inp: workflow.InputWithArgs[ffront_stages.FoastOperatorDefinition], -) -> ffront_stages.FoastClosure: - from_fieldop = inp.kwargs.pop("from_fieldop") - return ffront_stages.FoastClosure( - foast_op_def=inp.data, - args=inp.args, - kwargs=inp.kwargs, - closure_vars={inp.data.foast_node.id: from_fieldop}, - ) +from gt4py.next.otf import arguments, stages, toolchain, workflow +from gt4py.next.program_processors import modular_executor -@dataclasses.dataclass(frozen=True) -class FieldopTransformWorkflow(workflow.NamedStepSequenceWithArgs): - """Modular workflow for transformations with access to intermediates.""" - - func_to_foast: workflow.SkippableStep[ - ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition, - ffront_stages.FoastOperatorDefinition, - ] = dataclasses.field( - default_factory=lambda: func_to_foast.OptionalFuncToFoastFactory(cached=True) - ) - foast_to_foast_closure: workflow.Workflow[ - workflow.InputWithArgs[ffront_stages.FoastOperatorDefinition], ffront_stages.FoastClosure - ] = dataclasses.field(default=foast_to_foast_closure, metadata={"takes_args": True}) - foast_to_past_closure: workflow.Workflow[ - ffront_stages.FoastClosure, ffront_stages.PastClosure - ] = dataclasses.field( - default_factory=lambda: foast_to_past.FoastToPastClosure( - foast_to_past=workflow.CachedStep( - foast_to_past.foast_to_past, hash_function=ffront_stages.fingerprint_stage - ) - ) - ) - past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = ( - dataclasses.field(default=past_process_args.past_process_args) - ) - past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( - dataclasses.field(default_factory=past_to_itir.PastToItirFactory) - ) +ARGS: typing.TypeAlias = arguments.JITArgs +CARG: typing.TypeAlias = arguments.CompileTimeArgs +IT_PRG: typing.TypeAlias = itir.FencilDefinition - foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = ( - dataclasses.field( - default_factory=lambda: workflow.CachedStep( - step=foast_to_itir.foast_to_itir, hash_function=ffront_stages.fingerprint_stage - ) - ) - ) - @property - def step_order(self) -> list[str]: - return [ - "func_to_foast", - "foast_to_foast_closure", - "foast_to_past_closure", - "past_transform_args", - "past_to_itir", - ] +INPUT_DATA: typing.TypeAlias = DSL_FOP | FOP | DSL_PRG | PRG | IT_PRG +INPUT_PAIR: typing.TypeAlias = toolchain.CompilableProgram[INPUT_DATA, ARGS | CARG] -DEFAULT_FIELDOP_TRANSFORMS = FieldopTransformWorkflow() +@dataclasses.dataclass(frozen=True) +class Transforms(workflow.MultiWorkflow[INPUT_PAIR, stages.AOTProgram]): + """ + Modular workflow for transformations with access to intermediates. + + The set and order of transformation steps depends on the input type. + Thus this workflow can be applied to DSL field operator and program definitions, + as well as their AST representations. Even to Iterator IR programs, although in that + case it will be a no-op. + + The input to the workflow as well as each step must be a `CompilableProgram`. The arguments + inside the `CompilableProgram` passed to the whole workflow may be concrete (`JITArgs`) + or compile-time (`CompileTimeArgs`). The individual steps (apart from `.aotify_args`) + require compile-time arguments. Some of the steps can work with an empty `CompileTimeArgs` instance. + """ + + aotify_args: workflow.Workflow[ + toolchain.CompilableProgram[INPUT_DATA, ARGS], toolchain.CompilableProgram[INPUT_DATA, CARG] + ] = dataclasses.field(default_factory=arguments.adapted_jit_to_aot_args_factory) + + func_to_foast: workflow.Workflow[AOT_DSL_FOP, AOT_FOP] = dataclasses.field( + default_factory=func_to_foast.adapted_func_to_foast_factory + ) + func_to_past: workflow.Workflow[AOT_DSL_PRG, AOT_PRG] = dataclasses.field( + default_factory=func_to_past.adapted_func_to_past_factory + ) -@dataclasses.dataclass(frozen=True) -class ProgramTransformWorkflow(workflow.NamedStepSequenceWithArgs): - """Modular workflow for transformations with access to intermediates.""" - - func_to_past: workflow.SkippableStep[ - ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition, - ffront_stages.PastProgramDefinition, - ] = dataclasses.field( - default_factory=lambda: func_to_past.OptionalFuncToPastFactory(cached=True) + foast_to_itir: workflow.Workflow[AOT_FOP, itir.Expr] = dataclasses.field( + default_factory=foast_to_itir.adapted_foast_to_itir_factory ) - past_lint: workflow.Workflow[ - ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition - ] = dataclasses.field(default_factory=past_linters.LinterFactory) - past_to_past_closure: workflow.Workflow[ - ffront_stages.PastProgramDefinition, ffront_stages.PastClosure - ] = dataclasses.field( - default=lambda inp: ffront_stages.PastClosure( - past_node=inp.data.past_node, - closure_vars=inp.data.closure_vars, - grid_type=inp.data.grid_type, - args=inp.args, - kwargs=inp.kwargs, - ), - metadata={"takes_args": True}, + + field_view_op_to_prog: workflow.Workflow[AOT_FOP, AOT_PRG] = dataclasses.field( + default_factory=foast_to_past.operator_to_program_factory ) - past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = ( - dataclasses.field( - default=past_process_args.past_process_args, metadata={"takes_args": False} - ) + + past_lint: workflow.Workflow[AOT_PRG, AOT_PRG] = dataclasses.field( + default_factory=past_linters.adapted_linter_factory ) - past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( - dataclasses.field(default_factory=past_to_itir.PastToItirFactory) + + field_view_prog_args_transform: workflow.Workflow[AOT_PRG, AOT_PRG] = dataclasses.field( + default_factory=past_process_args.transform_program_args_factory ) + past_to_itir: workflow.Workflow[AOT_PRG, stages.AOTProgram] = dataclasses.field( + default_factory=past_to_itir.past_to_itir_factory + ) -DEFAULT_PROG_TRANSFORMS = ProgramTransformWorkflow() + def step_order(self, inp: INPUT_PAIR) -> list[str]: + steps: list[str] = [] + if isinstance(inp.args, ARGS): + steps.append("aotify_args") + match inp.data: + case DSL_FOP(): + steps.extend( + [ + "func_to_foast", + "field_view_op_to_prog", + "past_lint", + "field_view_prog_args_transform", + ] + ) + case FOP(): + steps.extend( + ["field_view_op_to_prog", "past_lint", "field_view_prog_args_transform"] + ) + case DSL_PRG(): + steps.extend(["func_to_past", "past_lint", "field_view_prog_args_transform"]) + case PRG(): + steps.extend(["past_lint", "field_view_prog_args_transform"]) + case _: + pass + steps.append("past_to_itir") + return steps + + +DEFAULT_TRANSFORMS: Transforms = Transforms() @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): - executor: ppi.ProgramExecutor + executor: modular_executor.ModularExecutor allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] - transforms_fop: FieldopTransformWorkflow = DEFAULT_FIELDOP_TRANSFORMS - transforms_prog: ProgramTransformWorkflow = DEFAULT_PROG_TRANSFORMS + transforms: workflow.Workflow[INPUT_PAIR, stages.AOTProgram] def __call__( self, - program: ffront_stages.ProgramDefinition | ffront_stages.FieldOperatorDefinition, + program: INPUT_DATA, *args: Any, **kwargs: Any, ) -> None: - if isinstance( - program, (ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition) - ): - offset_provider = kwargs.pop("offset_provider") - from_fieldop = kwargs.pop("from_fieldop") - program_call = self.transforms_fop( - workflow.InputWithArgs(program, args, kwargs | {"from_fieldop": from_fieldop}) - ) - program_call = dataclasses.replace( - program_call, kwargs=program_call.kwargs | {"offset_provider": offset_provider} - ) - else: - program_call = self.transforms_prog(workflow.InputWithArgs(program, args, kwargs)) - self.executor(program_call.program, *program_call.args, **program_call.kwargs) + if not isinstance(program, IT_PRG): + args, kwargs = signature.convert_to_positional(program, *args, **kwargs) + self.jit(program, *args, **kwargs)(*args, **kwargs) + + def jit(self, program: INPUT_DATA, *args: Any, **kwargs: Any) -> stages.CompiledProgram: + if not isinstance(program, IT_PRG): + args, kwargs = signature.convert_to_positional(program, *args, **kwargs) + aot_args = arguments.CompileTimeArgs.from_concrete_no_size(*args, **kwargs) + return self.compile(program, aot_args) + + def compile(self, program: INPUT_DATA, compile_time_args: CARG) -> stages.CompiledProgram: + return self.executor.otf_workflow( + self.transforms(toolchain.CompilableProgram(data=program, args=compile_time_args)) + ) @property def __name__(self) -> str: diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 64db9ef58d..a39d1219cb 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -35,7 +35,7 @@ from gt4py.next.ffront import ( field_operator_ast as foast, past_process_args, - past_to_itir, + signature, stages as ffront_stages, transform_utils, type_specifications as ts_ffront, @@ -48,6 +48,7 @@ ref, sym, ) +from gt4py.next.otf import arguments, stages, toolchain from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_info, type_specifications as ts, type_translation @@ -101,15 +102,21 @@ def definition(self): @functools.cached_property def past_stage(self): # backwards compatibility for backends that do not support the full toolchain - if self.backend is not None and self.backend.transforms_prog is not None: - return self.backend.transforms_prog.func_to_past(self.definition_stage) - return next_backend.DEFAULT_PROG_TRANSFORMS.func_to_past(self.definition_stage) + no_args_def = toolchain.CompilableProgram( + self.definition_stage, arguments.CompileTimeArgs.empty() + ) + if self.backend is not None and self.backend.transforms is not None: + return self.backend.transforms.func_to_past(no_args_def).data + return next_backend.DEFAULT_TRANSFORMS.func_to_past(no_args_def).data # TODO(ricoh): linting should become optional, up to the backend. def __post_init__(self): - if self.backend is not None and self.backend.transforms_prog is not None: - self.backend.transforms_prog.past_lint(self.past_stage) - return next_backend.DEFAULT_PROG_TRANSFORMS.past_lint(self.past_stage) + no_args_past = toolchain.CompilableProgram( + self.past_stage, arguments.CompileTimeArgs.empty() + ) + if self.backend is not None and self.backend.transforms is not None: + return self.backend.transforms.past_lint(no_args_past).data + return next_backend.DEFAULT_TRANSFORMS.past_lint(no_args_past).data @property def __name__(self) -> str: @@ -175,16 +182,17 @@ def _all_closure_vars(self) -> dict[str, Any]: @functools.cached_property def itir(self) -> itir.FencilDefinition: - no_args_past = ffront_stages.PastClosure( - past_node=self.past_stage.past_node, - closure_vars=self.past_stage.closure_vars, - grid_type=self.definition_stage.grid_type, - args=[], - kwargs={}, + no_args_past = toolchain.CompilableProgram( + data=ffront_stages.PastProgramDefinition( + past_node=self.past_stage.past_node, + closure_vars=self.past_stage.closure_vars, + grid_type=self.definition_stage.grid_type, + ), + args=arguments.CompileTimeArgs.empty(), ) - if self.backend is not None and self.backend.transforms_prog is not None: - return self.backend.transforms_prog.past_to_itir(no_args_past).program - return past_to_itir.PastToItirFactory()(no_args_past).program + if self.backend is not None and self.backend.transforms is not None: + return self.backend.transforms.past_to_itir(no_args_past).data + return next_backend.DEFAULT_TRANSFORMS.past_to_itir(no_args_past).data @functools.cached_property def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderElem]: @@ -214,12 +222,14 @@ def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderEle ) return implicit_offset_provider - def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) -> None: + def __call__( + self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any + ) -> None: offset_provider = offset_provider | self._implicit_offset_provider if self.backend is None: warnings.warn( UserWarning( - f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend." + f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a perfomance backend." ), stacklevel=2, ) @@ -234,9 +244,68 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) self.backend( - self.definition_stage, *args, **(kwargs | {"offset_provider": offset_provider}) + self.definition_stage, + *args, + **(kwargs | {"offset_provider": offset_provider}), + ) + + def freeze(self) -> FrozenProgram: + if self.backend is None: + raise ValueError("Can not freeze a program without backend (embedded execution).") + return FrozenProgram( + self.definition_stage if self.definition_stage else self.past_stage, + backend=self.backend, + ) + + +@dataclasses.dataclass(frozen=True) +class FrozenProgram: + """ + Simplified program instance, which skips the whole toolchain after the first execution. + + Does not work in embedded execution. + """ + + program: ffront_stages.DSL_PRG | ffront_stages.PRG + backend: next_backend.Backend + _compiled_program: Optional[stages.CompiledProgram] = dataclasses.field( + init=False, default=None + ) + + def __post_init__(self) -> None: + if self.backend is None: + raise ValueError("Can not JIT-compile programs without backend (embedded execution).") + + @property + def definition(self) -> str: + return self.program.definition + + def with_backend(self, backend: ppi.ProgramExecutor) -> FrozenProgram: + return self.__class__(program=self.program, backend=backend) + + def with_grid_type(self, grid_type: GridType) -> FrozenProgram: + return self.__class__( + program=dataclasses.replace(self.program, grid_type=grid_type), backend=self.backend ) + def jit( + self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any + ) -> stages.CompiledProgram: + return self.backend.jit(self.program, *args, offset_provider=offset_provider, **kwargs) + + def __call__( + self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any + ) -> None: + ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) + + args, kwargs = signature.convert_to_positional(self.program, *args, **kwargs) + + if not self._compiled_program: + super().__setattr__( + "_compiled_program", self.jit(*args, offset_provider=offset_provider, **kwargs) + ) + self._compiled_program(*args, offset_provider=offset_provider, **kwargs) + try: from gt4py.next.program_processors.runners.dace_iterator import Program @@ -255,20 +324,25 @@ class ProgramFromPast(Program): past_stage: ffront_stages.PastProgramDefinition - def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): + def __call__(self, *args: Any, offset_provider: dict[str, Dimension], **kwargs: Any) -> None: if self.backend is None: raise NotImplementedError( "Programs created from a PAST node (without a function definition) can not be executed in embedded mode" ) ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) - self.backend(self.past_stage, *args, **(kwargs | {"offset_provider": offset_provider})) + # TODO(ricoh): add test that does the equivalent of IDim + 1 in a ProgramFromPast + self.backend( + self.past_stage, + *args, + **(kwargs | {"offset_provider": offset_provider | self._implicit_offset_provider}), + ) # TODO(ricoh): linting should become optional, up to the backend. def __post_init__(self): - if self.backend is not None and self.backend.transforms_prog is not None: - self.backend.transforms_prog.past_lint(self.past_stage) - return next_backend.DEFAULT_PROG_TRANSFORMS.past_lint(self.past_stage) + if self.backend is not None and self.backend.transforms is not None: + self.backend.transforms.past_lint(self.past_stage) + return next_backend.DEFAULT_TRANSFORMS.past_lint(self.past_stage) @dataclasses.dataclass(frozen=True) @@ -357,12 +431,13 @@ def program( def program( - definition=None, + definition: Optional[types.FunctionType] = None, *, # `NOTHING` -> default backend, `None` -> no backend (embedded execution) - backend=eve.NOTHING, - grid_type=None, -) -> Program | Callable[[types.FunctionType], Program]: + backend: next_backend.Backend | eve.NOTHING = eve.NOTHING, + grid_type: Optional[GridType] = None, + frozen: bool = False, +) -> Program | FrozenProgram | Callable[[types.FunctionType], Program | FrozenProgram]: """ Generate an implementation of a program from a Python function object. @@ -382,9 +457,14 @@ def program( """ def program_inner(definition: types.FunctionType) -> Program: - return Program.from_function( - definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type + program = Program.from_function( + definition, + DEFAULT_BACKEND if backend is eve.NOTHING else backend, + grid_type, ) + if frozen: + return program.freeze() + return program return program_inner if definition is None else program_inner(definition) @@ -447,9 +527,13 @@ def __post_init__(self): @functools.cached_property def foast_stage(self) -> ffront_stages.FoastOperatorDefinition: - if self.backend is not None and self.backend.transforms_fop is not None: - return self.backend.transforms_fop.func_to_foast(self.definition_stage) - return next_backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(self.definition_stage) + if self.backend is not None and self.backend.transforms is not None: + return self.backend.transforms.func_to_foast( + toolchain.CompilableProgram(self.definition_stage, args=None) + ).data + return next_backend.DEFAULT_TRANSFORMS.func_to_foast( + toolchain.CompilableProgram(self.definition_stage, None) + ).data @property def __name__(self) -> str: @@ -473,40 +557,33 @@ def with_grid_type(self, grid_type: GridType) -> FieldOperator: ) def __gt_itir__(self) -> itir.FunctionDefinition: - if self.backend is not None and self.backend.transforms_fop is not None: - return self.backend.transforms_fop.foast_to_itir(self.foast_stage) - return next_backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_itir(self.foast_stage) + if self.backend is not None and self.backend.transforms is not None: + return self.backend.transforms.foast_to_itir( + toolchain.CompilableProgram(self.foast_stage, arguments.CompileTimeArgs.empty()) + ) + return next_backend.DEFAULT_TRANSFORMS.foast_to_itir( + toolchain.CompilableProgram(self.foast_stage, arguments.CompileTimeArgs.empty()) + ) def __gt_closure_vars__(self) -> dict[str, Any]: return self.foast_stage.closure_vars - def as_program( - self, arg_types: list[ts.TypeSpec], kwarg_types: dict[str, ts.TypeSpec] - ) -> Program: + def as_program(self, compiletime_args: arguments.CompileTimeArgs) -> Program: foast_with_types = ( - ffront_stages.FoastWithTypes( - foast_op_def=self.foast_stage, - arg_types=tuple(arg_types), - kwarg_types=kwarg_types, - closure_vars={self.foast_stage.foast_node.id: self}, + toolchain.CompilableProgram( + data=self.foast_stage, + args=compiletime_args, ), ) past_stage = None - if self.backend is not None and self.backend.transforms_fop is not None: - past_stage = self.backend.transforms_fop.foast_to_past_closure.foast_to_past( + if self.backend is not None and self.backend.transforms is not None: + past_stage = self.backend.transforms.field_view_op_to_prog.foast_to_past( foast_with_types - ) + ).data else: - past_stage = ( - next_backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_past_closure.foast_to_past( - ffront_stages.FoastWithTypes( - foast_op_def=self.foast_stage, - arg_types=tuple(arg_types), - kwarg_types=kwarg_types, - closure_vars={self.foast_stage.foast_node.id: self}, - ), - ) - ) + past_stage = next_backend.DEFAULT_TRANSFORMS.foast_to_past_closure.foast_to_past( + foast_with_types + ).data return ProgramFromPast(definition_stage=None, past_stage=past_stage, backend=self.backend) def __call__(self, *args, **kwargs) -> None: @@ -527,7 +604,6 @@ def __call__(self, *args, **kwargs) -> None: *args, out=out, offset_provider=offset_provider, - from_fieldop=self, **kwargs, ) else: @@ -566,7 +642,7 @@ class FieldOperatorFromFoast(FieldOperator): foast_stage: ffront_stages.FoastOperatorDefinition def __call__(self, *args, **kwargs) -> None: - return self.backend(self.foast_stage, *args, from_fieldop=self, **kwargs) + return self.backend(self.foast_stage, *args, **kwargs) @typing.overload diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 8f6ac7673b..4d3230a540 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -24,15 +24,35 @@ stages as ffront_stages, type_specifications as ts_ffront, ) +from gt4py.next.ffront.stages import AOT_FOP, FOP from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts def foast_to_gtir(inp: ffront_stages.FoastOperatorDefinition) -> itir.Expr: + """ + Lower a FOAST field operator node to GTIR. + + See the docstring of `FieldOperatorLowering` for details. + """ return FieldOperatorLowering.apply(inp.foast_node) +def foast_to_gtir_factory(cached: bool = True) -> workflow.Workflow[FOP, itir.Expr]: + """Wrap `foast_to_gtir` into a chainable and, optionally, cached workflow step.""" + wf = foast_to_gtir + if cached: + wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) + return wf + + +def adapted_foast_to_gtir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, itir.Expr]: + """Wrap the `foast_to_gtir` workflow step into an adapter to fit into backend transform workflows.""" + return toolchain.StripArgsAdapter(foast_to_gtir_factory(**kwargs)) + + def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: if not type_info.contains_local_field(node.type): return lambda x: im.op_as_fieldop("make_const_list")(x) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 64b7c66161..b32bf744f5 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -26,15 +26,35 @@ from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind +from gt4py.next.ffront.stages import AOT_FOP, FOP from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts -def foast_to_itir(inp: ffront_stages.FoastOperatorDefinition) -> itir.Expr: +def foast_to_itir(inp: FOP) -> itir.Expr: + """ + Lower a FOAST field operator node to Iterator IR. + + See the docstring of `FieldOperatorLowering` for details. + """ return FieldOperatorLowering.apply(inp.foast_node) +def foast_to_itir_factory(cached: bool = True) -> workflow.Workflow[FOP, itir.Expr]: + """Wrap `foast_to_itir` into a chainable and, optionally, cached workflow step.""" + wf = foast_to_itir + if cached: + wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) + return wf + + +def adapted_foast_to_itir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, itir.Expr]: + """Wrap the `foast_to_itir` workflow step into an adapter to fit into backend transform workflows.""" + return toolchain.StripArgsAdapter(foast_to_itir_factory(**kwargs)) + + def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: if not type_info.contains_local_field(node.type): return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index ed7ec5a9ed..a294dc33a9 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -7,117 +7,166 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses +from typing import Any, Optional from gt4py.eve import utils as eve_utils from gt4py.next.ffront import ( dialect_ast_enums, + foast_to_itir, program_ast as past, stages as ffront_stages, type_specifications as ts_ffront, ) from gt4py.next.ffront.past_passes import closure_var_type_deduction, type_deduction -from gt4py.next.otf import workflow +from gt4py.next.ffront.stages import AOT_FOP, AOT_PRG +from gt4py.next.iterator import ir as itir +from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -def foast_to_past(inp: ffront_stages.FoastWithTypes) -> ffront_stages.PastProgramDefinition: - # TODO(tehrengruber): implement mechanism to deduce default values - # of arg and kwarg types - # TODO(tehrengruber): check foast operator has no out argument that clashes - # with the out argument of the program we generate here. +@dataclasses.dataclass(frozen=True) +class ItirShim: + """ + A wrapper for a FOAST operator definition with `__gt_*__` special methods. + + Can be placed in a PAST program definition's closure variables so the program + lowering has access to the relevant information. + """ + + definition: AOT_FOP + foast_to_itir: workflow.Workflow[AOT_FOP, itir.Expr] + + def __gt_closure_vars__(self) -> Optional[dict[str, Any]]: + return self.definition.data.closure_vars + + def __gt_type__(self) -> ts.CallableType: + return self.definition.data.foast_node.type + + def __gt_itir__(self) -> itir.Expr: + return self.foast_to_itir(self.definition) + + +@dataclasses.dataclass(frozen=True) +class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): + """ + Generate a PAST program definition from a FOAST operator definition. + + This workflow step must must be given a FOAST -> ITIR lowering step so that it can place + valid `ItirShim` instances into the closure variables of the generated program. + + Example: + >>> from gt4py import next as gtx + >>> from gt4py.next.otf import arguments, toolchain + >>> IDim = gtx.Dimension("I") + + >>> @gtx.field_operator + ... def copy(a: gtx.Field[[IDim], gtx.float32]) -> gtx.Field[[IDim], gtx.float32]: + ... return a + + >>> op_to_prog = OperatorToProgram(foast_to_itir.adapted_foast_to_itir_factory()) + + >>> compile_time_args = arguments.CompileTimeArgs.from_concrete_no_size( + ... *( + ... arguments.CompileTimeArg(param.type) + ... for param in copy.foast_stage.foast_node.definition.params + ... ), + ... offset_provider={"I", IDim}, + ... ) + + >>> copy_program = op_to_prog(toolchain.CompilableProgram(copy.foast_stage, compile_time_args)) + + >>> print(copy_program.data.past_node.id) + __field_operator_copy + + >>> assert copy_program.data.closure_vars["copy"].definition.data is copy.foast_stage + """ + + foast_to_itir: workflow.Workflow[AOT_FOP, itir.Expr] + + def __call__(self, inp: AOT_FOP) -> AOT_PRG: + # TODO(tehrengruber): implement mechanism to deduce default values + # of arg and kwarg types + # TODO(tehrengruber): check foast operator has no out argument that clashes + # with the out argument of the program we generate here. - loc = inp.foast_op_def.foast_node.location - # use a new UID generator to allow caching - param_sym_uids = eve_utils.UIDGenerator() + arg_types = [type_translation.from_value(arg) for arg in inp.args.args] + kwarg_types = {k: type_translation.from_value(v) for k, v in inp.args.kwargs.items()} - type_ = inp.foast_op_def.foast_node.type - params_decl: list[past.Symbol] = [ - past.DataSymbol( - id=param_sym_uids.sequential_id(prefix="__sym"), - type=arg_type, + loc = inp.data.foast_node.location + # use a new UID generator to allow caching + param_sym_uids = eve_utils.UIDGenerator() + + type_ = inp.data.foast_node.type + params_decl: list[past.Symbol] = [ + past.DataSymbol( + id=param_sym_uids.sequential_id(prefix="__sym"), + type=arg_type, + namespace=dialect_ast_enums.Namespace.LOCAL, + location=loc, + ) + for arg_type in arg_types + ] + params_ref = [past.Name(id=pdecl.id, location=loc) for pdecl in params_decl] + out_sym: past.Symbol = past.DataSymbol( + id="out", + type=type_info.return_type(type_, with_args=list(arg_types), with_kwargs=kwarg_types), namespace=dialect_ast_enums.Namespace.LOCAL, location=loc, ) - for arg_type in inp.arg_types - ] - params_ref = [past.Name(id=pdecl.id, location=loc) for pdecl in params_decl] - out_sym: past.Symbol = past.DataSymbol( - id="out", - type=type_info.return_type( - type_, with_args=list(inp.arg_types), with_kwargs=inp.kwarg_types - ), - namespace=dialect_ast_enums.Namespace.LOCAL, - location=loc, - ) - out_ref = past.Name(id="out", location=loc) - - if inp.foast_op_def.foast_node.id in inp.foast_op_def.closure_vars: - raise RuntimeError("A closure variable has the same name as the field operator itself.") - closure_symbols: list[past.Symbol] = [ - past.Symbol( - id=inp.foast_op_def.foast_node.id, - type=ts.DeferredType(constraint=None), - namespace=dialect_ast_enums.Namespace.CLOSURE, - location=loc, - ), - ] - - untyped_past_node = past.Program( - id=f"__field_operator_{inp.foast_op_def.foast_node.id}", - type=ts.DeferredType(constraint=ts_ffront.ProgramType), - params=[*params_decl, out_sym], - body=[ - past.Call( - func=past.Name(id=inp.foast_op_def.foast_node.id, location=loc), - args=params_ref, - kwargs={"out": out_ref}, - location=loc, - ) - ], - closure_vars=closure_symbols, - location=loc, - ) - untyped_past_node = closure_var_type_deduction.ClosureVarTypeDeduction.apply( - untyped_past_node, inp.closure_vars - ) - past_node = type_deduction.ProgramTypeDeduction.apply(untyped_past_node) + out_ref = past.Name(id="out", location=loc) - return ffront_stages.PastProgramDefinition( - past_node=past_node, - closure_vars=inp.closure_vars, - grid_type=inp.foast_op_def.grid_type, - ) + if inp.data.foast_node.id in inp.data.closure_vars: + raise RuntimeError("A closure variable has the same name as the field operator itself.") + closure_symbols: list[past.Symbol] = [ + past.Symbol( + id=inp.data.foast_node.id, + type=ts.DeferredType(constraint=None), + namespace=dialect_ast_enums.Namespace.CLOSURE, + location=loc, + ), + ] -@dataclasses.dataclass(frozen=True) -class FoastToPastClosure(workflow.NamedStepSequence): - foast_to_past: workflow.Workflow[ - ffront_stages.FoastWithTypes, ffront_stages.PastProgramDefinition - ] - - def __call__(self, inp: ffront_stages.FoastClosure) -> ffront_stages.PastClosure: - # TODO(tehrengruber): check all offset providers are given - # deduce argument types - arg_types = [] - for arg in inp.args: - arg_types.append(type_translation.from_value(arg)) - kwarg_types = {} - for name, arg in inp.kwargs.items(): - kwarg_types[name] = type_translation.from_value(arg) - - past_def = super().__call__( - ffront_stages.FoastWithTypes( - foast_op_def=inp.foast_op_def, - arg_types=tuple(arg_types), - kwarg_types=kwarg_types, - closure_vars=inp.closure_vars, - ) + fieldop_itir_closure_vars = {inp.data.foast_node.id: ItirShim(inp, self.foast_to_itir)} + + untyped_past_node = past.Program( + id=f"__field_operator_{inp.data.foast_node.id}", + type=ts.DeferredType(constraint=ts_ffront.ProgramType), + params=[*params_decl, out_sym], + body=[ + past.Call( + func=past.Name(id=inp.data.foast_node.id, location=loc), + args=params_ref, + kwargs={"out": out_ref}, + location=loc, + ) + ], + closure_vars=closure_symbols, + location=loc, ) + untyped_past_node = closure_var_type_deduction.ClosureVarTypeDeduction.apply( + untyped_past_node, fieldop_itir_closure_vars + ) + past_node = type_deduction.ProgramTypeDeduction.apply(untyped_past_node) - return ffront_stages.PastClosure( - past_node=past_def.past_node, - closure_vars=past_def.closure_vars, - grid_type=past_def.grid_type, + return toolchain.CompilableProgram( + data=ffront_stages.PastProgramDefinition( + past_node=past_node, + closure_vars=fieldop_itir_closure_vars, + grid_type=inp.data.grid_type, + ), args=inp.args, - kwargs=inp.kwargs, ) + + +def operator_to_program_factory( + foast_to_itir_step: Optional[workflow.Workflow[AOT_FOP, itir.Expr]] = None, + cached: bool = True, +) -> workflow.Workflow[AOT_FOP, AOT_PRG]: + """Optionally wrap `OperatorToProgram` in a `CachedStep`.""" + wf: workflow.Workflow[AOT_FOP, AOT_PRG] = OperatorToProgram( + foast_to_itir_step or foast_to_itir.adapted_foast_to_itir_factory() + ) + if cached: + wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + return wf diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 120068220f..887e6cecba 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -10,12 +10,9 @@ import ast import builtins -import dataclasses import typing from typing import Any, Callable, Iterable, Mapping, Type -import factory - import gt4py.eve as eve from gt4py.next import errors from gt4py.next.ffront import ( @@ -39,14 +36,33 @@ from gt4py.next.ffront.foast_passes.iterable_unpack import UnpackedAssignPass from gt4py.next.ffront.foast_passes.type_alias_replacement import TypeAliasReplacement from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction -from gt4py.next.otf import workflow +from gt4py.next.ffront.stages import AOT_DSL_FOP, AOT_FOP, DSL_FOP, FOP +from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -@workflow.make_step -def func_to_foast( - inp: ffront_stages.FieldOperatorDefinition[ffront_stages.OperatorNodeT], -) -> ffront_stages.FoastOperatorDefinition[ffront_stages.OperatorNodeT]: +def func_to_foast(inp: DSL_FOP) -> FOP: + """ + Turn a DSL field operator definition into a FOAST operator definition, adding metadata. + + Examples: + + >>> from gt4py import next as gtx + >>> IDim = gtx.Dimension("I") + + >>> const = gtx.float32(2.0) + >>> def dsl_operator(a: gtx.Field[[IDim], gtx.float32]) -> gtx.Field[[IDim], gtx.float32]: + ... return a * const + + >>> dsl_operator_def = gtx.ffront.stages.FieldOperatorDefinition(definition=dsl_operator) + >>> foast_definition = func_to_foast(dsl_operator_def) + + >>> print(foast_definition.foast_node.id) + dsl_operator + + >>> print(foast_definition.closure_vars) + {'const': 2.0} + """ source_def = source_utils.SourceDefinition.from_function(inp.definition) closure_vars = source_utils.get_closure_vars_from_function(inp.definition) annotations = typing.get_type_hints(inp.definition) @@ -68,43 +84,21 @@ def func_to_foast( closure_vars=closure_vars, grid_type=inp.grid_type, attributes=inp.attributes, + debug=inp.debug, ) -@dataclasses.dataclass(frozen=True) -class OptionalFuncToFoast(workflow.SkippableStep): - step: workflow.Workflow[ - ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition - ] = func_to_foast - - def skip_condition( - self, inp: ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition - ) -> bool: - match inp: - case ffront_stages.FieldOperatorDefinition(): - return False - case ffront_stages.FoastOperatorDefinition(): - return True - - -@dataclasses.dataclass(frozen=True) -class OptionalFuncToFoastFactory(factory.Factory): - class Meta: - model = OptionalFuncToFoast - - class Params: - workflow: workflow.Workflow[ - ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition - ] = func_to_foast - cached = factory.Trait( - step=factory.LazyAttribute( - lambda o: workflow.CachedStep( - step=o.workflow, hash_function=ffront_stages.fingerprint_stage - ) - ) - ) +def func_to_foast_factory(cached: bool = True) -> workflow.Workflow[DSL_FOP, FOP]: + """Wrap `func_to_foast` in a chainable and optionally cached workflow step.""" + wf = workflow.make_step(func_to_foast) + if cached: + wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) + return wf + - step = factory.LazyAttribute(lambda o: o.workflow) +def adapted_func_to_foast_factory(**kwargs: Any) -> workflow.Workflow[AOT_DSL_FOP, AOT_FOP]: + """Wrap the `func_to_foast step in an adapter to fit into transform toolchains.`""" + return toolchain.DataOnlyAdapter(func_to_foast_factory(**kwargs)) class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index e1de316b15..f415c95b63 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -13,8 +13,6 @@ import typing from typing import Any, cast -import factory - from gt4py.next import errors from gt4py.next.ffront import ( dialect_ast_enums, @@ -26,12 +24,35 @@ from gt4py.next.ffront.dialect_parser import DialectParser from gt4py.next.ffront.past_passes.closure_var_type_deduction import ClosureVarTypeDeduction from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction -from gt4py.next.otf import workflow +from gt4py.next.ffront.stages import AOT_DSL_PRG, AOT_PRG, DSL_PRG, PRG +from gt4py.next.otf import toolchain, workflow from gt4py.next.type_system import type_specifications as ts, type_translation -@workflow.make_step -def func_to_past(inp: ffront_stages.ProgramDefinition) -> ffront_stages.PastProgramDefinition: +def func_to_past(inp: DSL_PRG) -> PRG: + """ + Turn a DSL program definition into a PAST Program definition, adding metadata. + + Examples: + + >>> from gt4py import next as gtx + >>> IDim = gtx.Dimension("I") + + >>> @gtx.field_operator + ... def copy(a: gtx.Field[[IDim], gtx.float32]) -> gtx.Field[[IDim], gtx.float32]: + ... return a + + >>> def dsl_program(a: gtx.Field[[IDim], gtx.float32], out: gtx.Field[[IDim], gtx.float32]): + ... copy(a, out=out) + + >>> dsl_definition = gtx.ffront.stages.ProgramDefinition(definition=dsl_program) + >>> past_definition = func_to_past(dsl_definition) + + >>> print(past_definition.past_node.id) + dsl_program + + >>> assert "copy" in past_definition.closure_vars + """ source_def = source_utils.SourceDefinition.from_function(inp.definition) closure_vars = source_utils.get_closure_vars_from_function(inp.definition) annotations = typing.get_type_hints(inp.definition) @@ -39,40 +60,29 @@ def func_to_past(inp: ffront_stages.ProgramDefinition) -> ffront_stages.PastProg past_node=ProgramParser.apply(source_def, closure_vars, annotations), closure_vars=closure_vars, grid_type=inp.grid_type, + debug=inp.debug, ) -@dataclasses.dataclass(frozen=True) -class OptionalFuncToPast(workflow.SkippableStep): - step: workflow.Workflow[ - ffront_stages.ProgramDefinition, ffront_stages.PastProgramDefinition - ] = func_to_past - - def skip_condition( - self, inp: ffront_stages.PastProgramDefinition | ffront_stages.ProgramDefinition - ) -> bool: - match inp: - case ffront_stages.ProgramDefinition(): - return False - case ffront_stages.PastProgramDefinition(): - return True - - -class OptionalFuncToPastFactory(factory.Factory): - class Meta: - model = OptionalFuncToPast - - class Params: - workflow = func_to_past - cached = factory.Trait( - step=factory.LazyAttribute( - lambda o: workflow.CachedStep( - step=o.workflow, hash_function=ffront_stages.fingerprint_stage - ) - ) - ) +def func_to_past_factory(cached: bool = False) -> workflow.Workflow[DSL_PRG, PRG]: + """ + Wrap `func_to_past` in a chainable and optionally cached workflow step. + + Caching is switched off by default, because whether recompiling is necessary can only be known after + the closure variables have been collected (which is done in this step). In special cases where it can + be guaranteed that the closure variables do not change, switching caching on should be safe. + """ + wf = workflow.make_step(func_to_past) + if cached: + wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + return wf + - step = factory.LazyAttribute(lambda o: o.workflow) +def adapted_func_to_past_factory(**kwargs: Any) -> workflow.Workflow[AOT_DSL_PRG, AOT_PRG]: + """ + Wrap an adapter around the DSL definition -> PAST definition step to fit into transform toolchains. + """ + return toolchain.DataOnlyAdapter(func_to_past_factory(**kwargs)) @dataclasses.dataclass(frozen=True, kw_only=True) diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py index 27dfbb6b35..1a5e7a757b 100644 --- a/src/gt4py/next/ffront/past_passes/linters.py +++ b/src/gt4py/next/ffront/past_passes/linters.py @@ -6,10 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import factory +from typing import Any from gt4py.next.ffront import gtcallable, stages as ffront_stages, transform_utils -from gt4py.next.otf import workflow +from gt4py.next.ffront.stages import AOT_PRG, PRG +from gt4py.next.otf import toolchain, workflow @workflow.make_step @@ -45,9 +46,12 @@ def lint_undefined_symbols( return inp -class LinterFactory(factory.Factory): - class Meta: - model = workflow.CachedStep +def linter_factory(cached: bool = True, adapter: bool = True) -> workflow.Workflow[PRG, PRG]: + wf = lint_misnamed_functions.chain(lint_undefined_symbols) + if cached: + wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) + return wf - step = lint_misnamed_functions.chain(lint_undefined_symbols) - hash_function = ffront_stages.fingerprint_stage + +def adapted_linter_factory(**kwargs: Any) -> workflow.Workflow[AOT_PRG, AOT_PRG]: + return toolchain.DataOnlyAdapter(linter_factory(**kwargs)) diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 3cd4d9b0eb..03326dfc57 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Iterator, Optional +from typing import Any, Iterator, TypeAlias from gt4py.next import common, errors from gt4py.next.ffront import ( @@ -14,27 +14,37 @@ stages as ffront_stages, type_specifications as ts_ffront, ) -from gt4py.next.otf import workflow +from gt4py.next.otf import arguments, toolchain, workflow from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -@workflow.make_step -def past_process_args(inp: ffront_stages.PastClosure) -> ffront_stages.PastClosure: - extra_kwarg_names = ["offset_provider", "column_axis"] - extra_kwargs = {k: v for k, v in inp.kwargs.items() if k in extra_kwarg_names} - kwargs = {k: v for k, v in inp.kwargs.items() if k not in extra_kwarg_names} +AOT_PRG: TypeAlias = toolchain.CompilableProgram[ + ffront_stages.PastProgramDefinition, arguments.CompileTimeArgs +] + + +def transform_program_args(inp: AOT_PRG) -> AOT_PRG: rewritten_args, size_args, kwargs = _process_args( - past_node=inp.past_node, args=list(inp.args), kwargs=kwargs + past_node=inp.data.past_node, args=list(inp.args.args), kwargs=inp.args.kwargs ) - return ffront_stages.PastClosure( - past_node=inp.past_node, - closure_vars=inp.closure_vars, - grid_type=inp.grid_type, - args=tuple([*rewritten_args, *size_args]), - kwargs=kwargs | extra_kwargs, + return toolchain.CompilableProgram( + data=inp.data, + args=arguments.CompileTimeArgs( + args=tuple((*rewritten_args, *(size_args))), + kwargs=kwargs, + offset_provider=inp.args.offset_provider, + column_axis=inp.args.column_axis, + ), ) +def transform_program_args_factory(cached: bool = True) -> workflow.Workflow[AOT_PRG, AOT_PRG]: + wf = transform_program_args + if cached: + wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + return wf + + def _validate_args(past_node: past.Program, args: list, kwargs: dict[str, Any]) -> None: arg_types = [type_translation.from_value(arg) for arg in args] kwarg_types = {k: type_translation.from_value(v) for k, v in kwargs.items()} @@ -67,7 +77,7 @@ def _process_args( ) # extract size of all field arguments - size_args: list[Optional[int]] = [] + size_args: list[int | type_translation.SizeArg] = [] rewritten_args = list(args) for param_idx, param in enumerate(past_node.params): if implicit_domain and isinstance(param.type, (ts.FieldType, ts.TupleType)): @@ -84,7 +94,7 @@ def _process_args( "Constituents of composite arguments (e.g. the elements of a" " tuple) need to have the same shape and dimensions." ) - size_args.extend(shape if shape else [None] * len(dims)) + size_args.extend(shape if shape else [type_translation.SizeArg()] * len(dims)) # type: ignore[arg-type] ##(ricoh) mypy is unable to correctly defer the type of the ternary expression return tuple(rewritten_args), tuple(size_args), kwargs @@ -98,7 +108,9 @@ def _field_constituents_shape_and_dims( yield from _field_constituents_shape_and_dims(el, el_type) case ts.FieldType(): dims = type_info.extract_dims(arg_type) - if dims: + if isinstance(arg, arguments.CompileTimeArg): + yield (tuple(), dims) + elif dims: assert hasattr(arg, "shape") and len(arg.shape) == len(dims) yield (arg.shape, dims) else: diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 6dd3fa1753..fe2a20ae07 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -9,10 +9,10 @@ from __future__ import annotations import dataclasses +import functools from typing import Any, Optional, cast import devtools -import factory from gt4py.eve import NodeTranslator, concepts, traits from gt4py.next import common, config, errors @@ -25,48 +25,87 @@ transform_utils, type_specifications as ts_ffront, ) +from gt4py.next.ffront.stages import AOT_PRG from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import stages, workflow from gt4py.next.type_system import type_info, type_specifications as ts -@dataclasses.dataclass(frozen=True) -class PastToItir(workflow.ChainableWorkflowMixin): - to_gtir: bool = False # FIXME[#1582](havogt): remove after refactoring to GTIR - - def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall: - all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars) - offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( - all_closure_vars, fbuiltins.FieldOffset, common.Dimension - ) - grid_type = transform_utils._deduce_grid_type( - inp.grid_type, offsets_and_dimensions.values() - ) - - gt_callables = transform_utils._filter_closure_vars_by_type( - all_closure_vars, gtcallable.GTCallable - ).values() - lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] +# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR +def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.AOTProgram: + """ + Lower a PAST program definition to Iterator IR. + + Example: + >>> from gt4py import next as gtx + >>> from gt4py.next.otf import arguments, toolchain + >>> IDim = gtx.Dimension("I") + + >>> @gtx.field_operator + ... def copy(a: gtx.Field[[IDim], gtx.float32]) -> gtx.Field[[IDim], gtx.float32]: + ... return a + + >>> @gtx.program + ... def copy_program(a: gtx.Field[[IDim], gtx.float32], out: gtx.Field[[IDim], gtx.float32]): + ... copy(a, out=out) + + >>> compile_time_args = arguments.CompileTimeArgs.from_concrete( + ... *( + ... arguments.CompileTimeArg(param.type) + ... for param in copy_program.past_stage.past_node.params + ... ), + ... offset_provider={"I", IDim}, + ... ) # this will include field dim size arguments automatically. + + >>> itir_copy = past_to_itir( + ... toolchain.CompilableProgram(copy_program.past_stage, compile_time_args) + ... ) + + >>> print(itir_copy.data.id) + copy_program + + >>> print(type(itir_copy.data)) + + """ + all_closure_vars = transform_utils._get_closure_vars_recursively(inp.data.closure_vars) + offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( + all_closure_vars, fbuiltins.FieldOffset, common.Dimension + ) + grid_type = transform_utils._deduce_grid_type( + inp.data.grid_type, offsets_and_dimensions.values() + ) + + gt_callables = transform_utils._filter_closure_vars_by_type( + all_closure_vars, gtcallable.GTCallable + ).values() + # TODO(ricoh): The following calls to .__gt_itir__, which will use whatever + # backend is set for each of these field operators (GTCallables). Instead + # we should use the current toolchain to lower these to ITIR. This will require + # making this step aware of the toolchain it is called by (it can be part of multiple). + lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] - itir_program = ProgramLowering.apply( - inp.past_node, - function_definitions=lowered_funcs, - grid_type=grid_type, - to_gtir=self.to_gtir, - ) + itir_program = ProgramLowering.apply( + inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type, to_gtir=to_gtir + ) - if config.DEBUG or "debug" in inp.kwargs: - devtools.debug(itir_program) + if config.DEBUG or inp.data.debug: + devtools.debug(itir_program) - return stages.ProgramCall( - itir_program, inp.args, inp.kwargs | {"column_axis": _column_axis(all_closure_vars)} - ) + return stages.AOTProgram( + data=itir_program, + args=dataclasses.replace(inp.args, column_axis=_column_axis(all_closure_vars)), + ) -class PastToItirFactory(factory.Factory): - class Meta: - model = PastToItir +# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR +def past_to_itir_factory( + cached: bool = True, to_gtir: bool = False +) -> workflow.Workflow[AOT_PRG, stages.AOTProgram]: + wf = workflow.make_step(functools.partial(past_to_itir, to_gtir=to_gtir)) + if cached: + wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) + return wf def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]: @@ -207,8 +246,10 @@ def visit_Program( params = self.visit(node.params) + implicit_domain = False if any("domain" not in body_entry.kwargs for body_entry in node.body): params = params + self._gen_size_params_from_program(node) + implicit_domain = True if self.to_gtir: set_ats = [self._visit_stencil_call_as_set_at(stmt, **kwargs) for stmt in node.body] @@ -218,6 +259,7 @@ def visit_Program( params=params, declarations=[], body=set_ats, + implicit_domain=implicit_domain, ) else: closures = [self._visit_stencil_call_as_closure(stmt, **kwargs) for stmt in node.body] @@ -226,6 +268,7 @@ def visit_Program( function_definitions=function_definitions, params=params, closures=closures, + implicit_domain=implicit_domain, ) def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir.SetAt: diff --git a/src/gt4py/next/ffront/signature.py b/src/gt4py/next/ffront/signature.py new file mode 100644 index 0000000000..9752ceaf32 --- /dev/null +++ b/src/gt4py/next/ffront/signature.py @@ -0,0 +1,153 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +# TODO(ricoh): This overlaps with `canonicalize_arguments`, solutions: +# - merge the two +# - extract the signature gathering functionality from canonicalize_arguments +# and use it to pass the signature through the toolchain so that the +# decorate step can take care of it. Then get rid of all pre-toolchain +# arguments rearranging (including this module) + +from __future__ import annotations + +import functools +import inspect +import types +from typing import Any, Callable + +from gt4py.next.ffront import ( + field_operator_ast as foast, + program_ast as past, + stages as ffront_stages, +) +from gt4py.next.type_system import type_specifications as ts + + +def should_be_positional(param: inspect.Parameter) -> bool: + return (param.kind is inspect.Parameter.POSITIONAL_ONLY) or ( + param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD + ) + + +@functools.singledispatch +def make_signature(func: Any) -> inspect.Signature: + """Make a signature for a Python or DSL callable, which suffices for use in 'convert_to_positional'.""" + if isinstance(func, types.FunctionType): + return inspect.signature(func) + raise NotImplementedError(f"'make_signature' not implemented for {type(func)}.") + + +@make_signature.register(foast.ScanOperator) +@make_signature.register(past.Program) +@make_signature.register(foast.FieldOperator) +def signature_from_fieldop(func: foast.FieldOperator) -> inspect.Signature: + if isinstance(func.type, ts.DeferredType): + raise NotImplementedError( + f"'make_signature' not implemented for pre type deduction {type(func)}." + ) + fieldview_signature = func.type.definition + return inspect.Signature( + parameters=[ + inspect.Parameter(name=str(i), kind=inspect.Parameter.POSITIONAL_ONLY) + for i, param in enumerate(fieldview_signature.pos_only_args) + ] + + [ + inspect.Parameter(name=k, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD) + for k in fieldview_signature.pos_or_kw_args + ], + ) + + +@make_signature.register(ffront_stages.FieldOperatorDefinition) +def signature_from_fieldop_def(func: ffront_stages.FieldOperatorDefinition) -> inspect.Signature: + signature = make_signature(func.definition) + if func.node_class == foast.ScanOperator: + return inspect.Signature(list(signature.parameters.values())[1:]) + return signature + + +@make_signature.register(ffront_stages.ProgramDefinition) +def signature_from_program_def(func: ffront_stages.ProgramDefinition) -> inspect.Signature: + return make_signature(func.definition) + + +@make_signature.register(ffront_stages.FoastOperatorDefinition) +def signature_from_foast_stage(func: ffront_stages.FoastOperatorDefinition) -> inspect.Signature: + return make_signature(func.foast_node) + + +@make_signature.register +def signature_from_past_stage(func: ffront_stages.PastProgramDefinition) -> inspect.Signature: + return make_signature(func.past_node) + + +def convert_to_positional( + func: Callable + | foast.FieldOperator + | foast.ScanOperator + | ffront_stages.FieldOperatorDefinition + | ffront_stages.FoastOperatorDefinition + | ffront_stages.ProgramDefinition + | ffront_stages.PastProgramDefinition, + *args: Any, + **kwargs: Any, +) -> tuple[tuple[Any, ...], dict[str, Any]]: + """ + Convert arguments given as keyword args to positional ones where possible. + + Raises en error if and only if there are clearly missing positional arguments, + Without awareness of the peculiarities of DSL function signatures. A more + thorough check on whether the signature is fulfilled is expected to happen + later in the toolchain. + + Note that positional-or-keyword arguments with defaults will have their defaults + inserted even if not strictly necessary. This is to reduce complexity and should + be changed if the current behavior is found harmful in some way. + + Examples: + >>> def example(posonly, /, pos_or_key, pk_with_default=42, *, key_only=43): + ... pass + + >>> convert_to_positional(example, 1, pos_or_key=2, key_only=3) + ((1, 2, 42), {'key_only': 3}) + >>> # inserting the default value '42' here could be avoided + >>> # but this is not the current behavior. + """ + signature = make_signature(func) + new_args = list(args) + modified_kwargs = kwargs.copy() + missing = [] + interesting_params = [p for p in signature.parameters.values() if should_be_positional(p)] + + for param in interesting_params[len(args) :]: + if param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD and param.name in modified_kwargs: + # if keyword allowed, check if was given as kwarg + new_args.append(modified_kwargs.pop(param.name)) + else: + # add default and report as missing if no default + # note: this treats POSITIONAL_ONLY params correctly, as they can not have a default. + new_args.append(param.default) + if param.default is inspect._empty: + missing.append(param.name) + if missing: + raise TypeError(f"Missing positional argument(s): {', '.join(missing)}.") + return tuple(new_args), modified_kwargs diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 7402922ae9..bf3bee4b56 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -21,11 +21,7 @@ from gt4py.eve import extended_typing as xtyping from gt4py.next import common from gt4py.next.ffront import field_operator_ast as foast, program_ast as past, source_utils -from gt4py.next.type_system import type_specifications as ts - - -if typing.TYPE_CHECKING: - pass +from gt4py.next.otf import arguments, toolchain OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) @@ -37,6 +33,11 @@ class FieldOperatorDefinition(Generic[OperatorNodeT]): grid_type: Optional[common.GridType] = None node_class: type[OperatorNodeT] = dataclasses.field(default=foast.FieldOperator) # type: ignore[assignment] # TODO(ricoh): understand why mypy complains attributes: dict[str, Any] = dataclasses.field(default_factory=dict) + debug: bool = False + + +DSL_FOP: typing.TypeAlias = FieldOperatorDefinition +AOT_DSL_FOP: typing.TypeAlias = toolchain.CompilableProgram[DSL_FOP, arguments.CompileTimeArgs] @dataclasses.dataclass(frozen=True) @@ -45,28 +46,22 @@ class FoastOperatorDefinition(Generic[OperatorNodeT]): closure_vars: dict[str, Any] grid_type: Optional[common.GridType] = None attributes: dict[str, Any] = dataclasses.field(default_factory=dict) + debug: bool = False -@dataclasses.dataclass(frozen=True) -class FoastWithTypes(Generic[OperatorNodeT]): - foast_op_def: FoastOperatorDefinition[OperatorNodeT] - arg_types: tuple[ts.TypeSpec, ...] - kwarg_types: dict[str, ts.TypeSpec] - closure_vars: dict[str, Any] - - -@dataclasses.dataclass(frozen=True) -class FoastClosure(Generic[OperatorNodeT]): - foast_op_def: FoastOperatorDefinition[OperatorNodeT] - args: tuple[Any, ...] - kwargs: dict[str, Any] - closure_vars: dict[str, Any] +FOP: typing.TypeAlias = FoastOperatorDefinition +AOT_FOP: typing.TypeAlias = toolchain.CompilableProgram[FOP, arguments.CompileTimeArgs] @dataclasses.dataclass(frozen=True) class ProgramDefinition: definition: types.FunctionType grid_type: Optional[common.GridType] = None + debug: bool = False + + +DSL_PRG: typing.TypeAlias = ProgramDefinition +AOT_DSL_PRG: typing.TypeAlias = toolchain.CompilableProgram[DSL_PRG, arguments.CompileTimeArgs] @dataclasses.dataclass(frozen=True) @@ -74,15 +69,11 @@ class PastProgramDefinition: past_node: past.Program closure_vars: dict[str, Any] grid_type: Optional[common.GridType] = None + debug: bool = False -@dataclasses.dataclass(frozen=True) -class PastClosure: - closure_vars: dict[str, Any] - past_node: past.Program - grid_type: Optional[common.GridType] - args: tuple[Any, ...] - kwargs: dict[str, Any] +PRG: typing.TypeAlias = PastProgramDefinition +AOT_PRG: typing.TypeAlias = toolchain.CompilableProgram[PRG, arguments.CompileTimeArgs] def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str: @@ -109,17 +100,21 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No @add_content_to_fingerprint.register(FieldOperatorDefinition) @add_content_to_fingerprint.register(FoastOperatorDefinition) -@add_content_to_fingerprint.register(FoastWithTypes) -@add_content_to_fingerprint.register(FoastClosure) -@add_content_to_fingerprint.register(ProgramDefinition) @add_content_to_fingerprint.register(PastProgramDefinition) -@add_content_to_fingerprint.register(PastClosure) -def add_content_to_fingerprint_stages(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None: +@add_content_to_fingerprint.register(toolchain.CompilableProgram) +@add_content_to_fingerprint.register(arguments.CompileTimeArgs) +def add_stage_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None: add_content_to_fingerprint(obj.__class__, hasher) for field in dataclasses.fields(obj): add_content_to_fingerprint(getattr(obj, field.name), hasher) +def add_jit_args_id_to_fingerprint( + obj: arguments.JITArgs, hasher: xtyping.HashlibAlgorithm +) -> None: + add_content_to_fingerprint(str(id(obj)), hasher) + + @add_content_to_fingerprint.register def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgorithm) -> None: sourcedef = source_utils.SourceDefinition.from_function(obj) @@ -153,3 +148,4 @@ def add_foast_located_node_to_fingerprint( ) -> None: add_content_to_fingerprint(obj.location, hasher) add_content_to_fingerprint(str(obj), hasher) + add_content_to_fingerprint(str(obj), hasher) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index fb8888271d..e1b52043ed 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -54,6 +54,7 @@ ) from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime +from gt4py.next.otf import arguments from gt4py.next.type_system import type_specifications as ts, type_translation @@ -144,36 +145,6 @@ def _dace_descriptor(self) -> NoReturn: # type: ignore[misc] __descriptor__ = _dace_descriptor -@dataclasses.dataclass(frozen=True) -class CompileTimeConnectivity: - max_neighbors: int - has_skip_values: bool - origin_axis: common.Dimension - neighbor_axis: common.Dimension - index_type: type[int] | type[np.int32] | type[np.int64] - - def mapped_index( - self, cur_index: int | np.integer, neigh_index: int | np.integer - ) -> Optional[int | np.integer]: - raise NotImplementedError( - "A CompileTimeConnectivity instance should not call `mapped_index`." - ) - - @classmethod - def from_connectivity(cls, connectivity: common.Connectivity) -> Self: - return cls( - max_neighbors=connectivity.max_neighbors, - has_skip_values=connectivity.has_skip_values, - origin_axis=connectivity.origin_axis, - neighbor_axis=connectivity.neighbor_axis, - index_type=connectivity.index_type, - ) - - @property - def table(self): - return None - - class StridedNeighborOffsetProvider: def __init__( self, @@ -1764,6 +1735,11 @@ def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any): common.UnitRange(0, 0), # empty: indicates column operation, will update later ) + import inspect + + if len(args) < len(inspect.getfullargspec(fun).args): + args = (*args, *arguments.iter_size_args(args)) + with embedded_context.new_context(**context_vars) as ctx: ctx.run(fun, *args) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 163c4b9b01..28adaaddf1 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -197,6 +197,7 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait): function_definitions: List[FunctionDefinition] params: List[Sym] closures: List[StencilClosure] + implicit_domain: bool = False _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS] @@ -222,6 +223,7 @@ class Program(Node, ValidatedSymbolTableTrait): params: List[Sym] declarations: List[Temporary] body: List[Stmt] + implicit_domain: bool = False _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in GTIR_BUILTINS] diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py index 90a071ca1d..db0b81a837 100644 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ b/src/gt4py/next/iterator/transforms/fencil_to_program.py @@ -30,6 +30,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: params=node.params, declarations=[], body=self.visit(node.closures), + implicit_domain=node.implicit_domain, ) def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries) -> itir.Program: @@ -39,4 +40,5 @@ def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries) - params=node.params, declarations=node.tmps, body=self.visit(node.fencil.closures), + implicit_domain=node.fencil.implicit_domain, ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 9bbbaa5c8f..f00a5e9f70 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -353,6 +353,7 @@ def always_extract_heuristics(_: ir.StencilClosure) -> Callable[[ir.Expr], bool] params=node.params + [im.sym(name) for name, _ in tmps] + [im.sym(AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant closures=list(reversed(closures)), location=node.location, + implicit_domain=node.implicit_domain, ), params=node.params, tmps=[ir.Temporary(id=name, dtype=type_) for name, type_ in tmps], @@ -626,6 +627,7 @@ def update_domains( params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again closures=list(reversed(closures)), location=node.fencil.location, + implicit_domain=node.fencil.implicit_domain, ), params=node.params, tmps=node.tmps, diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py new file mode 100644 index 0000000000..2bd6c2ebe9 --- /dev/null +++ b/src/gt4py/next/otf/arguments.py @@ -0,0 +1,244 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +import dataclasses +import typing +from typing import Any, Iterable, Iterator, Optional + +import numpy as np +from typing_extensions import Self + +from gt4py.next import common +from gt4py.next.otf import toolchain, workflow +from gt4py.next.type_system import type_specifications as ts, type_translation + + +DATA_T = typing.TypeVar("DATA_T") + + +@dataclasses.dataclass(frozen=True) +class JITArgs: + """Concrete (runtime) arguments to a GTX program in a format that can be passed into the toolchain.""" + + args: tuple[Any, ...] + kwargs: dict[str, Any] + + @classmethod + def from_signature(cls, *args: Any, **kwargs: Any) -> Self: + return cls(args=args, kwargs=kwargs) + + +@dataclasses.dataclass(frozen=True) +class CompileTimeArg: + """Standin (at compile-time) for a GTX program argument, retaining only the type information.""" + + gt_type: ts.TypeSpec + + def __gt_type__(self) -> ts.TypeSpec: + return self.gt_type + + @classmethod + def from_concrete(cls, value: Any) -> Self | tuple[Self | tuple, ...]: + gt_type = type_translation.from_value(value) + match gt_type: + case ts.TupleType(): + return tuple(cls.from_concrete(element) for element in value) + case _: + return cls(gt_type) + + +@dataclasses.dataclass(frozen=True) +class CompileTimeConnectivity(common.Connectivity): + """Compile-time standin for a GTX connectivity, retaining everything except the connectivity tables.""" + + max_neighbors: int + has_skip_values: bool + origin_axis: common.Dimension + neighbor_axis: common.Dimension + index_type: type[int] | type[np.int32] | type[np.int64] + + def mapped_index( + self, cur_index: int | np.integer, neigh_index: int | np.integer + ) -> Optional[int | np.integer]: + raise NotImplementedError( + "A CompileTimeConnectivity instance should not call `mapped_index`." + ) + + @classmethod + def from_connectivity(cls, connectivity: common.Connectivity) -> Self: + return cls( + max_neighbors=connectivity.max_neighbors, + has_skip_values=connectivity.has_skip_values, + origin_axis=connectivity.origin_axis, + neighbor_axis=connectivity.neighbor_axis, + index_type=connectivity.index_type, + ) + + @property + def table(self) -> None: + return None + + +@dataclasses.dataclass(frozen=True) +class CompileTimeArgs: + """Compile-time standins for arguments to a GTX program to be used in ahead-of-time compilation.""" + + args: tuple[CompileTimeArg | tuple, ...] + kwargs: dict[str, CompileTimeArg | tuple] + offset_provider: dict[str, common.Connectivity | common.Dimension] + column_axis: Optional[common.Dimension] + + @classmethod + def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: + """Convert concrete GTX program arguments into their compile-time counterparts.""" + compile_args = tuple(CompileTimeArg.from_concrete(arg) for arg in args) + kwargs_copy = kwargs.copy() + offset_provider = kwargs_copy.pop("offset_provider", {}) + return cls( + args=compile_args, + offset_provider=offset_provider, # TODO(ricoh): replace with the line below once the temporaries pass is AOT-ready. If unsure, just try it and run the tests. + # offset_provider={k: connectivity_or_dimension(v) for k, v in offset_provider.items()}, # noqa: ERA001 [commented-out-code] + column_axis=kwargs_copy.pop("column_axis", None), + kwargs={ + k: CompileTimeArg.from_concrete(v) for k, v in kwargs_copy.items() if v is not None + }, + ) + + @classmethod + def from_concrete(cls, *args: Any, **kwargs: Any) -> Self: + """Convert concrete GTX program arguments to compile-time, adding (compile-time) dimension size arguments.""" + no_size = cls.from_concrete_no_size(*args, **kwargs) + return cls( + args=(*no_size.args, *iter_size_compile_args(no_size.args)), + offset_provider=no_size.offset_provider, + column_axis=no_size.column_axis, + kwargs=no_size.kwargs, + ) + + @classmethod + def empty(cls) -> Self: + return cls(tuple(), {}, {}, None) + + +def jit_to_aot_args( + inp: JITArgs, +) -> CompileTimeArgs: + return CompileTimeArgs.from_concrete_no_size(*inp.args, **inp.kwargs) + + +def adapted_jit_to_aot_args_factory() -> ( + workflow.Workflow[ + toolchain.CompilableProgram[DATA_T, JITArgs], + toolchain.CompilableProgram[DATA_T, CompileTimeArgs], + ] +): + """Wrap `jit_to_aot` into a workflow adapter to fit into backend transform workflows.""" + return toolchain.ArgsOnlyAdapter(jit_to_aot_args) + + +def connectivity_or_dimension( + some_offset_provider: common.Connectivity | common.Dimension, +) -> CompileTimeConnectivity | common.Dimension: + match some_offset_provider: + case common.Dimension(): + return some_offset_provider + case common.Connectivity(): + return CompileTimeConnectivity.from_connectivity(some_offset_provider) + case _: + raise ValueError + + +def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: + for element in tuple_arg: + match element: + case tuple(): + found = find_first_field(element) + if found: + return found + case common.Field(): + return element + case _: + pass + return None + + +def find_first_field_type( + tuple_arg: tuple[Any, ...], +) -> Optional[CompileTimeArg]: + for element in tuple_arg: + match type_translation.from_value(element): + case ts.TupleType(): + found = find_first_field_type(element) + if found: + return found + case ts.FieldType(): + return element + case _: + pass + return None + + +def iter_size_args(args: tuple[Any, ...], inside_tuple: bool = False) -> Iterator[int]: + """ + Yield the size of each field argument in each dimension. + + This can be used to generate domain size arguments for FieldView Programs that use an implicit domain. + """ + print(f"iter_size_args: matching args {tuple(type(arg) for arg in args)}") + for arg in args: + print(f"iter_size_args: matching arg {arg}") + match arg: + case tuple(): + # we only need the first field, because all fields in a tuple must have the same dims and sizes + first_field = find_first_field(arg) + if first_field: + yield from iter_size_args((first_field,)) + case common.Field(): + print(f"iter_size_args: yielding from {arg.ndarray.shape}") + yield from arg.ndarray.shape + case _: + pass + + +def iter_size_compile_args( + args: Iterable[CompileTimeArg | tuple], +) -> Iterator[CompileTimeArg | tuple]: + """ + Yield a compile-time size argument for every compile-time field argument in each dimension. + + This can be used inside transformation workflows to generate compile-time domain size arguments for FieldView Programs that use an implicit domain. + """ + for arg in args: + match argt := type_translation.from_value(arg): + case ts.TupleType(): + # we only need the first field, because all fields in a tuple must have the same dims and sizes + first_field = find_first_field_type(typing.cast(tuple, arg)) + if first_field: + yield from iter_size_compile_args((first_field,)) + case ts.FieldType(): + yield from [ + CompileTimeArg(ts.ScalarType(kind=ts.ScalarKind.INT32)) for dim in argt.dims + ] + case _: + pass diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index 5199e68aed..debf8217cf 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -64,6 +64,7 @@ def __call__( cmake_flags=self.cmake_extra_flags or [], language=source.program_source.language, language_settings=source.program_source.language_settings, + implicit_domain=source.program_source.implicit_domain, ) if self.renew_compiledb or not ( @@ -213,6 +214,7 @@ def _cc_prototype_program_source( cmake_flags: list[str], language: type[SrcL], language_settings: languages.LanguageWithHeaderFilesSettings, + implicit_domain: bool, ) -> stages.ProgramSource: name = _cc_prototype_program_name(deps, build_type.value, cmake_flags) return stages.ProgramSource( @@ -221,6 +223,7 @@ def _cc_prototype_program_source( library_deps=deps, language=language, language_settings=language_settings, + implicit_domain=implicit_domain, ) diff --git a/src/gt4py/next/otf/compilation/cache.py b/src/gt4py/next/otf/compilation/cache.py index a8e62ed277..8907cd81c0 100644 --- a/src/gt4py/next/otf/compilation/cache.py +++ b/src/gt4py/next/otf/compilation/cache.py @@ -38,7 +38,8 @@ def _serialize_source(source: stages.ProgramSource) -> str: name: {source.entry_point.name} params: {', '.join(parameters)} deps: {', '.join(dependencies)} - src: {source.source_code}\ + src: {source.source_code} + implicit_domain: {source.implicit_domain} """ diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index b91a0a3019..9553faa34c 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -64,7 +64,7 @@ class Compiler( def __call__( self, inp: stages.CompilableSource[SourceLanguageType, LanguageSettingsType, languages.Python], - ) -> stages.CompiledProgram: + ) -> stages.ExtendedCompiledProgram: src_dir = cache.get_cache_folder(inp, self.cache_lifetime) data = build_data.read_data(src_dir) @@ -79,10 +79,17 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) - return getattr( + compiled_prog = getattr( importer.import_from_path(src_dir / new_data.module), new_data.entry_point_name ) + @dataclasses.dataclass(frozen=True) + class Wrapper(stages.ExtendedCompiledProgram): + implicit_domain: bool = inp.program_source.implicit_domain + __call__: stages.CompiledProgram = compiled_prog + + return Wrapper() + class CompilerFactory(factory.Factory): class Meta: diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 78ee8fede9..785a53fa40 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -9,13 +9,15 @@ from __future__ import annotations import dataclasses -from typing import Any, Generic, Optional, Protocol, TypeVar +from typing import Any, Generic, Optional, Protocol, TypeAlias, TypeVar from gt4py.next.iterator import ir as itir -from gt4py.next.otf import languages +from gt4py.next.otf import arguments, languages, toolchain from gt4py.next.otf.binding import interface +PrgT = TypeVar("PrgT") +ArgT = TypeVar("ArgT") SrcL = TypeVar("SrcL", bound=languages.LanguageTag) TgtL = TypeVar("TgtL", bound=languages.LanguageTag) SettingT = TypeVar("SettingT", bound=languages.LanguageSettings) @@ -24,13 +26,9 @@ SettingT_co = TypeVar("SettingT_co", bound=languages.LanguageSettings, covariant=True) -@dataclasses.dataclass(frozen=True) -class ProgramCall: - """ITIR/GTIR representation of a program together with arguments to be passed to it.""" - - program: itir.FencilDefinition | itir.Program - args: tuple[Any, ...] - kwargs: dict[str, Any] +AOTProgram: TypeAlias = toolchain.CompilableProgram[ + itir.FencilDefinition | itir.Program, arguments.CompileTimeArgs +] @dataclasses.dataclass(frozen=True) @@ -49,6 +47,7 @@ class ProgramSource(Generic[SrcL, SettingT]): library_deps: tuple[interface.LibraryDependency, ...] language: type[SrcL] language_settings: SettingT + implicit_domain: bool def __post_init__(self) -> None: if not isinstance(self.language_settings, self.language.settings_class): @@ -110,6 +109,12 @@ class CompiledProgram(Protocol): def __call__(self, *args: Any, **kwargs: Any) -> None: ... +class ExtendedCompiledProgram(CompiledProgram): + """Executable python representation of a program with extra info.""" + + implicit_domain: bool + + def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: """ Filter out multiple occurrences of the same ``interface.LibraryDependency``. diff --git a/src/gt4py/next/otf/step_types.py b/src/gt4py/next/otf/step_types.py index 20df113daf..82381234d5 100644 --- a/src/gt4py/next/otf/step_types.py +++ b/src/gt4py/next/otf/step_types.py @@ -22,7 +22,7 @@ class TranslationStep( - workflow.ReplaceEnabledWorkflowMixin[stages.ProgramCall, stages.ProgramSource[SrcL, LS]], + workflow.ReplaceEnabledWorkflowMixin[stages.AOTProgram, stages.ProgramSource[SrcL, LS]], Protocol[SrcL, LS], ): """Translate a GT4Py program to source code (ProgramCall -> ProgramSource).""" diff --git a/src/gt4py/next/otf/toolchain.py b/src/gt4py/next/otf/toolchain.py new file mode 100644 index 0000000000..4b7bd0b7ef --- /dev/null +++ b/src/gt4py/next/otf/toolchain.py @@ -0,0 +1,66 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import dataclasses +import typing +from typing import Generic + +from gt4py.next.otf import workflow + + +PrgT = typing.TypeVar("PrgT") +ArgT = typing.TypeVar("ArgT") +StartT = typing.TypeVar("StartT") +EndT = typing.TypeVar("EndT") + + +@dataclasses.dataclass +class CompilableProgram(Generic[PrgT, ArgT]): + data: PrgT + args: ArgT + + +@dataclasses.dataclass(frozen=True) +class DataOnlyAdapter( + workflow.ChainableWorkflowMixin, + workflow.ReplaceEnabledWorkflowMixin, + workflow.Workflow[CompilableProgram[StartT, ArgT], CompilableProgram[EndT, ArgT]], + Generic[ArgT, StartT, EndT], +): + step: workflow.Workflow[StartT, EndT] + + def __call__(self, inp: CompilableProgram[StartT, ArgT]) -> CompilableProgram[EndT, ArgT]: + return CompilableProgram(data=self.step(inp.data), args=inp.args) + + +@dataclasses.dataclass(frozen=True) +class ArgsOnlyAdapter( + workflow.ChainableWorkflowMixin, + workflow.ReplaceEnabledWorkflowMixin, + workflow.Workflow[CompilableProgram[PrgT, StartT], CompilableProgram[PrgT, EndT]], + Generic[PrgT, StartT, EndT], +): + step: workflow.Workflow[StartT, EndT] + + def __call__(self, inp: CompilableProgram[PrgT, StartT]) -> CompilableProgram[PrgT, EndT]: + return CompilableProgram(data=inp.data, args=self.step(inp.args)) + + +@dataclasses.dataclass(frozen=True) +class StripArgsAdapter( + workflow.ChainableWorkflowMixin, + workflow.ReplaceEnabledWorkflowMixin, + workflow.Workflow[CompilableProgram[StartT, ArgT], EndT], + Generic[ArgT, StartT, EndT], +): + step: workflow.Workflow[StartT, EndT] + + def __call__(self, inp: CompilableProgram[StartT, ArgT]) -> EndT: + return self.step(inp.data) diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 06057e5f5e..a63801c97e 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -8,6 +8,7 @@ from __future__ import annotations +import abc import dataclasses import functools import typing @@ -23,6 +24,8 @@ NewEndT = TypeVar("NewEndT") IntermediateT = TypeVar("IntermediateT") HashT = TypeVar("HashT") +DataT = TypeVar("DataT") +ArgT = TypeVar("ArgT") def make_step(function: Workflow[StartT, EndT]) -> ChainableWorkflowMixin[StartT, EndT]: @@ -152,6 +155,23 @@ def step_order(self) -> list[str]: return step_names +@dataclasses.dataclass(frozen=True) +class MultiWorkflow( + ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT] +): + """A flexible workflow, where the sequence of steps depends on the input type.""" + + def __call__(self, inp: StartT) -> EndT: + step_result: Any = inp + for step_name in self.step_order(inp): + step_result = getattr(self, step_name)(step_result) + return step_result + + @abc.abstractmethod + def step_order(self, inp: StartT) -> list[str]: + pass + + @dataclasses.dataclass(frozen=True) class StepSequence(ChainableWorkflowMixin[StartT, EndT]): """ @@ -259,26 +279,3 @@ def __call__(self, inp: StartT) -> EndT: def skip_condition(self, inp: StartT) -> bool: raise NotImplementedError() - - -@dataclasses.dataclass -class InputWithArgs(Generic[StartT]): - data: StartT - args: tuple[Any] - kwargs: dict[str, Any] - - -@dataclasses.dataclass(frozen=True) -class NamedStepSequenceWithArgs(NamedStepSequence[InputWithArgs[StartT], EndT]): - def __call__(self, inp: InputWithArgs[StartT]) -> EndT: - args = inp.args - kwargs = inp.kwargs - step_result: Any = inp.data - fields = {f.name: f for f in dataclasses.fields(self)} - for step_name in self.step_order: - step = getattr(self, step_name) - if fields[step_name].metadata.get("takes_args", False): - step_result = step(InputWithArgs(step_result, args, kwargs)) - else: - step_result = step(step_result) - return step_result diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 7f066df1dd..ac5325aade 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -39,11 +39,14 @@ def get_param_description(name: str, obj: Any) -> interface.Parameter: @dataclasses.dataclass(frozen=True) class GTFNTranslationStep( + workflow.ReplaceEnabledWorkflowMixin[ + stages.AOTProgram, + stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings], + ], workflow.ChainableWorkflowMixin[ - stages.ProgramCall, + stages.AOTProgram, stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings], ], - step_types.TranslationStep[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings], ): language_settings: Optional[languages.LanguageWithHeaderFilesSettings] = None # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 @@ -202,22 +205,22 @@ def generate_stencil_source( return codegen.format_source("cpp", generated_code, style="LLVM") def __call__( - self, inp: stages.ProgramCall + self, inp: stages.AOTProgram ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]: """Generate GTFN C++ code from the ITIR definition.""" - program = inp.program + program: itir.FencilDefinition | itir.Program = inp.data assert isinstance(program, itir.FencilDefinition) # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) regular_parameters, regular_args_expr = self._process_regular_arguments( - program, inp.args, inp.kwargs["offset_provider"] + program, inp.args.args, inp.args.offset_provider ) # handle connectivity parameters and arguments (i.e. what the user provided in the offset # provider) connectivity_parameters, connectivity_args_expr = self._process_connectivity_args( - inp.kwargs["offset_provider"] + inp.args.offset_provider ) # combine into a format that is aligned with what the backend expects @@ -233,8 +236,8 @@ def __call__( decl_src = cpp_interface.render_function_declaration(function, body=decl_body) stencil_src = self.generate_stencil_source( program, - inp.kwargs["offset_provider"], - inp.kwargs.get("column_axis", None), + inp.args.offset_provider, + inp.args.column_axis, ) source_code = interface.format_source( self._language_settings(), @@ -254,6 +257,7 @@ def __call__( source_code=source_code, language=self._language(), language_settings=self._language_settings(), + implicit_domain=inp.data.implicit_domain, ) return module @@ -312,8 +316,8 @@ class Meta: model = GTFNTranslationStep -translate_program_cpu: Final[step_types.TranslationStep] = GTFNTranslationStep() +translate_program_cpu: Final[step_types.TranslationStep] = GTFNTranslationStepFactory() -translate_program_gpu: Final[step_types.TranslationStep] = GTFNTranslationStep( +translate_program_gpu: Final[step_types.TranslationStep] = GTFNTranslationStepFactory( device_type=core_defs.DeviceType.CUDA ) diff --git a/src/gt4py/next/program_processors/modular_executor.py b/src/gt4py/next/program_processors/modular_executor.py index 206a012dd0..b5048c93ef 100644 --- a/src/gt4py/next/program_processors/modular_executor.py +++ b/src/gt4py/next/program_processors/modular_executor.py @@ -13,18 +13,21 @@ import gt4py.next.program_processors.processor_interface as ppi from gt4py.next.iterator import ir as itir -from gt4py.next.otf import stages, workflow +from gt4py.next.otf import arguments, stages, workflow @dataclasses.dataclass(frozen=True) class ModularExecutor(ppi.ProgramExecutor): - otf_workflow: workflow.Workflow[stages.ProgramCall, stages.CompiledProgram] + otf_workflow: workflow.Workflow[stages.AOTProgram, stages.CompiledProgram] name: Optional[str] = None def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: - self.otf_workflow(stages.ProgramCall(program=program, args=args, kwargs=kwargs))( - *args, offset_provider=kwargs["offset_provider"] - ) + self.otf_workflow( + stages.AOTProgram( + data=program, + args=arguments.CompileTimeArgs.from_concrete(*args, **kwargs), + ) + )(*args, offset_provider=kwargs["offset_provider"]) @property def __name__(self) -> str: diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index e1b06d8729..1ac62765fa 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -8,9 +8,8 @@ import factory -from gt4py.next import allocators as next_allocators, backend as next_backend -from gt4py.next.ffront import foast_to_gtir, past_to_itir, stages as ffront_stages -from gt4py.next.otf import workflow +from gt4py.next import allocators as next_allocators, backend +from gt4py.next.ffront import foast_to_gtir, past_to_itir from gt4py.next.program_processors import modular_executor from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow @@ -34,6 +33,8 @@ class Params: ) use_field_canonical_representation: bool = False + transforms = backend.DEFAULT_TRANSFORMS + run_dace_cpu = DaCeIteratorBackendFactory(cached=True, auto_optimize=True) run_dace_cpu_noopt = DaCeIteratorBackendFactory(cached=True, auto_optimize=False) @@ -44,18 +45,13 @@ class Params: itir_cpu = run_dace_cpu itir_gpu = run_dace_gpu -gtir_cpu = next_backend.Backend( +gtir_cpu = backend.Backend( executor=modular_executor.ModularExecutor( otf_workflow=dace_fieldview_workflow.DaCeWorkflowFactory(), name="dace.gtir.cpu" ), allocator=next_allocators.StandardCPUFieldBufferAllocator(), - transforms_fop=next_backend.FieldopTransformWorkflow( - past_to_itir=past_to_itir.PastToItirFactory(to_gtir=True), - foast_to_itir=workflow.CachedStep( - step=foast_to_gtir.foast_to_gtir, hash_function=ffront_stages.fingerprint_stage - ), - ), - transforms_prog=next_backend.ProgramTransformWorkflow( - past_to_itir=past_to_itir.PastToItirFactory(to_gtir=True) + transforms=backend.Transforms( + past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), + foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), ), ) diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index 9b30fe7052..dbe2b70ff8 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -18,7 +18,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, config -from gt4py.next.otf import languages, stages, step_types, workflow +from gt4py.next.otf import arguments, languages, stages, step_types, workflow from gt4py.next.otf.compilation import cache from gt4py.next.program_processors.runners.dace_common import dace_backend @@ -111,8 +111,14 @@ def convert_args( on_gpu = True if device == core_defs.DeviceType.CUDA else False def decorated_program( - *args: Any, offset_provider: dict[str, common.Connectivity | common.Dimension] + *args: Any, + offset_provider: dict[str, common.Connectivity | common.Dimension], + out: Any = None, ) -> Any: + if out is not None: + args = (*args, out) + if len(sdfg.arg_names) > len(args): + args = (*args, *arguments.iter_size_args(args)) if sdfg_program._lastargs: # The scalar arguments should be replaced with the actual value; for field arguments, # the data pointer should remain the same otherwise fast-call cannot be used and diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index d5c538ff87..af36c0cbc5 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -29,7 +29,7 @@ @dataclasses.dataclass(frozen=True) class DaCeTranslator( workflow.ChainableWorkflowMixin[ - stages.ProgramCall, stages.ProgramSource[languages.SDFG, languages.LanguageSettings] + stages.AOTProgram, stages.ProgramSource[languages.SDFG, languages.LanguageSettings] ], step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], ): @@ -52,21 +52,21 @@ def generate_sdfg( return gtir_to_sdfg.build_sdfg_from_gtir(program=ir, offset_provider=offset_provider) def __call__( - self, inp: stages.ProgramCall + self, inp: stages.AOTProgram ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the GTIR definition.""" - program = inp.program + program: itir.FencilDefinition | itir.Program = inp.data assert isinstance(program, itir.Program) sdfg = self.generate_sdfg( program, - inp.kwargs["offset_provider"], - inp.kwargs.get("column_axis", None), + inp.args.offset_provider, + inp.args.column_axis, ) param_types = tuple( interface.Parameter(param, tt.from_value(arg)) - for param, arg in zip(sdfg.arg_names, inp.args) + for param, arg in zip(sdfg.arg_names, inp.args.args) ) module: stages.ProgramSource[languages.SDFG, languages.LanguageSettings] = ( @@ -76,6 +76,7 @@ def __call__( library_deps=tuple(), language=languages.SDFG, language_settings=self._language_settings(), + implicit_domain=inp.data.implicit_domain, ) ) return module diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 784c18464a..034fd2e60b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -77,7 +77,8 @@ def _get_scan_dim( - scan_dim_index: domain index of the scan dimension - scan_dim_dtype: data type along the scan dimension """ - output_type = cast(ts.FieldType, storage_types[output.id]) + output_type = storage_types[output.id] + assert isinstance(output_type, ts.FieldType) sorted_dims = [ dim for _, dim in ( @@ -235,7 +236,8 @@ def add_storage_for_temporaries( return tmp_symbols def create_memlet_at(self, field_name: str, index: dict[str, str]): - field_type = cast(ts.FieldType, self.storage_types[field_name]) + field_type = self.storage_types[field_name] + assert isinstance(field_type, ts.FieldType) if self.use_field_canonical_representation: field_index = [ index[dim.value] for _, dim in dace_common_util.get_sorted_dims(field_type.dims) @@ -433,6 +435,18 @@ def visit_StencilClosure( program_arg_syms[name] = value else: program_arg_syms[name] = SymbolExpr(name, dtype) + else: + assert isinstance(type_, ts.FieldType) + # make shape symbols (corresponding to field size) available as arguments to domain visitor + if name in input_names or name in output_names: + field_symbols = [ + val + for val in closure_sdfg.arrays[name].shape + if isinstance(val, dace.symbol) and str(val) not in input_names + ] + for sym in field_symbols: + sym_name = str(sym) + program_arg_syms[sym_name] = SymbolExpr(sym, sym.dtype) closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) closure_domain = self._visit_domain(node.domain, closure_ctx) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 8c96d9dede..72cc15a46e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -31,7 +31,7 @@ @dataclasses.dataclass(frozen=True) class DaCeTranslator( workflow.ChainableWorkflowMixin[ - stages.ProgramCall, stages.ProgramSource[languages.SDFG, languages.LanguageSettings] + stages.AOTProgram, stages.ProgramSource[languages.SDFG, languages.LanguageSettings] ], step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], ): @@ -74,23 +74,23 @@ def generate_sdfg( ) def __call__( - self, inp: stages.ProgramCall + self, inp: stages.AOTProgram ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the ITIR definition.""" - program = inp.program + program: itir.FencilDefinition | itir.Program = inp.data assert isinstance(program, itir.FencilDefinition) - arg_types = [tt.from_value(arg) for arg in inp.args] + arg_types = [tt.from_value(arg) for arg in inp.args.args] sdfg = self.generate_sdfg( program, arg_types, - inp.kwargs["offset_provider"], - inp.kwargs.get("column_axis", None), + inp.args.offset_provider, + inp.args.column_axis, ) param_types = tuple( interface.Parameter(param, tt.from_value(arg)) - for param, arg in zip(sdfg.arg_names, inp.args) + for param, arg in zip(sdfg.arg_names, inp.args.args) ) module: stages.ProgramSource[languages.SDFG, languages.LanguageSettings] = ( @@ -100,6 +100,7 @@ def __call__( library_deps=tuple(), language=languages.SDFG, language_settings=self._language_settings(), + implicit_domain=inp.data.implicit_domain, ) ) return module diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index b400e20736..f03aaabf8f 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -13,6 +13,7 @@ backend = next_backend.Backend( + transforms=next_backend.DEFAULT_TRANSFORMS, executor=roundtrip.RoundtripExecutorFactory(dispatch_backend=roundtrip.default.executor), allocator=roundtrip.default.allocator, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 95e72983a8..a2badd5191 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -19,7 +19,7 @@ from gt4py.next import backend, common, config from gt4py.next.iterator import transforms from gt4py.next.iterator.transforms import global_tmps -from gt4py.next.otf import recipes, stages, workflow +from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb @@ -41,14 +41,23 @@ def convert_arg(arg: Any) -> Any: def convert_args( - inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU + inp: stages.ExtendedCompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU ) -> stages.CompiledProgram: def decorated_program( - *args: Any, offset_provider: dict[str, common.Connectivity | common.Dimension] + *args: Any, + offset_provider: dict[str, common.Connectivity | common.Dimension], + out: Any = None, ) -> None: + if out is not None: + args = (*args, out) converted_args = [convert_arg(arg) for arg in args] conn_args = extract_connectivity_args(offset_provider, device) - return inp(*converted_args, *conn_args) + # generate implicit domain size arguments only if necessary, using `iter_size_args()` + return inp( + *converted_args, + *(arguments.iter_size_args(args) if inp.implicit_domain else ()), + *conn_args, + ) return decorated_program @@ -92,20 +101,20 @@ def extract_connectivity_args( return args -def compilation_hash(otf_closure: stages.ProgramCall) -> int: +def compilation_hash(otf_closure: stages.AOTProgram) -> int: """Given closure compute a hash uniquely determining if we need to recompile.""" - offset_provider = otf_closure.kwargs["offset_provider"] + offset_provider = otf_closure.args.offset_provider return hash( ( - otf_closure.program, + otf_closure.data, # As the frontend types contain lists they are not hashable. As a workaround we just # use content_hash here. - content_hash(tuple(from_value(arg) for arg in otf_closure.args)), + content_hash(tuple(from_value(arg) for arg in otf_closure.args.args)), # Directly using the `id` of the offset provider is not possible as the decorator adds # the implicitly defined ones (i.e. to allow the `TDim + 1` syntax) resulting in a # different `id` every time. Instead use the `id` of each individual offset provider. tuple((k, id(v)) for (k, v) in offset_provider.items()) if offset_provider else None, - otf_closure.kwargs.get("column_axis", None), + otf_closure.args.column_axis, ) ) @@ -124,7 +133,8 @@ class Params: ) translation = factory.SubFactory( - gtfn_module.GTFNTranslationStepFactory, device_type=factory.SelfAttribute("..device_type") + gtfn_module.GTFNTranslationStepFactory, + device_type=factory.SelfAttribute("..device_type"), ) bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( nanobind.bind_source @@ -180,6 +190,7 @@ class Params: lambda o: modular_executor.ModularExecutor(otf_workflow=o.otf_workflow, name=o.name) ) allocator = next_allocators.StandardCPUFieldBufferAllocator() + transforms = backend.DEFAULT_TRANSFORMS run_gtfn = GTFNBackendFactory() diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 73e462e20e..8337a2b44a 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -13,6 +13,7 @@ import pathlib import tempfile import textwrap +import typing from collections.abc import Callable, Iterable from typing import Any, Optional @@ -21,9 +22,9 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import allocators as next_allocators, backend as next_backend, common, config -from gt4py.next.ffront import foast_to_gtir, past_to_itir, stages as ffront_stages -from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms -from gt4py.next.otf import stages, workflow +from gt4py.next.ffront import foast_to_gtir, past_to_itir +from gt4py.next.iterator import ir as itir, transforms as itir_transforms +from gt4py.next.otf import arguments, stages, workflow from gt4py.next.program_processors import modular_executor, processor_interface as ppi from gt4py.next.type_system import type_specifications as ts @@ -92,8 +93,8 @@ def fencil_generator( debug: bool, lift_mode: itir_transforms.LiftMode, use_embedded: bool, - offset_provider: dict[str, embedded.NeighborTableOffsetProvider], -) -> Callable: + offset_provider: dict[str, common.Connectivity | common.Dimension], +) -> stages.CompiledProgram: """ Generate a directly executable fencil from an ITIR node. @@ -112,7 +113,7 @@ def fencil_generator( if cache_key in _FENCIL_CACHE: if debug: print(f"Using cached fencil for key {cache_key}") - return _FENCIL_CACHE[cache_key] + return typing.cast(stages.CompiledProgram, _FENCIL_CACHE[cache_key]) ir = itir_transforms.apply_common_transforms( ir, lift_mode=lift_mode, offset_provider=offset_provider @@ -165,7 +166,6 @@ def fencil_generator( source_file.write("\n".join(axis_literals)) source_file.write("\n") source_file.write(program) - try: spec = importlib.util.spec_from_file_location("module.name", source_file_name) mod = importlib.util.module_from_spec(spec) # type: ignore @@ -180,51 +180,48 @@ def fencil_generator( _FENCIL_CACHE[cache_key] = fencil - return fencil - - -@ppi.program_executor # type: ignore[arg-type] -def execute_roundtrip( - ir: itir.Node, - *args: Any, - column_axis: Optional[common.Dimension] = None, - offset_provider: dict[str, embedded.NeighborTableOffsetProvider], - debug: Optional[bool] = None, - lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, - dispatch_backend: Optional[ppi.ProgramExecutor] = None, -) -> None: - debug = debug if debug is not None else config.DEBUG - fencil = fencil_generator( - ir, - offset_provider=offset_provider, - debug=debug, - lift_mode=lift_mode, - use_embedded=dispatch_backend is None, - ) - - new_kwargs: dict[str, Any] = {"offset_provider": offset_provider, "column_axis": column_axis} - if dispatch_backend: - new_kwargs["backend"] = dispatch_backend - - return fencil(*args, **new_kwargs) + return typing.cast(stages.CompiledProgram, fencil) @dataclasses.dataclass(frozen=True) -class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): +class Roundtrip(workflow.Workflow[stages.AOTProgram, stages.CompiledProgram]): debug: Optional[bool] = None lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE use_embedded: bool = True + dispatch_backend: Optional[ppi.ProgramExecutor] = None - def __call__(self, inp: stages.ProgramCall) -> stages.CompiledProgram: + def __call__(self, inp: stages.AOTProgram) -> stages.CompiledProgram: debug = config.DEBUG if self.debug is None else self.debug - return fencil_generator( - inp.program, - offset_provider=inp.kwargs.get("offset_provider", None), + + fencil = fencil_generator( + inp.data, + offset_provider=inp.args.offset_provider, debug=debug, lift_mode=self.lift_mode, use_embedded=self.use_embedded, ) + def decorated_fencil( + *args: Any, + offset_provider: dict[str, common.Connectivity | common.Dimension], + out: Any = None, + column_axis: Optional[common.Dimension] = None, + **kwargs: Any, + ) -> None: + if out is not None: + args = (*args, out) + if not column_axis: + column_axis = inp.args.column_axis + fencil( + *args, + offset_provider=offset_provider, + backend=self.dispatch_backend, + column_axis=inp.args.column_axis, + **kwargs, + ) + + return decorated_fencil + class RoundtripFactory(factory.Factory): class Meta: @@ -236,10 +233,16 @@ class RoundtripExecutor(modular_executor.ModularExecutor): dispatch_backend: Optional[ppi.ProgramExecutor] = None def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: - kwargs["backend"] = self.dispatch_backend - self.otf_workflow(stages.ProgramCall(program=program, args=args, kwargs=kwargs))( - *args, **kwargs - ) + argspec = arguments.CompileTimeArgs.from_concrete_no_size(*args, **kwargs) + self.otf_workflow( + stages.AOTProgram( + data=program, + args=dataclasses.replace( + argspec, + kwargs=argspec.kwargs, + ), + ) + )(*args, **kwargs) class RoundtripExecutorFactory(factory.Factory): @@ -263,22 +266,21 @@ class Params: ) default = next_backend.Backend( - executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator() + executor=executor, + allocator=next_allocators.StandardCPUFieldBufferAllocator(), + transforms=next_backend.DEFAULT_TRANSFORMS, ) with_temporaries = next_backend.Backend( - executor=executor_with_temporaries, allocator=next_allocators.StandardCPUFieldBufferAllocator() + executor=executor_with_temporaries, + allocator=next_allocators.StandardCPUFieldBufferAllocator(), + transforms=next_backend.DEFAULT_TRANSFORMS, ) gtir = next_backend.Backend( executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator(), - transforms_fop=next_backend.FieldopTransformWorkflow( - past_to_itir=past_to_itir.PastToItirFactory(to_gtir=True), - foast_to_itir=workflow.CachedStep( - step=foast_to_gtir.foast_to_gtir, hash_function=ffront_stages.fingerprint_stage - ), - ), - transforms_prog=next_backend.ProgramTransformWorkflow( - past_to_itir=past_to_itir.PastToItirFactory(to_gtir=True) + transforms=next_backend.Transforms( + past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), + foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), ), ) diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 62a6781316..d8054ad3da 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -25,6 +25,16 @@ from gt4py.next.type_system import type_info, type_specifications as ts +class SizeArg: + def __gt_type__(self) -> ts.ScalarType: + return ts.ScalarType(kind=ts.ScalarKind.INT32) + + def __eq__(self, other: object) -> bool: + if isinstance(other, SizeArg): + return True + return False + + def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: # make int & float precision platform independent. dt: np.dtype diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 3d797acb7c..1bc5899645 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -13,6 +13,7 @@ import gt4py.next as gtx from gt4py.next import backend as next_backend +from gt4py.next.otf import arguments from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case, unstructured_case @@ -56,7 +57,7 @@ def test_sdfgConvertible_laplap(cartesian_case): connectivities = {} # Dict of NeighborOffsetProviders, where self.table = None for k, v in cartesian_case.offset_provider.items(): if hasattr(v, "table"): - connectivities[k] = gtx.CompileTimeConnectivity( + connectivities[k] = arguments.CompileTimeConnectivity( v.max_neighbors, v.has_skip_values, v.origin_axis, v.neighbor_axis, v.table.dtype ) else: @@ -130,7 +131,7 @@ def sdfg( xp.asarray([[0, 1], [1, 2], [2, 0]]), Edge, Vertex, 2, False ) connectivities = {} - connectivities["E2V"] = gtx.CompileTimeConnectivity( + connectivities["E2V"] = arguments.CompileTimeConnectivity( e2v.max_neighbors, e2v.has_skip_values, e2v.origin_axis, e2v.neighbor_axis, e2v.table.dtype ) offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 55332c0a7c..b04c1fde12 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -32,7 +32,7 @@ def __call__(self, program, *args, **kwargs) -> None: raise ValueError("No backend selected! Backend selection is mandatory in tests.") -no_backend = NoBackend(executor=no_exec, transforms_prog=None, allocator=None) +no_backend = NoBackend(executor=no_exec, transforms=None, allocator=None) @pytest.fixture( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index d847c947d9..061af6a132 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -7,9 +7,13 @@ # SPDX-License-Identifier: BSD-3-Clause # TODO(dropd): Remove as soon as `gt4py.next.ffront.decorator` is type checked. +import numpy as np +import pytest + from gt4py import next as gtx from gt4py.next.iterator import ir as itir +from next_tests import definitions as test_definitions from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -28,3 +32,35 @@ def testee(a: cases.IField, out: cases.IField): assert isinstance(testee.itir, itir.FencilDefinition) assert isinstance(testee.with_backend(cartesian_case.executor).itir, itir.FencilDefinition) + + +def test_frozen(cartesian_case): + if cartesian_case.executor is None: + pytest.xfail("Frozen Program with embedded execution is not possible.") + + @gtx.field_operator + def testee_op(a: cases.IField) -> cases.IField: + return a + + @gtx.program(backend=cartesian_case.executor, frozen=True) + def testee(a: cases.IField, out: cases.IField): + testee_op(a, out=out) + + assert isinstance(testee, gtx.ffront.decorator.FrozenProgram) + + # first run should JIT compile + args_1, kwargs_1 = cases.get_default_data(cartesian_case, testee) + testee(*args_1, offset_provider=cartesian_case.offset_provider, **kwargs_1) + + # _compiled_program should be set after JIT compiling + args_2, kwargs_2 = cases.get_default_data(cartesian_case, testee) + testee._compiled_program(*args_2, offset_provider=cartesian_case.offset_provider, **kwargs_2) + + # and give expected results + assert np.allclose(kwargs_2["out"].ndarray, args_2[0].ndarray) + + # with_backend returns a new instance, which is frozen but not compiled yet + assert testee.with_backend(cartesian_case.executor)._compiled_program is None + + # with_grid_type returns a new instance, which is frozen but not compiled yet + assert testee.with_grid_type(cartesian_case.grid_type)._compiled_program is None diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 07e40675d6..df3b1ca106 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -33,6 +33,7 @@ @pytest.fixture def run_gtfn_with_temporaries_and_symbolic_sizes(): return backend.Backend( + transforms=backend.DEFAULT_TRANSFORMS, executor=modular_executor.ModularExecutor( name="run_gtfn_with_temporaries_and_sizes", otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace( diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index fc07559923..02104038dd 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -16,6 +16,7 @@ import gt4py.next as gtx import gt4py.next.program_processors.processor_interface as ppi +from gt4py.next import common from gt4py.next.iterator import builtins as it_builtins from gt4py.next.iterator.builtins import ( and_, diff --git a/tests/next_tests/unit_tests/ffront_tests/test_stages.py b/tests/next_tests/unit_tests/ffront_tests/test_stages.py index 2421f7708a..c1503f3e7c 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_stages.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_stages.py @@ -6,13 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import dataclasses - import pytest from gt4py import next as gtx from gt4py.next.ffront import stages -from gt4py.next.otf import workflow +from gt4py.next.otf import arguments, toolchain @pytest.fixture @@ -95,73 +93,22 @@ def test_fingerprint_stage_field_op_def(fieldop, samecode_fieldop, different_fie def test_fingerprint_stage_foast_op_def(fieldop, samecode_fieldop, different_fieldop): - foast = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(fieldop.definition_stage) - samecode = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast( - samecode_fieldop.definition_stage - ) - different = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast( - different_fieldop.definition_stage - ) - - assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast) - assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast) - - -@dataclasses.dataclass(frozen=True) -class ToFoastClosure(workflow.NamedStepSequenceWithArgs): - func_to_foast: workflow.Workflow = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast - foast_to_closure: workflow.Workflow = dataclasses.field( - default=gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_foast_closure, - metadata={"takes_args": True}, - ) - - -def test_fingerprint_stage_foast_closure(fieldop, samecode_fieldop, different_fieldop, idim, jdim): - toolchain = ToFoastClosure() - foast_closure = toolchain( - workflow.InputWithArgs( - data=fieldop.definition_stage, - args=(gtx.zeros({idim: 10}, gtx.int32),), - kwargs={ - "out": gtx.zeros({idim: 10}, gtx.int32), - "from_fieldop": fieldop, - }, - ), - ) - samecode = toolchain( - workflow.InputWithArgs( - data=samecode_fieldop.definition_stage, - args=(gtx.zeros({idim: 10}, gtx.int32),), - kwargs={ - "out": gtx.zeros({idim: 10}, gtx.int32), - "from_fieldop": samecode_fieldop, - }, + foast = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( + toolchain.CompilableProgram(fieldop.definition_stage, arguments.CompileTimeArgs.empty()) + ).data + samecode = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( + toolchain.CompilableProgram( + samecode_fieldop.definition_stage, arguments.CompileTimeArgs.empty() ) - ) - different = toolchain( - workflow.InputWithArgs( - data=different_fieldop.definition_stage, - args=(gtx.zeros({jdim: 10}, gtx.int32),), - kwargs={ - "out": gtx.zeros({jdim: 10}, gtx.int32), - "from_fieldop": different_fieldop, - }, - ) - ) - different_args = toolchain( - workflow.InputWithArgs( - data=fieldop.definition_stage, - args=(gtx.zeros({idim: 11}, gtx.int32),), - kwargs={ - "out": gtx.zeros({idim: 11}, gtx.int32), - "from_fieldop": fieldop, - }, + ).data + different = gtx.backend.DEFAULT_TRANSFORMS.func_to_foast( + toolchain.CompilableProgram( + different_fieldop.definition_stage, arguments.CompileTimeArgs.empty() ) - ) + ).data - assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast_closure) - assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast_closure) - assert stages.fingerprint_stage(different_args) != stages.fingerprint_stage(foast_closure) + assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast) + assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast) def test_fingerprint_stage_program_def(program, samecode_program, different_program): @@ -174,9 +121,19 @@ def test_fingerprint_stage_program_def(program, samecode_program, different_prog def test_fingerprint_stage_past_def(program, samecode_program, different_program): - past = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(program.definition_stage) - samecode = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(samecode_program.definition_stage) - different = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(different_program.definition_stage) + past = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( + toolchain.CompilableProgram(program.definition_stage, arguments.CompileTimeArgs.empty()) + ) + samecode = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( + toolchain.CompilableProgram( + samecode_program.definition_stage, arguments.CompileTimeArgs.empty() + ) + ) + different = gtx.backend.DEFAULT_TRANSFORMS.func_to_past( + toolchain.CompilableProgram( + different_program.definition_stage, arguments.CompileTimeArgs.empty() + ) + ) assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(past) assert stages.fingerprint_stage(different) != stages.fingerprint_stage(past) diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py index 80814582cd..79507f46b3 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py @@ -74,6 +74,7 @@ def make_program_source(name: str) -> stages.ProgramSource: library_deps=[interface.LibraryDependency("gridtools_cpu", "master")], language=languages.Cpp, language_settings=cpp_interface.CPP_DEFAULT, + implicit_domain=False, ) diff --git a/tests/next_tests/unit_tests/otf_tests/test_languages.py b/tests/next_tests/unit_tests/otf_tests/test_languages.py index 71b958622a..98f642016e 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_languages.py +++ b/tests/next_tests/unit_tests/otf_tests/test_languages.py @@ -22,6 +22,7 @@ def test_basic_settings_with_cpp_rejected(): language_settings=languages.LanguageSettings( formatter_key="cpp", formatter_style="llvm", file_extension="cpp" ), + implicit_domain=False, ) @@ -37,4 +38,5 @@ def test_header_files_settings_with_cpp_accepted(): file_extension="cpp", header_extension="hpp", ), + implicit_domain=False, ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 051685efda..bef82ad86f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -11,9 +11,9 @@ import gt4py.next as gtx from gt4py.next.iterator import ir as itir -from gt4py.next.otf import languages, stages -from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.otf import arguments, languages, stages +from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.type_system import type_translation @@ -61,7 +61,10 @@ def fencil_example(): def test_codegen(fencil_example): fencil, parameters = fencil_example module = gtfn_module.translate_program_cpu( - stages.ProgramCall(fencil, parameters, {"offset_provider": {}}) + stages.AOTProgram( + data=fencil, + args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), + ) ) assert module.entry_point.name == fencil.id assert any(d.name == "gridtools_cpu" for d in module.library_deps) diff --git a/tests/next_tests/unit_tests/test_config.py b/tests/next_tests/unit_tests/test_config.py index ccaeb5d775..a33bd5734a 100644 --- a/tests/next_tests/unit_tests/test_config.py +++ b/tests/next_tests/unit_tests/test_config.py @@ -21,6 +21,8 @@ def env_var(): yield env_var_name if saved is not None: os.environ[env_var_name] = saved + else: + _ = os.environ.pop(env_var_name, None) @pytest.mark.parametrize("value", ["False", "false", "0", "off"]) @@ -42,5 +44,5 @@ def test_env_flag_to_bool_invalid(env_var): def test_env_flag_to_bool_unset(env_var): - del os.environ[env_var] + _ = os.environ.pop(env_var, None) assert config.env_flag_to_bool(env_var, default=False) is False From b1dc8c4d9fee199a51645bcdfc6e1c0d26b236f2 Mon Sep 17 00:00:00 2001 From: Felix Thaler Date: Tue, 24 Sep 2024 09:54:48 +0200 Subject: [PATCH 04/11] Fix Some Typos in README (#1658) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 021da018be..b782e20f63 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ The following backends are supported: GT4Py can be installed as a regular Python package using `pip` (or any other PEP-517 frontend). As usual, we strongly recommended to create a new virtual environment to work on this project. -The performance backends also require the `Boost `\_\_ library, a dependency of [GridTools C++](https://github.com/GridTools/gridtools), which needs to be installed by the user. +The performance backends also require the [Boost](https://www.boost.org) library, a dependency of [GridTools C++](https://github.com/GridTools/gridtools), which needs to be installed by the user. ## ⚙ Configuration @@ -54,7 +54,7 @@ Other commonly used environment variables are: - `GT_CACHE_DIR_NAME`: Name of the compiler's cache directory (defaults to `.gt_cache`) - `GT_CACHE_ROOT`: Path to the compiler cache (defaults to `./`) -More options and details are available in [`config.py`](https://github.com/GridTools/gt4py/blob/main/src/gt4py/cartesian/config.py>). +More options and details are available in [`config.py`](https://github.com/GridTools/gt4py/blob/main/src/gt4py/cartesian/config.py). ## 📖 Documentation From e9e1fa32bae6911923f404231fd89b91efcaa07b Mon Sep 17 00:00:00 2001 From: Felix Thaler Date: Tue, 24 Sep 2024 14:15:12 +0200 Subject: [PATCH 05/11] fix: GitHub License Detection (#1657) GitHub fails to detect the license of GT4Py correctly (thus we get the ugly banner ![license](https://img.shields.io/github/license/GridTools/gt4py)). GitHub uses [licensee](https://github.com/licensee/licensee/) to detect the license as documented [here](https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/licensing-a-repository#detecting-a-license). This PR fixes license detection using licensee, which previously was broken due to an additional line above the typical BSD-3 copyright notice and the file `LICENSE_HEADER.txt` which confused licensee, too. --- .pre-commit-config.yaml | 2 +- CODING_GUIDELINES.md | 2 +- LICENSE_HEADER.txt => HEADER.txt | 0 LICENSE.txt | 2 -- 4 files changed, 2 insertions(+), 4 deletions(-) rename LICENSE_HEADER.txt => HEADER.txt (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5ad89cccd9..5e0314bca3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: - id: insert-license exclude: ^\..*$ types: [python] - args: [--comment-style, "|#|", --license-filepath, ./LICENSE_HEADER.txt, --fuzzy-match-generates-todo] + args: [--comment-style, "|#|", --license-filepath, ./HEADER.txt, --fuzzy-match-generates-todo] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 diff --git a/CODING_GUIDELINES.md b/CODING_GUIDELINES.md index 3f23e23f34..b4b27bbe9d 100644 --- a/CODING_GUIDELINES.md +++ b/CODING_GUIDELINES.md @@ -105,7 +105,7 @@ We highly encourage the [doctest][doctest] format for code examples in docstring In general, you should structure new Python modules in the following way: 1. _shebang_ line: `#! /usr/bin/env python3` (only for **executable scripts**!). -2. License header (see `LICENSE_HEADER.txt`). +2. License header (see `HEADER.txt`). 3. Module docstring. 4. Imports, alphabetically ordered within each block (fixed automatically by `ruff-formatter`): 1. Block of imports from the standard library. diff --git a/LICENSE_HEADER.txt b/HEADER.txt similarity index 100% rename from LICENSE_HEADER.txt rename to HEADER.txt diff --git a/LICENSE.txt b/LICENSE.txt index e188487720..4153d4b535 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,5 +1,3 @@ - GT4Py - GridTools Framework - Copyright (c) 2014-2024, ETH Zurich All rights reserved. From e10873ddb4850273a519e47c28aad0a9292ba42c Mon Sep 17 00:00:00 2001 From: Nicoletta Farabullini <41536517+nfarabullini@users.noreply.github.com> Date: Tue, 24 Sep 2024 16:17:06 +0200 Subject: [PATCH 06/11] feat[next]: unary operator for number type cast (#1659) edits in reference to issue #1643 ```python def unary_float(): return float(-1) ``` --- src/gt4py/next/ffront/foast_to_gtir.py | 27 ++++++++++------- src/gt4py/next/ffront/foast_to_itir.py | 29 +++++++++++-------- src/gt4py/next/ffront/func_to_foast.py | 15 ++++++---- .../ffront_tests/test_math_unary_builtins.py | 25 ++++++++++++++++ .../ffront_tests/test_foast_to_gtir.py | 22 ++++++++++++-- .../ffront_tests/test_foast_to_itir.py | 22 ++++++++++++-- 6 files changed, 107 insertions(+), 33 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 4d3230a540..2e007c28bc 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -405,16 +405,23 @@ def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - if isinstance(node.args[0], foast.Constant): - node_kind = self.visit(node.type).kind.name.lower() - target_type = fbuiltins.BUILTINS[node_kind] - source_type = {**fbuiltins.BUILTINS, "string": str}[node.args[0].type.__str__().lower()] - if target_type is bool and source_type is not bool: - return im.literal(str(bool(source_type(node.args[0].value))), "bool") - return im.literal(str(node.args[0].value), node_kind) - raise FieldOperatorLoweringError( - f"Encountered a type cast, which is not supported: {node}." - ) + el = node.args[0] + node_kind = self.visit(node.type).kind.name.lower() + source_type = {**fbuiltins.BUILTINS, "string": str}[el.type.__str__().lower()] + target_type = fbuiltins.BUILTINS[node_kind] + + if isinstance(el, foast.Constant): + val = source_type(el.value) + elif isinstance(el, foast.UnaryOp) and isinstance(el.operand, foast.Constant): + operand = source_type(el.operand.value) + val = eval(f"lambda arg: {el.op}arg")(operand) + else: + raise FieldOperatorLoweringError( + f"Type cast only supports literal arguments, {node.type} not supported." + ) + val = target_type(val) + + return im.literal(str(val), node_kind) def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: if isinstance(type_, ts.TupleType): diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index b32bf744f5..7936eda1cf 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -447,18 +447,23 @@ def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - if isinstance(node.args[0], foast.Constant): - node_kind = self.visit(node.type).kind.name.lower() - target_type = fbuiltins.BUILTINS[node_kind] - source_type = {**fbuiltins.BUILTINS, "string": str}[node.args[0].type.__str__().lower()] - if target_type is bool and source_type is not bool: - return im.promote_to_const_iterator( - im.literal(str(bool(source_type(node.args[0].value))), "bool") - ) - return im.promote_to_const_iterator(im.literal(str(node.args[0].value), node_kind)) - raise FieldOperatorLoweringError( - f"Encountered a type cast, which is not supported: {node}." - ) + el = node.args[0] + node_kind = self.visit(node.type).kind.name.lower() + source_type = {**fbuiltins.BUILTINS, "string": str}[el.type.__str__().lower()] + target_type = fbuiltins.BUILTINS[node_kind] + + if isinstance(el, foast.Constant): + val = source_type(el.value) + elif isinstance(el, foast.UnaryOp) and isinstance(el.operand, foast.Constant): + operand = source_type(el.operand.value) + val = eval(f"lambda arg: {el.op}arg")(operand) + else: + raise FieldOperatorLoweringError( + f"Type cast only supports literal arguments, {node.type} not supported." + ) + val = target_type(val) + + return im.promote_to_const_iterator(im.literal(str(val), node_kind)) def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 887e6cecba..ebe12d3a8b 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -503,11 +503,16 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: return foast.CompareOperator.NOTEQ def _verify_builtin_type_constructor(self, node: ast.Call) -> None: - if len(node.args) > 0 and not isinstance(node.args[0], ast.Constant): - raise errors.DSLError( - self.get_location(node), - f"'{self._func_name(node)}()' only takes literal arguments.", - ) + if len(node.args) > 0: + arg = node.args[0] + if not ( + isinstance(arg, ast.Constant) + or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) + ): + raise errors.DSLError( + self.get_location(node), + f"'{self._func_name(node)}()' only takes literal arguments.", + ) def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index a6eff8db5c..89c341e9a6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -11,6 +11,7 @@ import gt4py.next as gtx from gt4py.next import ( + broadcast, cbrt, ceil, cos, @@ -127,6 +128,30 @@ def uneg(inp: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, uneg, ref=lambda inp1: -inp1) +def test_unary_neg_float_conversion(cartesian_case): + @gtx.field_operator + def uneg_float() -> cases.IFloatField: + inp_f = broadcast(float(-1), (IDim,)) + return inp_f + + size = cartesian_case.default_sizes[IDim] + ref = cartesian_case.as_field([IDim], np.full(size, -1.0, dtype=float)) + out = cases.allocate(cartesian_case, uneg_float, cases.RETURN)() + cases.verify(cartesian_case, uneg_float, out=out, ref=ref) + + +def test_unary_neg_bool_conversion(cartesian_case): + @gtx.field_operator + def uneg_bool() -> cases.IBoolField: + inp_f = broadcast(bool(-1), (IDim,)) + return inp_f + + size = cartesian_case.default_sizes[IDim] + ref = cartesian_case.as_field([IDim], np.full(size, True, dtype=bool)) + out = cases.allocate(cartesian_case, uneg_bool, cases.RETURN)() + cases.verify(cartesian_case, uneg_bool, out=out, ref=ref) + + def test_unary_invert(cartesian_case): @gtx.field_operator def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 081c2514d9..706de8a3eb 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -395,6 +395,22 @@ def foo(inp: gtx.Field[[TDim], float64]): assert lowered.expr == reference +@pytest.mark.parametrize("var, var_type", [("-1.0", "float64"), ("True", "bool")]) +def test_unary_op_type_conversion(var, var_type): + def unary_float(): + return float(-1) + + def unary_bool(): + return bool(-1) + + fun = unary_bool if var_type == "bool" else unary_float + parsed = FieldOperatorParser.apply_to_function(fun) + lowered = FieldOperatorLowering.apply(parsed) + reference = im.literal(var, var_type) + + assert lowered.expr == reference + + def test_unpacking(): """Unpacking assigns should get separated.""" @@ -862,9 +878,9 @@ def foo() -> tuple[float, float, float32, float64, float, float32, float64]: im.literal("0.1", "float64"), im.literal("0.1", "float32"), im.literal("0.1", "float64"), - im.literal(".1", "float64"), - im.literal(".1", "float32"), - im.literal(".1", "float64"), + im.literal("0.1", "float64"), + im.literal("0.1", "float32"), + im.literal("0.1", "float64"), ) assert lowered.expr == reference diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index faedcf1544..c102df9d57 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -177,6 +177,22 @@ def unary(inp: gtx.Field[[TDim], float64]): assert lowered.expr == reference +@pytest.mark.parametrize("var, var_type", [("-1.0", "float64"), ("True", "bool")]) +def test_unary_op_type_conversion(var, var_type): + def unary_float(): + return float(-1) + + def unary_bool(): + return bool(-1) + + fun = unary_bool if var_type == "bool" else unary_float + parsed = FieldOperatorParser.apply_to_function(fun) + lowered = FieldOperatorLowering.apply(parsed) + reference = im.promote_to_const_iterator(im.literal(var, var_type)) + + assert lowered.expr == reference + + def test_unpacking(): """Unpacking assigns should get separated.""" @@ -553,9 +569,9 @@ def float_constrs() -> tuple[float, float, float32, float64, float, float32, flo im.promote_to_const_iterator(im.literal("0.1", "float64")), im.promote_to_const_iterator(im.literal("0.1", "float32")), im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal(".1", "float64")), - im.promote_to_const_iterator(im.literal(".1", "float32")), - im.promote_to_const_iterator(im.literal(".1", "float64")), + im.promote_to_const_iterator(im.literal("0.1", "float64")), + im.promote_to_const_iterator(im.literal("0.1", "float32")), + im.promote_to_const_iterator(im.literal("0.1", "float64")), ) assert lowered.expr == reference From 1ccf8c4d23a5a3cf8322a97f0af08caf37387147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 25 Sep 2024 11:34:41 +0200 Subject: [PATCH 07/11] fix: add ROCm compatiblity to storages (#1655) Reimplement some cartesian storage utilities to handle better some corner cases with GPU storages, and to emulate the missing CUDA Array Interface in CuPy-ROCm. The previous hack to support CuPy-ROCm storages (`__hip_array_interface__`) has been removed and therefore it could be also removed from [GridTools-C++] (https://github.com/GridTools/gridtools/blob/master/include/gridtools/storage/adapter/python_sid_adapter.hpp) at some point. --- src/gt4py/_core/definitions.py | 4 ++ src/gt4py/storage/allocators.py | 12 ---- src/gt4py/storage/cartesian/utils.py | 94 ++++++++++++++++++++++++---- 3 files changed, 85 insertions(+), 25 deletions(-) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 01fbd51476..9d07b2eb79 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -379,6 +379,8 @@ class DeviceType(enum.IntEnum): METAL = 8 VPI = 9 ROCM = 10 + CUDA_MANAGED = 13 + ONE_API = 14 CPUDeviceTyping: TypeAlias = Literal[DeviceType.CPU] @@ -389,6 +391,8 @@ class DeviceType(enum.IntEnum): MetalDeviceTyping: TypeAlias = Literal[DeviceType.METAL] VPIDeviceTyping: TypeAlias = Literal[DeviceType.VPI] ROCMDeviceTyping: TypeAlias = Literal[DeviceType.ROCM] +CUDAManagedDeviceTyping: TypeAlias = Literal[DeviceType.CUDA_MANAGED] +OneApiDeviceTyping: TypeAlias = Literal[DeviceType.ONE_API] DeviceTypeT = TypeVar( diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index fa9005e86b..298b9c2e5a 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -259,18 +259,6 @@ def allocate( buffer, dtype, shape, padded_shape, item_size, strides, byte_offset ) - if self.device_type == core_defs.DeviceType.ROCM: - # until we can rely on dlpack - ndarray.__hip_array_interface__ = { # type: ignore[attr-defined] - "shape": ndarray.shape, # type: ignore[union-attr] - "typestr": ndarray.dtype.descr[0][1], # type: ignore[union-attr] - "descr": ndarray.dtype.descr, # type: ignore[union-attr] - "stream": 1, - "version": 3, - "strides": ndarray.strides, # type: ignore[union-attr, attr-defined] - "data": (ndarray.data.ptr, False), # type: ignore[union-attr, attr-defined] - } - return TensorBuffer( buffer=buffer, memory_address=memory_address, diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 0d7fcab201..052238fe24 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -9,6 +9,7 @@ from __future__ import annotations import collections.abc +import functools import math import numbers from typing import Any, Final, Literal, Optional, Sequence, Tuple, Union, cast @@ -52,13 +53,38 @@ if CUPY_DEVICE == core_defs.DeviceType.CUDA: _GPUBufferAllocator = allocators.NDArrayBufferAllocator( - device_type=core_defs.DeviceType.CUDA, array_utils=allocators.cupy_array_utils + device_type=core_defs.DeviceType.CUDA, + array_utils=allocators.cupy_array_utils, ) - else: + elif CUPY_DEVICE == core_defs.DeviceType.ROCM: _GPUBufferAllocator = allocators.NDArrayBufferAllocator( - device_type=core_defs.DeviceType.ROCM, array_utils=allocators.cupy_array_utils + device_type=core_defs.DeviceType.ROCM, + array_utils=allocators.cupy_array_utils, ) + class CUDAArrayInterfaceNDArray(cp.ndarray): + def __new__(cls, input_array: "cp.ndarray") -> CUDAArrayInterfaceNDArray: + return ( + input_array + if isinstance(input_array, CUDAArrayInterfaceNDArray) + else cp.asarray(input_array).view(cls) + ) + + @property + def __cuda_array_interface__(self) -> dict: + return { + "shape": self.shape, + "typestr": self.dtype.descr[0][1], + "descr": self.dtype.descr, + "stream": 1, + "version": 3, + "strides": self.strides, + "data": (self.data.ptr, False), + } + + else: + raise ValueError("CuPy is available but no suitable device was found.") + def _idx_from_order(order): return list(np.argsort(order)) @@ -188,14 +214,36 @@ def asarray( # extract the buffer from a gt4py.next.Field # TODO(havogt): probably `Field` should provide the array interface methods when applicable array = array.ndarray - if device == "gpu" or (not device and hasattr(array, "__cuda_array_interface__")): - return cp.asarray(array) - if device == "cpu" or ( - not device and (hasattr(array, "__array_interface__") or hasattr(array, "__array__")) - ): - return np.asarray(array) - - if device: + + xp = None + if device == "cpu": + xp = np + elif device == "gpu": + assert cp is not None + xp = cp + elif not device: + if hasattr(array, "__dlpack_device__"): + kind, _ = array.__dlpack_device__() + if kind in [core_defs.DeviceType.CPU, core_defs.DeviceType.CPU_PINNED]: + xp = np + elif kind in [ + core_defs.DeviceType.CUDA, + core_defs.DeviceType.ROCM, + ]: + if cp is None: + raise RuntimeError("CuPy is required for GPU arrays") + xp = cp + elif hasattr(array, "__cuda_array_interface__"): + if cp is None: + raise RuntimeError("CuPy is required for GPU arrays") + xp = cp + elif hasattr(array, "__array_interface__") or hasattr(array, "__array__"): + xp = np + + if xp: + return xp.asarray(array) + + if device is not None: raise ValueError(f"Invalid device: {device!s}") raise TypeError(f"Cannot convert {type(array)} to ndarray") @@ -241,9 +289,10 @@ def allocate_gpu( alignment_bytes: int, aligned_index: Optional[Sequence[int]], ) -> Tuple["cp.ndarray", "cp.ndarray"]: + assert cp is not None assert _GPUBufferAllocator is not None, "GPU allocation library or device not found" device = core_defs.Device( # type: ignore[type-var] - core_defs.DeviceType.ROCM if gt_config.GT4PY_USE_HIP else core_defs.DeviceType.CUDA, 0 + (core_defs.DeviceType.ROCM if gt_config.GT4PY_USE_HIP else core_defs.DeviceType.CUDA), 0 ) buffer = _GPUBufferAllocator.allocate( shape, @@ -253,4 +302,23 @@ def allocate_gpu( byte_alignment=alignment_bytes, aligned_index=aligned_index, ) - return buffer.buffer, cast("cp.ndarray", buffer.ndarray) + + buffer_ndarray = cast("cp.ndarray", buffer.ndarray) + + return buffer.buffer, buffer_ndarray + + +if CUPY_DEVICE == core_defs.DeviceType.ROCM: + + @functools.wraps(allocate_gpu) + def allocate_gpu_rocm( + shape: Sequence[int], + layout_map: allocators.BufferLayoutMap, + dtype: DTypeLike, + alignment_bytes: int, + aligned_index: Optional[Sequence[int]], + ) -> Tuple["cp.ndarray", "cp.ndarray"]: + buffer, ndarray = allocate_gpu(shape, layout_map, dtype, alignment_bytes, aligned_index) + return buffer, CUDAArrayInterfaceNDArray(ndarray) + + allocate_gpu = allocate_gpu_rocm From f7e970782db20b66dd768fc347dfd9936c87af7d Mon Sep 17 00:00:00 2001 From: Stefano Ubbiali Date: Wed, 25 Sep 2024 21:37:19 +0200 Subject: [PATCH 08/11] fix[cartesian]: fix bugs in CuPy-ROCm storage allocation (#1662) Fix multiple bugs in the allocation of CuPy-ROCm storages introduced in previous PR #1655 (https://github.com/GridTools/gt4py/pull/1655), including an infinite recursion call in the main allocation function. --------- Co-authored-by: Enrique Gonzalez Paredes --- src/gt4py/storage/cartesian/utils.py | 58 +++++++++++++++------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 052238fe24..50500e536b 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -61,27 +61,6 @@ device_type=core_defs.DeviceType.ROCM, array_utils=allocators.cupy_array_utils, ) - - class CUDAArrayInterfaceNDArray(cp.ndarray): - def __new__(cls, input_array: "cp.ndarray") -> CUDAArrayInterfaceNDArray: - return ( - input_array - if isinstance(input_array, CUDAArrayInterfaceNDArray) - else cp.asarray(input_array).view(cls) - ) - - @property - def __cuda_array_interface__(self) -> dict: - return { - "shape": self.shape, - "typestr": self.dtype.descr[0][1], - "descr": self.dtype.descr, - "stream": 1, - "version": 3, - "strides": self.strides, - "data": (self.data.ptr, False), - } - else: raise ValueError("CuPy is available but no suitable device was found.") @@ -282,7 +261,7 @@ def allocate_cpu( return buffer.buffer, cast(np.ndarray, buffer.ndarray) -def allocate_gpu( +def _allocate_gpu( shape: Sequence[int], layout_map: allocators.BufferLayoutMap, dtype: DTypeLike, @@ -292,7 +271,8 @@ def allocate_gpu( assert cp is not None assert _GPUBufferAllocator is not None, "GPU allocation library or device not found" device = core_defs.Device( # type: ignore[type-var] - (core_defs.DeviceType.ROCM if gt_config.GT4PY_USE_HIP else core_defs.DeviceType.CUDA), 0 + (core_defs.DeviceType.ROCM if gt_config.GT4PY_USE_HIP else core_defs.DeviceType.CUDA), + 0, ) buffer = _GPUBufferAllocator.allocate( shape, @@ -308,17 +288,41 @@ def allocate_gpu( return buffer.buffer, buffer_ndarray +allocate_gpu = _allocate_gpu + if CUPY_DEVICE == core_defs.DeviceType.ROCM: - @functools.wraps(allocate_gpu) - def allocate_gpu_rocm( + class CUDAArrayInterfaceNDArray(cp.ndarray): + def __new__(cls, input_array: "cp.ndarray") -> CUDAArrayInterfaceNDArray: + return ( + input_array + if isinstance(input_array, CUDAArrayInterfaceNDArray) + else cp.asarray(input_array).view(cls) + ) + + @property + def __cuda_array_interface__(self) -> dict: + return { + "shape": self.shape, + "typestr": self.dtype.descr[0][1], + "descr": self.dtype.descr, + "stream": 1, + "version": 3, + "strides": self.strides, + "data": (self.data.ptr, False), + } + + __hip_array_interface__ = __cuda_array_interface__ + + @functools.wraps(_allocate_gpu) + def _allocate_gpu_rocm( shape: Sequence[int], layout_map: allocators.BufferLayoutMap, dtype: DTypeLike, alignment_bytes: int, aligned_index: Optional[Sequence[int]], ) -> Tuple["cp.ndarray", "cp.ndarray"]: - buffer, ndarray = allocate_gpu(shape, layout_map, dtype, alignment_bytes, aligned_index) + buffer, ndarray = _allocate_gpu(shape, layout_map, dtype, alignment_bytes, aligned_index) return buffer, CUDAArrayInterfaceNDArray(ndarray) - allocate_gpu = allocate_gpu_rocm + allocate_gpu = _allocate_gpu_rocm From 48b13cc0251618014830cb54a11952fde5ba157e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 26 Sep 2024 15:37:07 +0200 Subject: [PATCH 09/11] fext[next]: HIP support (#1661) Introduces the HIP language in the CMake setup. `gtfn_gpu` uses the implementation (CUDA or HIP) as detected by cupy. --- src/gt4py/next/otf/binding/nanobind.py | 2 +- .../otf/compilation/build_systems/cmake.py | 12 +++++++----- .../compilation/build_systems/cmake_lists.py | 3 +++ .../compilation/build_systems/compiledb.py | 4 ++-- src/gt4py/next/otf/languages.py | 8 ++++++-- .../codegens/gtfn/gtfn_module.py | 19 ++++++++++++++----- .../runners/dace_common/workflow.py | 2 +- .../runners/dace_iterator/workflow.py | 6 +++++- .../next/program_processors/runners/gtfn.py | 4 ++-- .../build_systems_tests/conftest.py | 2 +- .../unit_tests/otf_tests/test_languages.py | 4 ++-- .../gtfn_tests/test_gtfn_module.py | 2 +- 12 files changed, 45 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 27f788d224..24913a1365 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -188,7 +188,7 @@ def create_bindings( program_source The program source for which the bindings are created """ - if program_source.language not in [languages.Cpp, languages.Cuda]: + if program_source.language not in [languages.CPP, languages.CUDA, languages.HIP]: raise ValueError( f"Can only create bindings for C++ program sources, received '{program_source.language}'." ) diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake.py b/src/gt4py/next/otf/compilation/build_systems/cmake.py index 14377d1b82..33212848a6 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake.py @@ -22,7 +22,9 @@ @dataclasses.dataclass class CMakeFactory( compiler.BuildSystemProjectGenerator[ - languages.Cpp | languages.Cuda, languages.LanguageWithHeaderFilesSettings, languages.Python + languages.CPP | languages.CUDA | languages.HIP, + languages.LanguageWithHeaderFilesSettings, + languages.Python, ] ): """Create a CMakeProject from a ``CompilableSource`` stage object with given CMake settings.""" @@ -34,7 +36,7 @@ class CMakeFactory( def __call__( self, source: stages.CompilableSource[ - languages.Cpp | languages.Cuda, + languages.CPP | languages.CUDA | languages.HIP, languages.LanguageWithHeaderFilesSettings, languages.Python, ], @@ -48,8 +50,8 @@ def __call__( header_name = f"{name}.{source.program_source.language_settings.header_extension}" bindings_name = f"{name}_bindings.{source.program_source.language_settings.file_extension}" cmake_languages = [cmake_lists.Language(name="CXX")] - if source.program_source.language is languages.Cuda: - cmake_languages = [*cmake_languages, cmake_lists.Language(name="CUDA")] + if (src_lang := source.program_source.language) in [languages.CUDA, languages.HIP]: + cmake_languages = [*cmake_languages, cmake_lists.Language(name=src_lang.__name__)] cmake_lists_src = cmake_lists.generate_cmakelists_source( name, source.library_deps, [header_name, bindings_name], languages=cmake_languages ) @@ -70,7 +72,7 @@ def __call__( @dataclasses.dataclass class CMakeProject( stages.BuildSystemProject[ - languages.Cpp, languages.LanguageWithHeaderFilesSettings, languages.Python + languages.CPP, languages.LanguageWithHeaderFilesSettings, languages.Python ] ): """ diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index ae63d6613a..0533adac81 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -48,6 +48,9 @@ class CMakeListsGenerator(eve.codegen.TemplatedGenerator): if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) set(CMAKE_CUDA_ARCHITECTURES 60) endif() + if(NOT DEFINED CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES gfx90a) + endif() {{"\\n".join(languages)}} # Paths diff --git a/src/gt4py/next/otf/compilation/build_systems/compiledb.py b/src/gt4py/next/otf/compilation/build_systems/compiledb.py index debf8217cf..6bacde4937 100644 --- a/src/gt4py/next/otf/compilation/build_systems/compiledb.py +++ b/src/gt4py/next/otf/compilation/build_systems/compiledb.py @@ -255,8 +255,8 @@ def _cc_create_compiledb( prog_src_name = f"{name}.{header_ext}" binding_src_name = f"{name}.{src_ext}" cmake_languages = [cmake_lists.Language(name="CXX")] - if prototype_program_source.language is languages.Cuda: - cmake_languages = [*cmake_languages, cmake_lists.Language(name="CUDA")] + if (src_lang := prototype_program_source.language) in [languages.CUDA, languages.HIP]: + cmake_languages = [*cmake_languages, cmake_lists.Language(name=src_lang.__name__)] prototype_project = cmake.CMakeProject( generator_name="Ninja", diff --git a/src/gt4py/next/otf/languages.py b/src/gt4py/next/otf/languages.py index db1094db57..1564a300bb 100644 --- a/src/gt4py/next/otf/languages.py +++ b/src/gt4py/next/otf/languages.py @@ -59,10 +59,14 @@ class SDFG(LanguageTag): class NanobindSrcL(LanguageTag): ... -class Cpp(NanobindSrcL): +class CPP(NanobindSrcL): settings_class = LanguageWithHeaderFilesSettings ... -class Cuda(NanobindSrcL): +class CUDA(NanobindSrcL): + settings_class = LanguageWithHeaderFilesSettings + + +class HIP(NanobindSrcL): settings_class = LanguageWithHeaderFilesSettings diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index ac5325aade..5349464edd 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -68,6 +68,13 @@ def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSetting file_extension="cu", header_extension="cuh", ) + case core_defs.DeviceType.ROCM: + return languages.LanguageWithHeaderFilesSettings( + formatter_key=cpp_interface.CPP_DEFAULT.formatter_key, + formatter_style=cpp_interface.CPP_DEFAULT.formatter_style, + file_extension="hip", + header_extension="h", + ) case core_defs.DeviceType.CPU: return cpp_interface.CPP_DEFAULT case _: @@ -263,7 +270,7 @@ def __call__( def _backend_header(self) -> str: match self.device_type: - case core_defs.DeviceType.CUDA: + case core_defs.DeviceType.CUDA | core_defs.DeviceType.ROCM: return "gridtools/fn/backend/gpu.hpp" case core_defs.DeviceType.CPU: return "gridtools/fn/backend/naive.hpp" @@ -272,7 +279,7 @@ def _backend_header(self) -> str: def _backend_type(self) -> str: match self.device_type: - case core_defs.DeviceType.CUDA: + case core_defs.DeviceType.CUDA | core_defs.DeviceType.ROCM: return "gridtools::fn::backend::gpu{}" case core_defs.DeviceType.CPU: return "gridtools::fn::backend::naive{}" @@ -282,9 +289,11 @@ def _backend_type(self) -> str: def _language(self) -> type[languages.NanobindSrcL]: match self.device_type: case core_defs.DeviceType.CUDA: - return languages.Cuda + return languages.CUDA + case core_defs.DeviceType.ROCM: + return languages.HIP case core_defs.DeviceType.CPU: - return languages.Cpp + return languages.CPP case _: raise self._not_implemented_for_device_type() @@ -297,7 +306,7 @@ def _language_settings(self) -> languages.LanguageWithHeaderFilesSettings: def _library_name(self) -> str: match self.device_type: - case core_defs.DeviceType.CUDA: + case core_defs.DeviceType.CUDA | core_defs.DeviceType.ROCM: return "gridtools_gpu" case core_defs.DeviceType.CPU: return "gridtools_cpu" diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index dbe2b70ff8..73437664ff 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -108,7 +108,7 @@ def convert_args( ) -> stages.CompiledProgram: sdfg_program = inp.sdfg_program sdfg = sdfg_program.sdfg - on_gpu = True if device == core_defs.DeviceType.CUDA else False + on_gpu = True if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] else False def decorated_program( *args: Any, diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 72cc15a46e..0b23fde9d1 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -56,7 +56,11 @@ def generate_sdfg( offset_provider: dict[str, common.Dimension | common.Connectivity], column_axis: Optional[common.Dimension], ) -> dace.SDFG: - on_gpu = True if self.device_type == core_defs.DeviceType.CUDA else False + on_gpu = ( + True + if self.device_type in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] + else False + ) return build_sdfg_from_itir( program, diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index a2badd5191..9c4e73520f 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -65,7 +65,7 @@ def decorated_program( def _ensure_is_on_device( connectivity_arg: npt.NDArray, device: core_defs.DeviceType ) -> npt.NDArray: - if device == core_defs.DeviceType.CUDA: + if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]: import cupy as cp if not isinstance(connectivity_arg, cp.ndarray): @@ -160,7 +160,7 @@ class Params: name_postfix = "" gpu = factory.Trait( allocator=next_allocators.StandardGPUFieldBufferAllocator(), - device_type=core_defs.DeviceType.CUDA, + device_type=next_allocators.CUPY_DEVICE or core_defs.DeviceType.CUDA, name_device="gpu", ) cached = factory.Trait( diff --git a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py index 79507f46b3..0848aba5a6 100644 --- a/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py +++ b/tests/next_tests/unit_tests/otf_tests/compilation_tests/build_systems_tests/conftest.py @@ -72,7 +72,7 @@ def make_program_source(name: str) -> stages.ProgramSource: entry_point=entry_point, source_code=src, library_deps=[interface.LibraryDependency("gridtools_cpu", "master")], - language=languages.Cpp, + language=languages.CPP, language_settings=cpp_interface.CPP_DEFAULT, implicit_domain=False, ) diff --git a/tests/next_tests/unit_tests/otf_tests/test_languages.py b/tests/next_tests/unit_tests/otf_tests/test_languages.py index 98f642016e..7a3fc0c007 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_languages.py +++ b/tests/next_tests/unit_tests/otf_tests/test_languages.py @@ -18,7 +18,7 @@ def test_basic_settings_with_cpp_rejected(): entry_point=interface.Function(name="basic_settings_with_cpp", parameters=[]), source_code="", library_deps=(), - language=languages.Cpp, + language=languages.CPP, language_settings=languages.LanguageSettings( formatter_key="cpp", formatter_style="llvm", file_extension="cpp" ), @@ -31,7 +31,7 @@ def test_header_files_settings_with_cpp_accepted(): entry_point=interface.Function(name="basic_settings_with_cpp", parameters=[]), source_code="", library_deps=(), - language=languages.Cpp, + language=languages.CPP, language_settings=languages.LanguageWithHeaderFilesSettings( formatter_key="cpp", formatter_style="llvm", diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index bef82ad86f..e5abcd7f0a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -68,4 +68,4 @@ def test_codegen(fencil_example): ) assert module.entry_point.name == fencil.id assert any(d.name == "gridtools_cpu" for d in module.library_deps) - assert module.language is languages.Cpp + assert module.language is languages.CPP From 6d011ea01b8a386088b61a05dc7aa323e6080d2e Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 27 Sep 2024 08:41:01 +0200 Subject: [PATCH 10/11] fix[next]: Fix usage of DaCe fast-call to SDFG (#1656) This PR addresses some flaky test failures observed in GT4Py CI. The root cause was that the dace backend did not check the connectivity arrays, which are passed as keyword-arguments to the SDFG. It did only check the positional arguments. The connectivity arrays do not have to be allocated on the device memory: for gpu execution, the backend ensures that the connectivity arrays are copied to device memory just before passing them to the SDFG call. Previous implementation worked sometimes, when by chance cupy was reusing the same array on gpu memory, hence the flaky behavior of the tests. New test is added for the connectivity case. The previous test case is cleaned up and improved, by invalidating all scalar positional arguments at each SDFG call: this allows to test that they are overridden before fast_call. Additionally, this PR reduces the overhead of regular SDFG call: previous implementation was copying all the connectivity arrays to gpu memory, with this PR we only allocate cupy arrays for the connectivities used in the SDFG. --- .../runners/dace_common/dace_backend.py | 50 ++--- .../runners/dace_common/utility.py | 11 +- .../runners/dace_common/workflow.py | 76 ++++---- .../runners_tests/dace_tests/test_dace.py | 181 +++++++++++------- 4 files changed, 185 insertions(+), 133 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 71e66ca771..7063faee16 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -73,20 +73,9 @@ def _ensure_is_on_device( return connectivity_arg -def _get_connectivity_args( - neighbor_tables: Mapping[str, gtx_common.NeighborTable], device: dace.dtypes.DeviceType -) -> dict[str, Any]: - return { - dace_util.connectivity_identifier(offset): _ensure_is_on_device( - offset_provider.table, device - ) - for offset, offset_provider in neighbor_tables.items() - } - - def _get_shape_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] -) -> Mapping[str, int]: + arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] +) -> dict[str, int]: shape_args: dict[str, int] = {} for name, value in args.items(): for sym, size in zip(arrays[name].shape, value.shape, strict=True): @@ -101,8 +90,8 @@ def _get_shape_args( def _get_stride_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] -) -> Mapping[str, int]: + arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] +) -> dict[str, int]: stride_args = {} for name, value in args.items(): for sym, stride_size in zip(arrays[name].strides, value.strides, strict=True): @@ -121,6 +110,27 @@ def _get_stride_args( return stride_args +def get_sdfg_conn_args( + sdfg: dace.SDFG, + offset_provider: dict[str, Any], + on_gpu: bool, +) -> dict[str, np.typing.NDArray]: + """ + Extracts the connectivity tables that are used in the sdfg and ensures + that the memory buffers are allocated for the target device. + """ + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + + connectivity_args = {} + for offset, connectivity in dace_util.filter_connectivities(offset_provider).items(): + assert isinstance(connectivity, gtx_common.NeighborTable) + param = dace_util.connectivity_identifier(offset) + if param in sdfg.arrays: + connectivity_args[param] = _ensure_is_on_device(connectivity.table, device) + + return connectivity_args + + def get_sdfg_args( sdfg: dace.SDFG, *args: Any, @@ -138,17 +148,9 @@ def get_sdfg_args( """ offset_provider = kwargs["offset_provider"] - neighbor_tables: dict[str, gtx_common.NeighborTable] = {} - for offset, connectivity in dace_util.filter_connectivities(offset_provider).items(): - assert isinstance(connectivity, gtx_common.NeighborTable) - neighbor_tables[offset] = connectivity - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - dace_args = _get_args(sdfg, args, use_field_canonical_representation) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} - dace_conn_args = _get_connectivity_args(neighbor_tables, device) - # keep only connectivity tables that are used in the sdfg - dace_conn_args = {n: v for n, v in dace_conn_args.items() if n in sdfg.arrays} + dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu) dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = _get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = _get_stride_args(sdfg.arrays, dace_field_args) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index c8f9e37a6b..a892040303 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -8,7 +8,8 @@ from __future__ import annotations -from typing import Any, Mapping, Optional, Sequence +import re +from typing import Any, Final, Mapping, Optional, Sequence import dace @@ -17,6 +18,10 @@ from gt4py.next.type_system import type_specifications as ts +# regex to match the symbols for field shape and strides +FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile("__.+_(size|stride)_\d+") + + def as_scalar_type(typestr: str) -> ts.ScalarType: """Obtain GT4Py scalar type from generic numpy string representation.""" try: @@ -38,6 +43,10 @@ def field_stride_symbol_name(field_name: str, axis: int) -> str: return f"__{field_name}_stride_{axis}" +def is_field_symbol(name: str) -> bool: + return FIELD_SYMBOL_RE.match(name) is not None + + def debug_info( node: gtir.Node, *, default: Optional[dace.dtypes.DebugInfo] = None ) -> Optional[dace.dtypes.DebugInfo]: diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index 73437664ff..1caa4684c9 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -10,7 +10,7 @@ import ctypes import dataclasses -from typing import Any, Optional +from typing import Any import dace import factory @@ -20,26 +20,23 @@ from gt4py.next import common, config from gt4py.next.otf import arguments, languages, stages, step_types, workflow from gt4py.next.otf.compilation import cache -from gt4py.next.program_processors.runners.dace_common import dace_backend +from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils class CompiledDaceProgram(stages.CompiledProgram): sdfg_program: dace.CompiledSDFG - # Map SDFG argument to its position in program ABI; scalar arguments that are not used in the SDFG will not be present. - sdfg_arg_position: list[Optional[int]] - def __init__(self, program: dace.CompiledSDFG): - # extract position of arguments in program ABI - sdfg_arglist = program.sdfg.signature_arglist(with_types=False) - sdfg_arg_pos_mapping = {param: pos for pos, param in enumerate(sdfg_arglist)} - sdfg_used_symbols = program.sdfg.used_symbols(all_symbols=False) + # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; + # scalar arguments that are not used in the SDFG will not be present. + sdfg_arglist: list[tuple[str, dace.dtypes.Data]] + def __init__(self, program: dace.CompiledSDFG): self.sdfg_program = program - self.sdfg_arg_position = [ - sdfg_arg_pos_mapping[param] - if param in program.sdfg.arrays or param in sdfg_used_symbols - else None - for param in program.sdfg.arg_names + # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument + # name to its data type, in the same order as arguments appear in the program ABI. + # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. + self.sdfg_arglist = [ + (arg_name, arg_type) for arg_name, arg_type in program.sdfg.arglist().items() ] def __call__(self, *args: Any, **kwargs: Any) -> None: @@ -94,13 +91,6 @@ class Meta: model = DaCeCompiler -def _get_ctype_value(arg: Any, dtype: dace.dtypes.dataclass) -> Any: - if not isinstance(arg, (ctypes._SimpleCData, ctypes._Pointer)): - actype = dtype.as_ctypes() - return actype(arg) - return arg - - def convert_args( inp: CompiledDaceProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU, @@ -119,28 +109,36 @@ def decorated_program( args = (*args, out) if len(sdfg.arg_names) > len(args): args = (*args, *arguments.iter_size_args(args)) + if sdfg_program._lastargs: - # The scalar arguments should be replaced with the actual value; for field arguments, - # the data pointer should remain the same otherwise fast-call cannot be used and - # the args list needs to be reconstructed. + kwargs = dict(zip(sdfg.arg_names, args, strict=True)) + kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu)) + use_fast_call = True - for arg, param, pos in zip(args, sdfg.arg_names, inp.sdfg_arg_position, strict=True): - if isinstance(arg, common.Field): - desc = sdfg.arrays[param] - assert isinstance(desc, dace.data.Array) - assert isinstance(sdfg_program._lastargs[0][pos], ctypes.c_void_p) - if sdfg_program._lastargs[0][pos].value != get_array_interface_ptr( - arg.ndarray, desc.storage - ): + last_call_args = sdfg_program._lastargs[0] + # The scalar arguments should be overridden with the new value; for field arguments, + # the data pointer should remain the same otherwise fast_call cannot be used and + # the arguments list has to be reconstructed. + for i, (arg_name, arg_type) in enumerate(inp.sdfg_arglist): + if isinstance(arg_type, dace.data.Array): + assert arg_name in kwargs, f"Argument '{arg_name}' not found." + data_ptr = get_array_interface_ptr(kwargs[arg_name], arg_type.storage) + assert isinstance(last_call_args[i], ctypes.c_void_p) + if last_call_args[i].value != data_ptr: use_fast_call = False break - elif param in sdfg.arrays: - desc = sdfg.arrays[param] - assert isinstance(desc, dace.data.Scalar) - sdfg_program._lastargs[0][pos] = _get_ctype_value(arg, desc.dtype) - elif pos: - sym_dtype = sdfg.symbols[param] - sdfg_program._lastargs[0][pos] = _get_ctype_value(arg, sym_dtype) + else: + assert isinstance(arg_type, dace.data.Scalar) + assert isinstance(last_call_args[i], ctypes._SimpleCData) + if arg_name in kwargs: + # override the scalar value used in previous program call + actype = arg_type.dtype.as_ctypes() + last_call_args[i] = actype(kwargs[arg_name]) + else: + # shape and strides of arrays are supposed not to change, and can therefore be omitted + assert dace_utils.is_field_symbol( + arg_name + ), f"Argument '{arg_name}' not found." if use_fast_call: return sdfg_program.fast_call(*sdfg_program._lastargs) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index dcbab29efc..953491bde3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -10,52 +10,81 @@ import ctypes import unittest -from typing import Any import numpy as np import pytest +import gt4py._core.definitions as core_defs import gt4py.next as gtx -from gt4py.next import int32 from gt4py.next.ffront.fbuiltins import where from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( + E2V, cartesian_case, + unstructured_case, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, + mesh_descriptor, ) +from unittest.mock import patch from . import pytestmark dace = pytest.importorskip("dace") -def get_scalar_values_from_sdfg_args( - args: tuple[list[ctypes._SimpleCData], list[ctypes._SimpleCData]], -) -> list[Any]: - runtime_args, init_args = args - return [ - arg.value for arg in [*runtime_args, *init_args] if not isinstance(arg, ctypes.c_void_p) - ] +def make_mocks(monkeypatch): + # Wrap `compiled_sdfg.CompiledSDFG.fast_call` with mock object + mock_fast_call = unittest.mock.MagicMock() + dace_fast_call = dace.codegen.compiled_sdfg.CompiledSDFG.fast_call + + def mocked_fast_call(self, *args, **kwargs): + mock_fast_call.__call__(*args, **kwargs) + fast_call_result = dace_fast_call(self, *args, **kwargs) + # invalidate all scalar positional arguments to ensure that they are properly set + # next time the SDFG is executed before fast_call + positional_args = set(self.sdfg.arg_names) + sdfg_arglist = self.sdfg.arglist() + for i, (arg_name, arg_type) in enumerate(sdfg_arglist.items()): + if arg_name in positional_args and isinstance(arg_type, dace.data.Scalar): + assert isinstance(self._lastargs[0][i], ctypes.c_int) + self._lastargs[0][i].value = -1 + return fast_call_result + + monkeypatch.setattr(dace.codegen.compiled_sdfg.CompiledSDFG, "fast_call", mocked_fast_call) + + # Wrap `compiled_sdfg.CompiledSDFG._construct_args` with mock object + mock_construct_args = unittest.mock.MagicMock() + dace_construct_args = dace.codegen.compiled_sdfg.CompiledSDFG._construct_args + + def mocked_construct_args(self, *args, **kwargs): + mock_construct_args.__call__(*args, **kwargs) + return dace_construct_args(self, *args, **kwargs) + + monkeypatch.setattr( + dace.codegen.compiled_sdfg.CompiledSDFG, "_construct_args", mocked_construct_args + ) + + return mock_fast_call, mock_construct_args def test_dace_fastcall(cartesian_case, monkeypatch): """Test reuse of SDFG arguments between program calls by means of SDFG fastcall API.""" if not cartesian_case.executor or "dace" not in cartesian_case.executor.__name__: - pytest.skip("DaCe-specific testcase.") + pytest.skip("requires dace backend") @gtx.field_operator def testee( a: cases.IField, a_idx: cases.IField, unused_field: cases.IField, - a0: int32, - a1: int32, - a2: int32, - unused_scalar: int32, + a0: gtx.int32, + a1: gtx.int32, + a2: gtx.int32, + unused_scalar: gtx.int32, ) -> cases.IField: t0 = where(a_idx == 0, a + a0, a) t1 = where(a_idx == 1, t0 + a1, t0) @@ -70,27 +99,7 @@ def testee( unused_field = cases.allocate(cartesian_case, testee, "unused_field")() out = cases.allocate(cartesian_case, testee, cases.RETURN)() - # Wrap `compiled_sdfg.CompiledSDFG.fast_call` with mock object - mock_fast_call = unittest.mock.MagicMock() - mock_fast_call_attr = dace.codegen.compiled_sdfg.CompiledSDFG.fast_call - - def mocked_fast_call(self, *args, **kwargs): - mock_fast_call.__call__(*args, **kwargs) - return mock_fast_call_attr(self, *args, **kwargs) - - monkeypatch.setattr(dace.codegen.compiled_sdfg.CompiledSDFG, "fast_call", mocked_fast_call) - - # Wrap `compiled_sdfg.CompiledSDFG._construct_args` with mock object - mock_construct_args = unittest.mock.MagicMock() - mock_construct_args_attr = dace.codegen.compiled_sdfg.CompiledSDFG._construct_args - - def mocked_construct_args(self, *args, **kwargs): - mock_construct_args.__call__(*args, **kwargs) - return mock_construct_args_attr(self, *args, **kwargs) - - monkeypatch.setattr( - dace.codegen.compiled_sdfg.CompiledSDFG, "_construct_args", mocked_construct_args - ) + mock_fast_call, mock_construct_args = make_mocks(monkeypatch) # Reset mock objects and run/verify GT4Py program def verify_testee(): @@ -111,54 +120,88 @@ def verify_testee(): # On first run, the SDFG arguments will have to be constructed verify_testee() mock_construct_args.assert_called_once() - # here we store the reference to the tuple of arguments passed to `fast_call` on first run and compare on successive runs - fast_call_args = mock_fast_call.call_args.args - # and the scalar values in the order they appear in the program ABI - fast_call_scalar_values = get_scalar_values_from_sdfg_args(fast_call_args) - - def check_one_scalar_arg_changed(prev_scalar_args): - new_scalar_args = get_scalar_values_from_sdfg_args(mock_fast_call.call_args.args) - diff = np.array(new_scalar_args) - np.array(prev_scalar_args) - assert np.count_nonzero(diff) == 1 - - def check_scalar_args_all_same(prev_scalar_args): - new_scalar_args = get_scalar_values_from_sdfg_args(mock_fast_call.call_args.args) - diff = np.array(new_scalar_args) - np.array(prev_scalar_args) - assert np.count_nonzero(diff) == 0 - - def check_pointer_args_all_same(): - for arg, prev in zip(mock_fast_call.call_args.args, fast_call_args, strict=True): - if isinstance(arg, ctypes._Pointer): - assert arg == prev # Now modify the scalar arguments, used and unused ones: reuse previous SDFG arguments for i in range(4): a_offset[i] += 1 verify_testee() mock_construct_args.assert_not_called() - assert mock_fast_call.call_args.args == fast_call_args - check_pointer_args_all_same() - if i < 3: - # same arguments tuple object but one scalar value is changed - check_one_scalar_arg_changed(fast_call_scalar_values) - # update reference scalar values - fast_call_scalar_values = get_scalar_values_from_sdfg_args(fast_call_args) - else: - # unused scalar argument: the symbol is removed from the SDFG arglist and therefore no change - check_scalar_args_all_same(fast_call_scalar_values) # Modify content of current buffer: reuse previous SDFG arguments for buff in (a, unused_field): buff[0] += 1 verify_testee() mock_construct_args.assert_not_called() - # same arguments tuple object and same content - assert mock_fast_call.call_args.args == fast_call_args - check_pointer_args_all_same() - check_scalar_args_all_same(fast_call_scalar_values) # Pass a new buffer, which should trigger reconstruct of SDFG arguments: fastcall API will not be used a = cases.allocate(cartesian_case, testee, "a")() verify_testee() mock_construct_args.assert_called_once() - assert mock_fast_call.call_args.args != fast_call_args + + +def test_dace_fastcall_with_connectivity(unstructured_case, monkeypatch): + """Test reuse of SDFG arguments between program calls by means of SDFG fastcall API.""" + + if not unstructured_case.executor or "dace" not in unstructured_case.executor.__name__: + pytest.skip("requires dace backend") + + connectivity_E2V = unstructured_case.offset_provider["E2V"] + assert isinstance(connectivity_E2V, gtx.common.NeighborTable) + + # check that test connectivities are allocated on host memory + # this is an assumption to test that fast_call cannot be used for gpu tests + assert isinstance(connectivity_E2V.table, np.ndarray) + + @gtx.field_operator + def testee(a: cases.VField) -> cases.EField: + return a(E2V[0]) + + (a,), kwfields = cases.get_default_data(unstructured_case, testee) + numpy_ref = lambda a: a[connectivity_E2V.table[:, 0]] + + mock_fast_call, mock_construct_args = make_mocks(monkeypatch) + + # Reset mock objects and run/verify GT4Py program + def verify_testee(offset_provider): + mock_construct_args.reset_mock() + mock_fast_call.reset_mock() + cases.verify( + unstructured_case, + testee, + a, + **kwfields, + offset_provider=offset_provider, + ref=numpy_ref(a.asnumpy()), + ) + mock_fast_call.assert_called_once() + + if gtx.allocators.is_field_allocator_for( + unstructured_case.executor.allocator, core_defs.DeviceType.CPU + ): + offset_provider = unstructured_case.offset_provider + else: + assert gtx.allocators.is_field_allocator_for( + unstructured_case.executor.allocator, gtx.allocators.CUPY_DEVICE + ) + + import cupy as cp + + # The test connectivities are numpy arrays, by default, and they are copied + # to gpu memory at each program call (see `dace_backend._ensure_is_on_device`), + # therefore fast_call cannot be used (unless cupy reuses the same cupy array + # from the its memory pool, but this behavior is random and unpredictable). + # Here we copy the connectivity to gpu memory, and resuse the same cupy array + # on multiple program calls, in order to ensure that fast_call is used. + offset_provider = { + "E2V": gtx.NeighborTableOffsetProvider( + table=cp.asarray(connectivity_E2V.table), + origin_axis=connectivity_E2V.origin_axis, + neighbor_axis=connectivity_E2V.neighbor_axis, + max_neighbors=connectivity_E2V.max_neighbors, + has_skip_values=connectivity_E2V.has_skip_values, + ) + } + + verify_testee(offset_provider) + verify_testee(offset_provider) + mock_construct_args.assert_not_called() From fb1d494c143a91bb7a2347f15355335c531071b5 Mon Sep 17 00:00:00 2001 From: edopao Date: Fri, 27 Sep 2024 11:02:48 +0200 Subject: [PATCH 11/11] fix[next]: remove debug messages in OTF workflow (#1666) Remove some debug messages in OTF workflow related to AOT-toolchain. --- src/gt4py/next/otf/arguments.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 2bd6c2ebe9..fd9a0c225a 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -205,9 +205,7 @@ def iter_size_args(args: tuple[Any, ...], inside_tuple: bool = False) -> Iterato This can be used to generate domain size arguments for FieldView Programs that use an implicit domain. """ - print(f"iter_size_args: matching args {tuple(type(arg) for arg in args)}") for arg in args: - print(f"iter_size_args: matching arg {arg}") match arg: case tuple(): # we only need the first field, because all fields in a tuple must have the same dims and sizes @@ -215,7 +213,6 @@ def iter_size_args(args: tuple[Any, ...], inside_tuple: bool = False) -> Iterato if first_field: yield from iter_size_args((first_field,)) case common.Field(): - print(f"iter_size_args: yielding from {arg.ndarray.shape}") yield from arg.ndarray.shape case _: pass