Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Round things out and fix remaining tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 14, 2024
1 parent a4d85af commit 020c311
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 243 deletions.
142 changes: 44 additions & 98 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@
from collections.abc import MutableMapping
from pathlib import Path
from warnings import warn # noqa F401
from functools import wraps, partial
from functools import wraps

from pyop2.configuration import configuration
from pyop2.logger import debug
from pyop2.mpi import comm_cache_keyval, COMM_WORLD
from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval


# TODO: Remove this? Rewrite?
def report_cache(typ):
"""Report the size of caches of type ``typ``
Expand Down Expand Up @@ -169,11 +170,6 @@ class _CacheMiss:
CACHE_MISS = _CacheMiss()


class _CacheKey:
def __init__(self, key_value):
self.value = key_value


def _as_hexdigest(*args):
hash_ = hashlib.md5()
for a in args:
Expand All @@ -193,7 +189,6 @@ def __init__(self, cachedir):
:arg cachedir: The cache directory.
"""
self.cachedir = cachedir
self._keys = set()

def __getitem__(self, key):
"""Retrieve a value from the disk cache.
Expand All @@ -215,7 +210,6 @@ def __setitem__(self, key, value):
:arg key: The cache key, a 2-tuple of strings.
:arg value: The new item to store in the cache.
"""
self._keys.add(key)
k1, k2 = key[0][:2], key[0][2:] + key[1]
basedir = Path(self.cachedir, k1)
basedir.mkdir(parents=True, exist_ok=True)
Expand All @@ -229,18 +223,14 @@ def __setitem__(self, key, value):
def __delitem__(self, key):
raise ValueError(f"Cannot remove items from {self.__class__.__name__}")

def keys(self):
return self._keys

def __iter__(self):
for k in self._keys:
yield k
raise ValueError(f"Cannot iterate over keys in {self.__class__.__name__}")

def __len__(self):
return len(self._keys)
raise ValueError(f"Cannot query length of {self.__class__.__name__}")

def __repr__(self):
return "{" + " ".join(f"{k}: {v}" for k, v in self.items()) + "}"
return f"{self.__class__.__name__}(cachedir={self.cachedir})"

def open(self, *args, **kwargs):
return open(*args, **kwargs)
Expand All @@ -253,13 +243,41 @@ def write(self, filehandle, value):


def default_comm_fetcher(*args, **kwargs):
return kwargs.get("comm")
comms = filter(
lambda arg: isinstance(arg, MPI.Comm),
args + tuple(kwargs.values())
)
try:
comm = next(comms)
except StopIteration:
raise TypeError("No comms found in args or kwargs")
return comm


def default_parallel_hashkey(*args, **kwargs):
""" We now want to actively remove any comms from args and kwargs to get the same disk cache key
"""
hash_args = tuple(filter(
lambda arg: not isinstance(arg, MPI.Comm),
args
))
hash_kwargs = dict(filter(
lambda arg: not isinstance(arg[1], MPI.Comm),
kwargs.items()
))
return cachetools.keys.hashkey(*hash_args, **hash_kwargs)


default_parallel_hashkey = cachetools.keys.hashkey
class DEFAULT_CACHE(dict):
pass


def parallel_cache(hashkey=default_parallel_hashkey, comm_fetcher=default_comm_fetcher, cache_factory=None, broadcast=True):
def parallel_cache(
hashkey=default_parallel_hashkey,
comm_fetcher=default_comm_fetcher,
cache_factory=lambda: DEFAULT_CACHE(),
broadcast=True
):
"""Memory only cache decorator.
Decorator for wrapping a function to be called over a communiucator in a
Expand Down Expand Up @@ -336,11 +354,7 @@ def wrapper(*args, **kwargs):


# A small collection of default simple caches
class DEFAULT_CACHE(dict):
pass


memory_cache = partial(parallel_cache, cache_factory=lambda: DEFAULT_CACHE())
memory_cache = parallel_cache


def disk_only_cache(*args, cachedir=configuration["cache_dir"], **kwargs):
Expand All @@ -352,77 +366,9 @@ def decorator(func):
return memory_cache(*args, **kwargs)(disk_only_cache(*args, cachedir=cachedir, **kwargs)(func))
return decorator


# ### Some notes from Connor:
#
# def pcache(comm_seeker, key=None, cache_factory=dict):
#
# comm = comm_seeker()
# cache = cache_factory()
#
# @pcache(cachetools.LRUCache)
#
# @pcache(DiskCache)
#
# @pcache(MemDiskCache)
#
# @pcache(MemCache)
#
# mem_cache = pcache(cache_factory=cachetools.LRUCache)
# disk_cache = mem_cache(cache_factory=DiskCache)
#
# @pcache(comm_seeker=lambda obj, *_, **_: obj.comm, cache_factory=lambda: cachetools.LRUCache(maxsize=1000))
#
#
# @pmemcache
#
# @pmemdiskcache
#
# class ParallelObject(ABC):
# @abc.abstractproperty
# def _comm(self):
# pass
#
# class MyObj(ParallelObject):
#
# @pcached_property # assumes that obj has a "comm" attr
# @pcached_property(lambda self: self.comm)
# def myproperty(self):
# ...
#
#
# def pcached_property():
# def wrapper(self):
# assert isinstance(self, ParallelObject)
# ...
#
#
# from futils.mpi import ParallelObject
#
# from futils.cache import pcached_property
#
# from footils.cache import *
#
# footils == firedrake utils

# * parallel cached property
# * memcache / cache / cached
# * diskonlycache / disk_only_cached
# * memdiskcache / diskcache / disk_cached
# * memcache_no_bcast / broadcast=False
#
# parallel_cached_property = parallel_cache(lambda self: self._comm, key=lambda self: ())
#
# @time
# @timed
# def myslowfunc():
#
# ..
#
# my_fast_fun = cache(my_slow_fn)
#
####


# TODO:
# Implement an @parallel_cached_property decorator function
# TODO: (Wishlist)
# * Try more exotic caches ie: memory_cache = partial(parallel_cache, cache_factory=lambda: cachetools.LRUCache(maxsize=1000))
# * Add some sort of cache reporting
# * Add some sort of cache statistics
# * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method
# * Add some docstrings and maybe some exposition!
3 changes: 1 addition & 2 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,7 @@ def load_hashkey(*args, **kwargs):
code_hash = md5(str(args[0].cache_key).encode()).hexdigest()
else:
pass # This will raise an error in load
comm = kwargs.get('comm')
return comm, cachetools.keys.hashkey(code_hash, *args[1:], **kwargs)
return cachetools.keys.hashkey(code_hash, *args[1:], **kwargs)


@mpi.collective
Expand Down
1 change: 1 addition & 0 deletions requirements-git.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
git+https://github.com/firedrakeproject/loopy.git@main#egg=loopy
git+https://github.com/firedrakeproject/pytest-mpi.git@main#egg=pytest-mpi
102 changes: 54 additions & 48 deletions test/unit/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@
import os
import pytest
import tempfile
import cachetools
import numpy
from pyop2 import op2, mpi
from pyop2.caching import memory_and_disk_cache
from pyop2.caching import DEFAULT_CACHE, memory_and_disk_cache, clear_memory_cache


def _seed():
Expand Down Expand Up @@ -289,9 +288,9 @@ def cache(self):
int_comm = mpi.internal_comm(mpi.COMM_WORLD, self)
_cache = int_comm.Get_attr(mpi.comm_cache_keyval)
if _cache is None:
_cache = {}
_cache = {'DEFAULT_CACHE': DEFAULT_CACHE()}
mpi.COMM_WORLD.Set_attr(mpi.comm_cache_keyval, _cache)
return _cache
return _cache['DEFAULT_CACHE']

@pytest.fixture
def a(cls, diterset):
Expand Down Expand Up @@ -533,79 +532,86 @@ def test_sparsities_different_ordered_map_tuple_cached(self, m1, m2, ds2):
class TestDiskCachedDecorator:

@staticmethod
def myfunc(arg):
def myfunc(arg, comm):
"""Example function to cache the outputs of."""
return {arg}

def comm_fetcher(self, *args):
"""Communicator returning function."""
return mpi.internal_comm(mpi.COMM_WORLD, self)

def hash_key(self, *args):
"""Hash key suitable for caching"""
return cachetools.keys.hashkey(*args)
@pytest.fixture
def comm(self):
"""This fixture provides a temporary comm so that each test gets it's own
communicator and that caches are cleaned on free."""
temporary_comm = mpi.COMM_WORLD.Dup()
temporary_comm.name = "pytest temporary COMM_WORLD"
with mpi.temp_internal_comm(temporary_comm) as comm:
yield comm
temporary_comm.Free()

@pytest.fixture
def cachedir(cls):
return tempfile.TemporaryDirectory()

def test_decorator_in_memory_cache_reuses_results(self, cachedir):
def test_decorator_in_memory_cache_reuses_results(self, cachedir, comm):
decorated_func = memory_and_disk_cache(
comm_fetcher=self.comm_fetcher,
cachedir=cachedir.name
)(self.myfunc)

obj1 = decorated_func("input1")
caches = self.comm_fetcher().Get_attr(mpi.comm_cache_keyval)
mem_cache = caches["DEFAULT_CACHE"]
disk_cache = caches["DictLikeDiskAccess"]
obj1 = decorated_func("input1", comm=comm)
mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"]
assert len(mem_cache) == 1
assert len(disk_cache) == 1
assert len(os.listdir(cachedir.name)) == 1

obj2 = decorated_func("input1")
obj2 = decorated_func("input1", comm=comm)
assert obj1 is obj2
assert len(mem_cache) == 1
assert len(disk_cache) == 1
assert len(os.listdir(cachedir.name)) == 1

def test_decorator_collective_uses_different_in_memory_caches(self, cachedir):
decorated_func = memory_and_disk_cache(cachedir=cachedir.name)(self.myfunc)
collective_func = memory_and_disk_cache(
None, cachedir.name, self.collective_key, collective=True
def test_decorator_uses_different_in_memory_caches_on_different_comms(self, cachedir, comm):
comm_world_func = memory_and_disk_cache(
cachedir=cachedir.name
)(self.myfunc)

# obj1 should be cached on the comm cache and not the self.cache
obj1 = collective_func("input1")
comm_cache = self.comm.Get_attr(mpi.comm_cache_keyval)
assert len(cache) == 0
assert len(comm_cache) == 1
assert len(os.listdir(cachedir.name)) == 1

# obj2 should be cached on the self.cache and not the comm cache
obj2 = decorated_func("input1")
assert obj1 == obj2 and obj1 is not obj2
assert len(cache) == 1
assert len(comm_cache) == 1
assert len(os.listdir(cachedir.name)) == 1

def test_decorator_disk_cache_reuses_results(self, cachedir):
temporary_comm = mpi.COMM_SELF.Dup()
temporary_comm.name = "pytest temporary COMM_SELF"
with mpi.temp_internal_comm(temporary_comm) as comm_self:
comm_self_func = memory_and_disk_cache(
cachedir=cachedir.name
)(self.myfunc)

# obj1 should be cached on the COMM_WORLD cache
obj1 = comm_world_func("input1", comm=comm)
comm_world_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"]
assert len(comm_world_cache) == 1
assert len(os.listdir(cachedir.name)) == 1

# obj2 should be cached on the COMM_SELF cache
obj2 = comm_self_func("input1", comm=comm_self)
comm_self_cache = comm_self.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"]
assert obj1 == obj2 and obj1 is not obj2
assert len(comm_world_cache) == 1
assert len(comm_self_cache) == 1
assert len(os.listdir(cachedir.name)) == 1

temporary_comm.Free()

def test_decorator_disk_cache_reuses_results(self, cachedir, comm):
decorated_func = memory_and_disk_cache(cachedir=cachedir.name)(self.myfunc)

obj1 = decorated_func("input1")
cache.clear()
obj2 = decorated_func("input1")
obj1 = decorated_func("input1", comm=comm)
clear_memory_cache(comm)
obj2 = decorated_func("input1", comm=comm)
mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"]
assert obj1 == obj2 and obj1 is not obj2
assert len(cache) == 1
assert len(mem_cache) == 1
assert len(os.listdir(cachedir.name)) == 1

def test_decorator_cache_misses(self, cachedir):
def test_decorator_cache_misses(self, cachedir, comm):
decorated_func = memory_and_disk_cache(cachedir=cachedir.name)(self.myfunc)

obj1 = decorated_func("input1")
obj2 = decorated_func("input2")
obj1 = decorated_func("input1", comm=comm)
obj2 = decorated_func("input2", comm=comm)
mem_cache = comm.Get_attr(mpi.comm_cache_keyval)["DEFAULT_CACHE"]
assert obj1 != obj2
assert len(cache) == 2
assert len(mem_cache) == 2
assert len(os.listdir(cachedir.name)) == 2


Expand Down
Loading

0 comments on commit 020c311

Please sign in to comment.