Skip to content

Commit

Permalink
Fix initialization check on de/serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Jun 13, 2018
1 parent c42756b commit f5d7675
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
23 changes: 20 additions & 3 deletions src/main/java/weka/classifiers/functions/Dl4jMlpClassifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ public class Dl4jMlpClassifier extends RandomizableClassifier
*/
protected IterationListener iterationListener = new EpochListener();

/**
* Flag indicating if initialization is finished.
*/
protected boolean isInitializationFinished = false;

/**
* Default constructor
*/
Expand Down Expand Up @@ -350,7 +355,7 @@ public Capabilities getCapabilities() {
private void writeObject(ObjectOutputStream oos) throws IOException {
// figure out size of the written network
CountingOutputStream cos = new CountingOutputStream(new NullOutputStream());
if (replaceMissingFilter != null) {
if (isInitializationFinished) {
ModelSerializer.writeModel(model, cos, false);
}
modelSize = cos.getByteCount();
Expand All @@ -359,7 +364,7 @@ private void writeObject(ObjectOutputStream oos) throws IOException {
oos.defaultWriteObject();

// actually write the network
if (replaceMissingFilter != null) {
if (isInitializationFinished) {
ModelSerializer.writeModel(model, oos, false);
}
}
Expand All @@ -377,7 +382,7 @@ private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IO
ois.defaultReadObject();

// restore the network model
if (replaceMissingFilter != null) {
if (isInitializationFinished) {
File tmpFile = File.createTempFile("restore", "multiLayer");
tmpFile.deleteOnExit();
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(tmpFile));
Expand All @@ -404,12 +409,22 @@ private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IO
.map(l -> Layer.create(l.conf().getLayer()))
.toArray(Layer[]::new);
System.out.println();
} else {
layers = new Layer[] {createOutputLayer()};
}
} finally {
Thread.currentThread().setContextClassLoader(origLoader);
}
}

/**
* Generate the, for this model type, typical output layer.
* @return New OutputLayer object
*/
protected Layer<? extends BaseOutputLayer> createOutputLayer(){
return new OutputLayer();
}

/**
* Get the log file
*
Expand Down Expand Up @@ -729,6 +744,8 @@ public void initializeClassifier(Instances data) throws Exception {
model.setListeners(getListener());

numEpochsPerformed = 0;

isInitializationFinished = true;
} finally {
Thread.currentThread().setContextClassLoader(origLoader);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.AbstractLSTM;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
Expand Down Expand Up @@ -295,6 +296,14 @@ public void setZooModel(ZooModel zooModel) {
// Do nothing
}

/**
* Generate the, for this model type, typical output layer.
* @return New OutputLayer object
*/
protected Layer<? extends BaseOutputLayer> createOutputLayer(){
return new weka.dl4j.layers.RnnOutputLayer();
}

/**
* Returns default capabilities of the classifier.
*
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/weka/classifiers/functions/Dl4jMlpTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,13 @@ public void testSerialization() throws Exception {

File out = Paths.get(System.getProperty("java.io.tmpdir"), "out.object").toFile();
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(out));
clf.initializeClassifier(dataMnist);
oos.writeObject(clf);

ObjectInputStream ois = new ObjectInputStream(new FileInputStream(out));
Dl4jMlpClassifier clf2 = (Dl4jMlpClassifier) ois.readObject();

clf2.setNumEpochs(1);
clf2.initializeClassifier(dataMnist);
clf2.buildClassifier(dataMnist);
}

Expand Down

0 comments on commit f5d7675

Please sign in to comment.