forked from vlfeat/matconvnet-fcn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSegmentationLoss.m
28 lines (25 loc) · 952 Bytes
/
SegmentationLoss.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
classdef SegmentationLoss < dagnn.Loss
methods
function outputs = forward(obj, inputs, params)
mass = sum(sum(inputs{2} > 0,2),1) + 1 ;
outputs{1} = vl_nnloss(inputs{1}, inputs{2}, [], ...
'loss', obj.loss, ...
'instanceWeights', 1./mass) ;
n = obj.numAveraged ;
m = n + size(inputs{1},4) ;
obj.average = (n * obj.average + double(gather(outputs{1}))) / m ;
obj.numAveraged = m ;
end
function [derInputs, derParams] = backward(obj, inputs, params, derOutputs)
mass = sum(sum(inputs{2} > 0,2),1) + 1 ;
derInputs{1} = vl_nnloss(inputs{1}, inputs{2}, derOutputs{1}, ...
'loss', obj.loss, ...
'instanceWeights', 1./mass) ;
derInputs{2} = [] ;
derParams = {} ;
end
function obj = SegmentationLoss(varargin)
obj.load(varargin) ;
end
end
end