diff --git a/dace/sdfg/analysis/cfg.py b/dace/sdfg/analysis/cfg.py index 9021a79439..b8d8739a7e 100644 --- a/dace/sdfg/analysis/cfg.py +++ b/dace/sdfg/analysis/cfg.py @@ -5,7 +5,7 @@ from dace.symbolic import pystr_to_symbolic import networkx as nx import sympy as sp -from typing import Dict, Iterator, List, Set +from typing import Dict, Iterator, List, Optional, Set def acyclic_dominance_frontier(sdfg: SDFG, idom=None) -> Dict[SDFGState, Set[SDFGState]]: @@ -67,7 +67,7 @@ def back_edges(sdfg: SDFG, return [e for e in sdfg.edges() if e.dst in alldoms[e.src]] -def state_parent_tree(sdfg: SDFG) -> Dict[SDFGState, SDFGState]: +def state_parent_tree(sdfg: SDFG, loopexits: Optional[Dict[SDFGState, SDFGState]] = None) -> Dict[SDFGState, SDFGState]: """ Computes an upward-pointing tree of each state, pointing to the "parent state" it belongs to (in terms of structured control flow). More formally, @@ -81,7 +81,7 @@ def state_parent_tree(sdfg: SDFG) -> Dict[SDFGState, SDFGState]: """ idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state) alldoms = all_dominators(sdfg, idom) - loopexits: Dict[SDFGState, SDFGState] = defaultdict(lambda: None) + loopexits = loopexits if loopexits is not None else defaultdict(lambda: None) # First, annotate loops for be in back_edges(sdfg, idom, alldoms): @@ -94,10 +94,9 @@ def state_parent_tree(sdfg: SDFG) -> Dict[SDFGState, SDFGState]: in_edges = sdfg.in_edges(guard) out_edges = sdfg.out_edges(guard) - # A loop guard has two or more incoming edges (1 increment and - # n init, all identical), and exactly two outgoing edges (loop and - # exit loop). - if len(in_edges) < 2 or len(out_edges) != 2: + # A loop guard has at least one incoming edges (the backedge, performing the increment), and exactly two + # outgoing edges (loop and exit loop). + if len(in_edges) < 1 or len(out_edges) != 2: continue # The outgoing edges must be negations of one another. @@ -193,7 +192,7 @@ def cond_b(parent, child): # Step up for state in step_up: - if parents[state] is not None: + if parents[state] is not None and parents[parents[state]] is not None: parents[state] = parents[parents[state]] return parents @@ -204,7 +203,8 @@ def _stateorder_topological_sort(sdfg: SDFG, ptree: Dict[SDFGState, SDFGState], branch_merges: Dict[SDFGState, SDFGState], stop: SDFGState = None, - visited: Set[SDFGState] = None) -> Iterator[SDFGState]: + visited: Set[SDFGState] = None, + loopexits: Optional[Dict[SDFGState, SDFGState]] = None) -> Iterator[SDFGState]: """ Helper function for ``stateorder_topological_sort``. @@ -217,6 +217,8 @@ def _stateorder_topological_sort(sdfg: SDFG, :return: Generator that yields states in state-order from ``start`` to ``stop``. """ + loopexits = loopexits if loopexits is not None else defaultdict(lambda: None) + # Traverse states in custom order visited = visited or set() stack = [start] @@ -235,20 +237,21 @@ def _stateorder_topological_sort(sdfg: SDFG, continue elif len(oe) == 2: # Loop or branch # If loop, traverse body, then exit - if ptree[oe[0].dst] == node and ptree[oe[1].dst] != node: - for s in _stateorder_topological_sort(sdfg, oe[0].dst, ptree, branch_merges, stop=node, - visited=visited): - yield s - visited.add(s) - stack.append(oe[1].dst) - continue - elif ptree[oe[1].dst] == node and ptree[oe[0].dst] != node: - for s in _stateorder_topological_sort(sdfg, oe[1].dst, ptree, branch_merges, stop=node, - visited=visited): - yield s - visited.add(s) - stack.append(oe[0].dst) - continue + if node in loopexits: + if oe[0].dst == loopexits[node]: + for s in _stateorder_topological_sort(sdfg, oe[1].dst, ptree, branch_merges, stop=node, + visited=visited, loopexits=loopexits): + yield s + visited.add(s) + stack.append(oe[0].dst) + continue + elif oe[1].dst == loopexits[node]: + for s in _stateorder_topological_sort(sdfg, oe[0].dst, ptree, branch_merges, stop=node, + visited=visited, loopexits=loopexits): + yield s + visited.add(s) + stack.append(oe[1].dst) + continue # Otherwise, passthrough to branch # Branch if node in branch_merges: @@ -259,7 +262,7 @@ def _stateorder_topological_sort(sdfg: SDFG, # Otherwise (e.g., with return/break statements), traverse through each branch, # stopping at the end of the current tree level. mergestate = next(e.dst for e in sdfg.out_edges(stop) if ptree[e.dst] != stop) - except StopIteration: + except (StopIteration, KeyError): # If that fails, simply traverse branches in arbitrary order mergestate = stop @@ -272,7 +275,8 @@ def _stateorder_topological_sort(sdfg: SDFG, ptree, branch_merges, stop=mergestate, - visited=visited): + visited=visited, + loopexits=loopexits): yield s visited.add(s) stack.append(mergestate) @@ -288,11 +292,13 @@ def stateorder_topological_sort(sdfg: SDFG) -> Iterator[SDFGState]: :return: Generator that yields states in state-order. """ # Get parent states - ptree = state_parent_tree(sdfg) + loopexits: Dict[SDFGState, SDFGState] = defaultdict(lambda: None) + ptree = state_parent_tree(sdfg, loopexits) # Annotate branches branch_merges: Dict[SDFGState, SDFGState] = {} adf = acyclic_dominance_frontier(sdfg) + ipostdom = sdutil.postdominators(sdfg) for state in sdfg.nodes(): oedges = sdfg.out_edges(state) # Skip if not branch @@ -311,5 +317,7 @@ def stateorder_topological_sort(sdfg: SDFG) -> Iterator[SDFGState]: common_frontier |= frontier if len(common_frontier) == 1: branch_merges[state] = next(iter(common_frontier)) + elif len(common_frontier) > 1 and ipostdom[state] in common_frontier: + branch_merges[state] = ipostdom[state] - yield from _stateorder_topological_sort(sdfg, sdfg.start_state, ptree, branch_merges) + yield from _stateorder_topological_sort(sdfg, sdfg.start_state, ptree, branch_merges, loopexits=loopexits) diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 53b03c52e0..50aac77ae4 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -3,6 +3,7 @@ import ast from dataclasses import dataclass from dace.frontend.python import astutils +from dace.sdfg.analysis import cfg from dace.sdfg.sdfg import InterstateEdge from dace.sdfg import nodes, utils as sdutil from dace.transformation import pass_pipeline as ppl @@ -192,7 +193,7 @@ def _add_nested_datanames(name: str, desc: data.Structure): result[start_state].update(initial_symbols) # Traverse SDFG topologically - for state in optional_progressbar(sdfg.topological_sort(start_state), 'Collecting constants', + for state in optional_progressbar(cfg.stateorder_topological_sort(sdfg), 'Collecting constants', sdfg.number_of_nodes(), self.progress): # NOTE: We must always check the start-state regardless if there are initial symbols. This is necessary # when the start-state is a scope's guard instead of a special initialization state, i.e., when the start- diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index bff2e1350b..336ac4b428 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -13,7 +13,7 @@ @properties.make_properties class RemoveUnusedSymbols(ppl.Pass): """ - Prunes unused symbols from the SDFG symbol repository (``sdfg.symbols``). + Prunes unused symbols from the SDFG symbol repository (``sdfg.symbols``) and interstate edges. Also includes uses in Tasklets of all languages. """ @@ -30,7 +30,7 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Tuple[int, str]]]: """ - Propagates constants throughout the SDFG. + Removes unused symbols from the SDFG. :param sdfg: The SDFG to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass @@ -41,13 +41,19 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Tuple[int, str]]]: """ result: Set[str] = set() - symbols_to_consider = self.symbols or set(sdfg.symbols.keys()) + repository_symbols_to_consider = self.symbols or set(sdfg.symbols.keys()) # Compute used symbols used_symbols = self.used_symbols(sdfg) - # Remove unused symbols - for sym in symbols_to_consider - used_symbols: + # Remove unused symbols from interstate edge assignments. + for isedge in sdfg.all_interstate_edges(): + edge_symbols_to_consider = set(isedge.data.assignments.keys()) + for sym in edge_symbols_to_consider - used_symbols: + del isedge.data.assignments[sym] + + # Remove unused symbols from the SDFG's symbols repository. + for sym in repository_symbols_to_consider - used_symbols: if sym in sdfg.symbols: sdfg.remove_symbol(sym) result.add(sym) diff --git a/tests/passes/constant_propagation_test.py b/tests/passes/constant_propagation_test.py index c41f1fc4a6..5e7d3b0bac 100644 --- a/tests/passes/constant_propagation_test.py +++ b/tests/passes/constant_propagation_test.py @@ -436,6 +436,37 @@ def test_for_with_external_init_nested_start_with_guard(): assert np.allclose(val1, ref) +def test_skip_branch(): + sdfg = dace.SDFG('skip_branch') + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('__return', (1,), dace.int32) + init = sdfg.add_state('init') + if_guard = sdfg.add_state('if_guard') + if_state = sdfg.add_state('if_state') + if_end = sdfg.add_state('if_end') + sdfg.add_edge(init, if_guard, dace.InterstateEdge(assignments=dict(j=0))) + sdfg.add_edge(if_guard, if_end, dace.InterstateEdge('k<0')) + sdfg.add_edge(if_guard, if_state, dace.InterstateEdge('not (k<0)', assignments=dict(j=1))) + sdfg.add_edge(if_state, if_end, dace.InterstateEdge()) + ret_a = if_end.add_access('__return') + tasklet = if_end.add_tasklet('c1', {}, {'o1'}, 'o1 = j') + if_end.add_edge(tasklet, 'o1', ret_a, None, dace.Memlet('__return[0]')) + + sdfg.validate() + + rval_1 = sdfg(k=-1) + assert (rval_1[0] == 0) + rval_2 = sdfg(k=1) + assert (rval_2[0] == 1) + + ConstantPropagation().apply_pass(sdfg, {}) + + rval_1 = sdfg(k=-1) + assert (rval_1[0] == 0) + rval_2 = sdfg(k=1) + assert (rval_2[0] == 1) + + if __name__ == '__main__': test_simple_constants() test_nested_constants() @@ -452,3 +483,4 @@ def test_for_with_external_init_nested_start_with_guard(): test_for_with_external_init() test_for_with_external_init_nested() test_for_with_external_init_nested_start_with_guard() + test_skip_branch()