From e6440a687c4ea851b8661d18d9490604b116d440 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 11 Oct 2024 11:17:09 -0700 Subject: [PATCH] Fix race conditions in Constant Propagation and Reference-To-View (#1679) * Fixes a case where constant propagation would cause an inter-state edge assignment race condition * Fixes reference-to-view disconnecting a state graph and causing a race condition * More informative error message in code generation for copy dispatching --- dace/codegen/dispatcher.py | 5 ++ .../passes/constant_propagation.py | 16 +++++- .../passes/reference_reduction.py | 37 ++++++++----- tests/passes/constant_propagation_test.py | 55 +++++++++++++++++++ tests/sdfg/reference_test.py | 47 ++++++++++++++++ 5 files changed, 144 insertions(+), 16 deletions(-) diff --git a/dace/codegen/dispatcher.py b/dace/codegen/dispatcher.py index 3ac9e097f8..9bec33b4ef 100644 --- a/dace/codegen/dispatcher.py +++ b/dace/codegen/dispatcher.py @@ -598,6 +598,8 @@ def dispatch_copy(self, src_node: nodes.Node, dst_node: nodes.Node, edge: MultiC cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, function_stream: CodeIOStream, output_stream: CodeIOStream) -> None: """ Dispatches a code generator for a memory copy operation. """ + if edge.data.is_empty(): + return state = cfg.state(state_id) target = self.get_copy_dispatcher(src_node, dst_node, edge, sdfg, state) if target is None: @@ -616,6 +618,9 @@ def dispatch_output_definition(self, src_node: nodes.Node, dst_node: nodes.Node, """ state = cfg.state(state_id) target = self.get_copy_dispatcher(src_node, dst_node, edge, sdfg, state) + if target is None: + raise ValueError( + f'Could not dispatch copy code generator for {src_node} -> {dst_node} in state {state.label}') # Dispatch self._used_targets.add(target) diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index b2c3df3ce8..bfa0928415 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -177,7 +177,7 @@ def _add_nested_datanames(name: str, desc: data.Structure): # TODO: How are we handling this? pass arrays.add(f'{name}.{k}') - + for name, desc in sdfg.arrays.items(): if isinstance(desc, data.Structure): _add_nested_datanames(name, desc) @@ -222,6 +222,20 @@ def _add_nested_datanames(name: str, desc: data.Structure): else: assignments[aname] = aval + for edge in sdfg.out_edges(state): + for aname, aval in assignments.items(): + # If the specific replacement would result in the value + # being both used and reassigned on the same inter-state + # edge, remove it from consideration. + replacements = symbolic.free_symbols_and_functions(aval) + used_in_assignments = { + k + for k, v in edge.data.assignments.items() if aname in symbolic.free_symbols_and_functions(v) + } + reassignments = replacements & edge.data.assignments.keys() + if reassignments and (used_in_assignments - reassignments): + assignments[aname] = _UnknownValue + if state not in result: # Condition may evaluate to False when state is the start-state result[state] = {} redo |= self._propagate(result[state], assignments) diff --git a/dace/transformation/passes/reference_reduction.py b/dace/transformation/passes/reference_reduction.py index 5bee098c55..dc5ae1eb7d 100644 --- a/dace/transformation/passes/reference_reduction.py +++ b/dace/transformation/passes/reference_reduction.py @@ -166,21 +166,28 @@ def remove_refsets( affected_nodes = set() for e in state.in_edges_by_connector(node, 'set'): # This is a reference set edge. Consider scope and neighbors and remove set - edges_to_remove.add(e) - affected_nodes.add(e.src) - affected_nodes.add(e.dst) - - # If source node does not have any other neighbors, it can be removed - if all(ee is e or ee.data.is_empty() for ee in state.all_edges(e.src)): - nodes_to_remove.add(e.src) - # If set reference does not have any other neighbors, it can be removed - if all(ee is e or ee.data.is_empty() for ee in state.all_edges(node)): - nodes_to_remove.add(node) - - # If in a scope, ensure reference node will not be disconnected - scope = state.entry_node(node) - if scope is not None and node not in nodes_to_remove: - edges_to_add.append((scope, None, node, None, Memlet())) + if state.out_degree(e.dst) == 0: + edges_to_remove.add(e) + affected_nodes.add(e.src) + affected_nodes.add(e.dst) + + # If source node does not have any other neighbors, it can be removed + if all(ee is e or ee.data.is_empty() for ee in state.all_edges(e.src)): + nodes_to_remove.add(e.src) + # If set reference does not have any other neighbors, it can be removed + if all(ee is e or ee.data.is_empty() for ee in state.all_edges(node)): + nodes_to_remove.add(node) + + # If in a scope, ensure reference node will not be disconnected + scope = state.entry_node(node) + if scope is not None and node not in nodes_to_remove: + edges_to_add.append((scope, None, node, None, Memlet())) + else: # Node has other neighbors, modify edge to become an empty memlet instead + e.dst_conn = None + e.dst.remove_in_connector('set') + e.data = Memlet() + + # Modify the state graph as necessary for e in edges_to_remove: diff --git a/tests/passes/constant_propagation_test.py b/tests/passes/constant_propagation_test.py index 3420403b49..acb1033554 100644 --- a/tests/passes/constant_propagation_test.py +++ b/tests/passes/constant_propagation_test.py @@ -573,6 +573,59 @@ def test_dependency_change(): assert a[0] == ref +@pytest.mark.parametrize('extra_state', (False, True)) +def test_dependency_change_same_edge(extra_state): + """ + Tests a regression in constant propagation that stems from a variable's + dependency being set in the same edge where the pre-propagated symbol was + also a right-hand side expression. In this case, ``i61`` is incorrectly + propagated to ``i60`` and ``i17`` is set to ``i61``, which is also updated + on the same inter-state edge. + """ + + sdfg = dace.SDFG('tester') + sdfg.add_symbol('N', dace.int64) + sdfg.add_array('a', [1], dace.int64) + sdfg.add_scalar('cont', dace.int64, transient=True) + init = sdfg.add_state() + entry = sdfg.add_state('entry') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + final = sdfg.add_state('final') + + sdfg.add_edge(init, entry, dace.InterstateEdge(assignments=dict(i60='0'))) + sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i61='i60 + 1', i17='i60 * 12'))) + sdfg.add_edge(body, final, dace.InterstateEdge('cont')) + sdfg.add_edge(body, latch, dace.InterstateEdge('not cont', dict(i60='i61'))) + if not extra_state: + sdfg.add_edge(latch, body, dace.InterstateEdge(assignments=dict(i61='i60 + 1', i17='i60 * 12'))) + else: + # Test that the multi-value definition is not propagated to following edges + extra = sdfg.add_state('extra') + sdfg.add_edge(latch, extra, dace.InterstateEdge(assignments=dict(i61='i60 + 1', i17='i60 * 12'))) + sdfg.add_edge(extra, body, dace.InterstateEdge(assignments=dict(i18='i60 + i61'))) + + t = body.add_tasklet('add', {'inp'}, {'out', 'c'}, 'out = inp + i17; c = i61 == 10') + body.add_edge(body.add_read('a'), None, t, 'inp', dace.Memlet('a[0]')) + body.add_edge(t, 'out', body.add_write('a'), None, dace.Memlet('a[0]')) + body.add_edge(t, 'c', body.add_write('cont'), None, dace.Memlet('cont[0]')) + + ConstantPropagation().apply_pass(sdfg, {}) + + sdfg.validate() + + # Python code equivalent of the above SDFG + ref = 0 + i60 = 0 + for i60 in range(0, 10): + i17 = i60 * 12 + ref += i17 + + a = np.zeros([1], np.int64) + sdfg(a=a) + assert a[0] == ref + + if __name__ == '__main__': test_simple_constants() test_nested_constants() @@ -592,3 +645,5 @@ def test_dependency_change(): test_for_with_external_init_nested_start_with_guard() test_skip_branch() test_dependency_change() + test_dependency_change_same_edge(False) + test_dependency_change_same_edge(True) diff --git a/tests/sdfg/reference_test.py b/tests/sdfg/reference_test.py index 6c4d1eda1f..d712c653c9 100644 --- a/tests/sdfg/reference_test.py +++ b/tests/sdfg/reference_test.py @@ -7,6 +7,7 @@ from dace.transformation.passes.reference_reduction import ReferenceToView import numpy as np import pytest +import networkx as nx def test_unset_reference(): @@ -636,6 +637,51 @@ def test_ref2view_refset_in_scope(array_outside_scope, depends_on_iterate): assert np.allclose(B, ref) +def test_ref2view_reconnection(): + """ + Tests a regression in which ReferenceToView disconnects an existing weakly-connected state + and thus creating a race condition. + """ + sdfg = dace.SDFG('reftest') + sdfg.add_array('A', [2], dace.float64) + sdfg.add_array('B', [1], dace.float64) + sdfg.add_reference('ref', [1], dace.float64) + + state = sdfg.add_state() + a2 = state.add_access('A') + ref = state.add_access('ref') + b = state.add_access('B') + + t2 = state.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1') + state.add_edge(ref, None, t2, 'inp', dace.Memlet('ref[0]')) + state.add_edge(t2, 'out', b, None, dace.Memlet('B[0]')) + state.add_edge(a2, None, ref, 'set', dace.Memlet('A[1]')) + + t1 = state.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1') + a1 = state.add_access('A') + state.add_edge(a1, None, t1, 'inp', dace.Memlet('A[1]')) + state.add_edge(t1, 'out', a2, None, dace.Memlet('A[1]')) + + # Test correctness before pass + A = np.random.rand(2) + B = np.random.rand(1) + ref = (A[1] + 2) + sdfg(A=A, B=B) + assert np.allclose(B, ref) + + # Test reference-to-view + result = Pipeline([ReferenceToView()]).apply_pass(sdfg, {}) + assert result['ReferenceToView'] == {'ref'} + + # Pass should not break order + assert len(list(nx.weakly_connected_components(state.nx))) == 1 + + # Test correctness after pass + ref = (A[1] + 2) + sdfg(A=A, B=B) + assert np.allclose(B, ref) + + if __name__ == '__main__': test_unset_reference() test_reference_branch() @@ -662,3 +708,4 @@ def test_ref2view_refset_in_scope(array_outside_scope, depends_on_iterate): test_ref2view_refset_in_scope(False, True) test_ref2view_refset_in_scope(True, False) test_ref2view_refset_in_scope(True, True) + test_ref2view_reconnection()