Skip to content

Commit

Permalink
dev: try a newton fixed-point method
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed Jan 31, 2025
1 parent d15e76d commit f15ce84
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 2 deletions.
174 changes: 174 additions & 0 deletions dev/notebooks/spergel_fixed_point.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d7b8bc37-8799-433c-9399-de95a21a1727",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"import galsim\n",
"import numpy as np\n",
"\n",
"import jax.numpy as jnp"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "774101b1",
"metadata": {},
"outputs": [],
"source": [
"from jax_galsim.spergel import (\n",
" fz_nup1, _gammap1, _spergel_hlr_pade,\n",
" fluxfractionFunc, fz_nu, calculateFluxRadius,\n",
")\n",
"\n",
"@jax.jit\n",
"def _calculateFluxRadius_newtons_kernel(i, args):\n",
" \"\"\"Newton's method kernel for calculateFluxRadius\n",
"\n",
" Returns\n",
"\n",
" lnz - fluxfractionFunc(z, nu, alpha) / dfluxfractionFunc(z, nu, alpha)_dz / z\n",
"\n",
" which is Newton's kernel but in log space.\n",
" \"\"\"\n",
" lnz, alpha, nu = args\n",
" z = jnp.exp(lnz)\n",
" dn = (jnp.power(2.0, nu) * _gammap1(nu))\n",
" fz = 1.0 - fz_nup1(z, nu) / dn - alpha\n",
" dfzdz = z * fz_nu(z, nu) / dn\n",
"\n",
" # we clip the result to avoid numerical issues near bounds\n",
" lnz = jnp.clip(\n",
" lnz - fz / dfzdz / z,\n",
" min=-100,\n",
" max=100,\n",
" )\n",
"\n",
" return lnz, alpha, nu\n",
"\n",
"\n",
"@jax.jit\n",
"def calculateFluxRadiusNewton(alpha, nu):\n",
" \"\"\"Return radius R enclosing flux fraction alpha in unit of the scale radius r0\n",
"\n",
" Method: Solve F(R/r0=z)/Flux - alpha = 0 using Netwon's method\n",
"\n",
" We can integrate the profile to get\n",
"\n",
" F(R)/F = int( 1/(2^nu Gamma(nu+1)) (r/r0)^(nu+1) K_nu(r/r0) dr/r0; r=0..R) = alpha\n",
"\n",
" So if we define z = R/r0 and f(z) = F(z * r0)/F - alpha, then Newton's method is\n",
"\n",
" z -> z - f(z) / f'(z)\n",
"\n",
" We actually run the method for ln(z) which is\n",
"\n",
" ln(z) -> ln(z) - f(z) / f'(z) / z\n",
"\n",
" Typical use cases include:\n",
"\n",
" - alpha = 1/2 => R = Half-Light-Radius,\n",
" - alpha = 1 - folding-thresold => R used for stepk computation\n",
" \"\"\"\n",
" # seed the iteration with the Pade approximation to the HLR\n",
" # scaled by the fraction of flux to some power\n",
" zalpha = _spergel_hlr_pade(nu) * jnp.sqrt(alpha / 0.5)\n",
" return jnp.exp(jax.lax.fori_loop(\n",
" 0, 100,\n",
" _calculateFluxRadius_newtons_kernel,\n",
" (jnp.log(zalpha), alpha, nu),\n",
" )[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1be23e1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"eps, nu, log10(alpha): 1e-12 -0.84 -12.0\n",
"3.5138887102897e-38 -2.2121720121483927e-17 1.0587911840678754e-21 1.8761616702453412e-07 1.000534100015216e-12\n",
"\n",
"eps, nu, log10(alpha): 1e-12 3.999 -12.0\n",
"3.9966817649384216e-06 -1.576433954596703e-15 3.999832106175669e-06 3.1094518726606304e-16 9.984622740022494e-13\n",
"\n",
"eps, nu, log10(alpha): 1e-12 -0.84 -4.3428487456249e-13\n",
"25.572845945758726 0.0 25.572509765625 -3.3306690738754696e-16 0.999999999999\n",
"\n",
"eps, nu, log10(alpha): 1e-12 3.999 -4.3428487456249e-13\n",
"38.6677767503012 0.0 38.6676025390625 -1.1102230246251565e-16 0.999999999999\n",
"\n",
"eps, nu, log10(alpha): 0.1 -0.84 -1.0\n",
"0.0008333666650951336 6.38378239159465e-16 0.0008333666650951221 3.0531133177191805e-16 0.10000000000000096\n",
"\n",
"eps, nu, log10(alpha): 0.1 3.999 -1.0\n",
"1.3092245672406861 6.38378239159465e-16 1.3092245672406833 8.326672684688674e-17 0.10000000000000037\n",
"\n",
"eps, nu, log10(alpha): 0.1 -0.84 -0.045757490560675115\n",
"1.2147258941802845 0.0 1.214725894180284 -1.1102230246251565e-16 0.9000000000000001\n",
"\n",
"eps, nu, log10(alpha): 0.1 3.999 -0.045757490560675115\n",
"6.899340112339111 -1.1102230246251565e-16 6.899340112339113 -1.1102230246251565e-16 0.8999999999999999\n"
]
}
],
"source": [
"for eps in [1e-12, 0.1]:\n",
" for alpha in [eps, 1.0 - eps]:\n",
" for nu in [-0.84, 3.999]:\n",
"\n",
" print(\"\\neps, nu, log10(alpha):\", eps, nu, np.log10(alpha))\n",
" zfp = calculateFluxRadiusNewton(alpha, nu)\n",
" zbs = calculateFluxRadius(alpha, nu)\n",
" print(\n",
" zfp,\n",
" fluxfractionFunc(zfp, nu, alpha),\n",
" zbs,\n",
" fluxfractionFunc(zbs, nu, alpha),\n",
" galsim.Spergel(nu, scale_radius=1.0).calculateIntegratedFlux(zfp),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29cd9aa2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "jax-galsim",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
4 changes: 2 additions & 2 deletions jax_galsim/spergel.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def reducedfluxfractionFunc(z, nu, norm):


@jax.jit
def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0):
def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=40.0):
"""Return radius R enclosing flux fraction alpha in unit of the scale radius r0
Method: Solve F(R/r0=z)/Flux - alpha = 0 using bisection algorithm
Expand All @@ -186,7 +186,7 @@ def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0):
nb. it is supposed that nu is in [-0.85, 4.0] checked in the Spergel class init
"""
return bisect_for_root(
partial(fluxfractionFunc, nu=nu, alpha=alpha), zmin, zmax, niter=75
partial(fluxfractionFunc, nu=nu, alpha=alpha), zmin, zmax, niter=75,
)


Expand Down

0 comments on commit f15ce84

Please sign in to comment.