Skip to content

Commit

Permalink
Distributed Compilation as an option to DaCe Program (#1555)
Browse files Browse the repository at this point in the history
Option to activate/deactivate Distributed Compilation.

This small PR is based on the following comment (DAPP/DaCe Mattermost
channel):
_I have an unexpected behaviour in DaCe distributed compilation.
Currently, if you have an MPI program, distributed compilation is the
default behaviour (as seen in [this
file](https://github.com/spcl/dace/blob/master/dace/frontend/python/parser.py#L452)).
I was expecting that after the loading of the compiled sdfg every rank
would do symbol specialization.
Although, this is not the case, i.e. every rank uses the compiled sdfg
from rank 0, which specializes its symbols with the values corresponding
to rank 0. Therefore, the compiled sdfg loaded by all the other ranks
use a wrong sdfg (symbols are not specialized with the values of the
correct rank).
To validate this behaviour, I have de-activated the distributed
compilation and set `dace.config.Config.set("cache", value="unique")`.
Indeed, this approach works without any issue.
Is there a way to change this unexpected behaviour, i.e. to have by
default the distributed compilation but every rank to perform symbol
specialization.
To give a bit more context, I am generating an sdfg that uses closures
heavily, i.e. all the gt4py fields are defined externally to the sdfg
(could that be an issue)?_
  • Loading branch information
kotsaloscv authored Apr 16, 2024
1 parent 888fd2d commit 78759b5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
7 changes: 6 additions & 1 deletion dace/frontend/python/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def program(f: F,
recreate_sdfg: bool = True,
regenerate_code: bool = True,
recompile: bool = True,
distributed_compilation: bool = False,
constant_functions=False,
**kwargs) -> Callable[..., parser.DaceProgram]:
"""
Expand All @@ -60,6 +61,9 @@ def program(f: F,
it.
:param recompile: Whether to recompile the code. If False, the library in the build folder will be used if it exists,
without recompiling it.
:param distributed_compilation: Whether to compile the code from rank 0, and broadcast it to all the other ranks.
If False, every rank performs the compilation. In this case, make sure to check the ``cache`` configuration entry
such that no caching or clashes can happen between different MPI processes.
:param constant_functions: If True, assumes all external functions that do
not depend on internal variables are constant.
This will hardcode their return values into the
Expand All @@ -78,7 +82,8 @@ def program(f: F,
constant_functions,
recreate_sdfg=recreate_sdfg,
regenerate_code=regenerate_code,
recompile=recompile)
recompile=recompile,
distributed_compilation=distributed_compilation)


function = program
Expand Down
8 changes: 5 additions & 3 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(self,
recreate_sdfg: bool = True,
regenerate_code: bool = True,
recompile: bool = True,
distributed_compilation: bool = False,
method: bool = False):
from dace.codegen import compiled_sdfg # Avoid import loops

Expand All @@ -171,6 +172,7 @@ def __init__(self,
self.recreate_sdfg = recreate_sdfg
self.regenerate_code = regenerate_code
self.recompile = recompile
self.distributed_compilation = distributed_compilation

self.global_vars = _get_locals_and_globals(f)
self.signature = inspect.signature(f)
Expand Down Expand Up @@ -449,12 +451,12 @@ def __call__(self, *args, **kwargs):
sdfg.simplify()

with hooks.invoke_sdfg_call_hooks(sdfg) as sdfg:
if not mpi4py:
if self.distributed_compilation and mpi4py:
binaryobj = distributed_compile(sdfg, mpi4py.MPI.COMM_WORLD, validate=self.validate)
else:
# Compile SDFG (note: this is done after symbol inference due to shape
# altering transformations such as Vectorization)
binaryobj = sdfg.compile(validate=self.validate)
else:
binaryobj = distributed_compile(sdfg, mpi4py.MPI.COMM_WORLD, validate=self.validate)

# Recreate key and add to cache
cachekey = self._cache.make_key(argtypes, specified, self.closure_array_keys, self.closure_constant_keys,
Expand Down

0 comments on commit 78759b5

Please sign in to comment.