From 55a15eb472ea67a715ed6be6ecb13ec3c51ba63c Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:13:53 +0100 Subject: [PATCH] MAINT: move cache helper functions to `sympy._cache` --- src/ampform/sympy/__init__.py | 87 ++----------------- src/ampform/sympy/_cache.py | 87 +++++++++++++++++++ .../sympy/{test_caching.py => test_cache.py} | 8 +- 3 files changed, 97 insertions(+), 85 deletions(-) create mode 100644 src/ampform/sympy/_cache.py rename tests/sympy/{test_caching.py => test_cache.py} (95%) diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index 76bf00a2e..babfe0c8a 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -12,8 +12,6 @@ from __future__ import annotations -import functools -import hashlib import itertools import logging import os @@ -23,7 +21,6 @@ import warnings from abc import abstractmethod from os.path import abspath, dirname -from textwrap import dedent from typing import TYPE_CHECKING, Iterable, Sequence, SupportsFloat import sympy as sp @@ -31,6 +28,7 @@ from sympy.printing.precedence import PRECEDENCE from sympy.printing.pycode import _unpack_integral_limits # noqa: PLC2701 +from ._cache import get_readable_hash, get_system_cache_directory from ._decorator import ( ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport] SymPyAssumptions, # noqa: F401 # pyright: ignore[reportUnusedImport] @@ -342,26 +340,25 @@ def perform_cached_doit( """Perform :meth:`~sympy.core.basic.Basic.doit` and cache the result to disk. The cached result is fetched from disk if the hash of the original expression is the - same as the hash embedded in the filename. + same as the hash embedded in the filename (see :func:`.get_readable_hash`). Args: unevaluated_expr: A `sympy.Expr ` on which to call :meth:`~sympy.core.basic.Basic.doit`. cache_directory: The directory in which to cache the result. Defaults to :file:`ampform` under the system cache directory (see - :func:`_get_system_cache_directory`). + :func:`.get_system_cache_directory`). .. tip:: For a faster cache, set `PYTHONHASHSEED `_ to a fixed value. - .. autofunction:: _get_system_cache_directory - .. autofunction:: _get_readable_hash + .. automodule:: ampform.sympy._cache """ if cache_directory is None: - system_cache_dir = _get_system_cache_directory() + system_cache_dir = get_system_cache_directory() cache_directory = abspath(f"{system_cache_dir}/ampform") - h = _get_readable_hash(unevaluated_expr) + h = get_readable_hash(unevaluated_expr) filename = f"{cache_directory}/{h}.pkl" os.makedirs(dirname(filename), exist_ok=True) if os.path.exists(filename): @@ -374,75 +371,3 @@ def perform_cached_doit( with open(filename, "wb") as f: pickle.dump(unfolded_expr, f) return unfolded_expr - - -def _get_system_cache_directory() -> str: - r"""Return the system cache directory for the current platform. - - >>> import sys, pytest - >>> if sys.platform.startswith("darwin"): - ... assert _get_system_cache_directory().endswith("/Library/Caches") - >>> if sys.platform.startswith("linux"): - ... assert _get_system_cache_directory().endswith("/.cache") - >>> if sys.platform.startswith("win"): - ... assert _get_system_cache_directory().endswith(R"\AppData\Local") - """ - if sys.platform.startswith("linux"): - cache_directory = os.getenv("XDG_CACHE_HOME") - if cache_directory is not None: - return cache_directory - if sys.platform.startswith("darwin"): # macos - return os.path.expanduser("~/Library/Caches") - if sys.platform.startswith("win"): - cache_directory = os.getenv("LocalAppData") # noqa: SIM112 - if cache_directory is not None: - return cache_directory - return os.path.expanduser("~/AppData/Local") - return os.path.expanduser("~/.cache") - - -def _get_readable_hash(obj, ignore_hash_seed: bool = False) -> str: - """Get a human-readable hash of any hashable Python object. - - The algorithm is fastest if `PYTHONHASHSEED - `_ is set. - Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`. - - Args: - obj: Any hashable object, mutable or immutable, to be hashed. - ignore_hash_seed: Ignore the :code:`PYTHONHASHSEED` environment variable. If - :code:`True`, the hash seed is ignored and the hash is computed with - :func:`hashlib.sha256`. - """ - python_hash_seed = _get_python_hash_seed() - if ignore_hash_seed or python_hash_seed is None: - b = _to_bytes(obj) - return hashlib.sha256(b).hexdigest() - return f"pythonhashseed-{python_hash_seed}{hash(obj):+}" - - -def _to_bytes(obj) -> bytes: - if isinstance(obj, sp.Expr): - # Using the str printer is slower and not necessarily unique, - # but pickle.dumps() does not always result in the same bytes stream. - _warn_about_unsafe_hash() - return str(obj).encode() - return pickle.dumps(obj) - - -def _get_python_hash_seed() -> int | None: - python_hash_seed = os.environ.get("PYTHONHASHSEED", "") - if python_hash_seed is not None and python_hash_seed.isdigit(): - return int(python_hash_seed) - return None - - -@functools.lru_cache(maxsize=None) # warn once -def _warn_about_unsafe_hash(): - message = """ - PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions, - set the PYTHONHASHSEED environment variable to a fixed value and rerun the program. - See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED - """ - message = dedent(message).replace("\n", " ").strip() - _LOGGER.warning(message) diff --git a/src/ampform/sympy/_cache.py b/src/ampform/sympy/_cache.py new file mode 100644 index 000000000..421f4d89c --- /dev/null +++ b/src/ampform/sympy/_cache.py @@ -0,0 +1,87 @@ +"""Helper functions for :func:`.perform_cached_doit`.""" + +from __future__ import annotations + +import functools +import hashlib +import logging +import os +import pickle # noqa: S403 +import sys +from textwrap import dedent + +import sympy as sp + +_LOGGER = logging.getLogger(__name__) + + +def get_system_cache_directory() -> str: + r"""Return the system cache directory for the current platform. + + >>> import sys, pytest + >>> if sys.platform.startswith("darwin"): + ... assert get_system_cache_directory().endswith("/Library/Caches") + >>> if sys.platform.startswith("linux"): + ... assert get_system_cache_directory().endswith("/.cache") + >>> if sys.platform.startswith("win"): + ... assert get_system_cache_directory().endswith(R"\AppData\Local") + """ + if sys.platform.startswith("linux"): + cache_directory = os.getenv("XDG_CACHE_HOME") + if cache_directory is not None: + return cache_directory + if sys.platform.startswith("darwin"): # macos + return os.path.expanduser("~/Library/Caches") + if sys.platform.startswith("win"): + cache_directory = os.getenv("LocalAppData") # noqa: SIM112 + if cache_directory is not None: + return cache_directory + return os.path.expanduser("~/AppData/Local") + return os.path.expanduser("~/.cache") + + +def get_readable_hash(obj, ignore_hash_seed: bool = False) -> str: + """Get a human-readable hash of any hashable Python object. + + The algorithm is fastest if `PYTHONHASHSEED + `_ is set. + Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`. + + Args: + obj: Any hashable object, mutable or immutable, to be hashed. + ignore_hash_seed: Ignore the :code:`PYTHONHASHSEED` environment variable. If + :code:`True`, the hash seed is ignored and the hash is computed with + :func:`hashlib.sha256`. + """ + python_hash_seed = _get_python_hash_seed() + if ignore_hash_seed or python_hash_seed is None: + b = _to_bytes(obj) + return hashlib.sha256(b).hexdigest() + return f"pythonhashseed-{python_hash_seed}{hash(obj):+}" + + +def _to_bytes(obj) -> bytes: + if isinstance(obj, sp.Expr): + # Using the str printer is slower and not necessarily unique, + # but pickle.dumps() does not always result in the same bytes stream. + _warn_about_unsafe_hash() + return str(obj).encode() + return pickle.dumps(obj) + + +def _get_python_hash_seed() -> int | None: + python_hash_seed = os.environ.get("PYTHONHASHSEED", "") + if python_hash_seed is not None and python_hash_seed.isdigit(): + return int(python_hash_seed) + return None + + +@functools.lru_cache(maxsize=None) # warn once +def _warn_about_unsafe_hash(): + message = """ + PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions, + set the PYTHONHASHSEED environment variable to a fixed value and rerun the program. + See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED + """ + message = dedent(message).replace("\n", " ").strip() + _LOGGER.warning(message) diff --git a/tests/sympy/test_caching.py b/tests/sympy/test_cache.py similarity index 95% rename from tests/sympy/test_caching.py rename to tests/sympy/test_cache.py index 036744ab1..41fd50906 100644 --- a/tests/sympy/test_caching.py +++ b/tests/sympy/test_cache.py @@ -9,7 +9,7 @@ import sympy as sp from ampform.dynamics import EnergyDependentWidth -from ampform.sympy import _get_readable_hash, _warn_about_unsafe_hash +from ampform.sympy._cache import _warn_about_unsafe_hash, get_readable_hash if TYPE_CHECKING: from _pytest.logging import LogCaptureFixture @@ -61,7 +61,7 @@ def test_get_readable_hash(assumptions, expected_hashes, caplog: LogCaptureFixtu caplog.set_level(logging.WARNING) x, y = sp.symbols("x y", **assumptions) expr = x**2 + y - h_str = _get_readable_hash(expr) + h_str = get_readable_hash(expr) python_hash_seed = os.environ.get("PYTHONHASHSEED") if python_hash_seed is None: assert h_str[:7] == "bbc9833" @@ -88,7 +88,7 @@ def test_get_readable_hash_energy_dependent_width(): angular_momentum=angular_momentum, meson_radius=d, ) - h = _get_readable_hash(expr) + h = get_readable_hash(expr) python_hash_seed = os.environ.get("PYTHONHASHSEED") if python_hash_seed is None: pytest.skip("PYTHONHASHSEED has not been set") @@ -124,4 +124,4 @@ def test_get_readable_hash_large(amplitude_model: tuple[str, HelicityModel]): "canonical-helicity": "pythonhashseed-0-8505502895987205495", "helicity": "pythonhashseed-0-1430245260241162669", }[formalism] - assert _get_readable_hash(model.expression) == expected_hash + assert get_readable_hash(model.expression) == expected_hash