From 1e60e40a50e872fcfaa6b5ed74f79b19a52ac2e3 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 1 Mar 2024 22:20:45 +0100 Subject: [PATCH] MAINT: import v0.15 refactorings into v0.14 (#399) * FEAT: expose `get_outer_state_ids()` function * FEAT: expose amplitude naming functions * FEAT: extract `create_four_momentum_symbol()` --- src/ampform/helicity/__init__.py | 76 ++++++------------------------- src/ampform/helicity/decay.py | 20 +++++++- src/ampform/helicity/naming.py | 34 ++++++++++++++ src/ampform/kinematics/lorentz.py | 10 ++-- 4 files changed, 74 insertions(+), 66 deletions(-) diff --git a/src/ampform/helicity/__init__.py b/src/ampform/helicity/__init__.py index bb16453aa..3291b7368 100644 --- a/src/ampform/helicity/__init__.py +++ b/src/ampform/helicity/__init__.py @@ -14,7 +14,7 @@ from collections import OrderedDict, abc from decimal import Decimal from difflib import get_close_matches -from functools import reduce, singledispatch +from functools import reduce from typing import ( TYPE_CHECKING, Generator, @@ -45,6 +45,7 @@ ) from ampform.helicity.decay import ( TwoBodyDecay, + get_outer_state_ids, get_parent_id, get_prefactor, get_sibling_state_id, @@ -56,10 +57,13 @@ CanonicalAmplitudeNameGenerator, HelicityAmplitudeNameGenerator, NameGenerator, + create_amplitude_base, + create_amplitude_symbol, + create_helicity_symbol, + create_spin_projection_symbol, generate_transition_label, get_helicity_angle_symbols, get_helicity_suffix, - get_topology_identifier, natural_sorting, ) from ampform.kinematics import HelicityAdapter @@ -595,7 +599,7 @@ def formulate(self) -> HelicityModel: ) def __formulate_top_expression(self) -> PoolSum: - outer_state_ids = _get_outer_state_ids(self.__reaction) + outer_state_ids = get_outer_state_ids(self.__reaction) spin_projections: collections.defaultdict[sp.Symbol, set[sp.Rational]] = ( collections.defaultdict(set) ) @@ -605,7 +609,7 @@ def __formulate_top_expression(self) -> PoolSum: for transition in group: for i in outer_state_ids: state = transition.states[i] - symbol = _create_spin_projection_symbol(i) + symbol = create_spin_projection_symbol(i) value = sp.Rational(state.spin_projection) spin_projections[symbol].add(value) @@ -615,21 +619,20 @@ def __formulate_top_expression(self) -> PoolSum: else: indices = list(spin_projections) amplitude = sum( # type: ignore[assignment] - _create_amplitude_base(topology)[indices] - for topology in topology_groups + create_amplitude_base(topology)[indices] for topology in topology_groups ) return PoolSum(abs(amplitude) ** 2, *spin_projections.items()) def __formulate_aligned_amplitude( self, topology_groups: dict[Topology, list[StateTransition]] ) -> sp.Expr: - outer_state_ids = _get_outer_state_ids(self.__reaction) + outer_state_ids = get_outer_state_ids(self.__reaction) amplitude = sp.S.Zero for topology, transitions in topology_groups.items(): - base = _create_amplitude_base(topology) + base = create_amplitude_base(topology) helicities = [ _get_opposite_helicity_sign(topology, i) - * _create_helicity_symbol(topology, i) + * create_helicity_symbol(topology, i) for i in outer_state_ids ] amplitude_symbol = base[helicities] @@ -666,7 +669,7 @@ def __formulate_topology_amplitude( sequential_expressions.append(expression) first_transition = transitions[0] - symbol = _create_amplitude_symbol(first_transition) + symbol = create_amplitude_symbol(first_transition) expression = sum(sequential_expressions) # type: ignore[assignment] self.__ingredients.amplitudes[symbol] = expression return expression @@ -765,63 +768,12 @@ def __generate_amplitude_prefactor( return None -def _create_amplitude_symbol(transition: StateTransition) -> sp.Indexed: - outer_state_ids = _get_outer_state_ids(transition) - helicities = tuple( - sp.Rational(transition.states[i].spin_projection) for i in outer_state_ids - ) - base = _create_amplitude_base(transition.topology) - return base[helicities] - - def _get_opposite_helicity_sign(topology: Topology, state_id: int) -> Literal[-1, 1]: if state_id != -1 and is_opposite_helicity_state(topology, state_id): return -1 return 1 -def _create_amplitude_base(topology: Topology) -> sp.IndexedBase: - superscript = get_topology_identifier(topology) - return sp.IndexedBase(f"A^{superscript}", complex=True) - - -def _create_helicity_symbol( - topology: Topology, state_id: int, root: str = "lambda" -) -> sp.Symbol: - if state_id == -1: # initial state - name = "m_A" - else: - suffix = get_helicity_suffix(topology, state_id) - name = f"{root}{suffix}" - return sp.Symbol(name, rational=True) - - -def _create_spin_projection_symbol(state_id: int) -> sp.Symbol: - if state_id == -1: # initial state - suffix = "_A" - else: - suffix = str(state_id) - return sp.Symbol(f"m{suffix}", rational=True) - - -@singledispatch -def _get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]: - msg = f"Cannot get outer state IDs from a {type(obj).__name__}" - raise NotImplementedError(msg) - - -@_get_outer_state_ids.register(StateTransition) -def _(transition: StateTransition) -> list[int]: - outer_state_ids = list(transition.initial_states) - outer_state_ids += sorted(transition.final_states) - return outer_state_ids - - -@_get_outer_state_ids.register(ReactionInfo) -def _(reaction: ReactionInfo) -> list[int]: - return _get_outer_state_ids(reaction.transitions[0]) - - class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder): r"""Amplitude model generator for the canonical helicity formalism. @@ -1033,7 +985,7 @@ def formulate_rotation_chain( plus a Wigner rotation (see :func:`.formulate_wigner_rotation`) in case there is more than one helicity rotation. """ - helicity_symbol = _create_spin_projection_symbol(rotated_state_id) + helicity_symbol = create_spin_projection_symbol(rotated_state_id) helicity_rotations = formulate_helicity_rotation_chain( transition, rotated_state_id, helicity_symbol ) diff --git a/src/ampform/helicity/decay.py b/src/ampform/helicity/decay.py index bb7b7db05..76961f87a 100644 --- a/src/ampform/helicity/decay.py +++ b/src/ampform/helicity/decay.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Iterable from attrs import frozen -from qrules.transition import State, StateTransition +from qrules.transition import ReactionInfo, State, StateTransition if TYPE_CHECKING: from qrules.quantum_numbers import InteractionProperties @@ -295,6 +295,24 @@ def determine_attached_final_state(topology: Topology, state_id: int) -> list[in return sorted(topology.get_originating_final_state_edge_ids(edge.ending_node_id)) +@singledispatch +def get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]: + msg = f"Cannot get outer state IDs from a {type(obj).__name__}" + raise NotImplementedError(msg) + + +@get_outer_state_ids.register(StateTransition) +def _(transition: StateTransition) -> list[int]: + outer_state_ids = list(transition.initial_states) + outer_state_ids += sorted(transition.final_states) + return outer_state_ids + + +@get_outer_state_ids.register(ReactionInfo) +def _(reaction: ReactionInfo) -> list[int]: + return get_outer_state_ids(reaction.transitions[0]) + + def get_prefactor(transition: StateTransition) -> float: """Calculate the product of all prefactors defined in this transition. diff --git a/src/ampform/helicity/naming.py b/src/ampform/helicity/naming.py index cce42229e..5b6f5bbee 100644 --- a/src/ampform/helicity/naming.py +++ b/src/ampform/helicity/naming.py @@ -15,6 +15,7 @@ assert_isobar_topology, determine_attached_final_state, get_helicity_info, + get_outer_state_ids, get_sorted_states, ) @@ -287,6 +288,20 @@ def __generate_ls_arrow(transition: StateTransition, node_id: int) -> str: return Rf" \xrightarrow[S={coupled_spin}]{{L={angular_momentum}}} " +def create_amplitude_symbol(transition: StateTransition) -> sp.Indexed: + outer_state_ids = get_outer_state_ids(transition) + helicities = tuple( + sp.Rational(transition.states[i].spin_projection) for i in outer_state_ids + ) + base = create_amplitude_base(transition.topology) + return base[helicities] + + +def create_amplitude_base(topology: Topology) -> sp.IndexedBase: + superscript = get_topology_identifier(topology) + return sp.IndexedBase(f"A^{superscript}", complex=True) + + def generate_transition_label(transition: StateTransition) -> str: r"""Generate a label for a coherent intensity, including spin projection. @@ -495,3 +510,22 @@ def _render_float(value: float) -> str: if value > 0: return f"+{rational}" return str(rational) + + +def create_helicity_symbol( + topology: Topology, state_id: int, root: str = "lambda" +) -> sp.Symbol: + if state_id == -1: # initial state + name = "m_A" + else: + suffix = get_helicity_suffix(topology, state_id) + name = f"{root}{suffix}" + return sp.Symbol(name, rational=True) + + +def create_spin_projection_symbol(state_id: int) -> sp.Symbol: + if state_id == -1: # initial state + suffix = "_A" + else: + suffix = str(state_id) + return sp.Symbol(f"m{suffix}", rational=True) diff --git a/src/ampform/kinematics/lorentz.py b/src/ampform/kinematics/lorentz.py index 20f7f0b61..a03a4da1a 100644 --- a/src/ampform/kinematics/lorentz.py +++ b/src/ampform/kinematics/lorentz.py @@ -31,12 +31,16 @@ def create_four_momentum_symbols(topology: Topology) -> FourMomenta: >>> create_four_momentum_symbols(topologies[0]) {0: p0, 1: p1, 2: p2} """ - n_final_states = len(topology.outgoing_edge_ids) - return {i: FourMomentumSymbol(f"p{i}", shape=[]) for i in range(n_final_states)} + final_state_ids = sorted(topology.outgoing_edge_ids) + return {i: create_four_momentum_symbol(i) for i in final_state_ids} + + +def create_four_momentum_symbol(index: int) -> FourMomentumSymbol: + return FourMomentumSymbol(f"p{index}", shape=[]) FourMomenta = Dict[int, "FourMomentumSymbol"] -"""A mapping of state IDs to their corresponding `FourMomentumSymbol`. +"""A mapping of state IDs to their corresponding `.FourMomentumSymbol`. It's best to create a `dict` of `.FourMomenta` with :func:`create_four_momentum_symbols`.