From fe764e122b01be818d25c06fd356432c980f8f6b Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Fri, 18 Jun 2021 16:50:22 +0200 Subject: [PATCH] refactor!: implement StateTransition classes (#75) * feat: define ReactionInfo, State, StateTransition, and StateTransitionCollection classes * feat: implement as/fromdict for ReactionInfo etc * feat: provide pretty printer for FrozenDict * feat: provide pretty printer for Topology * fix: small improvement to Spin error message * refactor!: remove Result class * refactor: extract _dot.__strip_spin * style: clean up ignore statements on top * style: update notebook language version --- docs/abbreviate_signature.py | 6 +- docs/usage.ipynb | 6 +- docs/usage/particle.ipynb | 2 +- docs/usage/reaction.ipynb | 56 +-- docs/usage/visualize.ipynb | 20 +- src/qrules/__init__.py | 11 +- src/qrules/io/__init__.py | 37 +- src/qrules/io/_dict.py | 67 ++- src/qrules/io/_dot.py | 58 ++- src/qrules/particle.py | 11 +- src/qrules/solving.py | 1 - src/qrules/topology.py | 71 +++- src/qrules/transition.py | 395 +++++++++++++++--- tests/channels/test_d0_to_ks_kp_km.py | 9 +- tests/channels/test_jpsi_to_gamma_pi0_pi0.py | 46 +- tests/channels/test_y_to_d0_d0bar_pi0_pi0.py | 10 +- tests/unit/conftest.py | 4 +- .../conservation_rules/test_duck_typing.py | 1 - tests/unit/io/test_dot.py | 68 +-- tests/unit/io/test_io.py | 14 +- tests/unit/test_parity_prefactor.py | 20 +- tests/unit/test_solving.py | 8 - tests/unit/test_topology.py | 8 +- tests/unit/test_transition.py | 189 ++++++++- 24 files changed, 856 insertions(+), 262 deletions(-) delete mode 100644 tests/unit/test_solving.py diff --git a/docs/abbreviate_signature.py b/docs/abbreviate_signature.py index 48eca499..86c19ce2 100644 --- a/docs/abbreviate_signature.py +++ b/docs/abbreviate_signature.py @@ -1,3 +1,6 @@ +# cspell:ignore docutils +# pylint: disable=import-error +# pyright: reportMissingImports=false """Abbreviated the annotations generated by sphinx-autodoc. It's not necessary to generate the full path of type hints, because they are @@ -6,9 +9,6 @@ See also https://github.com/sphinx-doc/sphinx/issues/5868. """ -# cspell:ignore docutils -# pylint: disable=import-error -# pyright: reportMissingImports=false import sphinx.domains.python from docutils import nodes from sphinx import addnodes diff --git a/docs/usage.ipynb b/docs/usage.ipynb index 784337d1..ab61b79f 100644 --- a/docs/usage.ipynb +++ b/docs/usage.ipynb @@ -88,7 +88,7 @@ "source": [ "import qrules\n", "\n", - "result = qrules.generate_transitions(\n", + "reaction = qrules.generate_transitions(\n", " initial_state=\"J/psi(1S)\",\n", " final_state=[\"K0\", \"Sigma+\", \"p~\"],\n", " allowed_interaction_types=\"strong\",\n", @@ -104,7 +104,7 @@ "source": [ "import graphviz\n", "\n", - "dot = qrules.io.asdot(result, collapse_graphs=True)\n", + "dot = qrules.io.asdot(reaction, collapse_graphs=True)\n", "graphviz.Source(dot)" ] }, @@ -271,7 +271,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/docs/usage/particle.ipynb b/docs/usage/particle.ipynb index a74ca58b..a8a5c658 100644 --- a/docs/usage/particle.ipynb +++ b/docs/usage/particle.ipynb @@ -467,7 +467,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/docs/usage/reaction.ipynb b/docs/usage/reaction.ipynb index 4f861ab5..938601ce 100644 --- a/docs/usage/reaction.ipynb +++ b/docs/usage/reaction.ipynb @@ -215,14 +215,14 @@ "metadata": {}, "outputs": [], "source": [ - "result = stm.find_solutions(problem_sets)" + "reaction = stm.find_solutions(problem_sets)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The {meth}`~.StateTransitionManager.find_solutions` method returns a {class}`.Result` object from which you can extract the {attr}`~.Result.transitions`. Now, you can use {meth}`~.Result.get_intermediate_particles` to print the names of the intermediate states that the {class}`.StateTransitionManager` found:" + "The {meth}`~.StateTransitionManager.find_solutions` method returns a {class}`.ReactionInfo` object from which you can extract the {attr}`~.ReactionInfo.transitions`. Now, you can use {meth}`~.ReactionInfo.get_intermediate_particles` to print the names of the intermediate states that the {class}`.StateTransitionManager` found:" ] }, { @@ -231,8 +231,8 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"found\", len(result.transitions), \"solutions!\")\n", - "result.get_intermediate_particles().names" + "print(\"found\", len(reaction.transitions), \"solutions!\")\n", + "reaction.get_intermediate_particles().names" ] }, { @@ -244,9 +244,9 @@ "class: dropdown\n", "----\n", "\n", - "The \"number of {attr}`~.Result.transitions`\" is the total number of allowed {obj}`.StateTransitionGraph` instances that the {class}`.StateTransitionManager` has found. This also includes all allowed **spin projection combinations**. In this channel, we for example consider a $J/\\psi$ with spin projection $\\pm1$ that decays into a $\\gamma$ with spin projection $\\pm1$, which already gives us four possibilities.\n", + "The \"number of {attr}`~.ReactionInfo.transitions`\" is the total number of allowed {obj}`.StateTransitionGraph` instances that the {class}`.StateTransitionManager` has found. This also includes all allowed **spin projection combinations**. In this channel, we for example consider a $J/\\psi$ with spin projection $\\pm1$ that decays into a $\\gamma$ with spin projection $\\pm1$, which already gives us four possibilities.\n", "\n", - "On the other hand, the intermediate state names that was extracted with {meth}`.Result.get_intermediate_particles`, is just a {obj}`set` of the state names on the intermediate edges of the list of {attr}`~.Result.transitions`, regardless of spin projection.\n", + "On the other hand, the intermediate state names that was extracted with {meth}`.ReactionInfo.get_intermediate_particles`, is just a {obj}`set` of the state names on the intermediate edges of the list of {attr}`~.ReactionInfo.transitions`, regardless of spin projection.\n", "````" ] }, @@ -272,10 +272,10 @@ "source": [ "stm.set_allowed_interaction_types([InteractionType.STRONG])\n", "problem_sets = stm.create_problem_sets()\n", - "result = stm.find_solutions(problem_sets)\n", + "reaction = stm.find_solutions(problem_sets)\n", "\n", - "print(\"found\", len(result.transitions), \"solutions!\")\n", - "result.get_intermediate_particles().names" + "print(\"found\", len(reaction.transitions), \"solutions!\")\n", + "reaction.get_intermediate_particles().names" ] }, { @@ -297,10 +297,10 @@ "source": [ "stm.set_allowed_interaction_types([InteractionType.STRONG, InteractionType.EM])\n", "problem_sets = stm.create_problem_sets()\n", - "result = stm.find_solutions(problem_sets)\n", + "reaction = stm.find_solutions(problem_sets)\n", "\n", - "print(\"found\", len(result.transitions), \"solutions!\")\n", - "result.get_intermediate_particles().names" + "print(\"found\", len(reaction.transitions), \"solutions!\")\n", + "reaction.get_intermediate_particles().names" ] }, { @@ -339,10 +339,10 @@ "# i.e. f2 will find all f2's and f all f's independent of their spin\n", "stm.set_allowed_intermediate_particles([\"f(0)\", \"f(2)\"])\n", "\n", - "result = stm.find_solutions(problem_sets)\n", + "reaction = stm.find_solutions(problem_sets)\n", "\n", - "print(\"found\", len(result.transitions), \"solutions!\")\n", - "result.get_intermediate_particles().names" + "print(\"found\", len(reaction.transitions), \"solutions!\")\n", + "reaction.get_intermediate_particles().names" ] }, { @@ -366,7 +366,7 @@ "\n", "from qrules import io\n", "\n", - "dot = io.asdot(result, collapse_graphs=True, render_node=False)\n", + "dot = io.asdot(reaction, collapse_graphs=True, render_node=False)\n", "graphviz.Source(dot)" ] }, @@ -392,7 +392,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The {class}`.Result`, {class}`.StateTransitionGraph`, and {class}`.Topology` can be serialized to and from a {obj}`dict` with {func}`.io.asdict` and {func}`.io.fromdict`:" + "The {class}`.ReactionInfo`, {class}`.StateTransitionGraph`, and {class}`.Topology` can be serialized to and from a {obj}`dict` with {func}`.io.asdict` and {func}`.io.fromdict`:" ] }, { @@ -403,8 +403,7 @@ "source": [ "from qrules import io\n", "\n", - "graph = result.transitions[0]\n", - "io.asdict(graph.topology)" + "io.asdict(reaction.transition_groups[0].topology)" ] }, { @@ -415,7 +414,7 @@ "YAML is more human-readable than JSON, but reading and writing JSON is faster.\n", "```\n", "\n", - "This also means that the {obj}`.Result` can be written to JSON or YAML format with {func}`.io.write` and loaded again with {func}`.io.load`:" + "This also means that the {obj}`.ReactionInfo` can be written to JSON or YAML format with {func}`.io.write` and loaded again with {func}`.io.load`:" ] }, { @@ -424,9 +423,9 @@ "metadata": {}, "outputs": [], "source": [ - "io.write(result, \"transitions.json\")\n", - "imported_result = io.load(\"transitions.json\")\n", - "assert imported_result == result" + "io.write(reaction, \"transitions.json\")\n", + "imported_reaction = io.load(\"transitions.json\")\n", + "assert imported_reaction == reaction" ] }, { @@ -436,15 +435,6 @@ "Handy if it takes a lot of computation time to re-generate the transitions!" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```{warning}\n", - "It's not possible to {mod}`pickle` a {class}`.Result`, because {class}`.StateTransitionGraph` makes use of {class}`~typing.Generic`.\n", - "```" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -472,7 +462,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/docs/usage/visualize.ipynb b/docs/usage/visualize.ipynb index 53fd630c..9895697b 100644 --- a/docs/usage/visualize.ipynb +++ b/docs/usage/visualize.ipynb @@ -51,7 +51,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The {mod}`~qrules.io` module allows you to convert {class}`.StateTransitionGraph` and {class}`.Topology` instances to [DOT language](https://graphviz.org/doc/info/lang.html) with {func}`.asdot`. You can visualize its output with third-party libraries, such as [Graphviz](https://graphviz.org). This is particularly useful after running {meth}`~.StateTransitionManager.find_solutions`, which produces a {class}`.Result` object with a {class}`.list` of {class}`.StateTransitionGraph` instances (see {doc}`/usage/reaction`)." + "The {mod}`~qrules.io` module allows you to convert {class}`.StateTransitionGraph` and {class}`.Topology` instances to [DOT language](https://graphviz.org/doc/info/lang.html) with {func}`.asdot`. You can visualize its output with third-party libraries, such as [Graphviz](https://graphviz.org). This is particularly useful after running {meth}`~.StateTransitionManager.find_solutions`, which produces a {class}`.ReactionInfo` object with a {class}`.list` of {class}`.StateTransitionGraph` instances (see {doc}`/usage/reaction`)." ] }, { @@ -173,7 +173,7 @@ "source": [ "import qrules\n", "\n", - "result = qrules.generate_transitions(\n", + "reaction = qrules.generate_transitions(\n", " initial_state=\"psi(2S)\",\n", " final_state=[\"gamma\", \"eta\", \"eta\"],\n", " allowed_interaction_types=\"EM\",\n", @@ -184,7 +184,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "As noted in {ref}`usage/reaction:3. Find solutions`, the {attr}`~.Result.transitions` contain all spin projection combinations (which is necessary for the {mod}`ampform` package). It is possible to convert all these solutions to DOT language with {func}`~.asdot`. To avoid visualizing all solutions, we just take a subset of the {attr}`~.Result.transitions`:" + "As noted in {ref}`usage/reaction:3. Find solutions`, the {attr}`~.ReactionInfo.transitions` contain all spin projection combinations (which is necessary for the {mod}`ampform` package). It is possible to convert all these solutions to DOT language with {func}`~.asdot`. To avoid visualizing all solutions, we just take a subset of the {attr}`~.ReactionInfo.transitions`:" ] }, { @@ -193,7 +193,7 @@ "metadata": {}, "outputs": [], "source": [ - "dot = qrules.io.asdot(result.transitions[::50][:3]) # just some selection" + "dot = qrules.io.asdot(reaction.transitions[::50][:3]) # just some selection" ] }, { @@ -223,7 +223,7 @@ "import graphviz\n", "\n", "dot = qrules.io.asdot(\n", - " result.transitions[::50][:3], render_node=False\n", + " reaction.transitions[::50][:3], render_node=False\n", ") # just some selection\n", "graphviz.Source(dot)" ] @@ -241,7 +241,7 @@ "metadata": {}, "outputs": [], "source": [ - "qrules.io.write(result, \"decay_topologies_with_spin.gv\")" + "qrules.io.write(reaction, \"decay_topologies_with_spin.gv\")" ] }, { @@ -255,7 +255,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Since this list of all possible spin projections {attr}`~.Result.transitions` is rather long, it is often useful to use `strip_spin=True` or `collapse_graphs=True` to bundle comparable graphs. First, {code}`strip_spin=True` allows one collapse (ignore) the spin projections (we again show a selection only):" + "Since this list of all possible spin projections {attr}`~.ReactionInfo.transitions` is rather long, it is often useful to use `strip_spin=True` or `collapse_graphs=True` to bundle comparable graphs. First, {code}`strip_spin=True` allows one collapse (ignore) the spin projections (we again show a selection only):" ] }, { @@ -264,7 +264,7 @@ "metadata": {}, "outputs": [], "source": [ - "dot = qrules.io.asdot(result.transitions[:3], strip_spin=True)\n", + "dot = qrules.io.asdot(reaction.transitions[:3], strip_spin=True)\n", "graphviz.Source(dot)" ] }, @@ -290,7 +290,7 @@ "metadata": {}, "outputs": [], "source": [ - "dot = qrules.io.asdot(result, collapse_graphs=True, render_node=False)\n", + "dot = qrules.io.asdot(reaction, collapse_graphs=True, render_node=False)\n", "graphviz.Source(dot)" ] } @@ -311,7 +311,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/src/qrules/__init__.py b/src/qrules/__init__.py index 225113b0..d046da69 100644 --- a/src/qrules/__init__.py +++ b/src/qrules/__init__.py @@ -1,5 +1,4 @@ # pylint: disable=too-many-lines - """A rule based system that facilitates particle reaction analysis. QRules generates allowed particle transitions from a set of conservation rules @@ -74,7 +73,7 @@ from .transition import ( EdgeSettings, ProblemSet, - Result, + ReactionInfo, StateTransitionManager, ) @@ -299,7 +298,7 @@ def generate_transitions( # pylint: disable=too-many-arguments max_spin_magnitude: float = 2.0, topology_building: str = "isobar", number_of_threads: Optional[int] = None, -) -> Result: +) -> ReactionInfo: """Generate allowed transitions between an initial and final state. Serves as a facade to the `.StateTransitionManager` (see @@ -357,7 +356,7 @@ def generate_transitions( # pylint: disable=too-many-arguments arguments) would be: >>> import qrules - >>> result = qrules.generate_transitions( + >>> reaction = qrules.generate_transitions( ... initial_state="D0", ... final_state=["K~0", "K+", "K-"], ... allowed_intermediate_particles=["a(0)(980)", "a(2)(1320)-"], @@ -366,7 +365,9 @@ def generate_transitions( # pylint: disable=too-many-arguments ... particle_db=qrules.load_pdg(), ... topology_building="isobar", ... ) - >>> len(result.transitions) + >>> len(reaction.transition_groups) + 3 + >>> len(reaction.transitions) 4 """ if isinstance(initial_state, str) or ( diff --git a/src/qrules/io/__init__.py b/src/qrules/io/__init__.py index 99d19aee..e51f2971 100644 --- a/src/qrules/io/__init__.py +++ b/src/qrules/io/__init__.py @@ -1,3 +1,4 @@ +# pylint: disable=too-many-return-statements """Serialization module for the `qrules`. The `.io` module provides tools to export or import objects from `qrules` to @@ -14,18 +15,32 @@ from qrules.particle import Particle, ParticleCollection from qrules.topology import StateTransitionGraph, Topology -from qrules.transition import Result +from qrules.transition import ( + ReactionInfo, + State, + StateTransition, + StateTransitionCollection, +) from . import _dict, _dot def asdict(instance: object) -> dict: + # pylint: disable=protected-access if isinstance(instance, Particle): return _dict.from_particle(instance) if isinstance(instance, ParticleCollection): return _dict.from_particle_collection(instance) - if isinstance(instance, Result): - return _dict.from_result(instance) + if isinstance( + instance, + (ReactionInfo, State, StateTransition, StateTransitionCollection), + ): + return attr.asdict( + instance, + recurse=True, + filter=lambda attr, _: attr.init, + value_serializer=_dict._value_serializer, + ) if isinstance(instance, StateTransitionGraph): return _dict.from_stg(instance) if isinstance(instance, Topology): @@ -41,8 +56,12 @@ def fromdict(definition: dict) -> object: return _dict.build_particle(definition) if keys == {"particles"}: return _dict.build_particle_collection(definition) - if keys == {"transitions", "formalism"}: - return _dict.build_result(definition) + if keys == {"transition_groups", "formalism"}: + return _dict.build_reaction_info(definition) + if keys == {"topology", "states", "interactions"}: + return _dict.build_state_transition(definition) + if keys == {"transitions"}: + return _dict.build_stc(definition) if keys == {"topology", "edge_props", "node_props"}: return _dict.build_stg(definition) if keys == __REQUIRED_TOPOLOGY_FIELDS: @@ -105,6 +124,8 @@ def asdot( .. seealso:: :doc:`/usage/visualize` """ + if isinstance(instance, StateTransition): + instance = instance.to_graph() if isinstance(instance, (StateTransitionGraph, Topology)): return _dot.graph_to_dot( instance, @@ -113,9 +134,9 @@ def asdot( render_resonance_id=render_resonance_id, render_initial_state_id=render_initial_state_id, ) - if isinstance(instance, (Result, abc.Sequence)): - if isinstance(instance, Result): - instance = instance.transitions + if isinstance(instance, (ReactionInfo, StateTransitionCollection)): + instance = instance.to_graphs() + if isinstance(instance, abc.Sequence): return _dot.graph_list_to_dot( instance, render_node=render_node, diff --git a/src/qrules/io/_dict.py b/src/qrules/io/_dict.py index feabea48..bee8b83c 100644 --- a/src/qrules/io/_dict.py +++ b/src/qrules/io/_dict.py @@ -17,7 +17,12 @@ ) from qrules.quantum_numbers import InteractionProperties from qrules.topology import Edge, StateTransitionGraph, Topology -from qrules.transition import Result +from qrules.transition import ( + ReactionInfo, + State, + StateTransition, + StateTransitionCollection, +) def from_particle_collection(particles: ParticleCollection) -> dict: @@ -28,20 +33,11 @@ def from_particle(particle: Particle) -> dict: return attr.asdict( particle, recurse=True, - value_serializer=__value_serializer, + value_serializer=_value_serializer, filter=lambda attr, value: attr.default != value, ) -def from_result(result: Result) -> dict: - output: Dict[str, Any] = { - "transitions": [from_stg(graph) for graph in result.transitions], - } - if result.formalism is not None: - output["formalism"] = result.formalism - return output - - def from_stg(graph: StateTransitionGraph[ParticleWithSpin]) -> dict: topology = graph.topology edge_props_def = {} @@ -70,19 +66,21 @@ def from_topology(topology: Topology) -> dict: return attr.asdict( topology, recurse=True, - value_serializer=__value_serializer, + value_serializer=_value_serializer, filter=lambda a, v: a.init and a.default != v, ) -def __value_serializer( # pylint: disable=unused-argument +def _value_serializer( # pylint: disable=unused-argument inst: type, field: attr.Attribute, value: Any ) -> Any: if isinstance(value, abc.Mapping): if all(map(lambda p: isinstance(p, Particle), value.values())): return {k: v.name for k, v in value.items()} return dict(value) - if isinstance(value, Particle): + if not isinstance( + inst, (ReactionInfo, State, StateTransition, StateTransitionCollection) + ) and isinstance(value, Particle): return value.name if isinstance(value, Parity): return {"value": value.value} @@ -115,14 +113,13 @@ def build_particle(definition: dict) -> Particle: return Particle(**definition) -def build_result(definition: dict) -> Result: - formalism = definition.get("formalism") - transitions = [ - build_stg(graph_def) for graph_def in definition["transitions"] +def build_reaction_info(definition: dict) -> ReactionInfo: + transition_groups = [ + build_stc(graph_def) for graph_def in definition["transition_groups"] ] - return Result( - transitions=transitions, - formalism=formalism, + return ReactionInfo( + transition_groups=transition_groups, + formalism=definition["formalism"], ) @@ -148,6 +145,34 @@ def build_stg(definition: dict) -> StateTransitionGraph[ParticleWithSpin]: ) +def build_stc(definition: dict) -> StateTransitionCollection: + transitions = [ + build_state_transition(graph_def) + for graph_def in definition["transitions"] + ] + return StateTransitionCollection(transitions=transitions) + + +def build_state_transition(definition: dict) -> StateTransition: + topology = build_topology(definition["topology"]) + states = { + int(i): State( + particle=build_particle(state_def["particle"]), + spin_projection=float(state_def["spin_projection"]), + ) + for i, state_def in definition["states"].items() + } + interactions = { + int(i): InteractionProperties(**interaction_def) + for i, interaction_def in definition["interactions"].items() + } + return StateTransition( + topology=topology, + states=states, + interactions=interactions, + ) + + def build_topology(definition: dict) -> Topology: nodes = definition["nodes"] edges_def: Dict[int, dict] = definition["edges"] diff --git a/src/qrules/io/_dot.py b/src/qrules/io/_dot.py index 81b324ed..389ade32 100644 --- a/src/qrules/io/_dot.py +++ b/src/qrules/io/_dot.py @@ -8,6 +8,7 @@ from qrules.particle import Particle, ParticleCollection, ParticleWithSpin from qrules.quantum_numbers import InteractionProperties, _to_fraction from qrules.topology import StateTransitionGraph, Topology +from qrules.transition import StateTransition _DOT_HEAD = """digraph { rankdir=LR; @@ -100,7 +101,7 @@ def __graph_to_dot_content( # pylint: disable=too-many-locals,too-many-branches render_initial_state_id: bool, ) -> str: dot = "" - if isinstance(graph, StateTransitionGraph): + if isinstance(graph, (StateTransition, StateTransitionGraph)): topology = graph.topology elif isinstance(graph, Topology): topology = graph @@ -170,6 +171,8 @@ def __get_edge_label( edge_id: int, render_edge_id: bool, ) -> str: + if isinstance(graph, StateTransition): + graph = graph.to_graph() if isinstance(graph, StateTransitionGraph): edge_prop = graph.get_edge_props(edge_id) if not edge_prop: @@ -239,6 +242,8 @@ def _get_particle_graphs( """ inventory: List[StateTransitionGraph[Particle]] = [] for transition in graphs: + if isinstance(transition, StateTransition): + transition = transition.to_graph() if any( transition.compare( other, edge_comparator=lambda e1, e2: e1[0] == e2 @@ -246,28 +251,8 @@ def _get_particle_graphs( for other in inventory ): continue - new_edge_props = {} - for edge_id in transition.topology.edges: - edge_props = transition.get_edge_props(edge_id) - if edge_props: - new_edge_props[edge_id] = edge_props[0] - inventory.append( - StateTransitionGraph[Particle]( - topology=transition.topology, - node_props={ - i: node_props - for i, node_props in zip( - transition.topology.nodes, - map( - transition.get_node_props, - transition.topology.nodes, - ), - ) - if node_props - }, - edge_props=new_edge_props, - ) - ) + stripped_graph = __strip_spin(transition) + inventory.append(stripped_graph) inventory = sorted( inventory, key=lambda g: [ @@ -277,6 +262,33 @@ def _get_particle_graphs( return inventory +def __strip_spin( + graph: StateTransitionGraph[ParticleWithSpin], +) -> StateTransitionGraph[Particle]: + if isinstance(graph, StateTransition): + graph = graph.to_graph() + new_edge_props = {} + for edge_id in graph.topology.edges: + edge_props = graph.get_edge_props(edge_id) + if edge_props: + new_edge_props[edge_id] = edge_props[0] + return StateTransitionGraph[Particle]( + topology=graph.topology, + node_props={ + i: node_props + for i, node_props in zip( + graph.topology.nodes, + map( + graph.get_node_props, + graph.topology.nodes, + ), + ) + if node_props + }, + edge_props=new_edge_props, + ) + + def _collapse_graphs( graphs: Iterable[StateTransitionGraph[ParticleWithSpin]], ) -> List[StateTransitionGraph[ParticleCollection]]: diff --git a/src/qrules/particle.py b/src/qrules/particle.py index 8ec14886..9b3d3309 100644 --- a/src/qrules/particle.py +++ b/src/qrules/particle.py @@ -1,9 +1,10 @@ """A collection of particle info containers. -The `.particle` module is the starting point of `qrules`. Its main interface is -the `ParticleCollection`, which is a collection of immutable `Particle` -instances that are uniquely defined by their properties. As such, it can be -used stand-alone as a database of quantum numbers (see :doc:`/usage/particle`). +The :mod:`.particle` module is the starting point of `qrules`. Its main +interface is the `ParticleCollection`, which is a collection of immutable +`Particle` instances that are uniquely defined by their properties. As such, it +can be used stand-alone as a database of quantum numbers (see +:doc:`/usage/particle`). The `.transition` module uses the properties of `Particle` instances when it computes which `.StateTransitionGraph` s are allowed between an initial state @@ -67,7 +68,7 @@ def __attrs_post_init__(self) -> None: if abs(self.projection) > self.magnitude: if self.magnitude < 0.0: raise ValueError( - "Spin magnitude has to be positive:\n" f" {self.magnitude}" + f"Spin magnitude has to be positive, but is {self.magnitude}" ) raise ValueError( "Absolute value of spin projection cannot be larger than its " diff --git a/src/qrules/solving.py b/src/qrules/solving.py index ff3205ee..dd5833b5 100644 --- a/src/qrules/solving.py +++ b/src/qrules/solving.py @@ -1,5 +1,4 @@ # pylint: disable=too-many-lines - """Functions to solve a particle reaction problem. This module is responsible for solving a particle reaction problem stated by a diff --git a/src/qrules/topology.py b/src/qrules/topology.py index b8ff444f..1c4e84d0 100644 --- a/src/qrules/topology.py +++ b/src/qrules/topology.py @@ -11,8 +11,11 @@ import copy import itertools import logging +from abc import abstractmethod from collections import abc +from functools import total_ordering from typing import ( + Any, Callable, Collection, Dict, @@ -36,12 +39,30 @@ from .quantum_numbers import InteractionProperties -KeyType = TypeVar("KeyType") +try: + from typing import Protocol +except ImportError: + from typing_extensions import Protocol # type: ignore + +try: + from IPython.lib.pretty import PrettyPrinter +except ImportError: + PrettyPrinter = Any + + +class Comparable(Protocol): + @abstractmethod + def __lt__(self, other: Any) -> bool: + ... + + +KeyType = TypeVar("KeyType", bound=Comparable) """Type the keys of the `~typing.Mapping`, see `~typing.KeysView`.""" ValueType = TypeVar("ValueType") """Type the value of the `~typing.Mapping`, see `~typing.ValuesView`.""" +@total_ordering class FrozenDict( # pylint: disable=too-many-ancestors Generic[KeyType, ValueType], abc.Hashable, abc.Mapping ): @@ -58,6 +79,20 @@ def __init__(self, mapping: Optional[Mapping] = None): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.__mapping})" + def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None: + class_name = type(self).__name__ + if cycle: + p.text(f"{class_name}(...)") + else: + with p.group(indent=2, open=f"{class_name}({{"): + for key, value in self.items(): + p.breakable() + p.text(f"{key}: ") + p.pretty(value) + p.text(",") + p.breakable() + p.text("})") + def __iter__(self) -> Iterator[KeyType]: return iter(self.__mapping) @@ -67,6 +102,17 @@ def __len__(self) -> int: def __getitem__(self, key: KeyType) -> ValueType: return self.__mapping[key] + def __gt__(self, other: Any) -> bool: + if isinstance(other, abc.Mapping): + sorted_self = _convert_mapping_to_sorted_tuple(self) + sorted_other = _convert_mapping_to_sorted_tuple(other) + return sorted_self > sorted_other + + raise NotImplementedError( + f"Can only compare {self.__class__.__name__} with a mapping," + f" not with {other.__class__.__name__}" + ) + def __hash__(self) -> int: return self.__hash @@ -80,6 +126,12 @@ def values(self) -> ValuesView[ValueType]: return self.__mapping.values() +def _convert_mapping_to_sorted_tuple( + mapping: Mapping[KeyType, ValueType], +) -> Tuple[Tuple[KeyType, ValueType], ...]: + return tuple((key, mapping[key]) for key in sorted(mapping.keys())) + + def _to_optional_int(optional_int: Optional[int]) -> Optional[int]: if optional_int is None: return None @@ -189,6 +241,23 @@ def __get_surrounding_nodes(self, node_id: int) -> Set[int]: surrounding_nodes.discard(node_id) return surrounding_nodes + def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None: + class_name = type(self).__name__ + if cycle: + p.text(f"{class_name}(...)") + else: + with p.group(indent=2, open=f"{class_name}("): + for field in attr.fields(type(self)): + if not field.init: + continue + value = getattr(self, field.name) + p.breakable() + p.text(f"{field.name}=") + p.pretty(value) + p.text(",") + p.breakable() + p.text(")") + def is_isomorphic(self, other: "Topology") -> bool: """Check if two graphs are isomorphic. diff --git a/src/qrules/transition.py b/src/qrules/transition.py index c966a58f..885be9e1 100644 --- a/src/qrules/transition.py +++ b/src/qrules/transition.py @@ -2,14 +2,21 @@ import logging import multiprocessing -from collections import defaultdict +from collections import abc, defaultdict from copy import copy, deepcopy from enum import Enum, auto +from itertools import zip_longest from multiprocessing import Pool from typing import ( + Any, + Collection, + DefaultDict, Dict, + FrozenSet, Iterable, + Iterator, List, + Mapping, Optional, Sequence, Set, @@ -19,6 +26,7 @@ ) import attr +from attr.validators import instance_of from tqdm.auto import tqdm from ._system_control import ( @@ -38,7 +46,13 @@ create_initial_facts, match_external_edges, ) -from .particle import Particle, ParticleCollection, ParticleWithSpin, load_pdg +from .particle import ( + Particle, + ParticleCollection, + ParticleWithSpin, + _to_float, + load_pdg, +) from .quantum_numbers import ( EdgeQuantumNumber, EdgeQuantumNumbers, @@ -58,12 +72,18 @@ QNResult, ) from .topology import ( + FrozenDict, StateTransitionGraph, Topology, create_isobar_topologies, create_n_body_topology, ) +try: + from IPython.lib.pretty import PrettyPrinter +except ImportError: + PrettyPrinter = Any + class SolvingMode(Enum): """Types of modes for solving.""" @@ -150,61 +170,12 @@ def extend( ) -@attr.s(on_setattr=attr.setters.frozen) -class Result: - transitions: List[StateTransitionGraph[ParticleWithSpin]] = attr.ib( - factory=list - ) - formalism: Optional[str] = attr.ib(default=None) - - def get_initial_state(self) -> List[Particle]: - graph = self.__get_first_graph() - return [ - x[0] - for x in map( - graph.get_edge_props, graph.topology.incoming_edge_ids - ) - if x - ] - - def get_final_state(self) -> List[Particle]: - graph = self.__get_first_graph() - return [ - x[0] - for x in map( - graph.get_edge_props, graph.topology.outgoing_edge_ids - ) - if x - ] - - def __get_first_graph(self) -> StateTransitionGraph[ParticleWithSpin]: - if len(self.transitions) == 0: - raise ValueError( - f"No solutions in {self.__class__.__name__} object" - ) - return next(iter(self.transitions)) - - def get_intermediate_particles(self) -> ParticleCollection: - """Extract the names of the intermediate state particles.""" - intermediate_states = ParticleCollection() - for transition in self.transitions: - for edge_props in map( - transition.get_edge_props, - transition.topology.intermediate_edge_ids, - ): - if edge_props: - particle, _ = edge_props - if particle not in intermediate_states: - intermediate_states.add(particle) - return intermediate_states - - @attr.s class ProblemSet: """Particle reaction problem set, defined as a graph like data structure. Args: - topology: `~.Topology` that contains the structure of the reaction. + topology: `.Topology` that contains the structure of the reaction. initial_facts: `~.InitialFacts` that contain the info of initial and final state in connection with the topology. solving_settings: Solving related settings such as the conservation @@ -558,7 +529,7 @@ def create_edge_settings(edge_id: int) -> EdgeSettings: def find_solutions( # pylint: disable=too-many-branches self, problem_sets: Dict[float, List[ProblemSet]], - ) -> Result: + ) -> "ReactionInfo": # pylint: disable=too-many-locals """Check for solutions for a specific set of interaction settings.""" results: Dict[float, _SolutionContainer] = {} @@ -663,10 +634,7 @@ def find_solutions( # pylint: disable=too-many-branches raise ValueError("No solutions were found") match_external_edges(final_solutions) - return Result( - final_solutions, - formalism=self.formalism, - ) + return ReactionInfo.from_graphs(final_solutions, self.formalism) def _solve( self, qn_problem_set: QNProblemSet @@ -678,7 +646,7 @@ def _solve( def __convert_result( self, topology: Topology, qn_result: QNResult ) -> _SolutionContainer: - """Converts a `.QNResult` with a `.Topology` into a `.Result`. + """Converts a `.QNResult` with a `.Topology` into `.ReactionInfo`. The ParticleCollection is used to retrieve a particle instance reference to lower the memory footprint. @@ -707,3 +675,314 @@ def __convert_result( not_executed_edge_rules=qn_result.not_executed_edge_rules, ), ) + + +@attr.s(frozen=True) +class State: + particle: Particle = attr.ib(validator=instance_of(Particle)) + spin_projection: float = attr.ib(converter=_to_float) + + def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None: + class_name = type(self).__name__ + if cycle: + p.text(f"{class_name}(...)") + else: + with p.group(indent=2, open=f"{class_name}("): + for field in attr.fields(type(self)): + value = getattr(self, field.name) + p.breakable() + p.text(f"{field.name}=") + p.pretty(value) + p.text(",") + p.breakable() + p.text(")") + + +@attr.s(frozen=True) +class StateTransition: + """Frozen instance of a `.StateTransitionGraph` of `.Particle` with spin.""" + + topology: Topology = attr.ib(validator=instance_of(Topology)) + states: FrozenDict[int, State] = attr.ib(converter=FrozenDict) + interactions: FrozenDict[int, InteractionProperties] = attr.ib( + converter=FrozenDict + ) + + def __attrs_post_init__(self) -> None: + _assert_defined(self.topology.edges, self.states) + _assert_defined(self.topology.nodes, self.interactions) + + def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None: + class_name = type(self).__name__ + if cycle: + p.text(f"{class_name}(...)") + else: + with p.group(indent=2, open=f"{class_name}("): + for field in attr.fields(type(self)): + value = getattr(self, field.name) + p.breakable() + p.text(f"{field.name}=") + p.pretty(value) + p.text(",") + p.breakable() + p.text(")") + + @staticmethod + def from_graph( + graph: StateTransitionGraph[ParticleWithSpin], + ) -> "StateTransition": + return StateTransition( + topology=graph.topology, + states=FrozenDict( + { + i: State(*graph.get_edge_props(i)) + for i in graph.topology.edges + } + ), + interactions=FrozenDict( + {i: graph.get_node_props(i) for i in graph.topology.nodes} + ), + ) + + def to_graph(self) -> StateTransitionGraph[ParticleWithSpin]: + return StateTransitionGraph[ParticleWithSpin]( + topology=self.topology, + edge_props={ + i: (state.particle, state.spin_projection) + for i, state in self.states.items() + }, + node_props=self.interactions, + ) + + @property + def initial_states(self) -> Dict[int, State]: + return self.filter_states(self.topology.incoming_edge_ids) + + @property + def final_states(self) -> Dict[int, State]: + return self.filter_states(self.topology.outgoing_edge_ids) + + @property + def intermediate_states(self) -> Dict[int, State]: + return self.filter_states(self.topology.intermediate_edge_ids) + + def filter_states(self, edge_ids: Iterable[int]) -> Dict[int, State]: + return {i: self.states[i] for i in edge_ids} + + @property + def particles(self) -> Dict[int, Particle]: + return {i: edge_prop.particle for i, edge_prop in self.states.items()} + + +def _assert_defined(items: Collection, properties: Mapping) -> None: + existing = set(items) + defined = set(properties) + if existing & defined != existing: + raise ValueError( + "Some items have no property assigned to them." + f" Available items: {existing}, items with property: {defined}" + ) + + +def _to_frozenset( + iterable: Iterable[StateTransition], +) -> FrozenSet[StateTransition]: + if not all(map(lambda t: isinstance(t, StateTransition), iterable)): + raise TypeError( + f"Not all instances are of type {StateTransition.__name__}" + ) + return frozenset(iterable) + + +@attr.s(frozen=True, eq=False) +class StateTransitionCollection(abc.Set): + """`.StateTransition` instances with the same `.Topology` and edge IDs.""" + + transitions: FrozenSet[StateTransition] = attr.ib(converter=_to_frozenset) + topology: Topology = attr.ib(init=False, repr=False) + initial_state: FrozenDict[int, Particle] = attr.ib(init=False, repr=False) + final_state: FrozenDict[int, Particle] = attr.ib(init=False, repr=False) + + def __attrs_post_init__(self) -> None: + if not any(self.transitions): + ValueError(f"At least one {StateTransition.__name__} required") + some_transition = next(iter(self.transitions)) + topology = some_transition.topology + if not all(map(lambda t: t.topology == topology, self.transitions)): + raise TypeError( + f"Not all {StateTransition.__name__} items have the same" + f" underlying topology. Expecting: {topology}" + ) + object.__setattr__(self, "topology", topology) + object.__setattr__( + self, + "initial_state", + FrozenDict( + { + i: s.particle + for i, s in some_transition.states.items() + if i in some_transition.topology.incoming_edge_ids + } + ), + ) + object.__setattr__( + self, + "final_state", + FrozenDict( + { + i: s.particle + for i, s in some_transition.states.items() + if i in some_transition.topology.outgoing_edge_ids + } + ), + ) + + def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None: + class_name = type(self).__name__ + if cycle: + p.text(f"{class_name}(...)") + else: + with p.group(indent=2, open=f"{class_name}({{"): + for transition in self: + p.breakable() + p.pretty(transition) + p.text(",") + p.breakable() + p.text("})") + + def __contains__(self, item: object) -> bool: + return item in self.transitions + + def __iter__(self) -> Iterator[StateTransition]: + return iter(self.transitions) + + def __len__(self) -> int: + return len(self.transitions) + + @staticmethod + def from_graphs( + graphs: Iterable[StateTransitionGraph[ParticleWithSpin]], + ) -> "StateTransitionCollection": + transitions = [StateTransition.from_graph(g) for g in graphs] + return StateTransitionCollection(transitions) + + def to_graphs(self) -> List[StateTransitionGraph[ParticleWithSpin]]: + return [transition.to_graph() for transition in sorted(self)] + + def get_intermediate_particles(self) -> ParticleCollection: + """Extract the particle names of the intermediate states.""" + intermediate_states = ParticleCollection() + for transition in self.transitions: + for state in transition.intermediate_states.values(): + if state.particle not in intermediate_states: + intermediate_states.add(state.particle) + return intermediate_states + + +def _to_tuple( + iterable: Iterable[StateTransitionCollection], +) -> Tuple[StateTransitionCollection, ...]: + if not all( + map(lambda t: isinstance(t, StateTransitionCollection), iterable) + ): + raise TypeError( + f"Not all instances are of type {StateTransitionCollection.__name__}" + ) + return tuple(iterable) + + +@attr.s(frozen=True, eq=False) +class ReactionInfo: + """`StateTransitionCollection` instances, grouped by `.Topology`.""" + + transition_groups: Tuple[StateTransitionCollection, ...] = attr.ib( + converter=_to_tuple + ) + transitions: List[StateTransition] = attr.ib( + init=False, repr=False, eq=False + ) + initial_state: FrozenDict[int, Particle] = attr.ib(init=False, repr=False) + final_state: FrozenDict[int, Particle] = attr.ib(init=False, repr=False) + formalism: str = attr.ib(validator=instance_of(str)) + + def __attrs_post_init__(self) -> None: + if len(self.transition_groups) == 0: + ValueError( + f"At least one {StateTransitionCollection.__name__} required" + ) + transitions: List[StateTransition] = [] + for grouping in self.transition_groups: + transitions.extend(sorted(grouping)) + first_grouping = self.transition_groups[0] + object.__setattr__(self, "transitions", transitions) + object.__setattr__(self, "final_state", first_grouping.final_state) + object.__setattr__(self, "initial_state", first_grouping.initial_state) + + def __eq__(self, other: object) -> bool: + if isinstance(other, ReactionInfo): + for own_grouping, other_grouping in zip_longest( + self.transition_groups, other.transition_groups + ): + if own_grouping != other_grouping: + return False + return True + raise NotImplementedError( + f"Cannot compare {self.__class__.__name__} with {other.__class__.__name__}" + ) + + def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None: + class_name = type(self).__name__ + if cycle: + p.text(f"{class_name}(...)") + else: + with p.group(indent=2, open=f"{class_name}("): + p.breakable() + p.text("transition_groups=") + with p.group(indent=2, open="("): + for transition_grouping in self.transition_groups: + p.breakable() + p.pretty(transition_grouping) + p.text(",") + p.breakable() + p.text("),") + p.breakable() + p.text("formalism=") + p.pretty(self.formalism) + p.text(",") + p.breakable() + p.text(")") + + def get_intermediate_particles(self) -> ParticleCollection: + """Extract the names of the intermediate state particles.""" + return ParticleCollection( + set().union( + *[ + grouping.get_intermediate_particles() + for grouping in self.transition_groups + ] + ) + ) + + @staticmethod + def from_graphs( + graphs: Iterable[StateTransitionGraph[ParticleWithSpin]], + formalism: str, + ) -> "ReactionInfo": + transition_mapping: DefaultDict[ + Topology, List[StateTransition] + ] = defaultdict(list) + for graph in graphs: + transition_mapping[graph.topology].append( + StateTransition.from_graph(graph) + ) + transition_groups = tuple( + StateTransitionCollection(transitions) + for transitions in transition_mapping.values() + ) + return ReactionInfo(transition_groups, formalism) + + def to_graphs(self) -> List[StateTransitionGraph[ParticleWithSpin]]: + graphs: List[StateTransitionGraph[ParticleWithSpin]] = [] + for grouping in self.transition_groups: + graphs.extend(grouping.to_graphs()) + return graphs diff --git a/tests/channels/test_d0_to_ks_kp_km.py b/tests/channels/test_d0_to_ks_kp_km.py index 10e133b2..1ef5a9aa 100644 --- a/tests/channels/test_d0_to_ks_kp_km.py +++ b/tests/channels/test_d0_to_ks_kp_km.py @@ -2,7 +2,7 @@ def test_script(): - result = qrules.generate_transitions( + reaction = qrules.generate_transitions( initial_state="D0", final_state=["K~0", "K+", "K-"], allowed_intermediate_particles=[ @@ -12,8 +12,11 @@ def test_script(): ], number_of_threads=1, ) - assert len(result.transitions) == 5 - assert result.get_intermediate_particles().names == [ + assert len(reaction.transition_groups) == 3 + assert len(reaction.transition_groups[0]) == 2 + assert len(reaction.transition_groups[1]) == 1 + assert len(reaction.transition_groups[2]) == 2 + assert reaction.get_intermediate_particles().names == [ "a(0)(980)-", "a(0)(980)0", "a(0)(980)+", diff --git a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py index b4031084..d1b165aa 100644 --- a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py +++ b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py @@ -2,15 +2,17 @@ import qrules from qrules.combinatorics import _create_edge_id_particle_mapping +from qrules.particle import ParticleWithSpin +from qrules.topology import StateTransitionGraph @pytest.mark.parametrize( - ("allowed_intermediate_particles", "number_of_solutions"), + ("allowed_intermediate_particles", "n_topologies", "number_of_solutions"), [ - (["f(0)(1500)"], 4), - (["f(0)(980)", "f(0)(1500)"], 8), - (["f(2)(1270)"], 12), - (["omega(782)"], 8), + (["f(0)(1500)"], 1, 4), + (["f(0)(980)", "f(0)(1500)"], 1, 8), + (["f(2)(1270)"], 1, 12), + (["omega(782)"], 1, 8), ( [ "f(0)(980)", @@ -19,15 +21,19 @@ "f(2)(1950)", "omega(782)", ], + 2, 40, ), ], ) @pytest.mark.slow() def test_number_of_solutions( - particle_database, allowed_intermediate_particles, number_of_solutions + particle_database, + allowed_intermediate_particles, + n_topologies, + number_of_solutions, ): - result = qrules.generate_transitions( + reaction = qrules.generate_transitions( initial_state=("J/psi(1S)", [-1, +1]), final_state=["gamma", "pi0", "pi0"], particle_db=particle_database, @@ -36,15 +42,16 @@ def test_number_of_solutions( number_of_threads=1, formalism="helicity", ) - assert len(result.transitions) == number_of_solutions + assert len(reaction.transition_groups) == n_topologies + assert len(reaction.transitions) == number_of_solutions assert ( - result.get_intermediate_particles().names + reaction.get_intermediate_particles().names == allowed_intermediate_particles ) def test_id_to_particle_mappings(particle_database): - result = qrules.generate_transitions( + reaction = qrules.generate_transitions( initial_state=("J/psi(1S)", [-1, +1]), final_state=["gamma", "pi0", "pi0"], particle_db=particle_database, @@ -53,19 +60,22 @@ def test_id_to_particle_mappings(particle_database): number_of_threads=1, formalism="helicity", ) - assert len(result.transitions) == 4 - iter_solutions = iter(result.transitions) - first_solution = next(iter_solutions) + assert len(reaction.transition_groups) == 1 + assert len(reaction.transitions) == 4 + iter_transitions = iter(reaction.transitions) + first_transition = next(iter_transitions) + graph: StateTransitionGraph[ParticleWithSpin] = first_transition.to_graph() ref_mapping_fs = _create_edge_id_particle_mapping( - first_solution, first_solution.topology.outgoing_edge_ids + graph, graph.topology.outgoing_edge_ids ) ref_mapping_is = _create_edge_id_particle_mapping( - first_solution, first_solution.topology.incoming_edge_ids + graph, graph.topology.incoming_edge_ids ) - for solution in iter_solutions: + for transition in iter_transitions: + graph = transition.to_graph() assert ref_mapping_fs == _create_edge_id_particle_mapping( - solution, solution.topology.outgoing_edge_ids + graph, graph.topology.outgoing_edge_ids ) assert ref_mapping_is == _create_edge_id_particle_mapping( - solution, solution.topology.incoming_edge_ids + graph, graph.topology.incoming_edge_ids ) diff --git a/tests/channels/test_y_to_d0_d0bar_pi0_pi0.py b/tests/channels/test_y_to_d0_d0bar_pi0_pi0.py index d89d4fb8..c1552199 100644 --- a/tests/channels/test_y_to_d0_d0bar_pi0_pi0.py +++ b/tests/channels/test_y_to_d0_d0bar_pi0_pi0.py @@ -12,7 +12,7 @@ ], ) def test_simple(formalism, n_solutions, particle_database): - result = qrules.generate_transitions( + reaction = qrules.generate_transitions( initial_state=[("Y(4260)", [-1, +1])], final_state=["D*(2007)0", "D*(2007)~0"], particle_db=particle_database, @@ -20,7 +20,8 @@ def test_simple(formalism, n_solutions, particle_database): allowed_interaction_types="strong", number_of_threads=1, ) - assert len(result.transitions) == n_solutions + assert len(reaction.transition_groups) == 1 + assert len(reaction.transitions) == n_solutions @pytest.mark.slow() @@ -43,5 +44,6 @@ def test_full(formalism, n_solutions, particle_database): stm.set_allowed_interaction_types([InteractionType.STRONG]) stm.add_final_state_grouping([["D0", "pi0"], ["D~0", "pi0"]]) problem_sets = stm.create_problem_sets() - result = stm.find_solutions(problem_sets) - assert len(result.transitions) == n_solutions + reaction = stm.find_solutions(problem_sets) + assert len(reaction.transition_groups) == 1 + assert len(reaction.transitions) == n_solutions diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 48da0a54..2766f620 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -5,13 +5,13 @@ from _pytest.fixtures import SubRequest import qrules -from qrules import Result +from qrules import ReactionInfo logging.basicConfig(level=logging.ERROR) @pytest.fixture(scope="session", params=["canonical-helicity", "helicity"]) -def result(request: SubRequest) -> Result: +def reaction(request: SubRequest) -> ReactionInfo: formalism: str = request.param return qrules.generate_transitions( initial_state=[("J/psi(1S)", [-1, 1])], diff --git a/tests/unit/conservation_rules/test_duck_typing.py b/tests/unit/conservation_rules/test_duck_typing.py index 0127d0e5..dc703a07 100644 --- a/tests/unit/conservation_rules/test_duck_typing.py +++ b/tests/unit/conservation_rules/test_duck_typing.py @@ -1,4 +1,3 @@ -# cspell:ignore isclass """Check duck typing. Ideally, the rule input classes use a `~typing.Protocol`. This is not possible, diff --git a/tests/unit/io/test_dot.py b/tests/unit/io/test_dot.py index 4c0a0441..291f1acc 100644 --- a/tests/unit/io/test_dot.py +++ b/tests/unit/io/test_dot.py @@ -1,7 +1,7 @@ # pylint: disable=no-self-use import pydot -from qrules import Result, io +from qrules import io from qrules.io._dot import _collapse_graphs, _get_particle_graphs from qrules.particle import ParticleCollection from qrules.topology import ( @@ -10,17 +10,19 @@ create_isobar_topologies, create_n_body_topology, ) +from qrules.transition import ReactionInfo -def test_asdot(result: Result): - for transition in result.transitions: - dot_data = io.asdot(transition) +def test_asdot(reaction: ReactionInfo): + for grouping in reaction.transition_groups: + for transition in grouping: + dot_data = io.asdot(transition) assert pydot.graph_from_dot_data(dot_data) is not None - dot_data = io.asdot(result) + dot_data = io.asdot(reaction) assert pydot.graph_from_dot_data(dot_data) is not None - dot_data = io.asdot(result, strip_spin=True) + dot_data = io.asdot(reaction, strip_spin=True) assert pydot.graph_from_dot_data(dot_data) is not None - dot_data = io.asdot(result, collapse_graphs=True) + dot_data = io.asdot(reaction, collapse_graphs=True) assert pydot.graph_from_dot_data(dot_data) is not None @@ -50,30 +52,32 @@ def test_write_topology(self, output_dir): dot_data = stream.read() assert pydot.graph_from_dot_data(dot_data) is not None - def test_write_single_graph(self, output_dir: str, result: Result): - output_file = output_dir + "test_single_graph.gv" - io.write( - instance=result.transitions[0], - filename=output_file, - ) - with open(output_file, "r") as stream: - dot_data = stream.read() - assert pydot.graph_from_dot_data(dot_data) is not None + def test_write_single_graph(self, output_dir: str, reaction: ReactionInfo): + for i, transition in enumerate(reaction.transitions): + output_file = output_dir + f"test_single_graph_{i}.gv" + io.write( + instance=transition, + filename=output_file, + ) + with open(output_file, "r") as stream: + dot_data = stream.read() + assert pydot.graph_from_dot_data(dot_data) is not None - def test_write_graph_list(self, output_dir: str, result: Result): - output_file = output_dir + "test_graph_list.gv" - io.write( - instance=result.transitions, - filename=output_file, - ) - with open(output_file, "r") as stream: - dot_data = stream.read() - assert pydot.graph_from_dot_data(dot_data) is not None + def test_write_graph_list(self, output_dir: str, reaction: ReactionInfo): + for i, grouping in enumerate(reaction.transition_groups): + output_file = output_dir + f"test_graph_list_{i}.gv" + io.write( + instance=grouping, + filename=output_file, + ) + with open(output_file, "r") as stream: + dot_data = stream.read() + assert pydot.graph_from_dot_data(dot_data) is not None - def test_write_strip_spin(self, output_dir: str, result: Result): + def test_write_strip_spin(self, output_dir: str, reaction: ReactionInfo): output_file = output_dir + "test_particle_graphs.gv" io.write( - instance=io.asdot(result, strip_spin=True), + instance=io.asdot(reaction, strip_spin=True), filename=output_file, ) with open(output_file, "r") as stream: @@ -82,13 +86,13 @@ def test_write_strip_spin(self, output_dir: str, result: Result): def test_collapse_graphs( - result: Result, + reaction: ReactionInfo, particle_database: ParticleCollection, ): pdg = particle_database - particle_graphs = _get_particle_graphs(result.transitions) + particle_graphs = _get_particle_graphs(reaction.to_graphs()) assert len(particle_graphs) == 2 - collapsed_graphs = _collapse_graphs(result.transitions) + collapsed_graphs = _collapse_graphs(reaction.to_graphs()) assert len(collapsed_graphs) == 1 graph = next(iter(collapsed_graphs)) edge_id = next(iter(graph.topology.intermediate_edge_ids)) @@ -99,10 +103,10 @@ def test_collapse_graphs( def test_get_particle_graphs( - result: Result, particle_database: ParticleCollection + reaction: ReactionInfo, particle_database: ParticleCollection ): pdg = particle_database - particle_graphs = _get_particle_graphs(result.transitions) + particle_graphs = _get_particle_graphs(reaction.to_graphs()) assert len(particle_graphs) == 2 assert particle_graphs[0].get_edge_props(3) == pdg["f(0)(980)"] assert particle_graphs[1].get_edge_props(3) == pdg["f(0)(1500)"] diff --git a/tests/unit/io/test_io.py b/tests/unit/io/test_io.py index 8e3a687b..79c62ec9 100644 --- a/tests/unit/io/test_io.py +++ b/tests/unit/io/test_io.py @@ -10,7 +10,7 @@ create_isobar_topologies, create_n_body_topology, ) -from qrules.transition import Result +from qrules.transition import ReactionInfo def through_dict(instance): @@ -42,16 +42,16 @@ def test_asdict_fromdict(particle_selection: ParticleCollection): assert topology == fromdict -def test_asdict_fromdict_result(result: Result): +def test_asdict_fromdict_reaction(reaction: ReactionInfo): # StateTransitionGraph - for graph in result.transitions: + for graph in reaction.to_graphs(): fromdict = through_dict(graph) assert isinstance(fromdict, StateTransitionGraph) assert graph == fromdict - # Result - fromdict = through_dict(result) - assert isinstance(fromdict, Result) - assert result == fromdict + # ReactionInfo + fromdict = through_dict(reaction) + assert isinstance(fromdict, ReactionInfo) + assert reaction == fromdict def test_fromdict_exceptions(): diff --git a/tests/unit/test_parity_prefactor.py b/tests/unit/test_parity_prefactor.py index b9217edc..229493bd 100644 --- a/tests/unit/test_parity_prefactor.py +++ b/tests/unit/test_parity_prefactor.py @@ -68,20 +68,20 @@ def test_parity_prefactor( stm.set_allowed_interaction_types([InteractionType.EM]) problem_sets = stm.create_problem_sets() - result = stm.find_solutions(problem_sets) + reaction = stm.find_solutions(problem_sets) - for solution in result.transitions: - in_edge = [ - edge_id - for edge_id in solution.topology.edges - if solution.get_edge_props(edge_id)[0].name == ingoing_state + assert len(reaction.transition_groups) == 1 + for transition in reaction.transitions: + in_edges = [ + state_id + for state_id, state in transition.states.items() + if state.particle.name == ingoing_state ] - assert len(in_edge) == 1 - node_id = solution.topology.edges[in_edge[0]].ending_node_id + assert len(in_edges) == 1 + node_id = transition.topology.edges[in_edges[0]].ending_node_id assert isinstance(node_id, int) - assert ( relative_parity_prefactor - == solution.get_node_props(node_id).parity_prefactor + == transition.interactions[node_id].parity_prefactor ) diff --git a/tests/unit/test_solving.py b/tests/unit/test_solving.py deleted file mode 100644 index 4ec05c8c..00000000 --- a/tests/unit/test_solving.py +++ /dev/null @@ -1,8 +0,0 @@ -# pylint: disable=no-self-use -from qrules import Result - - -class TestResult: - def test_get_intermediate_state_names(self, result: Result): - intermediate_particles = result.get_intermediate_particles() - assert intermediate_particles.names == ["f(0)(980)", "f(0)(1500)"] diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py index 7cc868cd..11c7b57d 100644 --- a/tests/unit/test_topology.py +++ b/tests/unit/test_topology.py @@ -1,9 +1,10 @@ -# pylint: disable=no-self-use, redefined-outer-name, too-many-arguments +# pylint: disable=eval-used, no-self-use, redefined-outer-name, too-many-arguments # pyright: reportUnusedImport=false import typing import attr import pytest +from IPython.lib.pretty import pretty from qrules.topology import ( # noqa: F401 Edge, @@ -201,8 +202,9 @@ def test_constructor_exceptions(self, nodes, edges): ): assert Topology(nodes=nodes, edges=edges) - def test_repr_and_eq(self, two_to_three_decay: Topology): - topology = eval(str(two_to_three_decay)) # pylint: disable=eval-used + @pytest.mark.parametrize("repr_method", [repr, pretty]) + def test_repr_and_eq(self, repr_method, two_to_three_decay: Topology): + topology = eval(repr_method(two_to_three_decay)) assert topology == two_to_three_decay assert topology != float() diff --git a/tests/unit/test_transition.py b/tests/unit/test_transition.py index 49664b42..47c0f4f1 100644 --- a/tests/unit/test_transition.py +++ b/tests/unit/test_transition.py @@ -1,7 +1,192 @@ -# pylint: disable=no-self-use +# pyright: reportUnusedImport=false +# pylint: disable=eval-used, no-self-use +from operator import itemgetter +from typing import List + import pytest +from IPython.lib.pretty import pretty + +from qrules.particle import ( # noqa: F401 + Parity, + Particle, + ParticleCollection, + ParticleWithSpin, + Spin, +) +from qrules.quantum_numbers import InteractionProperties # noqa: F401 +from qrules.topology import ( # noqa: F401 + Edge, + FrozenDict, + StateTransitionGraph, + Topology, +) +from qrules.transition import State # noqa: F401 +from qrules.transition import ( + ReactionInfo, + StateTransition, + StateTransitionCollection, + StateTransitionManager, +) + + +class TestReactionInfo: + def test_properties(self, reaction: ReactionInfo): + assert reaction.initial_state[-1].name == "J/psi(1S)" + assert reaction.final_state[0].name == "gamma" + assert reaction.final_state[1].name == "pi0" + assert reaction.final_state[2].name == "pi0" + assert len(reaction.transition_groups) == 1 + for grouping in reaction.transition_groups: + assert isinstance(grouping, StateTransitionCollection) + if reaction.formalism.startswith("cano"): + assert len(reaction.transitions) == 16 + else: + assert len(reaction.transitions) == 8 + for transition in reaction.transitions: + assert isinstance(transition, StateTransition) + + @pytest.mark.parametrize("repr_method", [repr, pretty]) + def test_repr(self, repr_method, reaction: ReactionInfo): + instance = reaction + from_repr = eval(repr_method(instance)) + assert from_repr == instance + + def test_from_to_graphs(self, reaction: ReactionInfo): + graphs = reaction.to_graphs() + from_graphs = ReactionInfo.from_graphs(graphs, reaction.formalism) + assert from_graphs == reaction + + +class TestState: + @pytest.mark.parametrize( + ("state_def_1", "state_def_2"), + [ + (("a", -1), ("a", +1)), + (("a", 0), ("a", 0)), + (("a", 0), ("b", 0)), + (("a", -1), ("b", +1)), + ], + ) + def test_ordering(self, state_def_1, state_def_2): + def create_state(state_def) -> State: + return State( + particle=Particle(name=state_def[0], pid=0, spin=0, mass=0), + spin_projection=state_def[1], + ) + + state1 = create_state(state_def_1) + state2 = create_state(state_def_2) + assert state2 >= state1 + + +class TestStateTransition: + def test_ordering(self, reaction: ReactionInfo): + sorted_transitions: List[StateTransition] = sorted( + reaction.transitions + ) + if reaction.formalism.startswith("cano"): + first = sorted_transitions[0] + second = sorted_transitions[1] + assert first.interactions[0].l_magnitude == 0.0 + assert second.interactions[0].l_magnitude == 2.0 + assert first.interactions[1] == second.interactions[1] + transition_selection = sorted_transitions[::2] + else: + transition_selection = sorted_transitions + + simplified_rendering = [ + tuple( + ( + transition.states[state_id].particle.name, + int(transition.states[state_id].spin_projection), + ) + for state_id in sorted(transition.states) + ) + for transition in transition_selection + ] + + assert simplified_rendering[:3] == [ + ( + ("J/psi(1S)", -1), + ("gamma", -1), + ("pi0", 0), + ("pi0", 0), + ("f(0)(980)", 0), + ), + ( + ("J/psi(1S)", -1), + ("gamma", -1), + ("pi0", 0), + ("pi0", 0), + ("f(0)(1500)", 0), + ), + ( + ("J/psi(1S)", -1), + ("gamma", +1), + ("pi0", 0), + ("pi0", 0), + ("f(0)(980)", 0), + ), + ] + assert simplified_rendering[-1] == ( + ("J/psi(1S)", +1), + ("gamma", +1), + ("pi0", 0), + ("pi0", 0), + ("f(0)(1500)", 0), + ) + + # J/psi + first_half = slice(0, int(len(simplified_rendering) / 2)) + for item in simplified_rendering[first_half]: + assert item[0] == ("J/psi(1S)", -1) + second_half = slice(int(len(simplified_rendering) / 2), None) + for item in simplified_rendering[second_half]: + assert item[0] == ("J/psi(1S)", +1) + second_half = slice(int(len(simplified_rendering) / 2), None) + # gamma + for item in itemgetter(0, 1, 4, 5)(simplified_rendering): + assert item[1] == ("gamma", -1) + for item in itemgetter(2, 3, 6, 7)(simplified_rendering): + assert item[1] == ("gamma", +1) + # pi0 + for item in simplified_rendering: + assert item[2] == ("pi0", 0) + assert item[3] == ("pi0", 0) + # f0 + for item in simplified_rendering[::2]: + assert item[4] == ("f(0)(980)", 0) + for item in simplified_rendering[1::2]: + assert item[4] == ("f(0)(1500)", 0) + + @pytest.mark.parametrize("repr_method", [repr, pretty]) + def test_repr(self, repr_method, reaction: ReactionInfo): + for instance in reaction.transitions: + from_repr = eval(repr_method(instance)) + assert from_repr == instance + + def test_from_to_graph(self, reaction: ReactionInfo): + assert len(reaction.transition_groups) == 1 + assert len(reaction.transitions) in {8, 16} + for transition in reaction.transitions: + graph = transition.to_graph() + from_graph = StateTransition.from_graph(graph) + assert transition == from_graph + + +class TestStateTransitionCollection: + @pytest.mark.parametrize("repr_method", [repr, pretty]) + def test_repr(self, reaction: ReactionInfo, repr_method): + for instance in reaction.transition_groups: + from_repr = eval(repr_method(instance)) + assert from_repr == instance -from qrules.transition import StateTransitionManager + def test_from_to_graphs(self, reaction: ReactionInfo): + assert len(reaction.transition_groups) == 1 + transition_grouping = reaction.transition_groups[0] + graphs = transition_grouping.to_graphs() + from_graphs = StateTransitionCollection.from_graphs(graphs) + assert transition_grouping == from_graphs class TestStateTransitionManager: