diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index babfe0c8a..96d80e26c 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -14,13 +14,12 @@ import itertools import logging -import os import pickle # noqa: S403 import re import sys import warnings from abc import abstractmethod -from os.path import abspath, dirname +from pathlib import Path from typing import TYPE_CHECKING, Iterable, Sequence, SupportsFloat import sympy as sp @@ -43,6 +42,10 @@ make_commutative, # pyright: ignore[reportUnusedImport] # noqa: F401 ) +if sys.version_info < (3, 8): + from importlib_metadata import version +else: + from importlib.metadata import version if sys.version_info < (3, 12): from typing_extensions import override else: @@ -335,7 +338,7 @@ def _warn_if_scipy_not_installed() -> None: def perform_cached_doit( - unevaluated_expr: sp.Expr, cache_directory: str | None = None + unevaluated_expr: sp.Expr, cache_directory: Path | str | None = None ) -> sp.Expr: """Perform :meth:`~sympy.core.basic.Basic.doit` and cache the result to disk. @@ -357,11 +360,14 @@ def perform_cached_doit( """ if cache_directory is None: system_cache_dir = get_system_cache_directory() - cache_directory = abspath(f"{system_cache_dir}/ampform") + sympy_version = version("sympy") + cache_directory = Path(system_cache_dir) / "ampform" / f"sympy-v{sympy_version}" + if not isinstance(cache_directory, Path): + cache_directory = Path(cache_directory) + cache_directory.mkdir(exist_ok=True, parents=True) h = get_readable_hash(unevaluated_expr) - filename = f"{cache_directory}/{h}.pkl" - os.makedirs(dirname(filename), exist_ok=True) - if os.path.exists(filename): + filename = cache_directory / f"{h}.pkl" + if filename.exists(): with open(filename, "rb") as f: return pickle.load(f) # noqa: S301 _LOGGER.warning( diff --git a/src/ampform/sympy/_cache.py b/src/ampform/sympy/_cache.py index 421f4d89c..4eedf83b6 100644 --- a/src/ampform/sympy/_cache.py +++ b/src/ampform/sympy/_cache.py @@ -18,7 +18,7 @@ def get_system_cache_directory() -> str: r"""Return the system cache directory for the current platform. - >>> import sys, pytest + >>> import sys >>> if sys.platform.startswith("darwin"): ... assert get_system_cache_directory().endswith("/Library/Caches") >>> if sys.platform.startswith("linux"):