Skip to content

Commit

Permalink
fix missed exploration of edges in constant propagation (#1635)
Browse files Browse the repository at this point in the history
There is a bug related to the missed exploration of interstate edges
during constant propagation in case a loop body has a conditional
assignment. The reverse DFS yields parent-node pairs and analyzes only
the edge connecting the two. The DFS will yield a certain node only
once, while the assumption in the code is that the uniqueness is
enforced on the parent-node pair. This results in only one outgoing
interstate edge per body state being visited, leading to mistakes in the
common case of conditional assignments (which result in two outgoing
edges performing different assignments). If the visited edge does not
perform an assignment or assigns the initialization value, the symbol
will be wrongly interpreted as a constant and replaced in downstream
states.

A short reproducing example is:
```python
N = dace.symbol('N', dace.int64)

@dace.program
def program(in_arr: dace.bool[N], arr: dace.bool[N]):
    check = False
    for i in range(N):
        if in_arr[i]:
            check = True
        else:
            check = False
    for i in dace.map[0:N]:
        arr[i] = check

sdfg = program.to_sdfg(simplify=True)
sdfg.save('bug.sdfg')

# "arr[i] = check" will be replaced by "arr[i] = False"
```

The fix makes sure all interstate edges are visited at least once.
  • Loading branch information
luigifusco authored Sep 8, 2024
1 parent 0a2c55a commit 7210cb6
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 85 deletions.
122 changes: 37 additions & 85 deletions dace/transformation/passes/constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] =
result = {k: v for k, v in result.items() if k not in fsyms}
for sym in result:
if sym in sdfg.symbols:
# Remove from symbol repository and nested SDFG symbol mapipng
# Remove from symbol repository and nested SDFG symbol mapping
sdfg.remove_symbol(sym)

result = set(result.keys())
Expand Down Expand Up @@ -184,62 +184,44 @@ def _add_nested_datanames(name: str, desc: data.Structure):

# Process:
# * Collect constants in topologically ordered states
# * If unvisited state has one incoming edge - propagate symbols forward and edge assignments
# * If unvisited state has more than one incoming edge, consider all paths (use reverse DFS on unvisited paths)
# * Propagate forward symbols forward and edge assignments
# * If value is ambiguous (not the same), set value to UNKNOWN
# * Repeat until no update is performed

start_state = sdfg.start_state
if initial_symbols:
result[start_state] = {}
result[start_state].update(initial_symbols)

# Traverse SDFG topologically
for state in optional_progressbar(cfg.blockorder_topological_sort(sdfg), 'Collecting constants',
sdfg.number_of_nodes(), self.progress):
# NOTE: We must always check the start-state regardless if there are initial symbols. This is necessary
# when the start-state is a scope's guard instead of a special initialization state, i.e., when the start-
# state has incoming edges that may involve the initial symbols. See also:
# `tests.passes.constant_propagation_test.test_for_with_external_init_nested_start_with_guard``
if state in result and state is not start_state:
continue

# Get predecessors
in_edges = sdfg.in_edges(state)
if len(in_edges) == 1: # Special case, propagate as-is
if state not in result: # Condition evaluates to False when state is the start-state
redo = True
while redo:
redo = False
# Traverse SDFG topologically
for state in optional_progressbar(cfg.blockorder_topological_sort(sdfg), 'Collecting constants',
sdfg.number_of_nodes(), self.progress):

# Get predecessors
in_edges = sdfg.in_edges(state)
assignments = {}
for edge in in_edges:
# If source was already visited, use its propagated constants
constants: Dict[str, Any] = {}
if edge.src in result:
constants.update(result[edge.src])

# Update constants with incoming edge
self._propagate(constants, self._data_independent_assignments(edge.data, arrays))

for aname, aval in constants.items():
# If something was assigned more than once (to a different value), it's not a constant
if aname in assignments and aval != assignments[aname]:
assignments[aname] = _UnknownValue
else:
assignments[aname] = aval

if state not in result: # Condition may evaluate to False when state is the start-state
result[state] = {}

# First the prior state
if in_edges[0].src in result: # Condition evaluates to False when state is the start-state
self._propagate(result[state], result[in_edges[0].src])

# Then assignments on the incoming edge
self._propagate(result[state], self._data_independent_assignments(in_edges[0].data, arrays))
continue

# More than one incoming edge: may require reversed traversal
assignments = {}
for edge in in_edges:
# If source was already visited, use its propagated constants
constants: Dict[str, Any] = {}
if edge.src in result:
constants.update(result[edge.src])
else: # Otherwise, reverse DFS to find constants until a visited state
constants = self._constants_from_unvisited_state(sdfg, edge.src, arrays, result)

# Update constants with incoming edge
self._propagate(constants, self._data_independent_assignments(edge.data, arrays))

for aname, aval in constants.items():
# If something was assigned more than once (to a different value), it's not a constant
if aname in assignments and aval != assignments[aname]:
assignments[aname] = _UnknownValue
else:
assignments[aname] = aval

if state not in result: # Condition may evaluate to False when state is the start-state
result[state] = {}
self._propagate(result[state], assignments)
redo |= self._propagate(result[state], assignments)

return result

Expand Down Expand Up @@ -272,22 +254,16 @@ def _find_desc_symbols(self, sdfg: SDFG, constants: Dict[SDFGState, Dict[str, An

return symbols_in_data, symbols_in_data_with_multiple_values

def _propagate(self, symbols: Dict[str, Any], new_symbols: Dict[str, Any], backward: bool = False):
def _propagate(self, symbols: Dict[str, Any], new_symbols: Dict[str, Any]) -> bool:
"""
Updates symbols dictionary in-place with new symbols, propagating existing ones within.
:param symbols: The symbols dictionary to update.
:param new_symbols: The new symbols to include (and propagate ``symbols`` into).
:param backward: If True, assumes symbol back-propagation (i.e., only update keys in symbols if newer).
:return: True if symbols was modified, False otherwise
"""
if not new_symbols:
return
# If propagating backwards, ensure symbols are only added if they are not overridden
if backward:
for k, v in new_symbols.items():
if k not in symbols:
symbols[k] = v
return
return False

repl = {k: v for k, v in symbols.items() if v is not _UnknownValue}

Expand All @@ -314,8 +290,11 @@ def _replace_assignment(v, assignment):
k: _replace_assignment(v, {k}) if v is not _UnknownValue else _UnknownValue
for k, v in new_symbols.items()
}
original_symbols = symbols.copy()
symbols.update(propagated_symbols)

return original_symbols != symbols

def _data_independent_assignments(self, edge: InterstateEdge, arrays: Set[str]) -> Dict[str, Any]:
"""
Return symbol assignments that only depend on other symbols and constants, rather than data descriptors.
Expand All @@ -324,30 +303,3 @@ def _data_independent_assignments(self, edge: InterstateEdge, arrays: Set[str])
k: v if (not (symbolic.free_symbols_and_functions(v) & arrays)) else _UnknownValue
for k, v in edge.assignments.items()
}

def _constants_from_unvisited_state(self, sdfg: SDFG, state: SDFGState, arrays: Set[str],
existing_constants: Dict[SDFGState, Dict[str, Any]]) -> Dict[str, Any]:
"""
Collects constants from an unvisited state, traversing backwards until reaching states that do have
collected constants.
"""
result: Dict[str, Any] = {}

for parent, node in sdutil.dfs_conditional(sdfg,
sources=[state],
reverse=True,
condition=lambda p, c: c not in existing_constants,
yield_parent=True):
# Skip first node
if parent is None:
continue

# Get connecting edge (reversed)
edge = sdfg.edges_between(node, parent)[0]

# If node already has propagated constants, update dictionary and stop traversal
self._propagate(result, self._data_independent_assignments(edge.data, arrays), True)
if node in existing_constants:
self._propagate(result, existing_constants[node], True)

return result
35 changes: 35 additions & 0 deletions tests/passes/constant_propagation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,40 @@ def test_for_with_external_init():
assert np.allclose(val1, ref)


def test_for_with_conditional_assignment():
N = dace.symbol('N')

sdfg = dace.SDFG('for_with_conditional_assignment')
sdfg.add_symbol('i', dace.int64)
sdfg.add_symbol('check', dace.bool)
sdfg.add_symbol('__tmp1', dace.bool)
sdfg.add_array('__return', {1,}, dace.bool)
sdfg.add_array('in_arr', {N,}, dace.bool)

init = sdfg.add_state('init')
guard = sdfg.add_state('guard')
condition = sdfg.add_state('condition')
if_branch = sdfg.add_state('if_branch')
else_branch = sdfg.add_state('else_branch')
out = sdfg.add_state('out')

sdfg.add_edge(init, guard, dace.InterstateEdge(None, {'i': '0', 'check': 'False'}))
sdfg.add_edge(guard, condition, dace.InterstateEdge('(i < N)', {'__tmp1': 'in_arr[i]'}))
sdfg.add_edge(condition, if_branch, dace.InterstateEdge('__tmp1'))
sdfg.add_edge(if_branch, else_branch, dace.InterstateEdge(None, {'check': 'False'}))
sdfg.add_edge(condition, else_branch, dace.InterstateEdge('(not __tmp1)', {'check': 'True'}))
sdfg.add_edge(else_branch, guard, dace.InterstateEdge(None, {'i': '(i + 1)'}))
sdfg.add_edge(guard, out, dace.InterstateEdge('(not (i < N))'))

a = out.add_write('__return')
t = out.add_tasklet('tasklet', {}, {'__out'}, '__out = check')
out.add_edge(t, '__out', a, None, dace.Memlet('__return[0]'))
sdfg.validate()

ConstantPropagation().apply_pass(sdfg, {})
assert t.code.as_string == '__out = check'


def test_for_with_external_init_nested():

N = dace.symbol('N')
Expand Down Expand Up @@ -481,6 +515,7 @@ def test_skip_branch():
test_allocation_varying(False)
test_allocation_varying(True)
test_for_with_external_init()
test_for_with_conditional_assignment()
test_for_with_external_init_nested()
test_for_with_external_init_nested_start_with_guard()
test_skip_branch()

0 comments on commit 7210cb6

Please sign in to comment.