diff --git a/build.sbt b/build.sbt index e281398..06920fb 100644 --- a/build.sbt +++ b/build.sbt @@ -4,7 +4,7 @@ crossPaths := false autoScalaLibrary := false -version := "1.7.1-1" +version := "2.0.2-1" //javacOptions ++= Seq("-release", "8") @@ -34,7 +34,7 @@ scalacOptions ++= Seq("-feature", "-deprecation") libraryDependencies ++= Seq( "junit" % "junit" % "4.13.2" % "test", "com.github.sbt" % "junit-interface" % "0.13.3" % Test, - "org.slf4j" % "slf4j-api" % "1.7.36", + "org.slf4j" % "slf4j-api" % "2.0.5", "org.apache.hadoop" % "hadoop-hdfs" % "2.10.1" % "provided", "org.apache.hadoop" % "hadoop-common" % "2.10.1" % "provided" ) diff --git a/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 6fc13eb..3b491c3 100644 --- a/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -158,6 +158,51 @@ public void setAttrs(Map attrs) throws XGBoostError { } } + /** + * Get feature names from the Booster. + * @return + * @throws XGBoostError + */ + public final String[] getFeatureNames() throws XGBoostError { + int numFeature = (int) getNumFeature(); + String[] out = new String[numFeature]; + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_name", out)); + return out; + } + + /** + * Set feature names to the Booster. + * + * @param featureNames + * @throws XGBoostError + */ + public void setFeatureNames(String[] featureNames) throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo( + handle, "feature_name", featureNames)); + } + + /** + * Get feature types from the Booster. + * @return + * @throws XGBoostError + */ + public final String[] getFeatureTypes() throws XGBoostError { + int numFeature = (int) getNumFeature(); + String[] out = new String[numFeature]; + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_type", out)); + return out; + } + + /** + * Set feature types to the Booster. + * @param featureTypes + * @throws XGBoostError + */ + public void setFeatureTypes(String[] featureTypes) throws XGBoostError { + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo( + handle, "feature_type", featureTypes)); + } + /** * Update the booster for one iteration. * @@ -740,7 +785,7 @@ private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) { private void writeObject(java.io.ObjectOutputStream out) throws IOException { try { out.writeInt(version); - out.writeObject(this.toByteArray()); + out.writeObject(this.toByteArray("ubj")); } catch (XGBoostError ex) { ex.printStackTrace(); logger.error(ex.getMessage()); diff --git a/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java b/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java index c151fc7..7fa3901 100644 --- a/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java +++ b/src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java @@ -71,7 +71,6 @@ public final String getArrayInterfaceJson() { /** * Get the cuda array interface of the label columns. * The returned value must not be null or empty if we're creating - * {@link DeviceQuantileDMatrix#DeviceQuantileDMatrix(Iterator, float, int, int)} */ public abstract String getLabelsArrayInterface(); diff --git a/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index d25c520..b93b3b7 100644 --- a/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -57,17 +57,9 @@ public DMatrix(String dataPath) throws XGBoostError { * @throws XGBoostError */ @Deprecated - public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st) - throws XGBoostError { - long[] out = new long[1]; - if (st == SparseType.CSR) { - XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out)); - } else if (st == SparseType.CSC) { - XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out)); - } else { - throw new UnknownError("unknow sparsetype"); - } - handle = out[0]; + public DMatrix(long[] headers, int[] indices, float[] data, + DMatrix.SparseType st) throws XGBoostError { + this(headers, indices, data, st, 0, Float.NaN, -1); } /** @@ -80,15 +72,20 @@ public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType s * row number * @throws XGBoostError */ - public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st, int shapeParam) - throws XGBoostError { + public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st, + int shapeParam) throws XGBoostError { + this(headers, indices, data, st, shapeParam, Float.NaN, -1); + } + + public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st, int shapeParam, + float missing, int nthread) throws XGBoostError { long[] out = new long[1]; if (st == SparseType.CSR) { - XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, - shapeParam, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, + shapeParam, missing, nthread, out)); } else if (st == SparseType.CSC) { - XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, - shapeParam, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, + shapeParam, missing, nthread, out)); } else { throw new UnknownError("unknow sparsetype"); } @@ -403,6 +400,18 @@ public long rowNum() throws XGBoostError { return rowNum[0]; } + /** + * Get the number of non-missing values of DMatrix. + * + * @return The number of non-missing values + * @throws XGBoostError native error + */ + public long nonMissingNum() throws XGBoostError { + long[] n = new long[1]; + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixNumNonMissing(handle, n)); + return n[0]; + } + /** * save DMatrix to filePath */ diff --git a/src/main/java/ml/dmlc/xgboost4j/java/DeviceQuantileDMatrix.java b/src/main/java/ml/dmlc/xgboost4j/java/DeviceQuantileDMatrix.java deleted file mode 100644 index 849e7a7..0000000 --- a/src/main/java/ml/dmlc/xgboost4j/java/DeviceQuantileDMatrix.java +++ /dev/null @@ -1,68 +0,0 @@ -package ml.dmlc.xgboost4j.java; - -import java.util.Iterator; - -/** - * DeviceQuantileDMatrix will only be used to train - */ -public class DeviceQuantileDMatrix extends DMatrix { - /** - * Create DeviceQuantileDMatrix from iterator based on the cuda array interface - * @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface - * @param missing the missing value - * @param maxBin the max bin - * @param nthread the parallelism - * @throws XGBoostError - */ - public DeviceQuantileDMatrix( - Iterator iter, - float missing, - int maxBin, - int nthread) throws XGBoostError { - super(0); - long[] out = new long[1]; - XGBoostJNI.checkCall(XGBoostJNI.XGDeviceQuantileDMatrixCreateFromCallback( - iter, missing, maxBin, nthread, out)); - handle = out[0]; - } - - @Override - public void setLabel(Column column) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel."); - } - - @Override - public void setWeight(Column column) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight."); - } - - @Override - public void setBaseMargin(Column column) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin."); - } - - @Override - public void setLabel(float[] labels) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel."); - } - - @Override - public void setWeight(float[] weights) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight."); - } - - @Override - public void setBaseMargin(float[] baseMargin) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin."); - } - - @Override - public void setBaseMargin(float[][] baseMargin) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin."); - } - - @Override - public void setGroup(int[] group) throws XGBoostError { - throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup."); - } -} diff --git a/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index 75e1895..bcd0b1b 100644 --- a/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -17,6 +17,8 @@ import java.io.*; import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,6 +32,11 @@ public class XGBoost { private static final Log logger = LogFactory.getLog(XGBoost.class); + public static final String[] MAXIMIZ_METRICES = { + "auc", "aucpr", "pre", "pre@", "map", "ndcg", + "auc@", "aucpr@", "map@", "ndcg@", + }; + /** * load model from modelPath * @@ -158,7 +165,7 @@ public static Booster trainAndSaveCheckpoint( //collect eval matrixs String[] evalNames; DMatrix[] evalMats; - float bestScore; + float bestScore = 1; int bestIteration; List names = new ArrayList(); List mats = new ArrayList(); @@ -175,11 +182,7 @@ public static Booster trainAndSaveCheckpoint( evalNames = names.toArray(new String[names.size()]); evalMats = mats.toArray(new DMatrix[mats.size()]); - if (isMaximizeEvaluation(params)) { - bestScore = -Float.MAX_VALUE; - } else { - bestScore = Float.MAX_VALUE; - } + bestIteration = 0; metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics; @@ -198,6 +201,8 @@ public static Booster trainAndSaveCheckpoint( if (booster == null) { // Start training on a new booster booster = new Booster(params, allMats); + booster.setFeatureNames(dtrain.getFeatureNames()); + booster.setFeatureTypes(dtrain.getFeatureTypes()); booster.loadRabitCheckpoint(); } else { // Start training on an existing booster @@ -208,6 +213,9 @@ public static Booster trainAndSaveCheckpoint( checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds)); } + boolean initial_best_score_flag = false; + boolean max_direction = false; + // begin to train for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) { if (booster.getVersion() % 2 == 0) { @@ -229,6 +237,18 @@ public static Booster trainAndSaveCheckpoint( } else { evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut); } + + if (!initial_best_score_flag) { + if (isMaximizeEvaluation(evalInfo, evalNames, params)) { + max_direction = true; + bestScore = -Float.MAX_VALUE; + } else { + max_direction = false; + bestScore = Float.MAX_VALUE; + } + initial_best_score_flag = true; + } + for (int i = 0; i < metricsOut.length; i++) { metrics[i][iter] = metricsOut[i]; } @@ -236,7 +256,7 @@ public static Booster trainAndSaveCheckpoint( // If there is more than one evaluation datasets, the last one would be used // to determinate early stop. float score = metricsOut[metricsOut.length - 1]; - if (isMaximizeEvaluation(params)) { + if (max_direction) { // Update best score if the current score is better (no update when equal) if (score > bestScore) { bestScore = score; @@ -262,9 +282,7 @@ public static Booster trainAndSaveCheckpoint( break; } if (Communicator.getRank() == 0 && shouldPrint(params, iter)) { - if (shouldPrint(params, iter)){ - Communicator.communicatorPrint(evalInfo + '\n'); - } + Communicator.communicatorPrint(evalInfo + '\n'); } } booster.saveRabitCheckpoint(); @@ -358,16 +376,50 @@ static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIterat return iter - bestIteration >= earlyStoppingRounds; } - private static boolean isMaximizeEvaluation(Map params) { - try { + private static String getMetricNameFromlog(String evalInfo, String[] evalNames) { + String regexPattern = Pattern.quote(evalNames[0]) + "-(.*):"; + Pattern pattern = Pattern.compile(regexPattern); + Matcher matcher = pattern.matcher(evalInfo); + + String metricName = null; + if (matcher.find()) { + metricName = matcher.group(1); + logger.debug("Got the metric name: " + metricName); + } + return metricName; + } + + // visiable for testing + public static boolean isMaximizeEvaluation(String evalInfo, + String[] evalNames, + Map params) { + + String metricName; + + if (params.get("maximize_evaluation_metrics") != null) { + // user has forced the direction no matter what is the metric name. String maximize = String.valueOf(params.get("maximize_evaluation_metrics")); - assert(maximize != null); return Boolean.valueOf(maximize); - } catch (Exception ex) { - logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," + - " allowed value: true/false", ex); - throw ex; } + + if (params.get("eval_metric") != null) { + // user has special metric name + metricName = String.valueOf(params.get("eval_metric")); + } else { + // infer the metric name from log + metricName = getMetricNameFromlog(evalInfo, evalNames); + } + + assert metricName != null; + + if (!"mape".equals(metricName)) { + for (String x : MAXIMIZ_METRICES) { + if (metricName.startsWith(x)) { + return true; + } + } + } + return false; } /** diff --git a/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 2fdca3e..e69e635 100644 --- a/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -53,11 +53,15 @@ static void checkCall(int ret) throws XGBoostError { public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); - public final static native int XGDMatrixCreateFromCSREx(long[] indptr, int[] indices, float[] data, - int shapeParam, long[] out); + public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, + float[] data, int shapeParam, + float missing, int nthread, + long[] out); - public final static native int XGDMatrixCreateFromCSCEx(long[] colptr, int[] indices, float[] data, - int shapeParam, long[] out); + public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, + float[] data, int shapeParam, + float missing, int nthread, + long[] out); public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing, long[] out); @@ -93,6 +97,7 @@ public final static native int XGDMatrixGetStrFeatureInfo(long handle, String fi long[] outLength, String[][] outValues); public final static native int XGDMatrixNumRow(long handle, long[] row); + public final static native int XGDMatrixNumNonMissing(long handle, long[] nonMissings); public final static native int XGBoosterCreate(long[] handles, long[] out); @@ -146,10 +151,18 @@ final static native int CommunicatorAllreduce(ByteBuffer sendrecvbuf, int count, public final static native int XGDMatrixSetInfoFromInterface( long handle, String field, String json); + @Deprecated public final static native int XGDeviceQuantileDMatrixCreateFromCallback( java.util.Iterator iter, float missing, int nthread, int maxBin, long[] out); + public final static native int XGQuantileDMatrixCreateFromCallback( + java.util.Iterator iter, java.util.Iterator ref, String config, long[] out); + public final static native int XGDMatrixCreateFromArrayInterfaceColumns( String featureJson, float missing, int nthread, long[] out); + public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features); + + public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out); + } diff --git a/src/main/resources/lib/linux/x86_64/libxgboost4j.so b/src/main/resources/lib/linux/x86_64/libxgboost4j.so index be1e1d0..6fe961c 100644 --- a/src/main/resources/lib/linux/x86_64/libxgboost4j.so +++ b/src/main/resources/lib/linux/x86_64/libxgboost4j.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:00aeadc500abfc63978745d1d03feeb375474cf58d15409f06fe1be888b0bacc -size 7592000 +oid sha256:56e60ae2a46b4cd39617b083409a0819613fc399ed6607c0e706847f38b29f66 +size 12411608 diff --git a/src/main/resources/lib/macos/aarch64/libxgboost4j.dylib b/src/main/resources/lib/macos/aarch64/libxgboost4j.dylib old mode 100644 new mode 100755 index 53f723e..1a24bbe --- a/src/main/resources/lib/macos/aarch64/libxgboost4j.dylib +++ b/src/main/resources/lib/macos/aarch64/libxgboost4j.dylib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ea2834e2f48a21f70365faaaf5d826d3501db8cc9f7818f0a165f22618a41710 -size 3863999 +oid sha256:1206b4fa7b4ec27eea166be3ef50d4d27cf757f44a1bbebbd4165713bc7b73f5 +size 5085360 diff --git a/src/main/resources/lib/macos/x86_64/libxgboost4j.dylib b/src/main/resources/lib/macos/x86_64/libxgboost4j.dylib index 7074cd7..6cfbf67 100644 --- a/src/main/resources/lib/macos/x86_64/libxgboost4j.dylib +++ b/src/main/resources/lib/macos/x86_64/libxgboost4j.dylib @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:68c85cd0063462ec3368564d404efeaf6b1f38561f8d3b79e653cfcf24cce56e -size 4268352 +oid sha256:930c74410be917f6ee560432e32bf09eda69dccda9d023e32d5774b59609dd35 +size 5375104 diff --git a/src/main/resources/lib/windows/x86_64/xgboost4j.dll b/src/main/resources/lib/windows/x86_64/xgboost4j.dll index 14d45f0..43987d6 100644 --- a/src/main/resources/lib/windows/x86_64/xgboost4j.dll +++ b/src/main/resources/lib/windows/x86_64/xgboost4j.dll @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b3a652a16e7f72a4f0d6dd6745ecfbef0b6b77df293a627bfc50252d27e9dc6f -size 3208704 +oid sha256:4ed5df82bc1d9649ef8812839eae15eaa864f31af4b05297445ef2e5ba94decd +size 3227136 diff --git a/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 3cc54bb..d15016a 100644 --- a/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -1,5 +1,5 @@ /* - Copyright (c) 2014-2022 by Contributors + Copyright (c) 2014-2023 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,23 +15,27 @@ */ package ml.dmlc.xgboost4j.java; -import java.io.*; -import java.util.Arrays; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Map; - import junit.framework.TestCase; import org.junit.Test; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.util.*; +import java.util.concurrent.*; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.fail; + /** - * test cases for Booster + * test cases for Booster Inplace Predict * - * @author hzx + * @author hzx and Sovrn */ public class BoosterImplTest { - private String train_uri = "src/test/resources/agaricus.txt.train?indexing_mode=1"; - private String test_uri = "src/test/resources/agaricus.txt.test?indexing_mode=1"; + private String train_uri = "src/test/resources/agaricus.txt.train?indexing_mode=1&format=libsvm"; + private String test_uri = "src/test/resources/agaricus.txt.test?indexing_mode=1&format=libsvm"; public static class EvalError implements IEvaluation { @Override @@ -102,6 +106,15 @@ public void testBoosterBasic() throws XGBoostError, IOException { TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f); } + private float[] generateRandomDataSet(int size) { + float[] newSet = new float[size]; + Random random = new Random(); + for(int i = 0; i < size; i++) { + newSet[i] = random.nextFloat(); + } + return newSet; + } + @Test public void saveLoadModelWithPath() throws XGBoostError, IOException { DMatrix trainMat = new DMatrix(this.train_uri); @@ -122,6 +135,39 @@ public void saveLoadModelWithPath() throws XGBoostError, IOException { TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f); } + @Test + public void saveLoadModelWithFeaturesWithPath() throws XGBoostError, IOException { + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); + IEvaluation eval = new EvalError(); + + String[] featureNames = new String[126]; + String[] featureTypes = new String[126]; + for(int i = 0; i < 126; i++) { + featureNames[i] = "test_feature_name_" + i; + featureTypes[i] = "q"; + } + trainMat.setFeatureNames(featureNames); + testMat.setFeatureNames(featureNames); + trainMat.setFeatureTypes(featureTypes); + testMat.setFeatureTypes(featureTypes); + + Booster booster = trainBooster(trainMat, testMat); + // save and load, only json format save and load feature_name and feature_type + File temp = File.createTempFile("temp", ".json"); + temp.deleteOnExit(); + booster.saveModel(temp.getAbsolutePath()); + + String modelString = new String(booster.toByteArray("json")); + + Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath()); + assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj"))); + assert (Arrays.equals(bst2.toByteArray("json"), booster.toByteArray("json"))); + assert (Arrays.equals(bst2.toByteArray("deprecated"), booster.toByteArray("deprecated"))); + float[][] predicts2 = bst2.predict(testMat, true, 0); + TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f); + } + @Test public void saveLoadModelWithStream() throws XGBoostError, IOException { DMatrix trainMat = new DMatrix(this.train_uri); @@ -634,14 +680,12 @@ public void testTrainFromExistingModel() throws XGBoostError, IOException { float tempBoosterError = eval.eval(tempBooster.predict(testMat, true, 0), testMat); // Save tempBooster to bytestream and load back - int prevVersion = tempBooster.getVersion(); ByteArrayInputStream in = new ByteArrayInputStream(tempBooster.toByteArray()); tempBooster = XGBoost.loadModel(in); in.close(); - tempBooster.setVersion(prevVersion); // Continue training using tempBooster - round = 4; + round = 2; Booster booster2 = XGBoost.train(trainMat, paramMap, round, watches, null, null, null, 0, tempBooster); float booster2error = eval.eval(booster2.predict(testMat, true, 0), testMat); TestCase.assertTrue(booster1error == booster2error); diff --git a/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index b431668..a220e6a 100644 --- a/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -36,12 +36,11 @@ public class DMatrixTest { - @Test public void testCreateFromFile() throws XGBoostError { //create DMatrix from file String filePath = writeResourceIntoTempFile("/agaricus.txt.test"); - DMatrix dmat = new DMatrix(filePath); + DMatrix dmat = new DMatrix(filePath + "?format=libsvm"); //get label float[] labels = dmat.getLabel(); //check length