diff --git a/models/custom_model.py b/models/custom_model.py index 18f056da1..0362a5827 100644 --- a/models/custom_model.py +++ b/models/custom_model.py @@ -3,6 +3,7 @@ import tensorflow as tf import tensorflow.contrib.slim as slim import numpy as np +import frontend_builder def conv_block(inputs, n_filters, filter_size=[3, 3], dropout_p=0.0): """ @@ -16,95 +17,36 @@ def conv_block(inputs, n_filters, filter_size=[3, 3], dropout_p=0.0): out = slim.dropout(out, keep_prob=(1.0-dropout_p)) return out -def conv_transpose_block(inputs, n_filters, filter_size=[3, 3], dropout_p=0.0): +def conv_transpose_block(inputs, n_filters, strides=2, filter_size=[3, 3], dropout_p=0.0): """ Basic conv transpose block for Encoder-Decoder upsampling Apply successivly Transposed Convolution, BatchNormalization, ReLU nonlinearity Dropout (if dropout_p > 0) on the inputs """ - conv = slim.conv2d_transpose(inputs, n_filters, kernel_size=[3, 3], stride=[2, 2]) + conv = slim.conv2d_transpose(inputs, n_filters, kernel_size=[3, 3], stride=[strides, strides]) out = tf.nn.relu(slim.batch_norm(conv, fused=True)) if dropout_p != 0.0: out = slim.dropout(out, keep_prob=(1.0-dropout_p)) return out -def build_encoder_decoder_skip(inputs, num_classes, dropout_p=0.5, scope=None): - """ - Builds the Encoder-Decoder-Skip model. Inspired by SegNet with some modifications - Includes skip connections - - Arguments: - inputs: the input tensor - n_classes: number of classes - dropout_p: dropout rate applied after each convolution (0. for not using) - - Returns: - Encoder-Decoder model - """ - - ##################### - # Downsampling path # - ##################### - net = conv_block(inputs, 64) - net = conv_block(net, 64) - net = slim.pool(net, [2, 2], stride=[2, 2], pooling_type='MAX') - skip_1 = net - - net = conv_block(net, 128) - net = conv_block(net, 128) - net = slim.pool(net, [2, 2], stride=[2, 2], pooling_type='MAX') - skip_2 = net - - net = conv_block(net, 256) - net = conv_block(net, 256) - net = conv_block(net, 256) - net = slim.pool(net, [2, 2], stride=[2, 2], pooling_type='MAX') - skip_3 = net - - net = conv_block(net, 512) - net = conv_block(net, 512) - net = conv_block(net, 512) - net = slim.pool(net, [2, 2], stride=[2, 2], pooling_type='MAX') - skip_4 = net - - net = conv_block(net, 512) - net = conv_block(net, 512) - net = conv_block(net, 512) - net = slim.pool(net, [2, 2], stride=[2, 2], pooling_type='MAX') +def build_custom(inputs, num_classes, frontend="ResNet101", weight_decay=1e-5, is_training=True, pretrained_dir="models"): + + logits, end_points, frontend_scope, init_fn = frontend_builder.build_frontend(inputs, frontend, is_training=is_training) - ##################### - # Upsampling path # - ##################### - net = conv_transpose_block(net, 512) - net = conv_block(net, 512) - net = conv_block(net, 512) - net = conv_block(net, 512) - net = tf.add(net, skip_4) + up_1 = conv_transpose_block(end_points["pool2"], strides=4, n_filters=64) + up_2 = conv_transpose_block(end_points["pool3"], strides=8, n_filters=64) + up_3 = conv_transpose_block(end_points["pool4"], strides=16, n_filters=64) + up_4 = conv_transpose_block(end_points["pool5"], strides=32, n_filters=64) - net = conv_transpose_block(net, 512) - net = conv_block(net, 512) - net = conv_block(net, 512) - net = conv_block(net, 256) - net = tf.add(net, skip_3) + features = tf.concat([up_1, up_2, up_3, up_4], axis=-1) - net = conv_transpose_block(net, 256) - net = conv_block(net, 256) - net = conv_block(net, 256) - net = conv_block(net, 128) - net = tf.add(net, skip_2) + features = conv_block(inputs=features, n_filters=256, filter_size=[1, 1]) - net = conv_transpose_block(net, 128) - net = conv_block(net, 128) - net = conv_block(net, 64) - net = tf.add(net, skip_1) + features = conv_block(inputs=features, n_filters=64, filter_size=[3, 3]) + features = conv_block(inputs=features, n_filters=64, filter_size=[3, 3]) + features = conv_block(inputs=features, n_filters=64, filter_size=[3, 3]) - net = conv_transpose_block(net, 64) - net = conv_block(net, 64) - net = conv_block(net, 64) - ##################### - # Softmax # - ##################### - net = slim.conv2d(net, num_classes, [1, 1], scope='logits') + net = slim.conv2d(features, num_classes, [1, 1], scope='logits') return net \ No newline at end of file diff --git a/models/frontend_builder.py b/models/frontend_builder.py index b7870e459..9ac97a473 100644 --- a/models/frontend_builder.py +++ b/models/frontend_builder.py @@ -3,7 +3,6 @@ import resnet_v2 import mobilenet_v2 import inception_v4 -import nasnet import os diff --git a/models/resnet_v2.py b/models/resnet_v2.py index 6f8927bdc..f6d86839b 100644 --- a/models/resnet_v2.py +++ b/models/resnet_v2.py @@ -199,6 +199,7 @@ def resnet_v2(inputs, activation_fn=None, normalizer_fn=None): net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') + net = slim.utils.collect_named_outputs(end_points_collection, 'pool2', net) net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) # This is needed because the pre-activation variant does not have batch