diff --git a/docs/jpsi2pipipi.ipynb b/docs/jpsi2pipipi.ipynb index 46e9c684..b8aedfbc 100644 --- a/docs/jpsi2pipipi.ipynb +++ b/docs/jpsi2pipipi.ipynb @@ -37,6 +37,7 @@ "from typing import Iterable\n", "\n", "import ampform\n", + "import attrs\n", "import graphviz\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", @@ -61,7 +62,11 @@ "from tensorwaves.data.transform import SympyDataTransformer\n", "from tensorwaves.interface import DataSample, ParameterValue, ParametrizedFunction\n", "\n", - "from ampform_dpd import DalitzPlotDecompositionBuilder, simplify_latex_rendering\n", + "from ampform_dpd import (\n", + " DalitzPlotDecompositionBuilder,\n", + " set_initial_state_polarization,\n", + " simplify_latex_rendering,\n", + ")\n", "from ampform_dpd.decay import (\n", " IsobarNode,\n", " Particle,\n", @@ -149,7 +154,7 @@ "outputs": [], "source": [ "reaction = qrules.generate_transitions(\n", - " initial_state=INITIAL_STATE.name,\n", + " initial_state=(INITIAL_STATE.name, [-1, +1]),\n", " final_state=[p.name for p in FINAL_STATE],\n", " allowed_intermediate_particles=[\"a(0)(980)\", \"rho(770)\"],\n", " mass_conservation_factor=0,\n", @@ -282,7 +287,6 @@ { "cell_type": "markdown", "metadata": { - "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ @@ -303,7 +307,16 @@ "outputs": [], "source": [ "model_builder = DalitzPlotDecompositionBuilder(DECAY, min_ls=True)\n", + "initial_polarization = {\n", + " t.initial_states[-1].spin_projection for t in reaction.transitions\n", + "}\n", "dpd_model = model_builder.formulate(reference_subsystem=1)\n", + "dpd_model = attrs.evolve(\n", + " dpd_model,\n", + " intensity=set_initial_state_polarization(\n", + " dpd_model.intensity, initial_polarization\n", + " ),\n", + ")\n", "dpd_model.intensity" ] }, diff --git a/src/ampform_dpd/__init__.py b/src/ampform_dpd/__init__.py index 33b87f33..126bd14d 100644 --- a/src/ampform_dpd/__init__.py +++ b/src/ampform_dpd/__init__.py @@ -4,6 +4,7 @@ import sys from functools import lru_cache from itertools import product +from typing import Iterable import sympy as sp from ampform.sympy import PoolSum @@ -348,6 +349,19 @@ def _print_Indexed_latex(self, printer, *args): sp.Indexed._latex = _print_Indexed_latex +def set_initial_state_polarization( + intensity: PoolSum, spin_projections: Iterable[sp.Rational | float | int] +) -> PoolSum: + """Set the spin projections of the initial state.""" + helicity_symbol, _ = intensity.indices[0] + helicity_values = tuple(sp.Rational(i) for i in spin_projections) + new_indices = ( + (helicity_symbol, helicity_values), + *intensity.indices[1:], + ) + return PoolSum(intensity.expression, *new_indices) + + def _formulate_clebsch_gordan_factors( isobar: IsobarNode, helicities: dict[Particle, sp.Rational | sp.Symbol],