From 1d68af5b46b9f8d2942eb7b3888b60f678fc220e Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:57:22 +0200 Subject: [PATCH] MAINT: address `mypy` errors --- src/ampform/dynamics/__init__.py | 4 ++-- src/ampform/dynamics/builder.py | 8 ++++---- src/ampform/dynamics/kmatrix.py | 6 +++--- src/ampform/kinematics/lorentz.py | 7 ++++++- src/ampform/sympy/__init__.py | 2 +- src/ampform/sympy/_decorator.py | 2 +- src/ampform/sympy/deprecated.py | 2 +- src/ampform/sympy/math.py | 2 +- tests/dynamics/test_deprecated.py | 2 +- tests/dynamics/test_dynamics.py | 4 ++-- 10 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/ampform/dynamics/__init__.py b/src/ampform/dynamics/__init__.py index 20d044774..cb0260e88 100644 --- a/src/ampform/dynamics/__init__.py +++ b/src/ampform/dynamics/__init__.py @@ -54,7 +54,7 @@ class EnergyDependentWidth(sp.Expr): m_b: Any angular_momentum: Any meson_radius: Any - phsp_factor: PhaseSpaceFactorProtocol = argument( + phsp_factor: PhaseSpaceFactorProtocol = argument( # type:ignore[assignment] default=PhaseSpaceFactor, sympify=False ) name: str | None = argument(default=None, sympify=False) @@ -92,7 +92,7 @@ def relativistic_breit_wigner_with_ff( # noqa: PLR0917 m_b, angular_momentum, meson_radius, - phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, + phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment] ) -> sp.Expr: """Relativistic Breit-Wigner with `.FormFactor`. diff --git a/src/ampform/dynamics/builder.py b/src/ampform/dynamics/builder.py index 7e717e661..2a5904c53 100644 --- a/src/ampform/dynamics/builder.py +++ b/src/ampform/dynamics/builder.py @@ -123,7 +123,7 @@ def __init__( phsp_factor: PhaseSpaceFactorProtocol | None = None, ) -> None: if phsp_factor is None: - phsp_factor = PhaseSpaceFactor + phsp_factor = PhaseSpaceFactor # type:ignore[arg-type,assignment] self.phsp_factor = phsp_factor self.energy_dependent_width = energy_dependent_width self.form_factor = form_factor @@ -189,7 +189,7 @@ def __energy_dependent_breit_wigner( m_b=m_b, angular_momentum=angular_momentum, meson_radius=meson_radius, - phsp_factor=self.phsp_factor, + phsp_factor=self.phsp_factor, # type:ignore[arg-type] ) breit_wigner_expr = (res_mass * res_width) / ( res_mass**2 - s - mass_dependent_width * res_mass * sp.I @@ -245,7 +245,7 @@ def __create_symbols( create_relativistic_breit_wigner_with_ff = RelativisticBreitWignerBuilder( energy_dependent_width=True, form_factor=True, - phsp_factor=PhaseSpaceFactor, + phsp_factor=PhaseSpaceFactor, # type:ignore[arg-type] ).__call__ """Create a `.relativistic_breit_wigner_with_ff` for a two-body decay. @@ -256,7 +256,7 @@ def __create_symbols( create_analytic_breit_wigner = RelativisticBreitWignerBuilder( energy_dependent_width=True, form_factor=True, - phsp_factor=EqualMassPhaseSpaceFactor, + phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type] ).__call__ """Create a `.relativistic_breit_wigner_with_ff` with analytic continuation. diff --git a/src/ampform/dynamics/kmatrix.py b/src/ampform/dynamics/kmatrix.py index da2e3bf63..47bf7dd29 100644 --- a/src/ampform/dynamics/kmatrix.py +++ b/src/ampform/dynamics/kmatrix.py @@ -56,7 +56,7 @@ def formulate( # type: ignore[override] # noqa: D417 n_poles, parametrize: bool = True, return_t_hat: bool = False, - phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, + phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment] angular_momentum=0, meson_radius=1, ) -> sp.MutableDenseMatrix: @@ -116,7 +116,7 @@ def parametrization( # noqa: PLR0917 pole_id, angular_momentum=0, meson_radius=1, - phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, + phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment] ) -> sp.Expr: def residue_function(pole_id, i) -> sp.Expr: return residue_constant[pole_id, i] * sp.sqrt( @@ -296,7 +296,7 @@ def formulate( # type: ignore[override] # noqa: D417 n_poles, parametrize: bool = True, return_f_hat: bool = False, - phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, + phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment] angular_momentum=0, meson_radius=1, ) -> sp.MutableDenseMatrix: diff --git a/src/ampform/kinematics/lorentz.py b/src/ampform/kinematics/lorentz.py index 6d4a14deb..076f50554 100644 --- a/src/ampform/kinematics/lorentz.py +++ b/src/ampform/kinematics/lorentz.py @@ -2,6 +2,7 @@ from __future__ import annotations +import sys from typing import TYPE_CHECKING, Any, Callable, Dict import sympy as sp @@ -17,6 +18,10 @@ ) from ampform.sympy.math import ComplexSqrt +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias if TYPE_CHECKING: from qrules.topology import Topology from sympy.printing.latex import LatexPrinter @@ -45,7 +50,7 @@ def create_four_momentum_symbol(index: int) -> FourMomentumSymbol: It's best to create a `dict` of `.FourMomenta` with :func:`create_four_momentum_symbols`. """ -FourMomentumSymbol = ArraySymbol +FourMomentumSymbol: TypeAlias = ArraySymbol r"""Array-`~sympy.core.symbol.Symbol` that represents an array of four-momenta. The array is assumed to be of shape :math:`n\times 4` with :math:`n` the number of diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index c21bedb10..150e7864b 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -164,7 +164,7 @@ def free_symbols(self) -> set[sp.Basic]: return super().free_symbols - {s for s, _ in self.indices} @override - def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[override] + def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[misc] expr = self.evaluate() if deep: return expr.doit() diff --git a/src/ampform/sympy/_decorator.py b/src/ampform/sympy/_decorator.py index c7c417d6e..ebd5a27a4 100644 --- a/src/ampform/sympy/_decorator.py +++ b/src/ampform/sympy/_decorator.py @@ -274,7 +274,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]: return expr.evaluate() return expr - cls.__new__ = new_method # type: ignore[method-assign] + cls.__new__ = new_method # type: ignore[assignment] cls.__getnewargs__ = _get_arguments # type: ignore[assignment,method-assign] cls._hashable_content = _hashable_content_method # type: ignore[method-assign] if non_sympy_fields: diff --git a/src/ampform/sympy/deprecated.py b/src/ampform/sympy/deprecated.py index 34211e113..135caaf34 100644 --- a/src/ampform/sympy/deprecated.py +++ b/src/ampform/sympy/deprecated.py @@ -108,7 +108,7 @@ def __getnewargs_ex__(self) -> tuple[tuple, dict]: kwargs = {"name": self._name} return args, kwargs - @override + @override # type:ignore[misc] def _hashable_content(self) -> tuple: # https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165 # name is converted to string because unstable hash for None diff --git a/src/ampform/sympy/math.py b/src/ampform/sympy/math.py index eac223aa3..04c40dcd5 100644 --- a/src/ampform/sympy/math.py +++ b/src/ampform/sympy/math.py @@ -37,7 +37,7 @@ class ComplexSqrt(NumPyPrintable): @overload def __new__(cls, x: sp.Number, *args, **kwargs) -> sp.Expr: ... # type: ignore[misc] @overload - def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ... + def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ... # type:ignore[misc] @override def __new__(cls, x, *args, **kwargs): x = sp.sympify(x) diff --git a/tests/dynamics/test_deprecated.py b/tests/dynamics/test_deprecated.py index 0315b30f6..3bc2c364c 100644 --- a/tests/dynamics/test_deprecated.py +++ b/tests/dynamics/test_deprecated.py @@ -38,7 +38,7 @@ def test_pickle(): m_b=m_a, angular_momentum=0, meson_radius=1, - phsp_factor=EqualMassPhaseSpaceFactor, + phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type] name="Gamma_1", ) pickled_obj = pickle.dumps(expr) diff --git a/tests/dynamics/test_dynamics.py b/tests/dynamics/test_dynamics.py index 6b8d83201..06042f4dd 100644 --- a/tests/dynamics/test_dynamics.py +++ b/tests/dynamics/test_dynamics.py @@ -47,7 +47,7 @@ def test_init(): m_b=m_b, angular_momentum=angular_momentum, meson_radius=d, - phsp_factor=EqualMassPhaseSpaceFactor, + phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type] name="Gamma_1", ) assert width.phsp_factor is EqualMassPhaseSpaceFactor @@ -70,7 +70,7 @@ def test_doit_and_subs(self, method: str): m_b=m_a, angular_momentum=0, meson_radius=1, - phsp_factor=PhaseSpaceFactorSWave, + phsp_factor=PhaseSpaceFactorSWave, # type:ignore[arg-type] ) subs_first = round_nested(_subs(width, parameters, method).doit(), n_decimals=3) doit_first = round_nested(_subs(width.doit(), parameters, method), n_decimals=3)