Skip to content

Commit

Permalink
perf: unroll everything
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed Jun 18, 2024
1 parent 1389966 commit 4798a97
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions jax_galsim/interpolatedimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,7 @@ def _body(i, args):
interp.xrange + 1,
_body_1d,
[z, wy, msk, yind, xi, xp, zp],
unroll=int(interp.xrange),
unroll=int(interp.xrange) * 2 + 1,
)[0]
return [z, xi, yi, xp, yp, zp]

Expand All @@ -1056,7 +1056,7 @@ def _body(i, args):
interp.xrange + 1,
_body,
[jnp.zeros(x.shape, dtype=float), xi, yi, xp, yp, zp],
unroll=int(interp.xrange),
unroll=int(interp.xrange) * 2 + 1,
)[0]

# we reshape on the way out to match the input shape
Expand Down Expand Up @@ -1167,7 +1167,7 @@ def _body(i, args):
interp.xrange + 1,
_body_1d,
[z, wky, kyind, kxi, nkx, nkx_2, kxp, zp],
unroll=int(interp.xrange),
unroll=int(interp.xrange) * 2 + 1,
)[0]
return [z, kxi, kyi, nky, nkx, nkx_2, kxp, kyp, zp]

Expand All @@ -1176,7 +1176,7 @@ def _body(i, args):
interp.xrange + 1,
_body,
[jnp.zeros(kx.shape, dtype=complex), kxi, kyi, nky, nkx, nkx_2, kxp, kyp, zp],
unroll=int(interp.xrange),
unroll=int(interp.xrange) * 2 + 1,
)[0]
return z.reshape(orig_shape)

Expand Down

0 comments on commit 4798a97

Please sign in to comment.