diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 7c05b3ea38..9006ae3c10 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -124,7 +124,7 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = result = {k: v for k, v in result.items() if k not in fsyms} for sym in result: if sym in sdfg.symbols: - # Remove from symbol repository and nested SDFG symbol mapipng + # Remove from symbol repository and nested SDFG symbol mapping sdfg.remove_symbol(sym) result = set(result.keys()) @@ -184,62 +184,44 @@ def _add_nested_datanames(name: str, desc: data.Structure): # Process: # * Collect constants in topologically ordered states - # * If unvisited state has one incoming edge - propagate symbols forward and edge assignments - # * If unvisited state has more than one incoming edge, consider all paths (use reverse DFS on unvisited paths) + # * Propagate forward symbols forward and edge assignments # * If value is ambiguous (not the same), set value to UNKNOWN + # * Repeat until no update is performed start_state = sdfg.start_state if initial_symbols: result[start_state] = {} result[start_state].update(initial_symbols) - # Traverse SDFG topologically - for state in optional_progressbar(cfg.blockorder_topological_sort(sdfg), 'Collecting constants', - sdfg.number_of_nodes(), self.progress): - # NOTE: We must always check the start-state regardless if there are initial symbols. This is necessary - # when the start-state is a scope's guard instead of a special initialization state, i.e., when the start- - # state has incoming edges that may involve the initial symbols. See also: - # `tests.passes.constant_propagation_test.test_for_with_external_init_nested_start_with_guard`` - if state in result and state is not start_state: - continue - - # Get predecessors - in_edges = sdfg.in_edges(state) - if len(in_edges) == 1: # Special case, propagate as-is - if state not in result: # Condition evaluates to False when state is the start-state + redo = True + while redo: + redo = False + # Traverse SDFG topologically + for state in optional_progressbar(cfg.blockorder_topological_sort(sdfg), 'Collecting constants', + sdfg.number_of_nodes(), self.progress): + + # Get predecessors + in_edges = sdfg.in_edges(state) + assignments = {} + for edge in in_edges: + # If source was already visited, use its propagated constants + constants: Dict[str, Any] = {} + if edge.src in result: + constants.update(result[edge.src]) + + # Update constants with incoming edge + self._propagate(constants, self._data_independent_assignments(edge.data, arrays)) + + for aname, aval in constants.items(): + # If something was assigned more than once (to a different value), it's not a constant + if aname in assignments and aval != assignments[aname]: + assignments[aname] = _UnknownValue + else: + assignments[aname] = aval + + if state not in result: # Condition may evaluate to False when state is the start-state result[state] = {} - - # First the prior state - if in_edges[0].src in result: # Condition evaluates to False when state is the start-state - self._propagate(result[state], result[in_edges[0].src]) - - # Then assignments on the incoming edge - self._propagate(result[state], self._data_independent_assignments(in_edges[0].data, arrays)) - continue - - # More than one incoming edge: may require reversed traversal - assignments = {} - for edge in in_edges: - # If source was already visited, use its propagated constants - constants: Dict[str, Any] = {} - if edge.src in result: - constants.update(result[edge.src]) - else: # Otherwise, reverse DFS to find constants until a visited state - constants = self._constants_from_unvisited_state(sdfg, edge.src, arrays, result) - - # Update constants with incoming edge - self._propagate(constants, self._data_independent_assignments(edge.data, arrays)) - - for aname, aval in constants.items(): - # If something was assigned more than once (to a different value), it's not a constant - if aname in assignments and aval != assignments[aname]: - assignments[aname] = _UnknownValue - else: - assignments[aname] = aval - - if state not in result: # Condition may evaluate to False when state is the start-state - result[state] = {} - self._propagate(result[state], assignments) + redo |= self._propagate(result[state], assignments) return result @@ -272,22 +254,16 @@ def _find_desc_symbols(self, sdfg: SDFG, constants: Dict[SDFGState, Dict[str, An return symbols_in_data, symbols_in_data_with_multiple_values - def _propagate(self, symbols: Dict[str, Any], new_symbols: Dict[str, Any], backward: bool = False): + def _propagate(self, symbols: Dict[str, Any], new_symbols: Dict[str, Any]) -> bool: """ Updates symbols dictionary in-place with new symbols, propagating existing ones within. :param symbols: The symbols dictionary to update. :param new_symbols: The new symbols to include (and propagate ``symbols`` into). - :param backward: If True, assumes symbol back-propagation (i.e., only update keys in symbols if newer). + :return: True if symbols was modified, False otherwise """ if not new_symbols: - return - # If propagating backwards, ensure symbols are only added if they are not overridden - if backward: - for k, v in new_symbols.items(): - if k not in symbols: - symbols[k] = v - return + return False repl = {k: v for k, v in symbols.items() if v is not _UnknownValue} @@ -314,8 +290,11 @@ def _replace_assignment(v, assignment): k: _replace_assignment(v, {k}) if v is not _UnknownValue else _UnknownValue for k, v in new_symbols.items() } + original_symbols = symbols.copy() symbols.update(propagated_symbols) + return original_symbols != symbols + def _data_independent_assignments(self, edge: InterstateEdge, arrays: Set[str]) -> Dict[str, Any]: """ Return symbol assignments that only depend on other symbols and constants, rather than data descriptors. @@ -324,30 +303,3 @@ def _data_independent_assignments(self, edge: InterstateEdge, arrays: Set[str]) k: v if (not (symbolic.free_symbols_and_functions(v) & arrays)) else _UnknownValue for k, v in edge.assignments.items() } - - def _constants_from_unvisited_state(self, sdfg: SDFG, state: SDFGState, arrays: Set[str], - existing_constants: Dict[SDFGState, Dict[str, Any]]) -> Dict[str, Any]: - """ - Collects constants from an unvisited state, traversing backwards until reaching states that do have - collected constants. - """ - result: Dict[str, Any] = {} - - for parent, node in sdutil.dfs_conditional(sdfg, - sources=[state], - reverse=True, - condition=lambda p, c: c not in existing_constants, - yield_parent=True): - # Skip first node - if parent is None: - continue - - # Get connecting edge (reversed) - edge = sdfg.edges_between(node, parent)[0] - - # If node already has propagated constants, update dictionary and stop traversal - self._propagate(result, self._data_independent_assignments(edge.data, arrays), True) - if node in existing_constants: - self._propagate(result, existing_constants[node], True) - - return result diff --git a/tests/passes/constant_propagation_test.py b/tests/passes/constant_propagation_test.py index 5e7d3b0bac..89b7e7ed5c 100644 --- a/tests/passes/constant_propagation_test.py +++ b/tests/passes/constant_propagation_test.py @@ -352,6 +352,40 @@ def test_for_with_external_init(): assert np.allclose(val1, ref) +def test_for_with_conditional_assignment(): + N = dace.symbol('N') + + sdfg = dace.SDFG('for_with_conditional_assignment') + sdfg.add_symbol('i', dace.int64) + sdfg.add_symbol('check', dace.bool) + sdfg.add_symbol('__tmp1', dace.bool) + sdfg.add_array('__return', {1,}, dace.bool) + sdfg.add_array('in_arr', {N,}, dace.bool) + + init = sdfg.add_state('init') + guard = sdfg.add_state('guard') + condition = sdfg.add_state('condition') + if_branch = sdfg.add_state('if_branch') + else_branch = sdfg.add_state('else_branch') + out = sdfg.add_state('out') + + sdfg.add_edge(init, guard, dace.InterstateEdge(None, {'i': '0', 'check': 'False'})) + sdfg.add_edge(guard, condition, dace.InterstateEdge('(i < N)', {'__tmp1': 'in_arr[i]'})) + sdfg.add_edge(condition, if_branch, dace.InterstateEdge('__tmp1')) + sdfg.add_edge(if_branch, else_branch, dace.InterstateEdge(None, {'check': 'False'})) + sdfg.add_edge(condition, else_branch, dace.InterstateEdge('(not __tmp1)', {'check': 'True'})) + sdfg.add_edge(else_branch, guard, dace.InterstateEdge(None, {'i': '(i + 1)'})) + sdfg.add_edge(guard, out, dace.InterstateEdge('(not (i < N))')) + + a = out.add_write('__return') + t = out.add_tasklet('tasklet', {}, {'__out'}, '__out = check') + out.add_edge(t, '__out', a, None, dace.Memlet('__return[0]')) + sdfg.validate() + + ConstantPropagation().apply_pass(sdfg, {}) + assert t.code.as_string == '__out = check' + + def test_for_with_external_init_nested(): N = dace.symbol('N') @@ -481,6 +515,7 @@ def test_skip_branch(): test_allocation_varying(False) test_allocation_varying(True) test_for_with_external_init() + test_for_with_conditional_assignment() test_for_with_external_init_nested() test_for_with_external_init_nested_start_with_guard() test_skip_branch()