Skip to content

Commit

Permalink
Fixed edge consolidation (#1546)
Browse files Browse the repository at this point in the history
The main issue was a typo were `src` was accessed instead of `dst`.
  • Loading branch information
philip-paul-mueller authored Mar 11, 2024
1 parent ff6e064 commit 5f9233e
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,11 +585,14 @@ def consolidate_edges_scope(state: SDFGState, scope_node: Union[nd.EntryNode, nd
conn_to_remove = prefix + conn[offset:]
remove_outer_connector(conn_to_remove)
if isinstance(scope_node, nd.EntryNode):
out_edge = next(ed for ed in outer_edges(scope_node) if ed.dst_conn == target_conn)
edge_to_remove = next(ed for ed in outer_edges(scope_node) if ed.dst_conn == conn_to_remove)
out_edges = [ed for ed in outer_edges(scope_node) if ed.dst_conn == target_conn]
edges_to_remove = [ed for ed in outer_edges(scope_node) if ed.dst_conn == conn_to_remove]
else:
out_edge = next(ed for ed in outer_edges(scope_node) if ed.src_conn == target_conn)
edge_to_remove = next(ed for ed in outer_edges(scope_node) if ed.src_conn == conn_to_remove)
out_edges = [ed for ed in outer_edges(scope_node) if ed.src_conn == target_conn]
edges_to_remove = [ed for ed in outer_edges(scope_node) if ed.src_conn == conn_to_remove]
assert len(edges_to_remove) == 1 and len(out_edges) == 1
edge_to_remove = edges_to_remove[0]
out_edge = out_edges[0]
out_edge.data.subset = sbs.union(out_edge.data.subset, edge_to_remove.data.subset)

# Check if dangling connectors have been created and remove them,
Expand Down Expand Up @@ -627,9 +630,9 @@ def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge):
e = curedge.edge
state.remove_edge(e)
if inwards:
neighbors = [neighbor for neighbor in state.out_edges(e.src) if e.src_conn == neighbor.src_conn]
neighbors = [] if not e.src_conn else [neighbor for neighbor in state.out_edges_by_connector(e.src, e.src_conn)]
else:
neighbors = [neighbor for neighbor in state.in_edges(e.dst) if e.dst_conn == neighbor.dst_conn]
neighbors = [] if not e.dst_conn else [neighbor for neighbor in state.in_edges_by_connector(e.dst, e.dst_conn)]
if len(neighbors) > 0: # There are still edges connected, leave as-is
break

Expand All @@ -641,7 +644,7 @@ def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge):
else:
if e.dst_conn:
e.dst.remove_in_connector(e.dst_conn)
e.src.remove_out_connector('OUT' + e.dst_conn[2:])
e.dst.remove_out_connector('OUT' + e.dst_conn[2:])

# Continue traversing upwards
curedge = curedge.parent
Expand Down

0 comments on commit 5f9233e

Please sign in to comment.