diff --git a/robosat/hooks.py b/robosat/hooks.py new file mode 100644 index 00000000..8e852ead --- /dev/null +++ b/robosat/hooks.py @@ -0,0 +1,10 @@ +class FeatureHook: + def __init__(self, module): + self.features = None + self.hook = module.register_forward_hook(self.on) + + def on(self, module, inputs, outputs): + self.features = outputs + + def close(self): + self.hook.remove() diff --git a/robosat/losses.py b/robosat/losses.py index 6f2d817c..a2928aa5 100644 --- a/robosat/losses.py +++ b/robosat/losses.py @@ -4,6 +4,11 @@ import torch import torch.nn as nn +from torchvision.transforms.functional import normalize +from torchvision.models import vgg16_bn + +from robosat.hooks import FeatureHook + class CrossEntropyLoss2d(nn.Module): """Cross-entropy. @@ -117,3 +122,123 @@ def forward(self, inputs, targets): loss += torch.dot(nn.functional.relu(errors_sorted), iou) return loss / N + + +class CombinedLoss(nn.Module): + """Weighted combination of losses. + """ + + def __init__(self, criteria, weights): + """Creates a `CombinedLosses` instance. + + Args: + criteria: list of criteria to combine + weights: tensor to tune losses with + """ + + super().__init__() + + assert len(weights.size()) == 1 + assert weights.size(0) == len(criteria) + + self.criteria = criteria + self.weights = weights + + def forward(self, inputs, targets): + loss = 0.0 + + for criterion, w in zip(self.criteria, self.weights): + each = w * criterion(inputs, targets) + print(type(criterion).__name__, each.item()) # Todo: remove + loss += each + + return loss + + +class TopologyLoss(nn.Module): + """Topology loss working on a pre-trained model's feature map similarities. + + See: + - https://arxiv.org/abs/1603.08155 + - https://arxiv.org/abs/1712.02190 + + Note: implementation works with single channel tensors and stacks them for VGG. + """ + + def __init__(self, blocks, weights): + """Creates a `TopologyLoss` instance. + + Args: + blocks: list of block indices to use, in `[0, 6]` (e.g. `[0, 1, 2]`) + weights: tensor to tune losses per block (e.g. `[0.2, 0.6, 0.2]`) + + Note: the block indices correspond to a pre-trained VGG's feature maps to use. + """ + + super().__init__() + + assert len(weights.size()) == 1 + assert weights.size(0) == len(blocks) + + self.weights = weights + + assert len(blocks) <= 5 + assert all(i in range(5) for i in blocks) + assert sorted(blocks) == blocks + + features = vgg16_bn(pretrained=True).features + features.eval() + + for param in features.parameters(): + param.requires_grad = False + + relus = [i - 1 for i, m in enumerate(features) if isinstance(m, nn.MaxPool2d)] + + self.hooks = [FeatureHook(features[relus[i]]) for i in blocks] + + # Trim off unused layers to make forward pass more efficient + self.features = features[0 : relus[blocks[-1]] + 1] + + def forward(self, inputs, targets): + # model output to foreground probabilities + inputs = nn.functional.softmax(inputs, dim=1) + # we need to clone the tensor here before slicing otherwise pytorch + # will lose track of information required for gradient computation + inputs = inputs.clone()[:, 1, :, :] + + # masks are longs but vgg wants floats + targets = targets.float() + + # normalize foreground pixels to ImageNet statistics for pre-trained VGG + mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] + inputs = normalize(inputs, mean, std) + targets = normalize(targets, mean, std) + + # N, H, W -> N, C, H, W + inputs = inputs.unsqueeze(1) + targets = targets.unsqueeze(1) + + # repeat channel three times for using a pre-trained three-channel VGG + inputs = inputs.repeat(1, 3, 1, 1) + targets = targets.repeat(1, 3, 1, 1) + + # extract feature maps and compare their weighted loss + + self.features(inputs) + input_features = [hook.features.clone() for hook in self.hooks] + + self.features(targets) + target_features = [hook.features for hook in self.hooks] + + loss = 0.0 + + for lhs, rhs, w in zip(input_features, target_features, self.weights): + lhs = lhs.view(lhs.size(0), -1) + rhs = rhs.view(rhs.size(0), -1) + loss += nn.functional.mse_loss(lhs, rhs) * w + + return loss + + def close(self): + for hook in self.hooks: + hook.close() diff --git a/robosat/tools/train.py b/robosat/tools/train.py index 85245fd8..e4e54ec2 100644 --- a/robosat/tools/train.py +++ b/robosat/tools/train.py @@ -26,7 +26,7 @@ ) from robosat.datasets import SlippyMapTilesConcatenation from robosat.metrics import Metrics -from robosat.losses import CrossEntropyLoss2d, mIoULoss2d, FocalLoss2d, LovaszLoss2d +from robosat.losses import CrossEntropyLoss2d, mIoULoss2d, FocalLoss2d, LovaszLoss2d, CombinedLoss, TopologyLoss from robosat.unet import UNet from robosat.utils import plot from robosat.config import load_config @@ -108,6 +108,14 @@ def map_location(storage, _): else: sys.exit("Error: Unknown [opt][loss] value !") + # use first three vgg feature maps and weight their contribution to loss + topology_weights = torch.tensor([0.2, 0.6, 0.2]).to(device) + topology_loss = TopologyLoss([0, 1, 2], topology_weights).to(device) + + # combine the pixel-wise and the topology loss and weight them + loss_weights = torch.tensor([1.0, 10.0]).to(device) + criterion = CombinedLoss([criterion, topology_loss], loss_weights).to(device) + train_loader, val_loader = get_dataset_loaders(model, dataset, args.workers) num_epochs = model["opt"]["epochs"] @@ -163,6 +171,8 @@ def map_location(storage, _): torch.save(states, os.path.join(model["common"]["checkpoint"], checkpoint)) + topology_loss.close() + def train(loader, num_classes, device, net, optimizer, criterion): num_samples = 0