From c2ed129475fdfda07172600513cc1fa6ccafcd72 Mon Sep 17 00:00:00 2001 From: Jarron Leisenring Date: Fri, 30 Aug 2024 14:23:05 -0700 Subject: [PATCH] optimize PSF alignment options --- webbpsf_ext/imreg_tools.py | 435 +++++++++++++++++++++++++++++-------- 1 file changed, 349 insertions(+), 86 deletions(-) diff --git a/webbpsf_ext/imreg_tools.py b/webbpsf_ext/imreg_tools.py index 82c7b27..dffa585 100644 --- a/webbpsf_ext/imreg_tools.py +++ b/webbpsf_ext/imreg_tools.py @@ -943,7 +943,7 @@ def get_expected_loc(input, return_indices=True, add_sroffset=None): if not add_sroffset: xoff_asec = yoff_asec = 0.0 # Add in a SGD offsets if they exist - if is_sgd: + if is_sgd and (sgd_pattern is not None): xoff_arr, yoff_arr = get_sgd_offsets(sgd_pattern) xoff_asec += xoff_arr[sgd_pos] yoff_asec += yoff_arr[sgd_pos] @@ -1229,7 +1229,7 @@ def load_cropped_files(save_dir, files, xysub=65, bgsub=False, def recenter_psf(psfs_over, niter=3, halfwidth=7, gfit=True, in_place=False, **kwargs): - """Use center of mass algorithm to relocate PSF to center of image. + """Use Gaussian fit or center of mass algorithm to relocate PSF to center of image. Returns recentered PSFs and shift values used. @@ -1276,9 +1276,13 @@ def recenter_psf(psfs_over, niter=3, halfwidth=7, psfs_over[i] = psf xyoff_psfs_over.append(np.array([xsh_sum, ysh_sum])) + gc_str = 'Gaussian Fit' if gfit else 'CoM' + _log.info(f"Recentered oversampled PSF ({xsh_sum:.3f}, {ysh_sum:.3f}) pixels using {gc_str} algorithm.") + # Oversampled offsets xyoff_psfs_over = np.array(xyoff_psfs_over) + # If input was a single image, return same dimensions if ndim==2: psfs_over = psfs_over[0] @@ -1287,6 +1291,178 @@ def recenter_psf(psfs_over, niter=3, halfwidth=7, return psfs_over, xyoff_psfs_over +def subtract_psf(image, psf, osamp=1, bpmask=None, rin=None, rout=None, + xyshift=(0,0), psf_scale=None, psf_offset=0, + method='fourier', interp='lanczos', pad=True, cval=0, + kipc=None, kppc=None, diffusion_sigma=None, psf_corr_over=None, + weights=None, return_sum2=False, return_scale=False, **kwargs): + """ Subtract PSF from image + + Provide scale, offset, and shift values to PSF before subtraction. + Uses `fractional_image_shift` function to shift PSF. + + Parameters + ---------- + image: ndarray + Observed science image. + psf: ndarray + Oversampled PSF (shifted and scaled to match). + osamp: int + Oversampling factor of PSF. + bpmask: bool array + Bad pixel mask indicating pixels in input image to ignore. + rin: float + Inner radius of annulus for subtraction. Default is None. + rout: float + Outer radius of annulus for subtraction. Default is None. + xyshift: tuple + Shift values in (x,y) directions. Units of pixels. + psf_scale: float + Scale factor to apply to PSF. If set to None, then will + find the best scaling factor. + psf_offset: float + Offset to apply to PSF. + psf_corr_over: ndarray + Oversampled PSF correction image. If provided, then this + image is multiplied with the PSF after diffusion. These are + empirical corrections to the WebbPSF model to better match + the observed PSF. + kipc: ndarray + 3x3 array of IPC kernel values. If None, then no IPC is applied. + kppc: ndarray + 3x3 array of PPC kernel values. If None, then no PPC is applied. + Should already be oriented along readout direction of PSF. + diffusion_sigma: float + Sigma value for Gaussian diffusion kernel. If None, then + no diffusion is applied. In units of detector pixels. + weights: ndarray + Array of weights to use during the fitting process. + Useful if you have bad pixels to mask out (ie., + set them to zero). Default is None (no weights). + Should be same size as image. + Recommended is inverse variance map. + method : str + Method to use for shifting. Options are: + - 'fourier' : Shift in Fourier space + - 'fshift' : Shift using interpolation + - 'opencv' : Shift using OpenCV warpAffine + interp : str + Interpolation method to use for shifting using 'fshift' or 'opencv. + Default is 'cubic'. + For 'opencv', valid options are 'linear', 'cubic', and 'lanczos'. + for 'fshift', valid options are 'linear', 'cubic', and 'quintic'. + pad : bool + Should we pad the array before shifting, then truncate? + Otherwise, the image is wrapped. + cval : sequence or float, optional + The values to set the padded values for each axis. Default is 0. + ((before_1, after_1), ... (before_N, after_N)) unique pad constants for each axis. + ((before, after),) yields same before and after constants for each axis. + (constant,) or int is a shortcut for before = after = constant for all axes. + return_sum2 : bool + Return the sum of the squared difference between the image + and PSF. Default is False. + + Keyword Args + ------------ + gstd_pix : float + Standard deviation of Gaussian kernel to blur PSF during shift. + oversample : int + Oversampling factor for fractional shift. Default is 1. + order : int + Interpolation order for oversampling during shifting. Default is 1. + rescale_pix : bool + Explicitly rescale the pixel values during resampling to ensure that + the flux within a superpixel is preserved. + Default is False (zoom default behavior). + """ + + from webbpsf_ext.image_manip import image_shift_with_nans + # from webbpsf_ext.image_manip import apply_pixel_diffusion, add_ipc, add_ppc + # from webbpsf_ext.coords import dist_image + + # Shift oversampled PSF and + xsh_over, ysh_over = np.array(xyshift) * osamp + if method is not None: + kwargs_shift = {} + kwargs_shift['pad'] = pad + kwargs_shift['cval'] = cval + if method in ['fshift', 'opencv']: + kwargs_shift['interp'] = interp + # Scale Gaussian std dev by oversampling factor + gstd_pix = kwargs.pop('gstd_pix', None) + if gstd_pix is not None: + kwargs_shift['gstd_pix'] = gstd_pix * osamp + # psf_over = fractional_image_shift(psf, xsh_over, ysh_over, method=method, **kwargs_shift) + + # Perform oversampling during shifting process? + kwargs_shift['oversample'] = kwargs.pop('oversample', 1) + kwargs_shift['order'] = kwargs.pop('order', 1) + kwargs_shift['rescale_pix'] = kwargs.pop('rescale_pix', False) + psf_over = image_shift_with_nans(psf, xsh_over, ysh_over, shift_method=method, **kwargs_shift) + + # Charge diffusion + if diffusion_sigma is not None: + sigma_osamp = diffusion_sigma * osamp + psf_over = apply_pixel_diffusion(psf_over, sigma_osamp) + + # Apply PSF correction + if psf_corr_over is not None: + psf_over *= crop_image(psf_corr_over, psf_over.shape, fill_val=1) + + # Rebin to detector sampling + psf_det = frebin(psf_over, scale=1/osamp) if osamp!=1 else psf_over + + # Add IPC to detector-sampled PSF + if kipc is not None: + psf_det = add_ipc(psf_det, kernel=kipc) + + if kppc is not None: + psf_det = add_ppc(psf_det, kernel=kppc, nchans=1) + + # Crop image + if psf_det.shape != image.shape: + psf_det = crop_image(psf_det, image.shape) + + if psf_scale is None: + # Get optimal scale factor between images + # Ignore NaNs and zeros + good_mask = ~np.isnan(image) & ~np.isnan(psf_det) + good_mask = good_mask & (~np.isclose(image,0)) & (~np.isclose(psf_det,0)) + if bpmask is not None: + good_mask &= ~bpmask + + if (rin is not None) or (rout is not None): + rho = dist_image(image) + rin = 0 if rin is None else rin + rout = np.inf if rout is None else rout + good_mask &= (rho >= rin) & (rho <= rout) + + im_good = image[good_mask].flatten() - psf_offset + psf_good = psf_det[good_mask].flatten() + cf = np.linalg.lstsq(psf_good.reshape([1,-1]).T, im_good, rcond=None)[0] + psf_scale = cf[0] + + psf_det = psf_det * psf_scale + psf_offset + + # Subtract PSF from image + diff = image - psf_det + + if weights is not None: + diff = diff * weights + + if return_sum2: + # Set anything that are 0 in either image as zero in difference + zmask = np.isclose(image,0) | np.isclose(psf_det,0) + nmask = np.isnan(image) | np.isnan(psf_det) + mask = zmask | nmask + if bpmask is not None: + mask |= bpmask + diff[mask] = 0 + return (np.sum(diff**2), psf_scale) if return_scale else np.sum(diff**2) + else: + return (diff, psf_scale) if return_scale else diff + def correl_images(im1, im2, mask=None): """ Image correlation coefficient @@ -1395,7 +1571,8 @@ def find_max_crosscorr(corr, xsh_arr, ysh_arr, sub_sample): def gen_psf_offsets(psf, crop=65, xlim_pix=(-3,3), ylim_pix=(-3,3), dxy=0.05, psf_osamp=1, shift_func=fourier_imshift, ipc_vals=None, kipc=None, - kppc=None, diffusion_sigma=None, monitor=False, prog_leave=False, **kwargs): + kppc=None, diffusion_sigma=None, psf_corr_image=None, + monitor=False, prog_leave=False, **kwargs): """ Generate a series of downsampled cropped and shifted PSF images If fov_pix is odd, then crop should be odd. @@ -1454,7 +1631,7 @@ def gen_psf_offsets(psf, crop=65, xlim_pix=(-3,3), ylim_pix=(-3,3), dxy=0.05, yoff_over = yoff*psf_osamp crop_over = crop*psf_osamp - psf_sh = crop_image(psf0, crop_over, xyloc=None, delx=-xoff_over, dely=-yoff_over, + psf_sh = crop_image(psf0, crop_over, xyloc=None, delx=xoff_over, dely=yoff_over, shift_func=shift_func, **kwargs) # psf_sh = pad_or_cut_to_size(psf0, crop_over, offset_vals=(-yoff_over,-xoff_over), # shift_func=shift_func, pad=True) @@ -1464,6 +1641,13 @@ def gen_psf_offsets(psf, crop=65, xlim_pix=(-3,3), ylim_pix=(-3,3), dxy=0.05, dsig = diffusion_sigma * psf_osamp psf_sh = apply_pixel_diffusion(psf_sh, dsig) + # Apply PSF correction image + if psf_corr_image is not None: + psf_corr_im_sh = crop_image(psf_corr_image, crop_over, xyloc=None, + delx=xoff_over, dely=yoff_over, + shift_func=shift_func, fill_val=1, **kwargs) + psf_sh *= psf_corr_im_sh + # Rebin to detector pixels psf_sh = frebin(psf_sh, scale=1/psf_osamp) psf_sh_all.append(psf_sh) @@ -1547,7 +1731,8 @@ def find_offsets(input, psf, crop=65, xlim_pix=(-3,3), ylim_pix=(-3,3), def find_offsets2(input, xoff_pix, yoff_pix, psf_sh_all, bpmasks=None, - crop=65, rin=0, rout=None, dxy_fine=0.01, prog_leave=True, **kwargs): + crop=65, rin=0, rout=None, dxy_fine=0.01, prog_leave=True, + return_more=False, lsq_diff=False, **kwargs): """Find offsets necessary to align observations with input psf""" # Check if input is a dictionary @@ -1581,6 +1766,8 @@ def find_offsets2(input, xoff_pix, yoff_pix, psf_sh_all, bpmasks=None, iter_vals = tqdm(input, leave=prog_leave) # iter_vals = tqdm(keys,leave=prog_leave) if is_dict else tqdm(input,leave=prog_leave) i = 0 + if return_more: + res_dict = {} for val in iter_vals: if crop is None: @@ -1599,7 +1786,7 @@ def find_offsets2(input, xoff_pix, yoff_pix, psf_sh_all, bpmasks=None, im = crop_image(val, crop) # Crop PSFs to match size - psf_sh_crop = np.array([crop_image(psf, crop) for psf in psf_sh_all]) + psf_sh_crop = crop_image(psf_sh_all, crop) # Crop bp mask to match if bpmasks is None: @@ -1620,12 +1807,20 @@ def find_offsets2(input, xoff_pix, yoff_pix, psf_sh_all, bpmasks=None, zmask2 = np.sum(nanmask_psf, axis=0) == 0 ind_mask = rmask & zmask & zmask2 & (~bpmask) - # Cross-correlate to find best x,y shift to align image with PSF - cc = correl_images(psf_sh_crop, im, mask=ind_mask) - cc = cc.reshape(sh_grid) + if lsq_diff: + # Least squares difference + bpmask = ~ind_mask + sum_sqrs = np.array([subtract_psf(im, psf, bpmask=bpmask, return_sum2=True) for psf in psf_sh_crop]) + correlation_metric = 1 / sum_sqrs.reshape(sh_grid) + else: + # Cross-correlate to find best (x,y) shift to align image with PSF + cc = correl_images(psf_sh_crop, im, mask=ind_mask) + correlation_metric = cc.reshape(sh_grid) # Cubic interplotion of cross correlation image onto a finer grid - xsh, ysh = find_max_crosscorr(cc, xoff_pix, yoff_pix, dxy_fine) + xsh, ysh = find_max_crosscorr(correlation_metric, xoff_pix, yoff_pix, dxy_fine) + if return_more: + res_dict[i] = {'corr_map':correlation_metric, 'xoff_pix':xoff_pix, 'yoff_pix':yoff_pix} xsh0_pix.append(xsh) ysh0_pix.append(ysh) @@ -1638,7 +1833,10 @@ def find_offsets2(input, xoff_pix, yoff_pix, psf_sh_all, bpmasks=None, xsh0_pix = xsh0_pix[0] ysh0_pix = ysh0_pix[0] - return xsh0_pix, ysh0_pix + if return_more: + return xsh0_pix, ysh0_pix, res_dict + else: + return xsh0_pix, ysh0_pix def find_offsets_phase(input, psf, crop=65, rin=0, rout=None, dxy_fine=0.01, @@ -1659,7 +1857,7 @@ def find_offsets_phase(input, psf, crop=65, rin=0, rout=None, dxy_fine=0.01, keys = list(input.keys()) if is_dict else None # Ensure PSF is correct size - psf_sub = pad_or_cut_to_size(psf, crop) + psf_sub = crop_image(psf, crop, fill_val=0) xsh0_pix = [] ysh0_pix = [] @@ -1674,7 +1872,7 @@ def find_offsets_phase(input, psf, crop=65, rin=0, rout=None, dxy_fine=0.01, im = crop_observation(imfull, d['ap'], crop).copy() else: imfull = val - im = pad_or_cut_to_size(imfull, crop) + im = crop_image(imfull, crop, fill_val=0) # Create masks rdist = dist_image(im) @@ -1711,29 +1909,43 @@ def find_offsets_phase(input, psf, crop=65, rin=0, rout=None, dxy_fine=0.01, return res.squeeze() -def find_pix_offsets(imsub_arr, psfs, psf_osamp=1, kipc=None, kppc=None, - diffusion_sigma=None, phase=False, bpmask_arr=None, crop=None, **kwargs): +def find_pix_offsets(imsub_arr, psfs, psf_osamp=1, bpmask_arr=None, + crop=None, kipc=None, kppc=None, diffusion_sigma=None, + psf_corr_image=None, phase=False, xcorr=True, lsq_diff=False, + **kwargs): """Find number of pixels to offset PSFs to corrsponding images + + If multple methods are selected, then will return values for each in a dictionary. + If only one method is selected, then will return a single array of offsets. Parameters ---------- imsub_arr : ndarray Array of cropped images psfs : ndarray - Array of PSFs to align to images + Array of PSFs to align to images. Either same number of images + or a single PSF to align to all images. psf_osamp : int Oversampling factor of PSFs + bpmask_arr : ndarray + Bad pixel mask array. Should be same shape as imsub_arr. diffusion_sigma : float - Diffusion kernel sigma value + Diffusion kernel sigma value to apply to psfs. kipc : ndarray - IPC kernel + IPC kernel to apply to PSFs. kppc : ndarray PPC kernel. Should already align to readout direction of detector along rows. phase : bool Use phase cross-correlation to find offsets - bpmask_arr : ndarray - Bad pixel mask array. Should be same shape as imsub_arr. + psf_corr_image : ndarray + Correction factor to multiply PSF after diffussion + align_method : str + Method to use to align images. Options are 'xcorr', 'phase', + or 'lsqdiff'. Default is 'xcorr'. For 'xcorr', peform traditional + corr correlation to find offsets. For 'phase', use phase cross + correlation to find offsets. For 'lsqdiff', use least squares + difference to find offsets. Keyword Args ============ @@ -1743,22 +1955,10 @@ def find_pix_offsets(imsub_arr, psfs, psf_osamp=1, kipc=None, kppc=None, Exclude pixel exterior to this radius. xylim_pix : tuple or list Initial coarse step range in detector pixels. - corr_avg : bool - If True, then find best position using weighted average of the cross-correlation - map along the x and y axes. Otherwise, return the position of the maximum value - in the map. """ - sh_orig = imsub_arr.shape - sh_orig_psfs = psfs.shape - if len(sh_orig)==2: - imsub_arr = [imsub_arr] - bpmask_arr = [bpmask_arr] - psfs = [psfs] - elif len(sh_orig_psfs)==2: - psfs = [psfs] - - def find_pix_phase(im, psf, psf_osamp, kipc=None, kppc=None, diffusion_sigma=None, crop=15): + def find_pix_phase(im, psf, psf_osamp, kipc=None, kppc=None, diffusion_sigma=None, + psf_corr_image=None, crop=15, **kwargs): # Rebin to detector sampling if psf_osamp!=1: psf = frebin(psf, scale=1/psf_osamp) @@ -1766,6 +1966,9 @@ def find_pix_phase(im, psf, psf_osamp, kipc=None, kppc=None, diffusion_sigma=Non # Add diffusion if (diffusion_sigma is not None) and (diffusion_sigma>0): psf = apply_pixel_diffusion(psf, diffusion_sigma) + # Apply PSF correction image + if psf_corr_image is not None: + psf *= crop_image(psf_corr_image, psf.shape[-2:], fill_val=1) # Add IPC if kipc is not None: psf = add_ipc(psf, kernel=kipc) @@ -1773,11 +1976,14 @@ def find_pix_phase(im, psf, psf_osamp, kipc=None, kppc=None, diffusion_sigma=Non if kppc is not None: psf = add_ppc(psf, kernel=kppc, nchans=1) - res = find_offsets_phase(im, psf, crop=crop, rin=0, rout=None, dxy_fine=0.01) + rin = kwargs.get('rin', 0) + rout = kwargs.get('rout', None) + res = find_offsets_phase(im, psf, crop=crop, rin=rin, rout=rout, dxy_fine=0.001) return res - def find_pix_cc(im, psf, psf_osamp, kipc=None, kppc=None, diffusion_sigma=None, - crop=33, bpmask=None, return_grids=False, **kwargs): + def find_pix_cc(im, psf, psf_osamp, bpmask=None, crop=33, + kipc=None, kppc=None, diffusion_sigma=None, psf_corr_image=None, + lsq_diff=False, return_grids=False, **kwargs): """Cross correlate by shifting PSF in fine steps""" # Create a series of coarse offset PSFs to find initial estimate @@ -1786,72 +1992,129 @@ def find_pix_cc(im, psf, psf_osamp, kipc=None, kppc=None, diffusion_sigma=None, xlim_pix = ylim_pix = xylim_pix else: xlim_pix = ylim_pix = (-5,5) - res1 = kwargs.get('res1', None) - if res1 is None: - res1 = gen_psf_offsets(psf, crop=crop, xlim_pix=xlim_pix, ylim_pix=xlim_pix, dxy=0.5, - psf_osamp=psf_osamp, kipc=None, kppc=None, diffusion_sigma=None, - prog_leave=False, shift_func=fshift, **kwargs) - xoff_pix, yoff_pix, psf_sh_all = res1 - - # psf_sh_all are cropped to `crop`` value, whereas im is still input size - xsh_coarse, ysh_coarse = find_offsets2(im, xoff_pix, yoff_pix, psf_sh_all, crop=crop, - dxy_fine=0.5, bpmasks=bpmask, prog_leave=False, **kwargs) - - # Create finer grid off offset PSFs - xlim_pix = (xsh_coarse-0.25, xsh_coarse+0.25) - ylim_pix = (ysh_coarse-0.25, ysh_coarse+0.25) - res2 = kwargs.get('res2', None) - if res2 is None: - res2 = gen_psf_offsets(psf, crop=crop, xlim_pix=xlim_pix, ylim_pix=ylim_pix, dxy=0.01, - psf_osamp=psf_osamp, kipc=kipc, kppc=kppc, diffusion_sigma=diffusion_sigma, - prog_leave=False, **kwargs) + + dxy_coarse = kwargs.pop('dxy_coarse', 0.250) + dxy_fine = kwargs.pop('dxy_fine', 0.005) + + res_coarse = kwargs.get('res_coarse', None) + if res_coarse is None: + res_coarse = gen_psf_offsets(psf, crop=crop, xlim_pix=xlim_pix, ylim_pix=xlim_pix, dxy=dxy_coarse, + psf_osamp=psf_osamp, kipc=None, kppc=None, diffusion_sigma=None, + psf_corr_image=psf_corr_image, prog_leave=False, + shift_func=fshift, **kwargs) + xoff_pix, yoff_pix, psf_sh_all = res_coarse + + # psf_sh_all are cropped to `crop` value, whereas im is still input size + xsh_coarse, ysh_coarse = find_offsets2(im, xoff_pix, yoff_pix, psf_sh_all, bpmasks=bpmask, crop=crop, + dxy_fine=dxy_coarse, prog_leave=False, **kwargs) + + # Create finer grid of offset PSFs + xlim_pix = (xsh_coarse-dxy_coarse/2, xsh_coarse+dxy_coarse/2) + ylim_pix = (ysh_coarse-dxy_coarse/2, ysh_coarse+dxy_coarse/2) + res2 = gen_psf_offsets(psf, crop=crop, xlim_pix=xlim_pix, ylim_pix=ylim_pix, dxy=dxy_fine, + psf_osamp=psf_osamp, kipc=kipc, kppc=kppc, diffusion_sigma=diffusion_sigma, + psf_corr_image=psf_corr_image, prog_leave=False, **kwargs) xoff_pix, yoff_pix, psf_sh_all = res2 - # Perform cross correlations and interpolate at 0.01 pixel - xsh_fine, ysh_fine = find_offsets2(im, xoff_pix, yoff_pix, psf_sh_all, crop=crop, - dxy_fine=0.005, bpmasks=bpmask, prog_leave=False, **kwargs) + # Perform cross correlations and interpolate at 0.001 pixel + xsh_fine, ysh_fine = find_offsets2(im, xoff_pix, yoff_pix, psf_sh_all, bpmasks=bpmask, crop=crop, + dxy_fine=0.001, lsq_diff=lsq_diff, prog_leave=False, **kwargs) res = (xsh_fine, ysh_fine) if return_grids: - return res, res1, res2 + return res, res_coarse else: return res - xysh_pix = [] - iter_vals = trange(len(imsub_arr), desc='Image XCorr', leave=False) if len(imsub_arr)>=10 else range(len(imsub_arr)) + sh_orig = imsub_arr.shape + sh_orig_psfs = psfs.shape + if len(sh_orig)==2: + imsub_arr = [imsub_arr] + bpmask_arr = [bpmask_arr] + psfs = [psfs] + elif len(sh_orig_psfs)==2: + psfs = [psfs] + + xysh_pix_phase = [] + xysh_pix_cc = [] + xysh_pix_lsq = [] + iter_vals = trange(len(imsub_arr), desc='Image Alignment', leave=False) if len(imsub_arr)>=10 else range(len(imsub_arr)) for i in iter_vals: im = imsub_arr[i] # If only a single PSF was passed, then use it for all images psf = psfs[i] if sh_orig==sh_orig_psfs else psfs[0] - if phase: - crop = 15 if crop is None else crop - res = find_pix_phase(im, psf, psf_osamp, kipc=kipc, kppc=kppc, - diffusion_sigma=diffusion_sigma, crop=crop) - else: - crop = 21 if crop is None else crop + if crop is None: + crop = 15 if phase else 21 + # Ensure crop is at least 20 pixels larger than rin + rin = kwargs.get('rin', 0) + if crop-rin < 20: + crop = rin + 20 + # Ensure crop is odd + if np.mod(crop, 2)==0: + crop += 1 + if phase: + res = find_pix_phase(im, psf, psf_osamp, kipc=kipc, kppc=kppc, + diffusion_sigma=diffusion_sigma, + psf_corr_image=psf_corr_image, crop=crop, **kwargs) + xysh_pix_phase.append(res) + elif xcorr or lsq_diff: # Only set to return grid on first iteration - if len(sh_orig)==3 and len(sh_orig_psfs)==2 and i==0: - return_grids = True - else: - return_grids = False + return_grids = True if len(sh_orig)==3 and len(sh_orig_psfs)==2 and i==0 else False - res = find_pix_cc(im, psf, psf_osamp, kipc=kipc, kppc=kppc, - diffusion_sigma=diffusion_sigma, crop=crop, - bpmask=bpmask_arr[i], return_grids=return_grids, **kwargs) + try: + bpmask = bpmask_arr[i] + except TypeError: + bpmask = None + + # Do cross-correlation + if xcorr: + res = find_pix_cc(im, psf, psf_osamp, bpmask=bpmask, crop=crop, + kipc=kipc, kppc=kppc, diffusion_sigma=diffusion_sigma, + psf_corr_image=psf_corr_image, lsq_diff=False, + return_grids=return_grids, **kwargs) - # Set res1 and res2 going forward - if return_grids and i==0: - res, res1, res2 = res - kwargs['res1'] = res1 - kwargs['res2'] = res2 - - xysh_pix.append(res) - - if len(sh_orig)==2: - return np.asarray(xysh_pix[0]) + # Set res_coarse going forward + if return_grids and i==0: + res, res_coarse = res + kwargs['res_coarse'] = res_coarse + return_grids = False + + xysh_pix_cc.append(res) + + # Do least squares difference + if lsq_diff: + res = find_pix_cc(im, psf, psf_osamp, bpmask=bpmask, crop=crop, + kipc=kipc, kppc=kppc, diffusion_sigma=diffusion_sigma, + psf_corr_image=psf_corr_image, lsq_diff=True, + return_grids=return_grids, **kwargs) + + # Set res_coarse going forward + if return_grids and i==0: + res, res_coarse = res + kwargs['res_coarse'] = res_coarse + return_grids = False + + xysh_pix_lsq.append(res) + + if len(sh_orig)==2 and len(xysh_pix_phase)>0: + xysh_pix_phase = np.asarray(xysh_pix_phase[0]) + if len(sh_orig)==2 and len(xysh_pix_cc)>0: + xysh_pix_cc = np.asarray(xysh_pix_cc[0]) + if len(sh_orig)==2 and len(xysh_pix_lsq)>0: + xysh_pix_lsq = np.asarray(xysh_pix_lsq[0]) + + if phase + xcorr + lsq_diff > 1: + res = {} + if phase: res['phase'] = xysh_pix_phase + if xcorr: res['xcorr'] = xysh_pix_cc + if lsq_diff: res['lsqdiff'] = xysh_pix_lsq else: - return np.asarray(xysh_pix) + if phase: res = xysh_pix_phase + elif xcorr: res = xysh_pix_cc + elif lsq_diff: res = xysh_pix_lsq + + return res