diff --git a/.cspell.json b/.cspell.json index 83f931cec..53c4c1fac 100644 --- a/.cspell.json +++ b/.cspell.json @@ -42,6 +42,7 @@ "Dalitzplot", "MAINT", "Minkowski", + "PYDEVD", "adrs", "aitchison", "arange", @@ -80,9 +81,12 @@ "dtype", "dummified", "dummifies", + "dummify", "einsum", "elif", "epem", + "epsabs", + "epsrel", "eqnarray", "eval", "evalf", @@ -203,6 +207,7 @@ "xlim", "xreplace", "xticks", + "xytext", "yaxis", "ylabel", "ylim", @@ -273,6 +278,7 @@ "tensorwaves", "toctree", "topness", + "unevaluatable", "unitarity", "venv", "weisskopf", diff --git a/docs/usage/sympy.ipynb b/docs/usage/sympy.ipynb index 80369b35d..bc35b5cbf 100644 --- a/docs/usage/sympy.ipynb +++ b/docs/usage/sympy.ipynb @@ -216,6 +216,201 @@ { "cell_type": "markdown", "metadata": {}, + "source": [ + "## Numerical integrals" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In hadron physics and high-energy physics, it often happens that models contain integrals that do not have an analytical solution.. They can arise in theoretical models, complex scattering problems, or in the analysis of experimental data. In such cases, we need to resort to numerical integrations.\n", + "\n", + "SymPy provides the [`sympy.Integral`](https://docs.sympy.org/latest/modules/integrals/integrals.html#sympy.integrals.integrals.Integral) class, but this does not give us control over whether or not we want to avoid integrating the class analytically. An example of such an analytically unsolvable integral is shown below. Note that the integral does not evaluate despite the `doit()` call." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sympy as sp\n", + "\n", + "x, a, b = sp.symbols(\"x a b\")\n", + "p = sp.Symbol(\"p\", positive=True)\n", + "integral_expr = sp.Integral(sp.exp(x) / (x**p + 1), (x, a, b))\n", + "integral_expr.doit()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For amplitude models that contain such integrals that should not be solved analytically, AmpForm provides the {class}`.UnevaluatableIntegral` class. It functions in the same way as [`sympy.Integral`](https://docs.sympy.org/latest/modules/integrals/integrals.html#sympy.integrals.integrals.Integral), but prevents the class from evaluating at all, even if the integral can be solved analytically." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from ampform.sympy import UnevaluatableIntegral\n", + "\n", + "UnevaluatableIntegral(x**p, (x, a, b)).doit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sp.Integral(x**p, (x, a, b)).doit()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This allows {class}`.UnevaluatableIntegral` to serve as a placeholder in expression trees that we call `doit` on when lambdifying to a numerical function. The resulting numerical function takes **complex-valued** and **multidimensional arrays** as function arguments.\n", + "\n", + "In the following, we see an example where the parameter $p$ inside the integral gets an array as input." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "integral_expr = UnevaluatableIntegral(sp.exp(x) / (x**p + 1), (x, a, b))\n", + "integral_func = sp.lambdify(args=[p, a, b], expr=integral_expr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "a_val = 1.2\n", + "b_val = 3.6\n", + "p_array = np.array([0.4, 0.6, 0.8])\n", + "\n", + "areas = integral_func(p_array, a_val, b_val)\n", + "areas" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "jupyter": { + "source_hidden": true + }, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-input", + "scroll-input" + ] + }, + "outputs": [], + "source": [ + "%config InlineBackend.figure_formats = ['svg']\n", + "import matplotlib.pyplot as plt\n", + "\n", + "x_area = np.linspace(a_val, b_val, num=100)\n", + "x_line = np.linspace(0, 4, num=100)\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.set_xlabel(\"$x$\")\n", + "ax.set_ylabel(\"$x^p$\")\n", + "\n", + "for i, p_val in enumerate(p_array):\n", + " ax.plot(x_line, x_line**p_val, label=f\"$p={p_val}$\", c=f\"C{i}\")\n", + " ax.fill_between(x_area, x_area**p_val, alpha=(0.7 - i * 0.2), color=\"C0\")\n", + "\n", + "ax.text(\n", + " x=(a_val + b_val) / 2,\n", + " y=((a_val ** p_array[0] + b_val ** p_array[0]) / 2) * 0.5,\n", + " s=\"Area\",\n", + " horizontalalignment=\"center\",\n", + " verticalalignment=\"center\",\n", + ")\n", + "text_kwargs = dict(ha=\"center\", textcoords=\"offset points\", xytext=(0, -15))\n", + "ax.annotate(\"a\", (a_val, 0.08), **text_kwargs)\n", + "ax.annotate(\"b\", (b_val, 0.08), **text_kwargs)\n", + "\n", + "ax.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "The arrays can be complex-valued as well. This is particularly useful when calculating dispersion integrals (see **[TR-003](https://compwa.github.io/report/003#general-dispersion-integral)**)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "integral_func(\n", + " p=np.array([1.5 - 8.6j, -4.6 + 5.5j]),\n", + " a=a_val,\n", + " b=b_val,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, "source": [ "## Summations" ] @@ -275,7 +470,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.7" } }, "nbformat": 4, diff --git a/environment.yml b/environment.yml index fa1316535..8fb56b271 100644 --- a/environment.yml +++ b/environment.yml @@ -8,5 +8,6 @@ dependencies: - pip: - -c .constraints/py3.11.txt -e .[dev] variables: + PYDEVD_DISABLE_FILE_VALIDATION: 1 PRETTIER_LEGACY_CLI: "1" PYTHONHASHSEED: 0 diff --git a/pyproject.toml b/pyproject.toml index df518f736..0a346585f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,10 @@ name = "ampform" requires-python = ">=3.7" [project.optional-dependencies] -all = ["ampform[viz]"] +all = [ + "ampform[scipy]", + "ampform[viz]", +] dev = [ "ampform[all]", "ampform[doc]", @@ -62,7 +65,7 @@ dev = [ ] doc = [ "Sphinx >=3", - "ampform[viz]", + "ampform[all]", "black", "ipympl", "matplotlib", @@ -102,6 +105,7 @@ mypy = [ "mypy >=0.730", "sphinx-api-relink >=0.0.3", ] +scipy = ["scipy"] sty = [ "ampform[format]", "ampform[lint]", @@ -109,6 +113,7 @@ sty = [ "pre-commit >=1.4.0", ] test = [ + "ampform[scipy]", "black", "ipywidgets", # symplot "nbmake", @@ -184,6 +189,10 @@ warn_unused_configs = true ignore_missing_imports = true module = ["graphviz.*"] +[[tool.mypy.overrides]] +ignore_missing_imports = true +module = ["scipy.*"] + [[tool.mypy.overrides]] ignore_missing_imports = true module = ["ipywidgets.*"] diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index e0c9db2de..169b6baf7 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -19,6 +19,7 @@ import os import pickle import re +import warnings from abc import abstractmethod from os.path import abspath, dirname, expanduser from textwrap import dedent @@ -27,6 +28,7 @@ import sympy as sp from sympy.printing.conventions import split_super_sub from sympy.printing.precedence import PRECEDENCE +from sympy.printing.pycode import _unpack_integral_limits from ._decorator import ( ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport] @@ -350,3 +352,47 @@ def _warn_about_unsafe_hash(): """ message = dedent(message).replace("\n", " ").strip() _LOGGER.warning(message) + + +class UnevaluatableIntegral(sp.Integral): + abs_tolerance = 1e-5 + rel_tolerance = 1e-5 + limit = 50 + dummify = True + + def doit(self, **hints): + args = [arg.doit(**hints) for arg in self.args] + return self.func(*args) + + def _numpycode(self, printer, *args): + _warn_if_scipy_not_installed() + integration_vars, limits = _unpack_integral_limits(self) + if len(limits) != 1 or len(integration_vars) != 1: + msg = f"Cannot handle {len(limits)}-dimensional integrals" + raise ValueError(msg) + x = integration_vars[0] + a, b = limits[0] + expr = self.args[0] + if self.dummify: + dummy = sp.Dummy() + expr = expr.xreplace({x: dummy}) + x = dummy + integrate_func = "quad_vec" + printer.module_imports["scipy.integrate"].add(integrate_func) + return ( + f"{integrate_func}(lambda {printer._print(x)}: {printer._print(expr)}," + f" {printer._print(a)}, {printer._print(b)}," + f" epsabs={self.abs_tolerance}, epsrel={self.abs_tolerance}," + f" limit={self.limit})[0]" + ) + + +def _warn_if_scipy_not_installed() -> None: + try: + import scipy # noqa: F401 # pyright: ignore[reportUnusedImport, reportMissingImports] + except ImportError: + warnings.warn( + "Scipy is not installed. Install with 'pip install scipy' or with 'pip" + " install ampform[scipy]'", + stacklevel=1, + ) diff --git a/tests/sympy/test_integral.py b/tests/sympy/test_integral.py new file mode 100644 index 000000000..8d66423b9 --- /dev/null +++ b/tests/sympy/test_integral.py @@ -0,0 +1,31 @@ +import numpy as np +import pytest +import sympy as sp + +from ampform.sympy import UnevaluatableIntegral + + +class TestUnevaluatableIntegral: + def test_real_value_function(self): + x = sp.symbols("x") + integral_expr = UnevaluatableIntegral(x**2, (x, 1, 3)) + func = sp.lambdify(args=[], expr=integral_expr) + assert func() == 26 / 3 + + @pytest.mark.parametrize( + "p_value,expected", + [ + (2, 26 / 3), + (1, 4), + (1j, (1 / 2 - 1j / 2) * (-1 + 3 ** (1 + 1j))), + ( + np.array([0, 0.5, 1, 2]), + np.array([2, 2 * 3 ** (1 / 2) - 2 / 3, 4, 8 + 2 / 3]), + ), + ], + ) + def test_vectorized_parameter_function(self, p_value, expected): + x, p = sp.symbols("x,p") + integral_expr = UnevaluatableIntegral(x**p, (x, 1, 3)) + func = sp.lambdify(args=[p], expr=integral_expr) + assert pytest.approx(func(p=p_value)) == expected