Skip to content

Commit

Permalink
[dace] Add support for field arguments with offset
Browse files Browse the repository at this point in the history
Required by following icon4py stencils:
- TestMoVelocityAdvectionStencil03
- TestMoVelocityAdvectionStencil02VnIe
  • Loading branch information
edopao committed Oct 10, 2023
1 parent 0d821b1 commit 6850597
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ def get_shape_args(
}


def get_offset_args(
arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any]
) -> Mapping[str, int]:
return {
str(sym): -drange.start
for param, arg in zip(params, args)
if common.is_field(arg)
for sym, drange in zip(arrays[param.id].offset, arg.domain.ranges)
}


def get_stride_args(
arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any]
) -> Mapping[str, int]:
Expand Down Expand Up @@ -103,7 +114,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args)
dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args)
dace_offsets = get_offset_args(sdfg.arrays, program.params, args)

sdfg.build_folder = cache._session_cache_dir_path / ".dacecache"

Expand All @@ -113,7 +125,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
**dace_shapes,
**dace_conn_shapes,
**dace_strides,
**dace_conn_stirdes,
**dace_conn_strides,
**dace_offsets,
}
expected_args = {
key: value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,19 @@ def __init__(
self.offset_provider = offset_provider
self.storage_types = {}

def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec):
def add_storage(
self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, zero_offset: bool = False
):
if isinstance(type_, ts.FieldType):
shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
offset = (
None
if zero_offset
else [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))]
)
dtype = as_dace_type(type_.dtype)
sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype)
sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype)
elif isinstance(type_, ts.ScalarType):
sdfg.add_symbol(name, as_dace_type(type_))
else:
Expand All @@ -134,7 +141,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
scalar_kind = type_translation.get_scalar_kind(table.table.dtype)
local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL)
type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind))
self.add_storage(program_sdfg, connectivity_identifier(offset), type_)
self.add_storage(program_sdfg, connectivity_identifier(offset), type_, zero_offset=True)

# Create a nested SDFG for all stencil closures.
for closure in node.closures:
Expand Down Expand Up @@ -229,6 +236,7 @@ def visit_StencilClosure(
name,
shape=array_table[name].shape,
strides=array_table[name].strides,
offset=array_table[name].offset,
dtype=array_table[name].dtype,
)

Expand Down Expand Up @@ -428,6 +436,7 @@ def _visit_scan_stencil_closure(
name,
shape=array_table[name].shape,
strides=array_table[name].strides,
offset=array_table[name].offset,
dtype=array_table[name].dtype,
)
else:
Expand Down Expand Up @@ -527,6 +536,7 @@ def _visit_scan_stencil_closure(
data_name,
shape=(array_table[node.output.id].shape[scan_dim_index],),
strides=(array_table[node.output.id].strides[scan_dim_index],),
offset=(array_table[node.output.id].offset[scan_dim_index],),
dtype=array_table[node.output.id].dtype,
)
lambda_state.add_memlet_path(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -990,9 +990,12 @@ def closure_to_tasklet_sdfg(
stride = [
dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(ndim)
]
offset = [
dace.symbol(f"{unique_var_name()}_offset{i}", dtype=dace.int64) for i in range(ndim)
]
dims = [dim.value for dim in ty.dims]
dtype = as_dace_type(ty.dtype)
body.add_array(name, shape=shape, strides=stride, dtype=dtype)
body.add_array(name, shape=shape, strides=stride, offset=offset, dtype=dtype)
field = state.add_access(name)
indices = {dim: idx_accesses[dim] for dim in domain.keys()}
symbol_map[name] = IteratorExpr(field, indices, dtype, dims)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def connectivity_identifier(name: str):


def create_memlet_full(source_identifier: str, source_array: dace.data.Array):
bounds = [(0, size) for size in source_array.shape]
bounds = [
(f"-{offset}", f"{size}-{offset}")
for offset, size in zip(source_array.offset, source_array.shape)
]
subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds)
return dace.Memlet(data=source_identifier, subset=subset)

Expand All @@ -73,6 +76,10 @@ def map_nested_sdfg_symbols(
for arg_stride, param_stride in zip(arg_array.strides, param_array.strides):
if isinstance(param_stride, dace.symbol):
symbol_mapping[str(param_stride)] = str(arg_stride)
assert len(arg_array.offset) == len(param_array.offset)
for arg_offset, param_offset in zip(arg_array.offset, param_array.offset):
if isinstance(param_offset, dace.symbol):
symbol_mapping[str(param_offset)] = str(arg_offset)
else:
assert arg.subset.num_elements() == 1
for sym in nested_sdfg.free_symbols:
Expand Down

0 comments on commit 6850597

Please sign in to comment.