From e427617520af5d2b0d52e7dc41e23a7273f48051 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 4 Jan 2024 09:15:39 -0800 Subject: [PATCH] Fix redefinition of interstate edge type in code generator (#1490) --- dace/codegen/targets/framecode.py | 5 +++-- tests/codegen/symbol_arguments_test.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 0db4062976..7b6df55132 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -887,8 +887,9 @@ def generate_code(self, # NOTE: NestedSDFGs frequently contain tautologies in their symbol mapping, e.g., `'i': i`. Do not # redefine the symbols in such cases. - if (not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping - and str(sdfg.parent_nsdfg_node.symbol_mapping[isvarName]) == str(isvarName)): + # Additionally, do not redefine a symbol with its type if it was already defined + # as part of the function's arguments + if not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping: continue isvar = data.Scalar(isvarType) callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg) diff --git a/tests/codegen/symbol_arguments_test.py b/tests/codegen/symbol_arguments_test.py index 3ca89ddd06..557c42f8c1 100644 --- a/tests/codegen/symbol_arguments_test.py +++ b/tests/codegen/symbol_arguments_test.py @@ -48,7 +48,21 @@ def tester(A: dace.float64[N, N]): assert 'N' in sdfg.arglist() +def test_nested_sdfg_redefinition(): + sdfg = dace.SDFG('tester') + nsdfg = dace.SDFG('nester') + state = sdfg.add_state() + nnode = state.add_nested_sdfg(nsdfg, None, {}, {}, symbol_mapping=dict(sym=0)) + + nstate = nsdfg.add_state() + nstate.add_tasklet('nothing', {}, {}, 'a = sym') + nstate2 = nsdfg.add_state() + nsdfg.add_edge(nstate, nstate2, dace.InterstateEdge(assignments=dict(sym=1))) + sdfg.compile() + + if __name__ == '__main__': test_global_sizes() test_global_sizes_used() test_global_sizes_multidim() + test_nested_sdfg_redefinition()