-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ViT] Vision Transformer (ViT) backbone, layers, and image classifier #1989
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
amazing work! one nit comment.
Also, can you please add a demo colab? and a colab to show the numerics is verified. Basically the validate output block from your conversion script.
dtype=None, | ||
**kwargs, | ||
): | ||
data_format = standardize_data_format(data_format) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comment === Layers ===
dropout_rate=0.0, | ||
attention_dropout=0.0, | ||
layer_norm_epsilon=1e-6, | ||
use_mha_bias=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these args are missing in get_config
Weights Transfer for 4 variants colab gist: https://colab.research.google.com/gist/sineeli/10a7884bef6114eade3b237b63d7f2bd/-keras-hub-vit-weights-transfer.ipynb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! Very nice work. Just a couple comments.
"image resolution of 384x384 " | ||
), | ||
"params": 86090496, | ||
"official_name": "ViT", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we no longer need the official name or model card, we've reduced what we show on Keras.io to make this simpler. Our Kaggle page will act as the new model card.
) | ||
|
||
|
||
def convert_weights(keras_hub_model, hf_model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we write this as an in library converter? Seems very doable and then we expose this to anyone wanting to convert a vit checkpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -137,7 +139,10 @@ def __init__( | |||
# === Functional Model === | |||
inputs = self.backbone.input | |||
x = self.backbone(inputs) | |||
x = self.pooler(x) | |||
if pooling == "token": # used for Vision Transformer(ViT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"token" feels like a bit a weird name here, especially when compared to "avg"
or "max"
. Maybe "first"
?
This PR introduces a Vision Transformer (ViT) implementation