Skip to content

Commit

Permalink
fix: match final state IDs with final state order (#145)
Browse files Browse the repository at this point in the history
* fix: switch order ABCs and Generic
  This is a problem identified by newer versions of Pyright
* fix: match final state IDs in ReactionInfo to argument order
* test: assert final state ID matches order
  • Loading branch information
redeboer authored Jan 28, 2022
1 parent 8ea1be2 commit abd875d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/qrules/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __lt__(self, other: Any) -> bool:

@total_ordering
class FrozenDict( # pylint: disable=too-many-ancestors
Generic[KeyType, ValueType], abc.Hashable, abc.Mapping
abc.Hashable, abc.Mapping, Generic[KeyType, ValueType]
):
def __init__(self, mapping: Optional[Mapping] = None):
self.__mapping: Dict[KeyType, ValueType] = {}
Expand Down
36 changes: 36 additions & 0 deletions src/qrules/transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,10 @@ def find_solutions( # pylint: disable=too-many-branches
raise ValueError("No solutions were found")

match_external_edges(final_solutions)
final_solutions = [
_match_final_state_ids(graph, self.final_state)
for graph in final_solutions
]
return ReactionInfo.from_graphs(final_solutions, self.formalism)

def _solve(
Expand Down Expand Up @@ -699,6 +703,38 @@ def _safe_wrap_list(
)


def _match_final_state_ids(
graph: StateTransitionGraph[ParticleWithSpin],
state_definition: Sequence[StateDefinition],
) -> StateTransitionGraph[ParticleWithSpin]:
"""Temporary fix to https://github.com/ComPWA/qrules/issues/143."""
particle_names = _strip_spin(state_definition)
name_to_id = {name: i for i, name in enumerate(particle_names)}
id_remapping = {
name_to_id[graph.get_edge_props(i)[0].name]: i
for i in graph.topology.outgoing_edge_ids
}
new_topology = graph.topology.relabel_edges(id_remapping)
return StateTransitionGraph(
new_topology,
edge_props={
i: graph.get_edge_props(id_remapping.get(i, i))
for i in graph.topology.edges
},
node_props={i: graph.get_node_props(i) for i in graph.topology.nodes},
)


def _strip_spin(state_definition: Sequence[StateDefinition]) -> List[str]:
particle_names = []
for state in state_definition:
if isinstance(state, str):
particle_names.append(state)
else:
particle_names.append(state[0])
return particle_names


@implement_pretty_repr()
@attr.frozen(order=True)
class State:
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/test_qrules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

from qrules import generate_transitions


@pytest.mark.parametrize(
"resonance_names",
[
["Sigma(1660)~-"],
["N(1650)+"],
["K*(1680)~0"],
["Sigma(1660)~-", "N(1650)+"],
["Sigma(1660)~-", "K*(1680)~0"],
["N(1650)+", "K*(1680)~0"],
["Sigma(1660)~-", "N(1650)+", "K*(1680)~0"],
],
)
def test_generate_transitions(resonance_names):
final_state_names = ["K0", "Sigma+", "p~"]
reaction = generate_transitions(
initial_state="J/psi(1S)",
final_state=final_state_names,
allowed_intermediate_particles=resonance_names,
allowed_interaction_types="strong",
)
assert len(reaction.transition_groups) == len(resonance_names)
final_state = dict(enumerate(final_state_names))
for transition in reaction.transitions:
this_final_state = {
i: state.particle.name
for i, state in transition.final_states.items()
}
assert final_state == this_final_state

0 comments on commit abd875d

Please sign in to comment.