Skip to content

Commit

Permalink
MAINT: import v0.15 refactorings into v0.14 (#399)
Browse files Browse the repository at this point in the history
* FEAT: expose `get_outer_state_ids()` function
* FEAT: expose amplitude naming functions
* FEAT: extract `create_four_momentum_symbol()`
  • Loading branch information
redeboer authored Mar 1, 2024
1 parent dc92284 commit 1e60e40
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 66 deletions.
76 changes: 14 additions & 62 deletions src/ampform/helicity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections import OrderedDict, abc
from decimal import Decimal
from difflib import get_close_matches
from functools import reduce, singledispatch
from functools import reduce
from typing import (
TYPE_CHECKING,
Generator,
Expand Down Expand Up @@ -45,6 +45,7 @@
)
from ampform.helicity.decay import (
TwoBodyDecay,
get_outer_state_ids,
get_parent_id,
get_prefactor,
get_sibling_state_id,
Expand All @@ -56,10 +57,13 @@
CanonicalAmplitudeNameGenerator,
HelicityAmplitudeNameGenerator,
NameGenerator,
create_amplitude_base,
create_amplitude_symbol,
create_helicity_symbol,
create_spin_projection_symbol,
generate_transition_label,
get_helicity_angle_symbols,
get_helicity_suffix,
get_topology_identifier,
natural_sorting,
)
from ampform.kinematics import HelicityAdapter
Expand Down Expand Up @@ -595,7 +599,7 @@ def formulate(self) -> HelicityModel:
)

def __formulate_top_expression(self) -> PoolSum:
outer_state_ids = _get_outer_state_ids(self.__reaction)
outer_state_ids = get_outer_state_ids(self.__reaction)
spin_projections: collections.defaultdict[sp.Symbol, set[sp.Rational]] = (
collections.defaultdict(set)
)
Expand All @@ -605,7 +609,7 @@ def __formulate_top_expression(self) -> PoolSum:
for transition in group:
for i in outer_state_ids:
state = transition.states[i]
symbol = _create_spin_projection_symbol(i)
symbol = create_spin_projection_symbol(i)
value = sp.Rational(state.spin_projection)
spin_projections[symbol].add(value)

Expand All @@ -615,21 +619,20 @@ def __formulate_top_expression(self) -> PoolSum:
else:
indices = list(spin_projections)
amplitude = sum( # type: ignore[assignment]
_create_amplitude_base(topology)[indices]
for topology in topology_groups
create_amplitude_base(topology)[indices] for topology in topology_groups
)
return PoolSum(abs(amplitude) ** 2, *spin_projections.items())

def __formulate_aligned_amplitude(
self, topology_groups: dict[Topology, list[StateTransition]]
) -> sp.Expr:
outer_state_ids = _get_outer_state_ids(self.__reaction)
outer_state_ids = get_outer_state_ids(self.__reaction)
amplitude = sp.S.Zero
for topology, transitions in topology_groups.items():
base = _create_amplitude_base(topology)
base = create_amplitude_base(topology)
helicities = [
_get_opposite_helicity_sign(topology, i)
* _create_helicity_symbol(topology, i)
* create_helicity_symbol(topology, i)
for i in outer_state_ids
]
amplitude_symbol = base[helicities]
Expand Down Expand Up @@ -666,7 +669,7 @@ def __formulate_topology_amplitude(
sequential_expressions.append(expression)

first_transition = transitions[0]
symbol = _create_amplitude_symbol(first_transition)
symbol = create_amplitude_symbol(first_transition)
expression = sum(sequential_expressions) # type: ignore[assignment]
self.__ingredients.amplitudes[symbol] = expression
return expression
Expand Down Expand Up @@ -765,63 +768,12 @@ def __generate_amplitude_prefactor(
return None


def _create_amplitude_symbol(transition: StateTransition) -> sp.Indexed:
outer_state_ids = _get_outer_state_ids(transition)
helicities = tuple(
sp.Rational(transition.states[i].spin_projection) for i in outer_state_ids
)
base = _create_amplitude_base(transition.topology)
return base[helicities]


def _get_opposite_helicity_sign(topology: Topology, state_id: int) -> Literal[-1, 1]:
if state_id != -1 and is_opposite_helicity_state(topology, state_id):
return -1
return 1


def _create_amplitude_base(topology: Topology) -> sp.IndexedBase:
superscript = get_topology_identifier(topology)
return sp.IndexedBase(f"A^{superscript}", complex=True)


def _create_helicity_symbol(
topology: Topology, state_id: int, root: str = "lambda"
) -> sp.Symbol:
if state_id == -1: # initial state
name = "m_A"
else:
suffix = get_helicity_suffix(topology, state_id)
name = f"{root}{suffix}"
return sp.Symbol(name, rational=True)


def _create_spin_projection_symbol(state_id: int) -> sp.Symbol:
if state_id == -1: # initial state
suffix = "_A"
else:
suffix = str(state_id)
return sp.Symbol(f"m{suffix}", rational=True)


@singledispatch
def _get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]:
msg = f"Cannot get outer state IDs from a {type(obj).__name__}"
raise NotImplementedError(msg)


@_get_outer_state_ids.register(StateTransition)
def _(transition: StateTransition) -> list[int]:
outer_state_ids = list(transition.initial_states)
outer_state_ids += sorted(transition.final_states)
return outer_state_ids


@_get_outer_state_ids.register(ReactionInfo)
def _(reaction: ReactionInfo) -> list[int]:
return _get_outer_state_ids(reaction.transitions[0])


class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder):
r"""Amplitude model generator for the canonical helicity formalism.
Expand Down Expand Up @@ -1033,7 +985,7 @@ def formulate_rotation_chain(
plus a Wigner rotation (see :func:`.formulate_wigner_rotation`) in case there is
more than one helicity rotation.
"""
helicity_symbol = _create_spin_projection_symbol(rotated_state_id)
helicity_symbol = create_spin_projection_symbol(rotated_state_id)
helicity_rotations = formulate_helicity_rotation_chain(
transition, rotated_state_id, helicity_symbol
)
Expand Down
20 changes: 19 additions & 1 deletion src/ampform/helicity/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Iterable

from attrs import frozen
from qrules.transition import State, StateTransition
from qrules.transition import ReactionInfo, State, StateTransition

if TYPE_CHECKING:
from qrules.quantum_numbers import InteractionProperties
Expand Down Expand Up @@ -295,6 +295,24 @@ def determine_attached_final_state(topology: Topology, state_id: int) -> list[in
return sorted(topology.get_originating_final_state_edge_ids(edge.ending_node_id))


@singledispatch
def get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]:
msg = f"Cannot get outer state IDs from a {type(obj).__name__}"
raise NotImplementedError(msg)


@get_outer_state_ids.register(StateTransition)
def _(transition: StateTransition) -> list[int]:
outer_state_ids = list(transition.initial_states)
outer_state_ids += sorted(transition.final_states)
return outer_state_ids


@get_outer_state_ids.register(ReactionInfo)
def _(reaction: ReactionInfo) -> list[int]:
return get_outer_state_ids(reaction.transitions[0])


def get_prefactor(transition: StateTransition) -> float:
"""Calculate the product of all prefactors defined in this transition.
Expand Down
34 changes: 34 additions & 0 deletions src/ampform/helicity/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
assert_isobar_topology,
determine_attached_final_state,
get_helicity_info,
get_outer_state_ids,
get_sorted_states,
)

Expand Down Expand Up @@ -287,6 +288,20 @@ def __generate_ls_arrow(transition: StateTransition, node_id: int) -> str:
return Rf" \xrightarrow[S={coupled_spin}]{{L={angular_momentum}}} "


def create_amplitude_symbol(transition: StateTransition) -> sp.Indexed:
outer_state_ids = get_outer_state_ids(transition)
helicities = tuple(
sp.Rational(transition.states[i].spin_projection) for i in outer_state_ids
)
base = create_amplitude_base(transition.topology)
return base[helicities]


def create_amplitude_base(topology: Topology) -> sp.IndexedBase:
superscript = get_topology_identifier(topology)
return sp.IndexedBase(f"A^{superscript}", complex=True)


def generate_transition_label(transition: StateTransition) -> str:
r"""Generate a label for a coherent intensity, including spin projection.
Expand Down Expand Up @@ -495,3 +510,22 @@ def _render_float(value: float) -> str:
if value > 0:
return f"+{rational}"
return str(rational)


def create_helicity_symbol(
topology: Topology, state_id: int, root: str = "lambda"
) -> sp.Symbol:
if state_id == -1: # initial state
name = "m_A"
else:
suffix = get_helicity_suffix(topology, state_id)
name = f"{root}{suffix}"
return sp.Symbol(name, rational=True)


def create_spin_projection_symbol(state_id: int) -> sp.Symbol:
if state_id == -1: # initial state
suffix = "_A"
else:
suffix = str(state_id)
return sp.Symbol(f"m{suffix}", rational=True)
10 changes: 7 additions & 3 deletions src/ampform/kinematics/lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ def create_four_momentum_symbols(topology: Topology) -> FourMomenta:
>>> create_four_momentum_symbols(topologies[0])
{0: p0, 1: p1, 2: p2}
"""
n_final_states = len(topology.outgoing_edge_ids)
return {i: FourMomentumSymbol(f"p{i}", shape=[]) for i in range(n_final_states)}
final_state_ids = sorted(topology.outgoing_edge_ids)
return {i: create_four_momentum_symbol(i) for i in final_state_ids}


def create_four_momentum_symbol(index: int) -> FourMomentumSymbol:
return FourMomentumSymbol(f"p{index}", shape=[])


FourMomenta = Dict[int, "FourMomentumSymbol"]
"""A mapping of state IDs to their corresponding `FourMomentumSymbol`.
"""A mapping of state IDs to their corresponding `.FourMomentumSymbol`.
It's best to create a `dict` of `.FourMomenta` with
:func:`create_four_momentum_symbols`.
Expand Down

0 comments on commit 1e60e40

Please sign in to comment.