From fe71561cf4ed048fc834a899f590059b78b63a15 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Tue, 19 Mar 2024 11:22:29 +0100 Subject: [PATCH] ENH: write caches to user cache directory --- src/ampform_dpd/_cache.py | 91 +++++++++++++++++++++++++++++++++++++++ src/ampform_dpd/io.py | 69 +++++++---------------------- tests/test_io.py | 3 +- 3 files changed, 108 insertions(+), 55 deletions(-) create mode 100644 src/ampform_dpd/_cache.py diff --git a/src/ampform_dpd/_cache.py b/src/ampform_dpd/_cache.py new file mode 100644 index 00000000..e67e01b0 --- /dev/null +++ b/src/ampform_dpd/_cache.py @@ -0,0 +1,91 @@ +"""Helper functions for :func:`.perform_cached_doit`. + +Implementation taken from +https://github.com/ComPWA/ampform/blob/40a898f/src/ampform/sympy/_cache.py +""" + +from __future__ import annotations + +import functools +import hashlib +import logging +import os +import pickle # noqa: S403 +import sys +from textwrap import dedent + +import sympy as sp + +_LOGGER = logging.getLogger(__name__) + + +def get_system_cache_directory() -> str: + r"""Return the system cache directory for the current platform. + + >>> import sys + >>> if sys.platform.startswith("darwin"): + ... assert get_system_cache_directory().endswith("/Library/Caches") + >>> if sys.platform.startswith("linux"): + ... assert get_system_cache_directory().endswith("/.cache") + >>> if sys.platform.startswith("win"): + ... assert get_system_cache_directory().endswith(R"\AppData\Local") + """ + if sys.platform.startswith("linux"): + cache_directory = os.getenv("XDG_CACHE_HOME") + if cache_directory is not None: + return cache_directory + if sys.platform.startswith("darwin"): # macos + return os.path.expanduser("~/Library/Caches") + if sys.platform.startswith("win"): + cache_directory = os.getenv("LocalAppData") # noqa: SIM112 + if cache_directory is not None: + return cache_directory + return os.path.expanduser("~/AppData/Local") + return os.path.expanduser("~/.cache") + + +def get_readable_hash(obj, ignore_hash_seed: bool = False) -> str: + """Get a human-readable hash of any hashable Python object. + + The algorithm is fastest if `PYTHONHASHSEED + `_ is set. + Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`. + + Args: + obj: Any hashable object, mutable or immutable, to be hashed. + ignore_hash_seed: Ignore the :code:`PYTHONHASHSEED` environment variable. If + :code:`True`, the hash seed is ignored and the hash is computed with + :func:`hashlib.sha256`. + """ + python_hash_seed = _get_python_hash_seed() + if ignore_hash_seed or python_hash_seed is None: + b = _to_bytes(obj) + return hashlib.sha256(b).hexdigest() + return f"pythonhashseed-{python_hash_seed}{hash(obj):+}" + + +def _to_bytes(obj) -> bytes: + if isinstance(obj, sp.Expr): + # Using the str printer is slower and not necessarily unique, + # but pickle.dumps() does not always result in the same bytes stream. + _warn_about_unsafe_hash() + return str(obj).encode() + return pickle.dumps(obj) + + +def _get_python_hash_seed() -> int | None: + python_hash_seed = os.environ.get("PYTHONHASHSEED") + if python_hash_seed is not None and python_hash_seed.isdigit(): + return int(python_hash_seed) + return None + + +@functools.lru_cache(maxsize=None) # warn once +def _warn_about_unsafe_hash(): + message = """ + PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions, + set the PYTHONHASHSEED environment variable to a fixed value and rerun the program. + See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED + """ + message = dedent(message).replace("\n", " ").strip() + _LOGGER.warning(message) diff --git a/src/ampform_dpd/io.py b/src/ampform_dpd/io.py index 4321c7ed..9e89bb3a 100644 --- a/src/ampform_dpd/io.py +++ b/src/ampform_dpd/io.py @@ -18,14 +18,11 @@ from __future__ import annotations -import hashlib import logging -import os import pickle from collections import abc -from functools import lru_cache -from os.path import abspath, dirname, expanduser -from textwrap import dedent +from importlib.metadata import version +from pathlib import Path from typing import TYPE_CHECKING, Iterable, Mapping, Sequence, overload import cloudpickle @@ -36,6 +33,7 @@ ) from tensorwaves.function.sympy import create_function, create_parametrized_function +from ampform_dpd._cache import get_readable_hash, get_system_cache_directory from ampform_dpd.decay import IsobarNode, Particle, ThreeBodyDecay, ThreeBodyDecayChain if TYPE_CHECKING: @@ -240,7 +238,7 @@ def perform_cached_lambdify( # pyright: ignore[reportInconsistentOverload] expr: sp.Expr, parameters: Mapping[sp.Symbol, ParameterValue] | None = None, backend: str = "jax", - directory: str | None = None, + cache_directory: Path | str | None = None, ) -> ParametrizedFunction | Function: """Lambdify a SymPy `~sympy.core.expr.Expr` and cache the result to disk. @@ -266,9 +264,15 @@ def perform_cached_lambdify( # pyright: ignore[reportInconsistentOverload] .. seealso:: :func:`ampform.sympy.perform_cached_doit` """ - if directory is None: - main_cache_dir = _get_main_cache_dir() - directory = abspath(f"{main_cache_dir}/.sympy-cache-{backend}") + if cache_directory is None: + system_cache_dir = get_system_cache_directory() + backend_version = version(backend) + cache_directory = ( + Path(system_cache_dir) / "ampform_dpd" / f"{backend}-v{backend_version}" + ) + if not isinstance(cache_directory, Path): + cache_directory = Path(cache_directory) + cache_directory.mkdir(exist_ok=True, parents=True) if parameters is None: hash_obj = expr else: @@ -277,8 +281,8 @@ def perform_cached_lambdify( # pyright: ignore[reportInconsistentOverload] tuple((s, parameters[s]) for s in sorted(parameters, key=str)), ) h = get_readable_hash(hash_obj) - filename = f"{directory}/{h}.pkl" - if os.path.exists(filename): + filename = cache_directory / f"{h}.pkl" + if filename.exists(): with open(filename, "rb") as f: return pickle.load(f) _LOGGER.warning(f"Cached function file {filename} not found, lambdifying...") @@ -286,54 +290,11 @@ def perform_cached_lambdify( # pyright: ignore[reportInconsistentOverload] func = create_function(expr, backend) else: func = create_parametrized_function(expr, parameters, backend) - os.makedirs(dirname(filename), exist_ok=True) with open(filename, "wb") as f: cloudpickle.dump(func, f) return func -def _get_main_cache_dir() -> str: - cache_dir = os.environ.get("SYMPY_CACHE_DIR") - if cache_dir is None: - cache_dir = expanduser("~") # home directory - return cache_dir - - -def get_readable_hash(obj) -> str: - python_hash_seed = _get_python_hash_seed() - if python_hash_seed is not None: - return f"pythonhashseed-{python_hash_seed}{hash(obj):+}" - b = _to_bytes(obj) - return hashlib.sha256(b).hexdigest() - - -def _to_bytes(obj) -> bytes: - if isinstance(obj, sp.Expr): - # Using the str printer is slower and not necessarily unique, - # but pickle.dumps() does not always result in the same bytes stream. - _warn_about_unsafe_hash() - return str(obj).encode() - return pickle.dumps(obj) - - -def _get_python_hash_seed() -> int | None: - python_hash_seed = os.environ.get("PYTHONHASHSEED", "") - if python_hash_seed.isdigit(): - return int(python_hash_seed) - return None - - -@lru_cache(maxsize=None) # warn once -def _warn_about_unsafe_hash(): - message = """ - PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions, - set the PYTHONHASHSEED environment variable to a fixed value and rerun the program. - See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED - """ - message = dedent(message).replace("\n", " ").strip() - _LOGGER.warning(message) - - def simplify_latex_rendering() -> None: """Improve LaTeX rendering of an `~sympy.tensor.indexed.Indexed` object.""" diff --git a/tests/test_io.py b/tests/test_io.py index db1499e9..c3700498 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -10,8 +10,9 @@ import pytest import sympy as sp +from ampform_dpd._cache import _warn_about_unsafe_hash from ampform_dpd.decay import IsobarNode, Particle -from ampform_dpd.io import _warn_about_unsafe_hash, aslatex, get_readable_hash +from ampform_dpd.io import aslatex, get_readable_hash if TYPE_CHECKING: from _pytest.logging import LogCaptureFixture