diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 56035aba12..c8c1f21202 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -459,42 +459,36 @@ def visit_HorizontalExecution( write_memlets=write_memlets, ) - if next(dcir_node.walk_values().if_isinstance(dcir.IndexAccess).iterator, None) is not None: - """ - Special case of tasklet performing array access. The memlet should pass the full array shape - (no slicing) and the tasklet code should use all explicit indexes for array access. - """ - for memlet in [*read_memlets, *write_memlets]: - field_decl = global_ctx.library_node.field_decls[memlet.field] - # calculate array subset from original memlet - memlet_subset = make_dace_subset( - global_ctx.library_node.access_infos[memlet.field], - memlet.access_info, - field_decl.data_dims, - ) - # ensure grid access on single point - memlet_data_index = [ - dcir.Literal(value=str(r[0]), dtype=common.DataType.INT32) - for r, size in zip(memlet_subset, memlet_subset.size()) - if size == 1 - ] - # loop through assignment statements in the tasklet body - tasklet_subset_size = 0 + for memlet in [*read_memlets, *write_memlets]: + array_ndims = len(global_ctx.arrays[memlet.field].shape) + field_decl = global_ctx.library_node.field_decls[memlet.field] + # calculate array subset from original memlet + memlet_subset = make_dace_subset( + global_ctx.library_node.access_infos[memlet.field], + memlet.access_info, + field_decl.data_dims, + ) + # index values for single-point grid access + memlet_data_index = [ + dcir.Literal(value=str(dim_range[0]), dtype=common.DataType.INT32) + for dim_range, dim_size in zip(memlet_subset, memlet_subset.size()) + if dim_size == 1 + ] + if len(memlet_data_index) < array_ndims: + """ + Special case of tasklet performing array access. The memlet should pass the full array shape + (no slicing) and the tasklet code should use all explicit indexes for array access. + """ for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess): if access_node.data_index and access_node.name == memlet.connector: - if tasklet_subset_size != 0: - assert len(access_node.data_index) == tasklet_subset_size - else: - tasklet_subset_size = len(access_node.data_index) for idx in reversed(memlet_data_index): access_node.data_index.insert(0, idx) - # reshape memlet if tasklet accessed the endpoint array with partial index - if tasklet_subset_size != 0: - # ensure that memlet symbols used for array subset are defined in context - for sym in memlet.access_info.grid_subset.free_symbols: - symbol_collector.add_symbol(sym) - # set full shape on memlet - memlet.access_info = global_ctx.library_node.access_infos[memlet.field] + assert len(access_node.data_index) == array_ndims + # ensure that memlet symbols used for array indexing are defined in context + for sym in memlet.access_info.grid_subset.free_symbols: + symbol_collector.add_symbol(sym) + # set full shape on memlet + memlet.access_info = global_ctx.library_node.access_infos[memlet.field] for item in reversed(expansion_items): iteration_ctx = iteration_ctx.pop()