Skip to content

Commit

Permalink
Fix stats printing
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 24, 2024
1 parent 2df3f35 commit 3eef196
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 54 deletions.
127 changes: 75 additions & 52 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
# OF THE POSSIBILITY OF SUCH DAMAGE.

"""Provides common base classes for cached objects."""
import atexit
import cachetools
import hashlib
import os
Expand Down Expand Up @@ -148,59 +149,82 @@ def make_obj():
return obj


def cache_stats(comm=None, comm_name=None, alive=True, function=None, cache_type=None):
def cache_filter(comm=None, comm_name=None, alive=True, function=None, cache_type=None):
caches = _KNOWN_CACHES
if comm is not None:
with temp_internal_comm(comm) as icomm:
cache_collection = icomm.Get_attr(comm_cache_keyval)
if cache_collection is None:
print(f"Communicator {icomm.name} as no associated caches")
print(f"Communicator {icomm.name} has no associated caches")
comm_name = icomm.name
if comm_name is not None:
caches = filter(lambda c: c[2] == comm_name, caches)
caches = filter(lambda c: c.comm_name == comm_name, caches)
if alive:
caches = filter(lambda c: c[1] != MPI.COMM_NULL, caches)
caches = filter(lambda c: c.comm != MPI.COMM_NULL, caches)
if function is not None:
if isinstance(function, str):
caches = filter(lambda c: function in c[3].__qualname__, caches)
caches = filter(lambda c: function in c.func_name, caches)
else:
caches = filter(lambda c: c[3] is function, caches)
caches = filter(lambda c: c.func is function, caches)
if cache_type is not None:
if isinstance(cache_type, str):
caches = filter(lambda c: cache_type in c[4].__qualname__, caches)
caches = filter(lambda c: cache_type in c.cache_name, caches)
else:
caches = filter(lambda c: isinstance(c[4], cache_type), caches)
caches = filter(lambda c: c.cache_name == cache_type.__class__.__qualname__, caches)
return [*caches]


def get_stats(cache):
hit = miss = size = maxsize = -1
if isinstance(cache, cachetools.Cache):
size = cache.currsize
maxsize = cache.maxsize
if hasattr(cache, "instrument__"):
hit = cache.hit
miss = cache.miss
if size is None:
try:
size = len(cache)
except NotImplementedError:
pass
if maxsize is None:
try:
maxsize = cache.max_size
except AttributeError:
pass
return hit, miss, size, maxsize
class _CacheRecord:
def __init__(self, cidx, comm, func, cache):
self.cidx = cidx
self.comm = comm
self.comm_name = comm.name
self.func = func
self.func_module = func.__module__
self.func_name = func.__qualname__
self.cache = weakref.ref(cache)
fin = weakref.finalize(cache, self.finalize, cache)
fin.atexit = False
self.cache_name = cache.__class__.__qualname__
try:
self.cache_loc = cache.cachedir
except AttributeError:
self.cache_loc = "Memory"

def get_stats(self, cache=None):
if cache is None:
cache = self.cache()
hit = miss = size = maxsize = -1
if cache is None:
hit, miss, size, maxsize = self.hit, self.miss, self.size, self.maxsize
if isinstance(cache, cachetools.Cache):
size = cache.currsize
maxsize = cache.maxsize
if hasattr(cache, "instrument__"):
hit = cache.hit
miss = cache.miss
if size == -1:
try:
size = len(cache)
except NotImplementedError:
pass
if maxsize is None:
try:
maxsize = cache.max_size
except AttributeError:
pass
return hit, miss, size, maxsize

def finalize(self, cache):
self.hit, self.miss, self.size, self.maxsize = self.get_stats(cache)


def print_cache_stats(*args, **kwargs):
data = defaultdict(lambda: defaultdict(list))
for entry in cache_stats(*args, **kwargs):
ecid, ecomm, ecomm_name, efunction, ecache = entry
active = (ecomm != MPI.COMM_NULL)
data[(ecomm_name, active)][ecache.__class__.__name__].append(
(ecid, efunction.__module__, efunction.__name__, ecache)
for entry in cache_filter(*args, **kwargs):
active = (entry.comm != MPI.COMM_NULL)
data[(entry.comm_name, active)][(entry.cache_name, entry.cache_loc)].append(
(entry.cidx, entry.func_module, entry.func_name, entry.get_stats())
)

tab = " "
Expand All @@ -219,22 +243,22 @@ def print_cache_stats(*args, **kwargs):
comm_title = f"{ecomm[0]} ({active})"
print(f"|{comm_title:{col[0]}}|{no_stats}|")
for ecache, function_list in cachedict.items():
cache_title = f"{tab}{ecache}"
cache_title = f"{tab}{ecache[0]}"
print(f"|{cache_title:{col[0]}}|{no_stats}|")
try:
loc = function_list[0][-1].cachedir
except AttributeError:
loc = "Memory"
cache_location = f"{tab}{loc!s}"
if len(str(loc)) < col[0] - 5:
cache_location = f"{tab}{ecache[1]!s}"
if len(cache_location) < col[0]:
print(f"|{cache_location:{col[0]}}|{no_stats}|")
else:
print(f"|{cache_location:78}|")
for entry in function_list:
function_title = f"{tab*2}id={entry[0]} {'.'.join(entry[1:3])}"
stats = "|".join(f"{s:{w}}" for s, w in zip(get_stats(entry[3]), stats_col))
print(f"|{function_title:{col[0]}}|{stats:{col[1]}}|")
print(hline)
stats_row = "|".join(f"{s:{w}}" for s, w in zip(entry[3], stats_col))
print(f"|{function_title:{col[0]}}|{stats_row:{col[1]}}|")
print(hline)


if _running_on_ci:
print_cache_stats = atexit.register(print_cache_stats)


class _CacheMiss:
Expand Down Expand Up @@ -379,11 +403,11 @@ class DEFAULT_CACHE(dict):


# Examples of how to instrument and use different default caches:
# - DEFAULT_CACHE = instrument(DEFAULT_CACHE)
# - DEFAULT_CACHE = instrument(cachetools.LRUCache)
# - DEFAULT_CACHE = partial(DEFAULT_CACHE, maxsize=100)
# ~ DEFAULT_CACHE = instrument(DEFAULT_CACHE)
# ~ DEFAULT_CACHE = instrument(cachetools.LRUCache)
# ~ DEFAULT_CACHE = partial(DEFAULT_CACHE, maxsize=100)
EXOTIC_CACHE = partial(instrument(cachetools.LRUCache), maxsize=100)
# - DictLikeDiskAccess = instrument(DictLikeDiskAccess)
# ~ DictLikeDiskAccess = instrument(DictLikeDiskAccess)


# JBTODO: This functionality should only be enabled with a PYOP2_SPMD_STRICT
Expand Down Expand Up @@ -432,13 +456,11 @@ def wrapper(*args, **kwargs):
local_cache = cache_collection[cf.__class__.__name__]

# If this is a new cache or function add it to the list of known caches
if (comm, comm.name, func, weakref.ref(local_cache)) not in [c[1:] for c in _KNOWN_CACHES]:
# JBTODO: When a comm is freed we will not hold a ref to the cache,
# but we should have a finalizer that extracts the stats before the object
if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]:
# When a comm is freed we do not hold a reference to the cache.
# We attach a finalizer that extracts the stats before the cache
# is deleted.
_KNOWN_CACHES.append(
(next(_CACHE_CIDX), comm, comm.name, func, weakref.ref(local_cache))
)
_KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache))

# JBTODO: Replace everything below here with:
# value = local_cache.get(key, CACHE_MISS)
Expand Down Expand Up @@ -487,6 +509,7 @@ def wrapper(*args, **kwargs):
return decorator


# JBTODO: This needs some more thought
def clear_memory_cache(comm):
with temp_internal_comm(comm) as icomm:
if icomm.Get_attr(comm_cache_keyval) is not None:
Expand Down
5 changes: 3 additions & 2 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ def _add_profiling_events(dll, events):
ctypes.c_int.in_dll(dll, 'ID_'+e).value = PETSc.Log.Event(e).id


# JBTODO: Move to caching??
def clear_cache(prompt=False):
"""Clear the PyOP2 compiler cache.
Expand All @@ -688,10 +689,10 @@ def clear_cache(prompt=False):
for directory in cachedirs:
if not os.path.exists(directory):
print("Cache directory could not be found")
return
continue
if len(os.listdir(directory)) == 0:
print("No cached libraries to remove")
return
continue

remove = True
if prompt:
Expand Down

0 comments on commit 3eef196

Please sign in to comment.