-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_lib.py
1139 lines (1018 loc) · 48.2 KB
/
model_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Constructs model, inputs, and training environment."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import functools
import os
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
import tf_slim as slim
from object_detection import eval_util
from object_detection import exporter as exporter_lib
from object_detection import inputs
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.builders import optimizer_builder
from object_detection.core import standard_fields as fields
from object_detection.utils import config_util
from object_detection.utils import label_map_util
from object_detection.utils import ops
from object_detection.utils import shape_utils
from object_detection.utils import variables_helper
from object_detection.utils import visualization_utils as vis_utils
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import learn as contrib_learn
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
# A map of names to methods that help build the model.
MODEL_BUILD_UTIL_MAP = {
'get_configs_from_pipeline_file':
config_util.get_configs_from_pipeline_file,
'create_pipeline_proto_from_configs':
config_util.create_pipeline_proto_from_configs,
'merge_external_params_with_configs':
config_util.merge_external_params_with_configs,
'create_train_input_fn':
inputs.create_train_input_fn,
'create_eval_input_fn':
inputs.create_eval_input_fn,
'create_predict_input_fn':
inputs.create_predict_input_fn,
'detection_model_fn_base': model_builder.build,
}
def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
max_number_of_boxes):
"""Extracts groundtruth data from detection_model and prepares it for eval.
Args:
detection_model: A `DetectionModel` object.
class_agnostic: Whether the detections are class_agnostic.
max_number_of_boxes: Max number of groundtruth boxes.
Returns:
A tuple of:
groundtruth: Dictionary with the following fields:
'groundtruth_boxes': [batch_size, num_boxes, 4] float32 tensor of boxes,
in normalized coordinates.
'groundtruth_classes': [batch_size, num_boxes] int64 tensor of 1-indexed
classes.
'groundtruth_masks': 4D float32 tensor of instance masks (if provided in
groundtruth)
'groundtruth_is_crowd': [batch_size, num_boxes] bool tensor indicating
is_crowd annotations (if provided in groundtruth).
'groundtruth_area': [batch_size, num_boxes] float32 tensor indicating
the area (in the original absolute coordinates) of annotations (if
provided in groundtruth).
'num_groundtruth_boxes': [batch_size] tensor containing the maximum number
of groundtruth boxes per image..
'groundtruth_keypoints': [batch_size, num_boxes, num_keypoints, 2] float32
tensor of keypoints (if provided in groundtruth).
'groundtruth_dp_num_points_list': [batch_size, num_boxes] int32 tensor
with the number of DensePose points for each instance (if provided in
groundtruth).
'groundtruth_dp_part_ids_list': [batch_size, num_boxes,
max_sampled_points] int32 tensor with the part ids for each DensePose
sampled point (if provided in groundtruth).
'groundtruth_dp_surface_coords_list': [batch_size, num_boxes,
max_sampled_points, 4] containing the DensePose surface coordinates for
each sampled point (if provided in groundtruth).
'groundtruth_track_ids_list': [batch_size, num_boxes] int32 tensor
with track ID for each instance (if provided in groundtruth).
'groundtruth_group_of': [batch_size, num_boxes] bool tensor indicating
group_of annotations (if provided in groundtruth).
'groundtruth_labeled_classes': [batch_size, num_classes] int64
tensor of 1-indexed classes.
'groundtruth_verified_neg_classes': [batch_size, num_classes] float32
K-hot representation of 1-indexed classes which were verified as not
present in the image.
'groundtruth_not_exhaustive_classes': [batch_size, num_classes] K-hot
representation of 1-indexed classes which don't have all of their
instances marked exhaustively.
class_agnostic: Boolean indicating whether detections are class agnostic.
"""
input_data_fields = fields.InputDataFields()
groundtruth_boxes = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.boxes))
groundtruth_boxes_shape = tf.shape(groundtruth_boxes)
# For class-agnostic models, groundtruth one-hot encodings collapse to all
# ones.
if class_agnostic:
groundtruth_classes_one_hot = tf.ones(
[groundtruth_boxes_shape[0], groundtruth_boxes_shape[1], 1])
else:
groundtruth_classes_one_hot = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.classes))
label_id_offset = 1 # Applying label id offset (b/63711816)
groundtruth_classes = (
tf.argmax(groundtruth_classes_one_hot, axis=2) + label_id_offset)
groundtruth = {
input_data_fields.groundtruth_boxes: groundtruth_boxes,
input_data_fields.groundtruth_classes: groundtruth_classes
}
if detection_model.groundtruth_has_field(fields.BoxListFields.masks):
groundtruth[input_data_fields.groundtruth_instance_masks] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.masks))
if detection_model.groundtruth_has_field(fields.BoxListFields.is_crowd):
groundtruth[input_data_fields.groundtruth_is_crowd] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.is_crowd))
if detection_model.groundtruth_has_field(input_data_fields.groundtruth_area):
groundtruth[input_data_fields.groundtruth_area] = tf.stack(
detection_model.groundtruth_lists(input_data_fields.groundtruth_area))
if detection_model.groundtruth_has_field(fields.BoxListFields.keypoints):
groundtruth[input_data_fields.groundtruth_keypoints] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.keypoints))
if detection_model.groundtruth_has_field(
fields.BoxListFields.keypoint_depths):
groundtruth[input_data_fields.groundtruth_keypoint_depths] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.keypoint_depths))
groundtruth[
input_data_fields.groundtruth_keypoint_depth_weights] = tf.stack(
detection_model.groundtruth_lists(
fields.BoxListFields.keypoint_depth_weights))
if detection_model.groundtruth_has_field(
fields.BoxListFields.keypoint_visibilities):
groundtruth[input_data_fields.groundtruth_keypoint_visibilities] = tf.stack(
detection_model.groundtruth_lists(
fields.BoxListFields.keypoint_visibilities))
if detection_model.groundtruth_has_field(fields.BoxListFields.group_of):
groundtruth[input_data_fields.groundtruth_group_of] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.group_of))
label_id_offset_paddings = tf.constant([[0, 0], [1, 0]])
if detection_model.groundtruth_has_field(
input_data_fields.groundtruth_verified_neg_classes):
groundtruth[input_data_fields.groundtruth_verified_neg_classes] = tf.pad(
tf.stack(detection_model.groundtruth_lists(
input_data_fields.groundtruth_verified_neg_classes)),
label_id_offset_paddings)
if detection_model.groundtruth_has_field(
input_data_fields.groundtruth_not_exhaustive_classes):
groundtruth[
input_data_fields.groundtruth_not_exhaustive_classes] = tf.pad(
tf.stack(detection_model.groundtruth_lists(
input_data_fields.groundtruth_not_exhaustive_classes)),
label_id_offset_paddings)
if detection_model.groundtruth_has_field(
fields.BoxListFields.densepose_num_points):
groundtruth[input_data_fields.groundtruth_dp_num_points] = tf.stack(
detection_model.groundtruth_lists(
fields.BoxListFields.densepose_num_points))
if detection_model.groundtruth_has_field(
fields.BoxListFields.densepose_part_ids):
groundtruth[input_data_fields.groundtruth_dp_part_ids] = tf.stack(
detection_model.groundtruth_lists(
fields.BoxListFields.densepose_part_ids))
if detection_model.groundtruth_has_field(
fields.BoxListFields.densepose_surface_coords):
groundtruth[input_data_fields.groundtruth_dp_surface_coords] = tf.stack(
detection_model.groundtruth_lists(
fields.BoxListFields.densepose_surface_coords))
if detection_model.groundtruth_has_field(fields.BoxListFields.track_ids):
groundtruth[input_data_fields.groundtruth_track_ids] = tf.stack(
detection_model.groundtruth_lists(fields.BoxListFields.track_ids))
if detection_model.groundtruth_has_field(
input_data_fields.groundtruth_labeled_classes):
groundtruth[input_data_fields.groundtruth_labeled_classes] = tf.pad(
tf.stack(
detection_model.groundtruth_lists(
input_data_fields.groundtruth_labeled_classes)),
label_id_offset_paddings)
groundtruth[input_data_fields.num_groundtruth_boxes] = (
tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]]))
return groundtruth
def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
"""Unstacks all tensors in `tensor_dict` along 0th dimension.
Unstacks tensor from the tensor dict along 0th dimension and returns a
tensor_dict containing values that are lists of unstacked, unpadded tensors.
Tensors in the `tensor_dict` are expected to be of one of the three shapes:
1. [batch_size]
2. [batch_size, height, width, channels]
3. [batch_size, num_boxes, d1, d2, ... dn]
When unpad_groundtruth_tensors is set to true, unstacked tensors of form 3
above are sliced along the `num_boxes` dimension using the value in tensor
field.InputDataFields.num_groundtruth_boxes.
Note that this function has a static list of input data fields and has to be
kept in sync with the InputDataFields defined in core/standard_fields.py
Args:
tensor_dict: A dictionary of batched groundtruth tensors.
unpad_groundtruth_tensors: Whether to remove padding along `num_boxes`
dimension of the groundtruth tensors.
Returns:
A dictionary where the keys are from fields.InputDataFields and values are
a list of unstacked (optionally unpadded) tensors.
Raises:
ValueError: If unpad_tensors is True and `tensor_dict` does not contain
`num_groundtruth_boxes` tensor.
"""
unbatched_tensor_dict = {
key: tf.unstack(tensor) for key, tensor in tensor_dict.items()
}
if unpad_groundtruth_tensors:
if (fields.InputDataFields.num_groundtruth_boxes not in
unbatched_tensor_dict):
raise ValueError('`num_groundtruth_boxes` not found in tensor_dict. '
'Keys available: {}'.format(
unbatched_tensor_dict.keys()))
unbatched_unpadded_tensor_dict = {}
unpad_keys = set([
# List of input data fields that are padded along the num_boxes
# dimension. This list has to be kept in sync with InputDataFields in
# standard_fields.py.
fields.InputDataFields.groundtruth_instance_masks,
fields.InputDataFields.groundtruth_classes,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_keypoints,
fields.InputDataFields.groundtruth_keypoint_depths,
fields.InputDataFields.groundtruth_keypoint_depth_weights,
fields.InputDataFields.groundtruth_keypoint_visibilities,
fields.InputDataFields.groundtruth_dp_num_points,
fields.InputDataFields.groundtruth_dp_part_ids,
fields.InputDataFields.groundtruth_dp_surface_coords,
fields.InputDataFields.groundtruth_track_ids,
fields.InputDataFields.groundtruth_group_of,
fields.InputDataFields.groundtruth_difficult,
fields.InputDataFields.groundtruth_is_crowd,
fields.InputDataFields.groundtruth_area,
fields.InputDataFields.groundtruth_weights
]).intersection(set(unbatched_tensor_dict.keys()))
for key in unpad_keys:
unpadded_tensor_list = []
for num_gt, padded_tensor in zip(
unbatched_tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
unbatched_tensor_dict[key]):
tensor_shape = shape_utils.combined_static_and_dynamic_shape(
padded_tensor)
slice_begin = tf.zeros([len(tensor_shape)], dtype=tf.int32)
slice_size = tf.stack(
[num_gt] + [-1 if dim is None else dim for dim in tensor_shape[1:]])
unpadded_tensor = tf.slice(padded_tensor, slice_begin, slice_size)
unpadded_tensor_list.append(unpadded_tensor)
unbatched_unpadded_tensor_dict[key] = unpadded_tensor_list
unbatched_tensor_dict.update(unbatched_unpadded_tensor_dict)
return unbatched_tensor_dict
def provide_groundtruth(model, labels):
"""Provides the labels to a model as groundtruth.
This helper function extracts the corresponding boxes, classes,
keypoints, weights, masks, etc. from the labels, and provides it
as groundtruth to the models.
Args:
model: The detection model to provide groundtruth to.
labels: The labels for the training or evaluation inputs.
"""
gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
gt_classes_list = labels[fields.InputDataFields.groundtruth_classes]
gt_masks_list = None
if fields.InputDataFields.groundtruth_instance_masks in labels:
gt_masks_list = labels[
fields.InputDataFields.groundtruth_instance_masks]
gt_keypoints_list = None
if fields.InputDataFields.groundtruth_keypoints in labels:
gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
gt_keypoint_depths_list = None
gt_keypoint_depth_weights_list = None
if fields.InputDataFields.groundtruth_keypoint_depths in labels:
gt_keypoint_depths_list = (
labels[fields.InputDataFields.groundtruth_keypoint_depths])
gt_keypoint_depth_weights_list = (
labels[fields.InputDataFields.groundtruth_keypoint_depth_weights])
gt_keypoint_visibilities_list = None
if fields.InputDataFields.groundtruth_keypoint_visibilities in labels:
gt_keypoint_visibilities_list = labels[
fields.InputDataFields.groundtruth_keypoint_visibilities]
gt_dp_num_points_list = None
if fields.InputDataFields.groundtruth_dp_num_points in labels:
gt_dp_num_points_list = labels[
fields.InputDataFields.groundtruth_dp_num_points]
gt_dp_part_ids_list = None
if fields.InputDataFields.groundtruth_dp_part_ids in labels:
gt_dp_part_ids_list = labels[
fields.InputDataFields.groundtruth_dp_part_ids]
gt_dp_surface_coords_list = None
if fields.InputDataFields.groundtruth_dp_surface_coords in labels:
gt_dp_surface_coords_list = labels[
fields.InputDataFields.groundtruth_dp_surface_coords]
gt_track_ids_list = None
if fields.InputDataFields.groundtruth_track_ids in labels:
gt_track_ids_list = labels[
fields.InputDataFields.groundtruth_track_ids]
gt_weights_list = None
if fields.InputDataFields.groundtruth_weights in labels:
gt_weights_list = labels[fields.InputDataFields.groundtruth_weights]
gt_confidences_list = None
if fields.InputDataFields.groundtruth_confidences in labels:
gt_confidences_list = labels[
fields.InputDataFields.groundtruth_confidences]
gt_is_crowd_list = None
if fields.InputDataFields.groundtruth_is_crowd in labels:
gt_is_crowd_list = labels[fields.InputDataFields.groundtruth_is_crowd]
gt_group_of_list = None
if fields.InputDataFields.groundtruth_group_of in labels:
gt_group_of_list = labels[fields.InputDataFields.groundtruth_group_of]
gt_area_list = None
if fields.InputDataFields.groundtruth_area in labels:
gt_area_list = labels[fields.InputDataFields.groundtruth_area]
gt_labeled_classes = None
if fields.InputDataFields.groundtruth_labeled_classes in labels:
gt_labeled_classes = labels[
fields.InputDataFields.groundtruth_labeled_classes]
gt_verified_neg_classes = None
if fields.InputDataFields.groundtruth_verified_neg_classes in labels:
gt_verified_neg_classes = labels[
fields.InputDataFields.groundtruth_verified_neg_classes]
gt_not_exhaustive_classes = None
if fields.InputDataFields.groundtruth_not_exhaustive_classes in labels:
gt_not_exhaustive_classes = labels[
fields.InputDataFields.groundtruth_not_exhaustive_classes]
model.provide_groundtruth(
groundtruth_boxes_list=gt_boxes_list,
groundtruth_classes_list=gt_classes_list,
groundtruth_confidences_list=gt_confidences_list,
groundtruth_labeled_classes=gt_labeled_classes,
groundtruth_masks_list=gt_masks_list,
groundtruth_keypoints_list=gt_keypoints_list,
groundtruth_keypoint_visibilities_list=gt_keypoint_visibilities_list,
groundtruth_dp_num_points_list=gt_dp_num_points_list,
groundtruth_dp_part_ids_list=gt_dp_part_ids_list,
groundtruth_dp_surface_coords_list=gt_dp_surface_coords_list,
groundtruth_weights_list=gt_weights_list,
groundtruth_is_crowd_list=gt_is_crowd_list,
groundtruth_group_of_list=gt_group_of_list,
groundtruth_area_list=gt_area_list,
groundtruth_track_ids_list=gt_track_ids_list,
groundtruth_verified_neg_classes=gt_verified_neg_classes,
groundtruth_not_exhaustive_classes=gt_not_exhaustive_classes,
groundtruth_keypoint_depths_list=gt_keypoint_depths_list,
groundtruth_keypoint_depth_weights_list=gt_keypoint_depth_weights_list)
def create_model_fn(detection_model_fn, configs, hparams=None, use_tpu=False,
postprocess_on_cpu=False):
"""Creates a model function for `Estimator`.
Args:
detection_model_fn: Function that returns a `DetectionModel` instance.
configs: Dictionary of pipeline config objects.
hparams: `HParams` object.
use_tpu: Boolean indicating whether model should be constructed for
use on TPU.
postprocess_on_cpu: When use_tpu and postprocess_on_cpu is true, postprocess
is scheduled on the host cpu.
Returns:
`model_fn` for `Estimator`.
"""
train_config = configs['train_config']
eval_input_config = configs['eval_input_config']
eval_config = configs['eval_config']
def model_fn(features, labels, mode, params=None):
"""Constructs the object detection model.
Args:
features: Dictionary of feature tensors, returned from `input_fn`.
labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
otherwise None.
mode: Mode key from tf.estimator.ModeKeys.
params: Parameter dictionary passed from the estimator.
Returns:
An `EstimatorSpec` that encapsulates the model and its serving
configurations.
"""
params = params or {}
total_loss, train_op, detections, export_outputs = None, None, None, None
is_training = mode == tf.estimator.ModeKeys.TRAIN
# Make sure to set the Keras learning phase. True during training,
# False for inference.
tf.keras.backend.set_learning_phase(is_training)
# Set policy for mixed-precision training with Keras-based models.
if use_tpu and train_config.use_bfloat16:
from tensorflow.python.keras.engine import base_layer_utils # pylint: disable=g-import-not-at-top
# Enable v2 behavior, as `mixed_bfloat16` is only supported in TF 2.0.
base_layer_utils.enable_v2_dtype_behavior()
tf2.keras.mixed_precision.experimental.set_policy(
'mixed_bfloat16')
detection_model = detection_model_fn(
is_training=is_training, add_summaries=(not use_tpu))
scaffold_fn = None
if mode == tf.estimator.ModeKeys.TRAIN:
labels = unstack_batch(
labels,
unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
elif mode == tf.estimator.ModeKeys.EVAL:
# For evaling on train data, it is necessary to check whether groundtruth
# must be unpadded.
boxes_shape = (
labels[fields.InputDataFields.groundtruth_boxes].get_shape()
.as_list())
unpad_groundtruth_tensors = boxes_shape[1] is not None and not use_tpu
labels = unstack_batch(
labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
provide_groundtruth(detection_model, labels)
preprocessed_images = features[fields.InputDataFields.image]
side_inputs = detection_model.get_side_inputs(features)
if use_tpu and train_config.use_bfloat16:
with tf.tpu.bfloat16_scope():
prediction_dict = detection_model.predict(
preprocessed_images,
features[fields.InputDataFields.true_image_shape], **side_inputs)
prediction_dict = ops.bfloat16_to_float32_nested(prediction_dict)
else:
prediction_dict = detection_model.predict(
preprocessed_images,
features[fields.InputDataFields.true_image_shape], **side_inputs)
def postprocess_wrapper(args):
return detection_model.postprocess(args[0], args[1])
if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
if use_tpu and postprocess_on_cpu:
detections = tf.tpu.outside_compilation(
postprocess_wrapper,
(prediction_dict,
features[fields.InputDataFields.true_image_shape]))
else:
detections = postprocess_wrapper((
prediction_dict,
features[fields.InputDataFields.true_image_shape]))
if mode == tf.estimator.ModeKeys.TRAIN:
load_pretrained = hparams.load_pretrained if hparams else False
if train_config.fine_tune_checkpoint and load_pretrained:
if not train_config.fine_tune_checkpoint_type:
# train_config.from_detection_checkpoint field is deprecated. For
# backward compatibility, set train_config.fine_tune_checkpoint_type
# based on train_config.from_detection_checkpoint.
if train_config.from_detection_checkpoint:
train_config.fine_tune_checkpoint_type = 'detection'
else:
train_config.fine_tune_checkpoint_type = 'classification'
asg_map = detection_model.restore_map(
fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
load_all_detection_checkpoint_vars=(
train_config.load_all_detection_checkpoint_vars))
available_var_map = (
variables_helper.get_variables_available_in_checkpoint(
asg_map,
train_config.fine_tune_checkpoint,
include_global_step=False))
if use_tpu:
def tpu_scaffold():
tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
available_var_map)
return tf.train.Scaffold()
scaffold_fn = tpu_scaffold
else:
tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
available_var_map)
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
if (mode == tf.estimator.ModeKeys.EVAL and
eval_config.use_dummy_loss_in_eval):
total_loss = tf.constant(1.0)
losses_dict = {'Loss/total_loss': total_loss}
else:
losses_dict = detection_model.loss(
prediction_dict, features[fields.InputDataFields.true_image_shape])
losses = [loss_tensor for loss_tensor in losses_dict.values()]
if train_config.add_regularization_loss:
regularization_losses = detection_model.regularization_losses()
if use_tpu and train_config.use_bfloat16:
regularization_losses = ops.bfloat16_to_float32_nested(
regularization_losses)
if regularization_losses:
regularization_loss = tf.add_n(
regularization_losses, name='regularization_loss')
losses.append(regularization_loss)
losses_dict['Loss/regularization_loss'] = regularization_loss
total_loss = tf.add_n(losses, name='total_loss')
losses_dict['Loss/total_loss'] = total_loss
if 'graph_rewriter_config' in configs:
graph_rewriter_fn = graph_rewriter_builder.build(
configs['graph_rewriter_config'], is_training=is_training)
graph_rewriter_fn()
# TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
# can write learning rate summaries on TPU without host calls.
global_step = tf.train.get_or_create_global_step()
training_optimizer, optimizer_summary_vars = optimizer_builder.build(
train_config.optimizer)
if mode == tf.estimator.ModeKeys.TRAIN:
if use_tpu:
training_optimizer = tf.tpu.CrossShardOptimizer(training_optimizer)
# Optionally freeze some layers by setting their gradients to be zero.
trainable_variables = None
include_variables = (
train_config.update_trainable_variables
if train_config.update_trainable_variables else None)
exclude_variables = (
train_config.freeze_variables
if train_config.freeze_variables else None)
trainable_variables = slim.filter_variables(
tf.trainable_variables(),
include_patterns=include_variables,
exclude_patterns=exclude_variables)
clip_gradients_value = None
if train_config.gradient_clipping_by_norm > 0:
clip_gradients_value = train_config.gradient_clipping_by_norm
if not use_tpu:
for var in optimizer_summary_vars:
tf.summary.scalar(var.op.name, var)
summaries = [] if use_tpu else None
if train_config.summarize_gradients:
summaries = ['gradients', 'gradient_norm', 'global_gradient_norm']
train_op = slim.optimizers.optimize_loss(
loss=total_loss,
global_step=global_step,
learning_rate=None,
clip_gradients=clip_gradients_value,
optimizer=training_optimizer,
update_ops=detection_model.updates(),
variables=trainable_variables,
summaries=summaries,
name='') # Preventing scope prefix on all variables.
if mode == tf.estimator.ModeKeys.PREDICT:
exported_output = exporter_lib.add_output_tensor_nodes(detections)
export_outputs = {
tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
tf.estimator.export.PredictOutput(exported_output)
}
eval_metric_ops = None
scaffold = None
if mode == tf.estimator.ModeKeys.EVAL:
class_agnostic = (
fields.DetectionResultFields.detection_classes not in detections)
groundtruth = _prepare_groundtruth_for_eval(
detection_model, class_agnostic,
eval_input_config.max_number_of_boxes)
use_original_images = fields.InputDataFields.original_image in features
if use_original_images:
eval_images = features[fields.InputDataFields.original_image]
true_image_shapes = tf.slice(
features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3])
original_image_spatial_shapes = features[fields.InputDataFields
.original_image_spatial_shape]
else:
eval_images = features[fields.InputDataFields.image]
true_image_shapes = None
original_image_spatial_shapes = None
eval_dict = eval_util.result_dict_for_batched_example(
eval_images,
features[inputs.HASH_KEY],
detections,
groundtruth,
class_agnostic=class_agnostic,
scale_to_absolute=True,
original_image_spatial_shapes=original_image_spatial_shapes,
true_image_shapes=true_image_shapes)
if fields.InputDataFields.image_additional_channels in features:
eval_dict[fields.InputDataFields.image_additional_channels] = features[
fields.InputDataFields.image_additional_channels]
if class_agnostic:
category_index = label_map_util.create_class_agnostic_category_index()
else:
category_index = label_map_util.create_category_index_from_labelmap(
eval_input_config.label_map_path)
vis_metric_ops = None
if not use_tpu and use_original_images:
keypoint_edges = [
(kp.start, kp.end) for kp in eval_config.keypoint_edge]
eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
category_index,
max_examples_to_draw=eval_config.num_visualizations,
max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
min_score_thresh=eval_config.min_score_threshold,
use_normalized_coordinates=False,
keypoint_edges=keypoint_edges or None)
vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
eval_dict)
# Eval metrics on a single example.
eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
eval_config, list(category_index.values()), eval_dict)
for loss_key, loss_tensor in iter(losses_dict.items()):
eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
for var in optimizer_summary_vars:
eval_metric_ops[var.op.name] = (var, tf.no_op())
if vis_metric_ops is not None:
eval_metric_ops.update(vis_metric_ops)
eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}
if eval_config.use_moving_averages:
variable_averages = tf.train.ExponentialMovingAverage(0.0)
variables_to_restore = variable_averages.variables_to_restore()
keep_checkpoint_every_n_hours = (
train_config.keep_checkpoint_every_n_hours)
saver = tf.train.Saver(
variables_to_restore,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
scaffold = tf.train.Scaffold(saver=saver)
# EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
scaffold_fn=scaffold_fn,
predictions=detections,
loss=total_loss,
train_op=train_op,
eval_metrics=eval_metric_ops,
export_outputs=export_outputs)
else:
if scaffold is None:
keep_checkpoint_every_n_hours = (
train_config.keep_checkpoint_every_n_hours)
saver = tf.train.Saver(
sharded=True,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
save_relative_paths=True)
tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
scaffold = tf.train.Scaffold(saver=saver)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=detections,
loss=total_loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
export_outputs=export_outputs,
scaffold=scaffold)
return model_fn
def create_estimator_and_inputs(run_config,
hparams=None,
pipeline_config_path=None,
config_override=None,
train_steps=None,
sample_1_of_n_eval_examples=1,
sample_1_of_n_eval_on_train_examples=1,
model_fn_creator=create_model_fn,
use_tpu_estimator=False,
use_tpu=False,
num_shards=1,
params=None,
override_eval_num_epochs=True,
save_final_config=False,
postprocess_on_cpu=False,
export_to_tpu=None,
**kwargs):
"""Creates `Estimator`, input functions, and steps.
Args:
run_config: A `RunConfig`.
hparams: (optional) A `HParams`.
pipeline_config_path: A path to a pipeline config file.
config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to
override the config from `pipeline_config_path`.
train_steps: Number of training steps. If None, the number of training steps
is set from the `TrainConfig` proto.
sample_1_of_n_eval_examples: Integer representing how often an eval example
should be sampled. If 1, will sample all examples.
sample_1_of_n_eval_on_train_examples: Similar to
`sample_1_of_n_eval_examples`, except controls the sampling of training
data for evaluation.
model_fn_creator: A function that creates a `model_fn` for `Estimator`.
Follows the signature:
* Args:
* `detection_model_fn`: Function that returns `DetectionModel` instance.
* `configs`: Dictionary of pipeline config objects.
* `hparams`: `HParams` object.
* Returns:
`model_fn` for `Estimator`.
use_tpu_estimator: Whether a `TPUEstimator` should be returned. If False,
an `Estimator` will be returned.
use_tpu: Boolean, whether training and evaluation should run on TPU. Only
used if `use_tpu_estimator` is True.
num_shards: Number of shards (TPU cores). Only used if `use_tpu_estimator`
is True.
params: Parameter dictionary passed from the estimator. Only used if
`use_tpu_estimator` is True.
override_eval_num_epochs: Whether to overwrite the number of epochs to 1 for
eval_input.
save_final_config: Whether to save final config (obtained after applying
overrides) to `estimator.model_dir`.
postprocess_on_cpu: When use_tpu and postprocess_on_cpu are true,
postprocess is scheduled on the host cpu.
export_to_tpu: When use_tpu and export_to_tpu are true,
`export_savedmodel()` exports a metagraph for serving on TPU besides the
one on CPU.
**kwargs: Additional keyword arguments for configuration override.
Returns:
A dictionary with the following fields:
'estimator': An `Estimator` or `TPUEstimator`.
'train_input_fn': A training input function.
'eval_input_fns': A list of all evaluation input functions.
'eval_input_names': A list of names for each evaluation input.
'eval_on_train_input_fn': An evaluation-on-train input function.
'predict_input_fn': A prediction input function.
'train_steps': Number of training steps. Either directly from input or from
configuration.
"""
get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
'get_configs_from_pipeline_file']
merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
'merge_external_params_with_configs']
create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
'create_pipeline_proto_from_configs']
create_train_input_fn = MODEL_BUILD_UTIL_MAP['create_train_input_fn']
create_eval_input_fn = MODEL_BUILD_UTIL_MAP['create_eval_input_fn']
create_predict_input_fn = MODEL_BUILD_UTIL_MAP['create_predict_input_fn']
detection_model_fn_base = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']
configs = get_configs_from_pipeline_file(
pipeline_config_path, config_override=config_override)
kwargs.update({
'train_steps': train_steps,
'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu
})
if sample_1_of_n_eval_examples >= 1:
kwargs.update({
'sample_1_of_n_eval_examples': sample_1_of_n_eval_examples
})
if override_eval_num_epochs:
kwargs.update({'eval_num_epochs': 1})
tf.logging.warning(
'Forced number of epochs for all eval validations to be 1.')
configs = merge_external_params_with_configs(
configs, hparams, kwargs_dict=kwargs)
model_config = configs['model']
train_config = configs['train_config']
train_input_config = configs['train_input_config']
eval_config = configs['eval_config']
eval_input_configs = configs['eval_input_configs']
eval_on_train_input_config = copy.deepcopy(train_input_config)
eval_on_train_input_config.sample_1_of_n_examples = (
sample_1_of_n_eval_on_train_examples)
if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1:
tf.logging.warning('Expected number of evaluation epochs is 1, but '
'instead encountered `eval_on_train_input_config'
'.num_epochs` = '
'{}. Overwriting `num_epochs` to 1.'.format(
eval_on_train_input_config.num_epochs))
eval_on_train_input_config.num_epochs = 1
# update train_steps from config but only when non-zero value is provided
if train_steps is None and train_config.num_steps != 0:
train_steps = train_config.num_steps
detection_model_fn = functools.partial(
detection_model_fn_base, model_config=model_config)
# Create the input functions for TRAIN/EVAL/PREDICT.
train_input_fn = create_train_input_fn(
train_config=train_config,
train_input_config=train_input_config,
model_config=model_config)
eval_input_fns = []
for eval_input_config in eval_input_configs:
eval_input_fns.append(
create_eval_input_fn(
eval_config=eval_config,
eval_input_config=eval_input_config,
model_config=model_config))
eval_input_names = [
eval_input_config.name for eval_input_config in eval_input_configs
]
eval_on_train_input_fn = create_eval_input_fn(
eval_config=eval_config,
eval_input_config=eval_on_train_input_config,
model_config=model_config)
predict_input_fn = create_predict_input_fn(
model_config=model_config, predict_input_config=eval_input_configs[0])
# Read export_to_tpu from hparams if not passed.
if export_to_tpu is None and hparams is not None:
export_to_tpu = hparams.get('export_to_tpu', False)
tf.logging.info('create_estimator_and_inputs: use_tpu %s, export_to_tpu %s',
use_tpu, export_to_tpu)
model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu,
postprocess_on_cpu)
if use_tpu_estimator:
estimator = tf.estimator.tpu.TPUEstimator(
model_fn=model_fn,
train_batch_size=train_config.batch_size,
# For each core, only batch size 1 is supported for eval.
eval_batch_size=num_shards * 1 if use_tpu else 1,
use_tpu=use_tpu,
config=run_config,
export_to_tpu=export_to_tpu,
eval_on_tpu=False, # Eval runs on CPU, so disable eval on TPU
params=params if params else {})
else:
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
# Write the as-run pipeline config to disk.
if run_config.is_chief and save_final_config:
pipeline_config_final = create_pipeline_proto_from_configs(configs)
config_util.save_pipeline_config(pipeline_config_final, estimator.model_dir)
return dict(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fns=eval_input_fns,
eval_input_names=eval_input_names,
eval_on_train_input_fn=eval_on_train_input_fn,
predict_input_fn=predict_input_fn,
train_steps=train_steps)
def create_train_and_eval_specs(train_input_fn,
eval_input_fns,
eval_on_train_input_fn,
predict_input_fn,
train_steps,
eval_on_train_data=False,
final_exporter_name='Servo',
eval_spec_names=None):
"""Creates a `TrainSpec` and `EvalSpec`s.
Args:
train_input_fn: Function that produces features and labels on train data.
eval_input_fns: A list of functions that produce features and labels on eval
data.
eval_on_train_input_fn: Function that produces features and labels for
evaluation on train data.
predict_input_fn: Function that produces features for inference.
train_steps: Number of training steps.
eval_on_train_data: Whether to evaluate model on training data. Default is
False.
final_exporter_name: String name given to `FinalExporter`.
eval_spec_names: A list of string names for each `EvalSpec`.
Returns:
Tuple of `TrainSpec` and list of `EvalSpecs`. If `eval_on_train_data` is
True, the last `EvalSpec` in the list will correspond to training data. The
rest EvalSpecs in the list are evaluation datas.
"""
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=train_steps)
if eval_spec_names is None:
eval_spec_names = [str(i) for i in range(len(eval_input_fns))]
eval_specs = []
for index, (eval_spec_name, eval_input_fn) in enumerate(
zip(eval_spec_names, eval_input_fns)):
# Uses final_exporter_name as exporter_name for the first eval spec for
# backward compatibility.
if index == 0:
exporter_name = final_exporter_name
else:
exporter_name = '{}_{}'.format(final_exporter_name, eval_spec_name)
exporter = tf.estimator.FinalExporter(
name=exporter_name, serving_input_receiver_fn=predict_input_fn)
eval_specs.append(
tf.estimator.EvalSpec(
name=eval_spec_name,
input_fn=eval_input_fn,
steps=None,
exporters=exporter))
if eval_on_train_data:
eval_specs.append(
tf.estimator.EvalSpec(
name='eval_on_train', input_fn=eval_on_train_input_fn, steps=None))
return train_spec, eval_specs
def _evaluate_checkpoint(estimator,
input_fn,
checkpoint_path,
name,
max_retries=0):
"""Evaluates a checkpoint.
Args:
estimator: Estimator object to use for evaluation.
input_fn: Input function to use for evaluation.
checkpoint_path: Path of the checkpoint to evaluate.
name: Namescope for eval summary.
max_retries: Maximum number of times to retry the evaluation on encountering
a tf.errors.InvalidArgumentError. If negative, will always retry the
evaluation.
Returns:
Estimator evaluation results.
"""
always_retry = True if max_retries < 0 else False
retries = 0
while always_retry or retries <= max_retries:
try:
return estimator.evaluate(
input_fn=input_fn,
steps=None,
checkpoint_path=checkpoint_path,
name=name)
except tf.errors.InvalidArgumentError as e:
if always_retry or retries < max_retries:
tf.logging.info('Retrying checkpoint evaluation after exception: %s', e)
retries += 1
else:
raise e
def continuous_eval_generator(estimator,
model_dir,
input_fn,
train_steps,
name,
max_retries=0):
"""Perform continuous evaluation on checkpoints written to a model directory.