diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f2a49dfb..0913828d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -137,14 +137,14 @@ And that's all you need to do from now on. JAX-GalSim follows the NumPy/SciPy format: -However, most JAX-GalSim function will directly inherit the documentation from the reference GalSim project. We recommend avoid copy/pasting documentation, and instead using the `_wraps` utility to automatically reuse GalSim documentation: +However, most JAX-GalSim function will directly inherit the documentation from the reference GalSim project. We recommend avoid copy/pasting documentation, and instead using the `implements` utility to automatically reuse GalSim documentation: ```python import galsim as _galsim -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class -@_wraps(_galsim.Add, +@implements(_galsim.Add, lax_description="Does not support `ChromaticObject` at this point.") def Add(*args, **kwargs): return Sum(*args, **kwargs) @@ -160,4 +160,4 @@ Note that this tool has the option of providing a `lax_description` which will b In order to be able to use JAX transformations, we need to be able to flatten and unflatten objects. This happens within the `tree_flatten` and `tree_unflatten` methods. The unflattening can fail to work as expected when type checks are performed in the `__init__` method of a given object. To avoid this issue, the following strategy can used: -https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization \ No newline at end of file +https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization diff --git a/jax_galsim/angle.py b/jax_galsim/angle.py index cbb04c54..1bf4315e 100644 --- a/jax_galsim/angle.py +++ b/jax_galsim/angle.py @@ -19,13 +19,13 @@ # SOFTWARE. import galsim as _galsim import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import cast_to_float, ensure_hashable -@_wraps(_galsim.AngleUnit) +@implements(_galsim.AngleUnit) @register_pytree_node_class class AngleUnit(object): valid_names = ["rad", "deg", "hr", "hour", "arcmin", "arcsec"] @@ -61,7 +61,7 @@ def __div__(self, unit): __truediv__ = __div__ @staticmethod - @_wraps(_galsim.AngleUnit.from_name) + @implements(_galsim.AngleUnit.from_name) def from_name(unit): unit = unit.strip().lower() if unit.startswith("rad"): @@ -127,7 +127,7 @@ def tree_unflatten(cls, aux_data, children): arcsec = AngleUnit(jnp.pi / 648000.0) -@_wraps(_galsim.Angle) +@implements(_galsim.Angle) @register_pytree_node_class class Angle(object): def __init__(self, theta, unit=None): @@ -198,7 +198,7 @@ def __div__(self, other): __truediv__ = __div__ - @_wraps(_galsim.Angle.wrap) + @implements(_galsim.Angle.wrap) def wrap(self, center=None): if center is None: center = _Angle(0.0) @@ -329,7 +329,7 @@ def _make_dms_string(decimal, sep, prec, pad, plus_sign): string = string + sep3 return string - @_wraps(_galsim.Angle.hms) + @implements(_galsim.Angle.hms) def hms(self, sep=":", prec=None, pad=True, plus_sign=False): if not len(sep) <= 3: raise ValueError("sep must be a string or tuple of length <= 3") @@ -337,7 +337,7 @@ def hms(self, sep=":", prec=None, pad=True, plus_sign=False): raise ValueError("prec must be >= 0") return self._make_dms_string(self / hours, sep, prec, pad, plus_sign) - @_wraps(_galsim.Angle.dms) + @implements(_galsim.Angle.dms) def dms(self, sep=":", prec=None, pad=True, plus_sign=False): if not len(sep) <= 3: raise ValueError("sep must be a string or tuple of length <= 3") @@ -346,12 +346,12 @@ def dms(self, sep=":", prec=None, pad=True, plus_sign=False): return self._make_dms_string(self / degrees, sep, prec, pad, plus_sign) @staticmethod - @_wraps(_galsim.Angle.from_hms) + @implements(_galsim.Angle.from_hms) def from_hms(str): return Angle._parse_dms(str) * hours @staticmethod - @_wraps(_galsim.Angle.from_dms) + @implements(_galsim.Angle.from_dms) def from_dms(str): return Angle._parse_dms(str) * degrees @@ -400,7 +400,7 @@ def tree_unflatten(cls, aux_data, children): return ret -@_wraps(_galsim._Angle) +@implements(_galsim._Angle) def _Angle(theta): ret = Angle.__new__(Angle) ret._rad = theta diff --git a/jax_galsim/bessel.py b/jax_galsim/bessel.py index 1cfa9f0c..49ebec08 100644 --- a/jax_galsim/bessel.py +++ b/jax_galsim/bessel.py @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp import tensorflow_probability as tfp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements # the code here for Si, f, g and _si_small_pade is taken from galsim/src/math/Sinc.cpp @@ -91,7 +91,7 @@ def _si_small_pade(x, x2): # fmt: on -@_wraps(_galsim.bessel.si) +@implements(_galsim.bessel.si) @jax.jit def si(x): x2 = x * x diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 3f99f34d..40241e20 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,7 +1,7 @@ import galsim as _galsim import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import cast_to_float, cast_to_int, ensure_hashable @@ -16,7 +16,7 @@ # The reason for avoid these tests is that they are not easy to do for jitted code. -@_wraps(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR) +@implements(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class Bounds(_galsim.Bounds): def _parse_args(self, *args, **kwargs): @@ -104,7 +104,7 @@ def true_center(self): ) return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) - @_wraps(_galsim.Bounds.includes) + @implements(_galsim.Bounds.includes) def includes(self, *args): if len(args) == 1: if isinstance(args[0], Bounds): @@ -138,7 +138,7 @@ def includes(self, *args): else: raise TypeError("include takes at most 2 arguments (%d given)" % len(args)) - @_wraps(_galsim.Bounds.expand) + @implements(_galsim.Bounds.expand) def expand(self, factor_x, factor_y=None): if factor_y is None: factor_y = factor_x @@ -266,7 +266,7 @@ def from_galsim(cls, galsim_bounds): return _cls() -@_wraps(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR) +@implements(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class BoundsD(Bounds): _pos_class = PositionD @@ -300,7 +300,7 @@ def _center(self): return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) -@_wraps(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR) +@implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class BoundsI(Bounds): _pos_class = PositionI diff --git a/jax_galsim/box.py b/jax_galsim/box.py index 5ac8ca00..3c9b9111 100644 --- a/jax_galsim/box.py +++ b/jax_galsim/box.py @@ -1,6 +1,6 @@ import galsim as _galsim import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue @@ -9,7 +9,7 @@ from jax_galsim.random import UniformDeviate -@_wraps(_galsim.Box) +@implements(_galsim.Box) @register_pytree_node_class class Box(GSObject): _has_hard_edges = True @@ -100,7 +100,7 @@ def _drawKImage(self, image, jac=None): _jac = jnp.eye(2) if jac is None else jac return draw_by_kValue(self, image, _jac) - @_wraps(_galsim.Box.withFlux) + @implements(_galsim.Box.withFlux) def withFlux(self, flux): return Box( width=self.width, height=self.height, flux=flux, gsparams=self.gsparams @@ -116,7 +116,7 @@ def tree_unflatten(cls, aux_data, children): **aux_data ) - @_wraps(_galsim.Box._shoot) + @implements(_galsim.Box._shoot) def _shoot(self, photons, rng): ud = UniformDeviate(rng) @@ -126,7 +126,7 @@ def _shoot(self, photons, rng): photons.flux = self.flux / photons.size() -@_wraps(_galsim.Pixel) +@implements(_galsim.Pixel) @register_pytree_node_class class Pixel(Box): def __init__(self, scale, flux=1.0, gsparams=None): @@ -153,7 +153,7 @@ def __str__(self): s += ")" return s - @_wraps(_galsim.Pixel.withFlux) + @implements(_galsim.Pixel.withFlux) def withFlux(self, flux): return Pixel(scale=self.scale, flux=flux, gsparams=self.gsparams) diff --git a/jax_galsim/celestial.py b/jax_galsim/celestial.py index c00cbd33..c261e65b 100644 --- a/jax_galsim/celestial.py +++ b/jax_galsim/celestial.py @@ -24,7 +24,7 @@ import galsim as _galsim import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.angle import Angle, _Angle, arcsec, degrees, radians @@ -33,7 +33,7 @@ # we have to copy this one since JAX sends in `t` as a traced array # and the coord.Angle classes don't know how to handle that -@_wraps(_coord.util.ecliptic_obliquity) +@implements(_coord.util.ecliptic_obliquity) def _ecliptic_obliquity(epoch): # We need to figure out the time in Julian centuries from J2000 for this epoch. t = (epoch - 2000.0) / 100.0 @@ -53,7 +53,7 @@ def _sun_position_ecliptic(date): return _Angle(_coord.util.sun_position_ecliptic(date).rad) -@_wraps( +@implements( _galsim.celestial.CelestialCoord, lax_description=( "The JAX version of this object does not check that the declination is between -90 and 90." @@ -116,13 +116,13 @@ def _set_aux(self): self._z, ) = aux - @_wraps(_galsim.celestial.CelestialCoord.get_xyz) + @implements(_galsim.celestial.CelestialCoord.get_xyz) def get_xyz(self): return self._get_aux()[4:] @staticmethod @jax.jit - @_wraps( + @implements( _galsim.celestial.CelestialCoord.from_xyz, lax_description=( "The JAX version of this static method does not check that the norm of the input " @@ -156,7 +156,7 @@ def from_xyz(x, y, z): @staticmethod @jax.jit - @_wraps(_galsim.celestial.CelestialCoord.radec_to_xyz) + @implements(_galsim.celestial.CelestialCoord.radec_to_xyz) def radec_to_xyz(ra, dec, r=1.0): cosdec = jnp.cos(dec) x = cosdec * jnp.cos(ra) * r @@ -166,7 +166,7 @@ def radec_to_xyz(ra, dec, r=1.0): @staticmethod @partial(jax.jit, static_argnames=("return_r",)) - @_wraps(_galsim.celestial.CelestialCoord.xyz_to_radec) + @implements(_galsim.celestial.CelestialCoord.xyz_to_radec) def xyz_to_radec(x, y, z, return_r=False): xy2 = x**2 + y**2 ra = jnp.arctan2(y, x) @@ -182,7 +182,7 @@ def xyz_to_radec(x, y, z, return_r=False): else: return ra, dec - @_wraps(_galsim.celestial.CelestialCoord.normal) + @implements(_galsim.celestial.CelestialCoord.normal) def normal(self): return _CelestialCoord(self.ra.wrap(_Angle(jnp.pi)), self.dec) @@ -206,7 +206,7 @@ def _raw_cross(auxc1, auxc2): c1_x * c2_y - c2_x * c1_y, ) - @_wraps(_galsim.celestial.CelestialCoord.distanceTo) + @implements(_galsim.celestial.CelestialCoord.distanceTo) @jax.jit def distanceTo(self, coord2): # The easiest way to do this in a way that is stable for small separations @@ -240,7 +240,7 @@ def distanceTo(self, coord2): return _Angle(theta) - @_wraps( + @implements( _galsim.celestial.CelestialCoord.greatCirclePoint, lax_description=( "The JAX version of this method does not check that coord2 defines a unique great " @@ -349,7 +349,7 @@ def _alt_triple(self, aux, auxc2, auxc3): dsq_AB = self._raw_dsq(auxc2, auxc3) return 0.5 * (dsq_AC + dsq_BC - dsq_AB - 0.5 * dsq_AC * dsq_BC) - @_wraps(_galsim.celestial.CelestialCoord.angleBetween) + @implements(_galsim.celestial.CelestialCoord.angleBetween) @jax.jit def angleBetween(self, coord2, coord3): # Call A = coord2, B = coord3, C = self @@ -374,7 +374,7 @@ def angleBetween(self, coord2, coord3): C = jnp.arctan2(sinC, cosC) return _Angle(C) - @_wraps(_galsim.celestial.CelestialCoord.area) + @implements(_galsim.celestial.CelestialCoord.area) @jax.jit def area(self, coord2, coord3): # The area of a spherical triangle is defined by the "spherical excess", E. @@ -419,7 +419,7 @@ def area(self, coord2, coord3): _valid_projections = [None, "gnomonic", "stereographic", "lambert", "postel"] - @_wraps(_galsim.celestial.CelestialCoord.project) + @implements(_galsim.celestial.CelestialCoord.project) def project(self, coord2, projection=None): if projection not in CelestialCoord._valid_projections: raise ValueError("Unknown projection: %s" % projection) @@ -429,7 +429,7 @@ def project(self, coord2, projection=None): return u * radians, v * radians - @_wraps(_galsim.celestial.CelestialCoord.project_rad) + @implements(_galsim.celestial.CelestialCoord.project_rad) def project_rad(self, ra, dec, projection=None): if projection not in CelestialCoord._valid_projections: raise ValueError("Unknown projection: %s" % projection) @@ -501,7 +501,7 @@ def _project(self, auxc, projection): return u, v - @_wraps(_galsim.celestial.CelestialCoord.deproject) + @implements(_galsim.celestial.CelestialCoord.deproject) def deproject(self, u, v, projection=None): if projection not in CelestialCoord._valid_projections: raise ValueError("Unknown projection: %s" % projection) @@ -511,7 +511,7 @@ def deproject(self, u, v, projection=None): return CelestialCoord(_Angle(ra), _Angle(dec)) - @_wraps(_galsim.celestial.CelestialCoord.deproject_rad) + @implements(_galsim.celestial.CelestialCoord.deproject_rad) def deproject_rad(self, u, v, projection=None): if projection not in CelestialCoord._valid_projections: raise ValueError("Unknown projection: %s" % projection) @@ -587,14 +587,14 @@ def _deproject(self, u, v, projection): return ra, dec - @_wraps(_galsim.celestial.CelestialCoord.jac_deproject) + @implements(_galsim.celestial.CelestialCoord.jac_deproject) def jac_deproject(self, u, v, projection=None): if projection not in CelestialCoord._valid_projections: raise ValueError("Unknown projection: %s" % projection) return self._jac_deproject(u.rad, v.rad, projection) - @_wraps(_galsim.celestial.CelestialCoord.jac_deproject_rad) + @implements(_galsim.celestial.CelestialCoord.jac_deproject_rad) def jac_deproject_rad(self, u, v, projection=None): if projection not in CelestialCoord._valid_projections: raise ValueError("Unknown projection: %s" % projection) @@ -709,13 +709,13 @@ def _jac_deproject(self, u, v, projection): drdv *= cosdec return jnp.array([[drdu, drdv], [dddu, dddv]]) - @_wraps(_galsim.celestial.CelestialCoord.precess) + @implements(_galsim.celestial.CelestialCoord.precess) def precess(self, from_epoch, to_epoch): return CelestialCoord._precess( from_epoch, to_epoch, self._ra.rad, self._dec.rad ) - @_wraps(_galsim.celestial.CelestialCoord.galactic) + @implements(_galsim.celestial.CelestialCoord.galactic) def galactic(self, epoch=2000.0): # cf. Lang, Astrophysical Formulae, page 13 # cos(b) cos(el-33) = cos(dec) cos(ra-282.25) @@ -742,7 +742,7 @@ def galactic(self, epoch=2000.0): return (el, b) @staticmethod - @_wraps(_galsim.celestial.CelestialCoord.from_galactic) + @implements(_galsim.celestial.CelestialCoord.from_galactic) def from_galactic(el, b, epoch=2000.0): el0 = 32.93191857 * degrees r0 = 282.859481208 * degrees @@ -763,7 +763,7 @@ def from_galactic(el, b, epoch=2000.0): return CelestialCoord(temp.ra + r0, temp.dec).normal() @partial(jax.jit, static_argnames=("date",)) - @_wraps(_galsim.celestial.CelestialCoord.ecliptic) + @implements(_galsim.celestial.CelestialCoord.ecliptic) def ecliptic(self, epoch=2000.0, date=None): # We are going to work in terms of the (x, y, z) projections. _x, _y, _z = self._get_aux()[4:] @@ -794,7 +794,7 @@ def ecliptic(self, epoch=2000.0, date=None): @staticmethod @partial(jax.jit, static_argnames=("date",)) - @_wraps(_galsim.celestial.CelestialCoord.from_ecliptic) + @implements(_galsim.celestial.CelestialCoord.from_ecliptic) def from_ecliptic(lam, beta, epoch=2000.0, date=None): if date is not None: lam += _sun_position_ecliptic(date) diff --git a/jax_galsim/convolve.py b/jax_galsim/convolve.py index 855807a3..15fe777d 100644 --- a/jax_galsim/convolve.py +++ b/jax_galsim/convolve.py @@ -1,7 +1,7 @@ import galsim as _galsim import jax.numpy as jnp from galsim.errors import galsim_warn -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.gsobject import GSObject @@ -9,7 +9,7 @@ from jax_galsim.photon_array import PhotonArray -@_wraps( +@implements( _galsim.Convolve, lax_description="""Does not support ChromaticConvolutions""", ) @@ -31,7 +31,7 @@ def Convolve(*args, **kwargs): return Convolution(*args, **kwargs) -@_wraps( +@implements( _galsim.Convolution, lax_description="""Only supports 'fft' convolution.""", ) @@ -345,7 +345,7 @@ def tree_unflatten(cls, aux_data, children): return cls(children[0]["obj_list"], **aux_data) -@_wraps( +@implements( _galsim.convolve.Deconvolve, lax_description="Does not support ChromaticDeconvolution", ) @@ -364,7 +364,7 @@ def Deconvolve(obj, gsparams=None, propagate_gsparams=True): ) -@_wraps(_galsim.convolve.Deconvolution) +@implements(_galsim.convolve.Deconvolution) @register_pytree_node_class class Deconvolution(GSObject): _has_hard_edges = False diff --git a/jax_galsim/deltafunction.py b/jax_galsim/deltafunction.py index 2401e47c..fa4a396a 100644 --- a/jax_galsim/deltafunction.py +++ b/jax_galsim/deltafunction.py @@ -1,7 +1,7 @@ import galsim as _galsim import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue @@ -9,7 +9,7 @@ from jax_galsim.gsobject import GSObject -@_wraps(_galsim.DeltaFunction) +@implements(_galsim.DeltaFunction) @register_pytree_node_class class DeltaFunction(GSObject): _opt_params = {"flux": float} @@ -80,6 +80,6 @@ def _drawKImage(self, image, jac=None): _jac = jnp.eye(2) if jac is None else jac return draw_by_kValue(self, image, _jac) - @_wraps(_galsim.DeltaFunction.withFlux) + @implements(_galsim.DeltaFunction.withFlux) def withFlux(self, flux): return DeltaFunction(flux=flux, gsparams=self.gsparams) diff --git a/jax_galsim/deprecated.py b/jax_galsim/deprecated.py index 62345a3e..0155e083 100644 --- a/jax_galsim/deprecated.py +++ b/jax_galsim/deprecated.py @@ -1,12 +1,12 @@ import warnings import galsim as _galsim -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax_galsim.errors import GalSimDeprecationWarning -@_wraps( +@implements( _galsim.deprecated.depr, lax_description="""\ The JAX version of this function uses `stacklevel=3` to show where the diff --git a/jax_galsim/exponential.py b/jax_galsim/exponential.py index 584960c1..d41a90ce 100644 --- a/jax_galsim/exponential.py +++ b/jax_galsim/exponential.py @@ -1,6 +1,6 @@ import galsim as _galsim import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue @@ -10,7 +10,7 @@ from jax_galsim.utilities import lazy_property -@_wraps(_galsim.Exponential) +@implements(_galsim.Exponential) @register_pytree_node_class class Exponential(GSObject): # The half-light-radius is not analytic, but can be calculated numerically @@ -142,7 +142,7 @@ def _drawKImage(self, image, jac=None): _jac = jnp.eye(2) if jac is None else jac return draw_by_kValue(self, image, _jac) - @_wraps(_galsim.Exponential.withFlux) + @implements(_galsim.Exponential.withFlux) def withFlux(self, flux): return Exponential( scale_radius=self.scale_radius, flux=flux, gsparams=self.gsparams @@ -184,7 +184,7 @@ def _shoot_cdf(self): _cdf /= _cdf[-1] return _u_cdf, _cdf - @_wraps(_galsim.Exponential._shoot) + @implements(_galsim.Exponential._shoot) def _shoot(self, photons, rng): ud = UniformDeviate(rng) diff --git a/jax_galsim/fits.py b/jax_galsim/fits.py index 8c680041..f4c61ac6 100644 --- a/jax_galsim/fits.py +++ b/jax_galsim/fits.py @@ -6,7 +6,7 @@ import numpy as np from galsim.fits import FitsHeader, closeHDUList, readFile, writeFile # noqa: F401 from galsim.utilities import galsim_warn -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax_galsim.image import Image @@ -27,14 +27,14 @@ def _maybe_convert_and_warn(image): return image -@_wraps(_galsim.fits.read) +@implements(_galsim.fits.read) def read(*args, **kwargs): gsimage = _galsim.fits.read(*args, **kwargs) # galsim tests the dtypes against its Image class, so we need to test again here return _maybe_convert_and_warn(Image.from_galsim(gsimage)) -@_wraps(_galsim.fits.readMulti) +@implements(_galsim.fits.readMulti) def readMulti(*args, **kwargs): gsimage_list = _galsim.fits.readMulti(*args, **kwargs) return [ @@ -42,7 +42,7 @@ def readMulti(*args, **kwargs): ] -@_wraps(_galsim.fits.readCube) +@implements(_galsim.fits.readCube) def readCube(*args, **kwargs): gsimage_list = _galsim.fits.readCube(*args, **kwargs) return [ @@ -75,7 +75,7 @@ def _image_as_numpy(image): pass -@_wraps(_galsim.fits.write) +@implements(_galsim.fits.write) def write(*args, **kwargs): if len(args) >= 1 and isinstance(args[0], Image): with _image_as_numpy(args[0]) as image: @@ -84,7 +84,7 @@ def write(*args, **kwargs): _galsim.fits.write(*args, **kwargs) -@_wraps(_galsim.fits.writeMulti) +@implements(_galsim.fits.writeMulti) def writeMulti(*args, **kwargs): if len(args) >= 1: with ExitStack() as stack: @@ -99,7 +99,7 @@ def writeMulti(*args, **kwargs): _galsim.fits.writeMulti(*args, **kwargs) -@_wraps(_galsim.fits.writeCube) +@implements(_galsim.fits.writeCube) def writeCube(*args, **kwargs): if len(args) >= 1: with ExitStack() as stack: diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index d4a6414a..5113789c 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -5,7 +5,7 @@ import jax import jax.numpy as jnp import numpy as np -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim import fits @@ -47,7 +47,7 @@ ######################################################################################### -@_wraps( +@implements( _galsim.fitswcs.GSFitsWCS, lax_description=( "The JAX-GalSim version of this class does not raise errors if inverting the WCS to " @@ -866,7 +866,7 @@ def __hash__(self): return hash(repr(self)) -@_wraps(_galsim.fitswcs.TanWCS) +@implements(_galsim.fitswcs.TanWCS) def TanWCS(affine, world_origin, units=arcsec): # These will raise the appropriate errors if affine is not the right type. dudx = affine.dudx * units / degrees @@ -911,7 +911,7 @@ def TanWCS(affine, world_origin, units=arcsec): ] -@_wraps( +@implements( _galsim.fitswcs.FitsWCS, lax_description="JAX-GalSim only supports the GSFitsWCS class for celestial WCS types.", ) diff --git a/jax_galsim/gaussian.py b/jax_galsim/gaussian.py index eccd0eeb..6d5f1a91 100644 --- a/jax_galsim/gaussian.py +++ b/jax_galsim/gaussian.py @@ -1,6 +1,6 @@ import galsim as _galsim import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue @@ -9,7 +9,7 @@ from jax_galsim.random import GaussianDeviate -@_wraps(_galsim.Gaussian) +@implements(_galsim.Gaussian) @register_pytree_node_class class Gaussian(GSObject): # The FWHM of a Gaussian is 2 sqrt(2 ln2) sigma @@ -144,11 +144,11 @@ def _drawKImage(self, image, jac=None): _jac = jnp.eye(2) if jac is None else jac return draw_by_kValue(self, image, _jac) - @_wraps(_galsim.Gaussian.withFlux) + @implements(_galsim.Gaussian.withFlux) def withFlux(self, flux): return Gaussian(sigma=self.sigma, flux=flux, gsparams=self.gsparams) - @_wraps(_galsim.Gaussian._shoot) + @implements(_galsim.Gaussian._shoot) def _shoot(self, photons, rng): gd = GaussianDeviate(rng, sigma=self.sigma) diff --git a/jax_galsim/gsobject.py b/jax_galsim/gsobject.py index 85c1bbc2..797f1a46 100644 --- a/jax_galsim/gsobject.py +++ b/jax_galsim/gsobject.py @@ -5,7 +5,7 @@ import jax import jax.numpy as jnp import numpy as np -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements import jax_galsim.photon_array as pa from jax_galsim.core.draw import calculate_n_photons @@ -25,7 +25,7 @@ from jax_galsim.utilities import parse_pos_args -@_wraps(_galsim.GSObject) +@implements(_galsim.GSObject) class GSObject: def __init__(self, *, gsparams=None, **params): self._params = params # Dictionary containing all traced parameters @@ -115,12 +115,12 @@ def _centroid(self): return PositionD(0, 0) @property - @_wraps(_galsim.GSObject.positive_flux) + @implements(_galsim.GSObject.positive_flux) def positive_flux(self): return self._positive_flux @property - @_wraps(_galsim.GSObject.negative_flux) + @implements(_galsim.GSObject.negative_flux) def negative_flux(self): return self._negative_flux @@ -146,7 +146,7 @@ def _calculate_flux_per_photon(self): return 1.0 - 2.0 * eta @property - @_wraps(_galsim.GSObject.max_sb) + @implements(_galsim.GSObject.max_sb) def max_sb(self): return self._max_sb @@ -209,7 +209,7 @@ def __eq__(self, other): and is_equal_with_arrays(self.tree_flatten(), other.tree_flatten()) ) - @_wraps(_galsim.GSObject.xValue) + @implements(_galsim.GSObject.xValue) def xValue(self, *args, **kwargs): pos = parse_pos_args(args, kwargs, "x", "y") return self._xValue(pos) @@ -227,7 +227,7 @@ def _xValue(self, pos): "%s does not implement xValue" % self.__class__.__name__ ) - @_wraps(_galsim.GSObject.kValue) + @implements(_galsim.GSObject.kValue) def kValue(self, *args, **kwargs): kpos = parse_pos_args(args, kwargs, "kx", "ky") return self._kValue(kpos) @@ -238,7 +238,7 @@ def _kValue(self, kpos): "%s does not implement kValue" % self.__class__.__name__ ) - @_wraps(_galsim.GSObject.withGSParams) + @implements(_galsim.GSObject.withGSParams) def withGSParams(self, gsparams=None, **kwargs): if gsparams == self.gsparams: return self @@ -249,34 +249,34 @@ def withGSParams(self, gsparams=None, **kwargs): aux_data["gsparams"] = gsparams return self.tree_unflatten(aux_data, children) - @_wraps(_galsim.GSObject.withFlux) + @implements(_galsim.GSObject.withFlux) def withFlux(self, flux): return self.withScaledFlux(flux / self.flux) - @_wraps(_galsim.GSObject.withScaledFlux) + @implements(_galsim.GSObject.withScaledFlux) def withScaledFlux(self, flux_ratio): from jax_galsim.transform import Transform return Transform(self, flux_ratio=flux_ratio) - @_wraps(_galsim.GSObject.expand) + @implements(_galsim.GSObject.expand) def expand(self, scale): from jax_galsim.transform import Transform return Transform(self, jac=[scale, 0.0, 0.0, scale]) - @_wraps(_galsim.GSObject.dilate) + @implements(_galsim.GSObject.dilate) def dilate(self, scale): from jax_galsim.transform import Transform # equivalent to self.expand(scale) * (1./scale**2) return Transform(self, jac=[scale, 0.0, 0.0, scale], flux_ratio=scale**-2) - @_wraps(_galsim.GSObject.magnify) + @implements(_galsim.GSObject.magnify) def magnify(self, mu): return self.expand(jnp.sqrt(mu)) - @_wraps(_galsim.GSObject.shear) + @implements(_galsim.GSObject.shear) def shear(self, *args, **kwargs): from jax_galsim.shear import Shear from jax_galsim.transform import Transform @@ -385,13 +385,13 @@ def rotate(self, theta): s, c = theta.sincos() return Transform(self, jac=[c, -s, s, c]) - @_wraps(_galsim.GSObject.transform) + @implements(_galsim.GSObject.transform) def transform(self, dudx, dudy, dvdx, dvdy): from jax_galsim.transform import Transform return Transform(self, jac=[dudx, dudy, dvdx, dvdy]) - @_wraps(_galsim.GSObject.shift) + @implements(_galsim.GSObject.shift) def shift(self, *args, **kwargs): from jax_galsim.transform import Transform @@ -640,7 +640,7 @@ def _determine_wcs(self, scale, wcs, image, default_wcs=None): return wcs - @_wraps( + @implements( _galsim.GSObject.drawImage, lax_description="""\ The JAX-GalSim version of `drawImage` @@ -871,7 +871,7 @@ def drawImage( return image - @_wraps(_galsim.GSObject.drawReal) + @implements(_galsim.GSObject.drawReal) def drawReal(self, image, add_to_image=False): if image.wcs is None or not image.wcs.isPixelScale(): raise _galsim.GalSimValueError( @@ -897,7 +897,7 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): "%s does not implement drawReal" % self.__class__.__name__ ) - @_wraps(_galsim.GSObject.getGoodImageSize) + @implements(_galsim.GSObject.getGoodImageSize) def getGoodImageSize(self, pixel_scale): # Start with a good size from stepk and the pixel scale Nd = 2.0 * jnp.pi / (pixel_scale * self.stepk) @@ -910,7 +910,7 @@ def getGoodImageSize(self, pixel_scale): N = 2 * ((N + 1) // 2) return N - @_wraps(_galsim.GSObject.drawFFT_makeKImage) + @implements(_galsim.GSObject.drawFFT_makeKImage) def drawFFT_makeKImage(self, image): from jax_galsim.bounds import BoundsI from jax_galsim.image import ImageCD, ImageCF @@ -1045,7 +1045,7 @@ def drawFFT(self, image, add_to_image=False): kimage = self._drawKImage(kimage) return self.drawFFT_finish(image, kimage, wrap_size, add_to_image) - @_wraps(_galsim.GSObject.drawKImage) + @implements(_galsim.GSObject.drawKImage) def drawKImage( self, image=None, @@ -1151,7 +1151,7 @@ def drawKImage( return image - @_wraps(_galsim.GSObject._drawKImage) + @implements(_galsim.GSObject._drawKImage) def _drawKImage( self, image, jac=None ): # pragma: no cover (all our classes override this) @@ -1159,7 +1159,7 @@ def _drawKImage( "%s does not implement drawKImage" % self.__class__.__name__ ) - @_wraps(_galsim.GSObject._calculate_nphotons) + @implements(_galsim.GSObject._calculate_nphotons) def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): n_photons, g, _rng = calculate_n_photons( self.flux, @@ -1174,7 +1174,7 @@ def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng): rng._state = _rng._state return n_photons, g - @_wraps( + @implements( _galsim.GSObject.makePhot, lax_description="""\ The JAX-GalSim version of `makePhot` @@ -1242,7 +1242,7 @@ def makePhot( return photons - @_wraps( + @implements( _galsim.GSObject.drawPhot, lax_description="""\ The JAX-GalSim version of `drawPhot` @@ -1381,7 +1381,7 @@ def drawPhot( return _dfret.added_flux, _dfret.photons - @_wraps(_galsim.GSObject.shoot) + @implements(_galsim.GSObject.shoot) def shoot(self, n_photons, rng=None): photons = pa.PhotonArray(n_photons) @@ -1393,13 +1393,13 @@ def shoot(self, n_photons, rng=None): return photons - @_wraps(_galsim.GSObject._shoot) + @implements(_galsim.GSObject._shoot) def _shoot(self, photons, rng): raise NotImplementedError( "%s does not implement shoot" % self.__class__.__name__ ) - @_wraps(_galsim.GSObject.applyTo) + @implements(_galsim.GSObject.applyTo) def applyTo(self, photon_array, local_wcs=None, rng=None): # galsim does not deal with dxdz and dydz here - IDK why p1 = pa.PhotonArray(len(photon_array)) diff --git a/jax_galsim/gsparams.py b/jax_galsim/gsparams.py index cd514ba8..851a62f5 100644 --- a/jax_galsim/gsparams.py +++ b/jax_galsim/gsparams.py @@ -1,10 +1,10 @@ from dataclasses import dataclass import galsim as _galsim -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements -@_wraps(_galsim.GSParams) +@implements(_galsim.GSParams) @dataclass(frozen=True, repr=False) class GSParams: minimum_fft_size: int = 128 diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 541cddc9..51534803 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -1,7 +1,7 @@ import galsim as _galsim import jax.numpy as jnp import numpy as np -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.bounds import Bounds, BoundsD, BoundsI @@ -23,7 +23,7 @@ """ -@_wraps( +@implements( _galsim.Image, lax_description=IMAGE_LAX_DOCS, ) @@ -630,7 +630,7 @@ def __setitem__(self, *args): else: raise TypeError("image[..] requires either 1 or 2 args") - @_wraps(_galsim.Image.wrap) + @implements(_galsim.Image.wrap) def wrap(self, bounds, hermitian=False): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") @@ -715,7 +715,7 @@ def _wrap(self, bounds, hermx, hermy): return self.subImage(bounds) - @_wraps( + @implements( _galsim.Image.calculate_fft, lax_description="JAX-GalSim does not support forward FFTs of complex dtypes.", ) @@ -770,7 +770,7 @@ def calculate_fft(self): out.setOrigin(0, -No2) return out - @_wraps(_galsim.Image.calculate_inverse_fft) + @implements(_galsim.Image.calculate_inverse_fft) def calculate_inverse_fft(self): if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( @@ -863,7 +863,7 @@ def copyFrom(self, rhs): ) self._array = rhs._array - @_wraps( + @implements( _galsim.Image.view, lax_description="Contrary to GalSim, this will create a copy of the orginal image.", ) @@ -920,7 +920,7 @@ def view( return ret - @_wraps(_galsim.Image.shift) + @implements(_galsim.Image.shift) def shift(self, *args, **kwargs): delta = parse_pos_args(args, kwargs, "dx", "dy", integer=True) self._shift(delta) @@ -936,28 +936,28 @@ def _shift(self, delta): if self.wcs is not None: self.wcs = self.wcs.shiftOrigin(delta) - @_wraps(_galsim.Image.setCenter) + @implements(_galsim.Image.setCenter) def setCenter(self, *args, **kwargs): cen = parse_pos_args(args, kwargs, "xcen", "ycen", integer=True) self._shift(cen - self.center) - @_wraps(_galsim.Image.setOrigin) + @implements(_galsim.Image.setOrigin) def setOrigin(self, *args, **kwargs): origin = parse_pos_args(args, kwargs, "x0", "y0", integer=True) self._shift(origin - self.origin) @property - @_wraps(_galsim.Image.center) + @implements(_galsim.Image.center) def center(self): return self.bounds.center @property - @_wraps(_galsim.Image.true_center) + @implements(_galsim.Image.true_center) def true_center(self): return self.bounds.true_center @property - @_wraps(_galsim.Image.origin) + @implements(_galsim.Image.origin) def origin(self): return self.bounds.origin @@ -970,7 +970,7 @@ def __call__(self, *args, **kwargs): pos = parse_pos_args(args, kwargs, "x", "y", integer=True) return self.getValue(pos.x, pos.y) - @_wraps(_galsim.Image.getValue) + @implements(_galsim.Image.getValue) def getValue(self, x, y): if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( @@ -990,7 +990,7 @@ def _getValue(self, x, y): """ return self.array[y - self.ymin, x - self.xmin] - @_wraps(_galsim.Image.setValue) + @implements(_galsim.Image.setValue) def setValue(self, *args, **kwargs): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) @@ -1018,7 +1018,7 @@ def _setValue(self, x, y, value): """ self._array = self._array.at[y - self.ymin, x - self.xmin].set(value) - @_wraps(_galsim.Image.addValue) + @implements(_galsim.Image.addValue) def addValue(self, *args, **kwargs): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) @@ -1125,30 +1125,30 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) - @_wraps(_galsim.Image.transpose) + @implements(_galsim.Image.transpose) def transpose(self): bT = BoundsI(self.ymin, self.ymax, self.xmin, self.xmax) return _Image(self.array.T, bT, None) - @_wraps(_galsim.Image.flip_lr) + @implements(_galsim.Image.flip_lr) def flip_lr(self): return _Image(self.array.at[:, ::-1].get(), self._bounds, None) - @_wraps(_galsim.Image.flip_ud) + @implements(_galsim.Image.flip_ud) def flip_ud(self): return _Image(self.array.at[::-1, :].get(), self._bounds, None) - @_wraps(_galsim.Image.rot_cw) + @implements(_galsim.Image.rot_cw) def rot_cw(self): bT = BoundsI(self.ymin, self.ymax, self.xmin, self.xmax) return _Image(self.array.T.at[::-1, :].get(), bT, None) - @_wraps(_galsim.Image.rot_ccw) + @implements(_galsim.Image.rot_ccw) def rot_ccw(self): bT = BoundsI(self.ymin, self.ymax, self.xmin, self.xmax) return _Image(self.array.T.at[:, ::-1].get(), bT, None) - @_wraps(_galsim.Image.rot_180) + @implements(_galsim.Image.rot_180) def rot_180(self): return _Image(self.array.at[::-1, ::-1].get(), self._bounds, None) @@ -1198,7 +1198,7 @@ def from_galsim(cls, galsim_image): return im -@_wraps( +@implements( _galsim._Image, lax_description=IMAGE_LAX_DOCS, ) diff --git a/jax_galsim/interpolant.py b/jax_galsim/interpolant.py index 50d472c1..8c19bb19 100644 --- a/jax_galsim/interpolant.py +++ b/jax_galsim/interpolant.py @@ -10,7 +10,7 @@ import jax import jax.numpy as jnp from galsim.errors import GalSimValueError -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.bessel import si @@ -21,7 +21,7 @@ from jax_galsim.utilities import lazy_property -@_wraps(_galsim.interpolant.Interpolant) +@implements(_galsim.interpolant.Interpolant) @register_pytree_node_class class Interpolant: def __init__(self): @@ -313,7 +313,7 @@ def _shoot(self, photons, rng): # _unit_integrals, _positive_flux, _negative_flux, urange, and xrange -@_wraps(_galsim.interpolant.Delta) +@implements(_galsim.interpolant.Delta) @register_pytree_node_class class Delta(Interpolant): _positive_flux = 1.0 @@ -365,7 +365,7 @@ def _shoot(self, photons, rng): photons.flux = 1.0 / photons.size() -@_wraps(_galsim.interpolant.Nearest) +@implements(_galsim.interpolant.Nearest) @register_pytree_node_class class Nearest(Interpolant): _positive_flux = 1.0 @@ -409,7 +409,7 @@ def _shoot(self, photons, rng): photons.flux = 1.0 / photons.size() -@_wraps(_galsim.interpolant.SincInterpolant) +@implements(_galsim.interpolant.SincInterpolant) @register_pytree_node_class class SincInterpolant(Interpolant): def __init__(self, tol=None, gsparams=None): @@ -485,7 +485,7 @@ def _shoot(self, photons, rng): ) -@_wraps(_galsim.interpolant.Linear) +@implements(_galsim.interpolant.Linear) @register_pytree_node_class class Linear(Interpolant): _positive_flux = 1.0 @@ -537,7 +537,7 @@ def _shoot(self, photons, rng): photons.flux = 1.0 / photons.size() -@_wraps(_galsim.interpolant.Cubic) +@implements(_galsim.interpolant.Cubic) @register_pytree_node_class class Cubic(Interpolant): # these constants are from galsim itself in the cpp layer @@ -602,7 +602,7 @@ def ixrange(self): return 4 -@_wraps(_galsim.interpolant.Quintic) +@implements(_galsim.interpolant.Quintic) @register_pytree_node_class class Quintic(Interpolant): # these constants are from galsim itself in the cpp layer @@ -697,7 +697,7 @@ def ixrange(self): return 6 -@_wraps(_galsim.interpolant.Lanczos) +@implements(_galsim.interpolant.Lanczos) @register_pytree_node_class class Lanczos(Interpolant): # this data was generated in the dev notebook at diff --git a/jax_galsim/interpolatedimage.py b/jax_galsim/interpolatedimage.py index 3ca1cdde..f5de81ef 100644 --- a/jax_galsim/interpolatedimage.py +++ b/jax_galsim/interpolatedimage.py @@ -14,7 +14,7 @@ GalSimValueError, ) from galsim.utilities import doc_inherit -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim import fits @@ -52,7 +52,7 @@ def __dir__(cls): return list(keys) -@_wraps( +@implements( _galsim.InterpolatedImage, lax_description=textwrap.dedent( """The JAX equivalent of galsim.InterpolatedImage does not support @@ -337,7 +337,7 @@ def tree_unflatten(cls, aux_data, children): val.update(children[1]) return cls(children[0], **val) - @_wraps(_galsim.InterpolatedImage.withGSParams) + @implements(_galsim.InterpolatedImage.withGSParams) def withGSParams(self, gsparams=None, **kwargs): if gsparams == self.gsparams: return self @@ -916,7 +916,7 @@ def _shoot(self, photons, rng): photons.convolve(x_photons) -@_wraps(_galsim._InterpolatedImage) +@implements(_galsim._InterpolatedImage) def _InterpolatedImage( image, x_interpolant=Quintic(), diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 75925ed3..3c02c86b 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -1,7 +1,7 @@ import galsim as _galsim import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import Partial as partial from jax.tree_util import register_pytree_node_class @@ -68,7 +68,7 @@ def body(i, xcur): return rd -@_wraps(_galsim.Moffat) +@implements(_galsim.Moffat) @register_pytree_node_class class Moffat(GSObject): _is_axisymmetric = True @@ -372,7 +372,7 @@ def _drawKImage(self, image, jac=None): _jac = jnp.eye(2) if jac is None else jac return draw_by_kValue(self, image, _jac) - @_wraps(_galsim.Moffat.withFlux) + @implements(_galsim.Moffat.withFlux) def withFlux(self, flux): return Moffat( beta=self.beta, @@ -382,7 +382,7 @@ def withFlux(self, flux): gsparams=self.gsparams, ) - @_wraps(_galsim.Moffat.shoot) + @implements(_galsim.Moffat.shoot) def _shoot(self, photons, rng): # from the galsim C++ in SBMoffat.cpp ud = UniformDeviate(rng) diff --git a/jax_galsim/noise.py b/jax_galsim/noise.py index f04acd94..2a921987 100644 --- a/jax_galsim/noise.py +++ b/jax_galsim/noise.py @@ -1,7 +1,7 @@ import galsim as _galsim import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import cast_to_float, ensure_hashable @@ -10,13 +10,13 @@ from jax_galsim.random import BaseDeviate, GaussianDeviate, PoissonDeviate -@_wraps(_galsim.noise.addNoise) +@implements(_galsim.noise.addNoise) def addNoise(self, noise): # This will be inserted into the Image class as a method. So self = image. noise.applyTo(self) -@_wraps(_galsim.noise.addNoiseSNR) +@implements(_galsim.noise.addNoiseSNR) def addNoiseSNR(self, noise, snr, preserve_flux=False): # This will be inserted into the Image class as a method. So self = image. noise_var = noise.getVariance() @@ -38,7 +38,7 @@ def addNoiseSNR(self, noise, snr, preserve_flux=False): Image.addNoiseSNR = addNoiseSNR -@_wraps(_galsim.noise.BaseNoise) +@implements(_galsim.noise.BaseNoise) @register_pytree_node_class class BaseNoise: def __init__(self, rng=None): @@ -162,7 +162,7 @@ def tree_unflatten(cls, aux_data, children): return cls(rng=children[0]) -@_wraps(_galsim.noise.GaussianNoise) +@implements(_galsim.noise.GaussianNoise) @register_pytree_node_class class GaussianNoise(BaseNoise): def __init__(self, rng=None, sigma=1.0): @@ -188,7 +188,7 @@ def _withVariance(self, variance): def _withScaledVariance(self, variance_ratio): return GaussianNoise(self.rng, self.sigma * jnp.sqrt(variance_ratio)) - @_wraps( + @implements( _galsim.noise.GaussianNoise.copy, lax_description="JAX-GalSim RNGs cannot be shared so a copy is made if None is given.", ) @@ -221,7 +221,7 @@ def tree_unflatten(cls, aux_data, children): return cls(sigma=children[0], rng=children[1]) -@_wraps(_galsim.noise.PoissonNoise) +@implements(_galsim.noise.PoissonNoise) @register_pytree_node_class class PoissonNoise(BaseNoise): def __init__(self, rng=None, sky_level=0.0): @@ -282,7 +282,7 @@ def _withVariance(self, variance): def _withScaledVariance(self, variance_ratio): return PoissonNoise(self.rng, self.sky_level * variance_ratio) - @_wraps( + @implements( _galsim.noise.PoissonNoise.copy, lax_description="JAX-GalSim RNGs cannot be shared so a copy is made if None is given.", ) @@ -312,7 +312,7 @@ def tree_unflatten(cls, aux_data, children): return cls(sky_level=children[0], rng=children[1]) -@_wraps(_galsim.noise.CCDNoise) +@implements(_galsim.noise.CCDNoise) @register_pytree_node_class class CCDNoise(BaseNoise): def __init__(self, rng=None, sky_level=0.0, gain=1.0, read_noise=0.0): @@ -435,7 +435,7 @@ def _withScaledVariance(self, variance_ratio): read_noise=self.read_noise * jnp.sqrt(variance_ratio), ) - @_wraps( + @implements( _galsim.noise.CCDNoise.copy, lax_description="JAX-GalSim RNGs cannot be shared so a copy is made if None is given.", ) @@ -479,7 +479,7 @@ def tree_unflatten(cls, aux_data, children): ) -@_wraps(_galsim.noise.DeviateNoise) +@implements(_galsim.noise.DeviateNoise) @register_pytree_node_class class DeviateNoise(BaseNoise): def __init__(self, dev): @@ -499,7 +499,7 @@ def _withVariance(self, variance): def _withScaledVariance(self, variance): raise GalSimError("Changing the variance is not allowed for DeviateNoise") - @_wraps( + @implements( _galsim.noise.DeviateNoise.copy, lax_description="JAX-GalSim RNGs cannot be shared so a copy is made if None is given.", ) @@ -534,7 +534,7 @@ def tree_unflatten(cls, aux_data, children): return cls(rng=children[0]) -@_wraps(_galsim.noise.VariableGaussianNoise) +@implements(_galsim.noise.VariableGaussianNoise) @register_pytree_node_class class VariableGaussianNoise(BaseNoise): def __init__(self, rng, var_image): @@ -550,7 +550,7 @@ def var_image(self): # Repeat this here, since we want to add an extra sanity check, which should go in the # non-underscore version. - @_wraps(_galsim.noise.VariableGaussianNoise.applyTo) + @implements(_galsim.noise.VariableGaussianNoise.applyTo) def applyTo(self, image): if not isinstance(image, Image): raise TypeError("Provided image must be a galsim.Image") @@ -567,7 +567,7 @@ def _applyTo(self, image): noise_array = self._rng.generate_from_variance(self.var_image.array) image._array = image._array + noise_array.astype(image.dtype) - @_wraps( + @implements( _galsim.noise.VariableGaussianNoise.copy, lax_description="JAX-GalSim RNGs cannot be shared so a copy is made if None is given.", ) diff --git a/jax_galsim/photon_array.py b/jax_galsim/photon_array.py index c0eb1136..c0201ec8 100644 --- a/jax_galsim/photon_array.py +++ b/jax_galsim/photon_array.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp import jax.random as jrng -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import cast_to_python_int @@ -33,7 +33,7 @@ def fixed_photon_array_size(size): _JAX_GALSIM_PHOTON_ARRAY_SIZE = old_size -@_wraps( +@implements( _galsim.PhotonArray, lax_description="""\ JAX-GalSim PhotonArrays have significant differences from the original GalSim. @@ -138,7 +138,7 @@ def __init__( self.time = time @classmethod - @_wraps( + @implements( _galsim.PhotonArray.fromArrays, lax_description="JAX-GalSim does not do input type/size checking.", ) @@ -160,7 +160,7 @@ def fromArrays( ) @classmethod - @_wraps(_galsim.PhotonArray._fromArrays) + @implements(_galsim.PhotonArray._fromArrays) def _fromArrays( cls, x, @@ -508,7 +508,7 @@ def _set_self_at_inds(self, sinds): return self - @_wraps(_galsim.PhotonArray.assignAt) + @implements(_galsim.PhotonArray.assignAt) def assignAt(self, istart, rhs): from .deprecated import depr @@ -525,7 +525,7 @@ def assignAt(self, istart, rhs): s = slice(istart, istart + rhs.size()) return self._copyFrom(rhs, s, slice(None)) - @_wraps( + @implements( _galsim.PhotonArray.copyFrom, lax_description="The JAX version of PhotonArray.copyFrom does not raise for out of bounds indices.", ) @@ -833,7 +833,7 @@ def __eq__(self, other): def __ne__(self, other): return not self == other - @_wraps( + @implements( _galsim.PhotonArray.addTo, lax_description="The JAX equivalent of galsim.PhotonArray.addTo may not raise for undefined bounds.", ) diff --git a/jax_galsim/position.py b/jax_galsim/position.py index f89d9c4f..df8ba08d 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -1,13 +1,13 @@ import galsim as _galsim import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import cast_to_float, cast_to_int, ensure_hashable -@_wraps(_galsim.Position) +@implements(_galsim.Position) class Position(object): def __init__(self): raise NotImplementedError( @@ -174,7 +174,7 @@ def from_galsim(cls, galsim_position): return _cls(galsim_position.x, galsim_position.y) -@_wraps(_galsim.PositionD) +@implements(_galsim.PositionD) @register_pytree_node_class class PositionD(Position): def __init__(self, *args, **kwargs): @@ -199,7 +199,7 @@ def _check_scalar(self, other, op): raise TypeError("Can only %s a PositionD by float values" % op) -@_wraps(_galsim.PositionI) +@implements(_galsim.PositionI) @register_pytree_node_class class PositionI(Position): def __init__(self, *args, **kwargs): diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 03735294..548614f1 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -5,7 +5,7 @@ import jax import jax.numpy as jnp import jax.random as jrandom -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class try: @@ -67,7 +67,7 @@ def tree_unflatten(cls, aux_data, children): return cls(children[0]) -@_wraps( +@implements( _galsim.BaseDeviate, lax_description=LAX_FUNCTIONAL_RNG, ) @@ -79,28 +79,28 @@ def __init__(self, seed=None): self._params = {} @property - @_wraps(_galsim.BaseDeviate.has_reliable_discard) + @implements(_galsim.BaseDeviate.has_reliable_discard) def has_reliable_discard(self): return True @property - @_wraps(_galsim.BaseDeviate.generates_in_pairs) + @implements(_galsim.BaseDeviate.generates_in_pairs) def generates_in_pairs(self): return False - @_wraps( + @implements( _galsim.BaseDeviate.seed, lax_description="The JAX version of this method does no type checking.", ) def seed(self, seed=None): self._seed(seed=seed) - @_wraps(_galsim.BaseDeviate._seed) + @implements(_galsim.BaseDeviate._seed) def _seed(self, seed=None): _initial_seed = seed or secrets.randbelow(2**31) self._state.key = jrandom.key(_initial_seed) - @_wraps( + @implements( _galsim.BaseDeviate.reset, lax_description=("The JAX version of this method does no type checking."), ) @@ -137,26 +137,26 @@ def serialize(self): return repr(ensure_hashable(jrandom.key_data(self._key))) @property - @_wraps(_galsim.BaseDeviate.np) + @implements(_galsim.BaseDeviate.np) def np(self): raise NotImplementedError( "The JAX galsim.BaseDeviate does not support being used as a numpy PRNG." ) - @_wraps(_galsim.BaseDeviate.as_numpy_generator) + @implements(_galsim.BaseDeviate.as_numpy_generator) def as_numpy_generator(self): raise NotImplementedError( "The JAX galsim.BaseDeviate does not support being used as a numpy PRNG." ) - @_wraps( + @implements( _galsim.BaseDeviate.clearCache, lax_description="This method is a no-op for the JAX version of this class.", ) def clearCache(self): pass - @_wraps( + @implements( _galsim.BaseDeviate.discard, lax_description=( "The JAX version of this class has reliable discarding and uses one key per value " @@ -174,7 +174,7 @@ def __discard(i, key): return jax.lax.fori_loop(0, n, __discard, key) - @_wraps( + @implements( _galsim.BaseDeviate.raw, lax_description=( "The JAX version of this class does not use the raw value to " @@ -185,7 +185,7 @@ def raw(self): self._key, subkey = jrandom.split(self._key) return jrandom.bits(subkey, dtype=jnp.uint32) - @_wraps( + @implements( _galsim.BaseDeviate.generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " @@ -196,7 +196,7 @@ def generate(self, array): self._key, array = self.__class__._generate(self._key, array) return array - @_wraps( + @implements( _galsim.BaseDeviate.add_generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " @@ -210,7 +210,7 @@ def __call__(self): self._key, val = self.__class__._generate_one(self._key, None) return val - @_wraps(_galsim.BaseDeviate.duplicate) + @implements(_galsim.BaseDeviate.duplicate) def duplicate(self): ret = self.__class__.__new__(self.__class__) ret._state = _DeviateState(self._state.key) @@ -257,7 +257,7 @@ def __str__(self): return self.__repr__() -@_wraps( +@implements( _galsim.UniformDeviate, lax_description=LAX_FUNCTIONAL_RNG, ) @@ -285,7 +285,7 @@ def __str__(self): return "galsim.UniformDeviate()" -@_wraps( +@implements( _galsim.GaussianDeviate, lax_description=LAX_FUNCTIONAL_RNG, ) @@ -297,16 +297,16 @@ def __init__(self, seed=None, mean=0.0, sigma=1.0): self._params["sigma"] = sigma @property - @_wraps(_galsim.GaussianDeviate.mean) + @implements(_galsim.GaussianDeviate.mean) def mean(self): return self._params["mean"] @property - @_wraps(_galsim.GaussianDeviate.sigma) + @implements(_galsim.GaussianDeviate.sigma) def sigma(self): return self._params["sigma"] - @_wraps( + @implements( _galsim.GaussianDeviate.generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " @@ -334,7 +334,7 @@ def _generate_one(key, x): _key, subkey = jrandom.split(key) return _key, jrandom.normal(subkey, dtype=float) - @_wraps(_galsim.GaussianDeviate.generate_from_variance) + @implements(_galsim.GaussianDeviate.generate_from_variance) def generate_from_variance(self, array): self._key, _array = self.__class__._generate(self._key, array) return _array * jnp.sqrt(array) @@ -353,7 +353,7 @@ def __str__(self): ) -@_wraps( +@implements( _galsim.BinomialDeviate, lax_description=LAX_FUNCTIONAL_RNG, ) @@ -365,16 +365,16 @@ def __init__(self, seed=None, N=1, p=0.5): self._params["p"] = p @property - @_wraps(_galsim.BinomialDeviate.n) + @implements(_galsim.BinomialDeviate.n) def n(self): return self._params["N"] @property - @_wraps(_galsim.BinomialDeviate.p) + @implements(_galsim.BinomialDeviate.p) def p(self): return self._params["p"] - @_wraps( + @implements( _galsim.BinomialDeviate.generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " @@ -425,7 +425,7 @@ def __str__(self): ) -@_wraps( +@implements( _galsim.PoissonDeviate, lax_description=LAX_FUNCTIONAL_RNG, ) @@ -436,11 +436,11 @@ def __init__(self, seed=None, mean=1.0): self._params["mean"] = mean @property - @_wraps(_galsim.PoissonDeviate.mean) + @implements(_galsim.PoissonDeviate.mean) def mean(self): return self._params["mean"] - @_wraps( + @implements( _galsim.PoissonDeviate.generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " @@ -480,7 +480,7 @@ def _generate_one(key, mean): ) return _key, val - @_wraps(_galsim.PoissonDeviate.generate_from_expectation) + @implements(_galsim.PoissonDeviate.generate_from_expectation) def generate_from_expectation(self, array): self._key, _array = self.__class__._generate_from_exp(self._key, array) return _array @@ -506,7 +506,7 @@ def __str__(self): return "galsim.PoissonDeviate(mean=%r)" % (ensure_hashable(self.mean),) -@_wraps( +@implements( _galsim.WeibullDeviate, lax_description=LAX_FUNCTIONAL_RNG, ) @@ -518,16 +518,16 @@ def __init__(self, seed=None, a=1.0, b=1.0): self._params["b"] = b @property - @_wraps(_galsim.WeibullDeviate.a) + @implements(_galsim.WeibullDeviate.a) def a(self): return self._params["a"] @property - @_wraps(_galsim.WeibullDeviate.b) + @implements(_galsim.WeibullDeviate.b) def b(self): return self._params["b"] - @_wraps( + @implements( _galsim.WeibullDeviate.generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " @@ -576,7 +576,7 @@ def __str__(self): ) -@_wraps( +@implements( _galsim.GammaDeviate, lax_description=LAX_FUNCTIONAL_RNG, ) @@ -588,16 +588,16 @@ def __init__(self, seed=None, k=1.0, theta=1.0): self._params["theta"] = theta @property - @_wraps(_galsim.GammaDeviate.k) + @implements(_galsim.GammaDeviate.k) def k(self): return self._params["k"] @property - @_wraps(_galsim.GammaDeviate.theta) + @implements(_galsim.GammaDeviate.theta) def theta(self): return self._params["theta"] - @_wraps( + @implements( _galsim.GammaDeviate.generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " @@ -642,7 +642,7 @@ def __str__(self): ) -@_wraps( +@implements( _galsim.Chi2Deviate, lax_description=LAX_FUNCTIONAL_RNG, ) @@ -653,11 +653,11 @@ def __init__(self, seed=None, n=1.0): self._params["n"] = n @property - @_wraps(_galsim.Chi2Deviate.n) + @implements(_galsim.Chi2Deviate.n) def n(self): return self._params["n"] - @_wraps( + @implements( _galsim.Chi2Deviate.generate, lax_description=( "JAX arrays cannot be changed in-place, so the JAX version of " @@ -940,7 +940,7 @@ def __str__(self): # self._npoints == other._npoints)) -@_wraps( +@implements( _galsim.random.permute, lax_description="The JAX implementation of this function cannot operate in-place and so returns a new list of arrays.", ) diff --git a/jax_galsim/sensor.py b/jax_galsim/sensor.py index 80ffe1fa..9ff02452 100644 --- a/jax_galsim/sensor.py +++ b/jax_galsim/sensor.py @@ -1,18 +1,18 @@ import galsim as _galsim -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.errors import GalSimUndefinedBoundsError from jax_galsim.position import PositionI -@_wraps(_galsim.Sensor) +@implements(_galsim.Sensor) @register_pytree_node_class class Sensor: def __init__(self): pass - @_wraps(_galsim.Sensor.accumulate) + @implements(_galsim.Sensor.accumulate) def accumulate(self, photons, image, orig_center=None, resume=False): if not image.bounds.isDefined(): raise GalSimUndefinedBoundsError( @@ -20,7 +20,7 @@ def accumulate(self, photons, image, orig_center=None, resume=False): ) return photons.addTo(image) - @_wraps(_galsim.Sensor.calculate_pixel_areas) + @implements(_galsim.Sensor.calculate_pixel_areas) def calculate_pixel_areas(self, image, orig_center=PositionI(0, 0), use_flux=True): return 1.0 diff --git a/jax_galsim/shear.py b/jax_galsim/shear.py index e3fab21d..503ff63e 100644 --- a/jax_galsim/shear.py +++ b/jax_galsim/shear.py @@ -1,7 +1,7 @@ import galsim as _galsim import jax.numpy as jnp from galsim.errors import GalSimIncompatibleValuesError -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.angle import Angle, _Angle, radians @@ -9,7 +9,7 @@ @register_pytree_node_class -@_wraps( +@implements( _galsim.Shear, lax_description="""\ The jax_galsim implementation of ``Shear`` does not perform range checking of the \ @@ -254,13 +254,13 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) - @_wraps(_galsim.Shear.getMatrix) + @implements(_galsim.Shear.getMatrix) def getMatrix(self): return jnp.array( [[1.0 + self.g1, self.g2], [self.g2, 1.0 - self.g1]] ) / jnp.sqrt(1.0 - self.g**2) - @_wraps(_galsim.Shear.rotationWith) + @implements(_galsim.Shear.rotationWith) def rotationWith(self, other): # Save a little time by only working on the first column. S3 = self.getMatrix().dot(other.getMatrix()[:, :1]) @@ -297,7 +297,7 @@ def from_galsim(cls, galsim_shear): return cls(g1=galsim_shear.g1, g2=galsim_shear.g2) -@_wraps(_galsim._Shear) +@implements(_galsim._Shear) def _Shear(shear): ret = Shear.__new__(Shear) ret._g = shear diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index a87ad3d7..b0d00640 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -1,7 +1,7 @@ import galsim as _galsim import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import Partial as partial from jax.tree_util import register_pytree_node_class @@ -191,7 +191,7 @@ def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0): ) -@_wraps( +@implements( _galsim.Spergel, lax_description=r"""The fully normalized Spergel profile (used in both standard GalSim and JAX-GalSim) is .. math:: @@ -377,7 +377,7 @@ def _drawKImage(self, image, jac=None): _jac = jnp.eye(2) if jac is None else jac return draw_by_kValue(self, image, _jac) - @_wraps(_galsim.Spergel.withFlux) + @implements(_galsim.Spergel.withFlux) def withFlux(self, flux): return Spergel( nu=self.nu, @@ -459,7 +459,7 @@ def _shoot_neg(self, u): r = z * self._r0 return r - @_wraps(_galsim.Spergel._shoot) + @implements(_galsim.Spergel._shoot) def _shoot(self, photons, rng): ud = UniformDeviate(rng) u = ud.generate(photons.x) diff --git a/jax_galsim/sum.py b/jax_galsim/sum.py index 855f8590..84da7372 100644 --- a/jax_galsim/sum.py +++ b/jax_galsim/sum.py @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp import numpy as np -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.gsobject import GSObject @@ -11,14 +11,14 @@ from jax_galsim.random import BaseDeviate -@_wraps( +@implements( _galsim.Add, lax_description="Does not support `ChromaticObject` at this point." ) def Add(*args, **kwargs): return Sum(*args, **kwargs) -@_wraps( +@implements( _galsim.Sum, lax_description="Does not support `ChromaticObject` at this point." ) @register_pytree_node_class @@ -74,12 +74,12 @@ def obj_list(self): return self._params["obj_list"] @property - @_wraps(_galsim.Sum.flux) + @implements(_galsim.Sum.flux) def flux(self): flux_list = jnp.array([obj.flux for obj in self.obj_list]) return jnp.sum(flux_list) - @_wraps(_galsim.Sum.withGSParams) + @implements(_galsim.Sum.withGSParams) def withGSParams(self, gsparams=None, **kwargs): if gsparams == self.gsparams: return self diff --git a/jax_galsim/transform.py b/jax_galsim/transform.py index 7b5ff2a0..9ed61e13 100644 --- a/jax_galsim/transform.py +++ b/jax_galsim/transform.py @@ -1,7 +1,7 @@ import galsim as _galsim import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import compute_major_minor_from_jacobian, ensure_hashable @@ -10,7 +10,7 @@ from jax_galsim.position import PositionD -@_wraps( +@implements( _galsim.Transform, lax_description="Does not support Chromatic Objects or Convolutions.", ) @@ -36,7 +36,7 @@ def Transform( ) -@_wraps(_galsim.Transformation) +@implements(_galsim.Transformation) @register_pytree_node_class class Transformation(GSObject): def __init__( @@ -126,7 +126,7 @@ def flux_ratio(self): def _flux(self): return self._flux_scaling * self._original.flux - @_wraps(_galsim.Transformation.withGSParams) + @implements(_galsim.Transformation.withGSParams) def withGSParams(self, gsparams=None, **kwargs): """Create a version of the current object with the given gsparams diff --git a/jax_galsim/utilities.py b/jax_galsim/utilities.py index 95a7a985..0abb1c90 100644 --- a/jax_galsim/utilities.py +++ b/jax_galsim/utilities.py @@ -3,7 +3,7 @@ import galsim as _galsim import jax import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax_galsim.core.utils import has_tracers from jax_galsim.errors import GalSimIncompatibleValuesError, GalSimValueError @@ -12,7 +12,7 @@ printoptions = _galsim.utilities.printoptions -@_wraps( +@implements( _galsim.utilities.lazy_property, lax_description=( "The LAX version of this decorator uses an `_workspace` attribute " @@ -54,7 +54,7 @@ def wrapper(self): ) -@_wraps(_galsim.utilities.parse_pos_args) +@implements(_galsim.utilities.parse_pos_args) def parse_pos_args(args, kwargs, name1, name2, integer=False, others=[]): def canindex(arg): try: @@ -120,7 +120,7 @@ def canindex(arg): return pos -@_wraps(_galsim.utilities.g1g2_to_e1e2) +@implements(_galsim.utilities.g1g2_to_e1e2) def g1g2_to_e1e2(g1, g2): # Conversion: # e = (a^2-b^2) / (a^2+b^2) @@ -139,7 +139,7 @@ def g1g2_to_e1e2(g1, g2): return e1, e2 -@_wraps(_galsim.utilities.convert_interpolant) +@implements(_galsim.utilities.convert_interpolant) def convert_interpolant(interpolant): from jax_galsim.interpolant import Interpolant @@ -150,7 +150,7 @@ def convert_interpolant(interpolant): return Interpolant.from_name(interpolant) -@_wraps(_galsim.utilities.unweighted_moments) +@implements(_galsim.utilities.unweighted_moments) def unweighted_moments(image, origin=None): from jax_galsim.position import PositionD @@ -171,7 +171,7 @@ def unweighted_moments(image, origin=None): return dict(M0=M0, Mx=Mx, My=My, Mxx=Mxx, Myy=Myy, Mxy=Mxy) -@_wraps(_galsim.utilities.unweighted_shape) +@implements(_galsim.utilities.unweighted_shape) def unweighted_shape(arg): from jax_galsim.image import Image @@ -183,7 +183,7 @@ def unweighted_shape(arg): ) -@_wraps(_galsim.utilities.horner) +@implements(_galsim.utilities.horner) def horner(x, coef, dtype=None): x = jnp.array(x) coef = jnp.atleast_1d(coef) @@ -212,7 +212,7 @@ def horner(x, coef, dtype=None): ) -@_wraps(_galsim.utilities.horner2d) +@implements(_galsim.utilities.horner2d) def horner2d(x, y, coefs, dtype=None, triangle=False): x = jnp.array(x) y = jnp.array(y) diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index c55b3e2e..f4fda2c0 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -1,6 +1,6 @@ import galsim as _galsim import jax.numpy as jnp -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import implements from jax.tree_util import register_pytree_node_class from jax_galsim.angle import AngleUnit, arcsec, radians @@ -16,7 +16,7 @@ # We inherit from the reference BaseWCS and only redefine the methods that # make references to jax_galsim objects. class BaseWCS(_galsim.BaseWCS): - @_wraps(_galsim.BaseWCS.toWorld) + @implements(_galsim.BaseWCS.toWorld) def toWorld(self, *args, **kwargs): if len(args) == 1: if isinstance(args[0], GSObject): @@ -33,7 +33,7 @@ def toWorld(self, *args, **kwargs): else: raise TypeError("toWorld() takes either 1 or 2 positional arguments") - @_wraps(_galsim.BaseWCS.posToWorld) + @implements(_galsim.BaseWCS.posToWorld) def posToWorld(self, image_pos, color=None, **kwargs): if color is None: color = self._color @@ -41,7 +41,7 @@ def posToWorld(self, image_pos, color=None, **kwargs): raise TypeError("image_pos must be a PositionD or PositionI argument") return self._posToWorld(image_pos, color=color, **kwargs) - @_wraps(_galsim.BaseWCS.profileToWorld) + @implements(_galsim.BaseWCS.profileToWorld) def profileToWorld( self, image_profile, @@ -57,13 +57,13 @@ def profileToWorld( image_profile, flux_ratio, PositionD(offset) ) - @_wraps(_galsim.BaseWCS.shearToWorld) + @implements(_galsim.BaseWCS.shearToWorld) def shearToWorld(self, image_shear, image_pos=None, world_pos=None, color=None): if color is None: color = self._color return self.local(image_pos, world_pos, color=color)._shearToWorld(image_shear) - @_wraps(_galsim.BaseWCS.toImage) + @implements(_galsim.BaseWCS.toImage) def toImage(self, *args, **kwargs): if len(args) == 1: if isinstance(args[0], GSObject): @@ -80,7 +80,7 @@ def toImage(self, *args, **kwargs): else: raise TypeError("toImage() takes either 1 or 2 positional arguments") - @_wraps(_galsim.BaseWCS.posToImage) + @implements(_galsim.BaseWCS.posToImage) def posToImage(self, world_pos, color=None): if color is None: color = self._color @@ -90,7 +90,7 @@ def posToImage(self, world_pos, color=None): raise TypeError("world_pos must be a PositionD or PositionI argument") return self._posToImage(world_pos, color=color) - @_wraps(_galsim.BaseWCS.profileToImage) + @implements(_galsim.BaseWCS.profileToImage) def profileToImage( self, world_profile, @@ -106,13 +106,13 @@ def profileToImage( world_profile, flux_ratio, PositionD(offset) ) - @_wraps(_galsim.BaseWCS.shearToImage) + @implements(_galsim.BaseWCS.shearToImage) def shearToImage(self, world_shear, image_pos=None, world_pos=None, color=None): if color is None: color = self._color return self.local(image_pos, world_pos, color=color)._shearToImage(world_shear) - @_wraps(_galsim.BaseWCS.local) + @implements(_galsim.BaseWCS.local) def local(self, image_pos=None, world_pos=None, color=None): if color is None: color = self._color @@ -128,13 +128,13 @@ def local(self, image_pos=None, world_pos=None, color=None): raise TypeError("image_pos must be a PositionD or PositionI argument") return self._local(image_pos, color) - @_wraps(_galsim.BaseWCS.jacobian) + @implements(_galsim.BaseWCS.jacobian) def jacobian(self, image_pos=None, world_pos=None, color=None): if color is None: color = self._color return self.local(image_pos, world_pos, color=color)._toJacobian() - @_wraps(_galsim.BaseWCS.affine) + @implements(_galsim.BaseWCS.affine) def affine(self, image_pos=None, world_pos=None, color=None): if color is None: color = self._color @@ -153,7 +153,7 @@ def affine(self, image_pos=None, world_pos=None, color=None): world_pos = self.toWorld(image_pos, color=color) return jac.shiftOrigin(image_pos, world_pos, color=color) - @_wraps(_galsim.BaseWCS.shiftOrigin) + @implements(_galsim.BaseWCS.shiftOrigin) def shiftOrigin(self, origin, world_origin=None, color=None): if color is None: color = self._color @@ -161,7 +161,7 @@ def shiftOrigin(self, origin, world_origin=None, color=None): raise TypeError("origin must be a PositionD or PositionI argument") return self._shiftOrigin(origin, world_origin, color) - @_wraps(_galsim.BaseWCS.withOrigin) + @implements(_galsim.BaseWCS.withOrigin) def withOrigin(self, origin, world_origin=None, color=None): from .deprecated import depr @@ -557,7 +557,7 @@ class LocalWCS(UniformWCS): as (0,0) in world coordinates """ - @_wraps(_galsim.wcs.LocalWCS.isLocal) + @implements(_galsim.wcs.LocalWCS.isLocal) def isLocal(self): return True @@ -838,7 +838,7 @@ def __ne__(self, other): ######################################################################################### -@_wraps(_galsim.PixelScale) +@implements(_galsim.PixelScale) @register_pytree_node_class class PixelScale(LocalWCS): _isPixelScale = True @@ -937,7 +937,7 @@ def __hash__(self): return hash(repr(self)) -@_wraps(_galsim.ShearWCS) +@implements(_galsim.ShearWCS) @register_pytree_node_class class ShearWCS(LocalWCS): _req_params = {"scale": float, "shear": Shear} @@ -1063,7 +1063,7 @@ def tree_unflatten(cls, aux_data, children): return cls(*children) -@_wraps(_galsim.JacobianWCS) +@implements(_galsim.JacobianWCS) @register_pytree_node_class class JacobianWCS(LocalWCS): def __init__(self, dudx, dudy, dvdx, dvdy): @@ -1161,7 +1161,7 @@ def getMatrix(self): """ return jnp.array([[self.dudx, self.dudy], [self.dvdx, self.dvdy]], dtype=float) - @_wraps(_galsim.JacobianWCS.getDecomposition) + @implements(_galsim.JacobianWCS.getDecomposition) def getDecomposition(self): from .angle import radians @@ -1297,7 +1297,7 @@ def __hash__(self): ######################################################################################### -@_wraps(_galsim.OffsetWCS) +@implements(_galsim.OffsetWCS) @register_pytree_node_class class OffsetWCS(UniformWCS): _isPixelScale = True @@ -1369,7 +1369,7 @@ def __hash__(self): return hash(repr(self)) -@_wraps(_galsim.OffsetShearWCS) +@implements(_galsim.OffsetShearWCS) @register_pytree_node_class class OffsetShearWCS(UniformWCS): _req_params = {"scale": float, "shear": Shear} @@ -1452,7 +1452,7 @@ def tree_unflatten(cls, aux_data, children): return cls(*children) -@_wraps(_galsim.AffineTransform) +@implements(_galsim.AffineTransform) @register_pytree_node_class class AffineTransform(UniformWCS): def __init__(self, dudx, dudy, dvdx, dvdy, origin=None, world_origin=None):