Skip to content

Commit

Permalink
fix BCE return Nan (#2474)
Browse files Browse the repository at this point in the history
  • Loading branch information
qiuxin2012 authored Mar 29, 2018
1 parent 57d2efd commit 73ebf25
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class BCECriterion[@specialized(Float, Double) T: ClassTag]
// cmul support broadcasting
buffer.cmul(weights)
sum += ev.toType[Double](buffer.dot(target))
buffer.fill(ev.fromType(1.0 + eps)).sub(input).log().cmul(weights)
buffer.fill(ev.one).sub(input).add(ev.fromType(eps)).log().cmul(weights)
sum -= ev.toType[Double](buffer.dot(target))
if (onesBuffer.nElement() != buffer.nElement()) {
onesBuffer.resizeAs(buffer).fill(ev.one)
Expand All @@ -81,7 +81,7 @@ class BCECriterion[@specialized(Float, Double) T: ClassTag]
} else {
buffer.resizeAs(input).copy(input).add(ev.fromType(eps)).log()
sum += ev.toType[Double](buffer.dot(target))
buffer.fill(ev.fromType(1.0 + eps)).sub(input).log()
buffer.fill(ev.one).sub(input).add(ev.fromType(eps)).log()
sum -= ev.toType[Double](buffer.dot(target))
if (onesBuffer.nElement() != buffer.nElement()) {
onesBuffer.resizeAs(buffer).fill(ev.one)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,37 @@ class BCECriterionSpec extends FlatSpec with Matchers {

}

"BCECriterion's eps" should "works" in {
val criterion = BCECriterion[Float]()
val output = Tensor[Float](3)
output.setValue(1, 0f)
output.setValue(2, 1f)
output.setValue(3, 0.5f)
val target = Tensor[Float](3)
target.setValue(1, 0)
target.setValue(2, 1)
target.setValue(3, 1)

val loss = criterion.forward(output, target)
java.lang.Float.isNaN(loss) should be (false)
}

"BCECriterion's eps with weight" should "works" in {
val weights = Tensor[Float](3).rand()
val criterion = BCECriterion[Float](weights)
val output = Tensor[Float](3)
output.setValue(1, 0f)
output.setValue(2, 1f)
output.setValue(3, 0.5f)
val target = Tensor[Float](3)
target.setValue(1, 0)
target.setValue(2, 1)
target.setValue(3, 1)

val loss = criterion.forward(output, target)
java.lang.Float.isNaN(loss) should be (false)
}

"BCECriterion with more than two dimensions small input" should "" +
"return return right output and gradInput" in {

Expand Down

0 comments on commit 73ebf25

Please sign in to comment.