From e3768bab5d86343c55366069ac1afb8d30cc7191 Mon Sep 17 00:00:00 2001 From: Andy Kee Date: Tue, 17 Oct 2023 18:29:34 -0700 Subject: [PATCH] Wavefront.propagate_image -> propagate.propagate.dft --- lentil/__init__.py | 2 + lentil/propagate.py | 115 ++++++++++++++++++++++++++++++ lentil/wavefront.py | 75 ------------------- tests/test_propagate.py | 28 ++++---- tests/test_propagate_segmented.py | 6 +- tests/test_wavefront.py | 4 +- 6 files changed, 134 insertions(+), 96 deletions(-) create mode 100644 lentil/propagate.py diff --git a/lentil/__init__.py b/lentil/__init__.py index 46a60cd..51baee8 100644 --- a/lentil/__init__.py +++ b/lentil/__init__.py @@ -24,6 +24,8 @@ Flip ) +from lentil.propagate import propagate_dft + from lentil.ptype import ptype none = ptype('none') diff --git a/lentil/propagate.py b/lentil/propagate.py new file mode 100644 index 0000000..9a1b316 --- /dev/null +++ b/lentil/propagate.py @@ -0,0 +1,115 @@ +import numpy as np + +import lentil +from lentil.field import Field +from lentil.wavefront import Wavefront + +def propagate_dft(wavefront, shape, pixelscale, prop_shape=None, + oversample=2, inplace=True): + """Propagate a Wavefront using Fraunhofer diffraction. + + Parameters + ---------- + shape : int or (2,) tuple of ints + Shape of output Wavefront. + pixelscale : float or (2,) float + Physical sampling of output Wavefront. If a single value is supplied, + the output is assumed to be uniformly sampled in both x and y. + prop_shape : int or (2,) tuple of ints, optional + Shape of propagation output. If None (default), + ``prop_shape = prop``. If ``prop_shape != prop``, the propagation + result is placed in the appropriate location in the output plane. + ``prop_shape`` should not be larger than ``prop``. + oversample : int, optional + Number of times to oversample the output plane. Default is 2. + inplace : bool, optional + If True (default) the Wavefront is propagated in-place, otherwise + a copy is created and propagated. + + Returns + ------- + wavefront : :class:`~lentil.Wavefront` + A Wavefront propagated to the specified image plane + """ + + ptype_out = _propagate_ptype(wavefront.ptype, method='fraunhofer') + + shape = np.broadcast_to(shape, (2,)) + prop_shape = shape if prop_shape is None else np.broadcast_to(prop_shape, (2,)) + shape_out = shape * oversample + prop_shape_out = prop_shape * oversample + + dx = wavefront.pixelscale + du = np.broadcast_to(pixelscale, (2,)) + z = wavefront.focal_length + + data = wavefront.data + + if inplace: + out = wavefront + out.data = [] + out.pixelscale = du/oversample + out.shape = shape_out + out.ptype = ptype_out + else: + out = Wavefront.empty(wavelength=wavefront.wavelength, + pixelscale = du/oversample, + shape = shape_out, + ptype = ptype_out) + + for field in data: + # compute the field shift from any embedded tilts. note the return value + # is specified in terms of (r, c) + shift = field.shift(z=wavefront.focal_length, wavelength=wavefront.wavelength, + pixelscale=du, oversample=oversample, + indexing='ij') + + fix_shift = np.fix(shift) + subpx_shift = shift - fix_shift + + if _overlap(prop_shape_out, fix_shift, shape_out): + alpha = lentil.helper.dft_alpha(dx=dx, du=du, z=z, + wave=wavefront.wavelength, + oversample=oversample) + data = lentil.fourier.dft2(f=field.data, alpha=alpha, + npix=prop_shape_out, + shift=subpx_shift, + offset=field.offset, unitary=True) + out.data.append(Field(data=data, pixelscale=du/oversample, + offset=fix_shift)) + + if not out.data: + out.data.append(Field(data=0)) + + return out + +def _overlap(field_shape, field_shift, output_shape): + # Return True if there's any overlap between a shifted field and the + # output shape + output_shape = np.asarray(output_shape) + field_shape = np.asarray(field_shape) + field_shift = np.asarray(field_shift) + + # Output coordinates of the upper left corner of the shifted data array + field_shifted_ul = (output_shape / 2) - (field_shape / 2) + field_shift + + if field_shifted_ul[0] > output_shape[0]: + return False + if field_shifted_ul[0] + field_shape[0] < 0: + return False + if field_shifted_ul[1] > output_shape[1]: + return False + if field_shifted_ul[1] + field_shape[1] < 0: + return False + return True + +def _propagate_ptype(ptype, method='fraunhofer'): + if method == 'fraunhofer': + if ptype not in (lentil.pupil, lentil.image): + raise TypeError("Wavefront must have ptype 'pupil' "\ + "or 'image'") + + if ptype == lentil.pupil: + return lentil.image + else: + return lentil.pupil \ No newline at end of file diff --git a/lentil/wavefront.py b/lentil/wavefront.py index 0718bb6..3bef737 100644 --- a/lentil/wavefront.py +++ b/lentil/wavefront.py @@ -118,81 +118,6 @@ def insert(self, out, weight=1): out = lentil.field.insert(field, out, intensity=True, weight=weight) return out - def propagate_image(self, pixelscale, npix, npix_prop=None, oversample=2, - inplace=True): - """Propagate the Wavefront from a Pupil to an Image plane using - Fraunhofer diffraction. - - Parameters - ---------- - pixelscale : float or (2,) float - Physical sampling of output (image) plane. If a single value is supplied, - the output is assumed to be uniformly sampled in both x and y. - npix : int or (2,) tuple of ints - Shape of output plane. - npix_prop : int or (2,) tuple of ints, optional - Shape of propagation output plane. If None (default), - ``npix_prop = npix``. If ``npix_prop != npix``, the propagation - result is placed in the appropriate location in the output plane. - npix_prop cannot be larger than npix. - oversample : int, optional - Number of times to oversample the output plane. Default is 2. - inplace : bool, optional - If True (default) the wavefront is propagated in-place, otherwise - a copy is created and propagated. - - Returns - ------- - wavefront : :class:`~lentil.Wavefront` - A Wavefront propagated to the specified image plane - - """ - if self.ptype != lentil.pupil: - raise ValueError("Wavefront must have planetype 'pupil'") - - npix = np.asarray(lentil.sanitize_shape(npix)) - npix_prop = npix if npix_prop is None else np.asarray(lentil.sanitize_shape(npix_prop)) - prop_shape = npix_prop * oversample - - dx = self.pixelscale - du = np.asarray(lentil.sanitize_shape(pixelscale)) - z = self.focal_length - data = self.data - - if inplace: - out = self - out.data = [] - out.pixelscale = du / oversample - out.shape = npix * oversample - out.focal_length = np.inf - out.ptype = lentil.image - else: - out = Wavefront.empty(wavelength=self.wavelength, - pixelscale=du/oversample, - shape=npix*oversample, - ptype=lentil.image) - - for field in data: - # compute the field shift from any embedded tilts. note the return value - # is specified in terms of (r, c) - shift = field.shift(z=z, wavelength=self.wavelength, - pixelscale=du, oversample=oversample, - indexing='ij') - - fix_shift = np.fix(shift) - dft_shift = shift - fix_shift - - if _overlap(prop_shape, fix_shift, out.shape): - alpha = lentil.helper.dft_alpha(dx=dx, du=du, - wave=self.wavelength, z=z, - oversample=oversample) - data = lentil.fourier.dft2(f=field.data, alpha=alpha, - npix=prop_shape, shift=dft_shift, - offset=field.offset, unitary=True) - out.data.append(Field(data=data, pixelscale=du/oversample, - offset=fix_shift)) - return out - def _overlap(field_shape, field_shift, output_shape): # Return True if there's any overlap between a shifted field and the diff --git a/tests/test_propagate.py b/tests/test_propagate.py index 7417210..18e6bda 100644 --- a/tests/test_propagate.py +++ b/tests/test_propagate.py @@ -146,7 +146,7 @@ def test_amplitude_normalize_power(): p = TiltPupil(npix=256) w = lentil.Wavefront(wavelength=650e-9) w *= p - w = w.propagate_image(pixelscale=5e-6, npix=(64,64), oversample=1) + w = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=1) psf = w.intensity assert (np.sum(psf) <= 1) and (np.sum(psf) >= 0.95) @@ -155,11 +155,8 @@ def test_amplitude_normalize_power_oversample(): p = TiltPupil(npix=256) w = lentil.Wavefront(wavelength=650e-9) w *= p - w = w.propagate_image(pixelscale=5e-6, npix=(64,64), oversample=2) + w = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=2) psf = w.intensity - - #planes = [TiltPupil(npix=256), BasicDetector()] - #psf = lentil.propagate(planes, 650e-9, npix=(64, 64), oversample=2, rebin=True) assert (np.sum(psf) <= 1) and (np.sum(psf) >= 0.95) @@ -176,26 +173,25 @@ def test_propagate_airy(): w = lentil.Wavefront(wavelength=650e-9) w *= p - w = w.propagate_image(pixelscale=5e-6, npix=511, oversample=1) + w = lentil.propagate_dft(w, shape=511, pixelscale=5e-6, oversample=1) psf = w.intensity psf = psf/np.max(psf) assert np.all(np.isclose(psf, psf_airy, atol=1e-3)) def test_propagate_tilt_angle(): - #planes = [TiltPupil(npix=256), BasicDetector()] p = TiltPupil(npix=256) w_phase = lentil.Wavefront(650e-9) w_phase = p.multiply(w_phase) - w_phase = w_phase.propagate_image(pixelscale=5e-6, npix=128, oversample=2) + w_phase = lentil.propagate_dft(w_phase, shape=128, pixelscale=5e-6, oversample=2) psf_phase = w_phase.intensity - p.fit_tilt() + p.fit_tilt(inplace=True) w_angle = lentil.Wavefront(650e-9) w_angle = w_angle * p - w_angle = w_angle.propagate_image(pixelscale=5e-6, npix=128, oversample=2) + w_angle = lentil.propagate_dft(w_angle, shape=128, pixelscale=5e-6, oversample=2) psf_angle = w_angle.intensity # threshold the PSFs so that the centroiding is consistent @@ -214,7 +210,7 @@ def test_propagate_tilt_phase_analytic(): w = lentil.Wavefront(650e-9) w = pupil.multiply(w) - w = w.propagate_image(pixelscale=pixelscale, npix=npix, oversample=oversample) + w = lentil.propagate_dft(w, shape=npix, pixelscale=pixelscale, oversample=oversample) psf = w.intensity psf = psf/np.max(psf) @@ -244,7 +240,7 @@ def test_propagate_tilt_angle_analytic(): pupil.fit_tilt() w = lentil.Wavefront(650e-9) w = w * pupil - w = w.propagate_image(pixelscale=pixelscale, npix=npix, oversample=oversample) + w = lentil.propagate_dft(w, shape=npix, pixelscale=pixelscale, oversample=oversample) psf = w.intensity psf = psf/np.max(psf) @@ -274,12 +270,12 @@ def test_propagate_resample(): p = lentil.Pupil(focal_length=10, pixelscale=1 / 240, amplitude=amp, phase=opd) w = lentil.Wavefront(650e-9) w *= p - wi = w.propagate_image(pixelscale=5e-6, npix=64, oversample=10) + wi = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=10) p2 = p.rescale(scale=3, inplace=False) w2 = lentil.Wavefront(650e-9) w2 *= p2 - w2i = w2.propagate_image(pixelscale=5e-6, npix=64, oversample=10) + w2i = lentil.propagate_dft(w2, shape=(64,64), pixelscale=5e-6, oversample=10) # compute cross correlation between wi and w2i xc = np.fft.ifftshift(np.conj(np.fft.fft2(wi.intensity)) * np.fft.fft2(w2i.intensity)) @@ -294,8 +290,8 @@ def test_propagate_image_inplace(): amplitude=lentil.circle((256, 256), 120)) w = lentil.Wavefront(650e-9) w *= p - w_copy = w.propagate_image(pixelscale=5e-6, npix=64, oversample=2, inplace=False) - w_inplace = w.propagate_image(pixelscale=5e-6, npix=64, oversample=2, inplace=True) + w_copy = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=2, inplace=False) + w_inplace = lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6, oversample=2, inplace=True) assert w_copy is not w assert w_inplace is w diff --git a/tests/test_propagate_segmented.py b/tests/test_propagate_segmented.py index ca2968d..f2d1b71 100644 --- a/tests/test_propagate_segmented.py +++ b/tests/test_propagate_segmented.py @@ -30,13 +30,13 @@ def test_propagate_tilt_angle_mono(): w1 = lentil.Wavefront(wavelength=650e-9) w1 *= p - w1 = w1.propagate_image(pixelscale=5e-6, npix=128) + w1 = lentil.propagate_dft(w1, shape=(128,128), pixelscale=5e-6) psf_phase = w1.intensity - p.fit_tilt() + p.fit_tilt(inplace=True) w2 = lentil.Wavefront(wavelength=650e-9) w2 *= p - w2 = w2.propagate_image(pixelscale=5e-6, npix=128) + w2 = lentil.propagate_dft(w2, shape=(128,128), pixelscale=5e-6) psf_angle = w2.intensity # Normalize and threshold the PSFs so that the centroiding is consistent diff --git a/tests/test_wavefront.py b/tests/test_wavefront.py index 659b133..216c9b2 100644 --- a/tests/test_wavefront.py +++ b/tests/test_wavefront.py @@ -15,8 +15,8 @@ def test_wavefront_rmul(): def test_wavefront_propagate_image_non_pupil(): w = lentil.Wavefront(wavelength=500e-9) - with pytest.raises(ValueError): - w.propagate_image(pixelscale=5e-6, npix=64) + with pytest.raises(TypeError): + lentil.propagate_dft(w, shape=(64,64), pixelscale=5e-6) @pytest.mark.parametrize('field_shape, field_shift, output_shape', [