-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ViT] Vision Transformer (ViT) backbone, layers, and image classifier #1989
Open
sineeli
wants to merge
30
commits into
keras-team:master
Choose a base branch
from
sineeli:sineeli/ViT
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
741b889
vit base
sineeli 13dae08
Add vit backbone, classifier and preprocessor layers
sineeli b64b137
update args
sineeli 429d635
add default args
sineeli 6d69abc
correct build method
sineeli 2e87884
fix build issues
sineeli bd3cce0
fix bugs
sineeli 4232a06
Update backbone args and configs
sineeli 32b08c5
correct position ids dtype
sineeli cc938c6
build token layer
sineeli 78812de
token layer build
sineeli 8a20465
assign correct dtype to TokenLayer
sineeli de754cc
fix build shape of token layer
sineeli 84ba896
correct mlp dens var names
sineeli 7a70e16
use default norm mean and std as per hugging face config
sineeli 81e3021
correct position_ids
sineeli d3061d6
remove separate token layer
sineeli 618e163
correct position ids
sineeli 2338637
Checkpoint conversion script and minor changes
sineeli 95e5868
correct flag type
sineeli 9d2e5bd
correct key name
sineeli ac7d1d3
use flat list later we can extract in between layers if needed
sineeli 8065c01
Add test cases and correct dtype polciy for model
sineeli a8be824
add proper docstrings
sineeli 3f027a0
correct test cases
sineeli 05acb70
use numpy for test data
sineeli 521df6f
nit
sineeli ae2b800
nit
sineeli 26c2224
Merge branch 'master' into sineeli/ViT
sineeli 92149d5
add presets
sineeli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from keras_hub.src.models.vit.vit_backbone import ViTBackbone | ||
from keras_hub.src.models.vit.vit_presets import backbone_presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(backbone_presets, ViTBackbone) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
from keras_hub.src.models.vit.vit_layers import ViTEncoder | ||
from keras_hub.src.models.vit.vit_layers import ViTPatchingAndEmbedding | ||
from keras_hub.src.utils.keras_utils import standardize_data_format | ||
|
||
|
||
@keras_hub_export("keras_hub.models.ViTBackbone") | ||
class ViTBackbone(Backbone): | ||
"""Vision Transformer (ViT) backbone. | ||
|
||
This backbone implements the Vision Transformer architecture as described in | ||
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). | ||
It transforms the input image into a sequence of patches, embeds them, and | ||
then processes them through a series of Transformer encoder layers. | ||
|
||
Args: | ||
image_shape: A tuple or list of 3 integers representing the shape of the | ||
input image `(height, width, channels)`, `height` and `width` must | ||
be equal. | ||
patch_size: int. The size of each image patch, the input image will be | ||
divided into patches of shape `(patch_size, patch_size)`. | ||
num_layers: int. The number of transformer encoder layers. | ||
num_heads: int. specifying the number of attention heads in each | ||
Transformer encoder layer. | ||
hidden_dim: int. The dimensionality of the hidden representations. | ||
mlp_dim: int. The dimensionality of the intermediate MLP layer in | ||
each Transformer encoder layer. | ||
dropout_rate: float. The dropout rate for the Transformer encoder | ||
layers. | ||
attention_dropout: float. The dropout rate for the attention mechanism | ||
in each Transformer encoder layer. | ||
layer_norm_epsilon: float. Value used for numerical stability in | ||
layer normalization. | ||
use_mha_bias: bool. Whether to use bias in the multi-head | ||
attention layers. | ||
use_mlp_bias: bool. Whether to use bias in the MLP layers. | ||
data_format: str. `"channels_last"` or `"channels_first"`, specifying | ||
the data format for the input image. If `None`, defaults to | ||
`"channels_last"`. | ||
dtype: The dtype of the layer weights. Defaults to None. | ||
**kwargs: Additional keyword arguments to be passed to the parent | ||
`Backbone` class. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
image_shape, | ||
patch_size, | ||
num_layers, | ||
num_heads, | ||
hidden_dim, | ||
mlp_dim, | ||
dropout_rate=0.0, | ||
attention_dropout=0.0, | ||
layer_norm_epsilon=1e-6, | ||
use_mha_bias=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these args are missing in |
||
use_mlp_bias=True, | ||
data_format=None, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
# === Laters === | ||
data_format = standardize_data_format(data_format) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add comment === Layers === |
||
h_axis, w_axis, channels_axis = ( | ||
(-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3) | ||
) | ||
# Check that the input image is well specified. | ||
if image_shape[h_axis] is None or image_shape[w_axis] is None: | ||
raise ValueError( | ||
f"Image shape must have defined height and width. Found `None` " | ||
f"at index {h_axis} (height) or {w_axis} (width). " | ||
f"Image shape: {image_shape}" | ||
) | ||
if image_shape[h_axis] != image_shape[w_axis]: | ||
raise ValueError( | ||
f"Image height and width must be equal. Found height: " | ||
f"{image_shape[h_axis]}, width: {image_shape[w_axis]} at " | ||
f"indices {h_axis} and {w_axis} respectively. Image shape: " | ||
f"{image_shape}" | ||
) | ||
|
||
num_channels = image_shape[channels_axis] | ||
|
||
# === Functional Model === | ||
inputs = keras.layers.Input(shape=image_shape) | ||
|
||
x = ViTPatchingAndEmbedding( | ||
image_size=image_shape[h_axis], | ||
patch_size=patch_size, | ||
hidden_dim=hidden_dim, | ||
num_channels=num_channels, | ||
data_format=data_format, | ||
dtype=dtype, | ||
name="vit_patching_and_embedding", | ||
)(inputs) | ||
|
||
output = ViTEncoder( | ||
num_layers=num_layers, | ||
num_heads=num_heads, | ||
hidden_dim=hidden_dim, | ||
mlp_dim=mlp_dim, | ||
dropout_rate=dropout_rate, | ||
attention_dropout=attention_dropout, | ||
layer_norm_epsilon=layer_norm_epsilon, | ||
use_mha_bias=use_mha_bias, | ||
use_mlp_bias=use_mlp_bias, | ||
dtype=dtype, | ||
name="vit_encoder", | ||
)(x) | ||
|
||
super().__init__( | ||
inputs=inputs, | ||
outputs=output, | ||
dtype=dtype, | ||
**kwargs, | ||
) | ||
|
||
# === Config === | ||
self.image_shape = image_shape | ||
self.patch_size = patch_size | ||
self.num_layers = num_layers | ||
self.num_heads = num_heads | ||
self.hidden_dim = hidden_dim | ||
self.mlp_dim = mlp_dim | ||
self.dropout_rate = dropout_rate | ||
self.attention_dropout = attention_dropout | ||
self.layer_norm_epsilon = layer_norm_epsilon | ||
self.use_mha_bias = use_mha_bias | ||
self.use_mlp_bias = use_mlp_bias | ||
self.data_format = data_format | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"image_shape": self.image_shape, | ||
"patch_size": self.patch_size, | ||
"num_layers": self.num_layers, | ||
"num_heads": self.num_heads, | ||
"hidden_dim": self.hidden_dim, | ||
"mlp_dim": self.mlp_dim, | ||
"dropout_rate": self.dropout_rate, | ||
"attention_dropout": self.attention_dropout, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
"use_mha_bias": self.use_mha_bias, | ||
"use_mlp_bias": self.use_mlp_bias, | ||
} | ||
) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pytest | ||
from keras import ops | ||
|
||
from keras_hub.src.models.vit.vit_backbone import ViTBackbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class ViTBackboneTest(TestCase): | ||
def setUp(self): | ||
self.init_kwargs = { | ||
"image_shape": (28, 28, 3), | ||
"patch_size": 4, | ||
"num_layers": 3, | ||
"hidden_dim": 48, | ||
"num_heads": 6, | ||
"mlp_dim": 48 * 4, | ||
"use_mha_bias": True, | ||
} | ||
self.input_size = 28 | ||
self.input_data = ops.ones((2, self.input_size, self.input_size, 3)) | ||
|
||
def test_backbone_basics(self): | ||
self.run_backbone_test( | ||
cls=ViTBackbone, | ||
init_kwargs={**self.init_kwargs}, | ||
input_data=self.input_data, | ||
expected_output_shape=(2, 50, 48), | ||
run_quantization_check=False, | ||
) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=ViTBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.image_classifier import ImageClassifier | ||
from keras_hub.src.models.vit.vit_backbone import ViTBackbone | ||
from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( | ||
ViTImageClassifierPreprocessor, | ||
) | ||
|
||
|
||
@keras_hub_export("keras_hub.models.ViTImageClassifier") | ||
class ViTImageClassifier(ImageClassifier): | ||
backbone_cls = ViTBackbone | ||
preprocessor_cls = ViTImageClassifierPreprocessor | ||
|
||
def __init__(self, pooling="token", **kwargs): | ||
super().__init__(pooling=pooling, **kwargs) |
12 changes: 12 additions & 0 deletions
12
keras_hub/src/models/vit/vit_image_classifier_preprocessor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.image_classifier_preprocessor import ( | ||
ImageClassifierPreprocessor, | ||
) | ||
from keras_hub.src.models.vit.vit_backbone import ViTBackbone | ||
from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter | ||
|
||
|
||
@keras_hub_export("keras_hub.models.ViTImageClassifierPreprocessor") | ||
class ViTImageClassifierPreprocessor(ImageClassifierPreprocessor): | ||
backbone_cls = ViTBackbone | ||
image_converter_cls = ViTImageConverter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from keras_hub.src.models.vit.vit_backbone import ViTBackbone | ||
from keras_hub.src.models.vit.vit_image_classifier import ViTImageClassifier | ||
from keras_hub.src.models.vit.vit_image_classifier_preprocessor import ( | ||
ViTImageClassifierPreprocessor, | ||
) | ||
from keras_hub.src.models.vit.vit_image_converter import ViTImageConverter | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class ViTImageClassifierTest(TestCase): | ||
def setUp(self): | ||
self.images = np.ones((2, 28, 28, 3)) | ||
self.labels = [0, 1] | ||
self.backbone = ViTBackbone( | ||
image_shape=(28, 28, 3), | ||
patch_size=4, | ||
num_layers=3, | ||
num_heads=6, | ||
hidden_dim=48, | ||
mlp_dim=48 * 4, | ||
) | ||
image_converter = ViTImageConverter( | ||
image_size=(28, 28), | ||
scale=1 / 255.0, | ||
) | ||
preprocessor = ViTImageClassifierPreprocessor( | ||
image_converter=image_converter | ||
) | ||
self.init_kwargs = { | ||
"backbone": self.backbone, | ||
"num_classes": 2, | ||
"preprocessor": preprocessor, | ||
} | ||
self.train_data = (self.images, self.labels) | ||
|
||
def test_classifier_basics(self): | ||
self.run_task_test( | ||
cls=ViTImageClassifier, | ||
init_kwargs=self.init_kwargs, | ||
train_data=self.train_data, | ||
expected_output_shape=(2, 2), | ||
) | ||
|
||
def test_head_dtype(self): | ||
model = ViTImageClassifier(**self.init_kwargs, head_dtype="bfloat16") | ||
self.assertEqual(model.output_dense.compute_dtype, "bfloat16") | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=ViTImageClassifier, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.images, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"token" feels like a bit a weird name here, especially when compared to
"avg"
or"max"
. Maybe"first"
?