-
Notifications
You must be signed in to change notification settings - Fork 0
/
exporter_lib_v2.py
290 lines (234 loc) · 11 KB
/
exporter_lib_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# Lint as: python2, python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to export object detection inference graph."""
import ast
import os
import tensorflow.compat.v2 as tf
from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder
from object_detection.utils import config_util
INPUT_BUILDER_UTIL_MAP = {
'model_build': model_builder.build,
}
def _decode_image(encoded_image_string_tensor):
image_tensor = tf.image.decode_image(encoded_image_string_tensor,
channels=3)
image_tensor.set_shape((None, None, 3))
return image_tensor
def _decode_tf_example(tf_example_string_tensor):
tensor_dict = tf_example_decoder.TfExampleDecoder().decode(
tf_example_string_tensor)
image_tensor = tensor_dict[fields.InputDataFields.image]
return image_tensor
def _combine_side_inputs(side_input_shapes='',
side_input_types='',
side_input_names=''):
"""Zips the side inputs together.
Args:
side_input_shapes: forward-slash-separated list of comma-separated lists
describing input shapes.
side_input_types: comma-separated list of the types of the inputs.
side_input_names: comma-separated list of the names of the inputs.
Returns:
a zipped list of side input tuples.
"""
side_input_shapes = [
ast.literal_eval('[' + x + ']') for x in side_input_shapes.split('/')
]
side_input_types = eval('[' + side_input_types + ']') # pylint: disable=eval-used
side_input_names = side_input_names.split(',')
return zip(side_input_shapes, side_input_types, side_input_names)
class DetectionInferenceModule(tf.Module):
"""Detection Inference Module."""
def __init__(self, detection_model,
use_side_inputs=False,
zipped_side_inputs=None):
"""Initializes a module for detection.
Args:
detection_model: the detection model to use for inference.
use_side_inputs: whether to use side inputs.
zipped_side_inputs: the zipped side inputs.
"""
self._model = detection_model
def _get_side_input_signature(self, zipped_side_inputs):
sig = []
side_input_names = []
for info in zipped_side_inputs:
sig.append(tf.TensorSpec(shape=info[0],
dtype=info[1],
name=info[2]))
side_input_names.append(info[2])
return sig
def _get_side_names_from_zip(self, zipped_side_inputs):
return [side[2] for side in zipped_side_inputs]
def _preprocess_input(self, batch_input, decode_fn):
# Input preprocessing happends on the CPU. We don't need to use the device
# placement as it is automatically handled by TF.
def _decode_and_preprocess(single_input):
image = decode_fn(single_input)
image = tf.cast(image, tf.float32)
image, true_shape = self._model.preprocess(image[tf.newaxis, :, :, :])
return image[0], true_shape[0]
images, true_shapes = tf.map_fn(
_decode_and_preprocess,
elems=batch_input,
parallel_iterations=32,
back_prop=False,
fn_output_signature=(tf.float32, tf.int32))
return images, true_shapes
def _run_inference_on_images(self, images, true_shapes, **kwargs):
"""Cast image to float and run inference.
Args:
images: float32 Tensor of shape [None, None, None, 3].
true_shapes: int32 Tensor of form [batch, 3]
**kwargs: additional keyword arguments.
Returns:
Tensor dictionary holding detections.
"""
label_id_offset = 1
prediction_dict = self._model.predict(images, true_shapes, **kwargs)
detections = self._model.postprocess(prediction_dict, true_shapes)
classes_field = fields.DetectionResultFields.detection_classes
detections[classes_field] = (
tf.cast(detections[classes_field], tf.float32) + label_id_offset)
for key, val in detections.items():
detections[key] = tf.cast(val, tf.float32)
return detections
class DetectionFromImageModule(DetectionInferenceModule):
"""Detection Inference Module for image inputs."""
def __init__(self, detection_model,
use_side_inputs=False,
zipped_side_inputs=None):
"""Initializes a module for detection.
Args:
detection_model: the detection model to use for inference.
use_side_inputs: whether to use side inputs.
zipped_side_inputs: the zipped side inputs.
"""
if zipped_side_inputs is None:
zipped_side_inputs = []
sig = [tf.TensorSpec(shape=[1, None, None, 3],
dtype=tf.uint8,
name='input_tensor')]
if use_side_inputs:
sig.extend(self._get_side_input_signature(zipped_side_inputs))
self._side_input_names = self._get_side_names_from_zip(zipped_side_inputs)
def call_func(input_tensor, *side_inputs):
kwargs = dict(zip(self._side_input_names, side_inputs))
images, true_shapes = self._preprocess_input(input_tensor, lambda x: x)
return self._run_inference_on_images(images, true_shapes, **kwargs)
self.__call__ = tf.function(call_func, input_signature=sig)
# TODO(kaushikshiv): Check if omitting the signature also works.
super(DetectionFromImageModule, self).__init__(detection_model,
use_side_inputs,
zipped_side_inputs)
def get_true_shapes(input_tensor):
input_shape = tf.shape(input_tensor)
batch = input_shape[0]
image_shape = input_shape[1:]
true_shapes = tf.tile(image_shape[tf.newaxis, :], [batch, 1])
return true_shapes
class DetectionFromFloatImageModule(DetectionInferenceModule):
"""Detection Inference Module for float image inputs."""
@tf.function(
input_signature=[
tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32)])
def __call__(self, input_tensor):
images, true_shapes = self._preprocess_input(input_tensor, lambda x: x)
return self._run_inference_on_images(images,
true_shapes)
class DetectionFromEncodedImageModule(DetectionInferenceModule):
"""Detection Inference Module for encoded image string inputs."""
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def __call__(self, input_tensor):
images, true_shapes = self._preprocess_input(input_tensor, _decode_image)
return self._run_inference_on_images(images, true_shapes)
class DetectionFromTFExampleModule(DetectionInferenceModule):
"""Detection Inference Module for TF.Example inputs."""
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def __call__(self, input_tensor):
images, true_shapes = self._preprocess_input(input_tensor,
_decode_tf_example)
return self._run_inference_on_images(images, true_shapes)
DETECTION_MODULE_MAP = {
'image_tensor': DetectionFromImageModule,
'encoded_image_string_tensor':
DetectionFromEncodedImageModule,
'tf_example': DetectionFromTFExampleModule,
'float_image_tensor': DetectionFromFloatImageModule
}
def export_inference_graph(input_type,
pipeline_config,
trained_checkpoint_dir,
output_directory,
use_side_inputs=False,
side_input_shapes='',
side_input_types='',
side_input_names=''):
"""Exports inference graph for the model specified in the pipeline config.
This function creates `output_directory` if it does not already exist,
which will hold a copy of the pipeline config with filename `pipeline.config`,
and two subdirectories named `checkpoint` and `saved_model`
(containing the exported checkpoint and SavedModel respectively).
Args:
input_type: Type of input for the graph. Can be one of ['image_tensor',
'encoded_image_string_tensor', 'tf_example'].
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto.
trained_checkpoint_dir: Path to the trained checkpoint file.
output_directory: Path to write outputs.
use_side_inputs: boolean that determines whether side inputs should be
included in the input signature.
side_input_shapes: forward-slash-separated list of comma-separated lists
describing input shapes.
side_input_types: comma-separated list of the types of the inputs.
side_input_names: comma-separated list of the names of the inputs.
Raises:
ValueError: if input_type is invalid.
"""
output_checkpoint_directory = os.path.join(output_directory, 'checkpoint')
output_saved_model_directory = os.path.join(output_directory, 'saved_model')
detection_model = INPUT_BUILDER_UTIL_MAP['model_build'](
pipeline_config.model, is_training=False)
ckpt = tf.train.Checkpoint(
model=detection_model)
manager = tf.train.CheckpointManager(
ckpt, trained_checkpoint_dir, max_to_keep=1)
status = ckpt.restore(manager.latest_checkpoint).expect_partial()
if input_type not in DETECTION_MODULE_MAP:
raise ValueError('Unrecognized `input_type`')
if use_side_inputs and input_type != 'image_tensor':
raise ValueError('Side inputs supported for image_tensor input type only.')
zipped_side_inputs = []
if use_side_inputs:
zipped_side_inputs = _combine_side_inputs(side_input_shapes,
side_input_types,
side_input_names)
detection_module = DETECTION_MODULE_MAP[input_type](detection_model,
use_side_inputs,
list(zipped_side_inputs))
# Getting the concrete function traces the graph and forces variables to
# be constructed --- only after this can we save the checkpoint and
# saved model.
concrete_function = detection_module.__call__.get_concrete_function()
status.assert_existing_objects_matched()
exported_checkpoint_manager = tf.train.CheckpointManager(
ckpt, output_checkpoint_directory, max_to_keep=1)
exported_checkpoint_manager.save(checkpoint_number=0)
tf.saved_model.save(detection_module,
output_saved_model_directory,
signatures=concrete_function)
config_util.save_pipeline_config(pipeline_config, output_directory)