Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ViT] Vision Transformer (ViT) backbone, layers, and image classifier #1989

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
741b889
vit base
sineeli Nov 14, 2024
13dae08
Add vit backbone, classifier and preprocessor layers
sineeli Nov 15, 2024
b64b137
update args
sineeli Nov 15, 2024
429d635
add default args
sineeli Nov 15, 2024
6d69abc
correct build method
sineeli Nov 15, 2024
2e87884
fix build issues
sineeli Nov 15, 2024
bd3cce0
fix bugs
sineeli Nov 16, 2024
4232a06
Update backbone args and configs
sineeli Nov 18, 2024
32b08c5
correct position ids dtype
sineeli Nov 18, 2024
cc938c6
build token layer
sineeli Nov 18, 2024
78812de
token layer build
sineeli Nov 18, 2024
8a20465
assign correct dtype to TokenLayer
sineeli Nov 18, 2024
de754cc
fix build shape of token layer
sineeli Nov 18, 2024
84ba896
correct mlp dens var names
sineeli Nov 18, 2024
7a70e16
use default norm mean and std as per hugging face config
sineeli Nov 18, 2024
81e3021
correct position_ids
sineeli Nov 19, 2024
d3061d6
remove separate token layer
sineeli Nov 19, 2024
618e163
correct position ids
sineeli Nov 19, 2024
2338637
Checkpoint conversion script and minor changes
sineeli Nov 21, 2024
95e5868
correct flag type
sineeli Nov 21, 2024
9d2e5bd
correct key name
sineeli Nov 21, 2024
ac7d1d3
use flat list later we can extract in between layers if needed
sineeli Nov 21, 2024
8065c01
Add test cases and correct dtype polciy for model
sineeli Nov 21, 2024
a8be824
add proper docstrings
sineeli Nov 21, 2024
3f027a0
correct test cases
sineeli Nov 22, 2024
05acb70
use numpy for test data
sineeli Nov 25, 2024
521df6f
nit
sineeli Nov 25, 2024
ae2b800
nit
sineeli Nov 27, 2024
26c2224
Merge branch 'master' into sineeli/ViT
sineeli Dec 2, 2024
92149d5
add presets
sineeli Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
SegFormerImageConverter,
)
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter
from keras_hub.src.models.whisper.whisper_audio_converter import (
WhisperAudioConverter,
)
5 changes: 5 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@
from keras_hub.src.models.vgg.vgg_image_classifier_preprocessor import (
VGGImageClassifierPreprocessor,
)
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier
from keras_hub.src.models.vit.vit_image_classifier_preprocessor import (
ViTImageClassifierPreprocessor,
)
from keras_hub.src.models.vit_det.vit_det_backbone import ViTDetBackbone
from keras_hub.src.models.whisper.whisper_backbone import WhisperBackbone
from keras_hub.src.models.whisper.whisper_tokenizer import WhisperTokenizer
Expand Down
9 changes: 7 additions & 2 deletions keras_hub/src/models/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,12 @@ def __init__(
dtype=head_dtype,
name="pooler",
)
elif pooling == "token":
self.pooler = None
else:
raise ValueError(
"Unknown `pooling` type. Polling should be either `'avg'` or "
f"`'max'`. Received: pooling={pooling}."
f"`'max' or 'token'`. Received: pooling={pooling}."
)
self.output_dropout = keras.layers.Dropout(
dropout,
Expand All @@ -137,7 +139,10 @@ def __init__(
# === Functional Model ===
inputs = self.backbone.input
x = self.backbone(inputs)
x = self.pooler(x)
if pooling == "token": # used for Vision Transformer(ViT)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"token" feels like a bit a weird name here, especially when compared to "avg" or "max". Maybe "first"?

x = x[:, 0]
else:
x = self.pooler(x)
x = self.output_dropout(x)
outputs = self.output_dense(x)
super().__init__(
Expand Down
5 changes: 5 additions & 0 deletions keras_hub/src/models/vit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
from keras_hub.src.models.vit.vit_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, ViTBackbone)
152 changes: 152 additions & 0 deletions keras_hub/src/models/vit/vit_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.vit.vit_layers import ViTEncoder
from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding
from keras_hub.src.utils.keras_utils import standardize_data_format


@keras_hub_export("keras_hub.models.ViTBackbone")
class ViTBackbone(Backbone):
"""Vision Transformer (ViT) backbone.

This backbone implements the Vision Transformer architecture as described in
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929).
It transforms the input image into a sequence of patches, embeds them, and
then processes them through a series of Transformer encoder layers.

Args:
image_shape: A tuple or list of 3 integers representing the shape of the
input image `(height, width, channels)`, `height` and `width` must
be equal.
patch_size: int. The size of each image patch, the input image will be
divided into patches of shape `(patch_size, patch_size)`.
num_layers: int. The number of transformer encoder layers.
num_heads: int. specifying the number of attention heads in each
Transformer encoder layer.
hidden_dim: int. The dimensionality of the hidden representations.
mlp_dim: int. The dimensionality of the intermediate MLP layer in
each Transformer encoder layer.
dropout_rate: float. The dropout rate for the Transformer encoder
layers.
attention_dropout: float. The dropout rate for the attention mechanism
in each Transformer encoder layer.
layer_norm_epsilon: float. Value used for numerical stability in
layer normalization.
use_mha_bias: bool. Whether to use bias in the multi-head
attention layers.
use_mlp_bias: bool. Whether to use bias in the MLP layers.
data_format: str. `"channels_last"` or `"channels_first"`, specifying
the data format for the input image. If `None`, defaults to
`"channels_last"`.
dtype: The dtype of the layer weights. Defaults to None.
**kwargs: Additional keyword arguments to be passed to the parent
`Backbone` class.
"""

def __init__(
self,
image_shape,
patch_size,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout_rate=0.0,
attention_dropout=0.0,
layer_norm_epsilon=1e-6,
use_mha_bias=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these args are missing in get_config

use_mlp_bias=True,
data_format=None,
dtype=None,
**kwargs,
):
# === Laters ===
data_format = standardize_data_format(data_format)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment === Layers ===

h_axis, w_axis, channels_axis = (
(-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3)
)
# Check that the input image is well specified.
if image_shape[h_axis] is None or image_shape[w_axis] is None:
raise ValueError(
f"Image shape must have defined height and width. Found `None` "
f"at index {h_axis} (height) or {w_axis} (width). "
f"Image shape: {image_shape}"
)
if image_shape[h_axis] != image_shape[w_axis]:
raise ValueError(
f"Image height and width must be equal. Found height: "
f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at "
f"indices {h_axis} and {w_axis} respectively. Image shape: "
f"{image_shape}"
)

num_channels = image_shape[channels_axis]

# === Functional Model ===
inputs = keras.layers.Input(shape=image_shape)

x = ViTPatchingAndEmbedding(
image_size=image_shape[h_axis],
patch_size=patch_size,
hidden_dim=hidden_dim,
num_channels=num_channels,
data_format=data_format,
dtype=dtype,
name="vit_patching_and_embedding",
)(inputs)

output = ViTEncoder(
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
mlp_dim=mlp_dim,
dropout_rate=dropout_rate,
attention_dropout=attention_dropout,
layer_norm_epsilon=layer_norm_epsilon,
use_mha_bias=use_mha_bias,
use_mlp_bias=use_mlp_bias,
dtype=dtype,
name="vit_encoder",
)(x)

super().__init__(
inputs=inputs,
outputs=output,
dtype=dtype,
**kwargs,
)

# === Config ===
self.image_shape = image_shape
self.patch_size = patch_size
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.dropout_rate = dropout_rate
self.attention_dropout = attention_dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.use_mha_bias = use_mha_bias
self.use_mlp_bias = use_mlp_bias
self.data_format = data_format

def get_config(self):
config = super().get_config()
config.update(
{
"image_shape": self.image_shape,
"patch_size": self.patch_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"mlp_dim": self.mlp_dim,
"dropout_rate": self.dropout_rate,
"attention_dropout": self.attention_dropout,
"layer_norm_epsilon": self.layer_norm_epsilon,
"use_mha_bias": self.use_mha_bias,
"use_mlp_bias": self.use_mlp_bias,
}
)
return config
37 changes: 37 additions & 0 deletions keras_hub/src/models/vit/vit_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
from keras import ops

from keras_hub.src.models.vit.vit_backbone import ViTBackbone
from keras_hub.src.tests.test_case import TestCase


class ViTBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"image_shape": (28, 28, 3),
"patch_size": 4,
"num_layers": 3,
"hidden_dim": 48,
"num_heads": 6,
"mlp_dim": 48 * 4,
"use_mha_bias": True,
}
self.input_size = 28
self.input_data = ops.ones((2, self.input_size, self.input_size, 3))

def test_backbone_basics(self):
self.run_backbone_test(
cls=ViTBackbone,
init_kwargs={**self.init_kwargs},
input_data=self.input_data,
expected_output_shape=(2, 50, 48),
run_quantization_check=False,
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=ViTBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
15 changes: 15 additions & 0 deletions keras_hub/src/models/vit/vit_image_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.image_classifier import ImageClassifier
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
from keras_hub.src.models.vit.vit_image_classifier_preprocessor import (
ViTImageClassifierPreprocessor,
)


@keras_hub_export("keras_hub.models.ViTImageClassifier")
class ViTImageClassifier(ImageClassifier):
backbone_cls = ViTBackbone
preprocessor_cls = ViTImageClassifierPreprocessor

def __init__(self, pooling="token", **kwargs):
super().__init__(pooling=pooling, **kwargs)
12 changes: 12 additions & 0 deletions keras_hub/src/models/vit/vit_image_classifier_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.image_classifier_preprocessor import (
ImageClassifierPreprocessor,
)
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter


@keras_hub_export("keras_hub.models.ViTImageClassifierPreprocessor")
class ViTImageClassifierPreprocessor(ImageClassifierPreprocessor):
backbone_cls = ViTBackbone
image_converter_cls = ViTImageConverter
57 changes: 57 additions & 0 deletions keras_hub/src/models/vit/vit_image_classifier_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np
import pytest

from keras_hub.src.models.vit.vit_backbone import ViTBackbone
from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier
from keras_hub.src.models.vit.vit_image_classifier_preprocessor import (
ViTImageClassifierPreprocessor,
)
from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter
from keras_hub.src.tests.test_case import TestCase


class ViTImageClassifierTest(TestCase):
def setUp(self):
self.images = np.ones((2, 28, 28, 3))
self.labels = [0, 1]
self.backbone = ViTBackbone(
image_shape=(28, 28, 3),
patch_size=4,
num_layers=3,
num_heads=6,
hidden_dim=48,
mlp_dim=48 * 4,
)
image_converter = ViTImageConverter(
image_size=(28, 28),
scale=1 / 255.0,
)
preprocessor = ViTImageClassifierPreprocessor(
image_converter=image_converter
)
self.init_kwargs = {
"backbone": self.backbone,
"num_classes": 2,
"preprocessor": preprocessor,
}
self.train_data = (self.images, self.labels)

def test_classifier_basics(self):
self.run_task_test(
cls=ViTImageClassifier,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
expected_output_shape=(2, 2),
)

def test_head_dtype(self):
model = ViTImageClassifier(**self.init_kwargs, head_dtype="bfloat16")
self.assertEqual(model.output_dense.compute_dtype, "bfloat16")

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=ViTImageClassifier,
init_kwargs=self.init_kwargs,
input_data=self.images,
)
Loading
Loading