From 72a635a8595cc82d349c4b53bf93cf4b96ca377d Mon Sep 17 00:00:00 2001 From: Yanzhang Wang Date: Fri, 29 Jun 2018 18:25:50 +0800 Subject: [PATCH] fix: readObject for ModelInfo (#2571) --- .../bigdl/models/utils/ModelBroadcast.scala | 7 +++++++ .../bigdl/models/utils/ModelBroadcastSpec.scala | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/models/utils/ModelBroadcast.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/models/utils/ModelBroadcast.scala index 4bc7530875d..cf412f436c1 100644 --- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/models/utils/ModelBroadcast.scala +++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/models/utils/ModelBroadcast.scala @@ -191,6 +191,13 @@ private[bigdl] class ModelInfo[T: ClassTag](val uuid: String, @transient var mod out.writeObject(cloned) CachedModels.add(uuid, cloned) } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = { + in.defaultReadObject() + model = in.readObject().asInstanceOf[Module[T]] + CachedModels.add(uuid, model) + } } private[bigdl] object ModelInfo { diff --git a/spark/dl/src/test/scala/com/intel/analytics/bigdl/models/utils/ModelBroadcastSpec.scala b/spark/dl/src/test/scala/com/intel/analytics/bigdl/models/utils/ModelBroadcastSpec.scala index 17001e62509..7fe7f2e1748 100644 --- a/spark/dl/src/test/scala/com/intel/analytics/bigdl/models/utils/ModelBroadcastSpec.scala +++ b/spark/dl/src/test/scala/com/intel/analytics/bigdl/models/utils/ModelBroadcastSpec.scala @@ -21,6 +21,7 @@ import com.intel.analytics.bigdl.models.lenet.LeNet5 import com.intel.analytics.bigdl.nn.tf.Const import com.intel.analytics.bigdl.nn._ import com.intel.analytics.bigdl.tensor.Tensor +import org.apache.commons.lang3.SerializationUtils import org.apache.log4j.{Level, Logger} import org.apache.spark.{SparkConf, SparkContext} import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers} @@ -138,6 +139,18 @@ class ModelBroadcastSpec extends FlatSpec with Matchers with BeforeAndAfter { modelBroadCast.value().parameters()._1 should be(model.parameters()._1) } + "model info serialized" should "not be null" in { + val model = LeNet5(10).cloneModule() + val info = ModelInfo[Float]("124339", model) + + val newInfo = SerializationUtils.clone(info) + + newInfo.model should not be (null) + info.model.toString() should be (newInfo.model.toString()) + info.model.parameters()._1 should be (newInfo.model.parameters()._1) + info.model.parameters()._2 should be (newInfo.model.parameters()._2) + } + after { if (sc != null) { sc.stop()