Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spergel #86

Merged
merged 35 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
3f07298
first import spergel profile
jecampagne Jan 5, 2024
255d95d
version w/o shooting photons
jecampagne Jan 5, 2024
718f4ad
black...
jecampagne Jan 5, 2024
606f816
jitting/vmapping tests
jecampagne Jan 5, 2024
dccff0a
jitting/vmapping tests
jecampagne Jan 5, 2024
c31d1fd
import Spergel
jecampagne Jan 5, 2024
7b337bd
allow test_spergel.py from Galsim test suite
jecampagne Jan 5, 2024
320e6bc
add math expression to class comment
jecampagne Jan 6, 2024
0a8f918
1) take care of gamma(nu) with interger nu. 2) care of z=0 for fsmall…
jecampagne Jan 6, 2024
80c7205
fix bug in _xValue
jecampagne Jan 7, 2024
4a1808e
fix bug in _kValue
jecampagne Jan 7, 2024
7009fc7
fix typo in _shootxnorm
jecampagne Jan 7, 2024
b5b641a
add calculateFluxRadius & calculateIntegratedFlux functions
jecampagne Jan 7, 2024
d265d12
black...
jecampagne Jan 7, 2024
9cb1a8e
fix C&P typo
jecampagne Jan 8, 2024
3f75546
test_api fails
jecampagne Jan 8, 2024
1c62077
Update tests/jax/test_api.py
jecampagne Jan 9, 2024
b560bed
Update tests/jax/test_api.py
jecampagne Jan 9, 2024
2704fc5
run pre-commit
beckermr Jan 9, 2024
afa3bd8
fix fsmallz_nu missing c3 coeff
jecampagne Jan 9, 2024
8f39867
Merge branch 'spergel' of github.com:GalSim-developers/JAX-GalSim int…
jecampagne Jan 9, 2024
25f6ee3
fix to get workable function for low z and nu integer
jecampagne Jan 9, 2024
0f5160f
blacken
beckermr Jan 9, 2024
53526ec
adapt calculateFluxRadius to accept zmin/zmax tuning
jecampagne Jan 11, 2024
1a7a557
code _shoot_neg (shoot with negative nu) as it is in Galsim even if t…
jecampagne Jan 12, 2024
ee6effe
black
jecampagne Jan 12, 2024
89fb1d8
implement shoot for negative nu, such as 1) use a linear approx of th…
jecampagne Jan 14, 2024
15508ea
add Spergel in the changelog, 2) use _shootxnorm Galsim variable inst…
jecampagne Jan 15, 2024
ec69ac8
Update jax_galsim/spergel.py
jecampagne Jan 17, 2024
2ba762c
Modified Bessel 2nd Kind (kv)and Gamm(x) moved to bessel.py both for …
jecampagne Jan 18, 2024
fe808bd
Merge branch 'spergel' of github.com:GalSim-developers/JAX-GalSim int…
jecampagne Jan 18, 2024
167cc89
spurious import
jecampagne Jan 18, 2024
253232b
pre-commit
jecampagne Jan 18, 2024
34ffa4c
1) rm gamma from bessel ralated code, 2) adapt the input args of kv
jecampagne Jan 18, 2024
ebc0e5a
float needed
jecampagne Jan 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions jax_galsim/bessel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import jax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
import tensorflow_probability as tfp


# the code here for Si, f, g and _si_small_pade is taken from galsim/src/math/Sinc.cpp
Expand Down Expand Up @@ -101,3 +102,15 @@ def si(x):
- _g_pade(x, x2) * jnp.sin(x),
_si_small_pade(x, x2),
)


@jax.jit
def kv(nu, x):
"""Modified Bessel 2nd kind"""
return tfp.substrates.jax.math.bessel_kve(nu * 1.0, x) / jnp.exp(jnp.abs(x))


@jax.jit
def gamma(x):
"""Gamma(x)"""
return jnp.exp(jax.lax.lgamma(x * 1.0))
3 changes: 2 additions & 1 deletion jax_galsim/moffat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.bessel import j0
from jax_galsim.bessel import kv
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral
from jax_galsim.core.utils import bisect_for_root, ensure_hashable
Expand All @@ -18,7 +19,7 @@
@jax.jit
def _Knu(nu, x):
"""Modified Bessel 2nd kind for Untruncated Moffat"""
return tfp.substrates.jax.math.bessel_kve(nu * 1.0, x) / jnp.exp(jnp.abs(x))
return kv(nu, x)


@jax.jit
Expand Down
32 changes: 13 additions & 19 deletions jax_galsim/spergel.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,25 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
import tensorflow_probability as tfp
from jax._src.numpy.util import _wraps
from jax.tree_util import Partial as partial
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.utils import bisect_for_root, ensure_hashable
from jax_galsim.bessel import kv, gamma
from jax_galsim.gsobject import GSObject
from jax_galsim.random import UniformDeviate
from jax_galsim.utilities import lazy_property


@jax.jit
def _Knu(nu, x):
"""Modified Bessel 2nd kind"""
return tfp.substrates.jax.math.bessel_kve(nu * 1.0, x) / jnp.exp(jnp.abs(x))


@jax.jit
def _gamma(nu):
"""Gamma(nu) with care for integer nu in [0,5]"""
return jnp.select(
[nu == 0, nu == 1, nu == 2, nu == 3, nu == 4, nu == 5],
[jnp.inf, 1.0, 1.0, 2.0, 6.0, 24.0],
default=jnp.exp(jax.lax.lgamma(nu * 1.0)),
default=gamma(nu),
)


Expand All @@ -43,7 +37,7 @@ def z2lz(z):

@jax.jit
def f0(z):
"""K_0[z] z -> 0 O(z^4)"""
"""K_0[z] with z -> 0 O(z^4)"""
z2 = z * z
z4 = z2 * z2
c0 = 0.11593151565841244881
Expand All @@ -54,7 +48,7 @@ def f0(z):

@jax.jit
def f1(z):
"""z^1 K_1[z] z -> 0 O(z^4)"""
"""z^1 K_1[z] with z -> 0 O(z^4)"""
z2 = z * z
z4 = z2 * z2
c0 = z2lz(z) # z^2 log(z)
Expand All @@ -65,7 +59,7 @@ def f1(z):

@jax.jit
def f2(z):
"""z^2 K_2[z] z -> 0 O(z^4)"""
"""z^2 K_2[z] with z -> 0 O(z^4)"""
c1 = 0.10824143945730155610
z2 = z * z
z4 = z2 * z2
Expand All @@ -75,23 +69,23 @@ def f2(z):

@jax.jit
def f3(z):
"""z^3 K_3[z] z -> 0 O(z^4)"""
"""z^3 K_3[z] with z -> 0 O(z^4)"""
z2 = z * z
z4 = z2 * z2
return 8.0 - z2 + 0.125 * z4


@jax.jit
def f4(z):
"""z^4 K_4[z] z -> 0 O(z^4)"""
"""z^4 K_4[z] with z -> 0 O(z^4)"""
z2 = z * z
z4 = z2 * z2
return 48.0 - 4 * z2 + 0.25 * z4


@jax.jit
def f5(z):
"""z^5 K_5[z] z -> 0 O(z^4)"""
"""z^5 K_5[z] with z -> 0 O(z^4)"""
z2 = z * z
z4 = z2 * z2
return 384.0 - 24.0 * z2 + z4
Expand All @@ -100,7 +94,7 @@ def f5(z):
@jax.jit
def fsmallz_nu(z, nu):
def fnu(z, nu):
"""z^nu K_nu[z] z -> 0 O(z^4) z > 0"""
"""z^nu K_nu[z] with z -> 0 O(z^4) z > 0"""
nu += 1.0e-10 # to garanty that nu is not an integer
z2 = z * z
z4 = z2 * z2
Expand All @@ -121,14 +115,14 @@ def fnu(z, nu):

@jax.jit
def fz_nu(z, nu):
"""z^nu K_nu[z], z > 0"""
return jnp.where(z <= 1.0e-10, fsmallz_nu(z, nu), jnp.power(z, nu) * _Knu(nu, z))
"""z^nu K_nu[z] with z > 0"""
return jnp.where(z <= 1.0e-10, fsmallz_nu(z, nu), jnp.power(z, nu) * kv(nu, z))


@jax.jit
def fsmallz_nup1(z, nu):
def fnu(z, nu):
"""z^(nu+1) K_(nu+1)[z] z -> 0"""
"""z^(nu+1) K_(nu+1)[z] with z -> 0"""
z2 = z * z
z4 = z2 * z2
c1 = -jnp.power(2.0, -4.0 - nu)
Expand All @@ -150,7 +144,7 @@ def fnu(z, nu):
def fz_nup1(z, nu):
"""z^(nu+1) K_{nu+1}(z)"""
return jnp.where(
z <= 1.0e-10, fsmallz_nup1(z, nu), jnp.power(z, nu + 1.0) * _Knu(nu + 1.0, z)
z <= 1.0e-10, fsmallz_nup1(z, nu), jnp.power(z, nu + 1.0) * kv(nu + 1.0, z)
)


Expand Down
Loading