Skip to content

Commit

Permalink
Added masking (and few old stuffs)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcdl94 authored and fcdl94 committed Aug 25, 2021
1 parent 155d07f commit d3c0d1c
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 13 deletions.
13 changes: 11 additions & 2 deletions argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def modify_command_options(opts):
if opts.method == "GIFS":
opts.method = "WI"
opts.norm_act = "iabr"
opts.l2_loss = 0.1 if opts.l2_loss == 0 else opts.l2_loss
# opts.l2_loss = 0.1 if opts.l2_loss == 0 else opts.l2_loss
opts.loss_kd = 10
opts.dist_warm_start = True
elif opts.method == 'LWF':
opts.loss_kd = 10 if opts.loss_kd == 0 else opts.loss_kd
opts.method = "FT"
Expand All @@ -36,7 +38,7 @@ def modify_command_options(opts):
opts.train_only_novel = True
opts.train_only_classifier = True
opts.method = "FT"
opts.lr_cls = 10
opts.lr_cls = 1 # need to check!
elif opts.method == 'AFHN' and opts.step > 0:
opts.train_only_novel = True
opts.train_only_classifier = True
Expand Down Expand Up @@ -85,6 +87,8 @@ def get_argparser():
help="First index where to sample shots")
parser.add_argument("--input_mix", default="novel", choices=['novel', 'both'],
help="Which class to use for FSL")
parser.add_argument("--masking", action='store_true', default=False,
help='Mask old classes in incremental steps (def: False)')

# Train Options
parser.add_argument("--epochs", type=int, default=30,
Expand Down Expand Up @@ -192,6 +196,8 @@ def get_argparser():
help='The MiB distillation loss strength (Def 0.)')
parser.add_argument("--loss_kd", default=0, type=float,
help='The distillation loss strength (Def 0.)')
parser.add_argument("--ort_proto", default=0, type=float,
help='The ORT*PROTO loss strength (Def 0.)')
parser.add_argument("--kd_alpha", default=1, type=float,
help='The temperature vale (Def 1.)')
parser.add_argument("--l2_loss", default=0, type=float,
Expand Down Expand Up @@ -254,4 +260,7 @@ def get_argparser():
help="Use only a pixel for imprinting when with WI (default: False)")
parser.add_argument("--weight_mix", action='store_true', default=False,
help="When doing WI, sum to proto the mix of old weights (default: False)")



return parser
4 changes: 2 additions & 2 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def get_dataset(opts, task, train=True):
val_dst = Subset(train_dst, idx[train_len:], val_transform)
train_dst_noaug = Subset(train_dst, idx[:train_len], test_transform)
else:
train_dst = dataset(root=opts.data_root, task=task, train=True, transform=train_transform)
train_dst_noaug = dataset(root=opts.data_root, task=task, train=True, transform=test_transform)
train_dst = dataset(root=opts.data_root, task=task, train=True, transform=train_transform, masking=opts.masking)
train_dst_noaug = dataset(root=opts.data_root, task=task, train=True, transform=test_transform, masking=opts.masking)
val_dst = dataset(root=opts.data_root, task=task, train=False, transform=val_transform)

return train_dst, val_dst, train_dst_noaug
Expand Down
6 changes: 3 additions & 3 deletions methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .generative_AFHN import AFHN
from .generative import FGI

methods = {"FT", "SPN", "COS", "WI", 'DWI', "TWI", "AMP", "WG",
methods = {"FT", "SPN", "COS", "WI", 'DWI', "MWI", "AMP", "WG",
"GIFS", "LWF", "MIB", "ILT", "RT", "AFHN", "FGI", "FGI2"}


Expand All @@ -18,9 +18,9 @@ def get_method(opts, task, device, logger):
elif opts.method == 'DWI':
opts.method = 'COS'
return DynamicWI(task=task, device=device, logger=logger, opts=opts)
elif opts.method == 'TWI':
elif opts.method == 'MWI':
opts.method = 'COS'
return TrainedWI(task=task, device=device, logger=logger, opts=opts)
return MaskedWI(task=task, device=device, logger=logger, opts=opts)
elif opts.method == 'AMP':
opts.method = 'FT'
return AMP(task=task, device=device, logger=logger, opts=opts)
Expand Down
92 changes: 89 additions & 3 deletions methods/imprinting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import random
from .utils import get_batch, get_prototype
import math
import copy
from modules.deeplab import DeeplabV3
from functools import partial
from modules.custom_bn import InPlaceABR


class AMP(Trainer):
Expand Down Expand Up @@ -48,10 +52,70 @@ def warm_up_(self, dataset, epochs=1):
model.cls.imprint_weights_class(features=features, cl=c, alpha=self.amp_alpha)


class MaskedWI(Trainer):

def compute_weight(self, inp):
# input is a D dimensional prototype
inp = F.normalize(inp, dim=0)
weight = inp.mean(dim=0).view(-1, 1, 1)
return weight

@staticmethod
def get_proto(model, ds, cl, device, interpolate_label=True, return_all=False):
protos = []
for img, lbl in ds:
img, lbl = img.to(device), lbl.to(device)
mask = (lbl == cl).float()
img = img*mask
out = model(img.unsqueeze(0), use_classifier=False)
if interpolate_label: # to match output size
lbl = F.interpolate(lbl.float().view(1, 1, lbl.shape[0], lbl.shape[1]),
size=out.shape[-2:], mode="nearest").view(out.shape[-2:]).type(torch.uint8)
else: # interpolate output to match label size
out = F.interpolate(out, size=img.shape[-2:], mode="bilinear", align_corners=False)
out = out.squeeze(0)
out = out.view(out.shape[0], -1).t() # (HxW) x F
lbl = lbl.flatten() # Now it is (HxW)
if mask.sum() > 0:
protos.append(F.normalize((out[lbl == cl, :]), dim=1).mean(dim=0, keepdim=True))

if len(protos) > 0:
protos = torch.cat(protos, dim=0)
if return_all:
return protos
return protos.mean(dim=0)

def warm_up_(self, dataset, epochs=1):
model = self.model.module if self.distributed else self.model
model.eval()
classes = self.task.get_n_classes()
old_classes = 0
for c in classes[:-1]:
old_classes += c
new_classes = np.arange(old_classes, old_classes + classes[-1])
for c in new_classes:
weight = None
count = 0
with torch.no_grad():
while weight is None and count < 10:
ds = dataset.get_k_image_of_class(cl=c, k=self.task.nshot) # get K images of class c

wc = get_prototype(model, ds, c, self.device, interpolate_label=False, return_all=True)
if wc is not None:
weight = self.compute_weight(wc)
count += 1

if weight is not None:
model.cls.imprint_weights_class(weight, c)
else:
raise Exception(f"Unable to imprint weight of class {c} after {count} trials.")


class WeightImprinting(Trainer):
def __init__(self, task, device, logger, opts):
super().__init__(task, device, logger, opts)
self.pixel = opts.pixel_imprinting
self.masking = True
if opts.weight_mix:
self.normalize_weight = True
self.compute_score = True
Expand All @@ -76,6 +140,7 @@ def warm_up_(self, dataset, epochs=5):
images, labels = dataset[idx]
images = images.to(self.device, dtype=torch.float32)
labels = labels.to(self.device, dtype=torch.long)

out = model(images.unsqueeze(0), use_classifier=False) # .squeeze_(0)
out_size = images.shape[-2:]
out = F.interpolate(out, size=out_size, mode="bilinear", align_corners=False).squeeze_(0)
Expand Down Expand Up @@ -290,6 +355,14 @@ def __init__(self, task, device, logger, opts):
super(TrainedWI, self).__init__(task, device, logger, opts)
self.dim = self.n_channels

if self.step > 0:
self.generator = DeeplabV3(2048, 256, 256,
norm_act=partial(InPlaceABR, activation="leaky_relu", activation_param=.01),
out_stride=opts.output_stride, pooling_size=opts.pooling,
pooling=False, last_relu=opts.relu).to(self.device)
else:
self.generator = None

self.LR = opts.dyn_lr
self.ITER = opts.dyn_iter
self.BATCH_SIZE = 4
Expand All @@ -304,6 +377,7 @@ def compute_weight(self, inp):
return weight

def warm_up_(self, dataset, epochs=1):
assert self.generator is not None
model = self.model.module if self.distributed else self.model
model.eval()
classes = self.task.get_n_classes()
Expand Down Expand Up @@ -331,20 +405,30 @@ def cool_down(self, dataset, epochs=1):
if self.step == 0:
# instance a new model without DDP!
model = make_model(self.opts, CosineClassifier(self.task.get_n_classes(), channels=self.n_channels))

scaler = model.cls.scaler
state = {}
for k, v in self.model.state_dict().items():
state[k[7:]] = v
model.load_state_dict(state, strict=True)

self.generator = model.head
#self.generator.pooling = False

model = model.to(self.device)
model.eval()
for p in model.body.parameters():
p.requires_grad = False
for p in model.head.parameters():
p.requires_grad = False

model2 = copy.deepcopy(model)
model2.head = self.generator.to(self.device)

# instance optimizer, criterion and data
classes = np.arange(0, self.task.get_n_classes()[0])
classes = classes[1:] # remove bkg ONLY from sampling
params = [{"params": model.cls.cls.parameters()}, {"params": model.head.parameters()}]
params = [{"params": model.cls.cls.parameters()}, {"params": self.generator.parameters()}]
optimizer = torch.optim.SGD(params, lr=self.LR)
criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

Expand All @@ -371,7 +455,7 @@ def cool_down(self, dataset, epochs=1):
# print("WC is None!!")
weight[c] = F.normalize(model.cls.cls[0].weight[c], dim=0)
else:
class_weight = self.compute_weight(wc).detach()
class_weight = self.compute_weight(wc)
weight[c] = F.normalize(class_weight, dim=0)
else:
weight[c] = F.normalize(model.cls.cls[0].weight[c], dim=0)
Expand Down Expand Up @@ -412,10 +496,12 @@ def cool_down(self, dataset, epochs=1):

def load_state_dict(self, checkpoint, strict=True):
super().load_state_dict(checkpoint, strict)
if self.step > 0:
self.generator.load_state_dict(checkpoint['generator'])

def state_dict(self):
state = {"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(),
"scheduler": self.scheduler.state_dict()}
"scheduler": self.scheduler.state_dict(), "generator": self.generator.state_dict()}
return state


Expand Down
13 changes: 12 additions & 1 deletion methods/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch import distributed
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from utils.loss import HardNegativeMining, FocalLoss, KnowledgeDistillationLoss, CosineLoss, \
from utils.loss import HardNegativeMining, FocalLoss, KnowledgeDistillationLoss, CosineLoss, OrthPrototypeLoss, \
UnbiasedKnowledgeDistillationLoss, UnbiasedCrossEntropy, CosineKnowledgeDistillationLoss, ClassBkgLoss
from .segmentation_module import make_model
from modules.classifier import IncrementalClassifier, CosineClassifier, SPNetClassifier
Expand Down Expand Up @@ -104,6 +104,12 @@ def __init__(self, task, device, logger, opts):
else:
self.feat_criterion = None

if opts.ort_proto > 0:
self.ort_proto = opts.ort_proto
self.ort_proto_crit = OrthPrototypeLoss(self.task.get_n_classes())
else:
self.ort_proto_crit = None

# Output distillation
if opts.loss_kd > 0 or opts.mib_kd > 0:
assert self.model_old is not None, "Error, model old is None but distillation specified"
Expand Down Expand Up @@ -254,6 +260,11 @@ def train(self, cur_epoch, train_loader, metrics=None, print_int=10, n_iter=1):
if self.bkg_dist_crit is not None:
rloss += self.bkg_dist * self.bkg_dist_crit(outputs, labels)

if self.ort_proto_crit is not None:
if self.distributed:
rloss += self.ort_proto * self.ort_proto_crit(self.model.module.cls)
else:
rloss += self.ort_proto * self.ort_proto_crit(self.model.cls)
gloss += self.generative_loss(images, labels)

loss = self.reduction(criterion(outputs, labels), labels)
Expand Down
56 changes: 54 additions & 2 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def forward(self, inputs, targets):
inputs = inputs.narrow(1, 0, targets.shape[1])

if self.norm == "L2":
loss = (inputs - targets).norm(dim=1) / inputs.shape[1]
loss = ((inputs - targets)**2).mean(dim=1)
else:
loss = (inputs - targets).mean(dim=1)

Expand All @@ -93,6 +93,58 @@ def forward(self, inputs, targets):
return outputs


class OrthPrototypeLoss(nn.Module):
def __init__(self, classes):
super().__init__()
self.classes = classes
self.steps = len(classes)
self.novel = classes[-1]

def forward(self, classifier):
weights = []
novel_weights = classifier.cls[self.steps-1].weight.view(self.novel, -1)
loss = 0.

for s, clx in enumerate(self.classes):
if s < self.steps-1:
xx = classifier.cls[s].weight.detach()
xx = xx.view(xx.shape[0], xx.shape[1])
weights.append(xx)
weights = torch.cat(weights, dim=0)

for c in range(self.novel):
cl_loss = 0.
for i in range(len(weights)):
cl_loss += max(0., F.cosine_similarity(weights[i], novel_weights[c], dim=0))
for i in range(self.novel):
if i != c:
cl_loss += max(0., F.cosine_similarity(novel_weights[i], novel_weights[c], dim=0))
loss += cl_loss / (len(weights) + len(novel_weights) - 1)
return loss/self.novel


class UltimateLoss(nn.Module):
def __init__(self, alpha=1):
super().__init__()
self.alpha = alpha

def forward(self, inputs, body, targets, body_old):
inputs = inputs.narrow(1, 0, targets.shape[1])

inputs = F.interpolate(inputs, size=(body.shape[-2:]))
targets = F.interpolate(targets, size=(body.shape[-2:]))

outputs = torch.log_softmax(inputs, dim=1)
labels = torch.softmax(targets / self.alpha, dim=1)

los1 = -(outputs * labels).mean(dim=1)
los2 = (body - body_old).norm(dim=1)

loss = los1 * los2

return loss.mean()


class KnowledgeDistillationLoss(nn.Module):
def __init__(self, reduction='mean', kl=False, alpha=1.):
super().__init__()
Expand All @@ -103,7 +155,7 @@ def __init__(self, reduction='mean', kl=False, alpha=1.):
def forward(self, inputs, targets):
inputs = inputs.narrow(1, 0, targets.shape[1])

outputs = torch.log_softmax(inputs / self.alpha, dim=1)
outputs = torch.log_softmax(inputs, dim=1)
labels = torch.softmax(targets / self.alpha, dim=1)

if not self.kl:
Expand Down

0 comments on commit d3c0d1c

Please sign in to comment.