-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvgg_16.py
116 lines (95 loc) · 4.39 KB
/
vgg_16.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#%matplotlib inline
from matplotlib import pyplot as plt
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
sys.path.append("/Users/fuchao/github/models/slim")
import numpy as np
import tensorflow as tf
import urllib2
from datasets import imagenet
from nets import vgg
from preprocessing import vgg_preprocessing
checkpoints_dir = '/Users/fuchao/checkpoints'
slim = tf.contrib.slim
# We need default size of image for a particular network.
# The network was trained on images of that size -- so we
# resize input image later in the code.
image_size = vgg.vgg_16.default_image_size
with tf.Graph().as_default():
url = ("http://img04.tooopen.com/images/20130123/tooopen_09221274.jpg")
# Open specified url and load image as a string
image_string = urllib2.urlopen(url).read()
# Decode string into matrix with intensity values
image = tf.image.decode_jpeg(image_string, channels=3)
# Resize the input image, preserving the aspect ratio
# and make a central crop of the resulted image.
# The crop will be of the size of the default image size of
# the network.
processed_image = vgg_preprocessing.preprocess_image(image,
image_size,
image_size,
is_training=False)
# Networks accept images in batches.
# The first dimension usually represents the batch size.
# In our case the batch size is one.
processed_images = tf.expand_dims(processed_image, 0)
# Create the model, use the default arg scope to configure
# the batch norm parameters. arg_scope is a very conveniet
# feature of slim library -- you can define default
# parameters for layers -- like stride, padding etc.
with slim.arg_scope(vgg.vgg_arg_scope()):
logits, _ = vgg.vgg_16(processed_images,
num_classes=1000,
is_training=False)
# In order to get probabilities we apply softmax on the output.
probabilities = tf.nn.softmax(logits)
# Create a function that reads the network weights
# from the checkpoint file that you downloaded.
# We will run it in session later.
#print slim.get_model_variables()
init_fn = slim.assign_from_checkpoint_fn(
os.path.join(checkpoints_dir, 'vgg_16.ckpt'),
slim.get_model_variables('vgg_16'))
with tf.Session() as sess:
# Load weights
init_fn(sess)
# We want to get predictions, image as numpy matrix
# and resized and cropped piece that is actually
# being fed to the network.
np_image, network_input, probabilities = sess.run([image,
processed_image,
probabilities])
probabilities = probabilities[0, 0:]
sorted_inds = [i[0] for i in sorted(enumerate(-probabilities),
key=lambda x:x[1])]
# Show the downloaded image
'''
plt.figure()
plt.imshow(np_image.astype(np.uint8))
plt.suptitle("Downloaded image", fontsize=14, fontweight='bold')
plt.axis('off')
plt.show()
'''
# Show the image that is actually being fed to the network
# The image was resized while preserving aspect ratio and then
# cropped. After that, the mean pixel value was subtracted from
# each pixel of that crop. We normalize the image to be between [-1, 1]
# to show the image.
'''
plt.imshow( network_input / (network_input.max() - network_input.min()) )
plt.suptitle("Resized, Cropped and Mean-Centered input to the network",
fontsize=14, fontweight='bold')
plt.axis('off')
plt.show()
'''
names = imagenet.create_readable_names_for_imagenet_labels()
for i in range(5):
index = sorted_inds[i]
# Now we print the top-5 predictions that the network gives us with
# corresponding probabilities. Pay attention that the index with
# class names is shifted by 1 -- this is because some networks
# were trained on 1000 classes and others on 1001. VGG-16 was trained
# on 1000 classes.
print('Probability %0.2f => [%s]' % (probabilities[index], names[index+1]))
res = slim.get_model_variables()