Skip to content

Commit

Permalink
MAINT: upgrade to SymPy v1.13 (#435)
Browse files Browse the repository at this point in the history
* DX: ignore missing types `sympy`
* FIX: adjust simplification code for SymPy v1.13
* MAINT: address `mypy` errors
  • Loading branch information
redeboer authored Aug 6, 2024
1 parent d88ef2e commit c2a6fd0
Show file tree
Hide file tree
Showing 18 changed files with 44 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .constraints/py3.10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ sphinxcontrib-serializinghtml==1.1.10
sqlalchemy==2.0.31
stack-data==0.6.3
starlette==0.37.2
sympy==1.12.1
sympy==1.13.1
tabulate==0.9.0
terminado==0.18.1
tinycss2==1.3.0
Expand Down
2 changes: 1 addition & 1 deletion .constraints/py3.11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ sphinxcontrib-serializinghtml==1.1.10
sqlalchemy==2.0.31
stack-data==0.6.3
starlette==0.37.2
sympy==1.12.1
sympy==1.13.1
tabulate==0.9.0
terminado==0.18.1
tinycss2==1.3.0
Expand Down
2 changes: 1 addition & 1 deletion .constraints/py3.12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ sphinxcontrib-serializinghtml==1.1.10
sqlalchemy==2.0.31
stack-data==0.6.3
starlette==0.37.2
sympy==1.12.1
sympy==1.13.1
tabulate==0.9.0
terminado==0.18.1
tinycss2==1.3.0
Expand Down
2 changes: 1 addition & 1 deletion .constraints/py3.8.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ sphinxcontrib-qthelp==1.0.3
sphinxcontrib-serializinghtml==1.1.5
sqlalchemy==2.0.31
stack-data==0.6.3
sympy==1.12.1
sympy==1.13.1
tabulate==0.9.0
terminado==0.18.1
tinycss2==1.3.0
Expand Down
2 changes: 1 addition & 1 deletion .constraints/py3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ sphinxcontrib-serializinghtml==1.1.10
sqlalchemy==2.0.31
stack-data==0.6.3
starlette==0.37.2
sympy==1.12.1
sympy==1.13.1
tabulate==0.9.0
terminado==0.18.1
tinycss2==1.3.0
Expand Down
18 changes: 9 additions & 9 deletions docs/usage/dynamics/k-matrix.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -642,9 +642,9 @@
"outputs": [],
"source": [
"# reformulate terms\n",
"denominator, nominator = k_matrix.args\n",
"term1 = nominator.args[0] * denominator\n",
"term2 = nominator.args[1] * denominator\n",
"*rest, denominator, nominator = k_matrix.args\n",
"term1 = nominator.args[0] * denominator * sp.Mul(*rest)\n",
"term2 = nominator.args[1] * denominator * sp.Mul(*rest)\n",
"k_matrix = term1 + term2\n",
"k_matrix"
]
Expand Down Expand Up @@ -934,9 +934,9 @@
" sp.sqrt(rho): 1,\n",
" sp.conjugate(sp.sqrt(rho)): 1,\n",
"})\n",
"denominator, nominator = rel_k_matrix_2r.args\n",
"term1 = nominator.args[0] * denominator\n",
"term2 = nominator.args[1] * denominator\n",
"*rest, denominator, nominator = rel_k_matrix_2r.args\n",
"term1 = nominator.args[0] * denominator * sp.Mul(*rest)\n",
"term2 = nominator.args[1] * denominator * sp.Mul(*rest)\n",
"rel_k_matrix_2r = term1 + term2\n",
"rel_k_matrix_2r"
]
Expand Down Expand Up @@ -1081,9 +1081,9 @@
},
"outputs": [],
"source": [
"denominator, nominator = f_vector.args\n",
"term1 = nominator.args[0] * denominator\n",
"term2 = nominator.args[1] * denominator\n",
"*rest, denominator, nominator = f_vector.args\n",
"term1 = nominator.args[0] * denominator * sp.Mul(*rest)\n",
"term2 = nominator.args[1] * denominator * sp.Mul(*rest)\n",
"f_vector = term1 + term2\n",
"f_vector"
]
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ exclude = "_build"
show_error_codes = true
warn_unused_configs = true

[[tool.mypy.overrides]]
ignore_missing_imports = true
module = ["sympy.*"]

[[tool.mypy.overrides]]
ignore_missing_imports = true
module = ["graphviz.*"]
Expand Down
4 changes: 2 additions & 2 deletions src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class EnergyDependentWidth(sp.Expr):
m_b: Any
angular_momentum: Any
meson_radius: Any
phsp_factor: PhaseSpaceFactorProtocol = argument(
phsp_factor: PhaseSpaceFactorProtocol = argument( # type:ignore[assignment]
default=PhaseSpaceFactor, sympify=False
)
name: str | None = argument(default=None, sympify=False)
Expand Down Expand Up @@ -92,7 +92,7 @@ def relativistic_breit_wigner_with_ff( # noqa: PLR0917
m_b,
angular_momentum,
meson_radius,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment]
) -> sp.Expr:
"""Relativistic Breit-Wigner with `.FormFactor`.
Expand Down
8 changes: 4 additions & 4 deletions src/ampform/dynamics/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
phsp_factor: PhaseSpaceFactorProtocol | None = None,
) -> None:
if phsp_factor is None:
phsp_factor = PhaseSpaceFactor
phsp_factor = PhaseSpaceFactor # type:ignore[arg-type,assignment]
self.phsp_factor = phsp_factor
self.energy_dependent_width = energy_dependent_width
self.form_factor = form_factor
Expand Down Expand Up @@ -189,7 +189,7 @@ def __energy_dependent_breit_wigner(
m_b=m_b,
angular_momentum=angular_momentum,
meson_radius=meson_radius,
phsp_factor=self.phsp_factor,
phsp_factor=self.phsp_factor, # type:ignore[arg-type]
)
breit_wigner_expr = (res_mass * res_width) / (
res_mass**2 - s - mass_dependent_width * res_mass * sp.I
Expand Down Expand Up @@ -245,7 +245,7 @@ def __create_symbols(
create_relativistic_breit_wigner_with_ff = RelativisticBreitWignerBuilder(
energy_dependent_width=True,
form_factor=True,
phsp_factor=PhaseSpaceFactor,
phsp_factor=PhaseSpaceFactor, # type:ignore[arg-type]
).__call__
"""Create a `.relativistic_breit_wigner_with_ff` for a two-body decay.
Expand All @@ -256,7 +256,7 @@ def __create_symbols(
create_analytic_breit_wigner = RelativisticBreitWignerBuilder(
energy_dependent_width=True,
form_factor=True,
phsp_factor=EqualMassPhaseSpaceFactor,
phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type]
).__call__
"""Create a `.relativistic_breit_wigner_with_ff` with analytic continuation.
Expand Down
6 changes: 3 additions & 3 deletions src/ampform/dynamics/kmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def formulate( # type: ignore[override] # noqa: D417
n_poles,
parametrize: bool = True,
return_t_hat: bool = False,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment]
angular_momentum=0,
meson_radius=1,
) -> sp.MutableDenseMatrix:
Expand Down Expand Up @@ -116,7 +116,7 @@ def parametrization( # noqa: PLR0917
pole_id,
angular_momentum=0,
meson_radius=1,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment]
) -> sp.Expr:
def residue_function(pole_id, i) -> sp.Expr:
return residue_constant[pole_id, i] * sp.sqrt(
Expand Down Expand Up @@ -296,7 +296,7 @@ def formulate( # type: ignore[override] # noqa: D417
n_poles,
parametrize: bool = True,
return_f_hat: bool = False,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment]
angular_momentum=0,
meson_radius=1,
) -> sp.MutableDenseMatrix:
Expand Down
7 changes: 6 additions & 1 deletion src/ampform/kinematics/lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Any, Callable, Dict

import sympy as sp
Expand All @@ -17,6 +18,10 @@
)
from ampform.sympy.math import ComplexSqrt

if sys.version_info < (3, 10):
from typing_extensions import TypeAlias
else:
from typing import TypeAlias
if TYPE_CHECKING:
from qrules.topology import Topology
from sympy.printing.latex import LatexPrinter
Expand Down Expand Up @@ -45,7 +50,7 @@ def create_four_momentum_symbol(index: int) -> FourMomentumSymbol:
It's best to create a `dict` of `.FourMomenta` with
:func:`create_four_momentum_symbols`.
"""
FourMomentumSymbol = ArraySymbol
FourMomentumSymbol: TypeAlias = ArraySymbol
r"""Array-`~sympy.core.symbol.Symbol` that represents an array of four-momenta.
The array is assumed to be of shape :math:`n\times 4` with :math:`n` the number of
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def free_symbols(self) -> set[sp.Basic]:
return super().free_symbols - {s for s, _ in self.indices}

@override
def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[override]
def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[misc]
expr = self.evaluate()
if deep:
return expr.doit()
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]:
return expr.evaluate()
return expr

cls.__new__ = new_method # type: ignore[method-assign]
cls.__new__ = new_method # type: ignore[assignment]
cls.__getnewargs__ = _get_arguments # type: ignore[assignment,method-assign]
cls._hashable_content = _hashable_content_method # type: ignore[method-assign]
if non_sympy_fields:
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/sympy/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __getnewargs_ex__(self) -> tuple[tuple, dict]:
kwargs = {"name": self._name}
return args, kwargs

@override
@override # type:ignore[misc]
def _hashable_content(self) -> tuple:
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
# name is converted to string because unstable hash for None
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/sympy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ComplexSqrt(NumPyPrintable):
@overload
def __new__(cls, x: sp.Number, *args, **kwargs) -> sp.Expr: ... # type: ignore[misc]
@overload
def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ...
def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ... # type:ignore[misc]
@override
def __new__(cls, x, *args, **kwargs):
x = sp.sympify(x)
Expand Down
2 changes: 1 addition & 1 deletion tests/dynamics/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_pickle():
m_b=m_a,
angular_momentum=0,
meson_radius=1,
phsp_factor=EqualMassPhaseSpaceFactor,
phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type]
name="Gamma_1",
)
pickled_obj = pickle.dumps(expr)
Expand Down
4 changes: 2 additions & 2 deletions tests/dynamics/test_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_init():
m_b=m_b,
angular_momentum=angular_momentum,
meson_radius=d,
phsp_factor=EqualMassPhaseSpaceFactor,
phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type]
name="Gamma_1",
)
assert width.phsp_factor is EqualMassPhaseSpaceFactor
Expand All @@ -70,7 +70,7 @@ def test_doit_and_subs(self, method: str):
m_b=m_a,
angular_momentum=0,
meson_radius=1,
phsp_factor=PhaseSpaceFactorSWave,
phsp_factor=PhaseSpaceFactorSWave, # type:ignore[arg-type]
)
subs_first = round_nested(_subs(width, parameters, method).doit(), n_decimals=3)
doit_first = round_nested(_subs(width.doit(), parameters, method), n_decimals=3)
Expand Down
11 changes: 4 additions & 7 deletions tests/dynamics/test_kmatrix.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING

import pytest
import sympy as sp

from ampform.dynamics.kmatrix import NonRelativisticKMatrix
from symplot import rename_symbols, substitute_indexed_symbols

if TYPE_CHECKING:
import sympy as sp


class TestNonRelativisticKMatrix:
@pytest.mark.parametrize(
Expand All @@ -35,9 +32,9 @@ def test_interference_single_channel(self):
expr = substitute_indexed_symbols(expr)
expr = _remove_residue_constants(expr)
expr = _rename_widths(expr)
denominator, nominator = expr.args
term1 = nominator.args[0] * denominator
term2 = nominator.args[1] * denominator
*rest, denominator, nominator = expr.args
term1 = nominator.args[0] * denominator * sp.Mul(*rest)
term2 = nominator.args[1] * denominator * sp.Mul(*rest)
assert str(term1 / term2) == R"m1*w1*(m2**2 - s)/(m2*w2*(m1**2 - s))"


Expand Down

0 comments on commit c2a6fd0

Please sign in to comment.