-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_cls.py
94 lines (73 loc) · 2.87 KB
/
train_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import json
import pandas as pd
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint
def generate(batch, shape, ptrain, pval):
"""Data generation and augmentation
# Arguments
batch: Integer, batch size.
size: Integer, image size.
ptrain: train dir.
pval: eval dir.
# Returns
train_generator: train set generator
validation_generator: validation set generator
count1: Integer, number of train set.
count2: Integer, number of test set.
"""
datagen1 = ImageDataGenerator(rescale=1. / 255)
datagen2 = ImageDataGenerator(rescale=1. / 255)
train_generator = datagen1.flow_from_directory(
ptrain,
target_size=shape,
batch_size=batch,
class_mode='categorical')
validation_generator = datagen2.flow_from_directory(
pval,
target_size=shape,
batch_size=batch,
class_mode='categorical')
count1 = 0
for root, dirs, files in os.walk(ptrain):
for each in files:
count1 += 1
count2 = 0
for root, dirs, files in os.walk(pval):
for each in files:
count2 += 1
return train_generator, validation_generator, count1, count2
def train():
with open('config/config.json', 'r') as f:
cfg = json.load(f)
save_dir = cfg['save_dir']
shape = (int(cfg['height']), int(cfg['width']), 3)
n_class = int(cfg['class_number'])
batch = int(cfg['batch'])
if not os.path.exists(save_dir):
os.mkdir(save_dir)
if cfg['model'] == 'large':
from model.mobilenet_v3_large import MobileNetV3_Large
model = MobileNetV3_Large(shape, n_class).build()
if cfg['model'] == 'small':
from model.mobilenet_v3_small import MobileNetV3_Small
model = MobileNetV3_Small(shape, n_class).build()
pre_weights = cfg['weights']
if pre_weights and os.path.exists(pre_weights):
model.load_weights(pre_weights, by_name=True)
opt = Adam(lr=float(cfg['learning_rate']))
checkpoint = ModelCheckpoint('weights.{epoch:02d}-{val_loss:.2f}.hdf5', monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)
train_generator, validation_generator, count1, count2 = generate(batch, shape[:2], cfg['train_dir'], cfg['eval_dir'])
hist = model.fit_generator(
train_generator,
validation_data=validation_generator,
steps_per_epoch=count1 // batch,
validation_steps=count2 // batch,
epochs=cfg['epochs'],
callbacks=[checkpoint])
df = pd.DataFrame.from_dict(hist.history)
df.to_csv(os.path.join(save_dir, 'hist.csv'), encoding='utf-8', index=False)
model.save_weights(os.path.join(save_dir, '{}_weights.h5'.format(cfg['model'])))
if __name__ == '__main__':
train()