Skip to content

Commit

Permalink
Implements topology-aware loss function; resolves #133
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-j-h committed Oct 24, 2018
1 parent b7f6ebf commit 6c282db
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 1 deletion.
10 changes: 10 additions & 0 deletions robosat/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class FeatureHook:
def __init__(self, module):
self.features = None
self.hook = module.register_forward_hook(self.on)

def on(self, module, inputs, outputs):
self.features = outputs

def close(self):
self.hook.remove()
125 changes: 125 additions & 0 deletions robosat/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import torch
import torch.nn as nn

from torchvision.transforms.functional import normalize
from torchvision.models import vgg16_bn

from robosat.hooks import FeatureHook


class CrossEntropyLoss2d(nn.Module):
"""Cross-entropy.
Expand Down Expand Up @@ -117,3 +122,123 @@ def forward(self, inputs, targets):
loss += torch.dot(nn.functional.relu(errors_sorted), iou)

return loss / N


class CombinedLoss(nn.Module):
"""Weighted combination of losses.
"""

def __init__(self, criteria, weights):
"""Creates a `CombinedLosses` instance.
Args:
criteria: list of criteria to combine
weights: tensor to tune losses with
"""

super().__init__()

assert len(weights.size()) == 1
assert weights.size(0) == len(criteria)

self.criteria = criteria
self.weights = weights

def forward(self, inputs, targets):
loss = 0.0

for criterion, w in zip(self.criteria, self.weights):
each = w * criterion(inputs, targets)
print(type(criterion).__name__, each.item()) # Todo: remove
loss += each

return loss


class TopologyLoss(nn.Module):
"""Topology loss working on a pre-trained model's feature map similarities.
See:
- https://arxiv.org/abs/1603.08155
- https://arxiv.org/abs/1712.02190
Note: implementation works with single channel tensors and stacks them for VGG.
"""

def __init__(self, blocks, weights):
"""Creates a `TopologyLoss` instance.
Args:
blocks: list of block indices to use, in `[0, 6]` (e.g. `[0, 1, 2]`)
weights: tensor to tune losses per block (e.g. `[0.2, 0.6, 0.2]`)
Note: the block indices correspond to a pre-trained VGG's feature maps to use.
"""

super().__init__()

assert len(weights.size()) == 1
assert weights.size(0) == len(blocks)

self.weights = weights

assert len(blocks) <= 5
assert all(i in range(5) for i in blocks)
assert sorted(blocks) == blocks

features = vgg16_bn(pretrained=True).features
features.eval()

for param in features.parameters():
param.requires_grad = False

relus = [i - 1 for i, m in enumerate(features) if isinstance(m, nn.MaxPool2d)]

self.hooks = [FeatureHook(features[relus[i]]) for i in blocks]

# Trim off unused layers to make forward pass more efficient
self.features = features[0 : relus[blocks[-1]] + 1]

def forward(self, inputs, targets):
# model output to foreground probabilities
inputs = nn.functional.softmax(inputs, dim=1)
# we need to clone the tensor here before slicing otherwise pytorch
# will loose track of information required for gradient computation
inputs = inputs.clone()[:, 1, :, :]

# masks are longs but vgg wants floats
targets = targets.float()

# normalize foreground pixels to ImageNet statistics for pre-trained VGG
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
inputs = normalize(inputs, mean, std)
targets = normalize(targets, mean, std)

# N, H, W -> N, C, H, W
inputs = inputs.unsqueeze(1)
targets = targets.unsqueeze(1)

# repeat channel three times for using a pre-trained three-channel VGG
inputs = inputs.repeat(1, 3, 1, 1)
targets = targets.repeat(1, 3, 1, 1)

# extract feature maps and compare their weighted loss

self.features(inputs)
input_features = [hook.features.clone() for hook in self.hooks]

self.features(targets)
target_features = [hook.features for hook in self.hooks]

loss = 0.0

for lhs, rhs, w in zip(input_features, target_features, self.weights):
lhs = lhs.view(lhs.size(0), -1)
rhs = rhs.view(rhs.size(0), -1)
loss += nn.functional.l1_loss(lhs, rhs) * w

return loss

def close(self):
for hook in self.hooks:
hook.close()
12 changes: 11 additions & 1 deletion robosat/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from robosat.datasets import SlippyMapTilesConcatenation
from robosat.metrics import Metrics
from robosat.losses import CrossEntropyLoss2d, mIoULoss2d, FocalLoss2d, LovaszLoss2d
from robosat.losses import CrossEntropyLoss2d, mIoULoss2d, FocalLoss2d, LovaszLoss2d, CombinedLoss, TopologyLoss
from robosat.unet import UNet
from robosat.utils import plot
from robosat.config import load_config
Expand Down Expand Up @@ -108,6 +108,14 @@ def map_location(storage, _):
else:
sys.exit("Error: Unknown [opt][loss] value !")

# use first three vgg feature maps and weight their contribution to loss
topology_weights = torch.tensor([0.2, 0.6, 0.2]).to(device)
topology_loss = TopologyLoss([0, 1, 2], topology_weights).to(device)

# combine the pixel-wise and the topology loss and weight them
loss_weights = torch.tensor([1.0, 10.0]).to(device)
criterion = CombinedLoss([criterion, topology_loss], loss_weights).to(device)

train_loader, val_loader = get_dataset_loaders(model, dataset, args.workers)

num_epochs = model["opt"]["epochs"]
Expand Down Expand Up @@ -163,6 +171,8 @@ def map_location(storage, _):

torch.save(states, os.path.join(model["common"]["checkpoint"], checkpoint))

topology_loss.close()


def train(loader, num_classes, device, net, optimizer, criterion):
num_samples = 0
Expand Down

0 comments on commit 6c282db

Please sign in to comment.