From 78759b56b537930a5fd3d4bdd64048960765adf1 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Tue, 16 Apr 2024 16:08:17 +0200 Subject: [PATCH] Distributed Compilation as an option to DaCe Program (#1555) Option to activate/deactivate Distributed Compilation. This small PR is based on the following comment (DAPP/DaCe Mattermost channel): _I have an unexpected behaviour in DaCe distributed compilation. Currently, if you have an MPI program, distributed compilation is the default behaviour (as seen in [this file](https://github.com/spcl/dace/blob/master/dace/frontend/python/parser.py#L452)). I was expecting that after the loading of the compiled sdfg every rank would do symbol specialization. Although, this is not the case, i.e. every rank uses the compiled sdfg from rank 0, which specializes its symbols with the values corresponding to rank 0. Therefore, the compiled sdfg loaded by all the other ranks use a wrong sdfg (symbols are not specialized with the values of the correct rank). To validate this behaviour, I have de-activated the distributed compilation and set `dace.config.Config.set("cache", value="unique")`. Indeed, this approach works without any issue. Is there a way to change this unexpected behaviour, i.e. to have by default the distributed compilation but every rank to perform symbol specialization. To give a bit more context, I am generating an sdfg that uses closures heavily, i.e. all the gt4py fields are defined externally to the sdfg (could that be an issue)?_ --- dace/frontend/python/interface.py | 7 ++++++- dace/frontend/python/parser.py | 8 +++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index 69e650beaa..ecd0b164d6 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -42,6 +42,7 @@ def program(f: F, recreate_sdfg: bool = True, regenerate_code: bool = True, recompile: bool = True, + distributed_compilation: bool = False, constant_functions=False, **kwargs) -> Callable[..., parser.DaceProgram]: """ @@ -60,6 +61,9 @@ def program(f: F, it. :param recompile: Whether to recompile the code. If False, the library in the build folder will be used if it exists, without recompiling it. + :param distributed_compilation: Whether to compile the code from rank 0, and broadcast it to all the other ranks. + If False, every rank performs the compilation. In this case, make sure to check the ``cache`` configuration entry + such that no caching or clashes can happen between different MPI processes. :param constant_functions: If True, assumes all external functions that do not depend on internal variables are constant. This will hardcode their return values into the @@ -78,7 +82,8 @@ def program(f: F, constant_functions, recreate_sdfg=recreate_sdfg, regenerate_code=regenerate_code, - recompile=recompile) + recompile=recompile, + distributed_compilation=distributed_compilation) function = program diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 14377c4fe2..34cb8fb4ad 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -151,6 +151,7 @@ def __init__(self, recreate_sdfg: bool = True, regenerate_code: bool = True, recompile: bool = True, + distributed_compilation: bool = False, method: bool = False): from dace.codegen import compiled_sdfg # Avoid import loops @@ -171,6 +172,7 @@ def __init__(self, self.recreate_sdfg = recreate_sdfg self.regenerate_code = regenerate_code self.recompile = recompile + self.distributed_compilation = distributed_compilation self.global_vars = _get_locals_and_globals(f) self.signature = inspect.signature(f) @@ -449,12 +451,12 @@ def __call__(self, *args, **kwargs): sdfg.simplify() with hooks.invoke_sdfg_call_hooks(sdfg) as sdfg: - if not mpi4py: + if self.distributed_compilation and mpi4py: + binaryobj = distributed_compile(sdfg, mpi4py.MPI.COMM_WORLD, validate=self.validate) + else: # Compile SDFG (note: this is done after symbol inference due to shape # altering transformations such as Vectorization) binaryobj = sdfg.compile(validate=self.validate) - else: - binaryobj = distributed_compile(sdfg, mpi4py.MPI.COMM_WORLD, validate=self.validate) # Recreate key and add to cache cachekey = self._cache.make_key(argtypes, specified, self.closure_array_keys, self.closure_constant_keys,