From a57877f069a21ec4830ca1edd2d501340a1eca59 Mon Sep 17 00:00:00 2001 From: Christos Kotsalos Date: Thu, 21 Mar 2024 12:47:34 +0100 Subject: [PATCH] Add support for distributed compilation in DaceProgram (#1551) --- dace/frontend/python/parser.py | 15 ++++++++++++--- dace/sdfg/utils.py | 4 ++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 1b6817a7d0..14377c4fe2 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -21,6 +21,12 @@ except ImportError: from typing_compat import get_origin, get_args +try: + import mpi4py + from dace.sdfg.utils import distributed_compile +except ImportError: + mpi4py = None + ArgTypes = Dict[str, Data] @@ -443,9 +449,12 @@ def __call__(self, *args, **kwargs): sdfg.simplify() with hooks.invoke_sdfg_call_hooks(sdfg) as sdfg: - # Compile SDFG (note: this is done after symbol inference due to shape - # altering transformations such as Vectorization) - binaryobj = sdfg.compile(validate=self.validate) + if not mpi4py: + # Compile SDFG (note: this is done after symbol inference due to shape + # altering transformations such as Vectorization) + binaryobj = sdfg.compile(validate=self.validate) + else: + binaryobj = distributed_compile(sdfg, mpi4py.MPI.COMM_WORLD, validate=self.validate) # Recreate key and add to cache cachekey = self._cache.make_key(argtypes, specified, self.closure_array_keys, self.closure_constant_keys, diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index e211b50904..7311f4f028 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1355,7 +1355,7 @@ def load_precompiled_sdfg(folder: str): csdfg.ReloadableDLL(os.path.join(folder, 'build', f'lib{sdfg.name}.{suffix}'), sdfg.name)) -def distributed_compile(sdfg: SDFG, comm) -> csdfg.CompiledSDFG: +def distributed_compile(sdfg: SDFG, comm, validate: bool = True) -> csdfg.CompiledSDFG: """ Compiles an SDFG in rank 0 of MPI communicator ``comm``. Then, the compiled SDFG is loaded in all other ranks. @@ -1371,7 +1371,7 @@ def distributed_compile(sdfg: SDFG, comm) -> csdfg.CompiledSDFG: # Rank 0 compiles SDFG. if rank == 0: - func = sdfg.compile() + func = sdfg.compile(validate=validate) folder = sdfg.build_folder # Broadcasts build folder.