Skip to content

Commit

Permalink
fix: readObject for ModelInfo (#2571)
Browse files Browse the repository at this point in the history
  • Loading branch information
i8run authored Jun 29, 2018
1 parent 9d75c13 commit 72a635a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ private[bigdl] class ModelInfo[T: ClassTag](val uuid: String, @transient var mod
out.writeObject(cloned)
CachedModels.add(uuid, cloned)
}

@throws(classOf[IOException])
private def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
model = in.readObject().asInstanceOf[Module[T]]
CachedModels.add(uuid, model)
}
}

private[bigdl] object ModelInfo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.intel.analytics.bigdl.models.lenet.LeNet5
import com.intel.analytics.bigdl.nn.tf.Const
import com.intel.analytics.bigdl.nn._
import com.intel.analytics.bigdl.tensor.Tensor
import org.apache.commons.lang3.SerializationUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
Expand Down Expand Up @@ -138,6 +139,18 @@ class ModelBroadcastSpec extends FlatSpec with Matchers with BeforeAndAfter {
modelBroadCast.value().parameters()._1 should be(model.parameters()._1)
}

"model info serialized" should "not be null" in {
val model = LeNet5(10).cloneModule()
val info = ModelInfo[Float]("124339", model)

val newInfo = SerializationUtils.clone(info)

newInfo.model should not be (null)
info.model.toString() should be (newInfo.model.toString())
info.model.parameters()._1 should be (newInfo.model.parameters()._1)
info.model.parameters()._2 should be (newInfo.model.parameters()._2)
}

after {
if (sc != null) {
sc.stop()
Expand Down

0 comments on commit 72a635a

Please sign in to comment.