Ming Dai, Lingfeng Yang, Yihao Xu, Zhenhua Feng, Wankou Yang*
SouthEast University
- 2024.09.26: Our work has been accepted by NeurIPS 2024.
- 2024.10.13: The code and model are released.
Visual grounding is a common vision task that involves grounding descriptive sentences to the corresponding regions of an image. Most existing methods use independent image-text encoding and apply complex hand-crafted modules or encoder-decoder architectures for modal interaction and query reasoning. However, their performance significantly drops when dealing with complex textual expressions. This is because the former paradigm only utilizes limited downstream data to fit the multi-modal feature fusion. Therefore, it is only effective when the textual expressions are relatively simple. In contrast, given the wide diversity of textual expressions and the uniqueness of downstream training data, the existing fusion module, which extracts multimodal content from a visual-linguistic context, has not been fully investigated. In this paper, we present a simple yet robust transformer-based framework, SimVG, for visual grounding. Specifically, we decouple visual-linguistic feature fusion from downstream tasks by leveraging existing multimodal pre-trained models and incorporating additional object tokens to facilitate deep integration of downstream and pre-training tasks. Furthermore, we design a dynamic weight-balance distillation method in the multi-branch synchronous learning process to enhance the representation capability of the simpler branch. This branch only consists of a lightweight MLP, which simplifies the structure and improves reasoning speed. Experiments on six widely used VG datasets, \textit{i.e.}, RefCOCO/+/g, ReferIt, Flickr30K, and GRefCOCO, demonstrate the superiority of SimVG. Finally, the proposed method not only achieves improvements in efficiency and convergence speed but also attains new state-of-the-art performance on these benchmarks.
CUDA=11.8 torch=2.0.0 torchvision=0.15.0
pip install -r requirements.txt
Our code depends on parts of detrex and detectron2, so you need to install and compile them.
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
git clone https://github.com/IDEA-Research/detrex.git
cd detrex
git submodule init
git submodule update
pip install -e .
Then install SimVG package in editable mode:
pip install -e .
For the construction of the dataset (flickr30k, referit, refercoco/+/g), please refer to the SeqTR repository. For the construction of the grefcoco dataset, please refer to the GREC repository. In addition, we provide a converted version of the grefcoco dataset here.
The data structure should look like the following:
| -- data
| -- annotations
| -- flickr30k
| -- referitgame-berkeley
| -- refcoco-unc
| -- refcocoplus-unc
| -- refcocog-umd
| -- refcocog-google
| -- grefcoco
| -- instances.json
| -- mixed
| -- images
| -- mscoco
| -- saiaprtc12
| -- flickr30k
| -- visual-genome
SimVG
utilizes the BEiT-3 model as both the backbone and the multi-modality fusion module. The pre-trained weights can be downloaded from this link. Additionally, you will need to download the tokenizer for BEiT-3.
First, create a directory for the pre-trained weights:
mkdir pretrain_weights
Place the BEiT checkpoints and tokenizer within this directory.
The final directory structure of SimVG should resemble the following:
SimVG
├── configs
├── data
├── docs
├── pretrain_weights
├── simvg
└── tools
We train SimVG on a single RTX3090 GPU with 24 GB memory. The following script performs the training:
In this setting, you can complete both training decoder branch (DB) and distill the token branch (TB) in one training process.
python tools/train.py configs/simvg/single/ViT-base/[DATASET_NAME]/[DATASET_NAME]_onestage.py
[DATASET_NAME] is one of "flickr30k", "referit", "refcoco", "refcocoplus", "refcocog", and "refcocoggoogle".
This setting can further improve the performance of model. But you should complete it with more training time. And you should firstly train decoder branch.
python tools/train.py configs/simvg/single/ViT-base/[DATASET_NAME]/[DATASET_NAME]_twostage_1.py
Then load the weight to distill the TB branch.
python tools/train.py configs/simvg/single/ViT-base/[DATASET_NAME]/[DATASET_NAME]_twostage_2.py --load-from <pth/of/stage1>
You can select to add the "--load-from" option or change the "load-from" setting in the config file.
We pre-train SimVG on 8 RTX3090 GPUs with 24 GB memory.
For pre-training all flickr30k/referit/refcoco/+/g mix datasets(174K images):
bash tools/dist_train.sh configs/simvg/mix/ViT-base/pretrian-mixed.py 8
For pre-training refcoco/+/g cocoall dataset(28K images):
bash tools/dist_train.sh configs/simvg/mix/ViT-base/pretrain-cocoall.py 8
We finetune SimVG on 1 RTX3090 GPUs with 24GB memory.
For finetuning with the model pretrained in mix datasets (174K images).
python tools/train.py configs/simvg/mix/ViT-base/finetune_mix/noema#finetune#[[DATASET_NAME]].py --load-from <pth/of/pretrian-mixed>
For finetuning with the model pretrained in coocall datasets (28K images).
python tools/train.py configs/simvg/mix/ViT-base/finetune_coco_all/noema#finetune#[[DATASET_NAME]].py --load-from <pth/of/pretrain-cocoall>
Fine-tuning only further improve the decoder branch performance, if you want to use the simple token branch, you can select this mode to distill the model.
For distilling with the model pretrained in mix datasets (174K images).
python tools/train.py configs/simvg/mix/ViT-base/two-stage_distill_mix/noema#finetune#[[DATASET_NAME]].py --load-from <pth/of/pretrian-mixed>
For distilling with the model pretrained in coocall datasets (28K images).
python tools/train.py configs/simvg/mix/ViT-base/two-stage_distill_coco_all/noema#finetune#[[DATASET_NAME]].py --load-from <pth/of/pretrain-cocoall>
You can use the following instruction for testing all type of models.
python tools/test.py [PATH_TO_CONFIG_FILE] --load-from [PATH_TO_CHECKPOINT_FILE]
val | testA | testB | url | |
SimVG(ViT-L/32, DB) | 90.51 | 92.37 | 87.07 | model & log |
SimVG(ViT-L/32, refcocoallpretrain, two-stage distillation, TB) | 92.99 | 94.86 | 90.12 | model & log |
val | testA | testB | url | |
SimVG(ViT-L/32, DB) | 84.88 | 88.50 | 78.66 | model & log |
SimVG(ViT-L/32, refcocoallpretrain, two-stage distillation, TB) | 87.43 | 91.02 | 82.10 | model & log |
val-g | url | val-u | test-u | url | |
SimVG(ViT-L/32, DB) | 80.42 | - | 85.72 | 86.70 | model & log |
SimVG(ViT-L/32, refcocoallpretrain, two-stage distillation, TB) | - | - | 87.99 | 89.15 | model & log |
flickr30k | url | referit | url | |
SimVG(ViT-L/32, DB) | 78.75 | model & log | 83.15 | model & log |
This codebase is partially based on SeqTR and BEiT-3.
@misc{simvg,
title={SimVG: A Simple Framework for Visual Grounding with Decoupled Multi-modal Fusion},
author={Ming Dai and Lingfeng Yang and Yihao Xu and Zhenhua Feng and Wankou Yang},
year={2024},
eprint={2409.17531},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2409.17531},
}