Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add sliders in sympy etapipi notebook #95

Merged
merged 15 commits into from
Aug 27, 2024
276 changes: 245 additions & 31 deletions docs/eta-pi-p/manual.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This section is a follow-up to formulate the amplitude model for the $\\gamma p \\to \\eta\\pi^0 p$ channel example symbolically. See **[TR‑033](https://compwa.github.io/report/033)** for a purely numerical tutorial.\n",
"This section is a follow-up of previous chapter:[Reaction and Models](reaction-model.md), to formulate the amplitude model for the $\\gamma p \\to \\eta\\pi^0 p$ channel example symbolically. See **[TR‑033](https://compwa.github.io/report/033)** for a purely numerical tutorial.\n",
redeboer marked this conversation as resolved.
Show resolved Hide resolved
"\n",
"The model we want to implement is"
]
Expand All @@ -24,7 +24,7 @@
"\\begin{array}{rcl}\n",
"I &=& \\left|A^{12} + A^{23} + A^{31}\\right|^2 \\\\\n",
"A^{12} &=& \\frac{\\sum a_m Y_2^m (\\Omega_1)}{s_{12}-m^2_{a_2}+im_{a_2} \\Gamma_{a_2}} \\\\\n",
"A^{23} &=& \\frac{\\sum b_m Y_1^m (\\Omega_2)}{s_{23}-m^2_{\\Delta}+im_{\\Delta} \\Gamma_{\\Delta}} \\\\\n",
"A^{23} &=& \\frac{\\sum b_m Y_1^m (\\Omega_2)}{s_{23}-m^2_{\\Delta^+}+im_{\\Delta^+} \\Gamma_{\\Delta^+}} \\\\\n",
"A^{31} &=& \\frac{c_0}{s_{31}-m^2_{N^*}+im_{N^*} \\Gamma_{N^*}} \\,,\n",
"\\end{array}\n",
"$$"
Expand All @@ -41,17 +41,24 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"mystnb": {
"code_prompt_show": "Import Python libraries"
},
"tags": [
"hide-cell"
]
}
shenvitor marked this conversation as resolved.
Show resolved Hide resolved
},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"import logging\n",
"import os\n",
"import warnings\n",
"from collections import defaultdict\n",
"\n",
"import ipywidgets as w\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import sympy as sp\n",
Expand All @@ -71,13 +78,19 @@
")\n",
"from ampform.sympy import unevaluated\n",
"from ampform.sympy._array_expressions import ArraySum\n",
"from IPython.display import Latex\n",
"from IPython.display import Image, Latex, display\n",
"from tensorwaves.data import (\n",
" SympyDataTransformer,\n",
" TFPhaseSpaceGenerator,\n",
" TFUniformRealNumberGenerator,\n",
")\n",
"from tensorwaves.function.sympy import create_parametrized_function"
"from tensorwaves.function.sympy import create_parametrized_function\n",
"\n",
"STATIC_PAGE = \"EXECUTE_NB\" in os.environ\n",
"\n",
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n",
"logging.disable(logging.WARNING)\n",
"warnings.filterwarnings(\"ignore\")"
]
},
{
Expand Down Expand Up @@ -171,7 +184,7 @@
"metadata": {},
"outputs": [],
"source": [
"s23, m_delta, gamma_delta = sp.symbols(\"s_{23} m_Delta Gamma_Delta\")\n",
"s23, m_delta, gamma_delta = sp.symbols(r\"s_{23} m_{\\Delta^+} \\Gamma_{\\Delta^+}\")\n",
redeboer marked this conversation as resolved.
Show resolved Hide resolved
"b = sp.IndexedBase(\"b\")\n",
"m = sp.symbols(\"m\", cls=sp.Idx)\n",
"theta2, phi2 = sp.symbols(\"theta_2 phi_2\")\n",
Expand Down Expand Up @@ -431,9 +444,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
redeboer marked this conversation as resolved.
Show resolved Hide resolved
"tags": [
"hide-input"
]
Expand Down Expand Up @@ -471,6 +481,67 @@
"Latex(aslatex(parameters_default))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
":::{note}\n",
"The mass and width of resonances are customsed to make the resonance bands in a better visible form.\n",
shenvitor marked this conversation as resolved.
Show resolved Hide resolved
":::"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"isinstance(a[-2], sp.Indexed)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"isinstance(m_a2, sp.Indexed)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"type(m_a2).__mro__"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"type(a[-2]).__mro__"
]
},
shenvitor marked this conversation as resolved.
Show resolved Hide resolved
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -601,6 +672,98 @@
"### Dalitz Plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
}
},
shenvitor marked this conversation as resolved.
Show resolved Hide resolved
"outputs": [],
"source": [
"sliders = {}\n",
"categorized_sliders_m = defaultdict(list)\n",
"categorized_sliders_gamma = defaultdict(list)\n",
"categorized_cphi_pair = defaultdict(list)\n",
"\n",
"for symbol, value in parameters_default.items():\n",
" if symbol.name.startswith(R\"\\Gamma_{\"):\n",
" slider = w.FloatSlider(\n",
" description=Rf\"\\({sp.latex(symbol)}\\)\",\n",
" min=0.0,\n",
" max=1.0,\n",
" step=0.01,\n",
" value=value,\n",
" continuous_update=False,\n",
" )\n",
" sliders[symbol.name] = slider\n",
" if symbol.name.startswith(R\"\\Gamma_{N\"):\n",
" categorized_sliders_gamma[0].append(slider)\n",
" elif symbol.name.startswith(R\"\\Gamma_{\\D\"):\n",
" categorized_sliders_gamma[1].append(slider)\n",
" elif symbol.name.startswith(R\"\\Gamma_{a\"):\n",
" categorized_sliders_gamma[2].append(slider)\n",
"\n",
" elif symbol.name.startswith(\"m_{\"):\n",
" slider = w.FloatSlider(\n",
" description=Rf\"\\({sp.latex(symbol)}\\)\",\n",
" min=0.63,\n",
" max=4,\n",
" step=0.01,\n",
" value=value,\n",
" continuous_update=False,\n",
" )\n",
" sliders[symbol.name] = slider\n",
" if symbol.name.startswith(\"m_{N\"):\n",
" categorized_sliders_m[0].append(slider)\n",
" elif symbol.name.startswith(R\"m_{\\D\"):\n",
" categorized_sliders_m[1].append(slider)\n",
" elif symbol.name.startswith(\"m_{a\"):\n",
" categorized_sliders_m[2].append(slider)\n",
"\n",
" else:\n",
" c_latex = sp.latex(symbol)\n",
" phi_latex = Rf\"\\phi_{{{c_latex}}}\"\n",
" slider_c = w.FloatSlider(\n",
" description=Rf\"\\({c_latex}\\)\",\n",
" min=0,\n",
" max=10,\n",
" value=abs(value),\n",
" continuous_update=False,\n",
" )\n",
" slider_phi = w.FloatSlider(\n",
" description=Rf\"\\({phi_latex}\\)\",\n",
" min=-np.pi,\n",
" max=+np.pi,\n",
" value=np.angle(value),\n",
" continuous_update=False,\n",
" )\n",
"\n",
" sliders[symbol.name] = slider_c\n",
" sliders[f\"phi_{symbol.name}\"] = slider_phi\n",
" cphi_hbox = w.HBox([slider_c, slider_phi])\n",
" if symbol.base is a:\n",
" categorized_cphi_pair[2].append(cphi_hbox)\n",
" elif symbol.base is b:\n",
" categorized_cphi_pair[1].append(cphi_hbox)\n",
" elif symbol.base is c:\n",
" categorized_cphi_pair[0].append(cphi_hbox)\n",
" else:\n",
" raise NotImplementedError(symbol.name)\n",
"\n",
"tab_contents = []\n",
"resonances_name = [\"N*\", \"Δ*\", \"a₂*\"]\n",
"for i in range(len(resonances_name)):\n",
" tab_content = w.VBox([\n",
" w.HBox(categorized_sliders_m[i] + categorized_sliders_gamma[i]),\n",
" w.VBox(categorized_cphi_pair[i]),\n",
" ])\n",
" tab_contents.append(tab_content)\n",
"UI = w.Tab(tab_contents, titles=resonances_name)\n",
"UI"
shenvitor marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -624,6 +787,34 @@
"intensities = intensity_func(phsp)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"def insert_phi(parameters: dict) -> dict:\n",
" updated_parameters = {}\n",
" for key, value in parameters.items():\n",
" if key.startswith(\"phi_\"):\n",
" continue\n",
" if key.startswith((\"a\", \"b\", \"c\")):\n",
" phi_key = f\"phi_{key}\"\n",
" if phi_key in parameters:\n",
" phi = parameters[phi_key]\n",
" value *= np.exp(1j * phi) # noqa:PLW2901\n",
" updated_parameters[key] = value\n",
"\n",
" return updated_parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -638,25 +829,48 @@
},
"outputs": [],
"source": [
"%matplotlib widget\n",
"%config InlineBackend.figure_formats = ['png']\n",
"fig_2d, ax_2d = plt.subplots(dpi=200)\n",
"ax_2d.set_title(\"Model-weighted Phase space Dalitz Plot\")\n",
"ax_2d.set_xlabel(R\"$m^2(\\eta \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n",
"ax_2d.set_ylabel(R\"$m^2(\\pi^0 p)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n",
"\n",
"fig, ax = plt.subplots(dpi=200)\n",
"hist = ax.hist2d(\n",
" phsp[\"s_{12}\"],\n",
" phsp[\"s_{23}\"],\n",
" bins=200,\n",
" cmin=1e-6,\n",
" density=True,\n",
" cmap=\"jet\",\n",
" vmax=0.15,\n",
" weights=intensities,\n",
")\n",
"ax.set_title(\"Model-weighted Phase space Dalitz Plot\")\n",
"ax.set_xlabel(R\"$m^2(\\eta \\pi^0)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n",
"ax.set_ylabel(R\"$m^2(\\pi^0 p)\\;\\left[\\mathrm{GeV}^2\\right]$\")\n",
"fig.colorbar(hist[3], ax=ax)\n",
"fig.tight_layout()\n",
"plt.show()"
"mesh = None\n",
"\n",
"\n",
"def update_histogram(**parameters):\n",
" global mesh\n",
" parameters = insert_phi(parameters)\n",
" intensity_func.update_parameters(parameters)\n",
" intensity_weights = intensity_func(phsp)\n",
" bin_values, xedges, yedges = jnp.histogram2d(\n",
" phsp[\"s_{12}\"],\n",
" phsp[\"s_{23}\"],\n",
" bins=200,\n",
" weights=intensity_weights,\n",
" density=True,\n",
" )\n",
" bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)\n",
" x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])\n",
" if mesh is None:\n",
" mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap=\"jet\", vmax=0.15)\n",
" else:\n",
" mesh.set_array(bin_values.T)\n",
" fig_2d.canvas.draw_idle()\n",
"\n",
"\n",
"interactive_plot = w.interactive_output(update_histogram, sliders)\n",
"fig_2d.tight_layout()\n",
"fig_2d.colorbar(mesh, ax=ax_2d)\n",
"\n",
"if STATIC_PAGE:\n",
" filename = \"dalitz-plot.png\"\n",
" fig_2d.savefig(filename)\n",
" plt.close(fig_2d)\n",
" display(UI, Image(filename))\n",
"else:\n",
" display(UI, interactive_plot)"
]
},
{
Expand All @@ -674,8 +888,8 @@
"source_hidden": true
},
"tags": [
"hide-input",
"full-width"
"full-width",
"hide-input"
shenvitor marked this conversation as resolved.
Show resolved Hide resolved
]
},
"outputs": [],
Expand Down
Loading