Skip to content

Commit

Permalink
FEAT: expose amplitude naming functions
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Mar 1, 2024
1 parent 885e3ba commit 25a1979
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 41 deletions.
51 changes: 10 additions & 41 deletions src/ampform/helicity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,13 @@
CanonicalAmplitudeNameGenerator,
HelicityAmplitudeNameGenerator,
NameGenerator,
create_amplitude_base,
create_amplitude_symbol,
create_helicity_symbol,
create_spin_projection_symbol,
generate_transition_label,
get_helicity_angle_symbols,
get_helicity_suffix,
get_topology_identifier,
natural_sorting,
)
from ampform.kinematics import HelicityAdapter
Expand Down Expand Up @@ -606,7 +609,7 @@ def __formulate_top_expression(self) -> PoolSum:
for transition in group:
for i in outer_state_ids:
state = transition.states[i]
symbol = _create_spin_projection_symbol(i)
symbol = create_spin_projection_symbol(i)
value = sp.Rational(state.spin_projection)
spin_projections[symbol].add(value)

Expand All @@ -616,8 +619,7 @@ def __formulate_top_expression(self) -> PoolSum:
else:
indices = list(spin_projections)
amplitude = sum( # type: ignore[assignment]
_create_amplitude_base(topology)[indices]
for topology in topology_groups
create_amplitude_base(topology)[indices] for topology in topology_groups
)
return PoolSum(abs(amplitude) ** 2, *spin_projections.items())

Expand All @@ -627,10 +629,10 @@ def __formulate_aligned_amplitude(
outer_state_ids = get_outer_state_ids(self.__reaction)
amplitude = sp.S.Zero
for topology, transitions in topology_groups.items():
base = _create_amplitude_base(topology)
base = create_amplitude_base(topology)
helicities = [
_get_opposite_helicity_sign(topology, i)
* _create_helicity_symbol(topology, i)
* create_helicity_symbol(topology, i)
for i in outer_state_ids
]
amplitude_symbol = base[helicities]
Expand Down Expand Up @@ -667,7 +669,7 @@ def __formulate_topology_amplitude(
sequential_expressions.append(expression)

first_transition = transitions[0]
symbol = _create_amplitude_symbol(first_transition)
symbol = create_amplitude_symbol(first_transition)
expression = sum(sequential_expressions) # type: ignore[assignment]
self.__ingredients.amplitudes[symbol] = expression
return expression
Expand Down Expand Up @@ -766,45 +768,12 @@ def __generate_amplitude_prefactor(
return None


def _create_amplitude_symbol(transition: StateTransition) -> sp.Indexed:
outer_state_ids = get_outer_state_ids(transition)
helicities = tuple(
sp.Rational(transition.states[i].spin_projection) for i in outer_state_ids
)
base = _create_amplitude_base(transition.topology)
return base[helicities]


def _get_opposite_helicity_sign(topology: Topology, state_id: int) -> Literal[-1, 1]:
if state_id != -1 and is_opposite_helicity_state(topology, state_id):
return -1
return 1


def _create_amplitude_base(topology: Topology) -> sp.IndexedBase:
superscript = get_topology_identifier(topology)
return sp.IndexedBase(f"A^{superscript}", complex=True)


def _create_helicity_symbol(
topology: Topology, state_id: int, root: str = "lambda"
) -> sp.Symbol:
if state_id == -1: # initial state
name = "m_A"
else:
suffix = get_helicity_suffix(topology, state_id)
name = f"{root}{suffix}"
return sp.Symbol(name, rational=True)


def _create_spin_projection_symbol(state_id: int) -> sp.Symbol:
if state_id == -1: # initial state
suffix = "_A"
else:
suffix = str(state_id)
return sp.Symbol(f"m{suffix}", rational=True)


class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder):
r"""Amplitude model generator for the canonical helicity formalism.
Expand Down Expand Up @@ -1016,7 +985,7 @@ def formulate_rotation_chain(
plus a Wigner rotation (see :func:`.formulate_wigner_rotation`) in case there is
more than one helicity rotation.
"""
helicity_symbol = _create_spin_projection_symbol(rotated_state_id)
helicity_symbol = create_spin_projection_symbol(rotated_state_id)
helicity_rotations = formulate_helicity_rotation_chain(
transition, rotated_state_id, helicity_symbol
)
Expand Down
34 changes: 34 additions & 0 deletions src/ampform/helicity/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
assert_isobar_topology,
determine_attached_final_state,
get_helicity_info,
get_outer_state_ids,
get_sorted_states,
)

Expand Down Expand Up @@ -287,6 +288,20 @@ def __generate_ls_arrow(transition: StateTransition, node_id: int) -> str:
return Rf" \xrightarrow[S={coupled_spin}]{{L={angular_momentum}}} "


def create_amplitude_symbol(transition: StateTransition) -> sp.Indexed:
outer_state_ids = get_outer_state_ids(transition)
helicities = tuple(
sp.Rational(transition.states[i].spin_projection) for i in outer_state_ids
)
base = create_amplitude_base(transition.topology)
return base[helicities]


def create_amplitude_base(topology: Topology) -> sp.IndexedBase:
superscript = get_topology_identifier(topology)
return sp.IndexedBase(f"A^{superscript}", complex=True)


def generate_transition_label(transition: StateTransition) -> str:
r"""Generate a label for a coherent intensity, including spin projection.
Expand Down Expand Up @@ -495,3 +510,22 @@ def _render_float(value: float) -> str:
if value > 0:
return f"+{rational}"
return str(rational)


def create_helicity_symbol(
topology: Topology, state_id: int, root: str = "lambda"
) -> sp.Symbol:
if state_id == -1: # initial state
name = "m_A"
else:
suffix = get_helicity_suffix(topology, state_id)
name = f"{root}{suffix}"
return sp.Symbol(name, rational=True)


def create_spin_projection_symbol(state_id: int) -> sp.Symbol:
if state_id == -1: # initial state
suffix = "_A"
else:
suffix = str(state_id)
return sp.Symbol(f"m{suffix}", rational=True)

0 comments on commit 25a1979

Please sign in to comment.