From 07f7abb8873fc7cebe86d60cc717c767ddea20f7 Mon Sep 17 00:00:00 2001 From: SVJ_Vitor Date: Fri, 23 Aug 2024 14:34:43 +0200 Subject: [PATCH] ENH: add sliders and c-phi for etapip auto notebook (#92) --- docs/eta-pi-p/automated.ipynb | 192 ++++++++++++++++++++++++++-------- pyproject.toml | 1 + 2 files changed, 147 insertions(+), 46 deletions(-) diff --git a/docs/eta-pi-p/automated.ipynb b/docs/eta-pi-p/automated.ipynb index e71ea69..2f4949d 100644 --- a/docs/eta-pi-p/automated.ipynb +++ b/docs/eta-pi-p/automated.ipynb @@ -41,12 +41,14 @@ "import graphviz\n", "import ipywidgets as w\n", "import jax\n", + "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "import qrules\n", "import sympy as sp\n", "from ampform.dynamics.builder import RelativisticBreitWignerBuilder\n", "from ampform.io import aslatex, improve_latex_rendering\n", - "from IPython.display import SVG, Math, display\n", + "from IPython.display import SVG, Image, Math, display\n", "from qrules.particle import Particle, Spin, create_particle, load_pdg\n", "from tensorwaves.data import (\n", " SympyDataTransformer,\n", @@ -287,41 +289,6 @@ "phsp = helicity_transformer(phsp_momenta)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true - }, - "tags": [ - "hide-input", - "scroll-input" - ] - }, - "outputs": [], - "source": [ - "%config InlineBackend.figure_formats = ['png']\n", - "\n", - "fig, ax = plt.subplots(dpi=200)\n", - "hist = ax.hist2d(\n", - " phsp[\"m_01\"].real ** 2,\n", - " phsp[\"m_12\"].real ** 2,\n", - " bins=200,\n", - " cmin=1e-6,\n", - " density=True,\n", - " cmap=\"jet\",\n", - " vmax=0.15,\n", - " weights=intensity_func(phsp),\n", - ")\n", - "ax.set_title(\"Model-weighted Phase space Dalitz Plot\")\n", - "ax.set_xlabel(R\"$m^2(\\eta \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", - "ax.set_ylabel(R\"$m^2(\\pi^0 p)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", - "cbar = fig.colorbar(hist[3], ax=ax)\n", - "fig.tight_layout()\n", - "plt.show()" - ] - }, { "cell_type": "code", "execution_count": null, @@ -366,8 +333,9 @@ "outputs": [], "source": [ "sliders = {}\n", - "mass_sliders = []\n", - "gamma_sliders = []\n", + "categorized_sliders_m = defaultdict(list)\n", + "categorized_sliders_gamma = defaultdict(list)\n", + "categorized_cphi_pair = defaultdict(list)\n", "\n", "for symbol, value in model.parameter_defaults.items():\n", " if symbol.name.startswith(R\"\\Gamma_{\"):\n", @@ -375,25 +343,156 @@ " description=Rf\"\\({sp.latex(symbol)}\\)\",\n", " min=0.0,\n", " max=1.0,\n", + " step=0.01,\n", " value=value,\n", + " continuous_update=False,\n", " )\n", - " gamma_sliders.append(slider)\n", " sliders[symbol.name] = slider\n", + " if symbol.name.startswith(R\"\\Gamma_{N\"):\n", + " categorized_sliders_gamma[0].append(slider)\n", + " elif symbol.name.startswith(R\"\\Gamma_{\\D\"):\n", + " categorized_sliders_gamma[1].append(slider)\n", + " elif symbol.name.startswith(R\"\\Gamma_{a\"):\n", + " categorized_sliders_gamma[2].append(slider)\n", "\n", " if symbol.name.startswith(\"m_{\"):\n", " slider = w.FloatSlider(\n", " description=Rf\"\\({sp.latex(symbol)}\\)\",\n", " min=0.63,\n", " max=4,\n", + " step=0.01,\n", " value=value,\n", + " continuous_update=False,\n", " )\n", - " mass_sliders.append(slider)\n", " sliders[symbol.name] = slider\n", + " if symbol.name.startswith(\"m_{N\"):\n", + " categorized_sliders_m[0].append(slider)\n", + " elif symbol.name.startswith(R\"m_{\\D\"):\n", + " categorized_sliders_m[1].append(slider)\n", + " elif symbol.name.startswith(\"m_{a\"):\n", + " categorized_sliders_m[2].append(slider)\n", + "\n", + " if symbol.name.startswith(\"C_{\"):\n", + " c_latex = sp.latex(symbol)\n", + " phi_latex = c_latex.replace(\"C\", R\"\\phi\", 1)\n", + "\n", + " slider_c = w.FloatSlider(\n", + " description=Rf\"\\({c_latex}\\)\",\n", + " min=0,\n", + " max=10,\n", + " value=abs(value),\n", + " continuous_update=False,\n", + " )\n", + " slider_phi = w.FloatSlider(\n", + " description=Rf\"\\({phi_latex}\\)\",\n", + " min=-np.pi,\n", + " max=+np.pi,\n", + " value=np.angle(value),\n", + " continuous_update=False,\n", + " )\n", + "\n", + " sliders[symbol.name] = slider_c\n", + " sliders[symbol.name.replace(\"C\", \"phi\", 1)] = slider_phi\n", + "\n", + " cphi_hbox = w.HBox([slider_c, slider_phi])\n", + " if R\"\\D\" in symbol.name:\n", + " categorized_cphi_pair[1].append(cphi_hbox)\n", + " elif \"N\" in symbol.name:\n", + " categorized_cphi_pair[0].append(cphi_hbox)\n", + " elif \"a\" in symbol.name:\n", + " categorized_cphi_pair[2].append(cphi_hbox)\n", + "\n", + "tab_contents = []\n", + "resonances_name = [\"N*\", \"Δ*\", \"a₂*\"]\n", + "for i in range(len(resonances_name)):\n", + " tab_content = w.VBox([\n", + " w.HBox(categorized_sliders_m[i] + categorized_sliders_gamma[i]),\n", + " w.VBox(categorized_cphi_pair[i]),\n", + " ])\n", + " tab_contents.append(tab_content)\n", + "UI = w.Tab(tab_contents, titles=resonances_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "def insert_phi(parameters: dict) -> dict:\n", + " updated_parameters = {}\n", + " for key, value in parameters.items():\n", + " if key.startswith(\"phi_\"):\n", + " continue\n", + " if key.startswith(\"C_\"):\n", + " phi_key = key.replace(\"C_\", \"phi_\")\n", + " if phi_key in parameters:\n", + " phi = parameters[phi_key]\n", + " value *= np.exp(1j * phi) # noqa:PLW2901\n", + " updated_parameters[key] = value\n", + " return updated_parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input", + "full-width" + ] + }, + "outputs": [], + "source": [ + "%matplotlib widget\n", + "%config InlineBackend.figure_formats = ['png']\n", + "fig_2d, ax_2d = plt.subplots(dpi=200)\n", + "ax_2d.set_title(\"Model-weighted Phase space Dalitz Plot\")\n", + "ax_2d.set_xlabel(R\"$m^2(\\eta \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", + "ax_2d.set_ylabel(R\"$m^2(\\pi^0 p)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n", + "\n", + "mesh = None\n", + "\n", + "\n", + "def update_histogram(**parameters):\n", + " global mesh\n", + " parameters = insert_phi(parameters)\n", + " intensity_func.update_parameters(parameters)\n", + " intensity_weights = intensity_func(phsp)\n", + " bin_values, xedges, yedges = jnp.histogram2d(\n", + " phsp[\"m_01\"].real ** 2,\n", + " phsp[\"m_12\"].real ** 2,\n", + " bins=200,\n", + " weights=intensity_weights,\n", + " density=True,\n", + " )\n", + " bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n", + " x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n", + " if mesh is None:\n", + " mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap=\"jet\", vmax=0.15)\n", + " else:\n", + " mesh.set_array(bin_values.T)\n", + " fig_2d.canvas.draw_idle()\n", + "\n", "\n", - "UI = w.HBox((\n", - " w.VBox(mass_sliders),\n", - " w.VBox(gamma_sliders),\n", - "))" + "interactive_plot = w.interactive_output(update_histogram, sliders)\n", + "fig_2d.tight_layout()\n", + "fig_2d.colorbar(mesh, ax=ax_2d)\n", + "\n", + "if STATIC_PAGE:\n", + " filename = \"dalitz-plot.png\"\n", + " fig_2d.savefig(filename)\n", + " plt.close(fig_2d)\n", + " display(UI, Image(filename))\n", + "else:\n", + " display(UI, interactive_plot)" ] }, { @@ -430,11 +529,12 @@ "\n", "\n", "def update_plot(**parameters):\n", + " parameters = insert_phi(parameters)\n", + " intensity_func.update_parameters(parameters)\n", + " intensities = intensity_func(phsp)\n", " max_value = 0\n", " resonance_colors: dict[Particle, int] = {}\n", " color_id = 0\n", - " intensity_func.update_parameters(parameters)\n", - " intensities = intensity_func(phsp)\n", " for recoil_id, ax in enumerate(axes):\n", " decay_products = sorted({0, 1, 2} - {recoil_id})\n", " key = f\"m_{''.join(str(i) for i in decay_products)}\"\n", diff --git a/pyproject.toml b/pyproject.toml index 9b9f206..00c3b80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,6 +114,7 @@ split-on-trailing-comma = false "E303", "E501", "N806", + "N816", "PLC2401", "PLC2701", "PLR2004",