diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index f78d90095c..aae28f600f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -69,6 +69,17 @@ def get_shape_args( } +def get_offset_args( + arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] +) -> Mapping[str, int]: + return { + str(sym): -drange.start + for param, arg in zip(params, args) + if common.is_field(arg) + for sym, drange in zip(arrays[param.id].offset, arg.domain.ranges) + } + + def get_stride_args( arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] ) -> Mapping[str, int]: @@ -103,7 +114,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = get_stride_args(sdfg.arrays, dace_field_args) - dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args) + dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) + dace_offsets = get_offset_args(sdfg.arrays, program.params, args) sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" @@ -113,7 +125,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: **dace_shapes, **dace_conn_shapes, **dace_strides, - **dace_conn_stirdes, + **dace_conn_strides, + **dace_offsets, } expected_args = { key: value diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 56031d8555..f283597fff 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -105,12 +105,19 @@ def __init__( self.offset_provider = offset_provider self.storage_types = {} - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec): + def add_storage( + self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, zero_offset: bool = False + ): if isinstance(type_, ts.FieldType): shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + offset = ( + None + if zero_offset + else [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + ) dtype = as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) + sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) else: @@ -134,7 +141,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): scalar_kind = type_translation.get_scalar_kind(table.table.dtype) local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL) type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind)) - self.add_storage(program_sdfg, connectivity_identifier(offset), type_) + self.add_storage(program_sdfg, connectivity_identifier(offset), type_, zero_offset=True) # Create a nested SDFG for all stencil closures. for closure in node.closures: @@ -229,6 +236,7 @@ def visit_StencilClosure( name, shape=array_table[name].shape, strides=array_table[name].strides, + offset=array_table[name].offset, dtype=array_table[name].dtype, ) @@ -428,6 +436,7 @@ def _visit_scan_stencil_closure( name, shape=array_table[name].shape, strides=array_table[name].strides, + offset=array_table[name].offset, dtype=array_table[name].dtype, ) else: @@ -527,6 +536,7 @@ def _visit_scan_stencil_closure( data_name, shape=(array_table[node.output.id].shape[scan_dim_index],), strides=(array_table[node.output.id].strides[scan_dim_index],), + offset=(array_table[node.output.id].offset[scan_dim_index],), dtype=array_table[node.output.id].dtype, ) lambda_state.add_memlet_path( diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2e7a598d9a..80db92ff8b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -990,9 +990,12 @@ def closure_to_tasklet_sdfg( stride = [ dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(ndim) ] + offset = [ + dace.symbol(f"{unique_var_name()}_offset{i}", dtype=dace.int64) for i in range(ndim) + ] dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) - body.add_array(name, shape=shape, strides=stride, dtype=dtype) + body.add_array(name, shape=shape, strides=stride, offset=offset, dtype=dtype) field = state.add_access(name) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 889a1ab150..085d132ad5 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -47,7 +47,10 @@ def connectivity_identifier(name: str): def create_memlet_full(source_identifier: str, source_array: dace.data.Array): - bounds = [(0, size) for size in source_array.shape] + bounds = [ + (f"-{offset}", f"{size}-{offset}") + for offset, size in zip(source_array.offset, source_array.shape) + ] subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds) return dace.Memlet(data=source_identifier, subset=subset) @@ -73,6 +76,10 @@ def map_nested_sdfg_symbols( for arg_stride, param_stride in zip(arg_array.strides, param_array.strides): if isinstance(param_stride, dace.symbol): symbol_mapping[str(param_stride)] = str(arg_stride) + assert len(arg_array.offset) == len(param_array.offset) + for arg_offset, param_offset in zip(arg_array.offset, param_array.offset): + if isinstance(param_offset, dace.symbol): + symbol_mapping[str(param_offset)] = str(arg_offset) else: assert arg.subset.num_elements() == 1 for sym in nested_sdfg.free_symbols: