Skip to content

Commit

Permalink
Merge pull request #80 from beckermr/gsfitswcs
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr authored Nov 21, 2023
2 parents 50514dd + 87607e3 commit 856e702
Show file tree
Hide file tree
Showing 16 changed files with 1,673 additions and 123 deletions.
1 change: 1 addition & 0 deletions jax_galsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
)
from .fits import FitsHeader
from .celestial import CelestialCoord
from .fitswcs import TanWCS, FitsWCS, GSFitsWCS

# Shear
from .shear import Shear, _Shear
Expand Down
6 changes: 3 additions & 3 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax._src.numpy.util import _wraps
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import cast_scalar_to_float, ensure_hashable
from jax_galsim.core.utils import cast_to_float_array_scalar, ensure_hashable


@_wraps(_galsim.AngleUnit)
Expand All @@ -34,7 +34,7 @@ def __init__(self, value):
"""
:param value: The measure of the unit in radians.
"""
self._value = cast_scalar_to_float(value)
self._value = cast_to_float_array_scalar(value)

@property
def value(self):
Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(self, theta, unit=None):
raise TypeError("Invalid unit %s of type %s" % (unit, type(unit)))
else:
# Normal case
self._rad = cast_scalar_to_float(theta) * unit.value
self._rad = cast_to_float_array_scalar(theta) * unit.value

@property
def rad(self):
Expand Down
22 changes: 9 additions & 13 deletions jax_galsim/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
from jax._src.numpy.util import _wraps
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import (
cast_scalar_to_float,
cast_scalar_to_int,
ensure_hashable,
)
from jax_galsim.core.utils import cast_to_float, cast_to_int, ensure_hashable
from jax_galsim.position import Position, PositionD, PositionI


Expand Down Expand Up @@ -264,10 +260,10 @@ class BoundsD(Bounds):

def __init__(self, *args, **kwargs):
self._parse_args(*args, **kwargs)
self.xmin = cast_scalar_to_float(self.xmin)
self.xmax = cast_scalar_to_float(self.xmax)
self.ymin = cast_scalar_to_float(self.ymin)
self.ymax = cast_scalar_to_float(self.ymax)
self.xmin = cast_to_float(self.xmin)
self.xmax = cast_to_float(self.xmax)
self.ymin = cast_to_float(self.ymin)
self.ymax = cast_to_float(self.ymax)

def _check_scalar(self, x, name):
try:
Expand Down Expand Up @@ -298,10 +294,10 @@ class BoundsI(Bounds):

def __init__(self, *args, **kwargs):
self._parse_args(*args, **kwargs)
self.xmin = cast_scalar_to_int(self.xmin)
self.xmax = cast_scalar_to_int(self.xmax)
self.ymin = cast_scalar_to_int(self.ymin)
self.ymax = cast_scalar_to_int(self.ymax)
self.xmin = cast_to_int(self.xmin)
self.xmax = cast_to_int(self.xmax)
self.ymin = cast_to_int(self.ymin)
self.ymax = cast_to_int(self.ymax)

def _check_scalar(self, x, name):
try:
Expand Down
24 changes: 17 additions & 7 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,27 @@ def compute_major_minor_from_jacobian(jac):
return major, minor


def convert_to_float(x):
def cast_to_float_array_scalar(x):
"""Cast the input to a float array scalar. Works on python floats, iterables and jax arrays.
For iterables it always takes the first element after a call to .ravel()"""
if isinstance(x, jax.Array):
if x.shape == ():
return x.item()
else:
return x[0].astype(float).item()
return jnp.atleast_1d(x).astype(float).ravel()[0]
elif hasattr(x, "astype"):
return x.astype(float).ravel()[0]
else:
return jnp.atleast_1d(jnp.array(x, dtype=float)).ravel()[0]


def cast_to_python_float(x):
"""Cast the input to a python float. Works on python floats and jax arrays.
For jax arrays it always takes the first element after a call to .ravel()"""
if isinstance(x, jax.Array):
return cast_to_float_array_scalar(x).item()
else:
return float(x)


def cast_scalar_to_float(x):
def cast_to_float(x):
"""Cast the input to a float. Works on python floats and jax arrays."""
if isinstance(x, jax.Array):
return x.astype(float)
Expand All @@ -40,7 +50,7 @@ def cast_scalar_to_float(x):
return x


def cast_scalar_to_int(x):
def cast_to_int(x):
"""Cast the input to an int. Works on python floats/ints and jax arrays."""
if isinstance(x, jax.Array):
return x.astype(int)
Expand Down
1 change: 1 addition & 0 deletions jax_galsim/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
GalSimUndefinedBoundsError,
GalSimValueError,
GalSimWarning,
galsim_warn,
)
Loading

0 comments on commit 856e702

Please sign in to comment.