diff --git a/src/ampform/helicity/__init__.py b/src/ampform/helicity/__init__.py index 57fc4e2c3..f7ece5410 100644 --- a/src/ampform/helicity/__init__.py +++ b/src/ampform/helicity/__init__.py @@ -69,6 +69,10 @@ else: from singledispatchmethod import singledispatchmethod +if sys.version_info < (3, 12): + from typing_extensions import override +else: + from typing import override if TYPE_CHECKING: from IPython.lib.pretty import PrettyPrinter @@ -575,10 +579,12 @@ class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder): .. seealso:: `HelicityAmplitudeBuilder` and :doc:`/usage/helicity/formalism`. """ + @override def __init__(self, reaction: ReactionInfo) -> None: super().__init__(reaction) self._naming = CanonicalAmplitudeNameGenerator(reaction) + @override def _formulate_partial_decay( self, transition: StateTransition, node_id: int ) -> sp.Expr: diff --git a/src/ampform/helicity/naming.py b/src/ampform/helicity/naming.py index b427b458a..71f15575c 100644 --- a/src/ampform/helicity/naming.py +++ b/src/ampform/helicity/naming.py @@ -3,6 +3,7 @@ from __future__ import annotations import re +import sys from abc import ABC, abstractmethod from collections import defaultdict from functools import lru_cache @@ -20,6 +21,10 @@ group_by_spin_projection, ) +if sys.version_info < (3, 12): + from typing_extensions import override +else: + from typing import override if TYPE_CHECKING: from qrules.topology import Topology @@ -219,6 +224,7 @@ def generate_sequential_amplitude_suffix(self, transition: StateTransition) -> s class CanonicalAmplitudeNameGenerator(HelicityAmplitudeNameGenerator): + @override def __init__( self, transitions: ReactionInfo | Iterable[StateTransition], @@ -243,6 +249,7 @@ def insert_ls_combinations(self, value: bool) -> None: self.__insert_ls_combinations = value self._register_amplitude_coefficients() + @override def generate_amplitude_name( self, transition: StateTransition, @@ -262,6 +269,7 @@ def generate_amplitude_name( names.append(canonical_name) return "; ".join(names) + @override def _get_coefficient_components( self, transition: StateTransition, node_id: int ) -> tuple[str, str, str]: diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index 169b6baf7..c0f6895b1 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -19,6 +19,7 @@ import os import pickle import re +import sys import warnings from abc import abstractmethod from os.path import abspath, dirname, expanduser @@ -44,6 +45,10 @@ make_commutative, # pyright: ignore[reportUnusedImport] # noqa: F401 ) +if sys.version_info < (3, 12): + from typing_extensions import override +else: + from typing import override if TYPE_CHECKING: from sympy.printing.latex import LatexPrinter from sympy.printing.numpy import NumPyPrinter @@ -124,6 +129,7 @@ class PoolSum(sp.Expr): precedence = PRECEDENCE["Mul"] + @override def __new__( cls, expression, @@ -156,6 +162,7 @@ def indices(self) -> list[tuple[sp.Symbol, tuple[sp.Float, ...]]]: def free_symbols(self) -> set[sp.Basic]: return super().free_symbols - {s for s, _ in self.indices} + @override def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[override] expr = self.evaluate() if deep: @@ -360,10 +367,12 @@ class UnevaluatableIntegral(sp.Integral): limit = 50 dummify = True + @override def doit(self, **hints): args = [arg.doit(**hints) for arg in self.args] return self.func(*args) + @override def _numpycode(self, printer, *args): _warn_if_scipy_not_installed() integration_vars, limits = _unpack_integral_limits(self) diff --git a/src/ampform/sympy/_array_expressions.py b/src/ampform/sympy/_array_expressions.py index 22128596c..2d9428938 100644 --- a/src/ampform/sympy/_array_expressions.py +++ b/src/ampform/sympy/_array_expressions.py @@ -7,6 +7,7 @@ from __future__ import annotations import string +import sys from collections import abc from itertools import zip_longest from typing import TYPE_CHECKING, Iterable, overload @@ -26,11 +27,16 @@ get_shape, ) +if sys.version_info < (3, 12): + from typing_extensions import override +else: + from typing import override if TYPE_CHECKING: from sympy.printing.numpy import NumPyPrinter class ArrayElement(_ArrayExpr): + @override def __new__(cls, parent: sp.Expr, indices: Iterable) -> ArrayElement: # cspell:ignore sympified sympified_indices = sp.Tuple(*map(_sympify, indices)) @@ -121,6 +127,7 @@ class ArraySlice(_ArrayExpr): indices: tuple[sp.Tuple, ...] = property(lambda self: tuple(self.args[1])) # type: ignore[assignment] is_commutative = True + @override def __new__( cls, parent: sp.Basic, @@ -251,6 +258,7 @@ def _slice_to_str(self: Printer, x, dim) -> str: class ArraySum(sp.Expr): precedence = PRECEDENCE["Add"] + @override def __new__(cls, *terms: sp.Basic, **hints) -> ArraySum: terms = sp.sympify(terms) return sp.Expr.__new__(cls, *terms, **hints) @@ -307,6 +315,7 @@ def _strip_subscript_superscript(symbol: sp.Basic) -> str: class ArrayAxisSum(sp.Expr): is_commutative = True + @override def __new__(cls, array: sp.Expr, axis: int | None = None, **hints) -> ArrayAxisSum: if axis is not None and not isinstance(axis, (int, sp.Integer)): msg = "Only single digits allowed for axis" @@ -346,6 +355,7 @@ class ArrayMultiplication(sp.Expr): Lorentz matrices) to :math:`n\times\times4` (:math:`n` four-momentum tuples). """ + @override def __new__(cls, *tensors: sp.Basic, **hints) -> ArrayMultiplication: tensors = sp.sympify(tensors) return sp.Expr.__new__(cls, *tensors, **hints) @@ -400,6 +410,7 @@ class MatrixMultiplication(sp.Expr): Lorentz matrices) to :math:`n\times\times4\times4` (:math:`n` four-momentum tuples). """ + @override def __new__(cls, *tensors: sp.Basic, **hints) -> MatrixMultiplication: tensors = sp.sympify(tensors) return sp.Expr.__new__(cls, *tensors, **hints) diff --git a/src/ampform/sympy/deprecated.py b/src/ampform/sympy/deprecated.py index 2b32a574d..f1939b57d 100644 --- a/src/ampform/sympy/deprecated.py +++ b/src/ampform/sympy/deprecated.py @@ -6,12 +6,17 @@ from __future__ import annotations import functools +import sys from abc import abstractmethod from typing import TYPE_CHECKING, Callable, TypeVar from warnings import warn import sympy as sp +if sys.version_info < (3, 12): + from typing_extensions import override +else: + from typing import override if TYPE_CHECKING: from sympy.printing.latex import LatexPrinter @@ -58,6 +63,7 @@ def __init_subclass__(cls, **kwargs): ) super().__init_subclass__(**kwargs) + @override def __new__( cls: type[DecoratedClass], *args, @@ -103,6 +109,7 @@ def __getnewargs_ex__(self) -> tuple[tuple, dict]: kwargs = {"name": self._name} return args, kwargs + @override def _hashable_content(self) -> tuple: # https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165 # name is converted to string because unstable hash for None diff --git a/src/ampform/sympy/math.py b/src/ampform/sympy/math.py index 526cb3d69..eac223aa3 100644 --- a/src/ampform/sympy/math.py +++ b/src/ampform/sympy/math.py @@ -3,6 +3,7 @@ # cspell:ignore Lambdifier from __future__ import annotations +import sys from typing import TYPE_CHECKING, overload import sympy as sp @@ -10,6 +11,10 @@ from ampform.sympy import NumPyPrintable +if sys.version_info < (3, 12): + from typing_extensions import override +else: + from typing import override if TYPE_CHECKING: from sympy.printing.numpy import NumPyPrinter from sympy.printing.printer import Printer @@ -33,6 +38,7 @@ class ComplexSqrt(NumPyPrintable): def __new__(cls, x: sp.Number, *args, **kwargs) -> sp.Expr: ... # type: ignore[misc] @overload def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ... + @override def __new__(cls, x, *args, **kwargs): x = sp.sympify(x) args = sp.sympify((x, *args))