diff --git a/keras_hub/src/models/efficientnet/cba.py b/keras_hub/src/models/efficientnet/cba.py new file mode 100644 index 000000000..4e145282a --- /dev/null +++ b/keras_hub/src/models/efficientnet/cba.py @@ -0,0 +1,141 @@ +import keras + +BN_AXIS = 3 + + +class CBABlock(keras.layers.Layer): + """ + Args: + input_filters: int, the number of input filters + output_filters: int, the number of output filters + kernel_size: default 3, the kernel_size to apply to the expansion phase + convolutions + strides: default 1, the strides to apply to the expansion phase + convolutions + data_format: str, channels_last (default) or channels_first, expects + tensors to be of shape (N, H, W, C) or (N, C, H, W) respectively + batch_norm_momentum: default 0.9, the BatchNormalization momentum + batch_norm_epsilon: default 1e-3, the BatchNormalization epsilon + activation: default "swish", the activation function used between + convolution operations + dropout: float, the optional dropout rate to apply before the output + convolution, defaults to 0.2 + nores: bool, default False, forces no residual connection if True, + otherwise allows it if False. + + Returns: + A tensor representing a feature map, passed through the ConvBNAct + block + + Note: + Not intended to be used outside of the EfficientNet architecture. + """ + + def __init__( + self, + input_filters, + output_filters, + kernel_size=3, + strides=1, + data_format="channels_last", + batch_norm_momentum=0.9, + batch_norm_epsilon=1e-3, + activation="swish", + dropout=0.2, + nores=False, + **kwargs + ): + super().__init__(**kwargs) + self.input_filters = input_filters + self.output_filters = output_filters + self.kernel_size = kernel_size + self.strides = strides + self.data_format = data_format + self.batch_norm_momentum = batch_norm_momentum + self.batch_norm_epsilon = batch_norm_epsilon + self.activation = activation + self.dropout = dropout + self.nores = nores + + padding_pixels = kernel_size // 2 + self.conv1_pad = keras.layers.ZeroPadding2D( + padding=(padding_pixels, padding_pixels), + name=self.name + "conv_pad", + ) + self.conv1 = keras.layers.Conv2D( + filters=self.output_filters, + kernel_size=kernel_size, + strides=strides, + kernel_initializer=self._conv_kernel_initializer(), + padding="valid", + data_format=data_format, + use_bias=False, + name=self.name + "conv", + ) + self.bn1 = keras.layers.BatchNormalization( + axis=BN_AXIS, + momentum=self.batch_norm_momentum, + epsilon=self.batch_norm_epsilon, + name=self.name + "bn", + ) + self.act = keras.layers.Activation( + self.activation, name=self.name + "activation" + ) + + if self.dropout: + self.dropout_layer = keras.layers.Dropout( + self.dropout, + noise_shape=(None, 1, 1, 1), + name=self.name + "drop", + ) + + def _conv_kernel_initializer( + self, + scale=2.0, + mode="fan_out", + distribution="truncated_normal", + seed=None, + ): + return keras.initializers.VarianceScaling( + scale=scale, mode=mode, distribution=distribution, seed=seed + ) + + def build(self, input_shape): + if self.name is None: + self.name = keras.backend.get_uid("block0") + + def call(self, inputs): + x = self.conv1_pad(inputs) + x = self.conv1(x) + x = self.bn1(x) + x = self.act(x) + + # Residual: + if ( + self.strides == 1 + and self.input_filters == self.output_filters + and not self.nores + ): + if self.dropout: + x = self.dropout_layer(x) + x = keras.layers.Add(name=self.name + "add")([x, inputs]) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "input_filters": self.input_filters, + "output_filters": self.output_filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "data_format": self.data_format, + "batch_norm_momentum": self.batch_norm_momentum, + "batch_norm_epsilon": self.batch_norm_epsilon, + "activation": self.activation, + "dropout": self.dropout, + "nores": self.nores, + } + ) + + return config diff --git a/keras_hub/src/models/efficientnet/cba_test.py b/keras_hub/src/models/efficientnet/cba_test.py new file mode 100644 index 000000000..ec028b123 --- /dev/null +++ b/keras_hub/src/models/efficientnet/cba_test.py @@ -0,0 +1,22 @@ +import keras + +from keras_hub.src.models.efficientnet.cba import CBABlock +from keras_hub.src.tests.test_case import TestCase + + +class CBABlockTest(TestCase): + def test_same_input_output_shapes(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = CBABlock(input_filters=32, output_filters=32) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 32)) + self.assertLen(output, 1) + + def test_different_input_output_shapes(self): + inputs = keras.random.normal(shape=(1, 64, 64, 32), dtype="float32") + layer = CBABlock(input_filters=32, output_filters=48) + + output = layer(inputs) + self.assertEquals(output.shape, (1, 64, 64, 48)) + self.assertLen(output, 1) diff --git a/keras_hub/src/models/efficientnet/efficientnet_backbone.py b/keras_hub/src/models/efficientnet/efficientnet_backbone.py index 95f434149..c71979ad0 100644 --- a/keras_hub/src/models/efficientnet/efficientnet_backbone.py +++ b/keras_hub/src/models/efficientnet/efficientnet_backbone.py @@ -3,6 +3,7 @@ import keras from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.efficientnet.cba import CBABlock from keras_hub.src.models.efficientnet.fusedmbconv import FusedMBConvBlock from keras_hub.src.models.efficientnet.mbconv import MBConvBlock from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone @@ -26,15 +27,12 @@ class EfficientNetBackbone(FeaturePyramidBackbone): (https://arxiv.org/abs/2104.00298) (ICML 2021) Args: - width_coefficient: float, scaling coefficient for network width. - depth_coefficient: float, scaling coefficient for network depth. - dropout: float, dropout rate at skip connections. The default - value is set to 0.2. - depth_divisor: integer, a unit of network width. The default value is - set to 8. - activation: activation function to use between each convolutional layer. - input_shape: optional shape tuple, it should have exactly 3 input - channels. + stackwise_width_coefficients: list[float], scaling coefficient + for network width. If single float, it is assumed that this value + applies to all stacks. + stackwise_depth_coefficients: list[float], scaling coefficient + for network depth. If single float, it is assumed that this value + applies to all stacks. stackwise_kernel_sizes: list of ints, the kernel sizes used for each conv block. stackwise_num_repeats: list of ints, number of times to repeat each @@ -61,8 +59,17 @@ class EfficientNetBackbone(FeaturePyramidBackbone): stackwise_nores_option: list of bools, toggles if residiual connection is not used. If False (default), the stack will use residual connections, otherwise not. + dropout: float, dropout rate at skip connections. The default + value is set to 0.2. + depth_divisor: integer, a unit of network width. The default value is + set to 8. min_depth: integer, minimum number of filters. Can be None and ignored if use_depth_divisor_as_min_depth is set to True. + activation: activation function to use between each convolutional layer. + input_shape: optional shape tuple, it should have exactly 3 input + channels. + + include_initial_padding: bool, whether to include initial zero padding (as per v1). use_depth_divisor_as_min_depth: bool, whether to use depth_divisor as @@ -99,8 +106,8 @@ class EfficientNetBackbone(FeaturePyramidBackbone): def __init__( self, *, - width_coefficient, - depth_coefficient, + stackwise_width_coefficients=None, + stackwise_depth_coefficients=None, stackwise_kernel_sizes, stackwise_num_repeats, stackwise_input_filters, @@ -124,8 +131,19 @@ def __init__( batch_norm_momentum=0.9, batch_norm_epsilon=1e-5, projection_activation=None, + num_features=1280, **kwargs, ): + num_stacks = len(stackwise_kernel_sizes) + if "depth_coefficient" in kwargs: + stackwise_depth_coefficients = [ + kwargs.pop("depth_coefficient") + ] * num_stacks + if "width_coefficient" in kwargs: + stackwise_width_coefficients = [ + kwargs.pop("width_coefficient") + ] * num_stacks + image_input = keras.layers.Input(shape=input_shape) x = image_input # Intermediate result. @@ -138,7 +156,7 @@ def __init__( # Build stem stem_filters = round_filters( filters=stackwise_input_filters[0], - width_coefficient=width_coefficient, + width_coefficient=stackwise_width_coefficients[0], min_depth=min_depth, depth_divisor=depth_divisor, use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, @@ -170,17 +188,19 @@ def __init__( self._pyramid_outputs = {} curr_pyramid_level = 1 - for i in range(len(stackwise_kernel_sizes)): + for i in range(num_stacks): num_repeats = stackwise_num_repeats[i] input_filters = stackwise_input_filters[i] output_filters = stackwise_output_filters[i] force_input_filters = stackwise_force_input_filters[i] nores = stackwise_nores_option[i] + stack_width_coefficient = stackwise_width_coefficients[i] + stack_depth_coefficient = stackwise_depth_coefficients[i] # Update block input and output filters based on depth multiplier. input_filters = round_filters( filters=input_filters, - width_coefficient=width_coefficient, + width_coefficient=stack_width_coefficient, min_depth=min_depth, depth_divisor=depth_divisor, use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, @@ -188,7 +208,7 @@ def __init__( ) output_filters = round_filters( filters=output_filters, - width_coefficient=width_coefficient, + width_coefficient=stack_width_coefficient, min_depth=min_depth, depth_divisor=depth_divisor, use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, @@ -197,7 +217,7 @@ def __init__( repeats = round_repeats( repeats=num_repeats, - depth_coefficient=depth_coefficient, + depth_coefficient=stack_depth_coefficient, ) strides = stackwise_strides[i] squeeze_and_excite_ratio = stackwise_squeeze_and_excite_ratios[i] @@ -216,7 +236,7 @@ def __init__( if force_input_filters > 0: input_filters = round_filters( filters=force_input_filters, - width_coefficient=width_coefficient, + width_coefficient=stack_width_coefficient, min_depth=min_depth, depth_divisor=depth_divisor, use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, @@ -244,28 +264,40 @@ def __init__( name=block_name, ) else: - block = get_conv_constructor(stackwise_block_type)( - input_filters=input_filters, - output_filters=output_filters, - expand_ratio=stackwise_expansion_ratios[i], - kernel_size=stackwise_kernel_sizes[i], - strides=strides, - data_format=data_format, - se_ratio=squeeze_and_excite_ratio, - activation=activation, - dropout=dropout * block_id / blocks, - batch_norm_momentum=batch_norm_momentum, - batch_norm_epsilon=batch_norm_epsilon, - nores=nores, - name=block_name, - ) + constructor = get_conv_constructor(stackwise_block_type) + block_kwargs = { + "input_filters": input_filters, + "output_filters": output_filters, + "kernel_size": stackwise_kernel_sizes[i], + "strides": strides, + "data_format": data_format, + "activation": activation, + "dropout": dropout * block_id / blocks, + "batch_norm_momentum": batch_norm_momentum, + "batch_norm_epsilon": batch_norm_epsilon, + "nores": nores, + "name": block_name, + } + + if stackwise_block_type in ("fused", "unfused"): + block_kwargs["expand_ratio"] = ( + stackwise_expansion_ratios[i] + ) + block_kwargs["se_ratio"] = squeeze_and_excite_ratio + + if stackwise_block_type == "fused": + block_kwargs["projection_activation"] = ( + projection_activation + ) + + block = constructor(**block_kwargs) x = block(x) block_id += 1 # Build top top_filters = round_filters( - filters=1280, - width_coefficient=width_coefficient, + filters=num_features, + width_coefficient=stackwise_width_coefficients[-1], min_depth=min_depth, depth_divisor=depth_divisor, use_depth_divisor_as_min_depth=use_depth_divisor_as_min_depth, @@ -298,8 +330,8 @@ def __init__( super().__init__(inputs=image_input, outputs=x, **kwargs) # === Config === - self.width_coefficient = width_coefficient - self.depth_coefficient = depth_coefficient + self.stackwise_width_coefficients = stackwise_width_coefficients + self.stackwise_depth_coefficients = stackwise_depth_coefficients self.dropout = dropout self.depth_divisor = depth_divisor self.min_depth = min_depth @@ -329,8 +361,8 @@ def get_config(self): config = super().get_config() config.update( { - "width_coefficient": self.width_coefficient, - "depth_coefficient": self.depth_coefficient, + "stackwise_width_coefficients": self.stackwise_width_coefficients, + "stackwise_depth_coefficients": self.stackwise_depth_coefficients, "dropout": self.dropout, "depth_divisor": self.depth_divisor, "min_depth": self.min_depth, @@ -586,9 +618,11 @@ def get_conv_constructor(conv_type): return MBConvBlock elif conv_type == "fused": return FusedMBConvBlock + elif conv_type == "cba": + return CBABlock else: raise ValueError( "Expected `conv_type` to be " - "one of 'unfused', 'fused', but got " + "one of 'unfused', 'fused', 'cba', but got " f"`conv_type={conv_type}`" ) diff --git a/keras_hub/src/models/efficientnet/efficientnet_presets.py b/keras_hub/src/models/efficientnet/efficientnet_presets.py index 9e82a7ca5..202b20c0c 100644 --- a/keras_hub/src/models/efficientnet/efficientnet_presets.py +++ b/keras_hub/src/models/efficientnet/efficientnet_presets.py @@ -135,7 +135,7 @@ "path": "efficientnet", "model_card": "https://arxiv.org/abs/1905.11946", }, - "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_el_ra_imagenet/1", + "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet", }, "efficientnet_em_ra2_imagenet": { "metadata": { @@ -148,7 +148,7 @@ "path": "efficientnet", "model_card": "https://arxiv.org/abs/1905.11946", }, - "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_em_ra2_imagenet/1", + "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet", }, "efficientnet_es_ra_imagenet": { "metadata": { @@ -161,7 +161,46 @@ "path": "efficientnet", "model_card": "https://arxiv.org/abs/1905.11946", }, - "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_es_ra_imagenet/1", + "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet_b1_ft_imagenet", + }, + "efficientnet2_rw_m_agc_imagenet": { + "metadata": { + "description": ( + "EfficientNet-v2 Medium model trained on the ImageNet 1k " + "dataset with adaptive gradient clipping." + ), + "params": 53236442, + "official_name": "EfficientNet", + "path": "efficientnet", + "model_card": "https://arxiv.org/abs/2104.00298", + }, + "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_m_agc_imagenet", + }, + "efficientnet2_rw_s_ra2_imagenet": { + "metadata": { + "description": ( + "EfficientNet-v2 Small model trained on the ImageNet 1k " + "dataset with RandAugment2 recipe." + ), + "params": 23941296, + "official_name": "EfficientNet", + "path": "efficientnet", + "model_card": "https://arxiv.org/abs/2104.00298", + }, + "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_s_ra2_imagenet", + }, + "efficientnet2_rw_t_ra2_imagenet": { + "metadata": { + "description": ( + "EfficientNet-v2 Tiny model trained on the ImageNet 1k " + "dataset with RandAugment2 recipe." + ), + "params": 13649388, + "official_name": "EfficientNet", + "path": "efficientnet", + "model_card": "https://arxiv.org/abs/2104.00298", + }, + "kaggle_handle": "kaggle://keras/efficientnet/keras/efficientnet2_rw_t_ra2_imagenet", }, "efficientnet_lite0_ra_imagenet": { "metadata": { diff --git a/keras_hub/src/models/efficientnet/fusedmbconv.py b/keras_hub/src/models/efficientnet/fusedmbconv.py index 51a7f95fe..01934b762 100644 --- a/keras_hub/src/models/efficientnet/fusedmbconv.py +++ b/keras_hub/src/models/efficientnet/fusedmbconv.py @@ -2,15 +2,6 @@ BN_AXIS = 3 -CONV_KERNEL_INITIALIZER = { - "class_name": "VarianceScaling", - "config": { - "scale": 2.0, - "mode": "fan_out", - "distribution": "truncated_normal", - }, -} - class FusedMBConvBlock(keras.layers.Layer): """Implementation of the FusedMBConv block @@ -44,6 +35,8 @@ class FusedMBConvBlock(keras.layers.Layer): convolutions strides: default 1, the strides to apply to the expansion phase convolutions + data_format: str, channels_last (default) or channels_first, expects + tensors to be of shape (N, H, W, C) or (N, C, H, W) respectively se_ratio: default 0.0, The filters used in the Squeeze-Excitation phase, and are chosen as the maximum between 1 and input_filters*se_ratio batch_norm_momentum: default 0.9, the BatchNormalization momentum @@ -52,8 +45,14 @@ class FusedMBConvBlock(keras.layers.Layer): by 0 errors. activation: default "swish", the activation function used between convolution operations + projection_activation: default None, the activation function to use + after the output projection convoultion dropout: float, the optional dropout rate to apply before the output convolution, defaults to 0.2 + nores: bool, default False, forces no residual connection if True, + otherwise allows it if False. + projection_kernel_size: default 1, the kernel_size to apply to the + output projection phase convolution Returns: A tensor representing a feature map, passed through the FusedMBConv @@ -75,8 +74,10 @@ def __init__( batch_norm_momentum=0.9, batch_norm_epsilon=1e-3, activation="swish", + projection_activation=None, dropout=0.2, nores=False, + projection_kernel_size=1, **kwargs ): super().__init__(**kwargs) @@ -90,17 +91,24 @@ def __init__( self.batch_norm_momentum = batch_norm_momentum self.batch_norm_epsilon = batch_norm_epsilon self.activation = activation + self.projection_activation = projection_activation self.dropout = dropout self.nores = nores + self.projection_kernel_size = projection_kernel_size self.filters = self.input_filters * self.expand_ratio self.filters_se = max(1, int(input_filters * se_ratio)) + padding_pixels = kernel_size // 2 + self.conv1_pad = keras.layers.ZeroPadding2D( + padding=(padding_pixels, padding_pixels), + name=self.name + "expand_conv_pad", + ) self.conv1 = keras.layers.Conv2D( filters=self.filters, kernel_size=kernel_size, strides=strides, - kernel_initializer=CONV_KERNEL_INITIALIZER, - padding="same", + kernel_initializer=self._conv_kernel_initializer(), + padding="valid", data_format=data_format, use_bias=False, name=self.name + "expand_conv", @@ -121,7 +129,7 @@ def __init__( padding="same", data_format=data_format, activation=self.activation, - kernel_initializer=CONV_KERNEL_INITIALIZER, + kernel_initializer=self._conv_kernel_initializer(), name=self.name + "se_reduce", ) @@ -131,16 +139,21 @@ def __init__( padding="same", data_format=data_format, activation="sigmoid", - kernel_initializer=CONV_KERNEL_INITIALIZER, + kernel_initializer=self._conv_kernel_initializer(), name=self.name + "se_expand", ) + padding_pixels = projection_kernel_size // 2 + self.output_conv_pad = keras.layers.ZeroPadding2D( + padding=(padding_pixels, padding_pixels), + name=self.name + "project_conv_pad", + ) self.output_conv = keras.layers.Conv2D( filters=self.output_filters, - kernel_size=1 if expand_ratio != 1 else kernel_size, + kernel_size=projection_kernel_size, strides=1, - kernel_initializer=CONV_KERNEL_INITIALIZER, - padding="same", + kernel_initializer=self._conv_kernel_initializer(), + padding="valid", data_format=data_format, use_bias=False, name=self.name + "project_conv", @@ -153,6 +166,11 @@ def __init__( name=self.name + "project_bn", ) + if self.projection_activation: + self.projection_act = keras.layers.Activation( + self.projection_activation, name=self.name + "projection_act" + ) + if self.dropout: self.dropout_layer = keras.layers.Dropout( self.dropout, @@ -160,18 +178,27 @@ def __init__( name=self.name + "drop", ) + def _conv_kernel_initializer( + self, + scale=2.0, + mode="fan_out", + distribution="truncated_normal", + seed=None, + ): + return keras.initializers.VarianceScaling( + scale=scale, mode=mode, distribution=distribution, seed=seed + ) + def build(self, input_shape): if self.name is None: self.name = keras.backend.get_uid("block0") def call(self, inputs): # Expansion phase - if self.expand_ratio != 1: - x = self.conv1(inputs) - x = self.bn1(x) - x = self.act(x) - else: - x = inputs + x = self.conv1_pad(inputs) + x = self.conv1(x) + x = self.bn1(x) + x = self.act(x) # Squeeze and excite if 0 < self.se_ratio <= 1: @@ -194,10 +221,11 @@ def call(self, inputs): x = keras.layers.multiply([x, se], name=self.name + "se_excite") # Output phase: + x = self.output_conv_pad(x) x = self.output_conv(x) x = self.bn2(x) - if self.expand_ratio == 1: - x = self.act(x) + if self.expand_ratio == 1 and self.projection_activation: + x = self.projection_act(x) # Residual: if ( @@ -222,8 +250,10 @@ def get_config(self): "batch_norm_momentum": self.batch_norm_momentum, "batch_norm_epsilon": self.batch_norm_epsilon, "activation": self.activation, + "projection_activation": self.projection_activation, "dropout": self.dropout, "nores": self.nores, + "projection_kernel_size": self.projection_kernel_size, } base_config = super().get_config() diff --git a/keras_hub/src/models/efficientnet/mbconv.py b/keras_hub/src/models/efficientnet/mbconv.py index b4dc05f7c..20afab4e8 100644 --- a/keras_hub/src/models/efficientnet/mbconv.py +++ b/keras_hub/src/models/efficientnet/mbconv.py @@ -2,15 +2,6 @@ BN_AXIS = 3 -CONV_KERNEL_INITIALIZER = { - "class_name": "VarianceScaling", - "config": { - "scale": 2.0, - "mode": "fan_out", - "distribution": "truncated_normal", - }, -} - class MBConvBlock(keras.layers.Layer): def __init__( @@ -99,7 +90,7 @@ def __init__( filters=self.filters, kernel_size=1, strides=1, - kernel_initializer=CONV_KERNEL_INITIALIZER, + kernel_initializer=self._conv_kernel_initializer(), padding="same", data_format=data_format, use_bias=False, @@ -117,7 +108,7 @@ def __init__( self.depthwise = keras.layers.DepthwiseConv2D( kernel_size=self.kernel_size, strides=self.strides, - depthwise_initializer=CONV_KERNEL_INITIALIZER, + depthwise_initializer=self._conv_kernel_initializer(), padding="same", data_format=data_format, use_bias=False, @@ -137,7 +128,7 @@ def __init__( padding="same", data_format=data_format, activation=self.activation, - kernel_initializer=CONV_KERNEL_INITIALIZER, + kernel_initializer=self._conv_kernel_initializer(), name=self.name + "se_reduce", ) @@ -147,16 +138,22 @@ def __init__( padding="same", data_format=data_format, activation="sigmoid", - kernel_initializer=CONV_KERNEL_INITIALIZER, + kernel_initializer=self._conv_kernel_initializer(), name=self.name + "se_expand", ) + projection_kernel_size = 1 if expand_ratio != 1 else kernel_size + padding_pixels = projection_kernel_size // 2 + self.output_conv_pad = keras.layers.ZeroPadding2D( + padding=(padding_pixels, padding_pixels), + name=self.name + "project_conv_pad", + ) self.output_conv = keras.layers.Conv2D( filters=self.output_filters, - kernel_size=1 if expand_ratio != 1 else kernel_size, + kernel_size=projection_kernel_size, strides=1, - kernel_initializer=CONV_KERNEL_INITIALIZER, - padding="same", + kernel_initializer=self._conv_kernel_initializer(), + padding="valid", data_format=data_format, use_bias=False, name=self.name + "project_conv", @@ -176,6 +173,17 @@ def __init__( name=self.name + "drop", ) + def _conv_kernel_initializer( + self, + scale=2.0, + mode="fan_out", + distribution="truncated_normal", + seed=None, + ): + return keras.initializers.VarianceScaling( + scale=scale, mode=mode, distribution=distribution, seed=seed + ) + def build(self, input_shape): if self.name is None: self.name = keras.backend.get_uid("block0") @@ -214,6 +222,7 @@ def call(self, inputs): x = keras.layers.multiply([x, se], name=self.name + "se_excite") # Output phase + x = self.output_conv_pad(x) x = self.output_conv(x) x = self.bn3(x) diff --git a/keras_hub/src/utils/timm/convert_efficientnet.py b/keras_hub/src/utils/timm/convert_efficientnet.py index 0b2c1edba..fcedb2ecd 100644 --- a/keras_hub/src/utils/timm/convert_efficientnet.py +++ b/keras_hub/src/utils/timm/convert_efficientnet.py @@ -11,44 +11,44 @@ VARIANT_MAP = { "b0": { - "width_coefficient": 1.0, - "depth_coefficient": 1.0, + "stackwise_width_coefficients": [1.0] * 7, + "stackwise_depth_coefficients": [1.0] * 7, "stackwise_squeeze_and_excite_ratios": [0.25] * 7, }, "b1": { - "width_coefficient": 1.0, - "depth_coefficient": 1.1, + "stackwise_width_coefficients": [1.0] * 7, + "stackwise_depth_coefficients": [1.1] * 7, "stackwise_squeeze_and_excite_ratios": [0.25] * 7, }, "b2": { - "width_coefficient": 1.1, - "depth_coefficient": 1.2, + "stackwise_width_coefficients": [1.1] * 7, + "stackwise_depth_coefficients": [1.2] * 7, "stackwise_squeeze_and_excite_ratios": [0.25] * 7, }, "b3": { - "width_coefficient": 1.2, - "depth_coefficient": 1.4, + "stackwise_width_coefficients": [1.2] * 7, + "stackwise_depth_coefficients": [1.4] * 7, "stackwise_squeeze_and_excite_ratios": [0.25] * 7, }, "b4": { - "width_coefficient": 1.4, - "depth_coefficient": 1.8, + "stackwise_width_coefficients": [1.4] * 7, + "stackwise_depth_coefficients": [1.8] * 7, "stackwise_squeeze_and_excite_ratios": [0.25] * 7, }, "b5": { - "width_coefficient": 1.6, - "depth_coefficient": 2.2, + "stackwise_width_coefficients": [1.6] * 7, + "stackwise_depth_coefficients": [2.2] * 7, "stackwise_squeeze_and_excite_ratios": [0.25] * 7, }, "lite0": { - "width_coefficient": 1.0, - "depth_coefficient": 1.0, + "stackwise_width_coefficients": [1.0] * 7, + "stackwise_depth_coefficients": [1.0] * 7, "stackwise_squeeze_and_excite_ratios": [0] * 7, "activation": "relu6", }, "el": { - "width_coefficient": 1.2, - "depth_coefficient": 1.4, + "stackwise_width_coefficients": [1.2] * 6, + "stackwise_depth_coefficients": [1.4] * 6, "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5], "stackwise_num_repeats": [1, 2, 4, 5, 4, 2], "stackwise_input_filters": [32, 24, 32, 48, 96, 144], @@ -62,8 +62,8 @@ "activation": "relu", }, "em": { - "width_coefficient": 1.0, - "depth_coefficient": 1.1, + "stackwise_width_coefficients": [1.0] * 6, + "stackwise_depth_coefficients": [1.1] * 6, "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5], "stackwise_num_repeats": [1, 2, 4, 5, 4, 2], "stackwise_input_filters": [32, 24, 32, 48, 96, 144], @@ -77,8 +77,8 @@ "activation": "relu", }, "es": { - "width_coefficient": 1.0, - "depth_coefficient": 1.0, + "stackwise_width_coefficients": [1.0] * 6, + "stackwise_depth_coefficients": [1.0] * 6, "stackwise_kernel_sizes": [3, 3, 3, 5, 5, 5], "stackwise_num_repeats": [1, 2, 4, 5, 4, 2], "stackwise_input_filters": [32, 24, 32, 48, 96, 144], @@ -91,6 +91,53 @@ "stackwise_nores_option": [True] + [False] * 5, "activation": "relu", }, + "rw_m": { + "stackwise_width_coefficients": [1.2] * 6, + "stackwise_depth_coefficients": [1.2] * 4 + [1.6] * 2, + "stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3], + "stackwise_num_repeats": [2, 4, 4, 6, 9, 15], + "stackwise_input_filters": [24, 24, 48, 64, 128, 160], + "stackwise_output_filters": [24, 48, 64, 128, 160, 272], + "stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6], + "stackwise_strides": [1, 2, 2, 2, 1, 2], + "stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25], + "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3, + "stackwise_force_input_filters": [0, 0, 0, 0, 0, 0], + "stackwise_nores_option": [False] * 6, + "activation": "silu", + "num_features": 1792, + }, + "rw_s": { + "stackwise_width_coefficients": [1.0] * 6, + "stackwise_depth_coefficients": [1.0] * 6, + "stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3], + "stackwise_num_repeats": [2, 4, 4, 6, 9, 15], + "stackwise_input_filters": [24, 24, 48, 64, 128, 160], + "stackwise_output_filters": [24, 48, 64, 128, 160, 272], + "stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6], + "stackwise_strides": [1, 2, 2, 2, 1, 2], + "stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25], + "stackwise_block_types": ["fused"] * 3 + ["unfused"] * 3, + "stackwise_force_input_filters": [0, 0, 0, 0, 0, 0], + "stackwise_nores_option": [False] * 6, + "activation": "silu", + "num_features": 1792, + }, + "rw_t": { + "stackwise_width_coefficients": [0.8] * 6, + "stackwise_depth_coefficients": [0.9] * 6, + "stackwise_kernel_sizes": [3, 3, 3, 3, 3, 3], + "stackwise_num_repeats": [2, 4, 4, 6, 9, 15], + "stackwise_input_filters": [24, 24, 48, 64, 128, 160], + "stackwise_output_filters": [24, 48, 64, 128, 160, 256], + "stackwise_expansion_ratios": [1, 4, 4, 4, 6, 6], + "stackwise_strides": [1, 2, 2, 2, 1, 2], + "stackwise_squeeze_and_excite_ratios": [0, 0, 0, 0.25, 0.25, 0.25], + "stackwise_block_types": ["cba"] + ["fused"] * 2 + ["unfused"] * 3, + "stackwise_force_input_filters": [0, 0, 0, 0, 0, 0], + "stackwise_nores_option": [False] * 6, + "activation": "silu", + }, } @@ -199,15 +246,18 @@ def port_batch_normalization(keras_layer, hf_weight_prefix): # Stages num_stacks = len(backbone.stackwise_kernel_sizes) + for stack_index in range(num_stacks): block_type = backbone.stackwise_block_types[stack_index] expansion_ratio = backbone.stackwise_expansion_ratios[stack_index] repeats = backbone.stackwise_num_repeats[stack_index] + stack_depth_coefficient = backbone.stackwise_depth_coefficients[ + stack_index + ] + + repeats = int(math.ceil(stack_depth_coefficient * repeats)) - repeats = int( - math.ceil(VARIANT_MAP[variant]["depth_coefficient"] * repeats) - ) se_ratio = VARIANT_MAP[variant]["stackwise_squeeze_and_excite_ratios"][ stack_index ] @@ -278,18 +328,17 @@ def port_batch_normalization(keras_layer, hf_weight_prefix): fused_block_layer = backbone.get_layer(keras_block_prefix) # Initial Expansion Conv - if expansion_ratio != 1: - port_conv2d( - fused_block_layer.conv1, - hf_block_prefix + "conv_exp", - port_bias=False, - ) - conv_pw_count += 1 - port_batch_normalization( - fused_block_layer.bn1, - hf_block_prefix + f"bn{bn_count}", - ) - bn_count += 1 + port_conv2d( + fused_block_layer.conv1, + hf_block_prefix + "conv_exp", + port_bias=False, + ) + conv_pw_count += 1 + port_batch_normalization( + fused_block_layer.bn1, + hf_block_prefix + f"bn{bn_count}", + ) + bn_count += 1 if 0 < se_ratio <= 1: # Squeeze and Excite @@ -366,6 +415,20 @@ def port_batch_normalization(keras_layer, hf_weight_prefix): hf_block_prefix + f"bn{bn_count}", ) bn_count += 1 + elif block_type == "cba": + cba_block_layer = backbone.get_layer(keras_block_prefix) + # Initial Expansion Conv + port_conv2d( + cba_block_layer.conv1, + hf_block_prefix + "conv", + port_bias=False, + ) + conv_pw_count += 1 + port_batch_normalization( + cba_block_layer.bn1, + hf_block_prefix + f"bn{bn_count}", + ) + bn_count += 1 # Head/Top port_conv2d(backbone.get_layer("top_conv"), "conv_head", port_bias=False) diff --git a/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py b/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py index 53f3c9029..1a3f3bd56 100644 --- a/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py +++ b/tools/checkpoint_conversion/convert_efficientnet_checkpoints.py @@ -27,6 +27,12 @@ --preset efficientnet_em_ra2_imagenet --upload_uri kaggle://keras/efficientnet/keras/efficientnet_em_ra2_imagenet python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ --preset efficientnet_es_ra_imagenet --upload_uri kaggle://keras/efficientnet/keras/efficientnet_es_ra_imagenet +python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ + --preset efficientnet2_rw_m_agc_imagenet --upload_uri kaggle://keras/efficientnet/keras/efficientnet_el_ra_imagenet +python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ + --preset efficientnet2_rw_s_ra2_imagenet --upload_uri kaggle://keras/efficientnet/keras/efficientnet_em_ra2_imagenet +python tools/checkpoint_conversion/convert_efficientnet_checkpoints.py \ + --preset efficientnet2_rw_t_ra2_imagenet --upload_uri kaggle://keras/efficientnet/keras/efficientnet_es_ra_imagenet """ import os @@ -56,6 +62,9 @@ "efficientnet_el_ra_imagenet": "timm/efficientnet_el.ra_in1k", "efficientnet_em_ra2_imagenet": "timm/efficientnet_em.ra2_in1k", "efficientnet_es_ra_imagenet": "timm/efficientnet_es.ra_in1k", + "efficientnet2_rw_m_agc_imagenet": "timm/efficientnetv2_rw_m.agc_in1k", + "efficientnet2_rw_s_ra2_imagenet": "timm/efficientnetv2_rw_s.ra2_in1k", + "efficientnet2_rw_t_ra2_imagenet": "timm/efficientnetv2_rw_t.ra2_in1k", } FLAGS = flags.FLAGS