diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 07502dce..0563ba78 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -37,7 +37,15 @@ def _hankel(k, beta, rmax): @jax.jit -def _MoffatCalculateSRFromHLR(re, rm, beta): +def _bodymi(xcur, rm, re, beta): + x = (1 + jnp.power(1 + (rm / xcur) ** 2, 1 - beta)) / 2 + x = jnp.power(x, 1 / (1 - beta)) + x = jnp.sqrt(x - 1) + return re / x + + +@partial(jax.jit, static_argnames=("nitr",)) +def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=100): """ The basic equation that is relevant here is the flux of a Moffat profile out to some radius. @@ -53,17 +61,10 @@ def _MoffatCalculateSRFromHLR(re, rm, beta): nb2. In GalSim definition rm = 0 (ex. no truncated Moffat) means in reality rm=+Inf. BUT the case rm==0 is already done, so HERE rm != 0 """ - - # fix loop iteration is faster and reach eps=1e-6 (single precision) - def body(i, xcur): - x = (1 + jnp.power(1 + (rm / xcur) ** 2, 1 - beta)) / 2 - x = jnp.power(x, 1 / (1 - beta)) - x = jnp.sqrt(x - 1) - return re / x - - rd = jax.lax.fori_loop(0, 1000, body, re) - - return rd + xcur = re + for _ in range(nitr): + xcur = _bodymi(xcur, rm, re, beta) + return xcur @implements(_galsim.Moffat)