forked from AI100-CSDN/quiz_slim
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_eval_image_classifier.py
69 lines (59 loc) · 3.94 KB
/
train_eval_image_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
def parse_args(check=True):
print(sys.argv[1:]);
parser = argparse.ArgumentParser()
print('parser loaded')
# train
parser.add_argument('--dataset_name', type=str, default='quiz')
parser.add_argument('--dataset_dir', type=str)
parser.add_argument('--checkpoint_path', type=str)
parser.add_argument('--model_name', type=str, default='inception_v4')
parser.add_argument('--checkpoint_exclude_scopes', type=str, default='InceptionV4/Logits,InceptionV4/AuxLogits/Aux_logits')
parser.add_argument('--train_dir', type=str)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--optimizer', type=str, default='rmsprop')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--clone_on_cpu', type=bool, default=False)
print('parser extra argument 1')
# eval
parser.add_argument('--dataset_split_name', type=str, default='validation')
parser.add_argument('--eval_dir', type=str, default='validation')
parser.add_argument('--max_num_batches', type=str, default='validation')
print('parser extra argument 2')
FLAGS, unparsed = parser.parse_known_args()
print('parser extra argument finished')
return FLAGS, unparsed
# train_cmd = 'python3 ./train_image_classifier.py --dataset_name={dataset_name} --dataset_dir={dataset_dir} --checkpoint_path={checkpoint_path} --model_name={model_name} --checkpoint_exclude_scopes={checkpoint_exclude_scopes} --train_dir={train_dir} --learning_rate={learning_rate} --optimizer={optimizer} --batch_size={batch_size} --max_number_of_steps={max_number_of_steps} --clone_on_cpu={clone_on_cpu} '
train_cmd = 'python3 ./train_image_classifier.py --dataset_name={dataset_name} --dataset_dir={dataset_dir} --model_name={model_name} --train_dir={train_dir} --learning_rate={learning_rate} --optimizer={optimizer} --batch_size={batch_size}'
eval_cmd = 'python3 ./eval_image_classifier.py --dataset_name={dataset_name} --dataset_dir={dataset_dir} --dataset_split_name={dataset_split_name} --model_name={model_name} --checkpoint_path={checkpoint_path} --eval_dir={eval_dir} --batch_size={batch_size} --max_num_batches={max_num_batches}'
if __name__ == '__main__':
FLAGS, unparsed = parse_args()
print('current working dir [{0}]'.format(os.getcwd()))
w_d = os.path.dirname(os.path.abspath(__file__))
print('change wording dir to [{0}]'.format(w_d))
os.chdir(w_d)
step_per_epoch = 70000 // FLAGS.batch_size
for i in range(30):
steps = int(step_per_epoch * (i + 1))
# train 1 epoch
print('################ train ################')
p = os.popen(train_cmd.format(**{'dataset_name': FLAGS.dataset_name, 'dataset_dir': FLAGS.dataset_dir,
'checkpoint_path': FLAGS.checkpoint_path, 'model_name': FLAGS. model_name,
'checkpoint_exclude_scopes': FLAGS.checkpoint_exclude_scopes, 'train_dir': FLAGS. train_dir,
'learning_rate': FLAGS.learning_rate, 'optimizer': FLAGS.optimizer,
'batch_size': FLAGS.batch_size, 'max_number_of_steps': steps, 'clone_on_cpu': FLAGS.clone_on_cpu}))
for l in p:
print(p.strip())
# eval
print('################ eval ################')
p = os.popen(eval_cmd.format(**{'dataset_name': FLAGS.dataset_name, 'dataset_dir': FLAGS.dataset_dir,
'dataset_split_name': 'validation', 'model_name': FLAGS. model_name,
'checkpoint_path': FLAGS.train_dir, 'batch_size': FLAGS.batch_size,
'eval_dir': FLAGS. eval_dir, 'max_num_batches': FLAGS. max_num_batches}))
for l in p:
print(p.strip())