From ba1587ecc2b9a0a914fbb472922b1123a2c4a1ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 27 Feb 2024 15:03:27 +0100 Subject: [PATCH] Updated and fixed the MapExpansion transformation. (#1532) Before the transformation ignored the tiling parameter of the range, this was fixed. The transformation is now also able to limit the expansion. Thus it is possible to only expand the first k dimensions, the remaining dimensions will be kept in a multi dimensional map. This is a feature that will be need in GT4Py. --- dace/transformation/dataflow/map_expansion.py | 87 ++++++++++++++++--- tests/transformations/map_expansion_test.py | 53 ++++++++++- 2 files changed, 126 insertions(+), 14 deletions(-) diff --git a/dace/transformation/dataflow/map_expansion.py b/dace/transformation/dataflow/map_expansion.py index 9d89ec7c09..8bc14213b0 100644 --- a/dace/transformation/dataflow/map_expansion.py +++ b/dace/transformation/dataflow/map_expansion.py @@ -6,7 +6,7 @@ import copy import dace from dace import dtypes, subsets, symbolic -from dace.properties import EnumProperty, make_properties +from dace.properties import EnumProperty, make_properties, Property from dace.sdfg import nodes from dace.sdfg import utils as sdutil from dace.sdfg.graph import OrderedMultiDiConnectorGraph @@ -18,8 +18,9 @@ class MapExpansion(pm.SingleStateTransformation): """ Implements the map-expansion pattern. - Map-expansion takes an N-dimensional map and expands it to N - unidimensional maps. + Map-expansion takes an N-dimensional map and expands it. + It will generate the k nested unidimensional map and a (N-k)-dimensional inner most map. + If k is not specified all maps are expanded. New edges abide by the following rules: 1. If there are no edges coming from the outside, use empty memlets @@ -33,6 +34,11 @@ class MapExpansion(pm.SingleStateTransformation): dtype=dtypes.ScheduleType, default=dtypes.ScheduleType.Sequential, allow_none=True) + expansion_limit = Property(desc="How many unidimensional maps will be creaed, known as k. " + "If None, the default no limit is in place.", + dtype=int, + allow_none=True, + default=None) @classmethod def expressions(cls): @@ -43,22 +49,77 @@ def can_be_applied(self, graph: dace.SDFGState, expr_index: int, sdfg: dace.SDFG # includes an N-dimensional map, with N greater than one. return self.map_entry.map.get_param_num() > 1 + def generate_new_maps(self, + current_map: nodes.Map): + if self.expansion_limit is None: + full_expand = True + elif isinstance(self.expansion_limit, int): + full_expand = False + if self.expansion_limit <= 0: # These are invalid, so we make a full expansion + full_expand = True + elif (self.map_entry.map.get_param_num() - self.expansion_limit) <= 1: + full_expand = True + else: + raise TypeError(f"Does not know how to handle type {type(self.expansion_limit).__name__}") + + inner_schedule = self.inner_schedule or current_map.schedule + if full_expand: + new_maps = [ + nodes.Map( + current_map.label + '_' + str(param), [param], + subsets.Range([param_range]), + schedule=inner_schedule if dim != 0 else current_map.schedule) + for dim, param, param_range in zip(range(len(current_map.params)), current_map.params, current_map.range) + ] + for i, new_map in enumerate(new_maps): + new_map.range.tile_sizes[0] = current_map.range.tile_sizes[i] + + else: + k = self.expansion_limit + new_maps: list[nodes.Map] = [] + + # Unidimensional maps + for dim in range(0, k): + dim_param = current_map.params[dim] + dim_range = current_map.range.ranges[dim] + dim_tile = current_map.range.tile_sizes[dim] + new_maps.append( + nodes.Map( + current_map.label + '_' + str(dim_param), + [dim_param], + subsets.Range([dim_range]), + schedule=inner_schedule if dim != 0 else current_map.schedule )) + new_maps[-1].range.tile_sizes[0] = dim_tile + + # Multidimensional maps + mdim_params = current_map.params[k:] + mdim_ranges = current_map.range.ranges[k:] + mdim_tiles = current_map.range.tile_sizes[k:] + new_maps.append( + nodes.Map( + current_map.label, # The original name + mdim_params, + mdim_ranges, + schedule=inner_schedule )) + new_maps[-1].range.tile_sizes = mdim_tiles + return new_maps + def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): # Extract the map and its entry and exit nodes. map_entry = self.map_entry map_exit = graph.exit_node(map_entry) current_map = map_entry.map - # Create new maps - inner_schedule = self.inner_schedule or current_map.schedule - new_maps = [ - nodes.Map(current_map.label + '_' + str(param), [param], - subsets.Range([param_range]), - schedule=inner_schedule) - for param, param_range in zip(current_map.params[1:], current_map.range[1:]) - ] - current_map.params = [current_map.params[0]] - current_map.range = subsets.Range([current_map.range[0]]) + # Generate the new maps that we should use. + new_maps = self.generate_new_maps(current_map) + + if not new_maps: # No changes should be made -> noops + return + + # Reuse the map that is already existing for the first one. + current_map.params = new_maps[0].params + current_map.range = new_maps[0].range + new_maps.pop(0) # Create new map entries and exits entries = [nodes.MapEntry(new_map) for new_map in new_maps] diff --git a/tests/transformations/map_expansion_test.py b/tests/transformations/map_expansion_test.py index 1f9a97f810..6e4b965ba2 100644 --- a/tests/transformations/map_expansion_test.py +++ b/tests/transformations/map_expansion_test.py @@ -73,7 +73,7 @@ def toexpand(B: dace.float64[4, 4]): continue # (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet - if sdfg.start_state.entry_node(node) is None: + if state.entry_node(node) is None: assert state.in_degree(node) == 0 assert state.out_degree(node) == 1 assert len(node.out_connectors) == 0 @@ -113,7 +113,58 @@ def mymap(i: _[0:20], j: _[rng[0]:rng[1]], k: _[0:5]): print('Difference:', diff2) assert (diff <= 1e-5) and (diff2 <= 1e-5) +def test_expand_with_limits(): + @dace.program + def expansion(A: dace.float32[20, 30, 5]): + @dace.map + def mymap(i: _[0:20], j: _[0:30], k: _[0:5]): + a << A[i, j, k] + b >> A[i, j, k] + b = a * 2 + + A = np.random.rand(20, 30, 5).astype(np.float32) + expected = A.copy() + expected *= 2 + + sdfg = expansion.to_sdfg() + sdfg.simplify() + sdfg(A=A) + diff = np.linalg.norm(A - expected) + print('Difference (before transformation):', diff) + + sdfg.apply_transformations(MapExpansion, options=dict(expansion_limit=1)) + + map_entries = set() + state = sdfg.start_state + for node in state.nodes(): + if not isinstance(node, dace.nodes.MapEntry): + continue + + if state.entry_node(node) is None: + assert state.in_degree(node) == 1 + assert state.out_degree(node) == 1 + assert len(node.out_connectors) == 1 + assert len(node.map.range.ranges) == 1 + assert node.map.range.ranges[0][1] - node.map.range.ranges[0][0] + 1 == 20 + else: + assert state.in_degree(node) == 1 + assert state.out_degree(node) == 1 + assert len(node.out_connectors) == 1 + assert len(node.map.range.ranges) == 2 + assert list(map(lambda x: x[1] - x[0] + 1, node.map.range.ranges)) == [30, 5] + + map_entries.add(node) + + sdfg(A=A) + expected *= 2 + diff2 = np.linalg.norm(A - expected) + print('Difference:', diff2) + assert (diff <= 1e-5) and (diff2 <= 1e-5) + assert len(map_entries) == 2 + + if __name__ == '__main__': test_expand_with_inputs() test_expand_without_inputs() test_expand_without_dynamic_inputs() + test_expand_with_limits()