Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: SDFGConvertible Program for dace_fieldview backend #1742

Merged
merged 27 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4c41f02
[wip] start adding sdfg convertible Program for dace-fieldview
DropD Oct 29, 2024
481af34
add SDFGConvertible Program replacement to dace_fieldview
DropD Nov 19, 2024
ad0a7b2
improve generate_sdfg args, remove old Program replacement
DropD Nov 19, 2024
307b537
Merge branch 'main' into gtir-sdfg-convertible
DropD Nov 19, 2024
7c4d197
turn auto_optimize back on in __sdfg__
DropD Nov 19, 2024
3c56151
disable `auto_opt` once again in `__sdfg__`
DropD Nov 19, 2024
c5b4c43
support only CUDA device type in dace_fieldview Program
DropD Nov 20, 2024
835c115
[wip] bring back extra sdfg attributes for halo placement
DropD Nov 22, 2024
c126cb2
partially add halo exchange helper attrs with tests
DropD Nov 26, 2024
0b90583
add dace/gt4py type parsing crosscheck
DropD Nov 26, 2024
8c1f9f4
Merge branch 'main' into gtir-sdfg-convertible
DropD Nov 26, 2024
77c250b
fix dace_fieldview.program tests
DropD Nov 26, 2024
f8ec8f5
refactor extractors and remove debuginfo warning
DropD Dec 3, 2024
5831591
Merge branch 'main' into gtir-sdfg-convertible
DropD Dec 3, 2024
19c76ad
work around gtir transforms requiring connectivity tables
DropD Dec 17, 2024
b2931a5
Merge remote-tracking branch 'origin/main' into gtir-sdfg-convertible
edopao Dec 18, 2024
fd7f472
Add visitor fir Literal node
edopao Dec 18, 2024
71e54fd
Fix for attribute rename itir -> gtir on latest main
edopao Dec 18, 2024
99030d0
Fix error on main (is_field_allocator_factory_for -> is_field_allocat…
edopao Dec 18, 2024
439ec43
cover addition in the program tests
DropD Jan 6, 2025
7853745
Merge branch 'gtir-sdfg-convertible-node-visitor-regtest' into gtir-s…
DropD Jan 6, 2025
6c6a4e3
clean up dace-program tests and extractors
DropD Jan 6, 2025
f799bc0
cleanup orchestration tests
DropD Jan 6, 2025
9c884e3
Fix attribute dtype error
edopao Jan 13, 2025
3ecaab4
fix mypy error
edopao Jan 13, 2025
8976caa
fix stride problem
edopao Jan 13, 2025
d89383d
fix stride problem (1)
edopao Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ class Program:

definition_stage: ffront_stages.ProgramDefinition
backend: Optional[next_backend.Backend]
connectivities: Optional[common.OffsetProviderType] = None
connectivities: Optional[common.OffsetProvider] = (
None # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
)
Comment on lines +83 to +85
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@havogt I had to change this to satisfy mypy. Now it is also consistent with CompileTimeArgs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense


@classmethod
def from_function(
Expand Down Expand Up @@ -304,7 +306,7 @@ def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs:


try:
from gt4py.next.program_processors.runners.dace_iterator import Program
from gt4py.next.program_processors.runners.dace_fieldview.program import Program
DropD marked this conversation as resolved.
Show resolved Hide resolved
except ImportError:
pass

Expand Down
72 changes: 72 additions & 0 deletions src/gt4py/next/iterator/transforms/extractors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py import eve
from gt4py.next.iterator import ir as itir
from gt4py.next.type_system import type_specifications as ts


class SymbolNameSetExtractor(eve.NodeVisitor):
"""Extract a set of symbol names"""

def visit_Literal(self, node: itir.Literal) -> set[str]:
return set()

def generic_visitor(self, node: itir.Node) -> set[str]:
input_fields: set[str] = set()
for child in eve.trees.iter_children_values(node):
input_fields |= self.visit(child)
return input_fields

def visit_Node(self, node: itir.Node) -> set[str]:
return set()

def visit_Program(self, node: itir.Program) -> set[str]:
names = set()
for stmt in node.body:
names |= self.visit(stmt)
return names

def visit_IfStmt(self, node: itir.IfStmt) -> set[str]:
names = set()
for stmt in node.true_branch + node.false_branch:
names |= self.visit(stmt)
return names

def visit_Temporary(self, node: itir.Temporary) -> set[str]:
return set()

def visit_SymRef(self, node: itir.SymRef) -> set[str]:
return {str(node.id)}

@classmethod
def only_fields(cls, program: itir.Program) -> set[str]:
field_param_names = [
str(param.id) for param in program.params if isinstance(param.type, ts.FieldType)
]
return {name for name in cls().visit(program) if name in field_param_names}


class InputNamesExtractor(SymbolNameSetExtractor):
"""Extract the set of symbol names passed into field operators within a program."""

def visit_SetAt(self, node: itir.SetAt) -> set[str]:
return self.visit(node.expr)

def visit_FunCall(self, node: itir.FunCall) -> set[str]:
input_fields = set()
for arg in node.args:
input_fields |= self.visit(arg)
return input_fields


class OutputNamesExtractor(SymbolNameSetExtractor):
"""Extract the set of symbol names written to within a program"""

def visit_SetAt(self, node: itir.SetAt) -> set[str]:
return self.visit(node.target)
248 changes: 248 additions & 0 deletions src/gt4py/next/program_processors/runners/dace_fieldview/program.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import collections
import dataclasses
import itertools
import typing
from typing import Any, ClassVar, Optional, Sequence

import dace
import numpy as np

from gt4py.next import backend as next_backend, common
from gt4py.next.ffront import decorator
from gt4py.next.iterator import ir as itir, transforms as itir_transforms
from gt4py.next.iterator.transforms import extractors as extractors
from gt4py.next.otf import arguments, recipes, toolchain
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils
from gt4py.next.type_system import type_specifications as ts


@dataclasses.dataclass(frozen=True)
class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible):
"""Extension of GT4Py Program implementing the SDFGConvertible interface via GTIR."""

sdfg_closure_cache: dict[str, Any] = dataclasses.field(default_factory=dict)
# Being a ClassVar ensures that in an SDFG with multiple nested GT4Py Programs,
# there is no name mangling of the connectivity tables used across the nested SDFGs
# since they share the same memory address.
connectivity_tables_data_descriptors: ClassVar[
dict[str, dace.data.Array]
] = {} # symbolically defined

def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG:
if (self.backend is None) or "dace" not in self.backend.name.lower():
raise ValueError("The SDFG can be generated only for the DaCe backend.")

offset_provider: common.OffsetProvider = {
**(self.connectivities or {}),
**self._implicit_offset_provider,
}
column_axis = kwargs.get("column_axis", None)

# TODO(ricoh): connectivity tables required here for now.
gtir_stage = typing.cast(next_backend.Transforms, self.backend.transforms).past_to_itir(
toolchain.CompilableProgram(
data=self.past_stage,
args=arguments.CompileTimeArgs(
args=tuple(p.type for p in self.past_stage.past_node.params),
kwargs={},
column_axis=column_axis,
offset_provider=offset_provider,
),
)
)
program = gtir_stage.data
program = itir_transforms.apply_fieldview_transforms( # run the transforms separately because they require the runtime info
program, offset_provider=offset_provider
)
object.__setattr__(
gtir_stage,
"data",
program,
)
object.__setattr__(
gtir_stage.args, "offset_provider", gtir_stage.args.offset_provider_type
) # TODO(ricoh): currently this is circumventing the frozenness of CompileTimeArgs
# in order to isolate DaCe from the runtime tables in connectivities.offset_provider.
# These are needed at the time of writing for mandatory GTIR passes.
# Remove this as soon as Program does not expect connectivity tables anymore.

_crosscheck_dace_parsing(
dace_parsed_args=[*args, *kwargs.values()],
gt4py_program_args=[p.type for p in program.params],
)

compile_workflow = typing.cast(
recipes.OTFCompileWorkflow,
self.backend.executor
if not hasattr(self.backend.executor, "step")
else self.backend.executor.step,
) # We know which backend we are using, but we don't know if the compile workflow is cached.
# TODO(ricoh): switch 'itir_transforms_off=True' because we ran them separately previously
# and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with
# the other parts of the workaround when possible.
sdfg = dace.SDFG.from_json(
compile_workflow.translation.replace(itir_transforms_off=True)(gtir_stage).source_code
)

self.sdfg_closure_cache["arrays"] = sdfg.arrays

# Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields,
# offset_providers_per_input_field. Add them as dynamic attributes to the SDFG
field_params = {
str(param.id): param for param in program.params if isinstance(param.type, ts.FieldType)
}

def single_horizontal_dim_per_field(
fields: typing.Iterable[itir.Sym],
) -> typing.Iterator[tuple[str, common.Dimension]]:
for field in fields:
assert isinstance(field.type, ts.FieldType)
horizontal_dims = [
dim for dim in field.type.dims if dim.kind is common.DimensionKind.HORIZONTAL
]
# do nothing for fields with multiple horizontal dimensions
# or without horizontal dimensions
# this is only meant for use with unstructured grids
if len(horizontal_dims) == 1:
yield str(field.id), horizontal_dims[0]

input_fields = (
field_params[name] for name in extractors.InputNamesExtractor.only_fields(program)
)
sdfg.gt4py_program_input_fields = dict(single_horizontal_dim_per_field(input_fields))

output_fields = (
field_params[name] for name in extractors.OutputNamesExtractor.only_fields(program)
)
sdfg.gt4py_program_output_fields = dict(single_horizontal_dim_per_field(output_fields))

# TODO (ricoh): bring back sdfg.offset_providers_per_input_field.
# A starting point would be to use the "trace_shifts" pass on GTIR
# and associate the extracted shifts with each input field.
# Analogous to the version in `runners.dace_iterator.__init__`, which
# was removed when merging #1742.

return sdfg

def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]:
"""
Return the closure arrays of the SDFG represented by this object
as a mapping between array name and the corresponding value.

The connectivity tables are defined symbolically, i.e. table sizes & strides are DaCe symbols.
The need to define the connectivity tables in the `__sdfg_closure__` arises from the fact that
the offset providers are not part of GT4Py Program's arguments.
Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method.
"""
closure_dict: dict[str, Any] = {}

if self.connectivities:
symbols = {}
with_table = [
name for name, conn in self.connectivities.items() if common.is_neighbor_table(conn)
]
in_arrays_with_id = [
(name, conn_id)
for name in with_table
if (conn_id := dace_utils.connectivity_identifier(name))
in self.sdfg_closure_cache["arrays"]
]
in_arrays = (name for name, _ in in_arrays_with_id)
name_axis = list(itertools.product(in_arrays, [0, 1]))

def size_symbol_name(name: str, axis: int) -> str:
return dace_utils.field_size_symbol_name(
dace_utils.connectivity_identifier(name), axis
)

connectivity_tables_size_symbols = {
(sname := size_symbol_name(name, axis)): dace.symbol(sname)
for name, axis in name_axis
}

def stride_symbol_name(name: str, axis: int) -> str:
return dace_utils.field_stride_symbol_name(
dace_utils.connectivity_identifier(name), axis
)

connectivity_table_stride_symbols = {
(sname := stride_symbol_name(name, axis)): dace.symbol(sname)
for name, axis in name_axis
}

symbols = connectivity_tables_size_symbols | connectivity_table_stride_symbols

# Define the storage location (e.g. CPU, GPU) of the connectivity tables
if "storage" not in self.connectivity_tables_data_descriptors:
for _, conn_id in in_arrays_with_id:
self.connectivity_tables_data_descriptors["storage"] = self.sdfg_closure_cache[
"arrays"
][conn_id].storage
break

# Build the closure dictionary
for name, conn_id in in_arrays_with_id:
if conn_id not in self.connectivity_tables_data_descriptors:
conn = self.connectivities[name]
assert common.is_neighbor_table(conn)
self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array(
dtype=dace.dtypes.dtype_to_typeclass(conn.dtype.dtype.type),
shape=[
symbols[dace_utils.field_size_symbol_name(conn_id, 0)],
symbols[dace_utils.field_size_symbol_name(conn_id, 1)],
],
strides=[
symbols[dace_utils.field_stride_symbol_name(conn_id, 0)],
symbols[dace_utils.field_stride_symbol_name(conn_id, 1)],
],
storage=Program.connectivity_tables_data_descriptors["storage"],
)
closure_dict[conn_id] = self.connectivity_tables_data_descriptors[conn_id]

return closure_dict

def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]:
return [p.id for p in self.past_stage.past_node.params], []


def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> None:
for dace_parsed_arg, gt4py_program_arg in zip(
dace_parsed_args,
gt4py_program_args,
strict=False, # dace does not see implicit size args
):
match dace_parsed_arg:
case dace.data.Scalar():
assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg)
case bool() | np.bool_():
assert isinstance(gt4py_program_arg, ts.ScalarType)
assert gt4py_program_arg.kind == ts.ScalarKind.BOOL
case int() | np.integer():
assert isinstance(gt4py_program_arg, ts.ScalarType)
assert gt4py_program_arg.kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64]
case float() | np.floating():
assert isinstance(gt4py_program_arg, ts.ScalarType)
assert gt4py_program_arg.kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64]
case str() | np.str_():
assert isinstance(gt4py_program_arg, ts.ScalarType)
assert gt4py_program_arg.kind == ts.ScalarKind.STRING
case dace.data.Array():
assert isinstance(gt4py_program_arg, ts.FieldType)
assert isinstance(gt4py_program_arg.dtype, ts.ScalarType)
assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims)
assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype)
case dace.data.Structure() | dict() | collections.OrderedDict():
# offset provider
pass
case _:
raise ValueError(
f"Unresolved case for {dace_parsed_arg} (==, !=) {gt4py_program_arg}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class DaCeTranslator(
):
device_type: core_defs.DeviceType
auto_optimize: bool
itir_transforms_off: bool = False

def _language_settings(self) -> languages.LanguageSettings:
return languages.LanguageSettings(
Expand All @@ -51,7 +52,8 @@ def generate_sdfg(
auto_opt: bool,
on_gpu: bool,
) -> dace.SDFG:
ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider)
if not self.itir_transforms_off:
ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider)
sdfg = gtir_sdfg.build_sdfg_from_gtir(
ir, offset_provider_type=common.offset_provider_to_type(offset_provider)
)
Expand Down
Loading
Loading