-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathsample3.lua
185 lines (170 loc) · 7.09 KB
/
sample3.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
require 'torch'
require 'nn'
require 'nngraph'
require 'image'
-- local imports
require 'pm'
require 'gmms'
local utils = require 'misc.utils'
--require 'misc.DataLoader'
require 'misc.DataLoaderRaw'
local net_utils = require 'misc.net_utils'
-------------------------------------------------------------------------------
-- Input arguments and options
-------------------------------------------------------------------------------
cmd = torch.CmdLine()
cmd:text()
cmd:text('Sampling an Image from a Pixel Model')
cmd:text()
cmd:text('Options')
-- Input paths
cmd:option('-model','','path to model to evaluate')
cmd:option('-img_size', 256, 'size of the sampled image')
-- Sampling options
cmd:option('-batch_size', 1, 'if > 0 then overrule, otherwise load from checkpoint.')
cmd:option('-sample_max', 1, '1 = sample argmax words. 0 = sample from distributions.')
cmd:option('-beam_size', 2, 'used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
cmd:option('-temperature', 1.0, 'temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.')
-- For evaluation on a folder of images:
cmd:option('-image_folder', '', 'If this is nonempty then will predict on the images in this folder path')
cmd:option('-image_root', '', 'In case the image paths have to be preprended with a root path to an image folder')
-- For evaluation on MSCOCO images from some split:
cmd:option('-split', 'test', 'if running on MSCOCO images, which split to use: val|test|train')
-- misc
cmd:option('-backend', 'cudnn', 'nn|cudnn')
cmd:option('-id', 'evalscript', 'an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
cmd:option('-seed', 123, 'random number generator seed to use')
cmd:option('-gpuid', 0, 'which gpu to use. -1 = use CPU')
cmd:text()
-------------------------------------------------------------------------------
-- Basic Torch initializations
-------------------------------------------------------------------------------
local opt = cmd:parse(arg)
torch.manualSeed(opt.seed)
torch.setdefaulttensortype('torch.FloatTensor') -- for CPU
if opt.gpuid >= 0 then
require 'cutorch'
require 'cunn'
if opt.backend == 'cudnn' then require 'cudnn' end
cutorch.manualSeed(opt.seed)
cutorch.setDevice(opt.gpuid + 1) -- note +1 because lua is 1-indexed
end
-------------------------------------------------------------------------------
-- Load the model checkpoint to evaluate
-------------------------------------------------------------------------------
assert(string.len(opt.model) > 0, 'must provide a model')
local checkpoint = torch.load(opt.model)
local batch_size = opt.batch_size
if opt.batch_size == 0 then batch_size = checkpoint.opt.batch_size end
-- change it to evaluation mode
local protos = checkpoint.protos
local patch_size = checkpoint.opt['patch_size']
local border = checkpoint.opt['border_init']
protos.pm.recurrent_stride = patch_size + opt.img_size
protos.pm.seq_length = protos.pm.recurrent_stride * protos.pm.recurrent_stride
if opt.gpuid >= 0 then for k,v in pairs(protos) do v:cuda() end end
local pm = protos.pm
local crit = nn.PixelModelCriterion(pm.pixel_size, pm.num_mixtures)
pm.core:evaluate()
print('The loaded model is trained on patch size with: ', patch_size)
-- prepare the empty states
local init_state = {}
for L = 1,checkpoint.opt.num_layers do
-- c and h for all layers
local h_init = torch.zeros(batch_size, pm.rnn_size):double()
if opt.gpuid >= 0 then h_init = h_init:cuda() end
table.insert(init_state, h_init:clone()) -- for lstm c
table.insert(init_state, h_init:clone()) -- for lstm h
end
local states = {[0] = init_state}
local images = torch.Tensor(batch_size, pm.pixel_size, pm.recurrent_stride, pm.recurrent_stride):cuda()
------------------ debug ------------------------
local img = image.load('imgs/D1.png', pm.pixel_size, 'float')
img = image.scale(img, 256, 256):resize(1, pm.pixel_size, 256, 256)
img = torch.repeatTensor(img, batch_size, 1, 1, 1)
img = img:cuda()
local loss_sum = 0
local train_loss_sum = 0
-- random seed the zero-th pixel
-- local pixel = torch.rand(batch_size, pm.pixel_size):cuda()
local pixel
local gmms
-- loop through each timestep
for h=1,pm.recurrent_stride do
for w=1,pm.recurrent_stride do
local ww = w -- actual coordinate
if h % 2 == 0 then ww = pm.recurrent_stride + 1 - w end
local pixel_left, pixel_up, pixel_right
local pl, pr, pu
if ww == 1 or h % 2 == 0 then
if border == 0 then
pixel_left = torch.zeros(batch_size, pm.pixel_size):cuda()
else
pixel_left = torch.rand(batch_size, pm.pixel_size):cuda()
end
pl = 0
else
pixel_left = images[{{}, {}, h, ww-1}]
pl = ww - 1
end
if ww == pm.recurrent_stride or h % 2 == 1 then
if border == 0 then
pixel_right = torch.zeros(batch_size, pm.pixel_size):cuda()
else
pixel_right = torch.rand(batch_size, pm.pixel_size):cuda()
end
pr = 0
else
pixel_right = images[{{}, {}, h, ww+1}]
pr = ww + 1
end
if h == 1 then
if border == 0 then
pixel_up = torch.zeros(batch_size, pm.pixel_size):cuda()
else
pixel_up = torch.rand(batch_size, pm.pixel_size):cuda()
end
pu = 0
else
pixel_up = images[{{}, {}, h-1, ww}]
pu = ww
end
-- inputs to LSTM, {input, states[t, t-1], states[t-1, t], states[t, t+1] }
-- Need to fix this for the new model
local inputs = {torch.cat(torch.cat(pixel_left, pixel_up, 2), pixel_right, 2), unpack(states[pl])}
-- insert the states[t-1,t]
for i,v in ipairs(states[pu]) do table.insert(inputs, v) end
-- insert the states[t,t+1]
for i,v in ipairs(states[pr]) do table.insert(inputs, v) end
-- forward the network outputs, {next_c, next_h, next_c, next_h ..., output_vec}
local lsts = pm.core:forward(inputs)
-- save the state
states[ww] = {}
for i=1,pm.num_state do table.insert(states[ww], lsts[i]:clone()) end
gmms = lsts[#lsts]
-- sampling
--pixel = img[{{}, {}, h, ww}]
--images[{{},{},h,ww}] = pixel
local train_pixel = img[{{}, {}, h, ww}]:clone()
pixel, loss, train_loss = crit:sample(gmms, train_pixel)
--pixel = train_pixel
images[{{},{},h,ww}] = pixel
loss_sum = loss_sum + loss
train_loss_sum = train_loss_sum + train_loss
end
collectgarbage()
end
-- output the sampled images
local images_cpu = images:float()
images_cpu = images_cpu[{{}, {}, {patch_size+1, pm.recurrent_stride},{patch_size+1, pm.recurrent_stride}}]
images_cpu = images_cpu:clamp(0,1):mul(255):type('torch.ByteTensor')
for i=1,batch_size do
local filename = path.join('samples', i .. '.png')
image.save(filename, images_cpu[{i,1,{},{}}])
end
--loss_sum = loss_sum / (opt.img_size * opt.img_size)
--train_loss_sum = train_loss_sum / (opt.img_size * opt.img_size)
loss_sum = loss_sum / (pm.recurrent_stride * pm.recurrent_stride)
train_loss_sum = train_loss_sum / (pm.recurrent_stride * pm.recurrent_stride)
print('testing loss: ', loss_sum)
print('training loss: ', train_loss_sum)