Skip to content

Commit

Permalink
2.0.2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
shuttie committed Dec 5, 2023
1 parent 502a83d commit cc4efd5
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 137 deletions.
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ crossPaths := false

autoScalaLibrary := false

version := "1.7.1-1"
version := "2.0.2-1"

//javacOptions ++= Seq("-release", "8")

Expand Down Expand Up @@ -34,7 +34,7 @@ scalacOptions ++= Seq("-feature", "-deprecation")
libraryDependencies ++= Seq(
"junit" % "junit" % "4.13.2" % "test",
"com.github.sbt" % "junit-interface" % "0.13.3" % Test,
"org.slf4j" % "slf4j-api" % "1.7.36",
"org.slf4j" % "slf4j-api" % "2.0.5",
"org.apache.hadoop" % "hadoop-hdfs" % "2.10.1" % "provided",
"org.apache.hadoop" % "hadoop-common" % "2.10.1" % "provided"
)
47 changes: 46 additions & 1 deletion src/main/java/ml/dmlc/xgboost4j/java/Booster.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,51 @@ public void setAttrs(Map<String, String> attrs) throws XGBoostError {
}
}

/**
* Get feature names from the Booster.
* @return
* @throws XGBoostError
*/
public final String[] getFeatureNames() throws XGBoostError {
int numFeature = (int) getNumFeature();
String[] out = new String[numFeature];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_name", out));
return out;
}

/**
* Set feature names to the Booster.
*
* @param featureNames
* @throws XGBoostError
*/
public void setFeatureNames(String[] featureNames) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
handle, "feature_name", featureNames));
}

/**
* Get feature types from the Booster.
* @return
* @throws XGBoostError
*/
public final String[] getFeatureTypes() throws XGBoostError {
int numFeature = (int) getNumFeature();
String[] out = new String[numFeature];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_type", out));
return out;
}

/**
* Set feature types to the Booster.
* @param featureTypes
* @throws XGBoostError
*/
public void setFeatureTypes(String[] featureTypes) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
handle, "feature_type", featureTypes));
}

/**
* Update the booster for one iteration.
*
Expand Down Expand Up @@ -740,7 +785,7 @@ private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeInt(version);
out.writeObject(this.toByteArray());
out.writeObject(this.toByteArray("ubj"));
} catch (XGBoostError ex) {
ex.printStackTrace();
logger.error(ex.getMessage());
Expand Down
1 change: 0 additions & 1 deletion src/main/java/ml/dmlc/xgboost4j/java/ColumnBatch.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ public final String getArrayInterfaceJson() {
/**
* Get the cuda array interface of the label columns.
* The returned value must not be null or empty if we're creating
* {@link DeviceQuantileDMatrix#DeviceQuantileDMatrix(Iterator, float, int, int)}
*/
public abstract String getLabelsArrayInterface();

Expand Down
45 changes: 27 additions & 18 deletions src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2023 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -57,17 +57,9 @@ public DMatrix(String dataPath) throws XGBoostError {
* @throws XGBoostError
*/
@Deprecated
public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st)
throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out));
} else if (st == SparseType.CSC) {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out));
} else {
throw new UnknownError("unknow sparsetype");
}
handle = out[0];
public DMatrix(long[] headers, int[] indices, float[] data,
DMatrix.SparseType st) throws XGBoostError {
this(headers, indices, data, st, 0, Float.NaN, -1);
}

/**
Expand All @@ -80,15 +72,20 @@ public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType s
* row number
* @throws XGBoostError
*/
public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st, int shapeParam)
throws XGBoostError {
public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st,
int shapeParam) throws XGBoostError {
this(headers, indices, data, st, shapeParam, Float.NaN, -1);
}

public DMatrix(long[] headers, int[] indices, float[] data, DMatrix.SparseType st, int shapeParam,
float missing, int nthread) throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data,
shapeParam, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data,
shapeParam, missing, nthread, out));
} else if (st == SparseType.CSC) {
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data,
shapeParam, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data,
shapeParam, missing, nthread, out));
} else {
throw new UnknownError("unknow sparsetype");
}
Expand Down Expand Up @@ -403,6 +400,18 @@ public long rowNum() throws XGBoostError {
return rowNum[0];
}

/**
* Get the number of non-missing values of DMatrix.
*
* @return The number of non-missing values
* @throws XGBoostError native error
*/
public long nonMissingNum() throws XGBoostError {
long[] n = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixNumNonMissing(handle, n));
return n[0];
}

/**
* save DMatrix to filePath
*/
Expand Down
68 changes: 0 additions & 68 deletions src/main/java/ml/dmlc/xgboost4j/java/DeviceQuantileDMatrix.java

This file was deleted.

86 changes: 69 additions & 17 deletions src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import java.io.*;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -30,6 +32,11 @@
public class XGBoost {
private static final Log logger = LogFactory.getLog(XGBoost.class);

public static final String[] MAXIMIZ_METRICES = {
"auc", "aucpr", "pre", "pre@", "map", "ndcg",
"auc@", "aucpr@", "map@", "ndcg@",
};

/**
* load model from modelPath
*
Expand Down Expand Up @@ -158,7 +165,7 @@ public static Booster trainAndSaveCheckpoint(
//collect eval matrixs
String[] evalNames;
DMatrix[] evalMats;
float bestScore;
float bestScore = 1;
int bestIteration;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
Expand All @@ -175,11 +182,7 @@ public static Booster trainAndSaveCheckpoint(

evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
if (isMaximizeEvaluation(params)) {
bestScore = -Float.MAX_VALUE;
} else {
bestScore = Float.MAX_VALUE;
}

bestIteration = 0;
metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics;

Expand All @@ -198,6 +201,8 @@ public static Booster trainAndSaveCheckpoint(
if (booster == null) {
// Start training on a new booster
booster = new Booster(params, allMats);
booster.setFeatureNames(dtrain.getFeatureNames());
booster.setFeatureTypes(dtrain.getFeatureTypes());
booster.loadRabitCheckpoint();
} else {
// Start training on an existing booster
Expand All @@ -208,6 +213,9 @@ public static Booster trainAndSaveCheckpoint(
checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds));
}

boolean initial_best_score_flag = false;
boolean max_direction = false;

// begin to train
for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) {
if (booster.getVersion() % 2 == 0) {
Expand All @@ -229,14 +237,26 @@ public static Booster trainAndSaveCheckpoint(
} else {
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
}

if (!initial_best_score_flag) {
if (isMaximizeEvaluation(evalInfo, evalNames, params)) {
max_direction = true;
bestScore = -Float.MAX_VALUE;
} else {
max_direction = false;
bestScore = Float.MAX_VALUE;
}
initial_best_score_flag = true;
}

for (int i = 0; i < metricsOut.length; i++) {
metrics[i][iter] = metricsOut[i];
}

// If there is more than one evaluation datasets, the last one would be used
// to determinate early stop.
float score = metricsOut[metricsOut.length - 1];
if (isMaximizeEvaluation(params)) {
if (max_direction) {
// Update best score if the current score is better (no update when equal)
if (score > bestScore) {
bestScore = score;
Expand All @@ -262,9 +282,7 @@ public static Booster trainAndSaveCheckpoint(
break;
}
if (Communicator.getRank() == 0 && shouldPrint(params, iter)) {
if (shouldPrint(params, iter)){
Communicator.communicatorPrint(evalInfo + '\n');
}
Communicator.communicatorPrint(evalInfo + '\n');
}
}
booster.saveRabitCheckpoint();
Expand Down Expand Up @@ -358,16 +376,50 @@ static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIterat
return iter - bestIteration >= earlyStoppingRounds;
}

private static boolean isMaximizeEvaluation(Map<String, Object> params) {
try {
private static String getMetricNameFromlog(String evalInfo, String[] evalNames) {
String regexPattern = Pattern.quote(evalNames[0]) + "-(.*):";
Pattern pattern = Pattern.compile(regexPattern);
Matcher matcher = pattern.matcher(evalInfo);

String metricName = null;
if (matcher.find()) {
metricName = matcher.group(1);
logger.debug("Got the metric name: " + metricName);
}
return metricName;
}

// visiable for testing
public static boolean isMaximizeEvaluation(String evalInfo,
String[] evalNames,
Map<String, Object> params) {

String metricName;

if (params.get("maximize_evaluation_metrics") != null) {
// user has forced the direction no matter what is the metric name.
String maximize = String.valueOf(params.get("maximize_evaluation_metrics"));
assert(maximize != null);
return Boolean.valueOf(maximize);
} catch (Exception ex) {
logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," +
" allowed value: true/false", ex);
throw ex;
}

if (params.get("eval_metric") != null) {
// user has special metric name
metricName = String.valueOf(params.get("eval_metric"));
} else {
// infer the metric name from log
metricName = getMetricNameFromlog(evalInfo, evalNames);
}

assert metricName != null;

if (!"mape".equals(metricName)) {
for (String x : MAXIMIZ_METRICES) {
if (metricName.startsWith(x)) {
return true;
}
}
}
return false;
}

/**
Expand Down
Loading

0 comments on commit cc4efd5

Please sign in to comment.