Skip to content

Commit

Permalink
Wavefront.propagate_image -> propagate.propagate.dft
Browse files Browse the repository at this point in the history
  • Loading branch information
andykee committed Oct 18, 2023
1 parent a27c769 commit e3768ba
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 96 deletions.
2 changes: 2 additions & 0 deletions lentil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
Flip
)

from lentil.propagate import propagate_dft

from lentil.ptype import ptype

none = ptype('none')
Expand Down
115 changes: 115 additions & 0 deletions lentil/propagate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import numpy as np

import lentil
from lentil.field import Field
from lentil.wavefront import Wavefront

def propagate_dft(wavefront, shape, pixelscale, prop_shape=None,
oversample=2, inplace=True):
"""Propagate a Wavefront using Fraunhofer diffraction.
Parameters
----------
shape : int or (2,) tuple of ints
Shape of output Wavefront.
pixelscale : float or (2,) float
Physical sampling of output Wavefront. If a single value is supplied,
the output is assumed to be uniformly sampled in both x and y.
prop_shape : int or (2,) tuple of ints, optional
Shape of propagation output. If None (default),
``prop_shape = prop``. If ``prop_shape != prop``, the propagation
result is placed in the appropriate location in the output plane.
``prop_shape`` should not be larger than ``prop``.
oversample : int, optional
Number of times to oversample the output plane. Default is 2.
inplace : bool, optional
If True (default) the Wavefront is propagated in-place, otherwise
a copy is created and propagated.
Returns
-------
wavefront : :class:`~lentil.Wavefront`
A Wavefront propagated to the specified image plane
"""

ptype_out = _propagate_ptype(wavefront.ptype, method='fraunhofer')

shape = np.broadcast_to(shape, (2,))
prop_shape = shape if prop_shape is None else np.broadcast_to(prop_shape, (2,))
shape_out = shape * oversample
prop_shape_out = prop_shape * oversample

dx = wavefront.pixelscale
du = np.broadcast_to(pixelscale, (2,))
z = wavefront.focal_length

data = wavefront.data

if inplace:
out = wavefront
out.data = []
out.pixelscale = du/oversample
out.shape = shape_out
out.ptype = ptype_out
else:
out = Wavefront.empty(wavelength=wavefront.wavelength,
pixelscale = du/oversample,
shape = shape_out,
ptype = ptype_out)

for field in data:
# compute the field shift from any embedded tilts. note the return value
# is specified in terms of (r, c)
shift = field.shift(z=wavefront.focal_length, wavelength=wavefront.wavelength,
pixelscale=du, oversample=oversample,
indexing='ij')

fix_shift = np.fix(shift)
subpx_shift = shift - fix_shift

if _overlap(prop_shape_out, fix_shift, shape_out):
alpha = lentil.helper.dft_alpha(dx=dx, du=du, z=z,
wave=wavefront.wavelength,
oversample=oversample)
data = lentil.fourier.dft2(f=field.data, alpha=alpha,
npix=prop_shape_out,
shift=subpx_shift,
offset=field.offset, unitary=True)
out.data.append(Field(data=data, pixelscale=du/oversample,
offset=fix_shift))

if not out.data:
out.data.append(Field(data=0))

return out

def _overlap(field_shape, field_shift, output_shape):
# Return True if there's any overlap between a shifted field and the
# output shape
output_shape = np.asarray(output_shape)
field_shape = np.asarray(field_shape)
field_shift = np.asarray(field_shift)

# Output coordinates of the upper left corner of the shifted data array
field_shifted_ul = (output_shape / 2) - (field_shape / 2) + field_shift

if field_shifted_ul[0] > output_shape[0]:
return False
if field_shifted_ul[0] + field_shape[0] < 0:
return False
if field_shifted_ul[1] > output_shape[1]:
return False
if field_shifted_ul[1] + field_shape[1] < 0:
return False
return True

def _propagate_ptype(ptype, method='fraunhofer'):
if method == 'fraunhofer':
if ptype not in (lentil.pupil, lentil.image):
raise TypeError("Wavefront must have ptype 'pupil' "\
"or 'image'")

if ptype == lentil.pupil:
return lentil.image
else:
return lentil.pupil
75 changes: 0 additions & 75 deletions lentil/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,81 +118,6 @@ def insert(self, out, weight=1):
out = lentil.field.insert(field, out, intensity=True, weight=weight)
return out

def propagate_image(self, pixelscale, npix, npix_prop=None, oversample=2,
inplace=True):
"""Propagate the Wavefront from a Pupil to an Image plane using
Fraunhofer diffraction.
Parameters
----------
pixelscale : float or (2,) float
Physical sampling of output (image) plane. If a single value is supplied,
the output is assumed to be uniformly sampled in both x and y.
npix : int or (2,) tuple of ints
Shape of output plane.
npix_prop : int or (2,) tuple of ints, optional
Shape of propagation output plane. If None (default),
``npix_prop = npix``. If ``npix_prop != npix``, the propagation
result is placed in the appropriate location in the output plane.
npix_prop cannot be larger than npix.
oversample : int, optional
Number of times to oversample the output plane. Default is 2.
inplace : bool, optional
If True (default) the wavefront is propagated in-place, otherwise
a copy is created and propagated.
Returns
-------
wavefront : :class:`~lentil.Wavefront`
A Wavefront propagated to the specified image plane
"""
if self.ptype != lentil.pupil:
raise ValueError("Wavefront must have planetype 'pupil'")

npix = np.asarray(lentil.sanitize_shape(npix))
npix_prop = npix if npix_prop is None else np.asarray(lentil.sanitize_shape(npix_prop))
prop_shape = npix_prop * oversample

dx = self.pixelscale
du = np.asarray(lentil.sanitize_shape(pixelscale))
z = self.focal_length
data = self.data

if inplace:
out = self
out.data = []
out.pixelscale = du / oversample
out.shape = npix * oversample
out.focal_length = np.inf
out.ptype = lentil.image
else:
out = Wavefront.empty(wavelength=self.wavelength,
pixelscale=du/oversample,
shape=npix*oversample,
ptype=lentil.image)

for field in data:
# compute the field shift from any embedded tilts. note the return value
# is specified in terms of (r, c)
shift = field.shift(z=z, wavelength=self.wavelength,
pixelscale=du, oversample=oversample,
indexing='ij')

fix_shift = np.fix(shift)
dft_shift = shift - fix_shift

if _overlap(prop_shape, fix_shift, out.shape):
alpha = lentil.helper.dft_alpha(dx=dx, du=du,
wave=self.wavelength, z=z,
oversample=oversample)
data = lentil.fourier.dft2(f=field.data, alpha=alpha,
npix=prop_shape, shift=dft_shift,
offset=field.offset, unitary=True)
out.data.append(Field(data=data, pixelscale=du/oversample,
offset=fix_shift))
return out


def _overlap(field_shape, field_shift, output_shape):
# Return True if there's any overlap between a shifted field and the
Expand Down
28 changes: 12 additions & 16 deletions tests/test_propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_amplitude_normalize_power():
p = TiltPupil(npix=256)
w = lentil.Wavefront(wavelength=650e-9)
w *= p
w = w.propagate_image(pixelscale=5e-6, npix=(64,64), oversample=1)
w = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=1)
psf = w.intensity
assert (np.sum(psf) <= 1) and (np.sum(psf) >= 0.95)

Expand All @@ -155,11 +155,8 @@ def test_amplitude_normalize_power_oversample():
p = TiltPupil(npix=256)
w = lentil.Wavefront(wavelength=650e-9)
w *= p
w = w.propagate_image(pixelscale=5e-6, npix=(64,64), oversample=2)
w = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=2)
psf = w.intensity

#planes = [TiltPupil(npix=256), BasicDetector()]
#psf = lentil.propagate(planes, 650e-9, npix=(64, 64), oversample=2, rebin=True)
assert (np.sum(psf) <= 1) and (np.sum(psf) >= 0.95)


Expand All @@ -176,26 +173,25 @@ def test_propagate_airy():

w = lentil.Wavefront(wavelength=650e-9)
w *= p
w = w.propagate_image(pixelscale=5e-6, npix=511, oversample=1)
w = lentil.propagate_dft(w, shape=511, pixelscale=5e-6, oversample=1)
psf = w.intensity
psf = psf/np.max(psf)
assert np.all(np.isclose(psf, psf_airy, atol=1e-3))


def test_propagate_tilt_angle():
#planes = [TiltPupil(npix=256), BasicDetector()]

p = TiltPupil(npix=256)

w_phase = lentil.Wavefront(650e-9)
w_phase = p.multiply(w_phase)
w_phase = w_phase.propagate_image(pixelscale=5e-6, npix=128, oversample=2)
w_phase = lentil.propagate_dft(w_phase, shape=128, pixelscale=5e-6, oversample=2)
psf_phase = w_phase.intensity

p.fit_tilt()
p.fit_tilt(inplace=True)
w_angle = lentil.Wavefront(650e-9)
w_angle = w_angle * p
w_angle = w_angle.propagate_image(pixelscale=5e-6, npix=128, oversample=2)
w_angle = lentil.propagate_dft(w_angle, shape=128, pixelscale=5e-6, oversample=2)
psf_angle = w_angle.intensity

# threshold the PSFs so that the centroiding is consistent
Expand All @@ -214,7 +210,7 @@ def test_propagate_tilt_phase_analytic():

w = lentil.Wavefront(650e-9)
w = pupil.multiply(w)
w = w.propagate_image(pixelscale=pixelscale, npix=npix, oversample=oversample)
w = lentil.propagate_dft(w, shape=npix, pixelscale=pixelscale, oversample=oversample)
psf = w.intensity

psf = psf/np.max(psf)
Expand Down Expand Up @@ -244,7 +240,7 @@ def test_propagate_tilt_angle_analytic():
pupil.fit_tilt()
w = lentil.Wavefront(650e-9)
w = w * pupil
w = w.propagate_image(pixelscale=pixelscale, npix=npix, oversample=oversample)
w = lentil.propagate_dft(w, shape=npix, pixelscale=pixelscale, oversample=oversample)
psf = w.intensity

psf = psf/np.max(psf)
Expand Down Expand Up @@ -274,12 +270,12 @@ def test_propagate_resample():
p = lentil.Pupil(focal_length=10, pixelscale=1 / 240, amplitude=amp, phase=opd)
w = lentil.Wavefront(650e-9)
w *= p
wi = w.propagate_image(pixelscale=5e-6, npix=64, oversample=10)
wi = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=10)

p2 = p.rescale(scale=3, inplace=False)
w2 = lentil.Wavefront(650e-9)
w2 *= p2
w2i = w2.propagate_image(pixelscale=5e-6, npix=64, oversample=10)
w2i = lentil.propagate_dft(w2, shape=(64,64), pixelscale=5e-6, oversample=10)

# compute cross correlation between wi and w2i
xc = np.fft.ifftshift(np.conj(np.fft.fft2(wi.intensity)) * np.fft.fft2(w2i.intensity))
Expand All @@ -294,8 +290,8 @@ def test_propagate_image_inplace():
amplitude=lentil.circle((256, 256), 120))
w = lentil.Wavefront(650e-9)
w *= p
w_copy = w.propagate_image(pixelscale=5e-6, npix=64, oversample=2, inplace=False)
w_inplace = w.propagate_image(pixelscale=5e-6, npix=64, oversample=2, inplace=True)
w_copy = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=2, inplace=False)
w_inplace = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=2, inplace=True)

assert w_copy is not w
assert w_inplace is w
6 changes: 3 additions & 3 deletions tests/test_propagate_segmented.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ def test_propagate_tilt_angle_mono():

w1 = lentil.Wavefront(wavelength=650e-9)
w1 *= p
w1 = w1.propagate_image(pixelscale=5e-6, npix=128)
w1 = lentil.propagate_dft(w1, shape=(128,128), pixelscale=5e-6)
psf_phase = w1.intensity

p.fit_tilt()
p.fit_tilt(inplace=True)
w2 = lentil.Wavefront(wavelength=650e-9)
w2 *= p
w2 = w2.propagate_image(pixelscale=5e-6, npix=128)
w2 = lentil.propagate_dft(w2, shape=(128,128), pixelscale=5e-6)
psf_angle = w2.intensity

# Normalize and threshold the PSFs so that the centroiding is consistent
Expand Down
4 changes: 2 additions & 2 deletions tests/test_wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def test_wavefront_rmul():

def test_wavefront_propagate_image_non_pupil():
w = lentil.Wavefront(wavelength=500e-9)
with pytest.raises(ValueError):
w.propagate_image(pixelscale=5e-6, npix=64)
with pytest.raises(TypeError):
lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6)


@pytest.mark.parametrize('field_shape, field_shift, output_shape', [
Expand Down

0 comments on commit e3768ba

Please sign in to comment.