diff --git a/CHANGES.rst b/CHANGES.rst index 160699a..7bcb599 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,7 +1,13 @@ 1.6.dev (unreleased) ================ -- none yet +- Define the ``input_units``, ``return_units``, ``input_units_equivalencies``, + and ``bounding_box`` properties for all models. Use Astropy models' built-in + unit conversion support. + +- All models now require inputs with valid units (wavelength, wavenumber, or + frequency). Dimensionless inputs are no longer automatically converted to + wavenumber. 1.5 (2024-08-16) ================ diff --git a/docs/dust_extinction/dev_model.rst b/docs/dust_extinction/dev_model.rst index 020f7ee..357365a 100644 --- a/docs/dust_extinction/dev_model.rst +++ b/docs/dust_extinction/dev_model.rst @@ -16,7 +16,7 @@ All All dust extinction models have at least the following: * A member variable `x_range` that that define the valid range of wavelengths. These are defined in inverse microns as is common for extinction curve research. -* A member function `evaluate` that computes the extinction at a given `x` and any model parameter values. The `x` values are checked to be within the valid `x_range`. The `x` values should have astropy.units. If they do not, then they are assumed to be in inverse microns and a warning is issued stating such. +* A member function `evaluate` that computes the extinction at a given `x` and any model parameter values. The `x` values are checked to be within the valid `x_range`. The `x` values should have astropy.units. All of these classes used in ``dust_extinction`` are based on the `Model `_ astropy.modeling class. diff --git a/dust_extinction/averages.py b/dust_extinction/averages.py index 0135b8c..0a99b23 100644 --- a/dust_extinction/averages.py +++ b/dust_extinction/averages.py @@ -5,7 +5,6 @@ from astropy.table import Table from astropy.modeling.models import PowerLaw1D -from .helpers import _get_x_in_wavenumbers, _test_valid_x_range from .baseclasses import BaseExtModel from .shapes import P92, G21, _curve_F99_method @@ -92,17 +91,14 @@ class RL85_MWGC(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 1e-6 - def evaluate(self, in_x): + def evaluate(self, x): r""" RL85 MWGC function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] - - internally wavenumbers are used + x: float + expects either x in units of wavelengths, frequency, or wavenumber Returns ------- @@ -114,16 +110,11 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function using simple linear interpolation # avoids negative values of alav that happens with cubic splines f = interp1d(self.obsdata_x, self.obsdata_axav) - return f(x) + return f(x.value) class RRP89_MWGC(BaseExtModel): @@ -191,15 +182,14 @@ class RRP89_MWGC(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 1e-6 - def evaluate(self, in_x): + def evaluate(self, x): r""" RRP89 MWGC function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -213,16 +203,11 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function using simple linear interpolation # avoids negative values of alav that happens with cubic splines f = interp1d(self.obsdata_x, self.obsdata_axav) - return f(x) + return f(x.value) class B92_MWAvg(BaseExtModel): @@ -293,15 +278,14 @@ class B92_MWAvg(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 6e-3 - def evaluate(self, in_x): + def evaluate(self, x): """ B92 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -315,16 +299,10 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.name) - # define the function allowing for spline interpolation f = interp1d(self.obsdata_x, self.obsdata_axav) - return f(x) + return f(x.value) class G03_SMCBar(BaseExtModel): @@ -409,15 +387,14 @@ class G03_SMCBar(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 6e-2 - def evaluate(self, in_x): + def evaluate(self, x): """ G03 SMCBar function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -447,7 +424,7 @@ def evaluate(self, in_x): # return A(x)/A(V) return _curve_F99_method( - in_x, + x.value, self.Rv, C1, C2, @@ -540,15 +517,14 @@ class G03_LMCAvg(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 6e-2 - def evaluate(self, in_x): + def evaluate(self, x): """ G03 LMCAvg function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -576,7 +552,7 @@ def evaluate(self, in_x): # return A(x)/A(V) return _curve_F99_method( - in_x, + x.value, self.Rv, C1, C2, @@ -672,15 +648,14 @@ class G03_LMC2(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 6e-2 - def evaluate(self, in_x): + def evaluate(self, x): """ G03 LMC2 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -708,7 +683,7 @@ def evaluate(self, in_x): # return A(x)/A(V) return _curve_F99_method( - in_x, + x.value, self.Rv, C1, C2, @@ -789,15 +764,14 @@ class I05_MWAvg(BaseExtModel): # accuracy of the observed data based on published table obsdata_tolerance = 1e-6 - def evaluate(self, in_x): + def evaluate(self, x): """ I05 MWAvg function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -811,15 +785,10 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function allowing for spline interpolation f = interp1d(self.obsdata_x, self.obsdata_axav) - return f(x) + return f(x.value) class CT06_MWGC(BaseExtModel): @@ -891,13 +860,13 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ CT06 MWGC function Parameters ---------- - in_x: float + x: float expects either x in units of wavelengths or frequency or assumes wavelengths in wavenumbers [1/micron] @@ -913,15 +882,10 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function allowing for spline interpolation f = interp1d(self.obsdata_x, self.obsdata_axav) - return f(x) + return f(x.value) class CT06_MWLoc(BaseExtModel): @@ -993,15 +957,14 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ CG06 MWLoc function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -1015,15 +978,10 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function allowing for spline interpolation f = interp1d(self.obsdata_x, self.obsdata_axav) - return f(x) + return f(x.value) class GCC09_MWAvg(BaseExtModel): @@ -1145,15 +1103,14 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ GCC09_MWAvg function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -1167,11 +1124,6 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # P92 parameters fit to the data using uncs as weights p92_fit = P92( BKG_amp=203.805939127, @@ -1201,7 +1153,7 @@ def evaluate(self, in_x): ) # return A(x)/A(V) - return p92_fit(in_x) + return p92_fit(x) class F11_MWGC(BaseExtModel): @@ -1275,15 +1227,14 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ F11 MWGC function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -1297,15 +1248,10 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function allowing for spline interpolation f = interp1d(self.obsdata_x, self.obsdata_axav) - return f(x) + return f(x.value) class G21_MWAvg(BaseExtModel): @@ -1402,15 +1348,14 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ G21_MWAvg function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -1424,11 +1369,6 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # G21 parameters fit to the data using uncs as weights g21_fit = G21( scale=0.366, @@ -1445,7 +1385,7 @@ def evaluate(self, in_x): # return A(x)/A(V) # G21 a full dust_extinction model, hence send in x with units - return g21_fit(in_x) + return g21_fit(x) class D22_MWAvg(BaseExtModel): @@ -1526,15 +1466,14 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ D22_MWAvg function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -1548,17 +1487,12 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # setup the model d22_fit = PowerLaw1D(alpha=1.71, amplitude=0.386, x_0=1.0) # return A(x)/A(V) # Note that model in D22 was done versus wavelength in microns - return d22_fit(1.0 / x) + return d22_fit(1.0 / x.value) class G24_SMCAvg(BaseExtModel): @@ -1643,15 +1577,14 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ G24 SMCAvg function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -1677,7 +1610,7 @@ def evaluate(self, in_x): # return A(x)/A(V) return _curve_F99_method( - in_x, + x.value, self.Rv, C1, C2, @@ -1778,15 +1711,14 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - def evaluate(self, in_x): + def evaluate(self, x): """ G24 SMCBumps function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -1812,7 +1744,7 @@ def evaluate(self, in_x): # return A(x)/A(V) return _curve_F99_method( - in_x, + x.value, self.Rv, C1, C2, diff --git a/dust_extinction/baseclasses.py b/dust_extinction/baseclasses.py index 9b88630..04cca1d 100644 --- a/dust_extinction/baseclasses.py +++ b/dust_extinction/baseclasses.py @@ -2,8 +2,9 @@ from scipy.interpolate import interp1d from astropy.modeling import Model, Parameter, InputParameterError +from astropy import units as u -from .helpers import _get_x_in_wavenumbers, _test_valid_x_range +from .helpers import _test_valid_x_range __all__ = ["BaseExtModel", "BaseExtRvModel", "BaseExtRvAfAModel", "BaseExtGrainModel"] @@ -15,6 +16,19 @@ class BaseExtModel(Model): n_inputs = 1 n_outputs = 1 + input_units = {"x": u.micron**-1} + return_units = {"y": u.dimensionless_unscaled} + input_units_equivalencies = {"x": u.spectral()} + _input_units_strict = True + + def bounding_box(self): + return self.x_range / self.input_units["x"] + + def prepare_inputs(self, *inputs, model_set_axis=None, equivalencies=None, **kwargs): + xs, *rest = super().prepare_inputs(*inputs, model_set_axis=model_set_axis, equivalencies=equivalencies, **kwargs) + for x in xs: + _test_valid_x_range(x.value, self.x_range, self.__class__.__name__) + return xs, *rest def extinguish(self, x, Av=None, Ebv=None): """ @@ -170,15 +184,14 @@ class BaseExtGrainModel(BaseExtModel): None """ - def evaluate(self, in_x): + def evaluate(self, x): """ Generic dust grain model function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -192,14 +205,9 @@ def evaluate(self, in_x): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # define the function allowing for spline interpolation # fill value needed to handle numerical issues at the edges # the x values has already been checked to be in range f = interp1d(self.data_x, self.data_axav, fill_value="extrapolate") - return f(x) + return f(x.value) diff --git a/dust_extinction/helpers.py b/dust_extinction/helpers.py index d42ff82..a61fcf1 100644 --- a/dust_extinction/helpers.py +++ b/dust_extinction/helpers.py @@ -1,44 +1,7 @@ -import warnings - import numpy as np from scipy.special import comb -import astropy.units as u - -__all__ = ["_get_x_in_wavenumbers", "_test_valid_x_range", "_smoothstep"] - - -def _get_x_in_wavenumbers(in_x): - """ - Convert input x to wavenumber given x has units. - Otherwise, assume x is in waveneumbers and issue a warning to this effect. - - Parameters - ---------- - in_x : astropy.quantity or simple floats - x values - - Returns - ------- - x : floats - input x values in wavenumbers w/o units - """ - # handles the case where x is a scaler - in_x = np.atleast_1d(in_x) - - # check if in_x is an astropy quantity, if not issue a warning - if not isinstance(in_x, u.Quantity): - warnings.warn( - "x has no units, assuming x units are inverse microns", UserWarning - ) - - # convert to wavenumbers (1/micron) if x input in units - # otherwise, assume x in appropriate wavenumber units - with u.add_enabled_equivalencies(u.spectral()): - x_quant = u.Quantity(in_x, 1.0 / u.micron, dtype=np.float64) - # strip the quantity to avoid needing to add units to all the - # polynomical coefficients - return x_quant.value +__all__ = ["_smoothstep"] def _test_valid_x_range(x, x_range, outname): diff --git a/dust_extinction/parameter_averages.py b/dust_extinction/parameter_averages.py index 6b1023e..d9f00e2 100644 --- a/dust_extinction/parameter_averages.py +++ b/dust_extinction/parameter_averages.py @@ -8,7 +8,7 @@ from astropy.modeling.models import Drude1D, Polynomial1D, PowerLaw1D from .baseclasses import BaseExtRvModel, BaseExtRvAfAModel -from .helpers import _get_x_in_wavenumbers, _test_valid_x_range, _smoothstep +from .helpers import _smoothstep from .averages import G03_SMCBar from .shapes import _curve_F99_method, _modified_drude, FM90 @@ -88,15 +88,14 @@ class CCM89(BaseExtRvModel): x_range = x_range_CCM89 @staticmethod - def evaluate(in_x, Rv): + def evaluate(x, Rv): """ CCM89 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -110,15 +109,12 @@ def evaluate(in_x, Rv): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_CCM89, "CCM89") + shape = np.shape(x) + x = np.atleast_1d(x.value) # setup the a & b coefficient vectors - n_x = len(x) - a = np.zeros(n_x) - b = np.zeros(n_x) + a = np.zeros(x.shape) + b = np.zeros(x.shape) # define the ranges ir_indxs = np.where(np.logical_and(0.3 <= x, x < 1.1)) @@ -160,7 +156,7 @@ def evaluate(in_x, Rv): b[fuv_indxs] = np.polyval((0.374, -0.42, 4.257, 13.67), y) # return A(x)/A(V) - return a + b / Rv + return (a + b / Rv).reshape(shape) class O94(BaseExtRvModel): @@ -224,15 +220,14 @@ class O94(BaseExtRvModel): x_range = x_range_O94 @staticmethod - def evaluate(in_x, Rv): + def evaluate(x, Rv): """ O94 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -246,15 +241,12 @@ def evaluate(in_x, Rv): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_O94, "O94") + shape = np.shape(x) + x = np.atleast_1d(x.value) # setup the a & b coefficient vectors - n_x = len(x) - a = np.zeros(n_x) - b = np.zeros(n_x) + a = np.zeros(x.shape) + b = np.zeros(x.shape) # define the ranges ir_indxs = np.where(np.logical_and(0.3 <= x, x < 1.1)) @@ -296,7 +288,7 @@ def evaluate(in_x, Rv): b[fuv_indxs] = np.polyval((0.374, -0.42, 4.257, 13.67), y) # return A(x)/A(V) - return a + b / Rv + return (a + b / Rv).reshape(shape) class F99(BaseExtRvModel): @@ -362,15 +354,14 @@ class F99(BaseExtRvModel): Rv_range = [2.0, 6.0] x_range = x_range_F99 - def evaluate(self, in_x, Rv): + def evaluate(self, x, Rv): """ F99 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -430,7 +421,7 @@ def evaluate(self, in_x, Rv): # return A(x)/A(V) return _curve_F99_method( - in_x, + x.value, Rv, C1, C2, @@ -511,15 +502,14 @@ class F04(BaseExtRvModel): Rv_range = [2.0, 6.0] x_range = x_range_F04 - def evaluate(self, in_x, Rv): + def evaluate(self, x, Rv): """ F04 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -578,7 +568,7 @@ def evaluate(self, in_x, Rv): # return A(x)/A(V) return _curve_F99_method( - in_x, + x.value, Rv, C1, C2, @@ -657,16 +647,16 @@ class VCG04(BaseExtRvModel): x_range = x_range_VCG04 @staticmethod - def evaluate(in_x, Rv): + def evaluate(x, Rv): """ VCG04 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] - internally wavenumbers are used + x: float + expects either x in units of wavelengths, frequency, or wavenumber + + internally wavenumbers are used Returns ------- @@ -678,17 +668,12 @@ def evaluate(in_x, Rv): ValueError Input x values outside of defined range """ - # convert to wavenumbers (1/micron) if x input in units - # otherwise, assume x in appropriate wavenumber units - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_VCG04, "VCG04") + shape = np.shape(x) + x = np.atleast_1d(x.value) # setup the a & b coefficient vectors - n_x = len(x) - a = np.zeros(n_x) - b = np.zeros(n_x) + a = np.zeros(x.shape) + b = np.zeros(x.shape) # define the ranges nuv_indxs = np.where(np.logical_and(3.3 <= x, x <= 8.0)) @@ -710,7 +695,7 @@ def evaluate(in_x, Rv): b[fnuv_indxs] += 0.2060 * (y**2) + 0.0550 * (y**3) # return A(x)/A(V) - return a + b / Rv + return (a + b / Rv).reshape(shape) class GCC09(BaseExtRvModel): @@ -777,15 +762,14 @@ class GCC09(BaseExtRvModel): x_range = x_range_GCC09 @staticmethod - def evaluate(in_x, Rv): + def evaluate(x, Rv): """ GCC09 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -799,17 +783,12 @@ def evaluate(in_x, Rv): ValueError Input x values outside of defined range """ - # convert to wavenumbers (1/micron) if x input in units - # otherwise, assume x in appropriate wavenumber units - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_GCC09, "GCC09") + shape = np.shape(x) + x = np.atleast_1d(x.value) # setup the a & b coefficient vectors - n_x = len(x) - a = np.zeros(n_x) - b = np.zeros(n_x) + a = np.zeros(x.shape) + b = np.zeros(x.shape) # define the ranges nuv_indxs = np.where(np.logical_and(3.3 <= x, x <= 11.0)) @@ -831,7 +810,7 @@ def evaluate(in_x, Rv): b[fnuv_indxs] += 0.531 * (y**2) + 0.0544 * (y**3) # return A(x)/A(V) - return a + b / Rv + return (a + b / Rv).reshape(shape) class M14(BaseExtRvModel): @@ -914,15 +893,14 @@ class M14(BaseExtRvModel): Rv_range = [2.0, 6.0] x_range = x_range_M14 - def evaluate(self, in_x, Rv): + def evaluate(self, x, Rv): """ M14 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -936,10 +914,7 @@ def evaluate(self, in_x, Rv): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) + x = x.value # just in case someone calls evaluate explicitly Rv = np.atleast_1d(Rv) @@ -1157,15 +1132,14 @@ class G16(BaseExtRvAfAModel): x_range = x_range_G16 @staticmethod - def evaluate(in_x, RvA, fA): + def evaluate(x, RvA, fA): """ G16 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -1179,10 +1153,7 @@ def evaluate(in_x, RvA, fA): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_G16, "G16") + x = x.value # just in case someone calls evaluate explicitly RvA = np.atleast_1d(RvA) @@ -1287,15 +1258,14 @@ def __init__(self, Rv=3.1, **kwargs): super().__init__(Rv, **kwargs) - def evaluate(self, in_x, Rv): + def evaluate(self, x, Rv): """ F19 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -1309,13 +1279,6 @@ def evaluate(self, in_x, Rv): ValueError Input x values outside of defined range """ - # convert to wavenumbers (1/micron) if x input in units - # otherwise, assume x in appropriate wavenumber units - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) - # just in case someone calls evaluate explicitly Rv = np.atleast_1d(Rv) @@ -1323,7 +1286,7 @@ def evaluate(self, in_x, Rv): Rv = Rv[0] # use spline interpolation to evaluate the curve for the input x values - k_rV = interpolate.splev(x, self.spline_rep, der=0) + k_rV = interpolate.splev(x.value, self.spline_rep, der=0) # convert to A(x)/A(55) from E(x-55)/E(44-55) a_rV = k_rV / Rv + 1.0 @@ -1400,15 +1363,14 @@ def __init__(self, Rv=3.1, **kwargs): super().__init__(Rv, **kwargs) - def evaluate(self, in_x, Rv): + def evaluate(self, x, Rv): """ D22 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -1422,12 +1384,7 @@ def evaluate(self, in_x, Rv): ValueError Input x values outside of defined range """ - # convert to wavenumbers (1/micron) if x input in units - # otherwise, assume x in appropriate wavenumber units - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) + x = x.value # just in case someone calls evaluate explicitly Rv = np.atleast_1d(Rv) @@ -1499,15 +1456,14 @@ class G23(BaseExtRvModel): Rv_range = [2.3, 5.6] x_range = x_range_G23 - def evaluate(self, in_x, Rv): + def evaluate(self, x, Rv): """ G23 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -1521,15 +1477,12 @@ def evaluate(self, in_x, Rv): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, "G23") + shape = np.shape(x) + x = np.atleast_1d(x.value) # setup the a & b coefficient vectors - n_x = len(x) - self.a = np.zeros(n_x) - self.b = np.zeros(n_x) + self.a = np.zeros(x.shape) + self.b = np.zeros(x.shape) # define the ranges ir_indxs = np.where(np.logical_and(1.0 / 35.0 <= x, x < 1.0 / 1.0)) @@ -1615,7 +1568,7 @@ def evaluate(self, in_x, Rv): self.b[uvopt_overlap] += weights * m20_model_b(x[uvopt_overlap]) # return A(x)/A(V) - return self.a + self.b * (1 / Rv - 1 / 3.1) + return (self.a + self.b * (1 / Rv - 1 / 3.1)).reshape(shape) @staticmethod def nirmir_intercept(x, params): diff --git a/dust_extinction/shapes.py b/dust_extinction/shapes.py index 6a7a74c..a5f9b09 100644 --- a/dust_extinction/shapes.py +++ b/dust_extinction/shapes.py @@ -4,7 +4,7 @@ import astropy.units as u from astropy.modeling import Fittable1DModel, Parameter -from .helpers import _get_x_in_wavenumbers, _test_valid_x_range +from .baseclasses import BaseExtModel __all__ = ["FM90", "FM90_B3", "P92", "G21"] @@ -13,7 +13,7 @@ def _curve_F99_method( - in_x, + x, Rv, C1, C2, @@ -32,8 +32,7 @@ def _curve_F99_method( Parameters ---------- in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + assumes wavelengths in wavenumbers [1/micron] internally wavenumbers are used @@ -74,13 +73,10 @@ def _curve_F99_method( ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, valid_x_range, model_name) - # initialize extinction curve storage - axav = np.zeros(len(x)) + shape = np.shape(x) + x = np.atleast_1d(x) + axav = np.zeros(x.shape) # x value above which FM90 parametrization used x_cutval_uv = 10000.0 / 2700.0 @@ -132,7 +128,7 @@ def _curve_F99_method( axav[indxs_opir] = interpolate.splev(x[indxs_opir], spline_rep, der=0) # return A(x)/A(V) - return axav + return axav.reshape(shape) def _modified_drude(x, scale, x_o, gamma_o, asym): @@ -166,7 +162,15 @@ def _modified_drude(x, scale, x_o, gamma_o, asym): return y -class FM90(Fittable1DModel): +class BaseExtFittable1DModel(BaseExtModel, Fittable1DModel): + + def _parameter_units_for_data_units(self, inputs_unit, outputs_unit): + # Declare that all of the parameters are dimensionless, + # regardless of the input units. + return dict.fromkeys(self.param_names) + + +class FM90(BaseExtFittable1DModel): r""" Fitzpatrick & Massa (1990) 6 parameter ultraviolet shape model @@ -241,9 +245,6 @@ class FM90(Fittable1DModel): plt.show() """ - n_inputs = 1 - n_outputs = 1 - # bounds based on Gordon et al. (2024) results C1 = Parameter( description="linear term: y-intercept", default=0.10, bounds=(-10.0, 5.0) @@ -257,15 +258,14 @@ class FM90(Fittable1DModel): x_range = x_range_FM90 @staticmethod - def evaluate(in_x, C1, C2, C3, C4, xo, gamma): + def evaluate(x, C1, C2, C3, C4, xo, gamma): """ FM90 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -279,10 +279,8 @@ def evaluate(in_x, C1, C2, C3, C4, xo, gamma): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_FM90, "FM90") + shape = np.shape(x) + x = np.atleast_1d(x.value) # linear term exvebv = C1 + C2 * x @@ -298,14 +296,15 @@ def evaluate(in_x, C1, C2, C3, C4, xo, gamma): exvebv[fnuv_indxs] += C4 * (0.5392 * (y**2) + 0.05644 * (y**3)) # return E(x-V)/E(B-V) - return exvebv + return exvebv.reshape(shape) @staticmethod - def fit_deriv(in_x, C1, C2, C3, C4, xo, gamma): + def fit_deriv(x, C1, C2, C3, C4, xo, gamma): """ Derivatives of the FM90 function with respect to the parameters """ - x = in_x + shape = np.shape(x) + x = np.atleast_1d(x.value) # useful quantitites x2 = x**2 @@ -315,7 +314,7 @@ def fit_deriv(in_x, C1, C2, C3, C4, xo, gamma): denom = (x2mxo2_2 - x2 * g2) ** 2 # derivatives - d_C1 = np.full((len(x)), 1.0) + d_C1 = np.full(x.shape, 1.0) d_C2 = x d_C3 = x2 / (x2mxo2_2 + x2 * g2) @@ -324,13 +323,15 @@ def fit_deriv(in_x, C1, C2, C3, C4, xo, gamma): d_gamma = (2.0 * C2 * (x2**2) * gamma) / denom - d_C4 = np.zeros((len(x))) + d_C4 = np.zeros(x.shape) fuv_indxs = np.where(x >= 5.9) if len(fuv_indxs) > 0: y = x[fuv_indxs] - 5.9 d_C4[fuv_indxs] = 0.5392 * (y**2) + 0.05644 * (y**3) - return [d_C1, d_C2, d_C3, d_C4, d_xo, d_gamma] + return [d_C1.reshape(shape), d_C2.reshape(shape), d_C3.reshape(shape), + d_C4.reshape(shape), d_xo.reshape(shape), + d_gamma.reshape(shape)] # @property # def input_units(self): @@ -346,7 +347,7 @@ def fit_deriv(in_x, C1, C2, C3, C4, xo, gamma): # 'C4': outputs_unit[self.outputs[0]]} -class FM90_B3(Fittable1DModel): +class FM90_B3(BaseExtFittable1DModel): r""" Fitzpatrick & Massa (1990) 6 parameter ultraviolet shape model Version with bump amplitude B3 = C3/gamma^2 @@ -422,9 +423,6 @@ class FM90_B3(Fittable1DModel): plt.show() """ - n_inputs = 1 - n_outputs = 1 - # bounds based on Gordon et al. (2024) results C1 = Parameter( description="linear term: y-intercept", default=0.10, bounds=(-10.0, 5.0) @@ -438,15 +436,14 @@ class FM90_B3(Fittable1DModel): x_range = x_range_FM90 @staticmethod - def evaluate(in_x, C1, C2, B3, C4, xo, gamma): + def evaluate(x, C1, C2, B3, C4, xo, gamma): """ FM90 function Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -460,10 +457,7 @@ def evaluate(in_x, C1, C2, B3, C4, xo, gamma): ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, x_range_FM90, "FM90_B3") + x = x.value # linear term exvebv = C1 + C2 * x @@ -482,7 +476,7 @@ def evaluate(in_x, C1, C2, B3, C4, xo, gamma): return exvebv -class P92(Fittable1DModel): +class P92(BaseExtFittable1DModel): r""" Pei (1992) 24 parameter shape model @@ -619,9 +613,6 @@ class P92(Fittable1DModel): plt.show() """ - n_inputs = 1 - n_outputs = 1 - # constant for conversion from Ax/Ab to (more standard) Ax/Av AbAv = 1.0 / 3.08 + 1.0 @@ -711,7 +702,7 @@ def _p92_single_term(in_lambda, amplitude, cen_wave, b, n): def evaluate( self, - in_x, + x, BKG_amp, BKG_lambda, BKG_b, @@ -742,9 +733,8 @@ def evaluate( Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber internally wavenumbers are used @@ -758,10 +748,7 @@ def evaluate( ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, self.__class__.__name__) + x = x.value # calculate the terms lam = 1.0 / x @@ -781,7 +768,7 @@ def evaluate( fit_deriv = None -class G21(Fittable1DModel): +class G21(BaseExtFittable1DModel): r""" Gordon et al. (2021) powerlaw plus two modified Drude profiles (for the 10 & 20 micron silicate features) @@ -894,7 +881,7 @@ class G21(Fittable1DModel): def evaluate( self, - in_x, + x, scale, alpha, sil1_amp, @@ -911,9 +898,8 @@ def evaluate( Parameters ---------- - in_x: float - expects either x in units of wavelengths or frequency - or assumes wavelengths in wavenumbers [1/micron] + x: float + expects either x in units of wavelengths, frequency, or wavenumber Returns ------- @@ -925,12 +911,8 @@ def evaluate( ValueError Input x values outside of defined range """ - x = _get_x_in_wavenumbers(in_x) - - # check that the wavenumbers are within the defined range - _test_valid_x_range(x, self.x_range, "G21") - wave = 1 / x + wave = 1 / x.value # powerlaw axav = scale * (wave ** (-1.0 * alpha)) diff --git a/dust_extinction/tests/test_fm90.py b/dust_extinction/tests/test_fm90.py index 30dd842..52165d3 100644 --- a/dust_extinction/tests/test_fm90.py +++ b/dust_extinction/tests/test_fm90.py @@ -75,7 +75,7 @@ def test_FM90_fitting(): fm90_init = FM90() fit = LevMarLSQFitter() - g03_fit = fit(fm90_init, x[gindxs], y[gindxs]) + g03_fit = fit(fm90_init, x[gindxs] / u.micron, y[gindxs]) fit_vals = [ g03_fit.C1.value, g03_fit.C2.value, diff --git a/dust_extinction/tests/test_g16.py b/dust_extinction/tests/test_g16.py index e95e455..82207d1 100644 --- a/dust_extinction/tests/test_g16.py +++ b/dust_extinction/tests/test_g16.py @@ -40,7 +40,7 @@ def test_extinction_G16_fA_0_values(): # get the correct values gmodel = G03_SMCBar() - x = gmodel.obsdata_x + x = gmodel.obsdata_x / u.micron cor_vals = gmodel.obsdata_axav tolerance = gmodel.obsdata_tolerance diff --git a/dust_extinction/tests/test_warnings.py b/dust_extinction/tests/test_warnings.py index c179aab..b589852 100644 --- a/dust_extinction/tests/test_warnings.py +++ b/dust_extinction/tests/test_warnings.py @@ -16,17 +16,6 @@ ) -@pytest.mark.parametrize("model", all_models) -def test_nounits_warning(model): - ext = model() - x = np.arange(ext.x_range[0], ext.x_range[1], 0.1) - - with pytest.warns( - UserWarning, match="x has no units, assuming x units are inverse microns" - ): - ext(x) - - @pytest.mark.skip("Testing for no warnings got more complicated/does not work") @pytest.mark.parametrize("model", all_models) def test_units_nowarning_expected(model): @@ -67,7 +56,6 @@ def test_invalid_wavenumbers(model): tmodel = model() x_invalid_all = [-1.0, 0.9 * tmodel.x_range[0], 1.1 * tmodel.x_range[1]] for x_invalid in x_invalid_all: - _invalid_x_range(x_invalid, tmodel, tmodel.__class__.__name__) _invalid_x_range(x_invalid / u.micron, tmodel, tmodel.__class__.__name__) _invalid_x_range(u.micron / x_invalid, tmodel, tmodel.__class__.__name__) _invalid_x_range( @@ -79,7 +67,7 @@ def test_invalid_wavenumbers(model): def test_extinguish_no_av_or_ebv(model): ext = model() with pytest.raises(InputParameterError) as exc: - ext.extinguish(ext.x_range[0]) + ext.extinguish(ext.x_range[0] / u.micron) assert exc.value.args[0] == "neither Av or Ebv passed, one required"