Skip to content

Commit

Permalink
add saving history #92
Browse files Browse the repository at this point in the history
  • Loading branch information
lsantuari committed Jul 21, 2021
1 parent 687bd94 commit 21697db
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions scripts/utils/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,19 @@ def fitness(cnn_filters, cnn_layers, cnn_kernel_size, cnn_fc_nodes,
print('cnn_regularization_rate: ', cnn_regularization_rate)
print()

#if path_input_model is None:

model = create_model(train_X, 2,
learning_rate=cnn_init_learning_rate, regularization_rate=cnn_regularization_rate,
filters=cnn_filters, layers=cnn_layers, kernel_size=cnn_kernel_size, fc_nodes=cnn_fc_nodes)
print(model.summary())

# else:
#
# model = load_model(path_input_model)
# model.trainable = False
# model.summary()

callback_log = TensorBoard(
log_dir='log_dir',
histogram_freq=0,
Expand Down Expand Up @@ -130,22 +138,26 @@ def fitness(cnn_filters, cnn_layers, cnn_kernel_size, cnn_fc_nodes,

hist_df = pd.DataFrame(history.history)
hist_csv_file = 'history.csv'
with open(hist_csv_file, mode='w') as f:
hist_df.to_csv(f)

print()
print('Accuracy: {0:.2%}'.format(accuracy))
if accuracy > best_accuracy:

with open(hist_csv_file, mode='w') as f:
hist_df.to_csv(f)

model.save(path_best_model)
best_accuracy = accuracy

del model
tf.keras.backend.clear_session()
return -accuracy


def optimize(args):

global train_X, val_X, train_y, val_y, class_weights, batch_size, max_epoch, path_best_model
global train_X, val_X, train_y, val_y, class_weights, batch_size,\
max_epoch, path_best_model, path_input_model

randomState = 46
np.random.seed(randomState)
Expand All @@ -154,6 +166,7 @@ def optimize(args):
batch_size = args.batch_size
max_epoch = args.epochs
path_best_model = args.model
path_input_model = args.input_model

# training data
windows = args.windows.split(',')
Expand Down Expand Up @@ -238,6 +251,11 @@ def main():
type=str,
default='DEL',
help="Type of SV")
parser.add_argument('-im',
'--input_model',
type=str,
default='~/Documents/Projects/GTCG/sv-channels/sv-channels_manuscript/1KG_trios/results/manta_loocv/HG01053/manta_model.keras',
help="Input model")
parser.add_argument('-m',
'--model',
type=str,
Expand Down

0 comments on commit 21697db

Please sign in to comment.