Skip to content

Commit

Permalink
change _wraps to implements for compatibility with new jax version
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Jun 3, 2024
1 parent 4b12d6b commit 23c718a
Show file tree
Hide file tree
Showing 30 changed files with 269 additions and 269 deletions.
8 changes: 4 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ And that's all you need to do from now on.

JAX-GalSim follows the NumPy/SciPy format: <https://numpydoc.readthedocs.io/en/latest/format.html>

However, most JAX-GalSim function will directly inherit the documentation from the reference GalSim project. We recommend avoid copy/pasting documentation, and instead using the `_wraps` utility to automatically reuse GalSim documentation:
However, most JAX-GalSim function will directly inherit the documentation from the reference GalSim project. We recommend avoid copy/pasting documentation, and instead using the `implements` utility to automatically reuse GalSim documentation:

```python
import galsim as _galsim
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

@_wraps(_galsim.Add,
@implements(_galsim.Add,
lax_description="Does not support `ChromaticObject` at this point.")
def Add(*args, **kwargs):
return Sum(*args, **kwargs)
Expand All @@ -160,4 +160,4 @@ Note that this tool has the option of providing a `lax_description` which will b
In order to be able to use JAX transformations, we need to be able to flatten and unflatten objects. This happens within the `tree_flatten` and `tree_unflatten` methods.
The unflattening can fail to work as expected when type checks are performed in the `__init__` method of a given object. To avoid this issue, the following strategy can used:

https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
20 changes: 10 additions & 10 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
# SOFTWARE.
import galsim as _galsim
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import cast_to_float, ensure_hashable


@_wraps(_galsim.AngleUnit)
@implements(_galsim.AngleUnit)
@register_pytree_node_class
class AngleUnit(object):
valid_names = ["rad", "deg", "hr", "hour", "arcmin", "arcsec"]
Expand Down Expand Up @@ -61,7 +61,7 @@ def __div__(self, unit):
__truediv__ = __div__

@staticmethod
@_wraps(_galsim.AngleUnit.from_name)
@implements(_galsim.AngleUnit.from_name)
def from_name(unit):
unit = unit.strip().lower()
if unit.startswith("rad"):
Expand Down Expand Up @@ -127,7 +127,7 @@ def tree_unflatten(cls, aux_data, children):
arcsec = AngleUnit(jnp.pi / 648000.0)


@_wraps(_galsim.Angle)
@implements(_galsim.Angle)
@register_pytree_node_class
class Angle(object):
def __init__(self, theta, unit=None):
Expand Down Expand Up @@ -198,7 +198,7 @@ def __div__(self, other):

__truediv__ = __div__

@_wraps(_galsim.Angle.wrap)
@implements(_galsim.Angle.wrap)
def wrap(self, center=None):
if center is None:
center = _Angle(0.0)
Expand Down Expand Up @@ -329,15 +329,15 @@ def _make_dms_string(decimal, sep, prec, pad, plus_sign):
string = string + sep3
return string

@_wraps(_galsim.Angle.hms)
@implements(_galsim.Angle.hms)
def hms(self, sep=":", prec=None, pad=True, plus_sign=False):
if not len(sep) <= 3:
raise ValueError("sep must be a string or tuple of length <= 3")
if prec is not None and not prec >= 0:
raise ValueError("prec must be >= 0")
return self._make_dms_string(self / hours, sep, prec, pad, plus_sign)

@_wraps(_galsim.Angle.dms)
@implements(_galsim.Angle.dms)
def dms(self, sep=":", prec=None, pad=True, plus_sign=False):
if not len(sep) <= 3:
raise ValueError("sep must be a string or tuple of length <= 3")
Expand All @@ -346,12 +346,12 @@ def dms(self, sep=":", prec=None, pad=True, plus_sign=False):
return self._make_dms_string(self / degrees, sep, prec, pad, plus_sign)

@staticmethod
@_wraps(_galsim.Angle.from_hms)
@implements(_galsim.Angle.from_hms)
def from_hms(str):
return Angle._parse_dms(str) * hours

@staticmethod
@_wraps(_galsim.Angle.from_dms)
@implements(_galsim.Angle.from_dms)
def from_dms(str):
return Angle._parse_dms(str) * degrees

Expand Down Expand Up @@ -400,7 +400,7 @@ def tree_unflatten(cls, aux_data, children):
return ret


@_wraps(_galsim._Angle)
@implements(_galsim._Angle)
def _Angle(theta):
ret = Angle.__new__(Angle)
ret._rad = theta
Expand Down
4 changes: 2 additions & 2 deletions jax_galsim/bessel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax
import jax.numpy as jnp
import tensorflow_probability as tfp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements


# the code here for Si, f, g and _si_small_pade is taken from galsim/src/math/Sinc.cpp
Expand Down Expand Up @@ -91,7 +91,7 @@ def _si_small_pade(x, x2):
# fmt: on


@_wraps(_galsim.bessel.si)
@implements(_galsim.bessel.si)
@jax.jit
def si(x):
x2 = x * x
Expand Down
12 changes: 6 additions & 6 deletions jax_galsim/bounds.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import galsim as _galsim
import jax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import cast_to_float, cast_to_int, ensure_hashable
Expand All @@ -16,7 +16,7 @@


# The reason for avoid these tests is that they are not easy to do for jitted code.
@_wraps(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR)
@implements(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class Bounds(_galsim.Bounds):
def _parse_args(self, *args, **kwargs):
Expand Down Expand Up @@ -104,7 +104,7 @@ def true_center(self):
)
return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0)

@_wraps(_galsim.Bounds.includes)
@implements(_galsim.Bounds.includes)
def includes(self, *args):
if len(args) == 1:
if isinstance(args[0], Bounds):
Expand Down Expand Up @@ -138,7 +138,7 @@ def includes(self, *args):
else:
raise TypeError("include takes at most 2 arguments (%d given)" % len(args))

@_wraps(_galsim.Bounds.expand)
@implements(_galsim.Bounds.expand)
def expand(self, factor_x, factor_y=None):
if factor_y is None:
factor_y = factor_x
Expand Down Expand Up @@ -266,7 +266,7 @@ def from_galsim(cls, galsim_bounds):
return _cls()


@_wraps(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR)
@implements(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class BoundsD(Bounds):
_pos_class = PositionD
Expand Down Expand Up @@ -300,7 +300,7 @@ def _center(self):
return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0)


@_wraps(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR)
@implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class BoundsI(Bounds):
_pos_class = PositionI
Expand Down
12 changes: 6 additions & 6 deletions jax_galsim/box.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import galsim as _galsim
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy.util import implements
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
Expand All @@ -9,7 +9,7 @@
from jax_galsim.random import UniformDeviate


@_wraps(_galsim.Box)
@implements(_galsim.Box)
@register_pytree_node_class
class Box(GSObject):
_has_hard_edges = True
Expand Down Expand Up @@ -100,7 +100,7 @@ def _drawKImage(self, image, jac=None):
_jac = jnp.eye(2) if jac is None else jac
return draw_by_kValue(self, image, _jac)

@_wraps(_galsim.Box.withFlux)
@implements(_galsim.Box.withFlux)
def withFlux(self, flux):
return Box(
width=self.width, height=self.height, flux=flux, gsparams=self.gsparams
Expand All @@ -116,7 +116,7 @@ def tree_unflatten(cls, aux_data, children):
**aux_data
)

@_wraps(_galsim.Box._shoot)
@implements(_galsim.Box._shoot)
def _shoot(self, photons, rng):
ud = UniformDeviate(rng)

Expand All @@ -126,7 +126,7 @@ def _shoot(self, photons, rng):
photons.flux = self.flux / photons.size()


@_wraps(_galsim.Pixel)
@implements(_galsim.Pixel)
@register_pytree_node_class
class Pixel(Box):
def __init__(self, scale, flux=1.0, gsparams=None):
Expand All @@ -153,7 +153,7 @@ def __str__(self):
s += ")"
return s

@_wraps(_galsim.Pixel.withFlux)
@implements(_galsim.Pixel.withFlux)
def withFlux(self, flux):
return Pixel(scale=self.scale, flux=flux, gsparams=self.gsparams)

Expand Down
Loading

0 comments on commit 23c718a

Please sign in to comment.