Skip to content

Commit

Permalink
Fix constant propagation failing due to invalid topological sort (#1589)
Browse files Browse the repository at this point in the history
Constant propagation fails for certain graph structures due to an issue
with `dace.sdfg.graph.Graph.topological_sort`. This is related to #1560.
  • Loading branch information
phschaad authored Jun 14, 2024
1 parent d6f481a commit 59120ae
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 33 deletions.
62 changes: 35 additions & 27 deletions dace/sdfg/analysis/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dace.symbolic import pystr_to_symbolic
import networkx as nx
import sympy as sp
from typing import Dict, Iterator, List, Set
from typing import Dict, Iterator, List, Optional, Set


def acyclic_dominance_frontier(sdfg: SDFG, idom=None) -> Dict[SDFGState, Set[SDFGState]]:
Expand Down Expand Up @@ -67,7 +67,7 @@ def back_edges(sdfg: SDFG,
return [e for e in sdfg.edges() if e.dst in alldoms[e.src]]


def state_parent_tree(sdfg: SDFG) -> Dict[SDFGState, SDFGState]:
def state_parent_tree(sdfg: SDFG, loopexits: Optional[Dict[SDFGState, SDFGState]] = None) -> Dict[SDFGState, SDFGState]:
"""
Computes an upward-pointing tree of each state, pointing to the "parent
state" it belongs to (in terms of structured control flow). More formally,
Expand All @@ -81,7 +81,7 @@ def state_parent_tree(sdfg: SDFG) -> Dict[SDFGState, SDFGState]:
"""
idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state)
alldoms = all_dominators(sdfg, idom)
loopexits: Dict[SDFGState, SDFGState] = defaultdict(lambda: None)
loopexits = loopexits if loopexits is not None else defaultdict(lambda: None)

# First, annotate loops
for be in back_edges(sdfg, idom, alldoms):
Expand All @@ -94,10 +94,9 @@ def state_parent_tree(sdfg: SDFG) -> Dict[SDFGState, SDFGState]:
in_edges = sdfg.in_edges(guard)
out_edges = sdfg.out_edges(guard)

# A loop guard has two or more incoming edges (1 increment and
# n init, all identical), and exactly two outgoing edges (loop and
# exit loop).
if len(in_edges) < 2 or len(out_edges) != 2:
# A loop guard has at least one incoming edges (the backedge, performing the increment), and exactly two
# outgoing edges (loop and exit loop).
if len(in_edges) < 1 or len(out_edges) != 2:
continue

# The outgoing edges must be negations of one another.
Expand Down Expand Up @@ -193,7 +192,7 @@ def cond_b(parent, child):

# Step up
for state in step_up:
if parents[state] is not None:
if parents[state] is not None and parents[parents[state]] is not None:
parents[state] = parents[parents[state]]

return parents
Expand All @@ -204,7 +203,8 @@ def _stateorder_topological_sort(sdfg: SDFG,
ptree: Dict[SDFGState, SDFGState],
branch_merges: Dict[SDFGState, SDFGState],
stop: SDFGState = None,
visited: Set[SDFGState] = None) -> Iterator[SDFGState]:
visited: Set[SDFGState] = None,
loopexits: Optional[Dict[SDFGState, SDFGState]] = None) -> Iterator[SDFGState]:
"""
Helper function for ``stateorder_topological_sort``.
Expand All @@ -217,6 +217,8 @@ def _stateorder_topological_sort(sdfg: SDFG,
:return: Generator that yields states in state-order from ``start`` to
``stop``.
"""
loopexits = loopexits if loopexits is not None else defaultdict(lambda: None)

# Traverse states in custom order
visited = visited or set()
stack = [start]
Expand All @@ -235,20 +237,21 @@ def _stateorder_topological_sort(sdfg: SDFG,
continue
elif len(oe) == 2: # Loop or branch
# If loop, traverse body, then exit
if ptree[oe[0].dst] == node and ptree[oe[1].dst] != node:
for s in _stateorder_topological_sort(sdfg, oe[0].dst, ptree, branch_merges, stop=node,
visited=visited):
yield s
visited.add(s)
stack.append(oe[1].dst)
continue
elif ptree[oe[1].dst] == node and ptree[oe[0].dst] != node:
for s in _stateorder_topological_sort(sdfg, oe[1].dst, ptree, branch_merges, stop=node,
visited=visited):
yield s
visited.add(s)
stack.append(oe[0].dst)
continue
if node in loopexits:
if oe[0].dst == loopexits[node]:
for s in _stateorder_topological_sort(sdfg, oe[1].dst, ptree, branch_merges, stop=node,
visited=visited, loopexits=loopexits):
yield s
visited.add(s)
stack.append(oe[0].dst)
continue
elif oe[1].dst == loopexits[node]:
for s in _stateorder_topological_sort(sdfg, oe[0].dst, ptree, branch_merges, stop=node,
visited=visited, loopexits=loopexits):
yield s
visited.add(s)
stack.append(oe[1].dst)
continue
# Otherwise, passthrough to branch
# Branch
if node in branch_merges:
Expand All @@ -259,7 +262,7 @@ def _stateorder_topological_sort(sdfg: SDFG,
# Otherwise (e.g., with return/break statements), traverse through each branch,
# stopping at the end of the current tree level.
mergestate = next(e.dst for e in sdfg.out_edges(stop) if ptree[e.dst] != stop)
except StopIteration:
except (StopIteration, KeyError):
# If that fails, simply traverse branches in arbitrary order
mergestate = stop

Expand All @@ -272,7 +275,8 @@ def _stateorder_topological_sort(sdfg: SDFG,
ptree,
branch_merges,
stop=mergestate,
visited=visited):
visited=visited,
loopexits=loopexits):
yield s
visited.add(s)
stack.append(mergestate)
Expand All @@ -288,11 +292,13 @@ def stateorder_topological_sort(sdfg: SDFG) -> Iterator[SDFGState]:
:return: Generator that yields states in state-order.
"""
# Get parent states
ptree = state_parent_tree(sdfg)
loopexits: Dict[SDFGState, SDFGState] = defaultdict(lambda: None)
ptree = state_parent_tree(sdfg, loopexits)

# Annotate branches
branch_merges: Dict[SDFGState, SDFGState] = {}
adf = acyclic_dominance_frontier(sdfg)
ipostdom = sdutil.postdominators(sdfg)
for state in sdfg.nodes():
oedges = sdfg.out_edges(state)
# Skip if not branch
Expand All @@ -311,5 +317,7 @@ def stateorder_topological_sort(sdfg: SDFG) -> Iterator[SDFGState]:
common_frontier |= frontier
if len(common_frontier) == 1:
branch_merges[state] = next(iter(common_frontier))
elif len(common_frontier) > 1 and ipostdom[state] in common_frontier:
branch_merges[state] = ipostdom[state]

yield from _stateorder_topological_sort(sdfg, sdfg.start_state, ptree, branch_merges)
yield from _stateorder_topological_sort(sdfg, sdfg.start_state, ptree, branch_merges, loopexits=loopexits)
3 changes: 2 additions & 1 deletion dace/transformation/passes/constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
from dataclasses import dataclass
from dace.frontend.python import astutils
from dace.sdfg.analysis import cfg
from dace.sdfg.sdfg import InterstateEdge
from dace.sdfg import nodes, utils as sdutil
from dace.transformation import pass_pipeline as ppl
Expand Down Expand Up @@ -192,7 +193,7 @@ def _add_nested_datanames(name: str, desc: data.Structure):
result[start_state].update(initial_symbols)

# Traverse SDFG topologically
for state in optional_progressbar(sdfg.topological_sort(start_state), 'Collecting constants',
for state in optional_progressbar(cfg.stateorder_topological_sort(sdfg), 'Collecting constants',
sdfg.number_of_nodes(), self.progress):
# NOTE: We must always check the start-state regardless if there are initial symbols. This is necessary
# when the start-state is a scope's guard instead of a special initialization state, i.e., when the start-
Expand Down
16 changes: 11 additions & 5 deletions dace/transformation/passes/prune_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@properties.make_properties
class RemoveUnusedSymbols(ppl.Pass):
"""
Prunes unused symbols from the SDFG symbol repository (``sdfg.symbols``).
Prunes unused symbols from the SDFG symbol repository (``sdfg.symbols``) and interstate edges.
Also includes uses in Tasklets of all languages.
"""

Expand All @@ -30,7 +30,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:

def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Tuple[int, str]]]:
"""
Propagates constants throughout the SDFG.
Removes unused symbols from the SDFG.
:param sdfg: The SDFG to modify.
:param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass
Expand All @@ -41,13 +41,19 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Tuple[int, str]]]:
"""
result: Set[str] = set()

symbols_to_consider = self.symbols or set(sdfg.symbols.keys())
repository_symbols_to_consider = self.symbols or set(sdfg.symbols.keys())

# Compute used symbols
used_symbols = self.used_symbols(sdfg)

# Remove unused symbols
for sym in symbols_to_consider - used_symbols:
# Remove unused symbols from interstate edge assignments.
for isedge in sdfg.all_interstate_edges():
edge_symbols_to_consider = set(isedge.data.assignments.keys())
for sym in edge_symbols_to_consider - used_symbols:
del isedge.data.assignments[sym]

# Remove unused symbols from the SDFG's symbols repository.
for sym in repository_symbols_to_consider - used_symbols:
if sym in sdfg.symbols:
sdfg.remove_symbol(sym)
result.add(sym)
Expand Down
32 changes: 32 additions & 0 deletions tests/passes/constant_propagation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,37 @@ def test_for_with_external_init_nested_start_with_guard():
assert np.allclose(val1, ref)


def test_skip_branch():
sdfg = dace.SDFG('skip_branch')
sdfg.add_symbol('k', dace.int32)
sdfg.add_array('__return', (1,), dace.int32)
init = sdfg.add_state('init')
if_guard = sdfg.add_state('if_guard')
if_state = sdfg.add_state('if_state')
if_end = sdfg.add_state('if_end')
sdfg.add_edge(init, if_guard, dace.InterstateEdge(assignments=dict(j=0)))
sdfg.add_edge(if_guard, if_end, dace.InterstateEdge('k<0'))
sdfg.add_edge(if_guard, if_state, dace.InterstateEdge('not (k<0)', assignments=dict(j=1)))
sdfg.add_edge(if_state, if_end, dace.InterstateEdge())
ret_a = if_end.add_access('__return')
tasklet = if_end.add_tasklet('c1', {}, {'o1'}, 'o1 = j')
if_end.add_edge(tasklet, 'o1', ret_a, None, dace.Memlet('__return[0]'))

sdfg.validate()

rval_1 = sdfg(k=-1)
assert (rval_1[0] == 0)
rval_2 = sdfg(k=1)
assert (rval_2[0] == 1)

ConstantPropagation().apply_pass(sdfg, {})

rval_1 = sdfg(k=-1)
assert (rval_1[0] == 0)
rval_2 = sdfg(k=1)
assert (rval_2[0] == 1)


if __name__ == '__main__':
test_simple_constants()
test_nested_constants()
Expand All @@ -452,3 +483,4 @@ def test_for_with_external_init_nested_start_with_guard():
test_for_with_external_init()
test_for_with_external_init_nested()
test_for_with_external_init_nested_start_with_guard()
test_skip_branch()

0 comments on commit 59120ae

Please sign in to comment.