Skip to content

Commit

Permalink
feat[next][dace]: DaCe support for temporaries (#1351)
Browse files Browse the repository at this point in the history
Temporaries are implemented in DaCe backend as transient arrays. This PR adds extraction of temporaries and generation of corresponding transient arrays in the SDFG representation.
  • Loading branch information
edopao authored Feb 6, 2024
1 parent e462a2e commit 6509dd9
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dace.codegen.compiled_sdfg import CompiledSDFG
from dace.sdfg import utils as sdutils
from dace.transformation.auto import auto_optimize as autoopt
from dace.transformation.interstate import RefineNestedAccess

import gt4py.next.allocators as next_allocators
import gt4py.next.iterator.ir as itir
Expand Down Expand Up @@ -71,7 +72,7 @@ def preprocess_program(
lift_mode: itir_transforms.LiftMode,
unroll_reduce: bool = False,
):
return itir_transforms.apply_common_transforms(
node = itir_transforms.apply_common_transforms(
program,
common_subexpression_elimination=False,
force_inline_lambda_args=True,
Expand All @@ -80,6 +81,21 @@ def preprocess_program(
unroll_reduce=unroll_reduce,
)

if isinstance(node, itir_transforms.global_tmps.FencilWithTemporaries):
fencil_definition = node.fencil
tmps = node.tmps

elif isinstance(node, itir.FencilDefinition):
fencil_definition = node
tmps = []

else:
raise TypeError(
f"Expected 'FencilDefinition' or 'FencilWithTemporaries', got '{type(program).__name__}'."
)

return fencil_definition, tmps


def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
sdfg_params: Sequence[str] = sdfg.arg_names
Expand Down Expand Up @@ -160,6 +176,7 @@ def get_stride_args(
def get_cache_id(
build_type: str,
build_for_gpu: bool,
lift_mode: itir_transforms.LiftMode,
program: itir.FencilDefinition,
arg_types: Sequence[ts.TypeSpec],
column_axis: Optional[common.Dimension],
Expand All @@ -185,6 +202,7 @@ def offset_invariants(offset):
for arg in (
build_type,
build_for_gpu,
lift_mode,
program,
*arg_types,
column_axis,
Expand Down Expand Up @@ -272,17 +290,17 @@ def build_sdfg_from_itir(
sdfg.validate()
return sdfg

# TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force
# `lift_more` to `FORCE_INLINE` mode.
lift_mode = itir_transforms.LiftMode.FORCE_INLINE
arg_types = [type_translation.from_value(arg) for arg in args]

# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
program, tmps = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, tmps, column_axis)
sdfg = sdfg_genenerator.visit(program)
if sdfg is None:
raise RuntimeError(f"Visit failed for program {program.id}.")
elif tmps:
# This pass is needed to avoid transformation errors in SDFG inlining, because temporaries are using offsets
sdfg.apply_transformations_repeated(RefineNestedAccess)

for nested_sdfg in sdfg.all_sdfgs_recursive():
if not nested_sdfg.debuginfo:
Expand Down Expand Up @@ -338,7 +356,9 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):

arg_types = [type_translation.from_value(arg) for arg in args]

cache_id = get_cache_id(build_type, on_gpu, program, arg_types, column_axis, offset_provider)
cache_id = get_cache_id(
build_type, on_gpu, lift_mode, program, arg_types, column_axis, offset_provider
)
if build_cache is not None and cache_id in build_cache:
# retrieve SDFG program from build cache
sdfg_program = build_cache[cache_id]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import gt4py.eve as eve
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
from gt4py.next.common import NeighborTable
from gt4py.next.iterator import ir as itir, type_inference as itir_typing
from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef
from gt4py.next.iterator import (
ir as itir,
transforms as itir_transforms,
type_inference as itir_typing,
)
from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef
from gt4py.next.type_system import type_specifications as ts, type_translation

from .itir_to_tasklet import (
Expand All @@ -36,6 +40,7 @@
from .utility import (
add_mapped_nested_sdfg,
as_dace_type,
as_scalar_type,
connectivity_identifier,
create_memlet_at,
create_memlet_full,
Expand All @@ -44,6 +49,7 @@
flatten_list,
get_sorted_dims,
map_nested_sdfg_symbols,
new_array_symbols,
unique_name,
unique_var_name,
)
Expand Down Expand Up @@ -154,12 +160,14 @@ def __init__(
self,
param_types: list[ts.TypeSpec],
offset_provider: dict[str, NeighborTable],
tmps: list[itir_transforms.global_tmps.Temporary],
column_axis: Optional[Dimension] = None,
):
self.param_types = param_types
self.column_axis = column_axis
self.offset_provider = offset_provider
self.storage_types = {}
self.tmps = tmps

def add_storage(
self,
Expand Down Expand Up @@ -189,6 +197,70 @@ def add_storage(
raise NotImplementedError()
self.storage_types[name] = type_

def add_storage_for_temporaries(
self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG
) -> dict[str, str]:
symbol_map: dict[str, TaskletExpr] = {}
# The shape of temporary arrays might be defined based on scalar values passed as program arguments.
# Here we collect these values in a symbol map.
tmp_ids = set(tmp.id for tmp in self.tmps)
for sym in node_params:
if sym.id not in tmp_ids and sym.kind != "Iterator":
name_ = str(sym.id)
type_ = self.storage_types[name_]
assert isinstance(type_, ts.ScalarType)
symbol_map[name_] = SymbolExpr(name_, as_dace_type(type_))

tmp_symbols: dict[str, str] = {}
for tmp in self.tmps:
tmp_name = str(tmp.id)

# We visit the domain of the temporary field, passing the set of available symbols.
assert isinstance(tmp.domain, itir.FunCall)
self.node_types.update(itir_typing.infer_all(tmp.domain))
domain_ctx = Context(program_sdfg, defs_state, symbol_map)
tmp_domain = self._visit_domain(tmp.domain, domain_ctx)

# We build the FieldType for this temporary array.
dims: list[Dimension] = []
for dim, _ in tmp_domain:
dims.append(
Dimension(
value=dim,
kind=(
DimensionKind.VERTICAL
if self.column_axis is not None and self.column_axis.value == dim
else DimensionKind.HORIZONTAL
),
)
)
assert isinstance(tmp.dtype, str)
type_ = ts.FieldType(dims=dims, dtype=as_scalar_type(tmp.dtype))
self.storage_types[tmp_name] = type_

# N.B.: skip generation of symbolic strides and just let dace assign default strides, for now.
# Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics
# to assign optimal stride values.
tmp_shape, _ = new_array_symbols(tmp_name, len(dims))
tmp_offset = [
dace.symbol(unique_name(f"{tmp_name}_offset{i}")) for i in range(len(dims))
]
_, tmp_array = program_sdfg.add_array(
tmp_name, tmp_shape, as_dace_type(type_.dtype), offset=tmp_offset, transient=True
)

# Loop through all dimensions to visit the symbolic expressions for array shape and offset.
# These expressions are later mapped to interstate symbols.
for (_, (begin, end)), offset_sym, shape_sym in zip(
tmp_domain,
tmp_array.offset,
tmp_array.shape,
):
tmp_symbols[str(offset_sym)] = f"0 - {begin.value}"
tmp_symbols[str(shape_sym)] = f"{end.value} - {begin.value}"

return tmp_symbols

def get_output_nodes(
self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState
) -> dict[str, dace.nodes.AccessNode]:
Expand All @@ -204,7 +276,7 @@ def get_output_nodes(
def visit_FencilDefinition(self, node: itir.FencilDefinition):
program_sdfg = dace.SDFG(name=node.id)
program_sdfg.debuginfo = dace_debuginfo(node)
last_state = program_sdfg.add_state("program_entry", True)
entry_state = program_sdfg.add_state("program_entry", is_start_block=True)
self.node_types = itir_typing.infer_all(node)

# Filter neighbor tables from offset providers.
Expand All @@ -214,6 +286,20 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
for param, type_ in zip(node.params, self.param_types):
self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables)

if self.tmps:
tmp_symbols = self.add_storage_for_temporaries(node.params, entry_state, program_sdfg)
# on the first interstate edge define symbols for shape and offsets of temporary arrays
last_state = program_sdfg.add_state("init_symbols_for_temporaries")
program_sdfg.add_edge(
entry_state,
last_state,
dace.InterstateEdge(
assignments=tmp_symbols,
),
)
else:
last_state = entry_state

# Add connectivities as SDFG storages.
for offset, offset_provider in neighbor_tables.items():
scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ def as_dace_type(type_: ts.ScalarType):
raise ValueError(f"Scalar type '{type_}' not supported.")


def as_scalar_type(typestr: str) -> ts.ScalarType:
try:
kind = getattr(ts.ScalarKind, typestr.upper())
except AttributeError:
raise ValueError(f"Data type {typestr} not supported.")
return ts.ScalarType(kind)


def filter_neighbor_tables(offset_provider: dict[str, Any]):
return {
offset: table
Expand Down

0 comments on commit 6509dd9

Please sign in to comment.