Skip to content

Commit

Permalink
MAINT: address remnaining Ruff issues
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Aug 6, 2024
1 parent ffba738 commit 1aa02ca
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 33 deletions.
4 changes: 2 additions & 2 deletions docs/_extend_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,8 @@ def _append_to_docstring(class_type: Callable | type, appended_text: str) -> Non


def __generate_transitions_cached(
initial_state: list[tuple[str, list[float | int]] | str],
final_state: list[tuple[str, list[float | int]] | str],
initial_state: list[tuple[str, list[float]] | str],
final_state: list[tuple[str, list[float]] | str],
formalism: SpinFormalism,
) -> ReactionInfo:
version = get_package_version("qrules")
Expand Down
18 changes: 12 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ line-ending = "lf"

[tool.ruff.lint]
ignore = [
"ANN401",
"ANN",
"ARG00",
"COM812",
"CPY001",
"D101",
Expand All @@ -296,6 +297,7 @@ ignore = [
"D416",
"DOC",
"E501",
"FBT00",
"FURB101",
"FURB103",
"FURB140",
Expand All @@ -319,6 +321,14 @@ extend-immutable-calls = [
[tool.ruff.lint.flake8-builtins]
builtins-ignorelist = ["display"]

[tool.ruff.lint.flake8-self]
ignore-names = [
"_latex",
"_module",
"_numpycode",
"_print",
]

[tool.ruff.lint.isort]
known-third-party = ["sympy"]
split-on-trailing-comma = false
Expand All @@ -328,7 +338,6 @@ split-on-trailing-comma = false
"**/docs/usage/symplot.ipynb" = ["RUF027"]
"**/docs/usage/sympy.ipynb" = ["E731"]
"*.ipynb" = [
"ANN",
"B018",
"C408",
"C90",
Expand All @@ -347,13 +356,12 @@ split-on-trailing-comma = false
"S101",
"S301",
"S403",
"SLF001",
"T20",
"TCH00",
]
"benchmarks/*" = [
"ANN",
"D",
"FBT001",
"INP001",
"PGH001",
"PLC2701",
Expand Down Expand Up @@ -383,10 +391,8 @@ split-on-trailing-comma = false
]
"setup.py" = ["D100"]
"tests/*" = [
"ANN",
"C408",
"D",
"FBT001",
"INP001",
"PGH001",
"PLC2701",
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _(obj: complex, **kwargs) -> str:
return f"{real}{plus}{imag}i"


def __downcast(obj: float, **kwargs) -> float | int:
def __downcast(obj: float, **kwargs) -> float:
if obj.is_integer():
return int(obj)
return obj
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/kinematics/lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _implement_latex_subscript( # pyright: ignore[reportUnusedFunction]
def decorator(decorated_class: type[ExprClass]) -> type[ExprClass]:
def _latex_repr_(self: sp.Expr, printer: LatexPrinter, *args) -> str:
momentum = printer._print(self.momentum) # type: ignore[attr-defined]
if printer._needs_mul_brackets(self.momentum): # type: ignore[attr-defined]
if printer._needs_mul_brackets(self.momentum): # type: ignore[attr-defined] # noqa: SLF001
momentum = Rf"\left({momentum}\right)"
else:
momentum = Rf"{{{momentum}}}"
Expand Down
6 changes: 5 additions & 1 deletion src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
from importlib_metadata import version
else:
from importlib.metadata import version
if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self
if sys.version_info < (3, 12):
from typing_extensions import override
else:
Expand Down Expand Up @@ -137,7 +141,7 @@ def __new__(
*indices: tuple[sp.Symbol, Iterable[sp.Basic]],
evaluate: bool = False,
**hints,
) -> PoolSum:
) -> Self:
converted_indices = []
for idx_symbol, values in indices:
values = tuple(values)
Expand Down
30 changes: 17 additions & 13 deletions src/ampform/sympy/_array_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
get_shape,
)

if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self
if sys.version_info < (3, 12):
from typing_extensions import override
else:
Expand All @@ -37,7 +41,7 @@

class ArrayElement(_ArrayExpr):
@override
def __new__(cls, parent: sp.Expr, indices: Iterable) -> ArrayElement:
def __new__(cls, parent: sp.Expr, indices: Iterable) -> Self:
# cspell:ignore sympified
sympified_indices = sp.Tuple(*map(_sympify, indices))
parent_shape = get_shape(parent)
Expand Down Expand Up @@ -73,7 +77,7 @@ def indices(self) -> sp.Tuple:


# required for lambdify
_ArrayExpr._iterable = False # type: ignore[attr-defined]
_ArrayExpr._iterable = False # type: ignore[attr-defined] # noqa: SLF001


@overload
Expand Down Expand Up @@ -132,7 +136,7 @@ def __new__(
cls,
parent: sp.Basic,
indices: tuple[sp.Basic | int | slice, ...],
) -> ArraySlice:
) -> Self:
parent_shape = get_shape(parent)
normalized_indices = []
for idx, axis_size in zip_longest(indices, parent_shape):
Expand Down Expand Up @@ -249,17 +253,17 @@ def _slice_to_str(self: Printer, x, dim) -> str:
return ":".join("" if xi in {none, None} else self._print(xi) for xi in x)


LatexPrinter._print_ArrayElement = _print_latex_ArrayElement # type: ignore[assignment]
LatexPrinter._print_ArraySlice = _print_latex_ArraySlice # type: ignore[attr-defined]
StrPrinter._print_ArrayElement = _print_str_ArrayElement # type: ignore[assignment]
StrPrinter._print_ArraySlice = _print_str_ArraySlice # type: ignore[attr-defined]
LatexPrinter._print_ArrayElement = _print_latex_ArrayElement # type: ignore[assignment] # noqa: SLF001
LatexPrinter._print_ArraySlice = _print_latex_ArraySlice # type: ignore[attr-defined] # noqa: SLF001
StrPrinter._print_ArrayElement = _print_str_ArrayElement # type: ignore[assignment] # noqa: SLF001
StrPrinter._print_ArraySlice = _print_str_ArraySlice # type: ignore[attr-defined] # noqa: SLF001


class ArraySum(sp.Expr):
precedence = PRECEDENCE["Add"]

@override
def __new__(cls, *terms: sp.Basic, **hints) -> ArraySum:
def __new__(cls, *terms: sp.Basic, **hints) -> Self:
terms = sp.sympify(terms)
return sp.Expr.__new__(cls, *terms, **hints)

Expand All @@ -274,15 +278,15 @@ def _latex(self, printer: LatexPrinter, *args) -> str:
name = next(iter(names))
subscript = "".join(map(_get_subscript, self.terms))
return f"{{{name}}}_{{{subscript}}}"
return printer._print_ArraySum(self) # type: ignore[attr-defined]
return printer._print_ArraySum(self) # type: ignore[attr-defined] # noqa: SLF001


def _print_array_sum(self: Printer, expr: ArraySum) -> str:
terms = map(self._print, expr.terms)
return " + ".join(terms)


Printer._print_ArraySum = _print_array_sum # type: ignore[attr-defined]
Printer._print_ArraySum = _print_array_sum # type: ignore[attr-defined] # noqa: SLF001


def _get_subscript(symbol: sp.Basic) -> str:
Expand Down Expand Up @@ -316,7 +320,7 @@ class ArrayAxisSum(sp.Expr):
is_commutative = True

@override
def __new__(cls, array: sp.Expr, axis: int | None = None, **hints) -> ArrayAxisSum:
def __new__(cls, array: sp.Expr, axis: int | None = None, **hints) -> Self:
if axis is not None and not isinstance(axis, (int, sp.Integer)):
msg = "Only single digits allowed for axis"
raise TypeError(msg)
Expand Down Expand Up @@ -356,7 +360,7 @@ class ArrayMultiplication(sp.Expr):
"""

@override
def __new__(cls, *tensors: sp.Basic, **hints) -> ArrayMultiplication:
def __new__(cls, *tensors: sp.Basic, **hints) -> Self:
tensors = sp.sympify(tensors)
return sp.Expr.__new__(cls, *tensors, **hints)

Expand Down Expand Up @@ -411,7 +415,7 @@ class MatrixMultiplication(sp.Expr):
"""

@override
def __new__(cls, *tensors: sp.Basic, **hints) -> MatrixMultiplication:
def __new__(cls, *tensors: sp.Basic, **hints) -> Self:
tensors = sp.sympify(tensors)
return sp.Expr.__new__(cls, *tensors, **hints)

Expand Down
4 changes: 2 additions & 2 deletions src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def _eval_subs_method(self, old, new, **hints):
continue
if isclass(old_arg):
continue
new_attr = old_arg._subs(old, new, **hints)
new_attr = old_arg._subs(old, new, **hints) # noqa: SLF001
if not _aresame(new_attr, old_arg):
hit = True
new_args[i] = new_attr
Expand Down Expand Up @@ -533,7 +533,7 @@ def _xreplace_method(self, rule) -> tuple[sp.Expr, bool]:
hit = False
for arg in _get_arguments(self):
if hasattr(arg, "_xreplace") and not isclass(arg):
replace_result, is_replaced = arg._xreplace(rule)
replace_result, is_replaced = arg._xreplace(rule) # noqa: SLF001
elif isinstance(rule, abc.Mapping):
is_replaced = bool(arg in rule)
replace_result = rule.get(arg, arg)
Expand Down
8 changes: 4 additions & 4 deletions src/ampform/sympy/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def __new__(
"""
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L113-L119
obj = object.__new__(cls)
obj._args = args
obj._assumptions = cls.default_assumptions # type: ignore[attr-defined]
obj._mhash = None # cspell:ignore mhash
obj._name = name
obj._args = args # noqa: SLF001
obj._assumptions = cls.default_assumptions # type: ignore[attr-defined] # noqa: SLF001
obj._mhash = None # cspell:ignore mhash # noqa: SLF001
obj._name = name # noqa: SLF001
return obj

def __getnewargs_ex__(self) -> tuple[tuple, dict]:
Expand Down
4 changes: 2 additions & 2 deletions src/symplot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def set_values(self, *args: dict[str, float], **kwargs: float) -> None:
for keyword, value in value_mapping.items():
try:
self[keyword].value = value
except KeyError:
except KeyError: # noqa: PERF203
_LOGGER.warning(f'There is no slider with name or symbol "{keyword}"')
continue

Expand Down Expand Up @@ -219,7 +219,7 @@ def _is_min_max(

def _is_min_max_step(
range_def: RangeDefinition,
) -> TypeGuard[tuple[float, float, float | int]]:
) -> TypeGuard[tuple[float, float, float]]:
return len(range_def) == 3 # noqa: PLR2004


Expand Down
2 changes: 1 addition & 1 deletion tests/symplot/test_symplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_set_ranges(
slider_name: str,
min_: float,
max_: float,
n_steps: float | int | None,
n_steps: float | None,
step_size: float,
slider_kwargs: SliderKwargs,
) -> None:
Expand Down

0 comments on commit 1aa02ca

Please sign in to comment.