-
Notifications
You must be signed in to change notification settings - Fork 9
/
model.py
executable file
·697 lines (605 loc) · 25.7 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Provides DeepLab model definition and helper functions.
DeepLab is a deep learning system for semantic image segmentation with
the following features:
(1) Atrous convolution to explicitly control the resolution at which
feature responses are computed within Deep Convolutional Neural Networks.
(2) Atrous spatial pyramid pooling (ASPP) to robustly segment objects at
multiple scales with filters at multiple sampling rates and effective
fields-of-views.
(3) ASPP module augmented with image-level feature and batch normalization.
(4) A simple yet effective decoder module to recover the object boundaries.
See the following papers for more details:
"Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation"
Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam.
(https://arxiv.org/abs/1802.02611)
"Rethinking Atrous Convolution for Semantic Image Segmentation,"
Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam
(https://arxiv.org/abs/1706.05587)
"DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,
Atrous Convolution, and Fully Connected CRFs",
Liang-Chieh Chen*, George Papandreou*, Iasonas Kokkinos, Kevin Murphy,
Alan L Yuille (* equal contribution)
(https://arxiv.org/abs/1606.00915)
"Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected
CRFs"
Liang-Chieh Chen*, George Papandreou*, Iasonas Kokkinos, Kevin Murphy,
Alan L. Yuille (* equal contribution)
(https://arxiv.org/abs/1412.7062)
"""
import tensorflow as tf
from deeplab.core import feature_extractor
slim = tf.contrib.slim
_LOGITS_SCOPE_NAME = 'logits'
_MERGED_LOGITS_SCOPE = 'merged_logits'
_IMAGE_POOLING_SCOPE = 'image_pooling'
_ASPP_SCOPE = 'aspp'
_CONCAT_PROJECTION_SCOPE = 'concat_projection'
_DECODER_SCOPE = 'decoder'
def get_extra_layer_scopes(last_layers_contain_logits_only=False):
"""Gets the scopes for extra layers.
Args:
last_layers_contain_logits_only: Boolean, True if only consider logits as
the last layer (i.e., exclude ASPP module, decoder module and so on)
Returns:
A list of scopes for extra layers.
"""
if last_layers_contain_logits_only:
return [_LOGITS_SCOPE_NAME]
else:
return [
_LOGITS_SCOPE_NAME,
_IMAGE_POOLING_SCOPE,
_ASPP_SCOPE,
_CONCAT_PROJECTION_SCOPE,
_DECODER_SCOPE,
]
def predict_labels_multi_scale(images,
model_options,
eval_scales=(1.0,),
add_flipped_images=False):
"""Predicts segmentation labels.
Args:
images: A tensor of size [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
eval_scales: The scales to resize images for evaluation.
add_flipped_images: Add flipped images for evaluation or not.
Returns:
A dictionary with keys specifying the output_type (e.g., semantic
prediction) and values storing Tensors representing predictions (argmax
over channels). Each prediction has size [batch, height, width].
"""
outputs_to_predictions = {
output: []
for output in model_options.outputs_to_num_classes
}
for i, image_scale in enumerate(eval_scales):
with tf.variable_scope(tf.get_variable_scope(), reuse=True if i else None):
outputs_to_scales_to_logits = multi_scale_logits(
images,
model_options=model_options,
image_pyramid=[image_scale],
is_training=False,
fine_tune_batch_norm=False)
if add_flipped_images:
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
outputs_to_scales_to_logits_reversed = multi_scale_logits(
tf.reverse_v2(images, [2]),
model_options=model_options,
image_pyramid=[image_scale],
is_training=False,
fine_tune_batch_norm=False)
for output in sorted(outputs_to_scales_to_logits):
scales_to_logits = outputs_to_scales_to_logits[output]
logits = tf.image.resize_bilinear(
scales_to_logits[_MERGED_LOGITS_SCOPE],
tf.shape(images)[1:3],
align_corners=True)
outputs_to_predictions[output].append(
tf.expand_dims(tf.nn.softmax(logits), 4))
if add_flipped_images:
scales_to_logits_reversed = (
outputs_to_scales_to_logits_reversed[output])
logits_reversed = tf.image.resize_bilinear(
tf.reverse_v2(scales_to_logits_reversed[_MERGED_LOGITS_SCOPE], [2]),
tf.shape(images)[1:3],
align_corners=True)
outputs_to_predictions[output].append(
tf.expand_dims(tf.nn.softmax(logits_reversed), 4))
for output in sorted(outputs_to_predictions):
predictions = outputs_to_predictions[output]
# Compute average prediction across different scales and flipped images.
predictions = tf.reduce_mean(tf.concat(predictions, 4), axis=4)
outputs_to_predictions[output] = tf.argmax(predictions, 3)
return outputs_to_predictions
def predict_labels(images, model_options, image_pyramid=None):
"""Predicts segmentation labels.
Args:
images: A tensor of size [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
image_pyramid: Input image scales for multi-scale feature extraction.
Returns:
A dictionary with keys specifying the output_type (e.g., semantic
prediction) and values storing Tensors representing predictions (argmax
over channels). Each prediction has size [batch, height, width].
"""
outputs_to_scales_to_logits = multi_scale_logits(
images,
model_options=model_options,
image_pyramid=image_pyramid,
is_training=False,
fine_tune_batch_norm=False)
predictions = {}
for output in sorted(outputs_to_scales_to_logits):
scales_to_logits = outputs_to_scales_to_logits[output]
logits = tf.image.resize_bilinear(
scales_to_logits[_MERGED_LOGITS_SCOPE],
tf.shape(images)[1:3],
align_corners=True)
predictions[output] = tf.argmax(logits, 3)
return predictions
def scale_dimension(dim, scale):
"""Scales the input dimension.
Args:
dim: Input dimension (a scalar or a scalar Tensor).
scale: The amount of scaling applied to the input.
Returns:
Scaled dimension.
"""
if isinstance(dim, tf.Tensor):
return tf.cast((tf.to_float(dim) - 1.0) * scale + 1.0, dtype=tf.int32)
else:
return int((float(dim) - 1.0) * scale + 1.0)
def multi_scale_logits(images,
model_options,
image_pyramid,
weight_decay=0.0001,
is_training=False,
fine_tune_batch_norm=False):
"""Gets the logits for multi-scale inputs.
The returned logits are all downsampled (due to max-pooling layers)
for both training and evaluation.
Args:
images: A tensor of size [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
image_pyramid: Input image scales for multi-scale feature extraction.
weight_decay: The weight decay for model variables.
is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
Returns:
outputs_to_scales_to_logits: A map of maps from output_type (e.g.,
semantic prediction) to a dictionary of multi-scale logits names to
logits. For each output_type, the dictionary has keys which
correspond to the scales and values which correspond to the logits.
For example, if `scales` equals [1.0, 1.5], then the keys would
include 'merged_logits', 'logits_1.00' and 'logits_1.50'.
Raises:
ValueError: If model_options doesn't specify crop_size and its
add_image_level_feature = True, since add_image_level_feature requires
crop_size information.
"""
# Setup default values.
if not image_pyramid:
image_pyramid = [1.0]
if model_options.crop_size is None and model_options.add_image_level_feature:
raise ValueError(
'Crop size must be specified for using image-level feature.')
if model_options.model_variant == 'mobilenet_v2':
if (model_options.atrous_rates is not None or
model_options.decoder_output_stride is not None):
# Output a warning and users should make sure if the setting is desired.
tf.logging.warning('Our provided mobilenet_v2 checkpoint does not '
'include ASPP and decoder modules.')
crop_height = (
model_options.crop_size[0]
if model_options.crop_size else tf.shape(images)[1])
crop_width = (
model_options.crop_size[1]
if model_options.crop_size else tf.shape(images)[2])
# Compute the height, width for the output logits.
logits_output_stride = (
model_options.decoder_output_stride or model_options.output_stride)
logits_height = scale_dimension(
crop_height,
max(1.0, max(image_pyramid)) / logits_output_stride)
logits_width = scale_dimension(
crop_width,
max(1.0, max(image_pyramid)) / logits_output_stride)
# Compute the logits for each scale in the image pyramid.
outputs_to_scales_to_logits = {
k: {}
for k in model_options.outputs_to_num_classes
}
for count, image_scale in enumerate(image_pyramid):
if image_scale != 1.0:
scaled_height = scale_dimension(crop_height, image_scale)
scaled_width = scale_dimension(crop_width, image_scale)
scaled_crop_size = [scaled_height, scaled_width]
scaled_images = tf.image.resize_bilinear(
images, scaled_crop_size, align_corners=True)
if model_options.crop_size:
scaled_images.set_shape([None, scaled_height, scaled_width, 3])
else:
scaled_crop_size = model_options.crop_size
scaled_images = images
updated_options = model_options._replace(crop_size=scaled_crop_size)
outputs_to_logits = _get_logits(
scaled_images,
updated_options,
weight_decay=weight_decay,
reuse=True if count else None,
is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm)
# Resize the logits to have the same dimension before merging.
for output in sorted(outputs_to_logits):
outputs_to_logits[output] = tf.image.resize_bilinear(
outputs_to_logits[output], [logits_height, logits_width],
align_corners=True)
# Return when only one input scale.
if len(image_pyramid) == 1:
for output in sorted(model_options.outputs_to_num_classes):
outputs_to_scales_to_logits[output][
_MERGED_LOGITS_SCOPE] = outputs_to_logits[output]
return outputs_to_scales_to_logits
# Save logits to the output map.
for output in sorted(model_options.outputs_to_num_classes):
outputs_to_scales_to_logits[output][
'logits_%.2f' % image_scale] = outputs_to_logits[output]
# Merge the logits from all the multi-scale inputs.
for output in sorted(model_options.outputs_to_num_classes):
# Concatenate the multi-scale logits for each output type.
all_logits = [
tf.expand_dims(logits, axis=4)
for logits in outputs_to_scales_to_logits[output].values()
]
all_logits = tf.concat(all_logits, 4)
merge_fn = (
tf.reduce_max
if model_options.merge_method == 'max' else tf.reduce_mean)
outputs_to_scales_to_logits[output][_MERGED_LOGITS_SCOPE] = merge_fn(
all_logits, axis=4)
return outputs_to_scales_to_logits
def _extract_features(images,
model_options,
weight_decay=0.0001,
reuse=None,
is_training=False,
fine_tune_batch_norm=False):
"""Extracts features by the particular model_variant.
Args:
images: A tensor of size [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
weight_decay: The weight decay for model variables.
reuse: Reuse the model variables or not.
is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
Returns:
concat_logits: A tensor of size [batch, feature_height, feature_width,
feature_channels], where feature_height/feature_width are determined by
the images height/width and output_stride.
end_points: A dictionary from components of the network to the corresponding
activation.
"""
features, end_points = feature_extractor.extract_features(
images,
output_stride=model_options.output_stride,
multi_grid=model_options.multi_grid,
model_variant=model_options.model_variant,
weight_decay=weight_decay,
reuse=reuse,
is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm)
if not model_options.aspp_with_batch_norm:
return features, end_points
else:
batch_norm_params = {
'is_training': is_training and fine_tune_batch_norm,
'decay': 0.9997,
'epsilon': 1e-5,
'scale': True,
}
with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
padding='SAME',
stride=1,
reuse=reuse):
with slim.arg_scope([slim.batch_norm], **batch_norm_params):
depth = 256
branch_logits = []
if model_options.add_image_level_feature:
pool_height = scale_dimension(model_options.crop_size[0],
1. / model_options.output_stride)
pool_width = scale_dimension(model_options.crop_size[1],
1. / model_options.output_stride)
image_feature = slim.avg_pool2d(
features, [pool_height, pool_width], [pool_height, pool_width],
padding='VALID')
image_feature = slim.conv2d(
image_feature, depth, 1, scope=_IMAGE_POOLING_SCOPE)
image_feature = tf.image.resize_bilinear(
image_feature, [pool_height, pool_width], align_corners=True)
image_feature.set_shape([None, pool_height, pool_width, depth])
branch_logits.append(image_feature)
# Employ a 1x1 convolution.
branch_logits.append(slim.conv2d(features, depth, 1,
scope=_ASPP_SCOPE + str(0)))
if model_options.atrous_rates:
# Employ 3x3 convolutions with different atrous rates.
for i, rate in enumerate(model_options.atrous_rates, 1):
scope = _ASPP_SCOPE + str(i)
if model_options.aspp_with_separable_conv:
aspp_features = _split_separable_conv2d(
features,
filters=depth,
rate=rate,
weight_decay=weight_decay,
scope=scope)
else:
aspp_features = slim.conv2d(
features, depth, 3, rate=rate, scope=scope)
branch_logits.append(aspp_features)
# Merge branch logits.
concat_logits = tf.concat(branch_logits, 3)
concat_logits = slim.conv2d(
concat_logits, depth, 1, scope=_CONCAT_PROJECTION_SCOPE)
concat_logits = slim.dropout(
concat_logits,
keep_prob=0.9,
is_training=is_training,
scope=_CONCAT_PROJECTION_SCOPE + '_dropout')
return concat_logits, end_points
def _get_logits(images,
model_options,
weight_decay=0.0001,
reuse=None,
is_training=False,
fine_tune_batch_norm=False):
"""Gets the logits by atrous/image spatial pyramid pooling.
Args:
images: A tensor of size [batch, height, width, channels].
model_options: A ModelOptions instance to configure models.
weight_decay: The weight decay for model variables.
reuse: Reuse the model variables or not.
is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
Returns:
outputs_to_logits: A map from output_type to logits.
"""
features, end_points = _extract_features(
images,
model_options,
weight_decay=weight_decay,
reuse=reuse,
is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm)
if model_options.decoder_output_stride is not None:
decoder_height = scale_dimension(model_options.crop_size[0],
1.0 / model_options.decoder_output_stride)
decoder_width = scale_dimension(model_options.crop_size[1],
1.0 / model_options.decoder_output_stride)
features = refine_by_decoder(
features,
end_points,
decoder_height=decoder_height,
decoder_width=decoder_width,
decoder_use_separable_conv=model_options.decoder_use_separable_conv,
model_variant=model_options.model_variant,
weight_decay=weight_decay,
reuse=reuse,
is_training=is_training,
fine_tune_batch_norm=fine_tune_batch_norm)
outputs_to_logits = {}
for output in sorted(model_options.outputs_to_num_classes):
outputs_to_logits[output] = _get_branch_logits(
features,
model_options.outputs_to_num_classes[output],
model_options.atrous_rates,
aspp_with_batch_norm=model_options.aspp_with_batch_norm,
kernel_size=model_options.logits_kernel_size,
weight_decay=weight_decay,
reuse=reuse,
scope_suffix=output)
return outputs_to_logits
def refine_by_decoder(features,
end_points,
decoder_height,
decoder_width,
decoder_use_separable_conv=False,
model_variant=None,
weight_decay=0.0001,
reuse=None,
is_training=False,
fine_tune_batch_norm=False):
"""Adds the decoder to obtain sharper segmentation results.
Args:
features: A tensor of size [batch, features_height, features_width,
features_channels].
end_points: A dictionary from components of the network to the corresponding
activation.
decoder_height: The height of decoder feature maps.
decoder_width: The width of decoder feature maps.
decoder_use_separable_conv: Employ separable convolution for decoder or not.
model_variant: Model variant for feature extraction.
weight_decay: The weight decay for model variables.
reuse: Reuse the model variables or not.
is_training: Is training or not.
fine_tune_batch_norm: Fine-tune the batch norm parameters or not.
Returns:
Decoder output with size [batch, decoder_height, decoder_width,
decoder_channels].
"""
batch_norm_params = {
'is_training': is_training and fine_tune_batch_norm,
'decay': 0.9997,
'epsilon': 1e-5,
'scale': True,
}
with slim.arg_scope(
[slim.conv2d, slim.separable_conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
padding='SAME',
stride=1,
reuse=reuse):
with slim.arg_scope([slim.batch_norm], **batch_norm_params):
with tf.variable_scope(_DECODER_SCOPE, _DECODER_SCOPE, [features]):
feature_list = feature_extractor.networks_to_feature_maps[
model_variant][feature_extractor.DECODER_END_POINTS]
if feature_list is None:
tf.logging.info('Not found any decoder end points.')
return features
else:
decoder_features = features
for i, name in enumerate(feature_list):
decoder_features_list = [decoder_features]
feature_name = '{}/{}'.format(
feature_extractor.name_scope[model_variant], name)
decoder_features_list.append(
slim.conv2d(
end_points[feature_name],
48,
1,
scope='feature_projection' + str(i)))
# Resize to decoder_height/decoder_width.
for j, feature in enumerate(decoder_features_list):
decoder_features_list[j] = tf.image.resize_bilinear(
feature, [decoder_height, decoder_width], align_corners=True)
decoder_features_list[j].set_shape(
[None, decoder_height, decoder_width, None])
decoder_depth = 256
if decoder_use_separable_conv:
decoder_features = _split_separable_conv2d(
tf.concat(decoder_features_list, 3),
filters=decoder_depth,
rate=1,
weight_decay=weight_decay,
scope='decoder_conv0')
decoder_features = _split_separable_conv2d(
decoder_features,
filters=decoder_depth,
rate=1,
weight_decay=weight_decay,
scope='decoder_conv1')
else:
num_convs = 2
decoder_features = slim.repeat(
tf.concat(decoder_features_list, 3),
num_convs,
slim.conv2d,
decoder_depth,
3,
scope='decoder_conv' + str(i))
return decoder_features
def _get_branch_logits(features,
num_classes,
atrous_rates=None,
aspp_with_batch_norm=False,
kernel_size=1,
weight_decay=0.0001,
reuse=None,
scope_suffix=''):
"""Gets the logits from each model's branch.
The underlying model is branched out in the last layer when atrous
spatial pyramid pooling is employed, and all branches are sum-merged
to form the final logits.
Args:
features: A float tensor of shape [batch, height, width, channels].
num_classes: Number of classes to predict.
atrous_rates: A list of atrous convolution rates for last layer.
aspp_with_batch_norm: Use batch normalization layers for ASPP.
kernel_size: Kernel size for convolution.
weight_decay: Weight decay for the model variables.
reuse: Reuse model variables or not.
scope_suffix: Scope suffix for the model variables.
Returns:
Merged logits with shape [batch, height, width, num_classes].
Raises:
ValueError: Upon invalid input kernel_size value.
"""
# When using batch normalization with ASPP, ASPP has been applied before
# in _extract_features, and thus we simply apply 1x1 convolution here.
if aspp_with_batch_norm or atrous_rates is None:
if kernel_size != 1:
raise ValueError('Kernel size must be 1 when atrous_rates is None or '
'using aspp_with_batch_norm. Gets %d.' % kernel_size)
atrous_rates = [1]
with slim.arg_scope(
[slim.conv2d],
weights_regularizer=slim.l2_regularizer(weight_decay),
weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
reuse=reuse):
with tf.variable_scope(_LOGITS_SCOPE_NAME, _LOGITS_SCOPE_NAME, [features]):
branch_logits = []
for i, rate in enumerate(atrous_rates):
scope = scope_suffix
if i:
scope += '_%d' % i
branch_logits.append(
slim.conv2d(
features,
num_classes,
kernel_size=kernel_size,
rate=rate,
activation_fn=None,
normalizer_fn=None,
scope=scope))
return tf.add_n(branch_logits)
def _split_separable_conv2d(inputs,
filters,
rate=1,
weight_decay=0.00004,
depthwise_weights_initializer_stddev=0.33,
pointwise_weights_initializer_stddev=0.06,
scope=None):
"""Splits a separable conv2d into depthwise and pointwise conv2d.
This operation differs from `tf.layers.separable_conv2d` as this operation
applies activation function between depthwise and pointwise conv2d.
Args:
inputs: Input tensor with shape [batch, height, width, channels].
filters: Number of filters in the 1x1 pointwise convolution.
rate: Atrous convolution rate for the depthwise convolution.
weight_decay: The weight decay to use for regularizing the model.
depthwise_weights_initializer_stddev: The standard deviation of the
truncated normal weight initializer for depthwise convolution.
pointwise_weights_initializer_stddev: The standard deviation of the
truncated normal weight initializer for pointwise convolution.
scope: Optional scope for the operation.
Returns:
Computed features after split separable conv2d.
"""
outputs = slim.separable_conv2d(
inputs,
None,
3,
depth_multiplier=1,
rate=rate,
weights_initializer=tf.truncated_normal_initializer(
stddev=depthwise_weights_initializer_stddev),
weights_regularizer=None,
scope=scope + '_depthwise')
return slim.conv2d(
outputs,
filters,
1,
weights_initializer=tf.truncated_normal_initializer(
stddev=pointwise_weights_initializer_stddev),
weights_regularizer=slim.l2_regularizer(weight_decay),
scope=scope + '_pointwise')