From 19aa9836100e521f661a0ee500fa89bc83a7ae42 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 4 Oct 2023 13:18:34 +0200 Subject: [PATCH 01/10] [dace] Remove re-ordering of data layout --- .../runners/dace_iterator/__init__.py | 6 +- .../runners/dace_iterator/itir_to_sdfg.py | 65 ++++++++++--------- .../runners/dace_iterator/itir_to_tasklet.py | 17 ++--- .../runners/dace_iterator/utility.py | 13 ++-- 4 files changed, 50 insertions(+), 51 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index f78d90095c..37e59367c1 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -31,11 +31,7 @@ def convert_arg(arg: Any): if common.is_field(arg): - sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) - ndim = len(sorted_dims) - dim_indices = [dim[0] for dim in sorted_dims] - assert isinstance(arg.ndarray, np.ndarray) - return np.moveaxis(arg.ndarray, range(ndim), dim_indices) + return arg.ndarray return arg diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 56031d8555..3e2330023c 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -246,7 +246,7 @@ def visit_StencilClosure( ) access = closure_init_state.add_access(out_name) value = ValueExpr(access, dtype) - memlet = create_memlet_at(out_name, ("0",)) + memlet = dace.Memlet.simple(out_name, "0") closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) program_arg_syms[name] = value else: @@ -274,7 +274,7 @@ def visit_StencilClosure( transient_to_arg_name_mapping[nsdfg_output_name] = output_name # scan operator should always be the first function call in a closure if is_scan(node.stencil): - nsdfg, map_domain, scan_dim_index = self._visit_scan_stencil_closure( + nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( node, closure_sdfg.arrays, closure_domain, nsdfg_output_name ) results = [nsdfg_output_name] @@ -292,15 +292,16 @@ def visit_StencilClosure( output_memlet = create_memlet_at( output_name, - tuple( - f"i_{dim}" - if f"i_{dim}" in map_domain + self.storage_types[output_name], + { + dim: f"i_{dim}" + if f"i_{dim}" in map_ranges else f"0:{output_descriptor.shape[scan_dim_index]}" for dim, _ in closure_domain - ), + }, ) else: - nsdfg, map_domain, results = self._visit_parallel_stencil_closure( + nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( node, closure_sdfg.arrays, closure_domain ) assert len(results) == 1 @@ -313,7 +314,11 @@ def visit_StencilClosure( transient=True, ) - output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_domain.keys())) + output_memlet = create_memlet_at( + output_name, + self.storage_types[output_name], + {dim: f"i_{dim}" for dim, _ in closure_domain}, + ) input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, [output_memlet])} @@ -325,7 +330,7 @@ def visit_StencilClosure( nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( closure_state, sdfg=nsdfg, - map_ranges=map_domain or {"__dummy": "0"}, + map_ranges=map_ranges or {"__dummy": "0"}, inputs=array_mapping, outputs=output_mapping, symbol_mapping=symbol_mapping, @@ -341,10 +346,10 @@ def visit_StencilClosure( edge.src_conn, transient_access, None, - dace.Memlet(data=memlet.data, subset=output_subset), + dace.Memlet.simple(memlet.data, output_subset), ) - inner_memlet = dace.Memlet( - data=memlet.data, subset=output_subset, other_subset=memlet.subset + inner_memlet = dace.Memlet.simple( + memlet.data, output_subset, other_subset_str=memlet.subset ) closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) closure_state.remove_edge(edge) @@ -360,7 +365,7 @@ def visit_StencilClosure( None, map_entry, b.value.data, - create_memlet_at(b.value.data, ("0",)), + dace.Memlet.simple(b.value.data, "0"), ) return closure_sdfg @@ -390,12 +395,12 @@ def _visit_scan_stencil_closure( connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_domain = {} + map_ranges = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value if not dim == scan_dim: - map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" else: scan_lb_str = lb_str scan_ub_str = ub_str @@ -481,29 +486,29 @@ def _visit_scan_stencil_closure( "__result", carry_node1, None, - dace.Memlet(data=f"{scan_carry_name}", subset="0"), + dace.Memlet.simple(scan_carry_name, "0"), ) carry_node2 = lambda_state.add_access(scan_carry_name) lambda_state.add_memlet_path( carry_node2, scan_inner_node, - memlet=dace.Memlet(data=f"{scan_carry_name}", subset="0"), + memlet=dace.Memlet.simple(scan_carry_name, "0"), src_conn=None, dst_conn=lambda_carry_name, ) # connect access nodes to lambda inputs for (inner_name, _), data_name in zip(lambda_inputs[1:], input_names): - data_subset = ( - ", ".join([f"i_{dim}" for dim, _ in closure_domain]) - if isinstance(self.storage_types[data_name], ts.FieldType) - else "0" - ) + if isinstance(self.storage_types[data_name], ts.FieldType): + index = {dim: f"i_{dim}" for dim, _ in closure_domain} + memlet = create_memlet_at(data_name, self.storage_types[data_name], index) + else: + memlet = dace.Memlet.simple(data_name, "0") lambda_state.add_memlet_path( lambda_state.add_access(data_name), scan_inner_node, - memlet=dace.Memlet(data=f"{data_name}", subset=data_subset), + memlet=memlet, src_conn=None, dst_conn=inner_name, ) @@ -532,7 +537,7 @@ def _visit_scan_stencil_closure( lambda_state.add_memlet_path( scan_inner_node, lambda_state.add_access(data_name), - memlet=dace.Memlet(data=data_name, subset=f"i_{scan_dim}"), + memlet=dace.Memlet.simple(data_name, f"i_{scan_dim}"), src_conn=lambda_connector.value.label, dst_conn=None, ) @@ -544,10 +549,10 @@ def _visit_scan_stencil_closure( lambda_update_state.add_memlet_path( result_node, carry_node3, - memlet=dace.Memlet(data=f"{output_names[0]}", subset=f"i_{scan_dim}", other_subset="0"), + memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), ) - return scan_sdfg, map_domain, scan_dim_index + return scan_sdfg, map_ranges, scan_dim_index def _visit_parallel_stencil_closure( self, @@ -562,11 +567,11 @@ def _visit_parallel_stencil_closure( conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_domain = {} + map_ranges = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" # Create an SDFG for the tasklet that computes a single item of the output domain. index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} @@ -583,7 +588,7 @@ def _visit_parallel_stencil_closure( self.node_types, ) - return context.body, map_domain, [r.value.data for r in results] + return context.body, map_ranges, [r.value.data for r in results] def _visit_domain( self, node: itir.FunCall, context: Context @@ -606,7 +611,7 @@ def _visit_domain( ub = translator.visit(upper_bound)[0] bounds.append((dimension.value, (lb, ub))) - return tuple(sorted(bounds, key=lambda item: item[0])) + return tuple(bounds) @staticmethod def _check_no_lifts(node: itir.StencilClosure): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2e7a598d9a..c4ee342080 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -34,7 +34,6 @@ add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, - create_memlet_at, create_memlet_full, filter_neighbor_tables, map_nested_sdfg_symbols, @@ -595,9 +594,9 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ) # if dim is not found in iterator indices, we take the neighbor index over the reduction domain - array_index = [ + flat_index = [ f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name - for dim in sorted(iterator.dimensions) + for dim in iterator.dimensions ] args = [ValueExpr(iterator.field, iterator.dtype)] + [ ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices @@ -608,7 +607,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: name="deref", inputs=set(internals), outputs={"__result"}, - code=f"__result = {args[0].value.data}_v[{', '.join(array_index)}]", + code=f"__result = {args[0].value.data}_v[{', '.join(flat_index)}]", ) for arg, internal in zip(args, internals): @@ -630,12 +629,10 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: return [ValueExpr(value=result_access, dtype=iterator.dtype)] else: - sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) flat_index = [ - ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions + ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.dimensions ] - - args = [ValueExpr(iterator.field, int), *flat_index] + args = [ValueExpr(iterator.field, iterator.dtype), *flat_index] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") @@ -849,7 +846,7 @@ def _visit_reduce(self, node: itir.FunCall): p.apply_pass(lambda_context.body, {}) input_memlets = [ - create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args) + dace.Memlet.simple(expr.value.data, "__idx") for arg, expr in zip(node.args, args) ] output_memlet = dace.Memlet.simple(result_name, "0") @@ -928,7 +925,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = create_memlet_at(result_access.data, ("0",)) + memlet = dace.Memlet.simple(result_access.data, "0") self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 889a1ab150..3134e793e3 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -11,8 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any +from typing import Any, cast import dace @@ -49,12 +48,14 @@ def connectivity_identifier(name: str): def create_memlet_full(source_identifier: str, source_array: dace.data.Array): bounds = [(0, size) for size in source_array.shape] subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds) - return dace.Memlet(data=source_identifier, subset=subset) + return dace.Memlet.simple(source_identifier, subset) -def create_memlet_at(source_identifier: str, index: tuple[str, ...]): - subset = ", ".join(index) - return dace.Memlet(data=source_identifier, subset=subset) +def create_memlet_at(source_identifier: str, storage_type: ts.TypeSpec, index: dict[str, str]): + field_type = cast(ts.FieldType, storage_type) + field_index = [index[dim.value] for dim in field_type.dims] + subset = ", ".join(field_index) + return dace.Memlet.simple(source_identifier, subset) def map_nested_sdfg_symbols( From 6850597b05a3c0937d578afc5128715f4ab32961 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 10 Oct 2023 16:06:08 +0200 Subject: [PATCH 02/10] [dace] Add support for field arguments with offset Required by following icon4py stencils: - TestMoVelocityAdvectionStencil03 - TestMoVelocityAdvectionStencil02VnIe --- .../runners/dace_iterator/__init__.py | 17 +++++++++++++++-- .../runners/dace_iterator/itir_to_sdfg.py | 16 +++++++++++++--- .../runners/dace_iterator/itir_to_tasklet.py | 5 ++++- .../runners/dace_iterator/utility.py | 9 ++++++++- 4 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index f78d90095c..aae28f600f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -69,6 +69,17 @@ def get_shape_args( } +def get_offset_args( + arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] +) -> Mapping[str, int]: + return { + str(sym): -drange.start + for param, arg in zip(params, args) + if common.is_field(arg) + for sym, drange in zip(arrays[param.id].offset, arg.domain.ranges) + } + + def get_stride_args( arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] ) -> Mapping[str, int]: @@ -103,7 +114,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = get_stride_args(sdfg.arrays, dace_field_args) - dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args) + dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) + dace_offsets = get_offset_args(sdfg.arrays, program.params, args) sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" @@ -113,7 +125,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: **dace_shapes, **dace_conn_shapes, **dace_strides, - **dace_conn_stirdes, + **dace_conn_strides, + **dace_offsets, } expected_args = { key: value diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 56031d8555..f283597fff 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -105,12 +105,19 @@ def __init__( self.offset_provider = offset_provider self.storage_types = {} - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec): + def add_storage( + self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, zero_offset: bool = False + ): if isinstance(type_, ts.FieldType): shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + offset = ( + None + if zero_offset + else [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + ) dtype = as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) + sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) else: @@ -134,7 +141,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): scalar_kind = type_translation.get_scalar_kind(table.table.dtype) local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL) type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind)) - self.add_storage(program_sdfg, connectivity_identifier(offset), type_) + self.add_storage(program_sdfg, connectivity_identifier(offset), type_, zero_offset=True) # Create a nested SDFG for all stencil closures. for closure in node.closures: @@ -229,6 +236,7 @@ def visit_StencilClosure( name, shape=array_table[name].shape, strides=array_table[name].strides, + offset=array_table[name].offset, dtype=array_table[name].dtype, ) @@ -428,6 +436,7 @@ def _visit_scan_stencil_closure( name, shape=array_table[name].shape, strides=array_table[name].strides, + offset=array_table[name].offset, dtype=array_table[name].dtype, ) else: @@ -527,6 +536,7 @@ def _visit_scan_stencil_closure( data_name, shape=(array_table[node.output.id].shape[scan_dim_index],), strides=(array_table[node.output.id].strides[scan_dim_index],), + offset=(array_table[node.output.id].offset[scan_dim_index],), dtype=array_table[node.output.id].dtype, ) lambda_state.add_memlet_path( diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2e7a598d9a..80db92ff8b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -990,9 +990,12 @@ def closure_to_tasklet_sdfg( stride = [ dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(ndim) ] + offset = [ + dace.symbol(f"{unique_var_name()}_offset{i}", dtype=dace.int64) for i in range(ndim) + ] dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) - body.add_array(name, shape=shape, strides=stride, dtype=dtype) + body.add_array(name, shape=shape, strides=stride, offset=offset, dtype=dtype) field = state.add_access(name) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 889a1ab150..085d132ad5 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -47,7 +47,10 @@ def connectivity_identifier(name: str): def create_memlet_full(source_identifier: str, source_array: dace.data.Array): - bounds = [(0, size) for size in source_array.shape] + bounds = [ + (f"-{offset}", f"{size}-{offset}") + for offset, size in zip(source_array.offset, source_array.shape) + ] subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds) return dace.Memlet(data=source_identifier, subset=subset) @@ -73,6 +76,10 @@ def map_nested_sdfg_symbols( for arg_stride, param_stride in zip(arg_array.strides, param_array.strides): if isinstance(param_stride, dace.symbol): symbol_mapping[str(param_stride)] = str(arg_stride) + assert len(arg_array.offset) == len(param_array.offset) + for arg_offset, param_offset in zip(arg_array.offset, param_array.offset): + if isinstance(param_offset, dace.symbol): + symbol_mapping[str(param_offset)] = str(arg_offset) else: assert arg.subset.num_elements() == 1 for sym in nested_sdfg.free_symbols: From 397c3762748b710a9e5cb8bfbdf97a7d15095747 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 11 Oct 2023 11:22:54 +0200 Subject: [PATCH 03/10] [dace] Add support for tuple arguments --- .../runners/dace_iterator/__init__.py | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 42a4657d91..d5565071c3 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -15,7 +15,6 @@ from typing import Any, Mapping, Sequence import dace -import numpy as np import gt4py.next.iterator.ir as itir from gt4py.next import common @@ -31,8 +30,8 @@ def convert_arg(arg: Any): if common.is_field(arg): - return arg.ndarray - return arg + return (arg.ndarray, arg.domain) + return (arg, None) def preprocess_program(program: itir.FencilDefinition, offset_provider: Mapping[str, Any]): @@ -45,8 +44,25 @@ def preprocess_program(program: itir.FencilDefinition, offset_provider: Mapping[ return program +def expand_tuple_arg(name: str, arg: tuple) -> dict[str, Any]: + t = {} + for idx, member_arg in enumerate(arg): + member_name = f"{name}_{idx}" + if isinstance(member_arg, tuple): + t.update(expand_tuple_arg(member_name, member_arg)) + else: + t[member_name] = convert_arg(member_arg) + return t + + def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: - return {name.id: convert_arg(arg) for name, arg in zip(params, args)} + t = {} + for param, arg in zip(params, args): + if isinstance(arg, tuple): + t.update(expand_tuple_arg(param.id, arg)) + else: + t[param.id] = convert_arg(arg) + return t def get_connectivity_args( @@ -66,13 +82,12 @@ def get_shape_args( def get_offset_args( - arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] + arrays: Mapping[str, dace.data.Array], field_domains: Mapping[str, Any] ) -> Mapping[str, int]: return { str(sym): -drange.start - for param, arg in zip(params, args) - if common.is_field(arg) - for sym, drange in zip(arrays[param.id].offset, arg.domain.ranges) + for name, domain in field_domains.items() + for sym, drange in zip(arrays[name].offset, domain.ranges) } @@ -105,18 +120,22 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: sdfg.simplify() dace_args = get_args(program.params, args) - dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} + # domain is only set for field arguments + dace_field_args = {n: v for n, (v, d) in dace_args.items() if d} + dace_field_domains = {n: d for n, (v, d) in dace_args.items() if d} + dace_scalar_args = {n: v for n, (v, d) in dace_args.items() if d is None} dace_conn_args = get_connectivity_args(neighbor_tables) dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = get_stride_args(sdfg.arrays, dace_field_args) dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) - dace_offsets = get_offset_args(sdfg.arrays, program.params, args) + dace_offsets = get_offset_args(sdfg.arrays, dace_field_domains) sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" all_args = { - **dace_args, + **dace_field_args, + **dace_scalar_args, **dace_conn_args, **dace_shapes, **dace_conn_shapes, From 405dad600480f00e42bdf3a728fe8f906b81584f Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 11 Oct 2023 15:10:13 +0200 Subject: [PATCH 04/10] [dace] Add offset to scan range --- .../runners/dace_iterator/itir_to_sdfg.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 78350eb94b..42e8786c14 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -290,21 +290,15 @@ def visit_StencilClosure( _, (scan_lb, scan_ub) = closure_domain[scan_dim_index] output_subset = f"{scan_lb.value}:{scan_ub.value}" - closure_sdfg.add_array( - nsdfg_output_name, - dtype=output_descriptor.dtype, - shape=(array_table[output_name].shape[scan_dim_index],), - strides=(array_table[output_name].strides[scan_dim_index],), - transient=True, - ) - + scan_offset = output_descriptor.offset[scan_dim_index] + scan_shape = output_descriptor.shape[scan_dim_index] output_memlet = create_memlet_at( output_name, self.storage_types[output_name], { dim: f"i_{dim}" if f"i_{dim}" in map_ranges - else f"0:{output_descriptor.shape[scan_dim_index]}" + else f"{scan_offset}:{scan_offset}+{scan_shape}" for dim, _ in closure_domain }, ) @@ -316,12 +310,6 @@ def visit_StencilClosure( output_subset = "0" - closure_sdfg.add_scalar( - nsdfg_output_name, - dtype=output_descriptor.dtype, - transient=True, - ) - output_memlet = create_memlet_at( output_name, self.storage_types[output_name], From 3ae9478d9ab9d3aab3c55b10a6c8d0da781a3891 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 11 Oct 2023 16:19:07 +0200 Subject: [PATCH 05/10] Revert "[dace] Add offset to scan range" This reverts commit 405dad600480f00e42bdf3a728fe8f906b81584f. --- .../runners/dace_iterator/itir_to_sdfg.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 42e8786c14..78350eb94b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -290,15 +290,21 @@ def visit_StencilClosure( _, (scan_lb, scan_ub) = closure_domain[scan_dim_index] output_subset = f"{scan_lb.value}:{scan_ub.value}" - scan_offset = output_descriptor.offset[scan_dim_index] - scan_shape = output_descriptor.shape[scan_dim_index] + closure_sdfg.add_array( + nsdfg_output_name, + dtype=output_descriptor.dtype, + shape=(array_table[output_name].shape[scan_dim_index],), + strides=(array_table[output_name].strides[scan_dim_index],), + transient=True, + ) + output_memlet = create_memlet_at( output_name, self.storage_types[output_name], { dim: f"i_{dim}" if f"i_{dim}" in map_ranges - else f"{scan_offset}:{scan_offset}+{scan_shape}" + else f"0:{output_descriptor.shape[scan_dim_index]}" for dim, _ in closure_domain }, ) @@ -310,6 +316,12 @@ def visit_StencilClosure( output_subset = "0" + closure_sdfg.add_scalar( + nsdfg_output_name, + dtype=output_descriptor.dtype, + transient=True, + ) + output_memlet = create_memlet_at( output_name, self.storage_types[output_name], From 8f2190c230dc9b86175c20900bb6a4a573f3e4d2 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 11 Oct 2023 16:20:36 +0200 Subject: [PATCH 06/10] [dace] Fix - do not propagate offset to nested-SDFG --- .../runners/dace_iterator/itir_to_sdfg.py | 6 ++---- .../runners/dace_iterator/itir_to_tasklet.py | 5 +---- .../program_processors/runners/dace_iterator/utility.py | 9 +-------- 3 files changed, 4 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 78350eb94b..263d6cdc36 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -236,7 +236,6 @@ def visit_StencilClosure( name, shape=array_table[name].shape, strides=array_table[name].strides, - offset=array_table[name].offset, dtype=array_table[name].dtype, ) @@ -293,8 +292,8 @@ def visit_StencilClosure( closure_sdfg.add_array( nsdfg_output_name, dtype=output_descriptor.dtype, - shape=(array_table[output_name].shape[scan_dim_index],), - strides=(array_table[output_name].strides[scan_dim_index],), + shape=(output_descriptor.shape[scan_dim_index],), + strides=(output_descriptor.strides[scan_dim_index],), transient=True, ) @@ -441,7 +440,6 @@ def _visit_scan_stencil_closure( name, shape=array_table[name].shape, strides=array_table[name].strides, - offset=array_table[name].offset, dtype=array_table[name].dtype, ) else: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 0b8464b1d8..c4ee342080 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -987,12 +987,9 @@ def closure_to_tasklet_sdfg( stride = [ dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(ndim) ] - offset = [ - dace.symbol(f"{unique_var_name()}_offset{i}", dtype=dace.int64) for i in range(ndim) - ] dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) - body.add_array(name, shape=shape, strides=stride, offset=offset, dtype=dtype) + body.add_array(name, shape=shape, strides=stride, dtype=dtype) field = state.add_access(name) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index e4f4ad5f3b..3134e793e3 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -46,10 +46,7 @@ def connectivity_identifier(name: str): def create_memlet_full(source_identifier: str, source_array: dace.data.Array): - bounds = [ - (f"-{offset}", f"{size}-{offset}") - for offset, size in zip(source_array.offset, source_array.shape) - ] + bounds = [(0, size) for size in source_array.shape] subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds) return dace.Memlet.simple(source_identifier, subset) @@ -77,10 +74,6 @@ def map_nested_sdfg_symbols( for arg_stride, param_stride in zip(arg_array.strides, param_array.strides): if isinstance(param_stride, dace.symbol): symbol_mapping[str(param_stride)] = str(arg_stride) - assert len(arg_array.offset) == len(param_array.offset) - for arg_offset, param_offset in zip(arg_array.offset, param_array.offset): - if isinstance(param_offset, dace.symbol): - symbol_mapping[str(param_offset)] = str(arg_offset) else: assert arg.subset.num_elements() == 1 for sym in nested_sdfg.free_symbols: From e7e0064a254c11ed49552dd5114c0a338b366680 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 13 Oct 2023 13:53:54 +0200 Subject: [PATCH 07/10] Revert "[dace] Add support for tuple arguments" This reverts commit 397c3762748b710a9e5cb8bfbdf97a7d15095747. --- .../runners/dace_iterator/__init__.py | 41 +++++-------------- 1 file changed, 11 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index d5565071c3..42a4657d91 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -15,6 +15,7 @@ from typing import Any, Mapping, Sequence import dace +import numpy as np import gt4py.next.iterator.ir as itir from gt4py.next import common @@ -30,8 +31,8 @@ def convert_arg(arg: Any): if common.is_field(arg): - return (arg.ndarray, arg.domain) - return (arg, None) + return arg.ndarray + return arg def preprocess_program(program: itir.FencilDefinition, offset_provider: Mapping[str, Any]): @@ -44,25 +45,8 @@ def preprocess_program(program: itir.FencilDefinition, offset_provider: Mapping[ return program -def expand_tuple_arg(name: str, arg: tuple) -> dict[str, Any]: - t = {} - for idx, member_arg in enumerate(arg): - member_name = f"{name}_{idx}" - if isinstance(member_arg, tuple): - t.update(expand_tuple_arg(member_name, member_arg)) - else: - t[member_name] = convert_arg(member_arg) - return t - - def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: - t = {} - for param, arg in zip(params, args): - if isinstance(arg, tuple): - t.update(expand_tuple_arg(param.id, arg)) - else: - t[param.id] = convert_arg(arg) - return t + return {name.id: convert_arg(arg) for name, arg in zip(params, args)} def get_connectivity_args( @@ -82,12 +66,13 @@ def get_shape_args( def get_offset_args( - arrays: Mapping[str, dace.data.Array], field_domains: Mapping[str, Any] + arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] ) -> Mapping[str, int]: return { str(sym): -drange.start - for name, domain in field_domains.items() - for sym, drange in zip(arrays[name].offset, domain.ranges) + for param, arg in zip(params, args) + if common.is_field(arg) + for sym, drange in zip(arrays[param.id].offset, arg.domain.ranges) } @@ -120,22 +105,18 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: sdfg.simplify() dace_args = get_args(program.params, args) - # domain is only set for field arguments - dace_field_args = {n: v for n, (v, d) in dace_args.items() if d} - dace_field_domains = {n: d for n, (v, d) in dace_args.items() if d} - dace_scalar_args = {n: v for n, (v, d) in dace_args.items() if d is None} + dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} dace_conn_args = get_connectivity_args(neighbor_tables) dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = get_stride_args(sdfg.arrays, dace_field_args) dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) - dace_offsets = get_offset_args(sdfg.arrays, dace_field_domains) + dace_offsets = get_offset_args(sdfg.arrays, program.params, args) sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" all_args = { - **dace_field_args, - **dace_scalar_args, + **dace_args, **dace_conn_args, **dace_shapes, **dace_conn_shapes, From c3bc4a2ee78e29d57fa2102e014bdb72712257d8 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 13 Oct 2023 13:54:48 +0200 Subject: [PATCH 08/10] Revert "[dace] Remove re-ordering of data layout" This reverts commit 19aa9836100e521f661a0ee500fa89bc83a7ae42. --- .../runners/dace_iterator/__init__.py | 6 +- .../runners/dace_iterator/itir_to_sdfg.py | 65 +++++++++---------- .../runners/dace_iterator/itir_to_tasklet.py | 17 +++-- .../runners/dace_iterator/utility.py | 13 ++-- 4 files changed, 51 insertions(+), 50 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 42a4657d91..aae28f600f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -31,7 +31,11 @@ def convert_arg(arg: Any): if common.is_field(arg): - return arg.ndarray + sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) + ndim = len(sorted_dims) + dim_indices = [dim[0] for dim in sorted_dims] + assert isinstance(arg.ndarray, np.ndarray) + return np.moveaxis(arg.ndarray, range(ndim), dim_indices) return arg diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 263d6cdc36..1c018f6485 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -253,7 +253,7 @@ def visit_StencilClosure( ) access = closure_init_state.add_access(out_name) value = ValueExpr(access, dtype) - memlet = dace.Memlet.simple(out_name, "0") + memlet = create_memlet_at(out_name, ("0",)) closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) program_arg_syms[name] = value else: @@ -281,7 +281,7 @@ def visit_StencilClosure( transient_to_arg_name_mapping[nsdfg_output_name] = output_name # scan operator should always be the first function call in a closure if is_scan(node.stencil): - nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( + nsdfg, map_domain, scan_dim_index = self._visit_scan_stencil_closure( node, closure_sdfg.arrays, closure_domain, nsdfg_output_name ) results = [nsdfg_output_name] @@ -299,16 +299,15 @@ def visit_StencilClosure( output_memlet = create_memlet_at( output_name, - self.storage_types[output_name], - { - dim: f"i_{dim}" - if f"i_{dim}" in map_ranges + tuple( + f"i_{dim}" + if f"i_{dim}" in map_domain else f"0:{output_descriptor.shape[scan_dim_index]}" for dim, _ in closure_domain - }, + ), ) else: - nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( + nsdfg, map_domain, results = self._visit_parallel_stencil_closure( node, closure_sdfg.arrays, closure_domain ) assert len(results) == 1 @@ -321,11 +320,7 @@ def visit_StencilClosure( transient=True, ) - output_memlet = create_memlet_at( - output_name, - self.storage_types[output_name], - {dim: f"i_{dim}" for dim, _ in closure_domain}, - ) + output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_domain.keys())) input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, [output_memlet])} @@ -337,7 +332,7 @@ def visit_StencilClosure( nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( closure_state, sdfg=nsdfg, - map_ranges=map_ranges or {"__dummy": "0"}, + map_ranges=map_domain or {"__dummy": "0"}, inputs=array_mapping, outputs=output_mapping, symbol_mapping=symbol_mapping, @@ -353,10 +348,10 @@ def visit_StencilClosure( edge.src_conn, transient_access, None, - dace.Memlet.simple(memlet.data, output_subset), + dace.Memlet(data=memlet.data, subset=output_subset), ) - inner_memlet = dace.Memlet.simple( - memlet.data, output_subset, other_subset_str=memlet.subset + inner_memlet = dace.Memlet( + data=memlet.data, subset=output_subset, other_subset=memlet.subset ) closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) closure_state.remove_edge(edge) @@ -372,7 +367,7 @@ def visit_StencilClosure( None, map_entry, b.value.data, - dace.Memlet.simple(b.value.data, "0"), + create_memlet_at(b.value.data, ("0",)), ) return closure_sdfg @@ -402,12 +397,12 @@ def _visit_scan_stencil_closure( connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} + map_domain = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value if not dim == scan_dim: - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" else: scan_lb_str = lb_str scan_ub_str = ub_str @@ -493,29 +488,29 @@ def _visit_scan_stencil_closure( "__result", carry_node1, None, - dace.Memlet.simple(scan_carry_name, "0"), + dace.Memlet(data=f"{scan_carry_name}", subset="0"), ) carry_node2 = lambda_state.add_access(scan_carry_name) lambda_state.add_memlet_path( carry_node2, scan_inner_node, - memlet=dace.Memlet.simple(scan_carry_name, "0"), + memlet=dace.Memlet(data=f"{scan_carry_name}", subset="0"), src_conn=None, dst_conn=lambda_carry_name, ) # connect access nodes to lambda inputs for (inner_name, _), data_name in zip(lambda_inputs[1:], input_names): - if isinstance(self.storage_types[data_name], ts.FieldType): - index = {dim: f"i_{dim}" for dim, _ in closure_domain} - memlet = create_memlet_at(data_name, self.storage_types[data_name], index) - else: - memlet = dace.Memlet.simple(data_name, "0") + data_subset = ( + ", ".join([f"i_{dim}" for dim, _ in closure_domain]) + if isinstance(self.storage_types[data_name], ts.FieldType) + else "0" + ) lambda_state.add_memlet_path( lambda_state.add_access(data_name), scan_inner_node, - memlet=memlet, + memlet=dace.Memlet(data=f"{data_name}", subset=data_subset), src_conn=None, dst_conn=inner_name, ) @@ -545,7 +540,7 @@ def _visit_scan_stencil_closure( lambda_state.add_memlet_path( scan_inner_node, lambda_state.add_access(data_name), - memlet=dace.Memlet.simple(data_name, f"i_{scan_dim}"), + memlet=dace.Memlet(data=data_name, subset=f"i_{scan_dim}"), src_conn=lambda_connector.value.label, dst_conn=None, ) @@ -557,10 +552,10 @@ def _visit_scan_stencil_closure( lambda_update_state.add_memlet_path( result_node, carry_node3, - memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), + memlet=dace.Memlet(data=f"{output_names[0]}", subset=f"i_{scan_dim}", other_subset="0"), ) - return scan_sdfg, map_ranges, scan_dim_index + return scan_sdfg, map_domain, scan_dim_index def _visit_parallel_stencil_closure( self, @@ -575,11 +570,11 @@ def _visit_parallel_stencil_closure( conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} + map_domain = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" # Create an SDFG for the tasklet that computes a single item of the output domain. index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} @@ -596,7 +591,7 @@ def _visit_parallel_stencil_closure( self.node_types, ) - return context.body, map_ranges, [r.value.data for r in results] + return context.body, map_domain, [r.value.data for r in results] def _visit_domain( self, node: itir.FunCall, context: Context @@ -619,7 +614,7 @@ def _visit_domain( ub = translator.visit(upper_bound)[0] bounds.append((dimension.value, (lb, ub))) - return tuple(bounds) + return tuple(sorted(bounds, key=lambda item: item[0])) @staticmethod def _check_no_lifts(node: itir.StencilClosure): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index c4ee342080..2e7a598d9a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -34,6 +34,7 @@ add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, + create_memlet_at, create_memlet_full, filter_neighbor_tables, map_nested_sdfg_symbols, @@ -594,9 +595,9 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ) # if dim is not found in iterator indices, we take the neighbor index over the reduction domain - flat_index = [ + array_index = [ f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name - for dim in iterator.dimensions + for dim in sorted(iterator.dimensions) ] args = [ValueExpr(iterator.field, iterator.dtype)] + [ ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices @@ -607,7 +608,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: name="deref", inputs=set(internals), outputs={"__result"}, - code=f"__result = {args[0].value.data}_v[{', '.join(flat_index)}]", + code=f"__result = {args[0].value.data}_v[{', '.join(array_index)}]", ) for arg, internal in zip(args, internals): @@ -629,10 +630,12 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: return [ValueExpr(value=result_access, dtype=iterator.dtype)] else: + sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) flat_index = [ - ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.dimensions + ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions ] - args = [ValueExpr(iterator.field, iterator.dtype), *flat_index] + + args = [ValueExpr(iterator.field, int), *flat_index] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") @@ -846,7 +849,7 @@ def _visit_reduce(self, node: itir.FunCall): p.apply_pass(lambda_context.body, {}) input_memlets = [ - dace.Memlet.simple(expr.value.data, "__idx") for arg, expr in zip(node.args, args) + create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args) ] output_memlet = dace.Memlet.simple(result_name, "0") @@ -925,7 +928,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = dace.Memlet.simple(result_access.data, "0") + memlet = create_memlet_at(result_access.data, ("0",)) self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 3134e793e3..889a1ab150 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -11,7 +11,8 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, cast + +from typing import Any import dace @@ -48,14 +49,12 @@ def connectivity_identifier(name: str): def create_memlet_full(source_identifier: str, source_array: dace.data.Array): bounds = [(0, size) for size in source_array.shape] subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds) - return dace.Memlet.simple(source_identifier, subset) + return dace.Memlet(data=source_identifier, subset=subset) -def create_memlet_at(source_identifier: str, storage_type: ts.TypeSpec, index: dict[str, str]): - field_type = cast(ts.FieldType, storage_type) - field_index = [index[dim.value] for dim in field_type.dims] - subset = ", ".join(field_index) - return dace.Memlet.simple(source_identifier, subset) +def create_memlet_at(source_identifier: str, index: tuple[str, ...]): + subset = ", ".join(index) + return dace.Memlet(data=source_identifier, subset=subset) def map_nested_sdfg_symbols( From 5a198dfed000cc3a22f93d6de6aaccbd6bc63709 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 13 Oct 2023 15:34:06 +0200 Subject: [PATCH 09/10] [dace] Maintain canonical representation of field domain Keep alphabetical order of dimensions in field domain. --- .../runners/dace_iterator/__init__.py | 15 ++++++++++++--- .../runners/dace_iterator/itir_to_sdfg.py | 12 +++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index aae28f600f..13eab0b74b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -29,11 +29,20 @@ from .utility import connectivity_identifier, filter_neighbor_tables +def get_sorted_dims(dims: Sequence[common.Dimension]) -> Sequence[tuple[int, common.Dimension]]: + return sorted(enumerate(dims), key=lambda v: v[1].value) + + +def get_sorted_dim_ranges(domain: common.Domain) -> Sequence[common.UnitRange]: + sorted_dims = get_sorted_dims(domain.dims) + return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] + + def convert_arg(arg: Any): if common.is_field(arg): - sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) + sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) - dim_indices = [dim[0] for dim in sorted_dims] + dim_indices = [dim_index for dim_index, _ in sorted_dims] assert isinstance(arg.ndarray, np.ndarray) return np.moveaxis(arg.ndarray, range(ndim), dim_indices) return arg @@ -76,7 +85,7 @@ def get_offset_args( str(sym): -drange.start for param, arg in zip(params, args) if common.is_field(arg) - for sym, drange in zip(arrays[param.id].offset, arg.domain.ranges) + for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) } diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 1c018f6485..e350672061 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -105,16 +105,14 @@ def __init__( self.offset_provider = offset_provider self.storage_types = {} - def add_storage( - self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, zero_offset: bool = False - ): + def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] offset = ( - None - if zero_offset - else [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + if has_offset + else None ) dtype = as_dace_type(type_.dtype) sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) @@ -141,7 +139,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): scalar_kind = type_translation.get_scalar_kind(table.table.dtype) local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL) type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind)) - self.add_storage(program_sdfg, connectivity_identifier(offset), type_, zero_offset=True) + self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False) # Create a nested SDFG for all stencil closures. for closure in node.closures: From d560231619786f929495fd31c8c5ed2d3fc259e2 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 13 Oct 2023 15:43:36 +0200 Subject: [PATCH 10/10] [dace] Make utility get_sorted_dims --- .../runners/dace_iterator/__init__.py | 14 +++++--------- .../runners/dace_iterator/utility.py | 7 ++++++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 13eab0b74b..85de0dc5f4 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -18,7 +18,7 @@ import numpy as np import gt4py.next.iterator.ir as itir -from gt4py.next import common +from gt4py.next.common import Domain, UnitRange, is_field from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.otf.compilation import cache @@ -26,20 +26,16 @@ from gt4py.next.type_system import type_translation from .itir_to_sdfg import ItirToSDFG -from .utility import connectivity_identifier, filter_neighbor_tables +from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims -def get_sorted_dims(dims: Sequence[common.Dimension]) -> Sequence[tuple[int, common.Dimension]]: - return sorted(enumerate(dims), key=lambda v: v[1].value) - - -def get_sorted_dim_ranges(domain: common.Domain) -> Sequence[common.UnitRange]: +def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: sorted_dims = get_sorted_dims(domain.dims) return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] def convert_arg(arg: Any): - if common.is_field(arg): + if is_field(arg): sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) dim_indices = [dim_index for dim_index, _ in sorted_dims] @@ -84,7 +80,7 @@ def get_offset_args( return { str(sym): -drange.start for param, arg in zip(params, args) - if common.is_field(arg) + if is_field(arg) for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) } diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 889a1ab150..e12a20e8ad 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -12,10 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any +from typing import Any, Sequence import dace +from gt4py.next import Dimension from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.type_system import type_specifications as ts @@ -57,6 +58,10 @@ def create_memlet_at(source_identifier: str, index: tuple[str, ...]): return dace.Memlet(data=source_identifier, subset=subset) +def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]: + return sorted(enumerate(dims), key=lambda v: v[1].value) + + def map_nested_sdfg_symbols( parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet] ) -> dict[str, str]: