Skip to content

Commit

Permalink
feat[next][dace]: Several features for icon4py-backend integration (G…
Browse files Browse the repository at this point in the history
…ridTools#1525)

Several small features are included in this PR, all related to ITIR-DaCe
backend integration in icon4py workflow:

- Separate generation of SDFG from workflow stage for ITIR translation.
This is needed for both Liskov bindings in icon-dsl and DaCe
orchestration.
- Remove translation of StridedNeighborOffsetProvider (same as in GTFN
backend) because not supported by Liskov bindings.
- Add ITIR-DaCe CPU/GPU backends with temporaries pass on ITIR.
  • Loading branch information
edopao authored Apr 10, 2024
1 parent 609a5c9 commit 705530c
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 64 deletions.
6 changes: 6 additions & 0 deletions src/gt4py/next/program_processors/runners/dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,11 @@ class Params:


run_dace_cpu = DaCeBackendFactory(cached=True, auto_optimize=True)
run_dace_cpu_with_temporaries = DaCeBackendFactory(
cached=True, auto_optimize=True, use_temporaries=True
)

run_dace_gpu = DaCeBackendFactory(gpu=True, cached=True, auto_optimize=True)
run_dace_gpu_with_temporaries = DaCeBackendFactory(
gpu=True, cached=True, auto_optimize=True, use_temporaries=True
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import warnings
from inspect import currentframe, getframeinfo
from pathlib import Path
from typing import Any, Mapping, Optional, Sequence
from typing import Any, Callable, Mapping, Optional, Sequence

import dace
import numpy as np
Expand All @@ -24,10 +24,10 @@
import gt4py.next.iterator.ir as itir
from gt4py.next import common
from gt4py.next.iterator import transforms as itir_transforms
from gt4py.next.type_system import type_translation
from gt4py.next.type_system import type_specifications as ts

from .itir_to_sdfg import ItirToSDFG
from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims
from .utility import connectivity_identifier, filter_connectivities, get_sorted_dims


try:
Expand Down Expand Up @@ -67,6 +67,10 @@ def preprocess_program(
program: itir.FencilDefinition,
offset_provider: Mapping[str, Any],
lift_mode: itir_transforms.LiftMode,
symbolic_domain_sizes: Optional[dict[str, str]] = None,
temporary_extraction_heuristics: Optional[
Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]]
] = None,
unroll_reduce: bool = False,
):
node = itir_transforms.apply_common_transforms(
Expand All @@ -75,6 +79,8 @@ def preprocess_program(
force_inline_lambda_args=True,
lift_mode=lift_mode,
offset_provider=offset_provider,
symbolic_domain_sizes=symbolic_domain_sizes,
temporary_extraction_heuristics=temporary_extraction_heuristics,
unroll_reduce=unroll_reduce,
)

Expand Down Expand Up @@ -180,7 +186,10 @@ def get_sdfg_args(
"""
offset_provider = kwargs["offset_provider"]

neighbor_tables = filter_neighbor_tables(offset_provider)
neighbor_tables: dict[str, common.NeighborTable] = {}
for offset, connectivity in filter_connectivities(offset_provider).items():
assert isinstance(connectivity, common.NeighborTable)
neighbor_tables[offset] = connectivity
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

dace_args = get_args(sdfg, args, use_field_canonical_representation)
Expand Down Expand Up @@ -211,12 +220,16 @@ def get_sdfg_args(

def build_sdfg_from_itir(
program: itir.FencilDefinition,
*args,
arg_types: list[ts.TypeSpec],
offset_provider: dict[str, Any],
auto_optimize: bool = False,
on_gpu: bool = False,
column_axis: Optional[common.Dimension] = None,
lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE,
symbolic_domain_sizes: Optional[dict[str, str]] = None,
temporary_extraction_heuristics: Optional[
Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]]
] = None,
load_sdfg_from_file: bool = False,
save_sdfg: bool = True,
use_field_canonical_representation: bool = True,
Expand All @@ -225,12 +238,13 @@ def build_sdfg_from_itir(
Args:
program: The Fencil that should be translated.
*args: Arguments for which the fencil should be called.
arg_types: Types of the arguments passed to the fencil.
offset_provider: The set of offset providers that should be used.
auto_optimize: Apply DaCe's `auto_optimize` heuristic.
on_gpu: Performs the translation for GPU, defaults to `False`.
column_axis: The column axis to be used, defaults to `None`.
lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`.
symbolic_domain_sizes: Used for generation of liskov bindings when temporaries are enabled.
load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only.
save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`.
use_field_canonical_representation: If `True`, assume that the fields dimensions are sorted alphabetically.
Expand All @@ -245,10 +259,10 @@ def build_sdfg_from_itir(
sdfg.validate()
return sdfg

arg_types = [type_translation.from_value(arg) for arg in args]

# visit ITIR and generate SDFG
program, tmps = preprocess_program(program, offset_provider, lift_mode)
program, tmps = preprocess_program(
program, offset_provider, lift_mode, symbolic_domain_sizes, temporary_extraction_heuristics
)
sdfg_genenerator = ItirToSDFG(
arg_types, offset_provider, tmps, use_field_canonical_representation, column_axis
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

import gt4py.eve as eve
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
from gt4py.next.common import NeighborTable
from gt4py.next.common import Connectivity
from gt4py.next.iterator import (
ir as itir,
transforms as itir_transforms,
type_inference as itir_typing,
)
from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef
from gt4py.next.type_system import type_specifications as ts, type_translation
from gt4py.next.type_system import type_specifications as ts, type_translation as tt

from .itir_to_tasklet import (
Context,
Expand All @@ -44,10 +44,10 @@
as_scalar_type,
connectivity_identifier,
dace_debuginfo,
filter_neighbor_tables,
filter_connectivities,
flatten_list,
get_sorted_dims,
get_used_neighbor_tables,
get_used_connectivities,
map_nested_sdfg_symbols,
new_array_symbols,
unique_name,
Expand Down Expand Up @@ -119,7 +119,7 @@ def _make_array_shape_and_strides(
"""
dtype = dace.int32
sorted_dims = get_sorted_dims(dims) if sort_dims else list(enumerate(dims))
neighbor_tables = filter_neighbor_tables(offset_provider)
neighbor_tables = filter_connectivities(offset_provider)
shape = [
(
neighbor_tables[dim.value].max_neighbors
Expand Down Expand Up @@ -163,7 +163,7 @@ class ItirToSDFG(eve.NodeVisitor):
def __init__(
self,
param_types: list[ts.TypeSpec],
offset_provider: dict[str, NeighborTable],
offset_provider: dict[str, Connectivity | Dimension],
tmps: list[itir_transforms.global_tmps.Temporary],
use_field_canonical_representation: bool,
column_axis: Optional[Dimension] = None,
Expand Down Expand Up @@ -292,7 +292,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
self.node_types = itir_typing.infer_all(node)

# Filter neighbor tables from offset providers.
neighbor_tables = get_used_neighbor_tables(node, self.offset_provider)
neighbor_tables = get_used_connectivities(node, self.offset_provider)

# Add program parameters as SDFG storages.
for param, type_ in zip(node.params, self.param_types):
Expand All @@ -312,7 +312,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):

# Add connectivities as SDFG storages.
for offset, offset_provider in neighbor_tables.items():
scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype)
scalar_kind = tt.get_scalar_kind(offset_provider.index_type)
local_dim = Dimension(offset, kind=DimensionKind.LOCAL)
type_ = ts.FieldType(
[offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind)
Expand Down Expand Up @@ -382,7 +382,7 @@ def visit_StencilClosure(
closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True)

input_names = [str(inp.id) for inp in node.inputs]
neighbor_tables = get_used_neighbor_tables(node, self.offset_provider)
neighbor_tables = get_used_connectivities(node, self.offset_provider)
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state)
Expand Down Expand Up @@ -574,7 +574,7 @@ def _visit_scan_stencil_closure(
)

assert isinstance(node.output, SymRef)
neighbor_tables = get_used_neighbor_tables(node, self.offset_provider)
neighbor_tables = get_used_connectivities(node, self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

Expand Down Expand Up @@ -737,7 +737,7 @@ def _visit_parallel_stencil_closure(
tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ...
],
) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]:
neighbor_tables = get_used_neighbor_tables(node, self.offset_provider)
neighbor_tables = get_used_connectivities(node, self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@

import gt4py.eve.codegen
from gt4py import eve
from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing
from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value
from gt4py.next import Dimension, type_inference as next_typing
from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value, Connectivity
from gt4py.next.iterator import ir as itir, type_inference as itir_typing
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.iterator.ir import FunCall, Lambda
from gt4py.next.iterator.type_inference import Val
from gt4py.next.type_system import type_specifications as ts
Expand All @@ -36,7 +35,7 @@
connectivity_identifier,
dace_debuginfo,
flatten_list,
get_used_neighbor_tables,
get_used_connectivities,
map_nested_sdfg_symbols,
new_array_symbols,
unique_name,
Expand Down Expand Up @@ -184,7 +183,7 @@ def _visit_lift_in_neighbors_reduction(
transformer: "PythonTaskletCodegen",
node: itir.FunCall,
node_args: Sequence[IteratorExpr | list[ValueExpr]],
offset_provider: NeighborTableOffsetProvider,
offset_provider: Connectivity,
map_entry: dace.nodes.MapEntry,
map_exit: dace.nodes.MapExit,
neighbor_index_node: dace.nodes.AccessNode,
Expand Down Expand Up @@ -229,7 +228,7 @@ def _visit_lift_in_neighbors_reduction(
assert isinstance(y, ValueExpr)
input_nodes[x] = y.value

neighbor_tables = get_used_neighbor_tables(node.args[0], transformer.offset_provider)
neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider)
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

parent_sdfg = transformer.context.body
Expand Down Expand Up @@ -328,7 +327,7 @@ def builtin_neighbors(
offset_dim = offset_literal.value
assert isinstance(offset_dim, str)
offset_provider = transformer.offset_provider[offset_dim]
if not isinstance(offset_provider, NeighborTableOffsetProvider):
if not isinstance(offset_provider, Connectivity):
raise NotImplementedError(
"Neighbor reduction only implemented for connectivity based on neighbor tables."
)
Expand Down Expand Up @@ -917,7 +916,7 @@ def visit_Lambda(
]:
func_name = f"lambda_{abs(hash(node)):x}"
neighbor_tables = (
get_used_neighbor_tables(node, self.offset_provider) if use_neighbor_tables else {}
get_used_connectivities(node, self.offset_provider) if use_neighbor_tables else {}
)
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]

Expand Down Expand Up @@ -1070,7 +1069,7 @@ def _visit_call(self, node: itir.FunCall):
store, self.context.body.arrays[store]
)

neighbor_tables = get_used_neighbor_tables(node.fun, self.offset_provider)
neighbor_tables = get_used_connectivities(node.fun, self.offset_provider)
for offset in neighbor_tables.keys():
var = connectivity_identifier(offset)
nsdfg_inputs[var] = dace.Memlet.from_array(var, self.context.body.arrays[var])
Expand Down Expand Up @@ -1216,7 +1215,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]:
offset_node = self.visit(tail[1])[0]
assert offset_node.dtype in dace.dtypes.INTEGER_TYPES

if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider):
if isinstance(self.offset_provider[offset_dim], Connectivity):
offset_provider = self.offset_provider[offset_dim]
connectivity = self.context.state.add_access(
connectivity_identifier(offset_dim), debuginfo=di
Expand All @@ -1225,20 +1224,12 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]:
shifted_dim = offset_provider.origin_axis.value
target_dim = offset_provider.neighbor_axis.value
args = [
ValueExpr(connectivity, offset_provider.table.dtype),
ValueExpr(connectivity, _INDEX_DTYPE),
ValueExpr(iterator.indices[shifted_dim], offset_node.dtype),
offset_node,
]
internals = [f"{arg.value.data}_v" for arg in args]
expr = f"{internals[0]}[{internals[1]}, {internals[2]}]"
elif isinstance(self.offset_provider[offset_dim], StridedNeighborOffsetProvider):
offset_provider = self.offset_provider[offset_dim]

shifted_dim = offset_provider.origin_axis.value
target_dim = offset_provider.neighbor_axis.value
args = [ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node]
internals = [f"{arg.value.data}_v" for arg in args]
expr = f"{internals[0]} * {offset_provider.max_neighbors} + {internals[1]}"
else:
assert isinstance(self.offset_provider[offset_dim], Dimension)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import gt4py.next.iterator.ir as itir
from gt4py import eve
from gt4py.next import Dimension
from gt4py.next.common import NeighborTable
from gt4py.next.common import Connectivity
from gt4py.next.type_system import type_specifications as ts


Expand Down Expand Up @@ -60,20 +60,20 @@ def as_scalar_type(typestr: str) -> ts.ScalarType:
return ts.ScalarType(kind)


def filter_neighbor_tables(offset_provider: Mapping[str, Any]) -> dict[str, NeighborTable]:
def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Connectivity]:
return {
offset: table
for offset, table in offset_provider.items()
if isinstance(table, NeighborTable)
if isinstance(table, Connectivity)
}


def get_used_neighbor_tables(
def get_used_connectivities(
node: itir.Node, offset_provider: Mapping[str, Any]
) -> dict[str, NeighborTable]:
neighbor_tables = filter_neighbor_tables(offset_provider)
) -> dict[str, Connectivity]:
connectivities = filter_connectivities(offset_provider)
offset_dims = set(eve.walk_values(node).if_isinstance(itir.OffsetLiteral).getattr("value"))
return {offset: neighbor_tables[offset] for offset in offset_dims if offset in neighbor_tables}
return {offset: connectivities[offset] for offset in offset_dims if offset in connectivities}


def connectivity_identifier(name: str) -> str:
Expand Down
Loading

0 comments on commit 705530c

Please sign in to comment.