Skip to content

Commit

Permalink
[dace] Minor edit (2)
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Oct 10, 2023
1 parent 554f89a commit 8727628
Showing 1 changed file with 24 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,12 @@ def get_stride_args(
return stride_args


# TODO implement some hash of the FencilDefinition node as cache key, currently using the 'id' string
_build_cache: Dict[str, Tuple[dace.SDFG, CompiledSDFG]] = {}
_build_cache_cpu: Dict[int, Tuple[dace.SDFG, CompiledSDFG]] = {}
_build_cache_gpu: Dict[int, Tuple[dace.SDFG, CompiledSDFG]] = {}


def get_program_id(program: itir.FencilDefinition) -> int:
return hash(str(program))


@program_executor
Expand All @@ -105,14 +109,15 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
auto_optimize = kwargs.get("auto_optimize", False)
build_type = kwargs.get("build_type", "RelWithDebInfo")
run_on_gpu = kwargs.get("run_on_gpu", False)
use_build_cache = kwargs.get("use_build_cache", False)
build_cache = kwargs.get("build_cache", None)
# ITIR parameters
column_axis = kwargs.get("column_axis", None)
offset_provider = kwargs["offset_provider"]
neighbor_tables = filter_neighbor_tables(offset_provider)

if program.id in _build_cache:
sdfg, sdfg_func = _build_cache[program.id]
program_id = get_program_id(program)
if build_cache is not None and program_id in build_cache:
sdfg, sdfg_func = build_cache[program_id]
else:
program = preprocess_program(program, offset_provider)
arg_types = [type_translation.from_value(arg) for arg in args]
Expand Down Expand Up @@ -140,8 +145,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:
dace.config.Config.set("compiler", "build_type", value=build_type)
dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args)
sdfg_func = sdfg.compile(validate=False)
if use_build_cache:
_build_cache[program.id] = (sdfg, sdfg_func)
if build_cache is not None:
build_cache[program_id] = (sdfg, sdfg_func)

dace_args = get_args(program.params, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
Expand Down Expand Up @@ -193,7 +198,12 @@ def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
@program_executor
def run_dace_cpu_cached(program: itir.FencilDefinition, *args, **kwargs) -> None:
run_dace_iterator(
program, *args, **kwargs, build_type=_build_type, run_on_gpu=False, use_build_cache=True
program,
*args,
**kwargs,
build_cache=_build_cache_cpu,
build_type=_build_type,
run_on_gpu=False,
)


Expand All @@ -205,5 +215,10 @@ def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
@program_executor
def run_dace_gpu_cached(program: itir.FencilDefinition, *args, **kwargs) -> None:
run_dace_iterator(
program, *args, **kwargs, build_type=_build_type, run_on_gpu=True, use_build_cache=True
program,
*args,
**kwargs,
build_cache=_build_cache_gpu,
build_type=_build_type,
run_on_gpu=True,
)

0 comments on commit 8727628

Please sign in to comment.