Skip to content

Commit

Permalink
Committing the changes required to produce the results for the 2021 t…
Browse files Browse the repository at this point in the history
…ransit paper
  • Loading branch information
golmschenk committed Jan 24, 2021
1 parent 1ed893b commit afe6f6c
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 21 deletions.
14 changes: 8 additions & 6 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,27 @@
import datetime
from pathlib import Path

from ramjet.models.hades import Hades
from ramjet.models.hades import Hades, FfiHades
from ramjet.photometric_database.derived.tess_ffi_transit_databases import \
TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase
from ramjet.photometric_database.derived.tess_two_minute_cadence_transit_databases import \
TessTwoMinuteCadenceStandardAndInjectedTransitDatabase
from ramjet.analysis.model_loader import get_latest_log_directory
from ramjet.trial import infer

log_name = get_latest_log_directory(logs_directory='logs') # Uses the latest model in the log directory.
# log_name = 'logs/baseline YYYY-MM-DD-hh-mm-ss' # Specify the path to the model to use.
# log_name = get_latest_log_directory(logs_directory='logs') # Uses the latest model in the log directory.
log_name = 'logs/FFI transit sai aeb FfiHades mag14 quick pos no neg cont from existing no random start 2020-12-19-16-10-27' # Specify the path to the model to use.
saved_log_directory = Path(f'{log_name}')
datetime_string = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

print('Setting up dataset...', flush=True)
database = TessTwoMinuteCadenceStandardAndInjectedTransitDatabase()
database = TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase()
inference_dataset = database.generate_inference_dataset()

print('Loading model...', flush=True)
model = Hades(database.number_of_label_types)
model = FfiHades()
model.load_weights(str(saved_log_directory.joinpath('model.ckpt'))).expect_partial()

print('Inferring...', flush=True)
infer_results_path = saved_log_directory.joinpath(f'infer results {datetime_string}.csv')
infer(model, inference_dataset, infer_results_path)
infer(model, inference_dataset, infer_results_path, number_of_top_predictions_to_keep=5000)
2 changes: 1 addition & 1 deletion ramjet/analysis/transit_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def round_value_to_significant_figures(value):
if __name__ == '__main__':
print('Opening Bokeh application on http://localhost:5006/')
# Start the server.
server = Server({'/': TransitFitter(tic_id=362043085).bokeh_application})
server = Server({'/': TransitFitter(tic_id=297678377).bokeh_application})
server.start()
# Start the specific application on the server.
server.io_loop.add_callback(server.show, "/")
Expand Down
45 changes: 45 additions & 0 deletions ramjet/models/hades.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,48 @@ def call(self, inputs, training=False, mask=None):
x = self.prediction_layer(x, training=training)
outputs = self.reshape(x, training=training)
return outputs


class FfiHades(Model):
def __init__(self):
super().__init__()
self.block0 = LightCurveNetworkBlock(filters=8, kernel_size=3, pooling_size=2, batch_normalization=False,
dropout_rate=0)
self.block1 = LightCurveNetworkBlock(filters=8, kernel_size=3, pooling_size=2)
self.block2 = LightCurveNetworkBlock(filters=16, kernel_size=3, pooling_size=2)
self.block3 = LightCurveNetworkBlock(filters=32, kernel_size=3, pooling_size=2)
self.block4 = LightCurveNetworkBlock(filters=64, kernel_size=3, pooling_size=2)
self.block5 = LightCurveNetworkBlock(filters=128, kernel_size=3, pooling_size=2)
self.block6 = LightCurveNetworkBlock(filters=128, kernel_size=3, pooling_size=1)
self.block7 = LightCurveNetworkBlock(filters=128, kernel_size=3, pooling_size=1)
self.block8 = LightCurveNetworkBlock(filters=20, kernel_size=3, pooling_size=1, spatial=False)
self.block9 = LightCurveNetworkBlock(filters=20, kernel_size=7, pooling_size=1)
self.block10 = LightCurveNetworkBlock(filters=20, kernel_size=1, pooling_size=1, batch_normalization=False,
dropout_rate=0)
self.prediction_layer = Convolution1D(1, kernel_size=1, activation=sigmoid)
self.reshape = Reshape([1])

def call(self, inputs, training=False, mask=None):
"""
The forward pass of the layer.
:param inputs: The input tensor.
:param training: A boolean specifying if the layer should be in training mode.
:param mask: A mask for the input tensor.
:return: The output tensor of the layer.
"""
x = inputs
x = self.block0(x, training=training)
x = self.block1(x, training=training)
x = self.block2(x, training=training)
x = self.block3(x, training=training)
x = self.block4(x, training=training)
x = self.block5(x, training=training)
x = self.block6(x, training=training)
x = self.block7(x, training=training)
x = self.block8(x, training=training)
x = self.block9(x, training=training)
x = self.block10(x, training=training)
x = self.prediction_layer(x, training=training)
outputs = self.reshape(x, training=training)
return outputs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
Code representing the collection of TESS two minute cadence lightcurves containing eclipsing binaries.
"""
from typing import Union, List
import pandas as pd
from pathlib import Path
from typing import Union, List, Iterable

from peewee import Select

Expand Down Expand Up @@ -56,3 +58,15 @@ def get_sql_query(self) -> Select:
TessEclipsingBinaryMetadata.tic_id.not_in(transit_tic_id_query))
query = query.where(TessFfiLightcurveMetadata.tic_id.in_(eclipsing_binary_tic_id_query))
return query


class TessFfiQuickTransitNegativeLightcurveCollection(TessFfiLightcurveCollection):
def __init__(self, dataset_splits: Union[List[int], None] = None,
magnitude_range: (Union[float, None], Union[float, None]) = (None, None)):
super().__init__(dataset_splits=dataset_splits, magnitude_range=magnitude_range)
self.label = 0

def get_paths(self) -> Iterable[Path]:
data_frame = pd.read_csv('quick_negative_paths.csv')
paths = list(map(Path, data_frame['Lightcurve path'].values))
return paths
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def get_sql_query(self) -> Select:
:return: The SQL query.
"""
query = TessFfiLightcurveMetadata().select()
query = self.order_by_dataset_split_with_random_start(query, TessFfiLightcurveMetadata.dataset_split,
self.dataset_splits)
# query = self.order_by_dataset_split_with_random_start(query, TessFfiLightcurveMetadata.dataset_split,
# self.dataset_splits)
if self.magnitude_range[0] is not None and self.magnitude_range[1] is not None:
query = query.where(TessFfiLightcurveMetadata.magnitude.between(*self.magnitude_range))
elif self.magnitude_range[0] is not None:
Expand Down
21 changes: 15 additions & 6 deletions ramjet/photometric_database/derived/tess_ffi_transit_databases.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ramjet.photometric_database.derived.tess_ffi_eclipsing_binary_lightcurve_collection import \
TessFfiAntiEclipsingBinaryForTransitLightcurveCollection
TessFfiAntiEclipsingBinaryForTransitLightcurveCollection, TessFfiQuickTransitNegativeLightcurveCollection
from ramjet.photometric_database.derived.tess_ffi_lightcurve_collection import TessFfiLightcurveCollection
from ramjet.photometric_database.derived.tess_ffi_transit_lightcurve_collections import \
TessFfiConfirmedTransitLightcurveCollection, TessFfiNonTransitLightcurveCollection
Expand All @@ -15,11 +15,11 @@ def __init__(self):
super().__init__()
self.batch_size = 1000
self.time_steps_per_example = 1000
self.shuffle_buffer_size = 100000
self.shuffle_buffer_size = 10000
self.out_of_bounds_injection_handling = OutOfBoundsInjectionHandlingMethod.RANDOM_INJECTION_LOCATION


magnitude_range = (0, 11)
magnitude_range = (0, 14)


class TessFfiStandardTransitDatabase(TessFfiDatabase):
Expand Down Expand Up @@ -106,23 +106,32 @@ class TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase(TessFfiDataba
"""
def __init__(self):
super().__init__()
self.shuffle_buffer_size = 10000
self.number_of_parallel_processes_per_map = 6
self.training_standard_lightcurve_collections = [
TessFfiConfirmedTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range),
TessFfiNonTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range),
TessFfiAntiEclipsingBinaryForTransitLightcurveCollection(dataset_splits=list(range(8)),
magnitude_range=magnitude_range)
magnitude_range=magnitude_range),
# TessFfiQuickTransitNegativeLightcurveCollection(dataset_splits=list(range(8)),
# magnitude_range=magnitude_range)
]
self.training_injectee_lightcurve_collection = TessFfiNonTransitLightcurveCollection(
dataset_splits=list(range(8)), magnitude_range=magnitude_range)
self.training_injectable_lightcurve_collections = [
TessFfiConfirmedTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range),
TessFfiNonTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range),
TessFfiAntiEclipsingBinaryForTransitLightcurveCollection(dataset_splits=list(range(8)),
magnitude_range=magnitude_range)
magnitude_range=magnitude_range),
# TessFfiQuickTransitNegativeLightcurveCollection(dataset_splits=list(range(8)),
# magnitude_range=magnitude_range)
]
self.validation_standard_lightcurve_collections = [
TessFfiConfirmedTransitLightcurveCollection(dataset_splits=[8], magnitude_range=magnitude_range),
TessFfiNonTransitLightcurveCollection(dataset_splits=[8], magnitude_range=magnitude_range),
TessFfiAntiEclipsingBinaryForTransitLightcurveCollection(dataset_splits=list(range(8)),
magnitude_range=magnitude_range)
magnitude_range=magnitude_range),
# TessFfiQuickTransitNegativeLightcurveCollection(dataset_splits=list(range(8)),
# magnitude_range=magnitude_range)
]
self.inference_lightcurve_collections = [TessFfiLightcurveCollection(magnitude_range=magnitude_range)]
14 changes: 9 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from tensorflow.python.keras import callbacks
from tensorflow.python.keras.losses import BinaryCrossentropy

from ramjet.models.hades import Hades
from ramjet.basic_models import SimplePoolingLightcurveCnn2, FfiSimplePoolingLightcurveCnn2
from ramjet.models.hades import Hades, FfiHades
from ramjet.photometric_database.derived.tess_ffi_transit_databases import \
TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase
from ramjet.photometric_database.derived.tess_two_minute_cadence_transit_databases import \
TessTwoMinuteCadenceStandardAndInjectedTransitDatabase

Expand All @@ -14,9 +17,9 @@ def train():
"""Runs the training."""
print('Starting training process...', flush=True)
# Basic training settings.
trial_name = f'baseline' # Add any desired run name details to this string.
database = TessTwoMinuteCadenceStandardAndInjectedTransitDatabase()
model = Hades(database.number_of_label_types)
trial_name = f'FFI transit sai aeb FfiHades mag14 quick pos no neg cont from existing no random start' # Add any desired run name details to this string.
model = FfiHades()
database = TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase()
# database.batch_size = 100 # Reducing the batch size may help if you are running out of memory.
epochs_to_run = 1000
logs_directory = 'logs'
Expand All @@ -33,14 +36,15 @@ def train():
training_dataset, validation_dataset = database.generate_datasets()
optimizer = tf.optimizers.Adam(learning_rate=1e-4)
loss_metric = BinaryCrossentropy(name='Loss')
metrics = [tf.keras.metrics.AUC(num_thresholds=20, name='Area_under_ROC_curve', multi_label=True),
metrics = [tf.keras.metrics.AUC(num_thresholds=20, name='Area_under_ROC_curve'),
tf.metrics.SpecificityAtSensitivity(0.9, name='Specificity_at_90_percent_sensitivity'),
tf.metrics.SensitivityAtSpecificity(0.9, name='Sensitivity_at_90_percent_specificity'),
tf.metrics.BinaryAccuracy(name='Accuracy'), tf.metrics.Precision(name='Precision'),
tf.metrics.Recall(name='Recall')]

# Compile and train model.
model.compile(optimizer=optimizer, loss=loss_metric, metrics=metrics)
model.load_weights('/att/gpfsfs/briskfs01/ppl/golmsche/ramjet/logs/FFI transit sai aeb FfiHades mag13 quick pos no neg cont from existing no random start 2020-10-08-17-11-05/model.ckpt')
try:
model.fit(training_dataset, epochs=epochs_to_run, validation_data=validation_dataset,
callbacks=[tensorboard_callback, model_checkpoint_callback], steps_per_epoch=5000,
Expand Down

0 comments on commit afe6f6c

Please sign in to comment.