Skip to content

Commit

Permalink
Fixed MapExpansion Transformation (#1743)
Browse files Browse the repository at this point in the history
The transformation did not consider maps that are connected with
dependency edges.
Now, the transformation skips dependency edges if non-dependency edges
exist.
Otherwise, it connects the inner and outer maps with a single dependency
edge to maintain connectivity.
  • Loading branch information
Berke-Ates authored Dec 17, 2024
1 parent 859ab2f commit a517699
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
7 changes: 6 additions & 1 deletion dace/transformation/dataflow/map_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,15 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
graph.add_edge(entries[-1], edge.src_conn, edge.dst, edge.dst_conn, memlet=copy.deepcopy(edge.data))
graph.remove_edge(edge)

if graph.in_degree(map_entry) == 0:
if graph.in_degree(map_entry) == 0 or all(
e.dst_conn is None or not e.dst_conn.startswith("IN_")
for e in graph.in_edges(map_entry)
):
graph.add_memlet_path(map_entry, *entries, memlet=dace.Memlet())
else:
for edge in graph.in_edges(map_entry):
if edge.dst_conn is None:
continue
if not edge.dst_conn.startswith("IN_"):
continue

Expand Down
33 changes: 33 additions & 0 deletions tests/transformations/map_expansion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,41 @@ def mymap(i: _[0:20], j: _[0:30], k: _[0:5]):
assert len(map_entries) == 2


def test_expand_with_dependency_edges():

@dace.program
def expansion(A: dace.float32[2], B: dace.float32[2, 2, 2]):
for i in dace.map[0:2]:
A[i] = i

for j, k in dace.map[0:2, 0:2]:
B[i, j, k] = i * j + k

sdfg = expansion.to_sdfg()
sdfg.simplify()
sdfg.validate()

# If dependency edges are handled correctly, this should not raise an exception
try:
num_app = sdfg.apply_transformations_repeated(MapExpansion)
except Exception as e:
assert False, f"MapExpansion failed: {str(e)}"
assert num_app == 1
sdfg.validate()

A = np.random.rand(2).astype(np.float32)
B = np.random.rand(2, 2, 2).astype(np.float32)
sdfg(A=A, B=B)

A_expected = np.array([0, 1], dtype=np.float32)
B_expected = np.array([[[0, 1], [0, 1]], [[0, 1], [1, 2]]], dtype=np.float32)
assert np.all(A == A_expected)
assert np.all(B == B_expected)


if __name__ == '__main__':
test_expand_with_inputs()
test_expand_without_inputs()
test_expand_without_dynamic_inputs()
test_expand_with_limits()
test_expand_with_dependency_edges()

0 comments on commit a517699

Please sign in to comment.