Skip to content

Commit

Permalink
Make seed independent of num.threads and add legacy option (#1447)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikcs authored Sep 25, 2024
1 parent 8b08d83 commit 7bffa33
Show file tree
Hide file tree
Showing 42 changed files with 2,245 additions and 2,104 deletions.
8 changes: 3 additions & 5 deletions REFERENCE.md
Original file line number Diff line number Diff line change
Expand Up @@ -403,13 +403,11 @@ While the algorithm in `regression_forest` is very similar to that of classic ra

Overall, GRF is designed to produce the same estimates across platforms when using a consistent value for the random seed through the training option seed. However, there are still some cases where GRF can produce different estimates across platforms. When it comes to cross-platform predictions, the output of GRF will depend on a few factors beyond the forest seed.

One such factor is the compiler that was used to build GRF. Different compilers may have different default behavior around floating-point rounding, and these could lead to slightly different forest splits if the data requires numerical precision. Another factor is how the forest construction is distributed across different threads. Right now, our forest splitting algorithm can give different results depending on the number of threads that were used to build the forest.
One such factor is the compiler that was used to build GRF. Different compilers may have different default behavior around floating-point behavior and instruction optimizations, and these could lead to slightly different forest splits if the data requires numerical precision. In addition to setting the seed argument, rounding all input data to at most 8 significant digits may help.

Therefore, in order to ensure consistent results, we provide the following recommendations.
- Make sure arguments `seed` and `num.threads` are the same across platforms
- Round data to 8 significant digits
Even though the compiler is the same, different CPU architectures may produce slightly different output. One such example is GRF compiled with clang and run on x86 (Intel) vs. ARM (Apple Silicon).

Also, please note that we have not done extensive testing on Windows platforms, although we do not expect random number generation issues there to be different from Linux/Mac. Regardless of the platform, if results are still not consistent please help us by submitting a Github issue.
Prior to GRF version 2.4.0, another factor was how the forest construction was distributed across different threads. In these versions, our forest splitting algorithm can give different results depending on the number of threads used to build the forest, meaning that the num.threads argument had to be the same for cross-platform reproducibility. To restore this behavior in current versions of GRF, you can set the global R option `options(grf.legacy.seed=TRUE)` and exactly recover results produced with past versions of the package.


## References
Expand Down
8 changes: 7 additions & 1 deletion core/src/forest/ForestOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ ForestOptions::ForestOptions(uint num_trees,
double imbalance_penalty,
uint num_threads,
uint random_seed,
bool legacy_seed,
const std::vector<size_t>& sample_clusters,
uint samples_per_cluster):
ci_group_size(ci_group_size),
sample_fraction(sample_fraction),
tree_options(mtry, min_node_size, honesty, honesty_fraction, honesty_prune_leaves, alpha, imbalance_penalty),
sampling_options(samples_per_cluster, sample_clusters),
random_seed(random_seed) {
random_seed(random_seed),
legacy_seed(legacy_seed) {

this->num_threads = validate_num_threads(num_threads);

Expand Down Expand Up @@ -85,6 +87,10 @@ uint ForestOptions::get_random_seed() const {
return random_seed;
}

bool ForestOptions::get_legacy_seed() const {
return legacy_seed;
}

uint ForestOptions::validate_num_threads(uint num_threads) {
if (num_threads == DEFAULT_NUM_THREADS) {
return std::thread::hardware_concurrency();
Expand Down
4 changes: 4 additions & 0 deletions core/src/forest/ForestOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ForestOptions {
double imbalance_penalty,
uint num_threads,
uint random_seed,
bool legacy_seed,
const std::vector<size_t>& sample_clusters,
uint samples_per_cluster);

Expand All @@ -55,6 +56,8 @@ class ForestOptions {

uint get_num_threads() const;
uint get_random_seed() const;
// Toggle between seed and num_threads dependence to reproduce behavior prior to grf 2.4.0.
bool get_legacy_seed() const;

private:
uint num_trees;
Expand All @@ -66,6 +69,7 @@ class ForestOptions {

uint num_threads;
uint random_seed;
bool legacy_seed;
};

} // namespace grf
Expand Down
7 changes: 6 additions & 1 deletion core/src/forest/ForestTrainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ std::vector<std::unique_ptr<Tree>> ForestTrainer::train_batch(
trees.reserve(num_trees * ci_group_size);

for (size_t i = 0; i < num_trees; i++) {
uint tree_seed = udist(random_number_generator);
uint tree_seed;
if (options.get_legacy_seed()) {
tree_seed = udist(random_number_generator);
} else {
tree_seed = static_cast<uint>(options.get_random_seed() + start + i);
}
RandomSampler sampler(tree_seed, options.get_sampling_options());

if (ci_group_size == 1) {
Expand Down
2 changes: 1 addition & 1 deletion core/test/forest/ForestSmokeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ TEST_CASE("forests don't crash when there are fewer trees than threads", "[fores
uint samples_per_cluster = 0;

ForestOptions options(num_trees, ci_group_size, sample_fraction, mtry, min_node_size, honesty, honesty_fraction,
prune, alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
prune, alpha, imbalance_penalty, num_threads, seed, true, empty_clusters, samples_per_cluster);

Forest forest = trainer.train(data, options);
ForestPredictor predictor = regression_predictor(4);
Expand Down
4 changes: 2 additions & 2 deletions core/test/forest/LocalLinearForestTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ TEST_CASE("LLF gives reasonable prediction on friedman data", "[local linear], [
ForestOptions options (
num_trees, ci_group_size, sample_fraction,
mtry, min_node_size, honesty, honesty_fraction, prune,
alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
alpha, imbalance_penalty, num_threads, seed, true, empty_clusters, samples_per_cluster);
ForestTrainer trainer = regression_trainer();
Forest forest = trainer.train(data, options);

Expand Down Expand Up @@ -136,7 +136,7 @@ TEST_CASE("local linear forests give reasonable variance estimates", "[regressio
ForestOptions options (
num_trees, ci_group_size, sample_fraction,
mtry, min_node_size, honesty, honesty_fraction, prune,
alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
alpha, imbalance_penalty, num_threads, seed, true, empty_clusters, samples_per_cluster);
ForestTrainer trainer = regression_trainer();
Forest forest = trainer.train(data, options);

Expand Down
3 changes: 2 additions & 1 deletion core/test/utilities/ForestTestUtilities.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ ForestOptions ForestTestUtilities::default_options(bool honesty,
uint samples_per_cluster = 0;
uint num_threads = 4;
uint seed = 42;
bool legacy_seed = true;

return ForestOptions(num_trees,
ci_group_size, sample_fraction, mtry, min_node_size, honesty, honesty_fraction,
prune, alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
prune, alpha, imbalance_penalty, num_threads, seed, legacy_seed, empty_clusters, samples_per_cluster);
}
2 changes: 1 addition & 1 deletion r-package/grf/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Imports:
methods,
Rcpp (>= 0.12.15),
sandwich (>= 2.4-0)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.2
Suggests:
DiagrammeR,
MASS,
Expand Down
1 change: 1 addition & 0 deletions r-package/grf/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export(get_leaf_node)
export(get_sample_weights)
export(get_scores)
export(get_tree)
export(grf_options)
export(instrumental_forest)
export(ll_regression_forest)
export(lm_forest)
Expand Down
Loading

0 comments on commit 7bffa33

Please sign in to comment.