Skip to content

Commit

Permalink
FEAT: expose get_outer_state_ids() function
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Mar 1, 2024
1 parent 8e10424 commit 885e3ba
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
27 changes: 5 additions & 22 deletions src/ampform/helicity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections import OrderedDict, abc
from decimal import Decimal
from difflib import get_close_matches
from functools import reduce, singledispatch
from functools import reduce
from typing import (
TYPE_CHECKING,
Generator,
Expand Down Expand Up @@ -45,6 +45,7 @@
)
from ampform.helicity.decay import (
TwoBodyDecay,
get_outer_state_ids,
get_parent_id,
get_prefactor,
get_sibling_state_id,
Expand Down Expand Up @@ -595,7 +596,7 @@ def formulate(self) -> HelicityModel:
)

def __formulate_top_expression(self) -> PoolSum:
outer_state_ids = _get_outer_state_ids(self.__reaction)
outer_state_ids = get_outer_state_ids(self.__reaction)
spin_projections: collections.defaultdict[sp.Symbol, set[sp.Rational]] = (
collections.defaultdict(set)
)
Expand Down Expand Up @@ -623,7 +624,7 @@ def __formulate_top_expression(self) -> PoolSum:
def __formulate_aligned_amplitude(
self, topology_groups: dict[Topology, list[StateTransition]]
) -> sp.Expr:
outer_state_ids = _get_outer_state_ids(self.__reaction)
outer_state_ids = get_outer_state_ids(self.__reaction)
amplitude = sp.S.Zero
for topology, transitions in topology_groups.items():
base = _create_amplitude_base(topology)
Expand Down Expand Up @@ -766,7 +767,7 @@ def __generate_amplitude_prefactor(


def _create_amplitude_symbol(transition: StateTransition) -> sp.Indexed:
outer_state_ids = _get_outer_state_ids(transition)
outer_state_ids = get_outer_state_ids(transition)
helicities = tuple(
sp.Rational(transition.states[i].spin_projection) for i in outer_state_ids
)
Expand Down Expand Up @@ -804,24 +805,6 @@ def _create_spin_projection_symbol(state_id: int) -> sp.Symbol:
return sp.Symbol(f"m{suffix}", rational=True)


@singledispatch
def _get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]:
msg = f"Cannot get outer state IDs from a {type(obj).__name__}"
raise NotImplementedError(msg)


@_get_outer_state_ids.register(StateTransition)
def _(transition: StateTransition) -> list[int]:
outer_state_ids = list(transition.initial_states)
outer_state_ids += sorted(transition.final_states)
return outer_state_ids


@_get_outer_state_ids.register(ReactionInfo)
def _(reaction: ReactionInfo) -> list[int]:
return _get_outer_state_ids(reaction.transitions[0])


class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder):
r"""Amplitude model generator for the canonical helicity formalism.
Expand Down
20 changes: 19 additions & 1 deletion src/ampform/helicity/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Iterable

from attrs import frozen
from qrules.transition import State, StateTransition
from qrules.transition import ReactionInfo, State, StateTransition

if TYPE_CHECKING:
from qrules.quantum_numbers import InteractionProperties
Expand Down Expand Up @@ -295,6 +295,24 @@ def determine_attached_final_state(topology: Topology, state_id: int) -> list[in
return sorted(topology.get_originating_final_state_edge_ids(edge.ending_node_id))


@singledispatch
def get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]:
msg = f"Cannot get outer state IDs from a {type(obj).__name__}"
raise NotImplementedError(msg)


@get_outer_state_ids.register(StateTransition)
def _(transition: StateTransition) -> list[int]:
outer_state_ids = list(transition.initial_states)
outer_state_ids += sorted(transition.final_states)
return outer_state_ids


@get_outer_state_ids.register(ReactionInfo)
def _(reaction: ReactionInfo) -> list[int]:
return get_outer_state_ids(reaction.transitions[0])


def get_prefactor(transition: StateTransition) -> float:
"""Calculate the product of all prefactors defined in this transition.
Expand Down

0 comments on commit 885e3ba

Please sign in to comment.