-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
97 lines (86 loc) · 3.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import os
import sys
import io
import tensorflow as tf
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import time
import spectraGAN
def plot_to_image(figure):
"""Converts the matplotlib plot specified by 'figure' to a PNG image and
returns it. The supplied figure is closed and inaccessible after this call."""
# Save the plot to a PNG in memory.
buf = io.BytesIO()
plt.savefig(buf, format='png')
# Closing the figure prevents it from being displayed directly inside
# the notebook.
plt.close(figure)
buf.seek(0)
# Convert PNG buffer to TF image
image = tf.image.decode_png(buf.getvalue(), channels=4)
# Add the batch dimension
image = tf.expand_dims(image, 0)
return image
run_num = sys.argv[1]
continuing = False
baseDir = './expts/'
expDir = baseDir+'run'+str(run_num)+'/'
if not os.path.isdir(baseDir):
os.mkdir(baseDir)
if not os.path.isdir(expDir):
os.mkdir(expDir)
# Build GAN
GAN = spectraGAN.spectraGAN(expDir)
bestchi = 1e15
bestPchi = 1e15
for epoch in range(30):
start = time.time()
train_size = len(GAN.real_imgs)
np.random.shuffle(GAN.real_imgs)
with GAN.train_summary_writer.as_default():
for batchnum in range(train_size//GAN.batchsize):
samples = GAN.real_imgs[batchnum*GAN.batchsize:(batchnum+1)*GAN.batchsize]
GAN.train_step(samples)
if tf.equal(GAN.generator_optimizer.iterations % 100, 0):
# Generate sample imgs
fig = GAN.generate_images()
tf.summary.image("genimg", plot_to_image(fig),
step=GAN.generator_optimizer.iterations)
fig, chi = GAN.pix_hist()
tf.summary.image("pixhist", plot_to_image(fig),
step=GAN.generator_optimizer.iterations)
fig, Pchi = GAN.pspect()
tf.summary.image("powerspect", plot_to_image(fig),
step=GAN.generator_optimizer.iterations)
# Log scalars
tf.summary.scalar('G_loss', GAN.G_loss.result(),
step=GAN.generator_optimizer.iterations)
tf.summary.scalar('G_loss_gan', GAN.G_loss_gan.result(),
step=GAN.generator_optimizer.iterations)
tf.summary.scalar('G_loss_spect', GAN.G_loss_spect.result(),
step=GAN.generator_optimizer.iterations)
tf.summary.scalar('G_loss_spect_var', GAN.G_loss_spect_var.result(),
step=GAN.generator_optimizer.iterations)
tf.summary.scalar('D_loss', GAN.D_loss.result(),
step=GAN.generator_optimizer.iterations)
tf.summary.scalar('chi', chi, step=GAN.generator_optimizer.iterations)
tf.summary.scalar('Pchi', Pchi, step=GAN.generator_optimizer.iterations)
GAN.D_loss.reset_states()
GAN.G_loss.reset_states()
GAN.G_loss_gan.reset_states()
GAN.G_loss_spect.reset_states()
GAN.G_loss_spect_var.reset_states()
# Save model if chi is good
if GAN.generator_optimizer.iterations > 10000:
if chi < bestchi:
GAN.checkpoint.write(file_prefix = os.path.join(GAN.checkpoint_dir, 'BESTCHI'))
bestchi = chi
print('BESTCHI: iter=%d, chi=%f'%(GAN.generator_optimizer.iterations, chi))
if Pchi < bestPchi:
GAN.checkpoint.write(file_prefix = os.path.join(GAN.checkpoint_dir, 'BESTPCHI'))
bestPchi = Pchi
print('BESTPCHI: iter=%d, Pchi=%f'%(GAN.generator_optimizer.iterations, Pchi))
print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))
print('DONE')