diff --git a/msaexp/utils.py b/msaexp/utils.py index 6265b4a..604637f 100644 --- a/msaexp/utils.py +++ b/msaexp/utils.py @@ -1822,7 +1822,12 @@ def drizzled_hdu_figure( lw=2, ) ap.step( - ptab["pfit"] / pmax, xpr, color="coral", where="pre", alpha=0.5, lw=1 + ptab["pfit"] / pmax, + xpr, + color="coral", + where="pre", + alpha=0.5, + lw=1, ) ap.fill_betweenx( xpr + 0.5, xpr * 0.0, ptab["pfit"] / pmax, color="coral", alpha=0.2 @@ -2614,6 +2619,231 @@ def get_prism_wave_bar_correction( return bar, is_wrapped +def slit_shutter_scale(slit): + """ + Pixel scale of the ``slit_frame`` coordinate + + Parameters + ---------- + slit : `~jwst.datamodels.SlitModel` + Slitlet data object + + Returns + ------- + pix_scale : float + delta ``slit_frame`` / delta pixel. + + """ + sh = slit.data.shape + wcs = slit.meta.wcs + d2s = wcs.get_transform("detector", "slit_frame") + + x0 = d2s(sh[1] // 2, sh[0] // 2) + x1 = d2s(sh[1] // 2, sh[0] // 2 + 1) + + dx = np.array(x1) - np.array(x0) + pix_scale = np.sqrt(dx[0] ** 2 + dx[1] ** 2) + + return pix_scale + + +def get_slit_coordinates(slit, trace_with_ypos=False, **kwargs): + """ + Get wavelength and cross-dispersion coordinate arrays for a 2D slitlet + + Parameters + ---------- + slit : `~jwst.datamodels.SlitModel` + Slitlet data object + + trace_with_ypos : bool + Include source y position in trace calculation + + Returns + ------- + wave : array-like + Wavelengths, microns + + slit_frame_y : array-like + ``slit_frame`` cross dispersion coordinate + + yslit : array-like + Pixel offset relative to the trace center + """ + sh = slit.data.shape + yp, xp = np.indices(sh) + + _res = slit_trace_center( + slit, + with_source_xpos=False, + with_source_ypos=trace_with_ypos, + index_offset=0.0, + ) + + _xtr, _ytr, _wtr, slit_ra, slit_dec = _res + + xslit = xp + yslit = yp - _ytr + ypix = yp + + wcs = slit.meta.wcs + d2w = wcs.get_transform("detector", "world") + + _ypi, _xpi = np.indices(slit.data.shape) + _ras, _des, _wave = d2w(_xpi, _ypi) + + d2s = wcs.get_transform("detector", "slit_frame") + _sx, _sy, _slam = np.array(d2s(_xpi, _ypi)) + slit_frame_y = _sy + + return _wave, slit_frame_y, yslit + + +def get_slit_data(slit, wrap=True, **kwargs): + """ + Parse slit coordinates and attributes + + Parameters + ---------- + slit : `~jwst.datamodels.SlitModel` + Slitlet data object + + wrap : bool + Parameter for `msaexp.utils.get_prism_wave_bar_correction` + + Returns + ------- + data : dict + Slit data + + """ + slx = slice(slit.xstart - 1, slit.xstart - 1 + slit.xsize) + sly = slice(slit.ystart - 1, slit.ystart - 1 + slit.ysize) + + wave, slit_frame_y, yslit = get_slit_coordinates(slit, **kwargs) + sh = wave.shape + + shutter_scale = slit_shutter_scale(slit) + + dy = 0.0 + + shutter_y = (slit_frame_y / shutter_scale + dy) / 5.0 + + if slit.meta.exposure.type == "NRS_FIXEDSLIT": + bar = np.ones_like(wave).reshape(sh) + bar_wrapped = False + else: + bar, bar_wrapped = get_prism_wave_bar_correction( + shutter_y.flatten(), + wave.flatten(), + num_shutters=np.minimum(len(slit.shutter_state), 3), + wrap=wrap, + ) + bar = bar.reshape(sh) + + corr = slit.data * 1 / bar + msk = ~np.isfinite(corr + wave + bar) + + corr[msk] = 0 + wave[msk] = 0 + bar[msk] = 0 + + data = { + "wave": wave, + "slit_frame_y": slit_frame_y, + "shutter_y": shutter_y, + "yslit": yslit, + "shape": sh, + "corr": corr, + "bar": bar, + "slx": slx, + "sly": sly, + "shutter_state": slit.shutter_state, + "num_shutters": len(slit.shutter_state), + "bar_wrapped": bar_wrapped, + } + + return data + + +def fixed_slit_flat_field( + slit, apply=True, verbose=True, force=False, **kwargs +): + """ + Fixed slit cross-dispersion profile flat field + + Parameters + ---------- + slit : `~jwst.datamodels.SlitModel` + Slitlet data object + + apply : bool + Apply to ``slit.data`` attribute + + force : bool + Apply even if ``slit.flat_profile`` already found + + Returns + ------- + flat_profile : array-like + 2D flat-field profile + + """ + import yaml + + if slit.meta.exposure.type != "NRS_FIXEDSLIT": + return None + + profile_file = os.path.join( + os.path.dirname(__file__), + "data/extended_sensitivity/", + "fixed_slit_flat_profile_{0}.yaml".format(slit.name.lower()), + ) + + if not os.path.exists(profile_file): + msg = f"fixed_slit_flat_field: {profile_file} not found" + grizli.utils.log_comment(grizli.utils.LOGFILE, msg, verbose=verbose) + return None + + msg = f"fixed_slit_flat_field: {profile_file} (apply={apply})" + grizli.utils.log_comment(grizli.utils.LOGFILE, msg, verbose=verbose) + + with open(profile_file) as fp: + fs_data = yaml.load(fp, Loader=yaml.Loader) + + slit_data = get_slit_data(slit) + + bspl = grizli.utils.bspline_templates( + slit_data["yslit"].flatten(), + df=fs_data["ydf"], + minmax=fs_data["minmax"], + get_matrix=True, + ) + + coeffs = np.array(fs_data["coeffs"]) + flat_profile = bspl.dot(coeffs).reshape(slit.data.shape) + flat_profile[flat_profile < 0] = 0.0 + + if apply: + if (not hasattr(slit, "flat_profile")) | force: + if hasattr(slit, "flat_profile"): + slit.flat_profile *= flat_profile + else: + slit.flat_profile = flat_profile + + slit.data /= flat_profile + slit.err /= flat_profile + slit.var_rnoise /= flat_profile**2 + slit.var_poisson /= flat_profile**2 + else: + msg = f"fixed_slit_flat_field: existing flat_profile found" + grizli.utils.log_comment( + grizli.utils.LOGFILE, msg, verbose=verbose + ) + + return flat_profile + + def slit_extended_flux_calibration( slit, sens_file=None,