From 438861c1609a10080f17d8dec04081d9a269d72f Mon Sep 17 00:00:00 2001 From: Yanzhang Wang Date: Tue, 26 Mar 2019 18:43:16 +0800 Subject: [PATCH] fix: dropout should init primitive (#2788) --- .../intel/analytics/bigdl/nn/mkldnn/Dropout.scala | 1 + .../analytics/bigdl/nn/mkldnn/DropoutSpec.scala | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Dropout.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Dropout.scala index 7e881efd24f..67f178e2bb3 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Dropout.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/nn/mkldnn/Dropout.scala @@ -49,6 +49,7 @@ class Dropout( _gradOutputFormats = grad.map(x => HeapData(x.shape, format(x.shape))) _gradOutputFormatsForWeight = grad.map(x => HeapData(x.shape, format(x.shape))) _gradInputFormats = grad.map(x => HeapData(x.shape, format(x.shape))) + _gradInputFormats.map(_.getPrimitive(runtime)) gradInput = initTensor(_gradInputFormats.head) (_gradOutputFormats, _gradInputFormats) } diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/DropoutSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/DropoutSpec.scala index 5a45c236e7e..8fdf58b3e3e 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/DropoutSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/nn/mkldnn/DropoutSpec.scala @@ -67,4 +67,18 @@ class DropoutSpec extends FlatSpec with Matchers { val ratio = notEqZeros.toDouble / total ratio should be (1.0) } + + "dropout in sequential" should "work correctly" in { + val shape = Array(2, 3, 4, 4) + val dropout = Dropout() + val seq = Sequential().add(Input(shape, Memory.Format.nchw)) + .add(dropout) + .add(Output(Memory.Format.nchw)) + + seq.compile(TrainingPhase) + + val input = Tensor[Float](shape).rand(-1, 1) + seq.forward(input) + seq.backward(input, seq.output) + } }