forked from lengstrom/fast-style-transfer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
155 lines (134 loc) Β· 6.02 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import sys
sys.path.insert(0, 'src')
import transform, numpy as np, vgg, pdb, os
import scipy.misc
import tensorflow as tf
from utils import save_img, get_img, exists, list_files
from argparse import ArgumentParser
from collections import defaultdict
import time
BATCH_SIZE = 4
DEVICE = '/gpu:0'
# get img_shape
def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
assert len(paths_out) > 0
is_paths = type(data_in[0]) == str
if is_paths:
assert len(data_in) == len(paths_out)
img_shape = get_img(data_in[0]).shape
else:
assert data_in.size[0] == len(paths_out)
img_shape = X[0].shape
g = tf.Graph()
batch_size = min(len(paths_out), batch_size)
curr_num = 0
soft_config = tf.ConfigProto(allow_soft_placement=True)
soft_config.gpu_options.allow_growth = True
with g.as_default(), g.device(device_t), \
tf.Session(config=soft_config) as sess:
batch_shape = (batch_size,) + img_shape
img_placeholder = tf.placeholder(tf.float32, shape=batch_shape,
name='img_placeholder')
preds = transform.net(img_placeholder)
saver = tf.train.Saver()
if os.path.isdir(checkpoint_dir):
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise Exception("No checkpoint found...")
else:
saver.restore(sess, checkpoint_dir)
num_iters = int(len(paths_out)/batch_size)
for i in range(num_iters):
pos = i * batch_size
curr_batch_out = paths_out[pos:pos+batch_size]
if is_paths:
curr_batch_in = data_in[pos:pos+batch_size]
X = np.zeros(batch_shape, dtype=np.float32)
for j, path_in in enumerate(curr_batch_in):
img = get_img(path_in)
assert img.shape == img_shape, \
'Images have different dimensions. ' + \
'Resize images or use --allow-different-dimensions.'
X[j] = img
else:
X = data_in[pos:pos+batch_size]
_preds = sess.run(preds, feed_dict={img_placeholder:X})
for j, path_out in enumerate(curr_batch_out):
save_img(path_out, _preds[j])
remaining_in = data_in[num_iters*batch_size:]
remaining_out = paths_out[num_iters*batch_size:]
if len(remaining_in) > 0:
ffwd(remaining_in, remaining_out, checkpoint_dir,
device_t=device_t, batch_size=1)
def ffwd_to_img(in_path, out_path, checkpoint_dir, device='/cpu:0'):
paths_in, paths_out = [in_path], [out_path]
ffwd(paths_in, paths_out, checkpoint_dir, batch_size=1, device_t=device)
def ffwd_different_dimensions(in_path, out_path, checkpoint_dir,
device_t=DEVICE, batch_size=4):
in_path_of_shape = defaultdict(list)
out_path_of_shape = defaultdict(list)
for i in range(len(in_path)):
in_image = in_path[i]
out_image = out_path[i]
shape = "%dx%dx%d" % get_img(in_image).shape
in_path_of_shape[shape].append(in_image)
out_path_of_shape[shape].append(out_image)
for shape in in_path_of_shape:
print('Processing images of shape %s' % shape)
ffwd(in_path_of_shape[shape], out_path_of_shape[shape],
checkpoint_dir, device_t, batch_size)
def build_parser():
parser = ArgumentParser()
parser.add_argument('--checkpoint', type=str,
dest='checkpoint_dir',
help='dir or .ckpt file to load checkpoint from',
metavar='CHECKPOINT', required=True)
parser.add_argument('--in-path', type=str,
dest='in_path',help='dir or file to transform',
metavar='IN_PATH', required=True)
help_out = 'destination (dir or file) of transformed file or files'
parser.add_argument('--out-path', type=str,
dest='out_path', help=help_out, metavar='OUT_PATH',
required=True)
parser.add_argument('--device', type=str,
dest='device',help='device to perform compute on',
metavar='DEVICE', default=DEVICE)
parser.add_argument('--batch-size', type=int,
dest='batch_size',help='batch size for feedforwarding',
metavar='BATCH_SIZE', default=BATCH_SIZE)
parser.add_argument('--allow-different-dimensions', action='store_true',
dest='allow_different_dimensions',
help='allow different image dimensions')
return parser
def check_opts(opts):
exists(opts.checkpoint_dir, 'Checkpoint not found!')
exists(opts.in_path, 'In path not found!')
if os.path.isdir(opts.out_path):
exists(opts.out_path, 'out dir not found!')
assert opts.batch_size > 0
def main():
parser = build_parser()
opts = parser.parse_args()
check_opts(opts)
if not os.path.isdir(opts.in_path):
if os.path.exists(opts.out_path) and os.path.isdir(opts.out_path):
out_path = \
os.path.join(opts.out_path,os.path.basename(opts.in_path))
else:
out_path = opts.out_path
ffwd_to_img(opts.in_path, out_path, opts.checkpoint_dir,
device=opts.device)
else:
files = list_files(opts.in_path)
full_in = map(lambda x: os.path.join(opts.in_path,x), files)
full_out = map(lambda x: os.path.join(opts.out_path,x), files)
if opts.allow_different_dimensions:
ffwd_different_dimensions(full_in, full_out, opts.checkpoint_dir,
device_t=opts.device, batch_size=opts.batch_size)
else :
ffwd(full_in, full_out, opts.checkpoint_dir, device_t=opts.device,
batch_size=opts.batch_size)
if __name__ == '__main__':
main()