From be1e58b21bdf135a2fe3dd4eceed0d6dac6d15e5 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Mon, 19 Aug 2024 18:16:24 +0100 Subject: [PATCH] Add instrumentation to caches --- pyop2/caching.py | 291 +++++++++++++++++++++++++++----------- pyop2/configuration.py | 4 +- pyop2/op2.py | 9 +- test/unit/test_caching.py | 25 ++-- 4 files changed, 227 insertions(+), 102 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 2bc72525c..65954beda 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -32,7 +32,6 @@ # OF THE POSSIBILITY OF SUCH DAMAGE. """Provides common base classes for cached objects.""" - import cachetools import hashlib import os @@ -40,44 +39,27 @@ from collections.abc import MutableMapping from pathlib import Path from warnings import warn # noqa F401 -from functools import wraps +from collections import defaultdict +from itertools import count +from functools import partial, wraps from pyop2.configuration import configuration from pyop2.logger import debug -from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval +from pyop2.mpi import ( + MPI, COMM_WORLD, comm_cache_keyval, temp_internal_comm +) -# TODO: Remove this? Rewrite? -def report_cache(typ): - """Report the size of caches of type ``typ`` - - :arg typ: A class of cached object. For example - :class:`ObjectCached` or :class:`Cached`. - """ - from collections import defaultdict - from inspect import getmodule - from gc import get_objects - typs = defaultdict(lambda: 0) - n = 0 - for x in get_objects(): - if isinstance(x, typ): - typs[type(x)] += 1 - n += 1 - if n == 0: - print("\nNo %s objects in caches" % typ.__name__) - return - print("\n%d %s objects in caches" % (n, typ.__name__)) - print("Object breakdown") - print("================") - for k, v in typs.iteritems(): - mod = getmodule(k) - if mod is not None: - name = "%s.%s" % (mod.__name__, k.__name__) - else: - name = k.__name__ - print('%s: %d' % (name, v)) +# Caches created here are registered as a tuple of +# (creation_index, comm, comm.name, function, cache) +# in _KNOWN_CACHES +_CACHE_CIDX = count() +_KNOWN_CACHES = [] +# Flag for outputting information at the end of testing (do not abuse!) +_running_on_ci = bool(os.environ.get('PYOP2_CI_TESTS')) +# FIXME: (Later) Remove ObjectCached class ObjectCached(object): """Base class for objects that should be cached on another object. @@ -163,6 +145,95 @@ def make_obj(): return obj +def cache_stats(comm=None, comm_name=None, alive=True, function=None, cache_type=None): + caches = _KNOWN_CACHES + if comm is not None: + with temp_internal_comm(comm) as icomm: + cache_collection = icomm.Get_attr(comm_cache_keyval) + if cache_collection is None: + print(f"Communicator {icomm.name} as no associated caches") + comm_name = icomm.name + if comm_name is not None: + caches = filter(lambda c: c[2] == comm_name, caches) + if alive: + caches = filter(lambda c: c[1] != MPI.COMM_NULL, caches) + if function is not None: + if isinstance(function, str): + caches = filter(lambda c: function in c[3].__qualname__, caches) + else: + caches = filter(lambda c: c[3] is function, caches) + if cache_type is not None: + if isinstance(cache_type, str): + caches = filter(lambda c: cache_type in c[4].__qualname__, caches) + else: + caches = filter(lambda c: isinstance(c[4], cache_type), caches) + return [*caches] + + +def get_stats(cache): + hit = miss = size = maxsize = -1 + if isinstance(cache, cachetools.Cache): + size = cache.currsize + maxsize = cache.maxsize + if hasattr(cache, "instrument__"): + hit = cache.hit + miss = cache.miss + if size is None: + try: + size = len(cache) + except NotImplementedError: + pass + if maxsize is None: + try: + maxsize = cache.max_size + except AttributeError: + pass + return hit, miss, size, maxsize + + +def print_cache_stats(*args, **kwargs): + data = defaultdict(lambda: defaultdict(list)) + for entry in cache_stats(*args, **kwargs): + ecid, ecomm, ecomm_name, efunction, ecache = entry + active = (ecomm != MPI.COMM_NULL) + data[(ecomm_name, active)][ecache.__class__.__name__].append( + (ecid, efunction.__module__, efunction.__name__, ecache) + ) + + tab = " " + hline = "-"*120 + col = (90, 27) + stats_col = (6, 6, 6, 6) + stats = ("hit", "miss", "size", "max") + no_stats = "|".join(" "*ii for ii in stats_col) + print(hline) + print(f"|{'Cache':^{col[0]}}|{'Stats':^{col[1]}}|") + subtitles = "|".join(f"{st:^{w}}" for st, w in zip(stats, stats_col)) + print("|" + " "*col[0] + f"|{subtitles:{col[1]}}|") + print(hline) + for ecomm, cachedict in data.items(): + active = "Active" if ecomm[1] else "Freed" + comm_title = f"{ecomm[0]} ({active})" + print(f"|{comm_title:{col[0]}}|{no_stats}|") + for ecache, function_list in cachedict.items(): + cache_title = f"{tab}{ecache}" + print(f"|{cache_title:{col[0]}}|{no_stats}|") + try: + loc = function_list[0][-1].cachedir + except AttributeError: + loc = "Memory" + cache_location = f"{tab} ↳ {loc!s}" + if len(str(loc)) < col[0] - 5: + print(f"|{cache_location:{col[0]}}|{no_stats}|") + else: + print(f"|{cache_location:78}|") + for entry in function_list: + function_title = f"{tab*2}id={entry[0]} {'.'.join(entry[1:3])}" + stats = "|".join(f"{s:{w}}" for s, w in zip(get_stats(entry[3]), stats_col)) + print(f"|{function_title:{col[0]}}|{stats:{col[1]}}|") + print(hline) + + class _CacheMiss: pass @@ -180,11 +251,6 @@ def _as_hexdigest(*args): return hash_.hexdigest() -def clear_memory_cache(comm): - if comm.Get_attr(comm_cache_keyval) is not None: - comm.Set_attr(comm_cache_keyval, {}) - - class DictLikeDiskAccess(MutableMapping): def __init__(self, cachedir): """ @@ -224,17 +290,21 @@ def __setitem__(self, key, value): tempfile.rename(filepath) def __delitem__(self, key): - raise ValueError(f"Cannot remove items from {self.__class__.__name__}") + raise NotImplementedError(f"Cannot remove items from {self.__class__.__name__}") def __iter__(self): - raise ValueError(f"Cannot iterate over keys in {self.__class__.__name__}") + raise NotImplementedError(f"Cannot iterate over keys in {self.__class__.__name__}") def __len__(self): - raise ValueError(f"Cannot query length of {self.__class__.__name__}") + raise NotImplementedError(f"Cannot query length of {self.__class__.__name__}") def __repr__(self): return f"{self.__class__.__name__}(cachedir={self.cachedir})" + def __eq__(self, other): + # Instances are the same if they have the same cachedir + return self.cachedir == other.cachedir + def open(self, *args, **kwargs): return open(*args, **kwargs) @@ -271,10 +341,47 @@ def default_parallel_hashkey(*args, **kwargs): return cachetools.keys.hashkey(*hash_args, **hash_kwargs) +def instrument(cls): + @wraps(cls, updated=()) + class _wrapper(cls): + instrument__ = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hit = 0 + self.miss = 0 + + def get(self, key, default=None): + value = super().get(key, default) + if value is default: + self.miss += 1 + else: + self.hit += 1 + return value + + def __getitem__(self, key): + try: + value = super().__getitem__(key) + self.hit += 1 + except KeyError as e: + self.miss += 1 + raise e + return value + return _wrapper + + class DEFAULT_CACHE(dict): pass +# Examples of how to instrument and use different default caches: +# - DEFAULT_CACHE = instrument(DEFAULT_CACHE) +# - DEFAULT_CACHE = instrument(cachetools.LRUCache) +# - DEFAULT_CACHE = partial(DEFAULT_CACHE, maxsize=100) +EXOTIC_CACHE = partial(instrument(cachetools.LRUCache), maxsize=100) +# - DictLikeDiskAccess = instrument(DictLikeDiskAccess) + + def parallel_cache( hashkey=default_parallel_hashkey, comm_fetcher=default_comm_fetcher, @@ -299,54 +406,62 @@ def wrapper(*args, **kwargs): """ Extract the key and then try the memory cache before falling back on calling the function and populating the cache. """ - comm = comm_fetcher(*args, **kwargs) k = hashkey(*args, **kwargs) key = _as_hexdigest(*k), func.__qualname__ - - # Fetch the per-comm cache_collection or set it up if not present - # A collection is required since different types of cache can be set up on the same comm - cache_collection = comm.Get_attr(comm_cache_keyval) - if cache_collection is None: - cache_collection = {} - comm.Set_attr(comm_cache_keyval, cache_collection) - # If this kind of cache is already present on the - # cache_collection, get it, otherwise create it - local_cache = cache_collection.setdefault( - (cf := cache_factory()).__class__.__name__, - cf - ) - - if broadcast: - # Grab value from rank 0 memory cache and broadcast result - if comm.rank == 0: + # Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits + with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm: + # Fetch the per-comm cache_collection or set it up if not present + # A collection is required since different types of cache can be set up on the same comm + cache_collection = comm.Get_attr(comm_cache_keyval) + if cache_collection is None: + cache_collection = {} + comm.Set_attr(comm_cache_keyval, cache_collection) + # If this kind of cache is already present on the + # cache_collection, get it, otherwise create it + local_cache = cache_collection.setdefault( + (cf := cache_factory()).__class__.__name__, + cf + ) + local_cache = cache_collection[cf.__class__.__name__] + + # If this is a new cache or function add it to the list of known caches + if (comm, comm.name, func, local_cache) not in [k[1:] for k in _KNOWN_CACHES]: + _KNOWN_CACHES.append((next(_CACHE_CIDX), comm, comm.name, func, local_cache)) + + if broadcast: + # Grab value from rank 0 memory cache and broadcast result + if comm.rank == 0: + value = local_cache.get(key, CACHE_MISS) + if value is CACHE_MISS: + debug( + f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " + f"{k} {local_cache.__class__.__name__} cache miss" + ) + else: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') + # TODO: Add communication tags to avoid cross-broadcasting + comm.bcast(value, root=0) + else: + value = comm.bcast(CACHE_MISS, root=0) + if isinstance(value, _CacheMiss): + # We might have the CACHE_MISS from rank 0 and + # `(value is CACHE_MISS) == False` which is confusing, + # so we set it back to the local value + value = CACHE_MISS + else: + # Grab value from all ranks cache and broadcast cache hit/miss value = local_cache.get(key, CACHE_MISS) if value is CACHE_MISS: debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache miss') + cache_hit = False else: debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') - # TODO: Add communication tags to avoid cross-broadcasting - comm.bcast(value, root=0) - else: - value = comm.bcast(CACHE_MISS, root=0) - if isinstance(value, _CacheMiss): - # We might have the CACHE_MISS from rank 0 and - # `(value is CACHE_MISS) == False` which is confusing, - # so we set it back to the local value - value = CACHE_MISS - else: - # Grab value from all ranks cache and broadcast cache hit/miss - value = local_cache.get(key, CACHE_MISS) - if value is CACHE_MISS: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache miss') - cache_hit = False - else: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') - cache_hit = True - all_present = comm.allgather(cache_hit) + cache_hit = True + all_present = comm.allgather(cache_hit) - # If not present in the cache of all ranks we need to recompute on all ranks - if not min(all_present): - value = CACHE_MISS + # If not present in the cache of all ranks we need to recompute on all ranks + if not min(all_present): + value = CACHE_MISS if value is CACHE_MISS: value = func(*args, **kwargs) @@ -356,6 +471,12 @@ def wrapper(*args, **kwargs): return decorator +def clear_memory_cache(comm): + with temp_internal_comm(comm) as icomm: + if icomm.Get_attr(comm_cache_keyval) is not None: + icomm.Set_attr(comm_cache_keyval, {}) + + # A small collection of default simple caches memory_cache = parallel_cache @@ -374,8 +495,12 @@ def decorator(func): return decorator # TODO: (Wishlist) -# * Try more exotic caches ie: memory_cache = partial(parallel_cache, cache_factory=lambda: cachetools.LRUCache(maxsize=1000)) -# * Add some sort of cache reporting -# * Add some sort of cache statistics +# * Try more exotic caches ie: memory_cache = partial(parallel_cache, cache_factory=lambda: cachetools.LRUCache(maxsize=1000)) ✓ +# * Add some sort of cache reporting ✓ +# * Add some sort of cache statistics ✓ # * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method +# * Systematic investigation into cache sizes/types for Firedrake +# - Is a mem cache needed for DLLs? +# - Is LRUCache better than a simple dict? (memory profile test suite) +# - What is the optimal maxsize? # * Add some docstrings and maybe some exposition! diff --git a/pyop2/configuration.py b/pyop2/configuration.py index 29717718c..ff3721a6f 100644 --- a/pyop2/configuration.py +++ b/pyop2/configuration.py @@ -108,8 +108,8 @@ class Configuration(dict): ("PYOP2_NODE_LOCAL_COMPILATION", bool, True), "no_fork_available": ("PYOP2_NO_FORK_AVAILABLE", bool, False), - "print_cache_size": - ("PYOP2_PRINT_CACHE_SIZE", bool, False), + "print_cache_info": + ("PYOP2_CACHE_INFO", bool, False), "matnest": ("PYOP2_MATNEST", bool, True), "block_sparsity": diff --git a/pyop2/op2.py b/pyop2/op2.py index 85788eafa..35e5649f4 100644 --- a/pyop2/op2.py +++ b/pyop2/op2.py @@ -112,11 +112,10 @@ def init(**kwargs): @collective def exit(): """Exit OP2 and clean up""" - if configuration['print_cache_size'] and COMM_WORLD.rank == 0: - from caching import report_cache, Cached, ObjectCached - print('**** PyOP2 cache sizes at exit ****') - report_cache(typ=ObjectCached) - report_cache(typ=Cached) + if configuration['print_cache_info'] and COMM_WORLD.rank == 0: + from pyop2.caching import print_cache_stats + print(f"{' PyOP2 cache sizes on rank 0 at exit ':*^120}") + print_cache_stats(alive=False) configuration.reset() global _initialised _initialised = False diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index d34859cc2..6ab909b29 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -45,6 +45,7 @@ def _seed(): nelems = 8 +default_cache_name = DEFAULT_CACHE().__class__.__name__ @pytest.fixture @@ -286,11 +287,11 @@ class TestGeneratedCodeCache: @property def cache(self): int_comm = mpi.internal_comm(mpi.COMM_WORLD, self) - _cache = int_comm.Get_attr(mpi.comm_cache_keyval) - if _cache is None: - _cache = {'DEFAULT_CACHE': DEFAULT_CACHE()} - mpi.COMM_WORLD.Set_attr(mpi.comm_cache_keyval, _cache) - return _cache['DEFAULT_CACHE'] + _cache_collection = int_comm.Get_attr(mpi.comm_cache_keyval) + if _cache_collection is None: + _cache_collection = {default_cache_name: DEFAULT_CACHE()} + mpi.COMM_WORLD.Set_attr(mpi.comm_cache_keyval, _cache_collection) + return _cache_collection[default_cache_name] @pytest.fixture def a(cls, diterset): @@ -541,7 +542,7 @@ def comm(self): """This fixture provides a temporary comm so that each test gets it's own communicator and that caches are cleaned on free.""" temporary_comm = mpi.COMM_WORLD.Dup() - temporary_comm.name = "pytest temporary COMM_WORLD" + temporary_comm.name = "pytest temp COMM_WORLD" with mpi.temp_internal_comm(temporary_comm) as comm: yield comm temporary_comm.Free() @@ -556,7 +557,7 @@ def test_decorator_in_memory_cache_reuses_results(self, cachedir, comm): )(self.myfunc) obj1 = decorated_func("input1", comm=comm) - mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] assert len(mem_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 @@ -571,7 +572,7 @@ def test_decorator_uses_different_in_memory_caches_on_different_comms(self, cach )(self.myfunc) temporary_comm = mpi.COMM_SELF.Dup() - temporary_comm.name = "pytest temporary COMM_SELF" + temporary_comm.name = "pytest temp COMM_SELF" with mpi.temp_internal_comm(temporary_comm) as comm_self: comm_self_func = memory_and_disk_cache( cachedir=cachedir.name @@ -579,13 +580,13 @@ def test_decorator_uses_different_in_memory_caches_on_different_comms(self, cach # obj1 should be cached on the COMM_WORLD cache obj1 = comm_world_func("input1", comm=comm) - comm_world_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + comm_world_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] assert len(comm_world_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 # obj2 should be cached on the COMM_SELF cache obj2 = comm_self_func("input1", comm=comm_self) - comm_self_cache = comm_self.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + comm_self_cache = comm_self.Get_attr(mpi.comm_cache_keyval)[default_cache_name] assert obj1 == obj2 and obj1 is not obj2 assert len(comm_world_cache) == 1 assert len(comm_self_cache) == 1 @@ -599,7 +600,7 @@ def test_decorator_disk_cache_reuses_results(self, cachedir, comm): obj1 = decorated_func("input1", comm=comm) clear_memory_cache(comm) obj2 = decorated_func("input1", comm=comm) - mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] assert obj1 == obj2 and obj1 is not obj2 assert len(mem_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 @@ -609,7 +610,7 @@ def test_decorator_cache_misses(self, cachedir, comm): obj1 = decorated_func("input1", comm=comm) obj2 = decorated_func("input2", comm=comm) - mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"] + mem_cache = comm.Get_attr(mpi.comm_cache_keyval)[default_cache_name] assert obj1 != obj2 assert len(mem_cache) == 2 assert len(os.listdir(cachedir.name)) == 2