Skip to content

Commit

Permalink
MAINT: mark methods with @override (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer authored Feb 12, 2024
1 parent 58f5614 commit 750dfd7
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/ampform/helicity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@
else:
from singledispatchmethod import singledispatchmethod

if sys.version_info < (3, 12):
from typing_extensions import override
else:
from typing import override
if TYPE_CHECKING:
from IPython.lib.pretty import PrettyPrinter

Expand Down Expand Up @@ -575,10 +579,12 @@ class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder):
.. seealso:: `HelicityAmplitudeBuilder` and :doc:`/usage/helicity/formalism`.
"""

@override
def __init__(self, reaction: ReactionInfo) -> None:
super().__init__(reaction)
self._naming = CanonicalAmplitudeNameGenerator(reaction)

@override
def _formulate_partial_decay(
self, transition: StateTransition, node_id: int
) -> sp.Expr:
Expand Down
8 changes: 8 additions & 0 deletions src/ampform/helicity/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import re
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import lru_cache
Expand All @@ -20,6 +21,10 @@
group_by_spin_projection,
)

if sys.version_info < (3, 12):
from typing_extensions import override
else:
from typing import override
if TYPE_CHECKING:
from qrules.topology import Topology

Expand Down Expand Up @@ -219,6 +224,7 @@ def generate_sequential_amplitude_suffix(self, transition: StateTransition) -> s


class CanonicalAmplitudeNameGenerator(HelicityAmplitudeNameGenerator):
@override
def __init__(
self,
transitions: ReactionInfo | Iterable[StateTransition],
Expand All @@ -243,6 +249,7 @@ def insert_ls_combinations(self, value: bool) -> None:
self.__insert_ls_combinations = value
self._register_amplitude_coefficients()

@override
def generate_amplitude_name(
self,
transition: StateTransition,
Expand All @@ -262,6 +269,7 @@ def generate_amplitude_name(
names.append(canonical_name)
return "; ".join(names)

@override
def _get_coefficient_components(
self, transition: StateTransition, node_id: int
) -> tuple[str, str, str]:
Expand Down
9 changes: 9 additions & 0 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import pickle
import re
import sys
import warnings
from abc import abstractmethod
from os.path import abspath, dirname, expanduser
Expand All @@ -44,6 +45,10 @@
make_commutative, # pyright: ignore[reportUnusedImport] # noqa: F401
)

if sys.version_info < (3, 12):
from typing_extensions import override
else:
from typing import override
if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter
from sympy.printing.numpy import NumPyPrinter
Expand Down Expand Up @@ -124,6 +129,7 @@ class PoolSum(sp.Expr):

precedence = PRECEDENCE["Mul"]

@override
def __new__(
cls,
expression,
Expand Down Expand Up @@ -156,6 +162,7 @@ def indices(self) -> list[tuple[sp.Symbol, tuple[sp.Float, ...]]]:
def free_symbols(self) -> set[sp.Basic]:
return super().free_symbols - {s for s, _ in self.indices}

@override
def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[override]
expr = self.evaluate()
if deep:
Expand Down Expand Up @@ -360,10 +367,12 @@ class UnevaluatableIntegral(sp.Integral):
limit = 50
dummify = True

@override
def doit(self, **hints):
args = [arg.doit(**hints) for arg in self.args]
return self.func(*args)

@override
def _numpycode(self, printer, *args):
_warn_if_scipy_not_installed()
integration_vars, limits = _unpack_integral_limits(self)
Expand Down
11 changes: 11 additions & 0 deletions src/ampform/sympy/_array_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import string
import sys
from collections import abc
from itertools import zip_longest
from typing import TYPE_CHECKING, Iterable, overload
Expand All @@ -26,11 +27,16 @@
get_shape,
)

if sys.version_info < (3, 12):
from typing_extensions import override
else:
from typing import override
if TYPE_CHECKING:
from sympy.printing.numpy import NumPyPrinter


class ArrayElement(_ArrayExpr):
@override
def __new__(cls, parent: sp.Expr, indices: Iterable) -> ArrayElement:
# cspell:ignore sympified
sympified_indices = sp.Tuple(*map(_sympify, indices))
Expand Down Expand Up @@ -121,6 +127,7 @@ class ArraySlice(_ArrayExpr):
indices: tuple[sp.Tuple, ...] = property(lambda self: tuple(self.args[1])) # type: ignore[assignment]
is_commutative = True

@override
def __new__(
cls,
parent: sp.Basic,
Expand Down Expand Up @@ -251,6 +258,7 @@ def _slice_to_str(self: Printer, x, dim) -> str:
class ArraySum(sp.Expr):
precedence = PRECEDENCE["Add"]

@override
def __new__(cls, *terms: sp.Basic, **hints) -> ArraySum:
terms = sp.sympify(terms)
return sp.Expr.__new__(cls, *terms, **hints)
Expand Down Expand Up @@ -307,6 +315,7 @@ def _strip_subscript_superscript(symbol: sp.Basic) -> str:
class ArrayAxisSum(sp.Expr):
is_commutative = True

@override
def __new__(cls, array: sp.Expr, axis: int | None = None, **hints) -> ArrayAxisSum:
if axis is not None and not isinstance(axis, (int, sp.Integer)):
msg = "Only single digits allowed for axis"
Expand Down Expand Up @@ -346,6 +355,7 @@ class ArrayMultiplication(sp.Expr):
Lorentz matrices) to :math:`n\times\times4` (:math:`n` four-momentum tuples).
"""

@override
def __new__(cls, *tensors: sp.Basic, **hints) -> ArrayMultiplication:
tensors = sp.sympify(tensors)
return sp.Expr.__new__(cls, *tensors, **hints)
Expand Down Expand Up @@ -400,6 +410,7 @@ class MatrixMultiplication(sp.Expr):
Lorentz matrices) to :math:`n\times\times4\times4` (:math:`n` four-momentum tuples).
"""

@override
def __new__(cls, *tensors: sp.Basic, **hints) -> MatrixMultiplication:
tensors = sp.sympify(tensors)
return sp.Expr.__new__(cls, *tensors, **hints)
Expand Down
7 changes: 7 additions & 0 deletions src/ampform/sympy/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
from __future__ import annotations

import functools
import sys
from abc import abstractmethod
from typing import TYPE_CHECKING, Callable, TypeVar
from warnings import warn

import sympy as sp

if sys.version_info < (3, 12):
from typing_extensions import override
else:
from typing import override
if TYPE_CHECKING:
from sympy.printing.latex import LatexPrinter

Expand Down Expand Up @@ -58,6 +63,7 @@ def __init_subclass__(cls, **kwargs):
)
super().__init_subclass__(**kwargs)

@override
def __new__(
cls: type[DecoratedClass],
*args,
Expand Down Expand Up @@ -103,6 +109,7 @@ def __getnewargs_ex__(self) -> tuple[tuple, dict]:
kwargs = {"name": self._name}
return args, kwargs

@override
def _hashable_content(self) -> tuple:
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
# name is converted to string because unstable hash for None
Expand Down
6 changes: 6 additions & 0 deletions src/ampform/sympy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
# cspell:ignore Lambdifier
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, overload

import sympy as sp
from sympy.plotting.experimental_lambdify import Lambdifier

from ampform.sympy import NumPyPrintable

if sys.version_info < (3, 12):
from typing_extensions import override
else:
from typing import override
if TYPE_CHECKING:
from sympy.printing.numpy import NumPyPrinter
from sympy.printing.printer import Printer
Expand All @@ -33,6 +38,7 @@ class ComplexSqrt(NumPyPrintable):
def __new__(cls, x: sp.Number, *args, **kwargs) -> sp.Expr: ... # type: ignore[misc]
@overload
def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ...
@override
def __new__(cls, x, *args, **kwargs):
x = sp.sympify(x)
args = sp.sympify((x, *args))
Expand Down

0 comments on commit 750dfd7

Please sign in to comment.