diff --git a/dace/transformation/dataflow/stream_transient.py b/dace/transformation/dataflow/stream_transient.py index b8c0f5820c..d4df0b6855 100644 --- a/dace/transformation/dataflow/stream_transient.py +++ b/dace/transformation/dataflow/stream_transient.py @@ -6,7 +6,6 @@ from dace.symbolic import symstr import warnings -from numpy.core.numeric import outer from dace import data, dtypes, registry, symbolic, subsets from dace.frontend.operations import detect_reduction_type from dace.properties import SymbolicProperty, make_properties, Property diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 9006ae3c10..b2c3df3ce8 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -214,7 +214,10 @@ def _add_nested_datanames(name: str, desc: data.Structure): for aname, aval in constants.items(): # If something was assigned more than once (to a different value), it's not a constant - if aname in assignments and aval != assignments[aname]: + # If a symbol appearing in the replacing expression of a constant is modified, + # the constant is not valid anymore + if ((aname in assignments and aval != assignments[aname]) or + symbolic.free_symbols_and_functions(aval) & edge.data.assignments.keys()): assignments[aname] = _UnknownValue else: assignments[aname] = aval diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index d9cd798f0c..727ec5555b 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -1092,7 +1092,7 @@ def _subgraph_transformation_extract_sdfg_arg(*args) -> SDFG: raise TypeError('Unrecognized graph type "%s"' % type(subgraph).__name__) -def single_level_sdfg_only(cls: ppl.Pass): +def single_level_sdfg_only(cls: PassT) -> PassT: for function_name in ['apply_pass', 'apply_to']: _make_function_blocksafe(cls, function_name, lambda *args: args[1]) diff --git a/tests/passes/constant_propagation_test.py b/tests/passes/constant_propagation_test.py index 89b7e7ed5c..3420403b49 100644 --- a/tests/passes/constant_propagation_test.py +++ b/tests/passes/constant_propagation_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import pytest import dace @@ -359,8 +359,8 @@ def test_for_with_conditional_assignment(): sdfg.add_symbol('i', dace.int64) sdfg.add_symbol('check', dace.bool) sdfg.add_symbol('__tmp1', dace.bool) - sdfg.add_array('__return', {1,}, dace.bool) - sdfg.add_array('in_arr', {N,}, dace.bool) + sdfg.add_array('__return', {1}, dace.bool) + sdfg.add_array('in_arr', {N}, dace.bool) init = sdfg.add_state('init') guard = sdfg.add_state('guard') @@ -473,7 +473,7 @@ def test_for_with_external_init_nested_start_with_guard(): def test_skip_branch(): sdfg = dace.SDFG('skip_branch') sdfg.add_symbol('k', dace.int32) - sdfg.add_array('__return', (1,), dace.int32) + sdfg.add_array('__return', (1, ), dace.int32) init = sdfg.add_state('init') if_guard = sdfg.add_state('if_guard') if_state = sdfg.add_state('if_state') @@ -501,6 +501,78 @@ def test_skip_branch(): assert (rval_2[0] == 1) +def test_dependency_change(): + """ + Tests a regression in constant propagation that stems from a variable's + dependency being set in the same edge where the pre-propagated symbol was + also a right-hand side expression. The original SDFG is semantically-sound, + but the propagated one may update ``t`` to be ``t + `` + instead of the older ``irev``. + """ + + sdfg = dace.SDFG('tester') + sdfg.add_symbol('N', dace.int64) + sdfg.add_array('a', [1], dace.int64) + init = sdfg.add_state() + entry = sdfg.add_state('entry') + body = sdfg.add_state('body') + body2 = sdfg.add_state('body2') + exiting = sdfg.add_state('exiting') + latch = sdfg.add_state('latch') + final = sdfg.add_state('final') + + sdfg.add_edge(init, entry, dace.InterstateEdge(assignments=dict(i='0', t='0', irev='2500'))) + sdfg.add_edge(entry, body, dace.InterstateEdge()) + sdfg.add_edge( + body, body2, + dace.InterstateEdge(assignments=dict(t_next='(t + irev)', + irev_next='(irev + (- 1))', + i_next='i + 1'), )) + sdfg.add_edge( + body2, exiting, + dace.InterstateEdge(assignments=dict(cont='i_next == 2500'), )) + sdfg.add_edge(exiting, final, dace.InterstateEdge('cont')) + sdfg.add_edge(exiting, latch, dace.InterstateEdge('not cont', dict( + irev='irev_next', + i='i_next', + ))) + sdfg.add_edge(latch, body, dace.InterstateEdge(assignments=dict(t='t_next'))) + + t = body.add_tasklet('add', {'inp'}, {'out'}, 'out = inp + t') + body.add_edge(body.add_read('a'), None, t, 'inp', dace.Memlet('a[0]')) + body.add_edge(t, 'out', body.add_write('a'), None, dace.Memlet('a[0]')) + + ConstantPropagation().apply_pass(sdfg, {}) + + # Python code equivalent of the above SDFG + ref = 0 + + i = 0 + t = 0 + irev = 2500 + while True: + # body + ref += t + + # exiting state + t_next = t + irev + irev_next = (irev + (-1)) + i_next = i + 1 + cont = (i_next == 2500) + if not cont: + irev = irev_next + i = i_next + # + t = t_next + continue + else: + break + + a = np.zeros([1], np.int64) + sdfg(a=a) + assert a[0] == ref + + if __name__ == '__main__': test_simple_constants() test_nested_constants() @@ -519,3 +591,4 @@ def test_skip_branch(): test_for_with_external_init_nested() test_for_with_external_init_nested_start_with_guard() test_skip_branch() + test_dependency_change()