-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: First implementation of RFNN with sklearn RF estimators
This commit implements RFNN using strictly `RandomForestRegressor` objects from `scikit-learn` to mimic the previous implementation using `rpy2` and R's `randomForest` package. This commit was intended to show that regardless of random forest implementation, results from RF-NN should correspond closely. At present, however, these tests are currently failing. Best guess is that the implementation of random forest and hyperparameters differ enough that we can't expect that the same trees are being created in each forest. Testing predicted values will not work until neighbors and distances correspond closely between the two implementations. What I've observed is that neighbors for training data roughly correspond to each other but often don't share the same rank. We have tried using Spearman's rank correlation and the footrule distance to get a measure of the difference in neighbor position, but do not have a good standard for correctness.
- Loading branch information
Showing
4 changed files
with
134 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from sklearn.ensemble import RandomForestRegressor | ||
|
||
|
||
def sklearn_get_forest(X, y, n_tree, mt): | ||
""" | ||
Train a random forest regression model in sklearn. | ||
""" | ||
rf = RandomForestRegressor( | ||
n_estimators=n_tree, max_features=mt, random_state=42, min_samples_leaf=5 | ||
) | ||
rf.fit(X, y) | ||
return rf | ||
|
||
|
||
def sklearn_get_nodeset(rf, X): | ||
""" | ||
Get the nodes associated with X of the random forest regression model in sklearn. | ||
""" | ||
return rf.apply(X) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters