Skip to content

Commit

Permalink
FFT-based far-field propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
andykee committed Nov 21, 2023
1 parent 970414a commit 29aec68
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 43 deletions.
6 changes: 5 additions & 1 deletion lentil/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
Flip
)

from lentil.propagate import propagate_dft
from lentil.propagate import (
propagate_dft,
propagate_fft,
scratch_shape
)

from lentil import radiometry

Expand Down
5 changes: 0 additions & 5 deletions lentil/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,6 @@ def gaussian2d(size, sigma):
return G/np.sum(G)


def dft_alpha(dx, du, wave, z, oversample):
return ((dx[0]*du[0])/(wave*z*oversample),
(dx[1]*du[0])/(wave*z*oversample))


def boundary_slice(x, threshold=0, pad=(0, 0)):
"""Find bounding row and column indices of data within an array and
return the results as slice objects.
Expand Down
157 changes: 120 additions & 37 deletions lentil/propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,62 +7,144 @@

def propagate_fft(wavefront, pixelscale, shape=None, oversample=2,
scratch=None):
"""Propagate a Wavefront in the far-field using the FFT.
Parameters
----------
wavefront : :class:`~lentil.Wavefront`
Wavefront to propagate
pixelscale : float or (2,) float
Physical sampling of output Wavefront. If a single value is supplied,
the output is assumed to be uniformly sampled in both x and y.
shape : int or (2,) tuple of ints or None
Shape of output Wavefront. If None (default), the wavefront shape is
used.
oversample : float, optional
Number of times to oversample the output plane. Default is 2.
scratch : complex ndarray, optional
A pre-allocated array used for padding. Providing a sufficiently
large scratch array can improve broadband propagation performance by
avoiding the repeated allocation of arrays of zeros.
alpha = (wavefront.pixelscale * pixelscale)/(wavefront.wavelength * wavefront.focal_length)
pad_shape = np.round(oversample * np.reciprocal(alpha)).astype(int)
Returns
-------
wavefront : :class:`~lentil.Wavefront`
The propagated Wavefront
# since we can only pad to integer precision, the actual wavelength
# represented by the propagation may be slightly different
wavelength_prop = (pad_shape/oversample * wavefront.pixelscale * pixelscale)/wavefront.focal_length
"""
if _has_tilt(wavefront):
raise NotImplementedError('propagate_fft does not support Wavefronts '
'with fitted tilt. Use propagate_dft instead.')

ptype_out = _propagate_ptype(wavefront.ptype, method='fraunhofer')
pixelscale = np.broadcast_to(pixelscale, (2,))
fft_shape, prop_wavelength = _fft_shape(wavefront.pixelscale,
pixelscale,
wavefront.focal_length,
wavefront.wavelength,
oversample)

# TODO: verify field_shape is smaller than fft_shape
# TODO: what if field is bigger than fft_shape?
# field_shape = wavefront.shape

if shape is None:
shape_out = pad_shape
shape_out = tuple(fft_shape)
shape = (fft_shape[0]//oversample, fft_shape[1]//oversample)
else:
shape = np.broadcast_to(shape, (2,))
shape_out = (np.array(shape) * oversample).astype(int)

ptype_out = _propagate_ptype(wavefront.ptype, method='fraunhofer')
shape = tuple(np.broadcast_to(shape, (2,)))
if np.any(shape > fft_shape/oversample):
raise ValueError(f'requested shape {tuple(shape)} is larger in at '
f'least one dimension than maximum propagation '
f'shape {tuple(fft_shape//oversample)}')
else:
shape_out = (shape[0] * oversample, shape[1]*oversample)

out = Wavefront.empty(wavelength=wavelength_prop,
out = Wavefront.empty(wavelength=prop_wavelength,
pixelscale = pixelscale/oversample,
shape = shape_out,
ptype = ptype_out)

if scratch is not None:
if not all(np.asarray(scratch.shape) > fft_shape):
raise ValueError(f'scratch must have shape greater than or '
f'equal to {tuple(fft_shape)}')

# zero out the portion of scratch that we're going to use for the
# propagation and then insert the Wavefront field(s) into scratch
scratch[0:fft_shape[0], 0:fft_shape[1]] = 0
for field in wavefront.data:
scratch[0:fft_shape[0], 0:fft_shape[1]] = lentil.field.insert(field, scratch[0:fft_shape[0], 0:fft_shape[1]])
field =_fft2(scratch[0:fft_shape[0], 0:fft_shape[1]])

# if scratch is not None:
# if scratch.dtype != complex:
# raise TypeError('scratch must be complex')

# if not all(np.asarray(scratch.shape) > pad_shape):
# raise ValueError(f'scratch must have shape >= {pad_shape}')
else:
field = lentil.pad(wavefront.field, fft_shape)
field = _fft2(field)

out.data.append(Field(data=field, pixelscale=pixelscale/oversample))

# field = scratch # field is just a reference to scratch
# field[:] = 0
return out


def scratch_shape(wavelength, dx, du, z, oversample):
"""Compute the scratch shape required for an FFT propagation
# else:
# field = wavefront.field
Parameters
----------
wavelength : float or array_like
Wavelength or list of wavelengths
dx : float or (2,) float
Physical sampling of input Plane. If a single value is supplied,
the input is assumed to be uniformly sampled in both x and y.
du : float or (2,) float
Physical sampling of output Wavefront. If a single value is supplied,
the output is assumed to be uniformly sampled in both x and y.
z : float
Propagation distance
oversample : float
Number of times to oversample the output plane.
# if wavefront.tilt:
# raise ValueError
Returns
-------
shape : tuple
Scratch shape
field = lentil.pad(wavefront.field, pad_shape)
field = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(field), norm='ortho'))
"""
dx = np.broadcast_to(dx, (2,))
du = np.broadcast_to(du, (2,))
fft_shape, _ = _fft_shape(dx, du, z, np.max(wavelength), oversample)
return tuple(fft_shape)

if not (field.shape == shape_out).all():
field = lentil.pad(field, shape_out)

out.data.append(Field(data=field, pixelscale=pixelscale/oversample))

return out
def _dft_alpha(dx, du, wavelength, z, oversample):
return ((dx[0]*du[0])/(wavelength*z*oversample),
(dx[1]*du[1])/(wavelength*z*oversample))


def _fft_shape(dx, du, z, wavelength, oversample):
# Compute pad shape to satisfy requested sampling. Propagation wavelength
# is recomputed to account for integer padding of input plane
alpha = _dft_alpha(dx, du, z, wavelength, oversample)
fft_shape = np.round(np.reciprocal(alpha)).astype(int)
prop_wavelength = np.min((fft_shape/oversample * dx * du)/z)
return fft_shape, prop_wavelength


def _fft2(x):
return np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(x), norm='ortho'))



def _has_tilt(wavefront):
# Return True if and Wavefront Field has nonempty tilt
for field in wavefront.data:
if field.tilt:
return True
return False


def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,
oversample=2):
"""Propagate a Wavefront using Fraunhofer diffraction.
"""Propagate a Wavefront in the far-field using the DFT.
Parameters
----------
Expand All @@ -79,13 +161,13 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,
``prop_shape = prop``. If ``prop_shape != prop``, the propagation
result is placed in the appropriate location in the output plane.
``prop_shape`` should not be larger than ``prop``.
oversample : int, optional
oversample : float, optional
Number of times to oversample the output plane. Default is 2.
Returns
-------
wavefront : :class:`~lentil.Wavefront`
The orioagated Wavefront
The propagated Wavefront
"""

ptype_out = _propagate_ptype(wavefront.ptype, method='fraunhofer')
Expand Down Expand Up @@ -117,9 +199,9 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,
subpx_shift = shift - fix_shift

if _overlap(prop_shape_out, fix_shift, shape_out):
alpha = lentil.helper.dft_alpha(dx=dx, du=du, z=z,
wave=wavefront.wavelength,
oversample=oversample)
alpha = _dft_alpha(dx=dx, du=du, z=z,
wavelength=wavefront.wavelength,
oversample=oversample)
data = lentil.fourier.dft2(f=field.data, alpha=alpha,
shape=prop_shape_out,
shift=subpx_shift,
Expand All @@ -132,6 +214,7 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,

return out


def _overlap(field_shape, field_shift, output_shape):
# Return True if there's any overlap between a shifted field and the
# output shape
Expand Down

0 comments on commit 29aec68

Please sign in to comment.