diff --git a/pyiqa/archs/arch_util.py b/pyiqa/archs/arch_util.py index f3a869c..6ae25da 100644 --- a/pyiqa/archs/arch_util.py +++ b/pyiqa/archs/arch_util.py @@ -1,16 +1,11 @@ -import collections.abc import math + import torch -import torchvision -import warnings -from distutils.version import LooseVersion -from itertools import repeat from torch import nn as nn from torch.nn import functional as F from torch.nn import init as init from torch.nn.modules.batchnorm import _BatchNorm -from pyiqa.utils import get_root_logger from pyiqa.utils.download_util import load_file_from_url # -------------------------------------------- diff --git a/pyiqa/archs/brisque_arch.py b/pyiqa/archs/brisque_arch.py index 3252105..1a4a7a6 100644 --- a/pyiqa/archs/brisque_arch.py +++ b/pyiqa/archs/brisque_arch.py @@ -16,8 +16,8 @@ import torch import torch.nn.functional as F from pyiqa.utils.color_util import to_y_channel -from pyiqa.utils.matlab_functions import fspecial_gauss, imresize -from .func_util import estimate_ggd_param, estimate_aggd_param, safe_sqrt, normalize_img_with_guass +from pyiqa.utils.matlab_functions import imresize +from .func_util import estimate_ggd_param, estimate_aggd_param, normalize_img_with_guass from pyiqa.utils.download_util import load_file_from_url from pyiqa.utils.registry import ARCH_REGISTRY diff --git a/pyiqa/archs/ckdn_arch.py b/pyiqa/archs/ckdn_arch.py index b24d6e3..9de0399 100644 --- a/pyiqa/archs/ckdn_arch.py +++ b/pyiqa/archs/ckdn_arch.py @@ -14,7 +14,6 @@ import torchvision as tv from pyiqa.utils.registry import ARCH_REGISTRY from pyiqa.archs.arch_util import load_pretrained_network -from pyiqa.utils.download_util import load_file_from_url try: from torch.hub import load_state_dict_from_url diff --git a/pyiqa/archs/dbcnn_arch.py b/pyiqa/archs/dbcnn_arch.py index ec0fe8f..ae2fa90 100644 --- a/pyiqa/archs/dbcnn_arch.py +++ b/pyiqa/archs/dbcnn_arch.py @@ -6,8 +6,6 @@ """ -import os - import torch import torchvision import torch.nn as nn diff --git a/pyiqa/archs/func_util.py b/pyiqa/archs/func_util.py index 3be1134..58b2658 100644 --- a/pyiqa/archs/func_util.py +++ b/pyiqa/archs/func_util.py @@ -8,8 +8,8 @@ def torch_cov(tensor, rowvar=True, bias=False): - """Estimate a covariance matrix (np.cov) - https://gist.github.com/ModarTensai/5ab449acba9df1a26c12060240773110 + r"""Estimate a covariance matrix (np.cov) + Ref: https://gist.github.com/ModarTensai/5ab449acba9df1a26c12060240773110 """ tensor = tensor if rowvar else tensor.transpose(-1, -2) tensor = tensor - tensor.mean(dim=-1, keepdim=True) @@ -18,6 +18,11 @@ def torch_cov(tensor, rowvar=True, bias=False): def safe_sqrt(x: torch.Tensor) -> torch.Tensor: + r"""Safe sqrt with EPS to ensure numeric stability. + + Args: + x (torch.Tensor): should be non-negative + """ EPS = torch.finfo(x.dtype).eps return torch.sqrt(x + EPS) diff --git a/pyiqa/archs/inception.py b/pyiqa/archs/inception.py deleted file mode 100644 index de1abef..0000000 --- a/pyiqa/archs/inception.py +++ /dev/null @@ -1,307 +0,0 @@ -# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501 -# For FID metric - -import os -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.model_zoo import load_url -from torchvision import models - -# Inception weights ported to Pytorch from -# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz -FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 -LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 - - -class InceptionV3(nn.Module): - """Pretrained InceptionV3 network returning feature maps""" - - # Index of default block of inception to return, - # corresponds to output of final average pooling - DEFAULT_BLOCK_INDEX = 3 - - # Maps feature dimensionality to their output blocks indices - BLOCK_INDEX_BY_DIM = { - 64: 0, # First max pooling features - 192: 1, # Second max pooling features - 768: 2, # Pre-aux classifier features - 2048: 3 # Final average pooling features - } - - def __init__(self, - output_blocks=(DEFAULT_BLOCK_INDEX), - resize_input=True, - normalize_input=True, - requires_grad=False, - use_fid_inception=True): - """Build pretrained InceptionV3. - - Args: - output_blocks (list[int]): Indices of blocks to return features of. - Possible values are: - - 0: corresponds to output of first max pooling - - 1: corresponds to output of second max pooling - - 2: corresponds to output which is fed to aux classifier - - 3: corresponds to output of final average pooling - resize_input (bool): If true, bilinearly resizes input to width and - height 299 before feeding input to model. As the network - without fully connected layers is fully convolutional, it - should be able to handle inputs of arbitrary size, so resizing - might not be strictly needed. Default: True. - normalize_input (bool): If true, scales the input from range (0, 1) - to the range the pretrained Inception network expects, - namely (-1, 1). Default: True. - requires_grad (bool): If true, parameters of the model require - gradients. Possibly useful for finetuning the network. - Default: False. - use_fid_inception (bool): If true, uses the pretrained Inception - model used in Tensorflow's FID implementation. - If false, uses the pretrained Inception model available in - torchvision. The FID Inception model has different weights - and a slightly different structure from torchvision's - Inception model. If you want to compute FID scores, you are - strongly advised to set this parameter to true to get - comparable results. Default: True. - """ - super(InceptionV3, self).__init__() - - self.resize_input = resize_input - self.normalize_input = normalize_input - self.output_blocks = sorted(output_blocks) - self.last_needed_block = max(output_blocks) - - assert self.last_needed_block <= 3, ('Last possible output block index is 3') - - self.blocks = nn.ModuleList() - - if use_fid_inception: - inception = fid_inception_v3() - else: - try: - inception = models.inception_v3(pretrained=True, init_weights=False) - except TypeError: - # pytorch < 1.5 does not have init_weights for inception_v3 - inception = models.inception_v3(pretrained=True) - - # Block 0: input to maxpool1 - block0 = [ - inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3, - nn.MaxPool2d(kernel_size=3, stride=2) - ] - self.blocks.append(nn.Sequential(*block0)) - - # Block 1: maxpool1 to maxpool2 - if self.last_needed_block >= 1: - block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)] - self.blocks.append(nn.Sequential(*block1)) - - # Block 2: maxpool2 to aux classifier - if self.last_needed_block >= 2: - block2 = [ - inception.Mixed_5b, - inception.Mixed_5c, - inception.Mixed_5d, - inception.Mixed_6a, - inception.Mixed_6b, - inception.Mixed_6c, - inception.Mixed_6d, - inception.Mixed_6e, - ] - self.blocks.append(nn.Sequential(*block2)) - - # Block 3: aux classifier to final avgpool - if self.last_needed_block >= 3: - block3 = [ - inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c, - nn.AdaptiveAvgPool2d(output_size=(1, 1)) - ] - self.blocks.append(nn.Sequential(*block3)) - - for param in self.parameters(): - param.requires_grad = requires_grad - - def forward(self, x): - """Get Inception feature maps. - - Args: - x (Tensor): Input tensor of shape (b, 3, h, w). - Values are expected to be in range (-1, 1). You can also input - (0, 1) with setting normalize_input = True. - - Returns: - list[Tensor]: Corresponding to the selected output block, sorted - ascending by index. - """ - output = [] - - if self.resize_input: - x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False) - - if self.normalize_input: - x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) - - for idx, block in enumerate(self.blocks): - x = block(x) - if idx in self.output_blocks: - output.append(x) - - if idx == self.last_needed_block: - break - - return output - - -def fid_inception_v3(): - """Build pretrained Inception model for FID computation. - - The Inception model for FID computation uses a different set of weights - and has a slightly different structure than torchvision's Inception. - - This method first constructs torchvision's Inception and then patches the - necessary parts that are different in the FID Inception model. - """ - try: - inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False) - except TypeError: - # pytorch < 1.5 does not have init_weights for inception_v3 - inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False) - - inception.Mixed_5b = FIDInceptionA(192, pool_features=32) - inception.Mixed_5c = FIDInceptionA(256, pool_features=64) - inception.Mixed_5d = FIDInceptionA(288, pool_features=64) - inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) - inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) - inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) - inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) - inception.Mixed_7b = FIDInceptionE_1(1280) - inception.Mixed_7c = FIDInceptionE_2(2048) - - if os.path.exists(LOCAL_FID_WEIGHTS): - state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage) - else: - state_dict = load_url(FID_WEIGHTS_URL, progress=True) - - inception.load_state_dict(state_dict) - return inception - - -class FIDInceptionA(models.inception.InceptionA): - """InceptionA block patched for FID computation""" - - def __init__(self, in_channels, pool_features): - super(FIDInceptionA, self).__init__(in_channels, pool_features) - - def forward(self, x): - branch1x1 = self.branch1x1(x) - - branch5x5 = self.branch5x5_1(x) - branch5x5 = self.branch5x5_2(branch5x5) - - branch3x3dbl = self.branch3x3dbl_1(x) - branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) - branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) - - # Patch: Tensorflow's average pool does not use the padded zero's in - # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) - branch_pool = self.branch_pool(branch_pool) - - outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] - return torch.cat(outputs, 1) - - -class FIDInceptionC(models.inception.InceptionC): - """InceptionC block patched for FID computation""" - - def __init__(self, in_channels, channels_7x7): - super(FIDInceptionC, self).__init__(in_channels, channels_7x7) - - def forward(self, x): - branch1x1 = self.branch1x1(x) - - branch7x7 = self.branch7x7_1(x) - branch7x7 = self.branch7x7_2(branch7x7) - branch7x7 = self.branch7x7_3(branch7x7) - - branch7x7dbl = self.branch7x7dbl_1(x) - branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) - branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) - branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) - branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) - - # Patch: Tensorflow's average pool does not use the padded zero's in - # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) - branch_pool = self.branch_pool(branch_pool) - - outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] - return torch.cat(outputs, 1) - - -class FIDInceptionE_1(models.inception.InceptionE): - """First InceptionE block patched for FID computation""" - - def __init__(self, in_channels): - super(FIDInceptionE_1, self).__init__(in_channels) - - def forward(self, x): - branch1x1 = self.branch1x1(x) - - branch3x3 = self.branch3x3_1(x) - branch3x3 = [ - self.branch3x3_2a(branch3x3), - self.branch3x3_2b(branch3x3), - ] - branch3x3 = torch.cat(branch3x3, 1) - - branch3x3dbl = self.branch3x3dbl_1(x) - branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) - branch3x3dbl = [ - self.branch3x3dbl_3a(branch3x3dbl), - self.branch3x3dbl_3b(branch3x3dbl), - ] - branch3x3dbl = torch.cat(branch3x3dbl, 1) - - # Patch: Tensorflow's average pool does not use the padded zero's in - # its average calculation - branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False) - branch_pool = self.branch_pool(branch_pool) - - outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] - return torch.cat(outputs, 1) - - -class FIDInceptionE_2(models.inception.InceptionE): - """Second InceptionE block patched for FID computation""" - - def __init__(self, in_channels): - super(FIDInceptionE_2, self).__init__(in_channels) - - def forward(self, x): - branch1x1 = self.branch1x1(x) - - branch3x3 = self.branch3x3_1(x) - branch3x3 = [ - self.branch3x3_2a(branch3x3), - self.branch3x3_2b(branch3x3), - ] - branch3x3 = torch.cat(branch3x3, 1) - - branch3x3dbl = self.branch3x3dbl_1(x) - branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) - branch3x3dbl = [ - self.branch3x3dbl_3a(branch3x3dbl), - self.branch3x3dbl_3b(branch3x3dbl), - ] - branch3x3dbl = torch.cat(branch3x3dbl, 1) - - # Patch: The FID Inception model uses max pooling instead of average - # pooling. This is likely an error in this specific Inception - # implementation, as other Inception models use average pooling here - # (which matches the description in the paper). - branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) - branch_pool = self.branch_pool(branch_pool) - - outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] - return torch.cat(outputs, 1) diff --git a/pyiqa/archs/nima_arch.py b/pyiqa/archs/nima_arch.py index 4e9c3a0..2d21abe 100644 --- a/pyiqa/archs/nima_arch.py +++ b/pyiqa/archs/nima_arch.py @@ -1,4 +1,7 @@ r"""NIMA model. +Reference: + Talebi, Hossein, and Peyman Milanfar. "NIMA: Neural image assessment." + IEEE transactions on image processing 27, no. 8 (2018): 3998-4011. Created by: https://github.com/yunxiaoshi/Neural-IMage-Assessment/blob/master/model/model.py @@ -28,9 +31,6 @@ class NIMA(nn.Module): default input shape: - vgg and mobilenet: (N, 3, 224, 224) - inception: (N, 3, 299, 299) - Reference: - Talebi, Hossein, and Peyman Milanfar. "NIMA: Neural image assessment." - IEEE transactions on image processing 27, no. 8 (2018): 3998-4011. """ def __init__(self, base_model_name='vgg16', diff --git a/pyiqa/archs/vgg_arch.py b/pyiqa/archs/vgg_arch.py deleted file mode 100644 index 146b323..0000000 --- a/pyiqa/archs/vgg_arch.py +++ /dev/null @@ -1,161 +0,0 @@ -import os -import torch -from collections import OrderedDict -from torch import nn as nn -from torchvision.models import vgg as vgg - -from pyiqa.utils.registry import ARCH_REGISTRY - -VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' -NAMES = { - 'vgg11': [ - 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', - 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', - 'pool5' - ], - 'vgg13': [ - 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', - 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', - 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' - ], - 'vgg16': [ - 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', - 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', - 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', - 'pool5' - ], - 'vgg19': [ - 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', - 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', - 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', - 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' - ] -} - - -def insert_bn(names): - """Insert bn layer after each conv. - - Args: - names (list): The list of layer names. - - Returns: - list: The list of layer names with bn layers. - """ - names_bn = [] - for name in names: - names_bn.append(name) - if 'conv' in name: - position = name.replace('conv', '') - names_bn.append('bn' + position) - return names_bn - - -@ARCH_REGISTRY.register() -class VGGFeatureExtractor(nn.Module): - """VGG network for feature extraction. - - In this implementation, we allow users to choose whether use normalization - in the input feature and the type of vgg network. Note that the pretrained - path must fit the vgg type. - - Args: - layer_name_list (list[str]): Forward function returns the corresponding - features according to the layer_name_list. - Example: {'relu1_1', 'relu2_1', 'relu3_1'}. - vgg_type (str): Set the type of vgg network. Default: 'vgg19'. - use_input_norm (bool): If True, normalize the input image. Importantly, - the input feature must in the range [0, 1]. Default: True. - range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. - Default: False. - requires_grad (bool): If true, the parameters of VGG network will be - optimized. Default: False. - remove_pooling (bool): If true, the max pooling operations in VGG net - will be removed. Default: False. - pooling_stride (int): The stride of max pooling operation. Default: 2. - """ - - def __init__(self, - layer_name_list, - vgg_type='vgg19', - use_input_norm=True, - range_norm=False, - requires_grad=False, - remove_pooling=False, - pooling_stride=2): - super(VGGFeatureExtractor, self).__init__() - - self.layer_name_list = layer_name_list - self.use_input_norm = use_input_norm - self.range_norm = range_norm - - self.names = NAMES[vgg_type.replace('_bn', '')] - if 'bn' in vgg_type: - self.names = insert_bn(self.names) - - # only borrow layers that will be used to avoid unused params - max_idx = 0 - for v in layer_name_list: - idx = self.names.index(v) - if idx > max_idx: - max_idx = idx - - if os.path.exists(VGG_PRETRAIN_PATH): - vgg_net = getattr(vgg, vgg_type)(pretrained=False) - state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) - vgg_net.load_state_dict(state_dict) - else: - vgg_net = getattr(vgg, vgg_type)(pretrained=True) - - features = vgg_net.features[:max_idx + 1] - - modified_net = OrderedDict() - for k, v in zip(self.names, features): - if 'pool' in k: - # if remove_pooling is true, pooling operation will be removed - if remove_pooling: - continue - else: - # in some cases, we may want to change the default stride - modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) - else: - modified_net[k] = v - - self.vgg_net = nn.Sequential(modified_net) - - if not requires_grad: - self.vgg_net.eval() - for param in self.parameters(): - param.requires_grad = False - else: - self.vgg_net.train() - for param in self.parameters(): - param.requires_grad = True - - if self.use_input_norm: - # the mean is for image with range [0, 1] - self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) - # the std is for image with range [0, 1] - self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) - - def forward(self, x): - """Forward function. - - Args: - x (Tensor): Input tensor with shape (n, c, h, w). - - Returns: - Tensor: Forward results. - """ - if self.range_norm: - x = (x + 1) / 2 - if self.use_input_norm: - x = (x - self.mean) / self.std - - output = {} - for key, layer in self.vgg_net._modules.items(): - x = layer(x) - if key in self.layer_name_list: - output[key] = x.clone() - - return output diff --git a/pyiqa/archs/wadiqam_arch.py b/pyiqa/archs/wadiqam_arch.py index 8919cf4..83baf7f 100644 --- a/pyiqa/archs/wadiqam_arch.py +++ b/pyiqa/archs/wadiqam_arch.py @@ -1,9 +1,13 @@ r"""WaDIQaM model. -Created by: https://github.com/lidq92/WaDIQaM +Reference: + Bosse, Sebastian, Dominique Maniry, Klaus-Robert Müller, Thomas Wiegand, + and Wojciech Samek. "Deep neural networks for no-reference and full-reference + image quality assessment." IEEE Transactions on image processing 27, no. 1 + (2017): 206-219. +Created by: https://github.com/lidq92/WaDIQaM Modified by: Chaofeng Chen (https://github.com/chaofengc) - Refer to: Official code from https://github.com/dmaniry/deepIQA @@ -41,12 +45,6 @@ class WaDIQaM(nn.Module): load_feature_weight_only (Boolean): Only load featureweight. eps (float): Constant value. - Reference: - Bosse, Sebastian, Dominique Maniry, Klaus-Robert Müller, Thomas Wiegand, - and Wojciech Samek. "Deep neural networks for no-reference and full-reference - image quality assessment." IEEE Transactions on image processing 27, no. 1 - (2017): 206-219. - """ def __init__(self, metric_mode='FR', diff --git a/pyiqa/losses/__init__.py b/pyiqa/losses/__init__.py index 74f8d6e..ee9ce92 100644 --- a/pyiqa/losses/__init__.py +++ b/pyiqa/losses/__init__.py @@ -2,14 +2,13 @@ from pyiqa.utils import get_root_logger from pyiqa.utils.registry import LOSS_REGISTRY -from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize, - gradient_penalty_loss, r1_penalty) +from .losses import CharbonnierLoss, L1Loss, MSELoss, WeightedTVLoss + from .iqa_losses import EMDLoss, PLCCLoss, NiNLoss __all__ = [ - 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss', - 'r1_penalty', 'g_path_regularize', + 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'EMDLoss', 'PLCCLoss', 'NiNLoss' ] diff --git a/pyiqa/losses/losses.py b/pyiqa/losses/losses.py index e90d342..5ed3988 100644 --- a/pyiqa/losses/losses.py +++ b/pyiqa/losses/losses.py @@ -4,7 +4,6 @@ from torch import nn as nn from torch.nn import functional as F -from pyiqa.archs.vgg_arch import VGGFeatureExtractor from pyiqa.utils.registry import LOSS_REGISTRY from .loss_util import weighted_loss @@ -145,348 +144,3 @@ def forward(self, pred, weight=None): return loss -@LOSS_REGISTRY.register() -class PerceptualLoss(nn.Module): - """Perceptual loss with commonly used style loss. - - Args: - layer_weights (dict): The weight for each layer of vgg feature. - Here is an example: {'conv5_4': 1.}, which means the conv5_4 - feature layer (before relu5_4) will be extracted with weight - 1.0 in calculating losses. - vgg_type (str): The type of vgg network used as feature extractor. - Default: 'vgg19'. - use_input_norm (bool): If True, normalize the input image in vgg. - Default: True. - range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. - Default: False. - perceptual_weight (float): If `perceptual_weight > 0`, the perceptual - loss will be calculated and the loss will multiplied by the - weight. Default: 1.0. - style_weight (float): If `style_weight > 0`, the style loss will be - calculated and the loss will multiplied by the weight. - Default: 0. - criterion (str): Criterion used for perceptual loss. Default: 'l1'. - """ - - def __init__(self, - layer_weights, - vgg_type='vgg19', - use_input_norm=True, - range_norm=False, - perceptual_weight=1.0, - style_weight=0., - criterion='l1'): - super(PerceptualLoss, self).__init__() - self.perceptual_weight = perceptual_weight - self.style_weight = style_weight - self.layer_weights = layer_weights - self.vgg = VGGFeatureExtractor( - layer_name_list=list(layer_weights.keys()), - vgg_type=vgg_type, - use_input_norm=use_input_norm, - range_norm=range_norm) - - self.criterion_type = criterion - if self.criterion_type == 'l1': - self.criterion = torch.nn.L1Loss() - elif self.criterion_type == 'l2': - self.criterion = torch.nn.L2loss() - elif self.criterion_type == 'fro': - self.criterion = None - else: - raise NotImplementedError(f'{criterion} criterion has not been supported.') - - def forward(self, x, gt): - """Forward function. - - Args: - x (Tensor): Input tensor with shape (n, c, h, w). - gt (Tensor): Ground-truth tensor with shape (n, c, h, w). - - Returns: - Tensor: Forward results. - """ - # extract vgg features - x_features = self.vgg(x) - gt_features = self.vgg(gt.detach()) - - # calculate perceptual loss - if self.perceptual_weight > 0: - percep_loss = 0 - for k in x_features.keys(): - if self.criterion_type == 'fro': - percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] - else: - percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] - percep_loss *= self.perceptual_weight - else: - percep_loss = None - - # calculate style loss - if self.style_weight > 0: - style_loss = 0 - for k in x_features.keys(): - if self.criterion_type == 'fro': - style_loss += torch.norm( - self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] - else: - style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( - gt_features[k])) * self.layer_weights[k] - style_loss *= self.style_weight - else: - style_loss = None - - return percep_loss, style_loss - - def _gram_mat(self, x): - """Calculate Gram matrix. - - Args: - x (torch.Tensor): Tensor with shape of (n, c, h, w). - - Returns: - torch.Tensor: Gram matrix. - """ - n, c, h, w = x.size() - features = x.view(n, c, w * h) - features_t = features.transpose(1, 2) - gram = features.bmm(features_t) / (c * h * w) - return gram - - -@LOSS_REGISTRY.register() -class GANLoss(nn.Module): - """Define GAN loss. - - Args: - gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. - real_label_val (float): The value for real label. Default: 1.0. - fake_label_val (float): The value for fake label. Default: 0.0. - loss_weight (float): Loss weight. Default: 1.0. - Note that loss_weight is only for generators; and it is always 1.0 - for discriminators. - """ - - def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): - super(GANLoss, self).__init__() - self.gan_type = gan_type - self.loss_weight = loss_weight - self.real_label_val = real_label_val - self.fake_label_val = fake_label_val - - if self.gan_type == 'vanilla': - self.loss = nn.BCEWithLogitsLoss() - elif self.gan_type == 'lsgan': - self.loss = nn.MSELoss() - elif self.gan_type == 'wgan': - self.loss = self._wgan_loss - elif self.gan_type == 'wgan_softplus': - self.loss = self._wgan_softplus_loss - elif self.gan_type == 'hinge': - self.loss = nn.ReLU() - else: - raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') - - def _wgan_loss(self, input, target): - """wgan loss. - - Args: - input (Tensor): Input tensor. - target (bool): Target label. - - Returns: - Tensor: wgan loss. - """ - return -input.mean() if target else input.mean() - - def _wgan_softplus_loss(self, input, target): - """wgan loss with soft plus. softplus is a smooth approximation to the - ReLU function. - - In StyleGAN2, it is called: - Logistic loss for discriminator; - Non-saturating loss for generator. - - Args: - input (Tensor): Input tensor. - target (bool): Target label. - - Returns: - Tensor: wgan loss. - """ - return F.softplus(-input).mean() if target else F.softplus(input).mean() - - def get_target_label(self, input, target_is_real): - """Get target label. - - Args: - input (Tensor): Input tensor. - target_is_real (bool): Whether the target is real or fake. - - Returns: - (bool | Tensor): Target tensor. Return bool for wgan, otherwise, - return Tensor. - """ - - if self.gan_type in ['wgan', 'wgan_softplus']: - return target_is_real - target_val = (self.real_label_val if target_is_real else self.fake_label_val) - return input.new_ones(input.size()) * target_val - - def forward(self, input, target_is_real, is_disc=False): - """ - Args: - input (Tensor): The input for the loss module, i.e., the network - prediction. - target_is_real (bool): Whether the targe is real or fake. - is_disc (bool): Whether the loss for discriminators or not. - Default: False. - - Returns: - Tensor: GAN loss value. - """ - target_label = self.get_target_label(input, target_is_real) - if self.gan_type == 'hinge': - if is_disc: # for discriminators in hinge-gan - input = -input if target_is_real else input - loss = self.loss(1 + input).mean() - else: # for generators in hinge-gan - loss = -input.mean() - else: # other gan types - loss = self.loss(input, target_label) - - # loss_weight is always 1.0 for discriminators - return loss if is_disc else loss * self.loss_weight - - -@LOSS_REGISTRY.register() -class MultiScaleGANLoss(GANLoss): - """ - MultiScaleGANLoss accepts a list of predictions - """ - - def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): - super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) - - def forward(self, input, target_is_real, is_disc=False): - """ - The input is a list of tensors, or a list of (a list of tensors) - """ - if isinstance(input, list): - loss = 0 - for pred_i in input: - if isinstance(pred_i, list): - # Only compute GAN loss for the last layer - # in case of multiscale feature matching - pred_i = pred_i[-1] - # Safe operation: 0-dim tensor calling self.mean() does nothing - loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() - loss += loss_tensor - return loss / len(input) - else: - return super().forward(input, target_is_real, is_disc) - - -def r1_penalty(real_pred, real_img): - """R1 regularization for discriminator. The core idea is to - penalize the gradient on real data alone: when the - generator distribution produces the true data distribution - and the discriminator is equal to 0 on the data manifold, the - gradient penalty ensures that the discriminator cannot create - a non-zero gradient orthogonal to the data manifold without - suffering a loss in the GAN game. - - Ref: - Eq. 9 in Which training methods for GANs do actually converge. - """ - grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] - grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() - return grad_penalty - - -def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): - noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) - grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] - path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) - - path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) - - path_penalty = (path_lengths - path_mean).pow(2).mean() - - return path_penalty, path_lengths.detach().mean(), path_mean.detach() - - -def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): - """Calculate gradient penalty for wgan-gp. - - Args: - discriminator (nn.Module): Network for the discriminator. - real_data (Tensor): Real input data. - fake_data (Tensor): Fake input data. - weight (Tensor): Weight tensor. Default: None. - - Returns: - Tensor: A tensor for gradient penalty. - """ - - batch_size = real_data.size(0) - alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) - - # interpolate between real_data and fake_data - interpolates = alpha * real_data + (1. - alpha) * fake_data - interpolates = autograd.Variable(interpolates, requires_grad=True) - - disc_interpolates = discriminator(interpolates) - gradients = autograd.grad( - outputs=disc_interpolates, - inputs=interpolates, - grad_outputs=torch.ones_like(disc_interpolates), - create_graph=True, - retain_graph=True, - only_inputs=True)[0] - - if weight is not None: - gradients = gradients * weight - - gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() - if weight is not None: - gradients_penalty /= torch.mean(weight) - - return gradients_penalty - - -@LOSS_REGISTRY.register() -class GANFeatLoss(nn.Module): - """Define feature matching loss for gans - - Args: - criterion (str): Support 'l1', 'l2', 'charbonnier'. - loss_weight (float): Loss weight. Default: 1.0. - reduction (str): Specifies the reduction to apply to the output. - Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. - """ - - def __init__(self, criterion='l1', loss_weight=1.0, reduction='mean'): - super(GANFeatLoss, self).__init__() - if criterion == 'l1': - self.loss_op = L1Loss(loss_weight, reduction) - elif criterion == 'l2': - self.loss_op = MSELoss(loss_weight, reduction) - elif criterion == 'charbonnier': - self.loss_op = CharbonnierLoss(loss_weight, reduction) - else: - raise ValueError(f'Unsupported loss mode: {criterion}. Supported ones are: l1|l2|charbonnier') - - self.loss_weight = loss_weight - - def forward(self, pred_fake, pred_real): - num_d = len(pred_fake) - loss = 0 - for i in range(num_d): # for each discriminator - # last output is the final prediction, exclude it - num_intermediate_outputs = len(pred_fake[i]) - 1 - for j in range(num_intermediate_outputs): # for each layer output - unweighted_loss = self.loss_op(pred_fake[i][j], pred_real[i][j].detach()) - loss += unweighted_loss / num_d - return loss * self.loss_weight diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 57accf7..a775586 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -44,8 +44,7 @@ def load_org_results(): def run_test(test_metric_names): - # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - device = torch.device('cpu') + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'============> Testing on {device}') img_batch, ref_batch = load_test_img_batch()