Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
andykee committed Nov 18, 2023
1 parent 01fac46 commit f5ba587
Showing 1 changed file with 82 additions and 24 deletions.
106 changes: 82 additions & 24 deletions lentil/propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,

shape = wavefront.shape if shape is None else np.broadcast_to(shape, (2,))
prop_shape = shape if prop_shape is None else np.broadcast_to(prop_shape, (2,))
shape_out = shape * oversample
prop_shape_out = prop_shape * oversample
shape_out = np.asarray(shape) * oversample
prop_shape_out = np.asarray(prop_shape) * oversample

dx = wavefront.pixelscale
du = np.broadcast_to(pixelscale, (2,))
Expand Down Expand Up @@ -70,15 +70,21 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,
fix_shift = np.fix(shift)
subpx_shift = shift - fix_shift

if _overlap(prop_shape_out, fix_shift, shape_out):
#if _overlap(prop_shape_out, fix_shift, shape_out):
prop_shape_out, fix_shift = _update_shape_offset(out_shape=shape_out,
field_shape=prop_shape_out,
field_offset=fix_shift)
if any(prop_shape_out):
alpha = lentil.helper.dft_alpha(dx=dx, du=du, z=z,
wave=wavefront.wavelength,
oversample=oversample)
data = lentil.fourier.dft2(f=field.data, alpha=alpha,
shape=prop_shape_out,
shift=subpx_shift,
cin=field.offset,
cin=field.offset,
cout=(0,0),
unitary=True)

out.data.append(Field(data=data, pixelscale=du/oversample,
offset=fix_shift))

Expand All @@ -87,25 +93,6 @@ def propagate_dft(wavefront, pixelscale, shape=None, prop_shape=None,

return out

def _overlap(field_shape, field_shift, output_shape):
# Return True if there's any overlap between a shifted field and the
# output shape
output_shape = np.asarray(output_shape)
field_shape = np.asarray(field_shape)
field_shift = np.asarray(field_shift)

# Output coordinates of the upper left corner of the shifted data array
field_shifted_ul = (output_shape / 2) - (field_shape / 2) + field_shift

if field_shifted_ul[0] > output_shape[0]:
return False
if field_shifted_ul[0] + field_shape[0] < 0:
return False
if field_shifted_ul[1] > output_shape[1]:
return False
if field_shifted_ul[1] + field_shape[1] < 0:
return False
return True

def _propagate_ptype(ptype, method='fraunhofer'):
if method == 'fraunhofer':
Expand All @@ -116,4 +103,75 @@ def _propagate_ptype(ptype, method='fraunhofer'):
if ptype == lentil.pupil:
return lentil.image
else:
return lentil.pupil
return lentil.pupil


def _update_shape_offset(out_shape, field_shape, field_offset, mask=None):
out_shape = tuple(int(n) for n in out_shape)
field_shape = tuple(int(n) for n in field_shape)
field_offset = tuple(int(n) for n in field_offset)
field_ul = (out_shape[0]//2 - field_shape[0]//2 + field_offset[0],
out_shape[1]//2 - field_shape[1]//2 + field_offset[1])

if _overlap(out_shape, field_shape, field_ul):

# Field slice indices
field_rmin = int(0)
field_rmax = int(field_shape[0])
field_cmin = int(0)
field_cmax = int(field_shape[1])

# Output insertion slice indices
out_rmin = int(field_ul[0])
out_rmax = int(field_ul[0] + field_shape[0])
out_cmin = int(field_ul[1])
out_cmax = int(field_ul[1] + field_shape[1])

# reconcile the field and output insertion indices
if out_rmin < 0:
field_rmin = -1 * out_rmin
out_rmin = 0

if out_rmax > out_shape[0]:
field_rmax -= out_rmax - out_shape[0]
out_rmax = out_shape[0]

if out_cmin < 0:
field_cmin = -1 * out_cmin
out_cmin = 0

if out_cmax > out_shape[1]:
field_cmax -= out_cmax - out_shape[1]
out_cmax = out_shape[1]

out_center = (out_shape[0]//2, out_shape[1]//2)

field_shape = (field_rmax-field_rmin, field_cmax-field_cmin)
#print('field shape:', field_shape)

field_center = (out_rmin + field_shape[0]//2,
out_cmin + field_shape[1]//2)
#print('field center:', field_center)

field_offset = (field_center[0] - out_center[0],
field_center[1] - out_center[1])
#print('field offset:', field_offset)

return field_shape, field_offset

else:
return (), field_offset


def _overlap(out_shape, field_shape, field_ul):
# Return True if there's any overlap between a shifted field and the
# output shape
if field_ul[0] > out_shape[0]:
return False
if field_ul[0] + field_shape[0] < 0:
return False
if field_ul[1] > out_shape[1]:
return False
if field_ul[1] + field_shape[1] < 0:
return False
return True

0 comments on commit f5ba587

Please sign in to comment.