Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jecampagne committed Jan 18, 2024
1 parent 167cc89 commit 253232b
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion jax_galsim/bessel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
import tensorflow_probability as tfp
from jax._src.numpy.util import _wraps


# the code here for Si, f, g and _si_small_pade is taken from galsim/src/math/Sinc.cpp
Expand Down
2 changes: 1 addition & 1 deletion jax_galsim/moffat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from jax.tree_util import Partial as partial
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.bessel import j0
from jax_galsim.bessel import kv
from jax_galsim.core.bessel import j0
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.integrate import ClenshawCurtisQuad, quad_integral
from jax_galsim.core.utils import bisect_for_root, ensure_hashable
Expand Down
5 changes: 2 additions & 3 deletions jax_galsim/spergel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from jax.tree_util import Partial as partial
from jax.tree_util import register_pytree_node_class

from jax_galsim.bessel import gamma, kv
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.utils import bisect_for_root, ensure_hashable
from jax_galsim.bessel import kv, gamma
from jax_galsim.gsobject import GSObject
from jax_galsim.random import UniformDeviate
from jax_galsim.utilities import lazy_property
Expand Down Expand Up @@ -186,8 +186,7 @@ def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0):

@_wraps(
_galsim.Spergel,
lax_description="""
The fully normalized Spergel profile (used in both standard GalSim and JAX-GalSim) is
lax_description=r"""The fully normalized Spergel profile (used in both standard GalSim and JAX-GalSim) is
.. math::
I(r) = flux \times \left(2\pi 2^\nu \Gamma(1+\nu) r_0^2\right)^{-1}
Expand Down

0 comments on commit 253232b

Please sign in to comment.