Skip to content

Commit

Permalink
make fft a static arg and test we can still jit/vmap
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Oct 23, 2024
1 parent 93755cc commit 860adc8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 33 deletions.
8 changes: 5 additions & 3 deletions bpd/draw.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import galsim
import jax_galsim as xgalsim

GSPARAMS = xgalsim.GSParams(minimum_fft_size=256, maximum_fft_size=256)
from jax_galsim import GSParams


def draw_gaussian(
Expand All @@ -16,14 +15,17 @@ def draw_gaussian(
pixel_scale: float = 0.2,
slen: int = 53,
psf_hlr: float = 0.7,
fft_size: int = 256,
):
gsparams = GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)

# x, y arguments in pixels
gal = xgalsim.Gaussian(flux=f, half_light_radius=hlr)
gal = gal.shear(g1=e1, g2=e2)
gal = gal.shear(g1=g1, g2=g2)

psf = xgalsim.Gaussian(flux=1, half_light_radius=psf_hlr)
gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(GSPARAMS)
gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(gsparams)
image = gal_conv.drawImage(nx=slen, ny=slen, scale=pixel_scale, offset=(x, y))
return image.array

Expand Down
27 changes: 27 additions & 0 deletions tests/test_draw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Minimal amount of code that checks jax-galsim is working correctly."""

from functools import partial

import jax.numpy as jnp
from jax import jit as jjit
from jax import vmap

from bpd.draw import draw_gaussian


def test_jax_galsim():
_draw_fnc1 = partial(draw_gaussian, slen=101, fft_size=512)
_draw_fnc = vmap(jjit(_draw_fnc1))

f = jnp.array([1000, 2000])
hlr = jnp.array([0.9, 1.0])
e1 = jnp.array([0.2, -0.1])
e2 = jnp.array([0.0, 0.2])
g1 = jnp.array([0.02, 0.0])
g2 = jnp.array([0.0, 0.02])
x = jnp.array([1.0, 0.0])
y = jnp.array([0.0, 1.0])

a = _draw_fnc(f=f, hlr=hlr, e1=e1, e2=e2, g1=g1, g2=g2, x=x, y=y)
assert a.ndim == 3
assert a.shape == (2, 101, 101)
30 changes: 0 additions & 30 deletions tests/test_jax_galsim.py

This file was deleted.

0 comments on commit 860adc8

Please sign in to comment.