diff --git a/dace/transformation/dataflow/map_expansion.py b/dace/transformation/dataflow/map_expansion.py index 9d89ec7c09..8bc14213b0 100644 --- a/dace/transformation/dataflow/map_expansion.py +++ b/dace/transformation/dataflow/map_expansion.py @@ -6,7 +6,7 @@ import copy import dace from dace import dtypes, subsets, symbolic -from dace.properties import EnumProperty, make_properties +from dace.properties import EnumProperty, make_properties, Property from dace.sdfg import nodes from dace.sdfg import utils as sdutil from dace.sdfg.graph import OrderedMultiDiConnectorGraph @@ -18,8 +18,9 @@ class MapExpansion(pm.SingleStateTransformation): """ Implements the map-expansion pattern. - Map-expansion takes an N-dimensional map and expands it to N - unidimensional maps. + Map-expansion takes an N-dimensional map and expands it. + It will generate the k nested unidimensional map and a (N-k)-dimensional inner most map. + If k is not specified all maps are expanded. New edges abide by the following rules: 1. If there are no edges coming from the outside, use empty memlets @@ -33,6 +34,11 @@ class MapExpansion(pm.SingleStateTransformation): dtype=dtypes.ScheduleType, default=dtypes.ScheduleType.Sequential, allow_none=True) + expansion_limit = Property(desc="How many unidimensional maps will be creaed, known as k. " + "If None, the default no limit is in place.", + dtype=int, + allow_none=True, + default=None) @classmethod def expressions(cls): @@ -43,22 +49,77 @@ def can_be_applied(self, graph: dace.SDFGState, expr_index: int, sdfg: dace.SDFG # includes an N-dimensional map, with N greater than one. return self.map_entry.map.get_param_num() > 1 + def generate_new_maps(self, + current_map: nodes.Map): + if self.expansion_limit is None: + full_expand = True + elif isinstance(self.expansion_limit, int): + full_expand = False + if self.expansion_limit <= 0: # These are invalid, so we make a full expansion + full_expand = True + elif (self.map_entry.map.get_param_num() - self.expansion_limit) <= 1: + full_expand = True + else: + raise TypeError(f"Does not know how to handle type {type(self.expansion_limit).__name__}") + + inner_schedule = self.inner_schedule or current_map.schedule + if full_expand: + new_maps = [ + nodes.Map( + current_map.label + '_' + str(param), [param], + subsets.Range([param_range]), + schedule=inner_schedule if dim != 0 else current_map.schedule) + for dim, param, param_range in zip(range(len(current_map.params)), current_map.params, current_map.range) + ] + for i, new_map in enumerate(new_maps): + new_map.range.tile_sizes[0] = current_map.range.tile_sizes[i] + + else: + k = self.expansion_limit + new_maps: list[nodes.Map] = [] + + # Unidimensional maps + for dim in range(0, k): + dim_param = current_map.params[dim] + dim_range = current_map.range.ranges[dim] + dim_tile = current_map.range.tile_sizes[dim] + new_maps.append( + nodes.Map( + current_map.label + '_' + str(dim_param), + [dim_param], + subsets.Range([dim_range]), + schedule=inner_schedule if dim != 0 else current_map.schedule )) + new_maps[-1].range.tile_sizes[0] = dim_tile + + # Multidimensional maps + mdim_params = current_map.params[k:] + mdim_ranges = current_map.range.ranges[k:] + mdim_tiles = current_map.range.tile_sizes[k:] + new_maps.append( + nodes.Map( + current_map.label, # The original name + mdim_params, + mdim_ranges, + schedule=inner_schedule )) + new_maps[-1].range.tile_sizes = mdim_tiles + return new_maps + def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): # Extract the map and its entry and exit nodes. map_entry = self.map_entry map_exit = graph.exit_node(map_entry) current_map = map_entry.map - # Create new maps - inner_schedule = self.inner_schedule or current_map.schedule - new_maps = [ - nodes.Map(current_map.label + '_' + str(param), [param], - subsets.Range([param_range]), - schedule=inner_schedule) - for param, param_range in zip(current_map.params[1:], current_map.range[1:]) - ] - current_map.params = [current_map.params[0]] - current_map.range = subsets.Range([current_map.range[0]]) + # Generate the new maps that we should use. + new_maps = self.generate_new_maps(current_map) + + if not new_maps: # No changes should be made -> noops + return + + # Reuse the map that is already existing for the first one. + current_map.params = new_maps[0].params + current_map.range = new_maps[0].range + new_maps.pop(0) # Create new map entries and exits entries = [nodes.MapEntry(new_map) for new_map in new_maps] diff --git a/tests/transformations/map_expansion_test.py b/tests/transformations/map_expansion_test.py index 1f9a97f810..6e4b965ba2 100644 --- a/tests/transformations/map_expansion_test.py +++ b/tests/transformations/map_expansion_test.py @@ -73,7 +73,7 @@ def toexpand(B: dace.float64[4, 4]): continue # (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet - if sdfg.start_state.entry_node(node) is None: + if state.entry_node(node) is None: assert state.in_degree(node) == 0 assert state.out_degree(node) == 1 assert len(node.out_connectors) == 0 @@ -113,7 +113,58 @@ def mymap(i: _[0:20], j: _[rng[0]:rng[1]], k: _[0:5]): print('Difference:', diff2) assert (diff <= 1e-5) and (diff2 <= 1e-5) +def test_expand_with_limits(): + @dace.program + def expansion(A: dace.float32[20, 30, 5]): + @dace.map + def mymap(i: _[0:20], j: _[0:30], k: _[0:5]): + a << A[i, j, k] + b >> A[i, j, k] + b = a * 2 + + A = np.random.rand(20, 30, 5).astype(np.float32) + expected = A.copy() + expected *= 2 + + sdfg = expansion.to_sdfg() + sdfg.simplify() + sdfg(A=A) + diff = np.linalg.norm(A - expected) + print('Difference (before transformation):', diff) + + sdfg.apply_transformations(MapExpansion, options=dict(expansion_limit=1)) + + map_entries = set() + state = sdfg.start_state + for node in state.nodes(): + if not isinstance(node, dace.nodes.MapEntry): + continue + + if state.entry_node(node) is None: + assert state.in_degree(node) == 1 + assert state.out_degree(node) == 1 + assert len(node.out_connectors) == 1 + assert len(node.map.range.ranges) == 1 + assert node.map.range.ranges[0][1] - node.map.range.ranges[0][0] + 1 == 20 + else: + assert state.in_degree(node) == 1 + assert state.out_degree(node) == 1 + assert len(node.out_connectors) == 1 + assert len(node.map.range.ranges) == 2 + assert list(map(lambda x: x[1] - x[0] + 1, node.map.range.ranges)) == [30, 5] + + map_entries.add(node) + + sdfg(A=A) + expected *= 2 + diff2 = np.linalg.norm(A - expected) + print('Difference:', diff2) + assert (diff <= 1e-5) and (diff2 <= 1e-5) + assert len(map_entries) == 2 + + if __name__ == '__main__': test_expand_with_inputs() test_expand_without_inputs() test_expand_without_dynamic_inputs() + test_expand_with_limits()