-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
fcdl94
authored and
fcdl94
committed
Jul 17, 2020
1 parent
3d054b1
commit 416fca8
Showing
55 changed files
with
4,352 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# FSS | ||
## Few Shot Learning in Semantic Segmentation | ||
|
||
# How to download data | ||
|
||
> cd <target folder> | ||
> ../data/download_voc.sh | ||
# How to run the training | ||
|
||
> python -m torch.distributed.launch --nproc_per_node="total GPUs" train.py --data_root "folder where you downloaded the data" --name "name of exp" --batch_size=4 --num_workers=1 --other_args | ||
The default folder for the logs is logs/"name of exp". The log is in the format of tensorboard. | ||
|
||
The default is to use a pretraining for the backbone used, that is searched in the pretrained folder of the project. If you don't want to use pretrained, please use --no-pretrained. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import argparse | ||
import task | ||
|
||
|
||
def modify_command_options(opts): | ||
if opts.dataset == 'voc': | ||
opts.num_classes = 21 | ||
if opts.dataset == 'ade': | ||
opts.num_classes = 150 | ||
|
||
if not opts.visualize: | ||
opts.sample_num = 0 | ||
|
||
opts.no_cross_val = not opts.cross_val | ||
opts.pooling = round(opts.crop_size / opts.output_stride) | ||
|
||
return opts | ||
|
||
|
||
def get_argparser(): | ||
parser = argparse.ArgumentParser() | ||
|
||
# Performance Options | ||
parser.add_argument("--local_rank", type=int, default=0) | ||
parser.add_argument("--random_seed", type=int, default=42, | ||
help="random seed (default: 42)") | ||
parser.add_argument("--num_workers", type=int, default=1, | ||
help='number of workers (default: 1)') | ||
parser.add_argument('--opt_level', type=str, choices=['O0', 'O1', 'O2', 'O3'], default='O0') | ||
|
||
# Dataset Options | ||
parser.add_argument("--data_root", type=str, default="data", | ||
help="path to Dataset") | ||
parser.add_argument("--dataset", type=str, default='voc', | ||
choices=['voc'], help='Name of dataset') | ||
parser.add_argument("--num_classes", type=int, default=None, | ||
help="num classes (default: None)") | ||
|
||
# Task Options | ||
parser.add_argument("--step", type=int, default=0, | ||
help="Step (0 is base)") | ||
parser.add_argument("--task", type=str, default="15-5", choices=task.get_task_list(), | ||
help="Task to be executed (default: 15-5)") | ||
parser.add_argument("--nshot", type=int, default=5, | ||
help="If step>0, the shot to use for FSL (Def=5)") | ||
parser.add_argument("--ishot", type=int, default=0, | ||
help="First index where to sample shots") | ||
parser.add_argument("--use_bkg", default=False, action='store_true', | ||
help="Whether to use or not the background as a class (def is not)") | ||
parser.add_argument("--input_mix", default="novel", choices=['novel', 'both'], | ||
help="Which class to use for FSL") | ||
|
||
# Train Options | ||
parser.add_argument("--epochs", type=int, default=30, | ||
help="epoch number (default: 30)") | ||
|
||
parser.add_argument("--fix_bn", action='store_true', default=False, | ||
help='fix batch normalization during training (default: False)') | ||
parser.add_argument("--batch_size", type=int, default=4, | ||
help='batch size (default: 4)') | ||
parser.add_argument("--crop_size", type=int, default=512, | ||
help="crop size (default: 512)") | ||
|
||
parser.add_argument("--lr", type=float, default=0.007, | ||
help="learning rate (default: 0.007)") | ||
parser.add_argument("--momentum", type=float, default=0.9, | ||
help='momentum for SGD (default: 0.9)') | ||
parser.add_argument("--weight_decay", type=float, default=1e-4, | ||
help='weight decay (default: 1e-4)') | ||
|
||
parser.add_argument("--lr_policy", type=str, default='poly', | ||
choices=['poly', 'step'], help="lr schedule policy (default: poly)") | ||
parser.add_argument("--lr_decay_step", type=int, default=5000, | ||
help="decay step for stepLR (default: 5000)") | ||
parser.add_argument("--lr_decay_factor", type=float, default=0.1, | ||
help="decay factor for stepLR (default: 0.1)") | ||
parser.add_argument("--lr_power", type=float, default=0.9, | ||
help="power for polyLR (default: 0.9)") | ||
|
||
# Logging Options | ||
parser.add_argument("--logdir", type=str, default='./logs', | ||
help="path to Log directory (default: ./logs)") | ||
parser.add_argument("--name", type=str, default='Experiment', | ||
help="name of the experiment - to append to log directory (default: Experiment)") | ||
parser.add_argument("--sample_num", type=int, default=4, | ||
help='number of samples for visualization (default: 0)') | ||
parser.add_argument("--debug", action='store_true', default=False, | ||
help="verbose option") | ||
parser.add_argument("--visualize", action='store_false', default=True, | ||
help="visualization on tensorboard (def: Yes)") | ||
parser.add_argument("--print_interval", type=int, default=10, | ||
help="print interval of loss (default: 10)") | ||
parser.add_argument("--val_interval", type=int, default=1, | ||
help="epoch interval for eval (default: 1)") | ||
|
||
# Segmentation Architecture Options | ||
parser.add_argument("--backbone", type=str, default='resnet101', | ||
choices=['resnet50', 'resnet101'], help='backbone for the body (def: resnet50)') | ||
parser.add_argument("--output_stride", type=int, default=16, | ||
choices=[8, 16], help='stride for the backbone (def: 16)') | ||
parser.add_argument("--no_pretrained", action='store_true', default=False, | ||
help='Wheather to use pretrained or not (def: True)') | ||
parser.add_argument("--norm_act", type=str, default="iabn_sync", | ||
choices=['iabn_sync', 'iabn', 'abn', 'std'], help='Which BN to use (def: abn_sync') | ||
parser.add_argument("--fusion-mode", metavar="NAME", type=str, choices=["mean", "voting", "max"], default="mean", | ||
help="How to fuse the outputs. Options: 'mean', 'voting', 'max'") | ||
|
||
# Test and Checkpoint options | ||
parser.add_argument("--test", action='store_true', default=False, | ||
help="Whether to train or test only (def: train and test)") | ||
parser.add_argument("--ckpt", default=None, type=str, | ||
help="path to trained model. Leave it None if you want to retrain your model") | ||
parser.add_argument("--continue_ckpt", default=False, action='store_true', | ||
help="Restart from the ckpt. Named taken automatically from method name.") | ||
parser.add_argument("--ckpt_interval", type=int, default=1, | ||
help="epoch interval for saving model (default: 1)") | ||
parser.add_argument("--cross_val", action='store_true', default=False, | ||
help="If validate on training or on validation (default: Train)") | ||
|
||
# Method | ||
parser.add_argument("--method", type=str, default='FT', | ||
choices=['FT', 'SPN'], | ||
help="The method you want to use.") | ||
parser.add_argument("--embedding", type=str, default="fastnvec", choices=['word2vec', 'fasttext', 'fastnvec']) | ||
|
||
return parser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/usr/bin/env bash | ||
|
||
# such as add cd /vandal/dataset | ||
if [ $# -eq 1 ]; then | ||
dest=$1 | ||
else | ||
dest="." | ||
fi | ||
|
||
wget http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip -P $dest | ||
unzip $dest/ADEChallengeData2016.zip | ||
|
||
echo "Copy the files in data/ade_splits in the ade main folder." | ||
cp data/ade_splits/* $dest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#!/usr/bin/env bash | ||
|
||
# use this script in the destination folder. | ||
# such as add cd /home/datasets | ||
|
||
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar | ||
tar -xf VOCtrainval_11-May-2012.tar | ||
mkdir PascalVOC12 | ||
mv VOCdevkit/VOC2012/* PascalVOC12 | ||
cd PascalVOC12 | ||
wget http://cs.jhu.edu/~cxliu/data/SegmentationClassAug.zip | ||
wget http://cs.jhu.edu/~cxliu/data/SegmentationClassAug_Visualization.zip | ||
wget http://cs.jhu.edu/~cxliu/data/list.zip | ||
unzip SegmentationClassAug.zip | ||
unzip SegmentationClassAug_Visualization.zip | ||
unzip list.zip | ||
mv list splits | ||
|
||
# then link SegmentationClassAug into annotations | ||
# and link JPEGImages into images | ||
ln -s JPEGImages images | ||
ln -s SegmentationClassAug annotations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
0 background | ||
1 aeroplane | ||
2 bicycle | ||
3 bird | ||
4 boat | ||
5 bottle | ||
6 bus | ||
7 car | ||
8 cat | ||
9 chair | ||
10 cow | ||
11 diningtable | ||
12 dog | ||
13 horse | ||
14 motorbike | ||
15 person | ||
16 potted_plant | ||
17 sheep | ||
18 sofa | ||
19 train | ||
20 tv |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from .voc import VOCFSSDataset | ||
from .transform import Compose, RandomScale, RandomCrop, RandomHorizontalFlip, ToTensor, Normalize, CenterCrop | ||
import torch | ||
|
||
|
||
def get_dataset(opts, task): | ||
""" Dataset And Augmentation | ||
""" | ||
train_transform = Compose([ | ||
RandomScale((0.75, 1.5)), | ||
RandomCrop(opts.crop_size, pad_if_needed=True), | ||
RandomHorizontalFlip(), | ||
ToTensor(), | ||
Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]), | ||
]) | ||
|
||
val_transform = Compose([ | ||
CenterCrop(size=opts.crop_size), | ||
ToTensor(), | ||
Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]), | ||
]) | ||
test_transform = Compose([ | ||
ToTensor(), | ||
Normalize(mean=[0.485, 0.456, 0.406], | ||
std=[0.229, 0.224, 0.225]), | ||
]) | ||
|
||
if opts.dataset == 'voc': | ||
dataset = VOCFSSDataset | ||
else: | ||
raise NotImplementedError | ||
|
||
train_dst = dataset(root=opts.data_root, task=task, train=True, transform=train_transform) | ||
|
||
if opts.cross_val: | ||
train_len = int(0.8 * len(train_dst)) | ||
val_len = len(train_dst)-train_len | ||
train_dst, val_dst = torch.utils.data.random_split(train_dst, [train_len, val_len]) | ||
else: # don't use cross_val | ||
val_dst = dataset(root=opts.data_root, task=task, train=False, transform=val_transform) | ||
|
||
test_dst = dataset(root=opts.data_root, task=task, train=False, transform=test_transform) | ||
|
||
return train_dst, val_dst, test_dst |
Oops, something went wrong.