-
Notifications
You must be signed in to change notification settings - Fork 5
/
process.lua
executable file
·170 lines (150 loc) · 6.84 KB
/
process.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
local smr_dist = assert(require("libsmrdist"))
-- options comes from run.lua
local downs = options.downs
local boxh = options.boxh
local boxw = options.boxw
local dynamic_th= 0 -- the second max prob from the SMR after excluding the max prob and a small region around
local lifetime = 0 -- increases the search area
local threshold = 0 -- constant value that depends on the tracking, if going out of the scene (disappear), lost, or being tracked.
local lost = 0 -- object is lost (lost=1)
local disappear = 0 -- object is going out of the scene (disappear=1)
-- Find the max value and coordinates of a tensor
function GetMax(a)
x,xi = torch.max(a,1)
y,yi = torch.max(x,2) -- y = value
x_out = yi[1][1] -- y coord
y_out = xi[1][x_out] -- x coord
return y,x_out,y_out
end
-- Search the local area, if the object is lost
-- increase the search area by 1 every frame
function SMRtracker(patch)
for _,res in ipairs(state.resultsSMR) do
begin_x = math.max(0, res.lx - 60/downs - 1)
end_x = math.min(state.SMRProb:size(2), (res.lx+60/downs))
begin_y = math.max(0, res.ty - 50/downs - 1)
end_y = math.min(state.SMRProb:size(1), (res.ty+50/downs) )
end
if (lifetime > 0) then
begin_x = math.max(0, begin_x - lifetime)
end_x = math.min(state.SMRProb:size(2), end_x + lifetime)
begin_y = math.max(0, begin_y -lifetime )
end_y = math.min(state.SMRProb:size(1), end_y +lifetime)
end
lifetime = lifetime + 1
-- Call for the coroutines of SMR algorithm. state.SMRProb is filled with the similarity matching value
state.SMRProb:fill(0)
smr_dist.smr(state.SMRProb, state.input, patch, state.dynamic, begin_x, end_x, begin_y, end_y)
end
-- grab camera frames, and track the object
local function process()
------------------------------------------------------------
-- (0) grab frame, get Y chanel and resize
------------------------------------------------------------
profiler:start('get-frame')
source:getframe()
profiler:lap('get-frame')
------------------------------------------------------------
-- (1) SMR probability map
------------------------------------------------------------
state.SMRProb = torch.Tensor(math.floor(state.input:size(1)-boxh)+1, math.floor(state.input:size(2)-boxw)+1 ):fill(0)
if state.lastPatch:dim() > 0 then
SMRtracker(state.lastPatch)
state.resultsSMR = {}
value, px_nxt, py_nxt = GetMax(state.SMRProb)
local lx = math.min(math.max(0,(px_nxt-1)+1),state.input:size(2)-boxw+1)
local ty = math.min(math.max(0,(py_nxt-1)+1),state.input:size(1)-boxh+1)
-- Compare the two max. prob to see if one of them is really bigger than the other
-- if similiar the detection is not reliable.
window =8
state.SMRProb:narrow(2, math.max(px_nxt-window, 1), math.min(2*window, state.SMRProb:size(2)-px_nxt+window-1)):
narrow(1, math.max(py_nxt-window, 1), math.min(2*window,state.SMRProb:size(1)-py_nxt+window-1)):zero()
dynamic_th = state.SMRProb:max()
-- Dynamic thresholding
if (lost == 0) then
if(lx>=extension) and (ty>=extension) and (lx+boxw/downs)<=(state.input:size(2)-extension) and (ty+boxh/downs)<=(state.input:size(1)-extension-1) then
if (disappear == 1) then
threshold = 1.2
else
threshold = 1
end
else
threshold = 1.02
disappear = 1
end
else
threshold = 1.25
end
-- Accept or reject the detection
if (value[1][1]>(threshold*dynamic_th)) or (value[1][1]>dynamic_th+100) then
lifetime = 0
if (threshold == 1.25) then
disappear = 0
end
lost = 0
local nresult = {lx=lx, ty=ty, cx=lx+boxw/2, cy=ty+boxh/2, w=boxw, h=boxh,
class=state.classes[1], id=1, source=2}
table.insert(state.resultsSMR, nresult)
else
lost = 1
end
-- Template update
-- Do not update the template if the object is going out of the scene
-- A better template update mechanism is necessary to handle the occlusions.
for _,res in ipairs(state.resultsSMR) do
if(res.lx>=2*extension) and (res.ty>=2*extension) and (res.lx+boxw)<state.YUVFrame:size(3)+extension-1 and (res.ty+boxh)<state.YUVFrame:size(2)+extension-1 then
local patchYUV = torch.Tensor(boxh, boxw):fill(0)
patchYUV:copy(state.input[{ {res.ty, boxh+res.ty-1},{res.lx,boxw+res.lx-1}}])
if state.lastPatch:dim() > 0 then
difference = (state.lastPatch:add(-1, patchYUV)):abs()
if (difference:max()/2)~=0 then
state.dynamic=(difference:max()/2)
end
state.lastPatch:copy(patchYUV)
end
end
end
end
------------------------------------------------------------
-- (2) capture new prototype, upon user request
------------------------------------------------------------
if state.learn then
profiler:start('learn-new-view')
-- compute x,y coordinates
if options.source == 'dataset' then
ref_lx = math.min(math.max(state.learn.x+extension-boxw/2+1,1),state.input:size(2)-boxw)
ref_ty = math.min(math.max(state.learn.y+extension-boxh/2+1,1),state.input:size(1)-boxh)
else
ref_lx = math.min(math.max(state.learn.x-boxw/2,0),state.input:size(2)-boxw)
ref_ty = math.min(math.max(state.learn.y-boxh/2,0),state.input:size(1)-boxh)
end
state.logit('adding object at ' .. ref_lx
.. ',' .. ref_ty, state.learn.id)
-- and create a result !!
local nresult = {lx=ref_lx, ty=ref_ty, w=boxw,
h=boxh, class=state.classes[state.learn.id],
id=state.learn.id, source=6}
table.insert(state.resultsSMR, nresult)
-- save a patch
local patchYUV = torch.Tensor(boxh, boxw):fill(0)
patchYUV:copy(state.input[{ {ref_ty, boxh+ref_ty-1},{ref_lx,boxw+ref_lx-1}}])
state.lastPatch = patchYUV:clone()
-- done
state.learn = nil
profiler:lap('learn-new-view')
end
------------------------------------------------------------
-- (3) save results
------------------------------------------------------------
if state.dsoutfile then
local res = state.resultsSMR[1]
if res then
state.dsoutfile:writeString(res.lx .. ',' .. res.ty .. ',' ..
res.lx+res.w .. ',' .. res.ty+res.h)
else
state.dsoutfile:writeString('NaN,NaN,NaN,NaN')
end
state.dsoutfile:writeString('\n')
end
end
return process