-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kadir Yilmaz
committed
Jan 12, 2024
0 parents
commit 3252ede
Showing
124 changed files
with
11,586 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
.DS_STORE | ||
/saved | ||
/logs | ||
/data/semantic_kitti | ||
/.vscode | ||
.python-version | ||
__pycache__/ | ||
*.out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# MASK4D: Mask Transformer for 4D Panoptic Segmentation | ||
<div align="center"> | ||
<a href="https://github.com/YilmazKadir/">Kadir Yilmaz</a>, | ||
<a href="https://jonasschult.github.io/">Jonas Schult</a>, | ||
<a href="https://nekrasov.dev/">Alexey Nekrasov</a>, | ||
<a href="https://www.vision.rwth-aachen.de/person/1/">Bastian Leibe</a> | ||
|
||
RWTH Aachen University | ||
|
||
MASK4D is a transformer-based model for 4D Panoptic Segmentation, achieving a new state-of-the-art performance on the SemanticKITTI test set. | ||
|
||
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a> | ||
<a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a> | ||
<a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a> | ||
<a href="https://black.readthedocs.io/en/stable/"><img alt="Code style: black" src="https://img.shields.io/badge/code%20style-black-black.svg"></a> | ||
|
||
![teaser](./docs/github_teaser.jpg) | ||
|
||
</div> | ||
<br><br> | ||
|
||
[[Project Webpage](https://vision.rwth-aachen.de/mask4d)] [[arXiv](https://arxiv.org/abs/2309.16133)] | ||
|
||
## News | ||
|
||
* **2023-09-28**: Paper on arXiv | ||
|
||
### Dependencies | ||
The main dependencies of the project are the following: | ||
```yaml | ||
python: 3.8 | ||
cuda: 11.7 | ||
``` | ||
You can set up a conda environment as follows | ||
``` | ||
conda create --name mask4d python=3.8 | ||
conda activate mask4d | ||
pip install -r requirements.txt | ||
|
||
pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117 | ||
|
||
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+cu117.html | ||
|
||
pip install 'git+https://github.com/facebookresearch/detectron2.git@710e7795d0eeadf9def0e7ef957eea13532e34cf' --no-deps | ||
|
||
cd third_party/pointnet2 && python setup.py install | ||
|
||
cd .. | ||
git clone https://github.com/NVIDIA/MinkowskiEngine.git | ||
cd MinkowskiEngine | ||
python setup.py install | ||
cd ../.. | ||
``` | ||
|
||
### Data preprocessing | ||
After installing the dependencies, we preprocess the SemanticKITTI dataset. | ||
|
||
``` | ||
python -m datasets.preprocessing.semantic_kitti_preprocessing preprocess \ | ||
--data_dir "PATH_TO_RAW_SEMKITTI_DATASET" \ | ||
--save_dir "data/semantic_kitti" | ||
python -m datasets.preprocessing.semantic_kitti_preprocessing make_instance_database \ | ||
--data_dir "PATH_TO_RAW_SEMKITTI_DATASET" \ | ||
--save_dir "data/semantic_kitti" | ||
``` | ||
|
||
### Training and testing | ||
Train MASK4D: | ||
```bash | ||
python main_panoptic.py | ||
``` | ||
|
||
In the simplest case the inference command looks as follows: | ||
```bash | ||
python main_panoptic.py \ | ||
general.mode="validate" \ | ||
general.ckpt_path='PATH_TO_CHECKPOINT.ckpt' | ||
``` | ||
|
||
Or you can use DBSCAN to boost the scores even further: | ||
```bash | ||
python main_panoptic.py \ | ||
general.mode="validate" \ | ||
general.ckpt_path='PATH_TO_CHECKPOINT.ckpt' \ | ||
general.dbscan_eps=1.0 | ||
``` | ||
## Trained checkpoint | ||
[MASK4D](https://omnomnom.vision.rwth-aachen.de/data/mask4d/mask4d.ckpt) | ||
|
||
The provided model, trained after the submission, achieves 71.1 LSTQ without DBSCAN and 71.5 with DBSCAN post-processing. | ||
|
||
## BibTeX | ||
``` | ||
@article{yilmaz2023mask4d, | ||
title = {{MASK4D: Mask Transformer for 4D Panoptic Segmentation}}, | ||
author = {Yilmaz, Kadir and Schult, Jonas and Nekrasov, Alexey and Leibe, Bastian}, | ||
journal = {arXiv prepring arXiv:2309.16133}, | ||
year = {2023} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# pi = 3.14159265358979 | ||
# pi/2 = 1.57079632679489 | ||
# pi/3 = 1.04719755119659 | ||
# pi/6 = 0.52359877559829 | ||
# pi/12 = 0.26179938779914 | ||
# pi/24 = 0.13089969389957 | ||
|
||
__version__: 0.1.6 | ||
transform: | ||
__class_fullname__: volumentations.core.composition.Compose | ||
additional_targets: {} | ||
p: 1.0 | ||
transforms: | ||
- __class_fullname__: volumentations.augmentations.transforms.Scale3d | ||
always_apply: true | ||
p: 0.5 | ||
scale_limit: | ||
- - -0.1 | ||
- 0.1 | ||
- - -0.1 | ||
- 0.1 | ||
- - -0.1 | ||
- 0.1 | ||
- __class_fullname__: volumentations.augmentations.transforms.RotateAroundAxis3d | ||
always_apply: true | ||
axis: | ||
- 0 | ||
- 0 | ||
- 1 | ||
p: 0.5 | ||
rotation_limit: | ||
- -3.141592653589793 | ||
- 3.141592653589793 | ||
- __class_fullname__: volumentations.augmentations.transforms.RotateAroundAxis3d | ||
always_apply: true | ||
axis: | ||
- 0 | ||
- 1 | ||
- 0 | ||
p: 0.5 | ||
rotation_limit: | ||
- -0.13089969389957 | ||
- 0.13089969389957 | ||
- __class_fullname__: volumentations.augmentations.transforms.RotateAroundAxis3d | ||
always_apply: true | ||
axis: | ||
- 1 | ||
- 0 | ||
- 0 | ||
p: 0.5 | ||
rotation_limit: | ||
- -0.13089969389957 | ||
- 0.13089969389957 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# @package _group_ | ||
- _target_: pytorch_lightning.callbacks.ModelCheckpoint | ||
monitor: val_mean_lstq | ||
save_last: true | ||
save_top_k: 1 | ||
mode: max | ||
dirpath: ${general.save_dir} | ||
filename: "{epoch}-{val_mean_lstq:.3f}" | ||
every_n_epochs: 1 | ||
|
||
- _target_: pytorch_lightning.callbacks.LearningRateMonitor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
general: | ||
mode: "train" | ||
seed: null | ||
ckpt_path: null | ||
project_name: mask4d | ||
workspace: kadiryilmaz | ||
instance_population: 20 | ||
dbscan_eps: null | ||
experiment_name: ${now:%Y-%m-%d_%H%M%S} | ||
save_dir: saved/${general.experiment_name} | ||
gpus: 1 | ||
|
||
defaults: | ||
- data: kitti | ||
- data/data_loaders: simple_loader | ||
- data/datasets: semantic_kitti | ||
- data/collation_functions: voxelize_collate | ||
- logging: full | ||
- model: mask4d | ||
- optimizer: adamw | ||
- scheduler: onecyclelr | ||
- trainer: trainer30 | ||
- callbacks: callbacks_panoptic | ||
- matcher: hungarian_matcher | ||
- loss: set_criterion | ||
- metric: lstq | ||
|
||
hydra: | ||
run: | ||
dir: saved/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} | ||
sweep: | ||
dir: saved/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} | ||
# dir: ${general.save_dir} | ||
subdir: ${hydra.job.num}_${hydra.job.id} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# @package data | ||
|
||
train_collation: | ||
_target_: datasets.utils.VoxelizeCollate | ||
ignore_label: ${data.ignore_label} | ||
voxel_size: ${data.voxel_size} | ||
|
||
validation_collation: | ||
_target_: datasets.utils.VoxelizeCollate | ||
ignore_label: ${data.ignore_label} | ||
voxel_size: ${data.voxel_size} | ||
|
||
test_collation: | ||
_target_: datasets.utils.VoxelizeCollate | ||
ignore_label: ${data.ignore_label} | ||
voxel_size: ${data.voxel_size} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# @package data | ||
|
||
train_dataloader: | ||
_target_: torch.utils.data.DataLoader | ||
shuffle: true | ||
pin_memory: ${data.pin_memory} | ||
num_workers: ${data.num_workers} | ||
batch_size: ${data.batch_size} | ||
|
||
validation_dataloader: | ||
_target_: torch.utils.data.DataLoader | ||
shuffle: false | ||
pin_memory: ${data.pin_memory} | ||
num_workers: ${data.num_workers} | ||
batch_size: ${data.test_batch_size} | ||
|
||
test_dataloader: | ||
_target_: torch.utils.data.DataLoader | ||
shuffle: false | ||
pin_memory: ${data.pin_memory} | ||
num_workers: ${data.num_workers} | ||
batch_size: ${data.test_batch_size} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# @package data | ||
train_dataset: | ||
_target_: datasets.lidar.LidarDataset | ||
data_dir: data/semantic_kitti | ||
mode: ${data.train_mode} | ||
add_distance: ${data.add_distance} | ||
sweep: ${data.sweep} | ||
instance_population: ${data.instance_population} | ||
ignore_label: ${data.ignore_label} | ||
volume_augmentations_path: conf/augmentation/volumentations_aug.yaml | ||
|
||
validation_dataset: | ||
_target_: datasets.lidar.LidarDataset | ||
data_dir: data/semantic_kitti | ||
mode: ${data.validation_mode} | ||
add_distance: ${data.add_distance} | ||
sweep: ${data.sweep} | ||
instance_population: 0 | ||
ignore_label: ${data.ignore_label} | ||
volume_augmentations_path: null | ||
|
||
test_dataset: | ||
_target_: datasets.lidar.LidarDataset | ||
data_dir: data/semantic_kitti | ||
mode: ${data.test_mode} | ||
add_distance: ${data.add_distance} | ||
sweep: ${data.sweep} | ||
instance_population: 0 | ||
ignore_label: ${data.ignore_label} | ||
volume_augmentations_path: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# @package _group_ | ||
|
||
# these parameters are inherited by datasets, data_loaders and collators | ||
# but they might be overwritten | ||
|
||
# splits | ||
train_mode: train | ||
validation_mode: validation | ||
test_mode: test | ||
|
||
# dataset | ||
ignore_label: 255 | ||
add_distance: true | ||
in_channels: 2 | ||
num_labels: 19 | ||
instance_population: ${general.instance_population} | ||
sweep: 2 | ||
min_stuff_cls_id: 9 | ||
min_points: 50 | ||
class_names: ['car', 'bicycle', 'motorcycle', 'truck', 'other-vehicle', 'person', 'bicyclist', | ||
'motorcyclist', 'road', 'parking', 'sidewalk', 'other-ground', 'building', 'fence', 'vegetation', | ||
'trunk', 'terrain', 'pole', 'traffic-sign'] | ||
|
||
# data loader | ||
pin_memory: true | ||
num_workers: 4 | ||
batch_size: 4 | ||
test_batch_size: 2 | ||
|
||
# collation | ||
voxel_size: 0.05 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# @package _group_ | ||
- _target_: pytorch_lightning.loggers.WandbLogger | ||
project: ${general.project_name} | ||
name: ${general.experiment_name} | ||
save_dir: ${general.save_dir} | ||
entity: "kadiryilmaz193" | ||
id: ${general.experiment_name} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# @package _group_ | ||
_target_: models.criterion.SetCriterion | ||
num_classes: ${data.num_labels} | ||
eos_coef: 0.1 | ||
losses: | ||
- "labels" | ||
- "masks" | ||
- "bboxs" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# @package _group_ | ||
_target_: models.matcher.HungarianMatcher | ||
cost_class: 2. | ||
cost_mask: 5. | ||
cost_dice: 2. | ||
cost_box: 5. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# @package _group_ | ||
_target_: models.metrics.Panoptic4DEval | ||
n_classes: ${data.num_labels} | ||
min_stuff_cls_id: ${data.min_stuff_cls_id} | ||
min_points: ${data.min_points} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# @package _group_ | ||
_target_: models.Mask4D | ||
|
||
# backbone | ||
backbone: | ||
_target_: models.Res16UNet34C | ||
config: | ||
dialations: [ 1, 1, 1, 1 ] | ||
conv1_kernel_size: 5 | ||
bn_momentum: 0.02 | ||
in_channels: ${data.in_channels} | ||
out_channels: ${data.num_labels} | ||
|
||
# transformer parameters | ||
num_queries: 100 | ||
num_heads: 8 | ||
num_decoders: 3 | ||
num_levels: 4 | ||
sample_sizes: [4000, 8000, 16000, 32000] | ||
mask_dim: 128 | ||
dim_feedforward: 1024 | ||
num_labels: ${data.num_labels} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# @package _group_ | ||
_target_: torch.optim.AdamW | ||
lr: 0.0002 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# @package _group_ | ||
scheduler: | ||
_target_: torch.optim.lr_scheduler.OneCycleLR | ||
max_lr: ${optimizer.lr} | ||
epochs: ${trainer.max_epochs} | ||
# need to set to number because of tensorboard logger | ||
steps_per_epoch: -1 | ||
|
||
pytorch_lightning_params: | ||
interval: step |
Oops, something went wrong.