diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index 834c7eba..225113b0 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -19,7 +19,16 @@ """ from itertools import product -from typing import Dict, FrozenSet, List, Optional, Sequence, Set, Union +from typing import ( + Dict, + FrozenSet, + Iterable, + List, + Optional, + Sequence, + Set, + Union, +) import attr @@ -282,7 +291,7 @@ def generate_transitions( # pylint: disable=too-many-arguments initial_state: Union[StateDefinition, Sequence[StateDefinition]], final_state: Sequence[StateDefinition], allowed_intermediate_particles: Optional[List[str]] = None, - allowed_interaction_types: Optional[Union[str, List[str]]] = None, + allowed_interaction_types: Optional[Union[str, Iterable[str]]] = None, formalism: str = "canonical-helicity", particle_db: Optional[ParticleCollection] = None, mass_conservation_factor: Optional[float] = 3.0, @@ -310,9 +319,9 @@ def generate_transitions( # pylint: disable=too-many-arguments states that you want to allow as intermediate states. This helps (1) filter out resonances and (2) speed up computation time. - allowed_interaction_types (`str`, optional): Interaction types you want - to consider. For instance, both :code:`"strong and EM"` and - :code:`["s", "em"]` results in `~.InteractionType.EM` and + allowed_interaction_types: Interaction types you want to consider. For + instance, :code:`["s", "em"]` results in `~.InteractionType.EM` and + `~.InteractionType.STRONG` and :code:`["strong"]` results in `~.InteractionType.STRONG`. formalism (`str`, optional): Formalism that you intend to use in @@ -352,7 +361,7 @@ def generate_transitions( # pylint: disable=too-many-arguments ... initial_state="D0", ... final_state=["K~0", "K+", "K-"], ... allowed_intermediate_particles=["a(0)(980)", "a(2)(1320)-"], - ... allowed_interaction_types="ew", + ... allowed_interaction_types=["e", "w"], ... formalism="helicity", ... particle_db=qrules.load_pdg(), ... topology_building="isobar", @@ -379,52 +388,20 @@ def generate_transitions( # pylint: disable=too-many-arguments number_of_threads=number_of_threads, ) if allowed_interaction_types is not None: - interaction_types = _determine_interaction_types( - allowed_interaction_types - ) + if isinstance(allowed_interaction_types, str): + interaction_types = [ + InteractionType.from_str(allowed_interaction_types) + ] + else: + interaction_types = [ + InteractionType.from_str(description) + for description in allowed_interaction_types + ] stm.set_allowed_interaction_types(list(interaction_types)) problem_sets = stm.create_problem_sets() return stm.find_solutions(problem_sets) -def _determine_interaction_types( - description: Union[str, List[str]] -) -> Set[InteractionType]: - interaction_types: Set[InteractionType] = set() - if isinstance(description, list): - for i in description: - interaction_types.update( - _determine_interaction_types(description=i) - ) - return interaction_types - if not isinstance(description, str): - raise TypeError( - "Cannot handle interaction description of type " - f"{description.__class__.__name__}" - ) - if len(description) == 0: - raise ValueError('Provided an empty interaction type ("")') - interaction_name_lower = description.lower() - if "all" in interaction_name_lower: - for interaction in InteractionType: - interaction_types.add(interaction) - if ( - "em" in interaction_name_lower - or "ele" in interaction_name_lower - or interaction_name_lower.startswith("e") - ): - interaction_types.add(InteractionType.EM) - if "w" in interaction_name_lower: - interaction_types.add(InteractionType.WEAK) - if "strong" in interaction_name_lower or interaction_name_lower == "s": - interaction_types.add(InteractionType.STRONG) - if len(interaction_types) == 0: - raise ValueError( - f'Could not determine interaction type from "{description}"' - ) - return interaction_types - - def load_default_particles() -> ParticleCollection: """Load the default particle list that comes with `qrules`. diff --git a/src/qrules/settings.py b/src/qrules/settings.py index 323345ca..c479490c 100644 --- a/src/qrules/settings.py +++ b/src/qrules/settings.py @@ -92,6 +92,19 @@ class InteractionType(Enum): EM = auto() WEAK = auto() + @staticmethod + def from_str(description: str) -> "InteractionType": + description_lower = description.lower() + if description_lower.startswith("e"): + return InteractionType.EM + if description_lower.startswith("s"): + return InteractionType.STRONG + if description_lower.startswith("w"): + return InteractionType.WEAK + raise ValueError( + f'Could not determine interaction type from "{description}"' + ) + def create_interaction_settings( # pylint: disable=too-many-locals,too-many-arguments formalism: str, diff --git a/src/qrules/transition.py b/src/qrules/transition.py index eb197ee6..c966a58f 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -6,7 +6,17 @@ from copy import copy, deepcopy from enum import Enum, auto from multiprocessing import Pool -from typing import Dict, List, Optional, Sequence, Set, Tuple, Type, Union +from typing import ( + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) import attr from tqdm.auto import tqdm @@ -392,7 +402,7 @@ def add_final_state_grouping( self.final_state_groupings.append(fs_group) # type: ignore def set_allowed_interaction_types( - self, allowed_interaction_types: List[InteractionType] + self, allowed_interaction_types: Iterable[InteractionType] ) -> None: # verify order for allowed_types in allowed_interaction_types: @@ -406,7 +416,7 @@ def set_allowed_interaction_types( raise ValueError( f"interaction {allowed_types} not found in settings" ) - self.allowed_interaction_types = allowed_interaction_types + self.allowed_interaction_types = list(allowed_interaction_types) def create_problem_sets(self) -> Dict[float, List[ProblemSet]]: problem_sets = [] diff --git a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py index c4d7fbc2..b4031084 100644 --- a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py +++ b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py @@ -31,7 +31,7 @@ def test_number_of_solutions( initial_state=("J/psi(1S)", [-1, +1]), final_state=["gamma", "pi0", "pi0"], particle_db=particle_database, - allowed_interaction_types="strong and EM", + allowed_interaction_types=["strong", "EM"], allowed_intermediate_particles=allowed_intermediate_particles, number_of_threads=1, formalism="helicity", diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 91c2e689..48da0a54 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -17,6 +17,6 @@ def result(request: SubRequest) -> Result: initial_state=[("J/psi(1S)", [-1, 1])], final_state=["gamma", "pi0", "pi0"], allowed_intermediate_particles=["f(0)(980)", "f(0)(1500)"], - allowed_interaction_types="strong only", + allowed_interaction_types="strong", formalism=formalism, ) diff --git a/tests/unit/test_reaction.py b/tests/unit/test_reaction.py deleted file mode 100644 index ed53435f..00000000 --- a/tests/unit/test_reaction.py +++ /dev/null @@ -1,30 +0,0 @@ -import pytest - -from qrules import _determine_interaction_types -from qrules.settings import InteractionType as IT # noqa: N817 - - -@pytest.mark.parametrize( - ("description", "expected"), - [ - ("all", {IT.STRONG, IT.WEAK, IT.EM}), - ("EM", {IT.EM}), - ("electromagnetic", {IT.EM}), - ("electro-weak", {IT.EM, IT.WEAK}), - ("ew", {IT.EM, IT.WEAK}), - ("w", {IT.WEAK}), - ("strong", {IT.STRONG}), - ("only strong", {IT.STRONG}), - ("S", {IT.STRONG}), - (["e", "s", "w"], {IT.STRONG, IT.WEAK, IT.EM}), - ("strong and EM", {IT.STRONG, IT.EM}), - ("", ValueError), - ("non-existing", ValueError), - ], -) -def test_determine_interaction_types(description, expected): - if expected is ValueError: - with pytest.raises(ValueError, match=r"interaction type"): - assert _determine_interaction_types(description) - else: - assert _determine_interaction_types(description) == expected diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index a30d36a6..608b268d 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -1,3 +1,4 @@ +# pylint: disable=no-self-use import pytest from qrules.particle import ParticleCollection @@ -11,6 +12,29 @@ ) +class TestInteractionType: + @pytest.mark.parametrize( + ("description", "expected"), + [ + ("EM", InteractionType.EM), + ("e", InteractionType.EM), + ("electromagnetic", InteractionType.EM), + ("w", InteractionType.WEAK), + ("weak", InteractionType.WEAK), + ("strong", InteractionType.STRONG), + ("S", InteractionType.STRONG), + ("", ValueError), + ("non-existing", ValueError), + ], + ) + def test_from_str(self, description: str, expected: InteractionType): + if expected is ValueError: + with pytest.raises(ValueError, match=r"interaction type"): + assert InteractionType.from_str(description) + else: + assert InteractionType.from_str(description) == expected + + def test_create_domains(particle_database: ParticleCollection): pdg = particle_database pions = pdg.filter(lambda p: p.name.startswith("pi"))