Skip to content

Commit

Permalink
fix for masked values in FitTrace
Browse files Browse the repository at this point in the history
  • Loading branch information
cshanahan1 committed Jan 11, 2024
1 parent aea9d50 commit 779438e
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 35 deletions.
95 changes: 70 additions & 25 deletions specreduce/tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,28 +141,73 @@ def test_fit_trace():
with pytest.raises(ValueError, match=r'bins must be <'):
FitTrace(img, bins=ncols + 1)

# error on trace of otherwise valid image with all-nan window around guess
with pytest.raises(ValueError, match='pixels in window region are masked'):
FitTrace(img_win_nans, guess=guess, window=window)

# error on trace of all-nan image
with pytest.raises(ValueError, match=r'image is fully masked'):
FitTrace(img_all_nans)

# test that warning is raised when several bins are masked
mask = np.zeros(img.shape)
mask[:, 100] = 1
mask[:, 20] = 1
mask[:, 30] = 1
nddat = NDData(data=img, mask=mask, unit=u.DN)
msg = "All pixels in bins 20, 30, 100 are masked. Falling back on trace value from all-bin fit."
with pytest.warns(UserWarning, match=msg):
FitTrace(nddat)

# and when many bins are masked
mask = np.zeros(img.shape)
mask[:, 0:21] = 1
nddat = NDData(data=img, mask=mask, unit=u.DN)
msg = 'All pixels in bins 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 20 are masked.'
with pytest.warns(UserWarning, match=msg):
FitTrace(nddat)

class TestMasksTracing():

def mk_img(self, nrows=200, ncols=160):

np.random.seed(7)

sigma_pix = 4
sigma_noise = 1

col_model = models.Gaussian1D(amplitude=1, mean=nrows/2, stddev=sigma_pix)
noise = np.random.normal(scale=sigma_noise, size=(nrows, ncols))

index_arr = np.tile(np.arange(nrows), (ncols, 1))
img = col_model(index_arr.T) + noise

return img

def test_window_fit_trace(self):

"""This test function will test that masked values are treated correctly in
FitTrace, and produce the correct results and warning messages based on
`peak_method`."""
img = self.mk_img()

# create same-shaped variations of image with invalid values
nrows = 200
ncols = 160
img_all_nans = np.tile(np.nan, (nrows, ncols))

window = 10
guess = int(nrows/2)
img_win_nans = img.copy()
img_win_nans[guess - window:guess + window] = np.nan

# error on trace of otherwise valid image with all-nan window around guess
with pytest.raises(ValueError, match='pixels in window region are masked'):
FitTrace(img_win_nans, guess=guess, window=window)

# error on trace of all-nan image
with pytest.raises(ValueError, match=r'image is fully masked'):
FitTrace(img_all_nans)

def test_fit_trace_all_nan_columns(self):

img = self.mk_img()

# test that warning (dependent on choice of `peak_method`) is raised when a
# few bins are masked, and that theyre listed individually
mask = np.zeros(img.shape)
mask[:, 100] = 1
mask[:, 20] = 1
mask[:, 30] = 1
nddat = NDData(data=img, mask=mask, unit=u.DN)

with pytest.warns(UserWarning, match='All pixels in bins 20, 30, 100 are fully masked. Setting to zero.'):
FitTrace(nddat, peak_method='max')

# with pytest.warns(UserWarning, match='All pixels in bins 20, 30, 100 are fully masked. Setting to largest bin index (200).'):
# FitTrace(nddat, peak_method='centroid')

# with pytest.warns(UserWarning, match='All pixels in bins 20, 30, 100 are fully masked. Setting to nan.'):
# FitTrace(nddat, peak_method='gaussian')

# and when many bins are masked, that the message is consolidated
# mask = np.zeros(img.shape)
# mask[:, 0:21] = 1
# nddat = NDData(data=img, mask=mask, unit=u.DN)
# with pytest.warns(UserWarning, match='All pixels in bins 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 20 are masked.'):
# FitTrace(nddat)
43 changes: 33 additions & 10 deletions specreduce/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,23 @@ def __post_init__(self):

warn_bins = []
for i in range(self.bins):
# repeat earlier steps to create gaussian fit for each bin

# binned column
z_i = img[ilum2, x_bins[i]:x_bins[i+1]].sum(axis=self._disp_axis)
if not z_i.mask.all():
peak_y_i = ilum2[z_i.argmax()]
else:
warn_bins.append(i)
peak_y_i = peak_y
#print('z_i', z_i)

if self.peak_method == 'gaussian':

# if binned column is fully masked for peak_method='gaussian',
# the fit value for this bin should be nan, then continue to next
warn_msg = 'Setting to nan'
if z_i.mask.all():
warn_bins.append(i)
y_bins[i] = np.nan
continue

Check warning on line 318 in specreduce/tracing.py

View check run for this annotation

Codecov / codecov/patch

specreduce/tracing.py#L316-L318

Added lines #L316 - L318 were not covered by tests

peak_y_i = ilum2[z_i.argmax()]

yy_i_above_half_max = np.sum(z_i > (z_i.max() / 2))
width_guess_i = yy_i_above_half_max / gaussian_sigma_to_fwhm

Expand All @@ -336,24 +344,39 @@ def __post_init__(self):
y_bins[i] = popt_i.mean_0.value
popt_tot = popt_i

if z_i.mask.all(): # all-masked bins when peak_method is 'centroid' or 'max'
warn_bins.append(i)

elif self.peak_method == 'centroid':
z_i_cumsum = np.cumsum(z_i)
# find the interpolated index where the cumulative array reaches half the total
# cumulative values
# find the interpolated index where the cumulative array reaches
# half the total cumulative values
y_bins[i] = np.interp(z_i_cumsum[-1]/2., z_i_cumsum, ilum2)

# warning message for fully masked bin
# NOTE this reflects current behavior, should eventually be changed
# to set to nan by default (or zero fill / interpoate option once
# available) and warn accordingly.
warn_msg = f'Setting to largest bin index ({z_i_cumsum.shape[0]})'

elif self.peak_method == 'max':
# TODO: implement smoothing with provided width
y_bins[i] = ilum2[z_i.argmax()]

# warn about fully-masked bins (which, currently, means any bin with a single masked value)
# warning message for fully masked bin
# NOTE this reflects current behavior, should eventually be changed
# to set to nan by default (or zero fill / interpoate option once
# available) and warn accordingly.
warn_msg = 'Setting to zero'

# warn about fully-masked bins
if len(warn_bins) > 0:
# if there are a ton of bins, we don't want to print them all out
if len(warn_bins) > 20:
warn_bins = warn_bins[0: 10] + ['...'] + [warn_bins[-1]]
warnings.warn(f"All pixels in {'bins' if len(warn_bins) else 'bin'} "
f"{', '.join([str(x) for x in warn_bins])}"
" are masked. Falling back on trace value from all-bin fit.")
f" are fully masked. {warn_msg}.")

# recenter bin positions
x_bins = (x_bins[:-1] + x_bins[1:]) / 2
Expand Down

0 comments on commit 779438e

Please sign in to comment.