Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements topology-aware loss function; resolves #133 #135

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions robosat/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class FeatureHook:
def __init__(self, module):
self.features = None
self.hook = module.register_forward_hook(self.on)

def on(self, module, inputs, outputs):
self.features = outputs

def close(self):
self.hook.remove()
125 changes: 125 additions & 0 deletions robosat/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import torch
import torch.nn as nn

from torchvision.transforms.functional import normalize
from torchvision.models import vgg16_bn

from robosat.hooks import FeatureHook


class CrossEntropyLoss2d(nn.Module):
"""Cross-entropy.
Expand Down Expand Up @@ -117,3 +122,123 @@ def forward(self, inputs, targets):
loss += torch.dot(nn.functional.relu(errors_sorted), iou)

return loss / N


class CombinedLoss(nn.Module):
"""Weighted combination of losses.
"""

def __init__(self, criteria, weights):
"""Creates a `CombinedLosses` instance.

Args:
criteria: list of criteria to combine
weights: tensor to tune losses with
"""

super().__init__()

assert len(weights.size()) == 1
assert weights.size(0) == len(criteria)

self.criteria = criteria
self.weights = weights

def forward(self, inputs, targets):
loss = 0.0

for criterion, w in zip(self.criteria, self.weights):
each = w * criterion(inputs, targets)
print(type(criterion).__name__, each.item()) # Todo: remove
loss += each

return loss


class TopologyLoss(nn.Module):
"""Topology loss working on a pre-trained model's feature map similarities.

See:
- https://arxiv.org/abs/1603.08155
- https://arxiv.org/abs/1712.02190

Note: implementation works with single channel tensors and stacks them for VGG.
"""

def __init__(self, blocks, weights):
"""Creates a `TopologyLoss` instance.

Args:
blocks: list of block indices to use, in `[0, 6]` (e.g. `[0, 1, 2]`)
weights: tensor to tune losses per block (e.g. `[0.2, 0.6, 0.2]`)

Note: the block indices correspond to a pre-trained VGG's feature maps to use.
"""

super().__init__()

assert len(weights.size()) == 1
assert weights.size(0) == len(blocks)

self.weights = weights

assert len(blocks) <= 5
assert all(i in range(5) for i in blocks)
assert sorted(blocks) == blocks

features = vgg16_bn(pretrained=True).features
features.eval()

for param in features.parameters():
param.requires_grad = False

relus = [i - 1 for i, m in enumerate(features) if isinstance(m, nn.MaxPool2d)]

self.hooks = [FeatureHook(features[relus[i]]) for i in blocks]

# Trim off unused layers to make forward pass more efficient
self.features = features[0 : relus[blocks[-1]] + 1]

def forward(self, inputs, targets):
# model output to foreground probabilities
inputs = nn.functional.softmax(inputs, dim=1)
# we need to clone the tensor here before slicing otherwise pytorch
# will lose track of information required for gradient computation
inputs = inputs.clone()[:, 1, :, :]

# masks are longs but vgg wants floats
targets = targets.float()

# normalize foreground pixels to ImageNet statistics for pre-trained VGG
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
inputs = normalize(inputs, mean, std)
targets = normalize(targets, mean, std)

# N, H, W -> N, C, H, W
inputs = inputs.unsqueeze(1)
targets = targets.unsqueeze(1)

# repeat channel three times for using a pre-trained three-channel VGG
inputs = inputs.repeat(1, 3, 1, 1)
targets = targets.repeat(1, 3, 1, 1)

# extract feature maps and compare their weighted loss

self.features(inputs)
input_features = [hook.features.clone() for hook in self.hooks]

self.features(targets)
target_features = [hook.features for hook in self.hooks]

loss = 0.0

for lhs, rhs, w in zip(input_features, target_features, self.weights):
lhs = lhs.view(lhs.size(0), -1)
rhs = rhs.view(rhs.size(0), -1)
loss += nn.functional.mse_loss(lhs, rhs) * w

return loss

def close(self):
for hook in self.hooks:
hook.close()
12 changes: 11 additions & 1 deletion robosat/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from robosat.datasets import SlippyMapTilesConcatenation
from robosat.metrics import Metrics
from robosat.losses import CrossEntropyLoss2d, mIoULoss2d, FocalLoss2d, LovaszLoss2d
from robosat.losses import CrossEntropyLoss2d, mIoULoss2d, FocalLoss2d, LovaszLoss2d, CombinedLoss, TopologyLoss
from robosat.unet import UNet
from robosat.utils import plot
from robosat.config import load_config
Expand Down Expand Up @@ -108,6 +108,14 @@ def map_location(storage, _):
else:
sys.exit("Error: Unknown [opt][loss] value !")

# use first three vgg feature maps and weight their contribution to loss
topology_weights = torch.tensor([0.2, 0.6, 0.2]).to(device)
topology_loss = TopologyLoss([0, 1, 2], topology_weights).to(device)

# combine the pixel-wise and the topology loss and weight them
loss_weights = torch.tensor([1.0, 10.0]).to(device)
criterion = CombinedLoss([criterion, topology_loss], loss_weights).to(device)

train_loader, val_loader = get_dataset_loaders(model, dataset, args.workers)

num_epochs = model["opt"]["epochs"]
Expand Down Expand Up @@ -163,6 +171,8 @@ def map_location(storage, _):

torch.save(states, os.path.join(model["common"]["checkpoint"], checkpoint))

topology_loss.close()


def train(loader, num_classes, device, net, optimizer, criterion):
num_samples = 0
Expand Down