Skip to content

Commit

Permalink
FEAT: add unevaluated integral class (#394)
Browse files Browse the repository at this point in the history
See https://compwa.github.io/report/016#sympy-integral

* DOC: explain numerical integral examples
* DX: avoid IPython debug warnings
* ENH: emit warning if SciPy is not installed
  • Loading branch information
shenvitor authored and redeboer committed Feb 12, 2024
1 parent 84aa469 commit a761bc8
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 3 deletions.
6 changes: 6 additions & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"Dalitzplot",
"MAINT",
"Minkowski",
"PYDEVD",
"adrs",
"aitchison",
"arange",
Expand Down Expand Up @@ -80,9 +81,12 @@
"dtype",
"dummified",
"dummifies",
"dummify",
"einsum",
"elif",
"epem",
"epsabs",
"epsrel",
"eqnarray",
"eval",
"evalf",
Expand Down Expand Up @@ -202,6 +206,7 @@
"xlim",
"xreplace",
"xticks",
"xytext",
"yaxis",
"ylabel",
"ylim",
Expand Down Expand Up @@ -272,6 +277,7 @@
"tensorwaves",
"toctree",
"topness",
"unevaluatable",
"unitarity",
"venv",
"weisskopf",
Expand Down
197 changes: 196 additions & 1 deletion docs/usage/sympy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,201 @@
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Numerical integrals"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In hadron physics and high-energy physics, it often happens that models contain integrals that do not have an analytical solution.. They can arise in theoretical models, complex scattering problems, or in the analysis of experimental data. In such cases, we need to resort to numerical integrations.\n",
"\n",
"SymPy provides the [`sympy.Integral`](https://docs.sympy.org/latest/modules/integrals/integrals.html#sympy.integrals.integrals.Integral) class, but this does not give us control over whether or not we want to avoid integrating the class analytically. An example of such an analytically unsolvable integral is shown below. Note that the integral does not evaluate despite the `doit()` call."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sympy as sp\n",
"\n",
"x, a, b = sp.symbols(\"x a b\")\n",
"p = sp.Symbol(\"p\", positive=True)\n",
"integral_expr = sp.Integral(sp.exp(x) / (x**p + 1), (x, a, b))\n",
"integral_expr.doit()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For amplitude models that contain such integrals that should not be solved analytically, AmpForm provides the {class}`.UnevaluatableIntegral` class. It functions in the same way as [`sympy.Integral`](https://docs.sympy.org/latest/modules/integrals/integrals.html#sympy.integrals.integrals.Integral), but prevents the class from evaluating at all, even if the integral can be solved analytically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"from ampform.sympy import UnevaluatableIntegral\n",
"\n",
"UnevaluatableIntegral(x**p, (x, a, b)).doit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sp.Integral(x**p, (x, a, b)).doit()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This allows {class}`.UnevaluatableIntegral` to serve as a placeholder in expression trees that we call `doit` on when lambdifying to a numerical function. The resulting numerical function takes **complex-valued** and **multidimensional arrays** as function arguments.\n",
"\n",
"In the following, we see an example where the parameter $p$ inside the integral gets an array as input."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"integral_expr = UnevaluatableIntegral(sp.exp(x) / (x**p + 1), (x, a, b))\n",
"integral_func = sp.lambdify(args=[p, a, b], expr=integral_expr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"a_val = 1.2\n",
"b_val = 3.6\n",
"p_array = np.array([0.4, 0.6, 0.8])\n",
"\n",
"areas = integral_func(p_array, a_val, b_val)\n",
"areas"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"editable": true,
"jupyter": {
"source_hidden": true
},
"slideshow": {
"slide_type": ""
},
"tags": [
"hide-input",
"scroll-input"
]
},
"outputs": [],
"source": [
"%config InlineBackend.figure_formats = ['svg']\n",
"import matplotlib.pyplot as plt\n",
"\n",
"x_area = np.linspace(a_val, b_val, num=100)\n",
"x_line = np.linspace(0, 4, num=100)\n",
"\n",
"fig, ax = plt.subplots()\n",
"ax.set_xlabel(\"$x$\")\n",
"ax.set_ylabel(\"$x^p$\")\n",
"\n",
"for i, p_val in enumerate(p_array):\n",
" ax.plot(x_line, x_line**p_val, label=f\"$p={p_val}$\", c=f\"C{i}\")\n",
" ax.fill_between(x_area, x_area**p_val, alpha=(0.7 - i * 0.2), color=\"C0\")\n",
"\n",
"ax.text(\n",
" x=(a_val + b_val) / 2,\n",
" y=((a_val ** p_array[0] + b_val ** p_array[0]) / 2) * 0.5,\n",
" s=\"Area\",\n",
" horizontalalignment=\"center\",\n",
" verticalalignment=\"center\",\n",
")\n",
"text_kwargs = dict(ha=\"center\", textcoords=\"offset points\", xytext=(0, -15))\n",
"ax.annotate(\"a\", (a_val, 0.08), **text_kwargs)\n",
"ax.annotate(\"b\", (b_val, 0.08), **text_kwargs)\n",
"\n",
"ax.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"The arrays can be complex-valued as well. This is particularly useful when calculating dispersion integrals (see **[TR-003](https://compwa.github.io/report/003#general-dispersion-integral)**)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"outputs": [],
"source": [
"integral_func(\n",
" p=np.array([1.5 - 8.6j, -4.6 + 5.5j]),\n",
" a=a_val,\n",
" b=b_val,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"## Summations"
]
Expand Down Expand Up @@ -275,7 +470,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ dependencies:
- pip:
- -c .constraints/py3.11.txt -e .[dev]
variables:
PYDEVD_DISABLE_FILE_VALIDATION: 1
PRETTIER_LEGACY_CLI: "1"
PYTHONHASHSEED: 0
13 changes: 11 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ name = "ampform"
requires-python = ">=3.7"

[project.optional-dependencies]
all = ["ampform[viz]"]
all = [
"ampform[scipy]",
"ampform[viz]",
]
dev = [
"ampform[all]",
"ampform[doc]",
Expand All @@ -62,7 +65,7 @@ dev = [
]
doc = [
"Sphinx >=3",
"ampform[viz]",
"ampform[all]",
"black",
"ipympl",
"matplotlib",
Expand Down Expand Up @@ -102,13 +105,15 @@ mypy = [
"mypy >=0.730",
"sphinx-api-relink >=0.0.3",
]
scipy = ["scipy"]
sty = [
"ampform[format]",
"ampform[lint]",
"ampform[test]", # for pytest type hints
"pre-commit >=1.4.0",
]
test = [
"ampform[scipy]",
"black",
"ipywidgets", # symplot
"nbmake",
Expand Down Expand Up @@ -184,6 +189,10 @@ warn_unused_configs = true
ignore_missing_imports = true
module = ["graphviz.*"]

[[tool.mypy.overrides]]
ignore_missing_imports = true
module = ["scipy.*"]

[[tool.mypy.overrides]]
ignore_missing_imports = true
module = ["ipywidgets.*"]
Expand Down
46 changes: 46 additions & 0 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import pickle
import re
import warnings
from abc import abstractmethod
from os.path import abspath, dirname, expanduser
from textwrap import dedent
Expand All @@ -27,6 +28,7 @@
import sympy as sp
from sympy.printing.conventions import split_super_sub
from sympy.printing.precedence import PRECEDENCE
from sympy.printing.pycode import _unpack_integral_limits

from ._decorator import (
ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport]
Expand Down Expand Up @@ -350,3 +352,47 @@ def _warn_about_unsafe_hash():
"""
message = dedent(message).replace("\n", " ").strip()
_LOGGER.warning(message)


class UnevaluatableIntegral(sp.Integral):
abs_tolerance = 1e-5
rel_tolerance = 1e-5
limit = 50
dummify = True

def doit(self, **hints):
args = [arg.doit(**hints) for arg in self.args]
return self.func(*args)

def _numpycode(self, printer, *args):
_warn_if_scipy_not_installed()
integration_vars, limits = _unpack_integral_limits(self)
if len(limits) != 1 or len(integration_vars) != 1:
msg = f"Cannot handle {len(limits)}-dimensional integrals"
raise ValueError(msg)
x = integration_vars[0]
a, b = limits[0]
expr = self.args[0]
if self.dummify:
dummy = sp.Dummy()
expr = expr.xreplace({x: dummy})
x = dummy
integrate_func = "quad_vec"
printer.module_imports["scipy.integrate"].add(integrate_func)
return (
f"{integrate_func}(lambda {printer._print(x)}: {printer._print(expr)},"
f" {printer._print(a)}, {printer._print(b)},"
f" epsabs={self.abs_tolerance}, epsrel={self.abs_tolerance},"
f" limit={self.limit})[0]"
)


def _warn_if_scipy_not_installed() -> None:
try:
import scipy # noqa: F401 # pyright: ignore[reportUnusedImport, reportMissingImports]
except ImportError:
warnings.warn(
"Scipy is not installed. Install with 'pip install scipy' or with 'pip"
" install ampform[scipy]'",
stacklevel=1,
)
31 changes: 31 additions & 0 deletions tests/sympy/test_integral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import pytest
import sympy as sp

from ampform.sympy import UnevaluatableIntegral


class TestUnevaluatableIntegral:
def test_real_value_function(self):
x = sp.symbols("x")
integral_expr = UnevaluatableIntegral(x**2, (x, 1, 3))
func = sp.lambdify(args=[], expr=integral_expr)
assert func() == 26 / 3

@pytest.mark.parametrize(
"p_value,expected",
[
(2, 26 / 3),
(1, 4),
(1j, (1 / 2 - 1j / 2) * (-1 + 3 ** (1 + 1j))),
(
np.array([0, 0.5, 1, 2]),
np.array([2, 2 * 3 ** (1 / 2) - 2 / 3, 4, 8 + 2 / 3]),
),
],
)
def test_vectorized_parameter_function(self, p_value, expected):
x, p = sp.symbols("x,p")
integral_expr = UnevaluatableIntegral(x**p, (x, 1, 3))
func = sp.lambdify(args=[p], expr=integral_expr)
assert pytest.approx(func(p=p_value)) == expected

0 comments on commit a761bc8

Please sign in to comment.