diff --git a/pyop2/compilation.py b/pyop2/compilation.py index ecca43187..32f743c2f 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -44,8 +44,7 @@ from packaging.version import Version, InvalidVersion -from pyop2.mpi import MPI, collective, COMM_WORLD -from pyop2.mpi import dup_comm, get_compilation_comm, set_compilation_comm +from pyop2 import mpi from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError @@ -59,7 +58,7 @@ def _check_hashes(x, y, datatype): return False -_check_op = MPI.Op.Create(_check_hashes, commute=True) +_check_op = mpi.MPI.Op.Create(_check_hashes, commute=True) _compiler = None @@ -148,53 +147,6 @@ def sniff_compiler(exe): return compiler -@collective -def compilation_comm(comm): - """Get a communicator for compilation. - - :arg comm: The input communicator. - :returns: A communicator used for compilation (may be smaller) - """ - # Should we try and do node-local compilation? - if not configuration["node_local_compilation"]: - return comm - retcomm = get_compilation_comm(comm) - if retcomm is not None: - debug("Found existing compilation communicator") - return retcomm - if MPI.VERSION >= 3: - debug("Creating compilation communicator using MPI_Split_type") - retcomm = comm.Split_type(MPI.COMM_TYPE_SHARED) - debug("Finished creating compilation communicator using MPI_Split_type") - set_compilation_comm(comm, retcomm) - return retcomm - debug("Creating compilation communicator using MPI_Split + filesystem") - import tempfile - if comm.rank == 0: - if not os.path.exists(configuration["cache_dir"]): - os.makedirs(configuration["cache_dir"], exist_ok=True) - tmpname = tempfile.mkdtemp(prefix="rank-determination-", - dir=configuration["cache_dir"]) - else: - tmpname = None - tmpname = comm.bcast(tmpname, root=0) - if tmpname is None: - raise CompilationError("Cannot determine sharedness of filesystem") - # Touch file - debug("Made tmpdir %s" % tmpname) - with open(os.path.join(tmpname, str(comm.rank)), "wb"): - pass - comm.barrier() - import glob - ranks = sorted(int(os.path.basename(name)) - for name in glob.glob("%s/[0-9]*" % tmpname)) - debug("Creating compilation communicator using filesystem colors") - retcomm = comm.Split(color=min(ranks), key=comm.rank) - debug("Finished creating compilation communicator using filesystem colors") - set_compilation_comm(comm, retcomm) - return retcomm - - class Compiler(ABC): """A compiler for shared libraries. @@ -210,7 +162,7 @@ class Compiler(ABC): :arg cpp: Should we try and use the C++ compiler instead of the C compiler?. :kwarg comm: Optional communicator to compile the code on - (defaults to COMM_WORLD). + (defaults to pyop2.mpi.COMM_WORLD). """ _name = "unknown" @@ -226,16 +178,24 @@ class Compiler(ABC): _debugflags = () def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, comm=None): + # Get compiler version ASAP since it is used in __repr__ + self.sniff_compiler_version() + self._extra_compiler_flags = tuple(extra_compiler_flags) self._extra_linker_flags = tuple(extra_linker_flags) self._cpp = cpp self._debug = configuration["debug"] - # Ensure that this is an internal communicator. - comm = dup_comm(comm or COMM_WORLD) - self.comm = compilation_comm(comm) - self.sniff_compiler_version() + # Compilation communicators are reference counted on the PyOP2 comm + self.pcomm = mpi.internal_comm(comm) + self.comm = mpi.compilation_comm(self.pcomm) + + def __del__(self): + if hasattr(self, "comm"): + mpi.decref(self.comm) + if hasattr(self, "pcomm"): + mpi.decref(self.pcomm) def __repr__(self): return f"<{self._name} compiler, version {self.version or 'unknown'}>" @@ -313,7 +273,7 @@ def expandWl(ldflags): else: yield flag - @collective + @mpi.collective def get_so(self, jitmodule, extension): """Build a shared library and load it @@ -591,7 +551,7 @@ class AnonymousCompiler(Compiler): _name = "Unknown" -@collective +@mpi.collective def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), argtypes=None, restype=None, comm=None): """Build a shared library and return a function pointer from it. @@ -608,7 +568,7 @@ def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), :arg restype: The return type of the function (optional, pass ``None`` for ``void``). :kwarg comm: Optional communicator to compile the code on (only - rank 0 compiles code) (defaults to COMM_WORLD). + rank 0 compiles code) (defaults to pyop2.mpi.COMM_WORLD). """ from pyop2.global_kernel import GlobalKernel @@ -639,6 +599,7 @@ def __init__(self, code, argtypes): exe = configuration["cc"] or "mpicc" compiler = sniff_compiler(exe) dll = compiler(cppargs, ldargs, cpp=cpp, comm=comm).get_so(code, extension) + if isinstance(jitmodule, GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) diff --git a/pyop2/logger.py b/pyop2/logger.py index fb6532746..2e58e3446 100644 --- a/pyop2/logger.py +++ b/pyop2/logger.py @@ -40,6 +40,7 @@ handler = logging.StreamHandler() logger.addHandler(handler) + debug = logger.debug info = logger.info warning = logger.warning diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 1ee16c11d..66fa10f88 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -37,16 +37,28 @@ from petsc4py import PETSc from mpi4py import MPI # noqa import atexit +import gc +import glob +import os +import tempfile + +from pyop2.configuration import configuration +from pyop2.exceptions import CompilationError +from pyop2.logger import warning, debug, logger, DEBUG from pyop2.utils import trim -__all__ = ("COMM_WORLD", "COMM_SELF", "MPI", "dup_comm") +__all__ = ("COMM_WORLD", "COMM_SELF", "MPI", "internal_comm", "is_pyop2_comm", "incref", "decref", "temp_internal_comm") # These are user-level communicators, we never send any messages on # them inside PyOP2. -COMM_WORLD = PETSc.COMM_WORLD.tompi4py() +COMM_WORLD = PETSc.COMM_WORLD.tompi4py().Dup() +COMM_WORLD.Set_name("PYOP2_COMM_WORLD") -COMM_SELF = PETSc.COMM_SELF.tompi4py() +COMM_SELF = PETSc.COMM_SELF.tompi4py().Dup() +COMM_SELF.Set_name("PYOP2_COMM_SELF") + +PYOP2_FINALIZED = False # Exposition: # @@ -90,13 +102,22 @@ # outstanding duplicated communicators. +def collective(fn): + extra = trim(""" + This function is logically collective over MPI ranks, it is an + error to call it on fewer than all the ranks in MPI communicator. + """) + fn.__doc__ = "%s\n\n%s" % (trim(fn.__doc__), extra) if fn.__doc__ else extra + return fn + + def delcomm_outer(comm, keyval, icomm): """Deleter for internal communicator, removes reference to outer comm. :arg comm: Outer communicator. :arg keyval: The MPI keyval, should be ``innercomm_keyval``. :arg icomm: The inner communicator, should have a reference to - ``comm`. + ``comm``. """ if keyval != innercomm_keyval: raise ValueError("Unexpected keyval") @@ -118,63 +139,260 @@ def delcomm_outer(comm, keyval, icomm): # Outer communicator attribute (attaches user comm to inner communicator) outercomm_keyval = MPI.Comm.Create_keyval() +# Comm used for compilation, stashed on the internal communicator +compilationcomm_keyval = MPI.Comm.Create_keyval() + # List of internal communicators, must be freed at exit. dupped_comms = [] -def dup_comm(comm_in=None): - """Given a communicator return a communicator for internal use. +def is_pyop2_comm(comm): + """Returns ``True`` if ``comm`` is a PyOP2 communicator, + False if `comm` another communicator. + Raises exception if ``comm`` is not a communicator. - :arg comm_in: Communicator to duplicate. If not provided, - defaults to COMM_WORLD. - - :returns: An mpi4py communicator.""" - if comm_in is None: - comm_in = COMM_WORLD - if isinstance(comm_in, PETSc.Comm): - comm_in = comm_in.tompi4py() - elif not isinstance(comm_in, MPI.Comm): - raise ValueError("Don't know how to dup a %r" % type(comm_in)) - if comm_in == MPI.COMM_NULL: - return comm_in - refcount = comm_in.Get_attr(refcount_keyval) - if refcount is not None: - # Passed an existing PyOP2 comm, return it - comm_out = comm_in - refcount[0] += 1 + :arg comm: Communicator to query + """ + global PYOP2_FINALIZED + if isinstance(comm, PETSc.Comm): + ispyop2comm = False + elif comm == MPI.COMM_NULL: + if not PYOP2_FINALIZED: + raise ValueError("Communicator passed to is_pyop2_comm() is COMM_NULL") + else: + ispyop2comm = True + elif isinstance(comm, MPI.Comm): + ispyop2comm = bool(comm.Get_attr(refcount_keyval)) else: - # Check if communicator has an embedded PyOP2 comm. - comm_out = comm_in.Get_attr(innercomm_keyval) - if comm_out is None: - # Haven't seen this comm before, duplicate it. - comm_out = comm_in.Dup() - comm_in.Set_attr(innercomm_keyval, comm_out) - comm_out.Set_attr(outercomm_keyval, comm_in) - # Refcount - comm_out.Set_attr(refcount_keyval, [1]) - # Remember we need to destroy it. - dupped_comms.append(comm_out) + raise ValueError(f"Argument passed to is_pyop2_comm() is a {type(comm)}, which is not a recognised comm type") + return ispyop2comm + + +def pyop2_comm_status(): + """ Prints the reference counts for all comms PyOP2 has duplicated + """ + status_string = 'PYOP2 Communicator reference counts:\n' + status_string += '| Communicator name | Count |\n' + status_string += '==================================================\n' + for comm in dupped_comms: + if comm == MPI.COMM_NULL: + null = 'COMM_NULL' + status_string += f'| {null:39}| {0:5d} |\n' else: - refcount = comm_out.Get_attr(refcount_keyval) + refcount = comm.Get_attr(refcount_keyval)[0] if refcount is None: - raise ValueError("Inner comm without a refcount") - refcount[0] += 1 - return comm_out + refcount = -999 + status_string += f'| {comm.name:39}| {refcount:5d} |\n' + return status_string -# Comm used for compilation, stashed on the internal communicator -compilationcomm_keyval = MPI.Comm.Create_keyval() +class temp_internal_comm: + """ Use a PyOP2 internal communicator and + increment and decrement the internal comm. + :arg comm: Any communicator + """ + def __init__(self, comm): + self.user_comm = comm + self.internal_comm = internal_comm(self.user_comm) + + def __del__(self): + decref(self.internal_comm) + + def __enter__(self): + """ Returns an internal comm that will be safely decref'd + when the context manager is destroyed + + :returns pyop2_comm: A PyOP2 internal communicator + """ + return self.internal_comm + + def __exit__(self, exc_type, exc_value, traceback): + pass + + +def internal_comm(comm): + """ Creates an internal comm from the user comm. + If comm is None, create an internal communicator from COMM_WORLD + :arg comm: A communicator or None + + :returns pyop2_comm: A PyOP2 internal communicator + """ + # Parse inputs + if comm is None: + # None will be the default when creating most objects + comm = COMM_WORLD + elif isinstance(comm, PETSc.Comm): + comm = comm.tompi4py() + + # Check for invalid inputs + if comm == MPI.COMM_NULL: + raise ValueError("MPI_COMM_NULL passed to internal_comm()") + elif not isinstance(comm, MPI.Comm): + raise ValueError("Don't know how to dup a %r" % type(comm)) + + # Handle a valid input + if is_pyop2_comm(comm): + incref(comm) + pyop2_comm = comm + else: + pyop2_comm = dup_comm(comm) + return pyop2_comm + + +def incref(comm): + """ Increment communicator reference count + """ + assert is_pyop2_comm(comm) + refcount = comm.Get_attr(refcount_keyval) + refcount[0] += 1 + + +def decref(comm): + """ Decrement communicator reference count + """ + if not PYOP2_FINALIZED: + assert is_pyop2_comm(comm) + refcount = comm.Get_attr(refcount_keyval) + refcount[0] -= 1 + if refcount[0] == 0: + free_comm(comm) + elif comm != MPI.COMM_NULL: + free_comm(comm) + + +def dup_comm(comm_in): + """Given a communicator return a communicator for internal use. + + :arg comm_in: Communicator to duplicate + + :returns internal_comm: An internal (PyOP2) communicator.""" + assert not is_pyop2_comm(comm_in) + + # Check if communicator has an embedded PyOP2 comm. + internal_comm = comm_in.Get_attr(innercomm_keyval) + if internal_comm is None: + # Haven't seen this comm before, duplicate it. + internal_comm = comm_in.Dup() + comm_in.Set_attr(innercomm_keyval, internal_comm) + internal_comm.Set_attr(outercomm_keyval, comm_in) + # Name + internal_comm.Set_name(f"{comm_in.name or comm_in.py2f()}_DUP") + # Refcount + internal_comm.Set_attr(refcount_keyval, [0]) + incref(internal_comm) + # Remember we need to destroy it. + dupped_comms.append(internal_comm) + elif is_pyop2_comm(internal_comm): + # Inner comm is a PyOP2 comm, return it + incref(internal_comm) + else: + raise ValueError("Inner comm is not a PyOP2 comm") + return internal_comm + + +@collective +def create_split_comm(comm): + """ Create a split communicator based on either shared memory access + if using MPI >= 3, or shared local disk access if using MPI <= 3. + Used internally for creating compilation communicators + + :arg comm: A communicator to split + + :return split_comm: A split communicator + """ + if MPI.VERSION >= 3: + debug("Creating compilation communicator using MPI_Split_type") + split_comm = comm.Split_type(MPI.COMM_TYPE_SHARED) + debug("Finished creating compilation communicator using MPI_Split_type") + else: + debug("Creating compilation communicator using MPI_Split + filesystem") + if comm.rank == 0: + if not os.path.exists(configuration["cache_dir"]): + os.makedirs(configuration["cache_dir"], exist_ok=True) + tmpname = tempfile.mkdtemp(prefix="rank-determination-", + dir=configuration["cache_dir"]) + else: + tmpname = None + tmpname = comm.bcast(tmpname, root=0) + if tmpname is None: + raise CompilationError("Cannot determine sharedness of filesystem") + # Touch file + debug("Made tmpdir %s" % tmpname) + with open(os.path.join(tmpname, str(comm.rank)), "wb"): + pass + comm.barrier() + ranks = sorted(int(os.path.basename(name)) + for name in glob.glob("%s/[0-9]*" % tmpname)) + debug("Creating compilation communicator using filesystem colors") + split_comm = comm.Split(color=min(ranks), key=comm.rank) + debug("Finished creating compilation communicator using filesystem colors") + # Name + split_comm.Set_name(f"{comm.name or comm.py2f()}_COMPILATION") + # Refcount + split_comm.Set_attr(refcount_keyval, [0]) + incref(split_comm) + return split_comm def get_compilation_comm(comm): return comm.Get_attr(compilationcomm_keyval) -def set_compilation_comm(comm, inner): - comm.Set_attr(compilationcomm_keyval, inner) +def set_compilation_comm(comm, comp_comm): + """Stash the compilation communicator (``comp_comm``) on the + PyOP2 communicator ``comm`` + + :arg comm: A PyOP2 Communicator + :arg comp_comm: The compilation communicator + """ + if not is_pyop2_comm(comm): + raise ValueError("Compilation communicator must be stashed on a PyOP2 comm") + + # Check if the compilation communicator is already set + old_comp_comm = comm.Get_attr(compilationcomm_keyval) + if old_comp_comm is not None: + if is_pyop2_comm(old_comp_comm): + raise ValueError("Compilation communicator is not a PyOP2 comm, something is very broken!") + else: + decref(old_comp_comm) + + if not is_pyop2_comm(comp_comm): + raise ValueError( + "Communicator used for compilation communicator must be a PyOP2 communicator.\n" + "Use pyop2.mpi.dup_comm() to create a PyOP2 comm from an existing comm.") + else: + # Stash `comp_comm` as an attribute on `comm` + comm.Set_attr(compilationcomm_keyval, comp_comm) + +@collective +def compilation_comm(comm): + """Get a communicator for compilation. -def free_comm(comm, remove=True): + :arg comm: The input communicator, must be a PyOP2 comm. + :returns: A communicator used for compilation (may be smaller) + """ + if not is_pyop2_comm(comm): + raise ValueError("Compilation communicator is not a PyOP2 comm") + # Should we try and do node-local compilation? + if configuration["node_local_compilation"]: + retcomm = get_compilation_comm(comm) + if retcomm is not None: + debug("Found existing compilation communicator") + debug(f"{retcomm.name}") + else: + retcomm = create_split_comm(comm) + set_compilation_comm(comm, retcomm) + # Add to list of known duplicated comms + debug(f"Appending compiler comm {retcomm.name} to list of comms") + dupped_comms.append(retcomm) + else: + retcomm = comm + incref(retcomm) + return retcomm + + +def free_comm(comm): """Free an internal communicator. :arg comm: The communicator to free. @@ -183,21 +401,8 @@ def free_comm(comm, remove=True): This only actually calls MPI_Comm_free once the refcount drops to zero. """ - if comm == MPI.COMM_NULL: - return - refcount = comm.Get_attr(refcount_keyval) - if refcount is None: - # Not a PyOP2 communicator, check for an embedded comm. - comm = comm.Get_attr(innercomm_keyval) - if comm is None: - raise ValueError("Trying to destroy communicator not known to PyOP2") - refcount = comm.Get_attr(refcount_keyval) - if refcount is None: - raise ValueError("Inner comm without a refcount") - - refcount[0] -= 1 - - if refcount[0] == 0: + if comm != MPI.COMM_NULL: + assert is_pyop2_comm(comm) ocomm = comm.Get_attr(outercomm_keyval) if ocomm is not None: icomm = ocomm.Get_attr(innercomm_keyval) @@ -206,44 +411,58 @@ def free_comm(comm, remove=True): else: ocomm.Delete_attr(innercomm_keyval) del icomm - if remove: - # Only do this if not called from free_comms. + try: dupped_comms.remove(comm) + except ValueError: + debug(f"{comm.name} is not in list of known comms, probably already freed") + debug(f"Known comms are {[d.name for d in dupped_comms if d != MPI.COMM_NULL]}") compilation_comm = get_compilation_comm(comm) - if compilation_comm is not None: - compilation_comm.Free() + if compilation_comm == MPI.COMM_NULL: + comm.Delete_attr(compilationcomm_keyval) + elif compilation_comm is not None: + free_comm(compilation_comm) + comm.Delete_attr(compilationcomm_keyval) comm.Free() + else: + warning('Attempt to free MPI_COMM_NULL') @atexit.register -def free_comms(): +def _free_comms(): """Free all outstanding communicators.""" + global PYOP2_FINALIZED + PYOP2_FINALIZED = True + if logger.level > DEBUG: + debug = lambda string: None + else: + debug = lambda string: print(string) + debug("PyOP2 Finalizing") + # Collect garbage as it may hold on to communicator references + debug("Calling gc.collect()") + gc.collect() + debug(pyop2_comm_status()) + debug(f"Freeing comms in list (length {len(dupped_comms)})") while dupped_comms: - c = dupped_comms.pop() + c = dupped_comms[-1] refcount = c.Get_attr(refcount_keyval) - for _ in range(refcount[0]): - free_comm(c, remove=False) + debug(f"Freeing {c.name}, which has refcount {refcount[0]}") + free_comm(c) for kv in [refcount_keyval, innercomm_keyval, outercomm_keyval, compilationcomm_keyval]: MPI.Comm.Free_keyval(kv) + COMM_WORLD.Free() + COMM_SELF.Free() def hash_comm(comm): """Return a hashable identifier for a communicator.""" - # dup_comm returns a persistent internal communicator so we can - # use its id() as the hash since this is stable between invocations. - return id(dup_comm(comm)) - - -def collective(fn): - extra = trim(""" - This function is logically collective over MPI ranks, it is an - error to call it on fewer than all the ranks in MPI communicator. - """) - fn.__doc__ = "%s\n\n%s" % (trim(fn.__doc__), extra) if fn.__doc__ else extra - return fn + if not is_pyop2_comm(comm): + raise ValueError("`comm` passed to `hash_comm()` must be a PyOP2 communicator") + # `comm` must be a PyOP2 communicator so we can use its id() + # as the hash and this is stable between invocations. + return id(comm) # Install an exception hook to MPI Abort if an exception isn't caught diff --git a/pyop2/op2.py b/pyop2/op2.py index 1fe7f9d8a..726168e79 100644 --- a/pyop2/op2.py +++ b/pyop2/op2.py @@ -69,6 +69,9 @@ _initialised = False +# set the log level +set_log_level(configuration['log_level']) + def initialised(): """Check whether PyOP2 has been yet initialised but not yet finalised.""" @@ -101,7 +104,6 @@ def init(**kwargs): configuration.reconfigure(**kwargs) set_log_level(configuration['log_level']) - _initialised = True diff --git a/pyop2/parloop.py b/pyop2/parloop.py index 0ba340ee4..ac78e6bda 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -150,11 +150,12 @@ def __init__(self, global_knl, iterset, arguments): self.global_kernel = global_knl self.iterset = iterset + self.comm = mpi.internal_comm(iterset.comm) self.arguments, self.reduced_globals = self.prepare_reduced_globals(arguments, global_knl) - @property - def comm(self): - return self.iterset.comm + def __del__(self): + if hasattr(self, "comm"): + mpi.decref(self.comm) @property def local_kernel(self): @@ -454,8 +455,7 @@ def _check_frozen_access_modes(cls, local_knl, arguments): "Dats with frozen halos must always be accessed with the same access mode" ) - @classmethod - def prepare_reduced_globals(cls, arguments, global_knl): + def prepare_reduced_globals(self, arguments, global_knl): """Swap any :class:`GlobalParloopArg` instances that are INC'd into with zeroed replacements. @@ -465,9 +465,9 @@ def prepare_reduced_globals(cls, arguments, global_knl): """ arguments = list(arguments) reduced_globals = {} - for i, (lk_arg, gk_arg, pl_arg) in enumerate(cls.zip_arguments(global_knl, arguments)): + for i, (lk_arg, gk_arg, pl_arg) in enumerate(self.zip_arguments(global_knl, arguments)): if isinstance(gk_arg, GlobalKernelArg) and lk_arg.access == Access.INC: - tmp = Global(gk_arg.dim, data=np.zeros_like(pl_arg.data.data_ro), dtype=lk_arg.dtype) + tmp = Global(gk_arg.dim, data=np.zeros_like(pl_arg.data.data_ro), dtype=lk_arg.dtype, comm=self.comm) reduced_globals[tmp] = pl_arg arguments[i] = GlobalParloopArg(tmp) diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 03df1937b..615a2f82c 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -81,13 +81,17 @@ def __init__(self, dataset, data=None, dtype=None, name=None): EmptyDataMixin.__init__(self, data, dtype, self._shape) self._dataset = dataset - self.comm = dataset.comm + self.comm = mpi.internal_comm(dataset.comm) self.halo_valid = True self._name = name or "dat_#x%x" % id(self) self._halo_frozen = False self._frozen_access_mode = None + def __del__(self): + if hasattr(self, "comm"): + mpi.decref(self.comm) + @utils.cached_property def _kernel_args_(self): return (self._data.ctypes.data, ) @@ -768,7 +772,7 @@ def what(x): if not all(d.dtype == self._dats[0].dtype for d in self._dats): raise ex.DataValueError('MixedDat with different dtypes is not supported') # TODO: Think about different communicators on dats (c.f. MixedSet) - self.comm = self._dats[0].comm + self.comm = mpi.internal_comm(self._dats[0].comm) @property def dat_version(self): diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 635b130e3..4e114032a 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -29,12 +29,19 @@ def __init__(self, iter_set, dim=1, name=None): return if isinstance(iter_set, Subset): raise NotImplementedError("Deriving a DataSet from a Subset is unsupported") + self.comm = mpi.internal_comm(iter_set.comm) self._set = iter_set self._dim = utils.as_tuple(dim, numbers.Integral) self._cdim = np.prod(self._dim).item() self._name = name or "dset_#x%x" % id(self) self._initialized = True + def __del__(self): + # Cannot use hasattr here, since we define `__getattr__` + # This causes infinite recursion when looked up! + if "comm" in self.__dict__: + mpi.decref(self.comm) + @classmethod def _process_args(cls, *args, **kwargs): return (args[0], ) + args, kwargs @@ -59,7 +66,6 @@ def __setstate__(self, d): def __getattr__(self, name): """Returns a Set specific attribute.""" value = getattr(self.set, name) - setattr(self, name, value) return value def __getitem__(self, idx): @@ -202,10 +208,13 @@ class GlobalDataSet(DataSet): def __init__(self, global_): """ :param global_: The :class:`Global` on which this object is based.""" - + if self._initialized: + return self._global = global_ + self.comm = mpi.internal_comm(global_.comm) self._globalset = GlobalSet(comm=self.comm) self._name = "gdset_#x%x" % id(self) + self._initialized = True @classmethod def _cache_key(cls, *args): @@ -227,11 +236,6 @@ def name(self): """Returns the name of the data set.""" return self._global._name - @utils.cached_property - def comm(self): - """Return the communicator on which the set is defined.""" - return self._global.comm - @utils.cached_property def set(self): """Returns the parent set of the data set.""" @@ -371,6 +375,13 @@ def __init__(self, arg, dims=None): if self._initialized: return self._dsets = arg + try: + # Try to choose the comm to be the same as the first set + # of the MixedDataSet + comm = self._process_args(arg, dims)[0][0].comm + except AttributeError: + comm = None + self.comm = mpi.internal_comm(comm) self._initialized = True @classmethod diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index 883d99914..427118431 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -44,16 +44,20 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None): self.__init__(dim._dim, None, dtype=dim.dtype, name="copy_of_%s" % dim.name, comm=dim.comm) dim.copy(self) - return - self._dim = utils.as_tuple(dim, int) - self._cdim = np.prod(self._dim).item() - EmptyDataMixin.__init__(self, data, dtype, self._dim) - self._buf = np.empty(self.shape, dtype=self.dtype) - self._name = name or "global_#x%x" % id(self) - self.comm = comm - # Object versioning setup - petsc_counter = (self.comm and self.dtype == PETSc.ScalarType) - VecAccessMixin.__init__(self, petsc_counter=petsc_counter) + else: + self._dim = utils.as_tuple(dim, int) + self._cdim = np.prod(self._dim).item() + EmptyDataMixin.__init__(self, data, dtype, self._dim) + self._buf = np.empty(self.shape, dtype=self.dtype) + self._name = name or "global_#x%x" % id(self) + self.comm = mpi.internal_comm(comm) + # Object versioning setup + petsc_counter = (comm and self.dtype == PETSc.ScalarType) + VecAccessMixin.__init__(self, petsc_counter=petsc_counter) + + def __del__(self): + if hasattr(self, "comm"): + mpi.decref(self.comm) @utils.cached_property def _kernel_args_(self): diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 7eedbdc50..91224d52a 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -10,6 +10,7 @@ exceptions as ex, utils ) +from pyop2 import mpi from pyop2.types.set import GlobalSet, MixedSet, Set @@ -35,7 +36,7 @@ class Map: def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, offset_quotient=None): self._iterset = iterset self._toset = toset - self.comm = toset.comm + self.comm = mpi.internal_comm(toset.comm) self._arity = arity self._values = utils.verify_reshape(values, dtypes.IntType, (iterset.total_size, arity), allow_none=True) @@ -52,6 +53,10 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, o # A cache for objects built on top of this map self._cache = {} + def __del__(self): + if hasattr(self, "comm"): + mpi.decref(self.comm) + @utils.cached_property def _kernel_args_(self): return (self._values.ctypes.data, ) @@ -195,6 +200,7 @@ def __init__(self, map_, permutation): if isinstance(map_, ComposedMap): raise NotImplementedError("PermutedMap of ComposedMap not implemented: simply permute before composing") self.map_ = map_ + self.comm = mpi.internal_comm(map_.comm) self.permutation = np.asarray(permutation, dtype=Map.dtype) assert (np.unique(permutation) == np.arange(map_.arity, dtype=Map.dtype)).all() @@ -245,7 +251,7 @@ def __init__(self, *maps_, name=None): raise ex.MapTypeError("frommap.arity must be 1") self._iterset = maps_[-1].iterset self._toset = maps_[0].toset - self.comm = self._toset.comm + self.comm = mpi.internal_comm(self._toset.comm) self._arity = maps_[0].arity # Don't call super().__init__() to avoid calling verify_reshape() self._values = None @@ -309,7 +315,7 @@ def __init__(self, maps): raise ex.MapTypeError("All maps needs to share a communicator") if len(comms) == 0: raise ex.MapTypeError("Don't know how to make communicator") - self.comm = comms[0] + self.comm = mpi.internal_comm(comms[0]) self._initialized = True @classmethod diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index de89b1421..aefd77de1 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -68,11 +68,11 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._o_nnz = None self._nrows = None if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].toset.size self._ncols = None if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].toset.size - self.lcomm = dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm - self.rcomm = dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm + self.lcomm = mpi.internal_comm(dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm) + self.rcomm = mpi.internal_comm(dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm) else: - self.lcomm = self._rmaps[0].comm - self.rcomm = self._cmaps[0].comm + self.lcomm = mpi.internal_comm(self._rmaps[0].comm) + self.rcomm = mpi.internal_comm(self._cmaps[0].comm) rset, cset = self.dsets # All rmaps and cmaps have the same data set - just use the first. @@ -93,10 +93,8 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, if self.lcomm != self.rcomm: raise ValueError("Haven't thought hard enough about different left and right communicators") - self.comm = self.lcomm - + self.comm = mpi.internal_comm(self.lcomm) self._name = name or "sparsity_#x%x" % id(self) - self.iteration_regions = iteration_regions # If the Sparsity is defined on MixedDataSets, we need to build each # block separately @@ -131,6 +129,14 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._blocks = [[self]] self._initialized = True + def __del__(self): + if hasattr(self, "comm"): + mpi.decref(self.comm) + if hasattr(self, "lcomm"): + mpi.decref(self.lcomm) + if hasattr(self, "rcomm"): + mpi.decref(self.rcomm) + _cache = {} @classmethod @@ -363,6 +369,10 @@ class SparsityBlock(Sparsity): This class only implements the properties necessary to infer its shape. It does not provide arrays of non zero fill.""" def __init__(self, parent, i, j): + # Protect against re-initialization when retrieved from cache + if self._initialized: + return + self._dsets = (parent.dsets[0][i], parent.dsets[1][j]) self._rmaps = tuple(m.split[i] for m in parent.rmaps) self._cmaps = tuple(m.split[j] for m in parent.cmaps) @@ -373,10 +383,11 @@ def __init__(self, parent, i, j): self._dims = tuple([tuple([parent.dims[i][j]])]) self._blocks = [[self]] self.iteration_regions = parent.iteration_regions - self.lcomm = self.dsets[0].comm - self.rcomm = self.dsets[1].comm + self.lcomm = mpi.internal_comm(self.dsets[0].comm) + self.rcomm = mpi.internal_comm(self.dsets[1].comm) # TODO: think about lcomm != rcomm - self.comm = self.lcomm + self.comm = mpi.internal_comm(self.lcomm) + self._initialized = True @classmethod def _process_args(cls, *args, **kwargs): @@ -434,14 +445,22 @@ class AbstractMat(DataCarrier, abc.ABC): ('name', str, ex.NameTypeError)) def __init__(self, sparsity, dtype=None, name=None): self._sparsity = sparsity - self.lcomm = sparsity.lcomm - self.rcomm = sparsity.rcomm - self.comm = sparsity.comm + self.lcomm = mpi.internal_comm(sparsity.lcomm) + self.rcomm = mpi.internal_comm(sparsity.rcomm) + self.comm = mpi.internal_comm(sparsity.comm) dtype = dtype or dtypes.ScalarType self._datatype = np.dtype(dtype) self._name = name or "mat_#x%x" % id(self) self.assembly_state = Mat.ASSEMBLED + def __del__(self): + if hasattr(self, "comm"): + mpi.decref(self.comm) + if hasattr(self, "lcomm"): + mpi.decref(self.lcomm) + if hasattr(self, "rcomm"): + mpi.decref(self.rcomm) + @utils.validate_in(('access', _modes, ex.ModeValueError)) def __call__(self, access, path, lgmaps=None, unroll_map=False): from pyop2.parloop import MatLegacyArg, MixedMatLegacyArg @@ -939,7 +958,7 @@ def __init__(self, parent, i, j): colis = cset.local_ises[j] self.handle = parent.handle.getLocalSubMatrix(isrow=rowis, iscol=colis) - self.comm = parent.comm + self.comm = mpi.internal_comm(parent.comm) self.local_to_global_maps = self.handle.getLGMap() @property @@ -1094,7 +1113,8 @@ def mult(self, mat, x, y): a[0] = x.array_r else: x.array_r - x.comm.tompi4py().bcast(a) + with mpi.temp_internal_comm(x.comm) as comm: + comm.bcast(a) return y.scale(a) else: return v.pointwiseMult(x, y) @@ -1110,7 +1130,8 @@ def multTranspose(self, mat, x, y): a[0] = x.array_r else: x.array_r - x.comm.tompi4py().bcast(a) + with mpi.temp_internal_comm(x.comm) as comm: + comm.bcast(a) y.scale(a) else: v.pointwiseMult(x, y) @@ -1134,7 +1155,8 @@ def multTransposeAdd(self, mat, x, y, z): a[0] = x.array_r else: x.array_r - x.comm.tompi4py().bcast(a) + with mpi.temp_internal_comm(x.comm) as comm: + comm.bcast(a) if y == z: # Last two arguments are aliased. tmp = y.duplicate() diff --git a/pyop2/types/set.py b/pyop2/types/set.py index fed118b1c..1f6ea30c8 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -65,7 +65,7 @@ def _wrapper_cache_key_(self): @utils.validate_type(('size', (numbers.Integral, tuple, list, np.ndarray), ex.SizeTypeError), ('name', str, ex.NameTypeError)) def __init__(self, size, name=None, halo=None, comm=None): - self.comm = mpi.dup_comm(comm) + self.comm = mpi.internal_comm(comm) if isinstance(size, numbers.Integral): size = [size] * 3 size = utils.as_tuple(size, numbers.Integral, 3) @@ -78,6 +78,12 @@ def __init__(self, size, name=None, halo=None, comm=None): # A cache of objects built on top of this set self._cache = {} + def __del__(self): + # Cannot use hasattr here, since child classes define `__getattr__` + # This causes infinite recursion when looked up! + if "comm" in self.__dict__: + mpi.decref(self.comm) + @utils.cached_property def core_size(self): """Core set size. Owned elements not touching halo elements.""" @@ -219,7 +225,7 @@ class GlobalSet(Set): _argtypes_ = () def __init__(self, comm=None): - self.comm = mpi.dup_comm(comm) + self.comm = mpi.internal_comm(comm) self._cache = {} @utils.cached_property @@ -304,6 +310,7 @@ class ExtrudedSet(Set): @utils.validate_type(('parent', Set, TypeError)) def __init__(self, parent, layers, extruded_periodic=False): self._parent = parent + self.comm = mpi.internal_comm(parent.comm) try: layers = utils.verify_reshape(layers, dtypes.IntType, (parent.total_size, 2)) self.constant_layers = False @@ -341,7 +348,6 @@ def _wrapper_cache_key_(self): def __getattr__(self, name): """Returns a :class:`Set` specific attribute.""" value = getattr(self._parent, name) - setattr(self, name, value) return value def __contains__(self, set): @@ -385,6 +391,8 @@ class Subset(ExtrudedSet): @utils.validate_type(('superset', Set, TypeError), ('indices', (list, tuple, np.ndarray), TypeError)) def __init__(self, superset, indices): + self.comm = mpi.internal_comm(superset.comm) + # sort and remove duplicates indices = np.unique(indices) if isinstance(superset, Subset): @@ -420,7 +428,6 @@ def _argtypes_(self): def __getattr__(self, name): """Returns a :class:`Set` specific attribute.""" value = getattr(self._superset, name) - setattr(self, name, value) return value def __pow__(self, e): @@ -528,9 +535,13 @@ def __init__(self, sets): assert all(s is None or isinstance(s, GlobalSet) or ((s.layers == self._sets[0].layers).all() if s.layers is not None else True) for s in sets), \ "All components of a MixedSet must have the same number of layers." # TODO: do all sets need the same communicator? - self.comm = functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.comm, sets)) + self.comm = mpi.internal_comm(functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.comm, sets))) self._initialized = True + def __del__(self): + if self._initialized and hasattr(self, "comm"): + mpi.decref(self.comm) + @utils.cached_property def _kernel_args_(self): raise NotImplementedError diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index ff103bfd2..f175bc76f 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -540,10 +540,11 @@ def myfunc(arg): """Example function to cache the outputs of.""" return {arg} - @staticmethod - def collective_key(*args): + def collective_key(self, *args): """Return a cache key suitable for use when collective over a communicator.""" - return mpi.COMM_SELF, cachetools.keys.hashkey(*args) + # Explicitly `mpi.decref(self.comm)` in any test that uses this comm + self.comm = mpi.internal_comm(mpi.COMM_SELF) + return self.comm, cachetools.keys.hashkey(*args) @pytest.fixture def cache(cls): @@ -580,6 +581,7 @@ def test_decorator_collective_has_different_in_memory_key(self, cache, cachedir) assert obj1 == obj2 and obj1 is not obj2 assert len(cache) == 2 assert len(os.listdir(cachedir.name)) == 1 + mpi.decref(self.comm) def test_decorator_disk_cache_reuses_results(self, cache, cachedir): decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) diff --git a/test/unit/test_matrices.py b/test/unit/test_matrices.py index a84ea1aac..34b467e21 100644 --- a/test/unit/test_matrices.py +++ b/test/unit/test_matrices.py @@ -822,7 +822,7 @@ def mat(self, request, msparsity, non_nest_mixed_sparsity): def test_mat_starts_assembled(self, mat): assert mat.assembly_state is op2.Mat.ASSEMBLED for m in mat: - assert mat.assembly_state is op2.Mat.ASSEMBLED + assert m.assembly_state is op2.Mat.ASSEMBLED def test_after_set_local_state_is_insert(self, mat): mat[0, 0].set_local_diagonal_entries([0])