diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..459d7b7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,107 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +#PyCharm +.idea/ \ No newline at end of file diff --git a/argparser.py b/argparser.py index cc25359..2c6bf84 100644 --- a/argparser.py +++ b/argparser.py @@ -13,6 +13,7 @@ def modify_command_options(opts): opts.no_cross_val = not opts.cross_val opts.pooling = round(opts.crop_size / opts.output_stride) + opts.crop_size_test = opts.crop_size if opts.crop_size_test is None else opts.crop_size_test return opts @@ -60,6 +61,8 @@ def get_argparser(): help='batch size (default: 4)') parser.add_argument("--crop_size", type=int, default=512, help="crop size (default: 512)") + parser.add_argument("--crop_size_test", type=int, default=None, + help="test crop size (default: = --crop_size)") parser.add_argument("--lr", type=float, default=0.007, help="learning rate (default: 0.007)") @@ -117,9 +120,12 @@ def get_argparser(): parser.add_argument("--cross_val", action='store_true', default=False, help="If validate on training or on validation (default: Train)") + parser.add_argument("--step_ckpt", default=None, type=str, + help="path to trained model at previous step. Leave it None if you want to use def path") + # Method parser.add_argument("--method", type=str, default='FT', - choices=['FT', 'SPN'], + choices=['FT', 'SPN', 'COS'], help="The method you want to use.") parser.add_argument("--embedding", type=str, default="fastnvec", choices=['word2vec', 'fasttext', 'fastnvec']) diff --git a/dataset/__init__.py b/dataset/__init__.py index 0490145..fe110ea 100644 --- a/dataset/__init__.py +++ b/dataset/__init__.py @@ -7,7 +7,7 @@ def get_dataset(opts, task): """ Dataset And Augmentation """ train_transform = Compose([ - RandomScale((0.75, 1.5)), + RandomScale((0.5, 1.5)), RandomCrop(opts.crop_size, pad_if_needed=True), RandomHorizontalFlip(), ToTensor(), @@ -22,6 +22,7 @@ def get_dataset(opts, task): std=[0.229, 0.224, 0.225]), ]) test_transform = Compose([ + CenterCrop(size=opts.crop_size), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), diff --git a/dataset/transform.py b/dataset/transform.py index 51fcf85..d74a766 100644 --- a/dataset/transform.py +++ b/dataset/transform.py @@ -497,12 +497,12 @@ def __call__(self, img, lbl=None): # pad the width if needed if self.pad_if_needed and img.size[0] < self.size[1]: img = F.pad(img, padding=int((1 + self.size[1] - img.size[0]) / 2)) - lbl = F.pad(lbl, padding=int((1 + self.size[1] - lbl.size[0]) / 2)) + lbl = F.pad(lbl, padding=int((1 + self.size[1] - lbl.size[0]) / 2), fill=255) # pad the height if needed if self.pad_if_needed and img.size[1] < self.size[0]: img = F.pad(img, padding=int((1 + self.size[0] - img.size[1]) / 2)) - lbl = F.pad(lbl, padding=int((1 + self.size[0] - lbl.size[1]) / 2)) + lbl = F.pad(lbl, padding=int((1 + self.size[0] - lbl.size[1]) / 2), fill=255) i, j, h, w = self.get_params(img, self.size) diff --git a/methods/SPNet.py b/methods/SPNet.py index 4de8919..0b0103e 100644 --- a/methods/SPNet.py +++ b/methods/SPNet.py @@ -59,3 +59,11 @@ def initialize(self, opts): reduction = 'mean' self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction=reduction) + + def load_state_dict(self, checkpoint, strict=True): + if not strict: + del checkpoint["model"]['module.cls.class_emb'] + self.model.load_state_dict(checkpoint["model"], strict=strict) + if strict: + self.optimizer.load_state_dict(checkpoint["optimizer"]) + self.scheduler.load_state_dict(checkpoint["scheduler"]) diff --git a/methods/__init__.py b/methods/__init__.py index cc34d73..fb5dfaa 100644 --- a/methods/__init__.py +++ b/methods/__init__.py @@ -1,6 +1,7 @@ from .segmentation_module import make_model from .method import FineTuning from .SPNet import SPNet +from .imprinting import CosineFT def get_method(opts, task, device, logger): @@ -9,6 +10,8 @@ def get_method(opts, task, device, logger): method_ = FineTuning(task=task, device=device, logger=logger, opts=opts) elif opts.method == "SPN": method_ = SPNet(task=task, device=device, logger=logger, opts=opts) + elif opts.method == "COS": + method_ = CosineFT(task=task, device=device, logger=logger, opts=opts) else: raise NotImplementedError diff --git a/methods/imprinting.py b/methods/imprinting.py new file mode 100644 index 0000000..d2e2213 --- /dev/null +++ b/methods/imprinting.py @@ -0,0 +1,50 @@ +import torch +from .method import FineTuning +import torch.nn as nn +from .utils import get_scheduler +import torch.nn.functional as F +from .segmentation_module import make_model + + +class Classifier(nn.Module): + def __init__(self, channels, classes): + super().__init__() + self.cls = nn.ModuleList( + [nn.Conv2d(channels, c, 1, bias=False) for c in classes]) + self.scaler = 10. + + def forward(self, x): + x = F.normalize(x, p=2, dim=1) + out = [] + for i, mod in enumerate(self.cls): + out.append(self.scaler * F.conv2d(x, F.normalize(mod.weight, dim=1, p=2))) + return torch.cat(out, dim=1) + + +class CosineFT(FineTuning): + def initialize(self, opts): + + head_channels = 256 + self.model = make_model(opts, head_channels, Classifier(head_channels, self.task.get_n_classes())) + + if opts.fix_bn: + self.model.fix_bn() + + # xxx Set up optimizer + params = [] + params.append({"params": filter(lambda p: p.requires_grad, self.model.body.parameters()), + 'weight_decay': opts.weight_decay}) + + params.append({"params": filter(lambda p: p.requires_grad, self.model.head.parameters()), + 'weight_decay': opts.weight_decay, 'lr': opts.lr*10.}) + + params.append({"params": filter(lambda p: p.requires_grad, self.model.cls.parameters()), + 'weight_decay': opts.weight_decay, 'lr': opts.lr*10.}) + + self.optimizer = torch.optim.SGD(params, lr=opts.lr, momentum=0.9, nesterov=False) + + self.scheduler = get_scheduler(opts, self.optimizer) + self.logger.debug("Optimizer:\n%s" % self.optimizer) + + reduction = 'mean' + self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction=reduction) diff --git a/methods/method.py b/methods/method.py index b6f42cd..14b2117 100644 --- a/methods/method.py +++ b/methods/method.py @@ -34,10 +34,11 @@ def train(self, cur_epoch, train_loader, print_int=10): def validate(self, loader, metrics, ret_samples_ids=None): raise NotImplementedError - def load_state_dict(self, checkpoint): - self.model.load_state_dict(checkpoint["model"], strict=True) - self.optimizer.load_state_dict(checkpoint["optimizer"]) - self.scheduler.load_state_dict(checkpoint["scheduler"]) + def load_state_dict(self, checkpoint, strict=True): + self.model.load_state_dict(checkpoint["model"], strict=strict) + if strict: + self.optimizer.load_state_dict(checkpoint["optimizer"]) + self.scheduler.load_state_dict(checkpoint["scheduler"]) def state_dict(self): state = {"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), @@ -73,20 +74,15 @@ def forward(self, x): 'weight_decay': opts.weight_decay}) params.append({"params": filter(lambda p: p.requires_grad, self.model.head.parameters()), - 'weight_decay': opts.weight_decay}) + 'weight_decay': opts.weight_decay, 'lr': opts.lr*10.}) params.append({"params": filter(lambda p: p.requires_grad, self.model.cls.parameters()), - 'weight_decay': opts.weight_decay}) + 'weight_decay': opts.weight_decay, 'lr': opts.lr*10.}) - self.optimizer = torch.optim.SGD(params, lr=opts.lr, momentum=0.9, nesterov=True) + self.optimizer = torch.optim.SGD(params, lr=opts.lr, momentum=0.9, nesterov=False) self.scheduler = get_scheduler(opts, self.optimizer) self.logger.debug("Optimizer:\n%s" % self.optimizer) - self.model, self.optimizer = amp.initialize(self.model.to(self.device), self.optimizer, opt_level=opts.opt_level) - - # Put the model on GPU - self.model = DistributedDataParallel(self.model, delay_allreduce=True) - reduction = 'mean' self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction=reduction) @@ -184,7 +180,7 @@ def validate(self, loader, metrics, ret_samples_ids=None): if ret_samples_ids is not None and i in ret_samples_ids: # get samples ret_samples.append((images[0].detach().cpu().numpy(), - labels[0])) + labels[0], prediction[0])) # collect statistics from multiple processes metrics.synch(device) diff --git a/metrics/stream_metrics.py b/metrics/stream_metrics.py index fecc6f6..91dfbcb 100644 --- a/metrics/stream_metrics.py +++ b/metrics/stream_metrics.py @@ -35,6 +35,7 @@ class StreamSegMetrics(_StreamMetrics): """ Stream Metrics for Semantic Segmentation Task """ + def __init__(self, n_classes): super().__init__() self.n_classes = n_classes @@ -48,17 +49,23 @@ def update(self, label_trues, label_preds): def to_str(self, results): string = "\n" + ignore = ["Class IoU", "Class Acc", "Class Prec", + "Confusion Matrix Pred", "Confusion Matrix", "Confusion Matrix Text"] for k, v in results.items(): - if k!="Class IoU" and k!="Class Acc" and k!="Confusion Matrix": - string += "%s: %f\n"%(k, v) - - string+='Class IoU:\n' + if k not in ignore: + string += "%s: %f\n" % (k, v) + + string += 'Class IoU:\n' for k, v in results['Class IoU'].items(): - string += "\tclass %d: %s\n"%(k, str(v)) + string += "\tclass %d: %s\n" % (k, str(v)) - string+='Class Acc:\n' + string += 'Class Acc:\n' for k, v in results['Class Acc'].items(): - string += "\tclass %d: %s\n"%(k, str(v)) + string += "\tclass %d: %s\n" % (k, str(v)) + + string += 'Class Prec:\n' + for k, v in results['Class Prec'].items(): + string += "\tclass %d: %s\n" % (k, str(v)) return string @@ -87,24 +94,32 @@ def get_results(self): acc = diag.sum() / hist.sum() acc_cls_c = diag / (gt_sum + EPS) acc_cls = np.mean(acc_cls_c[mask]) + precision_cls_c = diag / (hist.sum(axis=0) + EPS) + precision_cls = np.mean(precision_cls_c) iu = diag / (gt_sum + hist.sum(axis=0) - diag + EPS) mean_iu = np.mean(iu[mask]) freq = hist.sum(axis=1) / hist.sum() fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() + cls_iu = dict(zip(range(self.n_classes), [iu[i] if m else "X" for i, m in enumerate(mask)])) cls_acc = dict(zip(range(self.n_classes), [acc_cls_c[i] if m else "X" for i, m in enumerate(mask)])) + cls_prec = dict(zip(range(self.n_classes), [precision_cls_c[i] if m else "X" for i, m in enumerate(mask)])) return { - "Total samples": self.total_samples, - "Overall Acc": acc, - "Mean Acc": acc_cls, - "FreqW Acc": fwavacc, - "Mean IoU": mean_iu, - "Class IoU": cls_iu, - "Class Acc": cls_acc, - "Confusion Matrix": self.confusion_matrix_to_fig() - } - + "Total samples": self.total_samples, + "Overall Acc": acc, + "Mean Acc": acc_cls, + "Mean Precision": precision_cls, + "FreqW Acc": fwavacc, + "Mean IoU": mean_iu, + "Class IoU": cls_iu, + "Class Acc": cls_acc, + "Class Prec": cls_prec, + "Confusion Matrix Text": self.confusion_matrix_to_text(), + "Confusion Matrix": self.confusion_matrix_to_fig(), + "Confusion Matrix Pred": self.confusion_matrix_to_fig(norm_gt=False) + } + def reset(self): self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) self.total_samples = 0 @@ -121,8 +136,12 @@ def synch(self, device): self.confusion_matrix = confusion_matrix.cpu().numpy() self.total_samples = samples.cpu().numpy() - def confusion_matrix_to_fig(self): - cm = self.confusion_matrix.astype('float') / (self.confusion_matrix.sum(axis=1)+0.000001)[:, np.newaxis] + def confusion_matrix_to_fig(self, norm_gt=True): + if norm_gt: + div = (self.confusion_matrix.sum(axis=1) + 0.000001)[:, np.newaxis] + else: + div = (self.confusion_matrix.sum(axis=0) + 0.000001)[np.newaxis, :] + cm = self.confusion_matrix.astype('float') / div fig, ax = plt.subplots() im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) ax.figure.colorbar(im, ax=ax) @@ -134,6 +153,12 @@ def confusion_matrix_to_fig(self): fig.tight_layout() return fig + def confusion_matrix_to_text(self): + string = [] + for i in range(self.n_classes): + string.append(f"{i} : {self.confusion_matrix[i].tolist()}") + return "\n" + "\n".join(string) + class AverageMeter(object): """Computes average values""" diff --git a/run.py b/run.py index 2120707..6ed60be 100644 --- a/run.py +++ b/run.py @@ -15,8 +15,8 @@ from metrics import StreamSegMetrics from task import Task - from methods import get_method +import time def save_ckpt(path, model, epoch, best_score): @@ -31,15 +31,41 @@ def save_ckpt(path, model, epoch, best_score): torch.save(state, path) +def get_step_ckpt(opts, logger, task_name): + # xxx Get step checkpoint + step_checkpoint = None + if opts.step_ckpt is not None: + path = opts.step_ckpt + else: + path = f"checkpoints/step/{task_name}/{opts.name}_{opts.step - 1}.pth" + + # generate model from path + if os.path.exists(path): + step_checkpoint = torch.load(path, map_location="cpu") + step_checkpoint['path'] = path + elif opts.debug: + logger.info( + f"[!] WARNING: Unable to find of step {opts.step - 1}! Do you really want to do from scratch?") + else: + raise FileNotFoundError(f"Step checkpoint not found in {path}") + + return step_checkpoint + + def main(opts): distributed.init_process_group(backend='nccl', init_method='env://') device_id, device = opts.local_rank, torch.device(opts.local_rank) rank, world_size = distributed.get_rank(), distributed.get_world_size() torch.cuda.set_device(device_id) + task = Task(opts) + # Initialize logging task_name = f"{opts.task}-{opts.dataset}" - logdir_full = f"{opts.logdir}/{task_name}/{opts.name}/" + if task.nshot != -1: + logdir_full = f"{opts.logdir}/{task_name}/{opts.name}-s{task.nshot}/" + else: + logdir_full = f"{opts.logdir}/{task_name}/{opts.name}/" if rank == 0: logger = Logger(logdir_full, rank=rank, debug=opts.debug, summary=opts.visualize, step=opts.step) else: @@ -56,7 +82,6 @@ def main(opts): np.random.seed(opts.random_seed) random.seed(opts.random_seed) - task = Task(opts) train_dst, val_dst, test_dst = get_dataset(opts, task) logger.info(f"Dataset: {opts.dataset}, Train set: {len(train_dst)}, Val set: {len(val_dst)}," f" Test set: {len(test_dst)}, n_classes {opts.num_classes}") @@ -78,6 +103,13 @@ def main(opts): model = get_method(opts, task, device, logger) logger.info(f"[!] Model made with{'out' if opts.no_pretrained else ''} pre-trained") + # IF step > 0 you need to reload pretrained + if task.step > 0: + step_ckpt = get_step_ckpt(opts, logger, task_name) + model.load_state_dict(step_ckpt['model_state'], strict=False) # False because of incr. classifiers + logger.info(f"[!] Previous model loaded from {step_ckpt['path']}") + # clean memory + del step_ckpt logger.debug(model) @@ -111,7 +143,7 @@ def main(opts): denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # de-normalization for original images - val_metrics = StreamSegMetrics(opts.num_classes) + val_metrics = StreamSegMetrics(len(task.get_order())) val_score = None results = {} @@ -121,10 +153,13 @@ def main(opts): while cur_epoch < opts.epochs and not opts.test: # ===== Train ===== train_loader.sampler.set_epoch(cur_epoch) # setup dataloader sampler + start = time.time() epoch_loss = model.train(cur_epoch=cur_epoch, train_loader=train_loader, print_int=opts.print_interval) + end = time.time() logger.info(f"End of Epoch {cur_epoch}/{opts.epochs}, Average Loss={epoch_loss[0] + epoch_loss[1]}," - f" Class Loss={epoch_loss[0]}, Reg Loss={epoch_loss[1]}") + f" Class Loss={epoch_loss[0]}, Reg Loss={epoch_loss[1]} " + f"-- time: {int(end-start)//60}:{int(end-start)%60}") # ===== Log metrics on Tensorboard ===== logger.add_scalar("E-Loss", epoch_loss[0] + epoch_loss[1], cur_epoch) @@ -139,7 +174,7 @@ def main(opts): logger.print("Done validation") logger.info(f"End of Validation {cur_epoch}/{opts.epochs}, Validation Loss={val_loss[0] + val_loss[1]}," - f" Class Loss={val_loss[0]}, Reg Loss={val_loss[1]}") + f" Class Loss={val_loss[0]}, Reg Loss={val_loss[1]} ") logger.info(val_metrics.to_str(val_score)) @@ -154,11 +189,12 @@ def main(opts): logger.add_table("Val_Acc_IoU", val_score['Class Acc'], cur_epoch) # logger.add_figure("Val_Confusion_Matrix", val_score['Confusion Matrix'], cur_epoch) - for k, (img, target, reconst) in enumerate(ret_samples): + for k, (img, target, pred) in enumerate(ret_samples): img = (denorm(img) * 255).astype(np.uint8) target = label2color(target).transpose(2, 0, 1).astype(np.uint8) + pred = label2color(pred).transpose(2, 0, 1).astype(np.uint8) - concat_img = np.concatenate((img, target), axis=2) # concat along width + concat_img = np.concatenate((img, target, pred), axis=2) # concat along width logger.add_image(f'Sample_{k}', concat_img, cur_epoch) # keep the metric to print them at the end of training @@ -175,13 +211,6 @@ def main(opts): cur_epoch += 1 - # ===== Save Best Model at the end of training ===== - if rank == 0: # save best model at the last iteration - # best model to build incremental steps - if not opts.debug: - save_ckpt(checkpoint_path, model, cur_epoch, best_score) - logger.info("[!] Checkpoint saved.") - torch.distributed.barrier() # xxx Test code! @@ -193,23 +222,24 @@ def main(opts): # load best model if opts.test: - model = get_method(opts, task, device, logger) # Put the model on GPU checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model_state"]) logger.info(f"*** Model restored from {checkpoint_path}") del checkpoint - val_loss, val_score, _ = model.validate(loader=test_loader, metrics=val_metrics, logger=logger) + val_loss, val_score, _ = model.validate(loader=test_loader, metrics=val_metrics) logger.print("Done test") logger.info(f"*** End of Test, Total Loss={val_loss[0]+val_loss[1]}," f" Class Loss={val_loss[0]}, Reg Loss={val_loss[1]}") logger.info(val_metrics.to_str(val_score)) logger.add_table("Test_Class_IoU", val_score['Class IoU']) logger.add_table("Test_Class_Acc", val_score['Class Acc']) - logger.add_figure("Test_Confusion_Matrix", val_score['Confusion Matrix']) + logger.add_figure("Test_Confusion_Matrix_Recall", val_score['Confusion Matrix']) + logger.add_figure("Test_Confusion_Matrix_Precision", val_score["Confusion Matrix Pred"]) results["T-IoU"] = val_score['Class IoU'] results["T-Acc"] = val_score['Class Acc'] + results["T-Prec"] = val_score['Class Prec'] logger.add_results(results) logger.add_scalar("T_Overall_Acc", val_score['Overall Acc']) diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..671b7ee --- /dev/null +++ b/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES=1 +alias exp='python -m torch.distributed.launch --nproc_per_node=1 run.py --opt_level O1' +shopt -s expand_aliases + +met=FT +name=FT_usebg +lr=1e-4 +exp --method ${met} --name ${name} --use_bkg --fix_bn --batch_size 10 --lr ${lr} --weight_decay 5e-4 --epochs 21 --step 0 --val_interval 20 --crop_size 320 --crop_size_test 512 +exp --method ${met} --name ${name} --use_bkg --fix_bn --batch_size 10 --lr ${lr} --weight_decay 5e-4 --epochs 1000 --step 1 --input_mix both --nshot 1 --val_interval 100 --crop_size 320 --crop_size_test 512 +exp --method ${met} --name ${name} --use_bkg --fix_bn --batch_size 10 --lr ${lr} --weight_decay 5e-4 --epochs 500 --step 1 --input_mix both --nshot 2 --val_interval 50 --crop_size 320 --crop_size_test 512 +exp --method ${met} --name ${name} --use_bkg --fix_bn --batch_size 10 --lr ${lr} --weight_decay 5e-4 --epochs 200 --step 1 --input_mix both --nshot 5 --val_interval 20 --crop_size 320 --crop_size_test 512 +exp --method ${met} --name ${name} --use_bkg --fix_bn --batch_size 10 --lr ${lr} --weight_decay 5e-4 --epochs 100 --step 1 --input_mix both --nshot 10 --val_interval 10 --crop_size 320 --crop_size_test 512 +exp --method ${met} --name ${name} --use_bkg --fix_bn --batch_size 10 --lr ${lr} --weight_decay 5e-4 --epochs 50 --step 1 --input_mix both --nshot 20 --val_interval 5 --crop_size 320 --crop_size_test 512 diff --git a/scripts/run.sh b/scripts/run.sh deleted file mode 100644 index 8bccc14..0000000 --- a/scripts/run.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -## be careful, change O1!!!!! -alias exp='python -m torch.distributed.launch --nproc_per_node=2 run.py --batch_size 12 --opt_level O1' -shopt -s expand_aliases - -exp --dataset voc --name FT_exemplars --method FT --task 15-5 --step 1 --lr 0.001 --random_seed 94 --epochs 1 --exemplars 150 \ No newline at end of file diff --git a/task.py b/task.py index e18e91c..d2db2fe 100644 --- a/task.py +++ b/task.py @@ -60,4 +60,7 @@ def get_task_dict(self): return {s: self.task_dict[s] for s in range(self.step+1)} def get_n_classes(self): - return [len(self.task_dict[s]) for s in range(self.step+1)] + r = [len(self.task_dict[s]) for s in range(self.step+1)] + if self.use_bkg: + r[0] += 1 + return r