Skip to content

Commit

Permalink
MAINT: address mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Aug 2, 2024
1 parent fc89828 commit 1d68af5
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/ampform/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class EnergyDependentWidth(sp.Expr):
m_b: Any
angular_momentum: Any
meson_radius: Any
phsp_factor: PhaseSpaceFactorProtocol = argument(
phsp_factor: PhaseSpaceFactorProtocol = argument( # type:ignore[assignment]
default=PhaseSpaceFactor, sympify=False
)
name: str | None = argument(default=None, sympify=False)
Expand Down Expand Up @@ -92,7 +92,7 @@ def relativistic_breit_wigner_with_ff( # noqa: PLR0917
m_b,
angular_momentum,
meson_radius,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment]
) -> sp.Expr:
"""Relativistic Breit-Wigner with `.FormFactor`.
Expand Down
8 changes: 4 additions & 4 deletions src/ampform/dynamics/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
phsp_factor: PhaseSpaceFactorProtocol | None = None,
) -> None:
if phsp_factor is None:
phsp_factor = PhaseSpaceFactor
phsp_factor = PhaseSpaceFactor # type:ignore[arg-type,assignment]
self.phsp_factor = phsp_factor
self.energy_dependent_width = energy_dependent_width
self.form_factor = form_factor
Expand Down Expand Up @@ -189,7 +189,7 @@ def __energy_dependent_breit_wigner(
m_b=m_b,
angular_momentum=angular_momentum,
meson_radius=meson_radius,
phsp_factor=self.phsp_factor,
phsp_factor=self.phsp_factor, # type:ignore[arg-type]
)
breit_wigner_expr = (res_mass * res_width) / (
res_mass**2 - s - mass_dependent_width * res_mass * sp.I
Expand Down Expand Up @@ -245,7 +245,7 @@ def __create_symbols(
create_relativistic_breit_wigner_with_ff = RelativisticBreitWignerBuilder(
energy_dependent_width=True,
form_factor=True,
phsp_factor=PhaseSpaceFactor,
phsp_factor=PhaseSpaceFactor, # type:ignore[arg-type]
).__call__
"""Create a `.relativistic_breit_wigner_with_ff` for a two-body decay.
Expand All @@ -256,7 +256,7 @@ def __create_symbols(
create_analytic_breit_wigner = RelativisticBreitWignerBuilder(
energy_dependent_width=True,
form_factor=True,
phsp_factor=EqualMassPhaseSpaceFactor,
phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type]
).__call__
"""Create a `.relativistic_breit_wigner_with_ff` with analytic continuation.
Expand Down
6 changes: 3 additions & 3 deletions src/ampform/dynamics/kmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def formulate( # type: ignore[override] # noqa: D417
n_poles,
parametrize: bool = True,
return_t_hat: bool = False,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment]
angular_momentum=0,
meson_radius=1,
) -> sp.MutableDenseMatrix:
Expand Down Expand Up @@ -116,7 +116,7 @@ def parametrization( # noqa: PLR0917
pole_id,
angular_momentum=0,
meson_radius=1,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment]
) -> sp.Expr:
def residue_function(pole_id, i) -> sp.Expr:
return residue_constant[pole_id, i] * sp.sqrt(
Expand Down Expand Up @@ -296,7 +296,7 @@ def formulate( # type: ignore[override] # noqa: D417
n_poles,
parametrize: bool = True,
return_f_hat: bool = False,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor,
phsp_factor: PhaseSpaceFactorProtocol = PhaseSpaceFactor, # type:ignore[assignment]
angular_momentum=0,
meson_radius=1,
) -> sp.MutableDenseMatrix:
Expand Down
7 changes: 6 additions & 1 deletion src/ampform/kinematics/lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Any, Callable, Dict

import sympy as sp
Expand All @@ -17,6 +18,10 @@
)
from ampform.sympy.math import ComplexSqrt

if sys.version_info < (3, 10):
from typing_extensions import TypeAlias
else:
from typing import TypeAlias
if TYPE_CHECKING:
from qrules.topology import Topology
from sympy.printing.latex import LatexPrinter
Expand Down Expand Up @@ -45,7 +50,7 @@ def create_four_momentum_symbol(index: int) -> FourMomentumSymbol:
It's best to create a `dict` of `.FourMomenta` with
:func:`create_four_momentum_symbols`.
"""
FourMomentumSymbol = ArraySymbol
FourMomentumSymbol: TypeAlias = ArraySymbol
r"""Array-`~sympy.core.symbol.Symbol` that represents an array of four-momenta.
The array is assumed to be of shape :math:`n\times 4` with :math:`n` the number of
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def free_symbols(self) -> set[sp.Basic]:
return super().free_symbols - {s for s, _ in self.indices}

@override
def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[override]
def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[misc]
expr = self.evaluate()
if deep:
return expr.doit()
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def new_method(cls, *args, evaluate: bool = False, **kwargs) -> type[ExprClass]:
return expr.evaluate()
return expr

cls.__new__ = new_method # type: ignore[method-assign]
cls.__new__ = new_method # type: ignore[assignment]
cls.__getnewargs__ = _get_arguments # type: ignore[assignment,method-assign]
cls._hashable_content = _hashable_content_method # type: ignore[method-assign]
if non_sympy_fields:
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/sympy/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __getnewargs_ex__(self) -> tuple[tuple, dict]:
kwargs = {"name": self._name}
return args, kwargs

@override
@override # type:ignore[misc]
def _hashable_content(self) -> tuple:
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
# name is converted to string because unstable hash for None
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/sympy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ComplexSqrt(NumPyPrintable):
@overload
def __new__(cls, x: sp.Number, *args, **kwargs) -> sp.Expr: ... # type: ignore[misc]
@overload
def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ...
def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ... # type:ignore[misc]
@override
def __new__(cls, x, *args, **kwargs):
x = sp.sympify(x)
Expand Down
2 changes: 1 addition & 1 deletion tests/dynamics/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_pickle():
m_b=m_a,
angular_momentum=0,
meson_radius=1,
phsp_factor=EqualMassPhaseSpaceFactor,
phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type]
name="Gamma_1",
)
pickled_obj = pickle.dumps(expr)
Expand Down
4 changes: 2 additions & 2 deletions tests/dynamics/test_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_init():
m_b=m_b,
angular_momentum=angular_momentum,
meson_radius=d,
phsp_factor=EqualMassPhaseSpaceFactor,
phsp_factor=EqualMassPhaseSpaceFactor, # type:ignore[arg-type]
name="Gamma_1",
)
assert width.phsp_factor is EqualMassPhaseSpaceFactor
Expand All @@ -70,7 +70,7 @@ def test_doit_and_subs(self, method: str):
m_b=m_a,
angular_momentum=0,
meson_radius=1,
phsp_factor=PhaseSpaceFactorSWave,
phsp_factor=PhaseSpaceFactorSWave, # type:ignore[arg-type]
)
subs_first = round_nested(_subs(width, parameters, method).doit(), n_decimals=3)
doit_first = round_nested(_subs(width.doit(), parameters, method), n_decimals=3)
Expand Down

0 comments on commit 1d68af5

Please sign in to comment.