From 7210cb686aa6fa330d274fe7c09aed415a137852 Mon Sep 17 00:00:00 2001 From: Luigi Fusco Date: Sun, 8 Sep 2024 02:15:28 +0200 Subject: [PATCH] fix missed exploration of edges in constant propagation (#1635) There is a bug related to the missed exploration of interstate edges during constant propagation in case a loop body has a conditional assignment. The reverse DFS yields parent-node pairs and analyzes only the edge connecting the two. The DFS will yield a certain node only once, while the assumption in the code is that the uniqueness is enforced on the parent-node pair. This results in only one outgoing interstate edge per body state being visited, leading to mistakes in the common case of conditional assignments (which result in two outgoing edges performing different assignments). If the visited edge does not perform an assignment or assigns the initialization value, the symbol will be wrongly interpreted as a constant and replaced in downstream states. A short reproducing example is: ```python N = dace.symbol('N', dace.int64) @dace.program def program(in_arr: dace.bool[N], arr: dace.bool[N]): check = False for i in range(N): if in_arr[i]: check = True else: check = False for i in dace.map[0:N]: arr[i] = check sdfg = program.to_sdfg(simplify=True) sdfg.save('bug.sdfg') # "arr[i] = check" will be replaced by "arr[i] = False" ``` The fix makes sure all interstate edges are visited at least once. --- .../passes/constant_propagation.py | 122 ++++++------------ tests/passes/constant_propagation_test.py | 35 +++++ 2 files changed, 72 insertions(+), 85 deletions(-) diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 7c05b3ea38..9006ae3c10 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -124,7 +124,7 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = result = {k: v for k, v in result.items() if k not in fsyms} for sym in result: if sym in sdfg.symbols: - # Remove from symbol repository and nested SDFG symbol mapipng + # Remove from symbol repository and nested SDFG symbol mapping sdfg.remove_symbol(sym) result = set(result.keys()) @@ -184,62 +184,44 @@ def _add_nested_datanames(name: str, desc: data.Structure): # Process: # * Collect constants in topologically ordered states - # * If unvisited state has one incoming edge - propagate symbols forward and edge assignments - # * If unvisited state has more than one incoming edge, consider all paths (use reverse DFS on unvisited paths) + # * Propagate forward symbols forward and edge assignments # * If value is ambiguous (not the same), set value to UNKNOWN + # * Repeat until no update is performed start_state = sdfg.start_state if initial_symbols: result[start_state] = {} result[start_state].update(initial_symbols) - # Traverse SDFG topologically - for state in optional_progressbar(cfg.blockorder_topological_sort(sdfg), 'Collecting constants', - sdfg.number_of_nodes(), self.progress): - # NOTE: We must always check the start-state regardless if there are initial symbols. This is necessary - # when the start-state is a scope's guard instead of a special initialization state, i.e., when the start- - # state has incoming edges that may involve the initial symbols. See also: - # `tests.passes.constant_propagation_test.test_for_with_external_init_nested_start_with_guard`` - if state in result and state is not start_state: - continue - - # Get predecessors - in_edges = sdfg.in_edges(state) - if len(in_edges) == 1: # Special case, propagate as-is - if state not in result: # Condition evaluates to False when state is the start-state + redo = True + while redo: + redo = False + # Traverse SDFG topologically + for state in optional_progressbar(cfg.blockorder_topological_sort(sdfg), 'Collecting constants', + sdfg.number_of_nodes(), self.progress): + + # Get predecessors + in_edges = sdfg.in_edges(state) + assignments = {} + for edge in in_edges: + # If source was already visited, use its propagated constants + constants: Dict[str, Any] = {} + if edge.src in result: + constants.update(result[edge.src]) + + # Update constants with incoming edge + self._propagate(constants, self._data_independent_assignments(edge.data, arrays)) + + for aname, aval in constants.items(): + # If something was assigned more than once (to a different value), it's not a constant + if aname in assignments and aval != assignments[aname]: + assignments[aname] = _UnknownValue + else: + assignments[aname] = aval + + if state not in result: # Condition may evaluate to False when state is the start-state result[state] = {} - - # First the prior state - if in_edges[0].src in result: # Condition evaluates to False when state is the start-state - self._propagate(result[state], result[in_edges[0].src]) - - # Then assignments on the incoming edge - self._propagate(result[state], self._data_independent_assignments(in_edges[0].data, arrays)) - continue - - # More than one incoming edge: may require reversed traversal - assignments = {} - for edge in in_edges: - # If source was already visited, use its propagated constants - constants: Dict[str, Any] = {} - if edge.src in result: - constants.update(result[edge.src]) - else: # Otherwise, reverse DFS to find constants until a visited state - constants = self._constants_from_unvisited_state(sdfg, edge.src, arrays, result) - - # Update constants with incoming edge - self._propagate(constants, self._data_independent_assignments(edge.data, arrays)) - - for aname, aval in constants.items(): - # If something was assigned more than once (to a different value), it's not a constant - if aname in assignments and aval != assignments[aname]: - assignments[aname] = _UnknownValue - else: - assignments[aname] = aval - - if state not in result: # Condition may evaluate to False when state is the start-state - result[state] = {} - self._propagate(result[state], assignments) + redo |= self._propagate(result[state], assignments) return result @@ -272,22 +254,16 @@ def _find_desc_symbols(self, sdfg: SDFG, constants: Dict[SDFGState, Dict[str, An return symbols_in_data, symbols_in_data_with_multiple_values - def _propagate(self, symbols: Dict[str, Any], new_symbols: Dict[str, Any], backward: bool = False): + def _propagate(self, symbols: Dict[str, Any], new_symbols: Dict[str, Any]) -> bool: """ Updates symbols dictionary in-place with new symbols, propagating existing ones within. :param symbols: The symbols dictionary to update. :param new_symbols: The new symbols to include (and propagate ``symbols`` into). - :param backward: If True, assumes symbol back-propagation (i.e., only update keys in symbols if newer). + :return: True if symbols was modified, False otherwise """ if not new_symbols: - return - # If propagating backwards, ensure symbols are only added if they are not overridden - if backward: - for k, v in new_symbols.items(): - if k not in symbols: - symbols[k] = v - return + return False repl = {k: v for k, v in symbols.items() if v is not _UnknownValue} @@ -314,8 +290,11 @@ def _replace_assignment(v, assignment): k: _replace_assignment(v, {k}) if v is not _UnknownValue else _UnknownValue for k, v in new_symbols.items() } + original_symbols = symbols.copy() symbols.update(propagated_symbols) + return original_symbols != symbols + def _data_independent_assignments(self, edge: InterstateEdge, arrays: Set[str]) -> Dict[str, Any]: """ Return symbol assignments that only depend on other symbols and constants, rather than data descriptors. @@ -324,30 +303,3 @@ def _data_independent_assignments(self, edge: InterstateEdge, arrays: Set[str]) k: v if (not (symbolic.free_symbols_and_functions(v) & arrays)) else _UnknownValue for k, v in edge.assignments.items() } - - def _constants_from_unvisited_state(self, sdfg: SDFG, state: SDFGState, arrays: Set[str], - existing_constants: Dict[SDFGState, Dict[str, Any]]) -> Dict[str, Any]: - """ - Collects constants from an unvisited state, traversing backwards until reaching states that do have - collected constants. - """ - result: Dict[str, Any] = {} - - for parent, node in sdutil.dfs_conditional(sdfg, - sources=[state], - reverse=True, - condition=lambda p, c: c not in existing_constants, - yield_parent=True): - # Skip first node - if parent is None: - continue - - # Get connecting edge (reversed) - edge = sdfg.edges_between(node, parent)[0] - - # If node already has propagated constants, update dictionary and stop traversal - self._propagate(result, self._data_independent_assignments(edge.data, arrays), True) - if node in existing_constants: - self._propagate(result, existing_constants[node], True) - - return result diff --git a/tests/passes/constant_propagation_test.py b/tests/passes/constant_propagation_test.py index 5e7d3b0bac..89b7e7ed5c 100644 --- a/tests/passes/constant_propagation_test.py +++ b/tests/passes/constant_propagation_test.py @@ -352,6 +352,40 @@ def test_for_with_external_init(): assert np.allclose(val1, ref) +def test_for_with_conditional_assignment(): + N = dace.symbol('N') + + sdfg = dace.SDFG('for_with_conditional_assignment') + sdfg.add_symbol('i', dace.int64) + sdfg.add_symbol('check', dace.bool) + sdfg.add_symbol('__tmp1', dace.bool) + sdfg.add_array('__return', {1,}, dace.bool) + sdfg.add_array('in_arr', {N,}, dace.bool) + + init = sdfg.add_state('init') + guard = sdfg.add_state('guard') + condition = sdfg.add_state('condition') + if_branch = sdfg.add_state('if_branch') + else_branch = sdfg.add_state('else_branch') + out = sdfg.add_state('out') + + sdfg.add_edge(init, guard, dace.InterstateEdge(None, {'i': '0', 'check': 'False'})) + sdfg.add_edge(guard, condition, dace.InterstateEdge('(i < N)', {'__tmp1': 'in_arr[i]'})) + sdfg.add_edge(condition, if_branch, dace.InterstateEdge('__tmp1')) + sdfg.add_edge(if_branch, else_branch, dace.InterstateEdge(None, {'check': 'False'})) + sdfg.add_edge(condition, else_branch, dace.InterstateEdge('(not __tmp1)', {'check': 'True'})) + sdfg.add_edge(else_branch, guard, dace.InterstateEdge(None, {'i': '(i + 1)'})) + sdfg.add_edge(guard, out, dace.InterstateEdge('(not (i < N))')) + + a = out.add_write('__return') + t = out.add_tasklet('tasklet', {}, {'__out'}, '__out = check') + out.add_edge(t, '__out', a, None, dace.Memlet('__return[0]')) + sdfg.validate() + + ConstantPropagation().apply_pass(sdfg, {}) + assert t.code.as_string == '__out = check' + + def test_for_with_external_init_nested(): N = dace.symbol('N') @@ -481,6 +515,7 @@ def test_skip_branch(): test_allocation_varying(False) test_allocation_varying(True) test_for_with_external_init() + test_for_with_conditional_assignment() test_for_with_external_init_nested() test_for_with_external_init_nested_start_with_guard() test_skip_branch()