diff --git a/argparser.py b/argparser.py index 2c1ac6b..0787614 100644 --- a/argparser.py +++ b/argparser.py @@ -19,7 +19,9 @@ def modify_command_options(opts): if opts.method == "GIFS": opts.method = "WI" opts.norm_act = "iabr" - opts.l2_loss = 0.1 if opts.l2_loss == 0 else opts.l2_loss + # opts.l2_loss = 0.1 if opts.l2_loss == 0 else opts.l2_loss + opts.loss_kd = 10 + opts.dist_warm_start = True elif opts.method == 'LWF': opts.loss_kd = 10 if opts.loss_kd == 0 else opts.loss_kd opts.method = "FT" @@ -36,7 +38,7 @@ def modify_command_options(opts): opts.train_only_novel = True opts.train_only_classifier = True opts.method = "FT" - opts.lr_cls = 10 + opts.lr_cls = 1 # need to check! elif opts.method == 'AFHN' and opts.step > 0: opts.train_only_novel = True opts.train_only_classifier = True @@ -85,6 +87,8 @@ def get_argparser(): help="First index where to sample shots") parser.add_argument("--input_mix", default="novel", choices=['novel', 'both'], help="Which class to use for FSL") + parser.add_argument("--masking", action='store_true', default=False, + help='Mask old classes in incremental steps (def: False)') # Train Options parser.add_argument("--epochs", type=int, default=30, @@ -192,6 +196,8 @@ def get_argparser(): help='The MiB distillation loss strength (Def 0.)') parser.add_argument("--loss_kd", default=0, type=float, help='The distillation loss strength (Def 0.)') + parser.add_argument("--ort_proto", default=0, type=float, + help='The ORT*PROTO loss strength (Def 0.)') parser.add_argument("--kd_alpha", default=1, type=float, help='The temperature vale (Def 1.)') parser.add_argument("--l2_loss", default=0, type=float, @@ -254,4 +260,7 @@ def get_argparser(): help="Use only a pixel for imprinting when with WI (default: False)") parser.add_argument("--weight_mix", action='store_true', default=False, help="When doing WI, sum to proto the mix of old weights (default: False)") + + + return parser diff --git a/dataset/__init__.py b/dataset/__init__.py index d2a6bcd..cfa9e58 100644 --- a/dataset/__init__.py +++ b/dataset/__init__.py @@ -77,8 +77,8 @@ def get_dataset(opts, task, train=True): val_dst = Subset(train_dst, idx[train_len:], val_transform) train_dst_noaug = Subset(train_dst, idx[:train_len], test_transform) else: - train_dst = dataset(root=opts.data_root, task=task, train=True, transform=train_transform) - train_dst_noaug = dataset(root=opts.data_root, task=task, train=True, transform=test_transform) + train_dst = dataset(root=opts.data_root, task=task, train=True, transform=train_transform, masking=opts.masking) + train_dst_noaug = dataset(root=opts.data_root, task=task, train=True, transform=test_transform, masking=opts.masking) val_dst = dataset(root=opts.data_root, task=task, train=False, transform=val_transform) return train_dst, val_dst, train_dst_noaug diff --git a/methods/__init__.py b/methods/__init__.py index ff20f9d..fcac4b1 100644 --- a/methods/__init__.py +++ b/methods/__init__.py @@ -4,7 +4,7 @@ from .generative_AFHN import AFHN from .generative import FGI -methods = {"FT", "SPN", "COS", "WI", 'DWI', "TWI", "AMP", "WG", +methods = {"FT", "SPN", "COS", "WI", 'DWI', "MWI", "AMP", "WG", "GIFS", "LWF", "MIB", "ILT", "RT", "AFHN", "FGI", "FGI2"} @@ -18,9 +18,9 @@ def get_method(opts, task, device, logger): elif opts.method == 'DWI': opts.method = 'COS' return DynamicWI(task=task, device=device, logger=logger, opts=opts) - elif opts.method == 'TWI': + elif opts.method == 'MWI': opts.method = 'COS' - return TrainedWI(task=task, device=device, logger=logger, opts=opts) + return MaskedWI(task=task, device=device, logger=logger, opts=opts) elif opts.method == 'AMP': opts.method = 'FT' return AMP(task=task, device=device, logger=logger, opts=opts) diff --git a/methods/imprinting.py b/methods/imprinting.py index 589279f..2a1f853 100644 --- a/methods/imprinting.py +++ b/methods/imprinting.py @@ -9,6 +9,10 @@ import random from .utils import get_batch, get_prototype import math +import copy +from modules.deeplab import DeeplabV3 +from functools import partial +from modules.custom_bn import InPlaceABR class AMP(Trainer): @@ -48,10 +52,70 @@ def warm_up_(self, dataset, epochs=1): model.cls.imprint_weights_class(features=features, cl=c, alpha=self.amp_alpha) +class MaskedWI(Trainer): + + def compute_weight(self, inp): + # input is a D dimensional prototype + inp = F.normalize(inp, dim=0) + weight = inp.mean(dim=0).view(-1, 1, 1) + return weight + + @staticmethod + def get_proto(model, ds, cl, device, interpolate_label=True, return_all=False): + protos = [] + for img, lbl in ds: + img, lbl = img.to(device), lbl.to(device) + mask = (lbl == cl).float() + img = img*mask + out = model(img.unsqueeze(0), use_classifier=False) + if interpolate_label: # to match output size + lbl = F.interpolate(lbl.float().view(1, 1, lbl.shape[0], lbl.shape[1]), + size=out.shape[-2:], mode="nearest").view(out.shape[-2:]).type(torch.uint8) + else: # interpolate output to match label size + out = F.interpolate(out, size=img.shape[-2:], mode="bilinear", align_corners=False) + out = out.squeeze(0) + out = out.view(out.shape[0], -1).t() # (HxW) x F + lbl = lbl.flatten() # Now it is (HxW) + if mask.sum() > 0: + protos.append(F.normalize((out[lbl == cl, :]), dim=1).mean(dim=0, keepdim=True)) + + if len(protos) > 0: + protos = torch.cat(protos, dim=0) + if return_all: + return protos + return protos.mean(dim=0) + + def warm_up_(self, dataset, epochs=1): + model = self.model.module if self.distributed else self.model + model.eval() + classes = self.task.get_n_classes() + old_classes = 0 + for c in classes[:-1]: + old_classes += c + new_classes = np.arange(old_classes, old_classes + classes[-1]) + for c in new_classes: + weight = None + count = 0 + with torch.no_grad(): + while weight is None and count < 10: + ds = dataset.get_k_image_of_class(cl=c, k=self.task.nshot) # get K images of class c + + wc = get_prototype(model, ds, c, self.device, interpolate_label=False, return_all=True) + if wc is not None: + weight = self.compute_weight(wc) + count += 1 + + if weight is not None: + model.cls.imprint_weights_class(weight, c) + else: + raise Exception(f"Unable to imprint weight of class {c} after {count} trials.") + + class WeightImprinting(Trainer): def __init__(self, task, device, logger, opts): super().__init__(task, device, logger, opts) self.pixel = opts.pixel_imprinting + self.masking = True if opts.weight_mix: self.normalize_weight = True self.compute_score = True @@ -76,6 +140,7 @@ def warm_up_(self, dataset, epochs=5): images, labels = dataset[idx] images = images.to(self.device, dtype=torch.float32) labels = labels.to(self.device, dtype=torch.long) + out = model(images.unsqueeze(0), use_classifier=False) # .squeeze_(0) out_size = images.shape[-2:] out = F.interpolate(out, size=out_size, mode="bilinear", align_corners=False).squeeze_(0) @@ -290,6 +355,14 @@ def __init__(self, task, device, logger, opts): super(TrainedWI, self).__init__(task, device, logger, opts) self.dim = self.n_channels + if self.step > 0: + self.generator = DeeplabV3(2048, 256, 256, + norm_act=partial(InPlaceABR, activation="leaky_relu", activation_param=.01), + out_stride=opts.output_stride, pooling_size=opts.pooling, + pooling=False, last_relu=opts.relu).to(self.device) + else: + self.generator = None + self.LR = opts.dyn_lr self.ITER = opts.dyn_iter self.BATCH_SIZE = 4 @@ -304,6 +377,7 @@ def compute_weight(self, inp): return weight def warm_up_(self, dataset, epochs=1): + assert self.generator is not None model = self.model.module if self.distributed else self.model model.eval() classes = self.task.get_n_classes() @@ -331,20 +405,30 @@ def cool_down(self, dataset, epochs=1): if self.step == 0: # instance a new model without DDP! model = make_model(self.opts, CosineClassifier(self.task.get_n_classes(), channels=self.n_channels)) + scaler = model.cls.scaler state = {} for k, v in self.model.state_dict().items(): state[k[7:]] = v model.load_state_dict(state, strict=True) + + self.generator = model.head + #self.generator.pooling = False + model = model.to(self.device) model.eval() for p in model.body.parameters(): p.requires_grad = False + for p in model.head.parameters(): + p.requires_grad = False + + model2 = copy.deepcopy(model) + model2.head = self.generator.to(self.device) # instance optimizer, criterion and data classes = np.arange(0, self.task.get_n_classes()[0]) classes = classes[1:] # remove bkg ONLY from sampling - params = [{"params": model.cls.cls.parameters()}, {"params": model.head.parameters()}] + params = [{"params": model.cls.cls.parameters()}, {"params": self.generator.parameters()}] optimizer = torch.optim.SGD(params, lr=self.LR) criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') @@ -371,7 +455,7 @@ def cool_down(self, dataset, epochs=1): # print("WC is None!!") weight[c] = F.normalize(model.cls.cls[0].weight[c], dim=0) else: - class_weight = self.compute_weight(wc).detach() + class_weight = self.compute_weight(wc) weight[c] = F.normalize(class_weight, dim=0) else: weight[c] = F.normalize(model.cls.cls[0].weight[c], dim=0) @@ -412,10 +496,12 @@ def cool_down(self, dataset, epochs=1): def load_state_dict(self, checkpoint, strict=True): super().load_state_dict(checkpoint, strict) + if self.step > 0: + self.generator.load_state_dict(checkpoint['generator']) def state_dict(self): state = {"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.state_dict()} + "scheduler": self.scheduler.state_dict(), "generator": self.generator.state_dict()} return state diff --git a/methods/trainer.py b/methods/trainer.py index 38d0c81..43088e9 100644 --- a/methods/trainer.py +++ b/methods/trainer.py @@ -2,7 +2,7 @@ from torch import distributed import torch.nn as nn from torch.nn.parallel import DistributedDataParallel -from utils.loss import HardNegativeMining, FocalLoss, KnowledgeDistillationLoss, CosineLoss, \ +from utils.loss import HardNegativeMining, FocalLoss, KnowledgeDistillationLoss, CosineLoss, OrthPrototypeLoss, \ UnbiasedKnowledgeDistillationLoss, UnbiasedCrossEntropy, CosineKnowledgeDistillationLoss, ClassBkgLoss from .segmentation_module import make_model from modules.classifier import IncrementalClassifier, CosineClassifier, SPNetClassifier @@ -104,6 +104,12 @@ def __init__(self, task, device, logger, opts): else: self.feat_criterion = None + if opts.ort_proto > 0: + self.ort_proto = opts.ort_proto + self.ort_proto_crit = OrthPrototypeLoss(self.task.get_n_classes()) + else: + self.ort_proto_crit = None + # Output distillation if opts.loss_kd > 0 or opts.mib_kd > 0: assert self.model_old is not None, "Error, model old is None but distillation specified" @@ -254,6 +260,11 @@ def train(self, cur_epoch, train_loader, metrics=None, print_int=10, n_iter=1): if self.bkg_dist_crit is not None: rloss += self.bkg_dist * self.bkg_dist_crit(outputs, labels) + if self.ort_proto_crit is not None: + if self.distributed: + rloss += self.ort_proto * self.ort_proto_crit(self.model.module.cls) + else: + rloss += self.ort_proto * self.ort_proto_crit(self.model.cls) gloss += self.generative_loss(images, labels) loss = self.reduction(criterion(outputs, labels), labels) diff --git a/utils/loss.py b/utils/loss.py index 25ebdb7..cbc02f3 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -79,7 +79,7 @@ def forward(self, inputs, targets): inputs = inputs.narrow(1, 0, targets.shape[1]) if self.norm == "L2": - loss = (inputs - targets).norm(dim=1) / inputs.shape[1] + loss = ((inputs - targets)**2).mean(dim=1) else: loss = (inputs - targets).mean(dim=1) @@ -93,6 +93,58 @@ def forward(self, inputs, targets): return outputs +class OrthPrototypeLoss(nn.Module): + def __init__(self, classes): + super().__init__() + self.classes = classes + self.steps = len(classes) + self.novel = classes[-1] + + def forward(self, classifier): + weights = [] + novel_weights = classifier.cls[self.steps-1].weight.view(self.novel, -1) + loss = 0. + + for s, clx in enumerate(self.classes): + if s < self.steps-1: + xx = classifier.cls[s].weight.detach() + xx = xx.view(xx.shape[0], xx.shape[1]) + weights.append(xx) + weights = torch.cat(weights, dim=0) + + for c in range(self.novel): + cl_loss = 0. + for i in range(len(weights)): + cl_loss += max(0., F.cosine_similarity(weights[i], novel_weights[c], dim=0)) + for i in range(self.novel): + if i != c: + cl_loss += max(0., F.cosine_similarity(novel_weights[i], novel_weights[c], dim=0)) + loss += cl_loss / (len(weights) + len(novel_weights) - 1) + return loss/self.novel + + +class UltimateLoss(nn.Module): + def __init__(self, alpha=1): + super().__init__() + self.alpha = alpha + + def forward(self, inputs, body, targets, body_old): + inputs = inputs.narrow(1, 0, targets.shape[1]) + + inputs = F.interpolate(inputs, size=(body.shape[-2:])) + targets = F.interpolate(targets, size=(body.shape[-2:])) + + outputs = torch.log_softmax(inputs, dim=1) + labels = torch.softmax(targets / self.alpha, dim=1) + + los1 = -(outputs * labels).mean(dim=1) + los2 = (body - body_old).norm(dim=1) + + loss = los1 * los2 + + return loss.mean() + + class KnowledgeDistillationLoss(nn.Module): def __init__(self, reduction='mean', kl=False, alpha=1.): super().__init__() @@ -103,7 +155,7 @@ def __init__(self, reduction='mean', kl=False, alpha=1.): def forward(self, inputs, targets): inputs = inputs.narrow(1, 0, targets.shape[1]) - outputs = torch.log_softmax(inputs / self.alpha, dim=1) + outputs = torch.log_softmax(inputs, dim=1) labels = torch.softmax(targets / self.alpha, dim=1) if not self.kl: