diff --git a/docs/lambda-k-pi/ampform-dpd.ipynb b/docs/lambda-k-pi/ampform-dpd.ipynb index d6c6a08..7d318b7 100644 --- a/docs/lambda-k-pi/ampform-dpd.ipynb +++ b/docs/lambda-k-pi/ampform-dpd.ipynb @@ -40,7 +40,6 @@ "import ampform\n", "import graphviz\n", "import ipywidgets as w\n", - "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", @@ -49,9 +48,7 @@ "from ampform.dynamics.builder import RelativisticBreitWignerBuilder\n", "from ampform.io import improve_latex_rendering\n", "from ampform_dpd import DalitzPlotDecompositionBuilder\n", - "from ampform_dpd.adapter.qrules import (\n", - " to_three_body_decay,\n", - ")\n", + "from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay\n", "from ampform_dpd.decay import DecayNode, ThreeBodyDecayChain\n", "from ampform_dpd.dynamics.builder import create_mass_symbol, get_mandelstam_s\n", "from ampform_dpd.io import aslatex\n", @@ -59,8 +56,6 @@ "from qrules.particle import Particle, Spin, create_particle, load_pdg\n", "from tensorwaves.data import (\n", " SympyDataTransformer,\n", - " TFPhaseSpaceGenerator,\n", - " TFUniformRealNumberGenerator,\n", ")\n", "from tensorwaves.function.sympy import create_parametrized_function\n", "\n", @@ -269,16 +264,14 @@ " max_angular_momentum=4,\n", " max_spin_magnitude=4,\n", " mass_conservation_factor=0,\n", - ")" + ")\n", + "reaction = normalize_state_ids(reaction)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "jupyter": { - "source_hidden": true - }, "tags": [ "hide-input" ] @@ -314,7 +307,7 @@ "source": [ "model_builder = ampform.get_builder(reaction)\n", "model_builder.config.scalar_initial_state_mass = True\n", - "model_builder.config.stable_final_state_ids = 0, 1, 2\n", + "model_builder.config.stable_final_state_ids = list(reaction.final_state)\n", "bw_builder = RelativisticBreitWignerBuilder(\n", " energy_dependent_width=False,\n", " form_factor=False,\n", @@ -446,13 +439,20 @@ "Math(aslatex(model.variables))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualization" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "unfolded_expression = model.expression.doit()\n", + "unfolded_expression = model.full_expression.doit()\n", "intensity_func = create_parametrized_function(\n", " expression=unfolded_expression,\n", " parameters=model.parameter_defaults,\n", @@ -460,45 +460,81 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "i, j = 3, 1\n", + "(k,) = {1, 2, 3} - {i, j}\n", + "σk, σk_expr = list(model.invariants.items())[k - 1]\n", + "Latex(aslatex({σk: σk_expr}))" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { + "jupyter": { + "source_hidden": true + }, "tags": [ - "remove-output" + "hide-input" ] }, "outputs": [], "source": [ - "phsp_event = 500_000\n", - "rng = TFUniformRealNumberGenerator(seed=0)\n", - "phsp_generator = TFPhaseSpaceGenerator(\n", - " initial_state_mass=reaction.initial_state[-1].mass,\n", - " final_state_masses={i: p.mass for i, p in reaction.final_state.items()},\n", - ")\n", - "phsp_momenta = phsp_generator.generate(phsp_event, rng)" + "resolution = 1_000\n", + "m = sorted(model.masses, key=str)\n", + "x_min = float(((m[j] + m[k]) ** 2).xreplace(model.masses))\n", + "x_max = float(((m[0] - m[i]) ** 2).xreplace(model.masses))\n", + "y_min = float(((m[i] + m[k]) ** 2).xreplace(model.masses))\n", + "y_max = float(((m[0] - m[j]) ** 2).xreplace(model.masses))\n", + "x_diff = x_max - x_min\n", + "y_diff = y_max - y_min\n", + "x_min -= 0.05 * x_diff\n", + "x_max += 0.05 * x_diff\n", + "y_min -= 0.05 * y_diff\n", + "y_max += 0.05 * y_diff\n", + "X, Y = jnp.meshgrid(\n", + " jnp.linspace(x_min, x_max, num=resolution),\n", + " jnp.linspace(y_min, y_max, num=resolution),\n", + ")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, "outputs": [], "source": [ - "helicity_transformer = SympyDataTransformer.from_sympy(\n", - " model.kinematic_variables,\n", - " backend=\"jax\",\n", - ")\n", - "phsp = helicity_transformer(phsp_momenta)" + "definitions = dict(model.variables)\n", + "definitions[σk] = σk_expr\n", + "definitions = {\n", + " symbol: expr.xreplace(definitions).xreplace(model.masses)\n", + " for symbol, expr in definitions.items()\n", + "}\n", + "\n", + "data_transformer = SympyDataTransformer.from_sympy(definitions, backend=\"jax\")\n", + "dalitz_data = {\n", + " f\"sigma{i}\": X,\n", + " f\"sigma{j}\": Y,\n", + "}\n", + "dalitz_data.update(data_transformer(dalitz_data))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "jupyter": { - "source_hidden": true - }, "tags": [ "hide-output", "hide-input" @@ -537,6 +573,7 @@ "categorized_sliders_m = defaultdict(list)\n", "categorized_sliders_gamma = defaultdict(list)\n", "categorized_cphi_pair = defaultdict(list)\n", + "couplings_name_root = R\"\\mathcal{H}^\\mathrm{decay}\"\n", "\n", "for symbol, value in model.parameter_defaults.items():\n", " if symbol.name.startswith(R\"\\Gamma_{\"):\n", @@ -549,11 +586,11 @@ " )\n", " sliders[symbol.name] = slider\n", " if symbol.name.startswith(R\"\\Gamma_{K\"):\n", - " categorized_sliders_gamma[0].append(slider)\n", - " elif symbol.name.startswith(R\"\\Gamma_{\\S\"):\n", " categorized_sliders_gamma[1].append(slider)\n", - " elif symbol.name.startswith(R\"\\Gamma_{N\"):\n", + " elif symbol.name.startswith(R\"\\Gamma_{\\S\"):\n", " categorized_sliders_gamma[2].append(slider)\n", + " elif symbol.name.startswith(R\"\\Gamma_{N\"):\n", + " categorized_sliders_gamma[3].append(slider)\n", "\n", " if symbol.name.startswith(\"m_{\"):\n", " slider = w.FloatSlider(\n", @@ -565,15 +602,15 @@ " )\n", " sliders[symbol.name] = slider\n", " if symbol.name.startswith(\"m_{K\"):\n", - " categorized_sliders_m[0].append(slider)\n", - " elif symbol.name.startswith(R\"m_{\\S\"):\n", " categorized_sliders_m[1].append(slider)\n", - " elif symbol.name.startswith(\"m_{N\"):\n", + " elif symbol.name.startswith(R\"m_{\\S\"):\n", " categorized_sliders_m[2].append(slider)\n", + " elif symbol.name.startswith(\"m_{N\"):\n", + " categorized_sliders_m[3].append(slider)\n", "\n", - " if symbol.name.startswith(\"C_{\"):\n", + " if symbol.name.startswith(couplings_name_root):\n", " c_latex = sp.latex(symbol)\n", - " phi_latex = c_latex.replace(\"C\", R\"\\phi\", 1)\n", + " phi_latex = c_latex.replace(couplings_name_root, R\"\\phi\", 1)\n", "\n", " slider_c = w.FloatSlider(\n", " description=Rf\"\\({c_latex}\\)\",\n", @@ -591,15 +628,15 @@ " )\n", "\n", " sliders[symbol.name] = slider_c\n", - " sliders[symbol.name.replace(\"C\", \"phi\", 1)] = slider_phi\n", + " sliders[symbol.name.replace(couplings_name_root, \"phi\", 1)] = slider_phi\n", "\n", " cphi_hbox = w.HBox([slider_c, slider_phi])\n", " if \"Sigma\" in symbol.name:\n", - " categorized_cphi_pair[1].append(cphi_hbox)\n", - " elif R\"\\to N\" in symbol.name:\n", " categorized_cphi_pair[2].append(cphi_hbox)\n", + " elif \"N\" in symbol.name:\n", + " categorized_cphi_pair[3].append(cphi_hbox)\n", " else:\n", - " categorized_cphi_pair[0].append(cphi_hbox)\n", + " categorized_cphi_pair[1].append(cphi_hbox)\n", "\n", "\n", "assert len(categorized_sliders_gamma) == 3\n", @@ -607,33 +644,33 @@ "assert len(categorized_cphi_pair) == 3\n", "\n", "subtabs = {}\n", - "for category, resonance_list in resonances.items():\n", - " subtabs[category] = []\n", + "for recoild_id, resonance_list in resonances.items():\n", + " subtabs[recoild_id] = []\n", " for particle in resonance_list:\n", " m_sliders = [\n", " slider\n", - " for slider in categorized_sliders_m[category]\n", + " for slider in categorized_sliders_m[recoild_id]\n", " if particle.latex in slider.description\n", " ]\n", " gamma_sliders = [\n", " slider\n", - " for slider in categorized_sliders_gamma[category]\n", + " for slider in categorized_sliders_gamma[recoild_id]\n", " if particle.latex in slider.description\n", " ]\n", " cphi_pairs = [\n", " hbox\n", - " for hbox in categorized_cphi_pair[category]\n", + " for hbox in categorized_cphi_pair[recoild_id]\n", " if particle.latex in hbox.children[0].description\n", " ]\n", " pole_pair = w.HBox(m_sliders + gamma_sliders)\n", " resonance_tab = w.VBox([pole_pair, *cphi_pairs])\n", - " subtabs[category].append(resonance_tab)\n", + " subtabs[recoild_id].append(resonance_tab)\n", "assert len(subtabs) == 3\n", "\n", "main_tabs = []\n", - "for category, slider_boxes in subtabs.items():\n", + "for recoild_id, slider_boxes in subtabs.items():\n", " sub_tab_widget = w.Tab(children=slider_boxes)\n", - " for i, particle in enumerate(resonances[category]):\n", + " for i, particle in enumerate(resonances[recoild_id]):\n", " sub_tab_widget.set_title(i, particle.name)\n", "\n", " main_tabs.append(sub_tab_widget)\n", @@ -654,10 +691,10 @@ "def insert_phi(parameters: dict) -> dict:\n", " updated_parameters = {}\n", " for key, value in parameters.items():\n", - " if key.startswith(\"phi_\"):\n", + " if key.startswith(\"phi\"):\n", " continue\n", - " if key.startswith(\"C_\"):\n", - " phi_key = key.replace(\"C_\", \"phi_\")\n", + " if key.startswith(couplings_name_root):\n", + " phi_key = key.replace(couplings_name_root, \"phi\")\n", " if phi_key in parameters:\n", " phi = parameters[phi_key]\n", " value *= np.exp(1j * phi) # noqa:PLW2901\n", @@ -689,28 +726,21 @@ "mesh = None\n", "\n", "\n", - "def update_histogram(**parameters):\n", + "def update_dalitz_plot(**parameters):\n", " global mesh\n", " parameters = insert_phi(parameters)\n", " intensity_func.update_parameters(parameters)\n", - " intensity_weights = intensity_func(phsp)\n", - " bin_values, xedges, yedges = jnp.histogram2d(\n", - " phsp[\"m_01\"].real ** 2,\n", - " phsp[\"m_12\"].real ** 2,\n", - " bins=200,\n", - " weights=intensity_weights,\n", - " density=True,\n", - " )\n", - " bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n", - " x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n", + " intensities = intensity_func(dalitz_data) # z\n", + " intensities /= jnp.nansum(intensities) # normalization\n", + "\n", " if mesh is None:\n", - " mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap=\"jet\", vmax=0.15)\n", + " mesh = ax_2d.pcolormesh(X, Y, intensities, cmap=\"jet\", vmax=3e-5)\n", " else:\n", - " mesh.set_array(bin_values.T)\n", + " mesh.set_array(intensities)\n", " fig_2d.canvas.draw_idle()\n", "\n", "\n", - "interactive_plot = w.interactive_output(update_histogram, sliders)\n", + "interactive_plot = w.interactive_output(update_dalitz_plot, sliders)\n", "fig_2d.tight_layout()\n", "fig_2d.colorbar(mesh, ax=ax_2d)\n", "display(UI, interactive_plot)" @@ -734,42 +764,44 @@ "%matplotlib widget\n", "%config InlineBackend.figure_formats = ['svg']\n", "\n", - "fig, axes = plt.subplots(figsize=(11, 3.5), ncols=3, sharey=True)\n", + "fig, axes = plt.subplots(figsize=(11, 3.5), ncols=2, sharey=True)\n", "fig.canvas.toolbar_visible = False\n", "fig.canvas.header_visible = False\n", "fig.canvas.footer_visible = False\n", - "ax1, ax2, ax3 = axes\n", + "ax1, ax2 = axes\n", "\n", - "for recoil_id, ax in enumerate(axes):\n", - " decay_products = sorted({0, 1, 2} - {recoil_id})\n", + "for ax in axes:\n", + " recoil_id = 3 if ax is ax1 else 1\n", + " decay_products = sorted(set(reaction.final_state) - {recoil_id})\n", " product_latex = \" \".join([reaction.final_state[i].latex for i in decay_products])\n", " ax.set_xlabel(f\"$m({product_latex})$ [GeV]\")\n", "\n", - "LINES = 3 * [None]\n", - "RESONANCE_LINE = [None] * len(reaction.get_intermediate_particles())\n", + "LINES = defaultdict(lambda: None)\n", + "RESONANCE_LINE = defaultdict(lambda: None)\n", "\n", "\n", "def update_plot(**parameters):\n", " parameters = insert_phi(parameters)\n", " intensity_func.update_parameters(parameters)\n", - " intensities = intensity_func(phsp)\n", + " intensities = intensity_func(dalitz_data) # z\n", + " intensities /= jnp.nansum(intensities) # normalization\n", + "\n", " max_value = 0\n", " color_id = 0\n", - " for recoil_id, ax in enumerate(axes):\n", - " decay_products = sorted({0, 1, 2} - {recoil_id})\n", - " key = f\"m_{''.join(str(i) for i in decay_products)}\"\n", - " bin_values, bin_edges = jax.numpy.histogram(\n", - " phsp[key].real,\n", - " bins=120,\n", - " density=True,\n", - " weights=intensities,\n", - " )\n", - " max_value = max(max_value, bin_values.max())\n", + " for ax in axes:\n", + " if ax is ax1:\n", + " x = jnp.sqrt(X[0])\n", + " y = jnp.nansum(intensities, axis=0)\n", + " else:\n", + " x = jnp.sqrt(Y[:, 0])\n", + " y = jnp.nansum(intensities, axis=1)\n", "\n", + " max_value = max(max_value, y.max())\n", + " recoil_id = 3 if ax is ax1 else 1\n", " if LINES[recoil_id] is None:\n", - " LINES[recoil_id] = ax.step(bin_edges[:-1], bin_values, alpha=0.5)[0]\n", + " LINES[recoil_id] = ax.plot(x, y, alpha=0.5)[0]\n", " else:\n", - " LINES[recoil_id].set_ydata(bin_values)\n", + " LINES[recoil_id].set_ydata(y)\n", "\n", " for resonance in resonances[recoil_id]:\n", " key = f\"m_{{{resonance.latex}}}\"\n", diff --git a/pyproject.toml b/pyproject.toml index 45228dc..a93930e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,6 +110,7 @@ builtins-ignorelist = ["display"] "D103", "E303", "E501", + "PLC2401", "PLR2004", "PLW0603", "S101",