Skip to content
forked from NVlabs/GCVit

Official PyTorch implementation of Global Context Vision Transformers

License

Notifications You must be signed in to change notification settings

Zhuifeng414/GCVit

 
 

Repository files navigation

Global Context Vision Transformer (GC ViT)

This repository is the official PyTorch implementation of Global Context Vision Transformers.

Global Context Vision Transformers
Ali Hatamizadeh, Hongxu (Danny) Yin, Jan Kautz, and Pavlo Molchanov.

GC ViT achieves state-of-the-art results across image classification, object detection and semantic segmentation tasks. On ImageNet-1K dataset for classification, the tiny, small, base and large variants of GC ViT with 28M, 51M, 90M and 201M parameters achieve 83.4, 83.9, 84.5 and 84.8 Top-1 accuracy, respectively, surpassing comparably-sized prior art such as CNN-based ConvNeXt and ViT-based Swin Transformer by a large margin. Pre-trained GC ViT backbones in downstream tasks of object detection, instance segmentation, and semantic segmentation using MS COCO and ADE20K datasets outperform prior work consistently, sometimes by large margins.

teaser

The architecture of GC ViT is demonstrated in the following:

teaser

Updates

01/12/2023

  1. Updated pre-trained weights for GC ViT Large model. Please see below for download link.

  2. ImageNet-v2 benchmarks released for all GC ViT models.

11/15/2022

  1. Pre-trained weights for GC ViT Large model are released. Please see below for download link.

08/28/2022

  1. GC ViT and its pre-trained weights are now available as part of timm library.

08/11/2022

  1. New Pre-trained model weights with improved performance have been released. Please see below for download link.
  2. GC ViT model has been updated with enhanced global query generator.

06/23/2022

  1. Pre-trained model weights released. Please see below for download link.

06/17/2022

  1. GC ViT model, training and validation scripts released for ImageNet-1K classification.
  2. Pre-trained model checkpoints will be released soon.

Introduction

GC ViT leverages global context self-attention modules, joint with local self-attention, to effectively yet efficiently model both long and short-range spatial interactions, without the need for expensive operations such as computing attention masks or shifting local windows.

teaser

ImageNet Benchmarks

ImageNet-1K Pretrained Models

Name Acc@1 Acc@5 Resolution #Params FLOPs Summary Download
GC ViT-XXT 79.8 95.1 224x224 12 2.1 summary model
GC ViT-XT 82.0 96.0 224x224 20 2.6 summary model
GC ViT-T 83.4 96.4 224x224 28 4.7 summary model
GC ViT-S 83.9 96.6 224x224 51 8.5 summary model
GC ViT-B 84.5 96.8 224x224 90 14.8 summary model
GC ViT-L 84.8 97.1 224x224 201 32.6 summary model

ImageNet-v2 Benchmarks

Name Acc@1-ImageNet-v2 Acc@1-ImageNet-1K Resolution
GC ViT-XXT 69.3 79.8 224x224
GC ViT-XT 71.3 82.0 224x224
GC ViT-T 73.1 83.4 224x224
GC ViT-S 73.8 83.9 224x224
GC ViT-B 74.4 84.5 224x224
GC ViT-L 74.9 84.8 224x224

Installation

This repository is compatible with NVIDIA PyTorch docker nvcr>=21.06 which can be obtained in this link.

The dependencies can be installed by running:

pip install -r requirements.txt

Data Preparation

Please download the ImageNet dataset from its official website. The training and validation images need to have sub-folders for each class with the following structure:

  imagenet
  ├── train
  │   ├── class1
  │   │   ├── img1.jpeg
  │   │   ├── img2.jpeg
  │   │   └── ...
  │   ├── class2
  │   │   ├── img3.jpeg
  │   │   └── ...
  │   └── ...
  └── val
      ├── class1
      │   ├── img4.jpeg
      │   ├── img5.jpeg
      │   └── ...
      ├── class2
      │   ├── img6.jpeg
      │   └── ...
      └── ...
 

Commands

Training on ImageNet-1K From Scratch (Multi-GPU)

The GC ViT model can be trained from scratch on ImageNet-1K dataset by running:

python -m torch.distributed.launch --nproc_per_node <num-of-gpus> --master_port 11223  train.py \ 
--config <config-file> --data_dir <imagenet-path> --batch-size --amp <batch-size-per-gpu> --tag <run-tag> --model-ema

To resume training from a pre-trained checkpoint:

python -m torch.distributed.launch --nproc_per_node <num-of-gpus> --master_port 11223  train.py \ 
--resume <checkpoint-path> --config <config-file> --amp --data_dir <imagenet-path> --batch-size <batch-size-per-gpu> --tag <run-tag> --model-ema

Evaluation

To evaluate a pre-trained checkpoint using ImageNet-1K validation set on a single GPU:

python validate.py --model <model-name> --checkpoint <checkpoint-path> --data_dir <imagenet-path> --batch-size <batch-size-per-gpu>

Citation

Please consider citing GC ViT paper if it is useful for your work:

@article{hatamizadeh2022global,
  title={Global Context Vision Transformers},
  author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
  journal={arXiv preprint arXiv:2206.09959},
  year={2022}
}

Third-party Implementations and Resources

In this section, we list third-party contributions by other users. If you would like to have your work included here, please raise an issue in this repository.

Name Link Contributor Note
timm Link @rwightman PyTorch
tfgcvit Link @shkarupa-alex Tensorflow 2.0 (Keras)
gcvit-tf Link @awsaf49 Tensorflow 2.0 (Keras)
GCViT-TensorFlow Link @EMalagoli92 Tensorflow 2.0 (Keras)
keras_cv_attention_models Link @leondgarse Keras
Paper Explanation Link @awsaf49 Annotated GC ViT
Colab Notebook Link @awsaf49 Flower classification
Kaggle Notebook Link @awsaf49 Flower classification
Live Demo Link @awsaf49 Hugging Face demo

Acknowledgement

  • This repository is built upon the timm library.

  • We would like to sincerely thank the community especially Github users @rwightman, @shkarupa-alex, @awsaf49, @leondgarse, who have provided insightful feedback, which has helped us to further improve GC ViT and achieve even better benchmarks.

Licenses

Copyright © 2022, NVIDIA Corporation. All rights reserved.

This work is made available under the Nvidia Source Code License-NC. Click here to view a copy of this license.

The pre-trained models are shared under CC-BY-NC-SA-4.0. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.

For license information regarding the timm repository, please refer to its official website.

For license information regarding the ImageNet dataset, please refer to its official website.

About

Official PyTorch implementation of Global Context Vision Transformers

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.6%
  • Shell 0.4%