-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next][dace]: lowering of scan to SDFG #1776
base: main
Are you sure you want to change the base?
Changes from all commits
df1847a
89ca8f7
f22eb64
c26d906
ba0a9ba
ac7acf8
877d81e
8baf6d1
784b573
de9c9de
14e66e8
fcfaf72
79204ee
5fe461a
a4bde3a
c75a8e4
6f72cac
397acae
acf5ac0
a706b27
c22cfc8
59e0ed5
792a8eb
61985f7
aa236a2
9bdc75b
746f9d8
0d894ff
440a474
500590b
c56e062
5d5992a
c167def
eb17345
8b163da
d15213a
55811dc
8f0e515
f01d291
62e1648
72e8830
39aeb20
de4a80e
45f9927
3fe538b
ee62266
4b0ac60
f701605
a19019f
c03492c
310fcce
4b487ea
4cf66e7
ab7ee5f
82cf491
a0dbea5
9128ffb
f940c4e
cc0777b
462f3c5
d9218b6
9d7e722
1ddd6fe
95e0007
f1b7a3f
50ad620
983022c
e7b1afb
df7bd0c
363ab59
9c19d32
9cad1f7
2700f53
a20d3c0
b5ff462
d326d3b
252f348
49f8172
2d6dfc0
6124c6d
45bcf97
23b0baa
ff05880
43ec33c
d43153a
2e82bd5
07e6a5c
4bf145b
2b03bb4
40c225d
419a386
cc9801b
360baae
f2396c4
a0c37cb
45c69ec
0f9043b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union | ||
|
||
import dace | ||
from dace.sdfg import utils as dace_sdfg_utils | ||
|
||
from gt4py import eve | ||
from gt4py.eve import concepts | ||
|
@@ -111,6 +112,21 @@ def get_symbol_type(self, symbol_name: str) -> ts.DataType: | |
"""Retrieve the GT4Py type of a symbol used in the SDFG.""" | ||
... | ||
|
||
@abc.abstractmethod | ||
def is_column_dimension(self, dim: gtx_common.Dimension) -> bool: | ||
"""Check if the given dimension is the column dimension.""" | ||
... | ||
|
||
@abc.abstractmethod | ||
def nested_context( | ||
self, | ||
sdfg: dace.SDFG, | ||
global_symbols: dict[str, ts.DataType], | ||
field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], | ||
) -> SDFGBuilder: | ||
"""Create a new empty context, useful to build a nested SDFG.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would expand on this a little bit more, as handling neste SDFG are probably the most important if not the only application of this function. |
||
... | ||
|
||
@abc.abstractmethod | ||
def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: | ||
"""Visit a node of the GT4Py IR.""" | ||
|
@@ -149,15 +165,6 @@ def _collect_symbols_in_domain_expressions( | |
) | ||
|
||
|
||
def _get_tuple_type(data: tuple[gtir_builtin_translators.FieldopResult, ...]) -> ts.TupleType: | ||
""" | ||
Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. | ||
""" | ||
return ts.TupleType( | ||
types=[_get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] | ||
) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): | ||
"""Provides translation capability from a GTIR program to a DaCe SDFG. | ||
|
@@ -173,6 +180,7 @@ class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): | |
""" | ||
|
||
offset_provider_type: gtx_common.OffsetProviderType | ||
column_dim: Optional[gtx_common.Dimension] | ||
global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) | ||
field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = dataclasses.field( | ||
default_factory=lambda: {} | ||
|
@@ -199,6 +207,25 @@ def make_field( | |
def get_symbol_type(self, symbol_name: str) -> ts.DataType: | ||
return self.global_symbols[symbol_name] | ||
|
||
def is_column_dimension(self, dim: gtx_common.Dimension) -> bool: | ||
assert self.column_dim | ||
return dim == self.column_dim | ||
|
||
def nested_context( | ||
self, | ||
sdfg: dace.SDFG, | ||
global_symbols: dict[str, ts.DataType], | ||
field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]], | ||
) -> SDFGBuilder: | ||
Comment on lines
+214
to
+219
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would also modify all places where you create nested SDFG such that they use this function. |
||
nsdfg_builder = GTIRToSDFG( | ||
self.offset_provider_type, self.column_dim, global_symbols, field_offsets | ||
) | ||
nsdfg_params = [ | ||
gtir.Sym(id=p_name, type=p_type) for p_name, p_type in global_symbols.items() | ||
] | ||
nsdfg_builder._add_sdfg_params(sdfg, node_params=nsdfg_params, symbolic_arguments=None) | ||
return nsdfg_builder | ||
|
||
def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: | ||
nsdfg_list = [ | ||
nsdfg.label for nsdfg in sdfg.all_sdfgs_recursive() if nsdfg.label.startswith(prefix) | ||
|
@@ -277,10 +304,11 @@ def _add_storage( | |
""" | ||
if isinstance(gt_type, ts.TupleType): | ||
tuple_fields = [] | ||
for tname, ttype in dace_gtir_utils.get_tuple_fields(name, gt_type, flatten=True): | ||
for sym in dace_gtir_utils.flatten_tuple_fields(name, gt_type): | ||
assert isinstance(sym.type, ts.DataType) | ||
tuple_fields.extend( | ||
self._add_storage( | ||
sdfg, symbolic_arguments, tname, ttype, transient, tuple_name=name | ||
sdfg, symbolic_arguments, sym.id, sym.type, transient, tuple_name=name | ||
) | ||
) | ||
return tuple_fields | ||
|
@@ -379,7 +407,7 @@ def _add_sdfg_params( | |
self, | ||
sdfg: dace.SDFG, | ||
node_params: Sequence[gtir.Sym], | ||
symbolic_arguments: set[str], | ||
symbolic_arguments: Optional[set[str]], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does accepting |
||
) -> list[str]: | ||
""" | ||
Helper function to add storage for node parameters and connectivity tables. | ||
|
@@ -389,6 +417,9 @@ def _add_sdfg_params( | |
except when they are listed in 'symbolic_arguments', in which case they | ||
will be represented in the SDFG as DaCe symbols. | ||
""" | ||
if symbolic_arguments is None: | ||
symbolic_arguments = set() | ||
|
||
# add non-transient arrays and/or SDFG symbols for the program arguments | ||
sdfg_args = [] | ||
for param in node_params: | ||
|
@@ -436,7 +467,7 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: | |
assert len(self.field_offsets) == 0 | ||
|
||
sdfg = dace.SDFG(node.id) | ||
sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) | ||
sdfg.debuginfo = dace_utils.debug_info(node) | ||
|
||
# DaCe requires C-compatible strings for the names of data containers, | ||
# such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name | ||
|
@@ -619,32 +650,23 @@ def visit_Lambda( | |
(str(param.id), arg) for param, arg in zip(node.params, args, strict=True) | ||
] | ||
|
||
def flatten_tuples( | ||
name: str, | ||
arg: gtir_builtin_translators.FieldopResult, | ||
) -> list[tuple[str, gtir_builtin_translators.FieldopData]]: | ||
if isinstance(arg, tuple): | ||
tuple_type = _get_tuple_type(arg) | ||
tuple_field_names = [ | ||
arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) | ||
] | ||
tuple_args = zip(tuple_field_names, arg, strict=True) | ||
return list( | ||
itertools.chain(*[flatten_tuples(fname, farg) for fname, farg in tuple_args]) | ||
) | ||
else: | ||
return [(name, arg)] | ||
|
||
lambda_arg_nodes = dict( | ||
itertools.chain(*[flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) | ||
itertools.chain( | ||
*[ | ||
gtir_builtin_translators.flatten_tuples(pname, arg) | ||
for pname, arg in lambda_args_mapping | ||
] | ||
) | ||
) | ||
|
||
# inherit symbols from parent scope but eventually override with local symbols | ||
lambda_symbols = { | ||
sym: self.global_symbols[sym] | ||
for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) | ||
} | { | ||
pname: _get_tuple_type(arg) if isinstance(arg, tuple) else arg.gt_type | ||
pname: gtir_builtin_translators.get_tuple_type(arg) | ||
if isinstance(arg, tuple) | ||
else arg.gt_type | ||
for pname, arg in lambda_args_mapping | ||
} | ||
|
||
|
@@ -659,12 +681,12 @@ def get_field_domain_offset( | |
elif field_domain_offset := self.field_offsets.get(p_name, None): | ||
return {p_name: field_domain_offset} | ||
elif isinstance(p_type, ts.TupleType): | ||
p_fields = dace_gtir_utils.get_tuple_fields(p_name, p_type, flatten=True) | ||
tsyms = dace_gtir_utils.flatten_tuple_fields(p_name, p_type) | ||
return functools.reduce( | ||
lambda field_offsets, field: ( | ||
field_offsets | get_field_domain_offset(field[0], field[1]) | ||
lambda field_offsets, sym: ( | ||
field_offsets | get_field_domain_offset(sym.id, sym.type) # type: ignore[arg-type] | ||
), | ||
p_fields, | ||
tsyms, | ||
{}, | ||
) | ||
return {} | ||
|
@@ -676,7 +698,7 @@ def get_field_domain_offset( | |
|
||
# lower let-statement lambda node as a nested SDFG | ||
lambda_translator = GTIRToSDFG( | ||
self.offset_provider_type, lambda_symbols, lambda_field_offsets | ||
self.offset_provider_type, self.column_dim, lambda_symbols, lambda_field_offsets | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here you should use your nested context I guess. |
||
) | ||
nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) | ||
nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) | ||
|
@@ -853,6 +875,7 @@ def visit_SymRef( | |
def build_sdfg_from_gtir( | ||
ir: gtir.Program, | ||
offset_provider_type: gtx_common.OffsetProviderType, | ||
column_dim: Optional[gtx_common.Dimension] = None, | ||
) -> dace.SDFG: | ||
""" | ||
Receives a GTIR program and lowers it to a DaCe SDFG. | ||
|
@@ -863,15 +886,19 @@ def build_sdfg_from_gtir( | |
Args: | ||
ir: The GTIR program node to be lowered to SDFG | ||
offset_provider_type: The definitions of offset providers used by the program node | ||
column_dim: Vertical dimension used for scan expressions. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would call that thing |
||
|
||
Returns: | ||
An SDFG in the DaCe canonical form (simplified) | ||
""" | ||
|
||
ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) | ||
ir = ir_prune_casts.PruneCasts().visit(ir) | ||
sdfg_genenerator = GTIRToSDFG(offset_provider_type) | ||
sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_dim) | ||
sdfg = sdfg_genenerator.visit(ir) | ||
assert isinstance(sdfg, dace.SDFG) | ||
|
||
# TODO(edopao): remove inlining when DaCe transformations support LoopRegion construct | ||
dace_sdfg_utils.inline_loop_blocks(sdfg) | ||
|
||
return sdfg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a dimension is a column, should not depend on the SDFG builder.
Why is it an abstract method, I would have guessed it is more a static one or does life in some type traits package.