From 741b8895ddb9209dbf82533427e8501b934cf6ee Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 13 Nov 2024 18:16:06 -0800 Subject: [PATCH 01/35] vit base --- keras_hub/src/models/vit/vit_backbone.py | 52 +++++ keras_hub/src/models/vit/vit_layers.py | 245 +++++++++++++++++++++++ 2 files changed, 297 insertions(+) create mode 100644 keras_hub/src/models/vit/vit_backbone.py create mode 100644 keras_hub/src/models/vit/vit_layers.py diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py new file mode 100644 index 000000000..65a3c7377 --- /dev/null +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -0,0 +1,52 @@ +import keras +from keras import ops + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.utils.keras_utils import standardize_data_format + + + + + + +class ViTBackbone(Backbone): + def __init__( + self, + image_shape, + patch_size, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + layer_norm_epsilon=1e-6, + data_format=None, + dtype=None, + **kwargs, + ): + data_format = standardize_data_format(data_format) + h_axis, w_axis = ( + (-3, -2) if data_format == "channels_last" else (-2, -1) + ) + # Check that the input image is well specified. + if image_shape[h_axis] is None or image_shape[w_axis] is None: + raise ValueError( + f"Image shape must have defined height and width. Found `None` " + f"at index {h_axis} (height) or {w_axis} (width). " + f"Image shape: {image_shape}" + ) + if image_shape[h_axis] != image_shape[w_axis]: + raise ValueError( + f"Image height and width must be equal. Found height: " + f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at " + f"indices {h_axis} and {w_axis} respectively. Image shape: " + f"{image_shape}" + ) + + # === Layers === + patch_and_embedding = ViTPatchingAndEmbedding( + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + embed_dim=hidden_dim, + ) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py new file mode 100644 index 000000000..643798522 --- /dev/null +++ b/keras_hub/src/models/vit/vit_layers.py @@ -0,0 +1,245 @@ +import keras +from keras import ops + +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class TokenLayer(keras.layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def build(self, input_shape): + self.cls_token = self.add_weight( + name="cls", + shape=(1, 1, input_shape[-1]), + initializer="zeros", + dtype=self.dtype_policy, + name="cls_token", + ) + self.built = True + + def call(self, inputs): + cls_token = self.cls_token + keras.ops.zeros_like(inputs[:, 0:1]) + out = keras.ops.concatenate([cls_token, inputs], axis=1) + + return out + + +class MLP(keras.layers.Layer): + def __init__( + self, + hidden_dim, + mlp_dim, + use_bias=True, + dropout_rate=0.0, + dtype=None, + **kwargs, + ): + super().__init__(**kwargs) + + # === config === + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.use_bias = use_bias + self.dropout_rate = dropout_rate + + def build(self, input_shape): + self.dense1 = keras.layers.Dense( + units=self.mlp_dim, + use_bias=self.use_bias, + activation="gelu", + bias_initializer=( + keras.initializers.RandomNormal(stddev=1e-6) + if self.use_bias + else None + ), + dtype=self.dtype_policy, + name="dense_1", + ) + self.dense1.build(input_shape) + self.dense2 = keras.layers.Dense( + units=self.hidden_dim, + use_bias=self.use_bias, + bias_initializer=( + keras.initializers.RandomNormal(stddev=1e-6) + if self.use_bias + else None + ), + dtype=self.dtype_policy, + name="dense_2", + ) + self.dense2.build((None, None, self.mlp_dim)) + self.dropout = keras.layers.Dropout(self.dropout_rate) + self.built = True + + def call(self, inputs): + x = self.dense1(inputs) + x = self.dense2(x) + out = self.dropout(x) + return out + + +class ViTPatchingAndEmbedding(keras.layers.Layer): + def __init__( + self, + image_size, + patch_size, + hidden_dim, + num_channels=3, + dtype=None, + **kwargs, + ): + super().__init__(**kwargs) + num_patches = (image_size // patch_size) ** 2 + num_positions = num_patches + 1 + + # === config === + self.hidden_dim = hidden_dim + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.num_positions = num_positions + self.dtype = dtype + + def build(self, input_shape): + self.patch_embedding = keras.layers.Conv2D( + filters=self.hidden_dim, + kernel_size=self.patch_size, + strides=self.patch_size, + padding="valid", + activation=None, + dtype=self.dtype_policy, + name="patch_embedding", + ) + self.patch_embedding.build(input_shape) + self.position_embedding = keras.layers.Embedding( + self.num_positions, + self.hidden_dim, + dtype=self.dtype_policy, + name="position_embedding", + ) + self.position_embedding.build([1, self.num_positions]) + self.position_ids = ops.expand_dims( + ops.arange(self.num_positions), axis=0 + ) + self.built = True + + def call(self, input_tokens): + x = self.patch_embedding(input_tokens) + input_shape = ops.shape(x) + x = ops.reshape(x, [input_shape[0], -1, input_shape[-1]]) + x = x + self.position_embedding(self.position_ids) + return x + + def compute_output_shape(self, input_shape): + return ( + input_shape[0], + self.num_positions, + self.hidden_dim, + ) + + +class ViTEncoderBlock(keras.layers.Layer): + def __init__( + self, + num_heads, + hidden_dim, + mlp_dim, + dropout_rate, + attention_dropout, + layer_norm_epsilon, + **kwargs, + ): + super().__init__(**kwargs) + + key_dim = hidden_dim // num_heads + + # === config === + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.key_dim = key_dim + self.mlp_dim = mlp_dim + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + + def build(self, input_shape): + # Attention block + self.layer_norm_1 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, name="ln_1" + ) + self.layer_norm_1.build(input_shape) + self.mha = keras.layers.MultiHeadAttention( + num_heads=self.num_heads, + key_dim=self.key_dim, + use_bias=False, + dropout=self.attention_dropout, + name="mha", + ) + self.mha.build(input_shape, input_shape) + self.dropout = keras.layers.Dropout(self.dropout_rate) + + # MLP block + self.layer_norm_2 = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, name="ln_2" + ) + self.layer_norm_2.build((None, None, self.hidden_dim)) + self.mlp = MLP( + hidden_dim=self.hidden_dim, mlp_dim=self.mlp_dim, name="mlp" + ) + self.mlp((None, None, self.hidden_dim)) + self.built = True + + def call(self, inputs): + x = self.layer_norm_1(inputs) + x = self.mha(x, x) + x = self.dropout(x) + x = x + inputs + + y = self.layer_norm_2(x) + y = self.mlp(y) + + return x + y + + +class ViTEncoder(keras.layers.Layer): + def __init__( + self, + num_layers, + num_heads, + hidden_dim, + mlp_dim, + dropout, + attention_dropout, + layer_norm_epsilon=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + + # === config === + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + + def build(self, input_shape): + layers = [] + for i in range(self.num_layers): + encoder_block = ViTEncoderBlock( + num_heads=self.num_heads, + hidden_dim=self.hidden_dim, + mlp_dim=self.mlp_dim, + attention_dropout=self.attention_dropout, + layer_norm_epsilon=self.layer_norm_epsilon, + name=f"tranformer_block_{i+1}", + ) + encoder_block.build((None, None, self.hidden_dim)) + layers.append(encoder_block) + + encoder_layers = keras.Sequential(layers) + + From 13dae08f12308675b6f0b6ab308b8c63815c559f Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 14:07:10 -0800 Subject: [PATCH 02/35] Add vit backbone, classifier and preprocessor layers --- keras_hub/api/layers/__init__.py | 1 + keras_hub/api/models/__init__.py | 5 ++ keras_hub/src/models/vit/__init__.py | 0 keras_hub/src/models/vit/vit_backbone.py | 35 ++++++++--- .../src/models/vit/vit_image_classifier.py | 61 +++++++++++++++++++ .../vit/vit_image_classifier_preprocessor.py | 12 ++++ .../src/models/vit/vit_image_converter.py | 8 +++ keras_hub/src/models/vit/vit_layers.py | 54 +++++++++++----- 8 files changed, 154 insertions(+), 22 deletions(-) create mode 100644 keras_hub/src/models/vit/__init__.py create mode 100644 keras_hub/src/models/vit/vit_image_classifier.py create mode 100644 keras_hub/src/models/vit/vit_image_classifier_preprocessor.py create mode 100644 keras_hub/src/models/vit/vit_image_converter.py diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 2a29cdb64..09cbb4c4d 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -62,6 +62,7 @@ SegFormerImageConverter, ) from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter from keras_hub.src.models.whisper.whisper_audio_converter import ( WhisperAudioConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index dd85a97a4..e43acc136 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -325,6 +325,11 @@ from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import ( VGGImageClassifierPreprocessor, ) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer diff --git a/keras_hub/src/models/vit/__init__.py b/keras_hub/src/models/vit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 65a3c7377..7dc74997e 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -1,14 +1,13 @@ import keras -from keras import ops +from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.vit.vit_layers import ViTEncoder +from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding from keras_hub.src.utils.keras_utils import standardize_data_format - - - - +@keras_hub_export("keras_hub.models.ViTBackbone") class ViTBackbone(Backbone): def __init__( self, @@ -44,9 +43,31 @@ def __init__( f"{image_shape}" ) - # === Layers === - patch_and_embedding = ViTPatchingAndEmbedding( + # === Functional Model === + inputs = keras.layers.Input(shape=image_shape) + + x = ViTPatchingAndEmbedding( kernel_size=(patch_size, patch_size), strides=(patch_size, patch_size), embed_dim=hidden_dim, + dtype=dtype, + )(inputs) + + x = ViTEncoder( + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + layer_norm_epsilon=layer_norm_epsilon, + dtype=dtype, + )(x) + + output = x[:, 0] + + super().__init__( + inputs=inputs, + outputs=output, + **kwargs, ) diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py new file mode 100644 index 000000000..e24312e5f --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -0,0 +1,61 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) + + +@keras_hub_export("keras_hub.models.ViTImageClassifier") +class ViTImageClassifier(ImageClassifier): + backbone_cls = ViTBackbone + preprocessor_cls = ViTImageClassifierPreprocessor + + def __init__( + self, + backbone, + num_classes, + preprocessor=None, + activation=None, + head_dtype=None, + **kwargs, + ): + head_dtype = head_dtype or backbone.dtype_policy + + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + dtype=head_dtype, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + outputs = self.output_dense(x) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === Config === + self.num_classes = num_classes + self.activation = activation + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "pooling": self.pooling, + } + ) + return config diff --git a/keras_hub/src/models/vit/vit_image_classifier_preprocessor.py b/keras_hub/src/models/vit/vit_image_classifier_preprocessor.py new file mode 100644 index 000000000..7e50918eb --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_classifier_preprocessor.py @@ -0,0 +1,12 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter + + +@keras_hub_export("keras_hub.models.ViTImageClassifierPreprocessor") +class ViTImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = ViTBackbone + image_converter_cls = ViTImageConverter diff --git a/keras_hub/src/models/vit/vit_image_converter.py b/keras_hub/src/models/vit/vit_image_converter.py new file mode 100644 index 000000000..79d3007ea --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.vit.vit_backbone import ViTBackbone + + +@keras_hub_export("keras_hub.layers.ViTImageConverter") +class ViTImageConverter(ImageConverter): + backbone_cls = ViTBackbone diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 643798522..8212e09d3 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -1,3 +1,5 @@ +import math + import keras from keras import ops @@ -10,7 +12,6 @@ def __init__(self, **kwargs): def build(self, input_shape): self.cls_token = self.add_weight( - name="cls", shape=(1, 1, input_shape[-1]), initializer="zeros", dtype=self.dtype_policy, @@ -32,7 +33,6 @@ def __init__( mlp_dim, use_bias=True, dropout_rate=0.0, - dtype=None, **kwargs, ): super().__init__(**kwargs) @@ -69,7 +69,7 @@ def build(self, input_shape): name="dense_2", ) self.dense2.build((None, None, self.mlp_dim)) - self.dropout = keras.layers.Dropout(self.dropout_rate) + self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") self.built = True def call(self, inputs): @@ -86,7 +86,7 @@ def __init__( patch_size, hidden_dim, num_channels=3, - dtype=None, + data_format=None, **kwargs, ): super().__init__(**kwargs) @@ -100,7 +100,7 @@ def __init__( self.num_channels = num_channels self.num_patches = num_patches self.num_positions = num_positions - self.dtype = dtype + self.data_format = standardize_data_format(data_format) def build(self, input_shape): self.patch_embedding = keras.layers.Conv2D( @@ -109,10 +109,15 @@ def build(self, input_shape): strides=self.patch_size, padding="valid", activation=None, + kernel_initializer=keras.initializers.RandomNormal( + stddev=math.sqrt(1 / (3 * self.patch_size * self.patch_size)), + ), dtype=self.dtype_policy, + data_format=self.data_format, name="patch_embedding", ) self.patch_embedding.build(input_shape) + self.token_layer = TokenLayer(dtype=self.dtype_policy) self.position_embedding = keras.layers.Embedding( self.num_positions, self.hidden_dim, @@ -125,10 +130,13 @@ def build(self, input_shape): ) self.built = True - def call(self, input_tokens): - x = self.patch_embedding(input_tokens) - input_shape = ops.shape(x) + def call(self, inputs): + x = self.patch_embedding(inputs) + input_shape = ops.shape(x) # (N, H, W, C) or (N, C, H, W) + if self.data_format == "channels_first": + x = ops.transpose(x, axes=(0, 2, 3, 1)) x = ops.reshape(x, [input_shape[0], -1, input_shape[-1]]) + x = self.token_layer(x) x = x + self.position_embedding(self.position_ids) return x @@ -167,7 +175,9 @@ def __init__( def build(self, input_shape): # Attention block self.layer_norm_1 = keras.layers.LayerNormalization( - epsilon=self.layer_norm_epsilon, name="ln_1" + epsilon=self.layer_norm_epsilon, + name="ln_1", + dtype=self.dtype_policy, ) self.layer_norm_1.build(input_shape) self.mha = keras.layers.MultiHeadAttention( @@ -176,17 +186,23 @@ def build(self, input_shape): use_bias=False, dropout=self.attention_dropout, name="mha", + dtype=self.dtype_policy, ) self.mha.build(input_shape, input_shape) - self.dropout = keras.layers.Dropout(self.dropout_rate) + self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") # MLP block self.layer_norm_2 = keras.layers.LayerNormalization( - epsilon=self.layer_norm_epsilon, name="ln_2" + epsilon=self.layer_norm_epsilon, + name="ln_2", + dtype=self.dtype_policy, ) self.layer_norm_2.build((None, None, self.hidden_dim)) self.mlp = MLP( - hidden_dim=self.hidden_dim, mlp_dim=self.mlp_dim, name="mlp" + hidden_dim=self.hidden_dim, + mlp_dim=self.mlp_dim, + name="mlp", + dtype=self.dtype_policy, ) self.mlp((None, None, self.hidden_dim)) self.built = True @@ -239,7 +255,15 @@ def build(self, input_shape): ) encoder_block.build((None, None, self.hidden_dim)) layers.append(encoder_block) - - encoder_layers = keras.Sequential(layers) - + self.encoder_layers = keras.Sequential(layers, name="encoder_layers") + self.layer_norm = keras.layers.Normalization( + self.layer_norm_epsilon, name="ln" + ) + self.layer_norm.build((None, None, self.hidden_dim)) + + def call(self, inputs): + x = self.dropout(inputs) + x = self.encoder_layers(x) + x = self.layer_norm(x) + return x From b64b137bb02d6981b5905962f35fae9797e22095 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 15:34:29 -0800 Subject: [PATCH 03/35] update args --- keras_hub/src/models/vit/vit_backbone.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 7dc74997e..66302e128 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -17,8 +17,8 @@ def __init__( num_heads, hidden_dim, mlp_dim, - dropout, - attention_dropout, + dropout=0.0, + attention_dropout=0.0, layer_norm_epsilon=1e-6, data_format=None, dtype=None, @@ -47,9 +47,9 @@ def __init__( inputs = keras.layers.Input(shape=image_shape) x = ViTPatchingAndEmbedding( - kernel_size=(patch_size, patch_size), - strides=(patch_size, patch_size), - embed_dim=hidden_dim, + image_size=image_shape[h_axis], + patch_size=patch_size, + hidden_dim=hidden_dim, dtype=dtype, )(inputs) From 429d6357ecca43c4d9804e27b6aa9636f50c2b58 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 15:44:59 -0800 Subject: [PATCH 04/35] add default args --- keras_hub/src/models/vit/vit_backbone.py | 4 ++-- keras_hub/src/models/vit/vit_layers.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 66302e128..693bec4e2 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -17,7 +17,7 @@ def __init__( num_heads, hidden_dim, mlp_dim, - dropout=0.0, + dropout_rate=0.0, attention_dropout=0.0, layer_norm_epsilon=1e-6, data_format=None, @@ -58,7 +58,7 @@ def __init__( num_heads=num_heads, hidden_dim=hidden_dim, mlp_dim=mlp_dim, - dropout=dropout, + dropout_rate=dropout_rate, attention_dropout=attention_dropout, layer_norm_epsilon=layer_norm_epsilon, dtype=dtype, diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 8212e09d3..bfd1c35af 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -154,9 +154,9 @@ def __init__( num_heads, hidden_dim, mlp_dim, - dropout_rate, - attention_dropout, - layer_norm_epsilon, + dropout_rate=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, **kwargs, ): super().__init__(**kwargs) @@ -226,8 +226,8 @@ def __init__( num_heads, hidden_dim, mlp_dim, - dropout, - attention_dropout, + dropout_rate=0.0, + attention_dropout=0.0, layer_norm_epsilon=1e-6, **kwargs, ): @@ -238,7 +238,7 @@ def __init__( self.num_heads = num_heads self.hidden_dim = hidden_dim self.mlp_dim = mlp_dim - self.dropout = dropout + self.dropout_rate = dropout_rate self.attention_dropout = attention_dropout self.layer_norm_epsilon = layer_norm_epsilon @@ -249,13 +249,14 @@ def build(self, input_shape): num_heads=self.num_heads, hidden_dim=self.hidden_dim, mlp_dim=self.mlp_dim, + dropout_rate=self.dropout_rate, attention_dropout=self.attention_dropout, layer_norm_epsilon=self.layer_norm_epsilon, name=f"tranformer_block_{i+1}", ) encoder_block.build((None, None, self.hidden_dim)) layers.append(encoder_block) - + self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") self.encoder_layers = keras.Sequential(layers, name="encoder_layers") self.layer_norm = keras.layers.Normalization( self.layer_norm_epsilon, name="ln" From 6d69abcda2fbea6d9b3f015f2edec69cfc4acac2 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 15:46:59 -0800 Subject: [PATCH 05/35] correct build method --- keras_hub/src/models/vit/vit_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index bfd1c35af..279655414 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -204,7 +204,7 @@ def build(self, input_shape): name="mlp", dtype=self.dtype_policy, ) - self.mlp((None, None, self.hidden_dim)) + self.mlp.build((None, None, self.hidden_dim)) self.built = True def call(self, inputs): From 2e878846d1081857a114e372d27d7533595cd2a5 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 15:51:40 -0800 Subject: [PATCH 06/35] fix build issues --- keras_hub/src/models/vit/vit_layers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 279655414..22322d3f3 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -252,14 +252,17 @@ def build(self, input_shape): dropout_rate=self.dropout_rate, attention_dropout=self.attention_dropout, layer_norm_epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, name=f"tranformer_block_{i+1}", ) encoder_block.build((None, None, self.hidden_dim)) layers.append(encoder_block) self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") self.encoder_layers = keras.Sequential(layers, name="encoder_layers") - self.layer_norm = keras.layers.Normalization( - self.layer_norm_epsilon, name="ln" + self.layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="ln", ) self.layer_norm.build((None, None, self.hidden_dim)) From bd3cce0a1e4d4d69d1f42b64b7f482a474144151 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 15 Nov 2024 16:01:09 -0800 Subject: [PATCH 07/35] fix bugs --- keras_hub/src/models/image_classifier.py | 9 +++- keras_hub/src/models/vit/vit_backbone.py | 4 +- .../src/models/vit/vit_image_classifier.py | 49 ------------------- 3 files changed, 8 insertions(+), 54 deletions(-) diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index e75e39089..ceafa76cb 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -117,10 +117,12 @@ def __init__( dtype=head_dtype, name="pooler", ) + elif pooling == "token": + self.pooler = None else: raise ValueError( "Unknown `pooling` type. Polling should be either `'avg'` or " - f"`'max'`. Received: pooling={pooling}." + f"`'max' or 'token'`. Received: pooling={pooling}." ) self.output_dropout = keras.layers.Dropout( dropout, @@ -137,7 +139,10 @@ def __init__( # === Functional Model === inputs = self.backbone.input x = self.backbone(inputs) - x = self.pooler(x) + if pooling == "token": # used for Vision Transformer(ViT) + x = x[:, 0] + else: + x = self.pooler(x) x = self.output_dropout(x) outputs = self.output_dense(x) super().__init__( diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 693bec4e2..33e8b610b 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -53,7 +53,7 @@ def __init__( dtype=dtype, )(inputs) - x = ViTEncoder( + output = ViTEncoder( num_layers=num_layers, num_heads=num_heads, hidden_dim=hidden_dim, @@ -64,8 +64,6 @@ def __init__( dtype=dtype, )(x) - output = x[:, 0] - super().__init__( inputs=inputs, outputs=output, diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py index e24312e5f..1aab26c0d 100644 --- a/keras_hub/src/models/vit/vit_image_classifier.py +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -1,5 +1,3 @@ -import keras - from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.vit.vit_backbone import ViTBackbone @@ -12,50 +10,3 @@ class ViTImageClassifier(ImageClassifier): backbone_cls = ViTBackbone preprocessor_cls = ViTImageClassifierPreprocessor - - def __init__( - self, - backbone, - num_classes, - preprocessor=None, - activation=None, - head_dtype=None, - **kwargs, - ): - head_dtype = head_dtype or backbone.dtype_policy - - # === Layers === - self.backbone = backbone - self.preprocessor = preprocessor - - self.output_dense = keras.layers.Dense( - num_classes, - activation=activation, - dtype=head_dtype, - name="predictions", - ) - - # === Functional Model === - inputs = self.backbone.input - x = self.backbone(inputs) - outputs = self.output_dense(x) - super().__init__( - inputs=inputs, - outputs=outputs, - **kwargs, - ) - - # === Config === - self.num_classes = num_classes - self.activation = activation - - def get_config(self): - # Backbone serialized in `super` - config = super().get_config() - config.update( - { - "num_classes": self.num_classes, - "pooling": self.pooling, - } - ) - return config From 4232a0659656ee0912cfb44e455785238240b334 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 12:23:00 -0800 Subject: [PATCH 08/35] Update backbone args and configs --- keras_hub/src/models/vit/vit_backbone.py | 38 +++++++++++- keras_hub/src/models/vit/vit_layers.py | 76 ++++++++++++++++++++---- 2 files changed, 102 insertions(+), 12 deletions(-) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 33e8b610b..19840de2d 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -25,8 +25,8 @@ def __init__( **kwargs, ): data_format = standardize_data_format(data_format) - h_axis, w_axis = ( - (-3, -2) if data_format == "channels_last" else (-2, -1) + h_axis, w_axis, channels_axis = ( + (-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3) ) # Check that the input image is well specified. if image_shape[h_axis] is None or image_shape[w_axis] is None: @@ -43,6 +43,8 @@ def __init__( f"{image_shape}" ) + num_channels = image_shape[channels_axis] + # === Functional Model === inputs = keras.layers.Input(shape=image_shape) @@ -50,7 +52,9 @@ def __init__( image_size=image_shape[h_axis], patch_size=patch_size, hidden_dim=hidden_dim, + num_channels=num_channels, dtype=dtype, + name="vit_patching_and_embedding", )(inputs) output = ViTEncoder( @@ -62,6 +66,7 @@ def __init__( attention_dropout=attention_dropout, layer_norm_epsilon=layer_norm_epsilon, dtype=dtype, + name="vit_encoder", )(x) super().__init__( @@ -69,3 +74,32 @@ def __init__( outputs=output, **kwargs, ) + + # === Config === + self.image_shape = image_shape + self.patch_size = patch_size + self.num_layers = num_layers + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.mlp_dim = mlp_dim + self.dropout_rate = dropout_rate + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.data_format = data_format + + def get_config(self): + config = super().get_config() + config.update( + { + "image_shape": self.image_shape, + "patch_size": self.patch_size, + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_dim": self.mlp_dim, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 22322d3f3..5afc1f46d 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -1,5 +1,3 @@ -import math - import keras from keras import ops @@ -37,7 +35,7 @@ def __init__( ): super().__init__(**kwargs) - # === config === + # === Config === self.hidden_dim = hidden_dim self.mlp_dim = mlp_dim self.use_bias = use_bias @@ -93,10 +91,10 @@ def __init__( num_patches = (image_size // patch_size) ** 2 num_positions = num_patches + 1 - # === config === - self.hidden_dim = hidden_dim + # === Config === self.image_size = image_size self.patch_size = patch_size + self.hidden_dim = hidden_dim self.num_channels = num_channels self.num_patches = num_patches self.num_positions = num_positions @@ -109,9 +107,6 @@ def build(self, input_shape): strides=self.patch_size, padding="valid", activation=None, - kernel_initializer=keras.initializers.RandomNormal( - stddev=math.sqrt(1 / (3 * self.patch_size * self.patch_size)), - ), dtype=self.dtype_policy, data_format=self.data_format, name="patch_embedding", @@ -122,6 +117,7 @@ def build(self, input_shape): self.num_positions, self.hidden_dim, dtype=self.dtype_policy, + embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02), name="position_embedding", ) self.position_embedding.build([1, self.num_positions]) @@ -147,6 +143,20 @@ def compute_output_shape(self, input_shape): self.hidden_dim, ) + def get_config(self): + config = super().get_config() + config.update( + { + "image_size": self.image_size, + "patch_size": self.patch_size, + "hidden_dim": self.hidden_dim, + "num_channels": self.num_channels, + "num_patches": self.num_patches, + "num_positions": self.num_positions, + } + ) + return config + class ViTEncoderBlock(keras.layers.Layer): def __init__( @@ -154,6 +164,8 @@ def __init__( num_heads, hidden_dim, mlp_dim, + use_mha_bias=True, + use_mlp_bias=True, dropout_rate=0.0, attention_dropout=0.0, layer_norm_epsilon=1e-6, @@ -163,11 +175,13 @@ def __init__( key_dim = hidden_dim // num_heads - # === config === + # === Config === self.num_heads = num_heads self.hidden_dim = hidden_dim self.key_dim = key_dim self.mlp_dim = mlp_dim + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias self.dropout_rate = dropout_rate self.attention_dropout = attention_dropout self.layer_norm_epsilon = layer_norm_epsilon @@ -183,7 +197,7 @@ def build(self, input_shape): self.mha = keras.layers.MultiHeadAttention( num_heads=self.num_heads, key_dim=self.key_dim, - use_bias=False, + use_bias=self.use_mha_bias, dropout=self.attention_dropout, name="mha", dtype=self.dtype_policy, @@ -201,6 +215,7 @@ def build(self, input_shape): self.mlp = MLP( hidden_dim=self.hidden_dim, mlp_dim=self.mlp_dim, + use_bias=self.use_mlp_bias, name="mlp", dtype=self.dtype_policy, ) @@ -218,6 +233,23 @@ def call(self, inputs): return x + y + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "key_dim": self.key_dim, + "mlp_dim": self.mlp_dim, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config + class ViTEncoder(keras.layers.Layer): def __init__( @@ -226,6 +258,8 @@ def __init__( num_heads, hidden_dim, mlp_dim, + use_mha_bias=True, + use_mlp_bias=True, dropout_rate=0.0, attention_dropout=0.0, layer_norm_epsilon=1e-6, @@ -238,6 +272,8 @@ def __init__( self.num_heads = num_heads self.hidden_dim = hidden_dim self.mlp_dim = mlp_dim + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias self.dropout_rate = dropout_rate self.attention_dropout = attention_dropout self.layer_norm_epsilon = layer_norm_epsilon @@ -250,6 +286,8 @@ def build(self, input_shape): hidden_dim=self.hidden_dim, mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, + use_mha_bias=self.use_mha_bias, + use_mlp_bias=self.use_mlp_bias, attention_dropout=self.attention_dropout, layer_norm_epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, @@ -265,9 +303,27 @@ def build(self, input_shape): name="ln", ) self.layer_norm.build((None, None, self.hidden_dim)) + self.built = True def call(self, inputs): x = self.dropout(inputs) x = self.encoder_layers(x) x = self.layer_norm(x) return x + + def get_config(self): + config = super().get_config() + config.update( + { + "num_layers": self.num_layers, + "num_heads": self.num_heads, + "hidden_dim": self.hidden_dim, + "mlp_dim": self.mlp_dim, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, + "dropout_rate": self.dropout_rate, + "attention_dropout": self.attention_dropout, + "layer_norm_epsilon": self.layer_norm_epsilon, + } + ) + return config From 32b08c5bfecde6ad8e0c1c06a20859f2c1902ed9 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 12:32:27 -0800 Subject: [PATCH 09/35] correct position ids dtype --- keras_hub/src/models/vit/vit_layers.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 5afc1f46d..af782d664 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -121,8 +121,15 @@ def build(self, input_shape): name="position_embedding", ) self.position_embedding.build([1, self.num_positions]) - self.position_ids = ops.expand_dims( - ops.arange(self.num_positions), axis=0 + self.position_ids = self.add_weight( + shape=(1, self.num_positions), + initializer="zeros", + # Let the backend determine the int dtype. For example, tf + # requires int64 for correct device placement, whereas jax and torch + # don't. + dtype=int, + trainable=False, + name="position_ids", ) self.built = True From cc938c68839c7047c7dfe4fb6317089bac757024 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 13:33:09 -0800 Subject: [PATCH 10/35] build token layer --- keras_hub/src/models/vit/vit_layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index af782d664..46cad86e6 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -113,6 +113,7 @@ def build(self, input_shape): ) self.patch_embedding.build(input_shape) self.token_layer = TokenLayer(dtype=self.dtype_policy) + self.build(input_shape) self.position_embedding = keras.layers.Embedding( self.num_positions, self.hidden_dim, From 78812ded1f54514c437c32f8030a3d00c4f56cb6 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 13:33:41 -0800 Subject: [PATCH 11/35] token layer build --- keras_hub/src/models/vit/vit_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 46cad86e6..5da27df65 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -113,7 +113,7 @@ def build(self, input_shape): ) self.patch_embedding.build(input_shape) self.token_layer = TokenLayer(dtype=self.dtype_policy) - self.build(input_shape) + self.token_layer.build(input_shape) self.position_embedding = keras.layers.Embedding( self.num_positions, self.hidden_dim, From 8a2046525118a5ea5d2a1c26b5e2f3b66a752362 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 13:38:27 -0800 Subject: [PATCH 12/35] assign correct dtype to TokenLayer --- keras_hub/src/models/vit/vit_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 5da27df65..6d60396bf 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -12,7 +12,7 @@ def build(self, input_shape): self.cls_token = self.add_weight( shape=(1, 1, input_shape[-1]), initializer="zeros", - dtype=self.dtype_policy, + dtype=self.dtype, name="cls_token", ) self.built = True From de754cca099c7858b50ee92cca4f4d387b3becef Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 13:53:34 -0800 Subject: [PATCH 13/35] fix build shape of token layer --- keras_hub/src/models/vit/vit_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 6d60396bf..237a93ffd 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -113,7 +113,7 @@ def build(self, input_shape): ) self.patch_embedding.build(input_shape) self.token_layer = TokenLayer(dtype=self.dtype_policy) - self.token_layer.build(input_shape) + self.token_layer.build((None, None, self.hidden_dim)) self.position_embedding = keras.layers.Embedding( self.num_positions, self.hidden_dim, From 84ba8968617c87d92e9713779fd429cc2680da46 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 15:09:06 -0800 Subject: [PATCH 14/35] correct mlp dens var names --- keras_hub/src/models/vit/vit_layers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 237a93ffd..26ae7c5ce 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -42,7 +42,7 @@ def __init__( self.dropout_rate = dropout_rate def build(self, input_shape): - self.dense1 = keras.layers.Dense( + self.dense_1 = keras.layers.Dense( units=self.mlp_dim, use_bias=self.use_bias, activation="gelu", @@ -54,8 +54,8 @@ def build(self, input_shape): dtype=self.dtype_policy, name="dense_1", ) - self.dense1.build(input_shape) - self.dense2 = keras.layers.Dense( + self.dense_1.build(input_shape) + self.dense_2 = keras.layers.Dense( units=self.hidden_dim, use_bias=self.use_bias, bias_initializer=( @@ -66,13 +66,13 @@ def build(self, input_shape): dtype=self.dtype_policy, name="dense_2", ) - self.dense2.build((None, None, self.mlp_dim)) + self.dense_2.build((None, None, self.mlp_dim)) self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") self.built = True def call(self, inputs): - x = self.dense1(inputs) - x = self.dense2(x) + x = self.dense_1(inputs) + x = self.dense_2(x) out = self.dropout(x) return out From 7a70e161bedca0233040d9145a190b761617f471 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 15:44:02 -0800 Subject: [PATCH 15/35] use default norm mean and std as per hugging face config --- .../src/models/vit/vit_image_converter.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/keras_hub/src/models/vit/vit_image_converter.py b/keras_hub/src/models/vit/vit_image_converter.py index 79d3007ea..705c8a8b4 100644 --- a/keras_hub/src/models/vit/vit_image_converter.py +++ b/keras_hub/src/models/vit/vit_image_converter.py @@ -1,8 +1,37 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.layers.preprocessing.image_converter import ImageConverter from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.utils.tensor_utils import preprocessing_function @keras_hub_export("keras_hub.layers.ViTImageConverter") class ViTImageConverter(ImageConverter): backbone_cls = ViTBackbone + + def __init__( + self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs + ): + super().__init__(**kwargs) + self.norm_mean = norm_mean + self.norm_std = norm_std + + @preprocessing_function + def call(self, inputs): + x = super().call(inputs) + # By default normalize using imagenet mean and std + if self.norm_mean: + x = x - self._expand_non_channel_dims(self.norm_mean, x) + if self.norm_std: + x = x / self._expand_non_channel_dims(self.norm_std, x) + + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "norm_mean": self.norm_mean, + "norm_std": self.norm_std, + } + ) + return config From 81e3021fc6284b01338f969e736185e3f4e57964 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 18 Nov 2024 16:36:16 -0800 Subject: [PATCH 16/35] correct position_ids --- keras_hub/src/models/vit/vit_layers.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 26ae7c5ce..4be066288 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -122,15 +122,8 @@ def build(self, input_shape): name="position_embedding", ) self.position_embedding.build([1, self.num_positions]) - self.position_ids = self.add_weight( - shape=(1, self.num_positions), - initializer="zeros", - # Let the backend determine the int dtype. For example, tf - # requires int64 for correct device placement, whereas jax and torch - # don't. - dtype=int, - trainable=False, - name="position_ids", + self.position_ids = keras.ops.expand_dims( + keras.ops.arange(self.num_positions), axis=0 ) self.built = True From d3061d6210c55639a9b9d090995641011c547d3f Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 19 Nov 2024 11:18:24 -0800 Subject: [PATCH 17/35] remove separate token layer --- keras_hub/src/models/vit/vit_layers.py | 61 +++++++++++++------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 4be066288..3ebf10a9f 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -4,26 +4,6 @@ from keras_hub.src.utils.keras_utils import standardize_data_format -class TokenLayer(keras.layers.Layer): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def build(self, input_shape): - self.cls_token = self.add_weight( - shape=(1, 1, input_shape[-1]), - initializer="zeros", - dtype=self.dtype, - name="cls_token", - ) - self.built = True - - def call(self, inputs): - cls_token = self.cls_token + keras.ops.zeros_like(inputs[:, 0:1]) - out = keras.ops.concatenate([cls_token, inputs], axis=1) - - return out - - class MLP(keras.layers.Layer): def __init__( self, @@ -101,6 +81,12 @@ def __init__( self.data_format = standardize_data_format(data_format) def build(self, input_shape): + self.class_token = self.add_weight( + shape=(self.hidden_dim,), + initializer="random_normal", + dtype=self.variable_dtype, + name="class_token", + ) self.patch_embedding = keras.layers.Conv2D( filters=self.hidden_dim, kernel_size=self.patch_size, @@ -112,8 +98,6 @@ def build(self, input_shape): name="patch_embedding", ) self.patch_embedding.build(input_shape) - self.token_layer = TokenLayer(dtype=self.dtype_policy) - self.token_layer.build((None, None, self.hidden_dim)) self.position_embedding = keras.layers.Embedding( self.num_positions, self.hidden_dim, @@ -122,20 +106,35 @@ def build(self, input_shape): name="position_embedding", ) self.position_embedding.build([1, self.num_positions]) - self.position_ids = keras.ops.expand_dims( - keras.ops.arange(self.num_positions), axis=0 + self.position_ids = self.add_weight( + shape=(1, self.num_positions), + initializer="zeros", + # Let the backend determine the int dtype. For example, tf + # requires int64 for correct device placement, whereas jax and torch + # don't. + dtype=int, + trainable=False, + name="position_ids", ) self.built = True def call(self, inputs): - x = self.patch_embedding(inputs) - input_shape = ops.shape(x) # (N, H, W, C) or (N, C, H, W) + patch_embeddings = self.patch_embedding(inputs) + input_shape = ops.shape( + patch_embeddings + ) # (N, H, W, C) or (N, C, H, W) if self.data_format == "channels_first": - x = ops.transpose(x, axes=(0, 2, 3, 1)) - x = ops.reshape(x, [input_shape[0], -1, input_shape[-1]]) - x = self.token_layer(x) - x = x + self.position_embedding(self.position_ids) - return x + patch_embeddings = ops.transpose( + patch_embeddings, axes=(0, 2, 3, 1) + ) + patch_embeddings = ops.reshape( + patch_embeddings, [input_shape[0], -1, input_shape[-1]] + ) + class_token = ops.expand_dims(self.class_token, axis=(0, 1)) + class_token = ops.tile(class_token, (input_shape[0], 1, 1)) + position_embeddings = self.position_embedding(self.position_ids) + embeddings = ops.concatenate([class_token, patch_embeddings], axis=1) + return ops.add(embeddings, position_embeddings) def compute_output_shape(self, input_shape): return ( From 618e163cb5b8bcc6d36fcc5e44c89acbd4560d48 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 19 Nov 2024 11:25:17 -0800 Subject: [PATCH 18/35] correct position ids --- keras_hub/src/models/vit/vit_layers.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 3ebf10a9f..7ea04d979 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -106,15 +106,8 @@ def build(self, input_shape): name="position_embedding", ) self.position_embedding.build([1, self.num_positions]) - self.position_ids = self.add_weight( - shape=(1, self.num_positions), - initializer="zeros", - # Let the backend determine the int dtype. For example, tf - # requires int64 for correct device placement, whereas jax and torch - # don't. - dtype=int, - trainable=False, - name="position_ids", + self.position_ids = keras.ops.expand_dims( + keras.ops.arange(self.num_positions), axis=0 ) self.built = True From 2338637658d1faef98e69601788bdea931ce1966 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 10:27:59 -0800 Subject: [PATCH 19/35] Checkpoint conversion script and minor changes --- keras_hub/src/models/vit/vit_backbone.py | 4 + .../src/models/vit/vit_image_classifier.py | 3 + keras_hub/src/models/vit/vit_layers.py | 16 +- .../convert_vit_checkpoints.py | 321 ++++++++++++++++++ 4 files changed, 337 insertions(+), 7 deletions(-) create mode 100644 tools/checkpoint_conversion/convert_vit_checkpoints.py diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 19840de2d..027be5aa2 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -20,6 +20,8 @@ def __init__( dropout_rate=0.0, attention_dropout=0.0, layer_norm_epsilon=1e-6, + use_mha_bias=True, + use_mlp_bias=True, data_format=None, dtype=None, **kwargs, @@ -65,6 +67,8 @@ def __init__( dropout_rate=dropout_rate, attention_dropout=attention_dropout, layer_norm_epsilon=layer_norm_epsilon, + use_mha_bias=use_mha_bias, + use_mlp_bias=use_mlp_bias, dtype=dtype, name="vit_encoder", )(x) diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py index 1aab26c0d..579538b6b 100644 --- a/keras_hub/src/models/vit/vit_image_classifier.py +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -10,3 +10,6 @@ class ViTImageClassifier(ImageClassifier): backbone_cls = ViTBackbone preprocessor_cls = ViTImageClassifierPreprocessor + + def __init__(self, pooling="token", **kwargs): + super().__init__(pooling=pooling, **kwargs) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 7ea04d979..eef58ca49 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -82,7 +82,11 @@ def __init__( def build(self, input_shape): self.class_token = self.add_weight( - shape=(self.hidden_dim,), + shape=( + 1, + 1, + self.hidden_dim, + ), initializer="random_normal", dtype=self.variable_dtype, name="class_token", @@ -105,7 +109,7 @@ def build(self, input_shape): embeddings_initializer=keras.initializers.RandomNormal(stddev=0.02), name="position_embedding", ) - self.position_embedding.build([1, self.num_positions]) + self.position_embedding.build((1, self.num_positions)) self.position_ids = keras.ops.expand_dims( keras.ops.arange(self.num_positions), axis=0 ) @@ -123,8 +127,7 @@ def call(self, inputs): patch_embeddings = ops.reshape( patch_embeddings, [input_shape[0], -1, input_shape[-1]] ) - class_token = ops.expand_dims(self.class_token, axis=(0, 1)) - class_token = ops.tile(class_token, (input_shape[0], 1, 1)) + class_token = ops.tile(self.class_token, (input_shape[0], 1, 1)) position_embeddings = self.position_embedding(self.position_ids) embeddings = ops.concatenate([class_token, patch_embeddings], axis=1) return ops.add(embeddings, position_embeddings) @@ -272,7 +275,7 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon def build(self, input_shape): - layers = [] + self.encoder_layers = keras.Sequential(name="encoder_layers") for i in range(self.num_layers): encoder_block = ViTEncoderBlock( num_heads=self.num_heads, @@ -287,9 +290,8 @@ def build(self, input_shape): name=f"tranformer_block_{i+1}", ) encoder_block.build((None, None, self.hidden_dim)) - layers.append(encoder_block) + self.encoder_layers.add(encoder_block) self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") - self.encoder_layers = keras.Sequential(layers, name="encoder_layers") self.layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, diff --git a/tools/checkpoint_conversion/convert_vit_checkpoints.py b/tools/checkpoint_conversion/convert_vit_checkpoints.py new file mode 100644 index 000000000..109f80212 --- /dev/null +++ b/tools/checkpoint_conversion/convert_vit_checkpoints.py @@ -0,0 +1,321 @@ +"""Convert ViT checkpoints. + +export KAGGLE_USERNAME=XXX +export KAGGLE_KEY=XXX + +python tools/checkpoint_conversion/convert_vit_checkpoints.py \ + --preset vit_base_patch16_224 +""" + +import os +import shutil + +import keras +import numpy as np +import torch +from absl import app +from absl import flags +from PIL import Image +from transformers import ViTForImageClassification +from transformers import ViTImageProcessor + +import keras_hub +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "vit_base_patch16_224": "google/vit-base-patch16-224", + "vit_base_patch16_384": "google/vit-base-patch16-384", + "vit_large_patch16_224": "google/vit-large-patch16-224", + "vit_large_patch16_384": "google/vit-large-patch16-384", +} + +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}', + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}"', + required=False, +) + +flags.DEFINE_string( + "backbone_conversion_only", + False, + "Set to `True` when you want to convert only backbone when classification " + "head weights are not available", +) + + +def convert_model(hf_model): + config = hf_model.config.to_dict() + image_size = config["image_size"] + backbone = ViTBackbone( + image_shape=(image_size, image_size, 3), + patch_size=config["patch_size"], + num_layers=config["num_hidden_layers"], + num_heads=config["num_heads"], + hidden_dim=config["hidden_size"], + mlp_dim=config["intermediate_size"], + dropout_rate=config["hidden_dropout_prob"], + attention_dropout=config["attention_probs_dropout_prob"], + use_mha_bias=config["qkv_bias"], + ) + if FLAGS.backbone_conversion_only: + return backbone + + return ViTImageClassifier( + backbone=backbone, + num_classes=1000, # num classes in ImageNet + ) + + +def convert_weights(keras_hub_model, hf_model): + state_dict = hf_model.state_dict() + state_dict.update(hf_model.named_buffers()) + + # Helper functions. + def port_weights(keras_variable, weight_key, hook_fn=None): + torch_tensor = state_dict[weight_key].cpu().numpy() + if hook_fn: + torch_tensor = hook_fn(torch_tensor, list(keras_variable.shape)) + keras_variable.assign(torch_tensor) + + def port_ln(keras_variable, weight_key): + port_weights(keras_variable.gamma, f"{weight_key}.weight") + port_weights(keras_variable.beta, f"{weight_key}.bias") + + def port_dense(keras_variable, weight_key): + port_weights( + keras_variable.kernel, + f"{weight_key}.weight", + hook_fn=lambda x, _: x.T, + ) + if keras_variable.bias is not None: + port_weights(keras_variable.bias, f"{weight_key}.bias") + + def port_mha(keras_variable, weight_key, num_heads, hidden_dim): + # query + port_weights( + keras_variable.query_dense.kernel, + f"{weight_key}.attention.query.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + port_weights( + keras_variable.query_dense.bias, + f"{weight_key}.attention.query.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # key + port_weights( + keras_variable.key_dense.kernel, + f"{weight_key}.attention.key.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + port_weights( + keras_variable.key_dense.bias, + f"{weight_key}.attention.key.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # value + port_weights( + keras_variable.value_dense.kernel, + f"{weight_key}.attention.value.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + port_weights( + keras_variable.value_dense.bias, + f"{weight_key}.attention.value.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # output + port_weights( + keras_variable.output_dense.kernel, + f"{weight_key}.output.dense.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (num_heads, hidden_dim // num_heads, hidden_dim) + ), + ) + port_weights( + keras_variable.output_dense.bias, f"{weight_key}.output.dense.bias" + ) + + port_weights( + keras_hub_model.backbone.layers[1].patch_embedding.kernel, + "vit.embeddings.patch_embeddings.projection.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + port_weights( + keras_hub_model.backbone.layers[1].patch_embedding.bias, + "vit.embeddings.patch_embeddings.projection.bias", + ) + + port_weights( + keras_hub_model.backbone.layers[1].class_token, + "vit.embeddings.cls_token", + ) + + port_weights( + keras_hub_model.backbone.layers[1].position_embedding.embeddings, + "vit.embeddings.position_embeddings", + hook_fn=lambda x, _: x[0], + ) + encoder_layers = keras_hub_model.backbone.layers[2].encoder_layers + for i, encoder_block in enumerate(encoder_layers): + prefix = "vit.encoder.layer" + num_heads = encoder_block.num_heads + hidden_dim = encoder_block.hidden_dim + + port_mha( + encoder_block.mha, + f"{prefix}.{i}.attention", + num_heads, + hidden_dim, + ) + port_ln(encoder_block.layer_norm_1, f"{prefix}.{i}.layernorm_before") + port_ln(encoder_block.layer_norm_2, f"{prefix}.{i}.layernorm_after") + + port_dense( + encoder_block.mlp.dense_1, f"{prefix}.{i}.intermediate.dense" + ) + port_dense(encoder_block.mlp.dense_2, f"{prefix}.{i}.output.dense") + + port_ln(keras_hub_model.backbone.layers[2].layer_norm, "vit.layernorm") + if not FLAGS.backbone_conversion_only: + port_dense(keras_hub_model.output_dense, "classifier") + + +def convert_image_converter(hf_image_processor): + config = hf_image_processor.to_dict() + image_size = (config["size"]["height"], config["size"]["width"]) + std = config["image_std"] + mean = config["image_mean"] + return ViTImageConverter( + image_size=image_size, + scale=config["rescale_factor"], + norm_mean=mean, + norm_std=std, + interpolation="bilinear", # ViT defaults to bilinear resampling. + ) + + +def validate_output( + keras_model, + keras_image_converter, + hf_model, + hf_image_processor, +): + file = keras.utils.get_file( + origin=("http://images.cocodataset.org/val2017/000000039769.jpg") + ) + image = Image.open(file) + + # Preprocess with hf. + hf_inputs = hf_image_processor( + image, + return_tensors="pt", + ) + hf_preprocessed = hf_inputs["pixel_values"].detach().cpu().numpy() + + # Preprocess with keras. + images = np.expand_dims(np.array(image).astype("float32"), axis=0) + images = np.concatenate([images, images], axis=0) + images = keras_image_converter(images) + keras_preprocessed = keras.ops.convert_to_numpy(images) + + # Call with hf. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + hf_inputs["pixel_values"] = torch.from_numpy( + keras.ops.convert_to_numpy( + keras.ops.transpose(keras_preprocessed, (0, 3, 1, 2)) + ) + ) + hf_outputs = hf_model(**hf_inputs) + hf_vision_logits = hf_outputs.logits.detach().cpu().numpy() + + # Call with keras. + keras_outputs = keras_model(keras_preprocessed) + keras_vision_logits = keras.ops.convert_to_numpy(keras_outputs) + + print("🔶 Keras output:", keras_vision_logits[0, :10]) + print("🔶 HF output:", hf_vision_logits[0, :10]) + modeling_diff = np.mean(np.abs(keras_vision_logits - hf_vision_logits)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean( + np.abs(keras_preprocessed - np.transpose(hf_preprocessed, (0, 2, 3, 1))) + ) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + print(f"🏃 Coverting {preset}") + + # Load huggingface model. + hf_model = ViTForImageClassification.from_pretrained(hf_preset) + hf_preprocessor = ViTImageProcessor.from_pretrained(hf_preset) + hf_model.eval() + + keras_model = convert_model(hf_model) + keras_image_converter = convert_image_converter(hf_preprocessor) + keras_image_preprocessor = ViTImageClassifierPreprocessor( + image_converter=keras_image_converter + ) + print("✅ KerasHub model loaded.") + + convert_weights(keras_model, hf_model) + print("✅ Weights converted.") + + validate_output( + keras_model, + keras_image_converter, + hf_model, + hf_preprocessor, + ) + print("✅ Output validated.") + + keras_model.save_to_preset(f"./{preset}") + keras_image_preprocessor.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}.") + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main) From 95e58681401f28463b7876d98a36eec9ad24c2dd Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 10:38:46 -0800 Subject: [PATCH 20/35] correct flag type --- tools/checkpoint_conversion/convert_vit_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/checkpoint_conversion/convert_vit_checkpoints.py b/tools/checkpoint_conversion/convert_vit_checkpoints.py index 109f80212..393bad956 100644 --- a/tools/checkpoint_conversion/convert_vit_checkpoints.py +++ b/tools/checkpoint_conversion/convert_vit_checkpoints.py @@ -49,7 +49,7 @@ required=False, ) -flags.DEFINE_string( +flags.DEFINE_bool( "backbone_conversion_only", False, "Set to `True` when you want to convert only backbone when classification " From 9d2e5bdd73699eb7a957ddd6940e4abc040044d7 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 10:40:51 -0800 Subject: [PATCH 21/35] correct key name --- tools/checkpoint_conversion/convert_vit_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/checkpoint_conversion/convert_vit_checkpoints.py b/tools/checkpoint_conversion/convert_vit_checkpoints.py index 393bad956..d7282d2df 100644 --- a/tools/checkpoint_conversion/convert_vit_checkpoints.py +++ b/tools/checkpoint_conversion/convert_vit_checkpoints.py @@ -64,7 +64,7 @@ def convert_model(hf_model): image_shape=(image_size, image_size, 3), patch_size=config["patch_size"], num_layers=config["num_hidden_layers"], - num_heads=config["num_heads"], + num_heads=config["num_attention_heads"], hidden_dim=config["hidden_size"], mlp_dim=config["intermediate_size"], dropout_rate=config["hidden_dropout_prob"], From ac7d1d3d1f830fd90127d179210ddfd5bf90741c Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 10:45:05 -0800 Subject: [PATCH 22/35] use flat list later we can extract in between layers if needed --- keras_hub/src/models/vit/vit_layers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index eef58ca49..449dca0f2 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -275,7 +275,7 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon def build(self, input_shape): - self.encoder_layers = keras.Sequential(name="encoder_layers") + self.encoder_layers = [] for i in range(self.num_layers): encoder_block = ViTEncoderBlock( num_heads=self.num_heads, @@ -290,7 +290,7 @@ def build(self, input_shape): name=f"tranformer_block_{i+1}", ) encoder_block.build((None, None, self.hidden_dim)) - self.encoder_layers.add(encoder_block) + self.encoder_layers.append(encoder_block) self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") self.layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, @@ -302,7 +302,8 @@ def build(self, input_shape): def call(self, inputs): x = self.dropout(inputs) - x = self.encoder_layers(x) + for i in range(self.num_layers): + x = self.encoder_layers[i](x) x = self.layer_norm(x) return x From 8065c01b293545598edda709f44c6df4a8729145 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 14:58:29 -0800 Subject: [PATCH 23/35] Add test cases and correct dtype polciy for model --- keras_hub/src/models/vit/vit_backbone.py | 1 + keras_hub/src/models/vit/vit_backbone_test.py | 36 ++++++++++++ .../models/vit/vit_image_classifier_test.py | 57 +++++++++++++++++++ keras_hub/src/models/vit/vit_layers.py | 12 +++- .../convert_vit_checkpoints.py | 16 +----- 5 files changed, 106 insertions(+), 16 deletions(-) create mode 100644 keras_hub/src/models/vit/vit_backbone_test.py create mode 100644 keras_hub/src/models/vit/vit_image_classifier_test.py diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 027be5aa2..fa6b18ec7 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -76,6 +76,7 @@ def __init__( super().__init__( inputs=inputs, outputs=output, + dtype=dtype, **kwargs, ) diff --git a/keras_hub/src/models/vit/vit_backbone_test.py b/keras_hub/src/models/vit/vit_backbone_test.py new file mode 100644 index 000000000..9a0368402 --- /dev/null +++ b/keras_hub/src/models/vit/vit_backbone_test.py @@ -0,0 +1,36 @@ +import pytest +from keras import ops + +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.tests.test_case import TestCase + + +class ViTBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "image_shape": (28, 28, 3), + "patch_size": 4, + "num_layers": 3, + "hidden_dim": 48, + "num_heads": 6, + "mlp_dim": 48 * 4, + "use_mha_bias": True, + } + self.input_size = 28 + self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) + + def test_backbone_basics(self): + self.run_vision_backbone_test( + cls=ViTBackbone, + init_kwargs={**self.init_kwargs}, + input_data=self.input_data, + expected_output_shape=(2, 50, 48), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ViTBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/vit/vit_image_classifier_test.py b/keras_hub/src/models/vit/vit_image_classifier_test.py new file mode 100644 index 000000000..a2e608594 --- /dev/null +++ b/keras_hub/src/models/vit/vit_image_classifier_test.py @@ -0,0 +1,57 @@ +import pytest +from keras import ops + +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier +from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( + ViTImageClassifierPreprocessor, +) +from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter +from keras_hub.src.tests.test_case import TestCase + + +class ViTImageClassifierTest(TestCase): + def setUp(self): + self.images = ops.ones((2, 28, 28, 3)) + self.labels = [0, 1] + self.backbone = ViTBackbone( + image_shape=(28, 28, 3), + patch_size=4, + num_layers=3, + num_heads=6, + hidden_dim=48, + mlp_dim=48 * 4, + ) + image_converter = ViTImageConverter( + image_size=(28, 28), + scale=1 / 255.0, + ) + preprocessor = ViTImageClassifierPreprocessor( + image_converter=image_converter + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "preprocessor": preprocessor, + } + self.train_data = (self.images, self.labels) + + def test_classifier_basics(self): + self.run_task_test( + cls=ViTImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + def test_head_dtype(self): + model = ViTImageClassifier(**self.init_kwargs, head_dtype="bfloat16") + self.assertEqual(model.output_dense.compute_dtype, "bfloat16") + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=ViTImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 449dca0f2..36a4b9e8b 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -47,7 +47,9 @@ def build(self, input_shape): name="dense_2", ) self.dense_2.build((None, None, self.mlp_dim)) - self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy, name="dropout" + ) self.built = True def call(self, inputs): @@ -199,7 +201,9 @@ def build(self, input_shape): dtype=self.dtype_policy, ) self.mha.build(input_shape, input_shape) - self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy, name="dropout" + ) # MLP block self.layer_norm_2 = keras.layers.LayerNormalization( @@ -291,7 +295,9 @@ def build(self, input_shape): ) encoder_block.build((None, None, self.hidden_dim)) self.encoder_layers.append(encoder_block) - self.dropout = keras.layers.Dropout(self.dropout_rate, name="dropout") + self.dropout = keras.layers.Dropout( + self.dropout_rate, dtype=self.dtype_policy, name="dropout" + ) self.layer_norm = keras.layers.LayerNormalization( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, diff --git a/tools/checkpoint_conversion/convert_vit_checkpoints.py b/tools/checkpoint_conversion/convert_vit_checkpoints.py index d7282d2df..4868a7b04 100644 --- a/tools/checkpoint_conversion/convert_vit_checkpoints.py +++ b/tools/checkpoint_conversion/convert_vit_checkpoints.py @@ -49,13 +49,6 @@ required=False, ) -flags.DEFINE_bool( - "backbone_conversion_only", - False, - "Set to `True` when you want to convert only backbone when classification " - "head weights are not available", -) - def convert_model(hf_model): config = hf_model.config.to_dict() @@ -71,8 +64,6 @@ def convert_model(hf_model): attention_dropout=config["attention_probs_dropout_prob"], use_mha_bias=config["qkv_bias"], ) - if FLAGS.backbone_conversion_only: - return backbone return ViTImageClassifier( backbone=backbone, @@ -204,8 +195,7 @@ def port_mha(keras_variable, weight_key, num_heads, hidden_dim): port_dense(encoder_block.mlp.dense_2, f"{prefix}.{i}.output.dense") port_ln(keras_hub_model.backbone.layers[2].layer_norm, "vit.layernorm") - if not FLAGS.backbone_conversion_only: - port_dense(keras_hub_model.output_dense, "classifier") + port_dense(keras_hub_model.output_dense, "classifier") def convert_image_converter(hf_image_processor): @@ -306,9 +296,9 @@ def main(_): hf_preprocessor, ) print("✅ Output validated.") - + keras_model.preprocessor = keras_image_preprocessor keras_model.save_to_preset(f"./{preset}") - keras_image_preprocessor.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}.") upload_uri = FLAGS.upload_uri From a8be82408f26c2dfff7d9964084cbde9bcac8571 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 21 Nov 2024 15:46:26 -0800 Subject: [PATCH 24/35] add proper docstrings --- keras_hub/src/models/vit/vit_backbone.py | 36 +++++++++++ .../src/models/vit/vit_image_converter.py | 36 +++++++++++ keras_hub/src/models/vit/vit_layers.py | 62 +++++++++++++++++++ 3 files changed, 134 insertions(+) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index fa6b18ec7..8fe3d2586 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -9,6 +9,42 @@ @keras_hub_export("keras_hub.models.ViTBackbone") class ViTBackbone(Backbone): + """Vision Transformer (ViT) backbone. + + This backbone implements the Vision Transformer architecture as described in + [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). + It transforms the input image into a sequence of patches, embeds them, and + then processes them through a series of Transformer encoder layers. + + Args: + image_shape: A tuple or list of 3 integers representing the shape of the + input image `(height, width, channels)`, `height` and `width` must + be equal. + patch_size: int. The size of each image patch, the input image will be + divided into patches of shape `(patch_size, patch_size)`. + num_layers: int. The number of transformer encoder layers. + num_heads: int. specifying the number of attention heads in each + Transformer encoder layer. + hidden_dim: int. The dimensionality of the hidden representations. + mlp_dim: int. The dimensionality of the intermediate MLP layer in + each Transformer encoder layer. + dropout_rate: float. The dropout rate for the Transformer encoder + layers. + attention_dropout: float. The dropout rate for the attention mechanism + in each Transformer encoder layer. + layer_norm_epsilon: float. Value used for numerical stability in + layer normalization. + use_mha_bias: bool. Whether to use bias in the multi-head + attention layers. + use_mlp_bias: bool. Whether to use bias in the MLP layers. + data_format: str. `"channels_last"` or `"channels_first"`, specifying + the data format for the input image. If `None`, defaults to + `"channels_last"`. + dtype: The dtype of the layer weights. Defaults to None. + **kwargs: Additional keyword arguments to be passed to the parent + `Backbone` class. + """ + def __init__( self, image_shape, diff --git a/keras_hub/src/models/vit/vit_image_converter.py b/keras_hub/src/models/vit/vit_image_converter.py index 705c8a8b4..b1699640c 100644 --- a/keras_hub/src/models/vit/vit_image_converter.py +++ b/keras_hub/src/models/vit/vit_image_converter.py @@ -6,6 +6,42 @@ @keras_hub_export("keras_hub.layers.ViTImageConverter") class ViTImageConverter(ImageConverter): + """Converts images to the format expected by a ViT model. + + This layer performs image normalization using mean and standard deviation values. + By default, it uses the same normalization as the + "google/vit-large-patch16-224" model on Hugging Face: + `norm_mean=[0.5, 0.5, 0.5]` and `norm_std=[0.5, 0.5, 0.5]` + ([reference](https://huggingface.co/google/vit-large-patch16-224/blob/main/preprocessor_config.json)). + These defaults are suitable for models pretrained using this normalization. + + Args: + norm_mean: list or tuple of floats. Mean values for image normalization. + Defaults to `[0.5, 0.5, 0.5]`. + norm_std: list or tuple of floats. Standard deviation values for + image normalization. Defaults to `[0.5, 0.5, 0.5]`. + **kwargs: Additional keyword arguments passed to + `keras_hub.layers.preprocessing.ImageConverter`. + + Examples: + ```python + import keras + import numpy as np + from keras_hub.src.layers import ViTImageConverter + + # Example image (replace with your actual image data) + image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C) + + # Create a ViTImageConverter instance + converter = ViTImageConverter( + image_size=(28,28), + scale=1/255. + ) + # Preprocess the image + preprocessed_image = converter(image) + ``` + """ + backbone_cls = ViTBackbone def __init__( diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index 36a4b9e8b..cec21e751 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -5,6 +5,17 @@ class MLP(keras.layers.Layer): + """Multi-Layer Perceptron (MLP) block. + + Args: + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_bias: bool. Whether to use bias in the dense layers. Defaults to + `True`. + dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + def __init__( self, hidden_dim, @@ -60,6 +71,20 @@ def call(self, inputs): class ViTPatchingAndEmbedding(keras.layers.Layer): + """Patches the image and embeds the patches. + + Args: + image_size: int. Size of the input image (height or width). + Assumed to be square. + patch_size: int. Size of each image patch. + hidden_dim: int. Dimensionality of the patch embeddings. + num_channels: int. Number of channels in the input image. Defaults to + `3`. + data_format: str. `"channels_last"` or `"channels_first"`. Defaults to + `None` (which uses `"channels_last"`). + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + def __init__( self, image_size, @@ -157,6 +182,24 @@ def get_config(self): class ViTEncoderBlock(keras.layers.Layer): + """Transformer encoder block. + + Args: + num_heads: int. Number of attention heads. + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_mha_bias: bool. Whether to use bias in the multi-head attention + layer. Defaults to `True`. + use_mlp_bias: bool. Whether to use bias in the MLP layer. Defaults to + `True`. + dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + attention_dropout: float. Dropout rate for the attention mechanism. + Between 0 and 1. Defaults to `0.0`. + layer_norm_epsilon: float. Small float value for layer normalization + stability. Defaults to `1e-6`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + def __init__( self, num_heads, @@ -252,6 +295,25 @@ def get_config(self): class ViTEncoder(keras.layers.Layer): + """Vision Transformer (ViT) encoder. + + Args: + num_layers: int. Number of Transformer encoder blocks. + num_heads: int. Number of attention heads. + hidden_dim: int. Dimensionality of the hidden representations. + mlp_dim: int. Dimensionality of the intermediate MLP layer. + use_mha_bias: bool. Whether to use bias in the multi-head attention + layers. Defaults to `True`. + use_mlp_bias: bool. Whether to use bias in the MLP layers. Defaults to + `True`. + dropout_rate: float. Dropout rate. Between 0 and 1. Defaults to `0.0`. + attention_dropout: float. Dropout rate for the attention mechanism. + Between 0 and 1. Defaults to `0.0`. + layer_norm_epsilon: float. Small float value for layer normalization + tability. Defaults to `1e-6`. + **kwargs: Additional keyword arguments passed to `keras.layers.Layer` + """ + def __init__( self, num_layers, From 3f027a0ae6b5f8d5abf32acdb3301c7dd4d3c265 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Fri, 22 Nov 2024 14:14:15 -0800 Subject: [PATCH 25/35] correct test cases --- keras_hub/src/models/vit/vit_backbone.py | 1 + keras_hub/src/models/vit/vit_backbone_test.py | 3 ++- keras_hub/src/models/vit/vit_layers.py | 8 +++----- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index 8fe3d2586..d044f8def 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -91,6 +91,7 @@ def __init__( patch_size=patch_size, hidden_dim=hidden_dim, num_channels=num_channels, + data_format=data_format, dtype=dtype, name="vit_patching_and_embedding", )(inputs) diff --git a/keras_hub/src/models/vit/vit_backbone_test.py b/keras_hub/src/models/vit/vit_backbone_test.py index 9a0368402..0ab0b389c 100644 --- a/keras_hub/src/models/vit/vit_backbone_test.py +++ b/keras_hub/src/models/vit/vit_backbone_test.py @@ -20,11 +20,12 @@ def setUp(self): self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) def test_backbone_basics(self): - self.run_vision_backbone_test( + self.run_backbone_test( cls=ViTBackbone, init_kwargs={**self.init_kwargs}, input_data=self.input_data, expected_output_shape=(2, 50, 48), + run_quantization_check=False, ) @pytest.mark.large diff --git a/keras_hub/src/models/vit/vit_layers.py b/keras_hub/src/models/vit/vit_layers.py index cec21e751..8cdc52ca7 100644 --- a/keras_hub/src/models/vit/vit_layers.py +++ b/keras_hub/src/models/vit/vit_layers.py @@ -144,17 +144,15 @@ def build(self, input_shape): def call(self, inputs): patch_embeddings = self.patch_embedding(inputs) - input_shape = ops.shape( - patch_embeddings - ) # (N, H, W, C) or (N, C, H, W) if self.data_format == "channels_first": patch_embeddings = ops.transpose( patch_embeddings, axes=(0, 2, 3, 1) ) + embeddings_shape = ops.shape(patch_embeddings) patch_embeddings = ops.reshape( - patch_embeddings, [input_shape[0], -1, input_shape[-1]] + patch_embeddings, [embeddings_shape[0], -1, embeddings_shape[-1]] ) - class_token = ops.tile(self.class_token, (input_shape[0], 1, 1)) + class_token = ops.tile(self.class_token, (embeddings_shape[0], 1, 1)) position_embeddings = self.position_embedding(self.position_ids) embeddings = ops.concatenate([class_token, patch_embeddings], axis=1) return ops.add(embeddings, position_embeddings) From 05acb706a3fee3778442f9810d305449d68830f0 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 25 Nov 2024 13:48:04 -0800 Subject: [PATCH 26/35] use numpy for test data --- keras_hub/src/models/vit/vit_image_classifier_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/models/vit/vit_image_classifier_test.py b/keras_hub/src/models/vit/vit_image_classifier_test.py index a2e608594..b50e51196 100644 --- a/keras_hub/src/models/vit/vit_image_classifier_test.py +++ b/keras_hub/src/models/vit/vit_image_classifier_test.py @@ -1,5 +1,5 @@ import pytest -from keras import ops +import numpy as np from keras_hub.src.models.vit.vit_backbone import ViTBackbone from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier @@ -12,7 +12,7 @@ class ViTImageClassifierTest(TestCase): def setUp(self): - self.images = ops.ones((2, 28, 28, 3)) + self.images = np.ones((2, 28, 28, 3)) self.labels = [0, 1] self.backbone = ViTBackbone( image_shape=(28, 28, 3), From 521df6fb3e8248954c530d864f41460faea6f3d7 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 25 Nov 2024 13:55:10 -0800 Subject: [PATCH 27/35] nit --- keras_hub/src/models/vit/vit_image_classifier_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/vit/vit_image_classifier_test.py b/keras_hub/src/models/vit/vit_image_classifier_test.py index b50e51196..29e3d6692 100644 --- a/keras_hub/src/models/vit/vit_image_classifier_test.py +++ b/keras_hub/src/models/vit/vit_image_classifier_test.py @@ -1,5 +1,5 @@ -import pytest import numpy as np +import pytest from keras_hub.src.models.vit.vit_backbone import ViTBackbone from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier From ae2b800fac4d1366f9d635bf66939509a009291f Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 27 Nov 2024 12:20:13 -0800 Subject: [PATCH 28/35] nit --- keras_hub/src/models/vit/vit_backbone.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/keras_hub/src/models/vit/vit_backbone.py b/keras_hub/src/models/vit/vit_backbone.py index d044f8def..c34ab7d49 100644 --- a/keras_hub/src/models/vit/vit_backbone.py +++ b/keras_hub/src/models/vit/vit_backbone.py @@ -62,6 +62,7 @@ def __init__( dtype=None, **kwargs, ): + # === Laters === data_format = standardize_data_format(data_format) h_axis, w_axis, channels_axis = ( (-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3) @@ -127,6 +128,8 @@ def __init__( self.dropout_rate = dropout_rate self.attention_dropout = attention_dropout self.layer_norm_epsilon = layer_norm_epsilon + self.use_mha_bias = use_mha_bias + self.use_mlp_bias = use_mlp_bias self.data_format = data_format def get_config(self): @@ -142,6 +145,8 @@ def get_config(self): "dropout_rate": self.dropout_rate, "attention_dropout": self.attention_dropout, "layer_norm_epsilon": self.layer_norm_epsilon, + "use_mha_bias": self.use_mha_bias, + "use_mlp_bias": self.use_mlp_bias, } ) return config From 92149d5da76b33376718ab7963df60d9b582c062 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 2 Dec 2024 13:57:30 -0800 Subject: [PATCH 29/35] add presets --- keras_hub/src/models/vit/__init__.py | 5 +++ keras_hub/src/models/vit/vit_presets.py | 57 +++++++++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 keras_hub/src/models/vit/vit_presets.py diff --git a/keras_hub/src/models/vit/__init__.py b/keras_hub/src/models/vit/__init__.py index e69de29bb..e4b42de07 100644 --- a/keras_hub/src/models/vit/__init__.py +++ b/keras_hub/src/models/vit/__init__.py @@ -0,0 +1,5 @@ +from keras_hub.src.models.vit.vit_backbone import ViTBackbone +from keras_hub.src.models.vit.vit_presets import backbone_presets +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, ViTBackbone) diff --git a/keras_hub/src/models/vit/vit_presets.py b/keras_hub/src/models/vit/vit_presets.py new file mode 100644 index 000000000..445372bd4 --- /dev/null +++ b/keras_hub/src/models/vit/vit_presets.py @@ -0,0 +1,57 @@ +"""ViT model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = { + "vit_base_patch16_224_imagenet": { + "metadata": { + "description": ( + "ViT-B16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 224x224 " + ), + "params": 85798656, + "official_name": "ViT", + "path": "vit", + "model_card": "https://www.kaggle.com/models/keras/vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet/1", + }, + "vit_base_patch16_384_imagenet": { + "metadata": { + "description": ( + "ViT-B16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 384x384 " + ), + "params": 86090496, + "official_name": "ViT", + "path": "vit", + "model_card": "https://www.kaggle.com/models/keras/vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_384_imagenet/1", + }, + "vit_large_patch16_224_imagenet": { + "metadata": { + "description": ( + "ViT-L16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 224x224 " + ), + "params": 303301632, + "official_name": "ViT", + "path": "vit", + "model_card": "https://www.kaggle.com/models/keras/vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet/1", + }, + "vit_large_patch16_384_imagenet": { + "metadata": { + "description": ( + "ViT-L16 model pre-trained on the ImageNet 1k dataset with " + "image resolution of 384x384 " + ), + "params": 303690752, + "official_name": "ViT", + "path": "vit", + "model_card": "https://www.kaggle.com/models/keras/vit", + }, + "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/1", + }, +} From 5374c704ac08117576432d25debcfe26df136139 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 4 Dec 2024 16:33:32 -0800 Subject: [PATCH 30/35] load vit preset from hugging face directly --- keras_hub/src/models/vit/vit_presets.py | 8 - .../src/utils/transformers/convert_vit.py | 151 ++++++++++++++++++ .../src/utils/transformers/preset_loader.py | 17 ++ 3 files changed, 168 insertions(+), 8 deletions(-) create mode 100644 keras_hub/src/utils/transformers/convert_vit.py diff --git a/keras_hub/src/models/vit/vit_presets.py b/keras_hub/src/models/vit/vit_presets.py index 445372bd4..16a6f694e 100644 --- a/keras_hub/src/models/vit/vit_presets.py +++ b/keras_hub/src/models/vit/vit_presets.py @@ -9,9 +9,7 @@ "image resolution of 224x224 " ), "params": 85798656, - "official_name": "ViT", "path": "vit", - "model_card": "https://www.kaggle.com/models/keras/vit", }, "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_224_imagenet/1", }, @@ -22,9 +20,7 @@ "image resolution of 384x384 " ), "params": 86090496, - "official_name": "ViT", "path": "vit", - "model_card": "https://www.kaggle.com/models/keras/vit", }, "kaggle_handle": "kaggle://keras/vit/keras/vit_base_patch16_384_imagenet/1", }, @@ -35,9 +31,7 @@ "image resolution of 224x224 " ), "params": 303301632, - "official_name": "ViT", "path": "vit", - "model_card": "https://www.kaggle.com/models/keras/vit", }, "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_224_imagenet/1", }, @@ -48,9 +42,7 @@ "image resolution of 384x384 " ), "params": 303690752, - "official_name": "ViT", "path": "vit", - "model_card": "https://www.kaggle.com/models/keras/vit", }, "kaggle_handle": "kaggle://keras/vit/keras/vit_large_patch16_384_imagenet/1", }, diff --git a/keras_hub/src/utils/transformers/convert_vit.py b/keras_hub/src/utils/transformers/convert_vit.py new file mode 100644 index 000000000..101e95b04 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_vit.py @@ -0,0 +1,151 @@ +import numpy as np + +from keras_hub.src.models.vit.vit_backbone import ViTBackbone + +backbone_cls = ViTBackbone + + +def convert_backbone_config(transformers_config): + image_size = transformers_config["image_size"] + return { + "image_shape": (image_size, image_size, 3), + "patch_size": transformers_config["patch_size"], + "num_layers": transformers_config["num_hidden_layers"], + "num_heads": transformers_config["num_attention_heads"], + "hidden_dim": transformers_config["hidden_size"], + "mlp_dim": transformers_config["intermediate_size"], + "dropout_rate": transformers_config["hidden_dropout_prob"], + "attention_dropout": transformers_config[ + "attention_probs_dropout_prob" + ], + "use_mha_bias": transformers_config["qkv_bias"], + } + + +def convert_weights(backbone, loader, transformers_config): + + def port_ln(keras_variable, weight_key): + loader.port_weight(keras_variable.gamma, f"{weight_key}.weight") + loader.port_weight(keras_variable.beta, f"{weight_key}.bias") + + def port_dense(keras_variable, weight_key): + loader.port_weight( + keras_variable.kernel, + f"{weight_key}.weight", + hook_fn=lambda x, _: x.T, + ) + if keras_variable.bias is not None: + loader.port_weight(keras_variable.bias, f"{weight_key}.bias") + + def port_mha(keras_variable, weight_key, num_heads, hidden_dim): + # query + loader.port_weight( + keras_variable.query_dense.kernel, + f"{weight_key}.attention.query.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + loader.port_weight( + keras_variable.query_dense.bias, + f"{weight_key}.attention.query.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # key + loader.port_weight( + keras_variable.key_dense.kernel, + f"{weight_key}.attention.key.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + loader.port_weight( + keras_variable.key_dense.bias, + f"{weight_key}.attention.key.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # value + loader.port_weight( + keras_variable.value_dense.kernel, + f"{weight_key}.attention.value.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (hidden_dim, num_heads, hidden_dim // num_heads) + ), + ) + loader.port_weight( + keras_variable.value_dense.bias, + f"{weight_key}.attention.value.bias", + hook_fn=lambda x, _: np.reshape( + x, (num_heads, hidden_dim // num_heads) + ), + ) + # output + loader.port_weight( + keras_variable.output_dense.kernel, + f"{weight_key}.output.dense.weight", + hook_fn=lambda x, _: np.reshape( + x.T, (num_heads, hidden_dim // num_heads, hidden_dim) + ), + ) + loader.port_weight( + keras_variable.output_dense.bias, f"{weight_key}.output.dense.bias" + ) + + loader.port_weight( + keras_variable=backbone.layers[1].patch_embedding.kernel, + hf_weight_key="vit.embeddings.patch_embeddings.projection.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + loader.port_weight( + backbone.layers[1].patch_embedding.bias, + "vit.embeddings.patch_embeddings.projection.bias", + ) + + loader.port_weight( + backbone.layers[1].class_token, + "vit.embeddings.cls_token", + ) + + loader.port_weight( + backbone.layers[1].position_embedding.embeddings, + "vit.embeddings.position_embeddings", + hook_fn=lambda x, _: x[0], + ) + encoder_layers = backbone.layers[2].encoder_layers + for i, encoder_block in enumerate(encoder_layers): + prefix = "vit.encoder.layer" + num_heads = encoder_block.num_heads + hidden_dim = encoder_block.hidden_dim + + port_mha( + encoder_block.mha, + f"{prefix}.{i}.attention", + num_heads, + hidden_dim, + ) + port_ln(encoder_block.layer_norm_1, f"{prefix}.{i}.layernorm_before") + port_ln(encoder_block.layer_norm_2, f"{prefix}.{i}.layernorm_after") + + port_dense( + encoder_block.mlp.dense_1, f"{prefix}.{i}.intermediate.dense" + ) + port_dense(encoder_block.mlp.dense_2, f"{prefix}.{i}.output.dense") + port_ln(backbone.layers[2].layer_norm, "vit.layernorm") + + +def convert_head(task, loader, transformers_config): + prefix = "classifier." + loader.port_weight( + task.output_dense.kernel, + hf_weight_key=prefix + "weight", + hook_fn=lambda x, _: np.transpose(np.squeeze(x)), + ) + loader.port_weight( + task.output_dense.bias, + hf_weight_key=prefix + "bias", + ) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index b285a3c09..0f402e8d8 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -1,5 +1,6 @@ """Convert huggingface models to KerasHub.""" +from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup from keras_hub.src.utils.transformers import convert_albert @@ -11,6 +12,7 @@ from keras_hub.src.utils.transformers import convert_llama3 from keras_hub.src.utils.transformers import convert_mistral from keras_hub.src.utils.transformers import convert_pali_gemma +from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -37,6 +39,8 @@ def __init__(self, preset, config): self.converter = convert_mistral elif model_type == "paligemma": self.converter = convert_pali_gemma + elif model_type == "vit": + self.converter = convert_vit else: raise ValueError( "KerasHub has no converter for huggingface/transformers models " @@ -55,6 +59,19 @@ def load_backbone(self, cls, load_weights, **kwargs): self.converter.convert_weights(backbone, loader, self.config) return backbone + def load_task(self, cls, load_weights, load_task_weights, **kwargs): + if not load_task_weights or not issubclass(cls, ImageClassifier): + return super().load_task( + cls, load_weights, load_task_weights, **kwargs + ) + # Support loading the classification head for classifier models. + kwargs["num_classes"] = self.config["num_classes"] + task = super().load_task(cls, load_weights, load_task_weights, **kwargs) + if load_task_weights: + with SafetensorLoader(self.preset, prefix="") as loader: + self.converter.convert_head(task, loader, self.config) + return task + def load_tokenizer(self, cls, config_name="tokenizer.json", **kwargs): return self.converter.convert_tokenizer(cls, self.preset, **kwargs) From ebee9ef23969f7e513536f2609220a6c78eafbfa Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Wed, 4 Dec 2024 16:59:00 -0800 Subject: [PATCH 31/35] nit --- keras_hub/src/utils/transformers/convert_vit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/utils/transformers/convert_vit.py b/keras_hub/src/utils/transformers/convert_vit.py index 101e95b04..9ce3b3d3a 100644 --- a/keras_hub/src/utils/transformers/convert_vit.py +++ b/keras_hub/src/utils/transformers/convert_vit.py @@ -143,7 +143,7 @@ def convert_head(task, loader, transformers_config): loader.port_weight( task.output_dense.kernel, hf_weight_key=prefix + "weight", - hook_fn=lambda x, _: np.transpose(np.squeeze(x)), + hook_fn=lambda x, _: x.T, ) loader.port_weight( task.output_dense.bias, From 93064bd50a4523e24db850c806ffc6451e2aa469 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Thu, 5 Dec 2024 14:37:13 -0800 Subject: [PATCH 32/35] handle num classes case for ViT --- keras_hub/src/utils/transformers/preset_loader.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 0f402e8d8..a3c46f4cf 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -60,12 +60,18 @@ def load_backbone(self, cls, load_weights, **kwargs): return backbone def load_task(self, cls, load_weights, load_task_weights, **kwargs): - if not load_task_weights or not issubclass(cls, ImageClassifier): + architecture = self.config["architectures"][0] + if ( + not load_task_weights + or not issubclass(cls, ImageClassifier) + or architecture == "ViTModel" + ): return super().load_task( cls, load_weights, load_task_weights, **kwargs ) # Support loading the classification head for classifier models. - kwargs["num_classes"] = self.config["num_classes"] + if architecture == "ViTForImageClassification": + kwargs["num_classes"] = len(self.config["id2label"]) task = super().load_task(cls, load_weights, load_task_weights, **kwargs) if load_task_weights: with SafetensorLoader(self.preset, prefix="") as loader: From e206e7b1762c610bfb39ed9a947f95062b922058 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 9 Dec 2024 10:01:30 -0800 Subject: [PATCH 33/35] replace toke with first --- keras_hub/src/models/image_classifier.py | 6 +++--- keras_hub/src/models/vit/vit_image_classifier.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index ceafa76cb..169418fad 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -117,12 +117,12 @@ def __init__( dtype=head_dtype, name="pooler", ) - elif pooling == "token": + elif pooling == "first": self.pooler = None else: raise ValueError( "Unknown `pooling` type. Polling should be either `'avg'` or " - f"`'max' or 'token'`. Received: pooling={pooling}." + f"`'max' or 'first'`. Received: pooling={pooling}." ) self.output_dropout = keras.layers.Dropout( dropout, @@ -139,7 +139,7 @@ def __init__( # === Functional Model === inputs = self.backbone.input x = self.backbone(inputs) - if pooling == "token": # used for Vision Transformer(ViT) + if pooling == "first": # used for Vision Transformer(ViT) x = x[:, 0] else: x = self.pooler(x) diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py index 579538b6b..c8bf594da 100644 --- a/keras_hub/src/models/vit/vit_image_classifier.py +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -11,5 +11,5 @@ class ViTImageClassifier(ImageClassifier): backbone_cls = ViTBackbone preprocessor_cls = ViTImageClassifierPreprocessor - def __init__(self, pooling="token", **kwargs): + def __init__(self, pooling="first", **kwargs): super().__init__(pooling=pooling, **kwargs) From 7a39d5bd0f055323f909e3802c9849e1dbc4d80c Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Mon, 9 Dec 2024 16:25:36 -0800 Subject: [PATCH 34/35] convert all vit checkpoints using tools --- .../convert_vit_checkpoints.py | 109 ++++++++++++++---- 1 file changed, 84 insertions(+), 25 deletions(-) diff --git a/tools/checkpoint_conversion/convert_vit_checkpoints.py b/tools/checkpoint_conversion/convert_vit_checkpoints.py index 4868a7b04..077753522 100644 --- a/tools/checkpoint_conversion/convert_vit_checkpoints.py +++ b/tools/checkpoint_conversion/convert_vit_checkpoints.py @@ -32,8 +32,15 @@ PRESET_MAP = { "vit_base_patch16_224": "google/vit-base-patch16-224", "vit_base_patch16_384": "google/vit-base-patch16-384", + "vit_base_patch32_384": "google/vit-base-patch32-384", "vit_large_patch16_224": "google/vit-large-patch16-224", "vit_large_patch16_384": "google/vit-large-patch16-384", + "vit_large_patch32_384": "google/vit-large-patch32-384", + "vit_base_patch16_224_in21k": "google/vit-base-patch16-224-in21k", + "vit_base_patch32_224_in21k": "google/vit-base-patch32-224-in21k", + "vit_large_patch16_224_in21k": "google/vit-large-patch16-224-in21k", + "vit_large_patch32_224_in21k": "google/vit-large-patch32-224-in21k", + "vit_huge_patch14_224_in21k": "google/vit-huge-patch14-224-in21k", } flags.DEFINE_string( @@ -65,13 +72,10 @@ def convert_model(hf_model): use_mha_bias=config["qkv_bias"], ) - return ViTImageClassifier( - backbone=backbone, - num_classes=1000, # num classes in ImageNet - ) + return backbone, config -def convert_weights(keras_hub_model, hf_model): +def convert_backbone_weights(backbone, hf_model): state_dict = hf_model.state_dict() state_dict.update(hf_model.named_buffers()) @@ -154,27 +158,27 @@ def port_mha(keras_variable, weight_key, num_heads, hidden_dim): ) port_weights( - keras_hub_model.backbone.layers[1].patch_embedding.kernel, + backbone.layers[1].patch_embedding.kernel, "vit.embeddings.patch_embeddings.projection.weight", hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), ) port_weights( - keras_hub_model.backbone.layers[1].patch_embedding.bias, + backbone.layers[1].patch_embedding.bias, "vit.embeddings.patch_embeddings.projection.bias", ) port_weights( - keras_hub_model.backbone.layers[1].class_token, + backbone.layers[1].class_token, "vit.embeddings.cls_token", ) port_weights( - keras_hub_model.backbone.layers[1].position_embedding.embeddings, + backbone.layers[1].position_embedding.embeddings, "vit.embeddings.position_embeddings", hook_fn=lambda x, _: x[0], ) - encoder_layers = keras_hub_model.backbone.layers[2].encoder_layers + encoder_layers = backbone.layers[2].encoder_layers for i, encoder_block in enumerate(encoder_layers): prefix = "vit.encoder.layer" num_heads = encoder_block.num_heads @@ -194,8 +198,31 @@ def port_mha(keras_variable, weight_key, num_heads, hidden_dim): ) port_dense(encoder_block.mlp.dense_2, f"{prefix}.{i}.output.dense") - port_ln(keras_hub_model.backbone.layers[2].layer_norm, "vit.layernorm") - port_dense(keras_hub_model.output_dense, "classifier") + port_ln(backbone.layers[2].layer_norm, "vit.layernorm") + # port_dense(keras_hub_model.output_dense, "classifier") + + +def convert_head_weights(keras_model, hf_model): + state_dict = hf_model.state_dict() + state_dict.update(hf_model.named_buffers()) + + def port_weights(keras_variable, weight_key, hook_fn=None): + torch_tensor = state_dict[weight_key].cpu().numpy() + if hook_fn: + torch_tensor = hook_fn(torch_tensor, list(keras_variable.shape)) + keras_variable.assign(torch_tensor) + + prefix = "classifier." + + port_weights( + keras_model.output_dense.kernel, + prefix + "weight", + hook_fn=lambda x, _: x.T, + ) + port_weights( + keras_model.output_dense.bias, + prefix + "bias", + ) def convert_image_converter(hf_image_processor): @@ -217,6 +244,7 @@ def validate_output( keras_image_converter, hf_model, hf_image_processor, + head_weights=False, ): file = keras.utils.get_file( origin=("http://images.cocodataset.org/val2017/000000039769.jpg") @@ -244,7 +272,11 @@ def validate_output( ) ) hf_outputs = hf_model(**hf_inputs) - hf_vision_logits = hf_outputs.logits.detach().cpu().numpy() + if head_weights: + hf_vision_logits = hf_outputs.logits.detach().cpu().numpy() + + else: + hf_vision_logits = hf_outputs.last_hidden_state.detach().cpu().numpy() # Call with keras. keras_outputs = keras_model(keras_preprocessed) @@ -252,6 +284,15 @@ def validate_output( print("🔶 Keras output:", keras_vision_logits[0, :10]) print("🔶 HF output:", hf_vision_logits[0, :10]) + if head_weights: + print( + "🔶 HF top 5 ImageNet outputs:", + keras_hub.utils.decode_imagenet_predictions(hf_vision_logits), + ) + print( + "🔶 Keras top 5 ImageNet outputs:", + keras_hub.utils.decode_imagenet_predictions(keras_outputs), + ) modeling_diff = np.mean(np.abs(keras_vision_logits - hf_vision_logits)) print("🔶 Modeling difference:", modeling_diff) preprocessing_diff = np.mean( @@ -279,25 +320,43 @@ def main(_): hf_preprocessor = ViTImageProcessor.from_pretrained(hf_preset) hf_model.eval() - keras_model = convert_model(hf_model) + keras_backbone, hf_config = convert_model(hf_model) keras_image_converter = convert_image_converter(hf_preprocessor) keras_image_preprocessor = ViTImageClassifierPreprocessor( image_converter=keras_image_converter ) print("✅ KerasHub model loaded.") - convert_weights(keras_model, hf_model) - print("✅ Weights converted.") + convert_backbone_weights(keras_backbone, hf_model) + print("✅ Backbone weights converted.") - validate_output( - keras_model, - keras_image_converter, - hf_model, - hf_preprocessor, - ) - print("✅ Output validated.") - keras_model.preprocessor = keras_image_preprocessor - keras_model.save_to_preset(f"./{preset}") + if hf_config["architectures"][0] == "ViTForImageClassification": + keras_model = ViTImageClassifier( + backbone=keras_backbone, num_classes=len(hf_config["id2label"]) + ) + convert_head_weights(keras_model, hf_model) + print("✅ Head weights converted.") + validate_output( + keras_model, + keras_image_converter, + hf_model, + hf_preprocessor, + head_weights=True, + ) + print("✅ Output validated.") + keras_model.preprocessor = keras_image_preprocessor + keras_model.save_to_preset(f"./{preset}") + else: + hf_model = hf_model.vit + validate_output( + keras_backbone, + keras_image_converter, + hf_model, + hf_preprocessor, + ) + print("✅ Output validated.") + keras_backbone.save_to_preset(f"./{preset}") + keras_image_preprocessor.save_to_preset(f"./{preset}") print(f"🏁 Preset saved to ./{preset}.") From 0827954b9335a15b0925e34a374a311a5a9457e8 Mon Sep 17 00:00:00 2001 From: Sravana Neeli Date: Tue, 10 Dec 2024 12:25:41 -0800 Subject: [PATCH 35/35] Add custom ImageClassifier for ViT --- keras_hub/src/models/image_classifier.py | 9 +- .../src/models/vit/vit_image_classifier.py | 179 +++++++++++++++++- 2 files changed, 179 insertions(+), 9 deletions(-) diff --git a/keras_hub/src/models/image_classifier.py b/keras_hub/src/models/image_classifier.py index 169418fad..e75e39089 100644 --- a/keras_hub/src/models/image_classifier.py +++ b/keras_hub/src/models/image_classifier.py @@ -117,12 +117,10 @@ def __init__( dtype=head_dtype, name="pooler", ) - elif pooling == "first": - self.pooler = None else: raise ValueError( "Unknown `pooling` type. Polling should be either `'avg'` or " - f"`'max' or 'first'`. Received: pooling={pooling}." + f"`'max'`. Received: pooling={pooling}." ) self.output_dropout = keras.layers.Dropout( dropout, @@ -139,10 +137,7 @@ def __init__( # === Functional Model === inputs = self.backbone.input x = self.backbone(inputs) - if pooling == "first": # used for Vision Transformer(ViT) - x = x[:, 0] - else: - x = self.pooler(x) + x = self.pooler(x) x = self.output_dropout(x) outputs = self.output_dense(x) super().__init__( diff --git a/keras_hub/src/models/vit/vit_image_classifier.py b/keras_hub/src/models/vit/vit_image_classifier.py index c8bf594da..3bb3463da 100644 --- a/keras_hub/src/models/vit/vit_image_classifier.py +++ b/keras_hub/src/models/vit/vit_image_classifier.py @@ -1,5 +1,9 @@ +import keras +from keras import ops + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.task import Task from keras_hub.src.models.vit.vit_backbone import ViTBackbone from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( ViTImageClassifierPreprocessor, @@ -8,8 +12,179 @@ @keras_hub_export("keras_hub.models.ViTImageClassifier") class ViTImageClassifier(ImageClassifier): + """ViT image classification task. + + `ViTImageClassifier` tasks wrap a `keras_hub.models.ViTBackbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + image classification. `ViTImageClassifier` tasks take an additional + `num_classes` argument, controlling the number of predicted output classes. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a string and `y` is a integer from `[0, num_classes)`. + + Not that unlike `keras_hub.model.ImageClassifier`, the `ViTImageClassifier` + we pluck out `cls_token` which is first seqence from the backbone. + + Args: + backbone: A `keras_hub.models.ViTBackbone` instance or a `keras.Model`. + num_classes: int. The number of classes to predict. + preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, + a `keras.Layer` instance, or a callable. If `None` no preprocessing + will be applied to the inputs. + pooling: String specifying the classification strategy. The choice + impacts the dimensionality and nature of the feature vector used for + classification. + `"token"`: A single vector (class token) representing the + overall image features. + `"gap"`: A single vector representing the average features + across the spatial dimensions. + `"token_unpooled"`: Ouputs directly tokens from `ViTBackbone` + representation_size: Optional dimensionality of the intermediate + representation layer before the final classification layer. + If `None`, the output of the transformer is directly used." + Defaults to `None` + activation: `None`, str, or callable. The activation function to use on + the `Dense` layer. Set `activation=None` to return the output + logits. Defaults to `"softmax"`. + head_dtype: `None`, str, or `keras.mixed_precision.DTypePolicy`. The + dtype to use for the classification head's computations and weights. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + classifier = keras_hub.models.ViTImageClassifier.from_preset( + "vgg_16_imagenet" + ) + classifier.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + labels = [0, 3] + classifier = keras_hub.models.VGGImageClassifier.from_preset( + "vit_base_patch16_224" + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + classifier = keras_hub.models.VGGImageClassifier.from_preset( + "vit_base_patch16_224" + ) + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + ) + classifier.backbone.trainable = False + classifier.fit(x=images, y=labels, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + labels = [0, 3] + model = keras_hub.models.ViTBackbone( + image_shape = (224, 224, 3), + patch_size=16, + num_layers=6, + num_heads=3, + hidden_dim=768, + mlp_dim=2048 + ) + classifier = keras_hub.models.ViTImageClassifier( + backbone=backbone, + num_classes=4, + ) + classifier.fit(x=images, y=labels, batch_size=2) + ``` + """ + backbone_cls = ViTBackbone preprocessor_cls = ViTImageClassifierPreprocessor - def __init__(self, pooling="first", **kwargs): - super().__init__(pooling=pooling, **kwargs) + def __init__( + self, + backbone, + num_classes, + preprocessor=None, + pooling="token", + representation_size=None, + activation=None, + dropout=0.0, + head_dtype=None, + **kwargs, + ): + head_dtype = head_dtype or backbone.dtype_policy + + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + if representation_size is not None: + self.representation_layer = keras.layers.Dense( + representation_size, activation="tanh", name="pre_logits" + ) + + self.dropout = keras.layers.Dropout( + rate=dropout, + dtype=head_dtype, + name="output_dropout", + ) + self.output_dense = keras.layers.Dense( + num_classes, + activation=activation, + dtype=head_dtype, + name="predictions", + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + if pooling == "token": + x = x[:, 0] + elif pooling == "gap": + ndim = len(ops.shape(x)) + x = ops.mean(x, axis=list(range(1, ndim - 1))) # (1,) or (1,2) + elif pooling == "token_unpooled": + pass + + if representation_size is not None: + x = self.representation_layer(x) + + x = self.dropout(x) + outputs = self.output_dense(x) + + # Skip the parent class functional model. + Task.__init__( + self, + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + # === config === + self.num_classes = num_classes + self.pooling = pooling + self.representation_size = representation_size + self.activation = activation + self.dropout = dropout + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "pooling": self.pooling, + "representation_size": self.representation_size, + "activation": self.activation, + "dropout": self.dropout, + } + ) + return config