Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds efficientnet2 presets #1983

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions keras_hub/src/models/efficientnet/cba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import keras

BN_AXIS = 3

CONV_KERNEL_INITIALIZER = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda funky to return this as a dict, any particular reason for it?

More in line with other models would be to just add a function here...

def conv_kernel_initializer(scale=2.):
    return keras.initializers.VarianceScaling(
        scale=scale,
        ...
    )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No particular reason, I was following the original style of the fusedmbconv/mbconv blocks. Done.

"class_name": "VarianceScaling",
"config": {
"scale": 2.0,
"mode": "fan_out",
"distribution": "truncated_normal",
},
}


class CBABlock(keras.layers.Layer):
"""
Args:
input_filters: int, the number of input filters
output_filters: int, the number of output filters
kernel_size: default 3, the kernel_size to apply to the expansion phase
convolutions
strides: default 1, the strides to apply to the expansion phase
convolutions
data_format: str, channels_last (default) or channels_first, expects
tensors to be of shape (N, H, W, C) or (N, C, H, W) respectively
batch_norm_momentum: default 0.9, the BatchNormalization momentum
batch_norm_epsilon: default 1e-3, the BatchNormalization epsilon
activation: default "swish", the activation function used between
convolution operations
dropout: float, the optional dropout rate to apply before the output
convolution, defaults to 0.2
nores: bool, default False, forces no residual connection if True,
otherwise allows it if False.

Returns:
A tensor representing a feature map, passed through the ConvBNAct
block

Note:
Not intended to be used outside of the EfficientNet architecture.
"""

def __init__(
self,
input_filters,
output_filters,
kernel_size=3,
strides=1,
data_format="channels_last",
batch_norm_momentum=0.9,
batch_norm_epsilon=1e-3,
activation="swish",
dropout=0.2,
nores=False,
**kwargs
):
super().__init__(**kwargs)
self.input_filters = input_filters
self.output_filters = output_filters
self.kernel_size = kernel_size
self.strides = strides
self.data_format = data_format
self.batch_norm_momentum = batch_norm_momentum
self.batch_norm_epsilon = batch_norm_epsilon
self.activation = activation
self.dropout = dropout
self.nores = nores

padding_pixels = kernel_size // 2
self.conv1_pad = keras.layers.ZeroPadding2D(
padding=(padding_pixels, padding_pixels),
name=self.name + "conv_pad",
)
self.conv1 = keras.layers.Conv2D(
filters=self.output_filters,
kernel_size=kernel_size,
strides=strides,
kernel_initializer=CONV_KERNEL_INITIALIZER,
padding="valid",
data_format=data_format,
use_bias=False,
name=self.name + "conv",
)
self.bn1 = keras.layers.BatchNormalization(
axis=BN_AXIS,
momentum=self.batch_norm_momentum,
epsilon=self.batch_norm_epsilon,
name=self.name + "bn",
)
self.act = keras.layers.Activation(
self.activation, name=self.name + "activation"
)

if self.dropout:
self.dropout_layer = keras.layers.Dropout(
self.dropout,
noise_shape=(None, 1, 1, 1),
name=self.name + "drop",
)

def build(self, input_shape):
if self.name is None:
self.name = keras.backend.get_uid("block0")

def call(self, inputs):
x = self.conv1_pad(inputs)
x = self.conv1(x)
x = self.bn1(x)
x = self.act(x)

# Residual:
if (
self.strides == 1
and self.input_filters == self.output_filters
and not self.nores
):
if self.dropout:
x = self.dropout_layer(x)
x = keras.layers.Add(name=self.name + "add")([x, inputs])
return x

def get_config(self):
config = super().get_config()
config.update(
{
"input_filters": self.input_filters,
"output_filters": self.output_filters,
"kernel_size": self.kernel_size,
"strides": self.strides,
"data_format": self.data_format,
"batch_norm_momentum": self.batch_norm_momentum,
"batch_norm_epsilon": self.batch_norm_epsilon,
"activation": self.activation,
"dropout": self.dropout,
"nores": self.nores,
}
)

return config
22 changes: 22 additions & 0 deletions keras_hub/src/models/efficientnet/cba_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import keras

from keras_hub.src.models.efficientnet.cba import CBABlock
from keras_hub.src.tests.test_case import TestCase


class CBABlockTest(TestCase):
def test_same_input_output_shapes(self):
inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32")
layer = CBABlock(input_filters=32, output_filters=32)

output = layer(inputs)
self.assertEquals(output.shape, (1, 64, 64, 32))
self.assertLen(output, 1)

def test_different_input_output_shapes(self):
inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32")
layer = CBABlock(input_filters=32, output_filters=48)

output = layer(inputs)
self.assertEquals(output.shape, (1, 64, 64, 48))
self.assertLen(output, 1)
112 changes: 73 additions & 39 deletions keras_hub/src/models/efficientnet/efficientnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.efficientnet.cba import CBABlock
from keras_hub.src.models.efficientnet.fusedmbconv import FusedMBConvBlock
from keras_hub.src.models.efficientnet.mbconv import MBConvBlock
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
Expand All @@ -26,15 +27,12 @@ class EfficientNetBackbone(FeaturePyramidBackbone):
(https://arxiv.org/abs/2104.00298) (ICML 2021)

Args:
width_coefficient: float, scaling coefficient for network width.
depth_coefficient: float, scaling coefficient for network depth.
dropout: float, dropout rate at skip connections. The default
value is set to 0.2.
depth_divisor: integer, a unit of network width. The default value is
set to 8.
activation: activation function to use between each convolutional layer.
input_shape: optional shape tuple, it should have exactly 3 input
channels.
stackwise_width_coefficient: list[float] or float, scaling coefficient
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this need to be a list[float]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, done

for network width. If single float, it is assumed that this value
applies to all stacks.
stackwise_depth_coefficient: list[float] or float, scaling coefficient
for network depth. If single float, it is assumed that this value
applies to all stacks.
stackwise_kernel_sizes: list of ints, the kernel sizes used for each
conv block.
stackwise_num_repeats: list of ints, number of times to repeat each
Expand All @@ -61,8 +59,17 @@ class EfficientNetBackbone(FeaturePyramidBackbone):
stackwise_nores_option: list of bools, toggles if residiual connection
is not used. If False (default), the stack will use residual
connections, otherwise not.
dropout: float, dropout rate at skip connections. The default
value is set to 0.2.
depth_divisor: integer, a unit of network width. The default value is
set to 8.
min_depth: integer, minimum number of filters. Can be None and ignored
if use_depth_divisor_as_min_depth is set to True.
activation: activation function to use between each convolutional layer.
input_shape: optional shape tuple, it should have exactly 3 input
channels.


include_initial_padding: bool, whether to include initial zero padding
(as per v1).
use_depth_divisor_as_min_depth: bool, whether to use depth_divisor as
Expand Down Expand Up @@ -99,8 +106,8 @@ class EfficientNetBackbone(FeaturePyramidBackbone):
def __init__(
self,
*,
width_coefficient,
depth_coefficient,
stackwise_width_coefficient=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably make these plural, in keeping with other args? stackwise_width_coefficients and stackwise_depth_coefficients

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

stackwise_depth_coefficient=None,
stackwise_kernel_sizes,
stackwise_num_repeats,
stackwise_input_filters,
Expand All @@ -124,8 +131,19 @@ def __init__(
batch_norm_momentum=0.9,
batch_norm_epsilon=1e-5,
projection_activation=None,
num_features=1280,
**kwargs,
):
num_stacks = len(stackwise_kernel_sizes)
if "depth_coefficient" in kwargs:
stackwise_depth_coefficient = [
kwargs.pop("depth_coefficient")
] * num_stacks
if "width_coefficient" in kwargs:
stackwise_width_coefficient = [
kwargs.pop("width_coefficient")
] * num_stacks

image_input = keras.layers.Input(shape=input_shape)

x = image_input # Intermediate result.
Expand All @@ -138,7 +156,7 @@ def __init__(
# Build stem
stem_filters = round_filters(
filters=stackwise_input_filters[0],
width_coefficient=width_coefficient,
width_coefficient=stackwise_width_coefficient[0],
min_depth=min_depth,
depth_divisor=depth_divisor,
use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth,
Expand Down Expand Up @@ -170,25 +188,27 @@ def __init__(
self._pyramid_outputs = {}
curr_pyramid_level = 1

for i in range(len(stackwise_kernel_sizes)):
for i in range(num_stacks):
num_repeats = stackwise_num_repeats[i]
input_filters = stackwise_input_filters[i]
output_filters = stackwise_output_filters[i]
force_input_filters = stackwise_force_input_filters[i]
nores = stackwise_nores_option[i]
stack_width_coefficient = stackwise_width_coefficient[i]
stack_depth_coefficient = stackwise_depth_coefficient[i]

# Update block input and output filters based on depth multiplier.
input_filters = round_filters(
filters=input_filters,
width_coefficient=width_coefficient,
width_coefficient=stack_width_coefficient,
min_depth=min_depth,
depth_divisor=depth_divisor,
use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth,
cap_round_filter_decrease=cap_round_filter_decrease,
)
output_filters = round_filters(
filters=output_filters,
width_coefficient=width_coefficient,
width_coefficient=stack_width_coefficient,
min_depth=min_depth,
depth_divisor=depth_divisor,
use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth,
Expand All @@ -197,7 +217,7 @@ def __init__(

repeats = round_repeats(
repeats=num_repeats,
depth_coefficient=depth_coefficient,
depth_coefficient=stack_depth_coefficient,
)
strides = stackwise_strides[i]
squeeze_and_excite_ratio = stackwise_squeeze_and_excite_ratios[i]
Expand All @@ -216,7 +236,7 @@ def __init__(
if force_input_filters > 0:
input_filters = round_filters(
filters=force_input_filters,
width_coefficient=width_coefficient,
width_coefficient=stack_width_coefficient,
min_depth=min_depth,
depth_divisor=depth_divisor,
use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth,
Expand Down Expand Up @@ -244,28 +264,40 @@ def __init__(
name=block_name,
)
else:
block = get_conv_constructor(stackwise_block_type)(
input_filters=input_filters,
output_filters=output_filters,
expand_ratio=stackwise_expansion_ratios[i],
kernel_size=stackwise_kernel_sizes[i],
strides=strides,
data_format=data_format,
se_ratio=squeeze_and_excite_ratio,
activation=activation,
dropout=dropout * block_id / blocks,
batch_norm_momentum=batch_norm_momentum,
batch_norm_epsilon=batch_norm_epsilon,
nores=nores,
name=block_name,
)
constructor = get_conv_constructor(stackwise_block_type)
block_kwargs = {
"input_filters": input_filters,
"output_filters": output_filters,
"kernel_size": stackwise_kernel_sizes[i],
"strides": strides,
"data_format": data_format,
"activation": activation,
"dropout": dropout * block_id / blocks,
"batch_norm_momentum": batch_norm_momentum,
"batch_norm_epsilon": batch_norm_epsilon,
"nores": nores,
"name": block_name,
}

if stackwise_block_type in ("fused", "unfused"):
block_kwargs["expand_ratio"] = (
stackwise_expansion_ratios[i]
)
block_kwargs["se_ratio"] = squeeze_and_excite_ratio

if stackwise_block_type == "fused":
block_kwargs["projection_activation"] = (
projection_activation
)

block = constructor(**block_kwargs)
x = block(x)
block_id += 1

# Build top
top_filters = round_filters(
filters=1280,
width_coefficient=width_coefficient,
filters=num_features,
width_coefficient=stackwise_width_coefficient[-1],
min_depth=min_depth,
depth_divisor=depth_divisor,
use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth,
Expand Down Expand Up @@ -298,8 +330,8 @@ def __init__(
super().__init__(inputs=image_input, outputs=x, **kwargs)

# === Config ===
self.width_coefficient = width_coefficient
self.depth_coefficient = depth_coefficient
self.stackwise_width_coefficient = stackwise_width_coefficient
self.stackwise_depth_coefficient = stackwise_depth_coefficient
self.dropout = dropout
self.depth_divisor = depth_divisor
self.min_depth = min_depth
Expand Down Expand Up @@ -329,8 +361,8 @@ def get_config(self):
config = super().get_config()
config.update(
{
"width_coefficient": self.width_coefficient,
"depth_coefficient": self.depth_coefficient,
"stackwise_width_coefficient": self.stackwise_width_coefficient,
"stackwise_depth_coefficient": self.stackwise_depth_coefficient,
"dropout": self.dropout,
"depth_divisor": self.depth_divisor,
"min_depth": self.min_depth,
Expand Down Expand Up @@ -586,9 +618,11 @@ def get_conv_constructor(conv_type):
return MBConvBlock
elif conv_type == "fused":
return FusedMBConvBlock
elif conv_type == "cba":
return CBABlock
else:
raise ValueError(
"Expected `conv_type` to be "
"one of 'unfused', 'fused', but got "
"one of 'unfused', 'fused', 'cba', but got "
f"`conv_type={conv_type}`"
)
Loading
Loading