You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Then I modified train.py to compute IS of (i) samples from a pretrained GAN model (after loading its snapshot), (ii) training set of CIFAR10. The modification is as follows:
ifargs.snapshot:
print("Resume training with snapshot:{}".format(args.snapshot))
chainer.serializers.load_npz(args.snapshot, trainer)
mean, std, ims=calc_inception_onthefly(gen, n_ims=5000, splits=1, path=args.inception_model_path)
np.savez('eval_images.npz', x=ims)
print ('IS of sngan => mean:{}, std:{}'.format(mean, std))
d_train, _=chainer.datasets.get_cifar10(ndim=3, withlabel=False, scale=255) # d_train is a numpy array of dtype=np.float32 and shape=50000,3,32,32mean, std, _=calc_inception_onthefly(ims=d_train, splits=10, path=args.inception_model_path)
print ('IS of cifar10 => mean:{}, std:{}'.format(mean, std))
Result of (i) is around 8.2, which is so close the one in the paper. However, your calculator gives 12.0 IS for training set of CIFAR10. But this score should be 11.24 as you state in the paper.
To further support my claim, I compute IS of samples generated by the pretrained GAN model (the ones that I save as eval_images.npz above) with the following Tensorflow code (this code gives 11.24 IS for training set of CIFAR10). But I got roughly 7.8, which is approximately 0.4 less than it should be.
Basically, these results show that your scores may not be comparable directly with the previous works, unless you implemented them and computed IS scores of all models with the same calculator. Could you please correct me if I am wrong in this analysis?
# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.pyimportosimportsysimporttarfilefromtqdmimporttqdm, trangefromsix.movesimporturllibimportnumpyasnpimporttensorflowastfMODEL_DIR='./'DATA_URL='http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'softmax=None# Call this function with numpy images. Each of elements should be a # numpy array with values ranging from 0 to 255.defget_inception_score(images, splits=10):
cfg=tf.ConfigProto()
cfg.gpu_options.allow_growth=Truewithtf.Session(config=cfg) assess:
preds= []
foriintrange(images.shape[0]):
input=images[[i]].astype(np.float32)
pred=sess.run(softmax, {'ExpandDims:0': input})
preds.append(pred)
preds=np.concatenate(preds, 0)
scores= []
foriintrange(splits):
part=preds[(i*preds.shape[0] //splits):((i+1) *preds.shape[0] //splits), :]
kl=part* (np.log(part) -np.log(np.expand_dims(np.mean(part, 0), 0)))
kl=np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
returnnp.mean(scores), np.std(scores)
# This function is called automatically.def_init_inception():
globalsoftmaxifnotos.path.exists(MODEL_DIR): os.makedirs(MODEL_DIR)
filename=DATA_URL.split('/')[-1]
filepath=os.path.join(MODEL_DIR, filename)
ifnotos.path.exists(filepath):
def_progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%'% (
filename, float(count*block_size) /float(total_size) *100.0))
sys.stdout.flush()
filepath, _=urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo=os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
withtf.gfile.FastGFile(os.path.join(MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') asf:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
_=tf.import_graph_def(graph_def, name='')
# Works with an arbitrary minibatch size. # NO IT DOES NOT!cfg=tf.ConfigProto()
cfg.gpu_options.allow_growth=Truewithtf.Session(config=cfg) assess:
pool3=sess.graph.get_tensor_by_name('pool_3:0')
ops=pool3.graph.get_operations()
forop_idx, opinenumerate(ops):
foroinop.outputs:
shape=o.get_shape()
shape= [s.valueforsinshape]
new_shape= []
forj, sinenumerate(shape):
ifs==1andj==0:
new_shape.append(None)
else:
new_shape.append(s)
o.set_shape(tf.TensorShape(new_shape))
w=sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
logits=tf.matmul(tf.squeeze(pool3, [1, 2]), w)
softmax=tf.nn.softmax(logits)
ifsoftmaxisNone:
_init_inception()
if__name__=='__main__':
fromtorchvision.datasetsimportCIFAR10cifar=CIFAR10(root='.', train=True, download=True)
print (get_inception_score(cifar.train_data))
x=np.load('eval_images.npz')['x']
print (x.shape) # (5000, 3, 32, 32)print (x.dtype) # uint8print (x.min(), x.max()) # 0, 255print (get_inception_score(np.transpose(x, [0, 2, 3, 1])))
The text was updated successfully, but these errors were encountered:
Hello. I may have found a bug in your IS calculator. I wrote the following function in evaluation.py
Then I modified train.py to compute IS of (i) samples from a pretrained GAN model (after loading its snapshot), (ii) training set of CIFAR10. The modification is as follows:
Result of (i) is around 8.2, which is so close the one in the paper. However, your calculator gives 12.0 IS for training set of CIFAR10. But this score should be 11.24 as you state in the paper.
To further support my claim, I compute IS of samples generated by the pretrained GAN model (the ones that I save as eval_images.npz above) with the following Tensorflow code (this code gives 11.24 IS for training set of CIFAR10). But I got roughly 7.8, which is approximately 0.4 less than it should be.
Basically, these results show that your scores may not be comparable directly with the previous works, unless you implemented them and computed IS scores of all models with the same calculator. Could you please correct me if I am wrong in this analysis?
The text was updated successfully, but these errors were encountered: