Skip to content

Commit

Permalink
feat[next-dace]: Add support for GPU execution (#1347)
Browse files Browse the repository at this point in the history
This PR adds support for GPU execution in DaCe Backend. Additionally, it also introduces a build cache for each visited ITIR program and corresponding binary DaCe program.
  • Loading branch information
edopao authored Oct 16, 2023
1 parent 0d821b1 commit 6c69398
Showing 1 changed file with 91 additions and 10 deletions.
101 changes: 91 additions & 10 deletions src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import dace
import numpy as np
from dace.codegen.compiled_sdfg import CompiledSDFG
from dace.transformation.auto import auto_optimize as autoopt

import gt4py.next.iterator.ir as itir
from gt4py.next import common
Expand All @@ -29,6 +31,14 @@
from .utility import connectivity_identifier, filter_neighbor_tables


""" Default build configuration in DaCe backend """
_build_type = "Release"
# removing -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins
_cpu_args = (
"-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -Wno-unused-parameter -Wno-unused-label"
)


def convert_arg(arg: Any):
if common.is_field(arg):
sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value)
Expand Down Expand Up @@ -85,17 +95,67 @@ def get_stride_args(
return stride_args


_build_cache_cpu: dict[int, CompiledSDFG] = {}
_build_cache_gpu: dict[int, CompiledSDFG] = {}


def get_cache_id(*cache_args) -> int:
return sum([hash(str(arg)) for arg in cache_args])


@program_executor
def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
# build parameters
auto_optimize = kwargs.get("auto_optimize", False)
build_type = kwargs.get("build_type", "RelWithDebInfo")
run_on_gpu = kwargs.get("run_on_gpu", False)
build_cache = kwargs.get("build_cache", None)
# ITIR parameters
column_axis = kwargs.get("column_axis", None)
offset_provider = kwargs["offset_provider"]
neighbor_tables = filter_neighbor_tables(offset_provider)

program = preprocess_program(program, offset_provider)
arg_types = [type_translation.from_value(arg) for arg in args]
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg: dace.SDFG = sdfg_genenerator.visit(program)
sdfg.simplify()
neighbor_tables = filter_neighbor_tables(offset_provider)

cache_id = get_cache_id(program, *arg_types, column_axis)
if build_cache is not None and cache_id in build_cache:
# retrieve SDFG program from build cache
sdfg_program = build_cache[cache_id]
sdfg = sdfg_program.sdfg
else:
# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg = sdfg_genenerator.visit(program)
sdfg.simplify()

# set array storage for GPU execution
if run_on_gpu:
device = dace.DeviceType.GPU
sdfg._name = f"{sdfg.name}_gpu"
for _, _, array in sdfg.arrays_recursive():
if not array.transient:
array.storage = dace.dtypes.StorageType.GPU_Global
else:
device = dace.DeviceType.CPU

# run DaCe auto-optimization heuristics
if auto_optimize:
# TODO Investigate how symbol definitions improve autoopt transformations,
# in which case the cache table should take the symbols map into account.
symbols: dict[str, int] = {}
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols)

# compile SDFG and retrieve SDFG program
sdfg.build_folder = cache._session_cache_dir_path / ".dacecache"
with dace.config.temporary_config():
dace.config.Config.set("compiler", "build_type", value=build_type)
dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args)
sdfg_program = sdfg.compile(validate=False)

# store SDFG program in build cache
if build_cache is not None:
build_cache[cache_id] = sdfg_program

dace_args = get_args(program.params, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
Expand All @@ -105,8 +165,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args)

sdfg.build_folder = cache._session_cache_dir_path / ".dacecache"

all_args = {
**dace_args,
**dace_conn_args,
Expand All @@ -120,9 +178,32 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
for key, value in all_args.items()
if key in sdfg.signature_arglist(with_types=False)
}

with dace.config.temporary_config():
dace.config.Config.set("compiler", "allow_view_arguments", value=True)
dace.config.Config.set("compiler", "build_type", value="Debug")
dace.config.Config.set("compiler", "cpu", "args", value="-O0")
dace.config.Config.set("frontend", "check_args", value=True)
sdfg(**expected_args)
sdfg_program(**expected_args)


@program_executor
def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
run_dace_iterator(
program,
*args,
**kwargs,
build_cache=_build_cache_cpu,
build_type=_build_type,
run_on_gpu=False,
)


@program_executor
def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
run_dace_iterator(
program,
*args,
**kwargs,
build_cache=_build_cache_gpu,
build_type=_build_type,
run_on_gpu=True,
)

0 comments on commit 6c69398

Please sign in to comment.