Skip to content

Commit

Permalink
add mask replacement of nan and negative mag errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sschmidt23 committed Jun 24, 2024
1 parent 8dac610 commit 683a763
Showing 1 changed file with 26 additions and 8 deletions.
34 changes: 26 additions & 8 deletions src/rail/estimation/algos/gpz.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
import qp


# set of magnitude errors that will replace values that are negative or np.nan
default_err_repl = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]

def _prepare_data(data_dict, bands, err_bands, nondet_val, maglims, logflag):

def _prepare_data(data_dict, bands, err_bands, nondet_val, maglims, logflag, repl_err_vals):
"""Put data in 2D np array expected by GPz.
For some reason they like to take the log of the magnitude errors, so
have that as a boolean option. Also replace nondetect vals for each
Expand All @@ -21,14 +24,17 @@ def _prepare_data(data_dict, bands, err_bands, nondet_val, maglims, logflag):
numbands = len(bands)
totrows = len(data_dict[bands[0]])
data = np.empty([totrows, 2 * numbands])
for i, (band, eband, lim) in enumerate(zip(bands, err_bands, maglims.values())):
for i, (band, eband, lim, rplval) in enumerate(zip(bands, err_bands, maglims.values(), repl_err_vals)):
data[:, i] = data_dict[band]
mask = np.bitwise_or(np.isclose(data_dict[band], nondet_val), np.isnan(data_dict[band]))
data[:, i][mask] = lim
errband = data_dict[eband]
emask = np.bitwise_or(errband <= 0., np.isnan(errband))
errband[emask] = rplval
if logflag:
data[:, numbands + i] = np.log(data_dict[eband])
data[:, numbands + i] = np.log(errband)
else: # pragma: no cover
data[:, numbands + i] = data_dict[eband]
data[:, numbands + i] = errband
data[:, numbands + i][mask] = 1.0
return data

Expand Down Expand Up @@ -63,7 +69,8 @@ class GPzInformer(CatInformer):
pca_decorrelate=Param(bool, True, msg="if True, decorrelate data using PCA as preprocessing stage"),
max_iter=Param(int, 200, msg="max number of iterations"),
max_attempt=Param(int, 100, msg="max iterations if no progress on validation"),
log_errors=Param(bool, True, msg="if true, take log of magnitude errors")
log_errors=Param(bool, True, msg="if true, take log of magnitude errors"),
replace_error_vals=Param(list, default_err_repl, msg="list of values to replace negative and nan mag err values")
)

def __init__(self, args, comm=None):
Expand All @@ -81,9 +88,14 @@ def run(self):
else: # pragma: no cover
training_data = self.get_data('input')

# check that lengths of bands, err_bands, and replace_error_vals match
if not np.logical_and(len(self.config.bands) == len(self.config.err_bands),
len(self.config.err_bands) == len(self.config.replace_error_vals)): # pragma: no cover
raise ValueError("lengths of bands, err_bands, and replace_error_vals do not match!")

input_array = _prepare_data(training_data, self.config.bands, self.config.err_bands,
self.config.nondetect_val, self.config.mag_limits,
self.config.log_errors)
self.config.log_errors, self.config.replace_error_vals)

sz = np.expand_dims(training_data[self.config.redshift_col], -1)
# need permutation mask to define training vs validation
Expand Down Expand Up @@ -128,19 +140,25 @@ class GPzEstimator(CatEstimator):
bands=SHARED_PARAMS,
err_bands=SHARED_PARAMS,
ref_band=SHARED_PARAMS,
log_errors=Param(bool, True, msg="if true, take log of magnitude errors"))
log_errors=Param(bool, True, msg="if true, take log of magnitude errors"),
replace_error_vals=Param(list, default_err_repl, msg="list of values to replace negative and nan mag err values")
)

def __init__(self, args, comm=None):
""" Constructor:
Do CatEstimator specific initialization """
CatEstimator.__init__(self, args, comm=comm)
self.zgrid = None
# check that lengths of bands, err_bands, and replace_error_vals match
if not np.logical_and(len(self.config.bands) == len(self.config.err_bands),
len(self.config.err_bands) == len(self.config.replace_error_vals)): # pragma: no cover
raise ValueError("lengths of bands, err_bands, and replace_error_vals do not match!")

def _process_chunk(self, start, end, data, first):
print(f"Process {self.rank} estimating GPz PZ PDF for rows {start:,} - {end:,}")
test_array = _prepare_data(data, self.config.bands, self.config.err_bands,
self.config.nondetect_val, self.config.mag_limits,
self.config.log_errors)
self.config.log_errors, self.config.replace_error_vals)

mu, totalV, modelV, noiseV, _ = self.model.predict(test_array)
ens = qp.Ensemble(qp.stats.norm, data=dict(loc=mu, scale=totalV))
Expand Down

0 comments on commit 683a763

Please sign in to comment.