Skip to content

Commit

Permalink
Add imvoxelnet model for vic3d late fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
haibao-yu committed Nov 25, 2022
1 parent 93be0e3 commit 8cc5bc8
Show file tree
Hide file tree
Showing 6 changed files with 515 additions and 7 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ we open source the OpenDAIRV2X, towards serving VICAD research community.

It will directly support different modaility detectors including image-modality detector, pointcloud-modality detector and image-pointcloud fusion detector.
Now it supports image-modality detector ImvoxelNet, pointcloud-modality detector PointPillars.
- [ ] Image-modality
- [x] Image-modality
- [x] Pointcloud-modality
- [ ] Multi-modality

Expand All @@ -66,9 +66,9 @@ VIC3D Benchmark is as following:
| Modality | Fusion | Model | Dataset | AP-3D (IoU=0.5) | | | | AP-BEV (IoU=0.5) | | | | AB |
| :-------: | :-----: | :--------: | :-------: | :----: | :----: | :----: | :-----: | :-----: | :---: | :----: | :-----: | :----: |
| | | | | Overall | 0-30m | 30-50m | 50-100m | Overall | 0-30m | 30-50m | 50-100m | |
| Image | VehOnly | ImvoxelNet | VIC-Sync | | | | | | | | | |
| | InfOnly | ImvoxelNet | VIC-Sync | | | | | | | | | |
| | Late-Fusion | ImvoxelNet | VIC-Sync | | | | | | | | | |
| Image | VehOnly | ImvoxelNet | VIC-Sync | 9.13 | 19.06 | 5.23 | 0.41 | 10.96 | 21.93 | 7.28 | 0.78 | 0 |
| | InfOnly | ImvoxelNet | VIC-Sync | 14.02 | 20.56 | 8.89 | 10.57 | 22.10 | 27.33 | 17.45 | 18.92 | 309.38 |
| | Late-Fusion | ImvoxelNet | VIC-Sync | 18.77 | 33.47 | 9.43 | 8.62 | 24.85 | 39.49 | 14.68 | 14.96 | 309.38|
|Pointcloud | VehOnly | PointPillars | VIC-Sync | 48.06 | 47.62 | 63.51 | 44.37 | 52.24 | 30.55 | 66.03 | 48.36 | 0 |
| | InfOnly | PointPillars | VIC-Sync | 17.58 | 23.00 | 13.96 | 9.17 | 27.26 | 29.07 | 23.92 | 26.64 | 478.61 |
| | Early Fusion | PointPillars | VIC-Sync | 62.61 | 64.82 | 68.68 | 56.57 | 68.91 | 68.92 | 73.64 | 65.66 | 1382275.75 |
Expand Down
125 changes: 124 additions & 1 deletion configs/vic3d/late-fusion-image/imvoxelnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,127 @@

## Introduction

We implement ImvoxelNet and provide the results and checkpoints on VIC-Sync datasets with MMDetection3D.
We implement ImvoxelNet and provide the results and checkpoints on VIC-Sync datasets with MMDetection3D.

## Results and models

| Modality | Fusion | Model | Dataset | | AP-3D(IoU=0.5) | | | | AP-BEV(IoU=0.5) | | | AB(Byte) | Download |
| -------- | ----------- | ---------- | -------- | ------- | ------------- | ------ | ------- | ------- | --------------- | ------ | ------- | ------ | ------------------------------------------------------------ |
| | | | | Overall | 0-30m | 30-50m | 50-100m | Overall | 0-30m | 30-50m | 50-100m | | |
| Image | Veh Only | ImvoxelNet | VIC-Sync | 9.13 | 19.06 | 5.23 | 0.41 | 10.96 | 21.93 | 7.28 | 0.78 | 0 | [model_v](https://drive.google.com/file/d/1dNupazp9t2D6mN8cs1ER8zuf3j9ZHNd6/view?usp=sharing) |
| Image | Inf Only | ImvoxelNet | VIC-Sync | 14.02 | 20.56 | 8.89 | 10.57 | 22.10 | 27.33 | 17.45 | 18.92 | 309.38 | [model_i](https://drive.google.com/file/d/1F0QSlsGQhtMd3Q66CcXgQJKZptERYhhk/view?usp=sharing) |
| Image | Late Fusion | ImvoxelNet | VIC-Sync | 18.77 | 33.47 | 9.43 | 8.62 | 24.85 | 39.49 | 14.68 | 14.96 | 309.38 | |

## Training & Evaluation

### Data Preparation
#### Download data and organise as follows
```
# For DAIR-V2X-C Dataset located at ${DAIR-V2X-C_DATASET_ROOT}
└── cooperative-vehicle-infrastructure <-- DAIR-V2X-C
└──── infrastructure-side <-- DAIR-V2X-C-I
├───── image
├───── velodyne
├───── calib
├───── label
└──── data_info.json
└──── vehicle-side <-- DAIR-V2X-C-V
├───── image
├───── velodyne
├───── calib
├───── label
└───── data_info.json
└──── cooperative
├───── label_world
└───── data_info.json
```

#### Create a symlink to the dataset root
```
cd ${dair-v2x_root}/dair-v2x
mkdir ./data/DAIR-V2X
ln -s ${DAIR-V2X-C_DATASET_ROOT}/cooperative-vehicle-infrastructure ${dair-v2x_root}/dair-v2x/data/DAIR-V2X
```

#### Create Kitti-format data (Option for model training)

Data creation should be under the gpu environment.
```commandline
# Kitti Format
cd ${dair-v2x_root}/dair-v2x
python tools/dataset_converter/dair2kitti.py --source-root ./data/DAIR-V2X/cooperative-vehicle-infrastructure/infrastructure-side \
--target-root ./data/DAIR-V2X/cooperative-vehicle-infrastructure/infrastructure-side \
--split-path ./data/split_datas/cooperative-split-data.json \
--label-type lidar --sensor-view infrastructure --no-classmerge
python tools/dataset_converter/dair2kitti.py --source-root ./data/DAIR-V2X/cooperative-vehicle-infrastructure/vehicle-side \
--target-root ./data/DAIR-V2X/cooperative-vehicle-infrastructure/vehicle-side \
--split-path ./data/split_datas/cooperative-split-data.json \
--label-type lidar --sensor-view vehicle --no-classmerge
```
In the end, the data and info files should be organized as follows
```
└── cooperative-vehicle-infrastructure <-- DAIR-V2X-C
└──── infrastructure-side <-- DAIR-V2X-C-I
├───── image
├───── velodyne
├───── calib
├───── label
├───── data_info.json
├───── ImageSets
└──── training
├───── image_2
├───── velodyne
├───── label_2
└───── calib
└──── testing
├───── vehicle-side <-- DAIR-V2X-C-V
├───── image
├───── velodyne
├───── calib
├───── label
├───── data_info.json
├───── ImageSets
└──── training
├───── image_2
├───── velodyne
├───── label_2
└───── calib
└──── cooperative
├───── label_world
└───── data_info.json
```

* VIC-Sync Dataset. VIC-Sync dataset is extracted from DAIR-V2X-C, which is composed of 9311 pairs of infrastructure and vehicle frames as well as their cooperative annotations as ground truth.
We split VIC-Sync dataset to train/valid/test part as 5:2:3 respectively.
Please refer [split data](../../../data/split_datas/cooperative-split-data.json) for the splitting file.


### Training
* Implementation Framework.
We directly use MMDetection3D (v0.17.1) to train the infrastructure 3D detector and vehicle 3D detector.
* Infrastructure detector training details.
Before training the detectors, we should follow MMDetection3D to convert the "./data/DAIR-V2X/cooperative-vehicle-infrastructure/infrastructure-side" into specific training format.
Then we train the PointPillars with configure file [trainval_config_i.py](./trainval_config_i.py)

* Vehicle detector training details.
Before training the detectors, we should follow MMDetection3D to convert the "./data/DAIR-V2X/cooperative-vehicle-infrastructure/vehicle-side" into specific training format.
Then we train the PointPillars with configure file [trainval_config_v.py](./trainval_config_v.py)

### Evaluation

Download following checkpoints and place them in this directory.
* [vic3d_latefusion_inf_imvoxelnet](https://drive.google.com/file/d/1F0QSlsGQhtMd3Q66CcXgQJKZptERYhhk/view?usp=sharing)
* [vic3d_latefusion_veh_imvoxelnet](https://drive.google.com/file/d/1dNupazp9t2D6mN8cs1ER8zuf3j9ZHNd6/view?usp=sharing)

Then use the following commands to get the evaluation results.
```
# An example to get the late fusion evaluation results within [0, 100]m range on VIC-Sync dataset
# bash scripts/eval_camera_late_fusion_imvoxelnet.sh [YOUR_CUDA_DEVICE] [FUSION_METHOD] [DELAY_K] [EXTEND_RANGE_START] [EXTEND_RANGE_END] [TIME_COMPENSATION]
cd ${dair-v2x_root}/dair-v2x/v2x
bash scripts/eval_camera_late_fusion_imvoxelnet.sh 0 late_fusion 0 0 100 --no-comp
```
* FUSION_METHOD candidates: [veh_only, inf_only, late_fusion].
* DELAY_K candidates: [0, 1, 2]. 0 denotes VIC-Sync dataset, 1 denotes VIC-Async-1 dataset,
2 denotes VIC-Async-2 dataset.
* [EXTEND_RANGE_START, EXTEND_RANGE_END] candidates: [[0, 100], [0, 30], [30, 50], [50, 100]].
* TIME_COMPENSATION candidates: [, --no-comp]. Empty denotes that we use time compensation to alleviate the temporal asyncrony problem.
170 changes: 170 additions & 0 deletions configs/vic3d/late-fusion-image/imvoxelnet/trainval_config_i.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
dataset_type = 'KittiDataset'
data_root = '../../../../data/DAIR-V2X/cooperative-vehicle-infrastructure/infrastructure-side/'
class_names = ['Car']
input_modality = dict(use_lidar=False, use_camera=True)
point_cloud_range = [0, -39.68, -3, 92.16, 39.68, 1]
voxel_size = [0.32, 0.32, 0.33]
length = int((point_cloud_range[3] - point_cloud_range[0]) / voxel_size[0])
width = int((point_cloud_range[4] - point_cloud_range[1]) / voxel_size[1])
height = int((point_cloud_range[5] - point_cloud_range[2]) / voxel_size[2])
output_shape = [width, length, height]
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
img_scale = (960, 540)
img_resize_scale = [(912, 513), (1008, 567)]

work_dir = './work_dirs/vic3d_latefusion_inf_imvoxelnet'

model = dict(
type='ImVoxelNet',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=64,
num_outs=4),
neck_3d=dict(type='OutdoorImVoxelNeck', in_channels=64, out_channels=256),
bbox_head=dict(
type='Anchor3DHead',
num_classes=1,
in_channels=256,
feat_channels=256,
use_direction_classifier=True,
anchor_generator=dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[0, -39.68, -1.78, 92.16, 39.68, -1.78]],
sizes=[[3.9, 1.6, 1.56]],
rotations=[0, 1.57],
reshape_out=True),
diff_rad_by_sin=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)),
n_voxels=output_shape,
anchor_generator=dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[0, -39.68, -3.08, 92.16, 39.68, 0.76]],
rotations=[.0]),
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6,
neg_iou_thr=0.45,
min_pos_iou=0.45,
ignore_iof_thr=-1),
allowed_border=0,
pos_weight=-1,
debug=False),
test_cfg=dict(
use_rotate_nms=True,
nms_across_levels=False,
nms_thr=0.01,
score_thr=0.1,
min_bbox_size=0,
nms_pre=100,
max_num=50))

train_pipeline = [
dict(type='LoadAnnotations3D'),
dict(type='LoadImageFromFile'),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(
type='Resize',
img_scale=img_resize_scale,
keep_ratio=True,
multiscale_mode='range'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', img_scale=img_scale, keep_ratio=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img'])
]

data = dict(
samples_per_gpu=1,
workers_per_gpu=1,
train=dict(
type='RepeatDataset',
times=1,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'kitti_infos_train.pkl',
split='training',
pts_prefix='velodyne_reduced',
pipeline=train_pipeline,
modality=input_modality,
classes=class_names,
test_mode=False)),
val=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'kitti_infos_val.pkl',
split='training',
pts_prefix='velodyne_reduced',
pipeline=test_pipeline,
modality=input_modality,
classes=class_names,
test_mode=True),
test=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'kitti_infos_val.pkl',
split='training',
pts_prefix='velodyne_reduced',
pipeline=test_pipeline,
modality=input_modality,
classes=class_names,
box_type_3d="Lidar",
test_mode=True))

optimizer = dict(
type='AdamW',
lr=0.0001,
weight_decay=0.0001,
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2))
lr_config = dict(policy='step', step=[8, 11])
total_epochs = 12

checkpoint_config = dict(interval=1, max_keep_ckpts=1)
log_config = dict(
interval=50,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
evaluation = dict(interval=1)
dist_params = dict(backend='nccl')
find_unused_parameters = True # only 1 of 4 FPN outputs is used
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
Loading

0 comments on commit 8cc5bc8

Please sign in to comment.