-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
108 lines (81 loc) · 3.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from argparse import ArgumentParser
from src.dataset import AirbusDataset
from src.model import ResNet, ResNetUnet
from src.metrics import BCE_Dice_loss, f2_score, dice_coeff
from settings.config import Config
import warnings
warnings.filterwarnings('ignore')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--classificator_train', type=str,
default='y')
parser.add_argument('--segmentator_train', type=str,
default='y')
args = parser.parse_args()
if tf.config.list_physical_devices('GPU'):
print("TensorFlow has detected GPUs.")
else:
print("No GPUs found. TensorFlow is using CPU.")
if args.classificator_train in ['y', 'yes']:
print('#' * 25)
print('Start of classificator train.')
print('#' * 25)
dataset = AirbusDataset()
train_ds, valid_ds = dataset.get_datasets(segmentation=False)
model = ResNet().build_model()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=Config.learning_rate),
loss='binary_crossentropy',
metrics=['accuracy', f2_score])
checkpoint = tf.keras.callbacks.ModelCheckpoint(
os.path.join(Config.model_path, 'resnet.h5'),
monitor='val_f2_score',
save_best_only=True,
save_weights_only=True,
mode='max',
initial_value_threshold=0.92
)
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_f2_score',
factor=0.7,
patience=3,
mode='max',
min_lr=1e-6
)
model.fit(train_ds, epochs=Config.epochs, validation_data=valid_ds,
callbacks=[checkpoint, lr_scheduler])
print('#' * 25)
print('End of classificator train.')
print('#' * 25)
if args.segmentator_train in ['y', 'yes']:
print('#' * 25)
print('Start of segmentator train.')
print('#' * 25)
dataset = AirbusDataset()
train_ds, valid_ds = dataset.get_datasets(segmentation=True)
model = ResNetUnet().build_model()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=Config.learning_rate),
loss=BCE_Dice_loss,
metrics=[f2_score, dice_coeff])
checkpoint = tf.keras.callbacks.ModelCheckpoint(
os.path.join(Config.model_path, 'unet.h5'),
monitor='val_dice_coeff',
save_best_only=True,
save_weights_only=True,
mode='max',
initial_value_threshold=0.75
)
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_dice_coeff',
factor=0.7,
patience=3,
mode='max',
min_lr=1e-6
)
model.fit(train_ds, epochs=Config.epochs, validation_data=valid_ds,
callbacks=[checkpoint, lr_scheduler])
print('#' * 25)
print('End of segmentator train.')
print('#' * 25)