Skip to content

Commit

Permalink
ENH: write caches to user cache directory
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Mar 19, 2024
1 parent 145e9e5 commit fe71561
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 55 deletions.
91 changes: 91 additions & 0 deletions src/ampform_dpd/_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Helper functions for :func:`.perform_cached_doit`.
Implementation taken from
https://github.com/ComPWA/ampform/blob/40a898f/src/ampform/sympy/_cache.py
"""

from __future__ import annotations

import functools
import hashlib
import logging
import os
import pickle # noqa: S403
import sys
from textwrap import dedent

import sympy as sp

_LOGGER = logging.getLogger(__name__)


def get_system_cache_directory() -> str:
r"""Return the system cache directory for the current platform.
>>> import sys
>>> if sys.platform.startswith("darwin"):
... assert get_system_cache_directory().endswith("/Library/Caches")
>>> if sys.platform.startswith("linux"):
... assert get_system_cache_directory().endswith("/.cache")
>>> if sys.platform.startswith("win"):
... assert get_system_cache_directory().endswith(R"\AppData\Local")
"""
if sys.platform.startswith("linux"):
cache_directory = os.getenv("XDG_CACHE_HOME")
if cache_directory is not None:
return cache_directory
if sys.platform.startswith("darwin"): # macos
return os.path.expanduser("~/Library/Caches")
if sys.platform.startswith("win"):
cache_directory = os.getenv("LocalAppData") # noqa: SIM112
if cache_directory is not None:
return cache_directory
return os.path.expanduser("~/AppData/Local")
return os.path.expanduser("~/.cache")


def get_readable_hash(obj, ignore_hash_seed: bool = False) -> str:
"""Get a human-readable hash of any hashable Python object.
The algorithm is fastest if `PYTHONHASHSEED
<https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_ is set.
Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`.
Args:
obj: Any hashable object, mutable or immutable, to be hashed.
ignore_hash_seed: Ignore the :code:`PYTHONHASHSEED` environment variable. If
:code:`True`, the hash seed is ignored and the hash is computed with
:func:`hashlib.sha256`.
"""
python_hash_seed = _get_python_hash_seed()
if ignore_hash_seed or python_hash_seed is None:
b = _to_bytes(obj)
return hashlib.sha256(b).hexdigest()
return f"pythonhashseed-{python_hash_seed}{hash(obj):+}"


def _to_bytes(obj) -> bytes:
if isinstance(obj, sp.Expr):
# Using the str printer is slower and not necessarily unique,
# but pickle.dumps() does not always result in the same bytes stream.
_warn_about_unsafe_hash()
return str(obj).encode()
return pickle.dumps(obj)


def _get_python_hash_seed() -> int | None:
python_hash_seed = os.environ.get("PYTHONHASHSEED")
if python_hash_seed is not None and python_hash_seed.isdigit():
return int(python_hash_seed)
return None


@functools.lru_cache(maxsize=None) # warn once
def _warn_about_unsafe_hash():
message = """
PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions,
set the PYTHONHASHSEED environment variable to a fixed value and rerun the program.
See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED
"""
message = dedent(message).replace("\n", " ").strip()
_LOGGER.warning(message)
69 changes: 15 additions & 54 deletions src/ampform_dpd/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@

from __future__ import annotations

import hashlib
import logging
import os
import pickle
from collections import abc
from functools import lru_cache
from os.path import abspath, dirname, expanduser
from textwrap import dedent
from importlib.metadata import version
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, overload

import cloudpickle
Expand All @@ -36,6 +33,7 @@
)
from tensorwaves.function.sympy import create_function, create_parametrized_function

from ampform_dpd._cache import get_readable_hash, get_system_cache_directory
from ampform_dpd.decay import IsobarNode, Particle, ThreeBodyDecay, ThreeBodyDecayChain

if TYPE_CHECKING:
Expand Down Expand Up @@ -240,7 +238,7 @@ def perform_cached_lambdify( # pyright: ignore[reportInconsistentOverload]
expr: sp.Expr,
parameters: Mapping[sp.Symbol, ParameterValue] | None = None,
backend: str = "jax",
directory: str | None = None,
cache_directory: Path | str | None = None,
) -> ParametrizedFunction | Function:
"""Lambdify a SymPy `~sympy.core.expr.Expr` and cache the result to disk.
Expand All @@ -266,9 +264,15 @@ def perform_cached_lambdify( # pyright: ignore[reportInconsistentOverload]
.. seealso:: :func:`ampform.sympy.perform_cached_doit`
"""
if directory is None:
main_cache_dir = _get_main_cache_dir()
directory = abspath(f"{main_cache_dir}/.sympy-cache-{backend}")
if cache_directory is None:
system_cache_dir = get_system_cache_directory()
backend_version = version(backend)
cache_directory = (
Path(system_cache_dir) / "ampform_dpd" / f"{backend}-v{backend_version}"
)
if not isinstance(cache_directory, Path):
cache_directory = Path(cache_directory)
cache_directory.mkdir(exist_ok=True, parents=True)
if parameters is None:
hash_obj = expr
else:
Expand All @@ -277,63 +281,20 @@ def perform_cached_lambdify( # pyright: ignore[reportInconsistentOverload]
tuple((s, parameters[s]) for s in sorted(parameters, key=str)),
)
h = get_readable_hash(hash_obj)
filename = f"{directory}/{h}.pkl"
if os.path.exists(filename):
filename = cache_directory / f"{h}.pkl"
if filename.exists():
with open(filename, "rb") as f:
return pickle.load(f)
_LOGGER.warning(f"Cached function file {filename} not found, lambdifying...")
if parameters is None:
func = create_function(expr, backend)
else:
func = create_parametrized_function(expr, parameters, backend)
os.makedirs(dirname(filename), exist_ok=True)
with open(filename, "wb") as f:
cloudpickle.dump(func, f)
return func


def _get_main_cache_dir() -> str:
cache_dir = os.environ.get("SYMPY_CACHE_DIR")
if cache_dir is None:
cache_dir = expanduser("~") # home directory
return cache_dir


def get_readable_hash(obj) -> str:
python_hash_seed = _get_python_hash_seed()
if python_hash_seed is not None:
return f"pythonhashseed-{python_hash_seed}{hash(obj):+}"
b = _to_bytes(obj)
return hashlib.sha256(b).hexdigest()


def _to_bytes(obj) -> bytes:
if isinstance(obj, sp.Expr):
# Using the str printer is slower and not necessarily unique,
# but pickle.dumps() does not always result in the same bytes stream.
_warn_about_unsafe_hash()
return str(obj).encode()
return pickle.dumps(obj)


def _get_python_hash_seed() -> int | None:
python_hash_seed = os.environ.get("PYTHONHASHSEED", "")
if python_hash_seed.isdigit():
return int(python_hash_seed)
return None


@lru_cache(maxsize=None) # warn once
def _warn_about_unsafe_hash():
message = """
PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions,
set the PYTHONHASHSEED environment variable to a fixed value and rerun the program.
See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED
"""
message = dedent(message).replace("\n", " ").strip()
_LOGGER.warning(message)


def simplify_latex_rendering() -> None:
"""Improve LaTeX rendering of an `~sympy.tensor.indexed.Indexed` object."""

Expand Down
3 changes: 2 additions & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import pytest
import sympy as sp

from ampform_dpd._cache import _warn_about_unsafe_hash
from ampform_dpd.decay import IsobarNode, Particle
from ampform_dpd.io import _warn_about_unsafe_hash, aslatex, get_readable_hash
from ampform_dpd.io import aslatex, get_readable_hash

if TYPE_CHECKING:
from _pytest.logging import LogCaptureFixture
Expand Down

0 comments on commit fe71561

Please sign in to comment.