Skip to content

Commit

Permalink
[dace] Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Oct 20, 2023
1 parent e65cbb2 commit 326748a
Showing 1 changed file with 16 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
from gt4py.next.iterator import ir as itir, type_inference as itir_typing
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef
from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef
from gt4py.next.iterator.transforms import global_tmps
from gt4py.next.type_system import type_specifications as ts, type_translation

Expand Down Expand Up @@ -130,9 +130,12 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset
self.storage_types[name] = type_

def generate_temporaries(
self, node_params, defs_state: dace.SDFGState, program_sdfg: dace.SDFG
self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG
) -> dict[str, IteratorExpr | ValueExpr | SymbolExpr]:
"""Create a table of symbols which are used to define array shape, stride and offset for temporaries."""
symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr] = {}
# The shape of temporary arrays might be defined based on the shape of other input/output fields.
# Therefore, here we collect the symbols used to define data-field parameters that are not temporaries.
for sym in node_params:
if all([sym.id != tmp.id for tmp in self.tmps]) and sym.kind != "Iterator":
name_ = str(sym.id)
Expand All @@ -148,6 +151,8 @@ def generate_temporaries(
self.node_types.update(itir_typing.infer_all(tmp.domain))
domain_ctx = Context(program_sdfg, defs_state, symbol_map)
tmp_domain = self._visit_domain(tmp.domain, domain_ctx)

# First build the FieldType for the temporary array
dims: list[Dimension] = []
for dim, _ in tmp_domain:
dims.append(
Expand All @@ -164,6 +169,7 @@ def generate_temporaries(
type_ = ts.FieldType(dims=dims, dtype=as_scalar_type(tmp.dtype))
self.add_storage(program_sdfg, tmp_name, type_)

# Then we retrieve the data container...
tmp_array = program_sdfg.arrays[tmp_name]
tmp_array.transient = True

Expand All @@ -182,6 +188,7 @@ def generate_temporaries(
None,
dace.Memlet.simple(stride_var, "0"),
)
# ...and loop through all dimensions to initialize the parameters for array offset/shape/stride
for (_, (begin, end)), offset_sym, shape_sym, stride_sym in reversed(
list(
zip(
Expand Down Expand Up @@ -226,6 +233,7 @@ def generate_temporaries(
dace.Memlet.simple(shape_var, "0"),
)

# The transient scalars containing the array parameters are later mapped to interstate symbols
tmp_symbols[str(offset_sym)] = offset_var
tmp_symbols[str(stride_sym)] = stride_var
tmp_symbols[str(shape_sym)] = stride_var = shape_var
Expand Down Expand Up @@ -749,6 +757,12 @@ def _check_shift_offsets_are_literals(node: itir.StencilClosure):

@staticmethod
def _replace_ssa_identifiers(closure: itir.StencilClosure):
"""
Replace unicode symbols in function arguments with suffix identifiers based on the number of characters.
For example, 'z_gammaᐞ0ᐞ3' is renamed to 'z_gamma7_0'
Unicode symbols are not accepted in DaCe connectors, because C/C++ code does not have support.
"""
_UNIQUE_NAME_SEPARATOR = "ᐞ"

def __replace_in_expression(fun: itir.FunCall, p_old: str, p_new: str):
Expand Down

0 comments on commit 326748a

Please sign in to comment.