Skip to content

Commit

Permalink
ENH(TR-032): compute histograms with JAX
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed May 24, 2024
1 parent 290cded commit 8e8960c
Showing 1 changed file with 83 additions and 47 deletions.
130 changes: 83 additions & 47 deletions docs/report/032.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"import ampform\n",
"import attrs\n",
"import graphviz\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
Expand Down Expand Up @@ -967,6 +968,45 @@
"### Weighted data with $F$ vector "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"mystnb": {
"code_prompt_show": "Function for plotting histograms with JAX"
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"def fast_histogram(\n",
" data: jnp.ndarray,\n",
" weights: jnp.ndarray | None = None,\n",
" bins: int = 100,\n",
" density: bool | None = None,\n",
" fill: bool = True,\n",
" ax=plt,\n",
" **plot_kwargs,\n",
") -> None:\n",
" bin_values, bin_edges = jnp.histogram(\n",
" data,\n",
" bins=bins,\n",
" density=density,\n",
" weights=weights,\n",
" )\n",
" if fill:\n",
" bin_rights = bin_edges[1:]\n",
" ax.fill_between(bin_rights, bin_values, step=\"pre\", **plot_kwargs)\n",
" else:\n",
" bin_mids = (bin_edges[:-1] + bin_edges[1:]) / 2\n",
" ax.step(bin_mids, bin_values, **plot_kwargs)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -980,18 +1020,25 @@
},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(9, 4))\n",
"ax.set_xlabel(R\"$m_{p\\eta/K\\Sigma}$ [GeV]\")\n",
"for i in range(n_channels):\n",
" fig, ax = plt.subplots(figsize=(6, 5))\n",
" intensity = np.real(INTENSITY_FUNCS_FVECTOR[i](PHSP[i]))\n",
" c = ax.hist(\n",
" np.real(PHSP[i][\"m_01\"]) ** 2,\n",
" bins=100,\n",
" fast_histogram(\n",
" np.real(PHSP[i][\"m_01\"]),\n",
" weights=intensity,\n",
" alpha=0.5,\n",
" bins=200,\n",
" density=True,\n",
" label=f\"${DECAYS[i].child1.latex} {DECAYS[i].child2.latex}$\",\n",
" ax=ax,\n",
" )\n",
" ax.set_xlabel(R\"$M^2\\left(\\eta p\\right)\\, \\mathrm{[(GeV/c)^2]}$\")\n",
" ax.set_ylabel(R\"Intensity [a.u.]\")\n",
" fig.tight_layout()\n",
" plt.show()"
"mass_pars = {k: v for k, v in new_parameters_fvector.items() if k.startswith(\"m_{\")}\n",
"for i, (k, v) in enumerate(mass_pars.items()):\n",
" ax.axvline(v, c=f\"C{i}\", label=f\"${k}$\", ls=\"dashed\")\n",
"ax.legend()\n",
"ax.set_ylim(0, None)\n",
"fig.show()"
]
},
{
Expand Down Expand Up @@ -1039,33 +1086,23 @@
},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(9, 4))\n",
"ax.set_xlabel(R\"$m_{p\\eta/K\\Sigma}$ [GeV]\")\n",
"for i in range(n_channels):\n",
" resonances = sorted(\n",
" MODELS[i].reaction_info.get_intermediate_particles(),\n",
" key=lambda p: p.mass,\n",
" )\n",
" evenly_spaced_interval = np.linspace(\n",
" 0, 1, len(INTENSITY_FUNCS_FVECTOR[i].parameters.items())\n",
" )\n",
" colors = [cm.rainbow(x) for x in evenly_spaced_interval]\n",
" fig, ax = plt.subplots(figsize=(9, 4))\n",
" ax.hist(\n",
" fast_histogram(\n",
" np.real(DATA[i][\"m_01\"]),\n",
" bins=200,\n",
" alpha=0.5,\n",
" bins=200,\n",
" density=True,\n",
" label=f\"${DECAYS[i].child1.latex} {DECAYS[i].child2.latex}$\",\n",
" ax=ax,\n",
" )\n",
" ax.set_xlabel(\"$m$ [GeV]\")\n",
" for (k, v), color in zip(new_parameters_fvector.items(), colors):\n",
" if k.startswith(\"m_{\"):\n",
" ax.axvline(\n",
" x=v,\n",
" linestyle=\"dotted\",\n",
" label=r\"$\" + k + \"$\",\n",
" color=color,\n",
" )\n",
" ax.legend()\n",
" plt.show()"
"mass_pars = {k: v for k, v in new_parameters_fvector.items() if k.startswith(\"m_{\")}\n",
"for i, (k, v) in enumerate(mass_pars.items()):\n",
" ax.axvline(v, c=f\"C{i}\", label=f\"${k}$\", ls=\"dashed\")\n",
"ax.legend()\n",
"ax.set_ylim(0, None)\n",
"fig.show()"
]
},
{
Expand Down Expand Up @@ -1114,16 +1151,12 @@
},
"outputs": [],
"source": [
"def indicate_masses(ax, function):\n",
" ax.set_xlabel(\"$m$ [GeV]\")\n",
" for (k, v), color_F in zip(function.parameters.items(), colors_F):\n",
" if k.startswith(\"m_{N\"):\n",
" ax.axvline(\n",
" x=v,\n",
" linestyle=\"dotted\",\n",
" label=r\"$\" + k + \"$\" \"(F vector)\",\n",
" color=color_F,\n",
" )\n",
"def indicate_masses(ax, intensity_func):\n",
" mass_pars = {\n",
" k: v for k, v in intensity_func.parameters.items() if k.startswith(\"m_{N\")\n",
" }\n",
" for i, (k, v) in enumerate(mass_pars.items()):\n",
" ax.axvline(v, c=f\"C{i}\", label=f\"${k}$\", ls=\"dashed\")\n",
"\n",
"\n",
"def compare_model(\n",
Expand All @@ -1133,23 +1166,25 @@
" function: Function[DataSample, np.ndarray],\n",
" bins: int = 100,\n",
"):\n",
" fig, ax = plt.subplots(figsize=(9, 4))\n",
" ax.hist(\n",
" fig, ax = plt.subplots(figsize=(9, 4), sharex=True)\n",
" fast_histogram(\n",
" data[variable_name].real,\n",
" bins=bins,\n",
" alpha=0.5,\n",
" label=\"data\",\n",
" bins=bins,\n",
" density=True,\n",
" label=\"data\",\n",
" ax=ax,\n",
" )\n",
" intensities = function(phsp)\n",
" ax.hist(\n",
" fast_histogram(\n",
" phsp[variable_name].real,\n",
" weights=intensities,\n",
" bins=bins,\n",
" histtype=\"step\",\n",
" color=\"red\",\n",
" label=\"Fit model with $F$ vector\",\n",
" density=True,\n",
" fill=False,\n",
" label=\"Fit model with $F$ vector\",\n",
" ax=ax,\n",
" )\n",
" indicate_masses(ax, function)\n",
" ax.axvline(\n",
Expand All @@ -1164,6 +1199,7 @@
" linestyle=\"dotted\",\n",
" label=rf\"${DECAYS[1].child1.latex} \\, {DECAYS[1].child2.latex}$ threshold\",\n",
" )\n",
" ax.set_ylim(0, None)\n",
" ax.legend()\n",
" fig.show()"
]
Expand Down

0 comments on commit 8e8960c

Please sign in to comment.