Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into embedded_field_scan
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Nov 29, 2023
2 parents 62a6ec8 + 91307b1 commit 13fe506
Show file tree
Hide file tree
Showing 17 changed files with 380 additions and 334 deletions.
4 changes: 3 additions & 1 deletion docs/development/ADRs/0015-Test_Exclusion_Matrices.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ by calling `next_tests.get_processor_id()`, which returns the so-called processo
The following backend processors are defined:

```python
DACE = "dace_iterator.run_dace_iterator"
DACE_CPU = "dace_iterator.run_dace_cpu"
DACE_GPU = "dace_iterator.run_dace_gpu"
GTFN_CPU = "otf_compile_executor.run_gtfn"
GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative"
GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries"
GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu"
```

Following the previous example, the GTFN backend with temporaries does not support yet dynamic offsets in ITIR:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,14 @@ markers = [
'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_constant_fields: tests that require backend support for constant fields',
'uses_dynamic_offsets: tests that require backend support for dynamic offsets',
'uses_floordiv: tests that require backend support for floor division',
'uses_if_stmts: tests that require backend support for if-statements',
'uses_index_fields: tests that require backend support for index fields',
'uses_lift_expressions: tests that require backend support for lift expressions',
'uses_negative_modulo: tests that require backend support for modulo on negative numbers',
'uses_origin: tests that require backend support for domain origin',
'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions',
'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields',
'uses_scan_in_field_operator: tests that require backend support for scan in field operator',
'uses_sparse_fields: tests that require backend support for sparse fields',
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
Expand Down
11 changes: 10 additions & 1 deletion src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
grid_type: Optional[GridType] = None
operator_attributes: Optional[dict[str, Any]] = None
_program_cache: dict = dataclasses.field(default_factory=dict)

@classmethod
def from_function(
Expand Down Expand Up @@ -616,6 +617,13 @@ def as_program(
# of arg and kwarg types
# TODO(tehrengruber): check foast operator has no out argument that clashes
# with the out argument of the program we generate here.
hash_ = eve_utils.content_hash(
(tuple(arg_types), tuple((name, arg) for name, arg in kwarg_types.items()))
)
try:
return self._program_cache[hash_]
except KeyError:
pass

loc = self.foast_node.location
param_sym_uids = eve_utils.UIDGenerator() # use a new UID generator to allow caching
Expand Down Expand Up @@ -669,12 +677,13 @@ def as_program(
untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars)
past_node = ProgramTypeDeduction.apply(untyped_past_node)

return Program(
self._program_cache[hash_] = Program(
past_node=past_node,
closure_vars=closure_vars,
backend=self.backend,
grid_type=self.grid_type,
)
return self._program_cache[hash_]

def __call__(
self,
Expand Down
48 changes: 26 additions & 22 deletions src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import hashlib
import warnings
from typing import Any, Mapping, Optional, Sequence

import dace
Expand All @@ -22,11 +23,11 @@
import gt4py.next.allocators as next_allocators
import gt4py.next.iterator.ir as itir
import gt4py.next.program_processors.otf_compile_executor as otf_exec
import gt4py.next.program_processors.processor_interface as ppi
from gt4py.next.common import Dimension, Domain, UnitRange, is_field
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider
from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
from gt4py.next.otf.compilation import cache
from gt4py.next.program_processors.processor_interface import program_executor
from gt4py.next.type_system import type_specifications as ts, type_translation

from .itir_to_sdfg import ItirToSDFG
Expand Down Expand Up @@ -94,10 +95,26 @@ def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]:
return {name.id: convert_arg(arg) for name, arg in zip(params, args)}


def _ensure_is_on_device(
connectivity_arg: np.typing.NDArray, device: dace.dtypes.DeviceType
) -> np.typing.NDArray:
if device == dace.dtypes.DeviceType.GPU:
if not isinstance(connectivity_arg, cp.ndarray):
warnings.warn(
"Copying connectivity to device. For performance make sure connectivity is provided on device."
)
return cp.asarray(connectivity_arg)
return connectivity_arg


def get_connectivity_args(
neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]]
neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]],
device: dace.dtypes.DeviceType,
) -> dict[str, Any]:
return {connectivity_identifier(offset): table.table for offset, table in neighbor_tables}
return {
connectivity_identifier(offset): _ensure_is_on_device(table.table, device)
for offset, table in neighbor_tables
}


def get_shape_args(
Expand Down Expand Up @@ -167,7 +184,6 @@ def get_cache_id(
return m.hexdigest()


@program_executor
def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
# build parameters
auto_optimize = kwargs.get("auto_optimize", False)
Expand All @@ -182,6 +198,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
offset_provider = kwargs["offset_provider"]

arg_types = [type_translation.from_value(arg) for arg in args]
device = dace.DeviceType.GPU if run_on_gpu else dace.DeviceType.CPU
neighbor_tables = filter_neighbor_tables(offset_provider)

cache_id = get_cache_id(program, arg_types, column_axis, offset_provider)
Expand All @@ -192,26 +209,16 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
else:
# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, run_on_gpu)
sdfg = sdfg_genenerator.visit(program)
sdfg.simplify()

# set array storage for GPU execution
if run_on_gpu:
device = dace.DeviceType.GPU
sdfg._name = f"{sdfg.name}_gpu"
for _, _, array in sdfg.arrays_recursive():
if not array.transient:
array.storage = dace.dtypes.StorageType.GPU_Global
else:
device = dace.DeviceType.CPU

# run DaCe auto-optimization heuristics
if auto_optimize:
# TODO Investigate how symbol definitions improve autoopt transformations,
# in which case the cache table should take the symbols map into account.
symbols: dict[str, int] = {}
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols)
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=run_on_gpu)

# compile SDFG and retrieve SDFG program
sdfg.build_folder = cache._session_cache_dir_path / ".dacecache"
Expand All @@ -226,7 +233,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:

dace_args = get_args(program.params, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_connectivity_args(neighbor_tables)
dace_conn_args = get_connectivity_args(neighbor_tables, device)
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
Expand Down Expand Up @@ -254,7 +261,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
sdfg_program(**expected_args)


@program_executor
def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
run_dace_iterator(
program,
Expand All @@ -267,13 +273,12 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:


run_dace_cpu = otf_exec.OTFBackend(
executor=_run_dace_cpu,
executor=ppi.program_executor(_run_dace_cpu, name="run_dace_cpu"),
allocator=next_allocators.StandardCPUFieldBufferAllocator(),
)

if cp:

@program_executor
def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
run_dace_iterator(
program,
Expand All @@ -286,12 +291,11 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:

else:

@program_executor
def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
raise RuntimeError("Missing `cupy` dependency for GPU execution.")


run_dace_gpu = otf_exec.OTFBackend(
executor=_run_dace_gpu,
executor=ppi.program_executor(_run_dace_gpu, name="run_dace_gpu"),
allocator=next_allocators.StandardGPUFieldBufferAllocator(),
)
Loading

0 comments on commit 13fe506

Please sign in to comment.