Skip to content

Commit

Permalink
Added CosineSimilarity Method
Browse files Browse the repository at this point in the history
Fixed few small bugs
  • Loading branch information
fcdl94 authored and fcdl94 committed Jul 21, 2020
1 parent 416fca8 commit 3f5c6dd
Show file tree
Hide file tree
Showing 13 changed files with 299 additions and 62 deletions.
107 changes: 107 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

#PyCharm
.idea/
8 changes: 7 additions & 1 deletion argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def modify_command_options(opts):

opts.no_cross_val = not opts.cross_val
opts.pooling = round(opts.crop_size / opts.output_stride)
opts.crop_size_test = opts.crop_size if opts.crop_size_test is None else opts.crop_size_test

return opts

Expand Down Expand Up @@ -60,6 +61,8 @@ def get_argparser():
help='batch size (default: 4)')
parser.add_argument("--crop_size", type=int, default=512,
help="crop size (default: 512)")
parser.add_argument("--crop_size_test", type=int, default=None,
help="test crop size (default: = --crop_size)")

parser.add_argument("--lr", type=float, default=0.007,
help="learning rate (default: 0.007)")
Expand Down Expand Up @@ -117,9 +120,12 @@ def get_argparser():
parser.add_argument("--cross_val", action='store_true', default=False,
help="If validate on training or on validation (default: Train)")

parser.add_argument("--step_ckpt", default=None, type=str,
help="path to trained model at previous step. Leave it None if you want to use def path")

# Method
parser.add_argument("--method", type=str, default='FT',
choices=['FT', 'SPN'],
choices=['FT', 'SPN', 'COS'],
help="The method you want to use.")
parser.add_argument("--embedding", type=str, default="fastnvec", choices=['word2vec', 'fasttext', 'fastnvec'])

Expand Down
3 changes: 2 additions & 1 deletion dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def get_dataset(opts, task):
""" Dataset And Augmentation
"""
train_transform = Compose([
RandomScale((0.75, 1.5)),
RandomScale((0.5, 1.5)),
RandomCrop(opts.crop_size, pad_if_needed=True),
RandomHorizontalFlip(),
ToTensor(),
Expand All @@ -22,6 +22,7 @@ def get_dataset(opts, task):
std=[0.229, 0.224, 0.225]),
])
test_transform = Compose([
CenterCrop(size=opts.crop_size),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
Expand Down
4 changes: 2 additions & 2 deletions dataset/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,12 +497,12 @@ def __call__(self, img, lbl=None):
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
img = F.pad(img, padding=int((1 + self.size[1] - img.size[0]) / 2))
lbl = F.pad(lbl, padding=int((1 + self.size[1] - lbl.size[0]) / 2))
lbl = F.pad(lbl, padding=int((1 + self.size[1] - lbl.size[0]) / 2), fill=255)

# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
img = F.pad(img, padding=int((1 + self.size[0] - img.size[1]) / 2))
lbl = F.pad(lbl, padding=int((1 + self.size[0] - lbl.size[1]) / 2))
lbl = F.pad(lbl, padding=int((1 + self.size[0] - lbl.size[1]) / 2), fill=255)

i, j, h, w = self.get_params(img, self.size)

Expand Down
8 changes: 8 additions & 0 deletions methods/SPNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,11 @@ def initialize(self, opts):

reduction = 'mean'
self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction=reduction)

def load_state_dict(self, checkpoint, strict=True):
if not strict:
del checkpoint["model"]['module.cls.class_emb']
self.model.load_state_dict(checkpoint["model"], strict=strict)
if strict:
self.optimizer.load_state_dict(checkpoint["optimizer"])
self.scheduler.load_state_dict(checkpoint["scheduler"])
3 changes: 3 additions & 0 deletions methods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .segmentation_module import make_model
from .method import FineTuning
from .SPNet import SPNet
from .imprinting import CosineFT


def get_method(opts, task, device, logger):
Expand All @@ -9,6 +10,8 @@ def get_method(opts, task, device, logger):
method_ = FineTuning(task=task, device=device, logger=logger, opts=opts)
elif opts.method == "SPN":
method_ = SPNet(task=task, device=device, logger=logger, opts=opts)
elif opts.method == "COS":
method_ = CosineFT(task=task, device=device, logger=logger, opts=opts)
else:
raise NotImplementedError

Expand Down
50 changes: 50 additions & 0 deletions methods/imprinting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
from .method import FineTuning
import torch.nn as nn
from .utils import get_scheduler
import torch.nn.functional as F
from .segmentation_module import make_model


class Classifier(nn.Module):
def __init__(self, channels, classes):
super().__init__()
self.cls = nn.ModuleList(
[nn.Conv2d(channels, c, 1, bias=False) for c in classes])
self.scaler = 10.

def forward(self, x):
x = F.normalize(x, p=2, dim=1)
out = []
for i, mod in enumerate(self.cls):
out.append(self.scaler * F.conv2d(x, F.normalize(mod.weight, dim=1, p=2)))
return torch.cat(out, dim=1)


class CosineFT(FineTuning):
def initialize(self, opts):

head_channels = 256
self.model = make_model(opts, head_channels, Classifier(head_channels, self.task.get_n_classes()))

if opts.fix_bn:
self.model.fix_bn()

# xxx Set up optimizer
params = []
params.append({"params": filter(lambda p: p.requires_grad, self.model.body.parameters()),
'weight_decay': opts.weight_decay})

params.append({"params": filter(lambda p: p.requires_grad, self.model.head.parameters()),
'weight_decay': opts.weight_decay, 'lr': opts.lr*10.})

params.append({"params": filter(lambda p: p.requires_grad, self.model.cls.parameters()),
'weight_decay': opts.weight_decay, 'lr': opts.lr*10.})

self.optimizer = torch.optim.SGD(params, lr=opts.lr, momentum=0.9, nesterov=False)

self.scheduler = get_scheduler(opts, self.optimizer)
self.logger.debug("Optimizer:\n%s" % self.optimizer)

reduction = 'mean'
self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction=reduction)
22 changes: 9 additions & 13 deletions methods/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ def train(self, cur_epoch, train_loader, print_int=10):
def validate(self, loader, metrics, ret_samples_ids=None):
raise NotImplementedError

def load_state_dict(self, checkpoint):
self.model.load_state_dict(checkpoint["model"], strict=True)
self.optimizer.load_state_dict(checkpoint["optimizer"])
self.scheduler.load_state_dict(checkpoint["scheduler"])
def load_state_dict(self, checkpoint, strict=True):
self.model.load_state_dict(checkpoint["model"], strict=strict)
if strict:
self.optimizer.load_state_dict(checkpoint["optimizer"])
self.scheduler.load_state_dict(checkpoint["scheduler"])

def state_dict(self):
state = {"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(),
Expand Down Expand Up @@ -73,20 +74,15 @@ def forward(self, x):
'weight_decay': opts.weight_decay})

params.append({"params": filter(lambda p: p.requires_grad, self.model.head.parameters()),
'weight_decay': opts.weight_decay})
'weight_decay': opts.weight_decay, 'lr': opts.lr*10.})

params.append({"params": filter(lambda p: p.requires_grad, self.model.cls.parameters()),
'weight_decay': opts.weight_decay})
'weight_decay': opts.weight_decay, 'lr': opts.lr*10.})

self.optimizer = torch.optim.SGD(params, lr=opts.lr, momentum=0.9, nesterov=True)
self.optimizer = torch.optim.SGD(params, lr=opts.lr, momentum=0.9, nesterov=False)
self.scheduler = get_scheduler(opts, self.optimizer)
self.logger.debug("Optimizer:\n%s" % self.optimizer)

self.model, self.optimizer = amp.initialize(self.model.to(self.device), self.optimizer, opt_level=opts.opt_level)

# Put the model on GPU
self.model = DistributedDataParallel(self.model, delay_allreduce=True)

reduction = 'mean'
self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction=reduction)

Expand Down Expand Up @@ -184,7 +180,7 @@ def validate(self, loader, metrics, ret_samples_ids=None):

if ret_samples_ids is not None and i in ret_samples_ids: # get samples
ret_samples.append((images[0].detach().cpu().numpy(),
labels[0]))
labels[0], prediction[0]))

# collect statistics from multiple processes
metrics.synch(device)
Expand Down
Loading

0 comments on commit 3f5c6dd

Please sign in to comment.