From b1a7f8a6ea76f913a0bf8b32de5bc416697218fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 27 Feb 2024 15:05:23 +0100 Subject: [PATCH] Updated and fixed the MapDimShuffle tranformation. (#1531) The reordering of the dimension was quite strange, it is now much more cleaner. Furthermore, the `tile_sizes` parameter of the map range was ignored, i.e. not changed. Furthermore, the parameters are now correclty copied. Before it was just assigned, thus if later the parameter list of the map was again changed this effect would propagate to _all_ maps that where treated. --- .../dataflow/map_dim_shuffle.py | 23 ++++++++++--------- tests/transformations/map_dim_shuffle_test.py | 3 +++ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/dace/transformation/dataflow/map_dim_shuffle.py b/dace/transformation/dataflow/map_dim_shuffle.py index 7b4114b188..ad17a5ddac 100644 --- a/dace/transformation/dataflow/map_dim_shuffle.py +++ b/dace/transformation/dataflow/map_dim_shuffle.py @@ -27,18 +27,19 @@ def expressions(cls): return [sdutil.node_path_graph(cls.map_entry)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + map_entry: nodes.MapEntry = self.map_entry + if self.parameters is None: + return False + if len(self.parameters) != len(map_entry.map.params): + return False + if set(self.parameters) != set(map_entry.map.params): + return False return True def apply(self, graph: SDFGState, sdfg: SDFG): - map_entry = self.map_entry - if self.parameters is None: - return - - if set(self.parameters) != set(map_entry.map.params): - return + map_entry: nodes.MapEntry = self.map_entry + new_map_order: list[int] = [map_entry.map.params.index(param) for param in self.parameters] - map_entry.range.ranges = [ - r for list_param in self.parameters for map_param, r in zip(map_entry.map.params, map_entry.range.ranges) - if list_param == map_param - ] - map_entry.map.params = self.parameters + map_entry.range.ranges = [map_entry.range.ranges[new_pos] for new_pos in new_map_order] + map_entry.range.tile_sizes = [map_entry.range.tile_sizes[new_pos] for new_pos in new_map_order] + map_entry.map.params = [map_entry.map.params[new_pos] for new_pos in new_map_order] diff --git a/tests/transformations/map_dim_shuffle_test.py b/tests/transformations/map_dim_shuffle_test.py index e0eb3f4311..1d9c73e5a2 100644 --- a/tests/transformations/map_dim_shuffle_test.py +++ b/tests/transformations/map_dim_shuffle_test.py @@ -36,6 +36,9 @@ def test_map_dim_shuffle(): sdfg(A=A, B=B) assert np.allclose(B, expected) + assert sdfg.apply_transformations_repeated(MapDimShuffle, options={"parameters": ["k", "i"]}) == 0 + assert sdfg.apply_transformations_repeated(MapDimShuffle, options={"parameters": ["k", "i", "l"]}) == 0 + if __name__ == '__main__': test_map_dim_shuffle()