forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
svhn-disturb.py
executable file
·62 lines (52 loc) · 1.91 KB
/
svhn-disturb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: svhn-disturb.py
# Author: Yuxin Wu <[email protected]>
import tensorflow as tf
import argparse
import numpy as np
import os
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from disturb import DisturbLabel
import imp
svhn_example = imp.load_source('svhn_example',
os.path.join(os.path.dirname(__file__), '..', 'svhn-digit-convnet.py')))
Model = svhn_example.Model
get_config = svhn_example.get_config
def get_data():
d1 = dataset.SVHNDigit('train')
d2 = dataset.SVHNDigit('extra')
data_train = RandomMixData([d1, d2])
data_train = DisturbLabel(data_train, args.prob)
data_test = dataset.SVHNDigit('test')
augmentors = [
imgaug.Resize((40, 40)),
imgaug.Brightness(30),
imgaug.Contrast((0.5,1.5)),
]
data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128)
data_train = PrefetchData(data_train, 5, 5)
augmentors = [ imgaug.Resize((40, 40)) ]
data_test = AugmentImageComponent(data_test, augmentors)
data_test = BatchData(data_test, 128, remainder=True)
return data_train, data_test
svhn_example.get_data = get_data
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model')
parser.add_argument('--prob', help='disturb prob', type=float, required=True)
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
config = get_config(args.prob)
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train()