-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next][dace]: DaCe support for temporaries #1351
Changes from 30 commits
3c38362
8fd0fda
1020248
6a46682
7b251f8
e2ff6f6
2380efb
5f10d3d
32052b6
842a340
e65cbb2
326748a
fec7ae4
fa7482d
81f3f74
240de21
22041e4
112a594
b8a5db7
da795ef
c387b38
6c5668f
4187ee1
67fc2cf
3e80907
55250db
8e203e8
15a693f
1a7da6a
3529244
9aa3629
8b74107
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,8 +19,12 @@ | |
import gt4py.eve as eve | ||
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing | ||
from gt4py.next.common import NeighborTable | ||
from gt4py.next.iterator import ir as itir, type_inference as itir_typing | ||
from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef | ||
from gt4py.next.iterator import ( | ||
ir as itir, | ||
transforms as itir_transforms, | ||
type_inference as itir_typing, | ||
) | ||
from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef | ||
from gt4py.next.type_system import type_specifications as ts, type_translation | ||
|
||
from .itir_to_tasklet import ( | ||
|
@@ -36,6 +40,7 @@ | |
from .utility import ( | ||
add_mapped_nested_sdfg, | ||
as_dace_type, | ||
as_scalar_type, | ||
connectivity_identifier, | ||
create_memlet_at, | ||
create_memlet_full, | ||
|
@@ -44,6 +49,7 @@ | |
flatten_list, | ||
get_sorted_dims, | ||
map_nested_sdfg_symbols, | ||
new_array_symbols, | ||
unique_name, | ||
unique_var_name, | ||
) | ||
|
@@ -154,12 +160,14 @@ def __init__( | |
self, | ||
param_types: list[ts.TypeSpec], | ||
offset_provider: dict[str, NeighborTable], | ||
tmps: list[itir_transforms.global_tmps.Temporary], | ||
column_axis: Optional[Dimension] = None, | ||
): | ||
self.param_types = param_types | ||
self.column_axis = column_axis | ||
self.offset_provider = offset_provider | ||
self.storage_types = {} | ||
self.tmps = tmps | ||
|
||
def add_storage( | ||
self, | ||
|
@@ -189,6 +197,103 @@ def add_storage( | |
raise NotImplementedError() | ||
self.storage_types[name] = type_ | ||
|
||
def generate_temporaries( | ||
self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG | ||
) -> dict[str, TaskletExpr]: | ||
"""Create a table of symbols which are used to define array shape, stride and offset for temporaries.""" | ||
symbol_map: dict[str, TaskletExpr] = {} | ||
# The shape of temporary arrays might be defined based on the shape of other input/output fields. | ||
# Therefore, here we collect the symbols used to define data-field parameters that are not temporaries. | ||
for sym in node_params: | ||
if all([sym.id != tmp.id for tmp in self.tmps]) and sym.kind != "Iterator": | ||
edopao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
name_ = str(sym.id) | ||
type_ = self.storage_types[name_] | ||
assert isinstance(type_, ts.ScalarType) | ||
symbol_map[name_] = SymbolExpr(name_, as_dace_type(type_)) | ||
|
||
symbol_dtype = dace.int64 | ||
tmp_symbols: dict[str, TaskletExpr] = {} | ||
for tmp in self.tmps: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You spend ~70 lines on filling |
||
tmp_name = str(tmp.id) | ||
assert isinstance(tmp.domain, itir.FunCall) | ||
self.node_types.update(itir_typing.infer_all(tmp.domain)) | ||
domain_ctx = Context(program_sdfg, defs_state, symbol_map) | ||
tmp_domain = self._visit_domain(tmp.domain, domain_ctx) | ||
|
||
# First build the FieldType for this temporary array | ||
dims: list[Dimension] = [] | ||
for dim, _ in tmp_domain: | ||
dims.append( | ||
Dimension( | ||
value=dim, | ||
kind=( | ||
DimensionKind.VERTICAL | ||
if self.column_axis is not None and self.column_axis.value == dim | ||
else DimensionKind.HORIZONTAL | ||
), | ||
) | ||
) | ||
assert isinstance(tmp.dtype, str) | ||
type_ = ts.FieldType(dims=dims, dtype=as_scalar_type(tmp.dtype)) | ||
self.storage_types[tmp_name] = type_ | ||
|
||
# N.B.: skip generation of symbolic strides and just let dace assign default strides, for now. | ||
# Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics | ||
# to assign optimal stride values. | ||
tmp_shape, _ = new_array_symbols(tmp_name, len(dims)) | ||
tmp_offset = [ | ||
dace.symbol(unique_name(f"{tmp_name}_offset{i}")) for i in range(len(dims)) | ||
] | ||
_, tmp_array = program_sdfg.add_array( | ||
tmp_name, tmp_shape, as_dace_type(type_.dtype), offset=tmp_offset, transient=True | ||
) | ||
|
||
# Loop through all dimensions to initialize the array parameters for shape and offsets | ||
for (_, (begin, end)), offset_sym, shape_sym in zip( | ||
edopao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tmp_domain, | ||
tmp_array.offset, | ||
tmp_array.shape, | ||
): | ||
offset_tasklet = defs_state.add_tasklet( | ||
"offset", | ||
code=f"__result = - {begin.value}", | ||
inputs={}, | ||
outputs={"__result"}, | ||
) | ||
offset_var = unique_var_name() | ||
program_sdfg.add_scalar(offset_var, symbol_dtype, transient=True) | ||
offset_node = defs_state.add_access(offset_var) | ||
defs_state.add_edge( | ||
offset_tasklet, | ||
"__result", | ||
offset_node, | ||
None, | ||
dace.Memlet.simple(offset_var, "0"), | ||
) | ||
|
||
shape_tasklet = defs_state.add_tasklet( | ||
"shape", | ||
code=f"__result = {end.value} - {begin.value}", | ||
inputs={}, | ||
outputs={"__result"}, | ||
) | ||
shape_var = unique_var_name() | ||
program_sdfg.add_scalar(shape_var, symbol_dtype, transient=True) | ||
shape_node = defs_state.add_access(shape_var) | ||
defs_state.add_edge( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to my understanding, you should be able to put even complex expressions in assignments on interstate edge. |
||
shape_tasklet, | ||
"__result", | ||
shape_node, | ||
None, | ||
dace.Memlet.simple(shape_var, "0"), | ||
) | ||
|
||
# The transient scalars containing the array parameters are later mapped to interstate symbols | ||
tmp_symbols[str(offset_sym)] = offset_var | ||
tmp_symbols[str(shape_sym)] = shape_var | ||
|
||
return tmp_symbols | ||
|
||
def get_output_nodes( | ||
self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState | ||
) -> dict[str, dace.nodes.AccessNode]: | ||
|
@@ -204,7 +309,7 @@ def get_output_nodes( | |
def visit_FencilDefinition(self, node: itir.FencilDefinition): | ||
program_sdfg = dace.SDFG(name=node.id) | ||
program_sdfg.debuginfo = dace_debuginfo(node) | ||
last_state = program_sdfg.add_state("program_entry", True) | ||
entry_state = program_sdfg.add_state("program_entry", True) | ||
self.node_types = itir_typing.infer_all(node) | ||
|
||
# Filter neighbor tables from offset providers. | ||
|
@@ -214,6 +319,9 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): | |
for param, type_ in zip(node.params, self.param_types): | ||
self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables) | ||
|
||
if self.tmps: | ||
philip-paul-mueller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tmp_symbols = self.generate_temporaries(node.params, entry_state, program_sdfg) | ||
|
||
# Add connectivities as SDFG storages. | ||
for offset, offset_provider in neighbor_tables.items(): | ||
scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype) | ||
|
@@ -231,6 +339,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): | |
) | ||
|
||
# Create a nested SDFG for all stencil closures. | ||
last_state = entry_state | ||
for closure in node.closures: | ||
# Translate the closure and its stencil's body to an SDFG. | ||
closure_sdfg, input_names, output_names = self.visit( | ||
|
@@ -269,6 +378,11 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): | |
access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) | ||
last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) | ||
|
||
if self.tmps: | ||
# on the first interstate edge define symbols for shape and offsets of temporary arrays | ||
inter_state_edge = program_sdfg.out_edges(entry_state)[0] | ||
inter_state_edge.data.assignments.update(tmp_symbols) | ||
|
||
# Create the call signature for the SDFG. | ||
# Only the arguments requiered by the Fencil, i.e. `node.params` are added as poitional arguments. | ||
# The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure about the name.
What confuses me is that
self.tmps
are the temporaries, or not?And later you call this function only if
self.tmps
is not empty, which seems a bit contradictory.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, function name changed to
add_storage_for_temporaries