Skip to content

Commit

Permalink
Masked DFT propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
andykee committed Mar 17, 2024
1 parent ea74fc4 commit cad2faa
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/ref/internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Extent
:toctree: generated/

lentil.extent.array_extent
lentil.extent.array_center
lentil.extent.intersect
lentil.extent.intersection_extent
lentil.extent.intersection_shape
Expand Down
84 changes: 76 additions & 8 deletions lentil/extent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Functions for working with extent (rmon, rmax, cmin, cmax) data
# Functions for working with extent (rmin, rmax, cmin, cmax) data

import numpy as np

Expand All @@ -8,16 +8,20 @@ def array_extent(shape, shift, parent_shape=None):
Parameters
----------
shape : (2,) array_like
Array shape
shift : (2,) array_like
Array shift in (r, c)
parent_shape : (2,) array like or None, optional
Enclosing parent shape. If None, the returned extent is relative
to the origin (0,0). If provided, the returned extent is relative
to the upper left corner of the parent shape.
Notes
-----
To use the values returned by ``extent()`` in a slice,
``rmax`` and ``cmax`` should be increased by 1.
"""

if len(shape) < 2:
shape = (1, 1)

Expand All @@ -36,13 +40,32 @@ def array_extent(shape, shift, parent_shape=None):
return rmin, rmax, cmin, cmax


def array_center(extent):
"""Compute the center of an extent
Parameters
----------
a : (4,) array_like
Array extent (rmin, rmax, cmin, cmax)
Returns
-------
tuple
"""

rmin, rmax, cmin, cmax = extent
nrow = rmax - rmin + 1
ncol = cmax - cmin + 1
return rmin + nrow//2, cmin + ncol//2


def intersect(a, b):
"""Return True if two extents intersect, otherwise False
Parameters
----------
a, b : (4,) array like
Two array extents (rmin, rmax, cmin, cmax)
Array extents (rmin, rmax, cmin, cmax)
Returns
-------
Expand All @@ -55,7 +78,18 @@ def intersect(a, b):


def intersection_extent(a, b):
# bounding array indices to be multiplied
"""Compute the extent of two overlapping extents
Parameters
----------
a, b : (4,) array_like
Array extents (rmin, rmax, cmin, cmax)
Returns
-------
tuple
"""

armin, armax, acmin, acmax = a
brmin, brmax, bcmin, bcmax = b

Expand All @@ -66,13 +100,23 @@ def intersection_extent(a, b):


def intersection_shape(a, b):
"""Compute the shape
"""Compute the shape of two overlapping extents. If there is no
overlap, an empty tuple is returned.
Parameters
----------
a, b : (4,) array like
Array extents (rmin, rmax, cmin, cmax)
Returns
-------
tuple
"""

rmin, rmax, cmin, cmax = intersection_extent(a, b)
nr, nc = rmax - rmin + 1, cmax - cmin + 1

if nr < 0 or nc < 0:
if nr <= 0 or nc <= 0:
shape = ()
else:
shape = (nr, nc)
Expand All @@ -81,6 +125,18 @@ def intersection_shape(a, b):


def intersection_slices(a, b):
"""Compute slices of overlapping areas between two overlapping extents
Parameters
----------
a, b : (4,) array like
Array extents (rmin, rmax, cmin, cmax)
Returns
-------
tuples of slices
"""

rmin, rmax, cmin, cmax = intersection_extent(a, b)

armin, armax, acmin, acmax = a
Expand All @@ -95,6 +151,18 @@ def intersection_slices(a, b):


def intersection_shift(a, b):
"""Compute the shift between two overlapping extents
Parameters
----------
a, b : (4,) array like
Array extents (rmin, rmax, cmin, cmax)
Returns
-------
tuple
"""

rmin, rmax, cmin, cmax = intersection_extent(a, b)
nrow = rmax - rmin + 1
ncol = cmax - cmin + 1
Expand Down
62 changes: 54 additions & 8 deletions lentil/propagate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

import lentil
import lentil.extent
from lentil.field import Field
from lentil.wavefront import Wavefront

Expand Down Expand Up @@ -144,7 +145,7 @@ def _has_tilt(wavefront):


def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,
oversample=2):
oversample=2, mask=None):
"""Propagate a Wavefront in the far-field using the DFT.
Parameters
Expand Down Expand Up @@ -178,6 +179,16 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,
shape_out = shape * oversample
prop_shape_out = prop_shape * oversample

if mask is not None:
mask = np.asarray(mask)
if np.all(mask.shape != shape_out):
raise ValueError(f'shape mismatch: mask shape {mask.shape} != output shape {tuple(shape_out)}')
mask_shape = _mask_shape(mask, threshold=0)
mask_shift = _mask_shift(mask, threshold=0)
out_extent = lentil.extent.array_extent(mask_shape, mask_shift)
else:
out_extent = lentil.extent.array_extent(shape_out, shift=(0,0))

dx = wavefront.pixelscale
du = np.broadcast_to(pixelscale, (2,))
z = wavefront.focal_length
Expand All @@ -192,31 +203,66 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,

for field in data:
# compute the field shift from any embedded tilts. note the return value
# is specified in terms of (r, c)
# is specified as (r, c)
shift = field.shift(z=wavefront.focal_length, wavelength=wavefront.wavelength,
pixelscale=du, oversample=oversample,
indexing='ij')

fix_shift = np.fix(shift)
subpx_shift = shift - fix_shift

if _overlap(prop_shape_out, fix_shift, shape_out):
prop_extent = lentil.extent.array_extent(prop_shape_out, fix_shift)

if lentil.extent.intersect(out_extent, prop_extent):
intersect_shape = lentil.extent.intersection_shape(out_extent, prop_extent)
intersect_shift = lentil.extent.intersection_shift(out_extent, prop_extent)
intersect_extent = lentil.extent.array_extent(intersect_shape, intersect_shift)

# compute additional shift rquired to offset any output clipping that
# may have occurred due to a mask or extending beyond the full output
# shape.
# NOTE: It appears the same functionality is available using
# extent.intersection_shift() but there is an off by one error on only
# one of the shift dimensions for some reason that causes the tests to
# fail
prop_center = lentil.extent.array_center(prop_extent)
intersect_center = lentil.extent.array_center(intersect_extent)
prop_shift = np.array(prop_center) - np.array(intersect_center)

alpha = _dft_alpha(dx=dx, du=du, z=z,
wavelength=wavefront.wavelength,
oversample=oversample)
data = lentil.fourier.dft2(f=field.data, alpha=alpha,
shape=prop_shape_out,
shift=subpx_shift,
shape=intersect_shape,
shift=prop_shift + subpx_shift,
offset=field.offset, unitary=True)
out.data.append(Field(data=data, pixelscale=du/oversample,
offset=fix_shift))
out.data.append(Field(data=data, pixelscale=du / oversample,
offset=intersect_shift))

if not out.data:
out.data.append(Field(data=0))

return out


def _mask_shape(x, threshold=0):
# compute the shape of a masked area inside an array of zeros
rmin, rmax, cmin, cmax = lentil.boundary(x, threshold)
return rmax - rmin + 1, cmax - cmin + 1


def _mask_shift(x, threshold=0):
# compute the shift of a masked area inside an array of zeros
shape_full = x.shape
rc_full, cc_full = shape_full[0]//2, shape_full[1]//2

rmin_extent, rmax_extent, cmin_extent, cmax_extent = lentil.boundary(x, threshold)
shape_extent = (rmax_extent-rmin_extent+1, cmax_extent-cmin_extent+1)
rc_extent, cc_extent = rmin_extent + shape_extent[0]//2, cmin_extent + shape_extent[1]//2

return rc_extent - rc_full, cc_extent - cc_full


def _overlap(field_shape, field_shift, output_shape):
# Return True if there's any overlap between a shifted field and the
# output shape
Expand Down
38 changes: 38 additions & 0 deletions tests/test_propagate_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import lentil

def test_propagate_mask(pupil):
p = pupil(focal_length=10, diameter=1, shape=(256,256), radius=120, coeffs=None)

mask = lentil.rectangle((256,256), 128,128, shift=(20, -30), antialias=False)

w = lentil.Wavefront(650e-9)
w = p.multiply(w)
w = lentil.propagate_dft(w, shape=128, pixelscale=5e-6, oversample=2)
psf = w.intensity

w_mask = lentil.Wavefront(650e-9)
w_mask = p.multiply(w_mask)
w_mask = lentil.propagate_dft(w_mask, shape=128, pixelscale=5e-6, oversample=2, mask=mask)
psf_mask = w_mask.intensity

assert np.allclose(psf_mask, psf*mask)


def test_propagate_mask_tilt_analytic(pupil):
p = pupil(focal_length=10, diameter=1, shape=(256,256), radius=120, coeffs=[0,1e-6, 2e-6])

mask = lentil.rectangle((256,256), 64, 64, shift=(20, -30), antialias=False)

w = lentil.Wavefront(650e-9)
w = p.multiply(w)
w = lentil.propagate_dft(w, shape=128, pixelscale=5e-6, oversample=2)
psf = w.intensity

w_mask = lentil.Wavefront(650e-9)
w_mask = p.multiply(w_mask)
w_mask = lentil.propagate_dft(w_mask, shape=128, pixelscale=5e-6, oversample=2, mask=mask)
psf_mask = w_mask.intensity

assert np.allclose(psf_mask, psf*mask)

0 comments on commit cad2faa

Please sign in to comment.