Skip to content

Commit

Permalink
Revert "ENH: compute model with JAX"
Browse files Browse the repository at this point in the history
This reverts commit 0f2a40f.
  • Loading branch information
redeboer committed Jul 16, 2024
1 parent 3ef0a1e commit a6dff95
Showing 1 changed file with 28 additions and 29 deletions.
57 changes: 28 additions & 29 deletions docs/report/033.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
},
"outputs": [],
"source": [
"%pip install -q iminuit==2.26.0 jax==0.4.30 jaxlib==0.4.30 matplotlib==3.9.1 numpy==1.26.4 pandas==2.2.2 particle==0.24.0 phasespace==1.10.3 scipy==1.14.0 tqdm==4.66.4 vector==1.4.1"
"%pip install -q iminuit==2.26.0 matplotlib==3.9.1 numpy==1.26.4 pandas==2.2.2 particle==0.24.0 phasespace==1.10.3 scipy==1.14.0 tqdm==4.66.4 vector==1.4.1"
]
},
{
Expand All @@ -63,7 +63,7 @@
},
"source": [
":::{admonition} Abstract\n",
"This document introduces Amplitude Analysis / Partial Wave Analysis (PWA) by demonstrating its application to a specific reaction channel and amplitude model. It aims to equip readers with a basic understanding of the full workflow and methodologies of PWA in hadron physics through a practical, hands-on example. Only basic Python programming and libraries (e.g. [`numpy`](https://numpy.org/doc/stable), [`scipy`](https://docs.scipy.org/doc/scipy), etc.) are used to illustrate the more fundamental steps in a PWA. Calculations with 4-vectors in this report are performed with the [`vector`](https://vector.readthedocs.io/en/latest/usage/intro.html) package. We use the [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) module to speed up calculations, but its interface is interchangeable with `numpy`.\n",
"This document introduces Amplitude Analysis / Partial Wave Analysis (PWA) by demonstrating its application to a specific reaction channel and amplitude model. It aims to equip readers with a basic understanding of the full workflow and methodologies of PWA in hadron physics through a practical, hands-on example. Only basic Python programming and libraries (e.g. [`numpy`](https://numpy.org/doc/stable), [`scipy`](https://docs.scipy.org/doc/scipy), etc.) are used to illustrate the more fundamental steps in a PWA. Calculations with 4-vectors in this report are performed with the [`vector`](https://vector.readthedocs.io/en/latest/usage/intro.html) package.\n",
":::"
]
},
Expand Down Expand Up @@ -135,7 +135,6 @@
},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
Expand Down Expand Up @@ -374,7 +373,7 @@
" A12 = BW(s12, M12, Gamma12)\n",
" A23 = BW(s23, M23, Gamma23)\n",
" A31 = BW(s31, M31, Gamma31)\n",
" return jnp.abs(A12 + A23 + A31) ** 2\n",
" return np.abs(A12 + A23 + A31) ** 2\n",
"\n",
"\n",
"def BW(s, m, Gamma):\n",
Expand Down Expand Up @@ -522,7 +521,7 @@
"outputs": [],
"source": [
"def SH_model(phi1, theta1, phi2, theta2, *, c_0, **pars):\n",
" return jnp.abs(Ylm12(phi1, theta1, **pars) + Ylm23(phi2, theta2, **pars) + c_0) ** 2"
" return np.abs(Ylm12(phi1, theta1, **pars) + Ylm23(phi2, theta2, **pars) + c_0) ** 2"
]
},
{
Expand Down Expand Up @@ -585,7 +584,7 @@
" A12 = BW(s12, M12, Gamma12) * Ylm12(phi1, theta1, **parameters)\n",
" A23 = BW(s23, M23, Gamma23) * Ylm23(phi2, theta2, **parameters)\n",
" A31 = BW(s31, M31, Gamma31) * c_0\n",
" return jnp.abs(A12 + A23 + A31) ** 2"
" return np.abs(A12 + A23 + A31) ** 2"
]
},
{
Expand Down Expand Up @@ -826,7 +825,7 @@
"m_proton = 0.938\n",
"m_eta = 0.548\n",
"m_pi = 0.135\n",
"m_0 = jnp.sqrt(2 * E_lab_gamma * m_proton + m_proton**2)\n",
"m_0 = np.sqrt(2 * E_lab_gamma * m_proton + m_proton**2)\n",
"m_0"
]
},
Expand Down Expand Up @@ -1234,9 +1233,9 @@
"p12_phsp = p1_phsp + p2_phsp\n",
"p23_phsp = p2_phsp + p3_phsp\n",
"p31_phsp = p3_phsp + p1_phsp\n",
"s12_phsp = jnp.array(p12_phsp.m2)\n",
"s23_phsp = jnp.array(p23_phsp.m2)\n",
"s31_phsp = jnp.array(p31_phsp.m2)"
"s12_phsp = p12_phsp.m2\n",
"s23_phsp = p23_phsp.m2\n",
"s31_phsp = p31_phsp.m2"
]
},
{
Expand Down Expand Up @@ -1497,12 +1496,12 @@
},
"outputs": [],
"source": [
"theta1_phsp = jnp.array(theta_helicity(p1_phsp, p12_phsp))\n",
"theta2_phsp = jnp.array(theta_helicity(p2_phsp, p23_phsp))\n",
"theta3_phsp = jnp.array(theta_helicity(p3_phsp, p31_phsp))\n",
"phi1_phsp = jnp.array(phi_helicity(p1_phsp, p12_phsp))\n",
"phi2_phsp = jnp.array(phi_helicity(p2_phsp, p23_phsp))\n",
"phi3_phsp = jnp.array(phi_helicity(p3_phsp, p31_phsp))"
"theta1_phsp = theta_helicity(p1_phsp, p12_phsp)\n",
"theta2_phsp = theta_helicity(p2_phsp, p23_phsp)\n",
"theta3_phsp = theta_helicity(p3_phsp, p31_phsp)\n",
"phi1_phsp = phi_helicity(p1_phsp, p12_phsp)\n",
"phi2_phsp = phi_helicity(p2_phsp, p23_phsp)\n",
"phi3_phsp = phi_helicity(p3_phsp, p31_phsp)"
]
},
{
Expand Down Expand Up @@ -1609,9 +1608,9 @@
},
"outputs": [],
"source": [
"PHI, THETA = jnp.meshgrid(\n",
" jnp.linspace(-np.pi, +np.pi, num=1_000),\n",
" jnp.linspace(0, np.pi, num=1_000),\n",
"PHI, THETA = np.meshgrid(\n",
" np.linspace(-np.pi, +np.pi, num=1_000),\n",
" np.linspace(0, np.pi, num=1_000),\n",
")\n",
"Z12 = Ylm12(PHI, THETA, **toy_parameters)\n",
"Z23 = Ylm23(PHI, THETA, **toy_parameters)"
Expand Down Expand Up @@ -2082,9 +2081,9 @@
"p23_data = p2_data + p3_data\n",
"p31_data = p3_data + p1_data\n",
"\n",
"s12_data = jnp.array(p12_data.m2)\n",
"s23_data = jnp.array(p23_data.m2)\n",
"s31_data = jnp.array(p31_data.m2)\n",
"s12_data = p12_data.m2\n",
"s23_data = p23_data.m2\n",
"s31_data = p31_data.m2\n",
"\n",
"theta1_CM_data = p1_data.theta\n",
"theta2_CM_data = p2_data.theta\n",
Expand All @@ -2093,12 +2092,12 @@
"phi2_CM_data = p2_data.phi\n",
"phi3_CM_data = p3_data.phi\n",
"\n",
"theta1_data = jnp.array(theta_helicity(p1_data, p12_data))\n",
"theta2_data = jnp.array(theta_helicity(p2_data, p23_data))\n",
"theta3_data = jnp.array(theta_helicity(p3_data, p31_data))\n",
"phi1_data = jnp.array(phi_helicity(p1_data, p12_data))\n",
"phi2_data = jnp.array(phi_helicity(p2_data, p23_data))\n",
"phi3_data = jnp.array(phi_helicity(p3_data, p31_data))"
"theta1_data = theta_helicity(p1_data, p12_data)\n",
"theta2_data = theta_helicity(p2_data, p23_data)\n",
"theta3_data = theta_helicity(p3_data, p31_data)\n",
"phi1_data = phi_helicity(p1_data, p12_data)\n",
"phi2_data = phi_helicity(p2_data, p23_data)\n",
"phi3_data = phi_helicity(p3_data, p31_data)"
]
},
{
Expand Down Expand Up @@ -3147,7 +3146,7 @@
" model_integral = BW_SH_model(*phsp, **parameters).mean()\n",
" data_intensities = BW_SH_model(*data, **parameters)\n",
" likelihoods = data_intensities / model_integral\n",
" log_likelihood = jnp.log(likelihoods).sum()\n",
" log_likelihood = np.log(likelihoods).sum()\n",
" return -log_likelihood"
]
},
Expand Down

0 comments on commit a6dff95

Please sign in to comment.