From bbbfb2ea8bc7981ded043a5da493bb0ce5258d27 Mon Sep 17 00:00:00 2001 From: William Song Date: Mon, 1 Dec 2014 22:35:04 -0800 Subject: [PATCH] L1 layer and tests, some python eval stuff --- include/caffe/loss_layers.hpp | 103 ++++++++++++++++++++- models/brody/solver.prototxt | 10 +-- models/brody/train_val_brody.prototxt | 124 +++++++++++++++++++++----- python/evaluate_result.py | 16 +--- python/store_results.py | 80 +++++++++++++++++ src/caffe/layer_factory.cpp | 2 + src/caffe/layers/l1_loss_layer.cpp | 61 +++++++++++++ src/caffe/layers/l1_loss_layer.cu | 45 ++++++++++ src/caffe/proto/caffe.proto | 3 +- src/caffe/test/test_l1_loss_layer.cpp | 91 +++++++++++++++++++ src/caffe/util/upgrade_proto.cpp | 2 + 11 files changed, 491 insertions(+), 46 deletions(-) create mode 100644 python/store_results.py create mode 100644 src/caffe/layers/l1_loss_layer.cpp create mode 100644 src/caffe/layers/l1_loss_layer.cu create mode 100644 src/caffe/test/test_l1_loss_layer.cpp diff --git a/include/caffe/loss_layers.hpp b/include/caffe/loss_layers.hpp index 08aa7752d4a..991c35a59a7 100644 --- a/include/caffe/loss_layers.hpp +++ b/include/caffe/loss_layers.hpp @@ -174,8 +174,8 @@ class ContrastiveLossLayer : public LossLayer { /** * @brief Computes the Contrastive error gradient w.r.t. the inputs. - * - * Computes the gradients with respect to the two input vectors (bottom[0] and + * + * Computes the gradients with respect to the two input vectors (bottom[0] and * bottom[1]), but not the similarity label (bottom[2]). * * @param top output Blob vector (length 1), providing the error gradient with @@ -194,7 +194,7 @@ class ContrastiveLossLayer : public LossLayer { * the features @f$a@f$; Backward fills their diff with * gradients if propagate_down[0] * -# @f$ (N \times C \times 1 \times 1) @f$ - * the features @f$b@f$; Backward fills their diff with gradients if + * the features @f$b@f$; Backward fills their diff with gradients if * propagate_down[1] */ virtual void Backward_cpu(const vector*>& top, @@ -763,6 +763,103 @@ class SoftmaxWithLossLayer : public LossLayer { vector*> softmax_top_vec_; }; +/** + * @brief Computes the L1 loss @f$ + * E = \frac{1}{N} \sum\limits_{n=1}^N \left| \hat{y}_n - y_n + * \right| @f$ for real-valued regression tasks. + * + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$ \hat{y} \in [-\infty, +\infty]@f$ + * -# @f$ (N \times C \times H \times W) @f$ + * the targets @f$ y \in [-\infty, +\infty]@f$ + * @param top output Blob vector (length 1) + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * the computed L1 loss: @f$ E = + * \frac{1}{2n} \sum\limits_{n=1}^N \left| \left| \hat{y}_n - y_n + * \right| \right|_2^2 @f$ + * + * This can be used for least-squares regression tasks. An InnerProductLayer + * input to a L1LossLayer exactly formulates a linear least squares + * regression problem. With non-zero weight decay the problem becomes one of + * ridge regression -- see src/caffe/test/test_sgd_solver.cpp for a concrete + * example wherein we check that the gradients computed for a Net with exactly + * this structure match hand-computed gradient formulas for ridge regression. + * + * (Note: Caffe, and SGD in general, is certainly \b not the best way to solve + * linear least squares problems! We use it only as an instructive example.) + */ +template +class L1LossLayer : public LossLayer { + public: + explicit L1LossLayer(const LayerParameter& param) + : LossLayer(param), diff_() {} + virtual void Reshape(const vector*>& bottom, + vector*>* top); + + virtual inline LayerParameter_LayerType type() const { + return LayerParameter_LayerType_EUCLIDEAN_LOSS; + } + + /** + * Unlike most loss layers, in the L1LossLayer we can backpropagate + * to both inputs -- override to return true and always allow force_backward. + */ + virtual inline bool AllowForceBackward(const int bottom_index) const { + return true; + } + + protected: + /// @copydoc L1LossLayer + virtual void Forward_cpu(const vector*>& bottom, + vector*>* top); + virtual void Forward_gpu(const vector*>& bottom, + vector*>* top); + + /** + * @brief Computes the L1 error gradient w.r.t. the inputs. + * + * Unlike other children of LossLayer, L1LossLayer \b can compute + * gradients with respect to the label inputs bottom[1] (but still only will + * if propagate_down[1] is set, due to being produced by learnable parameters + * or if force_backward is set). In fact, this layer is "commutative" -- the + * result is the same regardless of the order of the two bottoms. + * + * @param top output Blob vector (length 1), providing the error gradient with + * respect to the outputs + * -# @f$ (1 \times 1 \times 1 \times 1) @f$ + * This Blob's diff will simply contain the loss_weight* @f$ \lambda @f$, + * as @f$ \lambda @f$ is the coefficient of this layer's output + * @f$\ell_i@f$ in the overall Net loss + * @f$ E = \lambda_i \ell_i + \mbox{other loss terms}@f$; hence + * @f$ \frac{\partial E}{\partial \ell_i} = \lambda_i @f$. + * (*Assuming that this top Blob is not used as a bottom (input) by any + * other layer of the Net.) + * @param propagate_down see Layer::Backward. + * @param bottom input Blob vector (length 2) + * -# @f$ (N \times C \times H \times W) @f$ + * the predictions @f$\hat{y}@f$; Backward fills their diff with + * gradients @f$ + * \frac{\partial E}{\partial \hat{y}} = + * \frac{1}{n} \sum\limits_{n=1}^N (\hat{y}_n - y_n) + * @f$ if propagate_down[0] + * -# @f$ (N \times C \times H \times W) @f$ + * the targets @f$y@f$; Backward fills their diff with gradients + * @f$ \frac{\partial E}{\partial y} = + * \frac{1}{n} \sum\limits_{n=1}^N (y_n - \hat{y}_n) + * @f$ if propagate_down[1] + */ + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, vector*>* bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, vector*>* bottom); + + Blob diff_; + Blob sign_; +}; + + + } // namespace caffe #endif // CAFFE_LOSS_LAYERS_HPP_ diff --git a/models/brody/solver.prototxt b/models/brody/solver.prototxt index 25793c049b7..7f8fe6bd192 100644 --- a/models/brody/solver.prototxt +++ b/models/brody/solver.prototxt @@ -1,14 +1,14 @@ net: "models/brody/train_val_brody.prototxt" -test_iter: 1000 -test_interval: 1000 -base_lr: 0.00000001 +test_iter: 20 +test_interval: 5000 +base_lr: 0.0000001 lr_policy: "step" gamma: 0.1 stepsize: 100000 display: 20 max_iter: 1450000 -momentum: 0.2 +momentum: 0.9 weight_decay: 0.00005 -snapshot: 2000 +snapshot: 10000 snapshot_prefix: "models/brody/caffe_brody_train" solver_mode: GPU diff --git a/models/brody/train_val_brody.prototxt b/models/brody/train_val_brody.prototxt index 42ce2488e1c..51ba17ef565 100644 --- a/models/brody/train_val_brody.prototxt +++ b/models/brody/train_val_brody.prototxt @@ -8,7 +8,7 @@ layers { data_param { source: "driving_img_train" backend: LMDB - batch_size: 25 + batch_size: 5 } transform_param { mean_file: "driving_img_mean.binaryproto" @@ -24,7 +24,7 @@ layers { data_param { source: "driving_label_train" backend: LMDB - batch_size: 25 + batch_size: 5 } include: { phase: TRAIN } } @@ -37,7 +37,7 @@ layers { data_param { source: "driving_img_test" backend: LMDB - batch_size: 25 + batch_size: 5 } transform_param { mean_file: "driving_img_mean.binaryproto" @@ -53,7 +53,7 @@ layers { data_param { source: "driving_label_test" backend: LMDB - batch_size: 25 + batch_size: 5 } include: { phase: TEST } } @@ -65,9 +65,13 @@ layers { bottom: "label" top: "pixel-label" top: "bb-label" + top: "height-label" + top: "norm-label" slice_param { slice_dim: 1 slice_point: 16 + slice_point: 80 + slice_point: 96 } } @@ -86,6 +90,19 @@ layers { } } +layers { + name: "height-block" + type: CONCAT + bottom: "height-label" + bottom: "height-label" + bottom: "height-label" + bottom: "height-label" + top: "height-block" + concat_param { + concat_dim: 1 + } +} + layers { name: "conv1" type: CONVOLUTION @@ -105,7 +122,7 @@ layers { } bias_filler { type: "constant" - value: 0 + value: 0.1 } } } @@ -150,14 +167,14 @@ layers { num_output: 256 pad: 2 kernel_size: 5 - group: 1 + group: 2 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" - value: 1 + value: 0.1 } } } @@ -208,7 +225,7 @@ layers { } bias_filler { type: "constant" - value: 0 + value: 0.1 } } } @@ -231,14 +248,14 @@ layers { num_output: 384 pad: 1 kernel_size: 3 - group: 1 + group: 2 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" - value: 1 + value: 0.1 } } } @@ -261,14 +278,14 @@ layers { num_output: 256 pad: 1 kernel_size: 3 - group: 1 + group: 2 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" - value: 1 + value: 0.1 } } } @@ -304,7 +321,7 @@ layers { pad: 3 weight_filler { type: "gaussian" - std: 0.005 + std: 0.01 } bias_filler { type: "constant" @@ -324,7 +341,7 @@ layers { bottom: "fc6-conv" top: "fc6-conv" dropout_param { - dropout_ratio: 0.5 + dropout_ratio: 0.0 } } layers { @@ -341,7 +358,7 @@ layers { kernel_size: 1 weight_filler { type: "gaussian" - std: 0.005 + std: 0.01 } bias_filler { type: "constant" @@ -361,7 +378,7 @@ layers { bottom: "fc7-conv" top: "fc7-conv" dropout_param { - dropout_ratio: 0.5 + dropout_ratio: 0.0 } } @@ -370,16 +387,16 @@ layers { type: CONVOLUTION bottom: "fc7-conv" top: "bb-output" - blobs_lr: 0.001 - blobs_lr: 2 - weight_decay: 1 + blobs_lr: 100 + blobs_lr: 200 + weight_decay: 0.00001 weight_decay: 0 convolution_param { num_output: 64 kernel_size: 1 weight_filler { type: "gaussian" - std: 0.005 + std: 0.01 } bias_filler { type: "constant" @@ -440,12 +457,73 @@ layers { } } -# Squared loss on the bounding boxes. layers { name: "bb-loss" - type: EUCLIDEAN_LOSS + type: L1_LOSS bottom: "bb-masked-output" bottom: "bb-label" top: "bb-loss" - loss_weight: 0.1 + loss_weight: 1 } + +# Squared loss on the bounding boxes. +#layers { +# name: "bb-loss" +# type: EUCLIDEAN_LOSS +# bottom: "bb-masked-output" +# bottom: "bb-label" +# top: "bb-loss" +# loss_weight: 0.01 +#} + +# L1 error loss +#layers { +# name: "bb-diff" +# type: ELTWISE +# bottom: "bb-masked-output" +# bottom: "bb-label" +# eltwise_param { +# operation: SUM +# coeff: 1.0 +# coeff: -1.0 +# } +# top: "bb-diff" +#} + +#layers { +# name: "bb-loss" +# type: ABSVAL +# bottom: "bb-diff" +# top: "bb-loss" +# # 1 / (20 * 15 * 64) +# loss_weight: 0.00000000001 +#} + +#layers { +# name: "bb-loss-pow2" +# type: POWER +# bottom: "bb-diff" +# top: "bb-loss-pow2" +# # 1 / (20 * 15 * 64) +# power_param { +# power: 2 +# } +#} + +#layers { +# name: "bb-loss-height-normalize" +# type: ELTWISE +# bottom: "bb-loss-pow2" +# bottom: "height-block" +# eltwise_param { +# operation: PROD +# } +# top: "bb-loss" +# loss_weight: 0.1 +#} + +#layers { +# name: "bb-loss-silence" +# type: SILENCE +# bottom: "bb-loss" +#} diff --git a/python/evaluate_result.py b/python/evaluate_result.py index a97bfcaf999..c691ae836a4 100644 --- a/python/evaluate_result.py +++ b/python/evaluate_result.py @@ -37,7 +37,6 @@ def main(): bbs = tokens[2:] gt_bbs = get_gt_bbs(bbs) - img_name = fname.split('/')[-1] # print img_name, '...', start = time.time() scores = net.ff([caffe.io.load_image(fname)]) @@ -47,19 +46,8 @@ def main(): rects = get_rects(net.blobs['bb-output'].data[4], mask) if args.dump_images: - assert output_path != '' - image = net.deprocess('data', net.blobs['data'].data[4]) - zoomed_mask = np.empty((480, 640)) - zoomed_mask = scipy.ndimage.zoom(mask, 8, order=0) - masked_image = image.transpose((2, 0, 1)) - masked_image[0, :, :] += zoomed_mask - masked_image = np.clip(masked_image, 0, 1) - masked_image = masked_image.transpose((1, 2, 0)) - boxed_image = np.copy(masked_image) - if len(rects) > 0: - boxed_image = draw_rects(boxed_image, rects) - Image.fromarray( - (boxed_image * 255).astype('uint8')).save(args.output_path + '/' + img_name) + img_name = fname.split('/')[-1] + dump_image(net, mask, rects, args.output_path + img_name) used_rect = set() for bb in gt_bbs: diff --git a/python/store_results.py b/python/store_results.py new file mode 100644 index 00000000000..8e685b6eed9 --- /dev/null +++ b/python/store_results.py @@ -0,0 +1,80 @@ +import numpy as np +import scipy +import matplotlib.pyplot as plt +import caffe +import sys +import Image +import time +import cv2 +import argparse +import pickle + +from driving_utils import * + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--gt_label', required=True) + parser.add_argument('--output_prefix', required=True) + parser.add_argument('--model_prefix', required=True) + parser.add_argument('--deploy', required=True) + parser.add_argument('--iter', required=True) + parser.add_argument('--dump_images', action='store_true') + parser.add_argument('--image_output_path') + parser.add_argument('--interval', type=int, default=10000) + parser.add_argument('--sampling_increment', type=int, default=1) + args = parser.parse_args() + + if args.dump_images: + assert args.image_output_path is not None + + if '-' not in args.iter: + iters = [args.iter] + else: + begin, end = args.iter.split('-') + iters = [str(i) for i in range(int(begin), int(end), args.interval)] + + print 'Generating results for iterations', ' '.join(iters) + + for iter in iters: + print '#### Begin to generate detection results for iteration', iter + net = caffe.Classifier(args.deploy, args.model_prefix + '_iter_' + iter + '.caffemodel') +# net = caffe.Classifier('/deep/u/willsong/caffe/models/brody/deploy.prototxt', +# '/deep/u/willsong/caffe/models/brody/caffe_brody_train_iter_200000.caffemodel') + net.set_phase_test() + net.set_mode_gpu() + net.set_mean('data', np.load('/deep/u/willsong/caffe/python/driving_mean.npy')) # ImageNet mean + net.set_raw_scale('data', 255) # the reference model operates on images in [0,255] range instead of [0,1] + net.set_channel_swap('data', (2, 1, 0)) # the reference model has channels in BGR order instead of RGB + + all_rects = [] + if args.dump_images: + image_folder_name = args.image_output_path + '_' + iter + '/' + if not os.path.exists(image_folder_name): + os.makedirs(image_folder_name) + inc = 0 + for i, line in enumerate(open(args.gt_label).readlines()): + inc += 1 + if inc < args.sampling_increment: + continue + else: + inc = 0 + tokens = line.split() + fname = tokens[0] + + # print img_name, '...', + # start = time.time() + scores = net.ff([caffe.io.load_image(fname)]) + # print 'done ff, took %f seconds' % (time.time() - start) + + mask = get_mask(net.blobs['pixel-prob'].data[4]) + rects = get_rects(net.blobs['bb-output'].data[4], mask) + all_rects.append(rects) + + if args.dump_images: + img_name = fname.split('/')[-1] + dump_image(net, mask, rects, image_folder_name + img_name) + + pickle.dump(all_rects, open(args.output_prefix + '_' + iter + '.pkl', 'w')) + +if __name__ == '__main__': + main() diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index a20f2c6e636..0d78a8fb148 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -254,6 +254,8 @@ Layer* GetLayer(const LayerParameter& param) { return new VideoDataLayer(param); case LayerParameter_LayerType_WINDOW_DATA: return new WindowDataLayer(param); + case LayerParameter_LayerType_L1_LOSS: + return new L1LossLayer(param); case LayerParameter_LayerType_NONE: LOG(FATAL) << "Layer " << name << " has unspecified type."; default: diff --git a/src/caffe/layers/l1_loss_layer.cpp b/src/caffe/layers/l1_loss_layer.cpp new file mode 100644 index 00000000000..702a870c02d --- /dev/null +++ b/src/caffe/layers/l1_loss_layer.cpp @@ -0,0 +1,61 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void L1LossLayer::Reshape( + const vector*>& bottom, vector*>* top) { + LossLayer::Reshape(bottom, top); + CHECK_EQ(bottom[0]->channels(), bottom[1]->channels()); + CHECK_EQ(bottom[0]->height(), bottom[1]->height()); + CHECK_EQ(bottom[0]->width(), bottom[1]->width()); + diff_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + sign_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); +} + +template +void L1LossLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + int count = bottom[0]->count(); + caffe_sub( + count, + bottom[0]->cpu_data(), + bottom[1]->cpu_data(), + diff_.mutable_cpu_data()); + caffe_cpu_sign(count, diff_.cpu_data(), sign_.mutable_cpu_data()); + Dtype abs_sum = caffe_cpu_asum(count, diff_.cpu_data()); + Dtype loss = abs_sum / bottom[0]->num(); + (*top)[0]->mutable_cpu_data()[0] = loss; +} + +template +void L1LossLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, vector*>* bottom) { + for (int i = 0; i < 2; ++i) { + if (propagate_down[i]) { + const Dtype sign = (i == 0) ? 1 : -1; + const Dtype alpha = sign * top[0]->cpu_diff()[0] / (*bottom)[i]->num(); + caffe_cpu_axpby( + (*bottom)[i]->count(), // count + alpha, // alpha + sign_.cpu_data(), // a + Dtype(0), // beta + (*bottom)[i]->mutable_cpu_diff()); // b + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(L1LossLayer); +#endif + +INSTANTIATE_CLASS(L1LossLayer); + +} // namespace caffe diff --git a/src/caffe/layers/l1_loss_layer.cu b/src/caffe/layers/l1_loss_layer.cu new file mode 100644 index 00000000000..bdaf4035c88 --- /dev/null +++ b/src/caffe/layers/l1_loss_layer.cu @@ -0,0 +1,45 @@ +#include + +#include "caffe/layer.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void L1LossLayer::Forward_gpu(const vector*>& bottom, + vector*>* top) { + int count = bottom[0]->count(); + caffe_gpu_sub( + count, + bottom[0]->gpu_data(), + bottom[1]->gpu_data(), + diff_.mutable_gpu_data()); + Dtype abs_sum; + caffe_gpu_asum(count, diff_.gpu_data(), &abs_sum); + caffe_gpu_sign(count, diff_.gpu_data(), sign_.mutable_gpu_data()); + Dtype loss = abs_sum / bottom[0]->num(); + (*top)[0]->mutable_cpu_data()[0] = loss; +} + +template +void L1LossLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, vector*>* bottom) { + for (int i = 0; i < 2; ++i) { + if (propagate_down[i]) { + const Dtype sign = (i == 0) ? 1 : -1; + const Dtype alpha = sign * top[0]->cpu_diff()[0] / (*bottom)[i]->num(); + caffe_gpu_axpby( + (*bottom)[i]->count(), // count + alpha, // alpha + sign_.gpu_data(), // a + Dtype(0), // beta + (*bottom)[i]->mutable_gpu_diff()); // b + } + } +} + +INSTANTIATE_CLASS(L1LossLayer); + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 3bbf33f70c0..9e93c1ee1ee 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -219,7 +219,7 @@ message LayerParameter { // line above the enum. Update the next available ID when you add a new // LayerType. // - // LayerType next available ID: 40 (last added: MULTILANE_LABEL) + // LayerType next available ID: 41 (last added: L1_LOSS) enum LayerType { // "NONE" layer type is 0th enum element so that we don't cause confusion // by defaulting to an existent LayerType (instead, should usually error if @@ -236,6 +236,7 @@ message LayerParameter { DROPOUT = 6; DUMMY_DATA = 32; EUCLIDEAN_LOSS = 7; + L1_LOSS = 40; ELTWISE = 25; FLATTEN = 8; HDF5_DATA = 9; diff --git a/src/caffe/test/test_l1_loss_layer.cpp b/src/caffe/test/test_l1_loss_layer.cpp new file mode 100644 index 00000000000..d4e64a79213 --- /dev/null +++ b/src/caffe/test/test_l1_loss_layer.cpp @@ -0,0 +1,91 @@ +#include +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/filler.hpp" +#include "caffe/vision_layers.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +namespace caffe { + +template +class L1LossLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + L1LossLayerTest() + : blob_bottom_data_(new Blob(10, 5, 1, 1)), + blob_bottom_label_(new Blob(10, 5, 1, 1)), + blob_top_loss_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_data_); + blob_bottom_vec_.push_back(blob_bottom_data_); + filler.Fill(this->blob_bottom_label_); + blob_bottom_vec_.push_back(blob_bottom_label_); + blob_top_vec_.push_back(blob_top_loss_); + } + virtual ~L1LossLayerTest() { + delete blob_bottom_data_; + delete blob_bottom_label_; + delete blob_top_loss_; + } + + void TestForward() { + // Get the loss without a specified objective weight -- should be + // equivalent to explicitly specifiying a weight of 1. + LayerParameter layer_param; + L1LossLayer layer_weight_1(layer_param); + layer_weight_1.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_); + const Dtype loss_weight_1 = + layer_weight_1.Forward(this->blob_bottom_vec_, &this->blob_top_vec_); + + // Get the loss again with a different objective weight; check that it is + // scaled appropriately. + const Dtype kLossWeight = 3.7; + layer_param.add_loss_weight(kLossWeight); + L1LossLayer layer_weight_2(layer_param); + layer_weight_2.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_); + const Dtype loss_weight_2 = + layer_weight_2.Forward(this->blob_bottom_vec_, &this->blob_top_vec_); + const Dtype kErrorMargin = 1e-5; + EXPECT_NEAR(loss_weight_1 * kLossWeight, loss_weight_2, kErrorMargin); + // Make sure the loss is non-trivial. + const Dtype kNonTrivialAbsThresh = 1e-1; + EXPECT_GE(fabs(loss_weight_1), kNonTrivialAbsThresh); + } + + Blob* const blob_bottom_data_; + Blob* const blob_bottom_label_; + Blob* const blob_top_loss_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +TYPED_TEST_CASE(L1LossLayerTest, TestDtypesAndDevices); + +TYPED_TEST(L1LossLayerTest, TestForward) { + this->TestForward(); +} + +TYPED_TEST(L1LossLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + const Dtype kLossWeight = 3.7; + layer_param.add_loss_weight(kLossWeight); + L1LossLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_); + GradientChecker checker(1e-4, 1e-2, 1701); + checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_), + &(this->blob_top_vec_)); +} + +} // namespace caffe diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index c69c58eb340..52d21d75194 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -466,6 +466,8 @@ LayerParameter_LayerType UpgradeV0LayerType(const string& type) { return LayerParameter_LayerType_DROPOUT; } else if (type == "euclidean_loss") { return LayerParameter_LayerType_EUCLIDEAN_LOSS; + } else if (type == "l1_loss") { + return LayerParameter_LayerType_L1_LOSS; } else if (type == "flatten") { return LayerParameter_LayerType_FLATTEN; } else if (type == "hdf5_data") {