Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: DaCe support for field arguments with domain offset #1348

Merged
merged 11 commits into from
Oct 16, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,27 @@
import numpy as np

import gt4py.next.iterator.ir as itir
from gt4py.next import common
from gt4py.next.common import Domain, UnitRange, is_field
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
from gt4py.next.otf.compilation import cache
from gt4py.next.program_processors.processor_interface import program_executor
from gt4py.next.type_system import type_translation

from .itir_to_sdfg import ItirToSDFG
from .utility import connectivity_identifier, filter_neighbor_tables
from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims


def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]:
sorted_dims = get_sorted_dims(domain.dims)
return [domain.ranges[dim_index] for dim_index, _ in sorted_dims]


def convert_arg(arg: Any):
if common.is_field(arg):
sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value)
if is_field(arg):
sorted_dims = get_sorted_dims(arg.domain.dims)
ndim = len(sorted_dims)
dim_indices = [dim[0] for dim in sorted_dims]
dim_indices = [dim_index for dim_index, _ in sorted_dims]
assert isinstance(arg.ndarray, np.ndarray)
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
return arg
Expand Down Expand Up @@ -69,6 +74,17 @@ def get_shape_args(
}


def get_offset_args(
arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any]
) -> Mapping[str, int]:
return {
str(sym): -drange.start
for param, arg in zip(params, args)
if is_field(arg)
for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain))
}


def get_stride_args(
arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any]
) -> Mapping[str, int]:
Expand Down Expand Up @@ -103,7 +119,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args)
dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args)
dace_offsets = get_offset_args(sdfg.arrays, program.params, args)

sdfg.build_folder = cache._session_cache_dir_path / ".dacecache"

Expand All @@ -113,7 +130,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
**dace_shapes,
**dace_conn_shapes,
**dace_strides,
**dace_conn_stirdes,
**dace_conn_strides,
**dace_offsets,
}
expected_args = {
key: value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,17 @@ def __init__(
self.offset_provider = offset_provider
self.storage_types = {}

def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec):
def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True):
if isinstance(type_, ts.FieldType):
shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
offset = (
[dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
if has_offset
else None
)
dtype = as_dace_type(type_.dtype)
sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype)
sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype)
elif isinstance(type_, ts.ScalarType):
sdfg.add_symbol(name, as_dace_type(type_))
else:
Expand All @@ -134,7 +139,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
scalar_kind = type_translation.get_scalar_kind(table.table.dtype)
local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL)
type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind))
self.add_storage(program_sdfg, connectivity_identifier(offset), type_)
self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False)

# Create a nested SDFG for all stencil closures.
for closure in node.closures:
Expand Down Expand Up @@ -285,8 +290,8 @@ def visit_StencilClosure(
closure_sdfg.add_array(
nsdfg_output_name,
dtype=output_descriptor.dtype,
shape=(array_table[output_name].shape[scan_dim_index],),
strides=(array_table[output_name].strides[scan_dim_index],),
shape=(output_descriptor.shape[scan_dim_index],),
strides=(output_descriptor.strides[scan_dim_index],),
transient=True,
)

Expand Down Expand Up @@ -527,6 +532,7 @@ def _visit_scan_stencil_closure(
data_name,
shape=(array_table[node.output.id].shape[scan_dim_index],),
strides=(array_table[node.output.id].strides[scan_dim_index],),
offset=(array_table[node.output.id].offset[scan_dim_index],),
dtype=array_table[node.output.id].dtype,
)
lambda_state.add_memlet_path(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from typing import Any
from typing import Any, Sequence

import dace

from gt4py.next import Dimension
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.type_system import type_specifications as ts

Expand Down Expand Up @@ -57,6 +58,10 @@ def create_memlet_at(source_identifier: str, index: tuple[str, ...]):
return dace.Memlet(data=source_identifier, subset=subset)


def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]:
return sorted(enumerate(dims), key=lambda v: v[1].value)


def map_nested_sdfg_symbols(
parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet]
) -> dict[str, str]:
Expand Down