From fa3a95adf8d3ca24faec5e8ec124332951210c69 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Wed, 24 Apr 2024 18:20:15 +0200 Subject: [PATCH] ENH: make `Parity.value` of type `Literal[-1, 1]` --- docs/conf.py | 1 + src/qrules/quantum_numbers.py | 32 +++++++++++++++++--------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 52e9fe3b..2ef91d0c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -77,6 +77,7 @@ def create_constraints_inventory() -> None: "qrules.topology.NodeType": "typing.TypeVar", "SpinFormalism": ("obj", "qrules.transition.SpinFormalism"), "StateDefinition": ("obj", "qrules.combinatorics.StateDefinition"), + "typing.Literal[-1, 1]": "typing.Literal", } api_target_types: dict[str, str | tuple[str, str]] = { "qrules.combinatorics.InitialFacts": "obj", diff --git a/src/qrules/quantum_numbers.py b/src/qrules/quantum_numbers.py index bafae8b6..f7fd5aeb 100644 --- a/src/qrules/quantum_numbers.py +++ b/src/qrules/quantum_numbers.py @@ -8,34 +8,36 @@ from __future__ import annotations +import sys from decimal import Decimal from fractions import Fraction from functools import total_ordering -from typing import Any, Generator, NewType, Union +from typing import Any, Generator, NewType, SupportsInt, Union -import attrs from attrs import field, frozen -from attrs.validators import instance_of from qrules._implementers import implement_pretty_repr +if sys.version_info < (3, 8): + from typing_extensions import Literal +else: + from typing import Literal -def _check_plus_minus(_: Any, __: attrs.Attribute, value: Any) -> None: - if not isinstance(value, int): - msg = ( - f"Input for {Parity.__name__} has to be of type {int.__name__}, not" - f" {type(value).__name__}" - ) - raise TypeError(msg) - if value not in {-1, +1}: - msg = f"Parity can only be +1 or -1, not {value}" - raise ValueError(msg) + +def _to_parity(value: SupportsInt) -> Literal[-1, 1]: + value = int(value) + if value == -1: + return -1 + if value == +1: + return 1 + msg = f"Parity can only be +1 or -1, not {value}" + raise ValueError(msg) @total_ordering @frozen(eq=False, hash=True, order=False, repr=False) class Parity: # noqa: PLW1641 - value: int = field(validator=[instance_of(int), _check_plus_minus]) + value: Literal[-1, 1] = field(converter=_to_parity) def __eq__(self, other: object) -> bool: if isinstance(other, Parity): @@ -47,7 +49,7 @@ def __gt__(self, other: Any) -> bool: return True return self.value > int(other) - def __int__(self) -> int: + def __int__(self) -> Literal[-1, 1]: return self.value def __neg__(self) -> Parity: