Skip to content

Commit

Permalink
Bug in constant propagation with multiple constants (#1658)
Browse files Browse the repository at this point in the history
Propagating multiple constants symbolically at the same time is not a
good idea if propagated symbol A can change a value that affects
propagated symbol B. This PR adds a failing test and hopefully a fix.

@luigifusco @phschaad

---------

Co-authored-by: Luigi Fusco <[email protected]>
  • Loading branch information
tbennun and luigifusco authored Sep 18, 2024
1 parent b0699ed commit 829687c
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 7 deletions.
1 change: 0 additions & 1 deletion dace/transformation/dataflow/stream_transient.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from dace.symbolic import symstr
import warnings

from numpy.core.numeric import outer
from dace import data, dtypes, registry, symbolic, subsets
from dace.frontend.operations import detect_reduction_type
from dace.properties import SymbolicProperty, make_properties, Property
Expand Down
5 changes: 4 additions & 1 deletion dace/transformation/passes/constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def _add_nested_datanames(name: str, desc: data.Structure):

for aname, aval in constants.items():
# If something was assigned more than once (to a different value), it's not a constant
if aname in assignments and aval != assignments[aname]:
# If a symbol appearing in the replacing expression of a constant is modified,
# the constant is not valid anymore
if ((aname in assignments and aval != assignments[aname]) or
symbolic.free_symbols_and_functions(aval) & edge.data.assignments.keys()):
assignments[aname] = _UnknownValue
else:
assignments[aname] = aval
Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ def _subgraph_transformation_extract_sdfg_arg(*args) -> SDFG:
raise TypeError('Unrecognized graph type "%s"' % type(subgraph).__name__)


def single_level_sdfg_only(cls: ppl.Pass):
def single_level_sdfg_only(cls: PassT) -> PassT:

for function_name in ['apply_pass', 'apply_to']:
_make_function_blocksafe(cls, function_name, lambda *args: args[1])
Expand Down
81 changes: 77 additions & 4 deletions tests/passes/constant_propagation_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.

import pytest
import dace
Expand Down Expand Up @@ -359,8 +359,8 @@ def test_for_with_conditional_assignment():
sdfg.add_symbol('i', dace.int64)
sdfg.add_symbol('check', dace.bool)
sdfg.add_symbol('__tmp1', dace.bool)
sdfg.add_array('__return', {1,}, dace.bool)
sdfg.add_array('in_arr', {N,}, dace.bool)
sdfg.add_array('__return', {1}, dace.bool)
sdfg.add_array('in_arr', {N}, dace.bool)

init = sdfg.add_state('init')
guard = sdfg.add_state('guard')
Expand Down Expand Up @@ -473,7 +473,7 @@ def test_for_with_external_init_nested_start_with_guard():
def test_skip_branch():
sdfg = dace.SDFG('skip_branch')
sdfg.add_symbol('k', dace.int32)
sdfg.add_array('__return', (1,), dace.int32)
sdfg.add_array('__return', (1, ), dace.int32)
init = sdfg.add_state('init')
if_guard = sdfg.add_state('if_guard')
if_state = sdfg.add_state('if_state')
Expand Down Expand Up @@ -501,6 +501,78 @@ def test_skip_branch():
assert (rval_2[0] == 1)


def test_dependency_change():
"""
Tests a regression in constant propagation that stems from a variable's
dependency being set in the same edge where the pre-propagated symbol was
also a right-hand side expression. The original SDFG is semantically-sound,
but the propagated one may update ``t`` to be ``t + <modified irev>``
instead of the older ``irev``.
"""

sdfg = dace.SDFG('tester')
sdfg.add_symbol('N', dace.int64)
sdfg.add_array('a', [1], dace.int64)
init = sdfg.add_state()
entry = sdfg.add_state('entry')
body = sdfg.add_state('body')
body2 = sdfg.add_state('body2')
exiting = sdfg.add_state('exiting')
latch = sdfg.add_state('latch')
final = sdfg.add_state('final')

sdfg.add_edge(init, entry, dace.InterstateEdge(assignments=dict(i='0', t='0', irev='2500')))
sdfg.add_edge(entry, body, dace.InterstateEdge())
sdfg.add_edge(
body, body2,
dace.InterstateEdge(assignments=dict(t_next='(t + irev)',
irev_next='(irev + (- 1))',
i_next='i + 1'), ))
sdfg.add_edge(
body2, exiting,
dace.InterstateEdge(assignments=dict(cont='i_next == 2500'), ))
sdfg.add_edge(exiting, final, dace.InterstateEdge('cont'))
sdfg.add_edge(exiting, latch, dace.InterstateEdge('not cont', dict(
irev='irev_next',
i='i_next',
)))
sdfg.add_edge(latch, body, dace.InterstateEdge(assignments=dict(t='t_next')))

t = body.add_tasklet('add', {'inp'}, {'out'}, 'out = inp + t')
body.add_edge(body.add_read('a'), None, t, 'inp', dace.Memlet('a[0]'))
body.add_edge(t, 'out', body.add_write('a'), None, dace.Memlet('a[0]'))

ConstantPropagation().apply_pass(sdfg, {})

# Python code equivalent of the above SDFG
ref = 0

i = 0
t = 0
irev = 2500
while True:
# body
ref += t

# exiting state
t_next = t + irev
irev_next = (irev + (-1))
i_next = i + 1
cont = (i_next == 2500)
if not cont:
irev = irev_next
i = i_next
#
t = t_next
continue
else:
break

a = np.zeros([1], np.int64)
sdfg(a=a)
assert a[0] == ref


if __name__ == '__main__':
test_simple_constants()
test_nested_constants()
Expand All @@ -519,3 +591,4 @@ def test_skip_branch():
test_for_with_external_init_nested()
test_for_with_external_init_nested_start_with_guard()
test_skip_branch()
test_dependency_change()

0 comments on commit 829687c

Please sign in to comment.