From fcb29f1249bb3f3f6d16c20b9d098e3d6c740274 Mon Sep 17 00:00:00 2001 From: fcdl94 Date: Wed, 10 Feb 2021 10:32:05 +0100 Subject: [PATCH] Update of Dyn Weight Imprinting Added BornAgain for base step Few other fixes --- argparser.py | 14 +++++-- dataset/__init__.py | 8 ++-- methods/__init__.py | 2 +- methods/imprinting.py | 90 +++++++++++++++++++++++++++++++------------ methods/trainer.py | 75 ++++++++++++++++++++++-------------- methods/utils.py | 5 ++- modules/custom_bn.py | 7 ++-- run.py | 2 +- utils/loss.py | 24 ++++++------ 9 files changed, 151 insertions(+), 76 deletions(-) diff --git a/argparser.py b/argparser.py index b58a934..9596e6a 100644 --- a/argparser.py +++ b/argparser.py @@ -32,6 +32,10 @@ def modify_command_options(opts): opts.loss_kd = 10 if opts.loss_kd == 0 else opts.loss_kd opts.loss_de = 10 if opts.loss_de == 0 else opts.loss_de opts.method = "FT" + elif opts.method == 'RT': + opts.train_only_novel = True + opts.method = "FT" + opts.lr_cls = 10 if opts.train_only_classifier or opts.train_only_novel: opts.freeze = True @@ -189,17 +193,21 @@ def get_argparser(): help='The L1 feature loss strength (Def 0.)') parser.add_argument("--cos_loss", default=0, type=float, help='The feature loss strength (Def 0.)') + parser.add_argument("--kl_div", default=False, action='store_true', + help='Use true KL loss and not the CE loss.') + parser.add_argument("--dist_warm_start", default=False, action='store_true', + help='Use warm start for distillation.') + parser.add_argument("--born_again", default=False, action='store_true', + help='Use born again strategy (use --ckpt as model old).') parser.add_argument("--train_only_classifier", action='store_true', default=False, help="Freeze body and head of network (default: False)") parser.add_argument("--train_only_novel", action='store_true', default=False, help="Train only the classifier of current step (default: False)") parser.add_argument("--bn_momentum", default=None, type=float, - help="The BN momentum (Set to 0 to avoid update of running stats.)") + help="The BN momentum (Set to 0.1 to update of running stats of ABR.)") # to remove - parser.add_argument("--strong_scale", action='store_true', default=False, - help="Use strong scale augmentation (default: False)") parser.add_argument("--pixel_imprinting", action='store_true', default=False, help="Use only a pixel for imprinting when with WI (default: False)") parser.add_argument("--weight_mix", action='store_true', default=False, diff --git a/dataset/__init__.py b/dataset/__init__.py index ac094b0..36d8804 100644 --- a/dataset/__init__.py +++ b/dataset/__init__.py @@ -2,9 +2,11 @@ from .cityscapes import CityscapesFSSDataset from .coco import COCOFSS, COCO, COCOStuffFSS from .ade import AdeSegmentation -from .transform import Compose, RandomScale, RandomCrop, RandomHorizontalFlip, ToTensor, Normalize, CenterCrop, Resize +from .transform import Compose, RandomScale, RandomCrop, RandomHorizontalFlip, ToTensor, Normalize, \ + CenterCrop, Resize, RandomResizedCrop, ColorJitter import random from .utils import Subset, MyImageFolder, RandomDataset + TRAIN_CV = 0.8 @@ -40,9 +42,9 @@ def get_dataset(opts, task, train=True): dataset = COCOStuffFSS else: dataset = COCOFSS - scale = RandomScale((0.5, 2)) if not opts.strong_scale else RandomScale((0.5, 4)) + train_transform = Compose([ - scale, + RandomScale((0.5, 2)), RandomCrop(opts.crop_size, pad_if_needed=True), RandomHorizontalFlip(), ToTensor(), diff --git a/methods/__init__.py b/methods/__init__.py index 1243f5f..0c19ec6 100644 --- a/methods/__init__.py +++ b/methods/__init__.py @@ -2,7 +2,7 @@ from .trainer import Trainer from .imprinting import * -methods = {"FT", "SPN", "COS", "WI", 'DWI', 'WM', "AMP", "WG", "GIFS", "LWF", "MIB", "ILT"} +methods = {"FT", "SPN", "COS", "WI", 'DWI', 'WM', "AMP", "WG", "GIFS", "LWF", "MIB", "ILT", "RT"} def get_method(opts, task, device, logger): diff --git a/methods/imprinting.py b/methods/imprinting.py index fedb0e3..40453cf 100644 --- a/methods/imprinting.py +++ b/methods/imprinting.py @@ -19,7 +19,7 @@ def initialize(self, opts): super(AMP, self).initialize(opts) self.amp_alpha = opts.amp_alpha - def warm_up(self, dataset, epochs=1): + def warm_up_(self, dataset, epochs=1): model = self.model.module if self.distributed else self.model model.eval() classes = len(self.task.order) @@ -60,7 +60,7 @@ def __init__(self, task, device, logger, opts): self.normalize_weight = False self.compute_score = False - def warm_up(self, dataset, epochs=5): + def warm_up_(self, dataset, epochs=5): model = self.model.module if self.distributed else self.model model.eval() classes = self.task.get_n_classes() @@ -122,7 +122,7 @@ def warm_up(self, dataset, epochs=5): class WeightMixing(Trainer): use_bkg = False - def warm_up(self, dataset, epochs=1): + def warm_up_(self, dataset, epochs=1): model = self.model.module if self.distributed else self.model model.eval() start_from = 0 if WeightMixing.use_bkg else 1 @@ -158,16 +158,49 @@ def warm_up(self, dataset, epochs=1): class DynamicWI(Trainer): - LR = 0.01 - ITER = 500 - BATCH_SIZE = 10 - EPISODE = 5 + LR = 1 + ITER = 200 + BATCH_SIZE = 5 + EPISODE = 2 def __init__(self, task, device, logger, opts): super(DynamicWI, self).__init__(task, device, logger, opts) - self.weight = nn.Parameter(F.normalize(torch.ones((self.n_channels, 1, 1), device=self.device), dim=0)) + self.dim = self.n_channels + + self.weights = nn.Module() + self.weight_a = nn.Parameter(F.normalize(torch.ones((self.n_channels, 1, 1), device=self.device), dim=0)) + self.weights.register_parameter("weight_a", self.weight_a) + + self.weight_b = nn.Parameter(F.normalize(torch.ones((self.n_channels, 1, 1), device=self.device), dim=0)) + self.weights.register_parameter("weight_b", self.weight_b) + + self.keys = nn.Parameter(torch.randn(self.task.get_n_classes()[0], self.n_channels)) + self.weights.register_parameter("keys", self.keys) + + self.att_weight = nn.Parameter(torch.randn(self.n_channels, self.n_channels)) + self.weights.register_parameter("att_weight", self.att_weight) + + self.use_attention = True + + def compute_weight(self, inp, cls_weight): + # input is a D dimensional prototype + if self.use_attention: + sum_weight = torch.zeros(self.dim, 1, 1) + count_weight = 0 + for x in inp: + x = self.att_weight @ x # DxD x D = DxD + x = x / x.norm(dim=1) + keys = self.keys / self.keys.norm(dim=1) # CxD + x = (keys @ x).softmax(dim=0) # C + sum_weight += (x.view(-1, 1, 1, 1) * cls_weight).sum(dim=0) + count_weight += 1 + att_weight = sum_weight / count_weight + weight = self.weight_a * inp.mean(dim=0) + self.weight_b * att_weight + else: + weight = self.weight_a * inp.mean(dim=0) + return weight - def warm_up(self, dataset, epochs=1): + def warm_up_(self, dataset, epochs=1): model = self.model.module if self.distributed else self.model model.eval() classes = self.task.get_n_classes() @@ -178,12 +211,19 @@ def warm_up(self, dataset, epochs=1): for c in new_classes: ds = dataset.get_k_image_of_class(cl=c, k=self.task.nshot) # get K images of class c wc = get_prototype(model, ds, c, self.device) + weight = None count = 0 while wc is None and count < 10: ds = dataset.get_k_image_of_class(cl=c, k=self.task.nshot) # get K images of class c - wc = get_prototype(model, ds, c, self.device) + wc = get_prototype(model, ds, c, self.device, return_all=False) + if wc is not None: + weight = self.compute_weight(wc, model.cls.cls[0].weight[1:]) count += 1 - model.cls.imprint_weights_class(F.normalize(self.weight * wc.view(self.n_channels, 1, 1), dim=0), c) + + if weight is not None: + model.cls.imprint_weights_class(weight, c) + else: + raise Exception(f"Unable to imprint weight of class {c} after {count} trials.") def cool_down(self, dataset, epochs=1): if self.step == 0: @@ -200,8 +240,7 @@ def cool_down(self, dataset, epochs=1): # instance optimizer, criterion and data classes = np.arange(0, self.task.get_n_classes()[0]) classes = classes[1:] if self.task.use_bkg else classes # remove bkg if present - params = [ # {"params": model.cls.cls[0].weight, "lr": DynamicWI.LR*0.1}, - {"params": self.weight, "lr": DynamicWI.LR}] + params = [{"params": model.cls.cls.parameters()}, {"params": self.weights.parameters()}] optimizer = torch.optim.SGD(params, lr=DynamicWI.LR) criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') @@ -219,14 +258,16 @@ def cool_down(self, dataset, epochs=1): K = random.choice([1, 5, 10]) cls = random.choice(classes) # sample N classes for c in range(self.task.get_n_classes()[0]): - wc = None if c == cls: - ds = dataset.get_k_image_of_class(cl=c, k=K) # get K images of class c - wc = get_prototype(self.model, ds, c, self.device) - if wc is None: - weight[c] = F.normalize(model.cls.cls[0].weight[c], dim=0) + ds = dataset.get_k_image_of_class(cl=cls, k=K) # get K images of class c + wc = get_prototype(self.model, ds, cls, self.device, return_all=True) + if wc is None: + print("WC is None!!") + weight[c] = F.normalize(model.cls.cls[0].weight[c], dim=0) + else: + weight[c] = F.normalize(self.compute_weight(wc, model.cls.cls[0].weight)) else: - weight[c] = F.normalize(self.weight * wc.view(self.n_channels, 1, 1), dim=0) + weight[c] = F.normalize(model.cls.cls[0].weight[c], dim=0) # get a batch of images from dataloader it, batch = get_batch(it, dataloader) @@ -252,7 +293,7 @@ def cool_down(self, dataset, epochs=1): if (i % 50) == 0: self.logger.info(f"Cool down loss at iter {i + 1}: {loss_tot / (i + 1)}") - self.logger.debug(self.weight) + #self.logger.debug(self.weight) state = {} for k, v in model.state_dict().items(): state["module." + k] = v @@ -261,12 +302,13 @@ def cool_down(self, dataset, epochs=1): def load_state_dict(self, checkpoint, strict=True): super().load_state_dict(checkpoint, strict) if self.step > 0: - device = self.weight.device - self.weight.data = checkpoint['weight'].to(device) + self.weights.load_state_dict(checkpoint['weights']) + self.weights.to(self.device) def state_dict(self): state = {"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.state_dict(), "weight": self.weight.data} + "scheduler": self.scheduler.state_dict(), + "weights": self.weights.state_dict()} return state @@ -317,7 +359,7 @@ def gen_weight(self, model, ds, cl): return w - def warm_up(self, dataset, epochs=1): + def warm_up_(self, dataset, epochs=1): self.train = False model = self.model.module if self.distributed else self.model model.eval() diff --git a/methods/trainer.py b/methods/trainer.py index 9703b98..70112fb 100644 --- a/methods/trainer.py +++ b/methods/trainer.py @@ -7,9 +7,11 @@ from .segmentation_module import make_model from modules.classifier import IncrementalClassifier, CosineClassifier, SPNetClassifier from .utils import get_scheduler, get_batch, MeanReduction +import copy CLIP = 10 + class Trainer: def __init__(self, task, device, logger, opts): self.logger = logger @@ -18,20 +20,15 @@ def __init__(self, task, device, logger, opts): self.opts = opts self.novel_classes = self.task.get_n_classes()[-1] self.step = task.step - self.need_model_old = task.step > 0 and (opts.mib_kd > 0 or opts.loss_kd > 0 or opts.l2_loss > 0 or opts.l1_loss > 0 or opts.cos_loss > 0) + + self.need_model_old = (opts.born_again or opts.mib_kd > 0 or opts.loss_kd > 0 or + opts.l2_loss > 0 or opts.l1_loss > 0 or opts.cos_loss > 0) self.n_channels = -1 # features size, will be initialized in make model self.model = self.make_model() self.model = self.model.to(device) self.distributed = False self.model_old = None - if self.need_model_old: - self.model_old = self.make_model(is_old=True) - # put the old model into distributed memory and freeze it - for par in self.model_old.parameters(): - par.requires_grad = False - self.model_old.to(device) - self.model_old.eval() if opts.fix_bn: self.model.fix_bn() @@ -39,6 +36,19 @@ def __init__(self, task, device, logger, opts): if opts.bn_momentum is not None: self.model.bn_set_momentum(opts.bn_momentum) + self.initialize(opts) # initialize model parameters (e.g. perform WI) + + self.born_again = opts.born_again + self.dist_warm_start = opts.dist_warm_start + model_old_as_new = opts.born_again or opts.dist_warm_start + if self.need_model_old: + self.model_old = self.make_model(is_old=not model_old_as_new) + # put the old model into distributed memory and freeze it + for par in self.model_old.parameters(): + par.requires_grad = False + self.model_old.to(device) + self.model_old.eval() + # xxx Set up optimizer params = [] if not opts.freeze: @@ -87,11 +97,12 @@ def __init__(self, task, device, logger, opts): self.feat_criterion = None # Output distillation + self.it_kd = opts.iterative_kd if task.step > 0 and (opts.loss_kd > 0 or opts.mib_kd > 0): - assert self.model_old is not None, "Error, model old is None but distillation specified" + assert self.model_old is not None and not self.it_kd, "Error, model old is None but distillation specified" if opts.loss_kd > 0: + self.kd_criterion = KnowledgeDistillationLoss(reduction="mean", kl=opts.kl_div) self.kd_loss = opts.loss_kd - self.kd_criterion = KnowledgeDistillationLoss(reduction="mean") if opts.mib_kd > 0: self.kd_loss = opts.mib_kd self.kd_criterion = UnbiasedKnowledgeDistillationLoss(reduction="mean") @@ -106,8 +117,6 @@ def __init__(self, task, device, logger, opts): else: self.de_criterion = None - self.initialize(opts) # setup the model, optimizer, scheduler and criterion - def make_model(self, is_old=False): classifier, self.n_channels = self.get_classifier(is_old) model = make_model(self.opts, classifier) @@ -157,6 +166,12 @@ def initialize(self, opts): classifier.cls[0].bias[0].data.copy_(new_bias.squeeze(0)) def warm_up(self, dataset, epochs=1): + self.warm_up_(dataset, epochs) + # warm start means make KD after weight imprinting or similar + if self.dist_warm_start: + self.model_old.load_state_dict(self.model.state_dict()) + + def warm_up_(self, dataset, epochs=1): pass def cool_down(self, dataset, epochs=1): @@ -215,7 +230,7 @@ def train(self, cur_epoch, train_loader, metrics=None, print_int=10, n_iter=1): if rloss <= CLIP: loss_tot = loss + rloss else: - print(f"Warning, rloss is {rloss}! Term ignored Skipped") + print(f"Warning, rloss is {rloss}! Term ignored") loss_tot = loss loss_tot.backward() @@ -316,28 +331,32 @@ def validate(self, loader, metrics, ret_samples_ids=None, novel=False): def load_state_dict(self, checkpoint, strict=True): state = {} - if (self.need_model_old and not strict) or not self.distributed: + if self.need_model_old or not self.distributed: for k, v in checkpoint["model"].items(): state[k[7:]] = v - state = state if not self.distributed else checkpoint['model'] - if self.need_model_old and not strict: - self.model_old.load_state_dict(state, strict=True) # we are loading the old model + model_state = state if not self.distributed else checkpoint['model'] + + if self.born_again and strict: + self.model_old.load_state_dict(state) + self.model.load_state_dict(model_state) + else: + if self.need_model_old and not strict: + self.model_old.load_state_dict(state, strict=not self.dist_warm_start) # we are loading the old model - if 'module.cls.class_emb' in state and not strict: # if distributed - # remove from checkpoint since SPNClassifier is not incremental - del state['module.cls.class_emb'] + if 'module.cls.class_emb' in state and not strict: # if distributed + # remove from checkpoint since SPNClassifier is not incremental + del state['module.cls.class_emb'] - if 'cls.class_emb' in state and not strict: # if not distributed - # remove from checkpoint since SPNClassifier is not incremental - del state['cls.class_emb'] + if 'cls.class_emb' in state and not strict: # if not distributed + # remove from checkpoint since SPNClassifier is not incremental + del state['cls.class_emb'] - model_state = state - self.model.load_state_dict(model_state, strict=strict) + self.model.load_state_dict(model_state, strict=strict) - if strict: # if strict, we are in ckpt (not step) so load also optim and scheduler - self.optimizer.load_state_dict(checkpoint["optimizer"]) - self.scheduler.load_state_dict(checkpoint["scheduler"]) + if not self.born_again and strict: # if strict, we are in ckpt (not step) so load also optim and scheduler + self.optimizer.load_state_dict(checkpoint["optimizer"]) + self.scheduler.load_state_dict(checkpoint["scheduler"]) def state_dict(self): state = {"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), diff --git a/methods/utils.py b/methods/utils.py index ce0d225..46ff8f3 100644 --- a/methods/utils.py +++ b/methods/utils.py @@ -30,7 +30,7 @@ def get_batch(it, dataloader): return it, batch -def get_prototype(model, ds, cl, device, interpolate_label=True): +def get_prototype(model, ds, cl, device, interpolate_label=True, return_all=False): protos = [] with torch.no_grad(): for img, lbl in ds: @@ -46,8 +46,11 @@ def get_prototype(model, ds, cl, device, interpolate_label=True): lbl = lbl.flatten() # Now it is (HxW) if (lbl == cl).float().sum() > 0: protos.append(norm_mean(out[lbl == cl, :])) + if len(protos) > 0: protos = torch.cat(protos, dim=0) + if len(protos) > 1 and return_all: + return protos return protos.mean(dim=0) else: return None diff --git a/modules/custom_bn.py b/modules/custom_bn.py index de65199..45bc5c5 100644 --- a/modules/custom_bn.py +++ b/modules/custom_bn.py @@ -124,7 +124,7 @@ class ABR(nn.Module): activation_param : float Negative slope for the `leaky_relu` activation. """ - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", + def __init__(self, num_features, eps=1e-5, momentum=0.0, affine=True, activation="leaky_relu", activation_param=0.01, group=distributed.group.WORLD, renorm=True): super(ABR, self).__init__() self.num_features = num_features @@ -146,8 +146,7 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation self.group = group self.renorm = renorm - if renorm: - self.momentum = 0. + self.momentum = momentum def reset_parameters(self): nn.init.constant_(self.running_mean, 0) @@ -201,7 +200,7 @@ def extra_repr(self): class InPlaceABR(ABR): - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", + def __init__(self, num_features, eps=1e-5, momentum=0.0, affine=True, activation="leaky_relu", activation_param=0.01): super().__init__(num_features, eps, momentum, affine, activation, activation_param) diff --git a/run.py b/run.py index 25bdf01..0bfa875 100644 --- a/run.py +++ b/run.py @@ -165,7 +165,7 @@ def main(opts): if opts.ckpt is not None: assert os.path.isfile(opts.ckpt), "Error, ckpt not found. Check the correct directory" checkpoint = torch.load(opts.ckpt, map_location="cpu") - cur_epoch = checkpoint["epoch"] + 1 + cur_epoch = checkpoint["epoch"] + 1 if not opts.born_again else 0 model.load_state_dict(checkpoint["model_state"]) logger.info("[!] Model restored from %s" % opts.ckpt) del checkpoint diff --git a/utils/loss.py b/utils/loss.py index 2bc67bd..3da1d39 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -55,28 +55,30 @@ def forward(self, inputs): class KnowledgeDistillationLoss(nn.Module): - def __init__(self, reduction='mean', alpha=1.): + def __init__(self, reduction='mean', kl=False, alpha=1.): super().__init__() self.reduction = reduction self.alpha = alpha + self.kl = kl - def forward(self, inputs, targets, mask=None): + def forward(self, inputs, targets): inputs = inputs.narrow(1, 0, targets.shape[1]) - outputs = torch.log_softmax(inputs, dim=1) - labels = torch.softmax(targets * self.alpha, dim=1) + outputs = torch.log_softmax(inputs / self.alpha, dim=1) + labels = torch.softmax(targets / self.alpha, dim=1) - loss = (outputs * labels).mean(dim=1) - - if mask is not None: - loss = loss * mask.float() + if not self.kl: + loss = -(outputs * labels).mean(dim=1) + else: + loss = F.kl_div(outputs, labels, reduction='none') * (self.alpha ** 2) + loss = loss.sum(dim=1) if self.reduction == 'mean': - outputs = -torch.mean(loss) # torch.masked_select(loss, mask).mean() + outputs = torch.mean(loss) elif self.reduction == 'sum': - outputs = -torch.sum(loss) + outputs = torch.sum(loss) else: - outputs = -loss + outputs = loss return outputs