From 885e3bacad8136ffa51a37ee467ba45cbab2864d Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 1 Mar 2024 22:06:21 +0100 Subject: [PATCH] FEAT: expose `get_outer_state_ids()` function --- src/ampform/helicity/__init__.py | 27 +++++---------------------- src/ampform/helicity/decay.py | 20 +++++++++++++++++++- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/ampform/helicity/__init__.py b/src/ampform/helicity/__init__.py index bb16453aa..843f88fc2 100644 --- a/src/ampform/helicity/__init__.py +++ b/src/ampform/helicity/__init__.py @@ -14,7 +14,7 @@ from collections import OrderedDict, abc from decimal import Decimal from difflib import get_close_matches -from functools import reduce, singledispatch +from functools import reduce from typing import ( TYPE_CHECKING, Generator, @@ -45,6 +45,7 @@ ) from ampform.helicity.decay import ( TwoBodyDecay, + get_outer_state_ids, get_parent_id, get_prefactor, get_sibling_state_id, @@ -595,7 +596,7 @@ def formulate(self) -> HelicityModel: ) def __formulate_top_expression(self) -> PoolSum: - outer_state_ids = _get_outer_state_ids(self.__reaction) + outer_state_ids = get_outer_state_ids(self.__reaction) spin_projections: collections.defaultdict[sp.Symbol, set[sp.Rational]] = ( collections.defaultdict(set) ) @@ -623,7 +624,7 @@ def __formulate_top_expression(self) -> PoolSum: def __formulate_aligned_amplitude( self, topology_groups: dict[Topology, list[StateTransition]] ) -> sp.Expr: - outer_state_ids = _get_outer_state_ids(self.__reaction) + outer_state_ids = get_outer_state_ids(self.__reaction) amplitude = sp.S.Zero for topology, transitions in topology_groups.items(): base = _create_amplitude_base(topology) @@ -766,7 +767,7 @@ def __generate_amplitude_prefactor( def _create_amplitude_symbol(transition: StateTransition) -> sp.Indexed: - outer_state_ids = _get_outer_state_ids(transition) + outer_state_ids = get_outer_state_ids(transition) helicities = tuple( sp.Rational(transition.states[i].spin_projection) for i in outer_state_ids ) @@ -804,24 +805,6 @@ def _create_spin_projection_symbol(state_id: int) -> sp.Symbol: return sp.Symbol(f"m{suffix}", rational=True) -@singledispatch -def _get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]: - msg = f"Cannot get outer state IDs from a {type(obj).__name__}" - raise NotImplementedError(msg) - - -@_get_outer_state_ids.register(StateTransition) -def _(transition: StateTransition) -> list[int]: - outer_state_ids = list(transition.initial_states) - outer_state_ids += sorted(transition.final_states) - return outer_state_ids - - -@_get_outer_state_ids.register(ReactionInfo) -def _(reaction: ReactionInfo) -> list[int]: - return _get_outer_state_ids(reaction.transitions[0]) - - class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder): r"""Amplitude model generator for the canonical helicity formalism. diff --git a/src/ampform/helicity/decay.py b/src/ampform/helicity/decay.py index bb7b7db05..76961f87a 100644 --- a/src/ampform/helicity/decay.py +++ b/src/ampform/helicity/decay.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Iterable from attrs import frozen -from qrules.transition import State, StateTransition +from qrules.transition import ReactionInfo, State, StateTransition if TYPE_CHECKING: from qrules.quantum_numbers import InteractionProperties @@ -295,6 +295,24 @@ def determine_attached_final_state(topology: Topology, state_id: int) -> list[in return sorted(topology.get_originating_final_state_edge_ids(edge.ending_node_id)) +@singledispatch +def get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]: + msg = f"Cannot get outer state IDs from a {type(obj).__name__}" + raise NotImplementedError(msg) + + +@get_outer_state_ids.register(StateTransition) +def _(transition: StateTransition) -> list[int]: + outer_state_ids = list(transition.initial_states) + outer_state_ids += sorted(transition.final_states) + return outer_state_ids + + +@get_outer_state_ids.register(ReactionInfo) +def _(reaction: ReactionInfo) -> list[int]: + return get_outer_state_ids(reaction.transitions[0]) + + def get_prefactor(transition: StateTransition) -> float: """Calculate the product of all prefactors defined in this transition.