Skip to content

Commit

Permalink
Fix bug in map_fusion transformation (#1553)
Browse files Browse the repository at this point in the history
Four-lines bugfix and associated test case for map_fusion
transformation.

Without this change, the test would fail in SDFG validation with error:
`dace.sdfg.validation.InvalidSDFGEdgeError: Memlet data does not match
source or destination data nodes) (at state state, edge
__s0_n1None_n3IN_T[0] (V:None -> numeric:_inp))`

---------

Co-authored-by: alexnick83 <[email protected]>
  • Loading branch information
edopao and alexnick83 authored Mar 26, 2024
1 parent da0cde2 commit d0db188
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
6 changes: 6 additions & 0 deletions dace/transformation/dataflow/map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,12 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None)
local_node = edge.src
src_connector = edge.src_conn

# update edge data in case source or destination is a scalar access node
test_data = [node.data for node in (edge.src, edge.dst) if isinstance(node, nodes.AccessNode)]
for new_data in test_data:
if isinstance(sdfg.arrays[new_data], data.Scalar):
edge.data.data = new_data

# If destination of edge leads to multiple destinations, redirect all through an access node.
if other_edges:
# NOTE: If a new local node was already created, reuse it.
Expand Down
38 changes: 38 additions & 0 deletions tests/transformations/mapfusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,43 @@ def test_fusion_with_transient():
assert np.allclose(A, expected)


def test_fusion_with_transient_scalar():
N = 10
K = 4

def build_sdfg():
sdfg = dace.SDFG("map_fusion_with_transient_scalar")
state = sdfg.add_state()
sdfg.add_array("A", (N,K), dace.float64)
sdfg.add_array("B", (N,), dace.float64)
sdfg.add_array("T", (N,), dace.float64, transient=True)
t_node = state.add_access("T")
sdfg.add_scalar("V", dace.float64, transient=True)
v_node = state.add_access("V")

me1, mx1 = state.add_map("map1", dict(i=f"0:{N}"))
tlet1 = state.add_tasklet("select", {"_v"}, {"_out"}, f"_out = _v[i, {K-1}]")
state.add_memlet_path(state.add_access("A"), me1, tlet1, dst_conn="_v", memlet=dace.Memlet.from_array("A", sdfg.arrays["A"]))
state.add_edge(tlet1, "_out", v_node, None, dace.Memlet("V[0]"))
state.add_memlet_path(v_node, mx1, t_node, memlet=dace.Memlet("T[i]"))

me2, mx2 = state.add_map("map2", dict(j=f"0:{N}"))
tlet2 = state.add_tasklet("numeric", {"_inp"}, {"_out"}, f"_out = _inp + 1")
state.add_memlet_path(t_node, me2, tlet2, dst_conn="_inp", memlet=dace.Memlet("T[j]"))
state.add_memlet_path(tlet2, mx2, state.add_access("B"), src_conn="_out", memlet=dace.Memlet("B[j]"))

return sdfg

sdfg = build_sdfg()
sdfg.apply_transformations(MapFusion)

A = np.random.rand(N, K)
B = np.repeat(np.nan, N)
sdfg(A=A, B=B)

assert np.allclose(B, (A[:, K-1] + 1))


def test_fusion_with_inverted_indices():

@dace.program
Expand Down Expand Up @@ -278,6 +315,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3
test_multiple_fusions()
test_fusion_chain()
test_fusion_with_transient()
test_fusion_with_transient_scalar()
test_fusion_with_inverted_indices()
test_fusion_with_empty_memlet()
test_fusion_with_nested_sdfg_0()
Expand Down

0 comments on commit d0db188

Please sign in to comment.