From afe6f6cf8f0b96106f3eec48f68a1d7f1f0aefb9 Mon Sep 17 00:00:00 2001 From: Greg Olmschenk Date: Sun, 24 Jan 2021 17:51:20 -0500 Subject: [PATCH] Committing the changes required to produce the results for the 2021 transit paper --- infer.py | 14 +++--- ramjet/analysis/transit_fitter.py | 2 +- ramjet/models/hades.py | 45 +++++++++++++++++++ ..._eclipsing_binary_lightcurve_collection.py | 16 ++++++- .../derived/tess_ffi_lightcurve_collection.py | 4 +- .../derived/tess_ffi_transit_databases.py | 21 ++++++--- train.py | 14 +++--- 7 files changed, 95 insertions(+), 21 deletions(-) diff --git a/infer.py b/infer.py index 34e698b4..bb1edeb9 100644 --- a/infer.py +++ b/infer.py @@ -3,25 +3,27 @@ import datetime from pathlib import Path -from ramjet.models.hades import Hades +from ramjet.models.hades import Hades, FfiHades +from ramjet.photometric_database.derived.tess_ffi_transit_databases import \ + TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase from ramjet.photometric_database.derived.tess_two_minute_cadence_transit_databases import \ TessTwoMinuteCadenceStandardAndInjectedTransitDatabase from ramjet.analysis.model_loader import get_latest_log_directory from ramjet.trial import infer -log_name = get_latest_log_directory(logs_directory='logs') # Uses the latest model in the log directory. -# log_name = 'logs/baseline YYYY-MM-DD-hh-mm-ss' # Specify the path to the model to use. +# log_name = get_latest_log_directory(logs_directory='logs') # Uses the latest model in the log directory. +log_name = 'logs/FFI transit sai aeb FfiHades mag14 quick pos no neg cont from existing no random start 2020-12-19-16-10-27' # Specify the path to the model to use. saved_log_directory = Path(f'{log_name}') datetime_string = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") print('Setting up dataset...', flush=True) -database = TessTwoMinuteCadenceStandardAndInjectedTransitDatabase() +database = TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase() inference_dataset = database.generate_inference_dataset() print('Loading model...', flush=True) -model = Hades(database.number_of_label_types) +model = FfiHades() model.load_weights(str(saved_log_directory.joinpath('model.ckpt'))).expect_partial() print('Inferring...', flush=True) infer_results_path = saved_log_directory.joinpath(f'infer results {datetime_string}.csv') -infer(model, inference_dataset, infer_results_path) +infer(model, inference_dataset, infer_results_path, number_of_top_predictions_to_keep=5000) diff --git a/ramjet/analysis/transit_fitter.py b/ramjet/analysis/transit_fitter.py index ac556843..dc8d4385 100644 --- a/ramjet/analysis/transit_fitter.py +++ b/ramjet/analysis/transit_fitter.py @@ -350,7 +350,7 @@ def round_value_to_significant_figures(value): if __name__ == '__main__': print('Opening Bokeh application on http://localhost:5006/') # Start the server. - server = Server({'/': TransitFitter(tic_id=362043085).bokeh_application}) + server = Server({'/': TransitFitter(tic_id=297678377).bokeh_application}) server.start() # Start the specific application on the server. server.io_loop.add_callback(server.show, "/") diff --git a/ramjet/models/hades.py b/ramjet/models/hades.py index 60439316..7f171479 100644 --- a/ramjet/models/hades.py +++ b/ramjet/models/hades.py @@ -48,3 +48,48 @@ def call(self, inputs, training=False, mask=None): x = self.prediction_layer(x, training=training) outputs = self.reshape(x, training=training) return outputs + + +class FfiHades(Model): + def __init__(self): + super().__init__() + self.block0 = LightCurveNetworkBlock(filters=8, kernel_size=3, pooling_size=2, batch_normalization=False, + dropout_rate=0) + self.block1 = LightCurveNetworkBlock(filters=8, kernel_size=3, pooling_size=2) + self.block2 = LightCurveNetworkBlock(filters=16, kernel_size=3, pooling_size=2) + self.block3 = LightCurveNetworkBlock(filters=32, kernel_size=3, pooling_size=2) + self.block4 = LightCurveNetworkBlock(filters=64, kernel_size=3, pooling_size=2) + self.block5 = LightCurveNetworkBlock(filters=128, kernel_size=3, pooling_size=2) + self.block6 = LightCurveNetworkBlock(filters=128, kernel_size=3, pooling_size=1) + self.block7 = LightCurveNetworkBlock(filters=128, kernel_size=3, pooling_size=1) + self.block8 = LightCurveNetworkBlock(filters=20, kernel_size=3, pooling_size=1, spatial=False) + self.block9 = LightCurveNetworkBlock(filters=20, kernel_size=7, pooling_size=1) + self.block10 = LightCurveNetworkBlock(filters=20, kernel_size=1, pooling_size=1, batch_normalization=False, + dropout_rate=0) + self.prediction_layer = Convolution1D(1, kernel_size=1, activation=sigmoid) + self.reshape = Reshape([1]) + + def call(self, inputs, training=False, mask=None): + """ + The forward pass of the layer. + + :param inputs: The input tensor. + :param training: A boolean specifying if the layer should be in training mode. + :param mask: A mask for the input tensor. + :return: The output tensor of the layer. + """ + x = inputs + x = self.block0(x, training=training) + x = self.block1(x, training=training) + x = self.block2(x, training=training) + x = self.block3(x, training=training) + x = self.block4(x, training=training) + x = self.block5(x, training=training) + x = self.block6(x, training=training) + x = self.block7(x, training=training) + x = self.block8(x, training=training) + x = self.block9(x, training=training) + x = self.block10(x, training=training) + x = self.prediction_layer(x, training=training) + outputs = self.reshape(x, training=training) + return outputs \ No newline at end of file diff --git a/ramjet/photometric_database/derived/tess_ffi_eclipsing_binary_lightcurve_collection.py b/ramjet/photometric_database/derived/tess_ffi_eclipsing_binary_lightcurve_collection.py index c29fc3a6..a31d94b9 100644 --- a/ramjet/photometric_database/derived/tess_ffi_eclipsing_binary_lightcurve_collection.py +++ b/ramjet/photometric_database/derived/tess_ffi_eclipsing_binary_lightcurve_collection.py @@ -1,7 +1,9 @@ """ Code representing the collection of TESS two minute cadence lightcurves containing eclipsing binaries. """ -from typing import Union, List +import pandas as pd +from pathlib import Path +from typing import Union, List, Iterable from peewee import Select @@ -56,3 +58,15 @@ def get_sql_query(self) -> Select: TessEclipsingBinaryMetadata.tic_id.not_in(transit_tic_id_query)) query = query.where(TessFfiLightcurveMetadata.tic_id.in_(eclipsing_binary_tic_id_query)) return query + + +class TessFfiQuickTransitNegativeLightcurveCollection(TessFfiLightcurveCollection): + def __init__(self, dataset_splits: Union[List[int], None] = None, + magnitude_range: (Union[float, None], Union[float, None]) = (None, None)): + super().__init__(dataset_splits=dataset_splits, magnitude_range=magnitude_range) + self.label = 0 + + def get_paths(self) -> Iterable[Path]: + data_frame = pd.read_csv('quick_negative_paths.csv') + paths = list(map(Path, data_frame['Lightcurve path'].values)) + return paths diff --git a/ramjet/photometric_database/derived/tess_ffi_lightcurve_collection.py b/ramjet/photometric_database/derived/tess_ffi_lightcurve_collection.py index 02b9c9e4..18ef033d 100644 --- a/ramjet/photometric_database/derived/tess_ffi_lightcurve_collection.py +++ b/ramjet/photometric_database/derived/tess_ffi_lightcurve_collection.py @@ -34,8 +34,8 @@ def get_sql_query(self) -> Select: :return: The SQL query. """ query = TessFfiLightcurveMetadata().select() - query = self.order_by_dataset_split_with_random_start(query, TessFfiLightcurveMetadata.dataset_split, - self.dataset_splits) + # query = self.order_by_dataset_split_with_random_start(query, TessFfiLightcurveMetadata.dataset_split, + # self.dataset_splits) if self.magnitude_range[0] is not None and self.magnitude_range[1] is not None: query = query.where(TessFfiLightcurveMetadata.magnitude.between(*self.magnitude_range)) elif self.magnitude_range[0] is not None: diff --git a/ramjet/photometric_database/derived/tess_ffi_transit_databases.py b/ramjet/photometric_database/derived/tess_ffi_transit_databases.py index 1f8bfcc1..a298d6c0 100644 --- a/ramjet/photometric_database/derived/tess_ffi_transit_databases.py +++ b/ramjet/photometric_database/derived/tess_ffi_transit_databases.py @@ -1,5 +1,5 @@ from ramjet.photometric_database.derived.tess_ffi_eclipsing_binary_lightcurve_collection import \ - TessFfiAntiEclipsingBinaryForTransitLightcurveCollection + TessFfiAntiEclipsingBinaryForTransitLightcurveCollection, TessFfiQuickTransitNegativeLightcurveCollection from ramjet.photometric_database.derived.tess_ffi_lightcurve_collection import TessFfiLightcurveCollection from ramjet.photometric_database.derived.tess_ffi_transit_lightcurve_collections import \ TessFfiConfirmedTransitLightcurveCollection, TessFfiNonTransitLightcurveCollection @@ -15,11 +15,11 @@ def __init__(self): super().__init__() self.batch_size = 1000 self.time_steps_per_example = 1000 - self.shuffle_buffer_size = 100000 + self.shuffle_buffer_size = 10000 self.out_of_bounds_injection_handling = OutOfBoundsInjectionHandlingMethod.RANDOM_INJECTION_LOCATION -magnitude_range = (0, 11) +magnitude_range = (0, 14) class TessFfiStandardTransitDatabase(TessFfiDatabase): @@ -106,11 +106,15 @@ class TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase(TessFfiDataba """ def __init__(self): super().__init__() + self.shuffle_buffer_size = 10000 + self.number_of_parallel_processes_per_map = 6 self.training_standard_lightcurve_collections = [ TessFfiConfirmedTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range), TessFfiNonTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range), TessFfiAntiEclipsingBinaryForTransitLightcurveCollection(dataset_splits=list(range(8)), - magnitude_range=magnitude_range) + magnitude_range=magnitude_range), + # TessFfiQuickTransitNegativeLightcurveCollection(dataset_splits=list(range(8)), + # magnitude_range=magnitude_range) ] self.training_injectee_lightcurve_collection = TessFfiNonTransitLightcurveCollection( dataset_splits=list(range(8)), magnitude_range=magnitude_range) @@ -118,11 +122,16 @@ def __init__(self): TessFfiConfirmedTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range), TessFfiNonTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range), TessFfiAntiEclipsingBinaryForTransitLightcurveCollection(dataset_splits=list(range(8)), - magnitude_range=magnitude_range) + magnitude_range=magnitude_range), + # TessFfiQuickTransitNegativeLightcurveCollection(dataset_splits=list(range(8)), + # magnitude_range=magnitude_range) ] self.validation_standard_lightcurve_collections = [ TessFfiConfirmedTransitLightcurveCollection(dataset_splits=[8], magnitude_range=magnitude_range), TessFfiNonTransitLightcurveCollection(dataset_splits=[8], magnitude_range=magnitude_range), TessFfiAntiEclipsingBinaryForTransitLightcurveCollection(dataset_splits=list(range(8)), - magnitude_range=magnitude_range) + magnitude_range=magnitude_range), + # TessFfiQuickTransitNegativeLightcurveCollection(dataset_splits=list(range(8)), + # magnitude_range=magnitude_range) ] + self.inference_lightcurve_collections = [TessFfiLightcurveCollection(magnitude_range=magnitude_range)] diff --git a/train.py b/train.py index 11943218..520265db 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,10 @@ from tensorflow.python.keras import callbacks from tensorflow.python.keras.losses import BinaryCrossentropy -from ramjet.models.hades import Hades +from ramjet.basic_models import SimplePoolingLightcurveCnn2, FfiSimplePoolingLightcurveCnn2 +from ramjet.models.hades import Hades, FfiHades +from ramjet.photometric_database.derived.tess_ffi_transit_databases import \ + TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase from ramjet.photometric_database.derived.tess_two_minute_cadence_transit_databases import \ TessTwoMinuteCadenceStandardAndInjectedTransitDatabase @@ -14,9 +17,9 @@ def train(): """Runs the training.""" print('Starting training process...', flush=True) # Basic training settings. - trial_name = f'baseline' # Add any desired run name details to this string. - database = TessTwoMinuteCadenceStandardAndInjectedTransitDatabase() - model = Hades(database.number_of_label_types) + trial_name = f'FFI transit sai aeb FfiHades mag14 quick pos no neg cont from existing no random start' # Add any desired run name details to this string. + model = FfiHades() + database = TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase() # database.batch_size = 100 # Reducing the batch size may help if you are running out of memory. epochs_to_run = 1000 logs_directory = 'logs' @@ -33,7 +36,7 @@ def train(): training_dataset, validation_dataset = database.generate_datasets() optimizer = tf.optimizers.Adam(learning_rate=1e-4) loss_metric = BinaryCrossentropy(name='Loss') - metrics = [tf.keras.metrics.AUC(num_thresholds=20, name='Area_under_ROC_curve', multi_label=True), + metrics = [tf.keras.metrics.AUC(num_thresholds=20, name='Area_under_ROC_curve'), tf.metrics.SpecificityAtSensitivity(0.9, name='Specificity_at_90_percent_sensitivity'), tf.metrics.SensitivityAtSpecificity(0.9, name='Sensitivity_at_90_percent_specificity'), tf.metrics.BinaryAccuracy(name='Accuracy'), tf.metrics.Precision(name='Precision'), @@ -41,6 +44,7 @@ def train(): # Compile and train model. model.compile(optimizer=optimizer, loss=loss_metric, metrics=metrics) + model.load_weights('/att/gpfsfs/briskfs01/ppl/golmsche/ramjet/logs/FFI transit sai aeb FfiHades mag13 quick pos no neg cont from existing no random start 2020-10-08-17-11-05/model.ckpt') try: model.fit(training_dataset, epochs=epochs_to_run, validation_data=validation_dataset, callbacks=[tensorboard_callback, model_checkpoint_callback], steps_per_epoch=5000,