Skip to content

Commit

Permalink
[jvm-packages] Allow supression of Rabit output in Booster::train in …
Browse files Browse the repository at this point in the history
…xgboost4j (dmlc#4262)

* Make train in xgboost4j respect print params

Previously no setting in params argument of Booster::train would prevent
the Rabit.trackerPrint call. This can fill up a lot of screen space in
the case that many folds are being trained.
* Setting "silent" in this map to "true", "True", a non-zero integer, or
  a string that can be parsed to such an int will prevent printing.
* Setting "verbose_eval" to "False" or "false" will prevent printing.
* Setting "verbose_eval" to an int (or a String parseable to an int) n
  will result in printing every n steps, or no printing is n is zero.

This is to match the python behaviour described here:
https://www.kaggle.com/c/rossmann-store-sales/discussion/17499

* Fixed 'slient' typo in xgboost4j test

* private access on two methods
  • Loading branch information
harrybraviner authored and CodingCat committed Mar 21, 2019
1 parent 45c89a6 commit b374e0a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,58 @@ public static Booster train(
break;
}
}
if (Rabit.getRank() == 0) {
Rabit.trackerPrint(evalInfo + '\n');
if (Rabit.getRank() == 0 && shouldPrint(params, iter)) {
if (shouldPrint(params, iter)){
Rabit.trackerPrint(evalInfo + '\n');
}
}
}
booster.saveRabitCheckpoint();
}
return booster;
}

private static Integer tryGetIntFromObject(Object o) {
if (o instanceof Integer) {
return (int)o;
} else if (o instanceof String) {
try {
return Integer.parseInt((String)o);
} catch (NumberFormatException e) {
return null;
}
} else {
return null;
}
}

private static boolean shouldPrint(Map<String, Object> params, int iter) {
Object silent = params.get("silent");
Integer silentInt = tryGetIntFromObject(silent);
if (silent != null) {
if (silent.equals("true") || silent.equals("True")
|| (silentInt != null && silentInt != 0)) {
return false; // "silent" will stop printing, otherwise go look at "verbose_eval"
}
}

Object verboseEval = params.get("verbose_eval");
Integer verboseEvalInt = tryGetIntFromObject(verboseEval);
if (verboseEval == null) {
return true; // Default to printing evalInfo
} else if (verboseEval.equals("false") || verboseEval.equals("False")) {
return false;
} else if (verboseEvalInt != null) {
if (verboseEvalInt == 0) {
return false;
} else {
return iter % verboseEvalInt == 0;
}
} else {
return true; // Don't understand the option, default to printing
}
}

static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIteration) {
return iter - bestIteration >= earlyStoppingRounds;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class ScalaBoosterImplSuite extends FunSuite {

test("cross validation") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6",
val params = List("eta" -> "1.0", "max_depth" -> "3", "silent" -> "1", "nthread" -> "6",
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2
val nfold = 5
Expand Down

0 comments on commit b374e0a

Please sign in to comment.