-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtiming_benchmark.lua
69 lines (54 loc) · 2.1 KB
/
timing_benchmark.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
-- Copyright 2016 Anurag Ranjan and the Max Planck Gesellschaft.
-- All rights reserved.
-- This software is provided for research purposes only.
-- By using this software you agree to the terms of the license file
-- in the root folder.
-- For commercial use, please contact [email protected].
require 'image'
require 'cutorch'
local cmd = torch.CmdLine()
cmd:option('-data', '../FlyingChairs/data', 'Flying Chairs data directory')
opt = cmd:parse(arg or {})
opt.showFlow = 0
opt.fineHeight = 384
opt.fineWidth = 512
opt.preprocess = 0
opt.level = 9
opt.polluteFlow = 0
opt.augment = 0
opt.warp = 1
opt.batchSize = 1
local donkey = require('timing_util')
local train_samples, validation_samples = donkey.getTrainValidationSplits('train_val_split.txt')
local loss = torch.zeros(1,1, opt.fineHeight, opt.fineWidth):float()
local errors = torch.zeros(validation_samples:size()[1])
timings = torch.zeros(validation_samples:size()[1])
local loss = 0
local flowCPU = cutorch.createCudaHostTensor(640, 2,opt.fineHeight,opt.fineWidth):uniform()
for i=1,validation_samples:size()[1] do
collectgarbage()
local id = validation_samples[i][1]
local imgs, flow = donkey.testHook(id)
timer = torch.Timer()
imgs = imgs:resize(1,6,opt.fineHeight, opt.fineWidth):cuda()
flow_est = donkey.computeInitFlowL9(imgs):squeeze()
local time_elapsed = timer:time().real
flowCPU[i]:copyAsync(flow_est)
cutorch.streamSynchronize(cutorch.getStream())
print('Time Elapsed: '..time_elapsed)
timings[i] = time_elapsed
end
cutorch.streamSynchronize(cutorch.getStream())
for i=1,validation_samples:size()[1] do
local id = validation_samples[i][1]
local raw_im1, raw_im2, raw_flow = donkey.getRawData(id)
local _err = (raw_flow - flowCPU[i]):pow(2)
local err = torch.sum(_err, 1):sqrt()
loss = loss + err:float()
errors[i] = err:mean()
print(i, errors[i])
end
loss = torch.div(loss, validation_samples:size()[1])
print('Average EPE = '..loss:sum()/(opt.fineWidth*opt.fineHeight))
print('Mean Timing: ' ..timings:mean())
print('Median Timing: ' ..timings:median()[1])