diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index c34c829c31..86942874d1 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -1036,6 +1036,16 @@ def _Name(self, t: ast.Name): # Replace values with their code-generated names (for example, persistent arrays) desc = self.sdfg.arrays[t.id] self.write(ptr(t.id, desc, self.sdfg, self.codegen)) + + def _Attribute(self, t: ast.Attribute): + from dace.frontend.python.astutils import rname + name = rname(t) + if name not in self.sdfg.arrays: + return super()._Attribute(t) + + # Replace values with their code-generated names (for example, persistent arrays) + desc = self.sdfg.arrays[name] + self.write(ptr(name, desc, self.sdfg, self.codegen)) def _Subscript(self, t: ast.Subscript): from dace.frontend.python.astutils import subscript_to_slice diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 51daaa432b..160855458d 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -406,8 +406,12 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer self._dispatcher.declared_arrays.add(f"{name}->{k}", defined_type, ctypedef) - self.allocate_array(sdfg, cfg, dfg, state_id, nodes.AccessNode(f"{name}.{k}"), v, - function_stream, declaration_stream, allocation_stream) + if isinstance(v, data.Scalar): + # NOTE: Scalar members are already defined in the struct definition. + self._dispatcher.defined_vars.add(f"{name}->{k}", defined_type, ctypedef) + else: + self.allocate_array(sdfg, cfg, dfg, state_id, nodes.AccessNode(f"{name}.{k}"), v, + function_stream, declaration_stream, allocation_stream) return if isinstance(nodedesc, data.View): return self.allocate_view(sdfg, cfg, dfg, state_id, node, function_stream, declaration_stream, diff --git a/dace/data.py b/dace/data.py index a07fe42083..9749411fe6 100644 --- a/dace/data.py +++ b/dace/data.py @@ -167,9 +167,16 @@ class Data: Examples: Arrays, Streams, custom arrays (e.g., sparse matrices). """ + def _transient_setter(self, value): + self._transient = value + if isinstance(self, Structure): + for _, v in self.members.items(): + if isinstance(v, Data): + v.transient = value + dtype = TypeClassProperty(default=dtypes.int32, choices=dtypes.Typeclasses) shape = ShapeProperty(default=[]) - transient = Property(dtype=bool, default=False) + transient = Property(dtype=bool, default=False, setter=_transient_setter) storage = EnumProperty(dtype=dtypes.StorageType, desc="Storage location", default=dtypes.StorageType.Default) lifetime = EnumProperty(dtype=dtypes.AllocationLifetime, desc='Data allocation span', diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 7bced3bec9..0a1371ae68 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -474,6 +474,10 @@ def make_transients_persistent(sdfg: SDFG, not_persistent.add(dnode.data) continue desc = dnode.desc(nsdfg) + # Only convert what is not a member of a non-persistent struct. + if (dnode.root_data != dnode.data and + nsdfg.arrays[dnode.root_data].lifetime != dtypes.AllocationLifetime.Persistent): + continue # Only convert arrays and scalars that are not registers if not desc.transient or type(desc) not in {dt.Array, dt.Scalar}: not_persistent.add(dnode.data) diff --git a/tests/sdfg/data/structure_test.py b/tests/sdfg/data/structure_test.py index 55e3a936a7..f2ab0e51a2 100644 --- a/tests/sdfg/data/structure_test.py +++ b/tests/sdfg/data/structure_test.py @@ -543,6 +543,67 @@ def test_direct_read_nested_structure(): assert np.allclose(B, ref) +def test_read_struct_member_interstate_edge(): + sdfg = dace.SDFG('test_read_struct_member_interstate_edge') + + struct_data = dace.data.Structure({ + 'start': dace.data.Scalar(dace.int32), + 'stop': dace.data.Scalar(dace.int32), + }, 't_indices') + struct_data.transient = True + sdfg.add_datadesc('indices', struct_data) + sdfg.add_array('A', [20], dace.int32, transient=False) + _, desc_start_in = sdfg.add_scalar('start_in', dace.int32, transient=False) + _, desc_stop_in = sdfg.add_scalar('stop_in', dace.int32, transient=False) + _, v_start = sdfg.add_view('v_start', [1], dace.int32) + _, v_stop = sdfg.add_view('v_stop', [1], dace.int32) + + init = sdfg.add_state('init', is_start_block=True) + init2 = sdfg.add_state('init', is_start_block=True) + guard = sdfg.add_state('guard') + body = sdfg.add_state('guard') + tail = sdfg.add_state('guard') + sdfg.add_edge(init, init2, dace.InterstateEdge()) + sdfg.add_edge(init2, guard, dace.InterstateEdge(assignments={'i': 'indices.start'})) + sdfg.add_edge(guard, body, dace.InterstateEdge(condition='i <= indices.stop')) + sdfg.add_edge(body, guard, dace.InterstateEdge(assignments={'i': '(i + 1)'})) + sdfg.add_edge(guard, tail, dace.InterstateEdge(condition='not (i <= indices.stop)')) + + in_start_access = init.add_access('start_in') + t1 = init.add_tasklet('t1', {'i1'}, {'o1'}, 'o1 = i1') + start_access = init.add_access('v_start') + ind_access1 = init.add_access('indices') + init.add_edge(in_start_access, None, t1, 'i1', dace.Memlet.from_array('start_in', desc_start_in)) + init.add_edge(t1, 'o1', start_access, None, dace.Memlet.from_array('v_start', v_start)) + init.add_edge(start_access, 'views', ind_access1, None, + dace.Memlet.from_array('indices.start', struct_data.members['start'])) + + in_stop_access = init2.add_access('stop_in') + t2 = init2.add_tasklet('t2', {'i1'}, {'o1'}, 'o1 = i1') + stop_access = init2.add_access('v_stop') + ind_access2 = init2.add_access('indices') + init2.add_edge(in_stop_access, None, t2, 'i1', dace.Memlet.from_array('stop_in', desc_stop_in)) + init2.add_edge(t2, 'o1', stop_access, None, dace.Memlet.from_array('v_stop', v_stop)) + init2.add_edge(stop_access, 'views', ind_access2, None, + dace.Memlet.from_array('indices.stop', struct_data.members['stop'])) + + t3 = body.add_tasklet('t3', {}, {'o1'}, 'o1 = i') + a_access = body.add_access('A') + body.add_edge(t3, 'o1', a_access, None, dace.Memlet('A[i]')) + + sdfg.validate() + + arr = np.zeros((20,), dtype=np.int32) + arr_validate = np.zeros((20,), dtype=np.int32) + for i in range(11): + arr_validate[i] = i + + c_sdfg = sdfg.compile() + c_sdfg(A=arr, start_in=0, stop_in=10) + + assert np.allclose(arr, arr_validate) + + if __name__ == "__main__": test_read_structure() test_write_structure() @@ -552,3 +613,4 @@ def test_direct_read_nested_structure(): test_direct_read_structure() test_direct_read_nested_structure() test_direct_read_structure_loops() + test_read_struct_member_interstate_edge()