diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py deleted file mode 100644 index 94c962e92d..0000000000 --- a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py +++ /dev/null @@ -1,181 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause -# FIXME[#1582](tehrengruber): file should be removed after refactoring to GTIR -import enum -from typing import Callable, Optional - -from gt4py.eve import utils as eve_utils -from gt4py.next import common -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs -from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet -from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple -from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination -from gt4py.next.iterator.transforms.eta_reduction import EtaReduction -from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars -from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan -from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -from gt4py.next.iterator.transforms.inline_lifts import InlineLifts -from gt4py.next.iterator.transforms.merge_let import MergeLet -from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts -from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction -from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce - - -@enum.unique -class LiftMode(enum.Enum): - FORCE_INLINE = enum.auto() - USE_TEMPORARIES = enum.auto() - - -def _inline_lifts(ir, lift_mode): - if lift_mode == LiftMode.FORCE_INLINE: - return InlineLifts().visit(ir) - elif lift_mode == LiftMode.USE_TEMPORARIES: - return InlineLifts( - flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT - | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. - ).visit(ir) - else: - raise ValueError() - - return ir - - -def _inline_into_scan(ir, *, max_iter=10): - for _ in range(10): - # in case there are multiple levels of lambdas around the scan we have to do multiple iterations - inlined = InlineIntoScan().visit(ir) - inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") - return ir - - -def apply_common_transforms( - ir: itir.Node, - *, - lift_mode=None, - offset_provider=None, - unroll_reduce=False, - common_subexpression_elimination=True, - force_inline_lambda_args=False, - unconditionally_collapse_tuples=False, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - offset_provider_type: Optional[common.OffsetProviderType] = None, -) -> itir.Program: - assert isinstance(ir, itir.FencilDefinition) - # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps - if offset_provider_type is None: - offset_provider_type = common.offset_provider_to_type(offset_provider) - - ir = fencil_to_program.FencilToProgram().apply(ir) - icdlv_uids = eve_utils.UIDGenerator() - - if lift_mode is None: - lift_mode = LiftMode.FORCE_INLINE - assert isinstance(lift_mode, LiftMode) - ir = MergeLet().visit(ir) - ir = inline_fundefs.InlineFundefs().visit(ir) - - ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program - ir = PropagateDeref.apply(ir) - ir = NormalizeShifts().visit(ir) - - for _ in range(10): - inlined = ir - - inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil - inlined = _inline_lifts(inlined, lift_mode) - - inlined = InlineLambdas.apply( - inlined, - opcount_preserving=True, - force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), - # If trivial lifts are not inlined we might create temporaries for constants. In all - # other cases we want it anyway. - force_inline_trivial_lift_args=True, - ) - inlined = ConstantFolding.apply(inlined) - # This pass is required to be in the loop such that when an `if_` call with tuple arguments - # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply( - inlined, - offset_provider_type=offset_provider_type, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - # This pass is required such that a deref outside of a - # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the - # `tuple_get` is removed by the `CollapseTuple` pass. - inlined = PropagateDeref.apply(inlined) - - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - - if lift_mode != LiftMode.FORCE_INLINE: - raise NotImplementedError() - - # Since `CollapseTuple` relies on the type inference which does not support returning tuples - # larger than the number of closure outputs as given by the unconditional collapse, we can - # only run the unconditional version here instead of in the loop above. - if unconditionally_collapse_tuples: - ir = CollapseTuple.apply( - ir, - ignore_tuple_size=True, - offset_provider_type=offset_provider_type, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - - if lift_mode == LiftMode.FORCE_INLINE: - ir = _inline_into_scan(ir) - - ir = NormalizeShifts().visit(ir) - - ir = FuseMaps().visit(ir) - ir = CollapseListGet().visit(ir) - - if unroll_reduce: - for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) - if unrolled == ir: - break - ir = unrolled - ir = CollapseListGet().visit(ir) - ir = NormalizeShifts().visit(ir) - ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) - ir = NormalizeShifts().visit(ir) - else: - raise RuntimeError("Reduction unrolling failed.") - - ir = EtaReduction().visit(ir) - ir = ScanEtaReduction().visit(ir) - - if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[type-var] # always an itir.Program - ir = MergeLet().visit(ir) - - ir = InlineLambdas.apply( - ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args - ) - - assert isinstance(ir, itir.Program) - return ir diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 95186e0b5d..1b3b930818 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -8,45 +8,34 @@ import factory +import gt4py._core.definitions as core_defs +import gt4py.next.allocators as next_allocators from gt4py.next import backend +from gt4py.next.otf import workflow from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow -from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory -class DaCeIteratorBackendFactory(GTFNBackendFactory): +class DaCeFieldviewBackendFactory(GTFNBackendFactory): + class Meta: + model = backend.Backend + class Params: - otf_workflow = factory.SubFactory( - dace_iterator_workflow.DaCeWorkflowFactory, - device_type=factory.SelfAttribute("..device_type"), - use_field_canonical_representation=factory.SelfAttribute( - "..use_field_canonical_representation" - ), + name_device = "cpu" + name_cached = "" + name_postfix = "" + gpu = factory.Trait( + allocator=next_allocators.StandardGPUFieldBufferAllocator(), + device_type=next_allocators.CUPY_DEVICE or core_defs.DeviceType.CUDA, + name_device="gpu", ) - auto_optimize = factory.Trait( - otf_workflow__translation__auto_optimize=True, name_postfix="_opt" + cached = factory.Trait( + executor=factory.LazyAttribute( + lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + ), + name_cached="_cached", ) - use_field_canonical_representation: bool = False - - name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.itir" - ) - - transforms = backend.LEGACY_TRANSFORMS - - -run_dace_cpu = DaCeIteratorBackendFactory(cached=True, auto_optimize=True) -run_dace_cpu_noopt = DaCeIteratorBackendFactory(cached=True, auto_optimize=False) - -run_dace_gpu = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=True) -run_dace_gpu_noopt = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=False) - -itir_cpu = run_dace_cpu -itir_gpu = run_dace_gpu - - -class DaCeFieldviewBackendFactory(GTFNBackendFactory): - class Params: + device_type = core_defs.DeviceType.CPU otf_workflow = factory.SubFactory( dace_fieldview_workflow.DaCeWorkflowFactory, device_type=factory.SelfAttribute("..device_type"), @@ -55,11 +44,16 @@ class Params: auto_optimize = factory.Trait(name_postfix="_opt") name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.gtir" + lambda o: f"run_dace_{o.name_device}{o.name_cached}{o.name_postfix}" ) + executor = factory.LazyAttribute(lambda o: o.otf_workflow) + allocator = next_allocators.StandardCPUFieldBufferAllocator() transforms = backend.DEFAULT_TRANSFORMS -gtir_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) -gtir_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False) +run_dace_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=True) +run_dace_cpu_noopt = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) + +run_dace_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=True) +run_dace_gpu_noopt = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False) diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 56ba08015b..90e7e07ad5 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -24,7 +24,7 @@ cp = None -def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool) -> Any: +def _convert_arg(arg: Any, sdfg_param: str) -> Any: if not isinstance(arg, gtx_common.Field): return arg if len(arg.domain.dims) == 0: @@ -41,26 +41,14 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: raise RuntimeError( f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}." ) - if not use_field_canonical_representation: - return arg.ndarray - # the canonical representation requires alphabetical ordering of the dimensions in field domain definition - sorted_dims = dace_utils.get_sorted_dims(arg.domain.dims) - ndim = len(sorted_dims) - dim_indices = [dim_index for dim_index, _ in sorted_dims] - if isinstance(arg.ndarray, np.ndarray): - return np.moveaxis(arg.ndarray, range(ndim), dim_indices) - else: - assert cp is not None and isinstance(arg.ndarray, cp.ndarray) - return cp.moveaxis(arg.ndarray, range(ndim), dim_indices) - - -def _get_args( - sdfg: dace.SDFG, args: Sequence[Any], use_field_canonical_representation: bool -) -> dict[str, Any]: + return arg.ndarray + + +def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: sdfg_params: Sequence[str] = sdfg.arg_names flat_args: Iterable[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) return { - sdfg_param: _convert_arg(arg, sdfg_param, use_field_canonical_representation) + sdfg_param: _convert_arg(arg, sdfg_param) for sdfg_param, arg in zip(sdfg_params, flat_args, strict=True) } @@ -154,10 +142,10 @@ def get_sdfg_conn_args( def get_sdfg_args( sdfg: dace.SDFG, + offset_provider: gtx_common.OffsetProvider, *args: Any, check_args: bool = False, on_gpu: bool = False, - use_field_canonical_representation: bool = True, **kwargs: Any, ) -> dict[str, Any]: """Extracts the arguments needed to call the SDFG. @@ -166,10 +154,10 @@ def get_sdfg_args( Args: sdfg: The SDFG for which we want to get the arguments. + offset_provider: Offset provider. """ - offset_provider = kwargs["offset_provider"] - dace_args = _get_args(sdfg, args, use_field_canonical_representation) + dace_args = _get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu) dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index 3e96ef3cec..ac15bc1cbf 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Final, Literal, Optional, Sequence +from typing import Final, Literal, Optional import dace @@ -96,10 +96,3 @@ def filter_connectivity_types( for offset, conn in offset_provider_type.items() if isinstance(conn, gtx_common.NeighborConnectivityType) } - - -def get_sorted_dims( - dims: Sequence[gtx_common.Dimension], -) -> Sequence[tuple[int, gtx_common.Dimension]]: - """Sort list of dimensions in alphabetical order.""" - return sorted(enumerate(dims), key=lambda v: v[1].value) diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index 91e83dba9d..5d9ac863c5 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -150,9 +150,9 @@ def decorated_program( sdfg_args = dace_backend.get_sdfg_args( sdfg, + offset_provider, *args, check_args=False, - offset_provider=offset_provider, on_gpu=on_gpu, use_field_canonical_representation=use_field_canonical_representation, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py deleted file mode 100644 index ef09cf51cd..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ /dev/null @@ -1,377 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import dataclasses -import warnings -from collections import OrderedDict -from collections.abc import Callable, Sequence -from dataclasses import field -from inspect import currentframe, getframeinfo -from pathlib import Path -from typing import Any, ClassVar, Optional - -import dace -import numpy as np -from dace.sdfg import utils as sdutils -from dace.transformation.auto import auto_optimize as autoopt - -import gt4py.next.iterator.ir as itir -from gt4py.next import common -from gt4py.next.ffront import decorator -from gt4py.next.iterator import transforms as itir_transforms -from gt4py.next.iterator.ir import SymRef -from gt4py.next.iterator.transforms import ( - pass_manager_legacy as legacy_itir_transforms, - program_to_fencil, -) -from gt4py.next.iterator.type_system import inference as itir_type_inference -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_specifications as ts - -from .itir_to_sdfg import ItirToSDFG - - -def preprocess_program( - program: itir.FencilDefinition, - offset_provider_type: common.OffsetProviderType, - lift_mode: legacy_itir_transforms.LiftMode, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - unroll_reduce: bool = False, -): - node = legacy_itir_transforms.apply_common_transforms( - program, - common_subexpression_elimination=False, - force_inline_lambda_args=True, - lift_mode=lift_mode, - offset_provider_type=offset_provider_type, - symbolic_domain_sizes=symbolic_domain_sizes, - temporary_extraction_heuristics=temporary_extraction_heuristics, - unroll_reduce=unroll_reduce, - ) - - node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) - - if isinstance(node, itir.Program): - fencil_definition = program_to_fencil.program_to_fencil(node) - tmps = node.declarations - assert all(isinstance(tmp, itir.Temporary) for tmp in tmps) - else: - raise TypeError(f"Expected 'Program', got '{type(node).__name__}'.") - - return fencil_definition, tmps - - -def build_sdfg_from_itir( - program: itir.FencilDefinition, - arg_types: Sequence[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - auto_optimize: bool = False, - on_gpu: bool = False, - column_axis: Optional[common.Dimension] = None, - lift_mode: legacy_itir_transforms.LiftMode = legacy_itir_transforms.LiftMode.FORCE_INLINE, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - load_sdfg_from_file: bool = False, - save_sdfg: bool = True, - use_field_canonical_representation: bool = True, -) -> dace.SDFG: - """Translate a Fencil into an SDFG. - - Args: - program: The Fencil that should be translated. - arg_types: Types of the arguments passed to the fencil. - offset_provider: The set of offset providers that should be used. - auto_optimize: Apply DaCe's `auto_optimize` heuristic. - on_gpu: Performs the translation for GPU, defaults to `False`. - column_axis: The column axis to be used, defaults to `None`. - lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`. - symbolic_domain_sizes: Used for generation of liskov bindings when temporaries are enabled. - load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only. - save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`. - use_field_canonical_representation: If `True`, assume that the fields dimensions are sorted alphabetically. - """ - - sdfg_filename = f"_dacegraphs/gt4py/{program.id}.sdfg" - if load_sdfg_from_file and Path(sdfg_filename).exists(): - sdfg: dace.SDFG = dace.SDFG.from_file(sdfg_filename) - sdfg.validate() - return sdfg - - # visit ITIR and generate SDFG - program, tmps = preprocess_program( - program, - offset_provider_type, - lift_mode, - symbolic_domain_sizes, - temporary_extraction_heuristics, - ) - sdfg_genenerator = ItirToSDFG( - list(arg_types), - offset_provider_type, - tmps, - use_field_canonical_representation, - column_axis, - ) - sdfg = sdfg_genenerator.visit(program) - if sdfg is None: - raise RuntimeError(f"Visit failed for program {program.id}.") - - for nested_sdfg in sdfg.all_sdfgs_recursive(): - if not nested_sdfg.debuginfo: - _, frameinfo = ( - warnings.warn( - f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg.", - stacklevel=2, - ), - getframeinfo(currentframe()), # type: ignore[arg-type] - ) - nested_sdfg.debuginfo = dace.dtypes.DebugInfo( - start_line=frameinfo.lineno, end_line=frameinfo.lineno, filename=frameinfo.filename - ) - - # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct - sdutils.inline_loop_blocks(sdfg) - - # run DaCe transformations to simplify the SDFG - sdfg.simplify() - - # run DaCe auto-optimization heuristics - if auto_optimize: - # TODO: Investigate performance improvement from SDFG specialization with constant symbols, - # for array shape and strides, although this would imply JIT compilation. - symbols: dict[str, int] = {} - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) - elif on_gpu: - autoopt.apply_gpu_storage(sdfg) - - if on_gpu: - sdfg.apply_gpu_transformations() - - # Store the sdfg such that we can later reuse it. - if save_sdfg: - sdfg.save(sdfg_filename) - - return sdfg - - -@dataclasses.dataclass(frozen=True) -class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible): - """Extension of GT4Py Program implementing the SDFGConvertible interface.""" - - sdfg_closure_vars: dict[str, Any] = field(default_factory=dict) - - # Being a ClassVar ensures that in an SDFG with multiple nested GT4Py Programs, - # there is no name mangling of the connectivity tables used across the nested SDFGs - # since they share the same memory address. - connectivity_tables_data_descriptors: ClassVar[ - dict[str, dace.data.Array] - ] = {} # symbolically defined - - def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: - if "dace" not in self.backend.name.lower(): # type: ignore[union-attr] - raise ValueError("The SDFG can be generated only for the DaCe backend.") - - params = {str(p.id): p.type for p in self.itir.params} - fields = {str(p.id): p.type for p in self.itir.params if hasattr(p.type, "dims")} - arg_types = [*params.values()] - - dace_parsed_args = [*args, *kwargs.values()] - gt4py_program_args = [*params.values()] - _crosscheck_dace_parsing(dace_parsed_args, gt4py_program_args) - - if self.connectivities is None: - raise ValueError( - "[DaCe Orchestration] Connectivities -at compile time- are required to generate the SDFG. Use `with_connectivities` method." - ) - offset_provider_type = {**self.connectivities, **self._implicit_offset_provider} - - sdfg = self.backend.executor.step.translation.generate_sdfg( # type: ignore[union-attr] - self.itir, - arg_types, - offset_provider_type=offset_provider_type, - column_axis=kwargs.get("column_axis", None), - ) - self.sdfg_closure_vars["sdfg.arrays"] = sdfg.arrays # use it in __sdfg_closure__ - - # Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, offset_providers_per_input_field - # Add them as dynamic properties to the SDFG - - assert all( - isinstance(in_field, SymRef) - for closure in self.itir.closures - for in_field in closure.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_fields = [ - str(in_field.id) # type: ignore[union-attr] # ensured by assert - for closure in self.itir.closures - for in_field in closure.inputs - if str(in_field.id) in fields # type: ignore[union-attr] # ensured by assert - ] - sdfg.gt4py_program_input_fields = { - in_field: dim - for in_field in input_fields - for dim in fields[in_field].dims # type: ignore[union-attr] - if dim.kind == common.DimensionKind.HORIZONTAL - } - - output_fields = [] - for closure in self.itir.closures: - output = closure.output - if isinstance(output, itir.SymRef): - if str(output.id) in fields: - output_fields.append(str(output.id)) - else: - for arg in output.args: - if str(arg.id) in fields: # type: ignore[attr-defined] - output_fields.append(str(arg.id)) # type: ignore[attr-defined] - sdfg.gt4py_program_output_fields = { - output: dim - for output in output_fields - for dim in fields[output].dims # type: ignore[union-attr] - if dim.kind == common.DimensionKind.HORIZONTAL - } - - sdfg.offset_providers_per_input_field = {} - itir_tmp = legacy_itir_transforms.apply_common_transforms( - self.itir, offset_provider_type=offset_provider_type - ) - itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) - for closure in itir_tmp_fencil.closures: - params_shifts = itir_transforms.trace_shifts.trace_stencil( - closure.stencil, num_args=len(closure.inputs) - ) - for param, shifts in zip(closure.inputs, params_shifts): - assert isinstance( - param, SymRef - ) # backend only supports SymRef inputs, not `index` calls - if not isinstance(param.id, str): - continue - if param.id not in sdfg.gt4py_program_input_fields: - continue - sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts)) - - return sdfg - - def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: - """ - Returns the closure arrays of the SDFG represented by this object - as a mapping between array name and the corresponding value. - - The connectivity tables are defined symbolically, i.e. table sizes & strides are DaCe symbols. - The need to define the connectivity tables in the `__sdfg_closure__` arises from the fact that - the offset providers are not part of GT4Py Program's arguments. - Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. - """ - offset_provider_type = self.connectivities - - # Define DaCe symbols - connectivity_table_size_symbols = { - dace_utils.field_size_symbol_name( - dace_utils.connectivity_identifier(k), axis - ): dace.symbol( - dace_utils.field_size_symbol_name(dace_utils.connectivity_identifier(k), axis) - ) - for k, v in offset_provider_type.items() # type: ignore[union-attr] - for axis in [0, 1] - if isinstance(v, common.NeighborConnectivityType) - and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] - } - - connectivity_table_stride_symbols = { - dace_utils.field_stride_symbol_name( - dace_utils.connectivity_identifier(k), axis - ): dace.symbol( - dace_utils.field_stride_symbol_name(dace_utils.connectivity_identifier(k), axis) - ) - for k, v in offset_provider_type.items() # type: ignore[union-attr] - for axis in [0, 1] - if isinstance(v, common.NeighborConnectivityType) - and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] - } - - symbols = {**connectivity_table_size_symbols, **connectivity_table_stride_symbols} - - # Define the storage location (e.g. CPU, GPU) of the connectivity tables - if "storage" not in Program.connectivity_tables_data_descriptors: - for k, v in offset_provider_type.items(): # type: ignore[union-attr] - if not isinstance(v, common.NeighborConnectivityType): - continue - if dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"]: - Program.connectivity_tables_data_descriptors["storage"] = ( - self.sdfg_closure_vars[ - "sdfg.arrays" - ][dace_utils.connectivity_identifier(k)].storage - ) - break - - # Build the closure dictionary - closure_dict = {} - for k, v in offset_provider_type.items(): # type: ignore[union-attr] - conn_id = dace_utils.connectivity_identifier(k) - if ( - isinstance(v, common.NeighborConnectivityType) - and conn_id in self.sdfg_closure_vars["sdfg.arrays"] - ): - if conn_id not in Program.connectivity_tables_data_descriptors: - Program.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( - dtype=dace.int64 if v.dtype.scalar_type == np.int64 else dace.int32, - shape=[ - symbols[dace_utils.field_size_symbol_name(conn_id, 0)], - symbols[dace_utils.field_size_symbol_name(conn_id, 1)], - ], - strides=[ - symbols[dace_utils.field_stride_symbol_name(conn_id, 0)], - symbols[dace_utils.field_stride_symbol_name(conn_id, 1)], - ], - storage=Program.connectivity_tables_data_descriptors["storage"], - ) - closure_dict[conn_id] = Program.connectivity_tables_data_descriptors[conn_id] - - return closure_dict - - def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: - args = [] - for arg in self.past_stage.past_node.params: - args.append(arg.id) - return (args, []) - - -def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> bool: - for dace_parsed_arg, gt4py_program_arg in zip(dace_parsed_args, gt4py_program_args): - if isinstance(dace_parsed_arg, dace.data.Scalar): - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg) - elif isinstance( - dace_parsed_arg, (bool, int, float, str, np.bool_, np.integer, np.floating, np.str_) - ): # compile-time constant scalar - assert isinstance(gt4py_program_arg, ts.ScalarType) - if isinstance(dace_parsed_arg, (bool, np.bool_)): - assert gt4py_program_arg.kind == ts.ScalarKind.BOOL - elif isinstance(dace_parsed_arg, (int, np.integer)): - assert gt4py_program_arg.kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] - elif isinstance(dace_parsed_arg, (float, np.floating)): - assert gt4py_program_arg.kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] - elif isinstance(dace_parsed_arg, (str, np.str_)): - assert gt4py_program_arg.kind == ts.ScalarKind.STRING - elif isinstance(dace_parsed_arg, dace.data.Array): - assert isinstance(gt4py_program_arg, ts.FieldType) - assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype) - elif isinstance( - dace_parsed_arg, (dace.data.Structure, dict, OrderedDict) - ): # offset_provider - continue - else: - raise ValueError(f"Unresolved case for {dace_parsed_arg} (==, !=) {gt4py_program_arg}") - - return True diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py deleted file mode 100644 index 823943cfd5..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ /dev/null @@ -1,809 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import warnings -from typing import Optional, Sequence, cast - -import dace -from dace.sdfg.state import LoopRegion - -import gt4py.eve as eve -from gt4py.next import Dimension, DimensionKind, common -from gt4py.next.ffront import fbuiltins as gtx_fbuiltins -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt - -from .itir_to_tasklet import ( - Context, - GatherOutputSymbolsPass, - PythonTaskletCodegen, - SymbolExpr, - TaskletExpr, - ValueExpr, - closure_to_tasklet_sdfg, - is_scan, -) -from .utility import ( - add_mapped_nested_sdfg, - flatten_list, - get_used_connectivities, - map_nested_sdfg_symbols, - new_array_symbols, - unique_var_name, -) - - -def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]: - """ - Parse stencil expression to extract the scan arguments. - - Returns - ------- - tuple(is_forward, init_carry) - The output tuple fields verify the following semantics: - - is_forward: forward boolean flag - - init_carry: carry initial value - """ - stencil_fobj = cast(FunCall, stencil) - is_forward = stencil_fobj.args[1] - assert isinstance(is_forward, Literal) and type_info.is_logical(is_forward.type) - init_carry = stencil_fobj.args[2] - assert isinstance(init_carry, Literal) - return is_forward.value == "True", init_carry - - -def _get_scan_dim( - column_axis: Dimension, - storage_types: dict[str, ts.TypeSpec], - output: SymRef, - use_field_canonical_representation: bool, -) -> tuple[str, int, ts.ScalarType]: - """ - Extract information about the scan dimension. - - Returns - ------- - tuple(scan_dim_name, scan_dim_index, scan_dim_dtype) - The output tuple fields verify the following semantics: - - scan_dim_name: name of the scan dimension - - scan_dim_index: domain index of the scan dimension - - scan_dim_dtype: data type along the scan dimension - """ - output_type = storage_types[output.id] - assert isinstance(output_type, ts.FieldType) - sorted_dims = [ - dim - for _, dim in ( - dace_utils.get_sorted_dims(output_type.dims) - if use_field_canonical_representation - else enumerate(output_type.dims) - ) - ] - return (column_axis.value, sorted_dims.index(column_axis), output_type.dtype) - - -def _make_array_shape_and_strides( - name: str, - dims: Sequence[Dimension], - offset_provider_type: common.OffsetProviderType, - sort_dims: bool, -) -> tuple[list[dace.symbol], list[dace.symbol]]: - """ - Parse field dimensions and allocate symbols for array shape and strides. - - For local dimensions, the size is known at compile-time and therefore - the corresponding array shape dimension is set to an integer literal value. - - Returns - ------- - tuple(shape, strides) - The output tuple fields are arrays of dace symbolic expressions. - """ - dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) - sorted_dims = dace_utils.get_sorted_dims(dims) if sort_dims else list(enumerate(dims)) - connectivity_types = dace_utils.filter_connectivity_types(offset_provider_type) - shape = [ - ( - connectivity_types[dim.value].max_neighbors - if dim.kind == DimensionKind.LOCAL - # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain - else dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) - ) - for i, dim in sorted_dims - ] - strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) for i, _ in sorted_dims - ] - return shape, strides - - -def _check_no_lifts(node: itir.StencilClosure): - """ - Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions. - - Returns - ------- - True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise. - """ - neighbors_call_count = 0 - for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"): - if getattr(fun, "id", "") == "neighbors": - neighbors_call_count = 3 - elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1: - return False - neighbors_call_count = max(0, neighbors_call_count - 1) - return True - - -class ItirToSDFG(eve.NodeVisitor): - param_types: list[ts.TypeSpec] - storage_types: dict[str, ts.TypeSpec] - column_axis: Optional[Dimension] - offset_provider_type: common.OffsetProviderType - unique_id: int - use_field_canonical_representation: bool - - def __init__( - self, - param_types: list[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - tmps: list[itir.Temporary], - use_field_canonical_representation: bool, - column_axis: Optional[Dimension] = None, - ): - self.param_types = param_types - self.column_axis = column_axis - self.offset_provider_type = offset_provider_type - self.storage_types = {} - self.tmps = tmps - self.use_field_canonical_representation = use_field_canonical_representation - - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, sort_dimensions: bool): - if isinstance(type_, ts.FieldType): - shape, strides = _make_array_shape_and_strides( - name, type_.dims, self.offset_provider_type, sort_dimensions - ) - dtype = dace_utils.as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) - - elif isinstance(type_, ts.ScalarType): - dtype = dace_utils.as_dace_type(type_) - if name in sdfg.symbols: - assert sdfg.symbols[name].dtype == dtype - else: - sdfg.add_symbol(name, dtype) - - else: - raise NotImplementedError() - self.storage_types[name] = type_ - - def add_storage_for_temporaries( - self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG - ) -> dict[str, str]: - symbol_map: dict[str, TaskletExpr] = {} - # The shape of temporary arrays might be defined based on scalar values passed as program arguments. - # Here we collect these values in a symbol map. - for sym in node_params: - if isinstance(sym.type, ts.ScalarType): - name_ = str(sym.id) - symbol_map[name_] = SymbolExpr(name_, dace_utils.as_dace_type(sym.type)) - - tmp_symbols: dict[str, str] = {} - for tmp in self.tmps: - tmp_name = str(tmp.id) - - # We visit the domain of the temporary field, passing the set of available symbols. - assert isinstance(tmp.domain, itir.FunCall) - domain_ctx = Context(program_sdfg, defs_state, symbol_map) - tmp_domain = self._visit_domain(tmp.domain, domain_ctx) - - if isinstance(tmp.type, ts.TupleType): - raise NotImplementedError("Temporaries of tuples are not supported.") - assert isinstance(tmp.type, ts.FieldType) and isinstance(tmp.dtype, ts.ScalarType) - - # We store the FieldType for this temporary array. - self.storage_types[tmp_name] = tmp.type - - # N.B.: skip generation of symbolic strides and just let dace assign default strides, for now. - # Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics - # to assign optimal stride values. - tmp_shape, _ = new_array_symbols(tmp_name, len(tmp.type.dims)) - _, tmp_array = program_sdfg.add_array( - tmp_name, tmp_shape, dace_utils.as_dace_type(tmp.dtype), transient=True - ) - - # Loop through all dimensions to visit the symbolic expressions for array shape and offset. - # These expressions are later mapped to interstate symbols. - for (_, (begin, end)), shape_sym in zip(tmp_domain, tmp_array.shape): - # The temporary field has a dimension range defined by `begin` and `end` values. - # Therefore, the actual size is given by the difference `end.value - begin.value`. - # Instead of allocating the actual size, we allocate space to enable indexing from 0 - # because we want to avoid using dace array offsets (which will be deprecated soon). - # The result should still be valid, but the stencil will be using only a subset - # of the array. - if not (isinstance(begin, SymbolExpr) and begin.value == "0"): - warnings.warn( - f"Domain start offset for temporary {tmp_name} is ignored.", stacklevel=2 - ) - tmp_symbols[str(shape_sym)] = end.value - - return tmp_symbols - - def create_memlet_at(self, field_name: str, index: dict[str, str]): - field_type = self.storage_types[field_name] - assert isinstance(field_type, ts.FieldType) - if self.use_field_canonical_representation: - field_index = [ - index[dim.value] for _, dim in dace_utils.get_sorted_dims(field_type.dims) - ] - else: - field_index = [index[dim.value] for dim in field_type.dims] - subset = ", ".join(field_index) - return dace.Memlet(data=field_name, subset=subset) - - def get_output_nodes( - self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState - ) -> dict[str, dace.nodes.AccessNode]: - # Visit output node, which could be a `make_tuple` expression, to collect the required access nodes - output_symbols_pass = GatherOutputSymbolsPass(sdfg, state) - output_symbols_pass.visit(closure.output) - # Visit output node again to generate the corresponding tasklet - context = Context(sdfg, state, output_symbols_pass.symbol_refs) - translator = PythonTaskletCodegen( - self.offset_provider_type, context, self.use_field_canonical_representation - ) - output_nodes = flatten_list(translator.visit(closure.output)) - return {node.value.data: node.value for node in output_nodes} - - def visit_FencilDefinition(self, node: itir.FencilDefinition): - program_sdfg = dace.SDFG(name=node.id) - program_sdfg.debuginfo = dace_utils.debug_info(node) - entry_state = program_sdfg.add_state("program_entry", is_start_block=True) - - # Filter neighbor tables from offset providers. - connectivity_types = get_used_connectivities(node, self.offset_provider_type) - - # Add program parameters as SDFG storages. - for param, type_ in zip(node.params, self.param_types): - self.add_storage( - program_sdfg, str(param.id), type_, self.use_field_canonical_representation - ) - - if self.tmps: - tmp_symbols = self.add_storage_for_temporaries(node.params, entry_state, program_sdfg) - # on the first interstate edge define symbols for shape and offsets of temporary arrays - last_state = program_sdfg.add_state("init_symbols_for_temporaries") - program_sdfg.add_edge( - entry_state, last_state, dace.InterstateEdge(assignments=tmp_symbols) - ) - else: - last_state = entry_state - - # Add connectivities as SDFG storages. - for offset, connectivity_type in connectivity_types.items(): - scalar_type = tt.from_dtype(connectivity_type.dtype) - type_ = ts.FieldType( - [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type - ) - self.add_storage( - program_sdfg, - dace_utils.connectivity_identifier(offset), - type_, - sort_dimensions=False, - ) - - # Create a nested SDFG for all stencil closures. - for closure in node.closures: - # Translate the closure and its stencil's body to an SDFG. - closure_sdfg, input_names, output_names = self.visit( - closure, array_table=program_sdfg.arrays - ) - - # Create a new state for the closure. - last_state = program_sdfg.add_state_after(last_state) - - # Create memlets to transfer the program parameters - input_mapping = { - name: dace.Memlet.from_array(name, program_sdfg.arrays[name]) - for name in input_names - } - output_mapping = { - name: dace.Memlet.from_array(name, program_sdfg.arrays[name]) - for name in output_names - } - - symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, input_mapping) - - # Insert the closure's SDFG as a nested SDFG of the program. - nsdfg_node = last_state.add_nested_sdfg( - sdfg=closure_sdfg, - parent=program_sdfg, - inputs=set(input_names), - outputs=set(output_names), - symbol_mapping=symbol_mapping, - debuginfo=closure_sdfg.debuginfo, - ) - - # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. - for inner_name, memlet in input_mapping.items(): - access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) - last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - - for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) - last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) - - # Create the call signature for the SDFG. - # Only the arguments requiered by the Fencil, i.e. `node.params` are added as positional arguments. - # The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments. - program_sdfg.arg_names = [str(a) for a in node.params] - - program_sdfg.validate() - return program_sdfg - - def visit_StencilClosure( - self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] - ) -> tuple[dace.SDFG, list[str], list[str]]: - assert _check_no_lifts(node) - - # Create the closure's nested SDFG and single state. - closure_sdfg = dace.SDFG(name="closure") - closure_sdfg.debuginfo = dace_utils.debug_info(node) - closure_state = closure_sdfg.add_state("closure_entry") - closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) - - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state) - output_names = [k for k, _ in output_nodes.items()] - - # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG. - input_transients_mapping = {} - for name in [*input_names, *connectivity_names, *output_names]: - if name in closure_sdfg.arrays: - assert name in input_names and name in output_names - # In case of closures with in/out fields, there is risk of race condition - # between read/write access nodes in the (asynchronous) map tasklet. - transient_name = unique_var_name() - closure_sdfg.add_array( - transient_name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - transient=True, - ) - closure_init_state.add_nedge( - closure_init_state.add_access(name, debuginfo=closure_sdfg.debuginfo), - closure_init_state.add_access(transient_name, debuginfo=closure_sdfg.debuginfo), - dace.Memlet.from_array(name, closure_sdfg.arrays[name]), - ) - input_transients_mapping[name] = transient_name - elif isinstance(self.storage_types[name], ts.FieldType): - closure_sdfg.add_array( - name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - ) - else: - assert isinstance(self.storage_types[name], ts.ScalarType) - - input_field_names = [ - input_name - for input_name in input_names - if isinstance(self.storage_types[input_name], ts.FieldType) - ] - - # Closure outputs should all be fields - assert all( - isinstance(self.storage_types[output_name], ts.FieldType) - for output_name in output_names - ) - - # Update symbol table and get output domain of the closure - program_arg_syms: dict[str, TaskletExpr] = {} - for name, type_ in self.storage_types.items(): - if isinstance(type_, ts.ScalarType): - dtype = dace_utils.as_dace_type(type_) - if name in input_names: - out_name = unique_var_name() - closure_sdfg.add_scalar(out_name, dtype, transient=True) - out_tasklet = closure_init_state.add_tasklet( - f"get_{name}", - {}, - {"__result"}, - f"__result = {name}", - debuginfo=closure_sdfg.debuginfo, - ) - access = closure_init_state.add_access( - out_name, debuginfo=closure_sdfg.debuginfo - ) - value = ValueExpr(access, dtype) - memlet = dace.Memlet(data=out_name, subset="0") - closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) - program_arg_syms[name] = value - else: - program_arg_syms[name] = SymbolExpr(name, dtype) - else: - assert isinstance(type_, ts.FieldType) - # make shape symbols (corresponding to field size) available as arguments to domain visitor - if name in input_names or name in output_names: - field_symbols = [ - val - for val in closure_sdfg.arrays[name].shape - if isinstance(val, dace.symbol) and str(val) not in input_names - ] - for sym in field_symbols: - sym_name = str(sym) - program_arg_syms[sym_name] = SymbolExpr(sym, sym.dtype) - closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) - closure_domain = self._visit_domain(node.domain, closure_ctx) - - # Map SDFG tasklet arguments to parameters - input_local_names = [ - ( - input_transients_mapping[input_name] - if input_name in input_transients_mapping - else ( - input_name - if input_name in input_field_names - else cast(ValueExpr, program_arg_syms[input_name]).value.data - ) - ) - for input_name in input_names - ] - input_memlets = [ - dace.Memlet.from_array(name, closure_sdfg.arrays[name]) - for name in [*input_local_names, *connectivity_names] - ] - - # create and write to transient that is then copied back to actual output array to avoid aliasing of - # same memory in nested SDFG with different names - output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names} - # scan operator should always be the first function call in a closure - if is_scan(node.stencil): - assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs" - transient_name, output_name = next(iter(output_connectors_mapping.items())) - - nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( - node, closure_sdfg.arrays, closure_domain, transient_name - ) - results = [transient_name] - - _, (scan_lb, scan_ub) = closure_domain[scan_dim_index] - output_subset = f"{scan_lb.value}:{scan_ub.value}" - - domain_subset = { - dim: ( - f"i_{dim}" - if f"i_{dim}" in map_ranges - else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" - ) - for dim, _ in closure_domain - } - output_memlets = [self.create_memlet_at(output_name, domain_subset)] - else: - nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( - node, closure_sdfg.arrays, closure_domain - ) - - output_subset = "0" - - output_memlets = [ - self.create_memlet_at(output_name, {dim: f"i_{dim}" for dim, _ in closure_domain}) - for output_name in output_connectors_mapping.values() - ] - - input_mapping = { - param: arg for param, arg in zip([*input_names, *connectivity_names], input_memlets) - } - output_mapping = {param: memlet for param, memlet in zip(results, output_memlets)} - - symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, input_mapping) - - nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( - closure_state, - sdfg=nsdfg, - map_ranges=map_ranges or {"__dummy": "0"}, - inputs=input_mapping, - outputs=output_mapping, - symbol_mapping=symbol_mapping, - output_nodes=output_nodes, - debuginfo=nsdfg.debuginfo, - ) - access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} - for edge in closure_state.in_edges(map_exit): - memlet = edge.data - if memlet.data not in output_connectors_mapping: - continue - transient_access = closure_state.add_access(memlet.data, debuginfo=nsdfg.debuginfo) - closure_state.add_edge( - nsdfg_node, - edge.src_conn, - transient_access, - None, - dace.Memlet(data=memlet.data, subset=output_subset, debuginfo=nsdfg.debuginfo), - ) - inner_memlet = dace.Memlet( - data=memlet.data, subset=output_subset, other_subset=memlet.subset - ) - closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) - closure_state.remove_edge(edge) - access_nodes[memlet.data].data = output_connectors_mapping[memlet.data] - - return closure_sdfg, input_field_names + connectivity_names, output_names - - def _visit_scan_stencil_closure( - self, - node: itir.StencilClosure, - array_table: dict[str, dace.data.Array], - closure_domain: tuple[ - tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... - ], - output_name: str, - ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]: - # extract scan arguments - is_forward, init_carry_value = _get_scan_args(node.stencil) - # select the scan dimension based on program argument for column axis - assert self.column_axis - assert isinstance(node.output, SymRef) - scan_dim, scan_dim_index, scan_dtype = _get_scan_dim( - self.column_axis, - self.storage_types, - node.output, - self.use_field_canonical_representation, - ) - - assert isinstance(node.output, SymRef) - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} - for dim, (lb, ub) in closure_domain: - lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value - ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - if not dim == scan_dim: - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" - else: - scan_lb_str = lb_str - scan_ub_str = ub_str - - # the scan operator is implemented as an SDFG to be nested in the closure SDFG - scan_sdfg = dace.SDFG(name="scan") - scan_sdfg.debuginfo = dace_utils.debug_info(node) - - # the carry value of the scan operator exists only in the scope of the scan sdfg - scan_carry_name = unique_var_name() - scan_sdfg.add_scalar( - scan_carry_name, dtype=dace_utils.as_dace_type(scan_dtype), transient=True - ) - - # create a loop region for lambda call over the scan dimension - scan_loop_var = f"i_{scan_dim}" - if is_forward: - scan_loop = LoopRegion( - label="scan", - condition_expr=f"{scan_loop_var} < {scan_ub_str}", - loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_lb_str}", - update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", - inverted=False, - ) - else: - scan_loop = LoopRegion( - label="scan", - condition_expr=f"{scan_loop_var} >= {scan_lb_str}", - loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1", - update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", - inverted=False, - ) - scan_sdfg.add_node(scan_loop) - compute_state = scan_loop.add_state("lambda_compute", is_start_block=True) - update_state = scan_loop.add_state("lambda_update") - scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge()) - - start_state = scan_sdfg.add_state("start", is_start_block=True) - scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge()) - - # tasklet for initialization of carry - carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", - {}, - {"__result"}, - f"__result = {init_carry_value}", - debuginfo=scan_sdfg.debuginfo, - ) - start_state.add_edge( - carry_init_tasklet, - "__result", - start_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo), - None, - dace.Memlet(data=scan_carry_name, subset="0"), - ) - - # add storage to scan SDFG for inputs - for name in [*input_names, *connectivity_names]: - assert name not in scan_sdfg.arrays - if isinstance(self.storage_types[name], ts.FieldType): - scan_sdfg.add_array( - name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - ) - else: - scan_sdfg.add_scalar( - name, - dtype=dace_utils.as_dace_type(cast(ts.ScalarType, self.storage_types[name])), - ) - # add storage to scan SDFG for output - scan_sdfg.add_array( - output_name, - shape=(array_table[node.output.id].shape[scan_dim_index],), - strides=(array_table[node.output.id].strides[scan_dim_index],), - dtype=array_table[node.output.id].dtype, - ) - - # implement the lambda function as a nested SDFG that computes a single item in the scan dimension - lambda_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} - input_arrays = [(scan_carry_name, scan_dtype)] + [ - (name, self.storage_types[name]) for name in input_names - ] - connectivity_arrays = [(scan_sdfg.arrays[name], name) for name in connectivity_names] - lambda_context, lambda_outputs = closure_to_tasklet_sdfg( - node, - self.offset_provider_type, - lambda_domain, - input_arrays, - connectivity_arrays, - self.use_field_canonical_representation, - ) - - lambda_input_names = [name for name, _ in input_arrays] - lambda_output_names = [connector.value.data for connector in lambda_outputs] - - input_memlets = [ - dace.Memlet.from_array(name, scan_sdfg.arrays[name]) for name in lambda_input_names - ] - connectivity_memlets = [ - dace.Memlet.from_array(name, scan_sdfg.arrays[name]) for name in connectivity_names - ] - input_mapping = {param: arg for param, arg in zip(lambda_input_names, input_memlets)} - connectivity_mapping = { - param: arg for param, arg in zip(connectivity_names, connectivity_memlets) - } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping) - - scan_inner_node = compute_state.add_nested_sdfg( - lambda_context.body, - parent=scan_sdfg, - inputs=set(lambda_input_names) | set(connectivity_names), - outputs=set(lambda_output_names), - symbol_mapping=symbol_mapping, - debuginfo=lambda_context.body.debuginfo, - ) - - # connect scan SDFG to lambda inputs - for name, memlet in array_mapping.items(): - access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo) - compute_state.add_edge(access_node, None, scan_inner_node, name, memlet) - - output_names = [output_name] - assert len(lambda_output_names) == 1 - # connect lambda output to scan SDFG - for name, connector in zip(output_names, lambda_output_names): - compute_state.add_edge( - scan_inner_node, - connector, - compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo), - None, - dace.Memlet(data=name, subset=scan_loop_var), - ) - - update_state.add_nedge( - update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), - update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo), - dace.Memlet(data=output_name, subset=scan_loop_var, other_subset="0"), - ) - - return scan_sdfg, map_ranges, scan_dim_index - - def _visit_parallel_stencil_closure( - self, - node: itir.StencilClosure, - array_table: dict[str, dace.data.Array], - closure_domain: tuple[ - tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... - ], - ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} - for dim, (lb, ub) in closure_domain: - lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value - ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" - - # Create an SDFG for the tasklet that computes a single item of the output domain. - index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} - - input_arrays = [(name, self.storage_types[name]) for name in input_names] - connectivity_arrays = [(array_table[name], name) for name in connectivity_names] - - context, results = closure_to_tasklet_sdfg( - node, - self.offset_provider_type, - index_domain, - input_arrays, - connectivity_arrays, - self.use_field_canonical_representation, - ) - - return context.body, map_ranges, [r.value.data for r in results] - - def _visit_domain( - self, node: itir.FunCall, context: Context - ) -> tuple[tuple[str, tuple[SymbolExpr | ValueExpr, SymbolExpr | ValueExpr]], ...]: - assert isinstance(node.fun, itir.SymRef) - assert node.fun.id == "cartesian_domain" or node.fun.id == "unstructured_domain" - - bounds: list[tuple[str, tuple[ValueExpr, ValueExpr]]] = [] - - for named_range in node.args: - assert isinstance(named_range, itir.FunCall) - assert isinstance(named_range.fun, itir.SymRef) - assert len(named_range.args) == 3 - dimension = named_range.args[0] - assert isinstance(dimension, itir.AxisLiteral) - lower_bound = named_range.args[1] - upper_bound = named_range.args[2] - translator = PythonTaskletCodegen( - self.offset_provider_type, - context, - self.use_field_canonical_representation, - ) - lb = translator.visit(lower_bound)[0] - ub = translator.visit(upper_bound)[0] - bounds.append((dimension.value, (lb, ub))) - - return tuple(bounds) - - @staticmethod - def _check_shift_offsets_are_literals(node: itir.StencilClosure): - fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall) - shifts = [nd for nd in fun_calls if getattr(nd.fun, "id", "") == "shift"] - for shift in shifts: - if not all(isinstance(arg, (itir.Literal, itir.OffsetLiteral)) for arg in shift.args): - return False - return True diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py deleted file mode 100644 index 2b2669187a..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ /dev/null @@ -1,1564 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import copy -import dataclasses -import itertools -from collections.abc import Sequence -from typing import Any, Callable, Optional, TypeAlias, cast - -import dace -import numpy as np - -import gt4py.eve.codegen -from gt4py import eve -from gt4py.next import common -from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir import FunCall, Lambda -from gt4py.next.iterator.type_system import type_specifications as it_ts -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_specifications as ts - -from .utility import ( - add_mapped_nested_sdfg, - flatten_list, - get_used_connectivities, - map_nested_sdfg_symbols, - new_array_symbols, - unique_name, - unique_var_name, -) - - -_TYPE_MAPPING = { - "float": dace.float64, - "float32": dace.float32, - "float64": dace.float64, - "int": dace.int32 if np.dtype(int).itemsize == 4 else dace.int64, - "int32": dace.int32, - "int64": dace.int64, - "bool": dace.bool_, -} - - -def itir_type_as_dace_type(type_: ts.TypeSpec): - # TODO(tehrengruber): this function just converts the scalar type of whatever it is given, - # let it be a field, iterator, or directly a scalar. The caller should take care of the - # extraction. - dtype: ts.TypeSpec - if isinstance(type_, ts.FieldType): - dtype = type_.dtype - elif isinstance(type_, it_ts.IteratorType): - dtype = type_.element_type - else: - dtype = type_ - assert isinstance(dtype, ts.ScalarType) - return _TYPE_MAPPING[dtype.kind.name.lower()] - - -def get_reduce_identity_value(op_name_: str, type_: Any): - if op_name_ == "plus": - init_value = type_(0) - elif op_name_ == "multiplies": - init_value = type_(1) - elif op_name_ == "minimum": - init_value = type_("inf") - elif op_name_ == "maximum": - init_value = type_("-inf") - else: - raise NotImplementedError() - - return init_value - - -_MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} - - -# Define type of variables used for field indexing -_INDEX_DTYPE = _TYPE_MAPPING["int64"] - - -@dataclasses.dataclass -class SymbolExpr: - value: dace.symbolic.SymbolicType - dtype: dace.typeclass - - -@dataclasses.dataclass -class ValueExpr: - value: dace.nodes.AccessNode - dtype: dace.typeclass - - -@dataclasses.dataclass -class IteratorExpr: - field: dace.nodes.AccessNode - indices: dict[str, dace.nodes.AccessNode] - dtype: dace.typeclass - dimensions: list[str] - - -# Union of possible expression types -TaskletExpr: TypeAlias = IteratorExpr | SymbolExpr | ValueExpr - - -@dataclasses.dataclass -class Context: - body: dace.SDFG - state: dace.SDFGState - symbol_map: dict[str, TaskletExpr] - # if we encounter a reduction node, the reduction state needs to be pushed to child nodes - reduce_identity: Optional[SymbolExpr] - - def __init__( - self, - body: dace.SDFG, - state: dace.SDFGState, - symbol_map: dict[str, TaskletExpr], - reduce_identity: Optional[SymbolExpr] = None, - ): - self.body = body - self.state = state - self.symbol_map = symbol_map - self.reduce_identity = reduce_identity - - -def _visit_lift_in_neighbors_reduction( - transformer: PythonTaskletCodegen, - node: itir.FunCall, - node_args: Sequence[IteratorExpr | list[ValueExpr]], - connectivity_type: common.NeighborConnectivityType, - map_entry: dace.nodes.MapEntry, - map_exit: dace.nodes.MapExit, - neighbor_index_node: dace.nodes.AccessNode, - neighbor_value_node: dace.nodes.AccessNode, -) -> list[ValueExpr]: - assert transformer.context.reduce_identity is not None - neighbor_dim = connectivity_type.codomain.value - origin_dim = connectivity_type.source_dim.value - - lifted_args: list[IteratorExpr | ValueExpr] = [] - for arg in node_args: - if isinstance(arg, IteratorExpr): - if origin_dim in arg.indices: - lifted_indices = arg.indices.copy() - lifted_indices.pop(origin_dim) - lifted_indices[neighbor_dim] = neighbor_index_node - lifted_args.append( - IteratorExpr(arg.field, lifted_indices, arg.dtype, arg.dimensions) - ) - else: - lifted_args.append(arg) - else: - lifted_args.append(arg[0]) - - lift_context, inner_inputs, inner_outputs = transformer.visit(node.args[0], args=lifted_args) - assert len(inner_outputs) == 1 - inner_out_connector = inner_outputs[0].value.data - - input_nodes = {} - iterator_index_nodes = {} - lifted_index_connectors = [] - - for x, y in inner_inputs: - if isinstance(y, IteratorExpr): - field_connector, inner_index_table = x - input_nodes[field_connector] = y.field - for dim, connector in inner_index_table.items(): - if dim == neighbor_dim: - lifted_index_connectors.append(connector) - iterator_index_nodes[connector] = y.indices[dim] - else: - assert isinstance(y, ValueExpr) - input_nodes[x] = y.value - - neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider_type) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - parent_sdfg = transformer.context.body - parent_state = transformer.context.state - - input_mapping = { - connector: dace.Memlet.from_array(node.data, node.desc(parent_sdfg)) - for connector, node in input_nodes.items() - } - connectivity_mapping = { - name: dace.Memlet.from_array(name, parent_sdfg.arrays[name]) for name in connectivity_names - } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(parent_sdfg, lift_context.body, array_mapping) - - nested_sdfg_node = parent_state.add_nested_sdfg( - lift_context.body, - parent_sdfg, - inputs={*array_mapping.keys(), *iterator_index_nodes.keys()}, - outputs={inner_out_connector}, - symbol_mapping=symbol_mapping, - debuginfo=lift_context.body.debuginfo, - ) - - for connectivity_connector, memlet in connectivity_mapping.items(): - parent_state.add_memlet_path( - parent_state.add_access(memlet.data, debuginfo=lift_context.body.debuginfo), - map_entry, - nested_sdfg_node, - dst_conn=connectivity_connector, - memlet=memlet, - ) - - for inner_connector, access_node in input_nodes.items(): - parent_state.add_memlet_path( - access_node, - map_entry, - nested_sdfg_node, - dst_conn=inner_connector, - memlet=input_mapping[inner_connector], - ) - - for inner_connector, access_node in iterator_index_nodes.items(): - memlet = dace.Memlet(data=access_node.data, subset="0") - if inner_connector in lifted_index_connectors: - parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet) - else: - parent_state.add_memlet_path( - access_node, map_entry, nested_sdfg_node, dst_conn=inner_connector, memlet=memlet - ) - - parent_state.add_memlet_path( - nested_sdfg_node, - map_exit, - neighbor_value_node, - src_conn=inner_out_connector, - memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), - ) - - if connectivity_type.has_skip_values: - # check neighbor validity on if/else inter-state edge - # use one branch for connectivity case - start_state = lift_context.body.add_state_before( - lift_context.body.start_state, - "start", - condition=f"{lifted_index_connectors[0]} != {neighbor_skip_value}", - ) - # use the other branch for skip value case - skip_neighbor_state = lift_context.body.add_state("skip_neighbor") - skip_neighbor_state.add_edge( - skip_neighbor_state.add_tasklet( - "identity", {}, {"val"}, f"val = {transformer.context.reduce_identity.value}" - ), - "val", - skip_neighbor_state.add_access(inner_outputs[0].value.data), - None, - dace.Memlet(data=inner_outputs[0].value.data, subset="0"), - ) - lift_context.body.add_edge( - start_state, - skip_neighbor_state, - dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} == {neighbor_skip_value}"), - ) - - return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)] - - -def builtin_neighbors( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - sdfg: dace.SDFG = transformer.context.body - state: dace.SDFGState = transformer.context.state - - di = dace_utils.debug_info(node, default=sdfg.debuginfo) - offset_literal, data = node_args - assert isinstance(offset_literal, itir.OffsetLiteral) - offset_dim = offset_literal.value - assert isinstance(offset_dim, str) - connectivity_type = transformer.offset_provider_type[offset_dim] - if not isinstance(connectivity_type, common.NeighborConnectivityType): - raise NotImplementedError( - "Neighbor reduction only implemented for connectivity based on neighbor tables." - ) - - lift_node = None - if isinstance(data, FunCall): - assert isinstance(data.fun, itir.FunCall) - fun_node = data.fun - if isinstance(fun_node.fun, itir.SymRef) and fun_node.fun.id == "lift": - lift_node = fun_node - lift_args = transformer.visit(data.args) - iterator = next(filter(lambda x: isinstance(x, IteratorExpr), lift_args), None) - if lift_node is None: - iterator = transformer.visit(data) - assert isinstance(iterator, IteratorExpr) - field_desc = iterator.field.desc(transformer.context.body) - origin_index_node = iterator.indices[connectivity_type.source_dim.value] - - assert transformer.context.reduce_identity is not None - assert transformer.context.reduce_identity.dtype == iterator.dtype - - # gather the neighbors in a result array dimensioned for `max_neighbors` - neighbor_value_var = unique_var_name() - sdfg.add_array( - neighbor_value_var, - dtype=iterator.dtype, - shape=(connectivity_type.max_neighbors,), - transient=True, - ) - neighbor_value_node = state.add_access(neighbor_value_var, debuginfo=di) - - # allocate scalar to store index for direct addressing of neighbor field - neighbor_index_var = unique_var_name() - sdfg.add_scalar(neighbor_index_var, _INDEX_DTYPE, transient=True) - neighbor_index_node = state.add_access(neighbor_index_var, debuginfo=di) - - # generate unique map index name to avoid conflict with other maps inside same state - neighbor_map_index = unique_name(f"{offset_dim}_neighbor_map_idx") - me, mx = state.add_map( - f"{offset_dim}_neighbor_map", - ndrange={neighbor_map_index: f"0:{connectivity_type.max_neighbors}"}, - debuginfo=di, - ) - - table_name = dace_utils.connectivity_identifier(offset_dim) - shift_tasklet = state.add_tasklet( - "shift", - code=f"__result = __table[__idx, {neighbor_map_index}]", - inputs={"__table", "__idx"}, - outputs={"__result"}, - debuginfo=di, - ) - state.add_memlet_path( - state.add_access(table_name, debuginfo=di), - me, - shift_tasklet, - memlet=dace.Memlet.from_array(table_name, sdfg.arrays[table_name]), - dst_conn="__table", - ) - state.add_memlet_path( - origin_index_node, - me, - shift_tasklet, - memlet=dace.Memlet(data=origin_index_node.data, subset="0"), - dst_conn="__idx", - ) - state.add_edge( - shift_tasklet, - "__result", - neighbor_index_node, - None, - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - - if lift_node is not None: - _visit_lift_in_neighbors_reduction( - transformer, - lift_node, - lift_args, - connectivity_type, - me, - mx, - neighbor_index_node, - neighbor_value_node, - ) - else: - sorted_dims = transformer.get_sorted_field_dimensions(iterator.dimensions) - data_access_index = ",".join(f"{dim}_v" for dim in sorted_dims) - connector_neighbor_dim = f"{connectivity_type.codomain.value}_v" - data_access_tasklet = state.add_tasklet( - "data_access", - code=f"__data = __field[{data_access_index}] " - + ( - f"if {connector_neighbor_dim} != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if connectivity_type.has_skip_values - else "" - ), - inputs={"__field"} | {f"{dim}_v" for dim in iterator.dimensions}, - outputs={"__data"}, - debuginfo=di, - ) - state.add_memlet_path( - iterator.field, - me, - data_access_tasklet, - memlet=dace.Memlet.from_array(iterator.field.data, field_desc), - dst_conn="__field", - ) - for dim in iterator.dimensions: - connector = f"{dim}_v" - if dim == connectivity_type.codomain.value: - state.add_edge( - neighbor_index_node, - None, - data_access_tasklet, - connector, - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - else: - state.add_memlet_path( - iterator.indices[dim], - me, - data_access_tasklet, - dst_conn=connector, - memlet=dace.Memlet(data=iterator.indices[dim].data, subset="0"), - ) - - state.add_memlet_path( - data_access_tasklet, - mx, - neighbor_value_node, - memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index), - src_conn="__data", - ) - - if not connectivity_type.has_skip_values: - return [ValueExpr(neighbor_value_node, iterator.dtype)] - else: - """ - In case of neighbor tables with skip values, in addition to the array of neighbor values this function also - returns an array of booleans to indicate if the neighbor value is present or not. This node is only used - for neighbor reductions with lambda functions, a very specific case. For single input neighbor reductions, - the regular case, this node will be removed by the simplify pass. - """ - neighbor_valid_var = unique_var_name() - sdfg.add_array( - neighbor_valid_var, - dtype=dace.dtypes.bool, - shape=(connectivity_type.max_neighbors,), - transient=True, - ) - neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) - - neighbor_valid_tasklet = state.add_tasklet( - f"check_valid_neighbor_{offset_dim}", - {"__idx"}, - {"__valid"}, - f"__valid = True if __idx != {neighbor_skip_value} else False", - debuginfo=di, - ) - state.add_edge( - neighbor_index_node, - None, - neighbor_valid_tasklet, - "__idx", - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - state.add_memlet_path( - neighbor_valid_tasklet, - mx, - neighbor_valid_node, - memlet=dace.Memlet(data=neighbor_valid_var, subset=neighbor_map_index), - src_conn="__valid", - ) - return [ - ValueExpr(neighbor_value_node, iterator.dtype), - ValueExpr(neighbor_valid_node, dace.dtypes.bool), - ] - - -def builtin_can_deref( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - # first visit shift, to get set of indices for deref - can_deref_callable = node_args[0] - assert isinstance(can_deref_callable, itir.FunCall) - shift_callable = can_deref_callable.fun - assert isinstance(shift_callable, itir.FunCall) - assert isinstance(shift_callable.fun, itir.SymRef) - assert shift_callable.fun.id == "shift" - iterator = transformer._visit_shift(can_deref_callable) - - # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it - if not isinstance(iterator, IteratorExpr): - assert len(iterator) == 1 and isinstance(iterator[0], ValueExpr) - # We can always deref a value expression, therefore hard-code `can_deref` to True. - # Returning a SymbolExpr would be preferable, but it requires update to type-checking. - result_name = unique_var_name() - transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True) - result_node = transformer.context.state.add_access(result_name, debuginfo=di) - transformer.context.state.add_edge( - transformer.context.state.add_tasklet( - "can_always_deref", {}, {"_out"}, "_out = True", debuginfo=di - ), - "_out", - result_node, - None, - dace.Memlet(data=result_name, subset="0"), - ) - return [ValueExpr(result_node, dace.dtypes.bool)] - - # create tasklet to check that field indices are non-negative (-1 is invalid) - args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()] - internals = [f"{arg.value.data}_v" for arg in args] - expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals) - - return transformer.add_expr_tasklet( - list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref", dace_debuginfo=di - ) - - -def builtin_if( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - assert len(node_args) == 3 - sdfg = transformer.context.body - current_state = transformer.context.state - is_start_state = sdfg.start_block == current_state - - # build an empty state to join true and false branches - join_state = sdfg.add_state_before(current_state, "join") - - def build_if_state(arg, state): - symbol_map = copy.deepcopy(transformer.context.symbol_map) - node_context = Context(sdfg, state, symbol_map) - node_taskgen = PythonTaskletCodegen( - transformer.offset_provider_type, - node_context, - transformer.use_field_canonical_representation, - ) - return node_taskgen.visit(arg) - - # represent the if-statement condition as a tasklet inside an `if_statement` state preceding `join` state - stmt_state = sdfg.add_state_before(join_state, "if_statement", is_start_state) - stmt_node = build_if_state(node_args[0], stmt_state)[0] - assert isinstance(stmt_node, ValueExpr) - assert stmt_node.dtype == dace.dtypes.bool - assert sdfg.arrays[stmt_node.value.data].shape == (1,) - - # visit true and false branches (here called `tbr` and `fbr`) as separate states, following `if_statement` state - tbr_state = sdfg.add_state("true_branch") - sdfg.add_edge( - stmt_state, tbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == True") - ) - sdfg.add_edge(tbr_state, join_state, dace.InterstateEdge()) - tbr_values = flatten_list(build_if_state(node_args[1], tbr_state)) - # - fbr_state = sdfg.add_state("false_branch") - sdfg.add_edge( - stmt_state, fbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == False") - ) - sdfg.add_edge(fbr_state, join_state, dace.InterstateEdge()) - fbr_values = flatten_list(build_if_state(node_args[2], fbr_state)) - - assert isinstance(stmt_node, ValueExpr) - assert stmt_node.dtype == dace.dtypes.bool - # make the result of the if-statement evaluation available inside current state - ctx_stmt_node = ValueExpr(current_state.add_access(stmt_node.value.data), stmt_node.dtype) - - # we distinguish between select if-statements, where both true and false branches are symbolic expressions, - # and therefore do not require exclusive branch execution, and regular if-statements where at least one branch - # is a value expression, which has to be evaluated at runtime with conditional state transition - result_values = [] - assert len(tbr_values) == len(fbr_values) - for tbr_value, fbr_value in zip(tbr_values, fbr_values): - assert isinstance(tbr_value, (SymbolExpr, ValueExpr)) - assert isinstance(fbr_value, (SymbolExpr, ValueExpr)) - assert tbr_value.dtype == fbr_value.dtype - - if all(isinstance(x, SymbolExpr) for x in (tbr_value, fbr_value)): - # both branches return symbolic expressions, therefore the if-node can be translated - # to a select-tasklet inside current state - # TODO: use select-memlet when it becomes available in dace - code = f"{tbr_value.value} if _cond else {fbr_value.value}" - if_expr = transformer.add_expr_tasklet( - [(ctx_stmt_node, "_cond")], code, tbr_value.dtype, "if_select" - )[0] - result_values.append(if_expr) - else: - # at least one of the two branches contains a value expression, which should be evaluated - # only if the corresponding true/false condition is satisfied - desc = sdfg.arrays[ - tbr_value.value.data if isinstance(tbr_value, ValueExpr) else fbr_value.value.data - ] - var = unique_var_name() - if isinstance(desc, dace.data.Scalar): - sdfg.add_scalar(var, desc.dtype, transient=True) - else: - sdfg.add_array(var, desc.shape, desc.dtype, transient=True) - - # write result to transient data container and access it in the original state - for state, expr in [(tbr_state, tbr_value), (fbr_state, fbr_value)]: - val_node = state.add_access(var) - if isinstance(expr, ValueExpr): - state.add_nedge( - expr.value, val_node, dace.Memlet.from_array(expr.value.data, desc) - ) - else: - assert desc.shape == (1,) - state.add_edge( - state.add_tasklet("write_symbol", {}, {"_out"}, f"_out = {expr.value}"), - "_out", - val_node, - None, - dace.Memlet(var, "0"), - ) - result_values.append(ValueExpr(current_state.add_access(var), desc.dtype)) - - if tbr_state.is_empty() and fbr_state.is_empty(): - # if all branches are symbolic expressions, the true/false and join states can be removed - # as well as the conditional state transition - sdfg.remove_nodes_from([join_state, tbr_state, fbr_state]) - sdfg.add_edge(stmt_state, current_state, dace.InterstateEdge()) - elif tbr_state.is_empty(): - # use direct edge from if-statement to join state for true branch - tbr_condition = sdfg.edges_between(stmt_state, tbr_state)[0].condition - sdfg.edges_between(stmt_state, join_state)[0].contition = tbr_condition - sdfg.remove_node(tbr_state) - elif fbr_state.is_empty(): - # use direct edge from if-statement to join state for false branch - fbr_condition = sdfg.edges_between(stmt_state, fbr_state)[0].condition - sdfg.edges_between(stmt_state, join_state)[0].contition = fbr_condition - sdfg.remove_node(fbr_state) - else: - # remove direct edge from if-statement to join state - sdfg.remove_edge(sdfg.edges_between(stmt_state, join_state)[0]) - # the if-statement condition is not used in current state - current_state.remove_node(ctx_stmt_node.value) - - return result_values - - -def builtin_list_get( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = list(itertools.chain(*transformer.visit(node_args))) - assert len(args) == 2 - # index node - if isinstance(args[0], SymbolExpr): - index_value = args[0].value - result_name = unique_var_name() - transformer.context.body.add_scalar(result_name, args[1].dtype, transient=True) - result_node = transformer.context.state.add_access(result_name) - transformer.context.state.add_nedge( - args[1].value, result_node, dace.Memlet(data=args[1].value.data, subset=index_value) - ) - return [ValueExpr(result_node, args[1].dtype)] - - else: - expr_args = [(arg, f"{arg.value.data}_v") for arg in args] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[1]}[{internals[0]}]" - return transformer.add_expr_tasklet( - expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di - ) - - -def builtin_cast( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = transformer.visit(node_args[0]) - internals = [f"{arg.value.data}_v" for arg in args] - target_type = node_args[1] - assert isinstance(target_type, itir.SymRef) - expr = _MATH_BUILTINS_MAPPING[target_type.id].format(*internals) - type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - return transformer.add_expr_tasklet( - list(zip(args, internals)), expr, type_, "cast", dace_debuginfo=di - ) - - -def builtin_make_const_list( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = [transformer.visit(arg)[0] for arg in node_args] - assert all(isinstance(x, (SymbolExpr, ValueExpr)) for x in args) - args_dtype = [x.dtype for x in args] - assert len(set(args_dtype)) == 1 - dtype = args_dtype[0] - - var_name = unique_var_name() - transformer.context.body.add_array(var_name, (len(args),), dtype, transient=True) - var_node = transformer.context.state.add_access(var_name, debuginfo=di) - - for i, arg in enumerate(args): - if isinstance(arg, SymbolExpr): - transformer.context.state.add_edge( - transformer.context.state.add_tasklet( - f"get_arg{i}", {}, {"val"}, f"val = {arg.value}" - ), - "val", - var_node, - None, - dace.Memlet(data=var_name, subset=f"{i}"), - ) - else: - assert arg.value.desc(transformer.context.body).shape == (1,) - transformer.context.state.add_nedge( - arg.value, - var_node, - dace.Memlet(data=arg.value.data, subset="0", other_subset=f"{i}"), - ) - - return [ValueExpr(var_node, dtype)] - - -def builtin_make_tuple( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - args = [transformer.visit(arg) for arg in node_args] - return args - - -def builtin_tuple_get( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - elements = transformer.visit(node_args[1]) - index = node_args[0] - if isinstance(index, itir.Literal): - return [elements[int(index.value)]] - raise ValueError("Tuple can only be subscripted with compile-time constants.") - - -_GENERAL_BUILTIN_MAPPING: dict[ - str, Callable[[PythonTaskletCodegen, itir.Expr, list[itir.Expr]], list[ValueExpr]] -] = { - "can_deref": builtin_can_deref, - "cast_": builtin_cast, - "if_": builtin_if, - "list_get": builtin_list_get, - "make_const_list": builtin_make_const_list, - "make_tuple": builtin_make_tuple, - "neighbors": builtin_neighbors, - "tuple_get": builtin_tuple_get, -} - - -class GatherLambdaSymbolsPass(eve.NodeVisitor): - _sdfg: dace.SDFG - _state: dace.SDFGState - _symbol_map: dict[str, TaskletExpr | tuple[ValueExpr]] - _parent_symbol_map: dict[str, TaskletExpr] - - def __init__(self, sdfg, state, parent_symbol_map): - self._sdfg = sdfg - self._state = state - self._symbol_map = {} - self._parent_symbol_map = parent_symbol_map - - @property - def symbol_refs(self): - """Dictionary of symbols referenced from the lambda expression.""" - return self._symbol_map - - def _add_symbol(self, param, arg): - if isinstance(arg, ValueExpr): - # create storage in lambda sdfg - self._sdfg.add_scalar(param, dtype=arg.dtype) - # update table of lambda symbols - self._symbol_map[param] = ValueExpr( - self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype - ) - elif isinstance(arg, IteratorExpr): - # create storage in lambda sdfg - ndims = len(arg.dimensions) - shape, strides = new_array_symbols(param, ndims) - self._sdfg.add_array(param, shape=shape, strides=strides, dtype=arg.dtype) - index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()} - for _, index_name in index_names.items(): - self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) - # update table of lambda symbols - field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) - indices = { - dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo) - for dim, index_arg in index_names.items() - } - self._symbol_map[param] = IteratorExpr(field, indices, arg.dtype, arg.dimensions) - else: - assert isinstance(arg, SymbolExpr) - self._symbol_map[param] = arg - - def _add_tuple(self, param, args): - nodes = [] - # create storage in lambda sdfg for each tuple element - for arg in args: - var = unique_var_name() - self._sdfg.add_scalar(var, dtype=arg.dtype) - arg_node = self._state.add_access(var, debuginfo=self._sdfg.debuginfo) - nodes.append(ValueExpr(arg_node, arg.dtype)) - # update table of lambda symbols - self._symbol_map[param] = tuple(nodes) - - def visit_SymRef(self, node: itir.SymRef): - name = str(node.id) - if name in self._parent_symbol_map and name not in self._symbol_map: - arg = self._parent_symbol_map[name] - self._add_symbol(name, arg) - - def visit_Lambda(self, node: itir.Lambda, args: Optional[Sequence[TaskletExpr]] = None): - if args is not None: - if len(node.params) == len(args): - for param, arg in zip(node.params, args): - self._add_symbol(str(param.id), arg) - else: - # implicitly make tuple - assert len(node.params) == 1 - self._add_tuple(str(node.params[0].id), args) - self.visit(node.expr) - - -class GatherOutputSymbolsPass(eve.NodeVisitor): - _sdfg: dace.SDFG - _state: dace.SDFGState - _symbol_map: dict[str, TaskletExpr] - - @property - def symbol_refs(self): - """Dictionary of symbols referenced from the output expression.""" - return self._symbol_map - - def __init__(self, sdfg, state): - self._sdfg = sdfg - self._state = state - self._symbol_map = {} - - def visit_SymRef(self, node: itir.SymRef): - param = str(node.id) - if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: - access_node = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) - self._symbol_map[param] = ValueExpr( - access_node, - dtype=itir_type_as_dace_type(node.type), # type: ignore[arg-type] # ensure by type inference - ) - - -@dataclasses.dataclass -class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): - offset_provider_type: common.OffsetProviderType - context: Context - use_field_canonical_representation: bool - - def get_sorted_field_dimensions(self, dims: Sequence[str]): - return sorted(dims) if self.use_field_canonical_representation else dims - - def visit_FunctionDefinition(self, node: itir.FunctionDefinition, **kwargs): - raise NotImplementedError() - - def visit_Lambda( - self, node: itir.Lambda, args: Sequence[TaskletExpr], use_neighbor_tables: bool = True - ) -> tuple[ - Context, - list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]], - list[ValueExpr], - ]: - func_name = f"lambda_{abs(hash(node)):x}" - neighbor_tables = ( - get_used_connectivities(node, self.offset_provider_type) if use_neighbor_tables else {} - ) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # Create the SDFG for the lambda's body - lambda_sdfg = dace.SDFG(func_name) - lambda_sdfg.debuginfo = dace_utils.debug_info(node, default=self.context.body.debuginfo) - lambda_state = lambda_sdfg.add_state(f"{func_name}_body", is_start_block=True) - - lambda_symbols_pass = GatherLambdaSymbolsPass( - lambda_sdfg, lambda_state, self.context.symbol_map - ) - lambda_symbols_pass.visit(node, args=args) - - # Add for input nodes for lambda symbols - inputs: list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]] = [] - for sym, input_node in lambda_symbols_pass.symbol_refs.items(): - params = [str(p.id) for p in node.params] - try: - param_index = params.index(sym) - except ValueError: - param_index = -1 - if param_index >= 0: - outer_node = args[param_index] - else: - # the symbol is not found among lambda arguments, then it is inherited from parent scope - outer_node = self.context.symbol_map[sym] - if isinstance(input_node, IteratorExpr): - assert isinstance(outer_node, IteratorExpr) - index_params = { - dim: index_node.data for dim, index_node in input_node.indices.items() - } - inputs.append(((sym, index_params), outer_node)) - elif isinstance(input_node, ValueExpr): - assert isinstance(outer_node, ValueExpr) - inputs.append((sym, outer_node)) - elif isinstance(input_node, tuple): - assert param_index >= 0 - for i, input_node_i in enumerate(input_node): - arg_i = args[param_index + i] - assert isinstance(arg_i, ValueExpr) - assert isinstance(input_node_i, ValueExpr) - inputs.append((input_node_i.value.data, arg_i)) - - # Add connectivities as arrays - for name in connectivity_names: - shape, strides = new_array_symbols(name, ndim=2) - dtype = self.context.body.arrays[name].dtype - lambda_sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) - - # Translate the lambda's body in its own context - lambda_context = Context( - lambda_sdfg, - lambda_state, - lambda_symbols_pass.symbol_refs, - reduce_identity=self.context.reduce_identity, - ) - lambda_taskgen = PythonTaskletCodegen( - self.offset_provider_type, - lambda_context, - self.use_field_canonical_representation, - ) - - results: list[ValueExpr] = [] - # We are flattening the returned list of value expressions because the multiple outputs of a lambda - # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. - node.expr.location = node.location - for expr in flatten_list(lambda_taskgen.visit(node.expr)): - if isinstance(expr, ValueExpr): - result_name = unique_var_name() - lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) - result_access = lambda_state.add_access( - result_name, debuginfo=lambda_sdfg.debuginfo - ) - lambda_state.add_nedge( - expr.value, result_access, dace.Memlet(data=result_access.data, subset="0") - ) - result = ValueExpr(value=result_access, dtype=expr.dtype) - else: - # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = lambda_taskgen.add_expr_tasklet( - [], expr.value, expr.dtype, "forward", dace_debuginfo=lambda_sdfg.debuginfo - )[0] - lambda_sdfg.arrays[result.value.data].transient = False - results.append(result) - - # remove isolated access nodes for connectivity arrays not consumed by lambda - for sub_node in lambda_state.nodes(): - if isinstance(sub_node, dace.nodes.AccessNode): - if lambda_state.out_degree(sub_node) == 0 and lambda_state.in_degree(sub_node) == 0: - lambda_state.remove_node(sub_node) - - return lambda_context, inputs, results - - def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr: - param = str(node.id) - value = self.context.symbol_map[param] - if isinstance(value, (ValueExpr, SymbolExpr)): - return [value] - return value - - def visit_Literal(self, node: itir.Literal) -> list[SymbolExpr]: - return [SymbolExpr(node.value, itir_type_as_dace_type(node.type))] - - def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: - node.fun.location = node.location - if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": - return self._visit_deref(node) - if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): - if node.fun.fun.id == "shift": - return self._visit_shift(node) - elif node.fun.fun.id == "reduce": - return self._visit_reduce(node) - - if isinstance(node.fun, itir.SymRef): - builtin_name = str(node.fun.id) - if builtin_name in _MATH_BUILTINS_MAPPING: - return self._visit_numeric_builtin(node) - elif builtin_name in _GENERAL_BUILTIN_MAPPING: - return self._visit_general_builtin(node) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - return self._visit_call(node) - - def _visit_call(self, node: itir.FunCall): - args = self.visit(node.args) - args = [arg if isinstance(arg, Sequence) else [arg] for arg in args] - args = list(itertools.chain(*args)) - node.fun.location = node.location - func_context, func_inputs, results = self.visit(node.fun, args=args) - - nsdfg_inputs = {} - for name, value in func_inputs: - if isinstance(value, ValueExpr): - nsdfg_inputs[name] = dace.Memlet.from_array( - value.value.data, self.context.body.arrays[value.value.data] - ) - else: - assert isinstance(value, IteratorExpr) - field = name[0] - indices = name[1] - nsdfg_inputs[field] = dace.Memlet.from_array( - value.field.data, self.context.body.arrays[value.field.data] - ) - for dim, var in indices.items(): - store = value.indices[dim].data - nsdfg_inputs[var] = dace.Memlet.from_array( - store, self.context.body.arrays[store] - ) - - neighbor_tables = get_used_connectivities(node.fun, self.offset_provider_type) - for offset in neighbor_tables.keys(): - var = dace_utils.connectivity_identifier(offset) - nsdfg_inputs[var] = dace.Memlet.from_array(var, self.context.body.arrays[var]) - - symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) - - nsdfg_node = self.context.state.add_nested_sdfg( - func_context.body, - None, - inputs=set(nsdfg_inputs.keys()), - outputs=set(r.value.data for r in results), - symbol_mapping=symbol_mapping, - debuginfo=dace_utils.debug_info(node, default=func_context.body.debuginfo), - ) - - for name, value in func_inputs: - if isinstance(value, ValueExpr): - value_memlet = nsdfg_inputs[name] - self.context.state.add_edge(value.value, None, nsdfg_node, name, value_memlet) - else: - assert isinstance(value, IteratorExpr) - field = name[0] - indices = name[1] - field_memlet = nsdfg_inputs[field] - self.context.state.add_edge(value.field, None, nsdfg_node, field, field_memlet) - for dim, var in indices.items(): - store = value.indices[dim] - idx_memlet = nsdfg_inputs[var] - self.context.state.add_edge(store, None, nsdfg_node, var, idx_memlet) - for offset in neighbor_tables.keys(): - var = dace_utils.connectivity_identifier(offset) - memlet = nsdfg_inputs[var] - access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) - self.context.state.add_edge(access, None, nsdfg_node, var, memlet) - - result_exprs = [] - for result in results: - name = unique_var_name() - self.context.body.add_scalar(name, result.dtype, transient=True) - result_access = self.context.state.add_access(name, debuginfo=nsdfg_node.debuginfo) - result_exprs.append(ValueExpr(result_access, result.dtype)) - memlet = dace.Memlet.from_array(name, self.context.body.arrays[name]) - self.context.state.add_edge(nsdfg_node, result.value.data, result_access, None, memlet) - - return result_exprs - - def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - iterator = self.visit(node.args[0]) - if not isinstance(iterator, IteratorExpr): - # already a list of ValueExpr - return iterator - - sorted_dims = self.get_sorted_field_dimensions(iterator.dimensions) - if all([dim in iterator.indices for dim in iterator.dimensions]): - # The deref iterator has index values on all dimensions: the result will be a scalar - args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in sorted_dims - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet( - list(zip(args, internals)), expr, iterator.dtype, "deref", dace_debuginfo=di - ) - - else: - dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices] - assert len(dims_not_indexed) == 1 - offset = dims_not_indexed[0] - offset_provider_type = self.offset_provider_type[offset] - assert isinstance(offset_provider_type, common.NeighborConnectivityType) - neighbor_dim = offset_provider_type.codomain.value - - result_name = unique_var_name() - self.context.body.add_array( - result_name, (offset_provider_type.max_neighbors,), iterator.dtype, transient=True - ) - result_array = self.context.body.arrays[result_name] - result_node = self.context.state.add_access(result_name, debuginfo=di) - - deref_connectors = ["_inp"] + [ - f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices - ] - deref_nodes = [iterator.field] + [ - iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices - ] - deref_memlets = [ - dace.Memlet.from_array(iterator.field.data, iterator.field.desc(self.context.body)) - ] + [dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:]] - - # we create a mapped tasklet for array slicing - index_name = unique_name(f"_i_{neighbor_dim}") - map_ranges = {index_name: f"0:{offset_provider_type.max_neighbors}"} - src_subset = ",".join( - [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] - ) - self.context.state.add_mapped_tasklet( - "deref", - map_ranges, - inputs={k: v for k, v in zip(deref_connectors, deref_memlets)}, - outputs={"_out": dace.Memlet.from_array(result_name, result_array)}, - code=f"_out[{index_name}] = _inp[{src_subset}]", - external_edges=True, - input_nodes={node.data: node for node in deref_nodes}, - output_nodes={result_name: result_node}, - debuginfo=di, - ) - return [ValueExpr(result_node, iterator.dtype)] - - def _split_shift_args( - self, args: list[itir.Expr] - ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: - pairs = [args[i : i + 2] for i in range(0, len(args), 2)] - assert len(pairs) >= 1 - assert all(len(pair) == 2 for pair in pairs) - return pairs[-1], list(itertools.chain(*pairs[0:-1])) if len(pairs) > 1 else None - - def _make_shift_for_rest(self, rest, iterator): - return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), - args=[iterator], - location=iterator.location, - ) - - def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - shift = node.fun - assert isinstance(shift, itir.FunCall) - tail, rest = self._split_shift_args(shift.args) - if rest: - iterator = self.visit(self._make_shift_for_rest(rest, node.args[0])) - else: - iterator = self.visit(node.args[0]) - if not isinstance(iterator, IteratorExpr): - # shift cannot be applied because the argument is not iterable - # TODO: remove this special case when ITIR pass is able to catch it - assert isinstance(iterator, list) and len(iterator) == 1 - assert isinstance(iterator[0], ValueExpr) - return iterator - - assert isinstance(tail[0], itir.OffsetLiteral) - offset_dim = tail[0].value - assert isinstance(offset_dim, str) - offset_node = self.visit(tail[1])[0] - assert offset_node.dtype in dace.dtypes.INTEGER_TYPES - - if isinstance(self.offset_provider_type[offset_dim], common.NeighborConnectivityType): - offset_provider_type = cast( - common.NeighborConnectivityType, self.offset_provider_type[offset_dim] - ) # ensured by condition - connectivity = self.context.state.add_access( - dace_utils.connectivity_identifier(offset_dim), debuginfo=di - ) - - shifted_dim_tag = offset_provider_type.source_dim.value - target_dim_tag = offset_provider_type.codomain.value - args = [ - ValueExpr(connectivity, _INDEX_DTYPE), - ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), - offset_node, - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" - else: - shifted_dim = self.offset_provider_type[offset_dim] - assert isinstance(shifted_dim, common.Dimension) - - shifted_dim_tag = shifted_dim.value - target_dim_tag = shifted_dim_tag - args = [ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]} + {internals[1]}" - - shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, offset_node.dtype, "shift", dace_debuginfo=di - )[0].value - - shifted_index = {dim: value for dim, value in iterator.indices.items()} - del shifted_index[shifted_dim_tag] - shifted_index[target_dim_tag] = shifted_value - - return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) - - def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - offset = node.value - assert isinstance(offset, int) - offset_var = unique_var_name() - self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True) - offset_node = self.context.state.add_access(offset_var, debuginfo=di) - tasklet_node = self.context.state.add_tasklet( - "get_offset", {}, {"__out"}, f"__out = {offset}", debuginfo=di - ) - self.context.state.add_edge( - tasklet_node, "__out", offset_node, None, dace.Memlet(data=offset_var, subset="0") - ) - return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] - - def _visit_reduce(self, node: itir.FunCall): - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - reduce_dtype = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - - if len(node.args) == 1: - assert ( - isinstance(node.args[0], itir.FunCall) - and isinstance(node.args[0].fun, itir.SymRef) - and node.args[0].fun.id == "neighbors" - ) - assert isinstance(node.fun, itir.FunCall) - op_name = node.fun.args[0] - assert isinstance(op_name, itir.SymRef) - reduce_identity = node.fun.args[1] - assert isinstance(reduce_identity, itir.Literal) - - # set reduction state - self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - - args = self.visit(node.args[0]) - - assert 1 <= len(args) <= 2 - reduce_input_node = args[0].value - - else: - assert isinstance(node.fun, itir.FunCall) - assert isinstance(node.fun.args[0], itir.Lambda) - fun_node = node.fun.args[0] - assert isinstance(fun_node.expr, itir.FunCall) - - op_name = fun_node.expr.fun - assert isinstance(op_name, itir.SymRef) - reduce_identity = get_reduce_identity_value(op_name.id, reduce_dtype) - - # set reduction state in visit context - self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - - args = self.visit(node.args) - - # clear context - self.context.reduce_identity = None - - # check that all neighbor expressions have the same shape - args_shape = [ - arg[0].value.desc(self.context.body).shape - for arg in args - if arg[0].value.desc(self.context.body).shape != (1,) - ] - assert len(set(args_shape)) == 1 - nreduce_shape = args_shape[0] - - input_args = [arg[0] for arg in args] - input_valid_args = [arg[1] for arg in args if len(arg) == 2] - - assert len(nreduce_shape) == 1 - nreduce_index = unique_name("_i") - nreduce_domain = {nreduce_index: f"0:{nreduce_shape[0]}"} - - reduce_input_name = unique_var_name() - self.context.body.add_array( - reduce_input_name, nreduce_shape, reduce_dtype, transient=True - ) - - lambda_node = itir.Lambda( - expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location - ) - lambda_context, inner_inputs, inner_outputs = self.visit( - lambda_node, args=input_args, use_neighbor_tables=False - ) - - input_mapping = { - param: ( - dace.Memlet(data=arg.value.data, subset="0") - if arg.value.desc(self.context.body).shape == (1,) - else dace.Memlet(data=arg.value.data, subset=nreduce_index) - ) - for (param, _), arg in zip(inner_inputs, input_args) - } - output_mapping = { - inner_outputs[0].value.data: dace.Memlet( - data=reduce_input_name, subset=nreduce_index - ) - } - symbol_mapping = map_nested_sdfg_symbols( - self.context.body, lambda_context.body, input_mapping - ) - - if input_valid_args: - """ - The neighbors builtin returns an array of booleans in case the connectivity table contains skip values. - These booleans indicate whether the neighbor is present or not, and are used in a tasklet to select - the result of field access or the identity value, respectively. - If the neighbor table has full connectivity (no skip values by type definition), the input_valid node - is not built, and the construction of the select tasklet below is also skipped. - """ - input_args.append(input_valid_args[0]) - input_valid_node = input_valid_args[0].value - lambda_output_node = inner_outputs[0].value - # add input connector to nested sdfg - lambda_context.body.add_scalar("_valid_neighbor", dace.dtypes.bool) - input_mapping["_valid_neighbor"] = dace.Memlet( - data=input_valid_node.data, subset=nreduce_index - ) - # add select tasklet before writing to output node - # TODO: consider replacing it with a select-memlet once it is supported by DaCe SDFG API - output_edge = lambda_context.state.in_edges(lambda_output_node)[0] - assert isinstance( - lambda_context.body.arrays[output_edge.src.data], dace.data.Scalar - ) - select_tasklet = lambda_context.state.add_tasklet( - "neighbor_select", - {"_inp", "_valid"}, - {"_out"}, - f"_out = _inp if _valid else {reduce_identity}", - ) - lambda_context.state.add_edge( - output_edge.src, - None, - select_tasklet, - "_inp", - dace.Memlet(data=output_edge.src.data, subset="0"), - ) - lambda_context.state.add_edge( - lambda_context.state.add_access("_valid_neighbor"), - None, - select_tasklet, - "_valid", - dace.Memlet(data="_valid_neighbor", subset="0"), - ) - lambda_context.state.add_edge( - select_tasklet, - "_out", - lambda_output_node, - None, - dace.Memlet(data=lambda_output_node.data, subset="0"), - ) - lambda_context.state.remove_edge(output_edge) - - reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) - - nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( - self.context.state, - sdfg=lambda_context.body, - map_ranges=nreduce_domain, - inputs=input_mapping, - outputs=output_mapping, - symbol_mapping=symbol_mapping, - input_nodes={arg.value.data: arg.value for arg in input_args}, - output_nodes={reduce_input_name: reduce_input_node}, - debuginfo=di, - ) - - reduce_input_desc = reduce_input_node.desc(self.context.body) - - result_name = unique_var_name() - # we allocate an array instead of a scalar because the reduce library node is generic and expects an array node - self.context.body.add_array(result_name, (1,), reduce_dtype, transient=True) - result_access = self.context.state.add_access(result_name, debuginfo=di) - - reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") - reduce_node = self.context.state.add_reduce(reduce_wcr, None, reduce_identity) - self.context.state.add_nedge( - reduce_input_node, - reduce_node, - dace.Memlet.from_array(reduce_input_node.data, reduce_input_desc), - ) - self.context.state.add_nedge( - reduce_node, result_access, dace.Memlet(data=result_name, subset="0") - ) - - return [ValueExpr(result_access, reduce_dtype)] - - def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: - assert isinstance(node.fun, itir.SymRef) - fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = flatten_list(self.visit(node.args)) - expr_args = [ - (arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr) - ] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args - ] - expr = fmt.format(*internals) - type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - return self.add_expr_tasklet( - expr_args, - expr, - type_, - "numeric", - dace_debuginfo=dace_utils.debug_info(node, default=self.context.body.debuginfo), - ) - - def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: - assert isinstance(node.fun, itir.SymRef) - expr_func = _GENERAL_BUILTIN_MAPPING[str(node.fun.id)] - return expr_func(self, node, node.args) - - def add_expr_tasklet( - self, - args: list[tuple[ValueExpr, str]], - expr: str, - result_type: Any, - name: str, - dace_debuginfo: Optional[dace.dtypes.DebugInfo] = None, - ) -> list[ValueExpr]: - di = dace_debuginfo if dace_debuginfo else self.context.body.debuginfo - result_name = unique_var_name() - self.context.body.add_scalar(result_name, result_type, transient=True) - result_access = self.context.state.add_access(result_name, debuginfo=di) - - expr_tasklet = self.context.state.add_tasklet( - name=name, - inputs={internal for _, internal in args}, - outputs={"__result"}, - code=f"__result = {expr}", - debuginfo=di, - ) - - for arg, internal in args: - edges = self.context.state.in_edges(expr_tasklet) - used = False - for edge in edges: - if edge.dst_conn == internal: - used = True - break - if used: - continue - elif not isinstance(arg, SymbolExpr): - memlet = dace.Memlet.from_array( - arg.value.data, self.context.body.arrays[arg.value.data] - ) - self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - - memlet = dace.Memlet(data=result_access.data, subset="0") - self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) - - return [ValueExpr(result_access, result_type)] - - -def is_scan(node: itir.Node) -> bool: - return isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="scan") - - -def closure_to_tasklet_sdfg( - node: itir.StencilClosure, - offset_provider_type: common.OffsetProviderType, - domain: dict[str, str], - inputs: Sequence[tuple[str, ts.TypeSpec]], - connectivities: Sequence[tuple[dace.ndarray, str]], - use_field_canonical_representation: bool, -) -> tuple[Context, Sequence[ValueExpr]]: - body = dace.SDFG("tasklet_toplevel") - body.debuginfo = dace_utils.debug_info(node) - state = body.add_state("tasklet_toplevel_entry", True) - symbol_map: dict[str, TaskletExpr] = {} - - idx_accesses = {} - for dim, idx in domain.items(): - name = f"{idx}_value" - body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) - tasklet = state.add_tasklet( - f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=body.debuginfo - ) - access = state.add_access(name, debuginfo=body.debuginfo) - idx_accesses[dim] = access - state.add_edge(tasklet, "value", access, None, dace.Memlet(data=name, subset="0")) - for name, ty in inputs: - if isinstance(ty, ts.FieldType): - ndim = len(ty.dims) - shape, strides = new_array_symbols(name, ndim) - dims = [dim.value for dim in ty.dims] - dtype = dace_utils.as_dace_type(ty.dtype) - body.add_array(name, shape=shape, strides=strides, dtype=dtype) - field = state.add_access(name, debuginfo=body.debuginfo) - indices = {dim: idx_accesses[dim] for dim in domain.keys()} - symbol_map[name] = IteratorExpr(field, indices, dtype, dims) - else: - assert isinstance(ty, ts.ScalarType) - dtype = dace_utils.as_dace_type(ty) - body.add_scalar(name, dtype=dtype) - symbol_map[name] = ValueExpr(state.add_access(name, debuginfo=body.debuginfo), dtype) - for arr, name in connectivities: - shape, strides = new_array_symbols(name, ndim=2) - body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) - - context = Context(body, state, symbol_map) - translator = PythonTaskletCodegen( - offset_provider_type, context, use_field_canonical_representation - ) - - args = [itir.SymRef(id=name) for name, _ in inputs] - if is_scan(node.stencil): - stencil = cast(FunCall, node.stencil) - assert isinstance(stencil.args[0], Lambda) - lambda_node = itir.Lambda( - expr=stencil.args[0].expr, params=stencil.args[0].params, location=node.location - ) - fun_node = itir.FunCall(fun=lambda_node, args=args, location=node.location) - else: - fun_node = itir.FunCall(fun=node.stencil, args=args, location=node.location) - - results = translator.visit(fun_node) - for r in results: - context.body.arrays[r.value.data].transient = False - - return context, results diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py deleted file mode 100644 index 72bb32f003..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ /dev/null @@ -1,149 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import itertools -from typing import Any - -import dace - -import gt4py.next.iterator.ir as itir -from gt4py import eve -from gt4py.next import common -from gt4py.next.ffront import fbuiltins as gtx_fbuiltins -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils - - -def get_used_connectivities( - node: itir.Node, offset_provider_type: common.OffsetProviderType -) -> dict[str, common.NeighborConnectivityType]: - connectivities = dace_utils.filter_connectivity_types(offset_provider_type) - offset_dims = set(eve.walk_values(node).if_isinstance(itir.OffsetLiteral).getattr("value")) - return {offset: connectivities[offset] for offset in offset_dims if offset in connectivities} - - -def map_nested_sdfg_symbols( - parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet] -) -> dict[str, str]: - symbol_mapping: dict[str, str] = {} - for param, arg in array_mapping.items(): - arg_array = parent_sdfg.arrays[arg.data] - param_array = nested_sdfg.arrays[param] - if not isinstance(param_array, dace.data.Scalar): - assert len(arg.subset.size()) == len(param_array.shape) - for arg_shape, param_shape in zip(arg.subset.size(), param_array.shape): - if isinstance(param_shape, dace.symbol): - symbol_mapping[str(param_shape)] = str(arg_shape) - assert len(arg_array.strides) == len(param_array.strides) - for arg_stride, param_stride in zip(arg_array.strides, param_array.strides): - if isinstance(param_stride, dace.symbol): - symbol_mapping[str(param_stride)] = str(arg_stride) - else: - assert arg.subset.num_elements() == 1 - for sym in nested_sdfg.free_symbols: - if str(sym) not in symbol_mapping: - symbol_mapping[str(sym)] = str(sym) - return symbol_mapping - - -def add_mapped_nested_sdfg( - state: dace.SDFGState, - map_ranges: dict[str, str | dace.subsets.Subset] | list[tuple[str, str | dace.subsets.Subset]], - inputs: dict[str, dace.Memlet], - outputs: dict[str, dace.Memlet], - sdfg: dace.SDFG, - symbol_mapping: dict[str, Any] | None = None, - schedule: Any = dace.dtypes.ScheduleType.Default, - unroll_map: bool = False, - location: Any = None, - debuginfo: Any = None, - input_nodes: dict[str, dace.nodes.AccessNode] | None = None, - output_nodes: dict[str, dace.nodes.AccessNode] | None = None, -) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: - if not symbol_mapping: - symbol_mapping = {sym: sym for sym in sdfg.free_symbols} - - nsdfg_node = state.add_nested_sdfg( - sdfg, - None, - set(inputs.keys()), - set(outputs.keys()), - symbol_mapping, - name=sdfg.name, - schedule=schedule, - location=location, - debuginfo=debuginfo, - ) - - map_entry, map_exit = state.add_map( - f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo - ) - - if input_nodes is None: - input_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) - for name, memlet in inputs.items() - } - if output_nodes is None: - output_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) - for name, memlet in outputs.items() - } - if not inputs: - state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) - for name, memlet in inputs.items(): - state.add_memlet_path( - input_nodes[memlet.data], - map_entry, - nsdfg_node, - memlet=memlet, - src_conn=None, - dst_conn=name, - propagate=True, - ) - if not outputs: - state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) - for name, memlet in outputs.items(): - state.add_memlet_path( - nsdfg_node, - map_exit, - output_nodes[memlet.data], - memlet=memlet, - src_conn=name, - dst_conn=None, - propagate=True, - ) - - return nsdfg_node, map_entry, map_exit - - -def unique_name(prefix): - unique_id = getattr(unique_name, "_unique_id", 0) # static variable - setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant] - - return f"{prefix}_{unique_id}" - - -def unique_var_name(): - return unique_name("_var") - - -def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: - dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) - shape = [dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) for i in range(ndim)] - strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) for i in range(ndim) - ] - return shape, strides - - -def flatten_list(node_list: list[Any]) -> list[Any]: - return list( - itertools.chain.from_iterable( - [flatten_list(e) if isinstance(e, list) else [e] for e in node_list] - ) - ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py deleted file mode 100644 index 653ed4719d..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ /dev/null @@ -1,150 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import dataclasses -import functools -from typing import Callable, Optional, Sequence - -import dace -import factory - -from gt4py._core import definitions as core_defs -from gt4py.next import common, config -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import program_to_fencil -from gt4py.next.otf import languages, recipes, stages, step_types, workflow -from gt4py.next.otf.binding import interface -from gt4py.next.otf.languages import LanguageSettings -from gt4py.next.program_processors.runners.dace_common import workflow as dace_workflow -from gt4py.next.type_system import type_specifications as ts - -from . import build_sdfg_from_itir - - -@dataclasses.dataclass(frozen=True) -class DaCeTranslator( - workflow.ChainableWorkflowMixin[ - stages.CompilableProgram, stages.ProgramSource[languages.SDFG, languages.LanguageSettings] - ], - step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], -): - auto_optimize: bool = False - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - symbolic_domain_sizes: Optional[dict[str, str]] = None - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None - use_field_canonical_representation: bool = False - - def _language_settings(self) -> languages.LanguageSettings: - return languages.LanguageSettings( - formatter_key="", formatter_style="", file_extension="sdfg" - ) - - def generate_sdfg( - self, - program: itir.FencilDefinition, - arg_types: Sequence[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - column_axis: Optional[common.Dimension], - ) -> dace.SDFG: - on_gpu = ( - True - if self.device_type in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] - else False - ) - - return build_sdfg_from_itir( - program, - arg_types, - offset_provider_type=offset_provider_type, - auto_optimize=self.auto_optimize, - on_gpu=on_gpu, - column_axis=column_axis, - symbolic_domain_sizes=self.symbolic_domain_sizes, - temporary_extraction_heuristics=self.temporary_extraction_heuristics, - load_sdfg_from_file=False, - save_sdfg=False, - use_field_canonical_representation=self.use_field_canonical_representation, - ) - - def __call__( - self, inp: stages.CompilableProgram - ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: - """Generate DaCe SDFG file from the ITIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data - - if isinstance(program, itir.Program): - program = program_to_fencil.program_to_fencil(program) - - sdfg = self.generate_sdfg( - program, - inp.args.args, - common.offset_provider_to_type(inp.args.offset_provider), - inp.args.column_axis, - ) - - param_types = tuple( - interface.Parameter(param, arg) for param, arg in zip(sdfg.arg_names, inp.args.args) - ) - - module: stages.ProgramSource[languages.SDFG, languages.LanguageSettings] = ( - stages.ProgramSource( - entry_point=interface.Function(program.id, param_types), - source_code=sdfg.to_json(), - library_deps=tuple(), - language=languages.SDFG, - language_settings=self._language_settings(), - implicit_domain=inp.data.implicit_domain, - ) - ) - return module - - -class DaCeTranslationStepFactory(factory.Factory): - class Meta: - model = DaCeTranslator - - -def _no_bindings(inp: stages.ProgramSource) -> stages.CompilableSource: - return stages.CompilableSource(program_source=inp, binding_source=None) - - -class DaCeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( - lambda: config.CMAKE_BUILD_TYPE - ) - use_field_canonical_representation: bool = False - - translation = factory.SubFactory( - DaCeTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - use_field_canonical_representation=factory.SelfAttribute( - "..use_field_canonical_representation" - ), - ) - bindings = _no_bindings - compilation = factory.SubFactory( - dace_workflow.DaCeCompilationStepFactory, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - dace_workflow.convert_args, - device=o.device_type, - use_field_canonical_representation=o.use_field_canonical_representation, - ) - ) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 349d3e9f70..1593ab3ba6 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -11,11 +11,10 @@ import dataclasses import enum import importlib -from typing import Final, Optional, Protocol import pytest -from gt4py.next import allocators as next_allocators, backend as next_backend +from gt4py.next import allocators as next_allocators # Skip definitions @@ -67,10 +66,10 @@ class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum): class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): - DACE_CPU = "gt4py.next.program_processors.runners.dace.itir_cpu" - DACE_GPU = "gt4py.next.program_processors.runners.dace.itir_gpu" - GTIR_DACE_CPU = "gt4py.next.program_processors.runners.dace.gtir_cpu" - GTIR_DACE_GPU = "gt4py.next.program_processors.runners.dace.gtir_gpu" + DACE_CPU = "gt4py.next.program_processors.runners.dace.run_dace_cpu" + DACE_GPU = "gt4py.next.program_processors.runners.dace.run_dace_gpu" + DACE_CPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_cpu_noopt" + DACE_GPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_gpu_noopt" class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): @@ -139,21 +138,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] -DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ - (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), - (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), -] -GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ +DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), @@ -189,10 +174,16 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): BACKEND_SKIP_TEST_MATRIX = { EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, - OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.GTIR_DACE_CPU: GTIR_DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.GTIR_DACE_GPU: GTIR_DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST + + [ + (ALL, SKIP, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): Enable once the optimization pipeline is merged + OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST + + [ + (ALL, SKIP, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): Enable once the optimization pipeline is merged. + OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index f5646c71e4..08904c06f3 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -6,14 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from types import ModuleType -from typing import Optional - import numpy as np import pytest import gt4py.next as gtx -from gt4py.next import backend as next_backend, common +from gt4py.next import allocators as gtx_allocators, common as gtx_common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case, unstructured_case @@ -34,24 +31,22 @@ try: import dace - from gt4py.next.program_processors.runners.dace import ( - itir_cpu as run_dace_cpu, - itir_gpu as run_dace_gpu, - ) except ImportError: dace: Optional[ModuleType] = None # type:ignore[no-redef] - run_dace_cpu: Optional[next_backend.Backend] = None - run_dace_gpu: Optional[next_backend.Backend] = None pytestmark = pytest.mark.requires_dace def test_sdfgConvertible_laplap(cartesian_case): - # TODO(kotsaloscv): Temporary solution until the `requires_dace` marker is fully functional - if cartesian_case.backend not in [run_dace_cpu, run_dace_gpu]: + if not cartesian_case.backend or "dace" not in cartesian_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - if cartesian_case.backend == run_dace_gpu: + # TODO(ricoh): enable test after adding GTIR support + pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") + + allocator, backend = unstructured_case.allocator, unstructured_case.backend + + if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): import cupy as xp else: import numpy as xp @@ -64,13 +59,13 @@ def test_sdfgConvertible_laplap(cartesian_case): def sdfg(): tmp_field = xp.empty_like(out_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( - cartesian_case.backend - ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + backend + ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( in_field, tmp_field ) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( - cartesian_case.backend - ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + backend + ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( tmp_field, out_field ) @@ -94,13 +89,15 @@ def testee(a: gtx.Field[gtx.Dims[Vertex], gtx.float64], b: gtx.Field[gtx.Dims[Ed @pytest.mark.uses_unstructured_shift def test_sdfgConvertible_connectivities(unstructured_case): - # TODO(kotsaloscv): Temporary solution until the `requires_dace` marker is fully functional - if unstructured_case.backend not in [run_dace_cpu, run_dace_gpu]: + if not unstructured_case.backend or "dace" not in unstructured_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") + # TODO(ricoh): enable test after adding GTIR support + pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") + allocator, backend = unstructured_case.allocator, unstructured_case.backend - if backend == run_dace_gpu: + if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): import cupy as xp dace_storage_type = dace.StorageType.GPU_Global diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 794dd06709..1147f4bc3e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -66,11 +66,11 @@ def __gt_allocator__( marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_CPU, + next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, marks=pytest.mark.requires_dace, ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_GPU, + next_tests.definitions.OptionalProgramBackendId.DACE_GPU_NO_OPT, marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), ), ],