Skip to content

Commit

Permalink
refactor: simplify interaction type determination (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer authored Jun 18, 2021
1 parent 0c5df26 commit 9fa8183
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 82 deletions.
71 changes: 24 additions & 47 deletions src/qrules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@
"""

from itertools import product
from typing import Dict, FrozenSet, List, Optional, Sequence, Set, Union
from typing import (
Dict,
FrozenSet,
Iterable,
List,
Optional,
Sequence,
Set,
Union,
)

import attr

Expand Down Expand Up @@ -282,7 +291,7 @@ def generate_transitions( # pylint: disable=too-many-arguments
initial_state: Union[StateDefinition, Sequence[StateDefinition]],
final_state: Sequence[StateDefinition],
allowed_intermediate_particles: Optional[List[str]] = None,
allowed_interaction_types: Optional[Union[str, List[str]]] = None,
allowed_interaction_types: Optional[Union[str, Iterable[str]]] = None,
formalism: str = "canonical-helicity",
particle_db: Optional[ParticleCollection] = None,
mass_conservation_factor: Optional[float] = 3.0,
Expand Down Expand Up @@ -310,9 +319,9 @@ def generate_transitions( # pylint: disable=too-many-arguments
states that you want to allow as intermediate states. This helps
(1) filter out resonances and (2) speed up computation time.
allowed_interaction_types (`str`, optional): Interaction types you want
to consider. For instance, both :code:`"strong and EM"` and
:code:`["s", "em"]` results in `~.InteractionType.EM` and
allowed_interaction_types: Interaction types you want to consider. For
instance, :code:`["s", "em"]` results in `~.InteractionType.EM` and
`~.InteractionType.STRONG` and :code:`["strong"]` results in
`~.InteractionType.STRONG`.
formalism (`str`, optional): Formalism that you intend to use in
Expand Down Expand Up @@ -352,7 +361,7 @@ def generate_transitions( # pylint: disable=too-many-arguments
... initial_state="D0",
... final_state=["K~0", "K+", "K-"],
... allowed_intermediate_particles=["a(0)(980)", "a(2)(1320)-"],
... allowed_interaction_types="ew",
... allowed_interaction_types=["e", "w"],
... formalism="helicity",
... particle_db=qrules.load_pdg(),
... topology_building="isobar",
Expand All @@ -379,52 +388,20 @@ def generate_transitions( # pylint: disable=too-many-arguments
number_of_threads=number_of_threads,
)
if allowed_interaction_types is not None:
interaction_types = _determine_interaction_types(
allowed_interaction_types
)
if isinstance(allowed_interaction_types, str):
interaction_types = [
InteractionType.from_str(allowed_interaction_types)
]
else:
interaction_types = [
InteractionType.from_str(description)
for description in allowed_interaction_types
]
stm.set_allowed_interaction_types(list(interaction_types))
problem_sets = stm.create_problem_sets()
return stm.find_solutions(problem_sets)


def _determine_interaction_types(
description: Union[str, List[str]]
) -> Set[InteractionType]:
interaction_types: Set[InteractionType] = set()
if isinstance(description, list):
for i in description:
interaction_types.update(
_determine_interaction_types(description=i)
)
return interaction_types
if not isinstance(description, str):
raise TypeError(
"Cannot handle interaction description of type "
f"{description.__class__.__name__}"
)
if len(description) == 0:
raise ValueError('Provided an empty interaction type ("")')
interaction_name_lower = description.lower()
if "all" in interaction_name_lower:
for interaction in InteractionType:
interaction_types.add(interaction)
if (
"em" in interaction_name_lower
or "ele" in interaction_name_lower
or interaction_name_lower.startswith("e")
):
interaction_types.add(InteractionType.EM)
if "w" in interaction_name_lower:
interaction_types.add(InteractionType.WEAK)
if "strong" in interaction_name_lower or interaction_name_lower == "s":
interaction_types.add(InteractionType.STRONG)
if len(interaction_types) == 0:
raise ValueError(
f'Could not determine interaction type from "{description}"'
)
return interaction_types


def load_default_particles() -> ParticleCollection:
"""Load the default particle list that comes with `qrules`.
Expand Down
13 changes: 13 additions & 0 deletions src/qrules/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ class InteractionType(Enum):
EM = auto()
WEAK = auto()

@staticmethod
def from_str(description: str) -> "InteractionType":
description_lower = description.lower()
if description_lower.startswith("e"):
return InteractionType.EM
if description_lower.startswith("s"):
return InteractionType.STRONG
if description_lower.startswith("w"):
return InteractionType.WEAK
raise ValueError(
f'Could not determine interaction type from "{description}"'
)


def create_interaction_settings( # pylint: disable=too-many-locals,too-many-arguments
formalism: str,
Expand Down
16 changes: 13 additions & 3 deletions src/qrules/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@
from copy import copy, deepcopy
from enum import Enum, auto
from multiprocessing import Pool
from typing import Dict, List, Optional, Sequence, Set, Tuple, Type, Union
from typing import (
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)

import attr
from tqdm.auto import tqdm
Expand Down Expand Up @@ -392,7 +402,7 @@ def add_final_state_grouping(
self.final_state_groupings.append(fs_group) # type: ignore

def set_allowed_interaction_types(
self, allowed_interaction_types: List[InteractionType]
self, allowed_interaction_types: Iterable[InteractionType]
) -> None:
# verify order
for allowed_types in allowed_interaction_types:
Expand All @@ -406,7 +416,7 @@ def set_allowed_interaction_types(
raise ValueError(
f"interaction {allowed_types} not found in settings"
)
self.allowed_interaction_types = allowed_interaction_types
self.allowed_interaction_types = list(allowed_interaction_types)

def create_problem_sets(self) -> Dict[float, List[ProblemSet]]:
problem_sets = []
Expand Down
2 changes: 1 addition & 1 deletion tests/channels/test_jpsi_to_gamma_pi0_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_number_of_solutions(
initial_state=("J/psi(1S)", [-1, +1]),
final_state=["gamma", "pi0", "pi0"],
particle_db=particle_database,
allowed_interaction_types="strong and EM",
allowed_interaction_types=["strong", "EM"],
allowed_intermediate_particles=allowed_intermediate_particles,
number_of_threads=1,
formalism="helicity",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def result(request: SubRequest) -> Result:
initial_state=[("J/psi(1S)", [-1, 1])],
final_state=["gamma", "pi0", "pi0"],
allowed_intermediate_particles=["f(0)(980)", "f(0)(1500)"],
allowed_interaction_types="strong only",
allowed_interaction_types="strong",
formalism=formalism,
)
30 changes: 0 additions & 30 deletions tests/unit/test_reaction.py

This file was deleted.

24 changes: 24 additions & 0 deletions tests/unit/test_settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=no-self-use
import pytest

from qrules.particle import ParticleCollection
Expand All @@ -11,6 +12,29 @@
)


class TestInteractionType:
@pytest.mark.parametrize(
("description", "expected"),
[
("EM", InteractionType.EM),
("e", InteractionType.EM),
("electromagnetic", InteractionType.EM),
("w", InteractionType.WEAK),
("weak", InteractionType.WEAK),
("strong", InteractionType.STRONG),
("S", InteractionType.STRONG),
("", ValueError),
("non-existing", ValueError),
],
)
def test_from_str(self, description: str, expected: InteractionType):
if expected is ValueError:
with pytest.raises(ValueError, match=r"interaction type"):
assert InteractionType.from_str(description)
else:
assert InteractionType.from_str(description) == expected


def test_create_domains(particle_database: ParticleCollection):
pdg = particle_database
pions = pdg.filter(lambda p: p.name.startswith("pi"))
Expand Down

0 comments on commit 9fa8183

Please sign in to comment.