Skip to content

Commit

Permalink
ENH: add sliders and c-phi for etapip auto notebook (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
shenvitor authored Aug 23, 2024
1 parent 90c90a6 commit 07f7abb
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 46 deletions.
192 changes: 146 additions & 46 deletions docs/eta-pi-p/automated.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@
"import graphviz\n",
"import ipywidgets as w\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import qrules\n",
"import sympy as sp\n",
"from ampform.dynamics.builder import RelativisticBreitWignerBuilder\n",
"from ampform.io import aslatex, improve_latex_rendering\n",
"from IPython.display import SVG, Math, display\n",
"from IPython.display import SVG, Image, Math, display\n",
"from qrules.particle import Particle, Spin, create_particle, load_pdg\n",
"from tensorwaves.data import (\n",
" SympyDataTransformer,\n",
Expand Down Expand Up @@ -287,41 +289,6 @@
"phsp = helicity_transformer(phsp_momenta)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"hide-input",
"scroll-input"
]
},
"outputs": [],
"source": [
"%config InlineBackend.figure_formats = ['png']\n",
"\n",
"fig, ax = plt.subplots(dpi=200)\n",
"hist = ax.hist2d(\n",
" phsp[\"m_01\"].real ** 2,\n",
" phsp[\"m_12\"].real ** 2,\n",
" bins=200,\n",
" cmin=1e-6,\n",
" density=True,\n",
" cmap=\"jet\",\n",
" vmax=0.15,\n",
" weights=intensity_func(phsp),\n",
")\n",
"ax.set_title(\"Model-weighted Phase space Dalitz Plot\")\n",
"ax.set_xlabel(R\"$m^2(\\eta \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n",
"ax.set_ylabel(R\"$m^2(\\pi^0 p)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n",
"cbar = fig.colorbar(hist[3], ax=ax)\n",
"fig.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -366,34 +333,166 @@
"outputs": [],
"source": [
"sliders = {}\n",
"mass_sliders = []\n",
"gamma_sliders = []\n",
"categorized_sliders_m = defaultdict(list)\n",
"categorized_sliders_gamma = defaultdict(list)\n",
"categorized_cphi_pair = defaultdict(list)\n",
"\n",
"for symbol, value in model.parameter_defaults.items():\n",
" if symbol.name.startswith(R\"\\Gamma_{\"):\n",
" slider = w.FloatSlider(\n",
" description=Rf\"\\({sp.latex(symbol)}\\)\",\n",
" min=0.0,\n",
" max=1.0,\n",
" step=0.01,\n",
" value=value,\n",
" continuous_update=False,\n",
" )\n",
" gamma_sliders.append(slider)\n",
" sliders[symbol.name] = slider\n",
" if symbol.name.startswith(R\"\\Gamma_{N\"):\n",
" categorized_sliders_gamma[0].append(slider)\n",
" elif symbol.name.startswith(R\"\\Gamma_{\\D\"):\n",
" categorized_sliders_gamma[1].append(slider)\n",
" elif symbol.name.startswith(R\"\\Gamma_{a\"):\n",
" categorized_sliders_gamma[2].append(slider)\n",
"\n",
" if symbol.name.startswith(\"m_{\"):\n",
" slider = w.FloatSlider(\n",
" description=Rf\"\\({sp.latex(symbol)}\\)\",\n",
" min=0.63,\n",
" max=4,\n",
" step=0.01,\n",
" value=value,\n",
" continuous_update=False,\n",
" )\n",
" mass_sliders.append(slider)\n",
" sliders[symbol.name] = slider\n",
" if symbol.name.startswith(\"m_{N\"):\n",
" categorized_sliders_m[0].append(slider)\n",
" elif symbol.name.startswith(R\"m_{\\D\"):\n",
" categorized_sliders_m[1].append(slider)\n",
" elif symbol.name.startswith(\"m_{a\"):\n",
" categorized_sliders_m[2].append(slider)\n",
"\n",
" if symbol.name.startswith(\"C_{\"):\n",
" c_latex = sp.latex(symbol)\n",
" phi_latex = c_latex.replace(\"C\", R\"\\phi\", 1)\n",
"\n",
" slider_c = w.FloatSlider(\n",
" description=Rf\"\\({c_latex}\\)\",\n",
" min=0,\n",
" max=10,\n",
" value=abs(value),\n",
" continuous_update=False,\n",
" )\n",
" slider_phi = w.FloatSlider(\n",
" description=Rf\"\\({phi_latex}\\)\",\n",
" min=-np.pi,\n",
" max=+np.pi,\n",
" value=np.angle(value),\n",
" continuous_update=False,\n",
" )\n",
"\n",
" sliders[symbol.name] = slider_c\n",
" sliders[symbol.name.replace(\"C\", \"phi\", 1)] = slider_phi\n",
"\n",
" cphi_hbox = w.HBox([slider_c, slider_phi])\n",
" if R\"\\D\" in symbol.name:\n",
" categorized_cphi_pair[1].append(cphi_hbox)\n",
" elif \"N\" in symbol.name:\n",
" categorized_cphi_pair[0].append(cphi_hbox)\n",
" elif \"a\" in symbol.name:\n",
" categorized_cphi_pair[2].append(cphi_hbox)\n",
"\n",
"tab_contents = []\n",
"resonances_name = [\"N*\", \"Δ*\", \"a₂*\"]\n",
"for i in range(len(resonances_name)):\n",
" tab_content = w.VBox([\n",
" w.HBox(categorized_sliders_m[i] + categorized_sliders_gamma[i]),\n",
" w.VBox(categorized_cphi_pair[i]),\n",
" ])\n",
" tab_contents.append(tab_content)\n",
"UI = w.Tab(tab_contents, titles=resonances_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"def insert_phi(parameters: dict) -> dict:\n",
" updated_parameters = {}\n",
" for key, value in parameters.items():\n",
" if key.startswith(\"phi_\"):\n",
" continue\n",
" if key.startswith(\"C_\"):\n",
" phi_key = key.replace(\"C_\", \"phi_\")\n",
" if phi_key in parameters:\n",
" phi = parameters[phi_key]\n",
" value *= np.exp(1j * phi) # noqa:PLW2901\n",
" updated_parameters[key] = value\n",
" return updated_parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"hide-input",
"full-width"
]
},
"outputs": [],
"source": [
"%matplotlib widget\n",
"%config InlineBackend.figure_formats = ['png']\n",
"fig_2d, ax_2d = plt.subplots(dpi=200)\n",
"ax_2d.set_title(\"Model-weighted Phase space Dalitz Plot\")\n",
"ax_2d.set_xlabel(R\"$m^2(\\eta \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n",
"ax_2d.set_ylabel(R\"$m^2(\\pi^0 p)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n",
"\n",
"mesh = None\n",
"\n",
"\n",
"def update_histogram(**parameters):\n",
" global mesh\n",
" parameters = insert_phi(parameters)\n",
" intensity_func.update_parameters(parameters)\n",
" intensity_weights = intensity_func(phsp)\n",
" bin_values, xedges, yedges = jnp.histogram2d(\n",
" phsp[\"m_01\"].real ** 2,\n",
" phsp[\"m_12\"].real ** 2,\n",
" bins=200,\n",
" weights=intensity_weights,\n",
" density=True,\n",
" )\n",
" bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n",
" x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n",
" if mesh is None:\n",
" mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap=\"jet\", vmax=0.15)\n",
" else:\n",
" mesh.set_array(bin_values.T)\n",
" fig_2d.canvas.draw_idle()\n",
"\n",
"\n",
"UI = w.HBox((\n",
" w.VBox(mass_sliders),\n",
" w.VBox(gamma_sliders),\n",
"))"
"interactive_plot = w.interactive_output(update_histogram, sliders)\n",
"fig_2d.tight_layout()\n",
"fig_2d.colorbar(mesh, ax=ax_2d)\n",
"\n",
"if STATIC_PAGE:\n",
" filename = \"dalitz-plot.png\"\n",
" fig_2d.savefig(filename)\n",
" plt.close(fig_2d)\n",
" display(UI, Image(filename))\n",
"else:\n",
" display(UI, interactive_plot)"
]
},
{
Expand Down Expand Up @@ -430,11 +529,12 @@
"\n",
"\n",
"def update_plot(**parameters):\n",
" parameters = insert_phi(parameters)\n",
" intensity_func.update_parameters(parameters)\n",
" intensities = intensity_func(phsp)\n",
" max_value = 0\n",
" resonance_colors: dict[Particle, int] = {}\n",
" color_id = 0\n",
" intensity_func.update_parameters(parameters)\n",
" intensities = intensity_func(phsp)\n",
" for recoil_id, ax in enumerate(axes):\n",
" decay_products = sorted({0, 1, 2} - {recoil_id})\n",
" key = f\"m_{''.join(str(i) for i in decay_products)}\"\n",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ split-on-trailing-comma = false
"E303",
"E501",
"N806",
"N816",
"PLC2401",
"PLC2701",
"PLR2004",
Expand Down

0 comments on commit 07f7abb

Please sign in to comment.