-CUDA_VISIBLE_DEVICES=0 python main.py test \
+CUDA_VISIBLE_DEVICES=1 python main.py test \
-c configs/${DATASET}.yaml \
-m data/models/${LOG_DIR}/deeplabv2_resnet101_msc/*/checkpoint_final.pth \
# evaluate the model with CRF post-processing
-CUDA_VISIBLE_DEVICES=0 python main.py crf \
+CUDA_VISIBLE_DEVICES=1 python main.py crf \
-c configs/${DATASET}.yaml \
+import numpy as np
+from sklearn.metrics import confusion_matrix
+class _StreamMetrics(object):
+ def __init__(self):
+ """ Overridden by subclasses """
+ raise NotImplementedError()
+ def update(self, gt, pred):
+ """ Overridden by subclasses """
+ raise NotImplementedError()
+ def get_results(self):
+ """ Overridden by subclasses """
+ raise NotImplementedError()
+ def to_str(self, metrics):
+ """ Overridden by subclasses """
+ raise NotImplementedError()
+ def reset(self):
+ """ Overridden by subclasses """
+ raise NotImplementedError()
+class StreamSegMetrics(_StreamMetrics):
+ """
+ Stream Metrics for Semantic Segmentation Task
+ """
+ def __init__(self, n_classes):
+ self.n_classes = n_classes
+ self.confusion_matrix = np.zeros((n_classes, n_classes))
+ def update(self, label_trues, label_preds):
+ for lt, lp in zip(label_trues, label_preds):
+ self.confusion_matrix += self._fast_hist( lt.flatten(), lp.flatten() )
+ @staticmethod
+ def to_str(results):
+ string = "\n"
+ for k, v in results.items():
+ if k!="Class IoU":
+ string += "%s: %f\n"%(k, v)
+ #string+='Class IoU:\n'
+ #for k, v in results['Class IoU'].items():
+ # string += "\tclass %d: %f\n"%(k, v)
+ return string
+ def _fast_hist(self, label_true, label_pred):
+ mask = (label_true >= 0) & (label_true < self.n_classes)
+ hist = np.bincount(
+ self.n_classes * label_true[mask].astype(int) + label_pred[mask],
+ minlength=self.n_classes ** 2,
+ ).reshape(self.n_classes, self.n_classes)
+ return hist
+ def get_results(self):
+ """Returns accuracy score evaluation result.
+ - overall accuracy
+ - mean accuracy
+ - mean IU
+ - fwavacc
+ """
+ hist = self.confusion_matrix
+ acc = np.diag(hist).sum() / hist.sum()
+ acc_cls = np.diag(hist) / hist.sum(axis=1)
+ acc_cls = np.nanmean(acc_cls)
+ iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
+ mean_iu = np.nanmean(iu)
+ freq = hist.sum(axis=1) / hist.sum()
+ fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
+ cls_iu = dict(zip(range(self.n_classes), iu))
+ return {
+ "Overall Acc": acc,
+ "Mean Acc": acc_cls,
+ "FreqW Acc": fwavacc,
+ "Mean IoU": mean_iu,
+ "Class IoU": cls_iu,
+ }
+ def reset(self):
+ self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+#!/usr/bin/env python
+# coding: utf-8
+# Author: Kazuto Nakashima
+# URL: https://kazuto1011.github.io
+# Date: 07 January 2019
+from __future__ import absolute_import, division, print_function
+import random
+import os
+import argparse
+import cv2
+import time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import yaml
+from addict import Dict
+from PIL import Image
+from libs.datasets import get_dataset
+from libs.models import *
+from libs.utils import PolynomialLR
+from libs.utils.stream_metrics import StreamSegMetrics, AverageMeter
+def get_argparser():
+ parser = argparse.ArgumentParser()
+ # Datset Options
+ parser.add_argument("--config_path", type=str, help="config file path")
+ parser.add_argument("--gt_path", type=str, help="gt label path")
+ parser.add_argument("--log_dir", type=str, help="training log path")
+ parser.add_argument("--cuda", type=bool, default=True, help="GPU")
+ parser.add_argument("--random_seed", type=int, default=1, help="random seed (default: 1)")
+ parser.add_argument("--amp", action='store_true', default=False)
+ parser.add_argument("--val_interval", type=int, default=100, help="val_interval")
+ return parser
+def makedirs(dirs):
+ if not os.path.exists(dirs):
+ os.makedirs(dirs)
+def get_device(cuda):
+ cuda = cuda and torch.cuda.is_available()
+ device = torch.device("cuda" if cuda else "cpu")
+ if cuda:
+ print("Device:")
+ for i in range(torch.cuda.device_count()):
+ print(" {}:".format(i), torch.cuda.get_device_name(i))
+ else:
+ print("Device: CPU")
+ return device
+def get_params(model, key):
+ # For Dilated FCN
+ if key == "1x":
+ for m in model.named_modules():
+ if "layer" in m[0]:
+ if isinstance(m[1], nn.Conv2d):
+ for p in m[1].parameters():
+ yield p
+ # For conv weight in the ASPP module
+ if key == "10x":
+ for m in model.named_modules():
+ if "aspp" in m[0]:
+ if isinstance(m[1], nn.Conv2d):
+ yield m[1].weight
+ # For conv bias in the ASPP module
+ if key == "20x":
+ for m in model.named_modules():
+ if "aspp" in m[0]:
+ if isinstance(m[1], nn.Conv2d):
+ yield m[1].bias
+def resize_labels(labels, size):
+ """
+ Downsample labels for 0.5x and 0.75x logits by nearest interpolation.
+ Other nearest methods result in misaligned labels.
+ -> F.interpolate(labels, shape, mode='nearest')
+ -> cv2.resize(labels, shape, interpolation=cv2.INTER_NEAREST)
+ """
+ new_labels = []
+ for label in labels:
+ label = label.float().numpy()
+ label = Image.fromarray(label).resize(size, resample=Image.NEAREST)
+ new_labels.append(np.asarray(label))
+ new_labels = torch.LongTensor(new_labels)
+ return new_labels
+def main():
+ opts = get_argparser().parse_args()
+ print(opts)
+ # Setup random seed
+ torch.manual_seed(opts.random_seed)
+ np.random.seed(opts.random_seed)
+ random.seed(opts.random_seed)
+ """
+ Training DeepLab by v2 protocol
+ """
+ # Configuration
+ with open(opts.config_path) as f:
+ CONFIG = Dict(yaml.load(f))
+ device = get_device(opts.cuda)
+ torch.backends.cudnn.benchmark = True
+ # Dataset
+ train_dataset = get_dataset(CONFIG.DATASET.NAME)(
+ augment=True,
+ flip=True,
+ gt_path=opts.gt_path,
+ )
+ print(train_dataset)
+ print()
+ valid_dataset = get_dataset(CONFIG.DATASET.NAME)(
+ augment=False,
+ gt_path="SegmentationClassAug",
+ )
+ print(valid_dataset)
+ # DataLoader
+ train_loader = torch.utils.data.DataLoader(
+ dataset=train_dataset,
+ shuffle=True,
+ )
+ valid_loader = torch.utils.data.DataLoader(
+ dataset=valid_dataset,
+ shuffle=False,
+ )
+ # Model check
+ print("Model:", CONFIG.MODEL.NAME)
+ assert (
+ CONFIG.MODEL.NAME == "DeepLabV2_ResNet101_MSC"
+ ), 'Currently support only "DeepLabV2_ResNet101_MSC"'
+ # Model setup
+ print(" Init:", CONFIG.MODEL.INIT_MODEL)
+ state_dict = torch.load(CONFIG.MODEL.INIT_MODEL, map_location='cpu')
+ for m in model.base.state_dict().keys():
+ if m not in state_dict.keys():
+ print(" Skip init:", m)
+ model.base.load_state_dict(state_dict, strict=False) # to skip ASPP
+ model = nn.DataParallel(model)
+ model.to(device)
+ # Loss definition
+ criterion = nn.CrossEntropyLoss(ignore_index=CONFIG.DATASET.IGNORE_LABEL)
+ criterion.to(device)
+ # Optimizer
+ optimizer = torch.optim.SGD(
+ # cf lr_mult and decay_mult in train.prototxt
+ params=[
+ {
+ "params": get_params(model.module, key="1x"),
+ },
+ {
+ "params": get_params(model.module, key="10x"),
+ "lr": 10 * CONFIG.SOLVER.LR,
+ },
+ {
+ "params": get_params(model.module, key="20x"),
+ "lr": 20 * CONFIG.SOLVER.LR,
+ "weight_decay": 0.0,
+ },
+ ],
+ )
+ # Learning rate scheduler
+ scheduler = PolynomialLR(
+ optimizer=optimizer,
+ )
+ # Path to save models
+ checkpoint_dir = os.path.join(
+ "models",
+ opts.log_dir,
+ )
+ makedirs(checkpoint_dir)
+ print("Checkpoint dst:", checkpoint_dir)
+ model.train()
+ metrics = StreamSegMetrics(CONFIG.DATASET.N_CLASSES)
+ scaler = torch.cuda.amp.GradScaler(enabled=opts.amp)
+ avg_loss = AverageMeter()
+ avg_time = AverageMeter()
+ curr_iter = 0
+ best_score = 0
+ end_time = time.time()
+ while True:
+ for _, images, labels, cls_labels in train_loader:
+ curr_iter += 1
+ loss = 0
+ optimizer.zero_grad()
+ with torch.cuda.amp.autocast(enabled=opts.amp):
+ # Propagate forward
+ logits = model(images.to(device))
+ # Loss
+ for logit in logits:
+ # Resize labels for {100%, 75%, 50%, Max} logits
+ _, _, H, W = logit.shape
+ labels_ = resize_labels(labels, size=(H, W))
+ pseudo_labels = logit.detach() * cls_labels[:, :, None, None].to(device)
+ pseudo_labels = pseudo_labels.argmax(dim=1)
+ _loss = criterion(logit, labels_.to(device)) + criterion(logit, pseudo_labels)
+ loss += _loss
+ # Propagate backward (just compute gradients wrt the loss)
+ loss = (loss / len(logits))
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ # Update learning rate
+ scheduler.step()
+ avg_loss.update(loss.item())
+ avg_time.update(time.time() - end_time)
+ end_time = time.time()
+ # TensorBoard
+ if curr_iter % 10 == 0:
+ print(" Itrs %d/%d, Loss=%6f, Time=%.2f , LR=%.8f" %
+ avg_loss.avg, avg_time.avg*1000, optimizer.param_groups[0]['lr']))
+ # validation
+ if curr_iter % opts.val_interval == 0:
+ print("... validation")
+ metrics.reset()
+ with torch.no_grad():
+ for _, images, labels, _ in valid_loader:
+ images = images.to(device)
+ # Forward propagation
+ logits = model(images)
+ # Pixel-wise labeling
+ _, H, W = labels.shape
+ logits = F.interpolate(logits, size=(H, W),
+ mode="bilinear", align_corners=False)
+ preds = torch.argmax(logits, dim=1).cpu().numpy()
+ targets = labels.cpu().numpy()
+ metrics.update(targets, preds)
+ score = metrics.get_results()
+ print(metrics.to_str(score))
+ if score['Mean IoU'] > best_score: # save best model
+ best_score = score['Mean IoU']
+ torch.save(
+ model.module.state_dict(), os.path.join(checkpoint_dir, "checkpoint_best.pth")
+ )
+ if curr_iter > CONFIG.SOLVER.ITER_MAX:
+ return
+if __name__ == "__main__":
+ main()
# Training DeepLab-V2 using pseudo segmentation labels
-CUDA_VISIBLE_DEVICES=0,1 python main.py train -c configs/${DATASET}.yaml --gt_path=${GT_DIR} --log_dir=${LOG_DIR}
+#CUDA_VISIBLE_DEVICES=1,2 python main.py train -c configs/${DATASET}.yaml --gt_path=${GT_DIR} --log_dir=${LOG_DIR}
+#CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py train -c configs/${DATASET}.yaml --gt_path=${GT_DIR} --log_dir=${LOG_DIR}
-CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py train -c configs/${DATASET}.yaml --gt_path=${GT_DIR} --log_dir=${LOG_DIR}
\ No newline at end of file
+CUDA_VISIBLE_DEVICES=0 python main_v2.py --config_path ${CONFIG} --gt_path ${GT_DIR} --log_dir ${LOG_DIR}
+MIT License
+Copyright (c) 2020 Gongfan Fang
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+from .voc import VOCSegmentation
+from .cityscapes import Cityscapes
\ No newline at end of file
+import json
+import os
+from collections import namedtuple
+import torch
+import torch.utils.data as data
+from PIL import Image
+import numpy as np
+class Cityscapes(data.Dataset):
+ """Cityscapes Dataset.
+ **Parameters:**
+ - **root** (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located.
+ - **split** (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val'
+ - **mode** (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types.
+ - **transform** (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
+ - **target_transform** (callable, optional): A function/transform that takes in the target and transforms it.
+ """
+ # Based on https://github.com/mcordts/cityscapesScripts
+ CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
+ 'has_instances', 'ignore_in_eval', 'color'])
+ classes = [
+ CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
+ CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
+ CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
+ CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
+ CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
+ CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
+ CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
+ CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
+ CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
+ CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
+ CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
+ CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
+ CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
+ CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
+ CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
+ CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
+ CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
+ CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
+ CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
+ CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
+ CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
+ CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
+ CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
+ CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
+ CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
+ CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
+ CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
+ CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
+ CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
+ CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
+ CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
+ CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
+ CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
+ CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
+ CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)),
+ ]
+ train_id_to_color = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)]
+ train_id_to_color.append([0, 0, 0])
+ train_id_to_color = np.array(train_id_to_color)
+ id_to_train_id = np.array([c.train_id for c in classes])
+ #train_id_to_color = [(0, 0, 0), (128, 64, 128), (70, 70, 70), (153, 153, 153), (107, 142, 35),
+ # (70, 130, 180), (220, 20, 60), (0, 0, 142)]
+ #train_id_to_color = np.array(train_id_to_color)
+ #id_to_train_id = np.array([c.category_id for c in classes], dtype='uint8') - 1
+ def __init__(self, root, split='train', mode='fine', target_type='semantic', transform=None):
+ self.root = os.path.expanduser(root)
+ self.mode = 'gtFine'
+ self.target_type = target_type
+ self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
+ self.targets_dir = os.path.join(self.root, self.mode, split)
+ self.transform = transform
+ self.split = split
+ self.images = []
+ self.targets = []
+ if split not in ['train', 'test', 'val']:
+ raise ValueError('Invalid split for mode! Please use split="train", split="test"'
+ ' or split="val"')
+ if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
+ raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
+ ' specified "split" and "mode" are inside the "root" directory')
+ for city in os.listdir(self.images_dir):
+ img_dir = os.path.join(self.images_dir, city)
+ target_dir = os.path.join(self.targets_dir, city)
+ for file_name in os.listdir(img_dir):
+ self.images.append(os.path.join(img_dir, file_name))
+ target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
+ self._get_target_suffix(self.mode, self.target_type))
+ self.targets.append(os.path.join(target_dir, target_name))
+ @classmethod
+ def encode_target(cls, target):
+ return cls.id_to_train_id[np.array(target)]
+ @classmethod
+ def decode_target(cls, target):
+ target[target == 255] = 19
+ #target = target.astype('uint8') + 1
+ return cls.train_id_to_color[target]
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
+ than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
+ """
+ image = Image.open(self.images[index]).convert('RGB')
+ target = Image.open(self.targets[index])
+ if self.transform:
+ image, target = self.transform(image, target)
+ target = self.encode_target(target)
+ return image, target
+ def __len__(self):
+ return len(self.images)
+ def _load_json(self, path):
+ with open(path, 'r') as file:
+ data = json.load(file)
+ return data
+ def _get_target_suffix(self, mode, target_type):
+ if target_type == 'instance':
+ return '{}_instanceIds.png'.format(mode)
+ elif target_type == 'semantic':
+ return '{}_labelIds.png'.format(mode)
+ elif target_type == 'color':
+ return '{}_color.png'.format(mode)
+ elif target_type == 'polygon':
+ return '{}_polygons.json'.format(mode)
+ elif target_type == 'depth':
+ return '{}_disparity.png'.format(mode)
\ No newline at end of file
+import os
+import os.path
+import hashlib
+import errno
+from tqdm import tqdm
+def gen_bar_updater(pbar):
+ def bar_update(count, block_size, total_size):
+ if pbar.total is None and total_size:
+ pbar.total = total_size
+ progress_bytes = count * block_size
+ pbar.update(progress_bytes - pbar.n)
+ return bar_update
+def check_integrity(fpath, md5=None):
+ if md5 is None:
+ return True
+ if not os.path.isfile(fpath):
+ return False
+ md5o = hashlib.md5()
+ with open(fpath, 'rb') as f:
+ # read in 1MB chunks
+ for chunk in iter(lambda: f.read(1024 * 1024), b''):
+ md5o.update(chunk)
+ md5c = md5o.hexdigest()
+ if md5c != md5:
+ return False
+ return True
+def makedir_exist_ok(dirpath):
+ """
+ Python2 support for os.makedirs(.., exist_ok=True)
+ """
+ try:
+ os.makedirs(dirpath)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ pass
+ else:
+ raise
+def download_url(url, root, filename=None, md5=None):
+ """Download a file from a url and place it in root.
+ Args:
+ url (str): URL to download file from
+ root (str): Directory to place downloaded file in
+ filename (str): Name to save the file under. If None, use the basename of the URL
+ md5 (str): MD5 checksum of the download. If None, do not check
+ """
+ from six.moves import urllib
+ root = os.path.expanduser(root)
+ if not filename:
+ filename = os.path.basename(url)
+ fpath = os.path.join(root, filename)
+ makedir_exist_ok(root)
+ # downloads file
+ if os.path.isfile(fpath) and check_integrity(fpath, md5):
+ print('Using downloaded and verified file: ' + fpath)
+ else:
+ try:
+ print('Downloading ' + url + ' to ' + fpath)
+ urllib.request.urlretrieve(
+ url, fpath,
+ reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
+ )
+ except OSError:
+ if url[:5] == 'https':
+ url = url.replace('https:', 'http:')
+ print('Failed download. Trying https -> http instead.'
+ ' Downloading ' + url + ' to ' + fpath)
+ urllib.request.urlretrieve(
+ url, fpath,
+ reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
+ )
+def list_dir(root, prefix=False):
+ """List all directories at a given root
+ Args:
+ root (str): Path to directory whose folders need to be listed
+ prefix (bool, optional): If true, prepends the path to each result, otherwise
+ only returns the name of the directories found
+ """
+ root = os.path.expanduser(root)
+ directories = list(
+ filter(
+ lambda p: os.path.isdir(os.path.join(root, p)),
+ os.listdir(root)
+ )
+ )
+ if prefix is True:
+ directories = [os.path.join(root, d) for d in directories]
+ return directories
+def list_files(root, suffix, prefix=False):
+ """List all files ending with a suffix at a given root
+ Args:
+ root (str): Path to directory whose folders need to be listed
+ suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
+ It uses the Python "str.endswith" method and is passed directly
+ prefix (bool, optional): If true, prepends the path to each result, otherwise
+ only returns the name of the files found
+ """
+ root = os.path.expanduser(root)
+ files = list(
+ filter(
+ lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
+ os.listdir(root)
+ )
+ )
+ if prefix is True:
+ files = [os.path.join(root, d) for d in files]
+ return files
\ No newline at end of file
+import os
+import sys
+import tarfile
+import collections
+import torch.utils.data as data
+import shutil
+import numpy as np
+from PIL import Image
+from torchvision.datasets.utils import download_url, check_integrity
+ '2012': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
+ 'filename': 'VOCtrainval_11-May-2012.tar',
+ 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
+ 'base_dir': 'VOCdevkit/VOC2012'
+ },
+ '2011': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
+ 'filename': 'VOCtrainval_25-May-2011.tar',
+ 'md5': '6c3384ef61512963050cb5d687e5bf1e',
+ 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
+ },
+ '2010': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
+ 'filename': 'VOCtrainval_03-May-2010.tar',
+ 'md5': 'da459979d0c395079b5c75ee67908abb',
+ 'base_dir': 'VOCdevkit/VOC2010'
+ },
+ '2009': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
+ 'filename': 'VOCtrainval_11-May-2009.tar',
+ 'md5': '59065e4b188729180974ef6572f6a212',
+ 'base_dir': 'VOCdevkit/VOC2009'
+ },
+ '2008': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
+ 'filename': 'VOCtrainval_11-May-2012.tar',
+ 'md5': '2629fa636546599198acfcfbfcf1904a',
+ 'base_dir': 'VOCdevkit/VOC2008'
+ },
+ '2007': {
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
+ 'filename': 'VOCtrainval_06-Nov-2007.tar',
+ 'md5': 'c52e279531787c972589f7e41ab4ae64',
+ 'base_dir': 'VOCdevkit/VOC2007'
+ }
+def voc_cmap(N=256, normalized=False):
+ def bitget(byteval, idx):
+ return ((byteval & (1 << idx)) != 0)
+ dtype = 'float32' if normalized else 'uint8'
+ cmap = np.zeros((N, 3), dtype=dtype)
+ for i in range(N):
+ r = g = b = 0
+ c = i
+ for j in range(8):
+ r = r | (bitget(c, 0) << 7-j)
+ g = g | (bitget(c, 1) << 7-j)
+ b = b | (bitget(c, 2) << 7-j)
+ c = c >> 3
+ cmap[i] = np.array([r, g, b])
+ cmap = cmap/255 if normalized else cmap
+ return cmap
+class VOCSegmentation(data.Dataset):
+ """`Pascal VOC `_ Segmentation Dataset.
+ Args:
+ root (string): Root directory of the VOC Dataset.
+ year (string, optional): The dataset year, supports years 2007 to 2012.
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
+ download (bool, optional): If true, downloads the dataset from the internet and
+ puts it in root directory. If dataset is already downloaded, it is not
+ downloaded again.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ """
+ cmap = voc_cmap()
+ def __init__(self,
+ root,
+ year='2012',
+ image_set='train',
+ download=False,
+ transform=None,
+ ret_fname=False):
+ year = '2012'
+ self.root = os.path.expanduser(root)
+ self.year = year
+ self.url = DATASET_YEAR_DICT[year]['url']
+ self.filename = DATASET_YEAR_DICT[year]['filename']
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
+ self.transform = transform
+ self.ret_fname = ret_fname
+ self.image_set = image_set
+ #base_dir = DATASET_YEAR_DICT[year]['base_dir']
+ #voc_root = os.path.join(self.root, base_dir)
+ voc_root = self.root
+ image_dir = os.path.join(voc_root, 'JPEGImages')
+ if download:
+ download_extract(self.url, self.root, self.filename, self.md5)
+ if not os.path.isdir(voc_root):
+ raise RuntimeError('Dataset not found or corrupted.' +
+ ' You can use download=True to download it')
+ if image_set=='train':
+ mask_dir = os.path.join(voc_root, 'refined_pseudo_segmentation_labels')
+ assert os.path.exists(mask_dir), "refined_pseudo_segmentation_labels not found, please refer to README.md and prepare it manually"
+ split_f = './datasets/data/train_aug.txt'
+ else:
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
+ print("split_f : ", split_f, os.path.exists(split_f))
+ if not os.path.exists(split_f):
+ raise ValueError(
+ 'Wrong image_set entered! Please use image_set="train" '
+ 'or image_set="trainval" or image_set="val"')
+ with open(os.path.join(split_f), "r") as f:
+ self.file_names = [x.strip() for x in f.readlines()]
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names]
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names]
+ assert (len(self.images) == len(self.masks))
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is the image segmentation.
+ """
+ img = Image.open(self.images[index]).convert('RGB')
+ target = Image.open(self.masks[index])
+ if self.transform is not None:
+ img, target = self.transform(img, target)
+ if self.ret_fname:
+ return img, target, self.file_names[index]
+ return img, target
+ def __len__(self):
+ return len(self.images)
+ @classmethod
+ def decode_target(cls, mask):
+ """decode semantic mask to RGB image"""
+ return cls.cmap[mask]
+def download_extract(url, root, filename, md5):
+ download_url(url, root, filename, md5)
+ with tarfile.open(os.path.join(root, filename), "r") as tar:
+ tar.extractall(path=root)
\ No newline at end of file
+from tqdm import tqdm
+import network
+import utils
+import os
+import random
+import argparse
+import numpy as np
+import time
+import joblib
+import multiprocessing
+from torch.utils import data
+from datasets import VOCSegmentation, Cityscapes
+from utils import ext_transforms as et
+from metrics import StreamSegMetrics
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils.visualizer import Visualizer
+from utils.utils import AverageMeter
+from PIL import Image
+import matplotlib
+import matplotlib.pyplot as plt
+from utils.crf import DenseCRF
+torch.backends.cudnn.benchmark = True
+def get_argparser():
+ parser = argparse.ArgumentParser()
+ # Datset Options
+ parser.add_argument("--data_root", type=str, default='./datasets/data',
+ help="path to Dataset")
+ parser.add_argument("--num_classes", type=int, default=21,
+ help="num classes 21 for VOC")
+ # Deeplab Options
+ parser.add_argument("--model", type=str, default='deeplabv3plus_mobilenet',
+ choices=['deeplabv3_resnet50', 'deeplabv3plus_resnet50',
+ 'deeplabv3_resnet101', 'deeplabv3plus_resnet101',
+ 'deeplabv3_mobilenet', 'deeplabv3plus_mobilenet'], help='model name')
+ parser.add_argument("--separable_conv", action='store_true', default=False,
+ help="apply separable conv to decoder and aspp")
+ parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16])
+ # Train Options
+ parser.add_argument("--crop_val", action='store_true', default=False,
+ help='crop validation (default: False)')
+ parser.add_argument("--crop_size", type=int, default=513)
+ parser.add_argument("--val_batch_size", type=int, default=4,
+ help='batch size for validation (default: 4)')
+ parser.add_argument("--ckpt", default=None, type=str,
+ help="restore from checkpoint")
+ parser.add_argument("--gpu_id", type=str, default='0',
+ help="GPU ID")
+ parser.add_argument("--random_seed", type=int, default=2,
+ help="random seed (default: 2)")
+ # PASCAL VOC Options
+ parser.add_argument("--year", type=str, default='2012_aug',
+ choices=['2012_aug', '2012', '2011', '2009', '2008', '2007'], help='year of VOC')
+ # Visdom options
+ parser.add_argument("--logit_dir", type=str, default='./logits')
+ return parser
+def get_dataset(opts):
+ """ Dataset And Augmentation
+ """
+ if opts.crop_val:
+ val_transform = et.ExtCompose([
+ et.ExtResize(opts.crop_size),
+ et.ExtCenterCrop(opts.crop_size),
+ et.ExtToTensor(),
+ et.ExtNormalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+ else:
+ val_transform = et.ExtCompose([
+ et.ExtToTensor(),
+ et.ExtNormalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+ val_dst = VOCSegmentation(root=opts.data_root, year=opts.year,
+ image_set='val', download=False,
+ transform=val_transform, ret_fname=True)
+ return val_dst
+def validate(opts, model, loader, device, metrics):
+ """Do validation and return specified samples"""
+ metrics.reset()
+ with torch.no_grad():
+ for i, (images, labels, fnames) in enumerate(loader):
+ print("[%04d/%04d] " % (i, len(loader)), end="\r")
+ images = images.to(device, dtype=torch.float32)
+ labels = labels.to(device, dtype=torch.long)
+ outputs = model(images)
+ preds = outputs.detach().max(dim=1)[1].cpu().numpy()
+ targets = labels.cpu().numpy()
+ metrics.update(targets, preds)
+ for b in range(outputs.size(0)):
+ fname = fnames[b]
+ np.save(os.path.join(opts.logit_dir, fname + ".npy"), outputs[b].detach().cpu().numpy().astype(np.float16))
+ score = metrics.get_results()
+ return score
+def crf_inference(opts, dataset, metrics):
+ metrics.reset()
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ postprocessor = DenseCRF(
+ iter_max=10,
+ pos_xy_std=1,
+ pos_w=3,
+ bi_xy_std=67,
+ bi_rgb_std=3,
+ bi_w=4,
+ )
+ def process(i):
+ image, gt_label, fname = dataset.__getitem__(i)
+ filename = os.path.join(opts.logit_dir, fname + ".npy")
+ logit = np.load(filename)
+ _, H, W = image.shape
+ logit = torch.FloatTensor(logit)[None, ...]
+ logit = F.interpolate(logit, size=(H, W), mode="bilinear", align_corners=False)
+ prob = F.softmax(logit, dim=1)[0].numpy()
+ gt_label = gt_label.cpu().numpy()
+ image = image.permute(1, 2, 0).cpu().numpy()
+ image *= std
+ image += mean
+ image *= 255
+ image = image.astype(np.uint8)
+ prob = postprocessor(image, prob)
+ pred_label = np.argmax(prob, axis=0)
+ return pred_label, gt_label
+ # CRF in multi-process
+ results = joblib.Parallel(n_jobs=multiprocessing.cpu_count(), verbose=10, pre_dispatch="all")(
+ [joblib.delayed(process)(i) for i in range(len(dataset))]
+ )
+ preds, gts = zip(*results)
+ for pred, gt in zip(preds, gts):
+ metrics.update(gt, pred)
+ score = metrics.get_results()
+ return score
+def main():
+ opts = get_argparser().parse_args()
+ os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print("Device: %s" % device)
+ # Setup random seed
+ torch.manual_seed(opts.random_seed)
+ np.random.seed(opts.random_seed)
+ random.seed(opts.random_seed)
+ os.makedirs(opts.logit_dir, exist_ok=True)
+ # Setup dataloader
+ if not opts.crop_val:
+ opts.val_batch_size = 1
+ val_dst = get_dataset(opts)
+ val_loader = data.DataLoader(
+ val_dst, batch_size=opts.val_batch_size, shuffle=False, num_workers=4)
+ print("Dataset: voc, Val set: %d" %
+ ( len(val_dst)) )
+ # Set up model
+ model_map = {
+ 'deeplabv3_resnet50': network.deeplabv3_resnet50,
+ 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
+ 'deeplabv3_resnet101': network.deeplabv3_resnet101,
+ 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
+ 'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
+ 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
+ }
+ model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
+ if opts.separable_conv and 'plus' in opts.model:
+ network.convert_to_separable_conv(model.classifier)
+ utils.set_bn_momentum(model.backbone, momentum=0.01)
+ # Set up metrics
+ metrics = StreamSegMetrics(opts.num_classes)
+ # Restore
+ if opts.ckpt is not None and os.path.isfile(opts.ckpt):
+ # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
+ checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
+ model.load_state_dict(checkpoint["model_state"])
+ model = nn.DataParallel(model)
+ model.to(device)
+ print("Model restored from %s" % opts.ckpt)
+ del checkpoint # free memory
+ else:
+ assert "no checkpoint"
+ #========== Eval ==========#
+ model.eval()
+ val_score = validate(
+ opts=opts, model=model, loader=val_loader, device=device, metrics=metrics)
+ print(metrics.to_str(val_score))
+ print("\n\n----------- crf -------------")
+ crf_score = crf_inference(opts, val_dst, metrics)
+ print(metrics.to_str(crf_score))
+ os.system(f"rm -rf {opts.logit_dir}")
+if __name__ == '__main__':
+ main()
+from tqdm import tqdm
+import network
+import utils
+import os
+import random
+import argparse
+import numpy as np
+import time
+from torch.utils import data
+from datasets import VOCSegmentation, Cityscapes
+from utils import ext_transforms as et
+from metrics import StreamSegMetrics
+import torch
+import torch.nn as nn
+from utils.visualizer import Visualizer
+from utils.utils import AverageMeter
+from PIL import Image
+import matplotlib
+import matplotlib.pyplot as plt
+torch.backends.cudnn.benchmark = True
+def get_argparser():
+ parser = argparse.ArgumentParser()
+ # Datset Options
+ parser.add_argument("--data_root", type=str, default='./datasets/data',
+ help="path to Dataset")
+ parser.add_argument("--dataset", type=str, default='voc',
+ choices=['voc', 'cityscapes'], help='Name of dataset')
+ parser.add_argument("--num_classes", type=int, default=None,
+ help="num classes (default: None)")
+ # Deeplab Options
+ parser.add_argument("--model", type=str, default='deeplabv3plus_mobilenet',
+ choices=['deeplabv3_resnet50', 'deeplabv3plus_resnet50',
+ 'deeplabv3_resnet101', 'deeplabv3plus_resnet101',
+ 'deeplabv3_mobilenet', 'deeplabv3plus_mobilenet'], help='model name')
+ parser.add_argument("--separable_conv", action='store_true', default=False,
+ help="apply separable conv to decoder and aspp")
+ parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16])
+ # Train Options
+ parser.add_argument("--amp", action='store_true', default=False)
+ parser.add_argument("--test_only", action='store_true', default=False)
+ parser.add_argument("--save_val_results", action='store_true', default=False,
+ help="save segmentation results to \"./results\"")
+ parser.add_argument("--total_itrs", type=int, default=30e3,
+ help="epoch number (default: 30k)")
+ parser.add_argument("--lr", type=float, default=0.01,
+ help="learning rate (default: 0.01)")
+ parser.add_argument("--lr_policy", type=str, default='poly', choices=['poly', 'step'],
+ help="learning rate scheduler policy")
+ parser.add_argument("--step_size", type=int, default=10000)
+ parser.add_argument("--crop_val", action='store_true', default=False,
+ help='crop validation (default: False)')
+ parser.add_argument("--batch_size", type=int, default=16,
+ help='batch size (default: 16)')
+ parser.add_argument("--val_batch_size", type=int, default=4,
+ help='batch size for validation (default: 4)')
+ parser.add_argument("--crop_size", type=int, default=513)
+ parser.add_argument("--ckpt", default=None, type=str,
+ help="restore from checkpoint")
+ parser.add_argument("--continue_training", action='store_true', default=False)
+ parser.add_argument("--loss_type", type=str, default='cross_entropy',
+ choices=['cross_entropy', 'focal_loss'], help="loss type (default: False)")
+ parser.add_argument("--gpu_id", type=str, default='0',
+ help="GPU ID")
+ parser.add_argument("--weight_decay", type=float, default=1e-4,
+ help='weight decay (default: 1e-4)')
+ parser.add_argument("--random_seed", type=int, default=2,
+ help="random seed (default: 2)")
+ parser.add_argument("--print_interval", type=int, default=10,
+ help="print interval of loss (default: 10)")
+ parser.add_argument("--val_interval", type=int, default=100,
+ help="epoch interval for eval (default: 100)")
+ parser.add_argument("--download", action='store_true', default=False,
+ help="download datasets")
+ # PASCAL VOC Options
+ parser.add_argument("--year", type=str, default='2012_aug',
+ choices=['2012_aug', '2012', '2011', '2009', '2008', '2007'], help='year of VOC')
+ # Visdom options
+ parser.add_argument("--enable_vis", action='store_true', default=False,
+ help="use visdom for visualization")
+ parser.add_argument("--vis_port", type=str, default='13570',
+ help='port for visdom')
+ parser.add_argument("--vis_env", type=str, default='main',
+ help='env for visdom')
+ parser.add_argument("--vis_num_samples", type=int, default=8,
+ help='number of samples for visualization (default: 8)')
+ return parser
+def get_dataset(opts):
+ """ Dataset And Augmentation
+ """
+ if opts.dataset == 'voc':
+ train_transform = et.ExtCompose([
+ #et.ExtResize(size=opts.crop_size),
+ et.ExtRandomScale((0.5, 2.0)),
+ et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True),
+ et.ExtRandomHorizontalFlip(),
+ et.ExtToTensor(),
+ et.ExtNormalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+ if opts.crop_val:
+ val_transform = et.ExtCompose([
+ et.ExtResize(opts.crop_size),
+ et.ExtCenterCrop(opts.crop_size),
+ et.ExtToTensor(),
+ et.ExtNormalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+ else:
+ val_transform = et.ExtCompose([
+ et.ExtToTensor(),
+ et.ExtNormalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+ train_dst = VOCSegmentation(root=opts.data_root, year=opts.year,
+ image_set='train', download=opts.download, transform=train_transform)
+ val_dst = VOCSegmentation(root=opts.data_root, year=opts.year,
+ image_set='val', download=False, transform=val_transform)
+ if opts.dataset == 'cityscapes':
+ train_transform = et.ExtCompose([
+ #et.ExtResize( 512 ),
+ et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)),
+ et.ExtColorJitter( brightness=0.5, contrast=0.5, saturation=0.5 ),
+ et.ExtRandomHorizontalFlip(),
+ et.ExtToTensor(),
+ et.ExtNormalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+ val_transform = et.ExtCompose([
+ #et.ExtResize( 512 ),
+ et.ExtToTensor(),
+ et.ExtNormalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ ])
+ train_dst = Cityscapes(root=opts.data_root,
+ split='train', transform=train_transform)
+ val_dst = Cityscapes(root=opts.data_root,
+ split='val', transform=val_transform)
+ return train_dst, val_dst
+def validate(opts, model, loader, device, metrics, ret_samples_ids=None):
+ """Do validation and return specified samples"""
+ metrics.reset()
+ ret_samples = []
+ if opts.save_val_results:
+ if not os.path.exists('results'):
+ os.mkdir('results')
+ denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+ img_id = 0
+ with torch.no_grad():
+ for i, (images, labels) in tqdm(enumerate(loader)):
+ images = images.to(device, dtype=torch.float32)
+ labels = labels.to(device, dtype=torch.long)
+ outputs = model(images)
+ preds = outputs.detach().max(dim=1)[1].cpu().numpy()
+ targets = labels.cpu().numpy()
+ metrics.update(targets, preds)
+ if ret_samples_ids is not None and i in ret_samples_ids: # get vis samples
+ ret_samples.append(
+ (images[0].detach().cpu().numpy(), targets[0], preds[0]))
+ if opts.save_val_results:
+ for i in range(len(images)):
+ image = images[i].detach().cpu().numpy()
+ target = targets[i]
+ pred = preds[i]
+ image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8)
+ target = loader.dataset.decode_target(target).astype(np.uint8)
+ pred = loader.dataset.decode_target(pred).astype(np.uint8)
+ Image.fromarray(image).save('results/%d_image.png' % img_id)
+ Image.fromarray(target).save('results/%d_target.png' % img_id)
+ Image.fromarray(pred).save('results/%d_pred.png' % img_id)
+ fig = plt.figure()
+ plt.imshow(image)
+ plt.axis('off')
+ plt.imshow(pred, alpha=0.7)
+ ax = plt.gca()
+ ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator())
+ ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
+ plt.savefig('results/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0)
+ plt.close()
+ img_id += 1
+ score = metrics.get_results()
+ return score, ret_samples
+def main():
+ opts = get_argparser().parse_args()
+ if opts.dataset.lower() == 'voc':
+ opts.num_classes = 21
+ elif opts.dataset.lower() == 'cityscapes':
+ opts.num_classes = 19
+ # Setup visualization
+ vis = Visualizer(port=opts.vis_port,
+ env=opts.vis_env) if opts.enable_vis else None
+ if vis is not None: # display options
+ vis.vis_table("Options", vars(opts))
+ os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print("Device: %s" % device)
+ # Setup random seed
+ torch.manual_seed(opts.random_seed)
+ np.random.seed(opts.random_seed)
+ random.seed(opts.random_seed)
+ # Setup dataloader
+ if opts.dataset=='voc' and not opts.crop_val:
+ opts.val_batch_size = 1
+ train_dst, val_dst = get_dataset(opts)
+ train_loader = data.DataLoader(
+ train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2)
+ val_loader = data.DataLoader(
+ val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
+ print("Dataset: %s, Train set: %d, Val set: %d" %
+ (opts.dataset, len(train_dst), len(val_dst)))
+ # Set up model
+ model_map = {
+ 'deeplabv3_resnet50': network.deeplabv3_resnet50,
+ 'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
+ 'deeplabv3_resnet101': network.deeplabv3_resnet101,
+ 'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
+ 'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
+ 'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
+ }
+ model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
+ if opts.separable_conv and 'plus' in opts.model:
+ network.convert_to_separable_conv(model.classifier)
+ utils.set_bn_momentum(model.backbone, momentum=0.01)
+ # Set up metrics
+ metrics = StreamSegMetrics(opts.num_classes)
+ # Set up optimizer
+ optimizer = torch.optim.SGD(params=[
+ {'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
+ {'params': model.classifier.parameters(), 'lr': opts.lr},
+ ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
+ #optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
+ #torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
+ if opts.lr_policy=='poly':
+ scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
+ elif opts.lr_policy=='step':
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)
+ # Set up criterion
+ #criterion = utils.get_loss(opts.loss_type)
+ if opts.loss_type == 'focal_loss':
+ criterion = utils.FocalLoss(ignore_index=255, size_average=True)
+ elif opts.loss_type == 'cross_entropy':
+ criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
+ def save_ckpt(path):
+ """ save current model
+ """
+ torch.save({
+ "cur_itrs": cur_itrs,
+ "model_state": model.module.state_dict(),
+ "optimizer_state": optimizer.state_dict(),
+ "scheduler_state": scheduler.state_dict(),
+ "best_score": best_score,
+ }, path)
+ print("Model saved as %s" % path)
+ utils.mkdir('checkpoints')
+ # Restore
+ best_score = 0.0
+ cur_itrs = 0
+ cur_epochs = 0
+ if opts.ckpt is not None and os.path.isfile(opts.ckpt):
+ # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
+ checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
+ model.load_state_dict(checkpoint["model_state"])
+ model = nn.DataParallel(model)
+ model.to(device)
+ if opts.continue_training:
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
+ scheduler.load_state_dict(checkpoint["scheduler_state"])
+ cur_itrs = checkpoint["cur_itrs"]
+ best_score = checkpoint['best_score']
+ print("Training state restored from %s" % opts.ckpt)
+ print("Model restored from %s" % opts.ckpt)
+ del checkpoint # free memory
+ else:
+ print("[!] Retrain")
+ model = nn.DataParallel(model)
+ model.to(device)
+ #========== Train Loop ==========#
+ vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,
+ np.int32) if opts.enable_vis else None # sample idxs for visualization
+ denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images
+ if opts.test_only:
+ model.eval()
+ val_score, ret_samples = validate(
+ opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
+ print(metrics.to_str(val_score))
+ return
+ scaler = torch.cuda.amp.GradScaler(enabled=opts.amp)
+ avg_loss = AverageMeter()
+ avg_time = AverageMeter()
+ interval_loss = 0
+ while True: #cur_itrs < opts.total_itrs:
+ # ===== Train =====
+ avg_loss.reset()
+ avg_time.reset()
+ model.train()
+ cur_epochs += 1
+ end_time = time.time()
+ for (images, labels) in train_loader:
+ cur_itrs += 1
+ images = images.to(device, dtype=torch.float32)
+ labels = labels.to(device, dtype=torch.long)
+ optimizer.zero_grad()
+ with torch.cuda.amp.autocast(enabled=opts.amp):
+ outputs = model(images)
+ loss = criterion(outputs, labels)
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ scheduler.step()
+ avg_loss.update(loss.item())
+ avg_time.update(time.time() - end_time)
+ end_time = time.time()
+ if vis is not None:
+ vis.vis_scalar('Loss', cur_itrs, avg_loss.avg)
+ if (cur_itrs) % 10 == 0:
+ print(" Epoch %d, Itrs %d/%d, Loss=%6f, Time=%.2f , LR=%.8f" %
+ (cur_epochs, cur_itrs, opts.total_itrs,
+ avg_loss.avg, avg_time.avg*1000, optimizer.param_groups[0]['lr']))
+ if (cur_itrs) % opts.val_interval == 0:
+ #save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
+ # (opts.model, opts.dataset, opts.output_stride))
+ print("validation...")
+ model.eval()
+ val_score, ret_samples = validate(
+ opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
+ print(metrics.to_str(val_score))
+ if val_score['Mean IoU'] > best_score: # save best model
+ best_score = val_score['Mean IoU']
+ save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
+ (opts.model, opts.dataset,opts.output_stride))
+ if vis is not None: # visualize validation score and samples
+ vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
+ vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
+ vis.vis_table("[Val] Class IoU", val_score['Class IoU'])
+ for k, (img, target, lbl) in enumerate(ret_samples):
+ img = (denorm(img) * 255).astype(np.uint8)
+ target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
+ lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
+ concat_img = np.concatenate((img, target, lbl), axis=2) # concat along width
+ vis.vis_image('Sample %d' % k, concat_img)
+ model.train()
+ if cur_itrs >= opts.total_itrs:
+ return
+if __name__ == '__main__':
+ main()
+from .stream_metrics import StreamSegMetrics, AverageMeter
+import numpy as np
+from sklearn.metrics import confusion_matrix
+class _StreamMetrics(object):
+ def __init__(self):
+ """ Overridden by subclasses """
+ raise NotImplementedError()
+ def update(self, gt, pred):
+ """ Overridden by subclasses """
+ raise NotImplementedError()
+ def get_results(self):
+ """ Overridden by subclasses """
+ raise NotImplementedError()
+ def to_str(self, metrics):
+ """ Overridden by subclasses """
+ raise NotImplementedError()
+ def reset(self):
+ """ Overridden by subclasses """
+ raise NotImplementedError()
+class StreamSegMetrics(_StreamMetrics):
+ """
+ Stream Metrics for Semantic Segmentation Task
+ """
+ def __init__(self, n_classes):
+ self.n_classes = n_classes
+ self.confusion_matrix = np.zeros((n_classes, n_classes))
+ def update(self, label_trues, label_preds):
+ for lt, lp in zip(label_trues, label_preds):
+ self.confusion_matrix += self._fast_hist( lt.flatten(), lp.flatten() )
+ @staticmethod
+ def to_str(results):
+ string = "\n"
+ for k, v in results.items():
+ if k!="Class IoU":
+ string += "%s: %f\n"%(k, v)
+ #string+='Class IoU:\n'
+ #for k, v in results['Class IoU'].items():
+ # string += "\tclass %d: %f\n"%(k, v)
+ return string
+ def _fast_hist(self, label_true, label_pred):
+ mask = (label_true >= 0) & (label_true < self.n_classes)
+ hist = np.bincount(
+ self.n_classes * label_true[mask].astype(int) + label_pred[mask],
+ minlength=self.n_classes ** 2,
+ ).reshape(self.n_classes, self.n_classes)
+ return hist
+ def get_results(self):
+ """Returns accuracy score evaluation result.
+ - overall accuracy
+ - mean accuracy
+ - mean IU
+ - fwavacc
+ """
+ hist = self.confusion_matrix
+ acc = np.diag(hist).sum() / hist.sum()
+ acc_cls = np.diag(hist) / hist.sum(axis=1)
+ acc_cls = np.nanmean(acc_cls)
+ iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
+ mean_iu = np.nanmean(iu)
+ freq = hist.sum(axis=1) / hist.sum()
+ fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
+ cls_iu = dict(zip(range(self.n_classes), iu))
+ return {
+ "Overall Acc": acc,
+ "Mean Acc": acc_cls,
+ "FreqW Acc": fwavacc,
+ "Mean IoU": mean_iu,
+ "Class IoU": cls_iu,
+ }
+ def reset(self):
+ self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
+class AverageMeter(object):
+ """Computes average values"""
+ def __init__(self):
+ self.book = dict()
+ def reset_all(self):
+ self.book.clear()
+ def reset(self, id):
+ item = self.book.get(id, None)
+ if item is not None:
+ item[0] = 0
+ item[1] = 0
+ def update(self, id, val):
+ record = self.book.get(id, None)
+ if record is None:
+ self.book[id] = [val, 1]
+ else:
+ record[0]+=val
+ record[1]+=1
+ def get_results(self, id):
+ record = self.book.get(id, None)
+ assert record is not None
+ return record[0] / record[1]
+from .modeling import *
+from ._deeplab import convert_to_separable_conv
\ No newline at end of file
+import torch
+from torch import nn
+from torch.nn import functional as F
+from .utils import _SimpleSegmentationModel
+__all__ = ["DeepLabV3"]
+class DeepLabV3(_SimpleSegmentationModel):
+ """
+ Implements DeepLabV3 model from
+ `"Rethinking Atrous Convolution for Semantic Image Segmentation"
+ `_.
+ Arguments:
+ backbone (nn.Module): the network used to compute the features for the model.
+ The backbone should return an OrderedDict[Tensor], with the key being
+ "out" for the last feature map used, and "aux" if an auxiliary classifier
+ is used.
+ classifier (nn.Module): module that takes the "out" element returned from
+ the backbone and returns a dense prediction.
+ aux_classifier (nn.Module, optional): auxiliary classifier used during training
+ """
+ pass
+class DeepLabHeadV3Plus(nn.Module):
+ def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
+ super(DeepLabHeadV3Plus, self).__init__()
+ self.project = nn.Sequential(
+ nn.Conv2d(low_level_channels, 48, 1, bias=False),
+ nn.BatchNorm2d(48),
+ nn.ReLU(inplace=True),
+ )
+ self.aspp = ASPP(in_channels, aspp_dilate)
+ self.classifier = nn.Sequential(
+ nn.Conv2d(304, 256, 3, padding=1, bias=False),
+ nn.BatchNorm2d(256),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, num_classes, 1)
+ )
+ self._init_weight()
+ def forward(self, feature):
+ low_level_feature = self.project( feature['low_level'] )
+ output_feature = self.aspp(feature['out'])
+ output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
+ return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) )
+ def _init_weight(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+class DeepLabHead(nn.Module):
+ def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
+ super(DeepLabHead, self).__init__()
+ self.classifier = nn.Sequential(
+ ASPP(in_channels, aspp_dilate),
+ nn.Conv2d(256, 256, 3, padding=1, bias=False),
+ nn.BatchNorm2d(256),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, num_classes, 1)
+ )
+ self._init_weight()
+ def forward(self, feature):
+ return self.classifier( feature['out'] )
+ def _init_weight(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+class AtrousSeparableConvolution(nn.Module):
+ """ Atrous Separable Convolution
+ """
+ def __init__(self, in_channels, out_channels, kernel_size,
+ stride=1, padding=0, dilation=1, bias=True):
+ super(AtrousSeparableConvolution, self).__init__()
+ self.body = nn.Sequential(
+ # Separable Conv
+ nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ),
+ # PointWise Conv
+ nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
+ )
+ self._init_weight()
+ def forward(self, x):
+ return self.body(x)
+ def _init_weight(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+class ASPPConv(nn.Sequential):
+ def __init__(self, in_channels, out_channels, dilation):
+ modules = [
+ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True)
+ ]
+ super(ASPPConv, self).__init__(*modules)
+class ASPPPooling(nn.Sequential):
+ def __init__(self, in_channels, out_channels):
+ super(ASPPPooling, self).__init__(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True))
+ def forward(self, x):
+ size = x.shape[-2:]
+ x = super(ASPPPooling, self).forward(x)
+ return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
+class ASPP(nn.Module):
+ def __init__(self, in_channels, atrous_rates):
+ super(ASPP, self).__init__()
+ out_channels = 256
+ modules = []
+ modules.append(nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True)))
+ rate1, rate2, rate3 = tuple(atrous_rates)
+ modules.append(ASPPConv(in_channels, out_channels, rate1))
+ modules.append(ASPPConv(in_channels, out_channels, rate2))
+ modules.append(ASPPConv(in_channels, out_channels, rate3))
+ modules.append(ASPPPooling(in_channels, out_channels))
+ self.convs = nn.ModuleList(modules)
+ self.project = nn.Sequential(
+ nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ nn.Dropout(0.1),)
+ def forward(self, x):
+ res = []
+ for conv in self.convs:
+ res.append(conv(x))
+ res = torch.cat(res, dim=1)
+ return self.project(res)
+def convert_to_separable_conv(module):
+ new_module = module
+ if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1:
+ new_module = AtrousSeparableConvolution(module.in_channels,
+ module.out_channels,
+ module.kernel_size,
+ module.stride,
+ module.padding,
+ module.dilation,
+ module.bias)
+ for name, child in module.named_children():
+ new_module.add_module(name, convert_to_separable_conv(child))
+ return new_module
\ No newline at end of file
diff --git a/DeepLabV3Plus-Pytorch/network/backbone/__init__.py b/DeepLabV3Plus-Pytorch/network/backbone/__init__.py
new file mode 100644
index 0000000..afe983f
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/network/backbone/__init__.py
@@ -0,0 +1,2 @@
+from . import resnet
+from . import mobilenetv2
diff --git a/DeepLabV3Plus-Pytorch/network/backbone/mobilenetv2.py b/DeepLabV3Plus-Pytorch/network/backbone/mobilenetv2.py
new file mode 100644
index 0000000..46fa16a
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/network/backbone/mobilenetv2.py
@@ -0,0 +1,187 @@
+from torch import nn
+from torchvision.models.utils import load_state_dict_from_url
+import torch.nn.functional as F
+__all__ = ['MobileNetV2', 'mobilenet_v2']
+model_urls = {
+ 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1):
+ #padding = (kernel_size - 1) // 2
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+def fixed_padding(kernel_size, dilation):
+ kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
+ pad_total = kernel_size_effective - 1
+ pad_beg = pad_total // 2
+ pad_end = pad_total - pad_beg
+ return (pad_beg, pad_end, pad_beg, pad_end)
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, dilation, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+ self.input_padding = fixed_padding( 3, dilation )
+ def forward(self, x):
+ x_pad = F.pad(x, self.input_padding)
+ if self.use_res_connect:
+ return x + self.conv(x_pad)
+ else:
+ return self.conv(x_pad)
+class MobileNetV2(nn.Module):
+ def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
+ """
+ MobileNet V2 main class
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ """
+ super(MobileNetV2, self).__init__()
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ self.output_stride = output_stride
+ current_stride = 1
+ if inverted_residual_setting is None:
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ [6, 160, 3, 2],
+ [6, 320, 1, 1],
+ ]
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(3, input_channel, stride=2)]
+ current_stride *= 2
+ dilation=1
+ previous_dilation = 1
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ previous_dilation = dilation
+ if current_stride == output_stride:
+ stride = 1
+ dilation *= s
+ else:
+ stride = s
+ current_stride *= s
+ output_channel = int(c * width_mult)
+ for i in range(n):
+ if i==0:
+ features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t))
+ else:
+ features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t))
+ input_channel = output_channel
+ # building last several layers
+ features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
+ # make it nn.Sequential
+ self.features = nn.Sequential(*features)
+ # building classifier
+ self.classifier = nn.Sequential(
+ nn.Dropout(0.2),
+ nn.Linear(self.last_channel, num_classes),
+ )
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+ def forward(self, x):
+ x = self.features(x)
+ x = x.mean([2, 3])
+ x = self.classifier(x)
+ return x
+def mobilenet_v2(pretrained=False, progress=True, **kwargs):
+ """
+ Constructs a MobileNetV2 architecture from
+ `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ model = MobileNetV2(**kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
diff --git a/DeepLabV3Plus-Pytorch/network/backbone/resnet.py b/DeepLabV3Plus-Pytorch/network/backbone/resnet.py
new file mode 100644
index 0000000..cebee56
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/network/backbone/resnet.py
@@ -0,0 +1,343 @@
+import torch
+import torch.nn as nn
+from torchvision.models.utils import load_state_dict_from_url
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', 'wide_resnet101_2']
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+class BasicBlock(nn.Module):
+ expansion = 1
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+ def forward(self, x):
+ identity = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ out = self.relu(out)
+ return out
+class Bottleneck(nn.Module):
+ expansion = 4
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ def forward(self, x):
+ identity = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+ out = self.conv3(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ out = self.relu(out)
+ return out
+class ResNet(nn.Module):
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+ return nn.Sequential(*layers)
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+ return x
+def _resnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+def resnet18(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+def resnet34(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+def resnet50(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+def resnet101(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+def resnet152(pretrained=False, progress=True, **kwargs):
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
diff --git a/DeepLabV3Plus-Pytorch/network/modeling.py b/DeepLabV3Plus-Pytorch/network/modeling.py
new file mode 100644
index 0000000..b053083
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/network/modeling.py
@@ -0,0 +1,137 @@
+from .utils import IntermediateLayerGetter
+from ._deeplab import DeepLabHead, DeepLabHeadV3Plus, DeepLabV3
+from .backbone import resnet
+from .backbone import mobilenetv2
+def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
+ if output_stride==8:
+ replace_stride_with_dilation=[False, True, True]
+ aspp_dilate = [12, 24, 36]
+ else:
+ replace_stride_with_dilation=[False, False, True]
+ aspp_dilate = [6, 12, 18]
+ backbone = resnet.__dict__[backbone_name](
+ pretrained=pretrained_backbone,
+ replace_stride_with_dilation=replace_stride_with_dilation)
+ inplanes = 2048
+ low_level_planes = 256
+ if name=='deeplabv3plus':
+ return_layers = {'layer4': 'out', 'layer1': 'low_level'}
+ classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
+ elif name=='deeplabv3':
+ return_layers = {'layer4': 'out'}
+ classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
+ backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ model = DeepLabV3(backbone, classifier)
+ return model
+def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_backbone):
+ if output_stride==8:
+ aspp_dilate = [12, 24, 36]
+ else:
+ aspp_dilate = [6, 12, 18]
+ backbone = mobilenetv2.mobilenet_v2(pretrained=pretrained_backbone, output_stride=output_stride)
+ # rename layers
+ backbone.low_level_features = backbone.features[0:4]
+ backbone.high_level_features = backbone.features[4:-1]
+ backbone.features = None
+ backbone.classifier = None
+ inplanes = 320
+ low_level_planes = 24
+ if name=='deeplabv3plus':
+ return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'}
+ classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
+ elif name=='deeplabv3':
+ return_layers = {'high_level_features': 'out'}
+ classifier = DeepLabHead(inplanes , num_classes, aspp_dilate)
+ backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ model = DeepLabV3(backbone, classifier)
+ return model
+def _load_model(arch_type, backbone, num_classes, output_stride, pretrained_backbone):
+ if backbone=='mobilenetv2':
+ model = _segm_mobilenet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
+ elif backbone.startswith('resnet'):
+ model = _segm_resnet(arch_type, backbone, num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
+ else:
+ raise NotImplementedError
+ return model
+# Deeplab v3
+def deeplabv3_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
+ """Constructs a DeepLabV3 model with a ResNet-50 backbone.
+ Args:
+ num_classes (int): number of classes.
+ output_stride (int): output stride for deeplab.
+ pretrained_backbone (bool): If True, use the pretrained backbone.
+ """
+ return _load_model('deeplabv3', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
+def deeplabv3_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
+ """Constructs a DeepLabV3 model with a ResNet-101 backbone.
+ Args:
+ num_classes (int): number of classes.
+ output_stride (int): output stride for deeplab.
+ pretrained_backbone (bool): If True, use the pretrained backbone.
+ """
+ return _load_model('deeplabv3', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
+def deeplabv3_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True, **kwargs):
+ """Constructs a DeepLabV3 model with a MobileNetv2 backbone.
+ Args:
+ num_classes (int): number of classes.
+ output_stride (int): output stride for deeplab.
+ pretrained_backbone (bool): If True, use the pretrained backbone.
+ """
+ return _load_model('deeplabv3', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
+# Deeplab v3+
+def deeplabv3plus_resnet50(num_classes=21, output_stride=8, pretrained_backbone=True):
+ """Constructs a DeepLabV3 model with a ResNet-50 backbone.
+ Args:
+ num_classes (int): number of classes.
+ output_stride (int): output stride for deeplab.
+ pretrained_backbone (bool): If True, use the pretrained backbone.
+ """
+ return _load_model('deeplabv3plus', 'resnet50', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
+def deeplabv3plus_resnet101(num_classes=21, output_stride=8, pretrained_backbone=True):
+ """Constructs a DeepLabV3+ model with a ResNet-101 backbone.
+ Args:
+ num_classes (int): number of classes.
+ output_stride (int): output stride for deeplab.
+ pretrained_backbone (bool): If True, use the pretrained backbone.
+ """
+ return _load_model('deeplabv3plus', 'resnet101', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
+def deeplabv3plus_mobilenet(num_classes=21, output_stride=8, pretrained_backbone=True):
+ """Constructs a DeepLabV3+ model with a MobileNetv2 backbone.
+ Args:
+ num_classes (int): number of classes.
+ output_stride (int): output stride for deeplab.
+ pretrained_backbone (bool): If True, use the pretrained backbone.
+ """
+ return _load_model('deeplabv3plus', 'mobilenetv2', num_classes, output_stride=output_stride, pretrained_backbone=pretrained_backbone)
\ No newline at end of file
diff --git a/DeepLabV3Plus-Pytorch/network/utils.py b/DeepLabV3Plus-Pytorch/network/utils.py
new file mode 100644
index 0000000..d6e2782
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/network/utils.py
@@ -0,0 +1,76 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+from collections import OrderedDict
+class _SimpleSegmentationModel(nn.Module):
+ def __init__(self, backbone, classifier):
+ super(_SimpleSegmentationModel, self).__init__()
+ self.backbone = backbone
+ self.classifier = classifier
+ def forward(self, x):
+ input_shape = x.shape[-2:]
+ features = self.backbone(x)
+ x = self.classifier(features)
+ x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
+ return x
+class IntermediateLayerGetter(nn.ModuleDict):
+ """
+ Module wrapper that returns intermediate layers from a model
+ It has a strong assumption that the modules have been registered
+ into the model in the same order as they are used.
+ This means that one should **not** reuse the same nn.Module
+ twice in the forward if you want this to work.
+ Additionally, it is only able to query submodules that are directly
+ assigned to the model. So if `model` is passed, `model.feature1` can
+ be returned, but not `model.feature1.layer2`.
+ Arguments:
+ model (nn.Module): model on which we will extract the features
+ return_layers (Dict[name, new_name]): a dict containing the names
+ of the modules for which the activations will be returned as
+ the key of the dict, and the value of the dict is the name
+ of the returned activation (which the user can specify).
+ Examples::
+ >>> m = torchvision.models.resnet18(pretrained=True)
+ >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
+ >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
+ >>> {'layer1': 'feat1', 'layer3': 'feat2'})
+ >>> out = new_m(torch.rand(1, 3, 224, 224))
+ >>> print([(k, v.shape) for k, v in out.items()])
+ >>> [('feat1', torch.Size([1, 64, 56, 56])),
+ >>> ('feat2', torch.Size([1, 256, 14, 14]))]
+ """
+ def __init__(self, model, return_layers):
+ if not set(return_layers).issubset([name for name, _ in model.named_children()]):
+ raise ValueError("return_layers are not present in model")
+ orig_return_layers = return_layers
+ return_layers = {k: v for k, v in return_layers.items()}
+ layers = OrderedDict()
+ for name, module in model.named_children():
+ layers[name] = module
+ if name in return_layers:
+ del return_layers[name]
+ if not return_layers:
+ break
+ super(IntermediateLayerGetter, self).__init__(layers)
+ self.return_layers = orig_return_layers
+ def forward(self, x):
+ out = OrderedDict()
+ for name, module in self.named_children():
+ x = module(x)
+ if name in self.return_layers:
+ out_name = self.return_layers[name]
+ out[out_name] = x
+ return out
diff --git a/DeepLabV3Plus-Pytorch/requirements.txt b/DeepLabV3Plus-Pytorch/requirements.txt
new file mode 100644
index 0000000..48b62a8
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/requirements.txt
@@ -0,0 +1,8 @@
\ No newline at end of file
diff --git a/DeepLabV3Plus-Pytorch/train.sh b/DeepLabV3Plus-Pytorch/train.sh
new file mode 100755
index 0000000..28350c7
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/train.sh
@@ -0,0 +1,13 @@
+MODEL=deeplabv3plus_resnet101 # deeplabv3plus_resnet101, deeplabv3_resnet101
+# training with 2 GPUs
+CUDA_VISLBLE_DEVICES=0,1 python main.py --data_root ${ROOT} --model ${MODEL} --gpu_id 0,1 --amp --total_itrs ${ITER} --batch_size ${BATCH} --lr ${LR}
+# evalutation with crf
+CUDA_VISIBLE_DEVICES=0,1 python eval.py --gpu_id 0,1 --data_root ${ROOT} --model ${MODEL} --val_batch_size 16 --ckpt checkpoints/best_${MODEL}_voc_os16.pth
\ No newline at end of file
diff --git a/DeepLabV3Plus-Pytorch/utils/__init__.py b/DeepLabV3Plus-Pytorch/utils/__init__.py
new file mode 100644
index 0000000..172d9f8
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/utils/__init__.py
@@ -0,0 +1,4 @@
+from .utils import *
+from .visualizer import Visualizer
+from .scheduler import PolyLR
+from .loss import FocalLoss
\ No newline at end of file
diff --git a/DeepLabV3Plus-Pytorch/utils/crf.py b/DeepLabV3Plus-Pytorch/utils/crf.py
new file mode 100755
index 0000000..5449883
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/utils/crf.py
@@ -0,0 +1,41 @@
+#!/usr/bin/env python
+# coding: utf-8
+# Author: Kazuto Nakashima
+# URL: https://kazuto1011.github.io
+# Date: 09 January 2019
+import numpy as np
+import pydensecrf.densecrf as dcrf
+import pydensecrf.utils as utils
+class DenseCRF(object):
+ def __init__(self, iter_max, pos_w, pos_xy_std, bi_w, bi_xy_std, bi_rgb_std):
+ self.iter_max = iter_max
+ self.pos_w = pos_w
+ self.pos_xy_std = pos_xy_std
+ self.bi_w = bi_w
+ self.bi_xy_std = bi_xy_std
+ self.bi_rgb_std = bi_rgb_std
+ def __call__(self, image, probmap):
+ C, H, W = probmap.shape
+ U = utils.unary_from_softmax(probmap)
+ U = np.ascontiguousarray(U)
+ image = np.ascontiguousarray(image)
+ d = dcrf.DenseCRF2D(W, H, C)
+ d.setUnaryEnergy(U)
+ d.addPairwiseGaussian(sxy=self.pos_xy_std, compat=self.pos_w)
+ d.addPairwiseBilateral(
+ sxy=self.bi_xy_std, srgb=self.bi_rgb_std, rgbim=image, compat=self.bi_w
+ )
+ Q = d.inference(self.iter_max)
+ Q = np.array(Q).reshape((C, H, W))
+ return Q
diff --git a/DeepLabV3Plus-Pytorch/utils/ext_transforms.py b/DeepLabV3Plus-Pytorch/utils/ext_transforms.py
new file mode 100644
index 0000000..7a7bd9e
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/utils/ext_transforms.py
@@ -0,0 +1,571 @@
+import torchvision
+import torch
+import torchvision.transforms.functional as F
+import random
+import numbers
+import numpy as np
+from PIL import Image
+# Extended Transforms for Semantic Segmentation
+class ExtRandomHorizontalFlip(object):
+ """Horizontally flip the given PIL Image randomly with a given probability.
+ Args:
+ p (float): probability of the image being flipped. Default value is 0.5
+ """
+ def __init__(self, p=0.5):
+ self.p = p
+ def __call__(self, img, lbl):
+ """
+ Args:
+ img (PIL Image): Image to be flipped.
+ Returns:
+ PIL Image: Randomly flipped image.
+ """
+ if random.random() < self.p:
+ return F.hflip(img), F.hflip(lbl)
+ return img, lbl
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={})'.format(self.p)
+class ExtCompose(object):
+ """Composes several transforms together.
+ Args:
+ transforms (list of ``Transform`` objects): list of transforms to compose.
+ Example:
+ >>> transforms.Compose([
+ >>> transforms.CenterCrop(10),
+ >>> transforms.ToTensor(),
+ >>> ])
+ """
+ def __init__(self, transforms):
+ self.transforms = transforms
+ def __call__(self, img, lbl):
+ for t in self.transforms:
+ img, lbl = t(img, lbl)
+ return img, lbl
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += ' {0}'.format(t)
+ format_string += '\n)'
+ return format_string
+class ExtCenterCrop(object):
+ """Crops the given PIL Image at the center.
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ """
+ def __init__(self, size):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+ def __call__(self, img, lbl):
+ """
+ Args:
+ img (PIL Image): Image to be cropped.
+ Returns:
+ PIL Image: Cropped image.
+ """
+ return F.center_crop(img, self.size), F.center_crop(lbl, self.size)
+ def __repr__(self):
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
+class ExtRandomScale(object):
+ def __init__(self, scale_range, interpolation=Image.BILINEAR):
+ self.scale_range = scale_range
+ self.interpolation = interpolation
+ def __call__(self, img, lbl):
+ """
+ Args:
+ img (PIL Image): Image to be scaled.
+ lbl (PIL Image): Label to be scaled.
+ Returns:
+ PIL Image: Rescaled image.
+ PIL Image: Rescaled label.
+ """
+ assert img.size == lbl.size
+ scale = random.uniform(self.scale_range[0], self.scale_range[1])
+ target_size = ( int(img.size[1]*scale), int(img.size[0]*scale) )
+ return F.resize(img, target_size, self.interpolation), F.resize(lbl, target_size, Image.NEAREST)
+ def __repr__(self):
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
+class ExtScale(object):
+ """Resize the input PIL Image to the given scale.
+ Args:
+ Scale (sequence or int): scale factors
+ interpolation (int, optional): Desired interpolation. Default is
+ ``PIL.Image.BILINEAR``
+ """
+ def __init__(self, scale, interpolation=Image.BILINEAR):
+ self.scale = scale
+ self.interpolation = interpolation
+ def __call__(self, img, lbl):
+ """
+ Args:
+ img (PIL Image): Image to be scaled.
+ lbl (PIL Image): Label to be scaled.
+ Returns:
+ PIL Image: Rescaled image.
+ PIL Image: Rescaled label.
+ """
+ assert img.size == lbl.size
+ target_size = ( int(img.size[1]*self.scale), int(img.size[0]*self.scale) ) # (H, W)
+ return F.resize(img, target_size, self.interpolation), F.resize(lbl, target_size, Image.NEAREST)
+ def __repr__(self):
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
+class ExtRandomRotation(object):
+ """Rotate the image by angle.
+ Args:
+ degrees (sequence or float or int): Range of degrees to select from.
+ If degrees is a number instead of sequence like (min, max), the range of degrees
+ will be (-degrees, +degrees).
+ resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
+ An optional resampling filter.
+ See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
+ If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
+ expand (bool, optional): Optional expansion flag.
+ If true, expands the output to make it large enough to hold the entire rotated image.
+ If false or omitted, make the output image the same size as the input image.
+ Note that the expand flag assumes rotation around the center and no translation.
+ center (2-tuple, optional): Optional center of rotation.
+ Origin is the upper left corner.
+ Default is the center of the image.
+ """
+ def __init__(self, degrees, resample=False, expand=False, center=None):
+ if isinstance(degrees, numbers.Number):
+ if degrees < 0:
+ raise ValueError("If degrees is a single number, it must be positive.")
+ self.degrees = (-degrees, degrees)
+ else:
+ if len(degrees) != 2:
+ raise ValueError("If degrees is a sequence, it must be of len 2.")
+ self.degrees = degrees
+ self.resample = resample
+ self.expand = expand
+ self.center = center
+ @staticmethod
+ def get_params(degrees):
+ """Get parameters for ``rotate`` for a random rotation.
+ Returns:
+ sequence: params to be passed to ``rotate`` for random rotation.
+ """
+ angle = random.uniform(degrees[0], degrees[1])
+ return angle
+ def __call__(self, img, lbl):
+ """
+ img (PIL Image): Image to be rotated.
+ lbl (PIL Image): Label to be rotated.
+ Returns:
+ PIL Image: Rotated image.
+ PIL Image: Rotated label.
+ """
+ angle = self.get_params(self.degrees)
+ return F.rotate(img, angle, self.resample, self.expand, self.center), F.rotate(lbl, angle, self.resample, self.expand, self.center)
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
+ format_string += ', resample={0}'.format(self.resample)
+ format_string += ', expand={0}'.format(self.expand)
+ if self.center is not None:
+ format_string += ', center={0}'.format(self.center)
+ format_string += ')'
+ return format_string
+class ExtRandomHorizontalFlip(object):
+ """Horizontally flip the given PIL Image randomly with a given probability.
+ Args:
+ p (float): probability of the image being flipped. Default value is 0.5
+ """
+ def __init__(self, p=0.5):
+ self.p = p
+ def __call__(self, img, lbl):
+ """
+ Args:
+ img (PIL Image): Image to be flipped.
+ Returns:
+ PIL Image: Randomly flipped image.
+ """
+ if random.random() < self.p:
+ return F.hflip(img), F.hflip(lbl)
+ return img, lbl
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={})'.format(self.p)
+class ExtRandomVerticalFlip(object):
+ """Vertically flip the given PIL Image randomly with a given probability.
+ Args:
+ p (float): probability of the image being flipped. Default value is 0.5
+ """
+ def __init__(self, p=0.5):
+ self.p = p
+ def __call__(self, img, lbl):
+ """
+ Args:
+ img (PIL Image): Image to be flipped.
+ lbl (PIL Image): Label to be flipped.
+ Returns:
+ PIL Image: Randomly flipped image.
+ PIL Image: Randomly flipped label.
+ """
+ if random.random() < self.p:
+ return F.vflip(img), F.vflip(lbl)
+ return img, lbl
+ def __repr__(self):
+ return self.__class__.__name__ + '(p={})'.format(self.p)
+class ExtPad(object):
+ def __init__(self, diviser=32):
+ self.diviser = diviser
+ def __call__(self, img, lbl):
+ h, w = img.size
+ ph = (h//32+1)*32 - h if h%32!=0 else 0
+ pw = (w//32+1)*32 - w if w%32!=0 else 0
+ im = F.pad(img, ( pw//2, pw-pw//2, ph//2, ph-ph//2) )
+ lbl = F.pad(lbl, ( pw//2, pw-pw//2, ph//2, ph-ph//2))
+ return im, lbl
+class ExtToTensor(object):
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
+ """
+ def __init__(self, normalize=True, target_type='uint8'):
+ self.normalize = normalize
+ self.target_type = target_type
+ def __call__(self, pic, lbl):
+ """
+ Note that labels will not be normalized to [0, 1].
+ Args:
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+ lbl (PIL Image or numpy.ndarray): Label to be converted to tensor.
+ Returns:
+ Tensor: Converted image and label
+ """
+ if self.normalize:
+ return F.to_tensor(pic), torch.from_numpy( np.array( lbl, dtype=self.target_type) )
+ else:
+ return torch.from_numpy( np.array( pic, dtype=np.float32).transpose(2, 0, 1) ), torch.from_numpy( np.array( lbl, dtype=self.target_type) )
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
+class ExtNormalize(object):
+ """Normalize a tensor image with mean and standard deviation.
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
+ Args:
+ mean (sequence): Sequence of means for each channel.
+ std (sequence): Sequence of standard deviations for each channel.
+ """
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+ def __call__(self, tensor, lbl):
+ """
+ Args:
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
+ tensor (Tensor): Tensor of label. A dummy input for ExtCompose
+ Returns:
+ Tensor: Normalized Tensor image.
+ Tensor: Unchanged Tensor label
+ """
+ return F.normalize(tensor, self.mean, self.std), lbl
+ def __repr__(self):
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
+class ExtRandomCrop(object):
+ """Crop the given PIL Image at a random location.
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is
+ made.
+ padding (int or sequence, optional): Optional padding on each border
+ of the image. Default is 0, i.e no padding. If a sequence of length
+ 4 is provided, it is used to pad left, top, right, bottom borders
+ respectively.
+ pad_if_needed (boolean): It will pad the image if smaller than the
+ desired size to avoid raising an exception.
+ """
+ def __init__(self, size, padding=0, pad_if_needed=False):
+ if isinstance(size, numbers.Number):
+ self.size = (int(size), int(size))
+ else:
+ self.size = size
+ self.padding = padding
+ self.pad_if_needed = pad_if_needed
+ @staticmethod
+ def get_params(img, output_size):
+ """Get parameters for ``crop`` for a random crop.
+ Args:
+ img (PIL Image): Image to be cropped.
+ output_size (tuple): Expected output size of the crop.
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
+ """
+ w, h = img.size
+ th, tw = output_size
+ if w == tw and h == th:
+ return 0, 0, h, w
+ i = random.randint(0, h - th)
+ j = random.randint(0, w - tw)
+ return i, j, th, tw
+ def __call__(self, img, lbl):
+ """
+ Args:
+ img (PIL Image): Image to be cropped.
+ lbl (PIL Image): Label to be cropped.
+ Returns:
+ PIL Image: Cropped image.
+ PIL Image: Cropped label.
+ """
+ assert img.size == lbl.size, 'size of img and lbl should be the same. %s, %s'%(img.size, lbl.size)
+ if self.padding > 0:
+ img = F.pad(img, self.padding)
+ lbl = F.pad(lbl, self.padding)
+ # pad the width if needed
+ if self.pad_if_needed and img.size[0] < self.size[1]:
+ img = F.pad(img, padding=int((1 + self.size[1] - img.size[0]) / 2))
+ lbl = F.pad(lbl, padding=int((1 + self.size[1] - lbl.size[0]) / 2))
+ # pad the height if needed
+ if self.pad_if_needed and img.size[1] < self.size[0]:
+ img = F.pad(img, padding=int((1 + self.size[0] - img.size[1]) / 2))
+ lbl = F.pad(lbl, padding=int((1 + self.size[0] - lbl.size[1]) / 2))
+ i, j, h, w = self.get_params(img, self.size)
+ return F.crop(img, i, j, h, w), F.crop(lbl, i, j, h, w)
+ def __repr__(self):
+ return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
+class ExtResize(object):
+ """Resize the input PIL Image to the given size.
+ Args:
+ size (sequence or int): Desired output size. If size is a sequence like
+ (h, w), output size will be matched to this. If size is an int,
+ smaller edge of the image will be matched to this number.
+ i.e, if height > width, then image will be rescaled to
+ (size * height / width, size)
+ interpolation (int, optional): Desired interpolation. Default is
+ ``PIL.Image.BILINEAR``
+ """
+ def __init__(self, size, interpolation=Image.BILINEAR):
+ assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
+ self.size = size
+ self.interpolation = interpolation
+ def __call__(self, img, lbl):
+ """
+ Args:
+ img (PIL Image): Image to be scaled.
+ Returns:
+ PIL Image: Rescaled image.
+ """
+ return F.resize(img, self.size, self.interpolation), F.resize(lbl, self.size, Image.NEAREST)
+ def __repr__(self):
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
+class ExtColorJitter(object):
+ """Randomly change the brightness, contrast and saturation of an image.
+ Args:
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
+ or the given [min, max]. Should be non negative numbers.
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
+ or the given [min, max]. Should be non negative numbers.
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
+ or the given [min, max]. Should be non negative numbers.
+ hue (float or tuple of float (min, max)): How much to jitter hue.
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
+ Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
+ """
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
+ self.brightness = self._check_input(brightness, 'brightness')
+ self.contrast = self._check_input(contrast, 'contrast')
+ self.saturation = self._check_input(saturation, 'saturation')
+ self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
+ clip_first_on_zero=False)
+ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
+ if isinstance(value, numbers.Number):
+ if value < 0:
+ raise ValueError("If {} is a single number, it must be non negative.".format(name))
+ value = [center - value, center + value]
+ if clip_first_on_zero:
+ value[0] = max(value[0], 0)
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
+ raise ValueError("{} values should be between {}".format(name, bound))
+ else:
+ raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
+ # or (0., 0.) for hue, do nothing
+ if value[0] == value[1] == center:
+ value = None
+ return value
+ @staticmethod
+ def get_params(brightness, contrast, saturation, hue):
+ """Get a randomized transform to be applied on image.
+ Arguments are same as that of __init__.
+ Returns:
+ Transform which randomly adjusts brightness, contrast and
+ saturation in a random order.
+ """
+ transforms = []
+ if brightness is not None:
+ brightness_factor = random.uniform(brightness[0], brightness[1])
+ transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
+ if contrast is not None:
+ contrast_factor = random.uniform(contrast[0], contrast[1])
+ transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
+ if saturation is not None:
+ saturation_factor = random.uniform(saturation[0], saturation[1])
+ transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
+ if hue is not None:
+ hue_factor = random.uniform(hue[0], hue[1])
+ transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
+ random.shuffle(transforms)
+ transform = Compose(transforms)
+ return transform
+ def __call__(self, img, lbl):
+ """
+ Args:
+ img (PIL Image): Input image.
+ Returns:
+ PIL Image: Color jittered image.
+ """
+ transform = self.get_params(self.brightness, self.contrast,
+ self.saturation, self.hue)
+ return transform(img), lbl
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ format_string += 'brightness={0}'.format(self.brightness)
+ format_string += ', contrast={0}'.format(self.contrast)
+ format_string += ', saturation={0}'.format(self.saturation)
+ format_string += ', hue={0})'.format(self.hue)
+ return format_string
+class Lambda(object):
+ """Apply a user-defined lambda as a transform.
+ Args:
+ lambd (function): Lambda/function to be used for transform.
+ """
+ def __init__(self, lambd):
+ assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
+ self.lambd = lambd
+ def __call__(self, img):
+ return self.lambd(img)
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
+class Compose(object):
+ """Composes several transforms together.
+ Args:
+ transforms (list of ``Transform`` objects): list of transforms to compose.
+ Example:
+ >>> transforms.Compose([
+ >>> transforms.CenterCrop(10),
+ >>> transforms.ToTensor(),
+ >>> ])
+ """
+ def __init__(self, transforms):
+ self.transforms = transforms
+ def __call__(self, img):
+ for t in self.transforms:
+ img = t(img)
+ return img
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += ' {0}'.format(t)
+ format_string += '\n)'
+ return format_string
\ No newline at end of file
diff --git a/DeepLabV3Plus-Pytorch/utils/loss.py b/DeepLabV3Plus-Pytorch/utils/loss.py
new file mode 100644
index 0000000..64a5f54
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/utils/loss.py
@@ -0,0 +1,21 @@
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+class FocalLoss(nn.Module):
+ def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255):
+ super(FocalLoss, self).__init__()
+ self.alpha = alpha
+ self.gamma = gamma
+ self.ignore_index = ignore_index
+ self.size_average = size_average
+ def forward(self, inputs, targets):
+ ce_loss = F.cross_entropy(
+ inputs, targets, reduction='none', ignore_index=self.ignore_index)
+ pt = torch.exp(-ce_loss)
+ focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
+ if self.size_average:
+ return focal_loss.mean()
+ else:
+ return focal_loss.sum()
\ No newline at end of file
diff --git a/DeepLabV3Plus-Pytorch/utils/scheduler.py b/DeepLabV3Plus-Pytorch/utils/scheduler.py
new file mode 100644
index 0000000..65ffcec
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/utils/scheduler.py
@@ -0,0 +1,12 @@
+from torch.optim.lr_scheduler import _LRScheduler, StepLR
+class PolyLR(_LRScheduler):
+ def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6):
+ self.power = power
+ self.max_iters = max_iters # avoid zero lr
+ self.min_lr = min_lr
+ super(PolyLR, self).__init__(optimizer, last_epoch)
+ def get_lr(self):
+ return [ max( base_lr * ( 1 - self.last_epoch/self.max_iters )**self.power, self.min_lr)
+ for base_lr in self.base_lrs]
\ No newline at end of file
diff --git a/DeepLabV3Plus-Pytorch/utils/utils.py b/DeepLabV3Plus-Pytorch/utils/utils.py
new file mode 100644
index 0000000..b5be062
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/utils/utils.py
@@ -0,0 +1,55 @@
+from torchvision.transforms.functional import normalize
+import torch.nn as nn
+import numpy as np
+import os
+def denormalize(tensor, mean, std):
+ mean = np.array(mean)
+ std = np.array(std)
+ _mean = -mean/std
+ _std = 1/std
+ return normalize(tensor, _mean, _std)
+class Denormalize(object):
+ def __init__(self, mean, std):
+ mean = np.array(mean)
+ std = np.array(std)
+ self._mean = -mean/std
+ self._std = 1/std
+ def __call__(self, tensor):
+ if isinstance(tensor, np.ndarray):
+ return (tensor - self._mean.reshape(-1,1,1)) / self._std.reshape(-1,1,1)
+ return normalize(tensor, self._mean, self._std)
+def set_bn_momentum(model, momentum=0.1):
+ for m in model.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.momentum = momentum
+def fix_bn(model):
+ for m in model.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+def mkdir(path):
+ if not os.path.exists(path):
+ os.mkdir(path)
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
diff --git a/DeepLabV3Plus-Pytorch/utils/visualizer.py b/DeepLabV3Plus-Pytorch/utils/visualizer.py
new file mode 100644
index 0000000..d1280e2
--- /dev/null
+++ b/DeepLabV3Plus-Pytorch/utils/visualizer.py
@@ -0,0 +1,87 @@
+from visdom import Visdom
+import json
+class Visualizer(object):
+ """ Visualizer
+ """
+ def __init__(self, port='13579', env='main', id=None):
+ #self.cur_win = {}
+ self.vis = Visdom(port=port, env=env)
+ self.id = id
+ self.env = env
+ # Restore
+ #ori_win = self.vis.get_window_data()
+ #ori_win = json.loads(ori_win)
+ #print(ori_win)
+ #self.cur_win = { v['title']: k for k, v in ori_win.items() }
+ def vis_scalar(self, name, x, y, opts=None):
+ if not isinstance(x, list):
+ x = [x]
+ if not isinstance(y, list):
+ y = [y]
+ if self.id is not None:
+ name = "[%s]"%self.id + name
+ default_opts = { 'title': name }
+ if opts is not None:
+ default_opts.update(opts)
+ #win = self.cur_win.get(name, None)
+ #if win is not None:
+ self.vis.line( X=x, Y=y, win=name, opts=default_opts, update='append')
+ #else:
+ # self.cur_win[name] = self.vis.line( X=x, Y=y, opts=default_opts)
+ def vis_image(self, name, img, env=None, opts=None):
+ """ vis image in visdom
+ """
+ if env is None:
+ env = self.env
+ if self.id is not None:
+ name = "[%s]"%self.id + name
+ #win = self.cur_win.get(name, None)
+ default_opts = { 'title': name }
+ if opts is not None:
+ default_opts.update(opts)
+ #if win is not None:
+ self.vis.image( img=img, win=name, opts=opts, env=env )
+ #else:
+ # self.cur_win[name] = self.vis.image( img=img, opts=default_opts, env=env )
+ def vis_table(self, name, tbl, opts=None):
+ #win = self.cur_win.get(name, None)
+ tbl_str = " "
+ tbl_str+=" \
+ Term | \
+ Value | \
+ for k, v in tbl.items():
+ tbl_str+= " \
+ %s | \
+ %s | \
"%(k, v)
+ tbl_str+="
+ default_opts = { 'title': name }
+ if opts is not None:
+ default_opts.update(opts)
+ #if win is not None:
+ self.vis.text(tbl_str, win=name, opts=default_opts)
+ #else:
+ #self.cur_win[name] = self.vis.text(tbl_str, opts=default_opts)
+if __name__=='__main__':
+ import numpy as np
+ vis = Visualizer(port=35588, env='main')
+ tbl = {"lr": 214, "momentum": 0.9}
+ vis.vis_table("test_table", tbl)
+ tbl = {"lr": 244444, "momentum": 0.9, "haha": "hoho"}
+ vis.vis_table("test_table", tbl)
+ vis.vis_scalar(name='loss', x=0, y=1)
+ vis.vis_scalar(name='loss', x=2, y=4)
+ vis.vis_scalar(name='loss', x=4, y=6)
\ No newline at end of file
diff --git a/README.md b/README.md
index 9053dc2..24598bd 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,8 @@ Discriminative Region Suppression for Weakly-Supervised Semantic Segmentation [[
We propose the discriminative region suppression (DRS) module that is a simple yet effective method to expand object activation regions. DRS suppresses the attention on discriminative regions and spreads it to adjacent non-discriminative regions, generating dense localization maps.
+[2021.06.10] we support DeepLab-V3 segmentation network!
![DRS module](https://github.com/qjadud1994/DRS/blob/main/docs/DRS_module.png)
@@ -57,6 +59,8 @@ bash run.sh
## Training the DeepLab-V2 using pseudo labels
We adopt the DeepLab-V2 pytorch implementation from https://github.com/kazuto1011/deeplab-pytorch.
+* According to the [DeepLab-V2 pytorch implementation](https://github.com/kazuto1011/deeplab-pytorch#download-pre-trained-caffemodels) , we requires an initial weights [[download]](https://drive.google.com/file/d/1Wj8Maj9KGQgwtDfvIp8FChsdAIgDvliT/view?usp=sharing).
cd DeepLab-V2-PyTorch/
@@ -70,13 +74,30 @@ bash train.sh
bash eval.sh
+## Training the DeepLab-V3+ using pseudo labels
+We adopt the DeepLab-V3+ pytorch implementation from https://github.com/VainF/DeepLabV3Plus-Pytorch.
+Note that **DeepLab-V2** suffers from the small batch issue, therefore, they utilize COCO pretrained weight and freeze batch-normalization layers; DeepLab-V2 without COCO-pretrained weight cannot reproduce their performance even in fully-supervised setting.
+In contrast, **DeepLab-V3 does not require the COCO-pretrained weight** due to the recent large memory GPUs and Synchronized BatchNorm.
+We argue that the choice of DeepLab-V3 network is more reasonable and better to measure the quality of pseudo labels.
+cd DeepLabV3Plus-Pytorch/
+# training & evaluation the DeepLab-V3+ using pseudo labels
+vi run.sh # modify the dataset path --data_root
+bash run.sh
| Model | mIoU | mIoU + CRF | pretrained |
| :----: | :----: | :----: | :----: |
| DeepLab-V2 with ResNet-101 | 69.4% | 70.4% | [[link]](https://drive.google.com/drive/folders/1zJnRI5WRnv4cL9XY5jAojwIcO7MrUwun?usp=sharing)
+| DeepLab-V3+ with ResNet-101 | 70.4% | 71.0% | [[link]](https://drive.google.com/file/d/1W1LV3gvBPRr2lIlWdvqZ-cs87qYT8Nax/view?usp=sharing)
* Note that the pretrained weight path
-* According to the [DeepLab-V2 pytorch implementation](https://github.com/kazuto1011/deeplab-pytorch#download-pre-trained-caffemodels) we used, we requires an initial weights [[download]](https://drive.google.com/file/d/1Wj8Maj9KGQgwtDfvIp8FChsdAIgDvliT/view?usp=sharing).