From 5ecf211947135e77e183207b615b0a68d3c204ad Mon Sep 17 00:00:00 2001 From: William D'Arcy Kenworthy Date: Wed, 3 Jul 2024 10:05:50 -0500 Subject: [PATCH] Accidentally removed outputdir from gradientdescent causing a crash Fixed an initialization bug which wrote intial x0 guess to x1 instead --- saltshaker/training/TrainSALT.py | 5 ++--- saltshaker/training/base.py | 14 +++++++------- saltshaker/training/optimizers/gradientdescent.py | 1 + 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/saltshaker/training/TrainSALT.py b/saltshaker/training/TrainSALT.py index 7e46d4fe..6d150e7d 100755 --- a/saltshaker/training/TrainSALT.py +++ b/saltshaker/training/TrainSALT.py @@ -336,7 +336,6 @@ def bestfit(p): from numpy.random import default_rng rng = default_rng(134912348) - for sn in datadict.keys(): if self.options.snparlist: # hacky matching, but SN names are a mess as usual @@ -354,8 +353,8 @@ def bestfit(p): else: guess[parlist==f'x{i}_{sn}'] = rng.standard_normal() if snpar['x0'][iSN]<= 0: - log.warning(f'Bad input value for {sn}: x0= {snpar["x0"][iSN]}') - guess[parlist==f'x{i}_{sn}'] = 10**(-0.4*(cosmo.distmod(datadict[sn].zHelio).value-19.36-10.635)) + log.warning(f'Bad input value for {sn}: x0={ float(snpar["x0"][iSN])}') + guess[parlist==f'x0_{sn}'] = 10**(-0.4*(cosmo.distmod(datadict[sn].zHelio).value-19.36-10.635)) guess[parlist == 'c0_%s'%sn] = snpar['c'][iSN] guess[parlist == 'c1_%s'%sn] = np.random.exponential(0.2) diff --git a/saltshaker/training/base.py b/saltshaker/training/base.py index 99ec000b..7d73617b 100644 --- a/saltshaker/training/base.py +++ b/saltshaker/training/base.py @@ -322,14 +322,14 @@ def mkcuts(self,datadict): return outdict,cutdict def filter_select(self,survey,flt): - select = True if flt in self.options.__dict__[f"{survey.split('(')[0]}_ignore_filters"].replace(' ','').split(','): - select = False + return False - lambdaeff = self.kcordict[survey][flt]['lambdaeff'] - if lambdaeff < self.options.filtercen_obs_waverange[0] or \ - lambdaeff > self.options.filtercen_obs_waverange[1] : - select = False + else: + lambdaeff = self.kcordict[survey][flt]['lambdaeff'] + if lambdaeff < self.options.filtercen_obs_waverange[0] or \ + lambdaeff > self.options.filtercen_obs_waverange[1] : + return False - return select + return True # end filter_select diff --git a/saltshaker/training/optimizers/gradientdescent.py b/saltshaker/training/optimizers/gradientdescent.py index 683a1697..bfdbecf5 100644 --- a/saltshaker/training/optimizers/gradientdescent.py +++ b/saltshaker/training/optimizers/gradientdescent.py @@ -47,6 +47,7 @@ def __init__(self,guess,saltresids,outputdir,options): self.saltobj=saltresids for x in self.configoptionnames: self.__dict__[x]=getattr(options,x) + self.outputdir=options.outputdir assert(0