From 877def61c291a8db78c36dd58bb0158e3be380ca Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 8 Aug 2024 12:58:53 +0100 Subject: [PATCH 01/38] Remove hash_comm and add per-comm caches --- pyop2/caching.py | 23 +++++++++++++---------- pyop2/mpi.py | 21 ++++++++------------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 0f036212f..ef72799f6 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -41,7 +41,7 @@ import cachetools from pyop2.configuration import configuration -from pyop2.mpi import hash_comm +from pyop2.mpi import comm_cache_keyval from pyop2.utils import cached_property @@ -240,8 +240,7 @@ def cache_key(self): .. note:: If you intend to use this decorator to cache things that are collective across a communicator then you must include the communicator as part of - the cache key. Since communicators are themselves not hashable you should - use :func:`pyop2.mpi.hash_comm`. + the cache key. You should also make sure to use unbounded caches as otherwise some ranks may evict results leading to deadlocks. @@ -266,37 +265,41 @@ def decorator(func): def wrapper(*args, **kwargs): if collective: comm, disk_key = key(*args, **kwargs) - disk_key = _as_hexdigest(disk_key) - k = hash_comm(comm), disk_key + k = _as_hexdigest(disk_key) + local_cache = comm.Get_attr(comm_cache_keyval) + if local_cache is None: + local_cache = {} + comm.Set_attr(comm_cache_keyval, local_cache) else: k = _as_hexdigest(key(*args, **kwargs)) + local_cache = cache # first try the in-memory cache try: - return cache[k] + return local_cache[k] except KeyError: pass # then try to retrieve from disk if collective: if comm.rank == 0: - v = _disk_cache_get(cachedir, disk_key) + v = _disk_cache_get(cachedir, k) comm.bcast(v, root=0) else: v = comm.bcast(None, root=0) else: v = _disk_cache_get(cachedir, k) if v is not None: - return cache.setdefault(k, v) + return local_cache.setdefault(k, v) # if all else fails call func and populate the caches v = func(*args, **kwargs) if collective: if comm.rank == 0: - _disk_cache_set(cachedir, disk_key, v) + _disk_cache_set(cachedir, k, v) else: _disk_cache_set(cachedir, k, v) - return cache.setdefault(k, v) + return local_cache.setdefault(k, v) return wrapper return decorator diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 554155f20..0237433cb 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -227,6 +227,7 @@ def delcomm_outer(comm, keyval, icomm): innercomm_keyval = MPI.Comm.Create_keyval(delete_fn=delcomm_outer) outercomm_keyval = MPI.Comm.Create_keyval() compilationcomm_keyval = MPI.Comm.Create_keyval(delete_fn=delcomm_outer) +comm_cache_keyval = MPI.Comm.Create_keyval() def is_pyop2_comm(comm): @@ -539,22 +540,16 @@ def _free_comms(): debug(f"Freeing {comm.name}, with index {key}, which has refcount {refcount[0]}") comm.Free() del _DUPED_COMM_DICT[key] - for kv in [refcount_keyval, - innercomm_keyval, - outercomm_keyval, - compilationcomm_keyval]: + for kv in [ + refcount_keyval, + innercomm_keyval, + outercomm_keyval, + compilationcomm_keyval, + comm_cache_keyval + ]: MPI.Comm.Free_keyval(kv) -def hash_comm(comm): - """Return a hashable identifier for a communicator.""" - if not is_pyop2_comm(comm): - raise PyOP2CommError("`comm` passed to `hash_comm()` must be a PyOP2 communicator") - # `comm` must be a PyOP2 communicator so we can use its id() - # as the hash and this is stable between invocations. - return id(comm) - - # Install an exception hook to MPI Abort if an exception isn't caught # see: https://groups.google.com/d/msg/mpi4py/me2TFzHmmsQ/sSF99LE0t9QJ if COMM_WORLD.size > 1: From 767a21f7428f554a0970e86899869b49e2c145cc Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 8 Aug 2024 12:59:56 +0100 Subject: [PATCH 02/38] Update the TestDiskCachedDecorator test for collective --- test/unit/test_caching.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index 40c4256fb..2ffaaf6f9 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -39,6 +39,7 @@ import numpy from pyop2 import op2, mpi from pyop2.caching import disk_cached +from pyop2.mpi import comm_cache_keyval def _seed(): @@ -555,20 +556,23 @@ def test_decorator_in_memory_cache_reuses_results(self, cache, cachedir): assert len(cache) == 1 assert len(os.listdir(cachedir.name)) == 1 - def test_decorator_collective_has_different_in_memory_key(self, cache, cachedir): + def test_decorator_collective_uses_different_in_memory_caches(self, cache, cachedir): decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) collective_func = disk_cached(cache, cachedir.name, self.collective_key, collective=True)(self.myfunc) + # obj1 should be cached on the comm cache and not the self.cache obj1 = collective_func("input1") - assert len(cache) == 1 + comm_cache = self.comm.Get_attr(comm_cache_keyval) + assert len(cache) == 0 + assert len(comm_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 - # The new entry should have a different in-memory key since the communicator - # is not included but the same key on disk. + # obj2 should be cached on the self.cache and not the comm cache obj2 = decorated_func("input1") assert obj1 == obj2 and obj1 is not obj2 - assert len(cache) == 2 + assert len(cache) == 1 + assert len(comm_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 def test_decorator_disk_cache_reuses_results(self, cache, cachedir): From e35a589e5fed32902207b505d9033755c4d3f96b Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 8 Aug 2024 14:26:55 +0100 Subject: [PATCH 03/38] Refactor disk_cached wrapper and add warning --- pyop2/caching.py | 71 ++++++++++++++++++++++++--------------- test/unit/test_caching.py | 5 +-- 2 files changed, 46 insertions(+), 30 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index ef72799f6..e0d575054 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -33,12 +33,12 @@ """Provides common base classes for cached objects.""" +import cachetools import hashlib import os -from pathlib import Path import pickle - -import cachetools +from pathlib import Path +from warnings import warn from pyop2.configuration import configuration from pyop2.mpi import comm_cache_keyval @@ -261,45 +261,60 @@ def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=Fa if cachedir is None: cachedir = configuration["cache_dir"] + if collective and cache is not None: + warn( + "Global cache for collective disk cached call will not be used. " + "Pass `None` as the first argument" + ) + def decorator(func): - def wrapper(*args, **kwargs): - if collective: + if not collective: + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory then disk cache + before falling back on calling the function and populating the + caches. + """ + k = _as_hexdigest(key(*args, **kwargs)) + try: + v = cache[k] + except KeyError: + v = _disk_cache_get(cachedir, k) + + if v is None: + v = func(*args, **kwargs) + _disk_cache_set(cachedir, k, v) + return cache.setdefault(k, v) + + else: # Collective + def wrapper(*args, **kwargs): + """ Same as above, but in parallel over `comm` + """ comm, disk_key = key(*args, **kwargs) k = _as_hexdigest(disk_key) + + # Fetch the per-comm cache and set it up if not present local_cache = comm.Get_attr(comm_cache_keyval) if local_cache is None: local_cache = {} comm.Set_attr(comm_cache_keyval, local_cache) - else: - k = _as_hexdigest(key(*args, **kwargs)) - local_cache = cache - - # first try the in-memory cache - try: - return local_cache[k] - except KeyError: - pass - # then try to retrieve from disk - if collective: + # Grab value from rank 0 memory/disk cache and broadcast result if comm.rank == 0: - v = _disk_cache_get(cachedir, k) + try: + v = local_cache[k] + except KeyError: + v = _disk_cache_get(cachedir, k) comm.bcast(v, root=0) else: v = comm.bcast(None, root=0) - else: - v = _disk_cache_get(cachedir, k) - if v is not None: + + if v is None: + v = func(*args, **kwargs) + # Only write to the disk cache on rank 0 + if comm.rank == 0: + _disk_cache_set(cachedir, k, v) return local_cache.setdefault(k, v) - # if all else fails call func and populate the caches - v = func(*args, **kwargs) - if collective: - if comm.rank == 0: - _disk_cache_set(cachedir, k, v) - else: - _disk_cache_set(cachedir, k, v) - return local_cache.setdefault(k, v) return wrapper return decorator diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index 2ffaaf6f9..e78ff728d 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -558,8 +558,9 @@ def test_decorator_in_memory_cache_reuses_results(self, cache, cachedir): def test_decorator_collective_uses_different_in_memory_caches(self, cache, cachedir): decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) - collective_func = disk_cached(cache, cachedir.name, self.collective_key, - collective=True)(self.myfunc) + collective_func = disk_cached( + None, cachedir.name, self.collective_key, collective=True + )(self.myfunc) # obj1 should be cached on the comm cache and not the self.cache obj1 = collective_func("input1") From 64f561674fc2eefd2d85e7126bfe81e2986bd641 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 8 Aug 2024 18:31:12 +0100 Subject: [PATCH 04/38] Refactor memory/disk caches in order to remove use of id() in GlobalKernel --- pyop2/caching.py | 76 ++++++++++++++++++++++++++++++++++-------- pyop2/global_kernel.py | 21 ++++++------ 2 files changed, 74 insertions(+), 23 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index e0d575054..ac2fbe6a2 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -41,7 +41,8 @@ from warnings import warn from pyop2.configuration import configuration -from pyop2.mpi import comm_cache_keyval +from pyop2.logger import debug +from pyop2.mpi import comm_cache_keyval, COMM_WORLD from pyop2.utils import cached_property @@ -247,6 +248,54 @@ def cache_key(self): """ +def default_parallel_hashkey(comm, *args, **kwargs): + return comm, cachetools.keys.hashkey(*args, **kwargs) + + +def parallel_memory_only_cache(key=default_parallel_hashkey): + """Decorator for wrapping a function to be called over a communiucator in a + cache that stores values in memory. + + :arg key: Callable returning the cache key for the function inputs. This + function must return a 2-tuple where the first entry is the + communicator to be collective over and the second is the key. This is + required to ensure that deadlocks do not occur when using different + subcommunicators. + """ + def decorator(func): + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory cache before falling back + on calling the function and populating the cache. + """ + comm, mem_key = key(*args, **kwargs) + k = _as_hexdigest(mem_key) + + # Fetch the per-comm cache or set it up if not present + local_cache = comm.Get_attr(comm_cache_keyval) + if local_cache is None: + local_cache = {} + comm.Set_attr(comm_cache_keyval, local_cache) + + # Grab value from rank 0 memory cache and broadcast result + if comm.rank == 0: + + v = local_cache.get(k) + if v is None: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') + else: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit') + comm.bcast(v, root=0) + else: + v = comm.bcast(None, root=0) + + if v is None: + v = func(*args, **kwargs) + return local_cache.setdefault(k, v) + + return wrapper + return decorator + + def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False): """Decorator for wrapping a function in a cache that stores values in memory and to disk. @@ -277,33 +326,34 @@ def wrapper(*args, **kwargs): k = _as_hexdigest(key(*args, **kwargs)) try: v = cache[k] + debug(f'Serial: {k} memory cache hit') except KeyError: + debug(f'Serial: {k} memory cache miss') v = _disk_cache_get(cachedir, k) + if v is not None: + debug(f'Serial: {k} disk cache hit') if v is None: + debug(f'Serial: {k} disk cache miss') v = func(*args, **kwargs) _disk_cache_set(cachedir, k, v) return cache.setdefault(k, v) else: # Collective + @parallel_memory_only_cache(key=key) def wrapper(*args, **kwargs): """ Same as above, but in parallel over `comm` """ comm, disk_key = key(*args, **kwargs) k = _as_hexdigest(disk_key) - # Fetch the per-comm cache and set it up if not present - local_cache = comm.Get_attr(comm_cache_keyval) - if local_cache is None: - local_cache = {} - comm.Set_attr(comm_cache_keyval, local_cache) - - # Grab value from rank 0 memory/disk cache and broadcast result + # Grab value from rank 0 disk cache and broadcast result if comm.rank == 0: - try: - v = local_cache[k] - except KeyError: - v = _disk_cache_get(cachedir, k) + v = _disk_cache_get(cachedir, k) + if v is not None: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} disk cache hit') + else: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} disk cache miss') comm.bcast(v, root=0) else: v = comm.bcast(None, root=0) @@ -313,7 +363,7 @@ def wrapper(*args, **kwargs): # Only write to the disk cache on rank 0 if comm.rank == 0: _disk_cache_set(cachedir, k, v) - return local_cache.setdefault(k, v) + return v return wrapper return decorator diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 536d717e9..79fbcaeee 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -11,7 +11,7 @@ from petsc4py import PETSc from pyop2 import compilation, mpi -from pyop2.caching import Cached +from pyop2.caching import Cached, parallel_memory_only_cache from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes from pyop2.types import IterationRegion, Constant, READ @@ -334,24 +334,25 @@ def __init__(self, local_kernel, arguments, *, self._initialized = True + @staticmethod + def _call_key(self, comm, *args): + return comm, (0,) + @mpi.collective + @parallel_memory_only_cache(key=_call_key) def __call__(self, comm, *args): """Execute the compiled kernel. :arg comm: Communicator the execution is collective over. :*args: Arguments to pass to the compiled kernel. """ - # If the communicator changes then we cannot safely use the in-memory - # function cache. Note here that we are not using dup_comm to get a - # stable communicator id because we will already be using the internal one. - key = id(comm) - try: - func = self._func_cache[key] - except KeyError: - func = self.compile(comm) - self._func_cache[key] = func + func = self.compile(comm) func(*args) + # This method has to return _something_ for the `@parallel_memory_only_cache` + # to function correctly + return 0 + @property def _wrapper_name(self): import warnings From 9a31c3da6879dd48d16646e4ff8350f02cc03cb6 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 8 Aug 2024 21:58:35 +0100 Subject: [PATCH 05/38] Rethink memory only cache for non-broadcastable values --- pyop2/caching.py | 56 +++++++++++++++++++++++++++++++++++--- pyop2/compilation.py | 62 +++++++++++++++++++++++++----------------- pyop2/global_kernel.py | 15 ++-------- 3 files changed, 91 insertions(+), 42 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index ac2fbe6a2..1fd6b876a 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -248,13 +248,17 @@ def cache_key(self): """ -def default_parallel_hashkey(comm, *args, **kwargs): +def default_parallel_hashkey(*args, **kwargs): + comm = kwargs.get('comm') return comm, cachetools.keys.hashkey(*args, **kwargs) def parallel_memory_only_cache(key=default_parallel_hashkey): - """Decorator for wrapping a function to be called over a communiucator in a - cache that stores values in memory. + """Memory only cache decorator. + + Decorator for wrapping a function to be called over a communiucator in a + cache that stores broadcastable values in memory. If the value is found in + the cache of rank 0 it is broadcast to all other ranks. :arg key: Callable returning the cache key for the function inputs. This function must return a 2-tuple where the first entry is the @@ -278,7 +282,6 @@ def wrapper(*args, **kwargs): # Grab value from rank 0 memory cache and broadcast result if comm.rank == 0: - v = local_cache.get(k) if v is None: debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') @@ -296,6 +299,51 @@ def wrapper(*args, **kwargs): return decorator +def parallel_memory_only_cache_no_broadcast(key=default_parallel_hashkey): + """Memory only cache decorator. + + Decorator for wrapping a function to be called over a communiucator in a + cache that stores non-broadcastable values in memory, for instance function + pointers. If the value is not present on all ranks, all ranks repeat the + work. + + :arg key: Callable returning the cache key for the function inputs. This + function must return a 2-tuple where the first entry is the + communicator to be collective over and the second is the key. This is + required to ensure that deadlocks do not occur when using different + subcommunicators. + """ + def decorator(func): + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory cache before falling back + on calling the function and populating the cache. + """ + comm, mem_key = key(*args, **kwargs) + k = _as_hexdigest(mem_key) + + # Fetch the per-comm cache or set it up if not present + local_cache = comm.Get_attr(comm_cache_keyval) + if local_cache is None: + local_cache = {} + comm.Set_attr(comm_cache_keyval, local_cache) + + # Grab value from all ranks memory cache and vote + v = local_cache.get(k) + if v is None: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') + else: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit') + all_present = comm.allgather(bool(v)) + + # If not present in the cache of all ranks, recompute on all ranks + if not min(all_present): + v = func(*args, **kwargs) + return local_cache.setdefault(k, v) + + return wrapper + return decorator + + def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False): """Decorator for wrapping a function in a cache that stores values in memory and to disk. diff --git a/pyop2/compilation.py b/pyop2/compilation.py index f4a1af36a..80a5b4ccc 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -42,9 +42,11 @@ import shlex from hashlib import md5 from packaging.version import Version, InvalidVersion +from textwrap import dedent from pyop2 import mpi +from pyop2.caching import parallel_memory_only_cache_no_broadcast from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError @@ -317,36 +319,42 @@ def get_so(self, jitmodule, extension): dirpart, basename = basename[:2], basename[2:] cachedir = os.path.join(cachedir, dirpart) pid = os.getpid() - cname = os.path.join(cachedir, "%s_p%d.%s" % (basename, pid, extension)) - oname = os.path.join(cachedir, "%s_p%d.o" % (basename, pid)) - soname = os.path.join(cachedir, "%s.so" % basename) + cname = os.path.join(cachedir, f"{basename}_p{pid}.{extension}") + oname = os.path.join(cachedir, f"{basename}_p{pid}.o") + soname = os.path.join(cachedir, f"{basename}.so") # Link into temporary file, then rename to shared library # atomically (avoiding races). - tmpname = os.path.join(cachedir, "%s_p%d.so.tmp" % (basename, pid)) + tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp") if configuration['check_src_hashes'] or configuration['debug']: matching = self.comm.allreduce(basename, op=_check_op) if matching != basename: # Dump all src code to disk for debugging output = os.path.join(configuration["cache_dir"], "mismatching-kernels") - srcfile = os.path.join(output, "src-rank%d.c" % self.comm.rank) + srcfile = os.path.join(output, f"src-rank{self.comm.rank}.{extension}") if self.comm.rank == 0: os.makedirs(output, exist_ok=True) self.comm.barrier() with open(srcfile, "w") as f: f.write(jitmodule.code_to_compile) self.comm.barrier() - raise CompilationError("Generated code differs across ranks (see output in %s)" % output) + raise CompilationError(f"Generated code differs across ranks (see output in {output})") + + # Check whether this shared object already written to disk try: - # Are we in the cache? - return ctypes.CDLL(soname) + dll = ctypes.CDLL(soname) except OSError: - # No, let's go ahead and build + dll = None + got_dll = bool(dll) + all_dll = self.comm.allgather(got_dll) + + # If the library is not loaded _on all ranks_ build it + if not min(all_dll): if self.comm.rank == 0: # No need to do this on all ranks os.makedirs(cachedir, exist_ok=True) - logfile = os.path.join(cachedir, "%s_p%d.log" % (basename, pid)) - errfile = os.path.join(cachedir, "%s_p%d.err" % (basename, pid)) + logfile = os.path.join(cachedir, f"{basename}_p{pid}.log") + errfile = os.path.join(cachedir, f"{basename}_p{pid}.err") with progress(INFO, 'Compiling wrapper'): with open(cname, "w") as f: f.write(jitmodule.code_to_compile) @@ -356,7 +364,7 @@ def get_so(self, jitmodule, extension): + compiler_flags \ + ('-o', tmpname, cname) \ + self.ldflags - debug('Compilation command: %s', ' '.join(cc)) + debug(f"Compilation command: {' '.join(cc)}") with open(logfile, "w") as log, open(errfile, "w") as err: log.write("Compilation command:\n") log.write(" ".join(cc)) @@ -371,11 +379,12 @@ def get_so(self, jitmodule, extension): else: subprocess.check_call(cc, stderr=err, stdout=log) except subprocess.CalledProcessError as e: - raise CompilationError( - """Command "%s" return error status %d. -Unable to compile code -Compile log in %s -Compile errors in %s""" % (e.cmd, e.returncode, logfile, errfile)) + raise CompilationError(dedent(f""" + Command "{e.cmd}" return error status {e.returncode}. + Unable to compile code + Compile log in {logfile} + Compile errors in {errfile} + """)) else: cc = (compiler,) \ + compiler_flags \ @@ -384,8 +393,8 @@ def get_so(self, jitmodule, extension): ld = tuple(shlex.split(self.ld)) \ + ('-o', tmpname, oname) \ + tuple(self.expandWl(self.ldflags)) - debug('Compilation command: %s', ' '.join(cc)) - debug('Link command: %s', ' '.join(ld)) + debug(f"Compilation command: {' '.join(cc)}", ) + debug(f"Link command: {' '.join(ld)}") with open(logfile, "a") as log, open(errfile, "a") as err: log.write("Compilation command:\n") log.write(" ".join(cc)) @@ -409,17 +418,19 @@ def get_so(self, jitmodule, extension): subprocess.check_call(cc, stderr=err, stdout=log) subprocess.check_call(ld, stderr=err, stdout=log) except subprocess.CalledProcessError as e: - raise CompilationError( - """Command "%s" return error status %d. -Unable to compile code -Compile log in %s -Compile errors in %s""" % (e.cmd, e.returncode, logfile, errfile)) + raise CompilationError(dedent(f""" + Command "{e.cmd}" return error status {e.returncode}. + Unable to compile code + Compile log in {logfile} + Compile errors in {errfile} + """)) # Atomically ensure soname exists os.rename(tmpname, soname) # Wait for compilation to complete self.comm.barrier() # Load resulting library - return ctypes.CDLL(soname) + dll = ctypes.CDLL(soname) + return dll class MacClangCompiler(Compiler): @@ -547,6 +558,7 @@ class AnonymousCompiler(Compiler): @mpi.collective +@parallel_memory_only_cache_no_broadcast() def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), argtypes=None, restype=None, comm=None): """Build a shared library and return a function pointer from it. diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 79fbcaeee..8f35038ff 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -11,7 +11,7 @@ from petsc4py import PETSc from pyop2 import compilation, mpi -from pyop2.caching import Cached, parallel_memory_only_cache +from pyop2.caching import Cached from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes from pyop2.types import IterationRegion, Constant, READ @@ -329,30 +329,19 @@ def __init__(self, local_kernel, arguments, *, self._iteration_region = iteration_region self._pass_layer_arg = pass_layer_arg - # Cache for stashing the compiled code - self._func_cache = {} - self._initialized = True - @staticmethod - def _call_key(self, comm, *args): - return comm, (0,) - @mpi.collective - @parallel_memory_only_cache(key=_call_key) def __call__(self, comm, *args): """Execute the compiled kernel. :arg comm: Communicator the execution is collective over. :*args: Arguments to pass to the compiled kernel. """ + # It is unnecessary to cache this call as it is cached in pyop2/compilation.py func = self.compile(comm) func(*args) - # This method has to return _something_ for the `@parallel_memory_only_cache` - # to function correctly - return 0 - @property def _wrapper_name(self): import warnings From 3f030af3f1a21cf453ed5a066ede00e5569fa267 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Mon, 12 Aug 2024 15:18:41 +0100 Subject: [PATCH 06/38] New caching implementation, WIP needs tidying and better tests --- pyop2/caching.py | 229 ++++++++++++++++++++++++++++++++++++++ pyop2/compilation.py | 15 ++- pyop2/global_kernel.py | 31 +----- test/unit/test_caching.py | 12 +- 4 files changed, 258 insertions(+), 29 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 1fd6b876a..eb90f2258 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -39,6 +39,7 @@ import pickle from pathlib import Path from warnings import warn +from functools import wraps from pyop2.configuration import configuration from pyop2.logger import debug @@ -234,6 +235,234 @@ def cache_key(self): return self._key +class _CacheMiss: + pass + + +CACHE_MISS = _CacheMiss() + + +class _CacheKey: + def __init__(self, key_value): + self.value = key_value + + +class DiskCachedObject: + def __new__(cls, *args, **kwargs): + if isinstance(args[0], _CacheKey): + return super().__new__(cls) + comm, disk_key = cls._key(*args, **kwargs) + k = _as_hexdigest((disk_key, cls.__qualname__)) + if comm.rank == 0: + value = _disk_cache_get(cls._cachedir, k) + if value is None: + value = CACHE_MISS + id_str = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " + if value is CACHE_MISS: + debug(id_str + f'Disk cache miss for {cls.__qualname__}({args}{kwargs})') + else: + debug(id_str + f'Disk cache hit for {cls.__qualname__}({args}{kwargs})') + # TODO: Add communication tags to avoid cross-broadcasting + comm.bcast(value, root=0) + else: + value = comm.bcast(CACHE_MISS, root=0) + if isinstance(value, _CacheMiss): + # We might have the CACHE_MISS from rank 0 and + # `(value is CACHE_MISS) == False` which is confusing, + # so we set it back to the local value + value = CACHE_MISS + + if value is CACHE_MISS: + # We can't call the constructor as `cls(*args, **kwargs)` since we + # would call `__new__` and recurse infinitely. The solution is to + # create a new object and pass that to `__init__` + value = object.__new__(cls) + value.__init__(*args, **kwargs) + value._cache_key = _CacheKey(k) + if comm.rank == 0: + # Only write to the disk cache on rank 0 + _disk_cache_set(cls._cachedir, k, value) + + debug("Disk cache modifying init") + cls.__init__ = cls._skip_init(cls.__init__) + return value + + @classmethod + def _skip_init(cls, init): + """This function allows a class to skip it's init method""" + def restore_init(*args, **kwargs): + debug("Disk reset init") + cls.__init__ = init + return restore_init + + def __init_subclass__(cls, cachedir=None, key=None, **kwargs): + if cachedir is None or key is None: + raise TypeError( + f"A `cache` and a `key` are required to subclass {__class__.__name__}.\n" + "Try declaring your subclass as follows:\n" + f"\tclass {cls.__name__}({cls.__bases__[0].__name__}, cachedir=my_cache, key=my_key)" + ) + super().__init_subclass__(**kwargs) + cls._cachedir = cachedir + cls._key = key + + def __getnewargs__(self): + return (self._cache_key, ) + + +# TODO: Implement this... +# class MemoryCachedObject: +# def __new__(cls, *args, **kwargs): +# k = cls._key(*args, **kwargs), cls.__qualname__ +# value = cls._cache.get(k, CACHE_MISS) +# if value is CACHE_MISS: +# print(f'Cache miss for {cls.__qualname__}({args}{kwargs})') +# # We can't call the constructor as `cls(*args, **kwargs)` since we +# # would call `__new__` and recurse infinitely. The solution is to +# # create a new object and pass that to `__init__` +# value = object.__new__(cls) +# value.__init__(*args, **kwargs) +# cls._cache[k] = value +# else: +# print(f'Cache hit for {cls.__qualname__}({args}{kwargs})') +# cls.__init__ = _skip_init(cls, cls.__init__) +# return value +# +# def __init_subclass__(cls, cache=None, key=None, **kwargs): +# if cache is None or key is None: +# raise TypeError( +# f"A `cache` and a `key` are required to subclass {__class__.__name__}.\n" +# "Try declaring your subclass as follows:\n" +# f"\tclass {cls.__name__}({cls.__bases__[0].__name__}, cache=my_cache, key=my_key)" +# ) +# super().__init_subclass__(**kwargs) +# cls._cache = cache +# cls._key = key + + +class MemoryAndDiskCachedObject(DiskCachedObject, cachedir="", key=""): + def __new__(cls, *args, **kwargs): + if isinstance(args[0], _CacheKey): + return super().__new__(cls, *args) + comm, disk_key = cls._key(*args, **kwargs) + # Throw the qualified name into the key as a string so the memory cache + # can be debugged (a little bit) by a human. This shouldn't really be + # necessary, but some classes do not implement a proper repr. + k = _as_hexdigest((disk_key, cls.__qualname__)), cls.__qualname__ + + # Fetch the per-comm cache or set it up if not present + # from pyop2.mpi import COMM_WORLD, comm_cache_keyval + local_cache = comm.Get_attr(comm_cache_keyval) + if local_cache is None: + local_cache = {} + comm.Set_attr(comm_cache_keyval, local_cache) + + id_str = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " + if comm.rank == 0: + value = local_cache.get(k, CACHE_MISS) + if value is CACHE_MISS: + debug(id_str + f'Memory cache miss for {cls.__qualname__}({args}{kwargs})') + else: + debug(id_str + f'Memory cache hit for {cls.__qualname__}({args}{kwargs})') + # TODO: Add communication tags to avoid cross-broadcasting + comm.bcast(value, root=0) + else: + value = comm.bcast(CACHE_MISS, root=0) + if isinstance(value, _CacheMiss): + # We might have the CACHE_MISS from rank 0 and + # `(value is CACHE_MISS) == False` which is confusing, + # so we set it back to the local value + value = CACHE_MISS + + if value is CACHE_MISS: + # TODO: Fix comment + # We can call the constructor as `cls(*args, **kwargs)` here since we + # are subclassing `DiskCachedObject` and _want_ to call __new__ in + # case the object is in the disk cache. + value = super().__new__(cls, *args, **kwargs) + # Regardless whether the object was disk cached, init has already + # been called here + local_cache[k] = value + + return value + + def __init_subclass__(cls, cachedir=None, key=None, **kwargs): + if cachedir is None or key is None: + raise TypeError( + f"A `cache` and a `key` are required to subclass {__class__.__name__}.\n" + "Try declaring your subclass as follows:\n" + f"\tclass {cls.__name__}({cls.__bases__[0].__name__}, cache=my_cache, key=my_key)" + ) + super().__init_subclass__(cachedir=cachedir, key=key, **kwargs) + + def __getnewargs__(self): + return (self._cache_key, ) + + +# TODO: Remove class wrapper, this was a bad idea +def disk_cache(cachedir, key): + def decorator(orig_obj): + if isinstance(orig_obj, type(lambda: None)): + # Cached function wrapper + @wraps(orig_obj) + def _wrapper(*args, **kwargs): + comm, disk_key = key(*args, **kwargs) + k = _as_hexdigest((disk_key, orig_obj.__qualname__)) + if comm.rank == 0: + # Only read from disk on rank 0 + value = _disk_cache_get(cachedir, k) + id_str = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " + if value is CACHE_MISS: + debug(id_str + f"Disk cache miss for {orig_obj.__qualname__}({args}{kwargs})") + else: + debug(id_str + f'Disk cache hit for {orig_obj.__qualname__}({args}{kwargs})') + comm.bcast(value, root=0) + else: + value = comm.bcast(CACHE_MISS, root=0) + + if value is CACHE_MISS: + value = orig_obj(*args, **kwargs) + # Only write to the disk cache on rank 0 + if comm.rank == 0: + _disk_cache_set(cachedir, k, value) + return value + elif isinstance(orig_obj, type(object)): + # Cached object wrapper + @wraps(orig_obj, updated=()) + class _wrapper(orig_obj): + def __new__(cls, *args, **kwargs): + comm, disk_key = key(*args, **kwargs) + k = _as_hexdigest((disk_key, orig_obj.__qualname__)) + if comm.rank == 0: + # Only read from disk on rank 0 + value = _disk_cache_get(cachedir, k) + if value is None: + value = CACHE_MISS + id_str = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " + if value is CACHE_MISS: + debug(id_str + f'Disk cache miss for {orig_obj.__qualname__}({args}{kwargs})') + else: + debug(id_str + f'Disk cache hit for {orig_obj.__qualname__}({args}{kwargs})') + comm.bcast(value, root=0) + else: + value = comm.bcast(CACHE_MISS, root=0) + + if value is CACHE_MISS: + # We can't call the constructor as `orig_obj(*args, **kwargs)` + # since we might be subclassing another cached object. The + # solution is to create a new object and pass it to `__init__` + value = object.__new__(orig_obj) + orig_obj.__init__(value, *args, **kwargs) + if comm.rank == 0: + # Only write to the disk cache on rank 0 + _disk_cache_set(cachedir, k, value) + return value + else: + raise ValueError("Unknown object passed to decorator") + return _wrapper + return decorator + + cached = cachetools.cached """Cache decorator for functions. See the cachetools documentation for more information. diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 80a5b4ccc..b648413b6 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -40,6 +40,7 @@ import sys import ctypes import shlex +import cachetools from hashlib import md5 from packaging.version import Version, InvalidVersion from textwrap import dedent @@ -557,8 +558,20 @@ class AnonymousCompiler(Compiler): _name = "Unknown" +def load_hashkey(*args, **kwargs): + from pyop2.global_kernel import GlobalKernel + if isinstance(args[0], str): + code_hash = md5(args[0].encode()).hexdigest() + elif isinstance(args[0], GlobalKernel): + code_hash = md5(str(args[0].cache_key).encode()).hexdigest() + else: + pass # This will raise an error in load + comm = kwargs.get('comm') + return comm, cachetools.keys.hashkey(code_hash, *args[1:], **kwargs) + + @mpi.collective -@parallel_memory_only_cache_no_broadcast() +@parallel_memory_only_cache_no_broadcast(key=load_hashkey) def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), argtypes=None, restype=None, comm=None): """Build a shared library and return a function pointer from it. diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 8f35038ff..4ffee512d 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -1,7 +1,6 @@ import collections.abc import ctypes from dataclasses import dataclass -import itertools import os from typing import Optional, Tuple @@ -11,7 +10,6 @@ from petsc4py import PETSc from pyop2 import compilation, mpi -from pyop2.caching import Cached from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes from pyop2.types import IterationRegion, Constant, READ @@ -247,7 +245,7 @@ def pack(self): return MatPack -class GlobalKernel(Cached): +class GlobalKernel: """Class representing the generated code for the global computation. :param local_kernel: :class:`pyop2.LocalKernel` instance representing the @@ -271,22 +269,6 @@ class GlobalKernel(Cached): :param pass_layer_arg: Should the wrapper pass the current layer into the kernel (as an `int`). Only makes sense for indirect extruded iteration. """ - - _cache = {} - - @classmethod - def _cache_key(cls, local_knl, arguments, **kwargs): - key = [cls, local_knl.cache_key, - *kwargs.items(), configuration["simd_width"]] - - key.extend([a.cache_key for a in arguments]) - - counter = itertools.count() - seen_maps = collections.defaultdict(lambda: next(counter)) - key.extend([seen_maps[m] for a in arguments for m in a.maps]) - - return tuple(key) - def __init__(self, local_kernel, arguments, *, extruded=False, extruded_periodic=False, @@ -294,9 +276,6 @@ def __init__(self, local_kernel, arguments, *, subset=False, iteration_region=None, pass_layer_arg=False): - if self._initialized: - return - if not len(local_kernel.accesses) == len(arguments): raise ValueError( "Number of arguments passed to the local and global kernels" @@ -319,7 +298,11 @@ def __init__(self, local_kernel, arguments, *, raise ValueError( "Cannot request constant_layers argument for non-extruded iteration" ) - + self.cache_key = ( + local_kernel.cache_key, *[a.cache_key for a in arguments], + extruded, extruded_periodic, constant_layers, subset, + iteration_region, pass_layer_arg, configuration["simd_width"] + ) self.local_kernel = local_kernel self.arguments = arguments self._extruded = extruded @@ -329,8 +312,6 @@ def __init__(self, local_kernel, arguments, *, self._iteration_region = iteration_region self._pass_layer_arg = pass_layer_arg - self._initialized = True - @mpi.collective def __call__(self, comm, *args): """Execute the compiled kernel. diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index e78ff728d..007c5b7c5 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -39,7 +39,6 @@ import numpy from pyop2 import op2, mpi from pyop2.caching import disk_cached -from pyop2.mpi import comm_cache_keyval def _seed(): @@ -285,7 +284,14 @@ class TestGeneratedCodeCache: Generated Code Cache Tests. """ - cache = op2.GlobalKernel._cache + @property + def cache(self): + int_comm = mpi.internal_comm(mpi.COMM_WORLD, self) + _cache = int_comm.Get_attr(mpi.comm_cache_keyval) + if _cache is None: + _cache = {} + mpi.COMM_WORLD.Set_attr(mpi.comm_cache_keyval, _cache) + return _cache @pytest.fixture def a(cls, diterset): @@ -564,7 +570,7 @@ def test_decorator_collective_uses_different_in_memory_caches(self, cache, cache # obj1 should be cached on the comm cache and not the self.cache obj1 = collective_func("input1") - comm_cache = self.comm.Get_attr(comm_cache_keyval) + comm_cache = self.comm.Get_attr(mpi.comm_cache_keyval) assert len(cache) == 0 assert len(comm_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 From 6aa6554d9d679b411dc3b077e41d9551187b5fdd Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Mon, 12 Aug 2024 22:08:14 +0100 Subject: [PATCH 07/38] WIP --- pyop2/caching.py | 10 +- test/unit/test_updated_caching.py | 158 ++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 test/unit/test_updated_caching.py diff --git a/pyop2/caching.py b/pyop2/caching.py index eb90f2258..7522c6978 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -501,7 +501,7 @@ def wrapper(*args, **kwargs): on calling the function and populating the cache. """ comm, mem_key = key(*args, **kwargs) - k = _as_hexdigest(mem_key) + k = _as_hexdigest(mem_key), func.__qualname__ # Fetch the per-comm cache or set it up if not present local_cache = comm.Get_attr(comm_cache_keyval) @@ -548,7 +548,7 @@ def wrapper(*args, **kwargs): on calling the function and populating the cache. """ comm, mem_key = key(*args, **kwargs) - k = _as_hexdigest(mem_key) + k = _as_hexdigest(mem_key), func.__qualname__ # Fetch the per-comm cache or set it up if not present local_cache = comm.Get_attr(comm_cache_keyval) @@ -573,6 +573,7 @@ def wrapper(*args, **kwargs): return decorator +# TODO: Change call signature def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False): """Decorator for wrapping a function in a cache that stores values in memory and to disk. @@ -650,6 +651,11 @@ def _as_hexdigest(key): return hashlib.md5(str(key).encode()).hexdigest() +def clear_memory_cache(comm): + if comm.Get_attr(comm_cache_keyval) is not None: + comm.Set_attr(comm_cache_keyval, {}) + + def _disk_cache_get(cachedir, key): """Retrieve a value from the disk cache. diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py new file mode 100644 index 000000000..bdd6583bd --- /dev/null +++ b/test/unit/test_updated_caching.py @@ -0,0 +1,158 @@ +import pytest +from tempfile import gettempdir +from functools import partial + +from pyop2.caching import ( + disk_cached, + parallel_memory_only_cache, + parallel_memory_only_cache_no_broadcast, + DiskCachedObject, + MemoryAndDiskCachedObject, + default_parallel_hashkey, + clear_memory_cache +) +from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval + + +# For new disk_cached API +disk_cached = partial(disk_cached, None, key=default_parallel_hashkey, collective=True) + + +class StateIncrement: + """Simple class for keeping track of the number of times executed + """ + def __init__(self): + self._count = 0 + + def __call__(self): + self._count += 1 + return self._count + + @property + def value(self): + return self._count + + +def twople(x): + return (x, )*2 + + +def threeple(x): + return (x, )*3 + + +def n_comms(n): + return [MPI.COMM_WORLD]*n + + +def n_ops(n): + return [MPI.SUM]*n + + +# decorator = parallel_memory_only_cache, parallel_memory_only_cache_no_broadcast, disk_cached +def function_factory(state, decorator, f, **kwargs): + def custom_function(x, comm=COMM_WORLD): + state() + return f(x) + + return decorator(**kwargs)(custom_function) + + +# parent_class = DiskCachedObject, MemoryAndDiskCachedObject +# f(x) = x**2, x**3 +def object_factory(state, parent_class, f, **kwargs): + class CustomObject(parent_class, **kwargs): + def __init__(self, x, comm=COMM_WORLD): + state() + self.x = f(x) + + return CustomObject + + +@pytest.fixture +def state(): + return StateIncrement() + + +@pytest.fixture +def unique_tempdir(): + """This allows us to run with a different tempdir for each test that + requires one""" + return gettempdir() + + +@pytest.mark.parametrize("decorator, uncached_function", [ + (parallel_memory_only_cache, twople), + (parallel_memory_only_cache_no_broadcast, n_comms), + (disk_cached, twople) +]) +def test_function_args_twice_caches(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] is disk_cached: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + first = cached_function(2, comm=COMM_WORLD) + assert first == uncached_function(2) + assert state.value == 1 + second = cached_function(2, comm=COMM_WORLD) + assert second == uncached_function(2) + assert second is first + assert state.value == 1 + + clear_memory_cache(COMM_WORLD) + + +@pytest.mark.parametrize("decorator, uncached_function", [ + (parallel_memory_only_cache, twople), + (parallel_memory_only_cache_no_broadcast, n_comms), + (disk_cached, twople) +]) +def test_function_args_different(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] is disk_cached: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + first = cached_function(2, comm=COMM_WORLD) + assert first == uncached_function(2) + assert state.value == 1 + second = cached_function(3, comm=COMM_WORLD) + assert second == uncached_function(3) + assert state.value == 2 + + clear_memory_cache(COMM_WORLD) + + +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize("decorator, uncached_function", [ + (parallel_memory_only_cache, twople), + (parallel_memory_only_cache_no_broadcast, n_comms), + (disk_cached, twople) +]) +def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] is disk_cached: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + for ii in range(10): + color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED + comm12 = COMM_WORLD.Split(color=color) + if COMM_WORLD.rank < 2: + _ = cached_function(2, comm=comm12) + comm12.Free() + + color = 0 if COMM_WORLD.rank > 0 else MPI.UNDEFINED + comm23 = COMM_WORLD.Split(color=color) + if COMM_WORLD.rank > 0: + _ = cached_function(2, comm=comm23) + comm23.Free() + + clear_memory_cache(COMM_WORLD) From 6dec230cc3b3390a293b8e97b0cd0012ff74aceb Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 13 Aug 2024 17:31:18 +0100 Subject: [PATCH 08/38] Just notes --- pyop2/caching.py | 74 ++++++++++++++++++++++++++++++++++++++++++ pyop2/global_kernel.py | 1 + 2 files changed, 75 insertions(+) diff --git a/pyop2/caching.py b/pyop2/caching.py index 7522c6978..2ee71fe9f 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -481,6 +481,80 @@ def default_parallel_hashkey(*args, **kwargs): comm = kwargs.get('comm') return comm, cachetools.keys.hashkey(*args, **kwargs) +#### connor bits + +# ~ def pcache(comm_seeker, key=None, cache_factory=dict): + + # ~ comm = comm_seeker() + # ~ cache = cache_factory() + +# ~ @pcache(cachetools.LRUCache) + +@pcache(DiskCache) + +@pcache(MemDiskCache) + +@pcache(MemCache) + +mem_cache = pcache(cache_factory=cachetools.LRUCache) +disk_cache = mem_cache(cache_factory=DiskCache) + +# ~ @pcache(comm_seeker=lambda obj, *_, **_: obj.comm, cache_factory=lambda: cachetools.LRUCache(maxsize=1000)) + + +# ~ @pmemcache + +# ~ @pmemdiskcache + +# ~ class ParallelObject(ABC): + # ~ @abc.abstractproperty + # ~ def _comm(self): + # ~ pass + +# ~ class MyObj(ParallelObject): + + # ~ @pcached_property # assumes that obj has a "comm" attr + # ~ @pcached_property(lambda self: self.comm) + # ~ def myproperty(self): + # ~ ... + + +# ~ def pcached_property(): + # ~ def wrapper(self): + # ~ assert isinstance(self, ParallelObject) + # ~ ... + + +# ~ from futils.mpi import ParallelObject + +# ~ from futils.cache import pcached_property + +# ~ from footils.cache import * + +# footils == firedrake utils + +# * parallel cached property +# * memcache / cache / cached +# * diskonlycache / disk_only_cached +# * memdiskcache / diskcache / disk_cached +# * memcache_no_bcast / broadcast=False + +# ~ parallel_cached_property = parallel_cache(lambda self: self._comm, key=lambda self: ()) + +# ~ @time +# ~ @timed +# ~ def myslowfunc(): + # ~ .. + +# ~ my_fast_fun = cache(my_slow_fn) + +#### + + +# TODO: +# Implement an @parallel_cached_property decorator function + + def parallel_memory_only_cache(key=default_parallel_hashkey): """Memory only cache decorator. diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 4ffee512d..454ca414f 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -354,6 +354,7 @@ def builder(self): builder.add_argument(arg) return builder + # TODO: Wrap with parallel_cached_property @cached_property def code_to_compile(self): """Return the C/C++ source code as a string.""" From a7451fc1948778fbdca6972429f1558f10962347 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Wed, 14 Aug 2024 13:42:53 +0100 Subject: [PATCH 09/38] Remove cached(object) and lint --- pyop2/caching.py | 175 +++++++++--------------------- test/unit/test_updated_caching.py | 4 +- 2 files changed, 52 insertions(+), 127 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 2ee71fe9f..b5285408c 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -44,7 +44,6 @@ from pyop2.configuration import configuration from pyop2.logger import debug from pyop2.mpi import comm_cache_keyval, COMM_WORLD -from pyop2.utils import cached_property def report_cache(typ): @@ -162,79 +161,6 @@ def make_obj(): return obj -class Cached(object): - - """Base class providing global caching of objects. Derived classes need to - implement classmethods :meth:`_process_args` and :meth:`_cache_key` - and define a class attribute :attr:`_cache` of type :class:`dict`. - - .. warning:: - The derived class' :meth:`__init__` is still called if the object is - retrieved from cache. If that is not desired, derived classes can set - a flag indicating whether the constructor has already been called and - immediately return from :meth:`__init__` if the flag is set. Otherwise - the object will be re-initialized even if it was returned from cache! - """ - - def __new__(cls, *args, **kwargs): - args, kwargs = cls._process_args(*args, **kwargs) - key = cls._cache_key(*args, **kwargs) - - def make_obj(): - obj = super(Cached, cls).__new__(cls) - obj._key = key - obj._initialized = False - # obj.__init__ will be called twice when constructing - # something not in the cache. The first time here, with - # the canonicalised args, the second time directly in the - # subclass. But that one should hit the cache and return - # straight away. - obj.__init__(*args, **kwargs) - return obj - - # Don't bother looking in caches if we're not meant to cache - # this object. - if key is None: - return make_obj() - try: - return cls._cache_lookup(key) - except (KeyError, IOError): - obj = make_obj() - cls._cache_store(key, obj) - return obj - - @classmethod - def _cache_lookup(cls, key): - return cls._cache[key] - - @classmethod - def _cache_store(cls, key, val): - cls._cache[key] = val - - @classmethod - def _process_args(cls, *args, **kwargs): - """Pre-processes the arguments before they are being passed to - :meth:`_cache_key` and the constructor. - - :rtype: *must* return a :class:`list` of *args* and a - :class:`dict` of *kwargs*""" - return args, kwargs - - @classmethod - def _cache_key(cls, *args, **kwargs): - """Compute the cache key given the preprocessed constructor arguments. - - :rtype: Cache key to use or ``None`` if the object is not to be cached - - .. note:: The cache key must be hashable.""" - return tuple(args) + tuple([(k, v) for k, v in kwargs.items()]) - - @cached_property - def cache_key(self): - """Cache key.""" - return self._key - - class _CacheMiss: pass @@ -481,56 +407,56 @@ def default_parallel_hashkey(*args, **kwargs): comm = kwargs.get('comm') return comm, cachetools.keys.hashkey(*args, **kwargs) -#### connor bits - -# ~ def pcache(comm_seeker, key=None, cache_factory=dict): - - # ~ comm = comm_seeker() - # ~ cache = cache_factory() - -# ~ @pcache(cachetools.LRUCache) - -@pcache(DiskCache) - -@pcache(MemDiskCache) - -@pcache(MemCache) - -mem_cache = pcache(cache_factory=cachetools.LRUCache) -disk_cache = mem_cache(cache_factory=DiskCache) - -# ~ @pcache(comm_seeker=lambda obj, *_, **_: obj.comm, cache_factory=lambda: cachetools.LRUCache(maxsize=1000)) - - -# ~ @pmemcache - -# ~ @pmemdiskcache - -# ~ class ParallelObject(ABC): - # ~ @abc.abstractproperty - # ~ def _comm(self): - # ~ pass - -# ~ class MyObj(ParallelObject): - - # ~ @pcached_property # assumes that obj has a "comm" attr - # ~ @pcached_property(lambda self: self.comm) - # ~ def myproperty(self): - # ~ ... - - -# ~ def pcached_property(): - # ~ def wrapper(self): - # ~ assert isinstance(self, ParallelObject) - # ~ ... - - -# ~ from futils.mpi import ParallelObject - -# ~ from futils.cache import pcached_property - -# ~ from footils.cache import * - +# ### Some notes from Connor: +# +# def pcache(comm_seeker, key=None, cache_factory=dict): +# +# comm = comm_seeker() +# cache = cache_factory() +# +# @pcache(cachetools.LRUCache) +# +# @pcache(DiskCache) +# +# @pcache(MemDiskCache) +# +# @pcache(MemCache) +# +# mem_cache = pcache(cache_factory=cachetools.LRUCache) +# disk_cache = mem_cache(cache_factory=DiskCache) +# +# @pcache(comm_seeker=lambda obj, *_, **_: obj.comm, cache_factory=lambda: cachetools.LRUCache(maxsize=1000)) +# +# +# @pmemcache +# +# @pmemdiskcache +# +# class ParallelObject(ABC): +# @abc.abstractproperty +# def _comm(self): +# pass +# +# class MyObj(ParallelObject): +# +# @pcached_property # assumes that obj has a "comm" attr +# @pcached_property(lambda self: self.comm) +# def myproperty(self): +# ... +# +# +# def pcached_property(): +# def wrapper(self): +# assert isinstance(self, ParallelObject) +# ... +# +# +# from futils.mpi import ParallelObject +# +# from futils.cache import pcached_property +# +# from footils.cache import * +# # footils == firedrake utils # * parallel cached property @@ -555,7 +481,6 @@ def default_parallel_hashkey(*args, **kwargs): # Implement an @parallel_cached_property decorator function - def parallel_memory_only_cache(key=default_parallel_hashkey): """Memory only cache decorator. diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py index bdd6583bd..e3d901b4e 100644 --- a/test/unit/test_updated_caching.py +++ b/test/unit/test_updated_caching.py @@ -2,7 +2,7 @@ from tempfile import gettempdir from functools import partial -from pyop2.caching import ( +from pyop2.caching import ( # noqa: F401 disk_cached, parallel_memory_only_cache, parallel_memory_only_cache_no_broadcast, @@ -11,7 +11,7 @@ default_parallel_hashkey, clear_memory_cache ) -from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval +from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval # noqa: F401 # For new disk_cached API From c1a48c9857af33132318e6eb3b02740cd5caaec6 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Wed, 14 Aug 2024 18:46:12 +0100 Subject: [PATCH 10/38] WIP, need to fix remaining cache tests --- pyop2/caching.py | 604 +++++++++--------------------- pyop2/compilation.py | 4 +- test/unit/test_caching.py | 45 ++- test/unit/test_updated_caching.py | 162 ++++---- 4 files changed, 278 insertions(+), 537 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index b5285408c..e9b5a1160 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -37,9 +37,10 @@ import hashlib import os import pickle +from collections.abc import MutableMapping from pathlib import Path -from warnings import warn -from functools import wraps +from warnings import warn # noqa F401 +from functools import wraps, partial from pyop2.configuration import configuration from pyop2.logger import debug @@ -173,239 +174,184 @@ def __init__(self, key_value): self.value = key_value -class DiskCachedObject: - def __new__(cls, *args, **kwargs): - if isinstance(args[0], _CacheKey): - return super().__new__(cls) - comm, disk_key = cls._key(*args, **kwargs) - k = _as_hexdigest((disk_key, cls.__qualname__)) - if comm.rank == 0: - value = _disk_cache_get(cls._cachedir, k) - if value is None: - value = CACHE_MISS - id_str = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " - if value is CACHE_MISS: - debug(id_str + f'Disk cache miss for {cls.__qualname__}({args}{kwargs})') - else: - debug(id_str + f'Disk cache hit for {cls.__qualname__}({args}{kwargs})') - # TODO: Add communication tags to avoid cross-broadcasting - comm.bcast(value, root=0) - else: - value = comm.bcast(CACHE_MISS, root=0) - if isinstance(value, _CacheMiss): - # We might have the CACHE_MISS from rank 0 and - # `(value is CACHE_MISS) == False` which is confusing, - # so we set it back to the local value - value = CACHE_MISS - - if value is CACHE_MISS: - # We can't call the constructor as `cls(*args, **kwargs)` since we - # would call `__new__` and recurse infinitely. The solution is to - # create a new object and pass that to `__init__` - value = object.__new__(cls) - value.__init__(*args, **kwargs) - value._cache_key = _CacheKey(k) - if comm.rank == 0: - # Only write to the disk cache on rank 0 - _disk_cache_set(cls._cachedir, k, value) - - debug("Disk cache modifying init") - cls.__init__ = cls._skip_init(cls.__init__) - return value +def _as_hexdigest(*args): + hash_ = hashlib.md5() + for a in args: + hash_.update(str(a).encode()) + return hash_.hexdigest() - @classmethod - def _skip_init(cls, init): - """This function allows a class to skip it's init method""" - def restore_init(*args, **kwargs): - debug("Disk reset init") - cls.__init__ = init - return restore_init - - def __init_subclass__(cls, cachedir=None, key=None, **kwargs): - if cachedir is None or key is None: - raise TypeError( - f"A `cache` and a `key` are required to subclass {__class__.__name__}.\n" - "Try declaring your subclass as follows:\n" - f"\tclass {cls.__name__}({cls.__bases__[0].__name__}, cachedir=my_cache, key=my_key)" - ) - super().__init_subclass__(**kwargs) - cls._cachedir = cachedir - cls._key = key - - def __getnewargs__(self): - return (self._cache_key, ) - - -# TODO: Implement this... -# class MemoryCachedObject: -# def __new__(cls, *args, **kwargs): -# k = cls._key(*args, **kwargs), cls.__qualname__ -# value = cls._cache.get(k, CACHE_MISS) -# if value is CACHE_MISS: -# print(f'Cache miss for {cls.__qualname__}({args}{kwargs})') -# # We can't call the constructor as `cls(*args, **kwargs)` since we -# # would call `__new__` and recurse infinitely. The solution is to -# # create a new object and pass that to `__init__` -# value = object.__new__(cls) -# value.__init__(*args, **kwargs) -# cls._cache[k] = value -# else: -# print(f'Cache hit for {cls.__qualname__}({args}{kwargs})') -# cls.__init__ = _skip_init(cls, cls.__init__) -# return value -# -# def __init_subclass__(cls, cache=None, key=None, **kwargs): -# if cache is None or key is None: -# raise TypeError( -# f"A `cache` and a `key` are required to subclass {__class__.__name__}.\n" -# "Try declaring your subclass as follows:\n" -# f"\tclass {cls.__name__}({cls.__bases__[0].__name__}, cache=my_cache, key=my_key)" -# ) -# super().__init_subclass__(**kwargs) -# cls._cache = cache -# cls._key = key - - -class MemoryAndDiskCachedObject(DiskCachedObject, cachedir="", key=""): - def __new__(cls, *args, **kwargs): - if isinstance(args[0], _CacheKey): - return super().__new__(cls, *args) - comm, disk_key = cls._key(*args, **kwargs) - # Throw the qualified name into the key as a string so the memory cache - # can be debugged (a little bit) by a human. This shouldn't really be - # necessary, but some classes do not implement a proper repr. - k = _as_hexdigest((disk_key, cls.__qualname__)), cls.__qualname__ - - # Fetch the per-comm cache or set it up if not present - # from pyop2.mpi import COMM_WORLD, comm_cache_keyval - local_cache = comm.Get_attr(comm_cache_keyval) - if local_cache is None: - local_cache = {} - comm.Set_attr(comm_cache_keyval, local_cache) - - id_str = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " - if comm.rank == 0: - value = local_cache.get(k, CACHE_MISS) - if value is CACHE_MISS: - debug(id_str + f'Memory cache miss for {cls.__qualname__}({args}{kwargs})') - else: - debug(id_str + f'Memory cache hit for {cls.__qualname__}({args}{kwargs})') - # TODO: Add communication tags to avoid cross-broadcasting - comm.bcast(value, root=0) - else: - value = comm.bcast(CACHE_MISS, root=0) - if isinstance(value, _CacheMiss): - # We might have the CACHE_MISS from rank 0 and - # `(value is CACHE_MISS) == False` which is confusing, - # so we set it back to the local value - value = CACHE_MISS - - if value is CACHE_MISS: - # TODO: Fix comment - # We can call the constructor as `cls(*args, **kwargs)` here since we - # are subclassing `DiskCachedObject` and _want_ to call __new__ in - # case the object is in the disk cache. - value = super().__new__(cls, *args, **kwargs) - # Regardless whether the object was disk cached, init has already - # been called here - local_cache[k] = value +def clear_memory_cache(comm): + if comm.Get_attr(comm_cache_keyval) is not None: + comm.Set_attr(comm_cache_keyval, {}) + + +class DictLikeDiskAccess(MutableMapping): + def __init__(self, cachedir): + """ + + :arg cachedir: The cache directory. + """ + self.cachedir = cachedir + self._keys = set() + + def __getitem__(self, key): + """Retrieve a value from the disk cache. + + :arg key: The cache key, a 2-tuple of strings. + :returns: The cached object if found. + """ + filepath = Path(self.cachedir, key[0][:2], key[0][2:] + key[1]) + try: + with self.open(filepath, "rb") as fh: + value = self.read(fh) + except FileNotFoundError: + raise KeyError("File not on disk, cache miss") return value - def __init_subclass__(cls, cachedir=None, key=None, **kwargs): - if cachedir is None or key is None: - raise TypeError( - f"A `cache` and a `key` are required to subclass {__class__.__name__}.\n" - "Try declaring your subclass as follows:\n" - f"\tclass {cls.__name__}({cls.__bases__[0].__name__}, cache=my_cache, key=my_key)" - ) - super().__init_subclass__(cachedir=cachedir, key=key, **kwargs) + def __setitem__(self, key, value): + """Store a new value in the disk cache. + + :arg key: The cache key, a 2-tuple of strings. + :arg value: The new item to store in the cache. + """ + self._keys.add(key) + k1, k2 = key[0][:2], key[0][2:] + key[1] + basedir = Path(self.cachedir, k1) + basedir.mkdir(parents=True, exist_ok=True) + + tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp") + filepath = basedir.joinpath(k2) + with self.open(tempfile, "wb") as fh: + self.write(fh, value) + tempfile.rename(filepath) + + def __delitem__(self, key): + raise ValueError(f"Cannot remove items from {self.__class__.__name__}") + + def keys(self): + return self._keys + + def __iter__(self): + for k in self._keys: + yield k + + def __len__(self): + return len(self._keys) + + def __repr__(self): + return "{" + " ".join(f"{k}: {v}" for k, v in self.items()) + "}" + + def open(self, *args, **kwargs): + return open(*args, **kwargs) + + def read(self, filehandle): + return pickle.load(filehandle) + + def write(self, filehandle, value): + pickle.dump(value, filehandle) + + +def default_comm_fetcher(*args, **kwargs): + return kwargs.get("comm") + + +default_parallel_hashkey = cachetools.keys.hashkey + - def __getnewargs__(self): - return (self._cache_key, ) +def parallel_cache(hashkey=default_parallel_hashkey, comm_fetcher=default_comm_fetcher, cache_factory=None, broadcast=True): + """Memory only cache decorator. + Decorator for wrapping a function to be called over a communiucator in a + cache that stores broadcastable values in memory. If the value is found in + the cache of rank 0 it is broadcast to all other ranks. -# TODO: Remove class wrapper, this was a bad idea -def disk_cache(cachedir, key): - def decorator(orig_obj): - if isinstance(orig_obj, type(lambda: None)): - # Cached function wrapper - @wraps(orig_obj) - def _wrapper(*args, **kwargs): - comm, disk_key = key(*args, **kwargs) - k = _as_hexdigest((disk_key, orig_obj.__qualname__)) + :arg key: Callable returning the cache key for the function inputs. This + function must return a 2-tuple where the first entry is the + communicator to be collective over and the second is the key. This is + required to ensure that deadlocks do not occur when using different + subcommunicators. + """ + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory cache before falling back + on calling the function and populating the cache. + """ + comm = comm_fetcher(*args, **kwargs) + k = hashkey(*args, **kwargs) + key = _as_hexdigest(k), func.__qualname__ + + # Fetch the per-comm cache_collection or set it up if not present + # A collection is required since different types of cache can be set up on the same comm + cache_collection = comm.Get_attr(comm_cache_keyval) + if cache_collection is None: + cache_collection = {} + comm.Set_attr(comm_cache_keyval, cache_collection) + # If this kind of cache is already present on the + # cache_collection, get it, otherwise create it + local_cache = cache_collection.setdefault( + (cf := cache_factory()).__class__.__name__, + cf + ) + + if broadcast: + # Grab value from rank 0 memory cache and broadcast result if comm.rank == 0: - # Only read from disk on rank 0 - value = _disk_cache_get(cachedir, k) - id_str = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " - if value is CACHE_MISS: - debug(id_str + f"Disk cache miss for {orig_obj.__qualname__}({args}{kwargs})") + value = local_cache.get(key, CACHE_MISS) + if value is None: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') else: - debug(id_str + f'Disk cache hit for {orig_obj.__qualname__}({args}{kwargs})') + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit') + # TODO: Add communication tags to avoid cross-broadcasting comm.bcast(value, root=0) else: value = comm.bcast(CACHE_MISS, root=0) - + if isinstance(value, _CacheMiss): + # We might have the CACHE_MISS from rank 0 and + # `(value is CACHE_MISS) == False` which is confusing, + # so we set it back to the local value + value = CACHE_MISS + else: + # Grab value from all ranks cache and broadcast cache hit/miss + value = local_cache.get(key, CACHE_MISS) if value is CACHE_MISS: - value = orig_obj(*args, **kwargs) - # Only write to the disk cache on rank 0 - if comm.rank == 0: - _disk_cache_set(cachedir, k, value) - return value - elif isinstance(orig_obj, type(object)): - # Cached object wrapper - @wraps(orig_obj, updated=()) - class _wrapper(orig_obj): - def __new__(cls, *args, **kwargs): - comm, disk_key = key(*args, **kwargs) - k = _as_hexdigest((disk_key, orig_obj.__qualname__)) - if comm.rank == 0: - # Only read from disk on rank 0 - value = _disk_cache_get(cachedir, k) - if value is None: - value = CACHE_MISS - id_str = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " - if value is CACHE_MISS: - debug(id_str + f'Disk cache miss for {orig_obj.__qualname__}({args}{kwargs})') - else: - debug(id_str + f'Disk cache hit for {orig_obj.__qualname__}({args}{kwargs})') - comm.bcast(value, root=0) - else: - value = comm.bcast(CACHE_MISS, root=0) - - if value is CACHE_MISS: - # We can't call the constructor as `orig_obj(*args, **kwargs)` - # since we might be subclassing another cached object. The - # solution is to create a new object and pass it to `__init__` - value = object.__new__(orig_obj) - orig_obj.__init__(value, *args, **kwargs) - if comm.rank == 0: - # Only write to the disk cache on rank 0 - _disk_cache_set(cachedir, k, value) - return value - else: - raise ValueError("Unknown object passed to decorator") - return _wrapper + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') + cache_hit = False + else: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit') + cache_hit = True + all_present = comm.allgather(cache_hit) + + # If not present in the cache of all ranks we need to recompute on all ranks + if not min(all_present): + value = CACHE_MISS + + if value is CACHE_MISS: + value = func(*args, **kwargs) + return local_cache.setdefault(key, value) + + return wrapper return decorator -cached = cachetools.cached -"""Cache decorator for functions. See the cachetools documentation for more -information. +# A small collection of default simple caches +class DEFAULT_CACHE(dict): + pass + + +memory_cache = partial(parallel_cache, cache_factory=lambda: DEFAULT_CACHE()) -.. note:: - If you intend to use this decorator to cache things that are collective - across a communicator then you must include the communicator as part of - the cache key. - You should also make sure to use unbounded caches as otherwise some ranks - may evict results leading to deadlocks. -""" +def disk_only_cache(*args, cachedir=configuration["cache_dir"], **kwargs): + return parallel_cache(*args, **kwargs, cache_factory=lambda: DictLikeDiskAccess(cachedir)) -def default_parallel_hashkey(*args, **kwargs): - comm = kwargs.get('comm') - return comm, cachetools.keys.hashkey(*args, **kwargs) +def memory_and_disk_cache(*args, cachedir=configuration["cache_dir"], **kwargs): + def decorator(func): + return memory_cache(*args, **kwargs)(disk_only_cache(*args, cachedir=cachedir, **kwargs)(func)) + return decorator + # ### Some notes from Connor: # @@ -464,225 +410,19 @@ def default_parallel_hashkey(*args, **kwargs): # * diskonlycache / disk_only_cached # * memdiskcache / diskcache / disk_cached # * memcache_no_bcast / broadcast=False - -# ~ parallel_cached_property = parallel_cache(lambda self: self._comm, key=lambda self: ()) - -# ~ @time -# ~ @timed -# ~ def myslowfunc(): - # ~ .. - -# ~ my_fast_fun = cache(my_slow_fn) - +# +# parallel_cached_property = parallel_cache(lambda self: self._comm, key=lambda self: ()) +# +# @time +# @timed +# def myslowfunc(): +# +# .. +# +# my_fast_fun = cache(my_slow_fn) +# #### # TODO: # Implement an @parallel_cached_property decorator function - - -def parallel_memory_only_cache(key=default_parallel_hashkey): - """Memory only cache decorator. - - Decorator for wrapping a function to be called over a communiucator in a - cache that stores broadcastable values in memory. If the value is found in - the cache of rank 0 it is broadcast to all other ranks. - - :arg key: Callable returning the cache key for the function inputs. This - function must return a 2-tuple where the first entry is the - communicator to be collective over and the second is the key. This is - required to ensure that deadlocks do not occur when using different - subcommunicators. - """ - def decorator(func): - def wrapper(*args, **kwargs): - """ Extract the key and then try the memory cache before falling back - on calling the function and populating the cache. - """ - comm, mem_key = key(*args, **kwargs) - k = _as_hexdigest(mem_key), func.__qualname__ - - # Fetch the per-comm cache or set it up if not present - local_cache = comm.Get_attr(comm_cache_keyval) - if local_cache is None: - local_cache = {} - comm.Set_attr(comm_cache_keyval, local_cache) - - # Grab value from rank 0 memory cache and broadcast result - if comm.rank == 0: - v = local_cache.get(k) - if v is None: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') - else: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit') - comm.bcast(v, root=0) - else: - v = comm.bcast(None, root=0) - - if v is None: - v = func(*args, **kwargs) - return local_cache.setdefault(k, v) - - return wrapper - return decorator - - -def parallel_memory_only_cache_no_broadcast(key=default_parallel_hashkey): - """Memory only cache decorator. - - Decorator for wrapping a function to be called over a communiucator in a - cache that stores non-broadcastable values in memory, for instance function - pointers. If the value is not present on all ranks, all ranks repeat the - work. - - :arg key: Callable returning the cache key for the function inputs. This - function must return a 2-tuple where the first entry is the - communicator to be collective over and the second is the key. This is - required to ensure that deadlocks do not occur when using different - subcommunicators. - """ - def decorator(func): - def wrapper(*args, **kwargs): - """ Extract the key and then try the memory cache before falling back - on calling the function and populating the cache. - """ - comm, mem_key = key(*args, **kwargs) - k = _as_hexdigest(mem_key), func.__qualname__ - - # Fetch the per-comm cache or set it up if not present - local_cache = comm.Get_attr(comm_cache_keyval) - if local_cache is None: - local_cache = {} - comm.Set_attr(comm_cache_keyval, local_cache) - - # Grab value from all ranks memory cache and vote - v = local_cache.get(k) - if v is None: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') - else: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit') - all_present = comm.allgather(bool(v)) - - # If not present in the cache of all ranks, recompute on all ranks - if not min(all_present): - v = func(*args, **kwargs) - return local_cache.setdefault(k, v) - - return wrapper - return decorator - - -# TODO: Change call signature -def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False): - """Decorator for wrapping a function in a cache that stores values in memory and to disk. - - :arg cache: The in-memory cache, usually a :class:`dict`. - :arg cachedir: The location of the cache directory. Defaults to ``PYOP2_CACHE_DIR``. - :arg key: Callable returning the cache key for the function inputs. If ``collective`` - is ``True`` then this function must return a 2-tuple where the first entry is the - communicator to be collective over and the second is the key. This is required to ensure - that deadlocks do not occur when using different subcommunicators. - :arg collective: If ``True`` then cache lookup is done collectively over a communicator. - """ - if cachedir is None: - cachedir = configuration["cache_dir"] - - if collective and cache is not None: - warn( - "Global cache for collective disk cached call will not be used. " - "Pass `None` as the first argument" - ) - - def decorator(func): - if not collective: - def wrapper(*args, **kwargs): - """ Extract the key and then try the memory then disk cache - before falling back on calling the function and populating the - caches. - """ - k = _as_hexdigest(key(*args, **kwargs)) - try: - v = cache[k] - debug(f'Serial: {k} memory cache hit') - except KeyError: - debug(f'Serial: {k} memory cache miss') - v = _disk_cache_get(cachedir, k) - if v is not None: - debug(f'Serial: {k} disk cache hit') - - if v is None: - debug(f'Serial: {k} disk cache miss') - v = func(*args, **kwargs) - _disk_cache_set(cachedir, k, v) - return cache.setdefault(k, v) - - else: # Collective - @parallel_memory_only_cache(key=key) - def wrapper(*args, **kwargs): - """ Same as above, but in parallel over `comm` - """ - comm, disk_key = key(*args, **kwargs) - k = _as_hexdigest(disk_key) - - # Grab value from rank 0 disk cache and broadcast result - if comm.rank == 0: - v = _disk_cache_get(cachedir, k) - if v is not None: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} disk cache hit') - else: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} disk cache miss') - comm.bcast(v, root=0) - else: - v = comm.bcast(None, root=0) - - if v is None: - v = func(*args, **kwargs) - # Only write to the disk cache on rank 0 - if comm.rank == 0: - _disk_cache_set(cachedir, k, v) - return v - - return wrapper - return decorator - - -def _as_hexdigest(key): - return hashlib.md5(str(key).encode()).hexdigest() - - -def clear_memory_cache(comm): - if comm.Get_attr(comm_cache_keyval) is not None: - comm.Set_attr(comm_cache_keyval, {}) - - -def _disk_cache_get(cachedir, key): - """Retrieve a value from the disk cache. - - :arg cachedir: The cache directory. - :arg key: The cache key (must be a string). - :returns: The cached object if found, else ``None``. - """ - filepath = Path(cachedir, key[:2], key[2:]) - try: - with open(filepath, "rb") as f: - return pickle.load(f) - except FileNotFoundError: - return None - - -def _disk_cache_set(cachedir, key, value): - """Store a new value in the disk cache. - - :arg cachedir: The cache directory. - :arg key: The cache key (must be a string). - :arg value: The new item to store in the cache. - """ - k1, k2 = key[:2], key[2:] - basedir = Path(cachedir, k1) - basedir.mkdir(parents=True, exist_ok=True) - - tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp") - filepath = basedir.joinpath(k2) - with open(tempfile, "wb") as f: - pickle.dump(value, f) - tempfile.rename(filepath) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index b648413b6..e7f4fd857 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -47,7 +47,7 @@ from pyop2 import mpi -from pyop2.caching import parallel_memory_only_cache_no_broadcast +from pyop2.caching import memory_cache from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError @@ -571,7 +571,7 @@ def load_hashkey(*args, **kwargs): @mpi.collective -@parallel_memory_only_cache_no_broadcast(key=load_hashkey) +@memory_cache(hashkey=load_hashkey, broadcast=False) def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), argtypes=None, restype=None, comm=None): """Build a shared library and return a function pointer from it. diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index 007c5b7c5..bd549ba89 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -38,7 +38,7 @@ import cachetools import numpy from pyop2 import op2, mpi -from pyop2.caching import disk_cached +from pyop2.caching import memory_and_disk_cache def _seed(): @@ -537,34 +537,41 @@ def myfunc(arg): """Example function to cache the outputs of.""" return {arg} - def collective_key(self, *args): - """Return a cache key suitable for use when collective over a communicator.""" - self.comm = mpi.internal_comm(mpi.COMM_SELF, self) - return self.comm, cachetools.keys.hashkey(*args) + def comm_fetcher(self, *args): + """Communicator returning function.""" + return mpi.internal_comm(mpi.COMM_WORLD, self) - @pytest.fixture - def cache(cls): - return {} + def hash_key(self, *args): + """Hash key suitable for caching""" + return cachetools.keys.hashkey(*args) @pytest.fixture def cachedir(cls): return tempfile.TemporaryDirectory() - def test_decorator_in_memory_cache_reuses_results(self, cache, cachedir): - decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) + def test_decorator_in_memory_cache_reuses_results(self, cachedir): + decorated_func = memory_and_disk_cache( + comm_fetcher=self.comm_fetcher, + cachedir=cachedir.name + )(self.myfunc) obj1 = decorated_func("input1") - assert len(cache) == 1 + caches = self.comm_fetcher().Get_attr(mpi.comm_cache_keyval) + mem_cache = caches["DEFAULT_CACHE"] + disk_cache = caches["DictLikeDiskAccess"] + assert len(mem_cache) == 1 + assert len(disk_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 obj2 = decorated_func("input1") assert obj1 is obj2 - assert len(cache) == 1 + assert len(mem_cache) == 1 + assert len(disk_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 - def test_decorator_collective_uses_different_in_memory_caches(self, cache, cachedir): - decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) - collective_func = disk_cached( + def test_decorator_collective_uses_different_in_memory_caches(self, cachedir): + decorated_func = memory_and_disk_cache(cachedir=cachedir.name)(self.myfunc) + collective_func = memory_and_disk_cache( None, cachedir.name, self.collective_key, collective=True )(self.myfunc) @@ -582,8 +589,8 @@ def test_decorator_collective_uses_different_in_memory_caches(self, cache, cache assert len(comm_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 - def test_decorator_disk_cache_reuses_results(self, cache, cachedir): - decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) + def test_decorator_disk_cache_reuses_results(self, cachedir): + decorated_func = memory_and_disk_cache(cachedir=cachedir.name)(self.myfunc) obj1 = decorated_func("input1") cache.clear() @@ -592,8 +599,8 @@ def test_decorator_disk_cache_reuses_results(self, cache, cachedir): assert len(cache) == 1 assert len(os.listdir(cachedir.name)) == 1 - def test_decorator_cache_misses(self, cache, cachedir): - decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) + def test_decorator_cache_misses(self, cachedir): + decorated_func = memory_and_disk_cache(cachedir=cachedir.name)(self.myfunc) obj1 = decorated_func("input1") obj2 = decorated_func("input2") diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py index e3d901b4e..f9a6aa437 100644 --- a/test/unit/test_updated_caching.py +++ b/test/unit/test_updated_caching.py @@ -3,21 +3,15 @@ from functools import partial from pyop2.caching import ( # noqa: F401 - disk_cached, - parallel_memory_only_cache, - parallel_memory_only_cache_no_broadcast, - DiskCachedObject, - MemoryAndDiskCachedObject, + disk_only_cache, + memory_cache, + memory_and_disk_cache, default_parallel_hashkey, clear_memory_cache ) from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval # noqa: F401 -# For new disk_cached API -disk_cached = partial(disk_cached, None, key=default_parallel_hashkey, collective=True) - - class StateIncrement: """Simple class for keeping track of the number of times executed """ @@ -81,78 +75,78 @@ def unique_tempdir(): return gettempdir() -@pytest.mark.parametrize("decorator, uncached_function", [ - (parallel_memory_only_cache, twople), - (parallel_memory_only_cache_no_broadcast, n_comms), - (disk_cached, twople) -]) -def test_function_args_twice_caches(request, state, decorator, uncached_function, tmpdir): - if request.node.callspec.params["decorator"] is disk_cached: - kwargs = {"cachedir": tmpdir} - else: - kwargs = {} - - cached_function = function_factory(state, decorator, uncached_function, **kwargs) - assert state.value == 0 - first = cached_function(2, comm=COMM_WORLD) - assert first == uncached_function(2) - assert state.value == 1 - second = cached_function(2, comm=COMM_WORLD) - assert second == uncached_function(2) - assert second is first - assert state.value == 1 - - clear_memory_cache(COMM_WORLD) - - -@pytest.mark.parametrize("decorator, uncached_function", [ - (parallel_memory_only_cache, twople), - (parallel_memory_only_cache_no_broadcast, n_comms), - (disk_cached, twople) -]) -def test_function_args_different(request, state, decorator, uncached_function, tmpdir): - if request.node.callspec.params["decorator"] is disk_cached: - kwargs = {"cachedir": tmpdir} - else: - kwargs = {} - - cached_function = function_factory(state, decorator, uncached_function, **kwargs) - assert state.value == 0 - first = cached_function(2, comm=COMM_WORLD) - assert first == uncached_function(2) - assert state.value == 1 - second = cached_function(3, comm=COMM_WORLD) - assert second == uncached_function(3) - assert state.value == 2 - - clear_memory_cache(COMM_WORLD) - - -@pytest.mark.parallel(nprocs=3) -@pytest.mark.parametrize("decorator, uncached_function", [ - (parallel_memory_only_cache, twople), - (parallel_memory_only_cache_no_broadcast, n_comms), - (disk_cached, twople) -]) -def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir): - if request.node.callspec.params["decorator"] is disk_cached: - kwargs = {"cachedir": tmpdir} - else: - kwargs = {} - - cached_function = function_factory(state, decorator, uncached_function, **kwargs) - assert state.value == 0 - for ii in range(10): - color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED - comm12 = COMM_WORLD.Split(color=color) - if COMM_WORLD.rank < 2: - _ = cached_function(2, comm=comm12) - comm12.Free() - - color = 0 if COMM_WORLD.rank > 0 else MPI.UNDEFINED - comm23 = COMM_WORLD.Split(color=color) - if COMM_WORLD.rank > 0: - _ = cached_function(2, comm=comm23) - comm23.Free() - - clear_memory_cache(COMM_WORLD) +# ~ @pytest.mark.parametrize("decorator, uncached_function", [ + # ~ (parallel_memory_only_cache, twople), + # ~ (parallel_memory_only_cache_no_broadcast, n_comms), + # ~ (disk_cached, twople) +# ~ ]) +# ~ def test_function_args_twice_caches(request, state, decorator, uncached_function, tmpdir): + # ~ if request.node.callspec.params["decorator"] is disk_cached: + # ~ kwargs = {"cachedir": tmpdir} + # ~ else: + # ~ kwargs = {} + + # ~ cached_function = function_factory(state, decorator, uncached_function, **kwargs) + # ~ assert state.value == 0 + # ~ first = cached_function(2, comm=COMM_WORLD) + # ~ assert first == uncached_function(2) + # ~ assert state.value == 1 + # ~ second = cached_function(2, comm=COMM_WORLD) + # ~ assert second == uncached_function(2) + # ~ assert second is first + # ~ assert state.value == 1 + + # ~ clear_memory_cache(COMM_WORLD) + + +# ~ @pytest.mark.parametrize("decorator, uncached_function", [ + # ~ (parallel_memory_only_cache, twople), + # ~ (parallel_memory_only_cache_no_broadcast, n_comms), + # ~ (disk_cached, twople) +# ~ ]) +# ~ def test_function_args_different(request, state, decorator, uncached_function, tmpdir): + # ~ if request.node.callspec.params["decorator"] is disk_cached: + # ~ kwargs = {"cachedir": tmpdir} + # ~ else: + # ~ kwargs = {} + + # ~ cached_function = function_factory(state, decorator, uncached_function, **kwargs) + # ~ assert state.value == 0 + # ~ first = cached_function(2, comm=COMM_WORLD) + # ~ assert first == uncached_function(2) + # ~ assert state.value == 1 + # ~ second = cached_function(3, comm=COMM_WORLD) + # ~ assert second == uncached_function(3) + # ~ assert state.value == 2 + + # ~ clear_memory_cache(COMM_WORLD) + + +# ~ @pytest.mark.parallel(nprocs=3) +# ~ @pytest.mark.parametrize("decorator, uncached_function", [ + # ~ (parallel_memory_only_cache, twople), + # ~ (parallel_memory_only_cache_no_broadcast, n_comms), + # ~ (disk_cached, twople) +# ~ ]) +# ~ def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir): + # ~ if request.node.callspec.params["decorator"] is disk_cached: + # ~ kwargs = {"cachedir": tmpdir} + # ~ else: + # ~ kwargs = {} + + # ~ cached_function = function_factory(state, decorator, uncached_function, **kwargs) + # ~ assert state.value == 0 + # ~ for ii in range(10): + # ~ color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED + # ~ comm12 = COMM_WORLD.Split(color=color) + # ~ if COMM_WORLD.rank < 2: + # ~ _ = cached_function(2, comm=comm12) + # ~ comm12.Free() + + # ~ color = 0 if COMM_WORLD.rank > 0 else MPI.UNDEFINED + # ~ comm23 = COMM_WORLD.Split(color=color) + # ~ if COMM_WORLD.rank > 0: + # ~ _ = cached_function(2, comm=comm23) + # ~ comm23.Free() + + # ~ clear_memory_cache(COMM_WORLD) From d1bacf0bc82322d103125cd256df0a0e0ca9ca0a Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Wed, 14 Aug 2024 23:04:15 +0100 Subject: [PATCH 11/38] Round things out and fix remaining tests --- pyop2/caching.py | 142 ++++++++---------------- pyop2/compilation.py | 3 +- requirements-git.txt | 1 + test/unit/test_caching.py | 102 +++++++++-------- test/unit/test_updated_caching.py | 175 ++++++++++++++---------------- 5 files changed, 180 insertions(+), 243 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index e9b5a1160..c9f9f5edc 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -40,13 +40,14 @@ from collections.abc import MutableMapping from pathlib import Path from warnings import warn # noqa F401 -from functools import wraps, partial +from functools import wraps from pyop2.configuration import configuration from pyop2.logger import debug -from pyop2.mpi import comm_cache_keyval, COMM_WORLD +from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval +# TODO: Remove this? Rewrite? def report_cache(typ): """Report the size of caches of type ``typ`` @@ -169,11 +170,6 @@ class _CacheMiss: CACHE_MISS = _CacheMiss() -class _CacheKey: - def __init__(self, key_value): - self.value = key_value - - def _as_hexdigest(*args): hash_ = hashlib.md5() for a in args: @@ -193,7 +189,6 @@ def __init__(self, cachedir): :arg cachedir: The cache directory. """ self.cachedir = cachedir - self._keys = set() def __getitem__(self, key): """Retrieve a value from the disk cache. @@ -215,7 +210,6 @@ def __setitem__(self, key, value): :arg key: The cache key, a 2-tuple of strings. :arg value: The new item to store in the cache. """ - self._keys.add(key) k1, k2 = key[0][:2], key[0][2:] + key[1] basedir = Path(self.cachedir, k1) basedir.mkdir(parents=True, exist_ok=True) @@ -229,18 +223,14 @@ def __setitem__(self, key, value): def __delitem__(self, key): raise ValueError(f"Cannot remove items from {self.__class__.__name__}") - def keys(self): - return self._keys - def __iter__(self): - for k in self._keys: - yield k + raise ValueError(f"Cannot iterate over keys in {self.__class__.__name__}") def __len__(self): - return len(self._keys) + raise ValueError(f"Cannot query length of {self.__class__.__name__}") def __repr__(self): - return "{" + " ".join(f"{k}: {v}" for k, v in self.items()) + "}" + return f"{self.__class__.__name__}(cachedir={self.cachedir})" def open(self, *args, **kwargs): return open(*args, **kwargs) @@ -253,13 +243,41 @@ def write(self, filehandle, value): def default_comm_fetcher(*args, **kwargs): - return kwargs.get("comm") + comms = filter( + lambda arg: isinstance(arg, MPI.Comm), + args + tuple(kwargs.values()) + ) + try: + comm = next(comms) + except StopIteration: + raise TypeError("No comms found in args or kwargs") + return comm + + +def default_parallel_hashkey(*args, **kwargs): + """ We now want to actively remove any comms from args and kwargs to get the same disk cache key + """ + hash_args = tuple(filter( + lambda arg: not isinstance(arg, MPI.Comm), + args + )) + hash_kwargs = dict(filter( + lambda arg: not isinstance(arg[1], MPI.Comm), + kwargs.items() + )) + return cachetools.keys.hashkey(*hash_args, **hash_kwargs) -default_parallel_hashkey = cachetools.keys.hashkey +class DEFAULT_CACHE(dict): + pass -def parallel_cache(hashkey=default_parallel_hashkey, comm_fetcher=default_comm_fetcher, cache_factory=None, broadcast=True): +def parallel_cache( + hashkey=default_parallel_hashkey, + comm_fetcher=default_comm_fetcher, + cache_factory=lambda: DEFAULT_CACHE(), + broadcast=True +): """Memory only cache decorator. Decorator for wrapping a function to be called over a communiucator in a @@ -336,11 +354,7 @@ def wrapper(*args, **kwargs): # A small collection of default simple caches -class DEFAULT_CACHE(dict): - pass - - -memory_cache = partial(parallel_cache, cache_factory=lambda: DEFAULT_CACHE()) +memory_cache = parallel_cache def disk_only_cache(*args, cachedir=configuration["cache_dir"], **kwargs): @@ -352,77 +366,9 @@ def decorator(func): return memory_cache(*args, **kwargs)(disk_only_cache(*args, cachedir=cachedir, **kwargs)(func)) return decorator - -# ### Some notes from Connor: -# -# def pcache(comm_seeker, key=None, cache_factory=dict): -# -# comm = comm_seeker() -# cache = cache_factory() -# -# @pcache(cachetools.LRUCache) -# -# @pcache(DiskCache) -# -# @pcache(MemDiskCache) -# -# @pcache(MemCache) -# -# mem_cache = pcache(cache_factory=cachetools.LRUCache) -# disk_cache = mem_cache(cache_factory=DiskCache) -# -# @pcache(comm_seeker=lambda obj, *_, **_: obj.comm, cache_factory=lambda: cachetools.LRUCache(maxsize=1000)) -# -# -# @pmemcache -# -# @pmemdiskcache -# -# class ParallelObject(ABC): -# @abc.abstractproperty -# def _comm(self): -# pass -# -# class MyObj(ParallelObject): -# -# @pcached_property # assumes that obj has a "comm" attr -# @pcached_property(lambda self: self.comm) -# def myproperty(self): -# ... -# -# -# def pcached_property(): -# def wrapper(self): -# assert isinstance(self, ParallelObject) -# ... -# -# -# from futils.mpi import ParallelObject -# -# from futils.cache import pcached_property -# -# from footils.cache import * -# -# footils == firedrake utils - -# * parallel cached property -# * memcache / cache / cached -# * diskonlycache / disk_only_cached -# * memdiskcache / diskcache / disk_cached -# * memcache_no_bcast / broadcast=False -# -# parallel_cached_property = parallel_cache(lambda self: self._comm, key=lambda self: ()) -# -# @time -# @timed -# def myslowfunc(): -# -# .. -# -# my_fast_fun = cache(my_slow_fn) -# -#### - - -# TODO: -# Implement an @parallel_cached_property decorator function +# TODO: (Wishlist) +# * Try more exotic caches ie: memory_cache = partial(parallel_cache, cache_factory=lambda: cachetools.LRUCache(maxsize=1000)) +# * Add some sort of cache reporting +# * Add some sort of cache statistics +# * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method +# * Add some docstrings and maybe some exposition! diff --git a/pyop2/compilation.py b/pyop2/compilation.py index e7f4fd857..5ab5a9a6e 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -566,8 +566,7 @@ def load_hashkey(*args, **kwargs): code_hash = md5(str(args[0].cache_key).encode()).hexdigest() else: pass # This will raise an error in load - comm = kwargs.get('comm') - return comm, cachetools.keys.hashkey(code_hash, *args[1:], **kwargs) + return cachetools.keys.hashkey(code_hash, *args[1:], **kwargs) @mpi.collective diff --git a/requirements-git.txt b/requirements-git.txt index d6f3d2182..a8f7fb67f 100644 --- a/requirements-git.txt +++ b/requirements-git.txt @@ -1 +1,2 @@ git+https://github.com/firedrakeproject/loopy.git@main#egg=loopy +git+https://github.com/firedrakeproject/pytest-mpi.git@main#egg=pytest-mpi diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index bd549ba89..d34859cc2 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -35,10 +35,9 @@ import os import pytest import tempfile -import cachetools import numpy from pyop2 import op2, mpi -from pyop2.caching import memory_and_disk_cache +from pyop2.caching import DEFAULT_CACHE, memory_and_disk_cache, clear_memory_cache def _seed(): @@ -289,9 +288,9 @@ def cache(self): int_comm = mpi.internal_comm(mpi.COMM_WORLD, self) _cache = int_comm.Get_attr(mpi.comm_cache_keyval) if _cache is None: - _cache = {} + _cache = {'DEFAULT_CACHE': DEFAULT_CACHE()} mpi.COMM_WORLD.Set_attr(mpi.comm_cache_keyval, _cache) - return _cache + return _cache['DEFAULT_CACHE'] @pytest.fixture def a(cls, diterset): @@ -533,79 +532,86 @@ def test_sparsities_different_ordered_map_tuple_cached(self, m1, m2, ds2): class TestDiskCachedDecorator: @staticmethod - def myfunc(arg): + def myfunc(arg, comm): """Example function to cache the outputs of.""" return {arg} - def comm_fetcher(self, *args): - """Communicator returning function.""" - return mpi.internal_comm(mpi.COMM_WORLD, self) - - def hash_key(self, *args): - """Hash key suitable for caching""" - return cachetools.keys.hashkey(*args) + @pytest.fixture + def comm(self): + """This fixture provides a temporary comm so that each test gets it's own + communicator and that caches are cleaned on free.""" + temporary_comm = mpi.COMM_WORLD.Dup() + temporary_comm.name = "pytest temporary COMM_WORLD" + with mpi.temp_internal_comm(temporary_comm) as comm: + yield comm + temporary_comm.Free() @pytest.fixture def cachedir(cls): return tempfile.TemporaryDirectory() - def test_decorator_in_memory_cache_reuses_results(self, cachedir): + def test_decorator_in_memory_cache_reuses_results(self, cachedir, comm): decorated_func = memory_and_disk_cache( - comm_fetcher=self.comm_fetcher, cachedir=cachedir.name )(self.myfunc) - obj1 = decorated_func("input1") - caches = self.comm_fetcher().Get_attr(mpi.comm_cache_keyval) - mem_cache = caches["DEFAULT_CACHE"] - disk_cache = caches["DictLikeDiskAccess"] + obj1 = decorated_func("input1", comm=comm) + mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] assert len(mem_cache) == 1 - assert len(disk_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 - obj2 = decorated_func("input1") + obj2 = decorated_func("input1", comm=comm) assert obj1 is obj2 assert len(mem_cache) == 1 - assert len(disk_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 - def test_decorator_collective_uses_different_in_memory_caches(self, cachedir): - decorated_func = memory_and_disk_cache(cachedir=cachedir.name)(self.myfunc) - collective_func = memory_and_disk_cache( - None, cachedir.name, self.collective_key, collective=True + def test_decorator_uses_different_in_memory_caches_on_different_comms(self, cachedir, comm): + comm_world_func = memory_and_disk_cache( + cachedir=cachedir.name )(self.myfunc) - # obj1 should be cached on the comm cache and not the self.cache - obj1 = collective_func("input1") - comm_cache = self.comm.Get_attr(mpi.comm_cache_keyval) - assert len(cache) == 0 - assert len(comm_cache) == 1 - assert len(os.listdir(cachedir.name)) == 1 - - # obj2 should be cached on the self.cache and not the comm cache - obj2 = decorated_func("input1") - assert obj1 == obj2 and obj1 is not obj2 - assert len(cache) == 1 - assert len(comm_cache) == 1 - assert len(os.listdir(cachedir.name)) == 1 - - def test_decorator_disk_cache_reuses_results(self, cachedir): + temporary_comm = mpi.COMM_SELF.Dup() + temporary_comm.name = "pytest temporary COMM_SELF" + with mpi.temp_internal_comm(temporary_comm) as comm_self: + comm_self_func = memory_and_disk_cache( + cachedir=cachedir.name + )(self.myfunc) + + # obj1 should be cached on the COMM_WORLD cache + obj1 = comm_world_func("input1", comm=comm) + comm_world_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + assert len(comm_world_cache) == 1 + assert len(os.listdir(cachedir.name)) == 1 + + # obj2 should be cached on the COMM_SELF cache + obj2 = comm_self_func("input1", comm=comm_self) + comm_self_cache = comm_self.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + assert obj1 == obj2 and obj1 is not obj2 + assert len(comm_world_cache) == 1 + assert len(comm_self_cache) == 1 + assert len(os.listdir(cachedir.name)) == 1 + + temporary_comm.Free() + + def test_decorator_disk_cache_reuses_results(self, cachedir, comm): decorated_func = memory_and_disk_cache(cachedir=cachedir.name)(self.myfunc) - obj1 = decorated_func("input1") - cache.clear() - obj2 = decorated_func("input1") + obj1 = decorated_func("input1", comm=comm) + clear_memory_cache(comm) + obj2 = decorated_func("input1", comm=comm) + mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] assert obj1 == obj2 and obj1 is not obj2 - assert len(cache) == 1 + assert len(mem_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 - def test_decorator_cache_misses(self, cachedir): + def test_decorator_cache_misses(self, cachedir, comm): decorated_func = memory_and_disk_cache(cachedir=cachedir.name)(self.myfunc) - obj1 = decorated_func("input1") - obj2 = decorated_func("input2") + obj1 = decorated_func("input1", comm=comm) + obj2 = decorated_func("input2", comm=comm) + mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] assert obj1 != obj2 - assert len(cache) == 2 + assert len(mem_cache) == 2 assert len(os.listdir(cachedir.name)) == 2 diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py index f9a6aa437..2c8ee53bf 100644 --- a/test/unit/test_updated_caching.py +++ b/test/unit/test_updated_caching.py @@ -1,5 +1,4 @@ import pytest -from tempfile import gettempdir from functools import partial from pyop2.caching import ( # noqa: F401 @@ -43,7 +42,7 @@ def n_ops(n): return [MPI.SUM]*n -# decorator = parallel_memory_only_cache, parallel_memory_only_cache_no_broadcast, disk_cached +# decorator = parallel_memory_only_cache, parallel_memory_only_cache_no_broadcast, disk_only_cached def function_factory(state, decorator, f, **kwargs): def custom_function(x, comm=COMM_WORLD): state() @@ -52,101 +51,87 @@ def custom_function(x, comm=COMM_WORLD): return decorator(**kwargs)(custom_function) -# parent_class = DiskCachedObject, MemoryAndDiskCachedObject -# f(x) = x**2, x**3 -def object_factory(state, parent_class, f, **kwargs): - class CustomObject(parent_class, **kwargs): - def __init__(self, x, comm=COMM_WORLD): - state() - self.x = f(x) - - return CustomObject - - @pytest.fixture def state(): return StateIncrement() -@pytest.fixture -def unique_tempdir(): - """This allows us to run with a different tempdir for each test that - requires one""" - return gettempdir() - - -# ~ @pytest.mark.parametrize("decorator, uncached_function", [ - # ~ (parallel_memory_only_cache, twople), - # ~ (parallel_memory_only_cache_no_broadcast, n_comms), - # ~ (disk_cached, twople) -# ~ ]) -# ~ def test_function_args_twice_caches(request, state, decorator, uncached_function, tmpdir): - # ~ if request.node.callspec.params["decorator"] is disk_cached: - # ~ kwargs = {"cachedir": tmpdir} - # ~ else: - # ~ kwargs = {} - - # ~ cached_function = function_factory(state, decorator, uncached_function, **kwargs) - # ~ assert state.value == 0 - # ~ first = cached_function(2, comm=COMM_WORLD) - # ~ assert first == uncached_function(2) - # ~ assert state.value == 1 - # ~ second = cached_function(2, comm=COMM_WORLD) - # ~ assert second == uncached_function(2) - # ~ assert second is first - # ~ assert state.value == 1 - - # ~ clear_memory_cache(COMM_WORLD) - - -# ~ @pytest.mark.parametrize("decorator, uncached_function", [ - # ~ (parallel_memory_only_cache, twople), - # ~ (parallel_memory_only_cache_no_broadcast, n_comms), - # ~ (disk_cached, twople) -# ~ ]) -# ~ def test_function_args_different(request, state, decorator, uncached_function, tmpdir): - # ~ if request.node.callspec.params["decorator"] is disk_cached: - # ~ kwargs = {"cachedir": tmpdir} - # ~ else: - # ~ kwargs = {} - - # ~ cached_function = function_factory(state, decorator, uncached_function, **kwargs) - # ~ assert state.value == 0 - # ~ first = cached_function(2, comm=COMM_WORLD) - # ~ assert first == uncached_function(2) - # ~ assert state.value == 1 - # ~ second = cached_function(3, comm=COMM_WORLD) - # ~ assert second == uncached_function(3) - # ~ assert state.value == 2 - - # ~ clear_memory_cache(COMM_WORLD) - - -# ~ @pytest.mark.parallel(nprocs=3) -# ~ @pytest.mark.parametrize("decorator, uncached_function", [ - # ~ (parallel_memory_only_cache, twople), - # ~ (parallel_memory_only_cache_no_broadcast, n_comms), - # ~ (disk_cached, twople) -# ~ ]) -# ~ def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir): - # ~ if request.node.callspec.params["decorator"] is disk_cached: - # ~ kwargs = {"cachedir": tmpdir} - # ~ else: - # ~ kwargs = {} - - # ~ cached_function = function_factory(state, decorator, uncached_function, **kwargs) - # ~ assert state.value == 0 - # ~ for ii in range(10): - # ~ color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED - # ~ comm12 = COMM_WORLD.Split(color=color) - # ~ if COMM_WORLD.rank < 2: - # ~ _ = cached_function(2, comm=comm12) - # ~ comm12.Free() - - # ~ color = 0 if COMM_WORLD.rank > 0 else MPI.UNDEFINED - # ~ comm23 = COMM_WORLD.Split(color=color) - # ~ if COMM_WORLD.rank > 0: - # ~ _ = cached_function(2, comm=comm23) - # ~ comm23.Free() - - # ~ clear_memory_cache(COMM_WORLD) +@pytest.mark.parametrize("decorator, uncached_function", [ + (memory_cache, twople), + (partial(memory_cache, broadcast=False), n_comms), + (memory_and_disk_cache, twople), + (disk_only_cache, twople) +]) +def test_function_args_twice_caches(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + first = cached_function(2, comm=COMM_WORLD) + assert first == uncached_function(2) + assert state.value == 1 + second = cached_function(2, comm=COMM_WORLD) + assert second == uncached_function(2) + if request.node.callspec.params["decorator"] is not disk_only_cache: + assert second is first + assert state.value == 1 + + clear_memory_cache(COMM_WORLD) + + +@pytest.mark.parametrize("decorator, uncached_function", [ + (memory_cache, twople), + (partial(memory_cache, broadcast=False), n_comms), + (memory_and_disk_cache, twople), + (disk_only_cache, twople) +]) +def test_function_args_different(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + first = cached_function(2, comm=COMM_WORLD) + assert first == uncached_function(2) + assert state.value == 1 + second = cached_function(3, comm=COMM_WORLD) + assert second == uncached_function(3) + assert state.value == 2 + + clear_memory_cache(COMM_WORLD) + + +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize("decorator, uncached_function", [ + (memory_cache, twople), + (partial(memory_cache, broadcast=False), n_comms), + (memory_and_disk_cache, twople), + (disk_only_cache, twople) +]) +def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + for ii in range(10): + color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED + comm12 = COMM_WORLD.Split(color=color) + if COMM_WORLD.rank < 2: + _ = cached_function(2, comm=comm12) + comm12.Free() + + color = 0 if COMM_WORLD.rank > 0 else MPI.UNDEFINED + comm23 = COMM_WORLD.Split(color=color) + if COMM_WORLD.rank > 0: + _ = cached_function(2, comm=comm23) + comm23.Free() + + clear_memory_cache(COMM_WORLD) From a7f43d429c42d0be15f2ce8c74802449ba655645 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 15 Aug 2024 11:30:37 +0100 Subject: [PATCH 12/38] CI debugging --- .github/workflows/ci.yml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 788186ac9..40e937986 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: # Don't immediately kill all if one Python version fails fail-fast: false matrix: - python-version: ['3.8', '3.9', '3.10', '3.11'] + python-version: ['3.9', '3.10', '3.11', '3.12'] env: CC: mpicc PETSC_DIR: ${{ github.workspace }}/petsc @@ -58,7 +58,7 @@ jobs: working-directory: ${{ env.PETSC_DIR }}/src/binding/petsc4py run: | python -m pip install --upgrade pip - python -m pip install --upgrade wheel 'cython<3' numpy + python -m pip install --upgrade wheel cython numpy python -m pip install --no-deps . - name: Checkout PyOP2 @@ -83,6 +83,12 @@ jobs: working-directory: PyOP2 run: make lint + - name: Check MPI + shell: bash + working-directory: PyOP2 + run: mpiexec -n 4 python3 -c "from mpi4py import MPI; print(MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)" + timeout-minutes: 10 + - name: Run tests shell: bash working-directory: PyOP2 From 783ab12a286ed5a6af7fda4f0d9e39bb94d83c2e Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 15 Aug 2024 12:05:00 +0100 Subject: [PATCH 13/38] CI debugging, ssh in? --- .github/workflows/ci.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 40e937986..9b3cbac82 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,6 +28,13 @@ jobs: timeout-minutes: 60 steps: + - name: Setup tmate session + uses: mxschmitt/action-tmate@v3 + with: + detached: true + timeout-minutes: 20 + limit-access-to-actor: true + - name: Install system dependencies shell: bash run: | From 3d13c4f0ffb024fc8e4ba64b43aaeab83ae0ef2d Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 15 Aug 2024 15:00:23 +0100 Subject: [PATCH 14/38] CI debugging, a fix? --- .github/workflows/ci.yml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9b3cbac82..64846e4ed 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,8 +32,8 @@ jobs: uses: mxschmitt/action-tmate@v3 with: detached: true - timeout-minutes: 20 limit-access-to-actor: true + timeout-minutes: 20 - name: Install system dependencies shell: bash @@ -90,16 +90,13 @@ jobs: working-directory: PyOP2 run: make lint - - name: Check MPI - shell: bash - working-directory: PyOP2 - run: mpiexec -n 4 python3 -c "from mpi4py import MPI; print(MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)" - timeout-minutes: 10 - - name: Run tests shell: bash working-directory: PyOP2 - run: pytest --tb=native --timeout=480 --timeout-method=thread -o faulthandler_timeout=540 -v test + run: | + # Running parallel test cases separately works around a bug in pytest-mpi + pytest -k "not parallel" --tb=native --timeout=480 --timeout-method=thread -o faulthandler_timeout=540 -v test + mpiexec -n 3 pytest -k "parallel[3]" --tb=native --timeout=480 --timeout-method=thread -o faulthandler_timeout=540 -v test timeout-minutes: 10 - name: Build documentation From 756e29e5184b5d26b35c3d7d767f2cccce005013 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 15 Aug 2024 15:43:43 +0100 Subject: [PATCH 15/38] CI debugging, a fix for 3.12 --- .github/workflows/ci.yml | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 64846e4ed..6081d5b7b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,7 +33,7 @@ jobs: with: detached: true limit-access-to-actor: true - timeout-minutes: 20 + timeout-minutes: 15 - name: Install system dependencies shell: bash @@ -73,7 +73,7 @@ jobs: with: path: PyOP2 - - name: Install PyOP2 + - name: Install PyOP2 dependencies shell: bash working-directory: PyOP2 run: | @@ -85,6 +85,19 @@ jobs: python -m pip install -U pytest-timeout python -m pip install . + - name: Install PyOP2 (Python <3.12) + if: ${{ matrix.python-version != '3.12' }} + shell: bash + working-directory: PyOP2 + run: python -m pip install . + + # Not sure if this is a bug in setuptools or something PyOP2 is doing wrong + - name: Install PyOP2 (Python == 3.12) + if: ${{ matrix.python-version != '3.12' }} + shell: bash + working-directory: PyOP2 + run: python setup.py install + - name: Run linting shell: bash working-directory: PyOP2 From f3decc4acb63b1048fe367b74fb7fcc24c22192e Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 15 Aug 2024 16:05:32 +0100 Subject: [PATCH 16/38] . --- .github/workflows/ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6081d5b7b..84c1072eb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -83,7 +83,6 @@ jobs: python -m pip install pulp python -m pip install -U flake8 python -m pip install -U pytest-timeout - python -m pip install . - name: Install PyOP2 (Python <3.12) if: ${{ matrix.python-version != '3.12' }} From 723468bde682150f81ffb94ca17d3ec9f7868e6a Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 15 Aug 2024 16:31:28 +0100 Subject: [PATCH 17/38] . --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 84c1072eb..3b38b23c2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -92,7 +92,7 @@ jobs: # Not sure if this is a bug in setuptools or something PyOP2 is doing wrong - name: Install PyOP2 (Python == 3.12) - if: ${{ matrix.python-version != '3.12' }} + if: ${{ matrix.python-version == '3.12' }} shell: bash working-directory: PyOP2 run: python setup.py install From 474058d1370b4ce70264de013934a386e7077786 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 15 Aug 2024 16:54:19 +0100 Subject: [PATCH 18/38] . --- .github/workflows/ci.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3b38b23c2..709ba422a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -95,7 +95,9 @@ jobs: if: ${{ matrix.python-version == '3.12' }} shell: bash working-directory: PyOP2 - run: python setup.py install + run: | + python -m pip install -U setuptools + python setup.py install - name: Run linting shell: bash From aab930aca2d39f3ed0b6e2ad49e651157ee2997d Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 15 Aug 2024 17:22:12 +0100 Subject: [PATCH 19/38] Remove custom implementation of @cached_property, use the Python(>=3.8) one --- pyop2/utils.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/pyop2/utils.py b/pyop2/utils.py index 11b4ead5b..2f26741e1 100644 --- a/pyop2/utils.py +++ b/pyop2/utils.py @@ -40,29 +40,12 @@ from decorator import decorator import argparse +from functools import cached_property # noqa: F401 + from pyop2.exceptions import DataTypeError, DataValueError from pyop2.configuration import configuration -class cached_property(object): - - '''A read-only @property that is only evaluated once. The value is cached - on the object itself rather than the function or class; this should prevent - memory leakage.''' - - def __init__(self, fget, doc=None): - self.fget = fget - self.__doc__ = doc or fget.__doc__ - self.__name__ = fget.__name__ - self.__module__ = fget.__module__ - - def __get__(self, obj, cls): - if obj is None: - return self - obj.__dict__[self.__name__] = result = self.fget(obj) - return result - - def as_tuple(item, type=None, length=None, allow_none=False): # Empty list if we get passed None if item is None: From f5122496f7673c68b6c2e22167c295cc27e6de0f Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 15 Aug 2024 21:17:54 +0100 Subject: [PATCH 20/38] Add a sreial cache with the same interface --- pyop2/caching.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyop2/caching.py b/pyop2/caching.py index c9f9f5edc..e2e4f399e 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -357,6 +357,10 @@ def wrapper(*args, **kwargs): memory_cache = parallel_cache +def serial_cache(hashkey, cache_factory=lambda: DEFAULT_CACHE()): + return cachetools.cached(key=hashkey, cache=cache_factory()) + + def disk_only_cache(*args, cachedir=configuration["cache_dir"], **kwargs): return parallel_cache(*args, **kwargs, cache_factory=lambda: DictLikeDiskAccess(cachedir)) From c12194c7254262a5518439f12ec0398eadcc2b74 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 16 Aug 2024 21:21:23 +0100 Subject: [PATCH 21/38] Caching fixes and tweaks --- pyop2/caching.py | 15 +++++++++------ pyop2/compilation.py | 4 ++-- pyop2/global_kernel.py | 8 +++++++- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index e2e4f399e..2bc72525c 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -173,6 +173,9 @@ class _CacheMiss: def _as_hexdigest(*args): hash_ = hashlib.md5() for a in args: + # TODO: Remove or edit this check! + if isinstance(a, MPI.Comm) or isinstance(a, cachetools.keys._HashedTuple): + breakpoint() hash_.update(str(a).encode()) return hash_.hexdigest() @@ -298,7 +301,7 @@ def wrapper(*args, **kwargs): """ comm = comm_fetcher(*args, **kwargs) k = hashkey(*args, **kwargs) - key = _as_hexdigest(k), func.__qualname__ + key = _as_hexdigest(*k), func.__qualname__ # Fetch the per-comm cache_collection or set it up if not present # A collection is required since different types of cache can be set up on the same comm @@ -317,10 +320,10 @@ def wrapper(*args, **kwargs): # Grab value from rank 0 memory cache and broadcast result if comm.rank == 0: value = local_cache.get(key, CACHE_MISS) - if value is None: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') + if value is CACHE_MISS: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache miss') else: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit') + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') # TODO: Add communication tags to avoid cross-broadcasting comm.bcast(value, root=0) else: @@ -334,10 +337,10 @@ def wrapper(*args, **kwargs): # Grab value from all ranks cache and broadcast cache hit/miss value = local_cache.get(key, CACHE_MISS) if value is CACHE_MISS: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache miss') cache_hit = False else: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit') + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') cache_hit = True all_present = comm.allgather(cache_hit) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 5ab5a9a6e..f7b1a93d5 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -47,7 +47,7 @@ from pyop2 import mpi -from pyop2.caching import memory_cache +from pyop2.caching import memory_cache, default_parallel_hashkey from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError @@ -566,7 +566,7 @@ def load_hashkey(*args, **kwargs): code_hash = md5(str(args[0].cache_key).encode()).hexdigest() else: pass # This will raise an error in load - return cachetools.keys.hashkey(code_hash, *args[1:], **kwargs) + return default_parallel_hashkey(code_hash, *args[1:], **kwargs) @mpi.collective diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 454ca414f..3c8cde430 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -3,6 +3,7 @@ from dataclasses import dataclass import os from typing import Optional, Tuple +import itertools import loopy as lp import numpy as np @@ -298,8 +299,13 @@ def __init__(self, local_kernel, arguments, *, raise ValueError( "Cannot request constant_layers argument for non-extruded iteration" ) + + counter = itertools.count() + seen_maps = collections.defaultdict(lambda: next(counter)) self.cache_key = ( - local_kernel.cache_key, *[a.cache_key for a in arguments], + local_kernel.cache_key, + *[a.cache_key for a in arguments], + *[seen_maps[m] for a in arguments for m in a.maps], extruded, extruded_periodic, constant_layers, subset, iteration_region, pass_layer_arg, configuration["simd_width"] ) From c74c7bb1aa796cffb47b1f7b7a2550f7dafa2f48 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 16 Aug 2024 21:37:45 +0100 Subject: [PATCH 22/38] Lint --- pyop2/compilation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index f7b1a93d5..16ef97357 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -40,7 +40,6 @@ import sys import ctypes import shlex -import cachetools from hashlib import md5 from packaging.version import Version, InvalidVersion from textwrap import dedent From f52449ce6f3a62816415a0894269f7954f9ecff9 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Mon, 19 Aug 2024 18:16:24 +0100 Subject: [PATCH 23/38] Add instrumentation to caches --- pyop2/caching.py | 291 +++++++++++++++++++++++++++----------- pyop2/configuration.py | 4 +- pyop2/op2.py | 9 +- test/unit/test_caching.py | 25 ++-- 4 files changed, 227 insertions(+), 102 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 2bc72525c..65954beda 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -32,7 +32,6 @@ # OF THE POSSIBILITY OF SUCH DAMAGE. """Provides common base classes for cached objects.""" - import cachetools import hashlib import os @@ -40,44 +39,27 @@ from collections.abc import MutableMapping from pathlib import Path from warnings import warn # noqa F401 -from functools import wraps +from collections import defaultdict +from itertools import count +from functools import partial, wraps from pyop2.configuration import configuration from pyop2.logger import debug -from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval +from pyop2.mpi import ( + MPI, COMM_WORLD, comm_cache_keyval, temp_internal_comm +) -# TODO: Remove this? Rewrite? -def report_cache(typ): - """Report the size of caches of type ``typ`` - - :arg typ: A class of cached object. For example - :class:`ObjectCached` or :class:`Cached`. - """ - from collections import defaultdict - from inspect import getmodule - from gc import get_objects - typs = defaultdict(lambda: 0) - n = 0 - for x in get_objects(): - if isinstance(x, typ): - typs[type(x)] += 1 - n += 1 - if n == 0: - print("\nNo %s objects in caches" % typ.__name__) - return - print("\n%d %s objects in caches" % (n, typ.__name__)) - print("Object breakdown") - print("================") - for k, v in typs.iteritems(): - mod = getmodule(k) - if mod is not None: - name = "%s.%s" % (mod.__name__, k.__name__) - else: - name = k.__name__ - print('%s: %d' % (name, v)) +# Caches created here are registered as a tuple of +# (creation_index, comm, comm.name, function, cache) +# in _KNOWN_CACHES +_CACHE_CIDX = count() +_KNOWN_CACHES = [] +# Flag for outputting information at the end of testing (do not abuse!) +_running_on_ci = bool(os.environ.get('PYOP2_CI_TESTS')) +# FIXME: (Later) Remove ObjectCached class ObjectCached(object): """Base class for objects that should be cached on another object. @@ -163,6 +145,95 @@ def make_obj(): return obj +def cache_stats(comm=None, comm_name=None, alive=True, function=None, cache_type=None): + caches = _KNOWN_CACHES + if comm is not None: + with temp_internal_comm(comm) as icomm: + cache_collection = icomm.Get_attr(comm_cache_keyval) + if cache_collection is None: + print(f"Communicator {icomm.name} as no associated caches") + comm_name = icomm.name + if comm_name is not None: + caches = filter(lambda c: c[2] == comm_name, caches) + if alive: + caches = filter(lambda c: c[1] != MPI.COMM_NULL, caches) + if function is not None: + if isinstance(function, str): + caches = filter(lambda c: function in c[3].__qualname__, caches) + else: + caches = filter(lambda c: c[3] is function, caches) + if cache_type is not None: + if isinstance(cache_type, str): + caches = filter(lambda c: cache_type in c[4].__qualname__, caches) + else: + caches = filter(lambda c: isinstance(c[4], cache_type), caches) + return [*caches] + + +def get_stats(cache): + hit = miss = size = maxsize = -1 + if isinstance(cache, cachetools.Cache): + size = cache.currsize + maxsize = cache.maxsize + if hasattr(cache, "instrument__"): + hit = cache.hit + miss = cache.miss + if size is None: + try: + size = len(cache) + except NotImplementedError: + pass + if maxsize is None: + try: + maxsize = cache.max_size + except AttributeError: + pass + return hit, miss, size, maxsize + + +def print_cache_stats(*args, **kwargs): + data = defaultdict(lambda: defaultdict(list)) + for entry in cache_stats(*args, **kwargs): + ecid, ecomm, ecomm_name, efunction, ecache = entry + active = (ecomm != MPI.COMM_NULL) + data[(ecomm_name, active)][ecache.__class__.__name__].append( + (ecid, efunction.__module__, efunction.__name__, ecache) + ) + + tab = " " + hline = "-"*120 + col = (90, 27) + stats_col = (6, 6, 6, 6) + stats = ("hit", "miss", "size", "max") + no_stats = "|".join(" "*ii for ii in stats_col) + print(hline) + print(f"|{'Cache':^{col[0]}}|{'Stats':^{col[1]}}|") + subtitles = "|".join(f"{st:^{w}}" for st, w in zip(stats, stats_col)) + print("|" + " "*col[0] + f"|{subtitles:{col[1]}}|") + print(hline) + for ecomm, cachedict in data.items(): + active = "Active" if ecomm[1] else "Freed" + comm_title = f"{ecomm[0]} ({active})" + print(f"|{comm_title:{col[0]}}|{no_stats}|") + for ecache, function_list in cachedict.items(): + cache_title = f"{tab}{ecache}" + print(f"|{cache_title:{col[0]}}|{no_stats}|") + try: + loc = function_list[0][-1].cachedir + except AttributeError: + loc = "Memory" + cache_location = f"{tab} ↳ {loc!s}" + if len(str(loc)) < col[0] - 5: + print(f"|{cache_location:{col[0]}}|{no_stats}|") + else: + print(f"|{cache_location:78}|") + for entry in function_list: + function_title = f"{tab*2}id={entry[0]} {'.'.join(entry[1:3])}" + stats = "|".join(f"{s:{w}}" for s, w in zip(get_stats(entry[3]), stats_col)) + print(f"|{function_title:{col[0]}}|{stats:{col[1]}}|") + print(hline) + + class _CacheMiss: pass @@ -180,11 +251,6 @@ def _as_hexdigest(*args): return hash_.hexdigest() -def clear_memory_cache(comm): - if comm.Get_attr(comm_cache_keyval) is not None: - comm.Set_attr(comm_cache_keyval, {}) - - class DictLikeDiskAccess(MutableMapping): def __init__(self, cachedir): """ @@ -224,17 +290,21 @@ def __setitem__(self, key, value): tempfile.rename(filepath) def __delitem__(self, key): - raise ValueError(f"Cannot remove items from {self.__class__.__name__}") + raise NotImplementedError(f"Cannot remove items from {self.__class__.__name__}") def __iter__(self): - raise ValueError(f"Cannot iterate over keys in {self.__class__.__name__}") + raise NotImplementedError(f"Cannot iterate over keys in {self.__class__.__name__}") def __len__(self): - raise ValueError(f"Cannot query length of {self.__class__.__name__}") + raise NotImplementedError(f"Cannot query length of {self.__class__.__name__}") def __repr__(self): return f"{self.__class__.__name__}(cachedir={self.cachedir})" + def __eq__(self, other): + # Instances are the same if they have the same cachedir + return self.cachedir == other.cachedir + def open(self, *args, **kwargs): return open(*args, **kwargs) @@ -271,10 +341,47 @@ def default_parallel_hashkey(*args, **kwargs): return cachetools.keys.hashkey(*hash_args, **hash_kwargs) +def instrument(cls): + @wraps(cls, updated=()) + class _wrapper(cls): + instrument__ = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hit = 0 + self.miss = 0 + + def get(self, key, default=None): + value = super().get(key, default) + if value is default: + self.miss += 1 + else: + self.hit += 1 + return value + + def __getitem__(self, key): + try: + value = super().__getitem__(key) + self.hit += 1 + except KeyError as e: + self.miss += 1 + raise e + return value + return _wrapper + + class DEFAULT_CACHE(dict): pass +# Examples of how to instrument and use different default caches: +# - DEFAULT_CACHE = instrument(DEFAULT_CACHE) +# - DEFAULT_CACHE = instrument(cachetools.LRUCache) +# - DEFAULT_CACHE = partial(DEFAULT_CACHE, maxsize=100) +EXOTIC_CACHE = partial(instrument(cachetools.LRUCache), maxsize=100) +# - DictLikeDiskAccess = instrument(DictLikeDiskAccess) + + def parallel_cache( hashkey=default_parallel_hashkey, comm_fetcher=default_comm_fetcher, @@ -299,54 +406,62 @@ def wrapper(*args, **kwargs): """ Extract the key and then try the memory cache before falling back on calling the function and populating the cache. """ - comm = comm_fetcher(*args, **kwargs) k = hashkey(*args, **kwargs) key = _as_hexdigest(*k), func.__qualname__ - - # Fetch the per-comm cache_collection or set it up if not present - # A collection is required since different types of cache can be set up on the same comm - cache_collection = comm.Get_attr(comm_cache_keyval) - if cache_collection is None: - cache_collection = {} - comm.Set_attr(comm_cache_keyval, cache_collection) - # If this kind of cache is already present on the - # cache_collection, get it, otherwise create it - local_cache = cache_collection.setdefault( - (cf := cache_factory()).__class__.__name__, - cf - ) - - if broadcast: - # Grab value from rank 0 memory cache and broadcast result - if comm.rank == 0: + # Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits + with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm: + # Fetch the per-comm cache_collection or set it up if not present + # A collection is required since different types of cache can be set up on the same comm + cache_collection = comm.Get_attr(comm_cache_keyval) + if cache_collection is None: + cache_collection = {} + comm.Set_attr(comm_cache_keyval, cache_collection) + # If this kind of cache is already present on the + # cache_collection, get it, otherwise create it + local_cache = cache_collection.setdefault( + (cf := cache_factory()).__class__.__name__, + cf + ) + local_cache = cache_collection[cf.__class__.__name__] + + # If this is a new cache or function add it to the list of known caches + if (comm, comm.name, func, local_cache) not in [k[1:] for k in _KNOWN_CACHES]: + _KNOWN_CACHES.append((next(_CACHE_CIDX), comm, comm.name, func, local_cache)) + + if broadcast: + # Grab value from rank 0 memory cache and broadcast result + if comm.rank == 0: + value = local_cache.get(key, CACHE_MISS) + if value is CACHE_MISS: + debug( + f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " + f"{k} {local_cache.__class__.__name__} cache miss" + ) + else: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') + # TODO: Add communication tags to avoid cross-broadcasting + comm.bcast(value, root=0) + else: + value = comm.bcast(CACHE_MISS, root=0) + if isinstance(value, _CacheMiss): + # We might have the CACHE_MISS from rank 0 and + # `(value is CACHE_MISS) == False` which is confusing, + # so we set it back to the local value + value = CACHE_MISS + else: + # Grab value from all ranks cache and broadcast cache hit/miss value = local_cache.get(key, CACHE_MISS) if value is CACHE_MISS: debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache miss') + cache_hit = False else: debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') - # TODO: Add communication tags to avoid cross-broadcasting - comm.bcast(value, root=0) - else: - value = comm.bcast(CACHE_MISS, root=0) - if isinstance(value, _CacheMiss): - # We might have the CACHE_MISS from rank 0 and - # `(value is CACHE_MISS) == False` which is confusing, - # so we set it back to the local value - value = CACHE_MISS - else: - # Grab value from all ranks cache and broadcast cache hit/miss - value = local_cache.get(key, CACHE_MISS) - if value is CACHE_MISS: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache miss') - cache_hit = False - else: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') - cache_hit = True - all_present = comm.allgather(cache_hit) + cache_hit = True + all_present = comm.allgather(cache_hit) - # If not present in the cache of all ranks we need to recompute on all ranks - if not min(all_present): - value = CACHE_MISS + # If not present in the cache of all ranks we need to recompute on all ranks + if not min(all_present): + value = CACHE_MISS if value is CACHE_MISS: value = func(*args, **kwargs) @@ -356,6 +471,12 @@ def wrapper(*args, **kwargs): return decorator +def clear_memory_cache(comm): + with temp_internal_comm(comm) as icomm: + if icomm.Get_attr(comm_cache_keyval) is not None: + icomm.Set_attr(comm_cache_keyval, {}) + + # A small collection of default simple caches memory_cache = parallel_cache @@ -374,8 +495,12 @@ def decorator(func): return decorator # TODO: (Wishlist) -# * Try more exotic caches ie: memory_cache = partial(parallel_cache, cache_factory=lambda: cachetools.LRUCache(maxsize=1000)) -# * Add some sort of cache reporting -# * Add some sort of cache statistics +# * Try more exotic caches ie: memory_cache = partial(parallel_cache, cache_factory=lambda: cachetools.LRUCache(maxsize=1000)) ✓ +# * Add some sort of cache reporting ✓ +# * Add some sort of cache statistics ✓ # * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method +# * Systematic investigation into cache sizes/types for Firedrake +# - Is a mem cache needed for DLLs? +# - Is LRUCache better than a simple dict? (memory profile test suite) +# - What is the optimal maxsize? # * Add some docstrings and maybe some exposition! diff --git a/pyop2/configuration.py b/pyop2/configuration.py index 29717718c..ff3721a6f 100644 --- a/pyop2/configuration.py +++ b/pyop2/configuration.py @@ -108,8 +108,8 @@ class Configuration(dict): ("PYOP2_NODE_LOCAL_COMPILATION", bool, True), "no_fork_available": ("PYOP2_NO_FORK_AVAILABLE", bool, False), - "print_cache_size": - ("PYOP2_PRINT_CACHE_SIZE", bool, False), + "print_cache_info": + ("PYOP2_CACHE_INFO", bool, False), "matnest": ("PYOP2_MATNEST", bool, True), "block_sparsity": diff --git a/pyop2/op2.py b/pyop2/op2.py index 85788eafa..35e5649f4 100644 --- a/pyop2/op2.py +++ b/pyop2/op2.py @@ -112,11 +112,10 @@ def init(**kwargs): @collective def exit(): """Exit OP2 and clean up""" - if configuration['print_cache_size'] and COMM_WORLD.rank == 0: - from caching import report_cache, Cached, ObjectCached - print('**** PyOP2 cache sizes at exit ****') - report_cache(typ=ObjectCached) - report_cache(typ=Cached) + if configuration['print_cache_info'] and COMM_WORLD.rank == 0: + from pyop2.caching import print_cache_stats + print(f"{' PyOP2 cache sizes on rank 0 at exit ':*^120}") + print_cache_stats(alive=False) configuration.reset() global _initialised _initialised = False diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index d34859cc2..6ab909b29 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -45,6 +45,7 @@ def _seed(): nelems = 8 +default_cache_name = DEFAULT_CACHE().__class__.__name__ @pytest.fixture @@ -286,11 +287,11 @@ class TestGeneratedCodeCache: @property def cache(self): int_comm = mpi.internal_comm(mpi.COMM_WORLD, self) - _cache = int_comm.Get_attr(mpi.comm_cache_keyval) - if _cache is None: - _cache = {'DEFAULT_CACHE': DEFAULT_CACHE()} - mpi.COMM_WORLD.Set_attr(mpi.comm_cache_keyval, _cache) - return _cache['DEFAULT_CACHE'] + _cache_collection = int_comm.Get_attr(mpi.comm_cache_keyval) + if _cache_collection is None: + _cache_collection = {default_cache_name: DEFAULT_CACHE()} + mpi.COMM_WORLD.Set_attr(mpi.comm_cache_keyval, _cache_collection) + return _cache_collection[default_cache_name] @pytest.fixture def a(cls, diterset): @@ -541,7 +542,7 @@ def comm(self): """This fixture provides a temporary comm so that each test gets it's own communicator and that caches are cleaned on free.""" temporary_comm = mpi.COMM_WORLD.Dup() - temporary_comm.name = "pytest temporary COMM_WORLD" + temporary_comm.name = "pytest temp COMM_WORLD" with mpi.temp_internal_comm(temporary_comm) as comm: yield comm temporary_comm.Free() @@ -556,7 +557,7 @@ def test_decorator_in_memory_cache_reuses_results(self, cachedir, comm): )(self.myfunc) obj1 = decorated_func("input1", comm=comm) - mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] assert len(mem_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 @@ -571,7 +572,7 @@ def test_decorator_uses_different_in_memory_caches_on_different_comms(self, cach )(self.myfunc) temporary_comm = mpi.COMM_SELF.Dup() - temporary_comm.name = "pytest temporary COMM_SELF" + temporary_comm.name = "pytest temp COMM_SELF" with mpi.temp_internal_comm(temporary_comm) as comm_self: comm_self_func = memory_and_disk_cache( cachedir=cachedir.name @@ -579,13 +580,13 @@ def test_decorator_uses_different_in_memory_caches_on_different_comms(self, cach # obj1 should be cached on the COMM_WORLD cache obj1 = comm_world_func("input1", comm=comm) - comm_world_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + comm_world_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] assert len(comm_world_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 # obj2 should be cached on the COMM_SELF cache obj2 = comm_self_func("input1", comm=comm_self) - comm_self_cache = comm_self.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + comm_self_cache = comm_self.Get_attr(mpi.comm_cache_keyval)[default_cache_name] assert obj1 == obj2 and obj1 is not obj2 assert len(comm_world_cache) == 1 assert len(comm_self_cache) == 1 @@ -599,7 +600,7 @@ def test_decorator_disk_cache_reuses_results(self, cachedir, comm): obj1 = decorated_func("input1", comm=comm) clear_memory_cache(comm) obj2 = decorated_func("input1", comm=comm) - mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] assert obj1 == obj2 and obj1 is not obj2 assert len(mem_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 @@ -609,7 +610,7 @@ def test_decorator_cache_misses(self, cachedir, comm): obj1 = decorated_func("input1", comm=comm) obj2 = decorated_func("input2", comm=comm) - mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] assert obj1 != obj2 assert len(mem_cache) == 2 assert len(os.listdir(cachedir.name)) == 2 From 0c2b08880eb6f688a9c986ebd2b8db0f43fe6937 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Mon, 19 Aug 2024 18:37:02 +0100 Subject: [PATCH 24/38] Remove get_so from Compiler base class --- pyop2/compilation.py | 320 ++++++++++++++++++++++--------------------- 1 file changed, 162 insertions(+), 158 deletions(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 16ef97357..e0e58a15e 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -275,163 +275,6 @@ def sniff_compiler_version(self, cpp=False): def bugfix_cflags(self): return () - @staticmethod - def expandWl(ldflags): - """Generator to expand the `-Wl` compiler flags for use as linker flags - :arg ldflags: linker flags for a compiler command - """ - for flag in ldflags: - if flag.startswith('-Wl'): - for f in flag.lstrip('-Wl')[1:].split(','): - yield f - else: - yield flag - - @mpi.collective - def get_so(self, jitmodule, extension): - """Build a shared library and load it - - :arg jitmodule: The JIT Module which can generate the code to compile. - :arg extension: extension of the source file (c, cpp). - Returns a :class:`ctypes.CDLL` object of the resulting shared - library.""" - - # C or C++ - if self._cpp: - compiler = self.cxx - compiler_flags = self.cxxflags - else: - compiler = self.cc - compiler_flags = self.cflags - - # Determine cache key - hsh = md5(str(jitmodule.cache_key).encode()) - hsh.update(compiler.encode()) - if self.ld: - hsh.update(self.ld.encode()) - hsh.update("".join(compiler_flags).encode()) - hsh.update("".join(self.ldflags).encode()) - - basename = hsh.hexdigest() - - cachedir = configuration['cache_dir'] - - dirpart, basename = basename[:2], basename[2:] - cachedir = os.path.join(cachedir, dirpart) - pid = os.getpid() - cname = os.path.join(cachedir, f"{basename}_p{pid}.{extension}") - oname = os.path.join(cachedir, f"{basename}_p{pid}.o") - soname = os.path.join(cachedir, f"{basename}.so") - # Link into temporary file, then rename to shared library - # atomically (avoiding races). - tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp") - - if configuration['check_src_hashes'] or configuration['debug']: - matching = self.comm.allreduce(basename, op=_check_op) - if matching != basename: - # Dump all src code to disk for debugging - output = os.path.join(configuration["cache_dir"], "mismatching-kernels") - srcfile = os.path.join(output, f"src-rank{self.comm.rank}.{extension}") - if self.comm.rank == 0: - os.makedirs(output, exist_ok=True) - self.comm.barrier() - with open(srcfile, "w") as f: - f.write(jitmodule.code_to_compile) - self.comm.barrier() - raise CompilationError(f"Generated code differs across ranks (see output in {output})") - - # Check whether this shared object already written to disk - try: - dll = ctypes.CDLL(soname) - except OSError: - dll = None - got_dll = bool(dll) - all_dll = self.comm.allgather(got_dll) - - # If the library is not loaded _on all ranks_ build it - if not min(all_dll): - if self.comm.rank == 0: - # No need to do this on all ranks - os.makedirs(cachedir, exist_ok=True) - logfile = os.path.join(cachedir, f"{basename}_p{pid}.log") - errfile = os.path.join(cachedir, f"{basename}_p{pid}.err") - with progress(INFO, 'Compiling wrapper'): - with open(cname, "w") as f: - f.write(jitmodule.code_to_compile) - # Compiler also links - if not self.ld: - cc = (compiler,) \ - + compiler_flags \ - + ('-o', tmpname, cname) \ - + self.ldflags - debug(f"Compilation command: {' '.join(cc)}") - with open(logfile, "w") as log, open(errfile, "w") as err: - log.write("Compilation command:\n") - log.write(" ".join(cc)) - log.write("\n\n") - try: - if configuration['no_fork_available']: - cc += ["2>", errfile, ">", logfile] - cmd = " ".join(cc) - status = os.system(cmd) - if status != 0: - raise subprocess.CalledProcessError(status, cmd) - else: - subprocess.check_call(cc, stderr=err, stdout=log) - except subprocess.CalledProcessError as e: - raise CompilationError(dedent(f""" - Command "{e.cmd}" return error status {e.returncode}. - Unable to compile code - Compile log in {logfile} - Compile errors in {errfile} - """)) - else: - cc = (compiler,) \ - + compiler_flags \ - + ('-c', '-o', oname, cname) - # Extract linker specific "cflags" from ldflags - ld = tuple(shlex.split(self.ld)) \ - + ('-o', tmpname, oname) \ - + tuple(self.expandWl(self.ldflags)) - debug(f"Compilation command: {' '.join(cc)}", ) - debug(f"Link command: {' '.join(ld)}") - with open(logfile, "a") as log, open(errfile, "a") as err: - log.write("Compilation command:\n") - log.write(" ".join(cc)) - log.write("\n\n") - log.write("Link command:\n") - log.write(" ".join(ld)) - log.write("\n\n") - try: - if configuration['no_fork_available']: - cc += ["2>", errfile, ">", logfile] - ld += ["2>>", errfile, ">>", logfile] - cccmd = " ".join(cc) - ldcmd = " ".join(ld) - status = os.system(cccmd) - if status != 0: - raise subprocess.CalledProcessError(status, cccmd) - status = os.system(ldcmd) - if status != 0: - raise subprocess.CalledProcessError(status, ldcmd) - else: - subprocess.check_call(cc, stderr=err, stdout=log) - subprocess.check_call(ld, stderr=err, stdout=log) - except subprocess.CalledProcessError as e: - raise CompilationError(dedent(f""" - Command "{e.cmd}" return error status {e.returncode}. - Unable to compile code - Compile log in {logfile} - Compile errors in {errfile} - """)) - # Atomically ensure soname exists - os.rename(tmpname, soname) - # Wait for compilation to complete - self.comm.barrier() - # Load resulting library - dll = ctypes.CDLL(soname) - return dll - class MacClangCompiler(Compiler): """A compiler for building a shared library on Mac systems.""" @@ -616,7 +459,9 @@ def __init__(self, code, argtypes): else: exe = configuration["cc"] or "mpicc" compiler = sniff_compiler(exe, comm) - dll = compiler(cppargs, ldargs, cpp=cpp, comm=comm).get_so(code, extension) + + compiler_instance = compiler(cppargs, ldargs, cpp=cpp, comm=comm) + dll = make_so(compiler_instance, code, extension) if isinstance(jitmodule, GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) @@ -627,6 +472,165 @@ def __init__(self, code, argtypes): return fn +def expandWl(ldflags): + """Generator to expand the `-Wl` compiler flags for use as linker flags + :arg ldflags: linker flags for a compiler command + """ + for flag in ldflags: + if flag.startswith('-Wl'): + for f in flag.lstrip('-Wl')[1:].split(','): + yield f + else: + yield flag + + +@mpi.collective +def make_so(compiler, jitmodule, extension): + """Build a shared library and load it + + :arg compiler: The compiler to use to create the shared library. + :arg jitmodule: The JIT Module which can generate the code to compile. + :arg extension: extension of the source file (c, cpp). + Returns a :class:`ctypes.CDLL` object of the resulting shared + library.""" + + # C or C++ + if compiler._cpp: + exe = compiler.cxx + compiler_flags = compiler.cxxflags + else: + exe = compiler.cc + compiler_flags = compiler.cflags + + # Determine cache key + hsh = md5(str(jitmodule.cache_key).encode()) + hsh.update(exe.encode()) + if compiler.ld: + hsh.update(compiler.ld.encode()) + hsh.update("".join(compiler_flags).encode()) + hsh.update("".join(compiler.ldflags).encode()) + + basename = hsh.hexdigest() + + cachedir = configuration['cache_dir'] + + dirpart, basename = basename[:2], basename[2:] + cachedir = os.path.join(cachedir, dirpart) + pid = os.getpid() + cname = os.path.join(cachedir, f"{basename}_p{pid}.{extension}") + oname = os.path.join(cachedir, f"{basename}_p{pid}.o") + soname = os.path.join(cachedir, f"{basename}.so") + # Link into temporary file, then rename to shared library + # atomically (avoiding races). + tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp") + + if configuration['check_src_hashes'] or configuration['debug']: + matching = compiler.comm.allreduce(basename, op=_check_op) + if matching != basename: + # Dump all src code to disk for debugging + output = os.path.join(configuration["cache_dir"], "mismatching-kernels") + srcfile = os.path.join(output, f"src-rank{compiler.comm.rank}.{extension}") + if compiler.comm.rank == 0: + os.makedirs(output, exist_ok=True) + compiler.comm.barrier() + with open(srcfile, "w") as f: + f.write(jitmodule.code_to_compile) + compiler.comm.barrier() + raise CompilationError(f"Generated code differs across ranks (see output in {output})") + + # Check whether this shared object already written to disk + try: + dll = ctypes.CDLL(soname) + except OSError: + dll = None + got_dll = bool(dll) + all_dll = compiler.comm.allgather(got_dll) + + # If the library is not loaded _on all ranks_ build it + if not min(all_dll): + if compiler.comm.rank == 0: + # No need to do this on all ranks + os.makedirs(cachedir, exist_ok=True) + logfile = os.path.join(cachedir, f"{basename}_p{pid}.log") + errfile = os.path.join(cachedir, f"{basename}_p{pid}.err") + with progress(INFO, 'Compiling wrapper'): + with open(cname, "w") as f: + f.write(jitmodule.code_to_compile) + # Compiler also links + if not compiler.ld: + cc = (exe,) \ + + compiler_flags \ + + ('-o', tmpname, cname) \ + + compiler.ldflags + debug(f"Compilation command: {' '.join(cc)}") + with open(logfile, "w") as log, open(errfile, "w") as err: + log.write("Compilation command:\n") + log.write(" ".join(cc)) + log.write("\n\n") + try: + if configuration['no_fork_available']: + cc += ["2>", errfile, ">", logfile] + cmd = " ".join(cc) + status = os.system(cmd) + if status != 0: + raise subprocess.CalledProcessError(status, cmd) + else: + subprocess.check_call(cc, stderr=err, stdout=log) + except subprocess.CalledProcessError as e: + raise CompilationError(dedent(f""" + Command "{e.cmd}" return error status {e.returncode}. + Unable to compile code + Compile log in {logfile} + Compile errors in {errfile} + """)) + else: + cc = (exe,) \ + + compiler_flags \ + + ('-c', '-o', oname, cname) + # Extract linker specific "cflags" from ldflags + ld = tuple(shlex.split(compiler.ld)) \ + + ('-o', tmpname, oname) \ + + tuple(expandWl(compiler.ldflags)) + debug(f"Compilation command: {' '.join(cc)}", ) + debug(f"Link command: {' '.join(ld)}") + with open(logfile, "a") as log, open(errfile, "a") as err: + log.write("Compilation command:\n") + log.write(" ".join(cc)) + log.write("\n\n") + log.write("Link command:\n") + log.write(" ".join(ld)) + log.write("\n\n") + try: + if configuration['no_fork_available']: + cc += ["2>", errfile, ">", logfile] + ld += ["2>>", errfile, ">>", logfile] + cccmd = " ".join(cc) + ldcmd = " ".join(ld) + status = os.system(cccmd) + if status != 0: + raise subprocess.CalledProcessError(status, cccmd) + status = os.system(ldcmd) + if status != 0: + raise subprocess.CalledProcessError(status, ldcmd) + else: + subprocess.check_call(cc, stderr=err, stdout=log) + subprocess.check_call(ld, stderr=err, stdout=log) + except subprocess.CalledProcessError as e: + raise CompilationError(dedent(f""" + Command "{e.cmd}" return error status {e.returncode}. + Unable to compile code + Compile log in {logfile} + Compile errors in {errfile} + """)) + # Atomically ensure soname exists + os.rename(tmpname, soname) + # Wait for compilation to complete + compiler.comm.barrier() + # Load resulting library + dll = ctypes.CDLL(soname) + return dll + + def _add_profiling_events(dll, events): """ If PyOP2 is in profiling mode, events are attached to dll to profile the local linear algebra calls. From 0cf63f91b60a8e53dd38b385c5eb96c4ca01a13b Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Mon, 19 Aug 2024 19:01:38 +0100 Subject: [PATCH 25/38] Remove comm from Compiler class --- pyop2/compilation.py | 105 ++++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index e0e58a15e..aec360913 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -43,6 +43,7 @@ from hashlib import md5 from packaging.version import Version, InvalidVersion from textwrap import dedent +from functools import partial from pyop2 import mpi @@ -87,6 +88,36 @@ def set_default_compiler(compiler): ) +def sniff_compiler_version(compiler, cpp=False): + """Attempt to determine the compiler version number. + + :arg compiler: Instance of compiler to sniff the version of + :arg cpp: If set to True will use the C++ compiler rather than + the C compiler to determine the version number. + """ + # Note: + # Sniffing the compiler version for very large numbers of + # MPI ranks is expensive, ensure this is only run on rank 0 + exe = compiler.cxx if cpp else compiler.cc + version = None + # `-dumpversion` is not sufficient to get the whole version string (for some compilers), + # but other compilers do not implement `-dumpfullversion`! + for dumpstring in ["-dumpfullversion", "-dumpversion"]: + try: + output = subprocess.run( + [exe, dumpstring], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + encoding="utf-8" + ).stdout + version = Version(output) + break + except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion): + continue + return version + + def sniff_compiler(exe, comm=mpi.COMM_WORLD): """Obtain the correct compiler class by calling the compiler executable. @@ -153,6 +184,11 @@ def sniff_compiler(exe, comm=mpi.COMM_WORLD): else: compiler = AnonymousCompiler + # Now try and get a version number + temp = Compiler() + version = sniff_compiler_version(temp) + compiler = partial(compiler, version=version) + return comm.bcast(compiler, 0) @@ -186,9 +222,8 @@ class Compiler(ABC): _optflags = () _debugflags = () - def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, comm=None): - # Set compiler version ASAP since it is used in __repr__ - self.version = None + def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, version=None): + self.version = version self._extra_compiler_flags = tuple(extra_compiler_flags) self._extra_linker_flags = tuple(extra_linker_flags) @@ -196,11 +231,6 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co self._cpp = cpp self._debug = configuration["debug"] - # Compilation communicators are reference counted on the PyOP2 comm - self.pcomm = mpi.internal_comm(comm, self) - self.comm = mpi.compilation_comm(self.pcomm, self) - self.sniff_compiler_version() - def __repr__(self): return f"<{self._name} compiler, version {self.version or 'unknown'}>" @@ -242,35 +272,6 @@ def ldflags(self): ldflags += tuple(shlex.split(configuration["ldflags"])) return ldflags - def sniff_compiler_version(self, cpp=False): - """Attempt to determine the compiler version number. - - :arg cpp: If set to True will use the C++ compiler rather than - the C compiler to determine the version number. - """ - # Note: - # Sniffing the compiler version for very large numbers of - # MPI ranks is expensive - exe = self.cxx if cpp else self.cc - version = None - if self.comm.rank == 0: - # `-dumpversion` is not sufficient to get the whole version string (for some compilers), - # but other compilers do not implement `-dumpfullversion`! - for dumpstring in ["-dumpfullversion", "-dumpversion"]: - try: - output = subprocess.run( - [exe, dumpstring], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=True, - encoding="utf-8" - ).stdout - version = Version(output) - break - except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion): - continue - self.version = self.comm.bcast(version, 0) - @property def bugfix_cflags(self): return () @@ -460,8 +461,8 @@ def __init__(self, code, argtypes): exe = configuration["cc"] or "mpicc" compiler = sniff_compiler(exe, comm) - compiler_instance = compiler(cppargs, ldargs, cpp=cpp, comm=comm) - dll = make_so(compiler_instance, code, extension) + compiler_instance = compiler(cppargs, ldargs, cpp=cpp) + dll = make_so(compiler_instance, code, extension, comm) if isinstance(jitmodule, GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) @@ -485,14 +486,18 @@ def expandWl(ldflags): @mpi.collective -def make_so(compiler, jitmodule, extension): +def make_so(compiler, jitmodule, extension, comm): """Build a shared library and load it :arg compiler: The compiler to use to create the shared library. :arg jitmodule: The JIT Module which can generate the code to compile. :arg extension: extension of the source file (c, cpp). + :arg comm: Communicator over which to perform compilation. Returns a :class:`ctypes.CDLL` object of the resulting shared library.""" + # Compilation communicators are reference counted on the PyOP2 comm + pcomm = mpi.internal_comm(comm, compiler) + comm = mpi.compilation_comm(pcomm, compiler) # C or C++ if compiler._cpp: @@ -510,9 +515,9 @@ def make_so(compiler, jitmodule, extension): hsh.update("".join(compiler_flags).encode()) hsh.update("".join(compiler.ldflags).encode()) - basename = hsh.hexdigest() + basename = hsh.hexdigest() # This is hash key - cachedir = configuration['cache_dir'] + cachedir = configuration['cache_dir'] # This is cachedir dirpart, basename = basename[:2], basename[2:] cachedir = os.path.join(cachedir, dirpart) @@ -525,17 +530,17 @@ def make_so(compiler, jitmodule, extension): tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp") if configuration['check_src_hashes'] or configuration['debug']: - matching = compiler.comm.allreduce(basename, op=_check_op) + matching = comm.allreduce(basename, op=_check_op) if matching != basename: # Dump all src code to disk for debugging output = os.path.join(configuration["cache_dir"], "mismatching-kernels") - srcfile = os.path.join(output, f"src-rank{compiler.comm.rank}.{extension}") - if compiler.comm.rank == 0: + srcfile = os.path.join(output, f"src-rank{comm.rank}.{extension}") + if comm.rank == 0: os.makedirs(output, exist_ok=True) - compiler.comm.barrier() + comm.barrier() with open(srcfile, "w") as f: f.write(jitmodule.code_to_compile) - compiler.comm.barrier() + comm.barrier() raise CompilationError(f"Generated code differs across ranks (see output in {output})") # Check whether this shared object already written to disk @@ -544,11 +549,11 @@ def make_so(compiler, jitmodule, extension): except OSError: dll = None got_dll = bool(dll) - all_dll = compiler.comm.allgather(got_dll) + all_dll = comm.allgather(got_dll) # If the library is not loaded _on all ranks_ build it if not min(all_dll): - if compiler.comm.rank == 0: + if comm.rank == 0: # No need to do this on all ranks os.makedirs(cachedir, exist_ok=True) logfile = os.path.join(cachedir, f"{basename}_p{pid}.log") @@ -625,7 +630,7 @@ def make_so(compiler, jitmodule, extension): # Atomically ensure soname exists os.rename(tmpname, soname) # Wait for compilation to complete - compiler.comm.barrier() + comm.barrier() # Load resulting library dll = ctypes.CDLL(soname) return dll From 7ca74eb9f29c2eb6775dd0ecfa461a1a0963324b Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 20 Aug 2024 00:20:46 +0100 Subject: [PATCH 26/38] Use caching.disk_only_cache for make_so --- pyop2/caching.py | 9 +- pyop2/compilation.py | 293 ++++++++++++++++-------------- test/unit/test_updated_caching.py | 3 + 3 files changed, 161 insertions(+), 144 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 65954beda..a80cc767b 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -267,7 +267,7 @@ def __getitem__(self, key): """ filepath = Path(self.cachedir, key[0][:2], key[0][2:] + key[1]) try: - with self.open(filepath, "rb") as fh: + with self.open(filepath, mode="rb") as fh: value = self.read(fh) except FileNotFoundError: raise KeyError("File not on disk, cache miss") @@ -285,7 +285,7 @@ def __setitem__(self, key, value): tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp") filepath = basedir.joinpath(k2) - with self.open(tempfile, "wb") as fh: + with self.open(tempfile, mode="wb") as fh: self.write(fh, value) tempfile.rename(filepath) @@ -359,6 +359,8 @@ def get(self, key, default=None): self.hit += 1 return value + # JBTODO: Only instrument get, since we have to use get and get item in wrapper + # OR... find away around the hack in compilation.py def __getitem__(self, key): try: value = super().__getitem__(key) @@ -465,7 +467,8 @@ def wrapper(*args, **kwargs): if value is CACHE_MISS: value = func(*args, **kwargs) - return local_cache.setdefault(key, value) + local_cache[key] = value + return local_cache[key] return wrapper return decorator diff --git a/pyop2/compilation.py b/pyop2/compilation.py index aec360913..2e1954239 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -44,10 +44,12 @@ from packaging.version import Version, InvalidVersion from textwrap import dedent from functools import partial +from pathlib import Path +from contextlib import contextmanager from pyop2 import mpi -from pyop2.caching import memory_cache, default_parallel_hashkey +from pyop2.caching import parallel_cache, memory_cache, default_parallel_hashkey from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError @@ -204,10 +206,8 @@ class Compiler(ABC): (optional, prepended to any flags specified as the ldflags configuration option). The environment variable ``PYOP2_LDFLAGS`` can also be used to extend these options. - :arg cpp: Should we try and use the C++ compiler instead of the C - compiler?. - :kwarg comm: Optional communicator to compile the code on - (defaults to pyop2.mpi.COMM_WORLD). + :arg version: (Optional) usually sniffed by loader. + :arg debug: Whether to use debugging compiler flags. """ _name = "unknown" @@ -222,17 +222,22 @@ class Compiler(ABC): _optflags = () _debugflags = () - def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, version=None): - self.version = version - + def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), version=None, debug=False): self._extra_compiler_flags = tuple(extra_compiler_flags) self._extra_linker_flags = tuple(extra_linker_flags) - - self._cpp = cpp - self._debug = configuration["debug"] + self._version = version + self._debug = debug def __repr__(self): - return f"<{self._name} compiler, version {self.version or 'unknown'}>" + string = f"{self.__class__.__name__}(" + string += f"extra_compiler_flags={self._extra_compiler_flags}, " + string += f"extra_linker_flags={self._extra_linker_flags}, " + string += f"version={self._version!r}, " + string += f"debug={self._debug})" + return string + + def __str__(self): + return f"<{self._name} compiler, version {self._version or 'unknown'}>" @property def cc(self): @@ -319,7 +324,7 @@ class LinuxGnuCompiler(Compiler): @property def bugfix_cflags(self): """Flags to work around bugs in compilers.""" - ver = self.version + ver = self._version cflags = () if Version("4.8.0") <= ver < Version("4.9.0"): # GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61068 @@ -448,21 +453,21 @@ def __init__(self, code, argtypes): else: raise ValueError("Don't know how to compile code of type %r" % type(jitmodule)) - cpp = (extension == "cpp") global _compiler if _compiler: # Use the global compiler if it has been set compiler = _compiler else: # Sniff compiler from executable - if cpp: + if extension == "cpp": exe = configuration["cxx"] or "mpicxx" else: exe = configuration["cc"] or "mpicc" compiler = sniff_compiler(exe, comm) - compiler_instance = compiler(cppargs, ldargs, cpp=cpp) - dll = make_so(compiler_instance, code, extension, comm) + debug = configuration["debug"] + compiler_instance = compiler(cppargs, ldargs, debug=debug) + dll = _make_so_wrapper(compiler_instance, code, extension, comm) if isinstance(jitmodule, GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) @@ -485,157 +490,163 @@ def expandWl(ldflags): yield flag +from pyop2.caching import DictLikeDiskAccess + + +class CompilerDiskAccess(DictLikeDiskAccess): + @contextmanager + def open(self, *args, **kwargs): + # In the parent class the `open` method is called by `read` as: + # open(filename, mode="rb") + # and the `write` method as: + # open(tempname, mode="wb") + # Here we bypass this and just return the filename (pathlib.Path object) + # letting the read and write methods handle file opening. + if args[0].suffix: + # Writing: drop PID and extension + args[0].touch() + filename = args[0].with_name(args[0].name.split('_p')[0]) + else: + # Reading: Add extension + filename = args[0].with_suffix(".so") + yield filename + + def write(self, *args, **kwargs): + filename = args[0] + compiler, jitmodule, extension, comm = args[1] + _legacy_make_so(compiler, jitmodule, filename, extension, comm) + + def read(self, filename): + try: + return _legacy_load_so(filename) + except OSError as e: + raise FileNotFoundError(e) + + +def _make_so_hashkey(compiler, jitmodule, extension, comm): + if extension == "cpp": + exe = compiler.cxx + compiler_flags = compiler.cxxflags + else: + exe = compiler.cc + compiler_flags = compiler.cflags + return (compiler, exe, compiler_flags, compiler.ld, compiler.ldflags, jitmodule.cache_key) + + +@mpi.collective +@parallel_cache( + hashkey=_make_so_hashkey, + cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir']), + broadcast=False +) +def _make_so_wrapper(compiler, jitmodule, extension, comm): + # The creation of the shared library is handled by the `write` method of + # `CompilerDiskAccess` above. + # JBTODO: This is a bit of a hack... + return (compiler, jitmodule, extension, comm) + + @mpi.collective -def make_so(compiler, jitmodule, extension, comm): +def _legacy_make_so(compiler, jitmodule, filename, extension, comm): """Build a shared library and load it :arg compiler: The compiler to use to create the shared library. :arg jitmodule: The JIT Module which can generate the code to compile. + :arg filename: The filename of the library to create. :arg extension: extension of the source file (c, cpp). :arg comm: Communicator over which to perform compilation. Returns a :class:`ctypes.CDLL` object of the resulting shared library.""" # Compilation communicators are reference counted on the PyOP2 comm - pcomm = mpi.internal_comm(comm, compiler) - comm = mpi.compilation_comm(pcomm, compiler) + icomm = mpi.internal_comm(comm, compiler) + ccomm = mpi.compilation_comm(icomm, compiler) # C or C++ - if compiler._cpp: + if extension == "cpp": exe = compiler.cxx compiler_flags = compiler.cxxflags else: exe = compiler.cc compiler_flags = compiler.cflags - # Determine cache key - hsh = md5(str(jitmodule.cache_key).encode()) - hsh.update(exe.encode()) - if compiler.ld: - hsh.update(compiler.ld.encode()) - hsh.update("".join(compiler_flags).encode()) - hsh.update("".join(compiler.ldflags).encode()) - - basename = hsh.hexdigest() # This is hash key - - cachedir = configuration['cache_dir'] # This is cachedir - - dirpart, basename = basename[:2], basename[2:] - cachedir = os.path.join(cachedir, dirpart) + base = filename.name + path = filename.parent pid = os.getpid() - cname = os.path.join(cachedir, f"{basename}_p{pid}.{extension}") - oname = os.path.join(cachedir, f"{basename}_p{pid}.o") - soname = os.path.join(cachedir, f"{basename}.so") - # Link into temporary file, then rename to shared library - # atomically (avoiding races). - tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp") + cname = filename.with_name(f"{base}_p{pid}.{extension}") + oname = filename.with_name(f"{base}_p{pid}.o") + # Link into temporary file, then rename to shared library atomically (avoiding races). + tempname = filename.with_stem(f"{base}_p{pid}.so") + soname = filename.with_suffix(".so") if configuration['check_src_hashes'] or configuration['debug']: - matching = comm.allreduce(basename, op=_check_op) - if matching != basename: + # Reconstruct hash from filename + hashval = "".join(filename.parts[-2:]) + matching = ccomm.allreduce(hashval, op=_check_op) + if matching != hashval: # Dump all src code to disk for debugging - output = os.path.join(configuration["cache_dir"], "mismatching-kernels") - srcfile = os.path.join(output, f"src-rank{comm.rank}.{extension}") - if comm.rank == 0: - os.makedirs(output, exist_ok=True) - comm.barrier() - with open(srcfile, "w") as f: - f.write(jitmodule.code_to_compile) - comm.barrier() + output = Path(configuration["cache_dir"]).joinpath("mismatching-kernels") + srcfile = output.with_name(f"src-rank{comm.rank}.{extension}") + if ccomm.rank == 0: + output.mkdir(exist_ok=True) + ccomm.barrier() + with open(srcfile, "w") as fh: + fh.write(jitmodule.code_to_compile) + ccomm.barrier() raise CompilationError(f"Generated code differs across ranks (see output in {output})") - # Check whether this shared object already written to disk - try: - dll = ctypes.CDLL(soname) - except OSError: - dll = None - got_dll = bool(dll) - all_dll = comm.allgather(got_dll) - - # If the library is not loaded _on all ranks_ build it - if not min(all_dll): - if comm.rank == 0: - # No need to do this on all ranks - os.makedirs(cachedir, exist_ok=True) - logfile = os.path.join(cachedir, f"{basename}_p{pid}.log") - errfile = os.path.join(cachedir, f"{basename}_p{pid}.err") - with progress(INFO, 'Compiling wrapper'): - with open(cname, "w") as f: - f.write(jitmodule.code_to_compile) - # Compiler also links - if not compiler.ld: - cc = (exe,) \ - + compiler_flags \ - + ('-o', tmpname, cname) \ - + compiler.ldflags - debug(f"Compilation command: {' '.join(cc)}") - with open(logfile, "w") as log, open(errfile, "w") as err: - log.write("Compilation command:\n") - log.write(" ".join(cc)) - log.write("\n\n") - try: - if configuration['no_fork_available']: - cc += ["2>", errfile, ">", logfile] - cmd = " ".join(cc) - status = os.system(cmd) - if status != 0: - raise subprocess.CalledProcessError(status, cmd) - else: - subprocess.check_call(cc, stderr=err, stdout=log) - except subprocess.CalledProcessError as e: - raise CompilationError(dedent(f""" - Command "{e.cmd}" return error status {e.returncode}. - Unable to compile code - Compile log in {logfile} - Compile errors in {errfile} - """)) - else: - cc = (exe,) \ - + compiler_flags \ - + ('-c', '-o', oname, cname) - # Extract linker specific "cflags" from ldflags - ld = tuple(shlex.split(compiler.ld)) \ - + ('-o', tmpname, oname) \ - + tuple(expandWl(compiler.ldflags)) - debug(f"Compilation command: {' '.join(cc)}", ) - debug(f"Link command: {' '.join(ld)}") - with open(logfile, "a") as log, open(errfile, "a") as err: - log.write("Compilation command:\n") - log.write(" ".join(cc)) - log.write("\n\n") - log.write("Link command:\n") - log.write(" ".join(ld)) - log.write("\n\n") - try: - if configuration['no_fork_available']: - cc += ["2>", errfile, ">", logfile] - ld += ["2>>", errfile, ">>", logfile] - cccmd = " ".join(cc) - ldcmd = " ".join(ld) - status = os.system(cccmd) - if status != 0: - raise subprocess.CalledProcessError(status, cccmd) - status = os.system(ldcmd) - if status != 0: - raise subprocess.CalledProcessError(status, ldcmd) - else: - subprocess.check_call(cc, stderr=err, stdout=log) - subprocess.check_call(ld, stderr=err, stdout=log) - except subprocess.CalledProcessError as e: - raise CompilationError(dedent(f""" - Command "{e.cmd}" return error status {e.returncode}. - Unable to compile code - Compile log in {logfile} - Compile errors in {errfile} - """)) - # Atomically ensure soname exists - os.rename(tmpname, soname) - # Wait for compilation to complete - comm.barrier() - # Load resulting library - dll = ctypes.CDLL(soname) + # Compile on compilation communicator (ccomm) rank 0 + if comm.rank == 0: + logfile = path.with_name(f"{base}_p{pid}.log") + errfile = path.with_name(f"{base}_p{pid}.err") + with progress(INFO, 'Compiling wrapper'): + with open(cname, "w") as fh: + fh.write(jitmodule.code_to_compile) + # Compiler also links + if not compiler.ld: + cc = (exe,) + compiler_flags + ('-o', str(tempname), str(cname)) + compiler.ldflags + _run(cc, logfile, errfile) + else: + cc = (exe,) + compiler_flags + ('-c', '-o', oname, cname) + _run(cc, logfile, errfile) + # Extract linker specific "cflags" from ldflags + ld = tuple(shlex.split(compiler.ld)) + ('-o', str(tempname), str(oname)) + tuple(expandWl(compiler.ldflags)) + _run(ld, logfile, errfile) + # Atomically ensure soname exists + tempname.rename(soname) + # Wait for compilation to complete + ccomm.barrier() + + +def _legacy_load_so(filename): + # Load library + dll = ctypes.CDLL(filename) return dll +def _run(cc, logfile, errfile): + debug(f"Compilation command: {' '.join(cc)}") + try: + if configuration['no_fork_available']: + cc += ("2>", str(errfile), ">", str(logfile)) + cmd = " ".join(cc) + status = os.system(cmd) + if status != 0: + raise subprocess.CalledProcessError(status, cmd) + else: + with open(logfile, "w") as log, open(errfile, "w") as err: + log.write("Compilation command:\n") + log.write(" ".join(cc)) + log.write("\n\n") + subprocess.check_call(cc, stderr=err, stdout=log) + except subprocess.CalledProcessError as e: + raise CompilationError(dedent(f""" + Command "{e.cmd}" return error status {e.returncode}. + Unable to compile code + Compile log in {logfile!s} + Compile errors in {errfile!s} + """)) + + def _add_profiling_events(dll, events): """ If PyOP2 is in profiling mode, events are attached to dll to profile the local linear algebra calls. diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py index 2c8ee53bf..93af9f46c 100644 --- a/test/unit/test_updated_caching.py +++ b/test/unit/test_updated_caching.py @@ -115,12 +115,15 @@ def test_function_args_different(request, state, decorator, uncached_function, t ]) def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir): if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: + # In parallel different ranks can get different tempdirs, we just want one + tmpdir = COMM_WORLD.bcast(tmpdir, root=0) kwargs = {"cachedir": tmpdir} else: kwargs = {} cached_function = function_factory(state, decorator, uncached_function, **kwargs) assert state.value == 0 + for ii in range(10): color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED comm12 = COMM_WORLD.Split(color=color) From 91c25b8428b844a43a777f650d943d16fc12a152 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 20 Aug 2024 18:04:50 +0100 Subject: [PATCH 27/38] A better solution for wrapping make_so in a disk cache --- pyop2/caching.py | 27 +++++---- pyop2/compilation.py | 126 +++++++++++++++++++---------------------- pyop2/global_kernel.py | 2 +- 3 files changed, 72 insertions(+), 83 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index a80cc767b..662cfec1f 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -244,7 +244,7 @@ class _CacheMiss: def _as_hexdigest(*args): hash_ = hashlib.md5() for a in args: - # TODO: Remove or edit this check! + # JBTODO: Remove or edit this check! if isinstance(a, MPI.Comm) or isinstance(a, cachetools.keys._HashedTuple): breakpoint() hash_.update(str(a).encode()) @@ -252,12 +252,14 @@ def _as_hexdigest(*args): class DictLikeDiskAccess(MutableMapping): - def __init__(self, cachedir): + def __init__(self, cachedir, extension=".pickle"): """ :arg cachedir: The cache directory. + :arg extension: Optional extension to use for written files. """ self.cachedir = cachedir + self.extension = extension def __getitem__(self, key): """Retrieve a value from the disk cache. @@ -267,7 +269,7 @@ def __getitem__(self, key): """ filepath = Path(self.cachedir, key[0][:2], key[0][2:] + key[1]) try: - with self.open(filepath, mode="rb") as fh: + with self.open(filepath.with_suffix(self.extension), mode="rb") as fh: value = self.read(fh) except FileNotFoundError: raise KeyError("File not on disk, cache miss") @@ -287,7 +289,7 @@ def __setitem__(self, key, value): filepath = basedir.joinpath(k2) with self.open(tempfile, mode="wb") as fh: self.write(fh, value) - tempfile.rename(filepath) + tempfile.rename(filepath.with_suffix(self.extension)) def __delitem__(self, key): raise NotImplementedError(f"Cannot remove items from {self.__class__.__name__}") @@ -299,11 +301,11 @@ def __len__(self): raise NotImplementedError(f"Cannot query length of {self.__class__.__name__}") def __repr__(self): - return f"{self.__class__.__name__}(cachedir={self.cachedir})" + return f"{self.__class__.__name__}(cachedir={self.cachedir}, extension={self.extension})" def __eq__(self, other): # Instances are the same if they have the same cachedir - return self.cachedir == other.cachedir + return (self.cachedir == other.cachedir and self.extension == other.extension) def open(self, *args, **kwargs): return open(*args, **kwargs) @@ -359,8 +361,6 @@ def get(self, key, default=None): self.hit += 1 return value - # JBTODO: Only instrument get, since we have to use get and get item in wrapper - # OR... find away around the hack in compilation.py def __getitem__(self, key): try: value = super().__getitem__(key) @@ -441,7 +441,7 @@ def wrapper(*args, **kwargs): ) else: debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') - # TODO: Add communication tags to avoid cross-broadcasting + # JBTODO: Add communication tags to avoid cross-broadcasting comm.bcast(value, root=0) else: value = comm.bcast(CACHE_MISS, root=0) @@ -467,8 +467,7 @@ def wrapper(*args, **kwargs): if value is CACHE_MISS: value = func(*args, **kwargs) - local_cache[key] = value - return local_cache[key] + return local_cache.setdefault(key, value) return wrapper return decorator @@ -497,13 +496,13 @@ def decorator(func): return memory_cache(*args, **kwargs)(disk_only_cache(*args, cachedir=cachedir, **kwargs)(func)) return decorator -# TODO: (Wishlist) +# JBTODO: (Wishlist) # * Try more exotic caches ie: memory_cache = partial(parallel_cache, cache_factory=lambda: cachetools.LRUCache(maxsize=1000)) ✓ # * Add some sort of cache reporting ✓ # * Add some sort of cache statistics ✓ -# * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method +# * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method ✓ # * Systematic investigation into cache sizes/types for Firedrake -# - Is a mem cache needed for DLLs? +# - Is a mem cache needed for DLLs? No # - Is LRUCache better than a simple dict? (memory profile test suite) # - What is the optimal maxsize? # * Add some docstrings and maybe some exposition! diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 2e1954239..c4ebe9306 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -46,10 +46,12 @@ from functools import partial from pathlib import Path from contextlib import contextmanager +from tempfile import gettempdir +from itertools import cycle from pyop2 import mpi -from pyop2.caching import parallel_cache, memory_cache, default_parallel_hashkey +from pyop2.caching import parallel_cache, memory_cache, default_parallel_hashkey, _as_hexdigest, DictLikeDiskAccess from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError @@ -417,6 +419,7 @@ def load_hashkey(*args, **kwargs): return default_parallel_hashkey(code_hash, *args[1:], **kwargs) +# JBTODO: This should not be memory cached @mpi.collective @memory_cache(hashkey=load_hashkey, broadcast=False) def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), @@ -467,7 +470,12 @@ def __init__(self, code, argtypes): debug = configuration["debug"] compiler_instance = compiler(cppargs, ldargs, debug=debug) - dll = _make_so_wrapper(compiler_instance, code, extension, comm) + if configuration['check_src_hashes'] or configuration['debug']: + check_source_hashes(compiler_instance, code, extension, comm) + # This call is cached on disk + so_name = make_so(compiler_instance, code, extension, comm) + # This call is cached in memory by the OS + dll = ctypes.CDLL(so_name) if isinstance(jitmodule, GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) @@ -490,37 +498,18 @@ def expandWl(ldflags): yield flag -from pyop2.caching import DictLikeDiskAccess - - class CompilerDiskAccess(DictLikeDiskAccess): @contextmanager - def open(self, *args, **kwargs): - # In the parent class the `open` method is called by `read` as: - # open(filename, mode="rb") - # and the `write` method as: - # open(tempname, mode="wb") - # Here we bypass this and just return the filename (pathlib.Path object) - # letting the read and write methods handle file opening. - if args[0].suffix: - # Writing: drop PID and extension - args[0].touch() - filename = args[0].with_name(args[0].name.split('_p')[0]) - else: - # Reading: Add extension - filename = args[0].with_suffix(".so") + def open(self, filename, *args, **kwargs): yield filename - def write(self, *args, **kwargs): - filename = args[0] - compiler, jitmodule, extension, comm = args[1] - _legacy_make_so(compiler, jitmodule, filename, extension, comm) + def write(self, filename, value): + shutil.copy(value, filename) def read(self, filename): - try: - return _legacy_load_so(filename) - except OSError as e: - raise FileNotFoundError(e) + if not filename.exists(): + raise FileNotFoundError("File not on disk, cache miss") + return filename def _make_so_hashkey(compiler, jitmodule, extension, comm): @@ -533,21 +522,33 @@ def _make_so_hashkey(compiler, jitmodule, extension, comm): return (compiler, exe, compiler_flags, compiler.ld, compiler.ldflags, jitmodule.cache_key) +def check_source_hashes(compiler, jitmodule, extension, comm): + # Reconstruct hash from filename + hashval = _as_hexdigest(_make_so_hashkey(compiler, jitmodule, extension, comm)) + with mpi.temp_internal_comm(comm) as icomm: + matching = icomm.allreduce(hashval, op=_check_op) + if matching != hashval: + # Dump all src code to disk for debugging + output = Path(configuration["cache_dir"]).joinpath("mismatching-kernels") + srcfile = output.with_name(f"src-rank{icomm.rank}.{extension}") + if icomm.rank == 0: + output.mkdir(exist_ok=True) + icomm.barrier() + with open(srcfile, "w") as fh: + fh.write(jitmodule.code_to_compile) + icomm.barrier() + raise CompilationError(f"Generated code differs across ranks (see output in {output})") + + +FILE_CYCLER = cycle(f"{ii:02x}" for ii in range(256)) + + @mpi.collective @parallel_cache( hashkey=_make_so_hashkey, - cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir']), - broadcast=False + cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so"), ) -def _make_so_wrapper(compiler, jitmodule, extension, comm): - # The creation of the shared library is handled by the `write` method of - # `CompilerDiskAccess` above. - # JBTODO: This is a bit of a hack... - return (compiler, jitmodule, extension, comm) - - -@mpi.collective -def _legacy_make_so(compiler, jitmodule, filename, extension, comm): +def make_so(compiler, jitmodule, extension, comm, filename=None): """Build a shared library and load it :arg compiler: The compiler to use to create the shared library. @@ -555,8 +556,17 @@ def _legacy_make_so(compiler, jitmodule, filename, extension, comm): :arg filename: The filename of the library to create. :arg extension: extension of the source file (c, cpp). :arg comm: Communicator over which to perform compilation. + :arg filename: Optional Returns a :class:`ctypes.CDLL` object of the resulting shared library.""" + if filename is None: + tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") + tempdir.mkdir(exist_ok=True) + filename = tempdir.joinpath(f"foo{next(FILE_CYCLER)}.c") + else: + filename = Path(filename).absolute() + filename.parent.mkdir(exist_ok=True) + # Compilation communicators are reference counted on the PyOP2 comm icomm = mpi.internal_comm(comm, compiler) ccomm = mpi.compilation_comm(icomm, compiler) @@ -578,26 +588,10 @@ def _legacy_make_so(compiler, jitmodule, filename, extension, comm): tempname = filename.with_stem(f"{base}_p{pid}.so") soname = filename.with_suffix(".so") - if configuration['check_src_hashes'] or configuration['debug']: - # Reconstruct hash from filename - hashval = "".join(filename.parts[-2:]) - matching = ccomm.allreduce(hashval, op=_check_op) - if matching != hashval: - # Dump all src code to disk for debugging - output = Path(configuration["cache_dir"]).joinpath("mismatching-kernels") - srcfile = output.with_name(f"src-rank{comm.rank}.{extension}") - if ccomm.rank == 0: - output.mkdir(exist_ok=True) - ccomm.barrier() - with open(srcfile, "w") as fh: - fh.write(jitmodule.code_to_compile) - ccomm.barrier() - raise CompilationError(f"Generated code differs across ranks (see output in {output})") - # Compile on compilation communicator (ccomm) rank 0 if comm.rank == 0: - logfile = path.with_name(f"{base}_p{pid}.log") - errfile = path.with_name(f"{base}_p{pid}.err") + logfile = path.joinpath(f"{base}_p{pid}.log") + errfile = path.joinpath(f"{base}_p{pid}.err") with progress(INFO, 'Compiling wrapper'): with open(cname, "w") as fh: fh.write(jitmodule.code_to_compile) @@ -610,31 +604,27 @@ def _legacy_make_so(compiler, jitmodule, filename, extension, comm): _run(cc, logfile, errfile) # Extract linker specific "cflags" from ldflags ld = tuple(shlex.split(compiler.ld)) + ('-o', str(tempname), str(oname)) + tuple(expandWl(compiler.ldflags)) - _run(ld, logfile, errfile) + _run(ld, logfile, errfile, step="Linker", filemode="a") # Atomically ensure soname exists tempname.rename(soname) # Wait for compilation to complete ccomm.barrier() + return soname -def _legacy_load_so(filename): - # Load library - dll = ctypes.CDLL(filename) - return dll - - -def _run(cc, logfile, errfile): - debug(f"Compilation command: {' '.join(cc)}") +def _run(cc, logfile, errfile, step="Compilation", filemode="w"): + debug(f"{step} command: {' '.join(cc)}") try: if configuration['no_fork_available']: - cc += ("2>", str(errfile), ">", str(logfile)) + redirect = ">" if filemode == "w" else ">>" + cc += (f"2{redirect}", str(errfile), redirect, str(logfile)) cmd = " ".join(cc) status = os.system(cmd) if status != 0: raise subprocess.CalledProcessError(status, cmd) else: - with open(logfile, "w") as log, open(errfile, "w") as err: - log.write("Compilation command:\n") + with open(logfile, filemode) as log, open(errfile, filemode) as err: + log.write(f"{step} command:\n") log.write(" ".join(cc)) log.write("\n\n") subprocess.check_call(cc, stderr=err, stdout=log) diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 3c8cde430..d9108119f 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -360,13 +360,13 @@ def builder(self): builder.add_argument(arg) return builder - # TODO: Wrap with parallel_cached_property @cached_property def code_to_compile(self): """Return the C/C++ source code as a string.""" from pyop2.codegen.rep2loopy import generate wrapper = generate(self.builder) + # JBTODO: Expensive? Can this be wrapped with a cache? code = lp.generate_code_v2(wrapper) if self.local_kernel.cpp: From 919523ab3cbd4b10393dc58c541935b9c4cee252 Mon Sep 17 00:00:00 2001 From: "David A. Ham" Date: Wed, 21 Aug 2024 16:35:42 +0100 Subject: [PATCH 28/38] Update pyop2/caching.py Co-authored-by: Connor Ward --- pyop2/caching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 662cfec1f..41908f00f 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -392,7 +392,7 @@ def parallel_cache( ): """Memory only cache decorator. - Decorator for wrapping a function to be called over a communiucator in a + Decorator for wrapping a function to be called over a communicator in a cache that stores broadcastable values in memory. If the value is found in the cache of rank 0 it is broadcast to all other ranks. From b1ecd928130598fadcac3838b7171fb4db708060 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 22 Aug 2024 18:30:50 +0100 Subject: [PATCH 29/38] WIP: may have reintroduced deadlocks --- pyop2/compilation.py | 39 ++++++++++++++++++++++++++++++--------- pyop2/global_kernel.py | 17 +++++++++++------ test/unit/test_caching.py | 2 +- 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index c4ebe9306..6105a445a 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -55,6 +55,7 @@ from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError +import pyop2.global_kernel from petsc4py import PETSc @@ -420,6 +421,7 @@ def load_hashkey(*args, **kwargs): # JBTODO: This should not be memory cached +# ...benchmarking disagrees with my assessment @mpi.collective @memory_cache(hashkey=load_hashkey, broadcast=False) def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), @@ -440,8 +442,6 @@ def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), :kwarg comm: Optional communicator to compile the code on (only rank 0 compiles code) (defaults to pyop2.mpi.COMM_WORLD). """ - from pyop2.global_kernel import GlobalKernel - if isinstance(jitmodule, str): class StrCode(object): def __init__(self, code, argtypes): @@ -451,7 +451,7 @@ def __init__(self, code, argtypes): # cache key self.argtypes = argtypes code = StrCode(jitmodule, argtypes) - elif isinstance(jitmodule, GlobalKernel): + elif isinstance(jitmodule, pyop2.global_kernel.GlobalKernel): code = jitmodule else: raise ValueError("Don't know how to compile code of type %r" % type(jitmodule)) @@ -477,7 +477,7 @@ def __init__(self, code, argtypes): # This call is cached in memory by the OS dll = ctypes.CDLL(so_name) - if isinstance(jitmodule, GlobalKernel): + if isinstance(jitmodule, pyop2.global_kernel.GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) fn = getattr(dll, fn_name) @@ -511,6 +511,13 @@ def read(self, filename): raise FileNotFoundError("File not on disk, cache miss") return filename + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return self[key] + def _make_so_hashkey(compiler, jitmodule, extension, comm): if extension == "cpp": @@ -546,7 +553,7 @@ def check_source_hashes(compiler, jitmodule, extension, comm): @mpi.collective @parallel_cache( hashkey=_make_so_hashkey, - cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so"), + cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so") ) def make_so(compiler, jitmodule, extension, comm, filename=None): """Build a shared library and load it @@ -560,12 +567,15 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): Returns a :class:`ctypes.CDLL` object of the resulting shared library.""" if filename is None: - tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") - tempdir.mkdir(exist_ok=True) - filename = tempdir.joinpath(f"foo{next(FILE_CYCLER)}.c") + # JBTODO: Remove this directory at some point? + pyop2_tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") + tempdir = pyop2_tempdir.joinpath(f"{os.getpid()}") + # ~ tempdir = Path(mkdtemp(dir=pyop2_tempdir.joinpath(f"{os.getpid()}"))) + # This path + filename should be unique + filename = tempdir.joinpath("foo.c") else: + pyop2_tempdir = None filename = Path(filename).absolute() - filename.parent.mkdir(exist_ok=True) # Compilation communicators are reference counted on the PyOP2 comm icomm = mpi.internal_comm(comm, compiler) @@ -590,6 +600,11 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): # Compile on compilation communicator (ccomm) rank 0 if comm.rank == 0: + if pyop2_tempdir is None: + filename.parent.mkdir(exist_ok=True) + else: + pyop2_tempdir.mkdir(exist_ok=True) + tempdir.mkdir(exist_ok=True) logfile = path.joinpath(f"{base}_p{pid}.log") errfile = path.joinpath(f"{base}_p{pid}.err") with progress(INFO, 'Compiling wrapper'): @@ -612,6 +627,12 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): return soname +# JBTODO: Probably don't want to do this if we fail to compile... +# ~ @atexit +# ~ def _cleanup_tempdir(): + # ~ pyop2_tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") + + def _run(cc, logfile, errfile, step="Compilation", filemode="w"): debug(f"{step} command: {' '.join(cc)}") try: diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index d9108119f..7e313a5e8 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -10,7 +10,8 @@ import pytools from petsc4py import PETSc -from pyop2 import compilation, mpi +from pyop2 import mpi +from pyop2.compilation import load from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes from pyop2.types import IterationRegion, Constant, READ @@ -397,11 +398,15 @@ def compile(self, comm): + tuple(self.local_kernel.ldargs) ) - return compilation.load(self, extension, self.name, - cppargs=cppargs, - ldargs=ldargs, - restype=ctypes.c_int, - comm=comm) + return load( + self, + extension, + self.name, + cppargs=cppargs, + ldargs=ldargs, + restype=ctypes.c_int, + comm=comm + ) @cached_property def argtypes(self): diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index 6ab909b29..e335ec680 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -290,7 +290,7 @@ def cache(self): _cache_collection = int_comm.Get_attr(mpi.comm_cache_keyval) if _cache_collection is None: _cache_collection = {default_cache_name: DEFAULT_CACHE()} - mpi.COMM_WORLD.Set_attr(mpi.comm_cache_keyval, _cache_collection) + int_comm.Set_attr(mpi.comm_cache_keyval, _cache_collection) return _cache_collection[default_cache_name] @pytest.fixture From c6225f04710b07da5bfef74734dc7f93eb00370c Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 23 Aug 2024 00:12:20 +0100 Subject: [PATCH 30/38] WIP: Fixed the deadlock --- pyop2/compilation.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 6105a445a..69b6f44d0 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -48,6 +48,7 @@ from contextlib import contextmanager from tempfile import gettempdir from itertools import cycle +from uuid import uuid4 from pyop2 import mpi @@ -194,7 +195,7 @@ def sniff_compiler(exe, comm=mpi.COMM_WORLD): version = sniff_compiler_version(temp) compiler = partial(compiler, version=version) - return comm.bcast(compiler, 0) + return comm.bcast(compiler, root=0) class Compiler(ABC): @@ -568,11 +569,13 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): library.""" if filename is None: # JBTODO: Remove this directory at some point? + # Directory must be unique per user for shared machines pyop2_tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") - tempdir = pyop2_tempdir.joinpath(f"{os.getpid()}") - # ~ tempdir = Path(mkdtemp(dir=pyop2_tempdir.joinpath(f"{os.getpid()}"))) + # A UUID should ensure we have a unique path + uuid = uuid4().hex + tempdir = pyop2_tempdir.joinpath(f"{uuid[:2]}") # This path + filename should be unique - filename = tempdir.joinpath("foo.c") + filename = tempdir.joinpath(f"{uuid[2:]}.{extension}") else: pyop2_tempdir = None filename = Path(filename).absolute() @@ -589,6 +592,7 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): exe = compiler.cc compiler_flags = compiler.cflags + # JBTODO: Do we still need to worry about atomic file renaming in this function? base = filename.name path = filename.parent pid = os.getpid() @@ -599,7 +603,7 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): soname = filename.with_suffix(".so") # Compile on compilation communicator (ccomm) rank 0 - if comm.rank == 0: + if ccomm.rank == 0: if pyop2_tempdir is None: filename.parent.mkdir(exist_ok=True) else: @@ -622,9 +626,8 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): _run(ld, logfile, errfile, step="Linker", filemode="a") # Atomically ensure soname exists tempname.rename(soname) - # Wait for compilation to complete - ccomm.barrier() - return soname + + return ccomm.bcast(soname, root=0) # JBTODO: Probably don't want to do this if we fail to compile... From 13bc76c11170a28b9bb33a6466ae337ccbcb000e Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 23 Aug 2024 16:01:56 +0100 Subject: [PATCH 31/38] Add event decorators --- pyop2/caching.py | 6 ++++++ pyop2/compilation.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/pyop2/caching.py b/pyop2/caching.py index 41908f00f..ab48c966f 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -48,6 +48,7 @@ from pyop2.mpi import ( MPI, COMM_WORLD, comm_cache_keyval, temp_internal_comm ) +from petsc4py import PETSc # Caches created here are registered as a tuple of @@ -403,6 +404,7 @@ def parallel_cache( subcommunicators. """ def decorator(func): + @PETSc.Log.EventDecorator("PyOP2 Cache Wrapper") @wraps(func) def wrapper(*args, **kwargs): """ Extract the key and then try the memory cache before falling back @@ -430,6 +432,10 @@ def wrapper(*args, **kwargs): if (comm, comm.name, func, local_cache) not in [k[1:] for k in _KNOWN_CACHES]: _KNOWN_CACHES.append((next(_CACHE_CIDX), comm, comm.name, func, local_cache)) + # JBTODO: Replace everything below here with: + # value = local_cache.get(key, CACHE_MISS) + # and add an optional PYOP2_SPMD_STRICT environment variable + if broadcast: # Grab value from rank 0 memory cache and broadcast result if comm.rank == 0: diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 69b6f44d0..a17272a3f 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -425,6 +425,7 @@ def load_hashkey(*args, **kwargs): # ...benchmarking disagrees with my assessment @mpi.collective @memory_cache(hashkey=load_hashkey, broadcast=False) +@PETSc.Log.EventDecorator() def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), argtypes=None, restype=None, comm=None): """Build a shared library and return a function pointer from it. @@ -556,6 +557,7 @@ def check_source_hashes(compiler, jitmodule, extension, comm): hashkey=_make_so_hashkey, cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so") ) +@PETSc.Log.EventDecorator() def make_so(compiler, jitmodule, extension, comm, filename=None): """Build a shared library and load it From 9bdc4186508bf0a5043f07ee96b2721279e0a7de Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 23 Aug 2024 19:34:26 +0100 Subject: [PATCH 32/38] WIP: Fixing and tidying --- pyop2/caching.py | 20 ++++++--- pyop2/compilation.py | 94 +++++++++++++++++++++--------------------- pyop2/configuration.py | 1 + pyop2/exceptions.py | 10 +++++ pyop2/mpi.py | 2 + 5 files changed, 73 insertions(+), 54 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index ab48c966f..56b30096e 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -36,6 +36,7 @@ import hashlib import os import pickle +import weakref from collections.abc import MutableMapping from pathlib import Path from warnings import warn # noqa F401 @@ -44,6 +45,7 @@ from functools import partial, wraps from pyop2.configuration import configuration +from pyop2.exceptions import CachingError, HashError # noqa: F401 from pyop2.logger import debug from pyop2.mpi import ( MPI, COMM_WORLD, comm_cache_keyval, temp_internal_comm @@ -245,9 +247,8 @@ class _CacheMiss: def _as_hexdigest(*args): hash_ = hashlib.md5() for a in args: - # JBTODO: Remove or edit this check! - if isinstance(a, MPI.Comm) or isinstance(a, cachetools.keys._HashedTuple): - breakpoint() + if isinstance(a, MPI.Comm): + raise HashError("Communicators cannot be hashed, caching will be broken!") hash_.update(str(a).encode()) return hash_.hexdigest() @@ -385,6 +386,8 @@ class DEFAULT_CACHE(dict): # - DictLikeDiskAccess = instrument(DictLikeDiskAccess) +# JBTODO: This functionality should only be enabled with a PYOP2_SPMD_STRICT +# environment variable. def parallel_cache( hashkey=default_parallel_hashkey, comm_fetcher=default_comm_fetcher, @@ -429,8 +432,13 @@ def wrapper(*args, **kwargs): local_cache = cache_collection[cf.__class__.__name__] # If this is a new cache or function add it to the list of known caches - if (comm, comm.name, func, local_cache) not in [k[1:] for k in _KNOWN_CACHES]: - _KNOWN_CACHES.append((next(_CACHE_CIDX), comm, comm.name, func, local_cache)) + if (comm, comm.name, func, weakref.ref(local_cache)) not in [c[1:] for c in _KNOWN_CACHES]: + # JBTODO: When a comm is freed we will not hold a ref to the cache, + # but we should have a finalizer that extracts the stats before the object + # is deleted. + _KNOWN_CACHES.append( + (next(_CACHE_CIDX), comm, comm.name, func, weakref.ref(local_cache)) + ) # JBTODO: Replace everything below here with: # value = local_cache.get(key, CACHE_MISS) @@ -508,7 +516,7 @@ def decorator(func): # * Add some sort of cache statistics ✓ # * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method ✓ # * Systematic investigation into cache sizes/types for Firedrake -# - Is a mem cache needed for DLLs? No +# - Is a mem cache needed for DLLs? ~~No~~ Yes!! # - Is LRUCache better than a simple dict? (memory profile test suite) # - What is the optimal maxsize? # * Add some docstrings and maybe some exposition! diff --git a/pyop2/compilation.py b/pyop2/compilation.py index a17272a3f..9f2263d62 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -47,7 +47,6 @@ from pathlib import Path from contextlib import contextmanager from tempfile import gettempdir -from itertools import cycle from uuid import uuid4 @@ -69,6 +68,8 @@ def _check_hashes(x, y, datatype): _check_op = mpi.MPI.Op.Create(_check_hashes, commute=True) _compiler = None +# Directory must be unique per user for shared machines +MEM_TMP_DIR = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") def set_default_compiler(compiler): @@ -421,8 +422,6 @@ def load_hashkey(*args, **kwargs): return default_parallel_hashkey(code_hash, *args[1:], **kwargs) -# JBTODO: This should not be memory cached -# ...benchmarking disagrees with my assessment @mpi.collective @memory_cache(hashkey=load_hashkey, broadcast=False) @PETSc.Log.EventDecorator() @@ -476,7 +475,7 @@ def __init__(self, code, argtypes): check_source_hashes(compiler_instance, code, extension, comm) # This call is cached on disk so_name = make_so(compiler_instance, code, extension, comm) - # This call is cached in memory by the OS + # This call might be cached in memory by the OS (system dependent) dll = ctypes.CDLL(so_name) if isinstance(jitmodule, pyop2.global_kernel.GlobalKernel): @@ -532,6 +531,14 @@ def _make_so_hashkey(compiler, jitmodule, extension, comm): def check_source_hashes(compiler, jitmodule, extension, comm): + """A check to see whether code generated on all ranks is identical. + + :arg compiler: The compiler to use to create the shared library. + :arg jitmodule: The JIT Module which can generate the code to compile. + :arg filename: The filename of the library to create. + :arg extension: extension of the source file (c, cpp). + :arg comm: Communicator over which to perform compilation. + """ # Reconstruct hash from filename hashval = _as_hexdigest(_make_so_hashkey(compiler, jitmodule, extension, comm)) with mpi.temp_internal_comm(comm) as icomm: @@ -549,9 +556,6 @@ def check_source_hashes(compiler, jitmodule, extension, comm): raise CompilationError(f"Generated code differs across ranks (see output in {output})") -FILE_CYCLER = cycle(f"{ii:02x}" for ii in range(256)) - - @mpi.collective @parallel_cache( hashkey=_make_so_hashkey, @@ -570,16 +574,14 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): Returns a :class:`ctypes.CDLL` object of the resulting shared library.""" if filename is None: - # JBTODO: Remove this directory at some point? - # Directory must be unique per user for shared machines - pyop2_tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") # A UUID should ensure we have a unique path uuid = uuid4().hex - tempdir = pyop2_tempdir.joinpath(f"{uuid[:2]}") + # Taking the first two characters avoids using excessive filesystem inodes + tempdir = MEM_TMP_DIR.joinpath(f"{uuid[:2]}") # This path + filename should be unique filename = tempdir.joinpath(f"{uuid[2:]}.{extension}") else: - pyop2_tempdir = None + tempdir = None filename = Path(filename).absolute() # Compilation communicators are reference counted on the PyOP2 comm @@ -594,8 +596,8 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): exe = compiler.cc compiler_flags = compiler.cflags - # JBTODO: Do we still need to worry about atomic file renaming in this function? - base = filename.name + # TODO: Do we still need to worry about atomic file renaming in this function? + base = filename.stem path = filename.parent pid = os.getpid() cname = filename.with_name(f"{base}_p{pid}.{extension}") @@ -606,11 +608,10 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): # Compile on compilation communicator (ccomm) rank 0 if ccomm.rank == 0: - if pyop2_tempdir is None: + if tempdir is None: filename.parent.mkdir(exist_ok=True) else: - pyop2_tempdir.mkdir(exist_ok=True) - tempdir.mkdir(exist_ok=True) + tempdir.mkdir(parents=True, exist_ok=True) logfile = path.joinpath(f"{base}_p{pid}.log") errfile = path.joinpath(f"{base}_p{pid}.err") with progress(INFO, 'Compiling wrapper'): @@ -632,13 +633,9 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): return ccomm.bcast(soname, root=0) -# JBTODO: Probably don't want to do this if we fail to compile... -# ~ @atexit -# ~ def _cleanup_tempdir(): - # ~ pyop2_tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") - - def _run(cc, logfile, errfile, step="Compilation", filemode="w"): + """ Run a compilation command and handle logging + errors. + """ debug(f"{step} command: {' '.join(cc)}") try: if configuration['no_fork_available']: @@ -686,28 +683,29 @@ def clear_cache(prompt=False): :arg prompt: if ``True`` prompt before removing any files """ - cachedir = configuration['cache_dir'] - - if not os.path.exists(cachedir): - print("Cache directory could not be found") - return - if len(os.listdir(cachedir)) == 0: - print("No cached libraries to remove") - return - - remove = True - if prompt: - user = input(f"Remove cached libraries from {cachedir}? [Y/n]: ") - - while user.lower() not in ['', 'y', 'n']: - print("Please answer y or n.") - user = input(f"Remove cached libraries from {cachedir}? [Y/n]: ") - - if user.lower() == 'n': - remove = False - - if remove: - print(f"Removing cached libraries from {cachedir}") - shutil.rmtree(cachedir, ignore_errors=True) - else: - print("Not removing cached libraries") + cachedirs = [configuration['cache_dir'], MEM_TMP_DIR] + + for directory in cachedirs: + if not os.path.exists(directory): + print("Cache directory could not be found") + return + if len(os.listdir(directory)) == 0: + print("No cached libraries to remove") + return + + remove = True + if prompt: + user = input(f"Remove cached libraries from {directory}? [Y/n]: ") + + while user.lower() not in ['', 'y', 'n']: + print("Please answer y or n.") + user = input(f"Remove cached libraries from {directory}? [Y/n]: ") + + if user.lower() == 'n': + remove = False + + if remove: + print(f"Removing cached libraries from {directory}") + shutil.rmtree(directory, ignore_errors=True) + else: + print("Not removing cached libraries") diff --git a/pyop2/configuration.py b/pyop2/configuration.py index ff3721a6f..dc4db1679 100644 --- a/pyop2/configuration.py +++ b/pyop2/configuration.py @@ -40,6 +40,7 @@ from pyop2.exceptions import ConfigurationError +# JBTODO: Add a PYOP2_SPMD_STRICT environment variable to add various SPMD checks. class Configuration(dict): r"""PyOP2 configuration parameters diff --git a/pyop2/exceptions.py b/pyop2/exceptions.py index 9211857d0..eec5eedac 100644 --- a/pyop2/exceptions.py +++ b/pyop2/exceptions.py @@ -146,3 +146,13 @@ class CompilationError(RuntimeError): class SparsityFormatError(ValueError): """Unable to produce a sparsity for this matrix format.""" + + +class CachingError(ValueError): + + """A caching error.""" + + +class HashError(CachingError): + + """Something is wrong with the hash.""" diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 0237433cb..2831bc04f 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -160,6 +160,8 @@ class PyOP2CommError(ValueError): # PYOP2_FINALISED flag. +# JBTODO: Make this decorator infinitely more useful by adding barriers before +# and after the function call, if being run with PYOP2_SPMD_STRICT=1. def collective(fn): extra = trim(""" This function is logically collective over MPI ranks, it is an From aeddb093f15d64fd73cfa22435e227675820f523 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 23 Aug 2024 19:35:08 +0100 Subject: [PATCH 33/38] Add additional cache tests --- test/unit/test_updated_caching.py | 52 +++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py index 93af9f46c..5066554a1 100644 --- a/test/unit/test_updated_caching.py +++ b/test/unit/test_updated_caching.py @@ -1,14 +1,19 @@ +import ctypes import pytest +import os +import tempfile from functools import partial +from itertools import chain +from textwrap import dedent -from pyop2.caching import ( # noqa: F401 +from pyop2.caching import ( disk_only_cache, memory_cache, memory_and_disk_cache, - default_parallel_hashkey, clear_memory_cache ) -from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval # noqa: F401 +from pyop2.compilation import load +from pyop2.mpi import MPI, COMM_WORLD class StateIncrement: @@ -138,3 +143,44 @@ def test_function_over_different_comms(request, state, decorator, uncached_funct comm23.Free() clear_memory_cache(COMM_WORLD) + + +# pyop2/compilation.py uses a custom cache which we test here +@pytest.mark.parallel(nprocs=2) +def test_writing_large_so(): + # This test exercises the compilation caching when handling larger files + if COMM_WORLD.rank == 0: + preamble = dedent("""\ + #include \n + void big(double *result){ + """) + variables = (f"v{next(tempfile._get_candidate_names())}" for _ in range(128*1024)) + lines = (f" double {v} = {hash(v)/1000000000};\n *result += {v};\n" for v in variables) + program = "\n".join(chain.from_iterable(((preamble, ), lines, ("}\n", )))) + with open("big.c", "w") as fh: + fh.write(program) + + COMM_WORLD.Barrier() + with open("big.c", "r") as fh: + program = fh.read() + + if COMM_WORLD.rank == 1: + os.remove("big.c") + + fn = load(program, "c", "big", argtypes=(ctypes.c_voidp,), comm=COMM_WORLD) + assert fn is not None + + +@pytest.mark.parallel(nprocs=2) +def test_two_comms_compile_the_same_code(): + new_comm = COMM_WORLD.Split(color=COMM_WORLD.rank) + new_comm.name = "test_two_comms" + code = dedent("""\ + #include \n + void noop(){ + printf("Do nothing!\\n"); + } + """) + + fn = load(code, "c", "noop", argtypes=(), comm=COMM_WORLD) + assert fn is not None From f4f41944ca99e6569c70afea2bc432b130b140fd Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Sat, 24 Aug 2024 23:45:58 +0100 Subject: [PATCH 34/38] Fix stats printing --- pyop2/caching.py | 127 +++++++++++++++++++++++++------------------ pyop2/compilation.py | 5 +- 2 files changed, 78 insertions(+), 54 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 56b30096e..4771de511 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -32,6 +32,7 @@ # OF THE POSSIBILITY OF SUCH DAMAGE. """Provides common base classes for cached objects.""" +import atexit import cachetools import hashlib import os @@ -148,59 +149,82 @@ def make_obj(): return obj -def cache_stats(comm=None, comm_name=None, alive=True, function=None, cache_type=None): +def cache_filter(comm=None, comm_name=None, alive=True, function=None, cache_type=None): caches = _KNOWN_CACHES if comm is not None: with temp_internal_comm(comm) as icomm: cache_collection = icomm.Get_attr(comm_cache_keyval) if cache_collection is None: - print(f"Communicator {icomm.name} as no associated caches") + print(f"Communicator {icomm.name} has no associated caches") comm_name = icomm.name if comm_name is not None: - caches = filter(lambda c: c[2] == comm_name, caches) + caches = filter(lambda c: c.comm_name == comm_name, caches) if alive: - caches = filter(lambda c: c[1] != MPI.COMM_NULL, caches) + caches = filter(lambda c: c.comm != MPI.COMM_NULL, caches) if function is not None: if isinstance(function, str): - caches = filter(lambda c: function in c[3].__qualname__, caches) + caches = filter(lambda c: function in c.func_name, caches) else: - caches = filter(lambda c: c[3] is function, caches) + caches = filter(lambda c: c.func is function, caches) if cache_type is not None: if isinstance(cache_type, str): - caches = filter(lambda c: cache_type in c[4].__qualname__, caches) + caches = filter(lambda c: cache_type in c.cache_name, caches) else: - caches = filter(lambda c: isinstance(c[4], cache_type), caches) + caches = filter(lambda c: c.cache_name == cache_type.__class__.__qualname__, caches) return [*caches] -def get_stats(cache): - hit = miss = size = maxsize = -1 - if isinstance(cache, cachetools.Cache): - size = cache.currsize - maxsize = cache.maxsize - if hasattr(cache, "instrument__"): - hit = cache.hit - miss = cache.miss - if size is None: - try: - size = len(cache) - except NotImplementedError: - pass - if maxsize is None: - try: - maxsize = cache.max_size - except AttributeError: - pass - return hit, miss, size, maxsize +class _CacheRecord: + def __init__(self, cidx, comm, func, cache): + self.cidx = cidx + self.comm = comm + self.comm_name = comm.name + self.func = func + self.func_module = func.__module__ + self.func_name = func.__qualname__ + self.cache = weakref.ref(cache) + fin = weakref.finalize(cache, self.finalize, cache) + fin.atexit = False + self.cache_name = cache.__class__.__qualname__ + try: + self.cache_loc = cache.cachedir + except AttributeError: + self.cache_loc = "Memory" + + def get_stats(self, cache=None): + if cache is None: + cache = self.cache() + hit = miss = size = maxsize = -1 + if cache is None: + hit, miss, size, maxsize = self.hit, self.miss, self.size, self.maxsize + if isinstance(cache, cachetools.Cache): + size = cache.currsize + maxsize = cache.maxsize + if hasattr(cache, "instrument__"): + hit = cache.hit + miss = cache.miss + if size == -1: + try: + size = len(cache) + except NotImplementedError: + pass + if maxsize is None: + try: + maxsize = cache.max_size + except AttributeError: + pass + return hit, miss, size, maxsize + + def finalize(self, cache): + self.hit, self.miss, self.size, self.maxsize = self.get_stats(cache) def print_cache_stats(*args, **kwargs): data = defaultdict(lambda: defaultdict(list)) - for entry in cache_stats(*args, **kwargs): - ecid, ecomm, ecomm_name, efunction, ecache = entry - active = (ecomm != MPI.COMM_NULL) - data[(ecomm_name, active)][ecache.__class__.__name__].append( - (ecid, efunction.__module__, efunction.__name__, ecache) + for entry in cache_filter(*args, **kwargs): + active = (entry.comm != MPI.COMM_NULL) + data[(entry.comm_name, active)][(entry.cache_name, entry.cache_loc)].append( + (entry.cidx, entry.func_module, entry.func_name, entry.get_stats()) ) tab = " " @@ -219,22 +243,22 @@ def print_cache_stats(*args, **kwargs): comm_title = f"{ecomm[0]} ({active})" print(f"|{comm_title:{col[0]}}|{no_stats}|") for ecache, function_list in cachedict.items(): - cache_title = f"{tab}{ecache}" + cache_title = f"{tab}{ecache[0]}" print(f"|{cache_title:{col[0]}}|{no_stats}|") - try: - loc = function_list[0][-1].cachedir - except AttributeError: - loc = "Memory" - cache_location = f"{tab} ↳ {loc!s}" - if len(str(loc)) < col[0] - 5: + cache_location = f"{tab} ↳ {ecache[1]!s}" + if len(cache_location) < col[0]: print(f"|{cache_location:{col[0]}}|{no_stats}|") else: print(f"|{cache_location:78}|") for entry in function_list: function_title = f"{tab*2}id={entry[0]} {'.'.join(entry[1:3])}" - stats = "|".join(f"{s:{w}}" for s, w in zip(get_stats(entry[3]), stats_col)) - print(f"|{function_title:{col[0]}}|{stats:{col[1]}}|") - print(hline) + stats_row = "|".join(f"{s:{w}}" for s, w in zip(entry[3], stats_col)) + print(f"|{function_title:{col[0]}}|{stats_row:{col[1]}}|") + print(hline) + + +if _running_on_ci: + print_cache_stats = atexit.register(print_cache_stats) class _CacheMiss: @@ -379,11 +403,11 @@ class DEFAULT_CACHE(dict): # Examples of how to instrument and use different default caches: -# - DEFAULT_CACHE = instrument(DEFAULT_CACHE) -# - DEFAULT_CACHE = instrument(cachetools.LRUCache) -# - DEFAULT_CACHE = partial(DEFAULT_CACHE, maxsize=100) +# ~ DEFAULT_CACHE = instrument(DEFAULT_CACHE) +# ~ DEFAULT_CACHE = instrument(cachetools.LRUCache) +# ~ DEFAULT_CACHE = partial(DEFAULT_CACHE, maxsize=100) EXOTIC_CACHE = partial(instrument(cachetools.LRUCache), maxsize=100) -# - DictLikeDiskAccess = instrument(DictLikeDiskAccess) +# ~ DictLikeDiskAccess = instrument(DictLikeDiskAccess) # JBTODO: This functionality should only be enabled with a PYOP2_SPMD_STRICT @@ -432,13 +456,11 @@ def wrapper(*args, **kwargs): local_cache = cache_collection[cf.__class__.__name__] # If this is a new cache or function add it to the list of known caches - if (comm, comm.name, func, weakref.ref(local_cache)) not in [c[1:] for c in _KNOWN_CACHES]: - # JBTODO: When a comm is freed we will not hold a ref to the cache, - # but we should have a finalizer that extracts the stats before the object + if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]: + # When a comm is freed we do not hold a reference to the cache. + # We attach a finalizer that extracts the stats before the cache # is deleted. - _KNOWN_CACHES.append( - (next(_CACHE_CIDX), comm, comm.name, func, weakref.ref(local_cache)) - ) + _KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache)) # JBTODO: Replace everything below here with: # value = local_cache.get(key, CACHE_MISS) @@ -487,6 +509,7 @@ def wrapper(*args, **kwargs): return decorator +# JBTODO: This needs some more thought def clear_memory_cache(comm): with temp_internal_comm(comm) as icomm: if icomm.Get_attr(comm_cache_keyval) is not None: diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 9f2263d62..ce45c7e99 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -678,6 +678,7 @@ def _add_profiling_events(dll, events): ctypes.c_int.in_dll(dll, 'ID_'+e).value = PETSc.Log.Event(e).id +# JBTODO: Move to caching?? def clear_cache(prompt=False): """Clear the PyOP2 compiler cache. @@ -688,10 +689,10 @@ def clear_cache(prompt=False): for directory in cachedirs: if not os.path.exists(directory): print("Cache directory could not be found") - return + continue if len(os.listdir(directory)) == 0: print("No cached libraries to remove") - return + continue remove = True if prompt: From 1c69860af825c5f6d5b7cef8d46425b384456b18 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 27 Aug 2024 13:01:30 +0100 Subject: [PATCH 35/38] Handle multiple VENVs + instrumenting --- pyop2/caching.py | 15 +++++++-------- pyop2/compilation.py | 11 ++++++----- scripts/pyop2-clean | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 4771de511..6e0d47a36 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -402,12 +402,12 @@ class DEFAULT_CACHE(dict): pass -# Examples of how to instrument and use different default caches: -# ~ DEFAULT_CACHE = instrument(DEFAULT_CACHE) -# ~ DEFAULT_CACHE = instrument(cachetools.LRUCache) -# ~ DEFAULT_CACHE = partial(DEFAULT_CACHE, maxsize=100) +# Example of how to instrument and use different default caches: EXOTIC_CACHE = partial(instrument(cachetools.LRUCache), maxsize=100) -# ~ DictLikeDiskAccess = instrument(DictLikeDiskAccess) +# Turn on cache measurements if printing cache info is enabled +if configuration["print_cache_info"] or _running_on_ci: + DEFAULT_CACHE = instrument(DEFAULT_CACHE) + DictLikeDiskAccess = instrument(DictLikeDiskAccess) # JBTODO: This functionality should only be enabled with a PYOP2_SPMD_STRICT @@ -509,7 +509,6 @@ def wrapper(*args, **kwargs): return decorator -# JBTODO: This needs some more thought def clear_memory_cache(comm): with temp_internal_comm(comm) as icomm: if icomm.Get_attr(comm_cache_keyval) is not None: @@ -540,6 +539,6 @@ def decorator(func): # * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method ✓ # * Systematic investigation into cache sizes/types for Firedrake # - Is a mem cache needed for DLLs? ~~No~~ Yes!! -# - Is LRUCache better than a simple dict? (memory profile test suite) -# - What is the optimal maxsize? +# - Is LRUCache better than a simple dict? (memory profile test suite) No +# - What is the optimal maxsize? ∞ # * Add some docstrings and maybe some exposition! diff --git a/pyop2/compilation.py b/pyop2/compilation.py index ce45c7e99..c7a278feb 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -68,8 +68,10 @@ def _check_hashes(x, y, datatype): _check_op = mpi.MPI.Op.Create(_check_hashes, commute=True) _compiler = None -# Directory must be unique per user for shared machines -MEM_TMP_DIR = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}") +# Directory must be unique per VENV for multiple installs +# _and_ per user for shared machines +_EXE_HASH = md5(sys.executable.encode()).hexdigest()[-6:] +MEM_TMP_DIR = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}").joinpath(_EXE_HASH) def set_default_compiler(compiler): @@ -678,9 +680,8 @@ def _add_profiling_events(dll, events): ctypes.c_int.in_dll(dll, 'ID_'+e).value = PETSc.Log.Event(e).id -# JBTODO: Move to caching?? -def clear_cache(prompt=False): - """Clear the PyOP2 compiler cache. +def clear_compiler_disk_cache(prompt=False): + """Clear the PyOP2 compiler disk cache. :arg prompt: if ``True`` prompt before removing any files """ diff --git a/scripts/pyop2-clean b/scripts/pyop2-clean index ab29f1245..52f667ec4 100755 --- a/scripts/pyop2-clean +++ b/scripts/pyop2-clean @@ -1,6 +1,6 @@ #!/usr/bin/env python -from pyop2.compilation import clear_cache +from pyop2.compilation import clear_compiler_disk_cache if __name__ == '__main__': - clear_cache(prompt=True) + clear_compiler_disk_cache(prompt=True) From 8d0f319f1e3c084293b789a7acfd5931f6be7991 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 27 Aug 2024 14:59:18 +0100 Subject: [PATCH 36/38] Add PYOP2_SPMD_STRICT environment variable for checking MPI correctness --- pyop2/caching.py | 203 +++++++++++++++++------------- pyop2/compilation.py | 2 +- pyop2/configuration.py | 10 +- pyop2/mpi.py | 68 ++++++++-- test/unit/test_updated_caching.py | 7 +- 5 files changed, 187 insertions(+), 103 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 6e0d47a36..96c64de75 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -150,6 +150,8 @@ def make_obj(): def cache_filter(comm=None, comm_name=None, alive=True, function=None, cache_type=None): + """ Filter PyOP2 caches based on communicator, function or cache type. + """ caches = _KNOWN_CACHES if comm is not None: with temp_internal_comm(comm) as icomm: @@ -175,6 +177,8 @@ def cache_filter(comm=None, comm_name=None, alive=True, function=None, cache_typ class _CacheRecord: + """ Object for keeping a record of Pyop2 Cache statistics. + """ def __init__(self, cidx, comm, func, cache): self.cidx = cidx self.comm = comm @@ -220,6 +224,8 @@ def finalize(self, cache): def print_cache_stats(*args, **kwargs): + """ Print out the cache hit/miss/size/maxsize stats for PyOP2 caches. + """ data = defaultdict(lambda: defaultdict(list)) for entry in cache_filter(*args, **kwargs): active = (entry.comm != MPI.COMM_NULL) @@ -278,6 +284,8 @@ def _as_hexdigest(*args): class DictLikeDiskAccess(MutableMapping): + """ A Dictionary like interface for storing and retrieving objects from a disk cache. + """ def __init__(self, cachedir, extension=".pickle"): """ @@ -344,6 +352,8 @@ def write(self, filehandle, value): def default_comm_fetcher(*args, **kwargs): + """ A sensible default comm fetcher for use with `parallel_cache`. + """ comms = filter( lambda arg: isinstance(arg, MPI.Comm), args + tuple(kwargs.values()) @@ -356,8 +366,10 @@ def default_comm_fetcher(*args, **kwargs): def default_parallel_hashkey(*args, **kwargs): - """ We now want to actively remove any comms from args and kwargs to get the same disk cache key + """ A sensible default hash key for use with `parallel_cache`. """ + # We now want to actively remove any comms from args and kwargs to get + # the same disk cache key. hash_args = tuple(filter( lambda arg: not isinstance(arg, MPI.Comm), args @@ -370,6 +382,8 @@ def default_parallel_hashkey(*args, **kwargs): def instrument(cls): + """ Class decorator for dict-like objects for counting cache hits/misses. + """ @wraps(cls, updated=()) class _wrapper(cls): instrument__ = True @@ -410,106 +424,123 @@ class DEFAULT_CACHE(dict): DictLikeDiskAccess = instrument(DictLikeDiskAccess) -# JBTODO: This functionality should only be enabled with a PYOP2_SPMD_STRICT -# environment variable. -def parallel_cache( - hashkey=default_parallel_hashkey, - comm_fetcher=default_comm_fetcher, - cache_factory=lambda: DEFAULT_CACHE(), - broadcast=True -): - """Memory only cache decorator. - - Decorator for wrapping a function to be called over a communicator in a - cache that stores broadcastable values in memory. If the value is found in - the cache of rank 0 it is broadcast to all other ranks. - - :arg key: Callable returning the cache key for the function inputs. This - function must return a 2-tuple where the first entry is the - communicator to be collective over and the second is the key. This is - required to ensure that deadlocks do not occur when using different - subcommunicators. - """ - def decorator(func): - @PETSc.Log.EventDecorator("PyOP2 Cache Wrapper") - @wraps(func) - def wrapper(*args, **kwargs): - """ Extract the key and then try the memory cache before falling back - on calling the function and populating the cache. - """ - k = hashkey(*args, **kwargs) - key = _as_hexdigest(*k), func.__qualname__ - # Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits - with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm: - # Fetch the per-comm cache_collection or set it up if not present - # A collection is required since different types of cache can be set up on the same comm - cache_collection = comm.Get_attr(comm_cache_keyval) - if cache_collection is None: - cache_collection = {} - comm.Set_attr(comm_cache_keyval, cache_collection) - # If this kind of cache is already present on the - # cache_collection, get it, otherwise create it - local_cache = cache_collection.setdefault( - (cf := cache_factory()).__class__.__name__, - cf - ) - local_cache = cache_collection[cf.__class__.__name__] - - # If this is a new cache or function add it to the list of known caches - if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]: - # When a comm is freed we do not hold a reference to the cache. - # We attach a finalizer that extracts the stats before the cache - # is deleted. - _KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache)) - - # JBTODO: Replace everything below here with: - # value = local_cache.get(key, CACHE_MISS) - # and add an optional PYOP2_SPMD_STRICT environment variable - - if broadcast: - # Grab value from rank 0 memory cache and broadcast result - if comm.rank == 0: - value = local_cache.get(key, CACHE_MISS) - if value is CACHE_MISS: - debug( - f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " - f"{k} {local_cache.__class__.__name__} cache miss" - ) - else: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') - # JBTODO: Add communication tags to avoid cross-broadcasting - comm.bcast(value, root=0) - else: - value = comm.bcast(CACHE_MISS, root=0) - if isinstance(value, _CacheMiss): - # We might have the CACHE_MISS from rank 0 and - # `(value is CACHE_MISS) == False` which is confusing, - # so we set it back to the local value - value = CACHE_MISS - else: +if configuration["spmd_strict"]: + def parallel_cache( + hashkey=default_parallel_hashkey, + comm_fetcher=default_comm_fetcher, + cache_factory=lambda: DEFAULT_CACHE(), + ): + """Parallel cache decorator (SPMD strict-enabled). + """ + def decorator(func): + @PETSc.Log.EventDecorator("PyOP2 Cache Wrapper") + @wraps(func) + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory cache before falling back + on calling the function and populating the cache. SPMD strict ensures + that all ranks cache hit or miss to ensure that the function evaluation + always occurs in parallel. + """ + k = hashkey(*args, **kwargs) + key = _as_hexdigest(*k), func.__qualname__ + # Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits + with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm: + # Fetch the per-comm cache_collection or set it up if not present + # A collection is required since different types of cache can be set up on the same comm + cache_collection = comm.Get_attr(comm_cache_keyval) + if cache_collection is None: + cache_collection = {} + comm.Set_attr(comm_cache_keyval, cache_collection) + # If this kind of cache is already present on the + # cache_collection, get it, otherwise create it + local_cache = cache_collection.setdefault( + (cf := cache_factory()).__class__.__name__, + cf + ) + local_cache = cache_collection[cf.__class__.__name__] + + # If this is a new cache or function add it to the list of known caches + if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]: + # When a comm is freed we do not hold a reference to the cache. + # We attach a finalizer that extracts the stats before the cache + # is deleted. + _KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache)) + # Grab value from all ranks cache and broadcast cache hit/miss value = local_cache.get(key, CACHE_MISS) + debug_string = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " + debug_string += f"key={k} in cache: {local_cache.__class__.__name__} cache " if value is CACHE_MISS: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache miss') + debug(debug_string + "miss") cache_hit = False else: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') + debug(debug_string + "hit") cache_hit = True all_present = comm.allgather(cache_hit) - # If not present in the cache of all ranks we need to recompute on all ranks + # If not present in the cache of all ranks we force re-evaluation on all ranks if not min(all_present): value = CACHE_MISS - if value is CACHE_MISS: - value = func(*args, **kwargs) - return local_cache.setdefault(key, value) + if value is CACHE_MISS: + value = func(*args, **kwargs) + return local_cache.setdefault(key, value) + + return wrapper + return decorator +else: + def parallel_cache( + hashkey=default_parallel_hashkey, + comm_fetcher=default_comm_fetcher, + cache_factory=lambda: DEFAULT_CACHE(), + ): + """Parallel cache decorator. + """ + def decorator(func): + @PETSc.Log.EventDecorator("PyOP2 Cache Wrapper") + @wraps(func) + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory cache before falling back + on calling the function and populating the cache. + """ + k = hashkey(*args, **kwargs) + key = _as_hexdigest(*k), func.__qualname__ + # Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits + with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm: + # Fetch the per-comm cache_collection or set it up if not present + # A collection is required since different types of cache can be set up on the same comm + cache_collection = comm.Get_attr(comm_cache_keyval) + if cache_collection is None: + cache_collection = {} + comm.Set_attr(comm_cache_keyval, cache_collection) + # If this kind of cache is already present on the + # cache_collection, get it, otherwise create it + local_cache = cache_collection.setdefault( + (cf := cache_factory()).__class__.__name__, + cf + ) + local_cache = cache_collection[cf.__class__.__name__] + + # If this is a new cache or function add it to the list of known caches + if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]: + # When a comm is freed we do not hold a reference to the cache. + # We attach a finalizer that extracts the stats before the cache + # is deleted. + _KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache)) - return wrapper - return decorator + value = local_cache.get(key, CACHE_MISS) + + if value is CACHE_MISS: + value = func(*args, **kwargs) + return local_cache.setdefault(key, value) + + return wrapper + return decorator def clear_memory_cache(comm): + """ Completely remove all PyOP2 caches on a given communicator. + """ with temp_internal_comm(comm) as icomm: if icomm.Get_attr(comm_cache_keyval) is not None: icomm.Set_attr(comm_cache_keyval, {}) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index c7a278feb..86db95b9e 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -425,7 +425,7 @@ def load_hashkey(*args, **kwargs): @mpi.collective -@memory_cache(hashkey=load_hashkey, broadcast=False) +@memory_cache(hashkey=load_hashkey) @PETSc.Log.EventDecorator() def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), argtypes=None, restype=None, comm=None): diff --git a/pyop2/configuration.py b/pyop2/configuration.py index dc4db1679..0005ceeca 100644 --- a/pyop2/configuration.py +++ b/pyop2/configuration.py @@ -40,7 +40,6 @@ from pyop2.exceptions import ConfigurationError -# JBTODO: Add a PYOP2_SPMD_STRICT environment variable to add various SPMD checks. class Configuration(dict): r"""PyOP2 configuration parameters @@ -68,13 +67,16 @@ class Configuration(dict): to a node-local filesystem too. :param log_level: How chatty should PyOP2 be? Valid values are "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL". - :param print_cache_size: Should PyOP2 print the size of caches at + :param print_cache_size: Should PyOP2 print the cache information at program exit? :param matnest: Should matrices on mixed maps be built as nests? (Default yes) :param block_sparsity: Should sparsity patterns on datasets with cdim > 1 be built as block sparsities, or dof sparsities. The former saves memory but changes which preconditioners are available for the resulting matrices. (Default yes) + :param spmd_strict: Enable barriers for calls marked with @collective and + for cache access. This adds considerable overhead, but is useful for + tracking down deadlocks. (Default no) """ # name, env variable, type, default, write once cache_dir = os.path.join(gettempdir(), "pyop2-cache-uid%s" % os.getuid()) @@ -114,7 +116,9 @@ class Configuration(dict): "matnest": ("PYOP2_MATNEST", bool, True), "block_sparsity": - ("PYOP2_BLOCK_SPARSITY", bool, True) + ("PYOP2_BLOCK_SPARSITY", bool, True), + "spmd_strict": + ("PYOP2_SPMD_STRICT", bool, False), } """Default values for PyOP2 configuration parameters""" diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 2831bc04f..7e88b8dd0 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -37,6 +37,7 @@ from petsc4py import PETSc from mpi4py import MPI # noqa from itertools import count +from functools import wraps import atexit import gc import glob @@ -160,15 +161,64 @@ class PyOP2CommError(ValueError): # PYOP2_FINALISED flag. -# JBTODO: Make this decorator infinitely more useful by adding barriers before -# and after the function call, if being run with PYOP2_SPMD_STRICT=1. -def collective(fn): - extra = trim(""" - This function is logically collective over MPI ranks, it is an - error to call it on fewer than all the ranks in MPI communicator. - """) - fn.__doc__ = "%s\n\n%s" % (trim(fn.__doc__), extra) if fn.__doc__ else extra - return fn +if configuration["spmd_strict"]: + def collective(fn): + extra = trim(""" + This function is logically collective over MPI ranks, it is an + error to call it on fewer than all the ranks in MPI communicator. + PYOP2_SPMD_STRICT=1 is in your environment and function calls will be + guarded by a barrier where possible. + """) + + @wraps(fn) + def wrapper(*args, **kwargs): + comms = filter( + lambda arg: isinstance(arg, MPI.Comm), + args + tuple(kwargs.values()) + ) + try: + comm = next(comms) + except StopIteration: + if args and hasattr(args[0], "comm"): + comm = args[0].comm + else: + comm = None + + if comm is None: + debug( + "`@collective` wrapper found no communicators in args or kwargs, " + "this means that the call is implicitly collective over an " + "unknown communicator. " + f"The following call to {fn.__module__}.{fn.__qualname__} is " + "not protected by an MPI barrier." + ) + subcomm = ", UNKNOWN Comm" + else: + subcomm = f", {comm.name} R{comm.rank}" + + debug_string_pt1 = f"{COMM_WORLD.name} R{COMM_WORLD.rank}{subcomm}: " + debug_string_pt2 = f" {fn.__module__}.{fn.__qualname__}" + debug(debug_string_pt1 + "Entering" + debug_string_pt2) + if comm is not None: + comm.Barrier() + value = fn(*args, **kwargs) + debug(debug_string_pt1 + "Leaving" + debug_string_pt2) + if comm is not None: + comm.Barrier() + return value + + wrapper.__doc__ = f"{trim(fn.__doc__)}\n\n{extra}" if fn.__doc__ else extra + return wrapper +else: + def collective(fn): + extra = trim(""" + This function is logically collective over MPI ranks, it is an + error to call it on fewer than all the ranks in MPI communicator. + You can set PYOP2_SPMD_STRICT=1 in your environment to try and catch + non-collective calls. + """) + fn.__doc__ = f"{trim(fn.__doc__)}\n\n{extra}" if fn.__doc__ else extra + return fn def delcomm_outer(comm, keyval, icomm): diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py index 5066554a1..1d9424b05 100644 --- a/test/unit/test_updated_caching.py +++ b/test/unit/test_updated_caching.py @@ -2,7 +2,6 @@ import pytest import os import tempfile -from functools import partial from itertools import chain from textwrap import dedent @@ -63,7 +62,7 @@ def state(): @pytest.mark.parametrize("decorator, uncached_function", [ (memory_cache, twople), - (partial(memory_cache, broadcast=False), n_comms), + (memory_cache, n_comms), (memory_and_disk_cache, twople), (disk_only_cache, twople) ]) @@ -89,7 +88,7 @@ def test_function_args_twice_caches(request, state, decorator, uncached_function @pytest.mark.parametrize("decorator, uncached_function", [ (memory_cache, twople), - (partial(memory_cache, broadcast=False), n_comms), + (memory_cache, n_comms), (memory_and_disk_cache, twople), (disk_only_cache, twople) ]) @@ -114,7 +113,7 @@ def test_function_args_different(request, state, decorator, uncached_function, t @pytest.mark.parallel(nprocs=3) @pytest.mark.parametrize("decorator, uncached_function", [ (memory_cache, twople), - (partial(memory_cache, broadcast=False), n_comms), + (memory_cache, n_comms), (memory_and_disk_cache, twople), (disk_only_cache, twople) ]) From e8f08594154c54bc00acefac0fc011384979195d Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Sun, 1 Sep 2024 19:37:12 +0100 Subject: [PATCH 37/38] Actually put mismatched kernels in the directory created --- pyop2/compilation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 86db95b9e..9fb654d9c 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -548,7 +548,7 @@ def check_source_hashes(compiler, jitmodule, extension, comm): if matching != hashval: # Dump all src code to disk for debugging output = Path(configuration["cache_dir"]).joinpath("mismatching-kernels") - srcfile = output.with_name(f"src-rank{icomm.rank}.{extension}") + srcfile = output.joinpath(f"src-rank{icomm.rank}.{extension}") if icomm.rank == 0: output.mkdir(exist_ok=True) icomm.barrier() From 015cec341f6e45fbb3fe18144443f4605eb87dec Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Wed, 25 Sep 2024 00:48:53 +0100 Subject: [PATCH 38/38] Put all shared libraries i a sincle cache directory --- pyop2/caching.py | 32 ++++++++++++++++++++++++++++++++ pyop2/compilation.py | 4 ++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 96c64de75..d0539609a 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -351,6 +351,38 @@ def write(self, filehandle, value): pickle.dump(value, filehandle) +class NoShardDiskAccess(DictLikeDiskAccess): + def __getitem__(self, key): + """Retrieve a value from the disk cache. + + :arg key: The cache key, a 2-tuple of strings. + :returns: The cached object if found. + """ + filepath = Path(self.cachedir, key[0] + key[1]) + try: + with self.open(filepath.with_suffix(self.extension), mode="rb") as fh: + value = self.read(fh) + except FileNotFoundError: + raise KeyError("File not on disk, cache miss") + return value + + def __setitem__(self, key, value): + """Store a new value in the disk cache. + + :arg key: The cache key, a 2-tuple of strings. + :arg value: The new item to store in the cache. + """ + k = key[0] + key[1] + basedir = Path(self.cachedir) + basedir.mkdir(parents=True, exist_ok=True) + + tempfile = basedir.joinpath(f"{k}_p{os.getpid()}.tmp") + filepath = basedir.joinpath(k) + with self.open(tempfile, mode="wb") as fh: + self.write(fh, value) + tempfile.rename(filepath.with_suffix(self.extension)) + + def default_comm_fetcher(*args, **kwargs): """ A sensible default comm fetcher for use with `parallel_cache`. """ diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 9fb654d9c..58d3c7928 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -51,7 +51,7 @@ from pyop2 import mpi -from pyop2.caching import parallel_cache, memory_cache, default_parallel_hashkey, _as_hexdigest, DictLikeDiskAccess +from pyop2.caching import parallel_cache, memory_cache, default_parallel_hashkey, _as_hexdigest, NoShardDiskAccess from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError @@ -501,7 +501,7 @@ def expandWl(ldflags): yield flag -class CompilerDiskAccess(DictLikeDiskAccess): +class CompilerDiskAccess(NoShardDiskAccess): @contextmanager def open(self, filename, *args, **kwargs): yield filename