From 89136ba3ffa0e510b61cce62f1a7b0f2698c31a0 Mon Sep 17 00:00:00 2001 From: Tyler Pauly Date: Wed, 29 Jan 2025 15:55:27 -0500 Subject: [PATCH] JP-3862: Apply code style to pixel_replace module (#9107) --- .pre-commit-config.yaml | 1 - .ruff.toml | 4 +- jwst/pixel_replace/__init__.py | 2 + jwst/pixel_replace/pixel_replace.py | 303 +++++++++++++---------- jwst/pixel_replace/pixel_replace_step.py | 80 +++--- 5 files changed, 225 insertions(+), 165 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 808abb072d..c357640595 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -75,7 +75,6 @@ repos: jwst/persistence/.* | jwst/photom/.* | jwst/pipeline/.* | - jwst/pixel_replace/.* | jwst/ramp_fitting/.* | jwst/refpix/.* | jwst/regtest/.* | diff --git a/.ruff.toml b/.ruff.toml index acd3554d12..bc0d3a6048 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -67,7 +67,7 @@ exclude = [ "jwst/persistence/**.py", "jwst/photom/**.py", "jwst/pipeline/**.py", - "jwst/pixel_replace/**.py", + # "jwst/pixel_replace/**.py", "jwst/ramp_fitting/**.py", "jwst/refpix/**.py", "jwst/regtest/**.py", @@ -194,7 +194,7 @@ ignore-fully-untyped = true # Turn of annotation checking for fully untyped cod "jwst/persistence/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/photom/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/pipeline/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] -"jwst/pixel_replace/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] +# "jwst/pixel_replace/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/ramp_fitting/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/refpix/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] "jwst/regtest/**.py" = ["D", "N", "A", "ARG", "B", "C4", "ICN", "INP", "ISC", "LOG", "NPY", "PGH", "PTH", "S", "SLF", "SLOT", "T20", "TRY", "UP", "YTT", "E501"] diff --git a/jwst/pixel_replace/__init__.py b/jwst/pixel_replace/__init__.py index 5ff05aeab1..b26b26cbb0 100644 --- a/jwst/pixel_replace/__init__.py +++ b/jwst/pixel_replace/__init__.py @@ -1,3 +1,5 @@ +"""Estimate missing pixel values in spectral data.""" + from .pixel_replace_step import PixelReplaceStep __all__ = ["PixelReplaceStep"] diff --git a/jwst/pixel_replace/pixel_replace.py b/jwst/pixel_replace/pixel_replace.py index 098fdeed9c..53748a9136 100644 --- a/jwst/pixel_replace/pixel_replace.py +++ b/jwst/pixel_replace/pixel_replace.py @@ -10,24 +10,24 @@ class PixelReplacement: - """Main class for performing pixel replacement. + """ + Main class for performing pixel replacement. This class controls loading the input data model, selecting the method for pixel replacement, and executing each step. This class should provide modularization to allow for multiple options and possible future reference files. - """ # Shortcuts for DQ Flags - DO_NOT_USE = datamodels.dqflags.pixel['DO_NOT_USE'] - FLUX_ESTIMATED = datamodels.dqflags.pixel['FLUX_ESTIMATED'] - NON_SCIENCE = datamodels.dqflags.pixel['NON_SCIENCE'] + DO_NOT_USE = datamodels.dqflags.pixel["DO_NOT_USE"] + FLUX_ESTIMATED = datamodels.dqflags.pixel["FLUX_ESTIMATED"] + NON_SCIENCE = datamodels.dqflags.pixel["NON_SCIENCE"] # Shortcuts for dispersion direction for ease of reading HORIZONTAL = 1 VERTICAL = 2 - LOG_SLICE = ['column', 'row'] + LOG_SLICE = ["column", "row"] def __init__(self, input_model, **pars): """ @@ -35,44 +35,48 @@ def __init__(self, input_model, **pars): Parameters ---------- - input_model : DataModel, str - list of data models as ModelContainer or ASN file, - one data model for each input image + input_model : datamodel, str + Datamodel or list of data models as ModelContainer + or ASN file, one datamodel for each input image - pars : dict, optional + **pars : dict, optional Optional parameters to modify how pixel replacement will execute. """ self.input = input_model - self.pars = dict() + self.pars = {} self.pars.update(pars) self.output = self.input.copy() # Store algorithm options here. self.algorithm_dict = { - 'fit_profile': self.fit_profile, - 'mingrad': self.mingrad, + "fit_profile": self.fit_profile, + "mingrad": self.mingrad, } # Choose algorithm from dict using input par. try: - self.algorithm = self.algorithm_dict[self.pars['algorithm']] + self.algorithm = self.algorithm_dict[self.pars["algorithm"]] - except KeyError: - log.critical(f"Algorithm name {self.pars['algorithm']} provided does " - "not match an implemented algorithm!") - raise Exception + except KeyError as err: + log.critical( + f"Algorithm name {self.pars['algorithm']} provided does " + "not match an implemented algorithm!" + ) + raise KeyError from err def replace(self): """ + Unpack model and apply pixel replacement algorithm. + Process the input DataModel, unpack any model that holds more than one 2D spectrum, then apply selected algorithm to each 2D spectrum in input. """ # ImageModel inputs (MIR_LRS-FIXEDSLIT) # or 2D SlitModel inputs (e.g. NRS_FIXEDSLIT in spec3) - if (isinstance(self.input, datamodels.ImageModel) - or (isinstance(self.input, datamodels.SlitModel) - and self.input.data.ndim == 2)): + if isinstance(self.input, datamodels.ImageModel) or ( + isinstance(self.input, datamodels.SlitModel) and self.input.data.ndim == 2 + ): self.output = self.algorithm(self.input) n_replaced = np.count_nonzero(self.output.dq & self.FLUX_ESTIMATED) log.info(f"Input model had {n_replaced} pixels replaced.") @@ -80,25 +84,27 @@ def replace(self): # Attempt to run pixel replacement on each throw of the IFU slicer # individually. xx, yy = np.indices(self.input.data.shape) - if self.input.meta.exposure.type == 'MIR_MRS': - if self.pars['algorithm'] == 'mingrad': + if self.input.meta.exposure.type == "MIR_MRS": + if self.pars["algorithm"] == "mingrad": # mingrad method new_model = self.algorithm(self.input) self.output = new_model else: # fit_profile method - _, beta_array, _ = self.input.meta.wcs.transform('detector', 'alpha_beta', yy, xx) + _, beta_array, _ = self.input.meta.wcs.transform( + "detector", "alpha_beta", yy, xx + ) unique_beta = np.unique(beta_array) unique_beta = unique_beta[~np.isnan(unique_beta)] for i, beta in enumerate(unique_beta): # Define a mask that is True where this trace is located - trace_mask = (beta_array == beta) + trace_mask = beta_array == beta trace_model = self.input.copy() trace_model.dq = np.where( # When not in this trace, set NON_SCIENCE and DO_NOT_USE ~trace_mask, trace_model.dq | self.DO_NOT_USE | self.NON_SCIENCE, - trace_model.dq + trace_model.dq, ) trace_model = self.algorithm(trace_model) @@ -106,48 +112,55 @@ def replace(self): # Where trace is located, set replaced values trace_mask, trace_model.data, - self.output.data + self.output.data, ) # do the same for dq, err, and var - self.output.dq = np.where( - trace_mask, trace_model.dq, self.output.dq) - self.output.err = np.where( - trace_mask, trace_model.err, self.output.err) + self.output.dq = np.where(trace_mask, trace_model.dq, self.output.dq) + self.output.err = np.where(trace_mask, trace_model.err, self.output.err) self.output.var_poisson = np.where( - trace_mask, trace_model.var_poisson, self.output.var_poisson) + trace_mask, trace_model.var_poisson, self.output.var_poisson + ) self.output.var_rnoise = np.where( - trace_mask, trace_model.var_rnoise, self.output.var_rnoise) + trace_mask, trace_model.var_rnoise, self.output.var_rnoise + ) self.output.var_flat = np.where( - trace_mask, trace_model.var_flat, self.output.var_flat) + trace_mask, trace_model.var_flat, self.output.var_flat + ) n_replaced = np.count_nonzero(trace_model.dq & self.FLUX_ESTIMATED) - log.info(f"Input MRS frame had {n_replaced} pixels replaced in IFU slice {i+1}.") + log.info( + f"Input MRS frame had {n_replaced} pixels replaced " + f"in IFU slice {i + 1}." + ) trace_model.close() n_replaced = np.count_nonzero(self.output.dq & self.FLUX_ESTIMATED) log.info(f"Input MRS frame had {n_replaced} total pixels replaced.") else: - if self.pars['algorithm'] == 'mingrad': + if self.pars["algorithm"] == "mingrad": # mingrad method new_model = self.algorithm(self.input) self.output = new_model else: # fit_profile method - iterate over IFU slices - wcsobj, tr1, tr2, tr3 = nirspec._get_transforms(self.input, np.arange(30)) + wcsobj, tr1, tr2, tr3 = nirspec._get_transforms( # noqa: SLF001 + self.input, np.arange(30) + ) for i in range(30): - slice_wcs = nirspec._nrs_wcs_set_input_lite(self.input, wcsobj, i, - [tr1, tr2[i], tr3[i]]) - _, _, wave = slice_wcs.transform('detector', 'slicer', yy, xx) + slice_wcs = nirspec._nrs_wcs_set_input_lite( # noqa: SLF001 + self.input, wcsobj, i, [tr1, tr2[i], tr3[i]] + ) + _, _, wave = slice_wcs.transform("detector", "slicer", yy, xx) # Define a mask that is True where this trace is located - trace_mask = (wave > 0) + trace_mask = wave > 0 trace_model = self.input.copy() trace_model.dq = np.where( # When not in this trace, set NON_SCIENCE and DO_NOT_USE ~trace_mask, trace_model.dq | self.DO_NOT_USE | self.NON_SCIENCE, - trace_model.dq + trace_model.dq, ) trace_model = self.algorithm(trace_model) @@ -155,23 +168,27 @@ def replace(self): # Where trace is located, set replaced values trace_mask, trace_model.data, - self.output.data + self.output.data, ) # do the same for dq, err, and var - self.output.dq = np.where( - trace_mask, trace_model.dq, self.output.dq) - self.output.err = np.where( - trace_mask, trace_model.err, self.output.err) + self.output.dq = np.where(trace_mask, trace_model.dq, self.output.dq) + self.output.err = np.where(trace_mask, trace_model.err, self.output.err) self.output.var_poisson = np.where( - trace_mask, trace_model.var_poisson, self.output.var_poisson) + trace_mask, trace_model.var_poisson, self.output.var_poisson + ) self.output.var_rnoise = np.where( - trace_mask, trace_model.var_rnoise, self.output.var_rnoise) + trace_mask, trace_model.var_rnoise, self.output.var_rnoise + ) self.output.var_flat = np.where( - trace_mask, trace_model.var_flat, self.output.var_flat) + trace_mask, trace_model.var_flat, self.output.var_flat + ) n_replaced = np.count_nonzero(trace_model.dq & self.FLUX_ESTIMATED) - log.info(f"Input NRS_IFU frame had {n_replaced} pixels replaced in IFU slice {i + 1}.") + log.info( + f"Input NRS_IFU frame had {n_replaced} pixels " + f"replaced in IFU slice {i + 1}." + ) trace_model.close() @@ -180,8 +197,7 @@ def replace(self): # MultiSlitModel inputs (WFSS, NRS_FIXEDSLIT, ?) elif isinstance(self.input, datamodels.MultiSlitModel): - - for i, slit in enumerate(self.input.slits): + for i, _slit in enumerate(self.input.slits): slit_model = datamodels.SlitModel(self.input.slits[i].instance) slit_replaced = self.algorithm(slit_model) @@ -193,10 +209,11 @@ def replace(self): # CubeModel inputs are TSO (so far?); SlitModel may be NRS_BRIGHTOBJ, # also requiring a re-packaging of the data into 2D inputs for the algorithm - elif isinstance(self.input, (datamodels.CubeModel, datamodels.SlitModel)): + elif isinstance(self.input, datamodels.CubeModel | datamodels.SlitModel): for i in range(len(self.input.data)): img_model = datamodels.ImageModel( - data=self.input.data[i], dq=self.input.dq[i], + data=self.input.data[i], + dq=self.input.dq[i], err=self.input.err[i], var_poisson=self.input.var_poisson[i], var_rnoise=self.input.var_rnoise[i], @@ -218,7 +235,9 @@ def replace(self): else: # This should never happen, as these should be caught in the step code. - log.critical("Pixel replacement code did not filter this input correctly - skipping step.") + log.critical( + "Pixel replacement code did not filter this input correctly - skipping step." + ) return def fit_profile(self, model): @@ -248,12 +267,11 @@ def fit_profile(self, model): DataModel with flagged bad pixels now flagged with FLUX_ESTIMATED and holding a flux value estimated from spatial profile, derived from adjacent columns. - """ # np.nanmedian() entry full of NaN values would produce a numpy # warning (despite well-defined behavior - return a NaN) # so we suppress that here. - warnings.filterwarnings(action='ignore', message='All-NaN slice encountered') + warnings.filterwarnings(action="ignore", message="All-NaN slice encountered") dispaxis = model.meta.wcsinfo.dispersion_direction @@ -262,8 +280,10 @@ def fit_profile(self, model): # Truncate array to region where good pixels exist good_pixels = np.where(~model.dq & self.DO_NOT_USE) if np.any(0 in np.shape(good_pixels)): - log.warning("No good pixels in at least one dimension of " - "data array - skipping pixel replacement.") + log.warning( + "No good pixels in at least one dimension of " + "data array - skipping pixel replacement." + ) return model x_range = [np.min(good_pixels[0]), np.max(good_pixels[0]) + 1] y_range = [np.min(good_pixels[1]), np.max(good_pixels[1]) + 1] @@ -286,13 +306,9 @@ def fit_profile(self, model): # but only iterate through slices with valid data. for ind in range(*valid_shape[2 - dispaxis]): # Exclude regions with no data for dq slice. - dq_slice = model.dq[self.custom_slice(dispaxis, ind)][profile_cut[0]: profile_cut[1]] + dq_slice = model.dq[self.custom_slice(dispaxis, ind)][profile_cut[0] : profile_cut[1]] # Exclude regions with NON_SCIENCE flag - dq_slice = np.where( - dq_slice & self.NON_SCIENCE, - self.NON_SCIENCE, - dq_slice - ) + dq_slice = np.where(dq_slice & self.NON_SCIENCE, self.NON_SCIENCE, dq_slice) # Find bad pixels in region containing valid data. n_bad = np.count_nonzero(dq_slice & self.DO_NOT_USE) n_nonscience = np.count_nonzero(dq_slice & self.NON_SCIENCE) @@ -307,10 +323,11 @@ def fit_profile(self, model): log.debug(f"Number of profiles with at least one bad pixel: {len(profiles_to_replace)}") - for i, ind in enumerate(profiles_to_replace): - + for ind in profiles_to_replace: # Use sets for convenient finding of neighboring slices to use in profile creation - adjacent_inds = set(range(ind - self.pars['n_adjacent_cols'], ind + self.pars['n_adjacent_cols'] + 1)) + adjacent_inds = set( + range(ind - self.pars["n_adjacent_cols"], ind + self.pars["n_adjacent_cols"] + 1) + ) adjacent_inds.discard(ind) valid_adjacent_inds = list(adjacent_inds.intersection(valid_profiles)) @@ -336,16 +353,16 @@ def fit_profile(self, model): profile_norm_scale = np.nanmax(np.abs(profile_data), axis=(dispaxis - 1), keepdims=True) # If profile data has SNR < 5 everywhere just use unity scaling # (so we don't normalize to noise) - if (np.nanmax(profile_snr) < 5): + if np.nanmax(profile_snr) < 5: profile_norm_scale[:] = 1.0 normalized = profile_data / profile_norm_scale # Get corresponding error and variance data and scale and mask to match # Handle the variance arrays as errors, so the scales match. - err_names = ['err', 'var_poisson', 'var_rnoise', 'var_flat'] + err_names = ["err", "var_poisson", "var_rnoise", "var_flat"] norm_errors = {} for err_name in err_names: - if err_name.startswith('var'): + if err_name.startswith("var"): err = np.sqrt(getattr(model, err_name)) else: err = getattr(model, err_name) @@ -361,21 +378,21 @@ def fit_profile(self, model): # Do the same for the errors for err_name in norm_errors: - norm_errors[err_name] = np.nanmedian( - norm_errors[err_name], axis=(2 - dispaxis)) + norm_errors[err_name] = np.nanmedian(norm_errors[err_name], axis=(2 - dispaxis)) # Clean current profile of values flagged as bad current_condition = self.custom_slice(dispaxis, ind) current_profile = model.data[current_condition] cleaned_current = np.where( - model.dq[current_condition] & self.DO_NOT_USE, - np.nan, - current_profile + model.dq[current_condition] & self.DO_NOT_USE, np.nan, current_profile )[range(*profile_cut)] replace_mask = np.where(~np.isnan(cleaned_current))[0] if len(replace_mask) == 0: - log.info(f"Profile in {self.LOG_SLICE[dispaxis - 1]} {ind} has no valid values - skipping.") + log.info( + f"Profile in {self.LOG_SLICE[dispaxis - 1]} {ind} " + f"has no valid values - skipping." + ) continue min_median = median_profile[replace_mask] min_current = cleaned_current[replace_mask] @@ -385,12 +402,19 @@ def fit_profile(self, model): # Only do this scaling if we didn't default to all-unity scaling above, # and require input values below 1e20 so that we don't overflow the # minimization routine with extremely bad noise. - if ((np.nanmedian(profile_norm_scale) != 1.0) & (np.nanmax(np.abs(min_median)) < 1e20) - & (np.nanmax(np.abs(norm_current)) < 1e20)): + if ( + (np.nanmedian(profile_norm_scale) != 1.0) + & (np.nanmax(np.abs(min_median)) < 1e20) + & (np.nanmax(np.abs(norm_current)) < 1e20) + ): # TODO: check on signs here - absolute max sometimes picks up # large negative outliers - norm_scale = minimize(self.profile_mse, x0=np.abs(np.nanmax(norm_current)), - args=(np.abs(min_median), np.abs(norm_current)), method='Nelder-Mead').x + norm_scale = minimize( + self.profile_mse, + x0=np.abs(np.nanmax(norm_current)), + args=(np.abs(min_median), np.abs(norm_current)), + method="Nelder-Mead", + ).x scale = np.max(min_current) else: norm_scale = 1.0 @@ -398,19 +422,16 @@ def fit_profile(self, model): # Replace pixels that are do-not-use but not non-science current_dq = model.dq[current_condition][range(*profile_cut)] - replace_condition = (current_dq & self.DO_NOT_USE - ^ current_dq & self.NON_SCIENCE) == 1 + replace_condition = (current_dq & self.DO_NOT_USE ^ current_dq & self.NON_SCIENCE) == 1 replaced_current = np.where( - replace_condition, - median_profile * norm_scale * scale, - cleaned_current + replace_condition, median_profile * norm_scale * scale, cleaned_current ) # Change the dq bits where old flag was DO_NOT_USE and new value is not nan replaced_dq = np.where( replace_condition & ~(np.isnan(replaced_current)), current_dq ^ self.DO_NOT_USE ^ self.FLUX_ESTIMATED, - current_dq + current_dq, ) # Update data and DQ in the output model @@ -420,33 +441,29 @@ def fit_profile(self, model): # Also update the errors and variances current_err = model.err[current_condition][range(*profile_cut)] replaced_err = np.where( - replace_condition, - norm_errors['err'] * norm_scale * scale, - current_err + replace_condition, norm_errors["err"] * norm_scale * scale, current_err ) model_replaced.err[current_condition][range(*profile_cut)] = replaced_err current_var = model.var_poisson[current_condition][range(*profile_cut)] replaced_var = np.where( replace_condition, - (norm_errors['var_poisson'] * norm_scale * scale)**2, - current_var + (norm_errors["var_poisson"] * norm_scale * scale) ** 2, + current_var, ) model_replaced.var_poisson[current_condition][range(*profile_cut)] = replaced_var current_var = model.var_rnoise[current_condition][range(*profile_cut)] replaced_var = np.where( replace_condition, - (norm_errors['var_rnoise'] * norm_scale * scale)**2, - current_var + (norm_errors["var_rnoise"] * norm_scale * scale) ** 2, + current_var, ) model_replaced.var_rnoise[current_condition][range(*profile_cut)] = replaced_var current_var = model.var_flat[current_condition][range(*profile_cut)] replaced_var = np.where( - replace_condition, - (norm_errors['var_flat'] * norm_scale * scale)**2, - current_var + replace_condition, (norm_errors["var_flat"] * norm_scale * scale) ** 2, current_var ) model_replaced.var_flat[current_condition][range(*profile_cut)] = replaced_var @@ -484,12 +501,11 @@ def mingrad(self, model): DataModel with flagged bad pixels now flagged with FLUX_ESTIMATED and holding a flux value estimated from spatial profile, derived from adjacent columns. - """ # np.nanmedian() entry full of NaN values would produce a numpy # warning (despite well-defined behavior - return a NaN) # so we suppress that here. - warnings.filterwarnings(action='ignore', message='All-NaN slice encountered') + warnings.filterwarnings(action="ignore", message="All-NaN slice encountered") log.info("Using minimum gradient method.") @@ -515,39 +531,74 @@ def mingrad(self, model): # Make an array of x/y values on the detector (ysize, xsize) = indata.shape basex, basey = np.meshgrid(np.arange(xsize), np.arange(ysize)) - pad = 1 # Padding around edge of array to ensure we don't look for neighbors outside array + pad = 1 # Padding around edge of array to ensure we don't look for neighbors outside array # Find NaN-valued pixels - indx = np.where((~np.isfinite(indata)) - & (basey > pad) & (basey < ysize-pad) & (basex > pad) & (basex < xsize-pad)) + indx = np.where( + (~np.isfinite(indata)) + & (basey > pad) + & (basey < ysize - pad) + & (basex > pad) + & (basex < xsize - pad) + ) # X and Y indices yindx, xindx = indx[0], indx[1] # Loop over these NaN-valued pixels nreplaced = 0 for ii in range(0, len(xindx)): - left_data, right_data = indata[yindx[ii], xindx[ii] - 1], indata[yindx[ii], xindx[ii] + 1] - top_data, bottom_data = indata[yindx[ii] - 1, xindx[ii]], indata[yindx[ii] + 1, xindx[ii]] + left_data, right_data = ( + indata[yindx[ii], xindx[ii] - 1], + indata[yindx[ii], xindx[ii] + 1], + ) + top_data, bottom_data = ( + indata[yindx[ii] - 1, xindx[ii]], + indata[yindx[ii] + 1, xindx[ii]], + ) left_err, right_err = inerr[yindx[ii], xindx[ii] - 1], inerr[yindx[ii], xindx[ii] + 1] top_err, bottom_err = inerr[yindx[ii] - 1, xindx[ii]], inerr[yindx[ii] + 1, xindx[ii]] - left_var_p, right_var_p = in_var_p[yindx[ii], xindx[ii] - 1], in_var_p[yindx[ii], xindx[ii] + 1] - top_var_p, bottom_var_p = in_var_p[yindx[ii] - 1, xindx[ii]], in_var_p[yindx[ii] + 1, xindx[ii]] + left_var_p, right_var_p = ( + in_var_p[yindx[ii], xindx[ii] - 1], + in_var_p[yindx[ii], xindx[ii] + 1], + ) + top_var_p, bottom_var_p = ( + in_var_p[yindx[ii] - 1, xindx[ii]], + in_var_p[yindx[ii] + 1, xindx[ii]], + ) - left_var_r, right_var_r = in_var_r[yindx[ii], xindx[ii] - 1], in_var_r[yindx[ii], xindx[ii] + 1] - top_var_r, bottom_var_r = in_var_r[yindx[ii] - 1, xindx[ii]], in_var_r[yindx[ii] + 1, xindx[ii]] + left_var_r, right_var_r = ( + in_var_r[yindx[ii], xindx[ii] - 1], + in_var_r[yindx[ii], xindx[ii] + 1], + ) + top_var_r, bottom_var_r = ( + in_var_r[yindx[ii] - 1, xindx[ii]], + in_var_r[yindx[ii] + 1, xindx[ii]], + ) - left_var_f, right_var_f = in_var_f[yindx[ii], xindx[ii] - 1], in_var_f[yindx[ii], xindx[ii] + 1] - top_var_f, bottom_var_f = in_var_f[yindx[ii] - 1, xindx[ii]], in_var_f[yindx[ii] + 1, xindx[ii]] + left_var_f, right_var_f = ( + in_var_f[yindx[ii], xindx[ii] - 1], + in_var_f[yindx[ii], xindx[ii] + 1], + ) + top_var_f, bottom_var_f = ( + in_var_f[yindx[ii] - 1, xindx[ii]], + in_var_f[yindx[ii] + 1, xindx[ii]], + ) # Compute absolute difference (slope) and average value in each direction (may be NaN) diffs = np.array([np.abs(left_data - right_data), np.abs(top_data - bottom_data)]) - interp_data = np.array([(left_data + right_data) / 2., (top_data + bottom_data) / 2.]) - interp_err = np.array([(left_err + right_err) / 2., (top_err + bottom_err) / 2.]) - interp_var_p = np.array([(left_var_p + right_var_p) / 2., (top_var_p + bottom_var_p) / 2.]) - interp_var_r = np.array([(left_var_r + right_var_r) / 2., (top_var_r + bottom_var_r) / 2.]) - interp_var_f = np.array([(left_var_f + right_var_f) / 2., (top_var_f + bottom_var_f) / 2.]) + interp_data = np.array([(left_data + right_data) / 2.0, (top_data + bottom_data) / 2.0]) + interp_err = np.array([(left_err + right_err) / 2.0, (top_err + bottom_err) / 2.0]) + interp_var_p = np.array( + [(left_var_p + right_var_p) / 2.0, (top_var_p + bottom_var_p) / 2.0] + ) + interp_var_r = np.array( + [(left_var_r + right_var_r) / 2.0, (top_var_r + bottom_var_r) / 2.0] + ) + interp_var_f = np.array( + [(left_var_f + right_var_f) / 2.0, (top_var_f + bottom_var_f) / 2.0] + ) # Replace with the value from the lowest absolute slope estimator that was not NaN try: @@ -562,8 +613,9 @@ def mingrad(self, model): # If original pixel was in the science array, remove # the DO_NOT_USE flag - if ((indq[yindx[ii], xindx[ii]] & self.DO_NOT_USE) - and not (indq[yindx[ii], xindx[ii]] & self.NON_SCIENCE)): + if (indq[yindx[ii], xindx[ii]] & self.DO_NOT_USE) and not ( + indq[yindx[ii], xindx[ii]] & self.NON_SCIENCE + ): newdq[yindx[ii], xindx[ii]] -= self.DO_NOT_USE # Either way, add the FLUX_ESTIMATED flag @@ -582,8 +634,7 @@ def mingrad(self, model): def custom_slice(self, dispaxis, index): """ - Construct slice for ease of use with varying - dispersion axis. + Construct slice for ease of use with varying dispersion axis. Parameters ---------- @@ -605,10 +656,12 @@ def custom_slice(self, dispaxis, index): elif dispaxis == self.VERTICAL: return np.s_[index, :] else: - raise Exception + raise IndexError("Custom slice requires valid dispersion axis specification!") def profile_mse(self, scale, median, current): - """Function to feed optimization routine + """ + Calculate mean squared error of fitted profile. + Parameters ---------- scale : float @@ -626,6 +679,6 @@ def profile_mse(self, scale, median, current): float Mean squared error for minimization purposes """ - - return (np.nansum((current - (median * scale)) ** 2.) / - (len(median) - np.count_nonzero(np.isnan(current)))) + return np.nansum((current - (median * scale)) ** 2.0) / ( + len(median) - np.count_nonzero(np.isnan(current)) + ) diff --git a/jwst/pixel_replace/pixel_replace_step.py b/jwst/pixel_replace/pixel_replace_step.py index 2bfe166c0e..665dbcc5b9 100644 --- a/jwst/pixel_replace/pixel_replace_step.py +++ b/jwst/pixel_replace/pixel_replace_step.py @@ -9,8 +9,7 @@ class PixelReplaceStep(Step): """ - PixelReplaceStep: Module for replacing flagged bad pixels with an estimate - of their flux, prior to spectral extraction. + Module for replacing flagged bad pixels prior to spectral extraction. Attributes ---------- @@ -29,17 +28,21 @@ class PixelReplaceStep(Step): spec = """ algorithm = option("fit_profile", "mingrad", "N/A", default="fit_profile") - n_adjacent_cols = integer(default=3) # Number of adjacent columns to use in creation of profile + # Number of adjacent columns to use in profile creation + n_adjacent_cols = integer(default=3) skip = boolean(default=True) # Step must be turned on by parameter reference or user output_use_model = boolean(default=True) # Use input filenames in the output models """ - def process(self, input): - """Execute the step. + def process(self, input_data): + """ + Execute the step. Parameters ---------- - input : JWST DataModel + input_data : datamodel, str + The input datamodel or filename containing + spectral data in need of pixel replacement. Returns ------- @@ -48,27 +51,30 @@ def process(self, input): it will be a model containing data arrays with estimated fluxes for any bad pixels, now flagged as TO-BE-DETERMINED (DQ bit 7?). """ - with datamodels.open(input) as input_model: + with datamodels.open(input_data) as input_model: # If more than one 2d spectrum exists in input, call replacement - if isinstance(input_model, (datamodels.MultiSlitModel, - datamodels.SlitModel, - datamodels.ImageModel, - datamodels.IFUImageModel, - datamodels.CubeModel)): - self.log.debug(f'Input is a {input_model.meta.model_type}.') + if isinstance( + input_model, + datamodels.MultiSlitModel + | datamodels.SlitModel + | datamodels.ImageModel + | datamodels.IFUImageModel + | datamodels.CubeModel, + ): + self.log.debug(f"Input is a {input_model.meta.model_type}.") elif isinstance(input_model, datamodels.ModelContainer): - self.log.debug('Input is a ModelContainer.') + self.log.debug("Input is a ModelContainer.") else: - self.log.error(f'Input is of type {str(type(input_model))} for which') - self.log.error('pixel_replace does not have an algorithm.') - self.log.error('Pixel replacement will be skipped.') - input_model.meta.cal_step.pixel_replace = 'SKIPPED' + self.log.error(f"Input is of type {str(type(input_model))} for which") + self.log.error("pixel_replace does not have an algorithm.") + self.log.error("Pixel replacement will be skipped.") + input_model.meta.cal_step.pixel_replace = "SKIPPED" return input_model pars = { - 'algorithm': self.algorithm, - 'n_adjacent_cols': self.n_adjacent_cols, + "algorithm": self.algorithm, + "n_adjacent_cols": self.n_adjacent_cols, } # calwebb_spec3 case / ModelContainer @@ -83,27 +89,27 @@ def process(self, input): except (AttributeError, KeyError): pass if asn_id is None: - asn_id = self.search_attr('asn_id') + asn_id = self.search_attr("asn_id") if asn_id is not None: - _make_output_path = self.search_attr( - '_make_output_path', parent_first=True - ) - self._make_output_path = partial( - _make_output_path, - asn_id=asn_id - ) + _make_output_path = self.search_attr("_make_output_path", parent_first=True) + self._make_output_path = partial(_make_output_path, asn_id=asn_id) # Check models to confirm they are the correct type for i, model in enumerate(output_model): run_pixel_replace = True - if model.meta.model_type in ['MultiSlitModel', 'SlitModel', - 'ImageModel', 'IFUImageModel', 'CubeModel']: - self.log.debug('Input is a {model.meta.model_type}.') + if model.meta.model_type in [ + "MultiSlitModel", + "SlitModel", + "ImageModel", + "IFUImageModel", + "CubeModel", + ]: + self.log.debug("Input is a {model.meta.model_type}.") else: - self.log.error(f'Input is of type {model.meta.model_type} for which') - self.log.error('pixel_replace does not have an algorithm.') - self.log.error('Pixel replacement will be skipped.') - model.meta.cal_step.pixel_replace = 'SKIPPED' + self.log.error(f"Input is of type {model.meta.model_type} for which") + self.log.error("pixel_replace does not have an algorithm.") + self.log.error("Pixel replacement will be skipped.") + model.meta.cal_step.pixel_replace = "SKIPPED" run_pixel_replace = False # all checks on model have passed. Now run pixel replacement @@ -111,7 +117,7 @@ def process(self, input): replacement = PixelReplacement(model, **pars) replacement.replace() output_model[i] = replacement.output - record_step_status(output_model[i], 'pixel_replace', success=True) + record_step_status(output_model[i], "pixel_replace", success=True) return output_model @@ -121,6 +127,6 @@ def process(self, input): result = input_model.copy() replacement = PixelReplacement(result, **pars) replacement.replace() - record_step_status(replacement.output, 'pixel_replace', success=True) + record_step_status(replacement.output, "pixel_replace", success=True) result = replacement.output return result