Skip to content

Commit

Permalink
added option to classify with redshift
Browse files Browse the repository at this point in the history
  • Loading branch information
helenqu committed Oct 26, 2021
1 parent b88fed7 commit b622c7c
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 82 deletions.
7 changes: 5 additions & 2 deletions SNANA_FITS_to_pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def read_fits(fname,drop_separators=False):
df_phot = df_phot[df_phot.MJD != -777.000]

band_colname = "FLT" if "FLT" in df_phot.columns else "BAND" # check for filter column name from different versions of SNANA
df_header = df_header[["SNID", "SNTYPE", "PEAKMJD", "REDSHIFT_FINAL", "MWEBV"]]
df_header = df_header[["SNID", "SNTYPE", "PEAKMJD", "REDSHIFT_FINAL", "REDSHIFT_FINAL_ERR", "MWEBV"]]
df_phot = df_phot[["SNID", "MJD", band_colname, "FLUXCAL", "FLUXCALERR"]]
df_header = df_header.rename(columns={"SNID":"object_id", "SNTYPE": "true_target", "PEAKMJD": "true_peakmjd", "REDSHIFT_FINAL": "true_z", "MWEBV": "mwebv"})
df_header = df_header.rename(columns={"SNID":"object_id", "SNTYPE": "true_target", "PEAKMJD": "true_peakmjd", "REDSHIFT_FINAL": "true_z", "REDSHIFT_FINAL_ERR": "true_z_err", "MWEBV": "mwebv"})
df_header.replace({"true_target":
{120: 42, 20: 42, 121: 42, 21: 42, 122: 42, 22: 42, 130: 62, 30: 62, 131: 62, 31: 62, 101: 90, 1: 90, 102: 52, 2: 52, 104: 64, 4: 64, 103: 95, 3: 95, 191: 67, 91: 67}}, inplace=True)
df_phot = df_phot.rename(columns={"SNID":"object_id", "MJD": "mjd", band_colname: "passband", "FLUXCAL": "flux", "FLUXCALERR": "flux_err"})
Expand Down Expand Up @@ -97,6 +97,9 @@ def save_fits(df, fname):
csv_metadata_path = os.path.join(args.output_dir, os.path.basename(path).replace("PHOT.FITS.gz", "HEAD.csv"))
csv_lcdata_path = os.path.join(args.output_dir, os.path.basename(path).replace(".FITS.gz", ".csv"))

if os.path.exists(csv_metadata_path):
continue
print(f"writing to {csv_metadata_path}")
metadata, lcdata = read_fits(path)
if metadata.empty:
continue
Expand Down
8 changes: 7 additions & 1 deletion config/config_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,14 @@ sn_type_id_to_name:
90: "SNIa"
95: "SLSN-1"

categorical: False
### MODEL ARCHITECTURE / TRAINING PARAMS ###
mode: "train" # or "predict", but have to specify path to trained model
trained_model: "/path/to/trained/model" # only needs to be specified in predict mode
class_balanced: True
categorical: False # categorical vs. binary classification
batch_size: 32
num_epochs: 400
train_proportion: 0.8
val_proportion: 0.1
has_ids: True
with_z: False # classification with/without redshift
22 changes: 15 additions & 7 deletions create_heatmaps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def load_data(self):
self.lcdata_ids = np.intersect1d(self.lcdata['object_id'], metadata_ids)

if self.ids_path:
ids_file = h5py.File(IDS_PATH, "r")
ids_file = h5py.File(self.ids_path, "r")
self.ids = [x.decode('utf-8') for x in ids_file["names"]]
ids_file.close()
print("job {}: found ids, expect {} total heatmaps".format(self.index, len(self.ids)), flush=True)
Expand All @@ -81,10 +81,12 @@ def create_heatmaps(self, output_paths, mjd_minmaxes, fit_on_full_lc=True):
self.removed_by_type = {}
self.done_ids = []

timings = []
with tf.io.TFRecordWriter("{}/heatmaps_{}.tfrecord".format(output_path, self.index)) as writer:
for i, sn_id in enumerate(self.lcdata_ids):
if i % 1000 == 0:
print("job {}: processing {} of {}".format(self.index, i, len(self.lcdata_ids)), flush=True)
start = time.time()

sn_data = self._get_sn_data(sn_id)
if not sn_data:
Expand Down Expand Up @@ -144,9 +146,15 @@ def create_heatmaps(self, output_paths, mjd_minmaxes, fit_on_full_lc=True):
break
self.type_to_int_label[sn_name] = 1 if sn_name == "SNIa" or sn_name == "Ia" else 0

writer.write(image_example(heatmap.flatten().tobytes(), self.type_to_int_label[sn_name], sn_id))
z = sn_metadata['true_z'].iloc[0]
z_err = sn_metadata['true_z_err'].iloc[0]
writer.write(image_example(heatmap.flatten().tobytes(), self.type_to_int_label[sn_name], sn_id, z, z_err))
timings.append(time.time() - start)

self._done(sn_name, sn_id)

pd.DataFrame({"timings": timings}).to_csv(os.path.join(output_path, "timings.csv"), index=False)

if not os.path.exists(self.finished_filenames_path):
pd.DataFrame({"filenames": [os.path.basename(self.metadata_path)]}).to_csv(self.finished_filenames_path, index=False)
else:
Expand All @@ -167,13 +175,13 @@ def create_heatmaps(self, output_paths, mjd_minmaxes, fit_on_full_lc=True):
def _get_sn_data(self, sn_id):
sn_metadata = self.metadata[self.metadata.object_id == sn_id]
if sn_metadata.empty:
return None, None
return None

sn_name = self.sn_type_id_map[sn_metadata.true_target.iloc[0]]
not_in_ids = self.ids and "{}_{}".format(sn_name, sn_id) not in self.ids
already_done = sn_id in self.done_ids
if not_in_ids or already_done:
return None, None
return None

sn_lcdata = self.lcdata.loc['object_id', sn_id]['mjd', 'flux', 'flux_err', 'passband']

Expand All @@ -192,10 +200,10 @@ def _get_predictions_heatmap(self, gp, mjd_range, milkyway_ebv):
wavelengths = np.linspace(3000.0, 10100.0, self.wavelength_bins)
ext = get_extinction(milkyway_ebv, wavelengths)
ext = np.tile(np.expand_dims(ext, axis=1), len(times))

time_wavelength_grid = np.transpose([np.tile(times, len(wavelengths)), np.repeat(wavelengths, len(times))])

predictions, prediction_vars = gp(time_wavelength_grid, return_var=True)
ext_corrected_predictions = np.array(predictions).reshape(len(wavelengths), len(times)) + ext
prediction_uncertainties = np.sqrt(prediction_vars).reshape(len(wavelengths), len(times))
ext_corrected_predictions = np.array(predictions).reshape(32, 180) + ext
prediction_uncertainties = np.sqrt(prediction_vars).reshape(32, 180)

return ext_corrected_predictions, prediction_uncertainties
163 changes: 129 additions & 34 deletions create_heatmaps/heatmaps_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import os
from create_heatmaps.base import CreateHeatmapsBase
import json
import pandas as pd

class CreateHeatmapsFull(CreateHeatmapsBase):
def run(self):
Expand Down Expand Up @@ -28,7 +31,6 @@ def _calculate_mjd_range(sn_metadata, sn_data, mjd_minmax, has_peakmjd):
def _calculate_trigger(sn_metadata, sn_data):
sn_data.sort("mjd")
snrs_by_mjd = [[mjd, flux/flux_err] for mjd, flux, flux_err in sn_data.iterrows('mjd', 'flux', 'flux_err')]
# snrs_by_mjd = {sn_data.iloc[idx]['mjd']:sn_data.iloc[idx]['flux']/sn_data.iloc[idx]['flux_err'] for idx in range(len(sn_data))}
detections = [[mjd,snr] for mjd, snr in snrs_by_mjd if snr > 5]
if len(detections) < 2:
return
Expand All @@ -43,20 +45,21 @@ def _calculate_trigger(sn_metadata, sn_data):

class CreateHeatmapsEarlyMixed(CreateHeatmapsEarlyBase):
def run(self):
self.create_heatmaps([self.output_path], [[-20, [5,15,25,50]]], fit_on_full_lc=False)
print("running early mixed")
self.create_heatmaps([self.output_path], [[-20, np.arange(0,51)]], fit_on_full_lc=False)

@staticmethod
def _calculate_mjd_range(sn_metadata, sn_data, mjd_minmax, has_peakmjd):
mjd_min, mjd_max = mjd_minmax
trigger = self._calculate_trigger(sn_metadata, sn_data)
trigger = CreateHeatmapsEarlyMixed._calculate_trigger(sn_metadata, sn_data)
if not trigger:
return
mjd_max = np.random.choice(mjd_max)
return [trigger_mjd+mjd_min, trigger_mjd+mjd_max]
return [trigger+mjd_min, trigger+mjd_max],mjd_max #TODO: change backto one return val 7/16

class CreateHeatmapsEarly(CreateHeatmapsEarlyBase):
def run(self):
days_after_trigger = [5,15,25,50]
days_after_trigger = [5]
days_before_trigger = -20
output_paths = [f"{self.output_path}/{days_before_trigger}x{i}_trigger_32x180" for i in days_after_trigger]
mjd_ranges = [[days_before_trigger, i] for i in days_after_trigger]
Expand All @@ -66,12 +69,82 @@ def run(self):
@staticmethod
def _calculate_mjd_range(sn_metadata, sn_data, mjd_minmax, has_peakmjd):
mjd_min, mjd_max = mjd_minmax
trigger = self._calculate_trigger(sn_metadata, sn_data)
trigger = CreateHeatmapsEarly._calculate_trigger(sn_metadata, sn_data)
if not trigger:
return
return [trigger-mjd_min, trigger+mjd_max]
return [trigger+mjd_min, trigger+mjd_max]

class LastSNRById(CreateHeatmapsBase):
class SaveTriggerToCSV(CreateHeatmapsEarlyBase):
def run(self):
OUTPUT_PATH = os.path.dirname(self.metadata_path)
print("writing to {}".format(OUTPUT_PATH))
self.metadata["1season_peakmjd"] = np.zeros(len(self.metadata))
self.metadata["3season_peakmjd"] = np.zeros(len(self.metadata))

for i, sn_id in enumerate(self.lcdata_ids):
if i % 1000 == 0:
print(f"processing {i} of {len(self.lcdata_ids)}")
sn_metadata = self.metadata[self.metadata.object_id == sn_id]
sn_name = self.sn_type_id_map[sn_metadata.true_target.iloc[0]]
sn_lcdata = self.lcdata.loc['object_id', sn_id]['mjd', 'flux', 'flux_err', 'passband']

sn_mjd = sorted(sn_lcdata['mjd'], reverse=True)
trigger = sn_metadata.trigger_mjd.values[0]
if np.isnan(trigger):
continue
sn_mjd_trigger_idx = np.where([round(mjd, 3) for mjd in sn_mjd] == round(trigger, 3))[0]
if len(sn_mjd_trigger_idx) == 0:
print(trigger)
print([round(mjd, 3) for mjd in sn_mjd])
break
sn_mjd_trigger_idx = sn_mjd_trigger_idx[0]

season_start_idx = -1
for i in range(sn_mjd_trigger_idx, len(sn_mjd)-1):
if i == 0:
print(sn_mjd[i] - sn_mjd[i+1])
if sn_mjd[i] - sn_mjd[i+1] > 50:
season_start_idx = i
break
season_start = sn_mjd[season_start_idx]

sn_mjd = sorted(sn_lcdata['mjd'])
sn_mjd_trigger_idx = np.where([round(mjd, 3) for mjd in sn_mjd] == round(trigger, 3))[0][0]
season_end_idx = -1
for i in range(sn_mjd_trigger_idx, len(sn_mjd)-1):
if sn_mjd[i] - sn_mjd[i+1] < -100:
season_end_idx = i
break
season_end = sn_mjd[season_end_idx]

sn_data = sn_lcdata[np.logical_and(sn_lcdata["mjd"] >= season_start, sn_lcdata["mjd"] <= season_end)]
mjd = np.array(sn_data['mjd'])
flux = np.array(sn_data['flux'])
flux_err = np.array(sn_data['flux_err'])
snrs = flux**2 / flux_err**2
mask = snrs > 5
mjd = mjd[mask]
snrs = snrs[mask]
peak_mjd_oneseason = np.sum(mjd * snrs) / np.sum(snrs)
self.metadata.loc[self.metadata.object_id == sn_id, "1season_peakmjd"] = peak_mjd_oneseason

mjd = np.array(sn_lcdata['mjd'])
flux = np.array(sn_lcdata['flux'])
flux_err = np.array(sn_lcdata['flux_err'])
snrs = flux**2 / flux_err**2
mask = snrs > 5
mjd = mjd[mask]
snrs = snrs[mask]
if len(mjd) == 0 or len(snrs) == 0:
print(snid)
peak_mjd_calculated = np.sum(mjd * snrs) / np.sum(snrs)
self.metadata.loc[self.metadata.object_id == sn_id, "3season_peakmjd"] = peak_mjd_calculated
self.metadata.to_csv(os.path.join(OUTPUT_PATH, os.path.basename(self.metadata_path)), index=False)

class MagById(CreateHeatmapsBase):
@staticmethod
def _calculate_mjd_range(sn_metadata, sn_data, mjd_minmax, has_peakmjd):
raise NotImplementedError
def run(self):
def _calculate_detections(sn_data):
sn_data.sort("mjd")
Expand All @@ -95,27 +168,16 @@ def _calculate_trigger(sn_data):
trigger_mjd = detections_mjd[0]

return trigger_mjd
last_snr_by_id = {5: {}, 15: {}, 25: {}, 50: {}}
# num_detection_points_by_type = {5: {}, 15: {}, 25: {}, 50: {}}
metadata_path = self.config['metadata_paths'][self.index]
lcdata_path = self.config['lcdata_paths'][self.index]

metadata = pd.read_csv(metadata_path, compression="gzip") if os.path.splitext(metadata_path)[1] == ".gz" else pd.read_csv(metadata_path)
metadata_ids = metadata[metadata.true_target.isin(self.config["sn_type_id_to_name"].keys())].object_id

lcdata = pd.read_csv(lcdata_path, compression="gzip") if os.path.splitext(lcdata_path)[1] == ".gz" else pd.read_csv(lcdata_path)
lcdata = Table.from_pandas(lcdata)
lcdata.add_index('object_id')
lcdata_ids = np.intersect1d(lcdata['object_id'], metadata_ids)

for i, sn_id in enumerate(lcdata_ids):
mag_by_id = {0: [], 5: [], 15: []}
for i, sn_id in enumerate(self.lcdata_ids):
if i % 1000 == 0:
print(f"processing {i} of {len(lcdata_ids)}")
sn_metadata = metadata[metadata.object_id == sn_id]
sn_name = self.config["sn_type_id_to_name"][sn_metadata.true_target.iloc[0]]
sn_lcdata = lcdata.loc['object_id', sn_id]['mjd', 'flux', 'flux_err', 'passband']
print(f"processing {i} of {len(self.lcdata_ids)}")
sn_metadata = self.metadata[self.metadata.object_id == sn_id]
sn_name = self.sn_type_id_map[sn_metadata.true_target.iloc[0]]
sn_lcdata = self.lcdata.loc['object_id', sn_id]['mjd', 'flux', 'flux_err', 'passband']

for mjdmax in [5,15,25,50]:
for mjdmax in mag_by_id.keys():
trigger_mjd = _calculate_trigger(sn_lcdata)
detections = _calculate_detections(sn_lcdata)
if not detections or not trigger_mjd:
Expand All @@ -127,14 +189,47 @@ def _calculate_trigger(sn_data):
if not mask.any(): # if all false
print("empty sn data after mjd mask", mjd_range, np.min(mjds), np.max(mjds))
continue
# sn_lcdata_included = sn_lcdata[mask]
sn_lcdata_included = sn_lcdata[mask]
sn_lcdata_r = sn_lcdata_included[sn_lcdata_included['passband'] == 1]
if len(sn_lcdata_r) == 0:
continue
last_r_flux = sn_lcdata_r['flux'][-1]
last_r_mag = 27.5 - 2.5*np.log10(last_r_flux)
if last_r_mag <= 20:
mag_by_id[int(mjdmax)].append(int(sn_id))

with open(os.path.join(self.output_path, f"mag_over_20_ids_{self.index}.json"), "w+") as outfile:
json.dump(mag_by_id, outfile)

detections_mjd, detections_snr = detections
mask = np.logical_and(detections_mjd >= mjd_range[0], detections_mjd <= mjd_range[1])
detections_included = np.array(detections_snr)[mask]
last_snr = detections_included[-1]
class SaveFirstDetectionToCSV(CreateHeatmapsEarlyBase):
@staticmethod
def _calculate_first_detection(sn_metadata, sn_data):
sn_data.sort("mjd")
snrs_by_mjd = [[mjd, flux/flux_err] for mjd, flux, flux_err in sn_data.iterrows('mjd', 'flux', 'flux_err')]
detections = [[mjd,snr] for mjd, snr in snrs_by_mjd if snr > 5]
if len(detections) < 2:
return
first_detection_mjd = detections[0][0]

last_snr_by_id[mjdmax][int(sn_id)] = last_snr
return first_detection_mjd

with open(self.config["heatmaps_path"] + f"/last_snr_by_id_{self.index}.json", "w+") as outfile:
outfile.write(json.dumps(last_snr_by_id))
def run(self):
OUTPUT_PATH = os.path.dirname(self.metadata_path)
print("writing to {}".format(OUTPUT_PATH))

data = []
for i, sn_id in enumerate(self.lcdata_ids):
if i % 1000 == 0:
print(f"processing {i} of {len(self.lcdata_ids)}")
sn_metadata = self.metadata[self.metadata.object_id == sn_id]
sn_name = self.sn_type_id_map[sn_metadata.true_target.iloc[0]]
sn_lcdata = self.lcdata.loc['object_id', sn_id]['mjd', 'flux', 'flux_err', 'passband']

sn_mjd = sorted(sn_lcdata['mjd'], reverse=True)
trigger = sn_metadata.trigger_mjd.values[0]
first_detection = SaveFirstDetectionToCSV._calculate_first_detection(sn_metadata, sn_lcdata)
if np.isnan(trigger) or np.isnan(first_detection):
continue
data.append([sn_id, first_detection, trigger])

pd.DataFrame(data, columns=["snid", "first_detection_mjd", "trigger_mjd"]).to_csv(os.path.join(OUTPUT_PATH, os.path.basename(self.metadata_path)), index=False)
8 changes: 7 additions & 1 deletion create_heatmaps/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def grad_neg_ln_like(p):

return gaussian_process

def image_example(image_string, label, id):
def image_example(image_string, label, id, z, z_err):
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
Expand All @@ -62,9 +62,15 @@ def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

feature = {
'id': _int64_feature(id),
'label': _int64_feature(label),
'z': _float_feature(z),
'z_err': _float_feature(z_err),
'image_raw': _bytes_feature(image_string),
}

Expand Down
2 changes: 1 addition & 1 deletion create_heatmaps/manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from create_heatmaps.heatmaps_types import CreateHeatmapsFull, CreateHeatmapsEarlyMixed, CreateHeatmapsEarly
from create_heatmaps.heatmaps_types import CreateHeatmapsFull, CreateHeatmapsEarlyMixed, CreateHeatmapsEarly, MagById, SaveFirstDetectionToCSV

class CreateHeatmapsManager():
def run(self, config, index):
Expand Down
2 changes: 1 addition & 1 deletion create_heatmaps/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#SBATCH --qos=regular
#SBATCH -N 1
#SBATCH --cpus-per-task=32
#SBATCH --time=40:00:00
#SBATCH --time=20:00:00
#SBATCH --output={log_path}
export OMP_PROC_BIND=true
Expand Down
Loading

0 comments on commit b622c7c

Please sign in to comment.