Skip to content

Commit

Permalink
feat[next][dace]: make canonical representation of field domain optio…
Browse files Browse the repository at this point in the history
…nal (GridTools#1476)

Baseline implementation of DaCe backend was reordering the dimensions in
field domain based on alphabetical order. This is the canonical
representation of field domain, and provides the advantage of not
requiring regenerating the SDFG for different memory layouts of field
arguments. Besides, the code for accessing a field is simple, because
all field domains are assumed to follow the same layout.

However, the canonical representation poses an obstacle to the
realization of module-level SDFGs, because it requires an additional
conversion step of all array arguments before calling the SDFG.
Therefore, we make the canonical representation optional. Note that this
change should not have any performance impact, because the real memory
layout of field arrays is not modified.
  • Loading branch information
edopao authored Feb 29, 2024
1 parent 77a205b commit ae9c203
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 98 deletions.
96 changes: 58 additions & 38 deletions src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,38 +42,37 @@
cp = None


def get_sorted_dim_ranges(domain: common.Domain) -> Sequence[common.FiniteUnitRange]:
assert common.Domain.is_finite(domain)
sorted_dims = get_sorted_dims(domain.dims)
return [domain.ranges[dim_index] for dim_index, _ in sorted_dims]


""" Default build configuration in DaCe backend """
_build_type = "Release"


def convert_arg(arg: Any, sdfg_param: str):
if common.is_field(arg):
# field domain offsets are not supported
non_zero_offsets = [
(dim, dim_range)
for dim, dim_range in zip(arg.domain.dims, arg.domain.ranges)
if dim_range.start != 0
]
if non_zero_offsets:
dim, dim_range = non_zero_offsets[0]
raise RuntimeError(
f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}."
)
sorted_dims = get_sorted_dims(arg.domain.dims)
ndim = len(sorted_dims)
dim_indices = [dim_index for dim_index, _ in sorted_dims]
if isinstance(arg.ndarray, np.ndarray):
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
else:
assert cp is not None and isinstance(arg.ndarray, cp.ndarray)
return cp.moveaxis(arg.ndarray, range(ndim), dim_indices)
return arg
_default_on_gpu = False
_default_use_field_canonical_representation = False


def convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool):
if not common.is_field(arg):
return arg
# field domain offsets are not supported
non_zero_offsets = [
(dim, dim_range)
for dim, dim_range in zip(arg.domain.dims, arg.domain.ranges)
if dim_range.start != 0
]
if non_zero_offsets:
dim, dim_range = non_zero_offsets[0]
raise RuntimeError(
f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}."
)
if not use_field_canonical_representation:
return arg.ndarray
# the canonical representation requires alphabetical ordering of the dimensions in field domain definition
sorted_dims = get_sorted_dims(arg.domain.dims)
ndim = len(sorted_dims)
dim_indices = [dim_index for dim_index, _ in sorted_dims]
if isinstance(arg.ndarray, np.ndarray):
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
else:
assert cp is not None and isinstance(arg.ndarray, cp.ndarray)
return cp.moveaxis(arg.ndarray, range(ndim), dim_indices)


def preprocess_program(
Expand Down Expand Up @@ -107,9 +106,14 @@ def preprocess_program(
return fencil_definition, tmps


def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
def get_args(
sdfg: dace.SDFG, args: Sequence[Any], use_field_canonical_representation: bool
) -> dict[str, Any]:
sdfg_params: Sequence[str] = sdfg.arg_names
return {sdfg_param: convert_arg(arg, sdfg_param) for sdfg_param, arg in zip(sdfg_params, args)}
return {
sdfg_param: convert_arg(arg, sdfg_param, use_field_canonical_representation)
for sdfg_param, arg in zip(sdfg_params, args)
}


def _ensure_is_on_device(
Expand Down Expand Up @@ -162,8 +166,13 @@ def get_stride_args(
raise ValueError(
f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)."
)
stride_args[str(sym)] = stride

if isinstance(sym, dace.symbol):
assert sym.name not in stride_args
stride_args[str(sym)] = stride
elif sym != stride:
raise RuntimeError(
f"Expected stride {arrays[name].strides} for arg {name}, got {value.strides}."
)
return stride_args


Expand Down Expand Up @@ -221,12 +230,15 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, check_args: bool = False, **kwargs) ->
sdfg: The SDFG for which we want to get the arguments.
"""
offset_provider = kwargs["offset_provider"]
on_gpu = kwargs.get("on_gpu", False)
on_gpu = kwargs.get("on_gpu", _default_on_gpu)
use_field_canonical_representation = kwargs.get(
"use_field_canonical_representation", _default_use_field_canonical_representation
)

neighbor_tables = filter_neighbor_tables(offset_provider)
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

dace_args = get_args(sdfg, args)
dace_args = get_args(sdfg, args, use_field_canonical_representation)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_connectivity_args(neighbor_tables, device)
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
Expand Down Expand Up @@ -261,6 +273,7 @@ def build_sdfg_from_itir(
load_sdfg_from_file: bool = False,
cache_id: Optional[str] = None,
save_sdfg: bool = True,
use_field_canonical_representation: bool = True,
) -> dace.SDFG:
"""Translate a Fencil into an SDFG.
Expand All @@ -275,6 +288,7 @@ def build_sdfg_from_itir(
load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only.
cache_id: The id of the cache entry, used to disambiguate stored sdfgs.
save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`.
use_field_canonical_representation: If `True`, assume that the fields dimensions are sorted alphabetically.
Notes:
Currently only the `FORCE_INLINE` liftmode is supported and the value of `lift_mode` is ignored.
Expand All @@ -292,7 +306,9 @@ def build_sdfg_from_itir(

# visit ITIR and generate SDFG
program, tmps = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, tmps, column_axis)
sdfg_genenerator = ItirToSDFG(
arg_types, offset_provider, tmps, use_field_canonical_representation, column_axis
)
sdfg = sdfg_genenerator.visit(program)
if sdfg is None:
raise RuntimeError(f"Visit failed for program {program.id}.")
Expand Down Expand Up @@ -343,9 +359,12 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
build_cache = kwargs.get("build_cache", None)
compiler_args = kwargs.get("compiler_args", None) # `None` will take default.
build_type = kwargs.get("build_type", "RelWithDebInfo")
on_gpu = kwargs.get("on_gpu", False)
on_gpu = kwargs.get("on_gpu", _default_on_gpu)
auto_optimize = kwargs.get("auto_optimize", True)
lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE)
use_field_canonical_representation = kwargs.get(
"use_field_canonical_representation", _default_use_field_canonical_representation
)
# ITIR parameters
column_axis = kwargs.get("column_axis", None)
offset_provider = kwargs["offset_provider"]
Expand Down Expand Up @@ -374,6 +393,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
load_sdfg_from_file=load_sdfg_from_file,
cache_id=cache_id,
save_sdfg=save_sdfg,
use_field_canonical_representation=use_field_canonical_representation,
)

sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache"
Expand Down
Loading

0 comments on commit ae9c203

Please sign in to comment.