From d0db188db5f9d544c3c857ad0a5b32ad290c01ff Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 26 Mar 2024 12:22:55 +0100 Subject: [PATCH] Fix bug in map_fusion transformation (#1553) Four-lines bugfix and associated test case for map_fusion transformation. Without this change, the test would fail in SDFG validation with error: `dace.sdfg.validation.InvalidSDFGEdgeError: Memlet data does not match source or destination data nodes) (at state state, edge __s0_n1None_n3IN_T[0] (V:None -> numeric:_inp))` --------- Co-authored-by: alexnick83 <31545860+alexnick83@users.noreply.github.com> --- dace/transformation/dataflow/map_fusion.py | 6 ++++ tests/transformations/mapfusion_test.py | 38 ++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 9a0dd0e313..186ea32acc 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -481,6 +481,12 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None) local_node = edge.src src_connector = edge.src_conn + # update edge data in case source or destination is a scalar access node + test_data = [node.data for node in (edge.src, edge.dst) if isinstance(node, nodes.AccessNode)] + for new_data in test_data: + if isinstance(sdfg.arrays[new_data], data.Scalar): + edge.data.data = new_data + # If destination of edge leads to multiple destinations, redirect all through an access node. if other_edges: # NOTE: If a new local node was already created, reuse it. diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 653fb9d120..724c8c97ee 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -163,6 +163,43 @@ def test_fusion_with_transient(): assert np.allclose(A, expected) +def test_fusion_with_transient_scalar(): + N = 10 + K = 4 + + def build_sdfg(): + sdfg = dace.SDFG("map_fusion_with_transient_scalar") + state = sdfg.add_state() + sdfg.add_array("A", (N,K), dace.float64) + sdfg.add_array("B", (N,), dace.float64) + sdfg.add_array("T", (N,), dace.float64, transient=True) + t_node = state.add_access("T") + sdfg.add_scalar("V", dace.float64, transient=True) + v_node = state.add_access("V") + + me1, mx1 = state.add_map("map1", dict(i=f"0:{N}")) + tlet1 = state.add_tasklet("select", {"_v"}, {"_out"}, f"_out = _v[i, {K-1}]") + state.add_memlet_path(state.add_access("A"), me1, tlet1, dst_conn="_v", memlet=dace.Memlet.from_array("A", sdfg.arrays["A"])) + state.add_edge(tlet1, "_out", v_node, None, dace.Memlet("V[0]")) + state.add_memlet_path(v_node, mx1, t_node, memlet=dace.Memlet("T[i]")) + + me2, mx2 = state.add_map("map2", dict(j=f"0:{N}")) + tlet2 = state.add_tasklet("numeric", {"_inp"}, {"_out"}, f"_out = _inp + 1") + state.add_memlet_path(t_node, me2, tlet2, dst_conn="_inp", memlet=dace.Memlet("T[j]")) + state.add_memlet_path(tlet2, mx2, state.add_access("B"), src_conn="_out", memlet=dace.Memlet("B[j]")) + + return sdfg + + sdfg = build_sdfg() + sdfg.apply_transformations(MapFusion) + + A = np.random.rand(N, K) + B = np.repeat(np.nan, N) + sdfg(A=A, B=B) + + assert np.allclose(B, (A[:, K-1] + 1)) + + def test_fusion_with_inverted_indices(): @dace.program @@ -278,6 +315,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 test_multiple_fusions() test_fusion_chain() test_fusion_with_transient() + test_fusion_with_transient_scalar() test_fusion_with_inverted_indices() test_fusion_with_empty_memlet() test_fusion_with_nested_sdfg_0()