-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_hyperopt.py
81 lines (64 loc) · 2.11 KB
/
run_hyperopt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import random
from datetime import datetime
from pathlib import Path
from typing import Dict, List
import hydra
import numpy as np
import omegaconf
import tensorflow as tf
from omegaconf import DictConfig
from tqdm import tqdm
from wandb.keras import WandbCallback
import wandb
from src.base_model import build_model
from src.data import generate_data
tf.keras.backend.clear_session()
tf.autograph.set_verbosity(level=0, alsologtostdout=False)
tf.get_logger().setLevel(3)
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)
@hydra.main(config_path="src/conf", config_name="train")
def run_experiments(config: DictConfig):
wandb.init(project=config.project, entity="elte-ai4covid")
fold = np.random.randint(0, 5) if config.fold == -1 else config.fold
datasets = generate_data(config, fold=fold)
model = build_model(config)
chkpt_dir = (
Path(config.raw_output_base)
/ "checkpoints"
/ datetime.now().strftime("%Y-%m-%d_%H:%M:%S.%f")
)
Path(chkpt_dir).mkdir(exist_ok=True, parents=True)
chkpt_path = Path(chkpt_dir) / "cp.ckpt"
omegaconf.OmegaConf.save(config=config, f=Path(chkpt_dir) / "config.yaml")
with open(os.path.join(chkpt_dir, "fold"), "w") as fp:
fp.write(str(fold))
model.summary()
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=chkpt_path,
save_weights_only=True,
save_best_only=True,
verbose=0,
monitor="val_balanced_accuracy",
save_freq="epoch",
mode="max",
)
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor="val_balanced_accuracy",
min_delta=0,
patience=5,
verbose=0,
mode="min",
baseline=None,
)
_ = model.fit(
datasets["train_dataset"],
validation_data=datasets["validation_dataset"],
epochs=config.epochs,
steps_per_epoch=config.steps_per_epoch,
callbacks=[WandbCallback(), cp_callback, early_stopping],
verbose=1,
)
if __name__ == "__main__":
run_experiments()