Skip to content

Commit

Permalink
Fix array indirection to memlet subset promotion (#1406)
Browse files Browse the repository at this point in the history
The current solution is rather hacky. I want to run the tests first to
see the impacts of this change.
Additionally, there is no test yet, because validation doesn't catch the
erroneous SDFG yet.

Overall, it's not clear currently how to solve the issue and the PR
might change as we progress...
  • Loading branch information
BenWeber42 authored Sep 24, 2024
1 parent 7df09c7 commit 9945f48
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
3 changes: 2 additions & 1 deletion dace/transformation/passes/scalar_to_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def __init__(self, in_edges: Dict[str, mm.Memlet], out_edges: Dict[str, mm.Memle

def visit_Subscript(self, node: ast.Subscript) -> Any:
# Convert subscript to symbol name
node = self.generic_visit(node)
node_name = astutils.rname(node)
if node_name in self.in_edges:
self.latest[node_name] += 1
Expand All @@ -346,7 +347,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
return ast.copy_location(ast.Name(id=new_name, ctx=ast.Store()), node)
else:
self.do_not_remove.add(node_name)
return self.generic_visit(node)
return node


def _range_is_promotable(subset: subsets.Range, defined: Set[str]) -> bool:
Expand Down
46 changes: 42 additions & 4 deletions tests/passes/scalar_to_symbol_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" Tests the scalar to symbol promotion functionality. """
import dace
from dace.transformation.passes import scalar_to_symbol
from dace.sdfg.state import SDFGState
from dace.transformation import transformation as xf, interstate as isxf
from dace.transformation.interstate import loop_detection as ld
from dace import registry
from dace.transformation import helpers as xfh

import collections
from sympy import core as sympy_core
import numpy as np
import pytest

Expand Down Expand Up @@ -692,6 +690,45 @@ def test_ternary_expression(compile_time_evaluatable):
sdfg.compile()


def test_double_index_bug():

sdfg = dace.SDFG('test_')
state = sdfg.add_state()

sdfg.add_array('A', shape=(10, ), dtype=dace.float64)
sdfg.add_array('table', shape=(10, 2), dtype=dace.int64)
sdfg.add_array('B', shape=(10, ), dtype=dace.float64)
sdfg.add_scalar('idx', dace.int64, transient=True)
idx_node = state.add_access('idx')
set_tlet = state.add_tasklet('set_idx', code="_idx=0", inputs={}, outputs={"_idx"})
state.add_mapped_tasklet('map',
map_ranges={'i': "0:10"},
inputs={
'inp': dace.Memlet("A[0:10]"),
'_idx': dace.Memlet('idx[0]'),
'indices': dace.Memlet('table[0:10, 0:2]')
},
code="out = inp[indices[i,_idx]]",
outputs={'out': dace.Memlet("B[i]")},
external_edges=True,
input_nodes={'idx': idx_node})

state.add_edge(set_tlet, '_idx', idx_node, None, dace.Memlet('idx[0]'))

sdfg.simplify()

# Check that `indices` (which is an array) is not used in a memlet subset
for state in sdfg.states():
for memlet in state.edges():
subset = memlet.data.subset
if not isinstance(subset, dace.subsets.Range):
continue
for range in subset.ranges:
for part in range:
for sympy_node in sympy_core.preorder_traversal(part):
assert getattr(sympy_node, "name", None) != "indices"


if __name__ == '__main__':
test_find_promotable()
test_promote_simple()
Expand All @@ -715,3 +752,4 @@ def test_ternary_expression(compile_time_evaluatable):
test_dynamic_mapind()
test_ternary_expression(False)
test_ternary_expression(True)
test_double_index_bug()

0 comments on commit 9945f48

Please sign in to comment.