diff --git a/AUTHORS.md b/AUTHORS.md index 1aea538625..52091a2f59 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -23,6 +23,7 @@ - Madonna, Alberto. ETH Zurich - CSCS - Mariotti, Kean. ETH Zurich - CSCS - Müller, Christoph. MeteoSwiss +- Müller, Philip. ETH Zurich - CSCS - Osuna, Carlos. MeteoSwiss - Paone, Edoardo. ETH Zurich - CSCS - Röthlin, Matthias. MeteoSwiss diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py new file mode 100644 index 0000000000..9a7d1316ed --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -0,0 +1,43 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Transformation and optimization pipeline for the DaCe backend in GT4Py. + +Please also see [this HackMD document](https://hackmd.io/@gridtools/rklwk4OIR#Requirements-on-SDFG) +that explains the general structure and requirements on the SDFG. +""" + +from .auto_opt import dace_auto_optimize, gt_auto_optimize, gt_set_iteration_order, gt_simplify +from .gpu_utils import ( + GPUSetBlockSize, + SerialMapPromoterGPU, + gt_gpu_transformation, + gt_set_gpu_blocksize, +) +from .k_blocking import KBlocking +from .map_fusion_parallel import ParallelMapFusion +from .map_fusion_serial import SerialMapFusion +from .map_orderer import MapIterationOrder +from .map_promoter import SerialMapPromoter + + +__all__ = [ + "GPUSetBlockSize", + "KBlocking", + "MapIterationOrder", + "SerialMapFusion", + "SerialMapPromoter", + "SerialMapPromoterGPU", + "ParallelMapFusion", + "dace_auto_optimize", + "gt_auto_optimize", + "gt_gpu_transformation", + "gt_set_iteration_order", + "gt_set_gpu_blocksize", + "gt_simplify", +] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py new file mode 100644 index 0000000000..3b3dcce421 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -0,0 +1,354 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Fast access to the auto optimization on DaCe.""" + +from typing import Any, Optional, Sequence + +import dace +from dace.transformation import dataflow as dace_dataflow +from dace.transformation.auto import auto_optimize as dace_aoptimize + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +__all__ = [ + "dace_auto_optimize", + "gt_simplify", + "gt_set_iteration_order", + "gt_auto_optimize", +] + + +def dace_auto_optimize( + sdfg: dace.SDFG, + device: dace.DeviceType = dace.DeviceType.CPU, + use_gpu_storage: bool = True, + **kwargs: Any, +) -> dace.SDFG: + """This is a convenient wrapper arround DaCe's `auto_optimize` function. + + Args: + sdfg: The SDFG that should be optimized in place. + device: the device for which optimizations should be done, defaults to CPU. + use_gpu_storage: Assumes that the SDFG input is already on the GPU. + This parameter is `False` in DaCe but here is changed to `True`. + kwargs: Are forwarded to the underlying auto optimized exposed by DaCe. + """ + return dace_aoptimize.auto_optimize( + sdfg, + device=device, + use_gpu_storage=use_gpu_storage, + **kwargs, + ) + + +def gt_simplify( + sdfg: dace.SDFG, + validate: bool = True, + validate_all: bool = False, + skip: Optional[set[str]] = None, +) -> Any: + """Performs simplifications on the SDFG in place. + + Instead of calling `sdfg.simplify()` directly, you should use this function, + as it is specially tuned for GridTool based SDFGs. + + Args: + sdfg: The SDFG to optimize. + validate: Perform validation after the pass has run. + validate_all: Perform extensive validation. + skip: List of simplify passes that should not be applied. + + Note: + The reason for this function is that we can influence how simplify works. + Since some parts in simplify might break things in the SDFG. + However, currently nothing is customized yet, and the function just calls + the simplification pass directly. + """ + from dace.transformation.passes.simplify import SimplifyPass + + return SimplifyPass( + validate=validate, + validate_all=validate_all, + verbose=False, + skip=skip, + ).apply_pass(sdfg, {}) + + +def gt_set_iteration_order( + sdfg: dace.SDFG, + leading_dim: gtx_common.Dimension, + validate: bool = True, + validate_all: bool = False, +) -> Any: + """Set the iteration order of the Maps correctly. + + Modifies the order of the Map parameters such that `leading_dim` + is the fastest varying one, the order of the other dimensions in + a Map is unspecific. `leading_dim` should be the dimensions were + the stride is one. + + Args: + sdfg: The SDFG to process. + leading_dim: The leading dimensions. + validate: Perform validation during the steps. + validate_all: Perform extensive validation. + """ + return sdfg.apply_transformations_once_everywhere( + gtx_transformations.MapIterationOrder( + leading_dim=leading_dim, + ) + ) + + +def gt_auto_optimize( + sdfg: dace.SDFG, + gpu: bool, + leading_dim: Optional[gtx_common.Dimension] = None, + aggressive_fusion: bool = True, + make_persistent: bool = True, + gpu_block_size: Optional[Sequence[int | str] | str] = None, + block_dim: Optional[gtx_common.Dimension] = None, + blocking_size: int = 10, + reuse_transients: bool = False, + validate: bool = True, + validate_all: bool = False, + **kwargs: Any, +) -> dace.SDFG: + """Performs GT4Py specific optimizations on the SDFG in place. + + The auto optimization works in different phases, that focuses each on + different aspects of the SDFG. The initial SDFG is assumed to have a + very large number of rather simple Maps. + + 1. Some general simplification transformations, beyond classical simplify, + are applied to the SDFG. + 2. In this phase the function tries to reduce the number of maps. This + process mostly relies on the map fusion transformation. If + `aggressive_fusion` is set the function will also promote certain Maps, to + make them fusable. For this it will add dummy dimensions. However, currently + the function will only add horizonal dimensions. + In this phase some optimizations inside the bigger kernels themselves might + be applied as well. + 3. After the function created big kernels it will apply some optimization, + inside the kernels itself. For example fuse maps inside them. + 4. Afterwards it will process the map ranges and iteration order. For this + the function assumes that the dimension indicated by `leading_dim` is the + one with stride one. + 5. If requested the function will now apply blocking, on the dimension indicated + by `leading_dim`. (The reason that it is not done in the kernel optimization + phase is a restriction dictated by the implementation.) + 6. If requested the SDFG will be transformed to GPU. For this the + `gt_gpu_transformation()` function is used, that might apply several other + optimizations. + 7. Afterwards some general transformations to the SDFG are applied. + This includes: + - Use fast implementation for library nodes. + - Move small transients to stack. + - Make transients persistent (if requested). + - Apply DaCe's `TransientReuse` transformation (if requested). + + Args: + sdfg: The SDFG that should be optimized in place. + gpu: Optimize for GPU or CPU. + leading_dim: Leading dimension, indicates where the stride is 1. + aggressive_fusion: Be more aggressive in fusion, will lead to the promotion + of certain maps. + make_persistent: Turn all transients to persistent lifetime, thus they are + allocated over the whole lifetime of the program, even if the kernel exits. + Thus the SDFG can not be called by different threads. + gpu_block_size: The thread block size for maps in GPU mode, currently only + one for all. + block_dim: On which dimension blocking should be applied. + blocking_size: How many elements each block should process. + reuse_transients: Run the `TransientReuse` transformation, might reduce memory footprint. + validate: Perform validation during the steps. + validate_all: Perform extensive validation. + + Todo: + - Make sure that `SDFG.simplify()` is not called indirectly, by temporarily + overwriting it with `gt_simplify()`. + - Specify arguments to set the size of GPU thread blocks depending on the + dimensions. I.e. be able to use a different size for 1D than 2D Maps. + - Add a parallel version of Map fusion. + - Implement some model to further guide to determine what we want to fuse. + Something along the line "Fuse if operational intensity goes up, but + not if we have too much internal space (register pressure). + - Create a custom array elimination pass that honors rule 1. + - Check if a pipeline could be used to speed up some computations. + """ + device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU + + with dace.config.temporary_config(): + dace.Config.set("optimizer", "match_exception", value=True) + dace.Config.set("store_history", value=False) + + # TODO(phimuell): Should there be a zeroth phase, in which we generate + # a chanonical form of the SDFG, for example move all local maps + # to internal serial maps, such that they do not block fusion? + + # Phase 1: Initial Cleanup + gt_simplify(sdfg) + sdfg.apply_transformations_repeated( + [ + dace_dataflow.TrivialMapElimination, + # TODO(phimuell): Investigate if these two are appropriate. + dace_dataflow.MapReduceFusion, + dace_dataflow.MapWCRFusion, + ], + validate=validate, + validate_all=validate_all, + ) + + # Compute the SDFG hash to see if something has changed. + sdfg_hash = sdfg.hash_sdfg() + + # Phase 2: Kernel Creation + # We will now try to reduce the number of kernels and create large Maps/kernels. + # For this we essentially use Map fusion. We do this is a loop because + # after a graph modification followed by simplify new fusing opportunities + # might arise. We use the hash of the SDFG to detect if we have reached a + # fix point. + # TODO(phimuell): Find a better upper bound for the starvation protection. + for _ in range(100): + # Use map fusion to reduce their number and to create big kernels + # TODO(phimuell): Use a cost measurement to decide if fusion should be done. + # TODO(phimuell): Add parallel fusion transformation. Should it run after + # or with the serial one? + sdfg.apply_transformations_repeated( + [ + gtx_transformations.SerialMapFusion( + only_toplevel_maps=True, + ), + gtx_transformations.ParallelMapFusion( + only_toplevel_maps=True, + ), + ], + validate=validate, + validate_all=validate_all, + ) + + # Now do some cleanup task, that may enable further fusion opportunities. + # Note for performance reasons simplify is deferred. + phase2_cleanup = [] + phase2_cleanup.append(dace_dataflow.TrivialTaskletElimination()) + + # TODO(phimuell): Should we do this all the time or only once? (probably the later) + # TODO(phimuell): Add a criteria to decide if we should promote or not. + phase2_cleanup.append( + gtx_transformations.SerialMapPromoter( + only_toplevel_maps=True, + promote_vertical=True, + promote_horizontal=False, + promote_local=False, + ) + ) + + sdfg.apply_transformations_once_everywhere( + phase2_cleanup, + validate=validate, + validate_all=validate_all, + ) + + # Use the hash to determine if the transformations did modify the SDFG. + # If not we have optimized the SDFG as much as we could, in this phase. + old_sdfg_hash = sdfg_hash + sdfg_hash = sdfg.hash_sdfg() + if old_sdfg_hash == sdfg_hash: + break + + # The SDFG was modified by the transformations above. The SDFG was + # modified. Call Simplify and try again to further optimize. + gt_simplify(sdfg) + + else: + raise RuntimeWarning("Optimization of the SDFG did not converge.") + + # Phase 3: Optimizing the kernels themselves. + # Currently this only applies fusion inside Maps. + sdfg.apply_transformations_repeated( + [ + gtx_transformations.SerialMapFusion( + only_inner_maps=True, + ), + # TODO(phimuell): This might be a bit to aggressive, there should be + # more control about what to fuse; Serial fusing here should be good + # most of the times. + gtx_transformations.ParallelMapFusion( + only_inner_maps=True, + ), + ], + validate=validate, + validate_all=validate_all, + ) + gt_simplify(sdfg) + + # Phase 4: Iteration Space + # This essentially ensures that the stride 1 dimensions are handled + # by the inner most loop nest (CPU) or x-block (GPU) + if leading_dim is not None: + gt_set_iteration_order( + sdfg=sdfg, + leading_dim=leading_dim, + validate=validate, + validate_all=validate_all, + ) + + # Phase 5: Apply blocking + if block_dim is not None: + sdfg.apply_transformations_once_everywhere( + gtx_transformations.KBlocking( + blocking_size=blocking_size, + block_dim=block_dim, + ), + validate=validate, + validate_all=validate_all, + ) + + # Phase 6: Going to GPU + if gpu: + # TODO(phimuell): The GPU function might modify the map iteration order. + # This is because how it is implemented (promotion and + # fusion). However, because of its current state, this + # should not happen, but we have to look into it. + gpu_launch_factor: Optional[int] = kwargs.get("gpu_launch_factor", None) + gpu_launch_bounds: Optional[int] = kwargs.get("gpu_launch_bounds", None) + gtx_transformations.gt_gpu_transformation( + sdfg, + gpu_block_size=gpu_block_size, + gpu_launch_bounds=gpu_launch_bounds, + gpu_launch_factor=gpu_launch_factor, + validate=validate, + validate_all=validate_all, + try_removing_trivial_maps=True, + ) + + # Phase 7: General Optimizations + # The following operations apply regardless if we have a GPU or CPU. + # The DaCe auto optimizer also uses them. Note that the reuse transient + # is not done by DaCe. + if reuse_transients: + # TODO(phimuell): Investigate if we should enable it, it may make things + # harder for the compiler. Maybe write our own to + # only consider big transients and not small ones (~60B) + transient_reuse = dace.transformation.passes.TransientReuse() + transient_reuse.apply_pass(sdfg, {}) + + # Set the implementation of the library nodes. + dace_aoptimize.set_fast_implementations(sdfg, device) + # TODO(phimuell): Fix the bug, it uses the tile value and not the stack array value. + dace_aoptimize.move_small_arrays_to_stack(sdfg) + if make_persistent: + # TODO(phimuell): Allow to also to set the lifetime to `SDFG`. + dace_aoptimize.make_transients_persistent(sdfg, device) + + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py new file mode 100644 index 0000000000..70ac77b89d --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -0,0 +1,370 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Functions for turning an SDFG into a GPU SDFG.""" + +import copy +from typing import Any, Optional, Sequence, Union + +import dace +from dace import properties as dace_properties, transformation as dace_transformation +from dace.sdfg import SDFG, SDFGState, nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +__all__ = [ + "SerialMapPromoterGPU", + "GPUSetBlockSize", + "gt_gpu_transformation", + "gt_set_gpu_blocksize", +] + + +def gt_gpu_transformation( + sdfg: dace.SDFG, + try_removing_trivial_maps: bool = True, + use_gpu_storage: bool = True, + gpu_block_size: Optional[Sequence[int | str] | str] = None, + validate: bool = True, + validate_all: bool = False, + **kwargs: Any, +) -> dace.SDFG: + """Transform an SDFG into a GPU SDFG. + + The transformation expects a rather optimized SDFG and turn it into an SDFG + capable of running on the GPU. + The function performs the following steps: + - If requested, modify the storage location of the non transient arrays such + that they reside in GPU memory. + - Call the normal GPU transform function followed by simplify. + - If requested try to remove trivial kernels. + - If specified, set the `gpu_block_size` parameters of the Maps to the given value. + + Args: + sdfg: The SDFG that should be processed. + try_removing_trivial_maps: Try to get rid of trivial maps by incorporating them. + use_gpu_storage: Assume that the non global memory is already on the GPU. + gpu_block_size: Set to true when the SDFG array arguments are already allocated + on GPU global memory. This will avoid the data copy from host to GPU memory. + + Notes: + The function might modify the order of the iteration variables of some + maps. + In addition it might fuse Maps together that should not be fused. To prevent + that you should set `try_removing_trivial_maps` to `False`. + + Todo: + - Solve the fusing problem. + - Currently only one block size for all maps is given, add more options. + """ + + # You need guru level or above to use these arguments. + gpu_launch_factor: Optional[int] = kwargs.pop("gpu_launch_factor", None) + gpu_launch_bounds: Optional[int] = kwargs.pop("gpu_launch_bounds", None) + assert ( + len(kwargs) == 0 + ), f"gt_gpu_transformation(): found unknown arguments: {', '.join(arg for arg in kwargs.keys())}" + + # Turn all global arrays (which we identify as input) into GPU memory. + # This way the GPU transformation will not create this copying stuff. + if use_gpu_storage: + for desc in sdfg.arrays.values(): + if isinstance(desc, dace.data.Array) and not desc.transient: + desc.storage = dace.dtypes.StorageType.GPU_Global + + # Now turn it into a GPU SDFG + sdfg.apply_gpu_transformations( + validate=validate, + validate_all=validate_all, + simplify=False, + ) + # The documentation recommends to run simplify afterwards + gtx_transformations.gt_simplify(sdfg) + + if try_removing_trivial_maps: + # A Tasklet, outside of a Map, that writes into an array on GPU can not work + # `sdfg.appyl_gpu_transformations()` puts Map around it (if said Tasklet + # would write into a Scalar that then goes into a GPU Map, nothing would + # happen. So we might end up with lot of these trivial Maps, that results + # in a single kernel launch. To prevent this we will try to fuse them. + # NOTE: The current implementation has a bug, because promotion and fusion + # are two different steps. Because of this the function will implicitly + # fuse everything together it can find. + # TODO(phimuell): Fix the issue described above. + sdfg.apply_transformations_once_everywhere( + gtx_transformations.SerialMapPromoterGPU(), + validate=False, + validate_all=False, + ) + sdfg.apply_transformations_repeated( + gtx_transformations.SerialMapFusion( + only_toplevel_maps=True, + ), + validate=validate, + validate_all=validate_all, + ) + + # Set the GPU block size if it is known. + if gpu_block_size is not None: + gt_set_gpu_blocksize( + sdfg=sdfg, + gpu_block_size=gpu_block_size, + gpu_launch_bounds=gpu_launch_bounds, + gpu_launch_factor=gpu_launch_factor, + ) + + return sdfg + + +def gt_set_gpu_blocksize( + sdfg: dace.SDFG, + gpu_block_size: Optional[Sequence[int | str] | str], + gpu_launch_bounds: Optional[int | str] = None, + gpu_launch_factor: Optional[int] = None, +) -> Any: + """Set the block sizes of GPU Maps. + + Args: + sdfg: The SDFG to process. + gpu_block_size: The block size to use. + gpu_launch_bounds: The launch bounds to use. + gpu_launch_factor: The launch factor to use. + """ + xform = GPUSetBlockSize( + block_size=gpu_block_size, + launch_bounds=gpu_launch_bounds, + launch_factor=gpu_launch_factor, + ) + return sdfg.apply_transformations_once_everywhere([xform]) + + +def _gpu_block_parser( + self: "GPUSetBlockSize", + val: Any, +) -> None: + """Used by the setter of `GPUSetBlockSize.block_size`.""" + org_val = val + if isinstance(val, tuple): + pass + elif isinstance(val, list): + val = tuple(val) + elif isinstance(val, str): + val = tuple(x.strip() for x in val.split(",")) + else: + raise TypeError( + f"Does not know how to transform '{type(val).__name__}' into a proper GPU block size." + ) + if len(val) == 1: + val = (*val, 1, 1) + elif len(val) == 2: + val = (*val, 1) + elif len(val) != 3: + raise ValueError(f"Can not parse block size '{org_val}': wrong length") + try: + val = [int(x) for x in val] + except ValueError: + raise TypeError( + f"Currently only block sizes convertible to int are supported, you passed '{val}'." + ) from None + self._block_size = val + + +def _gpu_block_getter( + self: "GPUSetBlockSize", +) -> tuple[int, int, int]: + """Used as getter in the `GPUSetBlockSize.block_size` property.""" + assert isinstance(self._block_size, (tuple, list)) and len(self._block_size) == 3 + assert all(isinstance(x, int) for x in self._block_size) + return tuple(self._block_size) + + +@dace_properties.make_properties +class GPUSetBlockSize(dace_transformation.SingleStateTransformation): + """Sets the GPU block size on GPU Maps. + + It is also possible to set the launch bound. + + Args: + block_size: The block size that should be used. + launch_bounds: The value for the launch bound that should be used. + launch_factor: If no `launch_bounds` was given use the number of threads + in a block multiplied by this number. + + Todo: + Add the possibility to specify other bounds for 1, 2, or 3 dimensional maps. + """ + + block_size = dace_properties.Property( + dtype=None, + allow_none=False, + default=(32, 1, 1), + setter=_gpu_block_parser, + getter=_gpu_block_getter, + desc="Size of the block size a GPU Map should have.", + ) + + launch_bounds = dace_properties.Property( + dtype=str, + allow_none=True, + default=None, + desc="Set the launch bound property of the map.", + ) + + map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + block_size: Sequence[int | str] | str | None = None, + launch_bounds: int | str | None = None, + launch_factor: int | None = None, + ) -> None: + super().__init__() + if block_size is not None: + self.block_size = block_size + + if launch_factor is not None: + assert launch_bounds is None + self.launch_bounds = str( + int(launch_factor) * self.block_size[0] * self.block_size[1] * self.block_size[2] + ) + elif launch_bounds is None: + self.launch_bounds = None + elif isinstance(launch_bounds, (str, int)): + self.launch_bounds = str(launch_bounds) + else: + raise TypeError( + f"Does not know how to parse '{launch_bounds}' as 'launch_bounds' argument." + ) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_entry)] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Test if the block size can be set. + + The function tests: + - If the block size of the map is already set. + - If the map is at global scope. + - If if the schedule of the map is correct. + """ + + scope = graph.scope_dict() + if scope[self.map_entry] is not None: + return False + if self.map_entry.map.schedule not in dace.dtypes.GPU_SCHEDULES: + return False + if self.map_entry.map.gpu_block_size is not None: + return False + return True + + def apply( + self, + graph: Union[SDFGState, SDFG], + sdfg: SDFG, + ) -> None: + """Modify the map as requested.""" + self.map_entry.map.gpu_block_size = self.block_size + if self.launch_bounds is not None: # Note empty string has a meaning in DaCe + self.map_entry.map.gpu_launch_bounds = self.launch_bounds + + +@dace_properties.make_properties +class SerialMapPromoterGPU(dace_transformation.SingleStateTransformation): + """Serial Map promoter for empty Maps in case of trivial Maps. + + In CPU mode a Tasklet can be outside of a map, however, this is not + possible in GPU mode. For this reason DaCe wraps such Tasklets in a + trivial Map. + This transformation will look for such Maps and promote them, such + that they can be fused with downstream maps. + + Note: + This transformation must be run after the GPU Transformation. + + Todo: + - The transformation assumes that the upper Map is a trivial Tasklet. + Which should be the majority of all cases. + - Combine this transformation such that it can do serial fusion on its own. + """ + + # Pattern Matching + map_exit1 = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the promotion is possible. + + The function tests: + - If the top map is a trivial map. + - If a valid partition exists that can be fused at all. + """ + from .map_fusion_serial import SerialMapFusion + + map_exit_1: dace_nodes.MapExit = self.map_exit1 + map_1: dace_nodes.Map = map_exit_1.map + map_entry_2: dace_nodes.MapEntry = self.map_entry2 + + # Check if the first map is trivial. + if len(map_1.params) != 1: + return False + if map_1.range.num_elements() != 1: + return False + + # Check if it is a GPU map + if map_1.schedule not in [ + dace.dtypes.ScheduleType.GPU_Device, + dace.dtypes.ScheduleType.GPU_Default, + ]: + return False + + # Check if the partition exists, if not promotion to fusing is pointless. + # TODO(phimuell): Find the proper way of doing it. + serial_fuser = SerialMapFusion(only_toplevel_maps=True) + output_partition = serial_fuser.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + if output_partition is None: + return False + + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + """Performs the Map Promoting. + + The function essentially copies the parameters and the ranges from the + bottom map to the top one. + """ + map_1: dace_nodes.Map = self.map_exit1.map + map_2: dace_nodes.Map = self.map_entry2.map + + map_1.params = copy.deepcopy(map_2.params) + map_1.range = copy.deepcopy(map_2.range) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py new file mode 100644 index 0000000000..1e8ded1c1b --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/k_blocking.py @@ -0,0 +1,437 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import copy +import functools +from typing import Any, Optional, Union + +import dace +from dace import ( + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes as dace_nodes +from dace.transformation import helpers as dace_helpers + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util + + +@dace_properties.make_properties +class KBlocking(dace_transformation.SingleStateTransformation): + """Applies k-Blocking with separation on a Map. + + This transformation takes a multidimensional Map and performs blocking on a + dimension, that is commonly called "k", but identified with `block_dim`. + + All dimensions except `k` are unaffected by this transformation. In the outer + Map will be replace the `k` range, currently `k = 0:N`, with + `__coarse_k = 0:N:B`, where `N` is the original size of the range and `B` + is the block size, passed as `blocking_size`. The transformation also handles the + case if `N % B != 0`. + The transformation will then create an inner sequential map with + `k = __coarse_k:(__coarse_k + B)`. + + However, before the split the transformation examines all adjacent nodes of + the original Map. If a node does not depend on `k`, then the node will be + put between the two maps, thus its content will only be computed once. + + The function will also change the name of the outer map, it will append + `_blocked` to it. + """ + + blocking_size = dace_properties.Property( + dtype=int, + allow_none=True, + desc="Size of the inner k Block.", + ) + block_dim = dace_properties.Property( + dtype=str, + allow_none=True, + desc="Which dimension should be blocked (must be an exact match).", + ) + + map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + blocking_size: Optional[int] = None, + block_dim: Optional[Union[gtx_common.Dimension, str]] = None, + ) -> None: + super().__init__() + if isinstance(block_dim, str): + pass + elif isinstance(block_dim, gtx_common.Dimension): + block_dim = gtx_dace_fieldview_util.get_map_variable(block_dim) + if block_dim is not None: + self.block_dim = block_dim + if blocking_size is not None: + self.blocking_size = blocking_size + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_entry)] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Test if the map can be blocked. + + The test involves: + - Toplevel map. + - The map shall not be serial. + - The block dimension must be present (exact match). + - The map range must have step size of 1. + - The partition must exists (see `partition_map_output()`). + """ + if self.block_dim is None: + raise ValueError("The blocking dimension was not specified.") + elif self.blocking_size is None: + raise ValueError("The blocking size was not specified.") + + map_entry: dace_nodes.MapEntry = self.map_entry + map_params: list[str] = map_entry.map.params + map_range: dace_subsets.Range = map_entry.map.range + block_var: str = self.block_dim + + scope = graph.scope_dict() + if scope[map_entry] is not None: + return False + if block_var not in map_entry.map.params: + return False + if map_entry.map.schedule == dace.dtypes.ScheduleType.Sequential: + return False + if map_range[map_params.index(block_var)][2] != 1: + return False + if self.partition_map_output(map_entry, block_var, graph, sdfg) is None: + return False + + return True + + def apply( + self, + graph: Union[SDFGState, SDFG], + sdfg: SDFG, + ) -> None: + """Creates a blocking map. + + Performs the operation described in the doc string. + """ + outer_entry: dace_nodes.MapEntry = self.map_entry + outer_exit: dace_nodes.MapExit = graph.exit_node(outer_entry) + outer_map: dace_nodes.Map = outer_entry.map + map_range: dace_subsets.Range = outer_entry.map.range + map_params: list[str] = outer_entry.map.params + + # This is the name of the iterator we coarsen + block_var: str = self.block_dim + block_idx = map_params.index(block_var) + + # This is the name of the iterator that we use in the outer map for the + # blocked dimension + coarse_block_var = "__coarse_" + block_var + + # Now compute the partitions of the nodes. + independent_nodes, dependent_nodes = self.partition_map_output( # type: ignore[misc] # Guaranteed to be not `None`. + outer_entry, block_var, graph, sdfg + ) + + # Generate the sequential inner map + rng_start = map_range[block_idx][0] + rng_stop = map_range[block_idx][1] + inner_label = f"inner_{outer_map.label}" + inner_range = { + block_var: dace_subsets.Range.from_string( + f"({coarse_block_var} * {self.blocking_size} + {rng_start}):min(({rng_start} + {coarse_block_var} + 1) * {self.blocking_size}, {rng_stop} + 1)" + ) + } + inner_entry, inner_exit = graph.add_map( + name=inner_label, + ndrange=inner_range, + schedule=dace.dtypes.ScheduleType.Sequential, + ) + + # TODO(phimuell): Investigate if we want to prevent unrolling here + + # Now we modify the properties of the outer map. + coarse_block_range = dace_subsets.Range.from_string( + f"0:int_ceil(({rng_stop} + 1) - {rng_start}, {self.blocking_size})" + ).ranges[0] + outer_map.params[block_idx] = coarse_block_var + outer_map.range[block_idx] = coarse_block_range + outer_map.label = f"{outer_map.label}_blocked" + + # Contains the independent nodes that are already relocated. + relocated_nodes: set[dace_nodes.Node] = set() + + # Now we iterate over all the output edges of the outer map and rewire them. + # Note that this only handles the entry of the Map. + for out_edge in list(graph.out_edges(outer_entry)): + edge_dst: dace_nodes.Node = out_edge.dst + + if edge_dst in dependent_nodes: + # This is the simple case as we just have to rewire the edge + # and make a connection between the outer and inner map. + assert not out_edge.data.is_empty() + edge_conn: str = out_edge.src_conn[4:] + + # Must be before the handling of the modification below + # Note that this will remove the original edge from the SDFG. + dace_helpers.redirect_edge( + state=graph, + edge=out_edge, + new_src=inner_entry, + new_src_conn="OUT_" + edge_conn, + ) + + # In a valid SDFG only one edge can go into an input connector of a Map. + if "IN_" + edge_conn in inner_entry.in_connectors: + # We have found this edge multiple times already. + # To ensure that there is no error, we will create a new + # Memlet that reads the whole array. + piping_edge = next(graph.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) + data_name = piping_edge.data.data + piping_edge.data = dace.Memlet.from_array( + data_name, sdfg.arrays[data_name], piping_edge.data.wcr + ) + + else: + # This is the first time we found this connection. + # so we just create the edge. + graph.add_edge( + outer_entry, + "OUT_" + edge_conn, + inner_entry, + "IN_" + edge_conn, + copy.deepcopy(out_edge.data), + ) + inner_entry.add_in_connector("IN_" + edge_conn) + inner_entry.add_out_connector("OUT_" + edge_conn) + continue + + elif edge_dst in relocated_nodes: + # The node was already fully handled in the `else` clause. + continue + + else: + # Relocate the node and make the reconnection. + # Different from the dependent case we will handle all the edges + # of the node in one go. + relocated_nodes.add(edge_dst) + + # In order to be useful we have to temporarily store the data the + # independent node generates + assert graph.out_degree(edge_dst) == 1 # TODO(phimuell): Lift + if isinstance(edge_dst, dace_nodes.AccessNode): + # The independent node is an access node, so we can use it directly. + caching_node: dace_nodes.AccessNode = edge_dst + else: + # The dependent node is not an access node. For now we will + # just use the next node, with some restriction. + # TODO(phimuell): create an access node in this case instead. + caching_node = next(iter(graph.out_edges(edge_dst))).dst + assert graph.in_degree(caching_node) == 1 + assert isinstance(caching_node, dace_nodes.AccessNode) + + # Now rewire the Memlets that leave the caching node to go through + # new inner Map. + for consumer_edge in list(graph.out_edges(caching_node)): + new_map_conn = inner_entry.next_connector() + dace_helpers.redirect_edge( + state=graph, + edge=consumer_edge, + new_dst=inner_entry, + new_dst_conn="IN_" + new_map_conn, + ) + graph.add_edge( + inner_entry, + "OUT_" + new_map_conn, + consumer_edge.dst, + consumer_edge.dst_conn, + copy.deepcopy(consumer_edge.data), + ) + inner_entry.add_in_connector("IN_" + new_map_conn) + inner_entry.add_out_connector("OUT_" + new_map_conn) + continue + + # Handle the Map exits + # This is simple reconnecting, there would be possibilities for improvements + # but we do not use them for now. + for out_edge in list(graph.in_edges(outer_exit)): + edge_conn = out_edge.dst_conn[3:] + dace_helpers.redirect_edge( + state=graph, + edge=out_edge, + new_dst=inner_exit, + new_dst_conn="IN_" + edge_conn, + ) + graph.add_edge( + inner_exit, + "OUT_" + edge_conn, + outer_exit, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + inner_exit.add_in_connector("IN_" + edge_conn) + inner_exit.add_out_connector("OUT_" + edge_conn) + + # TODO(phimuell): Use a less expensive method. + dace.sdfg.propagation.propagate_memlets_state(sdfg, graph) + + def partition_map_output( + self, + map_entry: dace_nodes.MapEntry, + block_param: str, + state: SDFGState, + sdfg: SDFG, + ) -> tuple[set[dace_nodes.Node], set[dace_nodes.Node]] | None: + """Partition the outputs of the Map. + + The partition will only look at the direct intermediate outputs of the + Map. The outputs will be two sets, defined as: + - The independent outputs `\mathcal{I}`: + These are output nodes, whose output does not depend on the blocked + dimension. These nodes can be relocated between the outer and inner map. + - The dependent output `\mathcal{D}`: + These are the output nodes, whose output depend on the blocked dimension. + Thus they can not be relocated between the two maps, but will remain + inside the inner scope. + + In case the function fails to compute the partition `None` is returned. + + Args: + map_entry: The map entry node. + block_param: The Map variable that should be blocked. + state: The state on which we operate. + sdfg: The SDFG in which we operate on. + + Note: + - Currently this function only considers the input Memlets and the + `used_symbol` properties of a Tasklet. + - Furthermore only the first level is inspected. + """ + block_independent: set[dace_nodes.Node] = set() # `\mathcal{I}` + block_dependent: set[dace_nodes.Node] = set() # `\mathcal{D}` + + # Find all nodes that are adjacent to the map entry. + nodes_to_partition: set[dace_nodes.Node] = {edge.dst for edge in state.out_edges(map_entry)} + + # Now we examine every node and assign them to one of the sets. + # Note that this is only tentative and we will later inspect the + # outputs of the independent node and reevaluate their classification. + for node in nodes_to_partition: + # Filter out all nodes that we can not (yet) handle. + if not isinstance(node, (dace_nodes.Tasklet, dace_nodes.AccessNode)): + return None + + # Check if we have a strange Tasklet or if it uses the symbol inside it. + if isinstance(node, dace_nodes.Tasklet): + if node.side_effects: + return None + if block_param in node.free_symbols: + block_dependent.add(node) + continue + + # An independent node can (for now) only have one output. + # TODO(phimuell): Lift this restriction. + if state.out_degree(node) != 1: + block_dependent.add(node) + continue + + # Now we have to understand how the node generates its data. + # For this we have to look at all the edges that feed information to it. + edges: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = list(state.in_edges(node)) + + # If all edges are empty, i.e. they are only needed to keep the + # node inside the scope, consider it as independent. However, they have + # to be associated to the outer map. + if all(edge.data.is_empty() for edge in edges): + if not all(edge.src is map_entry for edge in edges): + return None + block_independent.add(node) + continue + + # Currently we do not allow that a node has a mix of empty and non + # empty Memlets, it is all or nothing. + if any(edge.data.is_empty() for edge in edges): + return None + + # If the node gets information from other nodes than the map entry + # we classify it as a dependent node. But there can be situations where + # the node could still be independent, for example if it is connected + # to a independent node, then it could be independent itself. + # TODO(phimuell): Consider independent node as "equal" to the map. + if any(edge.src is not map_entry for edge in edges): + block_dependent.add(node) + continue + + # Now we have to look at the edges individually. + # If this loop ends normally, i.e. it goes into its `else` + # clause then we classify the node as independent. + for edge in edges: + memlet: dace.Memlet = edge.data + src_subset: dace_subsets.Subset | None = memlet.src_subset + dst_subset: dace_subsets.Subset | None = memlet.dst_subset + edge_desc: dace.data.Data = sdfg.arrays[memlet.data] + edge_desc_size = functools.reduce(lambda a, b: a * b, edge_desc.shape) + + if memlet.is_empty(): + # Empty Memlets do not impose any restrictions. + continue + if memlet.num_elements() == edge_desc_size: + # The whole source array is consumed, which is not a problem. + continue + + # Now we have to look at the source and destination set of the Memlet. + subsets_to_inspect: list[dace_subsets.Subset] = [] + if dst_subset is not None: + subsets_to_inspect.append(dst_subset) + if src_subset is not None: + subsets_to_inspect.append(src_subset) + + # If a subset needs the block variable then the node is not + # independent from the block variable. + if any(block_param in subset.free_symbols for subset in subsets_to_inspect): + break + else: + # The loop ended normally, thus we did not found anything that made us + # _think_ that the node is _not_ an independent node. We will later + # also inspect the output, which might reclassify the node + block_independent.add(node) + + # If the node is not independent then it must be dependent, my dear Watson. + if node not in block_independent: + block_dependent.add(node) + + # We now make a last screening of the independent nodes. + # TODO(phimuell): Make an iterative process to find the maximal set. + for independent_node in list(block_independent): + if isinstance(independent_node, dace_nodes.AccessNode): + if state.in_degree(independent_node) != 1: + block_independent.discard(independent_node) + block_dependent.add(independent_node) + continue + for out_edge in state.out_edges(independent_node): + if ( + (not isinstance(out_edge.dst, dace_nodes.AccessNode)) + or (state.in_degree(out_edge.dst) != 1) + or (out_edge.dst.desc(sdfg).lifetime != dace.dtypes.AllocationLifetime.Scope) + ): + block_independent.discard(independent_node) + block_dependent.add(independent_node) + break + + assert not block_dependent.intersection(block_independent) + assert (len(block_independent) + len(block_dependent)) == len(nodes_to_partition) + + return (block_independent, block_dependent) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py new file mode 100644 index 0000000000..e8433f5cea --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -0,0 +1,565 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements Helper functionaliyies for map fusion""" + +import functools +import itertools +from typing import Any, Optional, Sequence, Union + +import dace +from dace import ( + data as dace_data, + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import ( + SDFG, + SDFGState, + graph as dace_graph, + nodes as dace_nodes, + validation as dace_validation, +) +from dace.transformation import helpers as dace_helpers + +from gt4py.next.program_processors.runners.dace_fieldview.transformations import util + + +@dace_properties.make_properties +class MapFusionHelper(dace_transformation.SingleStateTransformation): + """Contains common part of the fusion for parallel and serial Map fusion. + + The transformation assumes that the SDFG obeys the principals outlined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). + The main advantage of this structure is, that it is rather easy to determine + if a transient is used anywhere else. This check, performed by + `is_interstate_transient()`. It is further speeded up by cashing some computation, + thus such an object should not be used after interstate optimizations were applied + to the SDFG. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + """ + + only_toplevel_maps = dace_properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps are in the top level.", + ) + only_inner_maps = dace_properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", + ) + shared_transients = dace_properties.DictProperty( + key_type=SDFG, + value_type=set[str], + default=None, + allow_none=True, + desc="Maps SDFGs to the set of array transients that can not be removed. " + "The variable acts as a cache, and is managed by 'is_interstate_transient()'.", + ) + + def __init__( + self, + only_inner_maps: Optional[bool] = None, + only_toplevel_maps: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if only_toplevel_maps is not None: + self.only_toplevel_maps = bool(only_toplevel_maps) + if only_inner_maps is not None: + self.only_inner_maps = bool(only_inner_maps) + self.shared_transients = {} + + @classmethod + def expressions(cls) -> bool: + raise RuntimeError("The `_MapFusionHelper` is not a transformation on its own.") + + def can_be_fused( + self, + map_entry_1: dace_nodes.MapEntry, + map_entry_2: dace_nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Performs basic checks if the maps can be fused. + + This function only checks constrains that are common between serial and + parallel map fusion process, which includes: + - The scope of the maps. + - The scheduling of the maps. + - The map parameters. + + However, for performance reasons, the function does not check if the node + decomposition exists. + + Args: + map_entry_1: The entry of the first (in serial case the top) map. + map_exit_2: The entry of the second (in serial case the bottom) map. + graph: The SDFGState in which the maps are located. + sdfg: The SDFG itself. + permissive: Currently unused. + """ + if self.only_inner_maps and self.only_toplevel_maps: + raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") + + # Ensure that both have the same schedule + if map_entry_1.map.schedule != map_entry_2.map.schedule: + return False + + # Fusing is only possible if the two entries are in the same scope. + scope = graph.scope_dict() + if scope[map_entry_1] != scope[map_entry_2]: + return False + elif self.only_inner_maps: + if scope[map_entry_1] is None: + return False + elif self.only_toplevel_maps: + if scope[map_entry_1] is not None: + return False + # TODO(phimuell): Figuring out why this is here. + elif util.is_nested_sdfg(sdfg): + return False + + # We will now check if there exists a "remapping" that we can use. + if not self.map_parameter_compatible( + map_1=map_entry_1.map, map_2=map_entry_2.map, state=graph, sdfg=sdfg + ): + return False + + return True + + @staticmethod + def relocate_nodes( + from_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], + to_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + ) -> None: + """Move the connectors and edges from `from_node` to `to_nodes` node. + + This function will only rewire the edges, it does not remove the nodes + themselves. Furthermore, this function should be called twice per Map, + once for the entry and then for the exit. + While it does not remove the node themselves if guarantees that the + `from_node` has degree zero. + + Args: + from_node: Node from which the edges should be removed. + to_node: Node to which the edges should reconnect. + state: The state in which the operation happens. + sdfg: The SDFG that is modified. + """ + + # Now we relocate empty Memlets, from the `from_node` to the `to_node` + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + dace_helpers.redirect_edge(state, empty_edge, new_src=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + dace_helpers.redirect_edge(state, empty_edge, new_dst=to_node) + + # We now ensure that there is only one empty Memlet from the `to_node` to any other node. + # Although it is allowed, we try to prevent it. + empty_targets: set[dace_nodes.Node] = set() + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): + if empty_edge.dst in empty_targets: + state.remove_edge(empty_edge) + empty_targets.add(empty_edge.dst) + + # We now determine which edges we have to migrate, for this we are looking at + # the incoming edges, because this allows us also to detect dynamic map ranges. + for edge_to_move in list(state.in_edges(from_node)): + assert isinstance(edge_to_move.dst_conn, str) + + if not edge_to_move.dst_conn.startswith("IN_"): + # Dynamic Map Range + # The connector name simply defines a variable name that is used, + # inside the Map scope to define a variable. We handle it directly. + dmr_symbol = edge_to_move.dst_conn + + # TODO(phimuell): Check if the symbol is really unused in the target scope. + if dmr_symbol in to_node.in_connectors: + raise NotImplementedError( + f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" + f" to '{to_node}', but the symbol is already known there, but the" + " renaming is not implemented." + ) + if not to_node.add_in_connector(dmr_symbol, force=False): + raise RuntimeError( # Might fail because of out connectors. + f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." + ) + dace_helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + from_node.remove_in_connector(dmr_symbol) + + # There is no other edge that we have to consider, so we just end here + continue + + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + to_node.add_in_connector("IN_" + new_conn) + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + dace_helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + to_node.add_out_connector("OUT_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + dace_helpers.redirect_edge( + state, e, new_src=to_node, new_src_conn="OUT_" + new_conn + ) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) + + # Check if we succeeded. + if state.out_degree(from_node) != 0: + raise dace_validation.InvalidSDFGError( + f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + if state.in_degree(from_node) != 0: + raise dace_validation.InvalidSDFGError( + f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + assert len(from_node.in_connectors) == 0 + assert len(from_node.out_connectors) == 0 + + @staticmethod + def map_parameter_compatible( + map_1: dace_nodes.Map, + map_2: dace_nodes.Map, + state: Union[SDFGState, SDFG], + sdfg: SDFG, + ) -> bool: + """Checks if the parameters of `map_1` are compatible with `map_2`. + + The check follows the following rules: + - The names of the map variables must be the same, i.e. no renaming + is performed. + - The ranges must be the same. + """ + range_1: dace_subsets.Range = map_1.range + params_1: Sequence[str] = map_1.params + range_2: dace_subsets.Range = map_2.range + params_2: Sequence[str] = map_2.params + + # The maps are only fuseable if we have an exact match in the parameter names + # this is because we do not do any renaming. This is in accordance with the + # rules. + if set(params_1) != set(params_2): + return False + + # Maps the name of a parameter to the dimension index + param_dim_map_1: dict[str, int] = {pname: i for i, pname in enumerate(params_1)} + param_dim_map_2: dict[str, int] = {pname: i for i, pname in enumerate(params_2)} + + # To fuse the two maps the ranges must have the same ranges + for pname in params_1: + idx_1 = param_dim_map_1[pname] + idx_2 = param_dim_map_2[pname] + # TODO(phimuell): do we need to call simplify? + if range_1[idx_1] != range_2[idx_2]: + return False + + return True + + def is_interstate_transient( + self, + transient: Union[str, dace_nodes.AccessNode], + sdfg: dace.SDFG, + state: dace.SDFGState, + ) -> bool: + """Tests if `transient` is an interstate transient, an can not be removed. + + Essentially this function checks if a transient might be needed in a + different state in the SDFG, because it transmit information from + one state to the other. + If only the name of the data container is passed the function will + first look for an corresponding access node. + + The set of these "interstate transients" is computed once per SDFG. + The result is then cached internally for later reuse. + + Args: + transient: The transient that should be checked. + sdfg: The SDFG containing the array. + state: If given the state the node is located in. + + Note: + This function build upon the structure of the SDFG that is outlined + in the HackMD document. + """ + + # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) + # the set of such transients is partially given by all source access dace_nodes. + # Because of rule 3 we also include all scalars in this set, as an over + # approximation. Furthermore, because simplify might violate rule 3, + # we also include the sink dace_nodes. + + # See if we have already computed the set + if sdfg in self.shared_transients: + shared_sdfg_transients: set[str] = self.shared_transients[sdfg] + else: + # SDFG is not known so we have to compute the set. + shared_sdfg_transients = set() + for state_to_scan in sdfg.all_states(): + # TODO(phimuell): Use `all_nodes_recursive()` once it is available. + shared_sdfg_transients.update( + [ + node.data + for node in itertools.chain( + state_to_scan.source_nodes(), state_to_scan.sink_nodes() + ) + if isinstance(node, dace_nodes.AccessNode) + and sdfg.arrays[node.data].transient + ] + ) + self.shared_transients[sdfg] = shared_sdfg_transients + + if isinstance(transient, str): + name = transient + matching_access_nodes = [node for node in state.data_nodes() if node.data == name] + # Rule 8: There is only one access node per state for data. + assert len(matching_access_nodes) == 1 + transient = matching_access_nodes[0] + else: + assert isinstance(transient, dace_nodes.AccessNode) + name = transient.data + + desc: dace_data.Data = sdfg.arrays[name] + if not desc.transient: + return True + if isinstance(desc, dace_data.Scalar): + return True # Scalars can not be removed by fusion anyway. + + # Rule 8: If degree larger than one then it is used within the state. + if state.out_degree(transient) > 1: + return True + + # Now we check if it is used in a different state. + return name in shared_sdfg_transients + + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + map_exit_1: dace_nodes.MapExit, + map_entry_2: dace_nodes.MapEntry, + ) -> Union[ + tuple[ + set[dace_graph.MultiConnectorEdge[dace.Memlet]], + set[dace_graph.MultiConnectorEdge[dace.Memlet]], + set[dace_graph.MultiConnectorEdge[dace.Memlet]], + ], + None, + ]: + """Partition the output edges of `map_exit_1` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + + - Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + - Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit, enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not used anywhere else, thus it can + be removed. + - Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. + + Returns: + If such a decomposition exists the function will return the three sets + mentioned above in the same order. + In case the decomposition does not exist, i.e. the maps can not be fused + the function returns `None`. + + Args: + state: The in which the two maps are located. + sdfg: The full SDFG in whcih we operate. + map_exit_1: The exit node of the first map. + map_entry_2: The entry node of the second map. + """ + # The three outputs set. + pure_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: set[dace_nodes.Node] = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(map_exit_1): + intermediate_node: dace_nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # Now let's look at all nodes that are downstream of the intermediate node. + # This, among other things, will tell us, how we have to handle this node. + downstream_nodes = util.all_nodes_between( + graph=state, + begin=intermediate_node, + end=map_entry_2, + ) + + # If `downstream_nodes` is `None` this means that `map_entry_2` was never + # reached, thus `intermediate_node` does not enter the second map and + # the node is a pure output node. + if downstream_nodes is None: + pure_outputs.add(out_edge) + continue + + # The following tests are _after_ we have determined if we have a pure + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + + # In case the intermediate has more than one entry, all must come from the + # first map, otherwise we can not fuse them. Currently we restrict this + # even further by saying that it has only one incoming Memlet. + if state.in_degree(intermediate_node) != 1: + return None + + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. + # TODO(phimuell): Handle this case properly. + inner_collector_edges = list( + state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) + ) + if len(inner_collector_edges) > 1: + return None + + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, dace_nodes.AccessNode): + return None + intermediate_desc: dace_data.Data = intermediate_node.desc(sdfg) + if isinstance(intermediate_desc, dace_data.View): + return None + + # There are some restrictions we have on intermediate dace_nodes. The first one + # is that we do not allow WCR, this is because they need special handling + # which is currently not implement (the DaCe transformation has this + # restriction as well). The second one is that we can reduce the + # intermediate node and only feed a part into the second map, consider + # the case `b = a + 1; return b + 2`, where we have arrays. In this + # example only a single element must be available to the second map. + # However, this is hard to check so we will make a simplification. + # First, we will not check it at the producer, but at the consumer point. + # There we assume if the consumer does _not consume the whole_ + # intermediate array, then we can decompose the intermediate, by setting + # the map iteration index to zero and recover the shape, see + # implementation in the actual fusion routine. + # This is an assumption that is in most cases correct, but not always. + # However, doing it correctly is extremely complex. + for _, produce_edge in util.find_upstream_producers(state, out_edge): + if produce_edge.data.wcr is not None: + return None + + if len(downstream_nodes) == 0: + # There is nothing between intermediate node and the entry of the + # second map, thus the edge belongs either in `\mathbb{S}` or + # `\mathbb{E}`. + + # This is a very special situation, i.e. the access node has many + # different connections to the second map entry, this is a special + # case that we do not handle. + # TODO(phimuell): Handle this case. + if state.out_degree(intermediate_node) != 1: + return None + + # Certain nodes need more than one element as input. As explained + # above, in this situation we assume that we can naturally decompose + # them iff the node does not consume that whole intermediate. + # Furthermore, it can not be a dynamic map range or a library node. + intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) + consumers = util.find_downstream_consumers(state=state, begin=intermediate_node) + for consumer_node, feed_edge in consumers: + # TODO(phimuell): Improve this approximation. + if ( + intermediate_size != 1 + ) and feed_edge.data.num_elements() == intermediate_size: + return None + if consumer_node is map_entry_2: # Dynamic map range. + return None + if isinstance(consumer_node, dace_nodes.LibraryNode): + # TODO(phimuell): Allow some library dace_nodes. + return None + + # Note that "remove" has a special meaning here, regardless of the + # output of the check function, from within the second map we remove + # the intermediate, it has more the meaning of "do we need to + # reconstruct it after the second map again?" + if self.is_interstate_transient(intermediate_node, sdfg, state): + shared_outputs.add(out_edge) + else: + exclusive_outputs.add(out_edge) + continue + + else: + # There is not only a single connection from the intermediate node to + # the second map, but the intermediate has more connections, thus + # the node might belong to the shared output. Of the many different + # possibilities, we only consider a single case: + # - The intermediate has a single connection to the second map, that + # fulfills the restriction outlined above. + # - All other connections have no connection to the second map. + found_second_entry = False + intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) + for edge in state.out_edges(intermediate_node): + if edge.dst is map_entry_2: + if found_second_entry: # The second map was found again. + return None + found_second_entry = True + consumers = util.find_downstream_consumers(state=state, begin=edge) + for consumer_node, feed_edge in consumers: + if feed_edge.data.num_elements() == intermediate_size: + return None + if consumer_node is map_entry_2: # Dynamic map range + return None + if isinstance(consumer_node, dace_nodes.LibraryNode): + # TODO(phimuell): Allow some library dace_nodes. + return None + else: + # Ensure that there is no path that leads to the second map. + after_intermdiate_node = util.all_nodes_between( + graph=state, begin=edge.dst, end=map_entry_2 + ) + if after_intermdiate_node is not None: + return None + # If we are here, then we know that the node is a shared output + shared_outputs.add(out_edge) + continue + + assert exclusive_outputs or shared_outputs or pure_outputs + assert len(processed_inter_nodes) == sum( + len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] + ) + return (pure_outputs, exclusive_outputs, shared_outputs) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py new file mode 100644 index 0000000000..3f21f205b0 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py @@ -0,0 +1,127 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the parallel map fusing transformation.""" + +from typing import Any, Optional, Union + +import dace +from dace import properties as dace_properties, transformation as dace_transformation +from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview.transformations import ( + map_fusion_helper, + util, +) + + +@dace_properties.make_properties +class ParallelMapFusion(map_fusion_helper.MapFusionHelper): + """The `ParallelMapFusion` transformation allows to merge two parallel maps together. + + The `SerialMapFusion` transformation is only able to handle maps that are sequential, + however, this transformation is able to fuse _any_ maps that are not sequential + and are in the same scope. + + Args: + only_if_common_ancestor: Only perform fusion if both Maps share at least one + node as direct ancestor. This will increase the locality of the merge. + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + + Note: + This transformation only matches the entry nodes of the Map, but will also + modify the exit nodes of the Map. + """ + + map_entry1 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + + only_if_common_ancestor = dace_properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps share a node as parent.", + ) + + def __init__( + self, + only_if_common_ancestor: Optional[bool] = None, + **kwargs: Any, + ) -> None: + if only_if_common_ancestor is not None: + self.only_if_common_ancestor = only_if_common_ancestor + super().__init__(**kwargs) + + @classmethod + def expressions(cls) -> Any: + # This just matches _any_ two Maps inside a state. + state = dace_graph.OrderedMultiDiConnectorGraph() + state.add_nodes_from([cls.map_entry1, cls.map_entry2]) + return [state] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """The transformation is applicable.""" + map_entry_1: dace_nodes.MapEntry = self.map_entry1 + map_entry_2: dace_nodes.MapEntry = self.map_entry2 + + # Check the structural properties of the maps, this will also ensure that + # the two maps are in the same scope. + if not self.can_be_fused( + map_entry_1=map_entry_1, + map_entry_2=map_entry_2, + graph=graph, + sdfg=sdfg, + permissive=permissive, + ): + return False + + # Since the match expression matches any twp Maps, we have to ensure that + # the maps are parallel. The `can_be_fused()` function already verified + # if they are in the same scope. + if not util.is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2): + return False + + # Test if they have they share a node as direct ancestor. + if self.only_if_common_ancestor: + # This assumes that there is only one access node per data container in the state. + ancestors_1: set[dace_nodes.Node] = {e1.src for e1 in graph.in_edges(map_entry_1)} + if not any(e2.src in ancestors_1 for e2 in graph.in_edges(map_entry_2)): + return False + + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + """Performs the Map fusing. + + Essentially, the function relocate all edges from the nodes forming the second + Map to the corresponding nodes of the first Map. Afterwards the nodes of the + second Map are removed. + """ + assert self.map_parameter_compatible(self.map_entry1.map, self.map_entry2.map, graph, sdfg) + + map_entry_1: dace_nodes.MapEntry = self.map_entry1 + map_exit_1: dace_nodes.MapExit = graph.exit_node(map_entry_1) + map_entry_2: dace_nodes.MapEntry = self.map_entry2 + map_exit_2: dace_nodes.MapExit = graph.exit_node(map_entry_2) + + for to_node, from_node in zip((map_entry_1, map_exit_1), (map_entry_2, map_exit_2)): + self.relocate_nodes( + from_node=from_node, + to_node=to_node, + state=graph, + sdfg=sdfg, + ) + # The relocate function does not remove the node, so we must do it. + graph.remove_node(from_node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py new file mode 100644 index 0000000000..7cbf59813e --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py @@ -0,0 +1,477 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the serial map fusing transformation.""" + +import copy +from typing import Any, Union + +import dace +from dace import ( + dtypes as dace_dtypes, + properties as dace_properties, + subsets as dace_subsets, + symbolic as dace_symbolic, + transformation as dace_transformation, +) +from dace.sdfg import SDFG, SDFGState, graph as dace_graph, nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview.transformations import map_fusion_helper + + +@dace_properties.make_properties +class SerialMapFusion(map_fusion_helper.MapFusionHelper): + """Specialized replacement for the map fusion transformation that is provided by DaCe. + + As its name is indicating this transformation is only able to handle Maps that + are in sequence. Compared to the native DaCe transformation, this one is able + to handle more complex cases of connection between the maps. In that sense, it + is much more similar to DaCe's `SubgraphFusion` transformation. + + Things that are improved, compared to the native DaCe implementation: + - Nested Maps. + - Temporary arrays and the correct propagation of their Memlets. + - Top Maps that have multiple outputs. + + Conceptually this transformation removes the exit of the first or upper map + and the entry of the lower or second map and then rewrites the connections + appropriately. + + This transformation assumes that an SDFG obeys the structure that is outlined + [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). For that + reason it is not true replacement of the native DaCe transformation. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + + Notes: + - This transformation modifies more nodes than it matches! + """ + + map_exit1 = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + @classmethod + def expressions(cls) -> Any: + """Get the match expression. + + The transformation matches the exit node of the top Map that is connected to + an access node that again is connected to the entry node of the second Map. + An important note is, that the transformation operates not just on the + matched nodes, but more or less on anything that has an incoming connection + from the first Map or an outgoing connection to the second Map entry. + """ + return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the matched Maps can be merged. + + The two Maps are mergeable iff: + - The `can_be_fused()` of the base succeed, which checks some basic constraints. + - The decomposition exists and at least one of the intermediate sets + is not empty. + """ + assert isinstance(self.map_exit1, dace_nodes.MapExit) + assert isinstance(self.map_entry2, dace_nodes.MapEntry) + map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) + map_entry_2: dace_nodes.MapEntry = self.map_entry2 + + # This essentially test the structural properties of the two Maps. + if not self.can_be_fused( + map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg + ): + return False + + # Two maps can be serially fused if the node decomposition exists and + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=self.map_exit1, + map_entry_2=self.map_entry2, + ) + if output_partition is None: + return False + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): + return False + return True + + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + """Performs the serial Map fusing. + + The function first computes the map decomposition and then handles the + three sets. The pure outputs are handled by `relocate_nodes()` while + the two intermediate sets are handled by `handle_intermediate_set()`. + + By assumption we do not have to rename anything. + + Args: + graph: The SDFG state we are operating on. + sdfg: The SDFG we are operating on. + """ + # NOTE: `self.map_*` actually stores the ID of the node. + # once we start adding and removing nodes it seems that their ID changes. + # Thus we have to save them here, this is a known behaviour in DaCe. + assert isinstance(graph, dace.SDFGState) + assert isinstance(self.map_exit1, dace_nodes.MapExit) + assert isinstance(self.map_entry2, dace_nodes.MapEntry) + assert self.map_parameter_compatible(self.map_exit1.map, self.map_entry2.map, graph, sdfg) + + map_exit_1: dace_nodes.MapExit = self.map_exit1 + map_entry_2: dace_nodes.MapEntry = self.map_entry2 + map_exit_2: dace_nodes.MapExit = graph.exit_node(self.map_entry2) + map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) + + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + assert output_partition is not None # Make MyPy happy. + pure_outputs, exclusive_outputs, shared_outputs = output_partition + + if len(exclusive_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=exclusive_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=True, + ) + if len(shared_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=shared_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=False, + ) + assert pure_outputs == set(graph.out_edges(map_exit_1)) + if len(pure_outputs) != 0: + self.relocate_nodes( + from_node=map_exit_1, + to_node=map_exit_2, + state=graph, + sdfg=sdfg, + ) + + # Above we have handled the input of the second map and moved them + # to the first map, now we must move the output of the first map + # to the second one, as this one is used. + self.relocate_nodes( + from_node=map_entry_2, + to_node=map_entry_1, + state=graph, + sdfg=sdfg, + ) + + for node_to_remove in [map_exit_1, map_entry_2]: + assert graph.degree(node_to_remove) == 0 + graph.remove_node(node_to_remove) + + # Now turn the second output node into the output node of the first Map. + map_exit_2.map = map_entry_1.map + + @staticmethod + def handle_intermediate_set( + intermediate_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]], + state: SDFGState, + sdfg: SDFG, + map_exit_1: dace_nodes.MapExit, + map_entry_2: dace_nodes.MapEntry, + map_exit_2: dace_nodes.MapExit, + is_exclusive_set: bool, + ) -> None: + """This function handles the intermediate sets. + + The function is able to handle both the shared and exclusive intermediate + output set, see `partition_first_outputs()`. The main difference is that + in exclusive mode the intermediate nodes will be fully removed from + the SDFG. While in shared mode the intermediate node will be preserved. + + Args: + intermediate_outputs: The set of outputs, that should be processed. + state: The state in which the map is processed. + sdfg: The SDFG that should be optimized. + map_exit_1: The exit of the first/top map. + map_entry_2: The entry of the second map. + map_exit_2: The exit of the second map. + is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + Notes: + Before the transformation the `state` does not have to be valid and + after this function has run the state is (most likely) invalid. + + Todo: + Rewrite using `MemletTree`. + """ + + # Essentially this function removes the AccessNode between the two maps. + # However, we still need some temporary memory that we can use, which is + # just much smaller, i.e. a scalar. But all Memlets inside the second map + # assumes that the intermediate memory has the bigger shape. + # To fix that we will create this replacement dict that will replace all + # occurrences of the iteration variables of the second map with zero. + # Note that this is still not enough as the dimensionality might be different. + memlet_repl: dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} + + # Now we will iterate over all intermediate edges and process them. + # If not stated otherwise the comments assume that we run in exclusive mode. + for out_edge in intermediate_outputs: + # This is the intermediate node that, that we want to get rid of. + # In shared mode we want to recreate it after the second map. + inter_node: dace_nodes.AccessNode = out_edge.dst + inter_name = inter_node.data + inter_desc = inter_node.desc(sdfg) + inter_shape = inter_desc.shape + + # Now we will determine the shape of the new intermediate. This size of + # this temporary is given by the Memlet that goes into the first map exit. + pre_exit_edges = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + ) + if len(pre_exit_edges) != 1: + raise NotImplementedError() + pre_exit_edge = pre_exit_edges[0] + new_inter_shape_raw = dace_symbolic.overapproximate(pre_exit_edge.data.subset.size()) + + # Over approximation will leave us with some unneeded size one dimensions. + # That are known to cause some troubles, so we will now remove them. + squeezed_dims: list[int] = [] # These are the dimensions we removed. + new_inter_shape: list[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape) + ): + # Order of checks is important! + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + + # This is the name of the new "intermediate" node that we will create. + # It will only have the shape `new_inter_shape` which is basically its + # output within one Map iteration. + # NOTE: The insertion process might generate a new name. + new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + + # Now generate the intermediate data container. + if len(new_inter_shape) == 0: + assert pre_exit_edge.data.subset.num_elements() == 1 + is_scalar = True + new_inter_name, new_inter_desc = sdfg.add_scalar( + new_inter_name, + dtype=inter_desc.dtype, + transient=True, + storage=dace_dtypes.StorageType.Register, + find_new_name=True, + ) + + else: + assert (pre_exit_edge.data.subset.num_elements() > 1) or all( + x == 1 for x in new_inter_shape + ) + is_scalar = False + new_inter_name, new_inter_desc = sdfg.add_transient( + new_inter_name, + shape=new_inter_shape, + dtype=inter_desc.dtype, + find_new_name=True, + ) + new_inter_node: dace_nodes.AccessNode = state.add_access(new_inter_name) + + # New we will reroute the output Memlet, thus it will no longer pass + # through the Map exit but through the newly created intermediate. + # we will delete the previous edge later. + pre_exit_memlet: dace.Memlet = pre_exit_edge.data + new_pre_exit_memlet = copy.deepcopy(pre_exit_memlet) + + # We might operate on a different array, but the check below, ensures + # that we do not change the direction of the Memlet. + assert pre_exit_memlet.data == inter_name + new_pre_exit_memlet.data = new_inter_name + + # Now we have to modify the subset of the Memlet. + # Before the subset of the Memlet was dependent on the Map variables, + # however, this is no longer the case, as we removed them. This change + # has to be reflected in the Memlet. + # NOTE: Assert above ensures that the below is correct. + new_pre_exit_memlet.replace(memlet_repl) + if is_scalar: + new_pre_exit_memlet.subset = "0" + new_pre_exit_memlet.other_subset = None + else: + new_pre_exit_memlet.subset.pop(squeezed_dims) + + # Now we create the new edge between the producer and the new output + # (the new intermediate node). We will remove the old edge further down. + new_pre_exit_edge = state.add_edge( + pre_exit_edge.src, + pre_exit_edge.src_conn, + new_inter_node, + None, + new_pre_exit_memlet, + ) + + # We just have handled the last Memlet, but we must actually handle the + # whole producer side, i.e. the scope of the top Map. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(): + producer_edge = producer_tree.edge + + # Ensure the correctness of the rerouting below. + # TODO(phimuell): Improve the code below to remove the check. + assert producer_edge.data.data == inter_name + + # Will not change the direction, because of test above! + producer_edge.data.data = new_inter_name + producer_edge.data.replace(memlet_repl) + if is_scalar: + producer_edge.data.dst_subset = "0" + elif producer_edge.data.dst_subset is not None: + producer_edge.data.dst_subset.pop(squeezed_dims) + + # Now after we have handled the input of the new intermediate node, + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. + # NOTE: Assumes that map (if connected is the direct neighbour). + conn_names: set[str] = set() + for inter_node_out_edge in state.out_edges(inter_node): + if inter_node_out_edge.dst == map_entry_2: + assert inter_node_out_edge.dst_conn.startswith("IN_") + conn_names.add(inter_node_out_edge.dst_conn) + else: + # If we found another target than the second map entry from the + # intermediate node it means that the node _must_ survive, + # i.e. we are not in exclusive mode. + assert not is_exclusive_set + + # Now we will reroute the connections inside the second map, i.e. + # instead of consuming the old intermediate node, they will now + # consume the new intermediate node. + for in_conn_name in conn_names: + out_conn_name = "OUT_" + in_conn_name[3:] + + for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): + assert inner_edge.data.data == inter_name # DIRECTION!! + + # The create the first Memlet to transmit information, within + # the second map, we do this again by copying and modifying + # the original Memlet. + # NOTE: Test above is important to ensure the direction of the + # Memlet and the correctness of the code below. + new_inner_memlet = copy.deepcopy(inner_edge.data) + new_inner_memlet.replace(memlet_repl) + new_inner_memlet.data = new_inter_name # Because of the assert above, this will not change the direction. + + # Now remove the old edge, that started the second map entry. + # Also add the new edge that started at the new intermediate. + state.remove_edge(inner_edge) + new_inner_edge = state.add_edge( + new_inter_node, + None, + inner_edge.dst, + inner_edge.dst_conn, + new_inner_memlet, + ) + + # Now we do subset modification to ensure that nothing failed. + if is_scalar: + new_inner_memlet.src_subset = "0" + elif new_inner_memlet.src_subset is not None: + new_inner_memlet.src_subset.pop(squeezed_dims) + + # Now clean the Memlets of that tree to use the new intermediate node. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(): + consumer_edge = consumer_tree.edge + assert consumer_edge.data.data == inter_name + consumer_edge.data.data = new_inter_name + if is_scalar: + consumer_edge.data.src_subset = "0" + elif consumer_edge.data.subset is not None: + consumer_edge.data.subset.pop(squeezed_dims) + + # The edge that leaves the second map entry was already deleted. + # We will now delete the edges that brought the data. + for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + assert edge.src == inter_node + state.remove_edge(edge) + map_entry_2.remove_in_connector(in_conn_name) + map_entry_2.remove_out_connector(out_conn_name) + + if is_exclusive_set: + # In exclusive mode the old intermediate node is no longer needed. + assert state.degree(inter_node) == 1 + state.remove_edge_and_connectors(out_edge) + state.remove_node(inter_node) + + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + map_exit_1.remove_out_connector(out_edge.src_conn) + del sdfg.arrays[inter_name] + + else: + # This is the shared mode, so we have to recreate the intermediate + # node, but this time it is at the exit of the second map. + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + + # This is the Memlet that goes from the map internal intermediate + # temporary node to the Map output. This will essentially restore + # or preserve the output for the intermediate node. It is important + # that we use the data that `preExitEdge` was used. + new_exit_memlet = copy.deepcopy(pre_exit_edge.data) + assert new_exit_memlet.data == inter_name + new_exit_memlet.subset = pre_exit_edge.data.dst_subset + new_exit_memlet.other_subset = ( + "0" if is_scalar else dace_subsets.Range.from_array(inter_desc) + ) + + new_pre_exit_conn = map_exit_2.next_connector() + state.add_edge( + new_inter_node, + None, + map_exit_2, + "IN_" + new_pre_exit_conn, + new_exit_memlet, + ) + state.add_edge( + map_exit_2, + "OUT_" + new_pre_exit_conn, + inter_node, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) + map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) + + map_exit_1.remove_out_connector(out_edge.src_conn) + state.remove_edge(out_edge) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py new file mode 100644 index 0000000000..f7d447fdc6 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py @@ -0,0 +1,123 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Optional, Sequence, Union + +import dace +from dace import properties as dace_properties, transformation as dace_transformation +from dace.sdfg import SDFG, SDFGState, nodes as dace_nodes + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util + + +@dace_properties.make_properties +class MapIterationOrder(dace_transformation.SingleStateTransformation): + """Modify the order of the iteration variables. + + The iteration order, while irrelevant from an SDFG point of view, is highly + relevant in code, and the fastest varying index ("inner most loop" in CPU or + "x block dimension" in GPU) should be associated with the stride 1 dimension + of the array. + This transformation will reorder the map indexes such that this is the case. + + While the place of the leading dimension is clearly defined, the order of the + other loop indexes, after this transformation is unspecified. + + Args: + leading_dim: A GT4Py dimension object that identifies the dimension that + is supposed to have stride 1. + + Note: + The transformation does follow the rules outlines [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) + + Todo: + - Extend that different dimensions can be specified to be leading + dimensions, with some priority mechanism. + - Maybe also process the parameters to bring them in a canonical order. + """ + + leading_dim = dace_properties.Property( + dtype=str, + allow_none=True, + desc="Dimension that should become the leading dimension.", + ) + + map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + leading_dim: Optional[Union[gtx_common.Dimension, str]] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if isinstance(leading_dim, gtx_common.Dimension): + self.leading_dim = gtx_dace_fieldview_util.get_map_variable(leading_dim) + elif leading_dim is not None: + self.leading_dim = leading_dim + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_entry)] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Test if the map can be reordered. + + Essentially the function checks if the selected dimension is inside the map, + and if so, if it is on the right place. + """ + + if self.leading_dim is None: + return False + map_entry: dace_nodes.MapEntry = self.map_entry + map_params: Sequence[str] = map_entry.map.params + map_var: str = self.leading_dim + + if map_var not in map_params: + return False + if map_params[-1] == map_var: # Already at the correct location + return False + return True + + def apply( + self, + graph: Union[SDFGState, SDFG], + sdfg: SDFG, + ) -> None: + """Performs the actual parameter reordering. + + The function will make the map variable, that corresponds to + `self.leading_dim` the last map variable (this is given by the structure of + DaCe's code generator). + """ + map_entry: dace_nodes.MapEntry = self.map_entry + map_params: list[str] = map_entry.map.params + map_var: str = self.leading_dim + + # This implementation will just swap the variable that is currently the last + # with the one that should be the last. + dst_idx = -1 + src_idx = map_params.index(map_var) + + for to_process in [ + map_entry.map.params, + map_entry.map.range.ranges, + map_entry.map.range.tile_sizes, + ]: + assert isinstance(to_process, list) + src_val = to_process[src_idx] + dst_val = to_process[dst_idx] + to_process[dst_idx] = src_val + to_process[src_idx] = dst_val diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py new file mode 100644 index 0000000000..4bba23e6e2 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -0,0 +1,379 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Mapping, Optional, Sequence, Union + +import dace +from dace import ( + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import SDFG, SDFGState, nodes as dace_nodes + + +__all__ = [ + "SerialMapPromoter", +] + + +@dace_properties.make_properties +class BaseMapPromoter(dace_transformation.SingleStateTransformation): + """Base transformation to add certain missing dimension to a map. + + By adding certain dimension to a Map, it might became possible to use the Map + in more transformations. This class acts as a base and the actual matching and + checking must be implemented by a concrete implementation. + But it provides some basic check functionality and the actual promotion logic. + + The transformation operates on two Maps, first the "source map". This map + describes the Map that should be used as template. The second one is "map to + promote". After the transformation the "map to promote" will have the same + map parameter as the "source map" has. + + In order to properly work, the parameters of "source map" must be a strict + superset of the ones of "map to promote". Furthermore, this transformation + builds upon the structure defined [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). + Thus it only checks the name of the parameters. + + To influence what to promote the user must implement the `map_to_promote()` + and `source_map()` function. They have to return the map entry node. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + promote_vertical: If `True` promote vertical dimensions; `True` by default. + promote_local: If `True` promote local dimensions; `True` by default. + promote_horizontal: If `True` promote horizontal dimensions; `False` by default. + promote_all: Do not impose any restriction on what to promote. The only + reasonable value is `True` or `None`. + + Note: + This ignores tiling. + """ + + only_toplevel_maps = dace_properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps are on the top level.", + ) + only_inner_maps = dace_properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", + ) + promote_vertical = dace_properties.Property( + dtype=bool, + default=True, + desc="If `True` promote vertical dimensions.", + ) + promote_local = dace_properties.Property( + dtype=bool, + default=True, + desc="If `True` promote local dimensions.", + ) + promote_horizontal = dace_properties.Property( + dtype=bool, + default=False, + desc="If `True` promote horizontal dimensions.", + ) + promote_all = dace_properties.Property( + dtype=bool, + default=False, + desc="If `True` perform any promotion. Takes precedence over all other selectors.", + ) + + def map_to_promote( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> dace_nodes.MapEntry: + """Returns the map entry that should be promoted.""" + raise NotImplementedError(f"{type(self).__name__} must implement 'map_to_promote'.") + + def source_map( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> dace_nodes.MapEntry: + """Returns the map entry that is used as source/template.""" + raise NotImplementedError(f"{type(self).__name__} must implement 'source_map'.") + + @classmethod + def expressions(cls) -> Any: + raise TypeError("You must implement 'expressions' yourself.") + + def __init__( + self, + only_inner_maps: Optional[bool] = None, + only_toplevel_maps: Optional[bool] = None, + promote_local: Optional[bool] = None, + promote_vertical: Optional[bool] = None, + promote_horizontal: Optional[bool] = None, + promote_all: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if only_inner_maps is not None: + self.only_inner_maps = bool(only_inner_maps) + if only_toplevel_maps is not None: + self.only_toplevel_maps = bool(only_toplevel_maps) + if promote_local is not None: + self.promote_local = bool(promote_local) + if promote_vertical is not None: + self.promote_vertical = bool(promote_vertical) + if promote_horizontal is not None: + self.promote_horizontal = bool(promote_horizontal) + if promote_all is not None: + self.promote_all = bool(promote_all) + self.promote_horizontal = False + self.promote_vertical = False + self.promote_local = False + if only_inner_maps and only_toplevel_maps: + raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") + if not ( + self.promote_local + or self.promote_vertical + or self.promote_horizontal + or self.promote_all + ): + raise ValueError( + "You must select at least one class of dimension that should be promoted." + ) + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Perform some basic structural tests on the map. + + A subclass should call this function before checking anything else. If a + subclass has not called this function, the behaviour will be undefined. + The function checks: + - If the map to promote is in the right scope. + - If the parameter of the second map are compatible with each other. + - If a dimension would be promoted that should not. + """ + map_to_promote_entry: dace_nodes.MapEntry = self.map_to_promote(state=graph, sdfg=sdfg) + map_to_promote: dace_nodes.Map = map_to_promote_entry.map + source_map_entry: dace_nodes.MapEntry = self.source_map(state=graph, sdfg=sdfg) + source_map: dace_nodes.Map = source_map_entry.map + + # Test the scope of the promotee. + # Because of the nature of the transformation, it is not needed that the + # two maps are in the same scope. However, they should be in the same state + # to ensure that the symbols are the same and all. But this is guaranteed by + # the nature of this transformation (single state). + if self.only_inner_maps or self.only_toplevel_maps: + scopeDict: Mapping[dace_nodes.Node, Union[dace_nodes.Node, None]] = graph.scope_dict() + if self.only_inner_maps and (scopeDict[map_to_promote_entry] is None): + return False + if self.only_toplevel_maps and (scopeDict[map_to_promote_entry] is not None): + return False + + # Test if the map ranges are compatible with each other. + missing_map_parameters: list[str] | None = self.missing_map_params( + map_to_promote=map_to_promote, + source_map=source_map, + be_strict=True, + ) + if not missing_map_parameters: + return False + + # We now know which dimensions we have to add to the promotee map. + # Now we must test if we are also allowed to make that promotion in the first place. + if not self.promote_all: + dimension_identifier: list[str] = [] + if self.promote_local: + dimension_identifier.append("__gtx_localdim") + if self.promote_vertical: + dimension_identifier.append("__gtx_vertical") + if self.promote_horizontal: + dimension_identifier.append("__gtx_horizontal") + if not dimension_identifier: + return False + for missing_map_param in missing_map_parameters: + if not any( + missing_map_param.endswith(dim_identifier) + for dim_identifier in dimension_identifier + ): + return False + + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + """Performs the actual Map promoting. + + Add all parameters that `self.source_map` has but `self.map_to_promote` + lacks to `self.map_to_promote` the range of these new dimensions is taken + from the source map. + The order of the parameters the Map has after the promotion is unspecific. + """ + map_to_promote: dace_nodes.Map = self.map_to_promote(state=graph, sdfg=sdfg).map + source_map: dace_nodes.Map = self.source_map(state=graph, sdfg=sdfg).map + source_params: Sequence[str] = source_map.params + source_ranges: dace_subsets.Range = source_map.range + + missing_params: Sequence[str] = self.missing_map_params( # type: ignore[assignment] # Will never be `None` + map_to_promote=map_to_promote, + source_map=source_map, + be_strict=False, + ) + + # Maps the map parameter of the source map to its index, i.e. which map + # parameter it is. + map_source_param_to_idx: dict[str, int] = {p: i for i, p in enumerate(source_params)} + + promoted_params = list(map_to_promote.params) + promoted_ranges = list(map_to_promote.range.ranges) + + for missing_param in missing_params: + promoted_params.append(missing_param) + promoted_ranges.append(source_ranges[map_source_param_to_idx[missing_param]]) + + # Now update the map properties + # This action will also remove the tiles + map_to_promote.range = dace_subsets.Range(promoted_ranges) + map_to_promote.params = promoted_params + + def missing_map_params( + self, + map_to_promote: dace_nodes.Map, + source_map: dace_nodes.Map, + be_strict: bool = True, + ) -> list[str] | None: + """Returns the parameter that are missing in the map that should be promoted. + + The returned sequence is empty if they are already have the same parameters. + The function will return `None` is promoting is not possible. + + Args: + map_to_promote: The map that should be promoted. + source_map: The map acting as template. + be_strict: Ensure that the ranges that are already there are correct. + """ + source_params_set: set[str] = set(source_map.params) + curr_params_set: set[str] = set(map_to_promote.params) + + # The promotion can only work if the source map's parameters + # if a superset of the ones the map that should be promoted is. + if not source_params_set.issuperset(curr_params_set): + return None + + if be_strict: + # Check if the parameters that are already in the map to promote have + # the same range as in the source map. + source_ranges: dace_subsets.Range = source_map.range + curr_ranges: dace_subsets.Range = map_to_promote.range + curr_param_to_idx: dict[str, int] = {p: i for i, p in enumerate(map_to_promote.params)} + source_param_to_idx: dict[str, int] = {p: i for i, p in enumerate(source_map.params)} + for param_to_check in curr_params_set: + curr_range = curr_ranges[curr_param_to_idx[param_to_check]] + source_range = source_ranges[source_param_to_idx[param_to_check]] + if curr_range != source_range: + return None + return list(source_params_set - curr_params_set) + + +@dace_properties.make_properties +class SerialMapPromoter(BaseMapPromoter): + """Promote a map such that it can be fused serially. + + A condition for fusing serial Maps is that they cover the same range. This + transformation is able to promote a Map, i.e. adding the missing dimensions, + such that the maps can be fused. + For more information see the `BaseMapPromoter` class. + + Notes: + The transformation does not perform the fusing on its one. + + Todo: + The map should do the fusing on its own directly. + """ + + # Pattern Matching + exit_first_map = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + entry_second_map = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + + @classmethod + def expressions(cls) -> Any: + """Get the match expressions. + + The function generates two match expressions. The first match describes + the case where the top map must be promoted, while the second case is + the second/lower map must be promoted. + """ + return [ + dace.sdfg.utils.node_path_graph( + cls.exit_first_map, cls.access_node, cls.entry_second_map + ), + dace.sdfg.utils.node_path_graph( + cls.exit_first_map, cls.access_node, cls.entry_second_map + ), + ] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the Maps really can be fused.""" + from .map_fusion_serial import SerialMapFusion + + if not super().can_be_applied(graph, expr_index, sdfg, permissive): + return False + + # Check if the partition exists, if not promotion to fusing is pointless. + # TODO(phimuell): Find the proper way of doing it. + serial_fuser = SerialMapFusion(only_toplevel_maps=True) + output_partition = serial_fuser.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=self.exit_first_map, + map_entry_2=self.entry_second_map, + ) + if output_partition is None: + return False + + return True + + def map_to_promote( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> dace_nodes.MapEntry: + if self.expr_index == 0: + # The first the top map will be promoted. + return state.entry_node(self.exit_first_map) + assert self.expr_index == 1 + + # The second map will be promoted. + return self.entry_second_map + + def source_map( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> dace_nodes.MapEntry: + """Returns the map entry that is used as source/template.""" + if self.expr_index == 0: + # The first the top map will be promoted, so the second map is the source. + return self.entry_second_map + assert self.expr_index == 1 + + # The second map will be promoted, so the first is used as source + return state.entry_node(self.exit_first_map) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py new file mode 100644 index 0000000000..f91311261a --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -0,0 +1,193 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Common functionality for the transformations/optimization pipeline.""" + +from typing import Iterable, Union + +import dace +from dace.sdfg import graph as dace_graph, nodes as dace_nodes + + +def is_nested_sdfg( + sdfg: Union[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG], +) -> bool: + """Tests if `sdfg` is a NestedSDFG.""" + if isinstance(sdfg, dace.SDFGState): + sdfg = sdfg.parent + if isinstance(sdfg, dace_nodes.NestedSDFG): + return True + elif isinstance(sdfg, dace.SDFG): + if sdfg.parent_nsdfg_node is not None: + return True + return False + else: + raise TypeError(f"Does not know how to handle '{type(sdfg).__name__}'.") + + +def all_nodes_between( + graph: dace.SDFG | dace.SDFGState, + begin: dace_nodes.Node, + end: dace_nodes.Node, + reverse: bool = False, +) -> set[dace_nodes.Node] | None: + """Find all nodes that are reachable from `begin` but bound by `end`. + + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end`, this edge is ignored. It will thus found any node that is reachable + from `begin` by a path that does not involve `end`. The returned set will + never contain `end` nor `begin`. In case `end` is never found the function + will return `None`. + + If `reverse` is set to `True` the function will start exploring at `end` and + follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. + + Args: + graph: The graph to operate on. + begin: The start of the DFS. + end: The terminator node of the DFS. + reverse: Perform a backward DFS. + + Notes: + - The returned set will also contain the nodes of path that starts at + `begin` and ends at a node that is not `end`. + """ + + def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: + if reverse: + return (edge.src for edge in graph.in_edges(node)) + return (edge.dst for edge in graph.out_edges(node)) + + if reverse: + begin, end = end, begin + + to_visit: list[dace_nodes.Node] = [begin] + seen: set[dace_nodes.Node] = set() + found_end: bool = False + + while len(to_visit) > 0: + n: dace_nodes.Node = to_visit.pop() + if n == end: + found_end = True + continue + elif n in seen: + continue + seen.add(n) + to_visit.extend(next_nodes(n)) + + if not found_end: + return None + + seen.discard(begin) + return seen + + +def is_parallel( + graph: dace.SDFG | dace.SDFGState, + node1: dace_nodes.Node, + node2: dace_nodes.Node, +) -> bool: + """Tests if `node1` and `node2` are parallel. + + The nodes are parallel if `node2` can not be reached from `node1` and vice versa. + + Args: + graph: The graph to traverse. + node1: The first node to check. + node2: The second node to check. + """ + + # The `all_nodes_between()` function traverse the graph and returns `None` if + # `end` was not found. We have to call it twice, because we do not know + # which node is upstream if they are not parallel. + if all_nodes_between(graph=graph, begin=node1, end=node2) is not None: + return False + elif all_nodes_between(graph=graph, begin=node2, end=node1) is not None: + return False + return True + + +def find_downstream_consumers( + state: dace.SDFGState, + begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], + only_tasklets: bool = False, + reverse: bool = False, +) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: + """Find all downstream connectors of `begin`. + + A consumer, in for this function, is any node that is neither an entry nor + an exit node. The function returns a set of pairs, the first element is the + node that acts as consumer and the second is the edge that leads to it. + By setting `only_tasklets` the nodes the function finds are only Tasklets. + + To find this set the function starts a search at `begin`, however, it is also + possible to pass an edge as `begin`. + If `reverse` is `True` the function essentially finds the producers that are + upstream. + + Args: + state: The state in which to look for the consumers. + begin: The initial node that from which the search starts. + only_tasklets: Return only Tasklets. + reverse: Follow the reverse direction. + """ + if isinstance(begin, dace_graph.MultiConnectorEdge): + to_visit: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = [begin] + elif reverse: + to_visit = list(state.in_edges(begin)) + else: + to_visit = list(state.out_edges(begin)) + seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + found: set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() + + while len(to_visit) != 0: + curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() + next_node: dace_nodes.Node = curr_edge.src if reverse else curr_edge.dst + + if curr_edge in seen: + continue + seen.add(curr_edge) + + if isinstance(next_node, (dace_nodes.MapEntry, dace_nodes.MapExit)): + if reverse: + target_conn = curr_edge.src_conn[4:] + new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) + else: + # In forward mode a Map entry could also mean the definition of a + # dynamic map range. + if (not curr_edge.dst_conn.startswith("IN_")) and isinstance( + next_node, dace_nodes.MapEntry + ): + # This edge defines a dynamic map range, which is a consumer + if not only_tasklets: + found.add((next_node, curr_edge)) + continue + target_conn = curr_edge.dst_conn[3:] + new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) + to_visit.extend(new_edges) + del new_edges + else: + if only_tasklets and (not isinstance(next_node, dace_nodes.Tasklet)): + continue + found.add((next_node, curr_edge)) + + return found + + +def find_upstream_producers( + state: dace.SDFGState, + begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], + only_tasklets: bool = False, +) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: + """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" + return find_downstream_consumers( + state=state, + begin=begin, + only_tasklets=only_tasklets, + reverse=True, + ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/__init__.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/__init__.py new file mode 100644 index 0000000000..abf4c3e24c --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/__init__.py @@ -0,0 +1,8 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py new file mode 100644 index 0000000000..72e76a63e2 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/conftest.py @@ -0,0 +1,30 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Optional, Sequence, Union, overload, Literal, Generator + +import pytest +import dace +import copy +import numpy as np +from dace.sdfg import nodes as dace_nodes +from dace.transformation import dataflow as dace_dataflow + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +@pytest.fixture(autouse=True) +def _set_dace_settings() -> Generator[None, None, None]: + """Customizes DaCe settings during the tests.""" + with dace.config.temporary_config(): + dace.Config.set("optimizer", "match_exception", value=True) + dace.Config.set("compiler", "allow_view_arguments", value=True) + yield diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py new file mode 100644 index 0000000000..91d76ebd39 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_k_blocking.py @@ -0,0 +1,144 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Callable +import dace +import copy +import numpy as np + +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np.ndarray]]: + """Creates a simple SDFG. + + The k blocking transformation can be applied to the SDFG, however no node + can be taken out. This is because how it is constructed. However, applying + some simplistic transformations this can be done. + """ + sdfg = dace.SDFG("only_dependent") + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_symbol("N", dace.int32) + sdfg.add_symbol("M", dace.int32) + _, a = sdfg.add_array("a", ("N", "M"), dace.float64, transient=False) + _, b = sdfg.add_array("b", ("N",), dace.float64, transient=False) + _, c = sdfg.add_array("c", ("N", "M"), dace.float64, transient=False) + state.add_mapped_tasklet( + name="comp", + map_ranges=dict(i=f"0:N", j=f"0:M"), + inputs=dict(__in0=dace.Memlet("a[i, j]"), __in1=dace.Memlet("b[i]")), + outputs=dict(__out=dace.Memlet("c[i, j]")), + code="__out = __in0 + __in1", + external_edges=True, + ) + return sdfg, lambda a, b: a + b.reshape((-1, 1)) + + +def test_only_dependent(): + """Just applying the transformation to the SDFG. + + Because all of nodes (which is only a Tasklet) inside the map scope are + "dependent", see the transformation for explanation of terminology, the + transformation will only add an inner map. + """ + sdfg, reff = _get_simple_sdfg() + + N, M = 100, 10 + a = np.random.rand(N, M) + b = np.random.rand(N) + c = np.zeros_like(a) + ref = reff(a, b) + + # Apply the transformation + sdfg.apply_transformations_repeated( + gtx_transformations.KBlocking(blocking_size=10, block_dim="j"), + validate=True, + validate_all=True, + ) + + assert len(sdfg.states()) == 1 + state = sdfg.states()[0] + source_nodes = state.source_nodes() + assert len(source_nodes) == 2 + assert all(isinstance(x, dace_nodes.AccessNode) for x in source_nodes) + source_node = source_nodes[0] # Unspecific which one it is, but it does not matter. + assert state.out_degree(source_node) == 1 + outer_map: dace_nodes.MapEntry = next(iter(state.out_edges(source_node))).dst + assert isinstance(outer_map, dace_nodes.MapEntry) + assert state.in_degree(outer_map) == 2 + assert state.out_degree(outer_map) == 2 + assert len(outer_map.map.params) == 2 + assert "j" not in outer_map.map.params + assert all(isinstance(x.dst, dace_nodes.MapEntry) for x in state.out_edges(outer_map)) + inner_map: dace_nodes.MapEntry = next(iter(state.out_edges(outer_map))).dst + assert len(inner_map.map.params) == 1 + assert inner_map.map.params[0] == "j" + assert inner_map.map.schedule == dace.dtypes.ScheduleType.Sequential + + sdfg(a=a, b=b, c=c, N=N, M=M) + assert np.allclose(ref, c) + + +def test_intermediate_access_node(): + """Test the lifting out, version "AccessNode". + + The Tasklet of the SDFG generated by `_get_simple_sdfg()` has to be inside the + inner most loop because one of its input Memlet depends on `j`. However, + one of its input, `b[i]` does not. Instead of connecting `b` directly with the + Tasklet, this test will store `b[i]` inside a temporary inside the Map. + This access node is independent of `j` and can thus be moved out of the inner + most scope. + """ + sdfg, reff = _get_simple_sdfg() + + N, M = 100, 10 + a = np.random.rand(N, M) + b = np.random.rand(N) + c = np.zeros_like(a) + ref = reff(a, b) + + # Now make a small modification is such that the transformation does something. + state = sdfg.states()[0] + sdfg.add_scalar("tmp", dace.float64, transient=True) + + tmp = state.add_access("tmp") + edge = next( + e for e in state.edges() if isinstance(e.src, dace_nodes.MapEntry) and e.data.data == "b" + ) + state.add_edge(edge.src, edge.src_conn, tmp, None, copy.deepcopy(edge.data)) + state.add_edge(tmp, None, edge.dst, edge.dst_conn, dace.Memlet("tmp[0]")) + state.remove_edge(edge) + + # Test if after the modification the SDFG still works + sdfg(a=a, b=b, c=c, N=N, M=M) + assert np.allclose(ref, c) + + # Apply the transformation. + sdfg.apply_transformations_repeated( + gtx_transformations.KBlocking(blocking_size=10, block_dim="j"), + validate=True, + validate_all=True, + ) + + # Inspect if the SDFG was modified correctly. + # We only inspect `tmp` which now has to be between the two maps. + assert state.in_degree(tmp) == 1 + assert state.out_degree(tmp) == 1 + top_node = next(iter(state.in_edges(tmp))).src + bottom_node = next(iter(state.out_edges(tmp))).dst + assert isinstance(top_node, dace_nodes.MapEntry) + assert isinstance(bottom_node, dace_nodes.MapEntry) + assert bottom_node is not top_node + + c[:] = 0 + sdfg(a=a, b=b, c=c, N=N, M=M) + assert np.allclose(ref, c) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion_parallel.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion_parallel.py new file mode 100644 index 0000000000..be525eab93 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion_parallel.py @@ -0,0 +1,203 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Optional, Sequence, Union, Literal, overload + +import pytest +import dace +import copy +import numpy as np +from dace.sdfg import nodes as dace_nodes +from dace.transformation import dataflow as dace_dataflow + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) +from . import util + + +def _count_toplevel_maps(sdfg: dace.SDFG) -> int: + topLevelMaps = 0 + for state in sdfg.nodes(): + scope = state.scope_dict() + for maps in filter(lambda n: isinstance(n, dace_nodes.MapEntry), state.nodes()): + if scope[maps] is None: + topLevelMaps += 1 + return topLevelMaps + + +def _make_parallel_sdfg_1( + N_a: str | int, + N_b: str | int | None = None, + N_c: str | int | None = None, + it_var_b: str | None = None, + it_var_c: str | None = None, +) -> dace.SDFG: + """Create the "parallel_1_sdfg". + + This is a simple SDFG with two parallel Map, it has one input `a`. Further, + it has the two outputs `b` and `c` defined as `b := a + 2.` and `c := a + 4.`. + The array may have different length, but the size of `b` and `c` must be smaller + or equal the size of `a`. + By using `it_var_{b,c}` it is possible to control which iteration variables should + be used by the Maps. If `it_var_b` is not given it defaults to `__i0` and if + `it_var_c` is not given it defaults to the value given to `_it_var_b`. + + Args: + N_a: The length of array `a`, must be the largest. + N_b: The length of array `b`, if not given equals to `N_a`. + N_c: The length of array `c`, if not given equals to `N_a`. + it_var_b: The iteration variable used by the Map handling `b`. + it_var_c: The iteration variable used by the Map handling `c`. + """ + + if N_b is None: + N_b = N_a + if N_c is None: + N_c = N_a + if it_var_b is None: + it_var_b = "__i0" + if it_var_c is None: + it_var_c = it_var_b + + shapes = {"a": N_a, "b": N_b, "c": N_c} + + sdfg = dace.SDFG("parallel_1_sdfg") + state = sdfg.add_state(is_start_block=True) + + for name in ["a", "b", "c"]: + sdfg.add_array( + name=name, + shape=(shapes[name],), + dtype=dace.float64, + transient=False, + ) + a = state.add_access("a") + + state.add_mapped_tasklet( + name="first_computation", + map_ranges=[(it_var_b, f"0:{shapes['b']}")], + inputs={"__in0": dace.Memlet(f"a[{it_var_b}]")}, + code="__out = __in0 + 2.0", + outputs={"__out": dace.Memlet(f"b[{it_var_b}]")}, + input_nodes={"a": a}, + external_edges=True, + ) + + state.add_mapped_tasklet( + name="second_computation", + map_ranges=[(it_var_c, f"0:{shapes['c']}")], + input_nodes={"a": a}, + inputs={"__in0": dace.Memlet(f"a[{it_var_c}]")}, + code="__out = __in0 + 4.0", + outputs={"__out": dace.Memlet(f"c[{it_var_c}]")}, + external_edges=True, + ) + + return sdfg + + +def test_simple_fusing() -> None: + """Tests a simple case of parallel map fusion. + + The parallel maps have the same sizes and same iteration bounds and variables. + This means that the transformation applies. + """ + + # The size of the request. + N = 10 + sdfg = _make_parallel_sdfg_1(N_a="N_a") + assert _count_toplevel_maps(sdfg) == 2 + + # Now run the optimization + sdfg.apply_transformations_repeated([gtx_transformations.ParallelMapFusion], validate_all=True) + assert _count_toplevel_maps(sdfg) == 1, f"Expected that the two maps were fused." + + # Now run the SDFG to check if the code is still valid. + a = np.random.rand(N) + b = np.zeros_like(a) + c = np.zeros_like(a) + + # Compute the reference solution. + ref_b = a + 2 + ref_c = a + 4 + + # Now calling the SDFG. + sdfg(a=a, b=b, c=c, N_a=N) + + assert np.allclose(ref_b, b), f"Computation of 'b' failed." + assert np.allclose(ref_c, c), f"Computation of 'c' failed." + + +def test_non_fusable() -> None: + """Tests a case where the bounds did not match. + + This can never be fused. + """ + N = 10 + N_a, N_b, N_c = N, N, N - 1 + + sdfg = _make_parallel_sdfg_1(N_a="N_a", N_b="N_a", N_c="N_c") + assert _count_toplevel_maps(sdfg) == 2 + + # Now run the optimization, which will not succeed. + sdfg.apply_transformations_repeated([gtx_transformations.ParallelMapFusion], validate_all=True) + assert _count_toplevel_maps(sdfg) == 2, f"Expected that the two maps could not be fused." + + # Testing if it still runs as expected. + a = np.random.rand(N_a) + b = np.zeros(N_b) + c = np.zeros(N_c) + + # Compute the reference solution. + ref_b = a[0:N_b] + 2 + ref_c = a[0:N_c] + 4 + + # Now calling the SDFG. + sdfg(a=a, b=b, c=c, N_a=N_a, N_c=N_c) + + assert np.allclose(ref_b, b), f"Computation of 'b' failed." + assert np.allclose(ref_c, c), f"Computation of 'c' failed." + + +@pytest.mark.xfail(reason="Renaming of iteration variables is not implemented.") +def test_renaming_fusing() -> None: + """Tests if the renaming works. + + The two parallel maps are technically fusable, but the iteration variables + are different, thus a renaming must be done. + + Note: + The renaming feature is currently not implemented, so this test will + (currently) fail. + """ + + # The size of the request. + N = 10 + sdfg = _make_parallel_sdfg_1(N_a="N_a", it_var_b="__itB", it_var_c="__itC") + assert _count_toplevel_maps(sdfg) == 2 + + # Now run the optimization + sdfg.apply_transformations_repeated([gtx_transformations.ParallelMapFusion], validate_all=True) + assert _count_toplevel_maps(sdfg) == 1, f"Expected that the two maps were fused." + + # Now run the SDFG to check if the code is still valid. + a = np.random.rand(N) + b = np.zeros_like(a) + c = np.zeros_like(a) + + # Compute the reference solution. + ref_b = a + 2 + ref_c = a + 4 + + # Now calling the SDFG. + sdfg(a=a, b=b, c=c, N_a=N) + + assert np.allclose(ref_b, b), f"Computation of 'b' failed." + assert np.allclose(ref_c, c), f"Computation of 'c' failed." diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion_serial.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion_serial.py new file mode 100644 index 0000000000..8d8a108765 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_map_fusion_serial.py @@ -0,0 +1,456 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Optional, Sequence, Union, Literal, overload + +import pytest +import dace +import copy +import numpy as np +from dace.sdfg import nodes as dace_nodes +from dace.transformation import dataflow as dace_dataflow + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) +from . import util + + +def _make_serial_sdfg_1( + N: str | int, +) -> dace.SDFG: + """Create the "serial_1_sdfg". + + This is an SDFG with a single state containing two maps. It has the input + `a` and the output `b`, each two dimensional arrays, with size `0:N`. + The first map adds 1 to the input and writes it into `tmp`. The second map + adds another 3 to `tmp` and writes it back inside `b`. + + Args: + N: The size of the arrays. + """ + shape = (N, N) + sdfg = dace.SDFG("serial_1_sdfg") + state = sdfg.add_state(is_start_block=True) + + for name in ["a", "b", "tmp"]: + sdfg.add_array( + name=name, + shape=shape, + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + tmp = state.add_access("tmp") + + state.add_mapped_tasklet( + name="first_computation", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + inputs={"__in0": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in0 + 1.0", + outputs={"__out": dace.Memlet("tmp[__i0, __i1]")}, + output_nodes={"tmp": tmp}, + external_edges=True, + ) + + state.add_mapped_tasklet( + name="second_computation", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + input_nodes={"tmp": tmp}, + inputs={"__in0": dace.Memlet("tmp[__i0, __i1]")}, + code="__out = __in0 + 3.0", + outputs={"__out": dace.Memlet("b[__i0, __i1]")}, + external_edges=True, + ) + + return sdfg + + +def _make_serial_sdfg_2( + N: str | int, +) -> dace.SDFG: + """Create the "serial_2_sdfg". + + The generated SDFG uses `a` and input and has two outputs `b := a + 4` and + `c := a - 4`. There is a top map with a single Single Tasklet, that has + two outputs, the first one computes `a + 1` and stores that in `tmp_1`. + The second output computes `a - 1` and stores it `tmp_2`. + Below the top map are two (parallel) map, one compute `b := tmp_1 + 3`, while + the other compute `c := tmp_2 - 3`. This means that there are two map fusions. + The main important thing is that, the second map fusion will involve a pure + fusion (because the processing order is indeterministic, one does not know + which one in advance). + + Args: + N: The size of the arrays. + """ + shape = (N, N) + sdfg = dace.SDFG("serial_2_sdfg") + state = sdfg.add_state(is_start_block=True) + + for name in ["a", "b", "c", "tmp_1", "tmp_2"]: + sdfg.add_array( + name=name, + shape=shape, + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["tmp_1"].transient = True + sdfg.arrays["tmp_2"].transient = True + tmp_1 = state.add_access("tmp_1") + tmp_2 = state.add_access("tmp_2") + + state.add_mapped_tasklet( + name="first_computation", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + inputs={"__in0": dace.Memlet("a[__i0, __i1]")}, + code="__out0 = __in0 + 1.0\n__out1 = __in0 - 1.0", + outputs={ + "__out0": dace.Memlet("tmp_1[__i0, __i1]"), + "__out1": dace.Memlet("tmp_2[__i0, __i1]"), + }, + output_nodes={ + "tmp_1": tmp_1, + "tmp_2": tmp_2, + }, + external_edges=True, + ) + + state.add_mapped_tasklet( + name="first_computation", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + input_nodes={"tmp_1": tmp_1}, + inputs={"__in0": dace.Memlet("tmp_1[__i0, __i1]")}, + code="__out = __in0 + 3.0", + outputs={"__out": dace.Memlet("b[__i0, __i1]")}, + external_edges=True, + ) + state.add_mapped_tasklet( + name="second_computation", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + input_nodes={"tmp_2": tmp_2}, + inputs={"__in0": dace.Memlet("tmp_2[__i0, __i1]")}, + code="__out = __in0 - 3.0", + outputs={"__out": dace.Memlet("c[__i0, __i1]")}, + external_edges=True, + ) + + return sdfg + + +def _make_serial_sdfg_3( + N_input: str | int, + N_output: str | int, +) -> dace.SDFG: + """Creates a serial SDFG that has an indirect access Tasklet in the second map. + + The SDFG has three inputs `a`, `b` and `idx`. The first two are 1 dimensional + arrays, and the second is am array containing integers. + The top map computes `a + b` and stores that in `tmp`. + The second map then uses the elements of `idx` to make indirect accesses into + `tmp`, which are stored inside `c`. + + Args: + N_input: The length of `a` and `b`. + N_output: The length of `c` and `idx`. + """ + input_shape = (N_input,) + output_shape = (N_output,) + + sdfg = dace.SDFG("serial_3_sdfg") + state = sdfg.add_state(is_start_block=True) + + for name, shape in [ + ("a", input_shape), + ("b", input_shape), + ("c", output_shape), + ("idx", output_shape), + ("tmp", input_shape), + ]: + sdfg.add_array( + name=name, + shape=shape, + dtype=dace.int32 if name == "idx" else dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + tmp = state.add_access("tmp") + + state.add_mapped_tasklet( + name="first_computation", + map_ranges=[("__i0", f"0:{N_input}")], + inputs={ + "__in0": dace.Memlet("a[__i0]"), + "__in1": dace.Memlet("b[__i0]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + output_nodes={"tmp": tmp}, + external_edges=True, + ) + + state.add_mapped_tasklet( + name="indirect_access", + map_ranges=[("__i0", f"0:{N_output}")], + input_nodes={"tmp": tmp}, + inputs={ + "__index": dace.Memlet("idx[__i0]"), + "__array": dace.Memlet.simple("tmp", subset_str=f"0:{N_input}", num_accesses=1), + }, + code="__out = __array[__index]", + outputs={"__out": dace.Memlet("c[__i0]")}, + external_edges=True, + ) + + return sdfg + + +def test_exclusive_itermediate(): + """Tests if the exclusive intermediate branch works.""" + N = 10 + sdfg = _make_serial_sdfg_1(N) + + # Now apply the optimizations. + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + sdfg.apply_transformations( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert "tmp" not in sdfg.arrays + + # Test if the intermediate is a scalar + intermediate_nodes: list[dace_nodes.Node] = [ + node + for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + if node.data not in ["a", "b"] + ] + assert len(intermediate_nodes) == 1 + assert all(isinstance(node.desc(sdfg), dace.data.Scalar) for node in intermediate_nodes) + + a = np.random.rand(N, N) + b = np.empty_like(a) + ref = a + 4.0 + sdfg(a=a, b=b) + + assert np.allclose(b, ref) + + +def test_shared_itermediate(): + """Tests the shared intermediate path. + + The function uses the `_make_serial_sdfg_1()` SDFG. However, it promotes `tmp` + to a global, and it thus became a shared intermediate, i.e. will survive. + """ + N = 10 + sdfg = _make_serial_sdfg_1(N) + sdfg.arrays["tmp"].transient = False + + # Now apply the optimizations. + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + sdfg.apply_transformations( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert "tmp" in sdfg.arrays + + # Test if the intermediate is a scalar + intermediate_nodes: list[dace_nodes.Node] = [ + node + for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + if node.data not in ["a", "b", "tmp"] + ] + assert len(intermediate_nodes) == 1 + assert all(isinstance(node.desc(sdfg), dace.data.Scalar) for node in intermediate_nodes) + + a = np.random.rand(N, N) + b = np.empty_like(a) + tmp = np.empty_like(a) + + ref_b = a + 4.0 + ref_tmp = a + 1.0 + sdfg(a=a, b=b, tmp=tmp) + + assert np.allclose(b, ref_b) + assert np.allclose(tmp, ref_tmp) + + +def test_pure_output_node(): + """Tests the path of a pure intermediate.""" + N = 10 + sdfg = _make_serial_sdfg_2(N) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 + + # The first fusion will only bring it down to two maps. + sdfg.apply_transformations( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + sdfg.apply_transformations( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + + a = np.random.rand(N, N) + b = np.empty_like(a) + c = np.empty_like(a) + ref_b = a + 4.0 + ref_c = a - 4.0 + sdfg(a=a, b=b, c=c) + + assert np.allclose(b, ref_b) + assert np.allclose(c, ref_c) + + +def test_array_intermediate(): + """Tests the correct working if we have more than scalar intermediate. + + The test used `_make_serial_sdfg_1()` to get an SDFG and then call `MapExpansion`. + Map fusion is then called only outer maps, thus the intermediate node, must + be an array. + """ + N = 10 + sdfg = _make_serial_sdfg_1(N) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + sdfg.apply_transformations_repeated([dace_dataflow.MapExpansion]) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 4 + + # Now perform the fusion + sdfg.apply_transformations( + gtx_transformations.SerialMapFusion(only_toplevel_maps=True), + validate=True, + validate_all=True, + ) + map_entries = util._count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) + + scope = next(iter(sdfg.states())).scope_dict() + assert len(map_entries) == 3 + top_maps = [map_entry for map_entry in map_entries if scope[map_entry] is None] + assert len(top_maps) == 1 + top_map = top_maps[0] + assert sum(scope[map_entry] is top_map for map_entry in map_entries) == 2 + + # Find the access node that is the new intermediate node. + inner_access_nodes: list[dace_nodes.AccessNode] = [ + node + for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + if scope[node] is not None + ] + assert len(inner_access_nodes) == 1 + inner_access_node = inner_access_nodes[0] + inner_desc: dace.data.Data = inner_access_node.desc(sdfg) + assert inner_desc.shape == (N,) + + a = np.random.rand(N, N) + b = np.empty_like(a) + ref_b = a + 4.0 + sdfg(a=a, b=b) + + assert np.allclose(ref_b, b) + + +def test_interstate_transient(): + """Tests if an interstate transient is handled properly. + + This function uses the SDFG generated by `_make_serial_sdfg_2()`. It adds a second + state to SDFG in which `tmp_1` is read from and the result is written in `d` (new + variable). Thus `tmp_1` can not be removed. + """ + N = 10 + sdfg = _make_serial_sdfg_2(N) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert sdfg.number_of_nodes() == 1 + + # Now add the new state and the new output. + sdfg.add_datadesc("d", copy.deepcopy(sdfg.arrays["b"])) + head_state = next(iter(sdfg.states())) + new_state = sdfg.add_state_after(head_state) + + new_state.add_mapped_tasklet( + name="first_computation_second_state", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + inputs={"__in0": dace.Memlet("tmp_1[__i0, __i1]")}, + code="__out = __in0 + 9.0", + outputs={"__out": dace.Memlet("d[__i0, __i1]")}, + external_edges=True, + ) + + # Now apply the transformation + sdfg.apply_transformations_repeated( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert "tmp_1" in sdfg.arrays + assert "tmp_2" not in sdfg.arrays + assert sdfg.number_of_nodes() == 2 + assert util._count_nodes(head_state, dace_nodes.MapEntry) == 1 + assert util._count_nodes(new_state, dace_nodes.MapEntry) == 1 + + a = np.random.rand(N, N) + b = np.empty_like(a) + c = np.empty_like(a) + d = np.empty_like(a) + ref_b = a + 4.0 + ref_c = a - 4.0 + ref_d = a + 10.0 + + sdfg(a=a, b=b, c=c, d=d) + assert np.allclose(ref_b, b) + assert np.allclose(ref_c, c) + assert np.allclose(ref_d, d) + + +def test_indirect_access(): + """Tests if indirect accesses are handled. + + Indirect accesses, a Tasklet dereferences the array, can not be fused, because + the array is accessed by the Tasklet. + """ + N_input = 100 + N_output = 1000 + a = np.random.rand(N_input) + b = np.random.rand(N_input) + c = np.empty(N_output) + idx = np.random.randint(low=0, high=N_input, size=N_output, dtype=np.int32) + sdfg = _make_serial_sdfg_3(N_input=N_input, N_output=N_output) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + + def _ref(a, b, idx): + tmp = a + b + return tmp[idx] + + ref = _ref(a, b, idx) + + sdfg(a=a, b=b, idx=idx, c=c) + assert np.allclose(ref, c) + + # Now "apply" the transformation + sdfg.apply_transformations_repeated( + gtx_transformations.SerialMapFusion(), + validate=True, + validate_all=True, + ) + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + + c[:] = -1.0 + sdfg(a=a, b=b, idx=idx, c=c) + assert np.allclose(ref, c) + + +def test_indirect_access_2(): + # TODO(phimuell): Index should be computed and that map should be fusable. + pass diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py new file mode 100644 index 0000000000..5c9f555582 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/test_serial_map_promoter.py @@ -0,0 +1,89 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Callable +import dace +import copy +import numpy as np + +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) +from . import util + + +def test_serial_map_promotion(): + """Tests the serial Map promotion transformation.""" + N = 10 + shape_1d = (N,) + shape_2d = (N, N) + sdfg = dace.SDFG("serial_promotable") + state = sdfg.add_state(is_start_block=True) + + # 1D Arrays + for name in ["a", "tmp"]: + sdfg.add_array( + name=name, + shape=shape_1d, + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + + # 2D Arrays + for name in ["b", "c"]: + sdfg.add_array( + name=name, + shape=shape_2d, + dtype=dace.float64, + transient=False, + ) + tmp = state.add_access("tmp") + + _, map_entry_1d, _ = state.add_mapped_tasklet( + name="one_d_map", + map_ranges=[("__i0", f"0:{N}")], + inputs={"__in0": dace.Memlet("a[__i0]")}, + code="__out = __in0 + 1.0", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + output_nodes={"tmp": tmp}, + external_edges=True, + ) + + _, map_entry_2d, _ = state.add_mapped_tasklet( + name="two_d_map", + map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], + input_nodes={"tmp": tmp}, + inputs={"__in0": dace.Memlet("tmp[__i0]"), "__in1": dace.Memlet("b[__i0, __i1]")}, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("c[__i0, __i1]")}, + external_edges=True, + ) + + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert len(map_entry_1d.map.params) == 1 + assert len(map_entry_2d.map.params) == 2 + + sdfg.view() + # Now apply the promotion + sdfg.apply_transformations( + gtx_transformations.SerialMapPromoter( + promote_all=True, + ), + validate=True, + validate_all=True, + ) + + sdfg.view() + + assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert len(map_entry_1d.map.params) == 2 + assert len(map_entry_2d.map.params) == 2 + assert set(map_entry_1d.map.params) == set(map_entry_2d.map.params) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py new file mode 100644 index 0000000000..739582d5d9 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace/transformation_tests/util.py @@ -0,0 +1,59 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Union, Literal, overload + +import dace +from dace.sdfg import nodes as dace_nodes +from dace.transformation import dataflow as dace_dataflow + +__all__ = [ + "_count_nodes", +] + + +@overload +def _count_nodes( + graph: Union[dace.SDFG, dace.SDFGState], + node_type: tuple[type, ...] | type, + return_nodes: Literal[False], +) -> int: ... + + +@overload +def _count_nodes( + graph: Union[dace.SDFG, dace.SDFGState], + node_type: tuple[type, ...] | type, + return_nodes: Literal[True], +) -> list[dace_nodes.Node]: ... + + +def _count_nodes( + graph: Union[dace.SDFG, dace.SDFGState], + node_type: tuple[type, ...] | type, + return_nodes: bool = False, +) -> Union[int, list[dace_nodes.Node]]: + """Counts the number of nodes in of a particular type in `graph`. + + If `graph` is an SDFGState then only count the nodes inside this state, + but if `graph` is an SDFG count in all states. + + Args: + graph: The graph to scan. + node_type: The type or sequence of types of nodes to look for. + """ + + states = graph.states() if isinstance(graph, dace.SDFG) else [graph] + found_nodes: list[dace_nodes.Node] = [] + for state_nodes in states: + for node in state_nodes.nodes(): + if isinstance(node, node_type): + found_nodes.append(node) + if return_nodes: + return found_nodes + return len(found_nodes)