Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: implement perform_cached_doit() #333

Merged
merged 13 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 55 additions & 54 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,65 +1,66 @@
[flake8]
application-import-names =
ampform
ampform
filename =
./docs/*.py
./src/*.py
./tests/*.py
./docs/*.py
./src/*.py
./tests/*.py
exclude =
**/__pycache__
**/_build
*.pyi
/typings/**
**/__pycache__
**/_build
*.pyi
/typings/**
ignore =
# False positive with attribute docstrings
B018
# https://github.com/psf/black#slices
E203
# allowed by black
E231
# https://github.com/psf/black#line-length
E501
# should be possible to use {} in latex strings
FS003
# block quote ends without a blank line (black formatting)
RST201
# missing pygments
RST299
# unexpected indentation (related to google style docstring)
RST301
# false-positive error in math directive
RST307
# enforce type ignore with mypy error codes (combined --extend-select=TI100)
TI1
# https://github.com/psf/black#line-breaks--binary-operators
W503
# False positive with attribute docstrings
B018
# https://github.com/psf/black#slices
E203
# allowed by black
E231
# https://github.com/psf/black#line-length
E501
# should be possible to use {} in latex strings
FS003
# block quote ends without a blank line (black formatting)
RST201
# missing pygments
RST299
# unexpected indentation (related to google style docstring)
RST301
# false-positive error in math directive
RST307
# enforce type ignore with mypy error codes (combined --extend-select=TI100)
TI1
# https://github.com/psf/black#line-breaks--binary-operators
W503
extend-select =
TI100
TI100
per-file-ignores =
# unused imports for backward compatibility
src/ampform/dynamics/__init__.py:F401
# λ symbols for DPD paper
src/ampform/helicity/align/dpd.py:N806
# unused imports for backward compatibility
src/ampform/dynamics/__init__.py:F401
# λ symbols for DPD paper
src/ampform/helicity/align/dpd.py:N806
tests/sympy/test_caching.py:C408
radon-max-cc = 8
radon-no-assert = True
rst-roles =
attr
cite
class
doc
download
eq
file
func
meth
mod
pdg-review
ref
term
attr
cite
class
doc
download
eq
file
func
meth
mod
pdg-review
ref
term
rst-directives =
autolink-preface
automethod
deprecated
envvar
exception
seealso
autolink-preface
automethod
deprecated
envvar
exception
seealso
2 changes: 2 additions & 0 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
name: pytest
env:
PYTHONHASHSEED: "0"

on:
push:
Expand Down
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
],
"python.analysis.autoImportCompletions": false,
"python.analysis.diagnosticMode": "workspace",
"python.analysis.typeCheckingMode": "strict",
"python.formatting.provider": "black",
"python.linting.banditEnabled": false,
"python.linting.enabled": true,
Expand Down
2 changes: 1 addition & 1 deletion docs/_extend_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def _graphviz_to_image( # pylint: disable=too-many-arguments
options = {}
global _GRAPHVIZ_COUNTER # pylint: disable=global-statement
output_file = f"graphviz_{_GRAPHVIZ_COUNTER}"
_GRAPHVIZ_COUNTER += 1
_GRAPHVIZ_COUNTER += 1 # pyright: reportConstantRedefinition=false
graphviz.Source(dot).render(f"{_IMAGE_DIR}/{output_file}", format=format)
restructuredtext = "\n"
if label:
Expand Down
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

import requests

# pyright: reportConstantRedefinition=false
# pyright: reportMissingImports=false
# pyright: reportUntypedBaseClass=false
# pyright: reportUntypedFunctionDecorator=false
from pybtex.database import Entry
from pybtex.plugin import register_plugin
from pybtex.richtext import Tag, Text
Expand Down
38 changes: 37 additions & 1 deletion docs/usage/amplitude.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,34 @@
" model = pickle.load(stream)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cached expression 'unfolding'\n",
"\n",
"Amplitude model expressions can be extremely large. AmpForm can formulate such expressions relatively fast, but {mod}`sympy` has to 'unfold' these expressions with {meth}`~sympy.core.basic.Basic.doit`, which can take a long time. AmpForm provides a function that can cache the 'unfolded' expression to disk, so that the expression unfolding runs faster upon the next run."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ampform.sympy import perform_cached_doit\n",
"\n",
"full_expression = perform_cached_doit(model.expression)\n",
"sp.count_ops(full_expression)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See {func}`.perform_cached_doit` for some tips on how to improve performance."
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -1035,8 +1063,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ dependencies:
- |
-c .constraints/py3.8.txt
-e .[dev]
variables:
PYTHONHASHSEED: 0
16 changes: 15 additions & 1 deletion pyrightconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,24 @@
"exclude": [".git", ".tox", "docs/_build", "docs/adr"],
"include": ["docs", "src", "tests"],
"reportGeneralTypeIssues": false,
"reportIncompatibleMethodOverride": false,
"reportMissingParameterType": false,
"reportMissingTypeArgument": false,
"reportMissingTypeStubs": false,
"reportOverlappingOverload": false,
"reportPrivateImportUsage": false,
"reportPrivateUsage": false,
"reportUnboundVariable": false,
"reportUnknownArgumentType": false,
"reportUnknownMemberType": false,
"reportUnknownParameterType": false,
"reportUnknownVariableType": false,
"reportUnnecessaryComparison": false,
"reportUnnecessaryContains": false,
"reportUnnecessaryIsInstance": false,
"reportUnusedClass": true,
"reportUnusedFunction": true,
"reportUnusedImport": true,
"reportUnusedVariable": true
"reportUnusedVariable": true,
"typeCheckingMode": "strict"
}
3 changes: 2 additions & 1 deletion src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def __getnewargs_ex__(self) -> tuple[tuple, dict]:

def _hashable_content(self) -> tuple:
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
return (*self.args, self.phsp_factor, self._name)
# phsp_factor is converted to string because of unstable hash for classes
return (*super()._hashable_content(), str(self.phsp_factor))

def evaluate(self) -> sp.Expr:
s, mass0, gamma0, m_a, m_b, angular_momentum, meson_radius = self.args
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/helicity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def unfold_poolsums(expr: sp.Expr) -> sp.Expr:

intensity = self.intensity.evaluate()
intensity = unfold_poolsums(intensity)
return intensity.subs(self.amplitudes)
return intensity.xreplace(self.amplitudes)

def rename_symbols( # noqa: R701
self, renames: Iterable[tuple[str, str]] | Mapping[str, str]
Expand Down
84 changes: 84 additions & 0 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from __future__ import annotations

import functools
import hashlib
import itertools
import logging
import os
import pickle
import re
from abc import abstractmethod
from os.path import abspath, dirname, expanduser
from textwrap import dedent
from typing import Callable, Iterable, Sequence, SupportsFloat, TypeVar

import sympy as sp
Expand All @@ -16,6 +22,8 @@
from sympy.printing.numpy import NumPyPrinter
from sympy.printing.precedence import PRECEDENCE

_LOGGER = logging.getLogger(__name__)


class UnevaluatedExpression(sp.Expr):
"""Base class for expression classes with an :meth:`evaluate` method.
Expand Down Expand Up @@ -95,6 +103,11 @@ def __getnewargs_ex__(self) -> tuple[tuple, dict]:
kwargs = {"name": self._name}
return args, kwargs

def _hashable_content(self) -> tuple:
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
# name is converted to string because unstable hash for None
return (*super()._hashable_content(), str(self._name))

@abstractmethod
def evaluate(self) -> sp.Expr:
"""Evaluate and 'unfold' this `UnevaluatedExpression` by one level.
Expand Down Expand Up @@ -490,3 +503,74 @@ def determine_indices(symbol: sp.Basic) -> list[int]:
except SyntaxError:
return []
return list(indices)


def perform_cached_doit(
unevaluated_expr: sp.Expr, directory: str | None = None
) -> sp.Expr:
"""Perform :meth:`~sympy.core.basic.Basic.doit` cache the result to disk.

The cached result is fetched from disk if the hash of the original expression is the
same as the hash embedded in the filename.

Args:
unevaluated_expr: A `sympy.Expr <sympy.core.expr.Expr>` on which to call
:meth:`~sympy.core.basic.Basic.doit`.
directory: The directory in which to cache the result. If `None`, the cache
directory will be put under the home directory.

.. tip:: For a faster cache, set `PYTHONHASHSEED
<https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_ to a
fixed value.
"""
if directory is None:
home_directory = expanduser("~")
directory = abspath(f"{home_directory}/.sympy-cache")
h = get_readable_hash(unevaluated_expr)
filename = f"{directory}/{h}.pkl"
os.makedirs(dirname(filename), exist_ok=True)
if os.path.exists(filename):
with open(filename, "rb") as f:
return pickle.load(f)
_LOGGER.warning(
f"Cached expression file {filename} not found, performing doit()..."
)
unfolded_expr = unevaluated_expr.doit()
with open(filename, "wb") as f:
pickle.dump(unfolded_expr, f)
return unfolded_expr


def get_readable_hash(obj) -> str:
python_hash_seed = _get_python_hash_seed()
if python_hash_seed is not None:
return f"pythonhashseed-{python_hash_seed}{hash(obj):+}"
b = _to_bytes(obj)
return hashlib.sha256(b).hexdigest()


def _to_bytes(obj) -> bytes:
if isinstance(obj, sp.Expr):
# Using the str printer is slower and not necessarily unique,
# but pickle.dumps() does not always result in the same bytes stream.
_warn_about_unsafe_hash()
return str(obj).encode()
return pickle.dumps(obj)


def _get_python_hash_seed() -> int | None:
python_hash_seed = os.environ.get("PYTHONHASHSEED", "")
if python_hash_seed is not None and python_hash_seed.isdigit():
return int(python_hash_seed)
return None


@functools.lru_cache(maxsize=None) # warn once
def _warn_about_unsafe_hash():
message = """
PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions,
set the PYTHONHASHSEED environment variable to a fixed value and rerun the program.
See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED
"""
message = dedent(message).replace("\n", " ").strip()
_LOGGER.warning(message)
1 change: 1 addition & 0 deletions tests/dynamics/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_breit_wigner_with_energy_dependent_width(

builder.form_factor = True
bw_with_ff, parameters = builder(particle, variable_set)
# pyright: reportConstantRedefinition=false
L = variable_set.angular_momentum # noqa: N806
form_factor = formulate_form_factor(
s, m1, m2, angular_momentum=L, meson_radius=d
Expand Down
Loading