-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinception_score_mnist.py
38 lines (34 loc) · 1.12 KB
/
inception_score_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import math
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models import inception_v3
import numpy as np
from utils import get_mnist_classifer
def get_inception_score(imgs, use_cuda=None):
net = get_mnist_classifer()
net.eval()
if torch.cuda.is_available() and use_cuda != False:
net = net.cuda()
elif (torch.cuda.is_available() == False) and use_cuda == False:
print("not using cuda")
use_cuda = False
else:
print("Cuda not availiabe but use_cuda is True")
return
batch_size = np.shape(imgs[0])[0]
assert(len(np.shape(imgs[0])) == 4), "Batches of imgs had incorrect number of dimensions. Expected 5. Recieved shape: " + str(np.shape(imgs))
scores = []
for batch in imgs:
print(batch.shape)
s = net(batch)
scores += [s]
print("scores calculated")
p_yx = F.softmax(torch.cat(scores, 0), 1)
p_y = p_yx.mean(0).unsqueeze(0).expand(p_yx.size(0), -1)
KL_d = p_yx * (torch.log(p_yx) - torch.log(p_y))
final_score = KL_d.mean()
final_score = float(final_score.detach().cpu().numpy())
print("inception score", final_score)
return final_score