From 8e8960caad9c13cab4dcb9a3d15934b917ad2a8a Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 24 May 2024 13:44:38 +0200 Subject: [PATCH] ENH(TR-032): compute histograms with JAX --- docs/report/032.ipynb | 130 +++++++++++++++++++++++++++--------------- 1 file changed, 83 insertions(+), 47 deletions(-) diff --git a/docs/report/032.ipynb b/docs/report/032.ipynb index 16518fae..4487cc9e 100644 --- a/docs/report/032.ipynb +++ b/docs/report/032.ipynb @@ -80,6 +80,7 @@ "import ampform\n", "import attrs\n", "import graphviz\n", + "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", @@ -967,6 +968,45 @@ "### Weighted data with $F$ vector " ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "mystnb": { + "code_prompt_show": "Function for plotting histograms with JAX" + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "def fast_histogram(\n", + " data: jnp.ndarray,\n", + " weights: jnp.ndarray | None = None,\n", + " bins: int = 100,\n", + " density: bool | None = None,\n", + " fill: bool = True,\n", + " ax=plt,\n", + " **plot_kwargs,\n", + ") -> None:\n", + " bin_values, bin_edges = jnp.histogram(\n", + " data,\n", + " bins=bins,\n", + " density=density,\n", + " weights=weights,\n", + " )\n", + " if fill:\n", + " bin_rights = bin_edges[1:]\n", + " ax.fill_between(bin_rights, bin_values, step=\"pre\", **plot_kwargs)\n", + " else:\n", + " bin_mids = (bin_edges[:-1] + bin_edges[1:]) / 2\n", + " ax.step(bin_mids, bin_values, **plot_kwargs)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -980,18 +1020,25 @@ }, "outputs": [], "source": [ + "fig, ax = plt.subplots(figsize=(9, 4))\n", + "ax.set_xlabel(R\"$m_{p\\eta/K\\Sigma}$ [GeV]\")\n", "for i in range(n_channels):\n", - " fig, ax = plt.subplots(figsize=(6, 5))\n", " intensity = np.real(INTENSITY_FUNCS_FVECTOR[i](PHSP[i]))\n", - " c = ax.hist(\n", - " np.real(PHSP[i][\"m_01\"]) ** 2,\n", - " bins=100,\n", + " fast_histogram(\n", + " np.real(PHSP[i][\"m_01\"]),\n", " weights=intensity,\n", + " alpha=0.5,\n", + " bins=200,\n", + " density=True,\n", + " label=f\"${DECAYS[i].child1.latex} {DECAYS[i].child2.latex}$\",\n", + " ax=ax,\n", " )\n", - " ax.set_xlabel(R\"$M^2\\left(\\eta p\\right)\\, \\mathrm{[(GeV/c)^2]}$\")\n", - " ax.set_ylabel(R\"Intensity [a.u.]\")\n", - " fig.tight_layout()\n", - " plt.show()" + "mass_pars = {k: v for k, v in new_parameters_fvector.items() if k.startswith(\"m_{\")}\n", + "for i, (k, v) in enumerate(mass_pars.items()):\n", + " ax.axvline(v, c=f\"C{i}\", label=f\"${k}$\", ls=\"dashed\")\n", + "ax.legend()\n", + "ax.set_ylim(0, None)\n", + "fig.show()" ] }, { @@ -1039,33 +1086,23 @@ }, "outputs": [], "source": [ + "fig, ax = plt.subplots(figsize=(9, 4))\n", + "ax.set_xlabel(R\"$m_{p\\eta/K\\Sigma}$ [GeV]\")\n", "for i in range(n_channels):\n", - " resonances = sorted(\n", - " MODELS[i].reaction_info.get_intermediate_particles(),\n", - " key=lambda p: p.mass,\n", - " )\n", - " evenly_spaced_interval = np.linspace(\n", - " 0, 1, len(INTENSITY_FUNCS_FVECTOR[i].parameters.items())\n", - " )\n", - " colors = [cm.rainbow(x) for x in evenly_spaced_interval]\n", - " fig, ax = plt.subplots(figsize=(9, 4))\n", - " ax.hist(\n", + " fast_histogram(\n", " np.real(DATA[i][\"m_01\"]),\n", - " bins=200,\n", " alpha=0.5,\n", + " bins=200,\n", " density=True,\n", + " label=f\"${DECAYS[i].child1.latex} {DECAYS[i].child2.latex}$\",\n", + " ax=ax,\n", " )\n", - " ax.set_xlabel(\"$m$ [GeV]\")\n", - " for (k, v), color in zip(new_parameters_fvector.items(), colors):\n", - " if k.startswith(\"m_{\"):\n", - " ax.axvline(\n", - " x=v,\n", - " linestyle=\"dotted\",\n", - " label=r\"$\" + k + \"$\",\n", - " color=color,\n", - " )\n", - " ax.legend()\n", - " plt.show()" + "mass_pars = {k: v for k, v in new_parameters_fvector.items() if k.startswith(\"m_{\")}\n", + "for i, (k, v) in enumerate(mass_pars.items()):\n", + " ax.axvline(v, c=f\"C{i}\", label=f\"${k}$\", ls=\"dashed\")\n", + "ax.legend()\n", + "ax.set_ylim(0, None)\n", + "fig.show()" ] }, { @@ -1114,16 +1151,12 @@ }, "outputs": [], "source": [ - "def indicate_masses(ax, function):\n", - " ax.set_xlabel(\"$m$ [GeV]\")\n", - " for (k, v), color_F in zip(function.parameters.items(), colors_F):\n", - " if k.startswith(\"m_{N\"):\n", - " ax.axvline(\n", - " x=v,\n", - " linestyle=\"dotted\",\n", - " label=r\"$\" + k + \"$\" \"(F vector)\",\n", - " color=color_F,\n", - " )\n", + "def indicate_masses(ax, intensity_func):\n", + " mass_pars = {\n", + " k: v for k, v in intensity_func.parameters.items() if k.startswith(\"m_{N\")\n", + " }\n", + " for i, (k, v) in enumerate(mass_pars.items()):\n", + " ax.axvline(v, c=f\"C{i}\", label=f\"${k}$\", ls=\"dashed\")\n", "\n", "\n", "def compare_model(\n", @@ -1133,23 +1166,25 @@ " function: Function[DataSample, np.ndarray],\n", " bins: int = 100,\n", "):\n", - " fig, ax = plt.subplots(figsize=(9, 4))\n", - " ax.hist(\n", + " fig, ax = plt.subplots(figsize=(9, 4), sharex=True)\n", + " fast_histogram(\n", " data[variable_name].real,\n", - " bins=bins,\n", " alpha=0.5,\n", - " label=\"data\",\n", + " bins=bins,\n", " density=True,\n", + " label=\"data\",\n", + " ax=ax,\n", " )\n", " intensities = function(phsp)\n", - " ax.hist(\n", + " fast_histogram(\n", " phsp[variable_name].real,\n", " weights=intensities,\n", " bins=bins,\n", - " histtype=\"step\",\n", " color=\"red\",\n", - " label=\"Fit model with $F$ vector\",\n", " density=True,\n", + " fill=False,\n", + " label=\"Fit model with $F$ vector\",\n", + " ax=ax,\n", " )\n", " indicate_masses(ax, function)\n", " ax.axvline(\n", @@ -1164,6 +1199,7 @@ " linestyle=\"dotted\",\n", " label=rf\"${DECAYS[1].child1.latex} \\, {DECAYS[1].child2.latex}$ threshold\",\n", " )\n", + " ax.set_ylim(0, None)\n", " ax.legend()\n", " fig.show()" ]