diff --git a/dace/transformation/dataflow/map_dim_shuffle.py b/dace/transformation/dataflow/map_dim_shuffle.py index 7b4114b188..ad17a5ddac 100644 --- a/dace/transformation/dataflow/map_dim_shuffle.py +++ b/dace/transformation/dataflow/map_dim_shuffle.py @@ -27,18 +27,19 @@ def expressions(cls): return [sdutil.node_path_graph(cls.map_entry)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + map_entry: nodes.MapEntry = self.map_entry + if self.parameters is None: + return False + if len(self.parameters) != len(map_entry.map.params): + return False + if set(self.parameters) != set(map_entry.map.params): + return False return True def apply(self, graph: SDFGState, sdfg: SDFG): - map_entry = self.map_entry - if self.parameters is None: - return - - if set(self.parameters) != set(map_entry.map.params): - return + map_entry: nodes.MapEntry = self.map_entry + new_map_order: list[int] = [map_entry.map.params.index(param) for param in self.parameters] - map_entry.range.ranges = [ - r for list_param in self.parameters for map_param, r in zip(map_entry.map.params, map_entry.range.ranges) - if list_param == map_param - ] - map_entry.map.params = self.parameters + map_entry.range.ranges = [map_entry.range.ranges[new_pos] for new_pos in new_map_order] + map_entry.range.tile_sizes = [map_entry.range.tile_sizes[new_pos] for new_pos in new_map_order] + map_entry.map.params = [map_entry.map.params[new_pos] for new_pos in new_map_order] diff --git a/tests/transformations/map_dim_shuffle_test.py b/tests/transformations/map_dim_shuffle_test.py index e0eb3f4311..1d9c73e5a2 100644 --- a/tests/transformations/map_dim_shuffle_test.py +++ b/tests/transformations/map_dim_shuffle_test.py @@ -36,6 +36,9 @@ def test_map_dim_shuffle(): sdfg(A=A, B=B) assert np.allclose(B, expected) + assert sdfg.apply_transformations_repeated(MapDimShuffle, options={"parameters": ["k", "i"]}) == 0 + assert sdfg.apply_transformations_repeated(MapDimShuffle, options={"parameters": ["k", "i", "l"]}) == 0 + if __name__ == '__main__': test_map_dim_shuffle()