Skip to content

Commit

Permalink
Merge pull request BVLC#1874 from jeffdonahue/blob-math-test-precision
Browse files Browse the repository at this point in the history
BlobMathTest: fix precision issues
  • Loading branch information
jeffdonahue committed Feb 16, 2015
2 parents 413ee83 + 6a309c1 commit c09de35
Showing 1 changed file with 34 additions and 19 deletions.
53 changes: 34 additions & 19 deletions src/caffe/test/test_blob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ class BlobMathTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;
protected:
BlobMathTest()
: blob_(new Blob<Dtype>(2, 3, 4, 5)) {}
: blob_(new Blob<Dtype>(2, 3, 4, 5)),
epsilon_(1e-6) {}

virtual ~BlobMathTest() { delete blob_; }
Blob<Dtype>* const blob_;
Dtype epsilon_;
};

TYPED_TEST_CASE(BlobMathTest, TestDtypesAndDevices);
Expand Down Expand Up @@ -95,7 +98,8 @@ TYPED_TEST(BlobMathTest, TestSumOfSquares) {
default:
LOG(FATAL) << "Unknown device: " << TypeParam::device;
}
EXPECT_FLOAT_EQ(expected_sumsq, this->blob_->sumsq_data());
EXPECT_NEAR(expected_sumsq, this->blob_->sumsq_data(),
this->epsilon_ * expected_sumsq);
EXPECT_EQ(0, this->blob_->sumsq_diff());

// Check sumsq_diff too.
Expand All @@ -112,9 +116,12 @@ TYPED_TEST(BlobMathTest, TestSumOfSquares) {
default:
LOG(FATAL) << "Unknown device: " << TypeParam::device;
}
EXPECT_FLOAT_EQ(expected_sumsq, this->blob_->sumsq_data());
EXPECT_FLOAT_EQ(expected_sumsq * kDiffScaleFactor * kDiffScaleFactor,
this->blob_->sumsq_diff());
EXPECT_NEAR(expected_sumsq, this->blob_->sumsq_data(),
this->epsilon_ * expected_sumsq);
const Dtype expected_sumsq_diff =
expected_sumsq * kDiffScaleFactor * kDiffScaleFactor;
EXPECT_NEAR(expected_sumsq_diff, this->blob_->sumsq_diff(),
this->epsilon_ * expected_sumsq_diff);
}

TYPED_TEST(BlobMathTest, TestAsum) {
Expand Down Expand Up @@ -146,7 +153,8 @@ TYPED_TEST(BlobMathTest, TestAsum) {
default:
LOG(FATAL) << "Unknown device: " << TypeParam::device;
}
EXPECT_FLOAT_EQ(expected_asum, this->blob_->asum_data());
EXPECT_NEAR(expected_asum, this->blob_->asum_data(),
this->epsilon_ * expected_asum);
EXPECT_EQ(0, this->blob_->asum_diff());

// Check asum_diff too.
Expand All @@ -163,8 +171,11 @@ TYPED_TEST(BlobMathTest, TestAsum) {
default:
LOG(FATAL) << "Unknown device: " << TypeParam::device;
}
EXPECT_FLOAT_EQ(expected_asum, this->blob_->asum_data());
EXPECT_FLOAT_EQ(expected_asum * kDiffScaleFactor, this->blob_->asum_diff());
EXPECT_NEAR(expected_asum, this->blob_->asum_data(),
this->epsilon_ * expected_asum);
const Dtype expected_diff_asum = expected_asum * kDiffScaleFactor;
EXPECT_NEAR(expected_diff_asum, this->blob_->asum_diff(),
this->epsilon_ * expected_diff_asum);
}

TYPED_TEST(BlobMathTest, TestScaleData) {
Expand Down Expand Up @@ -193,20 +204,22 @@ TYPED_TEST(BlobMathTest, TestScaleData) {
}
const Dtype kDataScaleFactor = 3;
this->blob_->scale_data(kDataScaleFactor);
EXPECT_FLOAT_EQ(asum_before_scale * kDataScaleFactor,
this->blob_->asum_data());
EXPECT_NEAR(asum_before_scale * kDataScaleFactor, this->blob_->asum_data(),
this->epsilon_ * asum_before_scale * kDataScaleFactor);
EXPECT_EQ(0, this->blob_->asum_diff());

// Check scale_diff too.
const Dtype kDataToDiffScaleFactor = 7;
const Dtype* data = this->blob_->cpu_data();
caffe_cpu_scale(this->blob_->count(), kDataToDiffScaleFactor, data,
this->blob_->mutable_cpu_diff());
EXPECT_FLOAT_EQ(asum_before_scale * kDataScaleFactor,
this->blob_->asum_data());
const Dtype diff_asum_before_scale = this->blob_->asum_diff();
EXPECT_FLOAT_EQ(asum_before_scale * kDataScaleFactor * kDataToDiffScaleFactor,
diff_asum_before_scale);
const Dtype expected_asum_before_scale = asum_before_scale * kDataScaleFactor;
EXPECT_NEAR(expected_asum_before_scale, this->blob_->asum_data(),
this->epsilon_ * expected_asum_before_scale);
const Dtype expected_diff_asum_before_scale =
asum_before_scale * kDataScaleFactor * kDataToDiffScaleFactor;
EXPECT_NEAR(expected_diff_asum_before_scale, this->blob_->asum_diff(),
this->epsilon_ * expected_diff_asum_before_scale);
switch (TypeParam::device) {
case Caffe::CPU:
this->blob_->mutable_cpu_diff();
Expand All @@ -219,10 +232,12 @@ TYPED_TEST(BlobMathTest, TestScaleData) {
}
const Dtype kDiffScaleFactor = 3;
this->blob_->scale_diff(kDiffScaleFactor);
EXPECT_FLOAT_EQ(asum_before_scale * kDataScaleFactor,
this->blob_->asum_data());
EXPECT_FLOAT_EQ(diff_asum_before_scale * kDiffScaleFactor,
this->blob_->asum_diff());
EXPECT_NEAR(asum_before_scale * kDataScaleFactor, this->blob_->asum_data(),
this->epsilon_ * asum_before_scale * kDataScaleFactor);
const Dtype expected_diff_asum =
expected_diff_asum_before_scale * kDiffScaleFactor;
EXPECT_NEAR(expected_diff_asum, this->blob_->asum_diff(),
this->epsilon_ * expected_diff_asum);
}

} // namespace caffe

0 comments on commit c09de35

Please sign in to comment.