-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathmain.py
94 lines (78 loc) · 3.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# coding: utf-8
# --------------------------------------------------------
# FNM
# Written by Yichen Qian
# --------------------------------------------------------
import os
import tensorflow as tf
from config import cfg
from WGAN_GP import WGAN_GP
from utils import loadData
def main(_):
# Environment Setting
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = cfg.device_id
if not os.path.exists(cfg.results):
os.makedirs(cfg.results)
if not os.path.exists(cfg.checkpoint):
os.makedirs(cfg.checkpoint)
if not os.path.exists(cfg.summary_dir):
os.makedirs(cfg.summary_dir)
# Construct Networks
net = WGAN_GP()
data_feed = loadData(batch_size=cfg.batch_size, train_shuffle=True) # False
profile, front = data_feed.get_train()
net.build_up(profile, front)
# Train or Test
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config, graph=net.graph) as sess:
sess.run(tf.global_variables_initializer())
# Start Thread
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
saver = tf.train.Saver(max_to_keep=0) #
if cfg.is_finetune:
saver.restore(sess, cfg.checkpoint_ft)
print('Load Finetuned Model Successfully!')
num_batch = int(cfg.dataset_size / cfg.batch_size)
writer = tf.summary.FileWriter(cfg.summary_dir, sess.graph)
# Train by minibatch and critic
for epoch in range(cfg.epoch):
for step in range(num_batch):
# Discriminator Part
if(step < 25 and epoch == 0 and not cfg.is_finetune):
critic = 25
else:
critic = cfg.critic
for i in range(critic):
_ = sess.run(net.train_dis, {net.is_train:True})
# Generative Part
#_,fl,gl,dl,gen,summary = sess.run([net.train_gen,net.feature_loss,net.g_loss,
# net.d_loss,net.gen_p,net.train_summary],
# {net.is_train:True})
_,fl,gl,dl,gen,g1,g2,g4,summary = sess.run([net.train_gen, net.feature_loss,net.g_loss,
net.d_loss,net.gen_p,net.grad1,net.grad2,net.grad4,net.train_summary],
{net.is_train:True})
#print('%d-%d, Fea Loss:%.2f, D Loss:%4.1f, G Loss:%4.1f,' % (epoch, step, fl, dl, gl))
print('%d-%d, Fea Loss:%.2f, D Loss:%4.1f, G Loss:%4.1f, g1/2/4:%.5f/%.5f/%.5f ' % #
(epoch, step, fl, dl, gl, g1*cfg.lambda_fea,g2,g4))
# Save Model and Summary and Test
if(step % cfg.save_freq == 0):
writer.add_summary(summary, epoch*num_batch + step)
print("Saving Model....")
saver.save(sess, os.path.join(cfg.checkpoint, 'ck-%02d' % (epoch))) #
# test
fl, dl, gl = 0., 0., 0.
for i in range(50): # 25791 / 16
te_profile, te_front = data_feed.get_test_batch(cfg.batch_size)
dl_, gl_, fl_, images = sess.run([net.d_loss,net.g_loss, net.feature_loss, net.gen_p],
{profile:te_profile, front:te_front, net.is_train:False}) #
data_feed.save_images(images, epoch)
dl += dl_; gl += gl_; fl += fl_
print('Testing: Fea Loss:%.1f, D Loss:%.1f, G Loss:%.1f' % (fl, dl, gl))
# Close Threads
coord.request_stop()
coord.join(threads)
if __name__ == "__main__":
tf.app.run()