Skip to content

Latest commit

 

History

History
 
 

det

EVA: Object Detection & Instance Segmentation

Table of Contents

We provide fine-tuning and single-scale evaluation code on COCO & LVIS based on EVA pre-trained on Object365. All model weights related to object detection and instance segmentation are available for the community.

Setup

# recommended environment: torch1.9 + cuda11.1
pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install mmcv-full==1.6.1 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html  # for soft-nms

# build EVA det / Detectron2 from source
cd /path/to/EVA/det
python -m pip install -e .

Data preparation

Please prepare COCO 2017 & LVIS v1.0 datasets according to the guidelines in Detectron2.

Prepare Objects365 pre-trained EVA weights

model name #param. pre-training interations on Objects365 weight
eva_o365 1.1B 380k 🤗 HF link (4GB)

Models and results summary

EVA use ViT-Det + Cascade Mask RCNN as the object detection and instance segmentation head. We evaluate EVA on COCO 2017 and LVIS v1.0 benchmarks.

COCO 2017 (single-scale evaluation on val set)

init. model weight batch size iter AP box AP mask config model weight
eva_o365 64 35k 64.2 53.9 config 🤗 HF link (4GB)
eva_o365 64 45k 63.9 55.0 config 🤗 HF link (4GB)

LVIS v1.0 (single-scale evaluation on val set)

init. model weight batch size iter AP box AP mask config model weight
eva_o365 64 75k 62.2 55.0 config 🤗 HF link (4GB)

Evaluation

COCO 2017

Object Detection

PWC
PWC

To evaluate EVA on COCO 2017 val using a single node with 8 gpus:

python tools/lazyconfig_train_net.py --num-gpus 8 \
    --eval-only \
    --config-file projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_eva_1536.py \
    "train.init_checkpoint=/path/to/eva_coco_det.pth" \ # https://huggingface.co/BAAI/EVA/blob/main/eva_coco_det.pth
    "model.roi_heads.use_soft_nms=True" \
    'model.roi_heads.method="linear"' \
    "model.roi_heads.iou_threshold=0.6" \
    "model.roi_heads.override_score_thresh=0.0"

Expected results:

Evaluation results for bbox:
|   AP   |  AP50  |  AP75  |  APs   |  APm   |  APl   |
|:------:|:------:|:------:|:------:|:------:|:------:|
| 64.164 | 81.897 | 70.561 | 49.485 | 68.088 | 77.651 |

Instance Segmentation

PWC
PWC

To evaluate EVA on COCO 2017 val using a single node with 8 gpus:

python tools/lazyconfig_train_net.py --num-gpus 8 \
    --eval-only \
    --config-file projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_eva_1536.py \
    "train.init_checkpoint=/path/to/eva_coco_seg.pth" \ # https://huggingface.co/BAAI/EVA/blob/main/eva_coco_seg.pth
    "model.roi_heads.use_soft_nms=True" \
    'model.roi_heads.method="linear"' \
    "model.roi_heads.iou_threshold=0.6" \
    "model.roi_heads.override_score_thresh=0.0" \
    "model.roi_heads.maskness_thresh=0.5" # use maskness to calibrate mask predictions

Expected results:

Evaluation results for segm:
|   AP   |  AP50  |  AP75  |  APs   |  APm   |  APl   |
|:------:|:------:|:------:|:------:|:------:|:------:|
| 55.024 | 79.400 | 60.872 | 37.584 | 58.435 | 72.034 |

LVIS v1.0 val

PWC
PWC

To evaluate EVA on LVIS v1.0 val using a single node with 8 gpus:

python tools/lazyconfig_train_net.py --num-gpus 8 \
    --eval-only \
    --config-file projects/ViTDet/configs/LVIS/cascade_mask_rcnn_vitdet_eva_1536.py \
    "train.init_checkpoint=/path/to/eva_lvis.pth" \ # https://huggingface.co/BAAI/EVA/blob/main/eva_lvis.pth
    "dataloader.evaluator.max_dets_per_image=1000" \
    "model.roi_heads.maskness_thresh=0.5" # use maskness to calibrate mask predictions

Expected results

# object detection
Evaluation results for bbox:
|   AP   |  AP50  |  AP75  |  APs   |  APm   |  APl   |  APr   |  APc   |  APf   |
|:------:|:------:|:------:|:------:|:------:|:------:|:------:|:------:|:------:|
| 62.169 | 76.198 | 65.364 | 54.086 | 71.103 | 77.228 | 55.149 | 62.242 | 65.172 |

# instance segmentation
Evaluation results for segm:
|   AP   |  AP50  |  AP75  |  APs   |  APm   |  APl   |  APr   |  APc   |  APf   |
|:------:|:------:|:------:|:------:|:------:|:------:|:------:|:------:|:------:|
| 54.982 | 74.214 | 60.114 | 44.894 | 65.657 | 72.792 | 48.329 | 55.478 | 57.352 |

Training

COCO 2017

To train EVA on COCO 2017 using 8 nodes (total_batch_size=64):

python tools/lazyconfig_train_net.py --num-gpus 8 \
    --num-machines $NNODES --machine-rank $NODE_RANK --dist-url "tcp://$MASTER_ADDR:60900" \
    --config-file projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_eva.py \
    "train.init_checkpoint=/path/to/eva_o365.pth" \ # https://huggingface.co/BAAI/EVA/blob/main/eva_o365.pth
    "train.output_dir=/path/to/output"

LVIS v1.0

To train EVA on LVIS v1.0 using 8 nodes (total_batch_size=64):

python tools/lazyconfig_train_net.py --num-gpus 8 \
    --num-machines $NNODES --machine-rank $NODE_RANK --dist-url "tcp://$MASTER_ADDR:60900" \
    --config-file projects/ViTDet/configs/LVIS/cascade_mask_rcnn_vitdet_eva.py \
    "train.init_checkpoint=/path/to/eva_o365.pth" \ # https://huggingface.co/BAAI/EVA/blob/main/eva_o365.pth
    "train.output_dir=/path/to/output"

Acknowledgment

EVA object detection and instance segmentation are built upon Detectron2. Thanks for their awesome work!