Skip to content

Commit

Permalink
Fix problem with struct reads on interstate edges (#1512)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexandros Nikolaos Ziogas <[email protected]>
Co-authored-by: Tal Ben-Nun <[email protected]>
Co-authored-by: alexnick83 <[email protected]>
  • Loading branch information
4 people authored Nov 6, 2024
1 parent 1554421 commit 163366d
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 3 deletions.
10 changes: 10 additions & 0 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,16 @@ def _Name(self, t: ast.Name):
# Replace values with their code-generated names (for example, persistent arrays)
desc = self.sdfg.arrays[t.id]
self.write(ptr(t.id, desc, self.sdfg, self.codegen))

def _Attribute(self, t: ast.Attribute):
from dace.frontend.python.astutils import rname
name = rname(t)
if name not in self.sdfg.arrays:
return super()._Attribute(t)

# Replace values with their code-generated names (for example, persistent arrays)
desc = self.sdfg.arrays[name]
self.write(ptr(name, desc, self.sdfg, self.codegen))

def _Subscript(self, t: ast.Subscript):
from dace.frontend.python.astutils import subscript_to_slice
Expand Down
8 changes: 6 additions & 2 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,12 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV
ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype
defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer
self._dispatcher.declared_arrays.add(f"{name}->{k}", defined_type, ctypedef)
self.allocate_array(sdfg, cfg, dfg, state_id, nodes.AccessNode(f"{name}.{k}"), v,
function_stream, declaration_stream, allocation_stream)
if isinstance(v, data.Scalar):
# NOTE: Scalar members are already defined in the struct definition.
self._dispatcher.defined_vars.add(f"{name}->{k}", defined_type, ctypedef)
else:
self.allocate_array(sdfg, cfg, dfg, state_id, nodes.AccessNode(f"{name}.{k}"), v,
function_stream, declaration_stream, allocation_stream)
return
if isinstance(nodedesc, data.View):
return self.allocate_view(sdfg, cfg, dfg, state_id, node, function_stream, declaration_stream,
Expand Down
9 changes: 8 additions & 1 deletion dace/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,16 @@ class Data:
Examples: Arrays, Streams, custom arrays (e.g., sparse matrices).
"""

def _transient_setter(self, value):
self._transient = value
if isinstance(self, Structure):
for _, v in self.members.items():
if isinstance(v, Data):
v.transient = value

dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses)
shape = ShapeProperty(default=[])
transient = Property(dtype=bool, default=False)
transient = Property(dtype=bool, default=False, setter=_transient_setter)
storage = EnumProperty(dtype=dtypes.StorageType, desc="Storage location", default=dtypes.StorageType.Default)
lifetime = EnumProperty(dtype=dtypes.AllocationLifetime,
desc='Data allocation span',
Expand Down
4 changes: 4 additions & 0 deletions dace/transformation/auto/auto_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,10 @@ def make_transients_persistent(sdfg: SDFG,
not_persistent.add(dnode.data)
continue
desc = dnode.desc(nsdfg)
# Only convert what is not a member of a non-persistent struct.
if (dnode.root_data != dnode.data and
nsdfg.arrays[dnode.root_data].lifetime != dtypes.AllocationLifetime.Persistent):
continue
# Only convert arrays and scalars that are not registers
if not desc.transient or type(desc) not in {dt.Array, dt.Scalar}:
not_persistent.add(dnode.data)
Expand Down
62 changes: 62 additions & 0 deletions tests/sdfg/data/structure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,67 @@ def test_direct_read_nested_structure():
assert np.allclose(B, ref)


def test_read_struct_member_interstate_edge():
sdfg = dace.SDFG('test_read_struct_member_interstate_edge')

struct_data = dace.data.Structure({
'start': dace.data.Scalar(dace.int32),
'stop': dace.data.Scalar(dace.int32),
}, 't_indices')
struct_data.transient = True
sdfg.add_datadesc('indices', struct_data)
sdfg.add_array('A', [20], dace.int32, transient=False)
_, desc_start_in = sdfg.add_scalar('start_in', dace.int32, transient=False)
_, desc_stop_in = sdfg.add_scalar('stop_in', dace.int32, transient=False)
_, v_start = sdfg.add_view('v_start', [1], dace.int32)
_, v_stop = sdfg.add_view('v_stop', [1], dace.int32)

init = sdfg.add_state('init', is_start_block=True)
init2 = sdfg.add_state('init', is_start_block=True)
guard = sdfg.add_state('guard')
body = sdfg.add_state('guard')
tail = sdfg.add_state('guard')
sdfg.add_edge(init, init2, dace.InterstateEdge())
sdfg.add_edge(init2, guard, dace.InterstateEdge(assignments={'i': 'indices.start'}))
sdfg.add_edge(guard, body, dace.InterstateEdge(condition='i <= indices.stop'))
sdfg.add_edge(body, guard, dace.InterstateEdge(assignments={'i': '(i + 1)'}))
sdfg.add_edge(guard, tail, dace.InterstateEdge(condition='not (i <= indices.stop)'))

in_start_access = init.add_access('start_in')
t1 = init.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1')
start_access = init.add_access('v_start')
ind_access1 = init.add_access('indices')
init.add_edge(in_start_access, None, t1, 'i1', dace.Memlet.from_array('start_in', desc_start_in))
init.add_edge(t1, 'o1', start_access, None, dace.Memlet.from_array('v_start', v_start))
init.add_edge(start_access, 'views', ind_access1, None,
dace.Memlet.from_array('indices.start', struct_data.members['start']))

in_stop_access = init2.add_access('stop_in')
t2 = init2.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1')
stop_access = init2.add_access('v_stop')
ind_access2 = init2.add_access('indices')
init2.add_edge(in_stop_access, None, t2, 'i1', dace.Memlet.from_array('stop_in', desc_stop_in))
init2.add_edge(t2, 'o1', stop_access, None, dace.Memlet.from_array('v_stop', v_stop))
init2.add_edge(stop_access, 'views', ind_access2, None,
dace.Memlet.from_array('indices.stop', struct_data.members['stop']))

t3 = body.add_tasklet('t3', {}, {'o1'}, 'o1 = i')
a_access = body.add_access('A')
body.add_edge(t3, 'o1', a_access, None, dace.Memlet('A[i]'))

sdfg.validate()

arr = np.zeros((20,), dtype=np.int32)
arr_validate = np.zeros((20,), dtype=np.int32)
for i in range(11):
arr_validate[i] = i

c_sdfg = sdfg.compile()
c_sdfg(A=arr, start_in=0, stop_in=10)

assert np.allclose(arr, arr_validate)


if __name__ == "__main__":
test_read_structure()
test_write_structure()
Expand All @@ -552,3 +613,4 @@ def test_direct_read_nested_structure():
test_direct_read_structure()
test_direct_read_nested_structure()
test_direct_read_structure_loops()
test_read_struct_member_interstate_edge()

0 comments on commit 163366d

Please sign in to comment.