From 9945f48c9ff8b7deeb218cfe1b8cdbebd923be60 Mon Sep 17 00:00:00 2001 From: BenWeber42 Date: Tue, 24 Sep 2024 22:36:03 +0200 Subject: [PATCH] Fix array indirection to memlet subset promotion (#1406) The current solution is rather hacky. I want to run the tests first to see the impacts of this change. Additionally, there is no test yet, because validation doesn't catch the erroneous SDFG yet. Overall, it's not clear currently how to solve the issue and the PR might change as we progress... --- .../transformation/passes/scalar_to_symbol.py | 3 +- tests/passes/scalar_to_symbol_test.py | 46 +++++++++++++++++-- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 8b4f2a9be3..a0cb08ea0c 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -323,6 +323,7 @@ def __init__(self, in_edges: Dict[str, mm.Memlet], out_edges: Dict[str, mm.Memle def visit_Subscript(self, node: ast.Subscript) -> Any: # Convert subscript to symbol name + node = self.generic_visit(node) node_name = astutils.rname(node) if node_name in self.in_edges: self.latest[node_name] += 1 @@ -346,7 +347,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: return ast.copy_location(ast.Name(id=new_name, ctx=ast.Store()), node) else: self.do_not_remove.add(node_name) - return self.generic_visit(node) + return node def _range_is_promotable(subset: subsets.Range, defined: Set[str]) -> bool: diff --git a/tests/passes/scalar_to_symbol_test.py b/tests/passes/scalar_to_symbol_test.py index 140ec105f7..7fdfbdf737 100644 --- a/tests/passes/scalar_to_symbol_test.py +++ b/tests/passes/scalar_to_symbol_test.py @@ -1,14 +1,12 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the scalar to symbol promotion functionality. """ import dace from dace.transformation.passes import scalar_to_symbol -from dace.sdfg.state import SDFGState from dace.transformation import transformation as xf, interstate as isxf from dace.transformation.interstate import loop_detection as ld -from dace import registry -from dace.transformation import helpers as xfh import collections +from sympy import core as sympy_core import numpy as np import pytest @@ -692,6 +690,45 @@ def test_ternary_expression(compile_time_evaluatable): sdfg.compile() +def test_double_index_bug(): + + sdfg = dace.SDFG('test_') + state = sdfg.add_state() + + sdfg.add_array('A', shape=(10, ), dtype=dace.float64) + sdfg.add_array('table', shape=(10, 2), dtype=dace.int64) + sdfg.add_array('B', shape=(10, ), dtype=dace.float64) + sdfg.add_scalar('idx', dace.int64, transient=True) + idx_node = state.add_access('idx') + set_tlet = state.add_tasklet('set_idx', code="_idx=0", inputs={}, outputs={"_idx"}) + state.add_mapped_tasklet('map', + map_ranges={'i': "0:10"}, + inputs={ + 'inp': dace.Memlet("A[0:10]"), + '_idx': dace.Memlet('idx[0]'), + 'indices': dace.Memlet('table[0:10, 0:2]') + }, + code="out = inp[indices[i,_idx]]", + outputs={'out': dace.Memlet("B[i]")}, + external_edges=True, + input_nodes={'idx': idx_node}) + + state.add_edge(set_tlet, '_idx', idx_node, None, dace.Memlet('idx[0]')) + + sdfg.simplify() + + # Check that `indices` (which is an array) is not used in a memlet subset + for state in sdfg.states(): + for memlet in state.edges(): + subset = memlet.data.subset + if not isinstance(subset, dace.subsets.Range): + continue + for range in subset.ranges: + for part in range: + for sympy_node in sympy_core.preorder_traversal(part): + assert getattr(sympy_node, "name", None) != "indices" + + if __name__ == '__main__': test_find_promotable() test_promote_simple() @@ -715,3 +752,4 @@ def test_ternary_expression(compile_time_evaluatable): test_dynamic_mapind() test_ternary_expression(False) test_ternary_expression(True) + test_double_index_bug()