-
Notifications
You must be signed in to change notification settings - Fork 12
/
tfa_image.py
263 lines (218 loc) · 10.1 KB
/
tfa_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image warping using per-pixel flow vectors."""
import tensorflow as tf
# import tensor_types as types
from typing import Optional
@tf.function
def interpolate_bilinear(
grid,
query_points,
indexing="ij",
name=None,
):
"""Similar to Matlab's interp2 function.
Finds values for query points on a grid using bilinear interpolation.
Args:
grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`.
query_points: a 3-D float `Tensor` of N points with shape
`[batch, N, 2]`.
indexing: whether the query points are specified as row and column (ij),
or Cartesian coordinates (xy).
name: a name for the operation (optional).
Returns:
values: a 3-D `Tensor` with shape `[batch, N, channels]`
Raises:
ValueError: if the indexing mode is invalid, or if the shape of the
inputs invalid.
"""
return _interpolate_bilinear_with_checks(grid, query_points, indexing, name)
def _interpolate_bilinear_with_checks(
grid,
query_points,
indexing,
name=None,
):
"""Perform checks on inputs without tf.function decorator to avoid flakiness."""
if indexing != "ij" and indexing != "xy":
raise ValueError("Indexing mode must be 'ij' or 'xy'")
grid = tf.convert_to_tensor(grid)
query_points = tf.convert_to_tensor(query_points)
grid_shape = tf.shape(grid)
query_shape = tf.shape(query_points)
with tf.control_dependencies(
[
tf.debugging.assert_equal(tf.rank(grid), 4, "Grid must be 4D Tensor"),
tf.debugging.assert_greater_equal(
grid_shape[1], 2, "Grid height must be at least 2."
),
tf.debugging.assert_greater_equal(
grid_shape[2], 2, "Grid width must be at least 2."
),
tf.debugging.assert_equal(
tf.rank(query_points), 3, "Query points must be 3 dimensional."
),
tf.debugging.assert_equal(
query_shape[2], 2, "Query points last dimension must be 2."
),
]
):
return _interpolate_bilinear_impl(grid, query_points, indexing, name)
def _interpolate_bilinear_impl(
grid,
query_points,
indexing,
name,
):
"""tf.function implementation of interpolate_bilinear."""
with tf.name_scope(name or "interpolate_bilinear"):
grid_shape = tf.shape(grid)
query_shape = tf.shape(query_points)
batch_size, height, width, channels = (
grid_shape[0],
grid_shape[1],
grid_shape[2],
grid_shape[3],
)
num_queries = query_shape[1]
query_type = query_points.dtype
grid_type = grid.dtype
alphas = []
floors = []
ceils = []
index_order = [0, 1] if indexing == "ij" else [1, 0]
unstacked_query_points = tf.unstack(query_points, axis=2, num=2)
for i, dim in enumerate(index_order):
with tf.name_scope("dim-" + str(dim)):
queries = unstacked_query_points[dim]
size_in_indexing_dimension = grid_shape[i + 1]
# max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
# is still a valid index into the grid.
max_floor = tf.cast(size_in_indexing_dimension - 2, query_type)
min_floor = tf.constant(0.0, dtype=query_type)
floor = tf.math.minimum(
tf.math.maximum(min_floor, tf.math.floor(queries)), max_floor
)
int_floor = tf.cast(floor, tf.dtypes.int32)
floors.append(int_floor)
ceil = int_floor + 1
ceils.append(ceil)
# alpha has the same type as the grid, as we will directly use alpha
# when taking linear combinations of pixel values from the image.
alpha = tf.cast(queries - floor, grid_type)
min_alpha = tf.constant(0.0, dtype=grid_type)
max_alpha = tf.constant(1.0, dtype=grid_type)
alpha = tf.math.minimum(tf.math.maximum(min_alpha, alpha), max_alpha)
# Expand alpha to [b, n, 1] so we can use broadcasting
# (since the alpha values don't depend on the channel).
alpha = tf.expand_dims(alpha, 2)
alphas.append(alpha)
flattened_grid = tf.reshape(grid, [batch_size * height * width, channels])
batch_offsets = tf.reshape(
tf.range(batch_size) * height * width, [batch_size, 1]
)
# This wraps tf.gather. We reshape the image data such that the
# batch, y, and x coordinates are pulled into the first dimension.
# Then we gather. Finally, we reshape the output back. It's possible this
# code would be made simpler by using tf.gather_nd.
def gather(y_coords, x_coords, name):
with tf.name_scope("gather-" + name):
linear_coordinates = batch_offsets + y_coords * width + x_coords
gathered_values = tf.gather(flattened_grid, linear_coordinates)
return tf.reshape(gathered_values, [batch_size, num_queries, channels])
# grab the pixel values in the 4 corners around each query point
top_left = gather(floors[0], floors[1], "top_left")
top_right = gather(floors[0], ceils[1], "top_right")
bottom_left = gather(ceils[0], floors[1], "bottom_left")
bottom_right = gather(ceils[0], ceils[1], "bottom_right")
# now, do the actual interpolation
with tf.name_scope("interpolate"):
interp_top = alphas[1] * (top_right - top_left) + top_left
interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
interp = alphas[0] * (interp_bottom - interp_top) + interp_top
return interp
def _get_dim(x, idx):
if x.shape.ndims is None:
return tf.shape(x)[idx]
return x.shape[idx] or tf.shape(x)[idx]
@tf.function
def dense_image_warp(
image, flow, name=None
):
"""Image warping using per-pixel flow vectors.
Apply a non-linear warp to the image, where the warp is specified by a
dense flow field of offset vectors that define the correspondences of
pixel values in the output image back to locations in the source image.
Specifically, the pixel value at `output[b, j, i, c]` is
`images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]`.
The locations specified by this formula do not necessarily map to an int
index. Therefore, the pixel value is obtained by bilinear
interpolation of the 4 nearest pixels around
`(b, j - flow[b, j, i, 0], i - flow[b, j, i, 1])`. For locations outside
of the image, we use the nearest pixel values at the image boundary.
NOTE: The definition of the flow field above is different from that
of optical flow. This function expects the negative forward flow from
output image to source image. Given two images `I_1` and `I_2` and the
optical flow `F_12` from `I_1` to `I_2`, the image `I_1` can be
reconstructed by `I_1_rec = dense_image_warp(I_2, -F_12)`.
Args:
image: 4-D float `Tensor` with shape `[batch, height, width, channels]`.
flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`.
name: A name for the operation (optional).
Note that image and flow can be of type `tf.half`, `tf.float32`, or
`tf.float64`, and do not necessarily have to be the same type.
Returns:
A 4-D float `Tensor` with shape`[batch, height, width, channels]`
and same type as input image.
Raises:
ValueError: if `height < 2` or `width < 2` or the inputs have the wrong
number of dimensions.
"""
with tf.name_scope(name or "dense_image_warp"):
image = tf.convert_to_tensor(image)
flow = tf.convert_to_tensor(flow)
batch_size, height, width, channels = (
_get_dim(image, 0),
_get_dim(image, 1),
_get_dim(image, 2),
_get_dim(image, 3),
)
# The flow is defined on the image grid. Turn the flow into a list of query
# points in the grid space.
grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height))
stacked_grid = tf.cast(tf.stack([grid_y, grid_x], axis=2), flow.dtype)
batched_grid = tf.expand_dims(stacked_grid, axis=0)
query_points_on_grid = batched_grid - flow
query_points_flattened = tf.reshape(
query_points_on_grid, [batch_size, height * width, 2]
)
# Compute values at the query points, then reshape the result back to the
# image grid.
interpolated = interpolate_bilinear(image, query_points_flattened)
interpolated = tf.reshape(interpolated, [batch_size, height, width, channels])
return interpolated
@tf.function(experimental_implements="addons:DenseImageWarp")
def dense_image_warp_annotated(
image, flow, name=None
):
"""Similar to dense_image_warp but annotated with experimental_implements.
IMPORTANT: This is a temporary function and will be removed after TensorFlow's
next release.
This annotation make the serialized function detectable by the TFLite MLIR
converter and allow the converter to convert it to corresponding TFLite op.
However, with the annotation, this function cannot be used with backprop
under `tf.GradientTape` objects.
"""
return dense_image_warp(image, flow, name)