Skip to content

Commit

Permalink
fix for masked values in FitTrace
Browse files Browse the repository at this point in the history
  • Loading branch information
cshanahan1 committed Jan 17, 2024
1 parent aea9d50 commit b8c65c3
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 40 deletions.
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ Bug Fixes

- HorneExtract now accepts 'None' as a vaild option for ``bkgrd_prof``. [#171]

- Fix for fully masked bins in FitTrace when using ``gaussian`` for ``peak_method``.
Fully masked columns now have a peak of nan, which is used for the all-bin fit
for the Trace. Warning messages for ``peak_method`` == ``max`` and ``centroid``
are also now reflective of what the bin peak is being set to. [#205]

Other changes
^^^^^^^^^^^^^

Expand Down
107 changes: 79 additions & 28 deletions specreduce/tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ def test_fit_trace():
with pytest.raises(ValueError):
t = FitTrace(img, peak_method='invalid')

# create same-shaped variations of image with invalid values
img_all_nans = np.tile(np.nan, (nrows, ncols))

window = 10
guess = int(nrows/2)
img_win_nans = img.copy()
Expand All @@ -141,28 +138,82 @@ def test_fit_trace():
with pytest.raises(ValueError, match=r'bins must be <'):
FitTrace(img, bins=ncols + 1)

# error on trace of otherwise valid image with all-nan window around guess
with pytest.raises(ValueError, match='pixels in window region are masked'):
FitTrace(img_win_nans, guess=guess, window=window)

# error on trace of all-nan image
with pytest.raises(ValueError, match=r'image is fully masked'):
FitTrace(img_all_nans)

# test that warning is raised when several bins are masked
mask = np.zeros(img.shape)
mask[:, 100] = 1
mask[:, 20] = 1
mask[:, 30] = 1
nddat = NDData(data=img, mask=mask, unit=u.DN)
msg = "All pixels in bins 20, 30, 100 are masked. Falling back on trace value from all-bin fit."
with pytest.warns(UserWarning, match=msg):
FitTrace(nddat)

# and when many bins are masked
mask = np.zeros(img.shape)
mask[:, 0:21] = 1
nddat = NDData(data=img, mask=mask, unit=u.DN)
msg = 'All pixels in bins 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 20 are masked.'
with pytest.warns(UserWarning, match=msg):
FitTrace(nddat)

class TestMasksTracing():

def mk_img(self, nrows=200, ncols=160):

np.random.seed(7)

sigma_pix = 4
sigma_noise = 1

col_model = models.Gaussian1D(amplitude=1, mean=nrows/2, stddev=sigma_pix)
noise = np.random.normal(scale=sigma_noise, size=(nrows, ncols))

index_arr = np.tile(np.arange(nrows), (ncols, 1))
img = col_model(index_arr.T) + noise

return img

def test_window_fit_trace(self):

"""This test function will test that masked values are treated correctly in
FitTrace, and produce the correct results and warning messages based on
`peak_method`."""
img = self.mk_img()

# create same-shaped variations of image with invalid values
nrows = 200
ncols = 160
img_all_nans = np.tile(np.nan, (nrows, ncols))

window = 10
guess = int(nrows/2)
img_win_nans = img.copy()
img_win_nans[guess - window:guess + window] = np.nan

# error on trace of otherwise valid image with all-nan window around guess
with pytest.raises(ValueError, match='pixels in window region are masked'):
FitTrace(img_win_nans, guess=guess, window=window)

# error on trace of all-nan image
with pytest.raises(ValueError, match=r'image is fully masked'):
FitTrace(img_all_nans)

@pytest.mark.filterwarnings("ignore:The fit may be unsuccessful")
@pytest.mark.filterwarnings("ignore:Model is linear in parameters")
def test_fit_trace_all_nan_columns(self):

img = self.mk_img()

# test that warning (dependent on choice of `peak_method`) is raised when a
# few bins are masked, and that theyre listed individually
mask = np.zeros(img.shape)
mask[:, 100] = 1
mask[:, 20] = 1
mask[:, 30] = 1
nddat = NDData(data=img, mask=mask, unit=u.DN)

match_str = 'All pixels in bins 20, 30, 100 are fully masked. '

with pytest.warns(UserWarning, match=match_str +
'Setting bin peaks to zero.'):
FitTrace(nddat, peak_method='max')

with pytest.warns(UserWarning, match=match_str +
'Setting bin peaks to largest bin index \\(200\\)'):
FitTrace(nddat, peak_method='centroid')

with pytest.warns(UserWarning, match=match_str +
'Setting bin peaks to nan.'):
FitTrace(nddat, peak_method='gaussian')

# and when many bins are masked, that the message is consolidated
mask = np.zeros(img.shape)
mask[:, 0:21] = 1
nddat = NDData(data=img, mask=mask, unit=u.DN)
with pytest.warns(UserWarning, match='All pixels in bins '
'0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 20 are '
'fully masked. Setting bin peaks to zero.'):
FitTrace(nddat)
57 changes: 45 additions & 12 deletions specreduce/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ class FitTrace(Trace, _ImageParser):
_disp_axis = 1

def __post_init__(self):

# parse image
self.image = self._parse_image(self.image)

# mask any previously uncaught invalid values
Expand Down Expand Up @@ -262,14 +264,20 @@ def __post_init__(self):
warnings.warn('TRACE: Converting window to int')
self.window = int(self.window)

# fit the trace
self._fit_trace(img)

def _fit_trace(self, img):

yy = np.arange(img.shape[self._crossdisp_axis])

# set max peak location by user choice or wavelength with max avg flux
ztot = img.sum(axis=self._disp_axis) / img.shape[self._disp_axis]
peak_y = self.guess if self.guess is not None else ztot.argmax()
# NOTE: peak finder can be bad if multiple objects are on slit

yy = np.arange(img.shape[self._crossdisp_axis])

if self.peak_method == 'gaussian':

# guess the peak width as the FWHM, roughly converted to gaussian sigma
yy_above_half_max = np.sum(ztot > (ztot.max() / 2))
width_guess = yy_above_half_max / gaussian_sigma_to_fwhm
Expand All @@ -292,6 +300,8 @@ def __post_init__(self):
ilum2 = (yy if self.window is None
else yy[np.arange(peak_y - self.window,
peak_y + self.window, dtype=int)])

# check if everything in window region is masked
if img[ilum2].mask.all():
raise ValueError('All pixels in window region are masked. Check '
'for invalid values or use a larger window value.')
Expand All @@ -302,15 +312,21 @@ def __post_init__(self):

warn_bins = []
for i in range(self.bins):
# repeat earlier steps to create gaussian fit for each bin

# binned column (sum) if bins < ncols. otherwise, just columns.
z_i = img[ilum2, x_bins[i]:x_bins[i+1]].sum(axis=self._disp_axis)
if not z_i.mask.all():
peak_y_i = ilum2[z_i.argmax()]
else:
warn_bins.append(i)
peak_y_i = peak_y

if self.peak_method == 'gaussian':
# if binned column is fully masked for peak_method='gaussian',
# the fit value for this bin should be nan, then continue to next
warn_msg = 'nan'
if z_i.mask.all():
warn_bins.append(i)
y_bins[i] = np.nan
continue

peak_y_i = ilum2[z_i.argmax()]

yy_i_above_half_max = np.sum(z_i > (z_i.max() / 2))
width_guess_i = yy_i_above_half_max / gaussian_sigma_to_fwhm

Expand All @@ -336,24 +352,41 @@ def __post_init__(self):
y_bins[i] = popt_i.mean_0.value
popt_tot = popt_i

if z_i.mask.all(): # all-masked bins when peak_method is 'centroid' or 'max'
warn_bins.append(i)

elif self.peak_method == 'centroid':
z_i_cumsum = np.cumsum(z_i)
# find the interpolated index where the cumulative array reaches half the total
# cumulative values
# find the interpolated index where the cumulative array reaches
# half the total cumulative values
y_bins[i] = np.interp(z_i_cumsum[-1]/2., z_i_cumsum, ilum2)

# warning message for fully masked bin
# NOTE this reflects current behavior, should eventually be changed
# to set to nan by default (or zero fill / interpoate option once
# available) and warn accordingly.
warn_msg = f'largest bin index ({str(z_i_cumsum.shape[0])})'

elif self.peak_method == 'max':
# TODO: implement smoothing with provided width
y_bins[i] = ilum2[z_i.argmax()]

# warn about fully-masked bins (which, currently, means any bin with a single masked value)
# warning message for fully masked bin
# NOTE: should eventually be changed to set to nan by default
# (or zero fill / interpoate option onceavailable) and warn.
warn_msg = 'zero'

# warn about fully-masked bins
if len(warn_bins) > 0:

# if there are a ton of bins, we don't want to print them all out
if len(warn_bins) > 20:
warn_bins = warn_bins[0: 10] + ['...'] + [warn_bins[-1]]

warnings.warn(f"All pixels in {'bins' if len(warn_bins) else 'bin'} "
f"{', '.join([str(x) for x in warn_bins])}"
" are masked. Falling back on trace value from all-bin fit.")
" are fully masked. Setting bin"
f" peak{'s' if len(warn_bins) else ''} to {warn_msg}.")

# recenter bin positions
x_bins = (x_bins[:-1] + x_bins[1:]) / 2
Expand Down

0 comments on commit b8c65c3

Please sign in to comment.