-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
110 lines (95 loc) · 3.72 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# USAGE
# python train.py --checkpoints output/checkpoints
# python train.py --checkpoints output/checkpoints \
# --model output/checkpoints/epoch_40.hdf5 --start-epoch 40
# python train.py --checkpoints output/checkpoints \
# --model output/checkpoints/epoch_50.hdf5 --start-epoch 50
# set the matplotlib backend so figures can be saved in the background
import matplotlib
matplotlib.use("Agg")
# import the necessary packages
from pyimagesearch.callbacks.epochcheckpoint import EpochCheckpoint
from pyimagesearch.callbacks.trainingmonitor import TrainingMonitor
from pyimagesearch.nn.resnet import ResNet
from sklearn.preprocessing import LabelBinarizer
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras.datasets import fashion_mnist
from keras.models import load_model
import keras.backend as K
import numpy as np
import argparse
import cv2
import sys
import os
# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-c", "--checkpoints", required=True,
help="path to output checkpoint directory")
ap.add_argument("-m", "--model", type=str,
help="path to *specific* model checkpoint to load")
ap.add_argument("-s", "--start-epoch", type=int, default=0,
help="epoch to restart training at")
args = vars(ap.parse_args())
# grab the Fashion MNIST dataset (if this is your first time running
# this the dataset will be automatically downloaded)
print("[INFO] loading Fashion MNIST...")
((trainX, trainY), (testX, testY)) = fashion_mnist.load_data()
# Fashion MNIST images are 28x28 but the network we will be training
# is expecting 32x32 images
trainX = np.array([cv2.resize(x, (32, 32)) for x in trainX])
testX = np.array([cv2.resize(x, (32, 32)) for x in testX])
# scale data to the range of [0, 1]
trainX = trainX.astype("float32") / 255.0
testX = testX.astype("float32") / 255.0
# reshape the data matrices to include a channel dimension (required
# for training)
trainX = trainX.reshape((trainX.shape[0], 32, 32, 1))
testX = testX.reshape((testX.shape[0], 32, 32, 1))
# convert the labels from integers to vectors
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# construct the image generator for data augmentation
aug = ImageDataGenerator(width_shift_range=0.1,
height_shift_range=0.1, horizontal_flip=True,
fill_mode="nearest")
# if there is no specific model checkpoint supplied, then initialize
# the network (ResNet-56) and compile the model
if args["model"] is None:
print("[INFO] compiling model...")
opt = SGD(lr=1e-1)
model = ResNet.build(32, 32, 1, 10, (9, 9, 9),
(64, 64, 128, 256), reg=0.0001)
model.compile(loss="categorical_crossentropy", optimizer=opt,
metrics=["accuracy"])
# otherwise, we're using a checkpoint model
else:
# load the checkpoint from disk
print("[INFO] loading {}...".format(args["model"]))
model = load_model(args["model"])
# update the learning rate
print("[INFO] old learning rate: {}".format(
K.get_value(model.optimizer.lr)))
K.set_value(model.optimizer.lr, 1e-2)
print("[INFO] new learning rate: {}".format(
K.get_value(model.optimizer.lr)))
# build the path to the training plot and training history
plotPath = os.path.sep.join(["output", "resnet_fashion_mnist.png"])
jsonPath = os.path.sep.join(["output", "resnet_fashion_mnist.json"])
# construct the set of callbacks
callbacks = [
EpochCheckpoint(args["checkpoints"], every=5,
startAt=args["start_epoch"]),
TrainingMonitor(plotPath,
jsonPath=jsonPath,
startAt=args["start_epoch"])]
# train the network
print("[INFO] training network...")
model.fit_generator(
aug.flow(trainX, trainY, batch_size=128),
validation_data=(testX, testY),
steps_per_epoch=len(trainX) // 128,
epochs=80,
callbacks=callbacks,
verbose=1)