Skip to content

Commit

Permalink
Finalize draft of Dalitz plot and 1d projection
Browse files Browse the repository at this point in the history
  • Loading branch information
shenvitor committed Aug 16, 2024
1 parent d844ba4 commit 7e8975e
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 85 deletions.
202 changes: 117 additions & 85 deletions docs/lambda-k-pi/ampform-dpd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
"import ampform\n",
"import graphviz\n",
"import ipywidgets as w\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
Expand All @@ -49,18 +48,14 @@
"from ampform.dynamics.builder import RelativisticBreitWignerBuilder\n",
"from ampform.io import improve_latex_rendering\n",
"from ampform_dpd import DalitzPlotDecompositionBuilder\n",
"from ampform_dpd.adapter.qrules import (\n",
" to_three_body_decay,\n",
")\n",
"from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay\n",
"from ampform_dpd.decay import DecayNode, ThreeBodyDecayChain\n",
"from ampform_dpd.dynamics.builder import create_mass_symbol, get_mandelstam_s\n",
"from ampform_dpd.io import aslatex\n",
"from IPython.display import Latex, Markdown, Math, display\n",
"from qrules.particle import Particle, Spin, create_particle, load_pdg\n",
"from tensorwaves.data import (\n",
" SympyDataTransformer,\n",
" TFPhaseSpaceGenerator,\n",
" TFUniformRealNumberGenerator,\n",
")\n",
"from tensorwaves.function.sympy import create_parametrized_function\n",
"\n",
Expand Down Expand Up @@ -269,16 +264,14 @@
" max_angular_momentum=4,\n",
" max_spin_magnitude=4,\n",
" mass_conservation_factor=0,\n",
")"
")\n",
"reaction = normalize_state_ids(reaction)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"hide-input"
]
Expand Down Expand Up @@ -314,7 +307,7 @@
"source": [
"model_builder = ampform.get_builder(reaction)\n",
"model_builder.config.scalar_initial_state_mass = True\n",
"model_builder.config.stable_final_state_ids = 0, 1, 2\n",
"model_builder.config.stable_final_state_ids = list(reaction.final_state)\n",
"bw_builder = RelativisticBreitWignerBuilder(\n",
" energy_dependent_width=False,\n",
" form_factor=False,\n",
Expand Down Expand Up @@ -446,59 +439,102 @@
"Math(aslatex(model.variables))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualization"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"unfolded_expression = model.expression.doit()\n",
"unfolded_expression = model.full_expression.doit()\n",
"intensity_func = create_parametrized_function(\n",
" expression=unfolded_expression,\n",
" parameters=model.parameter_defaults,\n",
" backend=\"jax\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"i, j = 3, 1\n",
"(k,) = {1, 2, 3} - {i, j}\n",
"σk, σk_expr = list(model.invariants.items())[k - 1]\n",
"Latex(aslatex({σk: σk_expr}))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"remove-output"
"hide-input"
]
},
"outputs": [],
"source": [
"phsp_event = 500_000\n",
"rng = TFUniformRealNumberGenerator(seed=0)\n",
"phsp_generator = TFPhaseSpaceGenerator(\n",
" initial_state_mass=reaction.initial_state[-1].mass,\n",
" final_state_masses={i: p.mass for i, p in reaction.final_state.items()},\n",
")\n",
"phsp_momenta = phsp_generator.generate(phsp_event, rng)"
"resolution = 1_000\n",
"m = sorted(model.masses, key=str)\n",
"x_min = float(((m[j] + m[k]) ** 2).xreplace(model.masses))\n",
"x_max = float(((m[0] - m[i]) ** 2).xreplace(model.masses))\n",
"y_min = float(((m[i] + m[k]) ** 2).xreplace(model.masses))\n",
"y_max = float(((m[0] - m[j]) ** 2).xreplace(model.masses))\n",
"x_diff = x_max - x_min\n",
"y_diff = y_max - y_min\n",
"x_min -= 0.05 * x_diff\n",
"x_max += 0.05 * x_diff\n",
"y_min -= 0.05 * y_diff\n",
"y_max += 0.05 * y_diff\n",
"X, Y = jnp.meshgrid(\n",
" jnp.linspace(x_min, x_max, num=resolution),\n",
" jnp.linspace(y_min, y_max, num=resolution),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"helicity_transformer = SympyDataTransformer.from_sympy(\n",
" model.kinematic_variables,\n",
" backend=\"jax\",\n",
")\n",
"phsp = helicity_transformer(phsp_momenta)"
"definitions = dict(model.variables)\n",
"definitions[σk] = σk_expr\n",
"definitions = {\n",
" symbol: expr.xreplace(definitions).xreplace(model.masses)\n",
" for symbol, expr in definitions.items()\n",
"}\n",
"\n",
"data_transformer = SympyDataTransformer.from_sympy(definitions, backend=\"jax\")\n",
"dalitz_data = {\n",
" f\"sigma{i}\": X,\n",
" f\"sigma{j}\": Y,\n",
"}\n",
"dalitz_data.update(data_transformer(dalitz_data))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"hide-output",
"hide-input"
Expand Down Expand Up @@ -537,6 +573,7 @@
"categorized_sliders_m = defaultdict(list)\n",
"categorized_sliders_gamma = defaultdict(list)\n",
"categorized_cphi_pair = defaultdict(list)\n",
"couplings_name_root = R\"\\mathcal{H}^\\mathrm{decay}\"\n",
"\n",
"for symbol, value in model.parameter_defaults.items():\n",
" if symbol.name.startswith(R\"\\Gamma_{\"):\n",
Expand All @@ -549,11 +586,11 @@
" )\n",
" sliders[symbol.name] = slider\n",
" if symbol.name.startswith(R\"\\Gamma_{K\"):\n",
" categorized_sliders_gamma[0].append(slider)\n",
" elif symbol.name.startswith(R\"\\Gamma_{\\S\"):\n",
" categorized_sliders_gamma[1].append(slider)\n",
" elif symbol.name.startswith(R\"\\Gamma_{N\"):\n",
" elif symbol.name.startswith(R\"\\Gamma_{\\S\"):\n",
" categorized_sliders_gamma[2].append(slider)\n",
" elif symbol.name.startswith(R\"\\Gamma_{N\"):\n",
" categorized_sliders_gamma[3].append(slider)\n",
"\n",
" if symbol.name.startswith(\"m_{\"):\n",
" slider = w.FloatSlider(\n",
Expand All @@ -565,15 +602,15 @@
" )\n",
" sliders[symbol.name] = slider\n",
" if symbol.name.startswith(\"m_{K\"):\n",
" categorized_sliders_m[0].append(slider)\n",
" elif symbol.name.startswith(R\"m_{\\S\"):\n",
" categorized_sliders_m[1].append(slider)\n",
" elif symbol.name.startswith(\"m_{N\"):\n",
" elif symbol.name.startswith(R\"m_{\\S\"):\n",
" categorized_sliders_m[2].append(slider)\n",
" elif symbol.name.startswith(\"m_{N\"):\n",
" categorized_sliders_m[3].append(slider)\n",
"\n",
" if symbol.name.startswith(\"C_{\"):\n",
" if symbol.name.startswith(couplings_name_root):\n",
" c_latex = sp.latex(symbol)\n",
" phi_latex = c_latex.replace(\"C\", R\"\\phi\", 1)\n",
" phi_latex = c_latex.replace(couplings_name_root, R\"\\phi\", 1)\n",
"\n",
" slider_c = w.FloatSlider(\n",
" description=Rf\"\\({c_latex}\\)\",\n",
Expand All @@ -591,49 +628,49 @@
" )\n",
"\n",
" sliders[symbol.name] = slider_c\n",
" sliders[symbol.name.replace(\"C\", \"phi\", 1)] = slider_phi\n",
" sliders[symbol.name.replace(couplings_name_root, \"phi\", 1)] = slider_phi\n",
"\n",
" cphi_hbox = w.HBox([slider_c, slider_phi])\n",
" if \"Sigma\" in symbol.name:\n",
" categorized_cphi_pair[1].append(cphi_hbox)\n",
" elif R\"\\to N\" in symbol.name:\n",
" categorized_cphi_pair[2].append(cphi_hbox)\n",
" elif \"N\" in symbol.name:\n",
" categorized_cphi_pair[3].append(cphi_hbox)\n",
" else:\n",
" categorized_cphi_pair[0].append(cphi_hbox)\n",
" categorized_cphi_pair[1].append(cphi_hbox)\n",
"\n",
"\n",
"assert len(categorized_sliders_gamma) == 3\n",
"assert len(categorized_sliders_m) == 3\n",
"assert len(categorized_cphi_pair) == 3\n",
"\n",
"subtabs = {}\n",
"for category, resonance_list in resonances.items():\n",
" subtabs[category] = []\n",
"for recoild_id, resonance_list in resonances.items():\n",
" subtabs[recoild_id] = []\n",
" for particle in resonance_list:\n",
" m_sliders = [\n",
" slider\n",
" for slider in categorized_sliders_m[category]\n",
" for slider in categorized_sliders_m[recoild_id]\n",
" if particle.latex in slider.description\n",
" ]\n",
" gamma_sliders = [\n",
" slider\n",
" for slider in categorized_sliders_gamma[category]\n",
" for slider in categorized_sliders_gamma[recoild_id]\n",
" if particle.latex in slider.description\n",
" ]\n",
" cphi_pairs = [\n",
" hbox\n",
" for hbox in categorized_cphi_pair[category]\n",
" for hbox in categorized_cphi_pair[recoild_id]\n",
" if particle.latex in hbox.children[0].description\n",
" ]\n",
" pole_pair = w.HBox(m_sliders + gamma_sliders)\n",
" resonance_tab = w.VBox([pole_pair, *cphi_pairs])\n",
" subtabs[category].append(resonance_tab)\n",
" subtabs[recoild_id].append(resonance_tab)\n",
"assert len(subtabs) == 3\n",
"\n",
"main_tabs = []\n",
"for category, slider_boxes in subtabs.items():\n",
"for recoild_id, slider_boxes in subtabs.items():\n",
" sub_tab_widget = w.Tab(children=slider_boxes)\n",
" for i, particle in enumerate(resonances[category]):\n",
" for i, particle in enumerate(resonances[recoild_id]):\n",
" sub_tab_widget.set_title(i, particle.name)\n",
"\n",
" main_tabs.append(sub_tab_widget)\n",
Expand All @@ -654,10 +691,10 @@
"def insert_phi(parameters: dict) -> dict:\n",
" updated_parameters = {}\n",
" for key, value in parameters.items():\n",
" if key.startswith(\"phi_\"):\n",
" if key.startswith(\"phi\"):\n",
" continue\n",
" if key.startswith(\"C_\"):\n",
" phi_key = key.replace(\"C_\", \"phi_\")\n",
" if key.startswith(couplings_name_root):\n",
" phi_key = key.replace(couplings_name_root, \"phi\")\n",
" if phi_key in parameters:\n",
" phi = parameters[phi_key]\n",
" value *= np.exp(1j * phi) # noqa:PLW2901\n",
Expand Down Expand Up @@ -689,28 +726,21 @@
"mesh = None\n",
"\n",
"\n",
"def update_histogram(**parameters):\n",
"def update_dalitz_plot(**parameters):\n",
" global mesh\n",
" parameters = insert_phi(parameters)\n",
" intensity_func.update_parameters(parameters)\n",
" intensity_weights = intensity_func(phsp)\n",
" bin_values, xedges, yedges = jnp.histogram2d(\n",
" phsp[\"m_01\"].real ** 2,\n",
" phsp[\"m_12\"].real ** 2,\n",
" bins=200,\n",
" weights=intensity_weights,\n",
" density=True,\n",
" )\n",
" bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n",
" x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n",
" intensities = intensity_func(dalitz_data) # z\n",
" intensities /= jnp.nansum(intensities) # normalization\n",
"\n",
" if mesh is None:\n",
" mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap=\"jet\", vmax=0.15)\n",
" mesh = ax_2d.pcolormesh(X, Y, intensities, cmap=\"jet\", vmax=3e-5)\n",
" else:\n",
" mesh.set_array(bin_values.T)\n",
" mesh.set_array(intensities)\n",
" fig_2d.canvas.draw_idle()\n",
"\n",
"\n",
"interactive_plot = w.interactive_output(update_histogram, sliders)\n",
"interactive_plot = w.interactive_output(update_dalitz_plot, sliders)\n",
"fig_2d.tight_layout()\n",
"fig_2d.colorbar(mesh, ax=ax_2d)\n",
"display(UI, interactive_plot)"
Expand All @@ -734,42 +764,44 @@
"%matplotlib widget\n",
"%config InlineBackend.figure_formats = ['svg']\n",
"\n",
"fig, axes = plt.subplots(figsize=(11, 3.5), ncols=3, sharey=True)\n",
"fig, axes = plt.subplots(figsize=(11, 3.5), ncols=2, sharey=True)\n",
"fig.canvas.toolbar_visible = False\n",
"fig.canvas.header_visible = False\n",
"fig.canvas.footer_visible = False\n",
"ax1, ax2, ax3 = axes\n",
"ax1, ax2 = axes\n",
"\n",
"for recoil_id, ax in enumerate(axes):\n",
" decay_products = sorted({0, 1, 2} - {recoil_id})\n",
"for ax in axes:\n",
" recoil_id = 3 if ax is ax1 else 1\n",
" decay_products = sorted(set(reaction.final_state) - {recoil_id})\n",
" product_latex = \" \".join([reaction.final_state[i].latex for i in decay_products])\n",
" ax.set_xlabel(f\"$m({product_latex})$ [GeV]\")\n",
"\n",
"LINES = 3 * [None]\n",
"RESONANCE_LINE = [None] * len(reaction.get_intermediate_particles())\n",
"LINES = defaultdict(lambda: None)\n",
"RESONANCE_LINE = defaultdict(lambda: None)\n",
"\n",
"\n",
"def update_plot(**parameters):\n",
" parameters = insert_phi(parameters)\n",
" intensity_func.update_parameters(parameters)\n",
" intensities = intensity_func(phsp)\n",
" intensities = intensity_func(dalitz_data) # z\n",
" intensities /= jnp.nansum(intensities) # normalization\n",
"\n",
" max_value = 0\n",
" color_id = 0\n",
" for recoil_id, ax in enumerate(axes):\n",
" decay_products = sorted({0, 1, 2} - {recoil_id})\n",
" key = f\"m_{''.join(str(i) for i in decay_products)}\"\n",
" bin_values, bin_edges = jax.numpy.histogram(\n",
" phsp[key].real,\n",
" bins=120,\n",
" density=True,\n",
" weights=intensities,\n",
" )\n",
" max_value = max(max_value, bin_values.max())\n",
" for ax in axes:\n",
" if ax is ax1:\n",
" x = jnp.sqrt(X[0])\n",
" y = jnp.nansum(intensities, axis=0)\n",
" else:\n",
" x = jnp.sqrt(Y[:, 0])\n",
" y = jnp.nansum(intensities, axis=1)\n",
"\n",
" max_value = max(max_value, y.max())\n",
" recoil_id = 3 if ax is ax1 else 1\n",
" if LINES[recoil_id] is None:\n",
" LINES[recoil_id] = ax.step(bin_edges[:-1], bin_values, alpha=0.5)[0]\n",
" LINES[recoil_id] = ax.plot(x, y, alpha=0.5)[0]\n",
" else:\n",
" LINES[recoil_id].set_ydata(bin_values)\n",
" LINES[recoil_id].set_ydata(y)\n",
"\n",
" for resonance in resonances[recoil_id]:\n",
" key = f\"m_{{{resonance.latex}}}\"\n",
Expand Down
Loading

0 comments on commit 7e8975e

Please sign in to comment.