-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[next]: SDFGConvertible Program for dace_fieldview backend #1742
Merged
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
4c41f02
[wip] start adding sdfg convertible Program for dace-fieldview
DropD 481af34
add SDFGConvertible Program replacement to dace_fieldview
DropD ad0a7b2
improve generate_sdfg args, remove old Program replacement
DropD 307b537
Merge branch 'main' into gtir-sdfg-convertible
DropD 7c4d197
turn auto_optimize back on in __sdfg__
DropD 3c56151
disable `auto_opt` once again in `__sdfg__`
DropD c5b4c43
support only CUDA device type in dace_fieldview Program
DropD 835c115
[wip] bring back extra sdfg attributes for halo placement
DropD c126cb2
partially add halo exchange helper attrs with tests
DropD 0b90583
add dace/gt4py type parsing crosscheck
DropD 8c1f9f4
Merge branch 'main' into gtir-sdfg-convertible
DropD 77c250b
fix dace_fieldview.program tests
DropD f8ec8f5
refactor extractors and remove debuginfo warning
DropD 5831591
Merge branch 'main' into gtir-sdfg-convertible
DropD 19c76ad
work around gtir transforms requiring connectivity tables
DropD b2931a5
Merge remote-tracking branch 'origin/main' into gtir-sdfg-convertible
edopao fd7f472
Add visitor fir Literal node
edopao 71e54fd
Fix for attribute rename itir -> gtir on latest main
edopao 99030d0
Fix error on main (is_field_allocator_factory_for -> is_field_allocat…
edopao 439ec43
cover addition in the program tests
DropD 7853745
Merge branch 'gtir-sdfg-convertible-node-visitor-regtest' into gtir-s…
DropD 6c6a4e3
clean up dace-program tests and extractors
DropD f799bc0
cleanup orchestration tests
DropD 9c884e3
Fix attribute dtype error
edopao 3ecaab4
fix mypy error
edopao 8976caa
fix stride problem
edopao d89383d
fix stride problem (1)
edopao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# Please, refer to the LICENSE file in the root directory. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from gt4py import eve | ||
from gt4py.next.iterator import ir as itir | ||
from gt4py.next.type_system import type_specifications as ts | ||
|
||
|
||
class SymbolNameSetExtractor(eve.NodeVisitor): | ||
"""Extract a set of symbol names""" | ||
|
||
def visit_Literal(self, node: itir.Literal) -> set[str]: | ||
return set() | ||
|
||
def generic_visitor(self, node: itir.Node) -> set[str]: | ||
input_fields: set[str] = set() | ||
for child in eve.trees.iter_children_values(node): | ||
input_fields |= self.visit(child) | ||
return input_fields | ||
|
||
def visit_Node(self, node: itir.Node) -> set[str]: | ||
return set() | ||
|
||
def visit_Program(self, node: itir.Program) -> set[str]: | ||
names = set() | ||
for stmt in node.body: | ||
names |= self.visit(stmt) | ||
return names | ||
|
||
def visit_IfStmt(self, node: itir.IfStmt) -> set[str]: | ||
names = set() | ||
for stmt in node.true_branch + node.false_branch: | ||
names |= self.visit(stmt) | ||
return names | ||
|
||
def visit_Temporary(self, node: itir.Temporary) -> set[str]: | ||
return set() | ||
|
||
def visit_SymRef(self, node: itir.SymRef) -> set[str]: | ||
return {str(node.id)} | ||
|
||
@classmethod | ||
def only_fields(cls, program: itir.Program) -> set[str]: | ||
field_param_names = [ | ||
str(param.id) for param in program.params if isinstance(param.type, ts.FieldType) | ||
] | ||
return {name for name in cls().visit(program) if name in field_param_names} | ||
|
||
|
||
class InputNamesExtractor(SymbolNameSetExtractor): | ||
"""Extract the set of symbol names passed into field operators within a program.""" | ||
|
||
def visit_SetAt(self, node: itir.SetAt) -> set[str]: | ||
return self.visit(node.expr) | ||
|
||
def visit_FunCall(self, node: itir.FunCall) -> set[str]: | ||
input_fields = set() | ||
for arg in node.args: | ||
input_fields |= self.visit(arg) | ||
return input_fields | ||
|
||
|
||
class OutputNamesExtractor(SymbolNameSetExtractor): | ||
"""Extract the set of symbol names written to within a program""" | ||
|
||
def visit_SetAt(self, node: itir.SetAt) -> set[str]: | ||
return self.visit(node.target) |
248 changes: 248 additions & 0 deletions
248
src/gt4py/next/program_processors/runners/dace_fieldview/program.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# Please, refer to the LICENSE file in the root directory. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import collections | ||
import dataclasses | ||
import itertools | ||
import typing | ||
from typing import Any, ClassVar, Optional, Sequence | ||
|
||
import dace | ||
import numpy as np | ||
|
||
from gt4py.next import backend as next_backend, common | ||
from gt4py.next.ffront import decorator | ||
from gt4py.next.iterator import ir as itir, transforms as itir_transforms | ||
from gt4py.next.iterator.transforms import extractors as extractors | ||
from gt4py.next.otf import arguments, recipes, toolchain | ||
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils | ||
from gt4py.next.type_system import type_specifications as ts | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible): | ||
"""Extension of GT4Py Program implementing the SDFGConvertible interface via GTIR.""" | ||
|
||
sdfg_closure_cache: dict[str, Any] = dataclasses.field(default_factory=dict) | ||
# Being a ClassVar ensures that in an SDFG with multiple nested GT4Py Programs, | ||
# there is no name mangling of the connectivity tables used across the nested SDFGs | ||
# since they share the same memory address. | ||
connectivity_tables_data_descriptors: ClassVar[ | ||
dict[str, dace.data.Array] | ||
] = {} # symbolically defined | ||
|
||
def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: | ||
if (self.backend is None) or "dace" not in self.backend.name.lower(): | ||
raise ValueError("The SDFG can be generated only for the DaCe backend.") | ||
|
||
offset_provider: common.OffsetProvider = { | ||
**(self.connectivities or {}), | ||
**self._implicit_offset_provider, | ||
} | ||
column_axis = kwargs.get("column_axis", None) | ||
|
||
# TODO(ricoh): connectivity tables required here for now. | ||
gtir_stage = typing.cast(next_backend.Transforms, self.backend.transforms).past_to_itir( | ||
toolchain.CompilableProgram( | ||
data=self.past_stage, | ||
args=arguments.CompileTimeArgs( | ||
args=tuple(p.type for p in self.past_stage.past_node.params), | ||
kwargs={}, | ||
column_axis=column_axis, | ||
offset_provider=offset_provider, | ||
), | ||
) | ||
) | ||
program = gtir_stage.data | ||
program = itir_transforms.apply_fieldview_transforms( # run the transforms separately because they require the runtime info | ||
program, offset_provider=offset_provider | ||
) | ||
object.__setattr__( | ||
gtir_stage, | ||
"data", | ||
program, | ||
) | ||
object.__setattr__( | ||
gtir_stage.args, "offset_provider", gtir_stage.args.offset_provider_type | ||
) # TODO(ricoh): currently this is circumventing the frozenness of CompileTimeArgs | ||
# in order to isolate DaCe from the runtime tables in connectivities.offset_provider. | ||
# These are needed at the time of writing for mandatory GTIR passes. | ||
# Remove this as soon as Program does not expect connectivity tables anymore. | ||
|
||
_crosscheck_dace_parsing( | ||
dace_parsed_args=[*args, *kwargs.values()], | ||
gt4py_program_args=[p.type for p in program.params], | ||
) | ||
|
||
compile_workflow = typing.cast( | ||
recipes.OTFCompileWorkflow, | ||
self.backend.executor | ||
if not hasattr(self.backend.executor, "step") | ||
else self.backend.executor.step, | ||
) # We know which backend we are using, but we don't know if the compile workflow is cached. | ||
# TODO(ricoh): switch 'itir_transforms_off=True' because we ran them separately previously | ||
# and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with | ||
# the other parts of the workaround when possible. | ||
sdfg = dace.SDFG.from_json( | ||
compile_workflow.translation.replace(itir_transforms_off=True)(gtir_stage).source_code | ||
) | ||
|
||
self.sdfg_closure_cache["arrays"] = sdfg.arrays | ||
|
||
# Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, | ||
# offset_providers_per_input_field. Add them as dynamic attributes to the SDFG | ||
field_params = { | ||
str(param.id): param for param in program.params if isinstance(param.type, ts.FieldType) | ||
} | ||
|
||
def single_horizontal_dim_per_field( | ||
fields: typing.Iterable[itir.Sym], | ||
) -> typing.Iterator[tuple[str, common.Dimension]]: | ||
for field in fields: | ||
assert isinstance(field.type, ts.FieldType) | ||
horizontal_dims = [ | ||
dim for dim in field.type.dims if dim.kind is common.DimensionKind.HORIZONTAL | ||
] | ||
# do nothing for fields with multiple horizontal dimensions | ||
# or without horizontal dimensions | ||
# this is only meant for use with unstructured grids | ||
if len(horizontal_dims) == 1: | ||
yield str(field.id), horizontal_dims[0] | ||
|
||
input_fields = ( | ||
field_params[name] for name in extractors.InputNamesExtractor.only_fields(program) | ||
) | ||
sdfg.gt4py_program_input_fields = dict(single_horizontal_dim_per_field(input_fields)) | ||
|
||
output_fields = ( | ||
field_params[name] for name in extractors.OutputNamesExtractor.only_fields(program) | ||
) | ||
sdfg.gt4py_program_output_fields = dict(single_horizontal_dim_per_field(output_fields)) | ||
|
||
# TODO (ricoh): bring back sdfg.offset_providers_per_input_field. | ||
# A starting point would be to use the "trace_shifts" pass on GTIR | ||
# and associate the extracted shifts with each input field. | ||
# Analogous to the version in `runners.dace_iterator.__init__`, which | ||
# was removed when merging #1742. | ||
|
||
return sdfg | ||
|
||
def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: | ||
""" | ||
Return the closure arrays of the SDFG represented by this object | ||
as a mapping between array name and the corresponding value. | ||
|
||
The connectivity tables are defined symbolically, i.e. table sizes & strides are DaCe symbols. | ||
The need to define the connectivity tables in the `__sdfg_closure__` arises from the fact that | ||
the offset providers are not part of GT4Py Program's arguments. | ||
Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. | ||
""" | ||
closure_dict: dict[str, Any] = {} | ||
|
||
if self.connectivities: | ||
symbols = {} | ||
with_table = [ | ||
name for name, conn in self.connectivities.items() if common.is_neighbor_table(conn) | ||
] | ||
in_arrays_with_id = [ | ||
(name, conn_id) | ||
for name in with_table | ||
if (conn_id := dace_utils.connectivity_identifier(name)) | ||
in self.sdfg_closure_cache["arrays"] | ||
] | ||
in_arrays = (name for name, _ in in_arrays_with_id) | ||
name_axis = list(itertools.product(in_arrays, [0, 1])) | ||
|
||
def size_symbol_name(name: str, axis: int) -> str: | ||
return dace_utils.field_size_symbol_name( | ||
dace_utils.connectivity_identifier(name), axis | ||
) | ||
|
||
connectivity_tables_size_symbols = { | ||
(sname := size_symbol_name(name, axis)): dace.symbol(sname) | ||
for name, axis in name_axis | ||
} | ||
|
||
def stride_symbol_name(name: str, axis: int) -> str: | ||
return dace_utils.field_stride_symbol_name( | ||
dace_utils.connectivity_identifier(name), axis | ||
) | ||
|
||
connectivity_table_stride_symbols = { | ||
(sname := stride_symbol_name(name, axis)): dace.symbol(sname) | ||
for name, axis in name_axis | ||
} | ||
|
||
symbols = connectivity_tables_size_symbols | connectivity_table_stride_symbols | ||
|
||
# Define the storage location (e.g. CPU, GPU) of the connectivity tables | ||
if "storage" not in self.connectivity_tables_data_descriptors: | ||
for _, conn_id in in_arrays_with_id: | ||
self.connectivity_tables_data_descriptors["storage"] = self.sdfg_closure_cache[ | ||
"arrays" | ||
][conn_id].storage | ||
break | ||
|
||
# Build the closure dictionary | ||
for name, conn_id in in_arrays_with_id: | ||
if conn_id not in self.connectivity_tables_data_descriptors: | ||
conn = self.connectivities[name] | ||
assert common.is_neighbor_table(conn) | ||
self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( | ||
dtype=dace.dtypes.dtype_to_typeclass(conn.dtype.dtype.type), | ||
shape=[ | ||
symbols[dace_utils.field_size_symbol_name(conn_id, 0)], | ||
symbols[dace_utils.field_size_symbol_name(conn_id, 1)], | ||
], | ||
strides=[ | ||
symbols[dace_utils.field_stride_symbol_name(conn_id, 0)], | ||
symbols[dace_utils.field_stride_symbol_name(conn_id, 1)], | ||
], | ||
storage=Program.connectivity_tables_data_descriptors["storage"], | ||
) | ||
closure_dict[conn_id] = self.connectivity_tables_data_descriptors[conn_id] | ||
|
||
return closure_dict | ||
|
||
def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: | ||
return [p.id for p in self.past_stage.past_node.params], [] | ||
|
||
|
||
def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> None: | ||
for dace_parsed_arg, gt4py_program_arg in zip( | ||
dace_parsed_args, | ||
gt4py_program_args, | ||
strict=False, # dace does not see implicit size args | ||
): | ||
match dace_parsed_arg: | ||
case dace.data.Scalar(): | ||
assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg) | ||
case bool() | np.bool_(): | ||
assert isinstance(gt4py_program_arg, ts.ScalarType) | ||
assert gt4py_program_arg.kind == ts.ScalarKind.BOOL | ||
case int() | np.integer(): | ||
assert isinstance(gt4py_program_arg, ts.ScalarType) | ||
assert gt4py_program_arg.kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] | ||
case float() | np.floating(): | ||
assert isinstance(gt4py_program_arg, ts.ScalarType) | ||
assert gt4py_program_arg.kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] | ||
case str() | np.str_(): | ||
assert isinstance(gt4py_program_arg, ts.ScalarType) | ||
assert gt4py_program_arg.kind == ts.ScalarKind.STRING | ||
case dace.data.Array(): | ||
assert isinstance(gt4py_program_arg, ts.FieldType) | ||
assert isinstance(gt4py_program_arg.dtype, ts.ScalarType) | ||
assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) | ||
assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype) | ||
case dace.data.Structure() | dict() | collections.OrderedDict(): | ||
# offset provider | ||
pass | ||
case _: | ||
raise ValueError( | ||
f"Unresolved case for {dace_parsed_arg} (==, !=) {gt4py_program_arg}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@havogt I had to change this to satisfy mypy. Now it is also consistent with
CompileTimeArgs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense