diff --git a/SNANA_FITS_to_pd.py b/SNANA_FITS_to_pd.py index 69b593e..49a9759 100644 --- a/SNANA_FITS_to_pd.py +++ b/SNANA_FITS_to_pd.py @@ -54,9 +54,9 @@ def read_fits(fname,drop_separators=False): df_phot = df_phot[df_phot.MJD != -777.000] band_colname = "FLT" if "FLT" in df_phot.columns else "BAND" # check for filter column name from different versions of SNANA - df_header = df_header[["SNID", "SNTYPE", "PEAKMJD", "REDSHIFT_FINAL", "MWEBV"]] + df_header = df_header[["SNID", "SNTYPE", "PEAKMJD", "REDSHIFT_FINAL", "REDSHIFT_FINAL_ERR", "MWEBV"]] df_phot = df_phot[["SNID", "MJD", band_colname, "FLUXCAL", "FLUXCALERR"]] - df_header = df_header.rename(columns={"SNID":"object_id", "SNTYPE": "true_target", "PEAKMJD": "true_peakmjd", "REDSHIFT_FINAL": "true_z", "MWEBV": "mwebv"}) + df_header = df_header.rename(columns={"SNID":"object_id", "SNTYPE": "true_target", "PEAKMJD": "true_peakmjd", "REDSHIFT_FINAL": "true_z", "REDSHIFT_FINAL_ERR": "true_z_err", "MWEBV": "mwebv"}) df_header.replace({"true_target": {120: 42, 20: 42, 121: 42, 21: 42, 122: 42, 22: 42, 130: 62, 30: 62, 131: 62, 31: 62, 101: 90, 1: 90, 102: 52, 2: 52, 104: 64, 4: 64, 103: 95, 3: 95, 191: 67, 91: 67}}, inplace=True) df_phot = df_phot.rename(columns={"SNID":"object_id", "MJD": "mjd", band_colname: "passband", "FLUXCAL": "flux", "FLUXCALERR": "flux_err"}) @@ -97,6 +97,9 @@ def save_fits(df, fname): csv_metadata_path = os.path.join(args.output_dir, os.path.basename(path).replace("PHOT.FITS.gz", "HEAD.csv")) csv_lcdata_path = os.path.join(args.output_dir, os.path.basename(path).replace(".FITS.gz", ".csv")) + if os.path.exists(csv_metadata_path): + continue + print(f"writing to {csv_metadata_path}") metadata, lcdata = read_fits(path) if metadata.empty: continue diff --git a/config/config_example.yml b/config/config_example.yml index 0217016..1c72c0d 100644 --- a/config/config_example.yml +++ b/config/config_example.yml @@ -39,8 +39,14 @@ sn_type_id_to_name: 90: "SNIa" 95: "SLSN-1" -categorical: False +### MODEL ARCHITECTURE / TRAINING PARAMS ### +mode: "train" # or "predict", but have to specify path to trained model +trained_model: "/path/to/trained/model" # only needs to be specified in predict mode +class_balanced: True +categorical: False # categorical vs. binary classification batch_size: 32 num_epochs: 400 train_proportion: 0.8 val_proportion: 0.1 +has_ids: True +with_z: False # classification with/without redshift diff --git a/create_heatmaps/base.py b/create_heatmaps/base.py index d63abcc..da8bf7c 100644 --- a/create_heatmaps/base.py +++ b/create_heatmaps/base.py @@ -54,7 +54,7 @@ def load_data(self): self.lcdata_ids = np.intersect1d(self.lcdata['object_id'], metadata_ids) if self.ids_path: - ids_file = h5py.File(IDS_PATH, "r") + ids_file = h5py.File(self.ids_path, "r") self.ids = [x.decode('utf-8') for x in ids_file["names"]] ids_file.close() print("job {}: found ids, expect {} total heatmaps".format(self.index, len(self.ids)), flush=True) @@ -81,10 +81,12 @@ def create_heatmaps(self, output_paths, mjd_minmaxes, fit_on_full_lc=True): self.removed_by_type = {} self.done_ids = [] + timings = [] with tf.io.TFRecordWriter("{}/heatmaps_{}.tfrecord".format(output_path, self.index)) as writer: for i, sn_id in enumerate(self.lcdata_ids): if i % 1000 == 0: print("job {}: processing {} of {}".format(self.index, i, len(self.lcdata_ids)), flush=True) + start = time.time() sn_data = self._get_sn_data(sn_id) if not sn_data: @@ -144,9 +146,15 @@ def create_heatmaps(self, output_paths, mjd_minmaxes, fit_on_full_lc=True): break self.type_to_int_label[sn_name] = 1 if sn_name == "SNIa" or sn_name == "Ia" else 0 - writer.write(image_example(heatmap.flatten().tobytes(), self.type_to_int_label[sn_name], sn_id)) + z = sn_metadata['true_z'].iloc[0] + z_err = sn_metadata['true_z_err'].iloc[0] + writer.write(image_example(heatmap.flatten().tobytes(), self.type_to_int_label[sn_name], sn_id, z, z_err)) + timings.append(time.time() - start) + self._done(sn_name, sn_id) + pd.DataFrame({"timings": timings}).to_csv(os.path.join(output_path, "timings.csv"), index=False) + if not os.path.exists(self.finished_filenames_path): pd.DataFrame({"filenames": [os.path.basename(self.metadata_path)]}).to_csv(self.finished_filenames_path, index=False) else: @@ -167,13 +175,13 @@ def create_heatmaps(self, output_paths, mjd_minmaxes, fit_on_full_lc=True): def _get_sn_data(self, sn_id): sn_metadata = self.metadata[self.metadata.object_id == sn_id] if sn_metadata.empty: - return None, None + return None sn_name = self.sn_type_id_map[sn_metadata.true_target.iloc[0]] not_in_ids = self.ids and "{}_{}".format(sn_name, sn_id) not in self.ids already_done = sn_id in self.done_ids if not_in_ids or already_done: - return None, None + return None sn_lcdata = self.lcdata.loc['object_id', sn_id]['mjd', 'flux', 'flux_err', 'passband'] @@ -192,10 +200,10 @@ def _get_predictions_heatmap(self, gp, mjd_range, milkyway_ebv): wavelengths = np.linspace(3000.0, 10100.0, self.wavelength_bins) ext = get_extinction(milkyway_ebv, wavelengths) ext = np.tile(np.expand_dims(ext, axis=1), len(times)) - time_wavelength_grid = np.transpose([np.tile(times, len(wavelengths)), np.repeat(wavelengths, len(times))]) + predictions, prediction_vars = gp(time_wavelength_grid, return_var=True) - ext_corrected_predictions = np.array(predictions).reshape(len(wavelengths), len(times)) + ext - prediction_uncertainties = np.sqrt(prediction_vars).reshape(len(wavelengths), len(times)) + ext_corrected_predictions = np.array(predictions).reshape(32, 180) + ext + prediction_uncertainties = np.sqrt(prediction_vars).reshape(32, 180) return ext_corrected_predictions, prediction_uncertainties diff --git a/create_heatmaps/heatmaps_types.py b/create_heatmaps/heatmaps_types.py index c801048..07aacfe 100644 --- a/create_heatmaps/heatmaps_types.py +++ b/create_heatmaps/heatmaps_types.py @@ -1,5 +1,8 @@ import numpy as np +import os from create_heatmaps.base import CreateHeatmapsBase +import json +import pandas as pd class CreateHeatmapsFull(CreateHeatmapsBase): def run(self): @@ -28,7 +31,6 @@ def _calculate_mjd_range(sn_metadata, sn_data, mjd_minmax, has_peakmjd): def _calculate_trigger(sn_metadata, sn_data): sn_data.sort("mjd") snrs_by_mjd = [[mjd, flux/flux_err] for mjd, flux, flux_err in sn_data.iterrows('mjd', 'flux', 'flux_err')] - # snrs_by_mjd = {sn_data.iloc[idx]['mjd']:sn_data.iloc[idx]['flux']/sn_data.iloc[idx]['flux_err'] for idx in range(len(sn_data))} detections = [[mjd,snr] for mjd, snr in snrs_by_mjd if snr > 5] if len(detections) < 2: return @@ -43,20 +45,21 @@ def _calculate_trigger(sn_metadata, sn_data): class CreateHeatmapsEarlyMixed(CreateHeatmapsEarlyBase): def run(self): - self.create_heatmaps([self.output_path], [[-20, [5,15,25,50]]], fit_on_full_lc=False) + print("running early mixed") + self.create_heatmaps([self.output_path], [[-20, np.arange(0,51)]], fit_on_full_lc=False) @staticmethod def _calculate_mjd_range(sn_metadata, sn_data, mjd_minmax, has_peakmjd): mjd_min, mjd_max = mjd_minmax - trigger = self._calculate_trigger(sn_metadata, sn_data) + trigger = CreateHeatmapsEarlyMixed._calculate_trigger(sn_metadata, sn_data) if not trigger: return mjd_max = np.random.choice(mjd_max) - return [trigger_mjd+mjd_min, trigger_mjd+mjd_max] + return [trigger+mjd_min, trigger+mjd_max],mjd_max #TODO: change backto one return val 7/16 class CreateHeatmapsEarly(CreateHeatmapsEarlyBase): def run(self): - days_after_trigger = [5,15,25,50] + days_after_trigger = [5] days_before_trigger = -20 output_paths = [f"{self.output_path}/{days_before_trigger}x{i}_trigger_32x180" for i in days_after_trigger] mjd_ranges = [[days_before_trigger, i] for i in days_after_trigger] @@ -66,12 +69,82 @@ def run(self): @staticmethod def _calculate_mjd_range(sn_metadata, sn_data, mjd_minmax, has_peakmjd): mjd_min, mjd_max = mjd_minmax - trigger = self._calculate_trigger(sn_metadata, sn_data) + trigger = CreateHeatmapsEarly._calculate_trigger(sn_metadata, sn_data) if not trigger: return - return [trigger-mjd_min, trigger+mjd_max] + return [trigger+mjd_min, trigger+mjd_max] -class LastSNRById(CreateHeatmapsBase): +class SaveTriggerToCSV(CreateHeatmapsEarlyBase): + def run(self): + OUTPUT_PATH = os.path.dirname(self.metadata_path) + print("writing to {}".format(OUTPUT_PATH)) + self.metadata["1season_peakmjd"] = np.zeros(len(self.metadata)) + self.metadata["3season_peakmjd"] = np.zeros(len(self.metadata)) + + for i, sn_id in enumerate(self.lcdata_ids): + if i % 1000 == 0: + print(f"processing {i} of {len(self.lcdata_ids)}") + sn_metadata = self.metadata[self.metadata.object_id == sn_id] + sn_name = self.sn_type_id_map[sn_metadata.true_target.iloc[0]] + sn_lcdata = self.lcdata.loc['object_id', sn_id]['mjd', 'flux', 'flux_err', 'passband'] + + sn_mjd = sorted(sn_lcdata['mjd'], reverse=True) + trigger = sn_metadata.trigger_mjd.values[0] + if np.isnan(trigger): + continue + sn_mjd_trigger_idx = np.where([round(mjd, 3) for mjd in sn_mjd] == round(trigger, 3))[0] + if len(sn_mjd_trigger_idx) == 0: + print(trigger) + print([round(mjd, 3) for mjd in sn_mjd]) + break + sn_mjd_trigger_idx = sn_mjd_trigger_idx[0] + + season_start_idx = -1 + for i in range(sn_mjd_trigger_idx, len(sn_mjd)-1): + if i == 0: + print(sn_mjd[i] - sn_mjd[i+1]) + if sn_mjd[i] - sn_mjd[i+1] > 50: + season_start_idx = i + break + season_start = sn_mjd[season_start_idx] + + sn_mjd = sorted(sn_lcdata['mjd']) + sn_mjd_trigger_idx = np.where([round(mjd, 3) for mjd in sn_mjd] == round(trigger, 3))[0][0] + season_end_idx = -1 + for i in range(sn_mjd_trigger_idx, len(sn_mjd)-1): + if sn_mjd[i] - sn_mjd[i+1] < -100: + season_end_idx = i + break + season_end = sn_mjd[season_end_idx] + + sn_data = sn_lcdata[np.logical_and(sn_lcdata["mjd"] >= season_start, sn_lcdata["mjd"] <= season_end)] + mjd = np.array(sn_data['mjd']) + flux = np.array(sn_data['flux']) + flux_err = np.array(sn_data['flux_err']) + snrs = flux**2 / flux_err**2 + mask = snrs > 5 + mjd = mjd[mask] + snrs = snrs[mask] + peak_mjd_oneseason = np.sum(mjd * snrs) / np.sum(snrs) + self.metadata.loc[self.metadata.object_id == sn_id, "1season_peakmjd"] = peak_mjd_oneseason + + mjd = np.array(sn_lcdata['mjd']) + flux = np.array(sn_lcdata['flux']) + flux_err = np.array(sn_lcdata['flux_err']) + snrs = flux**2 / flux_err**2 + mask = snrs > 5 + mjd = mjd[mask] + snrs = snrs[mask] + if len(mjd) == 0 or len(snrs) == 0: + print(snid) + peak_mjd_calculated = np.sum(mjd * snrs) / np.sum(snrs) + self.metadata.loc[self.metadata.object_id == sn_id, "3season_peakmjd"] = peak_mjd_calculated + self.metadata.to_csv(os.path.join(OUTPUT_PATH, os.path.basename(self.metadata_path)), index=False) + +class MagById(CreateHeatmapsBase): + @staticmethod + def _calculate_mjd_range(sn_metadata, sn_data, mjd_minmax, has_peakmjd): + raise NotImplementedError def run(self): def _calculate_detections(sn_data): sn_data.sort("mjd") @@ -95,27 +168,16 @@ def _calculate_trigger(sn_data): trigger_mjd = detections_mjd[0] return trigger_mjd - last_snr_by_id = {5: {}, 15: {}, 25: {}, 50: {}} - # num_detection_points_by_type = {5: {}, 15: {}, 25: {}, 50: {}} - metadata_path = self.config['metadata_paths'][self.index] - lcdata_path = self.config['lcdata_paths'][self.index] - - metadata = pd.read_csv(metadata_path, compression="gzip") if os.path.splitext(metadata_path)[1] == ".gz" else pd.read_csv(metadata_path) - metadata_ids = metadata[metadata.true_target.isin(self.config["sn_type_id_to_name"].keys())].object_id - - lcdata = pd.read_csv(lcdata_path, compression="gzip") if os.path.splitext(lcdata_path)[1] == ".gz" else pd.read_csv(lcdata_path) - lcdata = Table.from_pandas(lcdata) - lcdata.add_index('object_id') - lcdata_ids = np.intersect1d(lcdata['object_id'], metadata_ids) - for i, sn_id in enumerate(lcdata_ids): + mag_by_id = {0: [], 5: [], 15: []} + for i, sn_id in enumerate(self.lcdata_ids): if i % 1000 == 0: - print(f"processing {i} of {len(lcdata_ids)}") - sn_metadata = metadata[metadata.object_id == sn_id] - sn_name = self.config["sn_type_id_to_name"][sn_metadata.true_target.iloc[0]] - sn_lcdata = lcdata.loc['object_id', sn_id]['mjd', 'flux', 'flux_err', 'passband'] + print(f"processing {i} of {len(self.lcdata_ids)}") + sn_metadata = self.metadata[self.metadata.object_id == sn_id] + sn_name = self.sn_type_id_map[sn_metadata.true_target.iloc[0]] + sn_lcdata = self.lcdata.loc['object_id', sn_id]['mjd', 'flux', 'flux_err', 'passband'] - for mjdmax in [5,15,25,50]: + for mjdmax in mag_by_id.keys(): trigger_mjd = _calculate_trigger(sn_lcdata) detections = _calculate_detections(sn_lcdata) if not detections or not trigger_mjd: @@ -127,14 +189,47 @@ def _calculate_trigger(sn_data): if not mask.any(): # if all false print("empty sn data after mjd mask", mjd_range, np.min(mjds), np.max(mjds)) continue - # sn_lcdata_included = sn_lcdata[mask] + sn_lcdata_included = sn_lcdata[mask] + sn_lcdata_r = sn_lcdata_included[sn_lcdata_included['passband'] == 1] + if len(sn_lcdata_r) == 0: + continue + last_r_flux = sn_lcdata_r['flux'][-1] + last_r_mag = 27.5 - 2.5*np.log10(last_r_flux) + if last_r_mag <= 20: + mag_by_id[int(mjdmax)].append(int(sn_id)) + + with open(os.path.join(self.output_path, f"mag_over_20_ids_{self.index}.json"), "w+") as outfile: + json.dump(mag_by_id, outfile) - detections_mjd, detections_snr = detections - mask = np.logical_and(detections_mjd >= mjd_range[0], detections_mjd <= mjd_range[1]) - detections_included = np.array(detections_snr)[mask] - last_snr = detections_included[-1] +class SaveFirstDetectionToCSV(CreateHeatmapsEarlyBase): + @staticmethod + def _calculate_first_detection(sn_metadata, sn_data): + sn_data.sort("mjd") + snrs_by_mjd = [[mjd, flux/flux_err] for mjd, flux, flux_err in sn_data.iterrows('mjd', 'flux', 'flux_err')] + detections = [[mjd,snr] for mjd, snr in snrs_by_mjd if snr > 5] + if len(detections) < 2: + return + first_detection_mjd = detections[0][0] - last_snr_by_id[mjdmax][int(sn_id)] = last_snr + return first_detection_mjd - with open(self.config["heatmaps_path"] + f"/last_snr_by_id_{self.index}.json", "w+") as outfile: - outfile.write(json.dumps(last_snr_by_id)) + def run(self): + OUTPUT_PATH = os.path.dirname(self.metadata_path) + print("writing to {}".format(OUTPUT_PATH)) + + data = [] + for i, sn_id in enumerate(self.lcdata_ids): + if i % 1000 == 0: + print(f"processing {i} of {len(self.lcdata_ids)}") + sn_metadata = self.metadata[self.metadata.object_id == sn_id] + sn_name = self.sn_type_id_map[sn_metadata.true_target.iloc[0]] + sn_lcdata = self.lcdata.loc['object_id', sn_id]['mjd', 'flux', 'flux_err', 'passband'] + + sn_mjd = sorted(sn_lcdata['mjd'], reverse=True) + trigger = sn_metadata.trigger_mjd.values[0] + first_detection = SaveFirstDetectionToCSV._calculate_first_detection(sn_metadata, sn_lcdata) + if np.isnan(trigger) or np.isnan(first_detection): + continue + data.append([sn_id, first_detection, trigger]) + + pd.DataFrame(data, columns=["snid", "first_detection_mjd", "trigger_mjd"]).to_csv(os.path.join(OUTPUT_PATH, os.path.basename(self.metadata_path)), index=False) diff --git a/create_heatmaps/helpers.py b/create_heatmaps/helpers.py index 6d4a19a..6dd56a5 100644 --- a/create_heatmaps/helpers.py +++ b/create_heatmaps/helpers.py @@ -51,7 +51,7 @@ def grad_neg_ln_like(p): return gaussian_process -def image_example(image_string, label, id): +def image_example(image_string, label, id, z, z_err): def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): @@ -62,9 +62,15 @@ def _int64_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + def _float_feature(value): + """Returns a float_list from a float / double.""" + return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) + feature = { 'id': _int64_feature(id), 'label': _int64_feature(label), + 'z': _float_feature(z), + 'z_err': _float_feature(z_err), 'image_raw': _bytes_feature(image_string), } diff --git a/create_heatmaps/manager.py b/create_heatmaps/manager.py index 43965d7..c441d2d 100644 --- a/create_heatmaps/manager.py +++ b/create_heatmaps/manager.py @@ -1,4 +1,4 @@ -from create_heatmaps.heatmaps_types import CreateHeatmapsFull, CreateHeatmapsEarlyMixed, CreateHeatmapsEarly +from create_heatmaps.heatmaps_types import CreateHeatmapsFull, CreateHeatmapsEarlyMixed, CreateHeatmapsEarly, MagById, SaveFirstDetectionToCSV class CreateHeatmapsManager(): def run(self, config, index): diff --git a/create_heatmaps/run.py b/create_heatmaps/run.py index f8da9d0..a15b763 100644 --- a/create_heatmaps/run.py +++ b/create_heatmaps/run.py @@ -9,7 +9,7 @@ #SBATCH --qos=regular #SBATCH -N 1 #SBATCH --cpus-per-task=32 -#SBATCH --time=40:00:00 +#SBATCH --time=20:00:00 #SBATCH --output={log_path} export OMP_PROC_BIND=true diff --git a/data_utils.py b/data_utils.py index b07195f..faf1b20 100644 --- a/data_utils.py +++ b/data_utils.py @@ -7,28 +7,37 @@ # - raw_record # - INPUT_SHAPE # - CATEGORICAL -def get_images(raw_record, input_shape, categorical=False, has_ids=False): +def get_images(raw_record, input_shape, has_ids=False, with_z=False): image_feature_description = { 'label': tf.io.FixedLenFeature([], tf.int64), 'image_raw': tf.io.FixedLenFeature([], tf.string), } if has_ids: image_feature_description['id'] = tf.io.FixedLenFeature([], tf.int64) + if with_z: + image_feature_description['z'] = tf.io.FixedLenFeature([], tf.float32) + image_feature_description['z_err'] = tf.io.FixedLenFeature([], tf.float32) + example = tf.io.parse_single_example(raw_record, image_feature_description) image = tf.reshape(tf.io.decode_raw(example['image_raw'], tf.float64), input_shape) image = image / tf.reduce_max(image[:,:,0]) + # TODO: have to subtract 1 from label to get rid of KN in early classification + if with_z: + output = [{"image": image, "z": example["z"], "z_err": example["z_err"]}, {"label": example['label']-1}] + else: + output = [{"image": image}, {"label": example['label']}] if has_ids: - return image, example['label'], tf.cast(example['id'], tf.int32) - return image, example['label'] + return output.append(tf.cast(example['id'], tf.int32)) + return output # balances classes, splits dataset into train/validation/test sets # requires: # - dataset # - train_proportion def stratified_split(dataset, train_proportion, types, include_test_set, class_balance): - by_type_data_lists = {sn_type: dataset.filter(lambda image, label, *_: label == sn_type) for sn_type in types} + by_type_data_lists = {sn_type: dataset.filter(lambda image, label, *_: label["label"] == sn_type) for sn_type in types} by_type_data_lengths = {k: sum([1 for _ in v]) for k,v in by_type_data_lists.items()} print(f"number of samples per label: {by_type_data_lengths}") diff --git a/model_utils.py b/model_utils.py index eb762c1..7979212 100644 --- a/model_utils.py +++ b/model_utils.py @@ -37,6 +37,7 @@ def __init__(self, config): # self.num_types = len(np.unique(types)) self.train_proportion = config.get('train_proportion', 0.8) self.has_ids = config.get('has_ids', False) + self.with_z = config.get('with_z', False) self.use_test_set = True if config["mode"] == "predict" else False self.external_trained_model = config.get("trained_model") self.abundances = None @@ -139,41 +140,48 @@ def get_test_set(self): # - NUM_TYPES def _define_and_compile_model(self, metrics=['accuracy']): y, x, _ = self.input_shape - - model = models.Sequential() - model.add(layers.ZeroPadding2D(padding=(0,1), input_shape=self.input_shape)) - model.add(layers.Conv2D(y, (y, 3), activation='elu')) - model.add(self.Reshape()) - model.add(layers.BatchNormalization()) - model.add(layers.ZeroPadding2D(padding=(0,1))) - model.add(layers.Conv2D(y, (y, 3), activation='elu')) - model.add(self.Reshape()) - model.add(layers.BatchNormalization()) - - model.add(layers.MaxPooling2D((2, 2))) + image_input = tf.keras.Input(shape=self.input_shape, name="image") + # z_input, z_err_input will only be used when doing classification with redshift + z_input = tf.keras.Input(shape=(1,), name="z") + z_err_input = tf.keras.Input(shape=(1,), name="z_err") + inputs = [image_input] if not self.with_z else [image_input, z_input, z_err_input] - model.add(layers.ZeroPadding2D(padding=(0,1))) - model.add(layers.Conv2D(int(y/2), (int(y/2), 3), activation='elu')) - model.add(self.Reshape()) - model.add(layers.BatchNormalization()) - model.add(layers.ZeroPadding2D(padding=(0,1))) - model.add(layers.Conv2D(int(y/2), (int(y/2), 3), activation='elu')) - model.add(self.Reshape()) - model.add(layers.BatchNormalization()) - - model.add(layers.MaxPooling2D((2, 2))) - - model.add(layers.Flatten()) - model.add(layers.Dropout(0.5)) - model.add(layers.Dense(32, activation='relu')) - model.add(layers.Dropout(0.3)) + x = layers.ZeroPadding2D(padding=(0,1))(image_input) + x = layers.Conv2D(y, (y, 3), activation='elu')(x) + x = self.Reshape()(x) + x = layers.BatchNormalization()(x) + x = layers.ZeroPadding2D(padding=(0,1))(x) + x = layers.Conv2D(y, (y, 3), activation='elu')(x) + x = self.Reshape()(x) + x = layers.BatchNormalization()(x) + + x = layers.MaxPooling2D((2, 2))(x) + + x = layers.ZeroPadding2D(padding=(0,1))(x) + x = layers.Conv2D(int(y/2), (int(y/2), 3), activation='elu')(x) + x = self.Reshape()(x) + x = layers.BatchNormalization()(x) + x = layers.ZeroPadding2D(padding=(0,1))(x) + x = layers.Conv2D(int(y/2), (int(y/2), 3), activation='elu')(x) + x = self.Reshape()(x) + x = layers.BatchNormalization()(x) + + x = layers.MaxPooling2D((2, 2))(x) + + x = layers.Flatten()(x) + x = layers.Dropout(0.5)(x) + if self.with_z: + x = layers.concatenate([x, z_input, z_err_input]) + x = layers.Dense(32, activation='relu')(x) + x = layers.Dropout(0.3)(x) if self.categorical: - model.add(layers.Dense(self.num_types, activation='softmax')) + sn_type_pred = layers.Dense(self.num_types, activation='softmax', name="label")(x) else: - model.add(layers.Dense(1, activation='sigmoid')) + sn_type_pred = layers.Dense(1, activation='sigmoid', name="label")(x) + model = Model(inputs=inputs, outputs=[sn_type_pred]) opt = optimizers.Adam(learning_rate=1e-4) loss = 'sparse_categorical_crossentropy' if self.categorical else 'binary_crossentropy' print(metrics) @@ -247,8 +255,8 @@ def _load_dataset(self): return raw_dataset def _retrieve_data(self, raw_dataset): - dataset = raw_dataset.map(lambda x: get_images(x, self.input_shape, self.categorical, self.has_ids), num_parallel_calls=40) - self.types = [0,1] if not self.categorical else np.unique([data[1] for data in dataset]) + dataset = raw_dataset.map(lambda x: get_images(x, self.input_shape, self.has_ids, self.with_z), num_parallel_calls=40) + self.types = [0,1] if not self.categorical else range(0, self.num_types) return dataset.apply(tf.data.experimental.ignore_errors()) @@ -268,6 +276,7 @@ def _split_and_retrieve_data(self): test_set = test_set.prefetch(tf.data.experimental.AUTOTUNE).cache() if self.use_test_set else None if self.has_ids: + # TODO: maybe don't actually extract the ids - too time-consuming, not that useful return extract_ids_and_batch(train_set, val_set, test_set, self.batch_size) else: train_set = train_set.batch(self.batch_size)