From 5f72ca7f526c333771f3762eeaf0d2f3446c4c0a Mon Sep 17 00:00:00 2001 From: Clare Shanahan Date: Tue, 21 Nov 2023 10:48:07 -0500 Subject: [PATCH] fix warning message --- specreduce/tests/test_tracing.py | 21 +++++++++++++++++++-- specreduce/tracing.py | 13 +++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/specreduce/tests/test_tracing.py b/specreduce/tests/test_tracing.py index 63bcabf..021bd0b 100644 --- a/specreduce/tests/test_tracing.py +++ b/specreduce/tests/test_tracing.py @@ -1,7 +1,8 @@ import numpy as np import pytest from astropy.modeling import models - +from astropy.nddata import NDData +import astropy.units as u from specreduce.utils.synth_data import make_2d_trace_image from specreduce.tracing import Trace, FlatTrace, ArrayTrace, FitTrace @@ -148,4 +149,20 @@ def test_fit_trace(): with pytest.raises(ValueError, match=r'image is fully masked'): FitTrace(img_all_nans) - # could try to catch warning thrown for all-nan bins + # test that warning is raised when several bins are masked + mask = np.zeros(img.shape) + mask[:, 100] = 1 + mask[:, 20] = 1 + mask[:, 30] = 1 + nddat = NDData(data=img, mask=mask, unit=u.DN) + msg = "All pixels in bins 20, 30, 100 are masked. Falling back on trace value from all-bin fit." + with pytest.warns(UserWarning, match=msg): + FitTrace(nddat) + + # and when many bins are masked + mask = np.zeros(img.shape) + mask[:, 0:21] = 1 + nddat = NDData(data=img, mask=mask, unit=u.DN) + msg = 'All pixels in bins 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., 20 are masked.' + with pytest.warns(UserWarning, match=msg): + FitTrace(nddat) diff --git a/specreduce/tracing.py b/specreduce/tracing.py index a508dd2..c1a9b6f 100644 --- a/specreduce/tracing.py +++ b/specreduce/tracing.py @@ -300,14 +300,14 @@ def __post_init__(self): self.bins + 1, dtype=int) y_bins = np.tile(np.nan, self.bins) + warn_bins = [] for i in range(self.bins): # repeat earlier steps to create gaussian fit for each bin z_i = img[ilum2, x_bins[i]:x_bins[i+1]].sum(axis=self._disp_axis) if not z_i.mask.all(): peak_y_i = ilum2[z_i.argmax()] else: - warnings.warn(f"All pixels in bin {i} are masked. Falling " - 'to trace value from all-bin fit.') + warn_bins.append(i) peak_y_i = peak_y if self.peak_method == 'gaussian': @@ -346,6 +346,15 @@ def __post_init__(self): # TODO: implement smoothing with provided width y_bins[i] = ilum2[z_i.argmax()] + # warn about fully-masked bins (which, currently, means any bin with a single masked value) + if len(warn_bins) > 0: + # if there are a ton of bins, we don't want to print them all out + if len(warn_bins) > 20: + warn_bins = warn_bins[0: 10] + ['...'] + [warn_bins[-1]] + warnings.warn(f"All pixels in {'bins' if len(warn_bins) else 'bin'} " + f"{', '.join([str(x) for x in warn_bins])}" + " are masked. Falling back on trace value from all-bin fit.") + # recenter bin positions x_bins = (x_bins[:-1] + x_bins[1:]) / 2