From 3eef196795ed00c397e1fcf4c86cae2220128c94 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Sat, 24 Aug 2024 23:45:58 +0100 Subject: [PATCH] Fix stats printing --- pyop2/caching.py | 127 +++++++++++++++++++++++++------------------ pyop2/compilation.py | 5 +- 2 files changed, 78 insertions(+), 54 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 56b30096e..4771de511 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -32,6 +32,7 @@ # OF THE POSSIBILITY OF SUCH DAMAGE. """Provides common base classes for cached objects.""" +import atexit import cachetools import hashlib import os @@ -148,59 +149,82 @@ def make_obj(): return obj -def cache_stats(comm=None, comm_name=None, alive=True, function=None, cache_type=None): +def cache_filter(comm=None, comm_name=None, alive=True, function=None, cache_type=None): caches = _KNOWN_CACHES if comm is not None: with temp_internal_comm(comm) as icomm: cache_collection = icomm.Get_attr(comm_cache_keyval) if cache_collection is None: - print(f"Communicator {icomm.name} as no associated caches") + print(f"Communicator {icomm.name} has no associated caches") comm_name = icomm.name if comm_name is not None: - caches = filter(lambda c: c[2] == comm_name, caches) + caches = filter(lambda c: c.comm_name == comm_name, caches) if alive: - caches = filter(lambda c: c[1] != MPI.COMM_NULL, caches) + caches = filter(lambda c: c.comm != MPI.COMM_NULL, caches) if function is not None: if isinstance(function, str): - caches = filter(lambda c: function in c[3].__qualname__, caches) + caches = filter(lambda c: function in c.func_name, caches) else: - caches = filter(lambda c: c[3] is function, caches) + caches = filter(lambda c: c.func is function, caches) if cache_type is not None: if isinstance(cache_type, str): - caches = filter(lambda c: cache_type in c[4].__qualname__, caches) + caches = filter(lambda c: cache_type in c.cache_name, caches) else: - caches = filter(lambda c: isinstance(c[4], cache_type), caches) + caches = filter(lambda c: c.cache_name == cache_type.__class__.__qualname__, caches) return [*caches] -def get_stats(cache): - hit = miss = size = maxsize = -1 - if isinstance(cache, cachetools.Cache): - size = cache.currsize - maxsize = cache.maxsize - if hasattr(cache, "instrument__"): - hit = cache.hit - miss = cache.miss - if size is None: - try: - size = len(cache) - except NotImplementedError: - pass - if maxsize is None: - try: - maxsize = cache.max_size - except AttributeError: - pass - return hit, miss, size, maxsize +class _CacheRecord: + def __init__(self, cidx, comm, func, cache): + self.cidx = cidx + self.comm = comm + self.comm_name = comm.name + self.func = func + self.func_module = func.__module__ + self.func_name = func.__qualname__ + self.cache = weakref.ref(cache) + fin = weakref.finalize(cache, self.finalize, cache) + fin.atexit = False + self.cache_name = cache.__class__.__qualname__ + try: + self.cache_loc = cache.cachedir + except AttributeError: + self.cache_loc = "Memory" + + def get_stats(self, cache=None): + if cache is None: + cache = self.cache() + hit = miss = size = maxsize = -1 + if cache is None: + hit, miss, size, maxsize = self.hit, self.miss, self.size, self.maxsize + if isinstance(cache, cachetools.Cache): + size = cache.currsize + maxsize = cache.maxsize + if hasattr(cache, "instrument__"): + hit = cache.hit + miss = cache.miss + if size == -1: + try: + size = len(cache) + except NotImplementedError: + pass + if maxsize is None: + try: + maxsize = cache.max_size + except AttributeError: + pass + return hit, miss, size, maxsize + + def finalize(self, cache): + self.hit, self.miss, self.size, self.maxsize = self.get_stats(cache) def print_cache_stats(*args, **kwargs): data = defaultdict(lambda: defaultdict(list)) - for entry in cache_stats(*args, **kwargs): - ecid, ecomm, ecomm_name, efunction, ecache = entry - active = (ecomm != MPI.COMM_NULL) - data[(ecomm_name, active)][ecache.__class__.__name__].append( - (ecid, efunction.__module__, efunction.__name__, ecache) + for entry in cache_filter(*args, **kwargs): + active = (entry.comm != MPI.COMM_NULL) + data[(entry.comm_name, active)][(entry.cache_name, entry.cache_loc)].append( + (entry.cidx, entry.func_module, entry.func_name, entry.get_stats()) ) tab = " " @@ -219,22 +243,22 @@ def print_cache_stats(*args, **kwargs): comm_title = f"{ecomm[0]} ({active})" print(f"|{comm_title:{col[0]}}|{no_stats}|") for ecache, function_list in cachedict.items(): - cache_title = f"{tab}{ecache}" + cache_title = f"{tab}{ecache[0]}" print(f"|{cache_title:{col[0]}}|{no_stats}|") - try: - loc = function_list[0][-1].cachedir - except AttributeError: - loc = "Memory" - cache_location = f"{tab} ↳ {loc!s}" - if len(str(loc)) < col[0] - 5: + cache_location = f"{tab} ↳ {ecache[1]!s}" + if len(cache_location) < col[0]: print(f"|{cache_location:{col[0]}}|{no_stats}|") else: print(f"|{cache_location:78}|") for entry in function_list: function_title = f"{tab*2}id={entry[0]} {'.'.join(entry[1:3])}" - stats = "|".join(f"{s:{w}}" for s, w in zip(get_stats(entry[3]), stats_col)) - print(f"|{function_title:{col[0]}}|{stats:{col[1]}}|") - print(hline) + stats_row = "|".join(f"{s:{w}}" for s, w in zip(entry[3], stats_col)) + print(f"|{function_title:{col[0]}}|{stats_row:{col[1]}}|") + print(hline) + + +if _running_on_ci: + print_cache_stats = atexit.register(print_cache_stats) class _CacheMiss: @@ -379,11 +403,11 @@ class DEFAULT_CACHE(dict): # Examples of how to instrument and use different default caches: -# - DEFAULT_CACHE = instrument(DEFAULT_CACHE) -# - DEFAULT_CACHE = instrument(cachetools.LRUCache) -# - DEFAULT_CACHE = partial(DEFAULT_CACHE, maxsize=100) +# ~ DEFAULT_CACHE = instrument(DEFAULT_CACHE) +# ~ DEFAULT_CACHE = instrument(cachetools.LRUCache) +# ~ DEFAULT_CACHE = partial(DEFAULT_CACHE, maxsize=100) EXOTIC_CACHE = partial(instrument(cachetools.LRUCache), maxsize=100) -# - DictLikeDiskAccess = instrument(DictLikeDiskAccess) +# ~ DictLikeDiskAccess = instrument(DictLikeDiskAccess) # JBTODO: This functionality should only be enabled with a PYOP2_SPMD_STRICT @@ -432,13 +456,11 @@ def wrapper(*args, **kwargs): local_cache = cache_collection[cf.__class__.__name__] # If this is a new cache or function add it to the list of known caches - if (comm, comm.name, func, weakref.ref(local_cache)) not in [c[1:] for c in _KNOWN_CACHES]: - # JBTODO: When a comm is freed we will not hold a ref to the cache, - # but we should have a finalizer that extracts the stats before the object + if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]: + # When a comm is freed we do not hold a reference to the cache. + # We attach a finalizer that extracts the stats before the cache # is deleted. - _KNOWN_CACHES.append( - (next(_CACHE_CIDX), comm, comm.name, func, weakref.ref(local_cache)) - ) + _KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache)) # JBTODO: Replace everything below here with: # value = local_cache.get(key, CACHE_MISS) @@ -487,6 +509,7 @@ def wrapper(*args, **kwargs): return decorator +# JBTODO: This needs some more thought def clear_memory_cache(comm): with temp_internal_comm(comm) as icomm: if icomm.Get_attr(comm_cache_keyval) is not None: diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 9f2263d62..ce45c7e99 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -678,6 +678,7 @@ def _add_profiling_events(dll, events): ctypes.c_int.in_dll(dll, 'ID_'+e).value = PETSc.Log.Event(e).id +# JBTODO: Move to caching?? def clear_cache(prompt=False): """Clear the PyOP2 compiler cache. @@ -688,10 +689,10 @@ def clear_cache(prompt=False): for directory in cachedirs: if not os.path.exists(directory): print("Cache directory could not be found") - return + continue if len(os.listdir(directory)) == 0: print("No cached libraries to remove") - return + continue remove = True if prompt: