Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: DaCe support for temporaries #1351

Merged
merged 32 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3c38362
[dace] Lift support with ITIR temporaries
edopao Oct 13, 2023
8fd0fda
[dace] Replace unicode symbols from SSA ITIR pass
edopao Oct 17, 2023
1020248
Merge branch 'GridTools:main' into dace-temporaries
edopao Oct 17, 2023
6a46682
[dace] Disable simplify for SDFG with temporaries
edopao Oct 18, 2023
7b251f8
[dace] Bugfix for array offset
edopao Oct 18, 2023
e2ff6f6
[dace] Fix for array offset in simplify pass
edopao Oct 19, 2023
2380efb
Revert "[dace] Bugfix for array offset"
edopao Oct 19, 2023
5f10d3d
[dace] Fix for ITIR temporary pass
edopao Oct 19, 2023
32052b6
[dace] Fix type-check error
edopao Oct 19, 2023
842a340
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Oct 20, 2023
e65cbb2
[dace] Fix code-stile errors
edopao Oct 20, 2023
326748a
[dace] Add comments
edopao Oct 20, 2023
fec7ae4
[dace] Improve string to enum conversion
edopao Oct 23, 2023
fa7482d
Merge branch 'GridTools:main' into dace-temporaries
edopao Oct 31, 2023
81f3f74
[dace] Add kwarg for lift_mode
edopao Nov 2, 2023
240de21
[dace] Force temporaries for LIFT
edopao Nov 7, 2023
22041e4
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Nov 23, 2023
112a594
[dace] Minor edit
edopao Nov 23, 2023
b8a5db7
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Nov 23, 2023
da795ef
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Dec 5, 2023
c387b38
[dace] Re-added simplify step
edopao Dec 5, 2023
6c5668f
[dace] Conform with error message guidelines
edopao Dec 11, 2023
4187ee1
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Dec 13, 2023
67fc2cf
[dace] Enable lift_mode arg
edopao Dec 13, 2023
3e80907
[dace] Fix formatting
edopao Dec 14, 2023
55250db
Merge remote-tracking branch 'origin/main' into dace-temporaries
edopao Feb 2, 2024
8e203e8
Minor edit
edopao Feb 2, 2024
15a693f
Fix error (lift_mode was ignored)
edopao Feb 5, 2024
1a7da6a
Minor edit
edopao Feb 5, 2024
3529244
[dace] Remove pass for SSA identifiers
edopao Feb 5, 2024
9aa3629
[dace] Review comments
edopao Feb 5, 2024
8b74107
[dace] Remove tasklet for symbolic expressions
edopao Feb 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import numpy as np
from dace.codegen.compiled_sdfg import CompiledSDFG
from dace.transformation.auto import auto_optimize as autoopt
from dace.transformation.interstate import RefineNestedAccess

import gt4py.next.iterator.ir as itir
from gt4py.next.common import Dimension, Domain, UnitRange, is_field
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider
from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms, global_tmps
from gt4py.next.otf.compilation import cache
from gt4py.next.program_processors.processor_interface import program_executor
from gt4py.next.type_system import type_specifications as ts, type_translation
Expand Down Expand Up @@ -54,11 +55,13 @@ def convert_arg(arg: Any):
return arg


def preprocess_program(program: itir.FencilDefinition, offset_provider: Mapping[str, Any]):
def preprocess_program(
program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: LiftMode
):
program = apply_common_transforms(
program,
offset_provider=offset_provider,
lift_mode=LiftMode.FORCE_INLINE,
lift_mode=lift_mode,
common_subexpression_elimination=False,
)
return program
Expand Down Expand Up @@ -161,10 +164,22 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
sdfg_program = build_cache[cache_id]
sdfg = sdfg_program.sdfg
else:
program = preprocess_program(program, offset_provider, LiftMode.FORCE_INLINE)
if all([ItirToSDFG._check_no_lifts(node) for node in program.closures]):
tmps = []
else:
program_with_tmps: global_tmps.FencilWithTemporaries = preprocess_program(
program, offset_provider, LiftMode.FORCE_TEMPORARIES
)
program = program_with_tmps.fencil
tmps = program_with_tmps.tmps
edopao marked this conversation as resolved.
Show resolved Hide resolved

# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, tmps, column_axis)
sdfg = sdfg_genenerator.visit(program)
if tmps:
# This pass is needed to avoid transformation errors in SDFG inlining, because temporaries are using offsets
sdfg.apply_transformations_repeated(RefineNestedAccess)
sdfg.simplify()

# set array storage for GPU execution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
from gt4py.next.iterator import ir as itir, type_inference as itir_typing
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef
from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef
from gt4py.next.iterator.transforms import global_tmps
from gt4py.next.type_system import type_specifications as ts, type_translation

from .itir_to_tasklet import (
Expand All @@ -34,6 +35,7 @@
from .utility import (
add_mapped_nested_sdfg,
as_dace_type,
as_scalar_type,
connectivity_identifier,
create_memlet_at,
create_memlet_full,
Expand Down Expand Up @@ -101,12 +103,14 @@ def __init__(
self,
param_types: list[ts.TypeSpec],
offset_provider: dict[str, NeighborTableOffsetProvider],
tmps: list[global_tmps.Temporary],
column_axis: Optional[Dimension] = None,
):
self.param_types = param_types
self.column_axis = column_axis
self.offset_provider = offset_provider
self.storage_types = {}
self.tmps = tmps

def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True):
if isinstance(type_, ts.FieldType):
Expand All @@ -125,16 +129,127 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset
raise NotImplementedError()
self.storage_types[name] = type_

def generate_temporaries(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about the name.
What confuses me is that self.tmps are the temporaries, or not?
And later you call this function only if self.tmps is not empty, which seems a bit contradictory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, function name changed to add_storage_for_temporaries

self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG
) -> dict[str, IteratorExpr | ValueExpr | SymbolExpr]:
"""Create a table of symbols which are used to define array shape, stride and offset for temporaries."""
symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr] = {}
# The shape of temporary arrays might be defined based on the shape of other input/output fields.
# Therefore, here we collect the symbols used to define data-field parameters that are not temporaries.
for sym in node_params:
if all([sym.id != tmp.id for tmp in self.tmps]) and sym.kind != "Iterator":
edopao marked this conversation as resolved.
Show resolved Hide resolved
name_ = str(sym.id)
type_ = self.storage_types[name_]
assert isinstance(type_, ts.ScalarType)
symbol_map[name_] = SymbolExpr(name_, as_dace_type(type_))

symbol_dtype = dace.int64
tmp_symbols: dict[str, IteratorExpr | ValueExpr | SymbolExpr] = {}
for tmp in self.tmps:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You spend ~70 lines on filling tmp_symbols and does not explains what the loop does or the variable represents.
But for the small loop above you write a description, this seems a bit strange.

tmp_name = str(tmp.id)
assert isinstance(tmp.domain, itir.FunCall)
self.node_types.update(itir_typing.infer_all(tmp.domain))
domain_ctx = Context(program_sdfg, defs_state, symbol_map)
tmp_domain = self._visit_domain(tmp.domain, domain_ctx)

# First build the FieldType for the temporary array
dims: list[Dimension] = []
for dim, _ in tmp_domain:
dims.append(
Dimension(
value=dim,
kind=(
DimensionKind.VERTICAL
if self.column_axis is not None and self.column_axis.value == dim
else DimensionKind.HORIZONTAL
),
)
)
assert isinstance(tmp.dtype, str)
type_ = ts.FieldType(dims=dims, dtype=as_scalar_type(tmp.dtype))
self.add_storage(program_sdfg, tmp_name, type_)

# Then we retrieve the data container...
tmp_array = program_sdfg.arrays[tmp_name]
tmp_array.transient = True

stride_var = unique_var_name()
program_sdfg.add_scalar(stride_var, symbol_dtype, transient=True)
stride_node = defs_state.add_access(stride_var)
defs_state.add_edge(
defs_state.add_tasklet(
"stride",
code="__result = 1",
inputs={},
outputs={"__result"},
),
"__result",
stride_node,
None,
dace.Memlet.simple(stride_var, "0"),
)
# ...and loop through all dimensions to initialize the parameters for array offset/shape/stride
for (_, (begin, end)), offset_sym, shape_sym, stride_sym in reversed(
list(
zip(
tmp_domain,
tmp_array.offset,
tmp_array.shape,
tmp_array.strides,
)
)
):
offset_tasklet = defs_state.add_tasklet(
"offset",
code=f"__result = - {begin.value}",
inputs={},
outputs={"__result"},
)
offset_var = unique_var_name()
program_sdfg.add_scalar(offset_var, symbol_dtype, transient=True)
offset_node = defs_state.add_access(offset_var)
defs_state.add_edge(
offset_tasklet,
"__result",
offset_node,
None,
dace.Memlet.simple(offset_var, "0"),
)

shape_tasklet = defs_state.add_tasklet(
"shape",
code=f"__result = {end.value} - {begin.value}",
inputs={},
outputs={"__result"},
)
shape_var = unique_var_name()
program_sdfg.add_scalar(shape_var, symbol_dtype, transient=True)
shape_node = defs_state.add_access(shape_var)
defs_state.add_edge(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to my understanding, you should be able to put even complex expressions in assignments on interstate edge.

shape_tasklet,
"__result",
shape_node,
None,
dace.Memlet.simple(shape_var, "0"),
)

# The transient scalars containing the array parameters are later mapped to interstate symbols
tmp_symbols[str(offset_sym)] = offset_var
tmp_symbols[str(stride_sym)] = stride_var
tmp_symbols[str(shape_sym)] = stride_var = shape_var

return tmp_symbols

def get_output_nodes(
self, closure: itir.StencilClosure, context: Context
) -> dict[str, dace.nodes.AccessNode]:
translator = PythonTaskletCodegen(self.offset_provider, context, self.node_types)
output_nodes = flatten_list(translator.visit(closure.output))
return {node.value.data: node.value for node in output_nodes}

def visit_FencilDefinition(self, node: itir.FencilDefinition):
def visit_FencilDefinition(self, node: itir.FencilDefinition, **kargs):
program_sdfg = dace.SDFG(name=node.id)
last_state = program_sdfg.add_state("program_entry")
entry_state = program_sdfg.add_state("program_entry")
self.node_types = itir_typing.infer_all(node)

# Filter neighbor tables from offset providers.
Expand All @@ -144,6 +259,9 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
for param, type_ in zip(node.params, self.param_types):
self.add_storage(program_sdfg, str(param.id), type_)

if self.tmps:
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
tmp_symbols = self.generate_temporaries(node.params, entry_state, program_sdfg)

# Add connectivities as SDFG storages.
for offset, table in neighbor_tables:
scalar_kind = type_translation.get_scalar_kind(table.table.dtype)
Expand All @@ -152,7 +270,10 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False)

# Create a nested SDFG for all stencil closures.
last_state = entry_state
for closure in node.closures:
ItirToSDFG._replace_ssa_identifiers(closure)

# Translate the closure and its stencil's body to an SDFG.
closure_sdfg, input_names, output_names = self.visit(
closure, array_table=program_sdfg.arrays
Expand Down Expand Up @@ -189,6 +310,11 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
access_node = last_state.add_access(inner_name)
last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet)

if self.tmps:
# on the first interstate edge define symbols for shape/stride/offsets of temporary arrays
inter_state_edge = program_sdfg.out_edges(entry_state)[0]
inter_state_edge.data.assignments.update(tmp_symbols)

program_sdfg.validate()
return program_sdfg

Expand Down Expand Up @@ -628,3 +754,33 @@ def _check_shift_offsets_are_literals(node: itir.StencilClosure):
if not all(isinstance(arg, (itir.Literal, itir.OffsetLiteral)) for arg in shift.args):
return False
return True

@staticmethod
def _replace_ssa_identifiers(closure: itir.StencilClosure):
edopao marked this conversation as resolved.
Show resolved Hide resolved
"""
Replace unicode symbols in function arguments with suffix identifiers based on the number of characters.

For example, 'z_gammaᐞ0ᐞ3' is renamed to 'z_gamma7_0'
Unicode symbols are not accepted in DaCe connectors, because C/C++ code does not have support.
"""
_UNIQUE_NAME_SEPARATOR = "ᐞ"

def __replace_in_expression(fun: itir.FunCall, p_old: str, p_new: str):
for arg in fun.args:
if isinstance(arg, itir.FunCall):
__replace_in_expression(arg, p_old, p_new)
elif isinstance(arg, itir.SymRef) and arg.id == p_old:
arg.id = eve.SymbolRef(p_new)

if isinstance(closure.stencil, itir.Lambda):
for p in closure.stencil.params:
p_new = ""
edopao marked this conversation as resolved.
Show resolved Hide resolved
p_old = str(p.id)
match = p_old.split(_UNIQUE_NAME_SEPARATOR)
p_new += match[0]
for suffix in match[1:]:
p_new += f"{len(p_new)}_{suffix}"
if p_new != p_old:
assert isinstance(closure.stencil.expr, itir.FunCall)
__replace_in_expression(closure.stencil.expr, p_old, p_new)
p.id = eve.SymbolName(p_new)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ def as_dace_type(type_: ts.ScalarType):
raise ValueError(f"scalar type {type_} not supported")


def as_scalar_type(dtype: str) -> ts.ScalarType:
try:
kind = getattr(ts.ScalarKind, dtype.upper())
except AttributeError:
raise ValueError(f"Data type {dtype} not supported.")
return ts.ScalarType(kind)
edopao marked this conversation as resolved.
Show resolved Hide resolved


def filter_neighbor_tables(offset_provider: dict[str, Any]):
return [
(offset, table)
Expand Down