Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: DaCe support for temporaries #1351

Merged
merged 32 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3c38362
[dace] Lift support with ITIR temporaries
edopao Oct 13, 2023
8fd0fda
[dace] Replace unicode symbols from SSA ITIR pass
edopao Oct 17, 2023
1020248
Merge branch 'GridTools:main' into dace-temporaries
edopao Oct 17, 2023
6a46682
[dace] Disable simplify for SDFG with temporaries
edopao Oct 18, 2023
7b251f8
[dace] Bugfix for array offset
edopao Oct 18, 2023
e2ff6f6
[dace] Fix for array offset in simplify pass
edopao Oct 19, 2023
2380efb
Revert "[dace] Bugfix for array offset"
edopao Oct 19, 2023
5f10d3d
[dace] Fix for ITIR temporary pass
edopao Oct 19, 2023
32052b6
[dace] Fix type-check error
edopao Oct 19, 2023
842a340
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Oct 20, 2023
e65cbb2
[dace] Fix code-stile errors
edopao Oct 20, 2023
326748a
[dace] Add comments
edopao Oct 20, 2023
fec7ae4
[dace] Improve string to enum conversion
edopao Oct 23, 2023
fa7482d
Merge branch 'GridTools:main' into dace-temporaries
edopao Oct 31, 2023
81f3f74
[dace] Add kwarg for lift_mode
edopao Nov 2, 2023
240de21
[dace] Force temporaries for LIFT
edopao Nov 7, 2023
22041e4
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Nov 23, 2023
112a594
[dace] Minor edit
edopao Nov 23, 2023
b8a5db7
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Nov 23, 2023
da795ef
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Dec 5, 2023
c387b38
[dace] Re-added simplify step
edopao Dec 5, 2023
6c5668f
[dace] Conform with error message guidelines
edopao Dec 11, 2023
4187ee1
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Dec 13, 2023
67fc2cf
[dace] Enable lift_mode arg
edopao Dec 13, 2023
3e80907
[dace] Fix formatting
edopao Dec 14, 2023
55250db
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Feb 2, 2024
8e203e8
Minor edit
edopao Feb 2, 2024
15a693f
Fix error (lift_mode was ignored)
edopao Feb 5, 2024
1a7da6a
Minor edit
edopao Feb 5, 2024
3529244
[dace] Remove pass for SSA identifiers
edopao Feb 5, 2024
9aa3629
[dace] Review comments
edopao Feb 5, 2024
8b74107
[dace] Remove tasklet for symbolic expressions
edopao Feb 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dace.codegen.compiled_sdfg import CompiledSDFG
from dace.sdfg import utils as sdutils
from dace.transformation.auto import auto_optimize as autoopt
from dace.transformation.interstate import RefineNestedAccess

import gt4py.next.allocators as next_allocators
import gt4py.next.iterator.ir as itir
Expand Down Expand Up @@ -71,7 +72,7 @@ def preprocess_program(
lift_mode: itir_transforms.LiftMode,
unroll_reduce: bool = False,
):
return itir_transforms.apply_common_transforms(
node = itir_transforms.apply_common_transforms(
program,
common_subexpression_elimination=False,
force_inline_lambda_args=True,
Expand All @@ -80,6 +81,21 @@ def preprocess_program(
unroll_reduce=unroll_reduce,
)

if isinstance(node, itir_transforms.global_tmps.FencilWithTemporaries):
fencil_definition = node.fencil
tmps = node.tmps

elif isinstance(node, itir.FencilDefinition):
fencil_definition = node
tmps = []

else:
raise TypeError(
f"Expected 'FencilDefinition' or 'FencilWithTemporaries', got '{type(program).__name__}'."
)

return fencil_definition, tmps


def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
sdfg_params: Sequence[str] = sdfg.arg_names
Expand Down Expand Up @@ -160,6 +176,7 @@ def get_stride_args(
def get_cache_id(
build_type: str,
build_for_gpu: bool,
lift_mode: itir_transforms.LiftMode,
program: itir.FencilDefinition,
arg_types: Sequence[ts.TypeSpec],
column_axis: Optional[common.Dimension],
Expand All @@ -185,6 +202,7 @@ def offset_invariants(offset):
for arg in (
build_type,
build_for_gpu,
lift_mode,
program,
*arg_types,
column_axis,
Expand Down Expand Up @@ -272,17 +290,17 @@ def build_sdfg_from_itir(
sdfg.validate()
return sdfg

# TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force
# `lift_more` to `FORCE_INLINE` mode.
lift_mode = itir_transforms.LiftMode.FORCE_INLINE
arg_types = [type_translation.from_value(arg) for arg in args]

# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
program, tmps = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, tmps, column_axis)
sdfg = sdfg_genenerator.visit(program)
if sdfg is None:
raise RuntimeError(f"Visit failed for program {program.id}.")
elif tmps:
# This pass is needed to avoid transformation errors in SDFG inlining, because temporaries are using offsets
sdfg.apply_transformations_repeated(RefineNestedAccess)

for nested_sdfg in sdfg.all_sdfgs_recursive():
if not nested_sdfg.debuginfo:
Expand Down Expand Up @@ -338,7 +356,9 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):

arg_types = [type_translation.from_value(arg) for arg in args]

cache_id = get_cache_id(build_type, on_gpu, program, arg_types, column_axis, offset_provider)
cache_id = get_cache_id(
build_type, on_gpu, lift_mode, program, arg_types, column_axis, offset_provider
)
if build_cache is not None and cache_id in build_cache:
# retrieve SDFG program from build cache
sdfg_program = build_cache[cache_id]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import gt4py.eve as eve
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
from gt4py.next.common import NeighborTable
from gt4py.next.iterator import ir as itir, type_inference as itir_typing
from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef
from gt4py.next.iterator import (
ir as itir,
transforms as itir_transforms,
type_inference as itir_typing,
)
from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef
from gt4py.next.type_system import type_specifications as ts, type_translation

from .itir_to_tasklet import (
Expand All @@ -36,6 +40,7 @@
from .utility import (
add_mapped_nested_sdfg,
as_dace_type,
as_scalar_type,
connectivity_identifier,
create_memlet_at,
create_memlet_full,
Expand All @@ -44,6 +49,7 @@
flatten_list,
get_sorted_dims,
map_nested_sdfg_symbols,
new_array_symbols,
unique_name,
unique_var_name,
)
Expand Down Expand Up @@ -154,12 +160,14 @@ def __init__(
self,
param_types: list[ts.TypeSpec],
offset_provider: dict[str, NeighborTable],
tmps: list[itir_transforms.global_tmps.Temporary],
column_axis: Optional[Dimension] = None,
):
self.param_types = param_types
self.column_axis = column_axis
self.offset_provider = offset_provider
self.storage_types = {}
self.tmps = tmps

def add_storage(
self,
Expand Down Expand Up @@ -189,6 +197,103 @@ def add_storage(
raise NotImplementedError()
self.storage_types[name] = type_

def generate_temporaries(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about the name.
What confuses me is that self.tmps are the temporaries, or not?
And later you call this function only if self.tmps is not empty, which seems a bit contradictory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, function name changed to add_storage_for_temporaries

self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG
) -> dict[str, TaskletExpr]:
"""Create a table of symbols which are used to define array shape, stride and offset for temporaries."""
symbol_map: dict[str, TaskletExpr] = {}
# The shape of temporary arrays might be defined based on the shape of other input/output fields.
# Therefore, here we collect the symbols used to define data-field parameters that are not temporaries.
for sym in node_params:
if all([sym.id != tmp.id for tmp in self.tmps]) and sym.kind != "Iterator":
edopao marked this conversation as resolved.
Show resolved Hide resolved
name_ = str(sym.id)
type_ = self.storage_types[name_]
assert isinstance(type_, ts.ScalarType)
symbol_map[name_] = SymbolExpr(name_, as_dace_type(type_))

symbol_dtype = dace.int64
tmp_symbols: dict[str, TaskletExpr] = {}
for tmp in self.tmps:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You spend ~70 lines on filling tmp_symbols and does not explains what the loop does or the variable represents.
But for the small loop above you write a description, this seems a bit strange.

tmp_name = str(tmp.id)
assert isinstance(tmp.domain, itir.FunCall)
self.node_types.update(itir_typing.infer_all(tmp.domain))
domain_ctx = Context(program_sdfg, defs_state, symbol_map)
tmp_domain = self._visit_domain(tmp.domain, domain_ctx)

# First build the FieldType for this temporary array
dims: list[Dimension] = []
for dim, _ in tmp_domain:
dims.append(
Dimension(
value=dim,
kind=(
DimensionKind.VERTICAL
if self.column_axis is not None and self.column_axis.value == dim
else DimensionKind.HORIZONTAL
),
)
)
assert isinstance(tmp.dtype, str)
type_ = ts.FieldType(dims=dims, dtype=as_scalar_type(tmp.dtype))
self.storage_types[tmp_name] = type_

# N.B.: skip generation of symbolic strides and just let dace assign default strides, for now.
# Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics
# to assign optimal stride values.
tmp_shape, _ = new_array_symbols(tmp_name, len(dims))
tmp_offset = [
dace.symbol(unique_name(f"{tmp_name}_offset{i}")) for i in range(len(dims))
]
_, tmp_array = program_sdfg.add_array(
tmp_name, tmp_shape, as_dace_type(type_.dtype), offset=tmp_offset, transient=True
)

# Loop through all dimensions to initialize the array parameters for shape and offsets
for (_, (begin, end)), offset_sym, shape_sym in zip(
edopao marked this conversation as resolved.
Show resolved Hide resolved
tmp_domain,
tmp_array.offset,
tmp_array.shape,
):
offset_tasklet = defs_state.add_tasklet(
"offset",
code=f"__result = - {begin.value}",
inputs={},
outputs={"__result"},
)
offset_var = unique_var_name()
program_sdfg.add_scalar(offset_var, symbol_dtype, transient=True)
offset_node = defs_state.add_access(offset_var)
defs_state.add_edge(
offset_tasklet,
"__result",
offset_node,
None,
dace.Memlet.simple(offset_var, "0"),
)

shape_tasklet = defs_state.add_tasklet(
"shape",
code=f"__result = {end.value} - {begin.value}",
inputs={},
outputs={"__result"},
)
shape_var = unique_var_name()
program_sdfg.add_scalar(shape_var, symbol_dtype, transient=True)
shape_node = defs_state.add_access(shape_var)
defs_state.add_edge(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to my understanding, you should be able to put even complex expressions in assignments on interstate edge.

shape_tasklet,
"__result",
shape_node,
None,
dace.Memlet.simple(shape_var, "0"),
)

# The transient scalars containing the array parameters are later mapped to interstate symbols
tmp_symbols[str(offset_sym)] = offset_var
tmp_symbols[str(shape_sym)] = shape_var

return tmp_symbols

def get_output_nodes(
self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState
) -> dict[str, dace.nodes.AccessNode]:
Expand All @@ -204,7 +309,7 @@ def get_output_nodes(
def visit_FencilDefinition(self, node: itir.FencilDefinition):
program_sdfg = dace.SDFG(name=node.id)
program_sdfg.debuginfo = dace_debuginfo(node)
last_state = program_sdfg.add_state("program_entry", True)
entry_state = program_sdfg.add_state("program_entry", True)
self.node_types = itir_typing.infer_all(node)

# Filter neighbor tables from offset providers.
Expand All @@ -214,6 +319,9 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
for param, type_ in zip(node.params, self.param_types):
self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables)

if self.tmps:
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
tmp_symbols = self.generate_temporaries(node.params, entry_state, program_sdfg)

# Add connectivities as SDFG storages.
for offset, offset_provider in neighbor_tables.items():
scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype)
Expand All @@ -231,6 +339,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
)

# Create a nested SDFG for all stencil closures.
last_state = entry_state
for closure in node.closures:
# Translate the closure and its stencil's body to an SDFG.
closure_sdfg, input_names, output_names = self.visit(
Expand Down Expand Up @@ -269,6 +378,11 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo)
last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet)

if self.tmps:
# on the first interstate edge define symbols for shape and offsets of temporary arrays
inter_state_edge = program_sdfg.out_edges(entry_state)[0]
inter_state_edge.data.assignments.update(tmp_symbols)

# Create the call signature for the SDFG.
# Only the arguments requiered by the Fencil, i.e. `node.params` are added as poitional arguments.
# The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ def as_dace_type(type_: ts.ScalarType):
raise ValueError(f"Scalar type '{type_}' not supported.")


def as_scalar_type(typestr: str) -> ts.ScalarType:
try:
kind = getattr(ts.ScalarKind, typestr.upper())
except AttributeError:
raise ValueError(f"Data type {typestr} not supported.")
return ts.ScalarType(kind)


def filter_neighbor_tables(offset_provider: dict[str, Any]):
return {
offset: table
Expand Down
Loading