diff --git a/lentil/__init__.py b/lentil/__init__.py index 528fd83..c35ef09 100644 --- a/lentil/__init__.py +++ b/lentil/__init__.py @@ -32,7 +32,11 @@ Flip ) -from lentil.propagate import propagate_dft +from lentil.propagate import ( + propagate_dft, + propagate_fft, + scratch_shape +) from lentil import radiometry diff --git a/lentil/helper.py b/lentil/helper.py index e84ad48..ede86c6 100644 --- a/lentil/helper.py +++ b/lentil/helper.py @@ -23,11 +23,6 @@ def gaussian2d(size, sigma): return G/np.sum(G) -def dft_alpha(dx, du, wave, z, oversample): - return ((dx[0]*du[0])/(wave*z*oversample), - (dx[1]*du[0])/(wave*z*oversample)) - - def boundary_slice(x, threshold=0, pad=(0, 0)): """Find bounding row and column indices of data within an array and return the results as slice objects. diff --git a/lentil/propagate.py b/lentil/propagate.py index eecbd4c..cae72c4 100644 --- a/lentil/propagate.py +++ b/lentil/propagate.py @@ -7,62 +7,144 @@ def propagate_fft(wavefront, pixelscale, shape=None, oversample=2, scratch=None): + """Propagate a Wavefront in the far-field using the FFT. + + Parameters + ---------- + wavefront : :class:`~lentil.Wavefront` + Wavefront to propagate + pixelscale : float or (2,) float + Physical sampling of output Wavefront. If a single value is supplied, + the output is assumed to be uniformly sampled in both x and y. + shape : int or (2,) tuple of ints or None + Shape of output Wavefront. If None (default), the wavefront shape is + used. + oversample : float, optional + Number of times to oversample the output plane. Default is 2. + scratch : complex ndarray, optional + A pre-allocated array used for padding. Providing a sufficiently + large scratch array can improve broadband propagation performance by + avoiding the repeated allocation of arrays of zeros. - alpha = (wavefront.pixelscale * pixelscale)/(wavefront.wavelength * wavefront.focal_length) - pad_shape = np.round(oversample * np.reciprocal(alpha)).astype(int) + Returns + ------- + wavefront : :class:`~lentil.Wavefront` + The propagated Wavefront - # since we can only pad to integer precision, the actual wavelength - # represented by the propagation may be slightly different - wavelength_prop = (pad_shape/oversample * wavefront.pixelscale * pixelscale)/wavefront.focal_length + """ + if _has_tilt(wavefront): + raise NotImplementedError('propagate_fft does not support Wavefronts ' + 'with fitted tilt. Use propagate_dft instead.') + + ptype_out = _propagate_ptype(wavefront.ptype, method='fraunhofer') + pixelscale = np.broadcast_to(pixelscale, (2,)) + fft_shape, prop_wavelength = _fft_shape(wavefront.pixelscale, + pixelscale, + wavefront.focal_length, + wavefront.wavelength, + oversample) + + # TODO: verify field_shape is smaller than fft_shape + # TODO: what if field is bigger than fft_shape? + # field_shape = wavefront.shape if shape is None: - shape_out = pad_shape + shape_out = tuple(fft_shape) + shape = (fft_shape[0]//oversample, fft_shape[1]//oversample) else: - shape = np.broadcast_to(shape, (2,)) - shape_out = (np.array(shape) * oversample).astype(int) - - ptype_out = _propagate_ptype(wavefront.ptype, method='fraunhofer') + shape = tuple(np.broadcast_to(shape, (2,))) + if np.any(shape > fft_shape/oversample): + raise ValueError(f'requested shape {tuple(shape)} is larger in at ' + f'least one dimension than maximum propagation ' + f'shape {tuple(fft_shape//oversample)}') + else: + shape_out = (shape[0] * oversample, shape[1]*oversample) - out = Wavefront.empty(wavelength=wavelength_prop, + out = Wavefront.empty(wavelength=prop_wavelength, pixelscale = pixelscale/oversample, shape = shape_out, ptype = ptype_out) + + if scratch is not None: + if not all(np.asarray(scratch.shape) > fft_shape): + raise ValueError(f'scratch must have shape greater than or ' + f'equal to {tuple(fft_shape)}') + + # zero out the portion of scratch that we're going to use for the + # propagation and then insert the Wavefront field(s) into scratch + scratch[0:fft_shape[0], 0:fft_shape[1]] = 0 + for field in wavefront.data: + scratch[0:fft_shape[0], 0:fft_shape[1]] = lentil.field.insert(field, scratch[0:fft_shape[0], 0:fft_shape[1]]) + field =_fft2(scratch[0:fft_shape[0], 0:fft_shape[1]]) - # if scratch is not None: - # if scratch.dtype != complex: - # raise TypeError('scratch must be complex') - - # if not all(np.asarray(scratch.shape) > pad_shape): - # raise ValueError(f'scratch must have shape >= {pad_shape}') + else: + field = lentil.pad(wavefront.field, fft_shape) + field = _fft2(field) + + out.data.append(Field(data=field, pixelscale=pixelscale/oversample)) - # field = scratch # field is just a reference to scratch - # field[:] = 0 + return out + +def scratch_shape(wavelength, dx, du, z, oversample): + """Compute the scratch shape required for an FFT propagation - # else: - # field = wavefront.field + Parameters + ---------- + wavelength : float or array_like + Wavelength or list of wavelengths + dx : float or (2,) float + Physical sampling of input Plane. If a single value is supplied, + the input is assumed to be uniformly sampled in both x and y. + du : float or (2,) float + Physical sampling of output Wavefront. If a single value is supplied, + the output is assumed to be uniformly sampled in both x and y. + z : float + Propagation distance + oversample : float + Number of times to oversample the output plane. - # if wavefront.tilt: - # raise ValueError + Returns + ------- + shape : tuple + Scratch shape - field = lentil.pad(wavefront.field, pad_shape) - field = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(field), norm='ortho')) + """ + dx = np.broadcast_to(dx, (2,)) + du = np.broadcast_to(du, (2,)) + fft_shape, _ = _fft_shape(dx, du, z, np.max(wavelength), oversample) + return tuple(fft_shape) - if not (field.shape == shape_out).all(): - field = lentil.pad(field, shape_out) - - out.data.append(Field(data=field, pixelscale=pixelscale/oversample)) - return out +def _dft_alpha(dx, du, wavelength, z, oversample): + return ((dx[0]*du[0])/(wavelength*z*oversample), + (dx[1]*du[1])/(wavelength*z*oversample)) +def _fft_shape(dx, du, z, wavelength, oversample): + # Compute pad shape to satisfy requested sampling. Propagation wavelength + # is recomputed to account for integer padding of input plane + alpha = _dft_alpha(dx, du, z, wavelength, oversample) + fft_shape = np.round(np.reciprocal(alpha)).astype(int) + prop_wavelength = np.min((fft_shape/oversample * dx * du)/z) + return fft_shape, prop_wavelength + + +def _fft2(x): + return np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(x), norm='ortho')) - + +def _has_tilt(wavefront): + # Return True if and Wavefront Field has nonempty tilt + for field in wavefront.data: + if field.tilt: + return True + return False def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None, oversample=2): - """Propagate a Wavefront using Fraunhofer diffraction. + """Propagate a Wavefront in the far-field using the DFT. Parameters ---------- @@ -79,13 +161,13 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None, ``prop_shape = prop``. If ``prop_shape != prop``, the propagation result is placed in the appropriate location in the output plane. ``prop_shape`` should not be larger than ``prop``. - oversample : int, optional + oversample : float, optional Number of times to oversample the output plane. Default is 2. Returns ------- wavefront : :class:`~lentil.Wavefront` - The orioagated Wavefront + The propagated Wavefront """ ptype_out = _propagate_ptype(wavefront.ptype, method='fraunhofer') @@ -117,9 +199,9 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None, subpx_shift = shift - fix_shift if _overlap(prop_shape_out, fix_shift, shape_out): - alpha = lentil.helper.dft_alpha(dx=dx, du=du, z=z, - wave=wavefront.wavelength, - oversample=oversample) + alpha = _dft_alpha(dx=dx, du=du, z=z, + wavelength=wavefront.wavelength, + oversample=oversample) data = lentil.fourier.dft2(f=field.data, alpha=alpha, shape=prop_shape_out, shift=subpx_shift, @@ -132,6 +214,7 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None, return out + def _overlap(field_shape, field_shift, output_shape): # Return True if there's any overlap between a shifted field and the # output shape