Skip to content

Commit

Permalink
feat[next]: Add DaCe support for field arguments with domain offset (#…
Browse files Browse the repository at this point in the history
…1348)

This PR adds support in DaCe backend for field arguments with domain offset. This feature is required by icon4py stencils.
  • Loading branch information
edopao authored Oct 16, 2023
1 parent d07104d commit 45a6e6d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,20 @@
from dace.transformation.auto import auto_optimize as autoopt

import gt4py.next.iterator.ir as itir
from gt4py.next import common
from gt4py.next.common import Domain, UnitRange, is_field
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
from gt4py.next.otf.compilation import cache
from gt4py.next.program_processors.processor_interface import program_executor
from gt4py.next.type_system import type_translation

from .itir_to_sdfg import ItirToSDFG
from .utility import connectivity_identifier, filter_neighbor_tables
from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims


def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]:
sorted_dims = get_sorted_dims(domain.dims)
return [domain.ranges[dim_index] for dim_index, _ in sorted_dims]


""" Default build configuration in DaCe backend """
Expand All @@ -40,10 +45,10 @@


def convert_arg(arg: Any):
if common.is_field(arg):
sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value)
if is_field(arg):
sorted_dims = get_sorted_dims(arg.domain.dims)
ndim = len(sorted_dims)
dim_indices = [dim[0] for dim in sorted_dims]
dim_indices = [dim_index for dim_index, _ in sorted_dims]
assert isinstance(arg.ndarray, np.ndarray)
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
return arg
Expand Down Expand Up @@ -79,6 +84,17 @@ def get_shape_args(
}


def get_offset_args(
arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any]
) -> Mapping[str, int]:
return {
str(sym): -drange.start
for param, arg in zip(params, args)
if is_field(arg)
for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain))
}


def get_stride_args(
arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any]
) -> Mapping[str, int]:
Expand Down Expand Up @@ -163,15 +179,17 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args)
dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args)
dace_offsets = get_offset_args(sdfg.arrays, program.params, args)

all_args = {
**dace_args,
**dace_conn_args,
**dace_shapes,
**dace_conn_shapes,
**dace_strides,
**dace_conn_stirdes,
**dace_conn_strides,
**dace_offsets,
}
expected_args = {
key: value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,17 @@ def __init__(
self.offset_provider = offset_provider
self.storage_types = {}

def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec):
def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True):
if isinstance(type_, ts.FieldType):
shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
offset = (
[dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
if has_offset
else None
)
dtype = as_dace_type(type_.dtype)
sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype)
sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype)
elif isinstance(type_, ts.ScalarType):
sdfg.add_symbol(name, as_dace_type(type_))
else:
Expand All @@ -136,7 +141,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
scalar_kind = type_translation.get_scalar_kind(table.table.dtype)
local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL)
type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind))
self.add_storage(program_sdfg, connectivity_identifier(offset), type_)
self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False)

# Create a nested SDFG for all stencil closures.
for closure in node.closures:
Expand Down Expand Up @@ -287,8 +292,8 @@ def visit_StencilClosure(
closure_sdfg.add_array(
nsdfg_output_name,
dtype=output_descriptor.dtype,
shape=(array_table[output_name].shape[scan_dim_index],),
strides=(array_table[output_name].strides[scan_dim_index],),
shape=(output_descriptor.shape[scan_dim_index],),
strides=(output_descriptor.strides[scan_dim_index],),
transient=True,
)

Expand Down Expand Up @@ -528,6 +533,7 @@ def _visit_scan_stencil_closure(
data_name,
shape=(array_table[node.output.id].shape[scan_dim_index],),
strides=(array_table[node.output.id].strides[scan_dim_index],),
offset=(array_table[node.output.id].offset[scan_dim_index],),
dtype=array_table[node.output.id].dtype,
)
lambda_state.add_memlet_path(
Expand Down

0 comments on commit 45a6e6d

Please sign in to comment.