Skip to content

Commit

Permalink
tweaks for cupy compability
Browse files Browse the repository at this point in the history
  • Loading branch information
kvangorkom committed Feb 23, 2025
1 parent 363b119 commit 9e924ed
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 40 deletions.
88 changes: 57 additions & 31 deletions poppy/polarized_wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,19 @@ def __init__(self,
**kwargs
)
# TO DO: clean up the logic of checking which is specified and handling appropriately
self.input_stokes_vector = input_stokes_vector
self.input_polarization = input_polarization
#self.input_stokes_vector = input_stokes_vector
#self.input_polarization = input_polarization
self.pol_type = None

if input_stokes_vector is not None: # wavefront tensor
self.input_polarization = None
self.pol_type = 'tensor'
self.wavefront = self.wavefront * xp.eye(2)[:, :, xp.newaxis, xp.newaxis]
self.input_stokes_vector = xp.asarray(input_stokes_vector)
elif input_polarization is not None: # wavefront vector
self.pol_type = 'vector'
self.wavefront = self.wavefront * xp.asarray(input_polarization)[:, xp.newaxis, xp.newaxis]
self.input_polarization = xp.asarray(input_polarization)
self.wavefront = self.wavefront * self.input_polarization[:, xp.newaxis, xp.newaxis]
else:
raise ValueError('Either input_stokes_vector or input_polarization must be specified! For scalar diffraction, use Wavefront or FresnelWavefront.')

Expand Down Expand Up @@ -232,41 +234,65 @@ def jones_to_mueller(jones_matrix):
"""
shape = jones_matrix.shape
# ordering convention below starts with diagonal terms
j = xp.concatenate([[jones_matrix[0,0],
jones_matrix[1,1],
jones_matrix[1,0],
jones_matrix[0,1]]],
axis=0)
j = xp.concatenate(xp.asanyarray([[jones_matrix[0,0], # cupy requires casting the list to a cupy array
jones_matrix[1,1],
jones_matrix[1,0],
jones_matrix[0,1]]]),
axis=0)
jc = j.conj()
e = j * jc

e0, e1, e2, e3 = e
j0, j1, j2, j3 = j
jc0, jc1, jc2, jc3 = jc

# construct the Mueller matrix
# row 1
M00 = ne.evaluate('0.5*(e0 + e1 + e2 + e3)')
M01 = ne.evaluate('0.5*(e0 - e1 - e2 + e3)')
M02 = ne.evaluate('(j0*jc2).real + (j3*jc1).real')
M03 = ne.evaluate('-(jc0*j2).imag - (jc3*j1).imag')
# row 2
M10 = ne.evaluate('0.5*(e0 - e1 + e2 - e3)')
M11 = ne.evaluate('0.5*(e0 + e1 - e2 - e3)')
M12 = ne.evaluate('(j0*jc2).real - (j3*jc1).real')
M13 = ne.evaluate('-(jc0*j2).imag + (jc3*j1).imag')
# row 3
M20 = ne.evaluate('(j0*jc3).real + (j2*jc1).real')
M21 = ne.evaluate('(j0*jc3).real - (j2*jc1).real')
M22 = ne.evaluate('(j0*jc1).real + (j2*jc3).real')
M23 = ne.evaluate('-(jc0*j1).imag + (jc2*j3).imag')
# row 4
M30 = ne.evaluate('(jc0*j3).imag + (jc2*j1).imag')
M31 = ne.evaluate('(jc0*j3).imag - (jc2*j1).imag')
M32 = ne.evaluate('(jc0*j1).imag + (jc2*j3).imag')
M33 = ne.evaluate('(j0*jc1).real - (j2*jc3).real')

M = np.asarray([[M00, M01, M02, M03],
if accel_math._USE_NUMEXPR:
# construct the Mueller matrix
# row 1
M00 = ne.evaluate('0.5*(e0 + e1 + e2 + e3)')
M01 = ne.evaluate('0.5*(e0 - e1 - e2 + e3)')
M02 = ne.evaluate('(j0*jc2).real + (j3*jc1).real')
M03 = ne.evaluate('-(jc0*j2).imag - (jc3*j1).imag')
# row 2
M10 = ne.evaluate('0.5*(e0 - e1 + e2 - e3)')
M11 = ne.evaluate('0.5*(e0 + e1 - e2 - e3)')
M12 = ne.evaluate('(j0*jc2).real - (j3*jc1).real')
M13 = ne.evaluate('-(jc0*j2).imag + (jc3*j1).imag')
# row 3
M20 = ne.evaluate('(j0*jc3).real + (j2*jc1).real')
M21 = ne.evaluate('(j0*jc3).real - (j2*jc1).real')
M22 = ne.evaluate('(j0*jc1).real + (j2*jc3).real')
M23 = ne.evaluate('-(jc0*j1).imag + (jc2*j3).imag')
# row 4
M30 = ne.evaluate('(jc0*j3).imag + (jc2*j1).imag')
M31 = ne.evaluate('(jc0*j3).imag - (jc2*j1).imag')
M32 = ne.evaluate('(jc0*j1).imag + (jc2*j3).imag')
M33 = ne.evaluate('(j0*jc1).real - (j2*jc3).real')
else:
# construct the Mueller matrix
# row 1
M00 =0.5*(e0 + e1 + e2 + e3)
M01 =0.5*(e0 - e1 - e2 + e3)
M02 = (j0*jc2).real + (j3*jc1).real
M03 =-(jc0*j2).imag - (jc3*j1).imag
# row 2
M10 = 0.5*(e0 - e1 + e2 - e3)
M11 = 0.5*(e0 + e1 - e2 - e3)
M12 =(j0*jc2).real - (j3*jc1).real
M13 = -(jc0*j2).imag + (jc3*j1).imag
# row 3
M20 = (j0*jc3).real + (j2*jc1).real
M21 = (j0*jc3).real - (j2*jc1).real
M22 = (j0*jc1).real + (j2*jc3).real
M23 = -(jc0*j1).imag + (jc2*j3).imag
# row 4
M30 = (jc0*j3).imag + (jc2*j1).imag
M31 = (jc0*j3).imag - (jc2*j1).imag
M32 = (jc0*j1).imag + (jc2*j3).imag
M33 = (j0*jc1).real - (j2*jc3).real


M = xp.asarray([[M00, M01, M02, M03],
[M10, M11, M12, M13],
[M20, M21, M22, M23],
[M30, M31, M32, M33]])
Expand Down
33 changes: 24 additions & 9 deletions poppy/poppy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,10 +852,10 @@ def _resample_wavefront_pixelscale(self, detector):
_log.debug("Wavefront pixel scale: {:.3f}".format(self.pixelscale.to(detector.pixelscale.unit)))
_log.debug("Desired detector pixel scale: {:.3f}".format(detector.pixelscale))
_log.debug("Wavefront FOV: {} pixels, {:.3f}".format(self.shape,
self.shape[0]*u.pixel*self.pixelscale.to(
self.shape[-2]*u.pixel*self.pixelscale.to(
detector.pixelscale.unit)))
_log.debug("Desired detector FOV: {} pixels, {:.3f}".format(detector.shape,
detector.shape[0]*u.pixel*detector.pixelscale))
detector.shape[-2]*u.pixel*detector.pixelscale))

# Provide 2-pixel margin around image to reduce interpolation errors at edge, but also make
# sure that image is centered properly after it gets cropped down to detector size
Expand Down Expand Up @@ -910,14 +910,14 @@ def interpolator_multidim(x_out, y_out):
#wf_xmin = pixscale * cropped_wf.shape[0]/2
# Note, carefully handle the offset-by-one to be consistent with
# the use of arange above; avoid fencepost error.
wf_xmax = pixscale_in * cropped_wf.shape[0]/2
wf_xmax = pixscale_in * cropped_wf.shape[-2]/2

x,y = xp.ogrid[-wf_xmax:wf_xmax-pixscale_in:cropped_wf.shape[0]*1j,
-wf_xmax:wf_xmax-pixscale_in:cropped_wf.shape[1]*1j]
x,y = xp.ogrid[-wf_xmax:wf_xmax-pixscale_in:cropped_wf.shape[-2]*1j,
-wf_xmax:wf_xmax-pixscale_in:cropped_wf.shape[-1]*1j]

det_xmax = pixscale_out * detector.shape[0]/2
newx,newy = xp.mgrid[-det_xmax:det_xmax-pixscale_out:detector.shape[0]*1j,
-det_xmax:det_xmax-pixscale_out:detector.shape[1]*1j]
det_xmax = pixscale_out * detector.shape[-2]/2
newx,newy = xp.mgrid[-det_xmax:det_xmax-pixscale_out:detector.shape[-2]*1j,
-det_xmax:det_xmax-pixscale_out:detector.shape[-1]*1j]

x0 = x[0,0]
y0 = y[0,0]
Expand All @@ -929,7 +929,22 @@ def interpolator_multidim(x_out, y_out):

coords = xp.array([ivals, jvals])

new_wf = _scipy.ndimage.map_coordinates(cropped_wf, coords, order=detector.interp_order)
def interpolate(arr):
"""
Handle the interpolation for scalar and polarized wavefronts
"""
if xp.ndim(arr) == 2:
return _scipy.ndimage.map_coordinates(arr, coords, order=detector.interp_order)
else:
# for polarized wavefronts, loop over polarization axis/axes and perform the interpolation
pol_shape = arr.shape[:-2]
resampled_arr = xp.empty((*pol_shape, len(newx), len(newy)), dtype=arr.dtype)
for i in np.ndindex(pol_shape):
resampled_arr[i] = _scipy.ndimage.map_coordinates(arr[i], coords, order=detector.interp_order)
return resampled_arr

#new_wf = _scipy.ndimage.map_coordinates(cropped_wf, coords, order=detector.interp_order)
new_wf = interpolate(cropped_wf)

# enforce conservation of energy:
new_wf *= 1. / pixscale_ratio
Expand Down

0 comments on commit 9e924ed

Please sign in to comment.