Skip to content

Commit

Permalink
First Commit
Browse files Browse the repository at this point in the history
  • Loading branch information
fcdl94 authored and fcdl94 committed Jul 17, 2020
1 parent 3d054b1 commit 416fca8
Show file tree
Hide file tree
Showing 55 changed files with 4,352 additions and 0 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# FSS
## Few Shot Learning in Semantic Segmentation

# How to download data

> cd <target folder>
> ../data/download_voc.sh
# How to run the training

> python -m torch.distributed.launch --nproc_per_node="total GPUs" train.py --data_root "folder where you downloaded the data" --name "name of exp" --batch_size=4 --num_workers=1 --other_args
The default folder for the logs is logs/"name of exp". The log is in the format of tensorboard.

The default is to use a pretraining for the backbone used, that is searched in the pretrained folder of the project. If you don't want to use pretrained, please use --no-pretrained.
126 changes: 126 additions & 0 deletions argparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import argparse
import task


def modify_command_options(opts):
if opts.dataset == 'voc':
opts.num_classes = 21
if opts.dataset == 'ade':
opts.num_classes = 150

if not opts.visualize:
opts.sample_num = 0

opts.no_cross_val = not opts.cross_val
opts.pooling = round(opts.crop_size / opts.output_stride)

return opts


def get_argparser():
parser = argparse.ArgumentParser()

# Performance Options
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--random_seed", type=int, default=42,
help="random seed (default: 42)")
parser.add_argument("--num_workers", type=int, default=1,
help='number of workers (default: 1)')
parser.add_argument('--opt_level', type=str, choices=['O0', 'O1', 'O2', 'O3'], default='O0')

# Dataset Options
parser.add_argument("--data_root", type=str, default="data",
help="path to Dataset")
parser.add_argument("--dataset", type=str, default='voc',
choices=['voc'], help='Name of dataset')
parser.add_argument("--num_classes", type=int, default=None,
help="num classes (default: None)")

# Task Options
parser.add_argument("--step", type=int, default=0,
help="Step (0 is base)")
parser.add_argument("--task", type=str, default="15-5", choices=task.get_task_list(),
help="Task to be executed (default: 15-5)")
parser.add_argument("--nshot", type=int, default=5,
help="If step>0, the shot to use for FSL (Def=5)")
parser.add_argument("--ishot", type=int, default=0,
help="First index where to sample shots")
parser.add_argument("--use_bkg", default=False, action='store_true',
help="Whether to use or not the background as a class (def is not)")
parser.add_argument("--input_mix", default="novel", choices=['novel', 'both'],
help="Which class to use for FSL")

# Train Options
parser.add_argument("--epochs", type=int, default=30,
help="epoch number (default: 30)")

parser.add_argument("--fix_bn", action='store_true', default=False,
help='fix batch normalization during training (default: False)')
parser.add_argument("--batch_size", type=int, default=4,
help='batch size (default: 4)')
parser.add_argument("--crop_size", type=int, default=512,
help="crop size (default: 512)")

parser.add_argument("--lr", type=float, default=0.007,
help="learning rate (default: 0.007)")
parser.add_argument("--momentum", type=float, default=0.9,
help='momentum for SGD (default: 0.9)')
parser.add_argument("--weight_decay", type=float, default=1e-4,
help='weight decay (default: 1e-4)')

parser.add_argument("--lr_policy", type=str, default='poly',
choices=['poly', 'step'], help="lr schedule policy (default: poly)")
parser.add_argument("--lr_decay_step", type=int, default=5000,
help="decay step for stepLR (default: 5000)")
parser.add_argument("--lr_decay_factor", type=float, default=0.1,
help="decay factor for stepLR (default: 0.1)")
parser.add_argument("--lr_power", type=float, default=0.9,
help="power for polyLR (default: 0.9)")

# Logging Options
parser.add_argument("--logdir", type=str, default='./logs',
help="path to Log directory (default: ./logs)")
parser.add_argument("--name", type=str, default='Experiment',
help="name of the experiment - to append to log directory (default: Experiment)")
parser.add_argument("--sample_num", type=int, default=4,
help='number of samples for visualization (default: 0)')
parser.add_argument("--debug", action='store_true', default=False,
help="verbose option")
parser.add_argument("--visualize", action='store_false', default=True,
help="visualization on tensorboard (def: Yes)")
parser.add_argument("--print_interval", type=int, default=10,
help="print interval of loss (default: 10)")
parser.add_argument("--val_interval", type=int, default=1,
help="epoch interval for eval (default: 1)")

# Segmentation Architecture Options
parser.add_argument("--backbone", type=str, default='resnet101',
choices=['resnet50', 'resnet101'], help='backbone for the body (def: resnet50)')
parser.add_argument("--output_stride", type=int, default=16,
choices=[8, 16], help='stride for the backbone (def: 16)')
parser.add_argument("--no_pretrained", action='store_true', default=False,
help='Wheather to use pretrained or not (def: True)')
parser.add_argument("--norm_act", type=str, default="iabn_sync",
choices=['iabn_sync', 'iabn', 'abn', 'std'], help='Which BN to use (def: abn_sync')
parser.add_argument("--fusion-mode", metavar="NAME", type=str, choices=["mean", "voting", "max"], default="mean",
help="How to fuse the outputs. Options: 'mean', 'voting', 'max'")

# Test and Checkpoint options
parser.add_argument("--test", action='store_true', default=False,
help="Whether to train or test only (def: train and test)")
parser.add_argument("--ckpt", default=None, type=str,
help="path to trained model. Leave it None if you want to retrain your model")
parser.add_argument("--continue_ckpt", default=False, action='store_true',
help="Restart from the ckpt. Named taken automatically from method name.")
parser.add_argument("--ckpt_interval", type=int, default=1,
help="epoch interval for saving model (default: 1)")
parser.add_argument("--cross_val", action='store_true', default=False,
help="If validate on training or on validation (default: Train)")

# Method
parser.add_argument("--method", type=str, default='FT',
choices=['FT', 'SPN'],
help="The method you want to use.")
parser.add_argument("--embedding", type=str, default="fastnvec", choices=['word2vec', 'fasttext', 'fastnvec'])

return parser
14 changes: 14 additions & 0 deletions data/download_ade.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env bash

# such as add cd /vandal/dataset
if [ $# -eq 1 ]; then
dest=$1
else
dest="."
fi

wget http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip -P $dest
unzip $dest/ADEChallengeData2016.zip

echo "Copy the files in data/ade_splits in the ade main folder."
cp data/ade_splits/* $dest
22 changes: 22 additions & 0 deletions data/download_voc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env bash

# use this script in the destination folder.
# such as add cd /home/datasets

wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
tar -xf VOCtrainval_11-May-2012.tar
mkdir PascalVOC12
mv VOCdevkit/VOC2012/* PascalVOC12
cd PascalVOC12
wget http://cs.jhu.edu/~cxliu/data/SegmentationClassAug.zip
wget http://cs.jhu.edu/~cxliu/data/SegmentationClassAug_Visualization.zip
wget http://cs.jhu.edu/~cxliu/data/list.zip
unzip SegmentationClassAug.zip
unzip SegmentationClassAug_Visualization.zip
unzip list.zip
mv list splits

# then link SegmentationClassAug into annotations
# and link JPEGImages into images
ln -s JPEGImages images
ln -s SegmentationClassAug annotations
21 changes: 21 additions & 0 deletions data/voc/labels.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
0 background
1 aeroplane
2 bicycle
3 bird
4 boat
5 bottle
6 bus
7 car
8 cat
9 chair
10 cow
11 diningtable
12 dog
13 horse
14 motorbike
15 person
16 potted_plant
17 sheep
18 sofa
19 train
20 tv
Binary file added data/voc/split/inverse_dict_train.pkl
Binary file not shown.
Binary file added data/voc/split/novel_cls.npy
Binary file not shown.
Binary file added data/voc/split/seen_cls.npy
Binary file not shown.
Binary file added data/voc/split/test_ids.npy
Binary file not shown.
Binary file added data/voc/split/test_list.npy
Binary file not shown.
Binary file added data/voc/split/train_ids.npy
Binary file not shown.
Binary file added data/voc/split/train_list.npy
Binary file not shown.
Binary file added data/voc/split/val_cls.npy
Binary file not shown.
Binary file added data/voc/word_vectors/fasttext.pkl
Binary file not shown.
Binary file added data/voc/word_vectors/glove.pkl
Binary file not shown.
Binary file added data/voc/word_vectors/word2vec.pkl
Binary file not shown.
46 changes: 46 additions & 0 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from .voc import VOCFSSDataset
from .transform import Compose, RandomScale, RandomCrop, RandomHorizontalFlip, ToTensor, Normalize, CenterCrop
import torch


def get_dataset(opts, task):
""" Dataset And Augmentation
"""
train_transform = Compose([
RandomScale((0.75, 1.5)),
RandomCrop(opts.crop_size, pad_if_needed=True),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])

val_transform = Compose([
CenterCrop(size=opts.crop_size),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
test_transform = Compose([
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])

if opts.dataset == 'voc':
dataset = VOCFSSDataset
else:
raise NotImplementedError

train_dst = dataset(root=opts.data_root, task=task, train=True, transform=train_transform)

if opts.cross_val:
train_len = int(0.8 * len(train_dst))
val_len = len(train_dst)-train_len
train_dst, val_dst = torch.utils.data.random_split(train_dst, [train_len, val_len])
else: # don't use cross_val
val_dst = dataset(root=opts.data_root, task=task, train=False, transform=val_transform)

test_dst = dataset(root=opts.data_root, task=task, train=False, transform=test_transform)

return train_dst, val_dst, test_dst
Loading

0 comments on commit 416fca8

Please sign in to comment.