Skip to content

Commit

Permalink
MAINT: move cache helper functions to sympy._cache
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Mar 8, 2024
1 parent fde3929 commit 55a15eb
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 85 deletions.
87 changes: 6 additions & 81 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

from __future__ import annotations

import functools
import hashlib
import itertools
import logging
import os
Expand All @@ -23,14 +21,14 @@
import warnings
from abc import abstractmethod
from os.path import abspath, dirname
from textwrap import dedent
from typing import TYPE_CHECKING, Iterable, Sequence, SupportsFloat

import sympy as sp
from sympy.printing.conventions import split_super_sub
from sympy.printing.precedence import PRECEDENCE
from sympy.printing.pycode import _unpack_integral_limits # noqa: PLC2701

from ._cache import get_readable_hash, get_system_cache_directory
from ._decorator import (
ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport]
SymPyAssumptions, # noqa: F401 # pyright: ignore[reportUnusedImport]
Expand Down Expand Up @@ -342,26 +340,25 @@ def perform_cached_doit(
"""Perform :meth:`~sympy.core.basic.Basic.doit` and cache the result to disk.
The cached result is fetched from disk if the hash of the original expression is the
same as the hash embedded in the filename.
same as the hash embedded in the filename (see :func:`.get_readable_hash`).
Args:
unevaluated_expr: A `sympy.Expr <sympy.core.expr.Expr>` on which to call
:meth:`~sympy.core.basic.Basic.doit`.
cache_directory: The directory in which to cache the result. Defaults to
:file:`ampform` under the system cache directory (see
:func:`_get_system_cache_directory`).
:func:`.get_system_cache_directory`).
.. tip:: For a faster cache, set `PYTHONHASHSEED
<https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_ to a
fixed value.
.. autofunction:: _get_system_cache_directory
.. autofunction:: _get_readable_hash
.. automodule:: ampform.sympy._cache
"""
if cache_directory is None:
system_cache_dir = _get_system_cache_directory()
system_cache_dir = get_system_cache_directory()
cache_directory = abspath(f"{system_cache_dir}/ampform")
h = _get_readable_hash(unevaluated_expr)
h = get_readable_hash(unevaluated_expr)
filename = f"{cache_directory}/{h}.pkl"
os.makedirs(dirname(filename), exist_ok=True)
if os.path.exists(filename):
Expand All @@ -374,75 +371,3 @@ def perform_cached_doit(
with open(filename, "wb") as f:
pickle.dump(unfolded_expr, f)
return unfolded_expr


def _get_system_cache_directory() -> str:
r"""Return the system cache directory for the current platform.
>>> import sys, pytest
>>> if sys.platform.startswith("darwin"):
... assert _get_system_cache_directory().endswith("/Library/Caches")
>>> if sys.platform.startswith("linux"):
... assert _get_system_cache_directory().endswith("/.cache")
>>> if sys.platform.startswith("win"):
... assert _get_system_cache_directory().endswith(R"\AppData\Local")
"""
if sys.platform.startswith("linux"):
cache_directory = os.getenv("XDG_CACHE_HOME")
if cache_directory is not None:
return cache_directory
if sys.platform.startswith("darwin"): # macos
return os.path.expanduser("~/Library/Caches")
if sys.platform.startswith("win"):
cache_directory = os.getenv("LocalAppData") # noqa: SIM112
if cache_directory is not None:
return cache_directory
return os.path.expanduser("~/AppData/Local")
return os.path.expanduser("~/.cache")


def _get_readable_hash(obj, ignore_hash_seed: bool = False) -> str:
"""Get a human-readable hash of any hashable Python object.
The algorithm is fastest if `PYTHONHASHSEED
<https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_ is set.
Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`.
Args:
obj: Any hashable object, mutable or immutable, to be hashed.
ignore_hash_seed: Ignore the :code:`PYTHONHASHSEED` environment variable. If
:code:`True`, the hash seed is ignored and the hash is computed with
:func:`hashlib.sha256`.
"""
python_hash_seed = _get_python_hash_seed()
if ignore_hash_seed or python_hash_seed is None:
b = _to_bytes(obj)
return hashlib.sha256(b).hexdigest()
return f"pythonhashseed-{python_hash_seed}{hash(obj):+}"


def _to_bytes(obj) -> bytes:
if isinstance(obj, sp.Expr):
# Using the str printer is slower and not necessarily unique,
# but pickle.dumps() does not always result in the same bytes stream.
_warn_about_unsafe_hash()
return str(obj).encode()
return pickle.dumps(obj)


def _get_python_hash_seed() -> int | None:
python_hash_seed = os.environ.get("PYTHONHASHSEED", "")
if python_hash_seed is not None and python_hash_seed.isdigit():
return int(python_hash_seed)
return None


@functools.lru_cache(maxsize=None) # warn once
def _warn_about_unsafe_hash():
message = """
PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions,
set the PYTHONHASHSEED environment variable to a fixed value and rerun the program.
See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED
"""
message = dedent(message).replace("\n", " ").strip()
_LOGGER.warning(message)
87 changes: 87 additions & 0 deletions src/ampform/sympy/_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Helper functions for :func:`.perform_cached_doit`."""

from __future__ import annotations

import functools
import hashlib
import logging
import os
import pickle # noqa: S403
import sys
from textwrap import dedent

import sympy as sp

_LOGGER = logging.getLogger(__name__)


def get_system_cache_directory() -> str:
r"""Return the system cache directory for the current platform.
>>> import sys, pytest
>>> if sys.platform.startswith("darwin"):
... assert get_system_cache_directory().endswith("/Library/Caches")
>>> if sys.platform.startswith("linux"):
... assert get_system_cache_directory().endswith("/.cache")
>>> if sys.platform.startswith("win"):
... assert get_system_cache_directory().endswith(R"\AppData\Local")
"""
if sys.platform.startswith("linux"):
cache_directory = os.getenv("XDG_CACHE_HOME")
if cache_directory is not None:
return cache_directory
if sys.platform.startswith("darwin"): # macos
return os.path.expanduser("~/Library/Caches")
if sys.platform.startswith("win"):
cache_directory = os.getenv("LocalAppData") # noqa: SIM112
if cache_directory is not None:
return cache_directory
return os.path.expanduser("~/AppData/Local")
return os.path.expanduser("~/.cache")


def get_readable_hash(obj, ignore_hash_seed: bool = False) -> str:
"""Get a human-readable hash of any hashable Python object.
The algorithm is fastest if `PYTHONHASHSEED
<https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_ is set.
Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`.
Args:
obj: Any hashable object, mutable or immutable, to be hashed.
ignore_hash_seed: Ignore the :code:`PYTHONHASHSEED` environment variable. If
:code:`True`, the hash seed is ignored and the hash is computed with
:func:`hashlib.sha256`.
"""
python_hash_seed = _get_python_hash_seed()
if ignore_hash_seed or python_hash_seed is None:
b = _to_bytes(obj)
return hashlib.sha256(b).hexdigest()
return f"pythonhashseed-{python_hash_seed}{hash(obj):+}"


def _to_bytes(obj) -> bytes:
if isinstance(obj, sp.Expr):
# Using the str printer is slower and not necessarily unique,
# but pickle.dumps() does not always result in the same bytes stream.
_warn_about_unsafe_hash()
return str(obj).encode()
return pickle.dumps(obj)


def _get_python_hash_seed() -> int | None:
python_hash_seed = os.environ.get("PYTHONHASHSEED", "")
if python_hash_seed is not None and python_hash_seed.isdigit():
return int(python_hash_seed)
return None


@functools.lru_cache(maxsize=None) # warn once
def _warn_about_unsafe_hash():
message = """
PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions,
set the PYTHONHASHSEED environment variable to a fixed value and rerun the program.
See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED
"""
message = dedent(message).replace("\n", " ").strip()
_LOGGER.warning(message)
8 changes: 4 additions & 4 deletions tests/sympy/test_caching.py → tests/sympy/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sympy as sp

from ampform.dynamics import EnergyDependentWidth
from ampform.sympy import _get_readable_hash, _warn_about_unsafe_hash
from ampform.sympy._cache import _warn_about_unsafe_hash, get_readable_hash

if TYPE_CHECKING:
from _pytest.logging import LogCaptureFixture
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_get_readable_hash(assumptions, expected_hashes, caplog: LogCaptureFixtu
caplog.set_level(logging.WARNING)
x, y = sp.symbols("x y", **assumptions)
expr = x**2 + y
h_str = _get_readable_hash(expr)
h_str = get_readable_hash(expr)
python_hash_seed = os.environ.get("PYTHONHASHSEED")
if python_hash_seed is None:
assert h_str[:7] == "bbc9833"
Expand All @@ -88,7 +88,7 @@ def test_get_readable_hash_energy_dependent_width():
angular_momentum=angular_momentum,
meson_radius=d,
)
h = _get_readable_hash(expr)
h = get_readable_hash(expr)
python_hash_seed = os.environ.get("PYTHONHASHSEED")
if python_hash_seed is None:
pytest.skip("PYTHONHASHSEED has not been set")
Expand Down Expand Up @@ -124,4 +124,4 @@ def test_get_readable_hash_large(amplitude_model: tuple[str, HelicityModel]):
"canonical-helicity": "pythonhashseed-0-8505502895987205495",
"helicity": "pythonhashseed-0-1430245260241162669",
}[formalism]
assert _get_readable_hash(model.expression) == expected_hash
assert get_readable_hash(model.expression) == expected_hash

0 comments on commit 55a15eb

Please sign in to comment.