diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 8b4f2a9be3..a0cb08ea0c 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -323,6 +323,7 @@ def __init__(self, in_edges: Dict[str, mm.Memlet], out_edges: Dict[str, mm.Memle def visit_Subscript(self, node: ast.Subscript) -> Any: # Convert subscript to symbol name + node = self.generic_visit(node) node_name = astutils.rname(node) if node_name in self.in_edges: self.latest[node_name] += 1 @@ -346,7 +347,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: return ast.copy_location(ast.Name(id=new_name, ctx=ast.Store()), node) else: self.do_not_remove.add(node_name) - return self.generic_visit(node) + return node def _range_is_promotable(subset: subsets.Range, defined: Set[str]) -> bool: diff --git a/tests/passes/scalar_to_symbol_test.py b/tests/passes/scalar_to_symbol_test.py index 140ec105f7..7fdfbdf737 100644 --- a/tests/passes/scalar_to_symbol_test.py +++ b/tests/passes/scalar_to_symbol_test.py @@ -1,14 +1,12 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the scalar to symbol promotion functionality. """ import dace from dace.transformation.passes import scalar_to_symbol -from dace.sdfg.state import SDFGState from dace.transformation import transformation as xf, interstate as isxf from dace.transformation.interstate import loop_detection as ld -from dace import registry -from dace.transformation import helpers as xfh import collections +from sympy import core as sympy_core import numpy as np import pytest @@ -692,6 +690,45 @@ def test_ternary_expression(compile_time_evaluatable): sdfg.compile() +def test_double_index_bug(): + + sdfg = dace.SDFG('test_') + state = sdfg.add_state() + + sdfg.add_array('A', shape=(10, ), dtype=dace.float64) + sdfg.add_array('table', shape=(10, 2), dtype=dace.int64) + sdfg.add_array('B', shape=(10, ), dtype=dace.float64) + sdfg.add_scalar('idx', dace.int64, transient=True) + idx_node = state.add_access('idx') + set_tlet = state.add_tasklet('set_idx', code="_idx=0", inputs={}, outputs={"_idx"}) + state.add_mapped_tasklet('map', + map_ranges={'i': "0:10"}, + inputs={ + 'inp': dace.Memlet("A[0:10]"), + '_idx': dace.Memlet('idx[0]'), + 'indices': dace.Memlet('table[0:10, 0:2]') + }, + code="out = inp[indices[i,_idx]]", + outputs={'out': dace.Memlet("B[i]")}, + external_edges=True, + input_nodes={'idx': idx_node}) + + state.add_edge(set_tlet, '_idx', idx_node, None, dace.Memlet('idx[0]')) + + sdfg.simplify() + + # Check that `indices` (which is an array) is not used in a memlet subset + for state in sdfg.states(): + for memlet in state.edges(): + subset = memlet.data.subset + if not isinstance(subset, dace.subsets.Range): + continue + for range in subset.ranges: + for part in range: + for sympy_node in sympy_core.preorder_traversal(part): + assert getattr(sympy_node, "name", None) != "indices" + + if __name__ == '__main__': test_find_promotable() test_promote_simple() @@ -715,3 +752,4 @@ def test_ternary_expression(compile_time_evaluatable): test_dynamic_mapind() test_ternary_expression(False) test_ternary_expression(True) + test_double_index_bug()