Skip to content

Commit

Permalink
ENH: specify formalism with Literal
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Mar 3, 2024
1 parent fb132e2 commit 3ddf349
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 20 deletions.
3 changes: 2 additions & 1 deletion src/qrules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
EdgeSettings,
ProblemSet,
ReactionInfo,
SpinFormalism,
StateTransitionManager,
)

Expand Down Expand Up @@ -264,7 +265,7 @@ def generate_transitions( # noqa: PLR0917
final_state: Sequence[StateDefinition],
allowed_intermediate_particles: list[str] | None = None,
allowed_interaction_types: str | Iterable[str] | None = None,
formalism: str = "canonical-helicity",
formalism: SpinFormalism = "canonical-helicity",
particle_db: ParticleCollection | None = None,
mass_conservation_factor: float | None = 3.0,
max_angular_momentum: int = 2,
Expand Down
3 changes: 2 additions & 1 deletion src/qrules/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

if TYPE_CHECKING:
from qrules.particle import Particle, ParticleCollection
from qrules.transition import SpinFormalism

Check warning on line 53 in src/qrules/settings.py

View check run for this annotation

Codecov / codecov/patch

src/qrules/settings.py#L53

Added line #L53 was not covered by tests

__QRULES_PATH = dirname(realpath(__file__))
ADDITIONAL_PARTICLES_DEFINITIONS_PATH: str = join(
Expand Down Expand Up @@ -118,7 +119,7 @@ def from_str(description: str) -> InteractionType:


def create_interaction_settings( # noqa: PLR0917
formalism: str,
formalism: SpinFormalism,
particle_db: ParticleCollection,
nbody_topology: bool = False,
mass_conservation_factor: float | None = 3.0,
Expand Down
27 changes: 14 additions & 13 deletions src/qrules/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from copy import copy, deepcopy
from enum import Enum, auto
from multiprocessing import Pool
from typing import TYPE_CHECKING, Iterable, Sequence, overload
from typing import TYPE_CHECKING, Iterable, Literal, Sequence, overload

import attrs
from attrs import define, field, frozen
from attrs.validators import instance_of
from attrs.validators import in_, instance_of
from tqdm.auto import tqdm

from qrules._implementers import implement_pretty_repr
Expand Down Expand Up @@ -83,6 +83,12 @@

_LOGGER = logging.getLogger(__name__)

SpinFormalism = Literal[
"helicity",
"canonical-helicity",
"canonical",
]


class SolvingMode(Enum):
"""Types of modes for solving."""
Expand Down Expand Up @@ -226,7 +232,7 @@ def __init__( # noqa: C901, PLR0912, PLR0917
InteractionType, tuple[EdgeSettings, NodeSettings]
]
| None = None,
formalism: str = "helicity",
formalism: SpinFormalism = "helicity",
topology_building: str = "isobar",
solving_mode: SolvingMode = SolvingMode.FAST,
reload_pdg: bool = False,
Expand All @@ -240,18 +246,13 @@ def __init__( # noqa: C901, PLR0912, PLR0917
self.__number_of_threads = NumberOfThreads.get()
if interaction_type_settings is None:
interaction_type_settings = {}
allowed_formalisms = [
"helicity",
"canonical-helicity",
"canonical",
]
if formalism not in allowed_formalisms:
if formalism not in set(SpinFormalism.__args__): # type: ignore[attr-defined]
msg = (
f'Formalism "{formalism}" not implemented. Use one of'
f" {allowed_formalisms} instead."
f" {', '.join(SpinFormalism.__args__)} instead." # type: ignore[attr-defined]
)
raise NotImplementedError(msg)
self.__formalism = str(formalism)
self.__formalism = formalism
self.__particles = ParticleCollection()
if particle_db is not None:
self.__particles = particle_db
Expand Down Expand Up @@ -343,7 +344,7 @@ def set_allowed_intermediate_particles(
self.__intermediate_particle_filters = selected_particles.names

@property
def formalism(self) -> str:
def formalism(self) -> SpinFormalism:
return self.__formalism

def add_final_state_grouping(self, fs_group: list[str] | list[list[str]]) -> None:
Expand Down Expand Up @@ -744,7 +745,7 @@ class ReactionInfo:
"""Ordered collection of `StateTransition` instances."""

transitions: tuple[StateTransition, ...] = field(converter=_sort_tuple)
formalism: str = field(validator=instance_of(str))
formalism: SpinFormalism = field(validator=in_(SpinFormalism.__args__)) # type: ignore[attr-defined]

initial_state: FrozenDict[int, Particle] = field(init=False, repr=False, eq=False)
final_state: FrozenDict[int, Particle] = field(init=False, repr=False, eq=False)
Expand Down
3 changes: 2 additions & 1 deletion tests/channels/test_psi2s_to_eta_k_kstar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import qrules
from qrules.particle import ParticleCollection
from qrules.transition import SpinFormalism


@pytest.mark.parametrize("formalism", ["helicity", "canonical-helicity"])
Expand All @@ -15,7 +16,7 @@
["h(1)(1415)", "omega(1650)"],
],
)
def test_resonances(formalism, resonances, modified_pdg):
def test_resonances(formalism: SpinFormalism, resonances, modified_pdg):
reaction = qrules.generate_transitions(
initial_state=("psi(2S)", [+1, -1]),
final_state=["eta", "K-", "K*(892)+"],
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import TYPE_CHECKING

import pytest
from _pytest.fixtures import SubRequest
Expand All @@ -7,12 +8,15 @@
from qrules import ReactionInfo
from qrules.topology import Edge, Topology

if TYPE_CHECKING:
from qrules.transition import SpinFormalism

logging.basicConfig(level=logging.ERROR)


@pytest.fixture(scope="session", params=["canonical-helicity", "helicity"])
def reaction(request: SubRequest) -> ReactionInfo:
formalism: str = request.param
formalism: SpinFormalism = request.param
return qrules.generate_transitions(
initial_state=[("J/psi(1S)", [-1, 1])],
final_state=["gamma", "pi0", "pi0"],
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/io/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
create_isobar_topologies,
create_n_body_topology,
)
from qrules.transition import ReactionInfo
from qrules.transition import ReactionInfo, SpinFormalism


def test_asdot(reaction: ReactionInfo):
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_asdot_no_label_overwriting(reaction: ReactionInfo):
"formalism",
["canonical", "canonical-helicity", "helicity"],
)
def test_asdot_problemset(formalism: str):
def test_asdot_problemset(formalism: SpinFormalism):
stm = qrules.StateTransitionManager(
initial_state=[("J/psi(1S)", [+1])],
final_state=["gamma", "pi0", "pi0"],
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_int_domain,
create_interaction_settings,
)
from qrules.transition import SpinFormalism


class TestInteractionType:
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_create_interaction_settings(
particle_database: ParticleCollection,
interaction_type: InteractionType,
nbody_topology: bool,
formalism: str,
formalism: SpinFormalism,
):
settings = create_interaction_settings(
formalism,
Expand Down

0 comments on commit 3ddf349

Please sign in to comment.