Skip to content

Commit

Permalink
Merge pull request #82 from beckermr/photon-array
Browse files Browse the repository at this point in the history
ENH add photon shooting
  • Loading branch information
beckermr authored Dec 19, 2023
2 parents 856e702 + d06205a commit 7e688ed
Show file tree
Hide file tree
Showing 46 changed files with 4,020 additions and 14,506 deletions.
16 changes: 11 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,25 @@
* `Shear`
* `Convolve`
* `InterpolatedImage` and `Interpolant`
* `PhotonArray`
* `Sensor`
* `AngleUnit`, `Angle`, and `CelestialCoord`
* `BaseDeviate` and child classes
* `BaseNoise` and child classes
* Added implementation of fundamental operations:
* `drawImage`
* `drawReal`
* `drawFFT`
* `drawKImage`
* `makePhot`
* `drawPhot`
* Added implementation of simple light profiles:
* `Gaussian`, `Exponential`, `Pixel`, `Box`, `Moffat`
* `Gaussian`, `Exponential`, `Pixel`, `Box`, `Moffat`, `DeltaFunction`
* Added implementation of simple WCS:
* `PixelScale`, `OffsetWCS`, `JacobianWCS`, `AffineTransform`, `ShearWCS`, `OffsetShearWCS`
* Added automated suite of tests against reference GalSim
* `PixelScale`, `OffsetWCS`, `JacobianWCS`, `AffineTransform`, `ShearWCS`, `OffsetShearWCS`, `GSFitsWCS`, `FitsWCS`, `TanWCS`
* Added automated suite of tests using the reference GalSim and LSSTDESC-Coord test suites
* Added support for the `galsim.fits` module
* Added a `from_galsim` method to convert from GalSim objects to JAX-GalSim objects

* Caveats
* Real space convolution and photon shooting methods are not
yet implemented in drawImage.
* Real space convolution are not yet implemented in `drawImage``.
6 changes: 6 additions & 0 deletions jax_galsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
ImageS,
ImageUI,
ImageUS,
_Image,
)

# GSObject
Expand All @@ -55,6 +56,7 @@
from .sum import Add, Sum
from .transform import Transform, Transformation
from .convolve import Convolve, Convolution, Deconvolution, Deconvolve
from .deltafunction import DeltaFunction

# WCS
from .wcs import (
Expand Down Expand Up @@ -86,6 +88,10 @@
)
from .interpolatedimage import InterpolatedImage, _InterpolatedImage

# Photon Shooting
from .photon_array import PhotonArray
from .sensor import Sensor

# packages kept separate
from . import bessel
from . import fits
Expand Down
8 changes: 5 additions & 3 deletions jax_galsim/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax._src.numpy.util import _wraps
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import cast_to_float_array_scalar, ensure_hashable
from jax_galsim.core.utils import cast_to_float, ensure_hashable


@_wraps(_galsim.AngleUnit)
Expand All @@ -34,7 +34,9 @@ def __init__(self, value):
"""
:param value: The measure of the unit in radians.
"""
self._value = cast_to_float_array_scalar(value)
if isinstance(value, AngleUnit):
raise TypeError("Cannot construct AngleUnit from another AngleUnit")
self._value = cast_to_float(value)

@property
def value(self):
Expand Down Expand Up @@ -142,7 +144,7 @@ def __init__(self, theta, unit=None):
raise TypeError("Invalid unit %s of type %s" % (unit, type(unit)))
else:
# Normal case
self._rad = cast_to_float_array_scalar(theta) * unit.value
self._rad = cast_to_float(theta) * unit.value

@property
def rad(self):
Expand Down
60 changes: 44 additions & 16 deletions jax_galsim/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
from jax_galsim.core.utils import cast_to_float, cast_to_int, ensure_hashable
from jax_galsim.position import Position, PositionD, PositionI

BOUNDS_LAX_DESCR = """\
The JAX implementation
- will not always test whether the bounds are valid
- will not always test whether BoundsI is initialized with integers
"""


# The reason for avoid these tests is that they are not easy to do for jitted code.
@_wraps(
_galsim.Bounds,
lax_description=(
"The JAX implementation will not test whether the bounds are valid."
"This is defined as always true."
"It will also not test whether BoundsI is indeed initialized with integers."
),
)
@_wraps(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class Bounds(_galsim.Bounds):
def _parse_args(self, *args, **kwargs):
Expand Down Expand Up @@ -81,6 +81,16 @@ def _parse_args(self, *args, **kwargs):
if kwargs:
raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys())

# for simple inputs, we can check if the bounds are valid
if (
isinstance(self.xmin, (float, int))
and isinstance(self.xmax, (float, int))
and isinstance(self.ymin, (float, int))
and isinstance(self.ymax, (float, int))
and ((self.xmin > self.xmax) or (self.ymin > self.ymax))
):
self._isdefined = False

@property
def true_center(self):
"""The central position of the `Bounds` as a `PositionD`.
Expand Down Expand Up @@ -245,15 +255,18 @@ def from_galsim(cls, galsim_bounds):
"galsim_bounds must be either a %s or a %s"
% (_galsim.BoundsD.__name__, _galsim.BoundsI.__name__)
)
return _cls(
galsim_bounds.xmin,
galsim_bounds.xmax,
galsim_bounds.ymin,
galsim_bounds.ymax,
)
if galsim_bounds.isDefined():
return _cls(
galsim_bounds.xmin,
galsim_bounds.xmax,
galsim_bounds.ymin,
galsim_bounds.ymax,
)
else:
return _cls()


@_wraps(_galsim.BoundsD)
@_wraps(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class BoundsD(Bounds):
_pos_class = PositionD
Expand Down Expand Up @@ -287,13 +300,28 @@ def _center(self):
return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0)


@_wraps(_galsim.BoundsI)
@_wraps(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class BoundsI(Bounds):
_pos_class = PositionI

def __init__(self, *args, **kwargs):
self._parse_args(*args, **kwargs)
# for simple inputs, we can check if the bounds are valid ints
if (
isinstance(self.xmin, (float, int))
and isinstance(self.xmax, (float, int))
and isinstance(self.ymin, (float, int))
and isinstance(self.ymax, (float, int))
and (
self.xmin != int(self.xmin)
or self.xmax != int(self.xmax)
or self.ymin != int(self.ymin)
or self.ymax != int(self.ymax)
)
):
raise TypeError("BoundsI must be initialized with integer values")

self.xmin = cast_to_int(self.xmin)
self.xmax = cast_to_int(self.xmax)
self.ymin = cast_to_int(self.ymin)
Expand Down
10 changes: 10 additions & 0 deletions jax_galsim/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.utils import ensure_hashable
from jax_galsim.gsobject import GSObject
from jax_galsim.random import UniformDeviate


@_wraps(_galsim.Box)
Expand Down Expand Up @@ -115,6 +116,15 @@ def tree_unflatten(cls, aux_data, children):
**aux_data
)

@_wraps(_galsim.Box._shoot)
def _shoot(self, photons, rng):
ud = UniformDeviate(rng)

# this does not fill arrays like in galsim
photons.x = (ud.generate(photons.x) - 0.5) * self.width
photons.y = (ud.generate(photons.y) - 0.5) * self.height
photons.flux = self.flux / photons.size()


@_wraps(_galsim.Pixel)
@register_pytree_node_class
Expand Down
11 changes: 10 additions & 1 deletion jax_galsim/convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from jax_galsim.gsobject import GSObject
from jax_galsim.gsparams import GSParams
from jax_galsim.photon_array import PhotonArray


@_wraps(
Expand Down Expand Up @@ -308,7 +309,15 @@ def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
raise NotImplementedError("Real-space convolutions are not implemented")

def _shoot(self, photons, rng):
raise NotImplementedError("Photon shooting convolutions are not implemented")
self.obj_list[0]._shoot(photons, rng)
# It may be necessary to shuffle when convolving because we do not have a
# guarantee that the convolvee's photons are uncorrelated, e.g., they might
# both have their negative ones at the end.
# However, this decision is now made by the convolve method.
for obj in self.obj_list[1:]:
p1 = PhotonArray(len(photons))
obj._shoot(p1, rng)
photons.convolve(p1, rng)

def _drawKImage(self, image, jac=None):
image = self.obj_list[0]._drawKImage(image, jac)
Expand Down
Loading

0 comments on commit 7e688ed

Please sign in to comment.