diff --git a/dace/transformation/dataflow/copy_to_map.py b/dace/transformation/dataflow/copy_to_map.py index 5b4260ad55..9c4dbce627 100644 --- a/dace/transformation/dataflow/copy_to_map.py +++ b/dace/transformation/dataflow/copy_to_map.py @@ -1,12 +1,13 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -from dace import dtypes, symbolic, data, subsets, Memlet +from dace import dtypes, symbolic, data, subsets, Memlet, properties from dace.sdfg.scope import is_devicelevel_gpu from dace.transformation import transformation as xf from dace.sdfg import SDFGState, SDFG, nodes, utils as sdutil from typing import Tuple +import itertools - +@properties.make_properties class CopyToMap(xf.SingleStateTransformation): """ Converts an access node -> access node copy into a map. Useful for generating manual code and @@ -14,6 +15,10 @@ class CopyToMap(xf.SingleStateTransformation): """ a = xf.PatternNode(nodes.AccessNode) b = xf.PatternNode(nodes.AccessNode) + ignore_strides = properties.Property( + default=False, + desc='Ignore the stride of the data container; Defaults to `False`.', + ) @classmethod def expressions(cls): @@ -31,7 +36,10 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi if isinstance(self.b.desc(sdfg), data.View): if sdutil.get_view_node(graph, self.b) == self.a: return False - if self.a.desc(sdfg).strides == self.b.desc(sdfg).strides: + if (not self.ignore_strides) and self.a.desc(sdfg).strides == self.b.desc(sdfg).strides: + return False + # Ensures that the edge goes from `a` -> `b`. + if not any(edge.dst is self.b for edge in graph.out_edges(self.a)): return False return True @@ -62,31 +70,69 @@ def delinearize_linearize(self, desc: data.Array, copy_shape: Tuple[symbolic.Sym return subsets.Range([(ind, ind, 1) for ind in cur_index]) def apply(self, state: SDFGState, sdfg: SDFG): - adesc = self.a.desc(sdfg) - bdesc = self.b.desc(sdfg) - edge = state.edges_between(self.a, self.b)[0] + avnode = self.a + av = avnode.data + adesc = avnode.desc(sdfg) + bvnode = self.b + bv = bvnode.data + bdesc = bvnode.desc(sdfg) + + edge = state.edges_between(avnode, bvnode)[0] + src_subset = edge.data.get_src_subset(edge, state) + if src_subset is None: + src_subset = subsets.Range.from_array(adesc) + src_subset_size = src_subset.size() + red_src_subset_size = tuple(s for s in src_subset_size if s != 1) + + dst_subset = edge.data.get_dst_subset(edge, state) + if dst_subset is None: + dst_subset = subsets.Range.from_array(bdesc) + dst_subset_size = dst_subset.size() + red_dst_subset_size = tuple(s for s in dst_subset_size if s != 1) if len(adesc.shape) >= len(bdesc.shape): - copy_shape = edge.data.get_src_subset(edge, state).size() + copy_shape = src_subset_size copy_a = True else: - copy_shape = edge.data.get_dst_subset(edge, state).size() + copy_shape = dst_subset_size copy_a = False - maprange = {f'__i{i}': (0, s - 1, 1) for i, s in enumerate(copy_shape)} - - av = self.a.data - bv = self.b.data - avnode = self.a - bvnode = self.b - - # Linearize and delinearize to get index expression for other side - if copy_a: - a_index = [symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape))] - b_index = self.delinearize_linearize(bdesc, copy_shape, edge.data.get_dst_subset(edge, state)) + if tuple(src_subset_size) == tuple(dst_subset_size): + # The two subsets have exactly the same shape, so we can just copying with an offset. + # We use another index variables for the tests only. + maprange = {f'__j{i}': (0, s - 1, 1) for i, s in enumerate(copy_shape)} + a_index = [symbolic.pystr_to_symbolic(f'__j{i} + ({src_subset[i][0]})') for i in range(len(copy_shape))] + b_index = [symbolic.pystr_to_symbolic(f'__j{i} + ({dst_subset[i][0]})') for i in range(len(copy_shape))] + elif red_src_subset_size == red_dst_subset_size and (len(red_dst_subset_size) > 0): + # If we remove all size 1 dimensions that the two subsets have the same size. + # This is essentially the memlet `a[0:10, 2, 0:10] -> 0:10, 10:20` + # We use another index variable only for the tests but we would have to + # recreate the index anyways. + maprange = {f'__j{i}': (0, s - 1, 1) for i, s in enumerate(red_src_subset_size)} + cnt = itertools.count(0) + a_index = [ + symbolic.pystr_to_symbolic(f'{src_subset[i][0]}') + if s == 1 + else symbolic.pystr_to_symbolic(f'__j{next(cnt)} + ({src_subset[i][0]})') + for i, s in enumerate(src_subset_size) + ] + cnt = itertools.count(0) + b_index = [ + symbolic.pystr_to_symbolic(f'{dst_subset[i][0]}') + if s == 1 + else symbolic.pystr_to_symbolic(f'__j{next(cnt)} + ({dst_subset[i][0]})') + for i, s in enumerate(dst_subset_size) + ] else: - a_index = self.delinearize_linearize(adesc, copy_shape, edge.data.get_src_subset(edge, state)) - b_index = [symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape))] + # We have to delinearize and linearize + # We use another index variable for the tests. + maprange = {f'__i{i}': (0, s - 1, 1) for i, s in enumerate(copy_shape)} + if copy_a: + a_index = [symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape))] + b_index = self.delinearize_linearize(bdesc, copy_shape, edge.data.get_dst_subset(edge, state)) + else: + a_index = self.delinearize_linearize(adesc, copy_shape, edge.data.get_src_subset(edge, state)) + b_index = [symbolic.pystr_to_symbolic(f'__i{i}') for i in range(len(copy_shape))] a_subset = subsets.Range([(ind, ind, 1) for ind in a_index]) b_subset = subsets.Range([(ind, ind, 1) for ind in b_index]) @@ -101,7 +147,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): schedule = dtypes.ScheduleType.GPU_Device # Add copy map - t, _, _ = state.add_mapped_tasklet('copy', + t, _, _ = state.add_mapped_tasklet(f'copy_{av}_{bv}', maprange, dict(__inp=Memlet(data=av, subset=a_subset)), '__out = __inp', diff --git a/tests/transformations/copy_to_map_test.py b/tests/transformations/copy_to_map_test.py index 2b237d84d5..a0931fa1b8 100644 --- a/tests/transformations/copy_to_map_test.py +++ b/tests/transformations/copy_to_map_test.py @@ -4,6 +4,8 @@ import copy import pytest import numpy as np +import re +from typing import Tuple, Optional def _copy_to_map(storage: dace.StorageType): @@ -102,9 +104,165 @@ def test_preprocess(): assert np.allclose(out, inp) +def _perform_non_lin_delin_test( + sdfg: dace.SDFG, +) -> bool: + """Performs test for the special case CopyToMap that bypasses linearizing and delinearaziong. + """ + assert sdfg.number_of_nodes() == 1 + state: dace.SDFGState = sdfg.states()[0] + assert state.number_of_nodes() == 2 + assert state.number_of_edges() == 1 + assert all(isinstance(node, dace.nodes.AccessNode) for node in state.nodes()) + sdfg.validate() + + a = np.random.rand(*sdfg.arrays["a"].shape) + b_unopt = np.random.rand(*sdfg.arrays["b"].shape) + b_opt = b_unopt.copy() + sdfg(a=a, b=b_unopt) + + nb_runs = sdfg.apply_transformations_repeated(CopyToMap, validate=True, options={"ignore_strides": True}) + assert nb_runs == 1, f"Expected 1 application, but {nb_runs} were performed." + + # Now looking for the tasklet and checking if the memlets follows the expected + # simple pattern. + tasklet: dace.nodes.Tasklet = next(iter([node for node in state.nodes() if isinstance(node, dace.nodes.Tasklet)])) + pattern: re.Pattern = re.compile(r"(__j[0-9])|(__j[0-9]+\s*\+\s*[0-9]+)|([0-9]+)") + + assert state.in_degree(tasklet) == 1 + assert state.out_degree(tasklet) == 1 + in_edge = next(iter(state.in_edges(tasklet))) + out_edge = next(iter(state.out_edges(tasklet))) + + assert all(pattern.fullmatch(str(idxs[0]).strip()) for idxs in in_edge.data.src_subset), f"IN: {in_edge.data.src_subset}" + assert all(pattern.fullmatch(str(idxs[0]).strip()) for idxs in out_edge.data.dst_subset), f"OUT: {out_edge.data.dst_subset}" + + # Now call it again after the optimization. + sdfg(a=a, b=b_opt) + assert np.allclose(b_unopt, b_opt) + + return True + +def _make_non_lin_delin_sdfg( + shape_a: Tuple[int, ...], + shape_b: Optional[Tuple[int, ...]] = None +) -> Tuple[dace.SDFG, dace.SDFGState, dace.nodes.AccessNode, dace.nodes.AccessNode]: + + if shape_b is None: + shape_b = shape_a + + sdfg = dace.SDFG("bypass1") + state = sdfg.add_state(is_start_block=True) + + ac = [] + for name, shape in [('a', shape_a), ('b', shape_b)]: + sdfg.add_array( + name=name, + shape=shape, + dtype=dace.float64, + transient=False, + ) + ac.append(state.add_access(name)) + + return sdfg, state, ac[0], ac[1] + + +def test_non_lin_delin_1(): + sdfg, state, a, b = _make_non_lin_delin_sdfg((10, 10)) + state.add_nedge( + a, + b, + dace.Memlet("a[0:10, 0:10] -> [0:10, 0:10]"), + ) + _perform_non_lin_delin_test(sdfg) + +def test_non_lin_delin_2(): + sdfg, state, a, b = _make_non_lin_delin_sdfg((10, 10), (100, 100)) + state.add_nedge( + a, + b, + dace.Memlet("a[0:10, 0:10] -> [50:60, 40:50]"), + ) + _perform_non_lin_delin_test(sdfg) + + +def test_non_lin_delin_3(): + sdfg, state, a, b = _make_non_lin_delin_sdfg((100, 100), (100, 100)) + state.add_nedge( + a, + b, + dace.Memlet("a[1:11, 20:30] -> [50:60, 40:50]"), + ) + _perform_non_lin_delin_test(sdfg) + + +def test_non_lin_delin_4(): + sdfg, state, a, b = _make_non_lin_delin_sdfg((100, 4, 100), (100, 100)) + state.add_nedge( + a, + b, + dace.Memlet("a[1:11, 2, 20:30] -> [50:60, 40:50]"), + ) + _perform_non_lin_delin_test(sdfg) + + +def test_non_lin_delin_5(): + sdfg, state, a, b = _make_non_lin_delin_sdfg((100, 4, 100), (100, 10, 100)) + state.add_nedge( + a, + b, + dace.Memlet("a[1:11, 2, 20:30] -> [50:60, 4, 40:50]"), + ) + _perform_non_lin_delin_test(sdfg) + + +def test_non_lin_delin_6(): + sdfg, state, a, b = _make_non_lin_delin_sdfg((100, 100), (100, 10, 100)) + state.add_nedge( + a, + b, + dace.Memlet("a[1:11, 20:30] -> [50:60, 4, 40:50]"), + ) + _perform_non_lin_delin_test(sdfg) + + +def test_non_lin_delin_7(): + sdfg, state, a, b = _make_non_lin_delin_sdfg((10, 10), (20, 20)) + state.add_nedge( + a, + b, + dace.Memlet("b[5:15, 6:16]"), + ) + _perform_non_lin_delin_test(sdfg) + + +def test_non_lin_delin_8(): + sdfg, state, a, b = _make_non_lin_delin_sdfg((20, 20), (10, 10)) + state.add_nedge( + a, + b, + dace.Memlet("a[5:15, 6:16]"), + ) + _perform_non_lin_delin_test(sdfg) + + if __name__ == '__main__': + test_non_lin_delin_1() + test_non_lin_delin_2() + test_non_lin_delin_3() + test_non_lin_delin_4() + test_non_lin_delin_5() + test_non_lin_delin_6() + test_non_lin_delin_7() + test_non_lin_delin_8() + test_copy_to_map() - test_copy_to_map_gpu() test_flatten_to_map() - test_flatten_to_map_gpu() - test_preprocess() + try: + import cupy + test_copy_to_map_gpu() + test_flatten_to_map_gpu() + test_preprocess() + except ModuleNotFoundError as E: + if "'cupy'" not in str(E): + raise