diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 90ef506bcd..420346ca88 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -752,7 +752,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: return self.generic_visit(node) # Then query for the right value - if isinstance(node.value, ast.Dict): + if isinstance(node.value, ast.Dict): # Dict for k, v in zip(node.value.keys, node.value.values): try: gkey = astutils.evalnode(k, self.globals) @@ -760,8 +760,20 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: continue if gkey == gslice: return self._visit_potential_constant(v, True) - else: # List or Tuple - return self._visit_potential_constant(node.value.elts[gslice], True) + elif isinstance(node.value, (ast.List, ast.Tuple)): # List & Tuple + # Loop over the list if slicing makes it a list + if isinstance(node.value.elts[gslice], List): + visited_list = astutils.copy_tree(node.value) + visited_list.elts.clear() + for v in node.value.elts[gslice]: + visited_cst = self._visit_potential_constant(v, True) + visited_list.elts.append(visited_cst) + node.value = visited_list + return node + else: + return self._visit_potential_constant(node.value.elts[gslice], True) + else: # Catch-all + return self._visit_potential_constant(node, True) return self._visit_potential_constant(node, True) diff --git a/tests/python_frontend/unroll_test.py b/tests/python_frontend/unroll_test.py index 98c81156a0..bf2b1e7c91 100644 --- a/tests/python_frontend/unroll_test.py +++ b/tests/python_frontend/unroll_test.py @@ -169,6 +169,52 @@ def tounroll(A: dace.float64[3]): assert np.allclose(a, np.array([1, 2, 3])) +def test_list_global_enumerate(): + tracer_variables = ["vapor", "rain", "nope"] + + @dace.program + def enumerate_parsing( + A, + tracers: dace.compiletime, # Dict[str, np.float64] + ): + for i, q in enumerate(tracer_variables[0:2]): + tracers[q][:] = A # type:ignore + + a = np.ones([3]) + q = { + "vapor": np.zeros([3]), + "rain": np.zeros([3]), + "nope": np.zeros([3]), + } + enumerate_parsing(a, q) + assert np.allclose(q["vapor"], np.array([1, 1, 1])) + assert np.allclose(q["rain"], np.array([1, 1, 1])) + assert np.allclose(q["nope"], np.array([0, 0, 0])) + + +def test_tuple_global_enumerate(): + tracer_variables = ("vapor", "rain", "nope") + + @dace.program + def enumerate_parsing( + A, + tracers: dace.compiletime, # Dict[str, np.float64] + ): + for i, q in enumerate(tracer_variables[0:2]): + tracers[q][:] = A # type:ignore + + a = np.ones([3]) + q = { + "vapor": np.zeros([3]), + "rain": np.zeros([3]), + "nope": np.zeros([3]), + } + enumerate_parsing(a, q) + assert np.allclose(q["vapor"], np.array([1, 1, 1])) + assert np.allclose(q["rain"], np.array([1, 1, 1])) + assert np.allclose(q["nope"], np.array([0, 0, 0])) + + def test_tuple_elements_zip(): a1 = [2, 3, 4] a2 = (4, 5, 6)