-
Notifications
You must be signed in to change notification settings - Fork 60
/
warp.py
91 lines (76 loc) · 2.31 KB
/
warp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import tensorflow as tf
def get_pixel_value(img, x, y):
"""
Utility function to get pixel value for coordinate
vectors x and y from a 4D tensor image.
Input
-----
- img: tensor of shape (B, H, W, C)
- x: flattened tensor of shape (B*H*W, )
- y: flattened tensor of shape (B*H*W, )
Returns
-------
- output: tensor of shape (B, H, W, C)
"""
shape = tf.shape(x)
batch_size = shape[0]
height = shape[1]
width = shape[2]
batch_idx = tf.range(0, batch_size)
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
b = tf.tile(batch_idx, (1, height, width))
indices = tf.stack([b, y, x], 3)
return tf.gather_nd(img, indices)
def tf_warp(img, flow, H, W):
# H = 256
# W = 256
x,y = tf.meshgrid(tf.range(W), tf.range(H))
x = tf.expand_dims(x,0)
x = tf.expand_dims(x,-1)
y = tf.expand_dims(y,0)
y = tf.expand_dims(y,-1)
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)
grid = tf.concat([x,y],axis = -1)
# print grid.shape
flows = grid+flow
#print(flows.shape)
max_y = tf.cast(H - 1, tf.int32)
max_x = tf.cast(W - 1, tf.int32)
zero = tf.zeros([], dtype=tf.int32)
x = flows[:,:,:, 0]
y = flows[:,:,:, 1]
x0 = x
y0 = y
x0 = tf.cast(x0, tf.int32)
x1 = x0 + 1
y0 = tf.cast(y0, tf.int32)
y1 = y0 + 1
# clip to range [0, H/W] to not violate img boundaries
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
# get pixel value at corner coords
Ia = get_pixel_value(img, x0, y0)
Ib = get_pixel_value(img, x0, y1)
Ic = get_pixel_value(img, x1, y0)
Id = get_pixel_value(img, x1, y1)
# recast as float for delta calculation
x0 = tf.cast(x0, tf.float32)
x1 = tf.cast(x1, tf.float32)
y0 = tf.cast(y0, tf.float32)
y1 = tf.cast(y1, tf.float32)
# calculate deltas
wa = (x1-x) * (y1-y)
wb = (x1-x) * (y-y0)
wc = (x-x0) * (y1-y)
wd = (x-x0) * (y-y0)
# add dimension for addition
wa = tf.expand_dims(wa, axis=3)
wb = tf.expand_dims(wb, axis=3)
wc = tf.expand_dims(wc, axis=3)
wd = tf.expand_dims(wd, axis=3)
# compute output
out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
return out