From 608aa805bbe0670cf20e6a1649a4b921144831d8 Mon Sep 17 00:00:00 2001 From: alexnick83 <31545860+alexnick83@users.noreply.github.com> Date: Mon, 26 Feb 2024 20:45:05 +0100 Subject: [PATCH] Fixes for structures nested in (nested) struct-arrays (#1534) --- dace/codegen/targets/cpu.py | 8 ++-- dace/codegen/targets/framecode.py | 7 ++- dace/dtypes.py | 6 ++- dace/sdfg/sdfg.py | 4 +- dace/sdfg/validation.py | 17 +++++++ tests/sdfg/data/container_array_test.py | 59 +++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 7 deletions(-) diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index a2bc44caea..0d153fb332 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -19,6 +19,7 @@ from dace.sdfg import (ScopeSubgraphView, SDFG, scope_contains_scope, is_array_stream_view, NodeNotExpandedError, dynamic_map_inputs, local_transients) from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga, is_in_scope +from dace.sdfg.validation import validate_memlet_data from typing import Union from dace.codegen.targets import fpga @@ -40,7 +41,7 @@ def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''): _visit_structure(v, args, f'{prefix}->{k}') elif isinstance(v, data.ContainerArray): _visit_structure(v.stype, args, f'{prefix}->{k}') - elif isinstance(v, data.Data): + if isinstance(v, data.Data): args[f'{prefix}->{k}'] = v # Keeps track of generated connectors, so we know how to access them in nested scopes @@ -620,6 +621,7 @@ def copy_memory( callsite_stream, ) + def _emit_copy( self, sdfg, @@ -637,9 +639,9 @@ def _emit_copy( orig_vconn = vconn # Determine memlet directionality - if isinstance(src_node, nodes.AccessNode) and memlet.data == src_node.data: + if isinstance(src_node, nodes.AccessNode) and validate_memlet_data(memlet.data, src_node.data): write = True - elif isinstance(dst_node, nodes.AccessNode) and memlet.data == dst_node.data: + elif isinstance(dst_node, nodes.AccessNode) and validate_memlet_data(memlet.data, dst_node.data): write = False elif isinstance(src_node, nodes.CodeNode) and isinstance(dst_node, nodes.CodeNode): # Code->Code copy (not read nor write) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 7587f84f54..c1abf82b69 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -155,6 +155,8 @@ def generate_fileheader(self, sdfg: SDFG, global_stream: CodeIOStream, backend: if arr is not None: datatypes.add(arr.dtype) + emitted = set() + def _emit_definitions(dtype: dtypes.typeclass, wrote_something: bool) -> bool: if isinstance(dtype, dtypes.pointer): wrote_something = _emit_definitions(dtype._typeclass, wrote_something) @@ -164,7 +166,10 @@ def _emit_definitions(dtype: dtypes.typeclass, wrote_something: bool) -> bool: if hasattr(dtype, 'emit_definition'): if not wrote_something: global_stream.write("", sdfg) - global_stream.write(dtype.emit_definition(), sdfg) + if dtype not in emitted: + global_stream.write(dtype.emit_definition(), sdfg) + wrote_something = True + emitted.add(dtype) return wrote_something # Emit unique definitions diff --git a/dace/dtypes.py b/dace/dtypes.py index f3f27368a5..76e6db8397 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -1449,8 +1449,10 @@ def validate_name(name): return False if name in {'True', 'False', 'None'}: return False - if namere.match(name) is None: - return False + tokens = name.split('.') + for token in tokens: + if namere.match(token) is None: + return False return True diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index adcaacaf27..5017a6ff86 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -728,7 +728,9 @@ def replace_dict(self, # Replace in arrays and symbols (if a variable name) if replace_keys: - for name, new_name in repldict.items(): + # Filter out nested data names, as we cannot and do not want to replace names in nested data descriptors + repldict_filtered = {k: v for k, v in repldict.items() if '.' not in k} + for name, new_name in repldict_filtered.items(): if validate_name(new_name): _replace_dict_keys(self._arrays, name, new_name) _replace_dict_keys(self.symbols, name, new_name) diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 299ffc96fa..660e45e574 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -981,3 +981,20 @@ def __str__(self): locinfo += f'\nInvalid SDFG saved for inspection in {os.path.abspath(self.path)}' return f'{self.message} (at state {state.label}{edgestr}){locinfo}' + + +def validate_memlet_data(memlet_data: str, access_data: str) -> bool: + """ Validates that the src/dst access node data matches the memlet data. + + :param memlet_data: The data of the memlet. + :param access_data: The data of the access node. + :return: True if the memlet data matches the access node data. + """ + if memlet_data == access_data: + return True + if memlet_data is None or access_data is None: + return False + access_tokens = access_data.split('.') + memlet_tokens = memlet_data.split('.') + mem_root = '.'.join(memlet_tokens[:len(access_tokens)]) + return mem_root == access_data diff --git a/tests/sdfg/data/container_array_test.py b/tests/sdfg/data/container_array_test.py index 7685361d0f..091bb487d8 100644 --- a/tests/sdfg/data/container_array_test.py +++ b/tests/sdfg/data/container_array_test.py @@ -258,8 +258,67 @@ def test_two_levels(): assert np.allclose(ref, B[0]) +def test_multi_nested_containers(): + + M, N = dace.symbol('M'), dace.symbol('N') + sdfg = dace.SDFG('tester') + float_desc = dace.data.Scalar(dace.float32) + E_desc = dace.data.Structure({'F': dace.float32[N], 'G':float_desc}, 'InnerStruct') + B_desc = dace.data.ContainerArray(E_desc, [M]) + A_desc = dace.data.Structure({'B': B_desc, 'C': dace.float32[M], 'D': float_desc}, 'OuterStruct') + sdfg.add_datadesc('A', A_desc) + sdfg.add_datadesc_view('vB', B_desc) + sdfg.add_datadesc_view('vE', E_desc) + sdfg.add_array('out', [M, N], dace.float32) + + state = sdfg.add_state() + rA = state.add_read('A') + vB = state.add_access('vB') + vE = state.add_access('vE') + wout = state.add_write('out') + + me, mx = state.add_map('outer_product', dict(i='0:M', j='0:N')) + tasklet = state.add_tasklet('outer_product', {'__in_A_B_E_F', '__in_A_B_E_G', '__in_A_C', '__in_A_D'}, {'__out'}, + '__out = (__in_A_B_E_F + __in_A_B_E_G) * (__in_A_C + __in_A_D)') + + state.add_edge(rA, None, vB, 'views', dace.Memlet('A.B')) + state.add_memlet_path(vB, me, vE, dst_conn='views', memlet=dace.Memlet('vB[i]')) + state.add_edge(vE, None, tasklet, '__in_A_B_E_F', dace.Memlet('vE.F[j]')) + state.add_edge(vE, None, tasklet, '__in_A_B_E_G', dace.Memlet(data='vE.G', subset='0')) + state.add_memlet_path(rA, me, tasklet, dst_conn='__in_A_C', memlet=dace.Memlet('A.C[i]')) + state.add_memlet_path(rA, me, tasklet, dst_conn='__in_A_D', memlet=dace.Memlet(data='A.D', subset='0')) + state.add_memlet_path(tasklet, mx, wout, src_conn='__out', memlet=dace.Memlet('out[i, j]')) + + c_data = np.arange(5, dtype=np.float32) + f_data = np.arange(5 * 3, dtype=np.float32).reshape(5, 3) + + e_class = E_desc.dtype._typeclass.as_ctypes() + b_obj = [] + b_data = np.ndarray((5, ), dtype=ctypes.c_void_p) + for i in range(5): + f_obj = f_data[i].__array_interface__['data'][0] + e_obj = e_class(F=f_obj, G=ctypes.c_float(0.1)) + b_obj.append(e_obj) # NOTE: This is needed to keep the object alive ... + b_data[i] = ctypes.addressof(e_obj) + a_dace = A_desc.dtype._typeclass.as_ctypes()(B=b_data.__array_interface__['data'][0], + C=c_data.__array_interface__['data'][0], + D=ctypes.c_float(0.2)) + + + + + out_dace = np.empty((5, 3), dtype=np.float32) + ref = np.empty((5, 3), dtype=np.float32) + for i in range(5): + ref[i] = (f_data[i] + 0.1) * (c_data[i] + 0.2) + + sdfg(A=a_dace, out=out_dace, M=5, N=3) + assert np.allclose(out_dace, ref) + + if __name__ == '__main__': test_read_struct_array() test_write_struct_array() test_jagged_container_array() test_two_levels() + test_multi_nested_containers()