This repository has been archived by the owner on Jul 13, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
78 lines (74 loc) · 3.15 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from tensorbayes.utils import progbar
from scipy.stats import mode
import numpy as np
import os.path
def stream_print(f, string, pipe_to_file=True):
print string
if pipe_to_file and f is not None:
f.write(string + '\n')
f.flush()
def test_acc(mnist, sess, qy_logit):
epoch = int(mnist.test.images.shape[0]/100)
logits = []
for i in range(epoch):
logits_batch = sess.run(qy_logit, feed_dict={'x:0': mnist.test.next_batch(100)[0]})
logits += list(logits_batch)
logits = np.array(logits)
cat_pred = logits.argmax(1)
# print cat_pred[:20]
# print cat_pred[1000:1020]
# print cat_pred[2000:2020]
# print cat_pred[3000:3020]
# raise ValueError('Nothing')
# print cat_pred
real_pred = np.zeros_like(cat_pred)
for cat in xrange(logits.shape[1]):
idx = cat_pred == cat
lab = mnist.test.labels.argmax(1)[idx]
if len(lab) == 0:
continue
real_pred[cat_pred == cat] = mode(lab).mode[0]
# print real_pred
return np.mean(real_pred == mnist.test.labels.argmax(1))
def open_file(fname):
if fname is None:
return None
else:
i = 0
while os.path.isfile('{:s}.{:d}'.format(fname, i)):
i += 1
return open('{:s}.{:d}'.format(fname, i), 'w', 0)
def train(fname, mnist, sess_info, epochs):
(sess, qy_logit, nent, loss, train_step) = sess_info
f = open_file(fname)
iterep = 500
for i in range(iterep * epochs):
sess.run(train_step, feed_dict={'x:0': mnist.train.next_batch(100)[0]})
# if i<= 20:
# print test_acc(mnist, sess, qy_logit)
# else:
# raise ValueError("nothing")
# a, b = sess.run([nent, loss], feed_dict={'x:0': mnist.train.images[np.random.choice(50000, 10000)]})
# c, d = sess.run([nent, loss], feed_dict={'x:0': mnist.test.images})
# a, b, c, d = -a.mean(), b.mean(), -c.mean(), d.mean()
# e = test_acc(mnist, sess, qy_logit)
# string = ('{:>10s},{:>10s},{:>10s},{:>10s},{:>10s},{:>10s}'
# .format('tr_ent', 'tr_loss', 't_ent', 't_loss', 't_acc', 'iteration')+'\n')
# stream_print(f, string, i <= iterep)
# string = ('{:10.2e},{:10.2e},{:10.2e},{:10.2e},{:10.2e},{:10d}'
# .format(a, b, c, d, e, i + 1))
# stream_print(f, string)
progbar(i, iterep)
# print '\n'
if (i + 1) % iterep == 0:
a, b = sess.run([nent, loss], feed_dict={'x:0': mnist.train.images[np.random.choice(50000, 1000)]})
c, d = sess.run([nent, loss], feed_dict={'x:0': mnist.test.images[np.random.choice(10000, 1000)]})
a, b, c, d = -a.mean(), b.mean(), -c.mean(), d.mean()
e = test_acc(mnist, sess, qy_logit)
string = ('{:>10s},{:>10s},{:>10s},{:>10s},{:>10s},{:>10s}'
.format('tr_ent', 'tr_loss', 't_ent', 't_loss', 't_acc', 'epoch'))
stream_print(f, string, i <= iterep)
string = ('{:10.2e},{:10.2e},{:10.2e},{:10.2e},{:10.2e},{:10d}'
.format(a, b, c, d, e, (i + 1) / iterep))
stream_print(f, string)
if f is not None: f.close()