diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index f78d90095c..25609b1035 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -16,6 +16,8 @@ import dace import numpy as np +from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.iterator.ir as itir from gt4py.next import common @@ -29,6 +31,14 @@ from .utility import connectivity_identifier, filter_neighbor_tables +""" Default build configuration in DaCe backend """ +_build_type = "Release" +# removing -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins +_cpu_args = ( + "-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -Wno-unused-parameter -Wno-unused-label" +) + + def convert_arg(arg: Any): if common.is_field(arg): sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) @@ -85,17 +95,67 @@ def get_stride_args( return stride_args +_build_cache_cpu: dict[int, CompiledSDFG] = {} +_build_cache_gpu: dict[int, CompiledSDFG] = {} + + +def get_cache_id(*cache_args) -> int: + return sum([hash(str(arg)) for arg in cache_args]) + + @program_executor def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: + # build parameters + auto_optimize = kwargs.get("auto_optimize", False) + build_type = kwargs.get("build_type", "RelWithDebInfo") + run_on_gpu = kwargs.get("run_on_gpu", False) + build_cache = kwargs.get("build_cache", None) + # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] - neighbor_tables = filter_neighbor_tables(offset_provider) - program = preprocess_program(program, offset_provider) arg_types = [type_translation.from_value(arg) for arg in args] - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) - sdfg: dace.SDFG = sdfg_genenerator.visit(program) - sdfg.simplify() + neighbor_tables = filter_neighbor_tables(offset_provider) + + cache_id = get_cache_id(program, *arg_types, column_axis) + if build_cache is not None and cache_id in build_cache: + # retrieve SDFG program from build cache + sdfg_program = build_cache[cache_id] + sdfg = sdfg_program.sdfg + else: + # visit ITIR and generate SDFG + program = preprocess_program(program, offset_provider) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) + sdfg = sdfg_genenerator.visit(program) + sdfg.simplify() + + # set array storage for GPU execution + if run_on_gpu: + device = dace.DeviceType.GPU + sdfg._name = f"{sdfg.name}_gpu" + for _, _, array in sdfg.arrays_recursive(): + if not array.transient: + array.storage = dace.dtypes.StorageType.GPU_Global + else: + device = dace.DeviceType.CPU + + # run DaCe auto-optimization heuristics + if auto_optimize: + # TODO Investigate how symbol definitions improve autoopt transformations, + # in which case the cache table should take the symbols map into account. + symbols: dict[str, int] = {} + sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols) + + # compile SDFG and retrieve SDFG program + sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "build_type", value=build_type) + dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) + sdfg_program = sdfg.compile(validate=False) + + # store SDFG program in build cache + if build_cache is not None: + build_cache[cache_id] = sdfg_program dace_args = get_args(program.params, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} @@ -105,8 +165,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: dace_strides = get_stride_args(sdfg.arrays, dace_field_args) dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args) - sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" - all_args = { **dace_args, **dace_conn_args, @@ -120,9 +178,32 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: for key, value in all_args.items() if key in sdfg.signature_arglist(with_types=False) } + with dace.config.temporary_config(): dace.config.Config.set("compiler", "allow_view_arguments", value=True) - dace.config.Config.set("compiler", "build_type", value="Debug") - dace.config.Config.set("compiler", "cpu", "args", value="-O0") dace.config.Config.set("frontend", "check_args", value=True) - sdfg(**expected_args) + sdfg_program(**expected_args) + + +@program_executor +def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_dace_iterator( + program, + *args, + **kwargs, + build_cache=_build_cache_cpu, + build_type=_build_type, + run_on_gpu=False, + ) + + +@program_executor +def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_dace_iterator( + program, + *args, + **kwargs, + build_cache=_build_cache_gpu, + build_type=_build_type, + run_on_gpu=True, + )