Skip to content

Commit

Permalink
enable quantile regression on RangerForestRegressor (#47)
Browse files Browse the repository at this point in the history
* add matplotlib to dev

* add quantile regression

* update docs

* add default to docs
  • Loading branch information
crflynn authored Oct 28, 2020
1 parent b08e005 commit 26ffe86
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 21 deletions.
33 changes: 26 additions & 7 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ The ``RangerForestClassifier`` predictor uses ``ranger``'s ForestProbability cla
from sklearn.model_selection import train_test_split
from skranger.ensemble import RangerForestClassifier
X, y = load_iris(True)
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
rfc = RangerForestClassifier()
Expand All @@ -72,27 +72,46 @@ The ``RangerForestClassifier`` predictor uses ``ranger``'s ForestProbability cla
RangerForestRegressor
~~~~~~~~~~~~~~~~~~~~~

The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class.
The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class. It also supports quantile regression using the ``predict_quantiles`` method.

.. code-block:: python
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from skranger.ensemble import RangerForestRegressor
X, y = load_boston(True)
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
rfr = RangerForestRegressor()
rfr.fit(X_train, y_train)
predictions = rfr.predict(X_test)
print(predictions)
# [20.01270808 24.65041667 11.97722067 20.10345 26.48676667 42.19045952
# 19.821 31.51163333 8.34169603 18.94511667 20.21901915 16.01440705
# [18.39205325 21.41698333 14.29509221 35.34981667 27.64378333 20.98569135
# 21.15996673 14.0288093 9.44657947 29.99185 19.3774 11.88189465
# ...
# 18.37752952 19.34765 20.13355 21.19648333 18.91611667 15.58964837
# 31.4223 ]
# 11.08502822 36.80993636 18.29633154 12.90448354 20.94311667 11.45154934
# 41.44466667]
# enable quantile regression on instantiation
rfr = RangerForestRegressor(quantiles=True)
rfr.fit(X_train, y_train)
quantile_lower = rfr.predict_quantiles(X_test, quantiles=[0.1])
print(quantile_lower)
# [12.9 17. 8. 28. 22. 10.9 7. 8. 5. 20.8 16.9 7. 8. 18.
# 22. 19. 29. 21. 19. 19. 22. 10.9 20. 16. 14. 20. 9.8 22.9
# ...
# 16. 17. 12. 20. 13. 26. 19. 21.9 7. 14.9 13. 8. 17.9 7.9
# 29. ]
quantile_upper = rfr.predict_quantiles(X_test, quantiles=[0.9])
print(quantile_upper)
# [23. 27. 21. 44. 32.1 50. 50. 18.2 12. 43. 22. 17. 17. 24.
# 31.1 25. 37. 28. 23. 24. 28. 18. 28. 23. 23. 26. 17.1 43.
# ...
# 22. 24. 20. 28. 18. 44.2 24. 33.4 15.1 50. 21. 17. 25. 13.
# 50. ]
RangerForestSurvival
Expand Down
33 changes: 26 additions & 7 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ The ``RangerForestClassifier`` predictor uses ``ranger``'s ForestProbability cla
from sklearn.model_selection import train_test_split
from skranger.ensemble import RangerForestClassifier
X, y = load_iris(True)
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
rfc = RangerForestClassifier()
Expand All @@ -84,27 +84,46 @@ The ``RangerForestClassifier`` predictor uses ``ranger``'s ForestProbability cla
RangerForestRegressor
~~~~~~~~~~~~~~~~~~~~~

The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class.
The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class. It also supports quantile regression using the ``predict_quantiles`` method.

.. code-block:: python
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from skranger.ensemble import RangerForestRegressor
X, y = load_boston(True)
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
rfr = RangerForestRegressor()
rfr.fit(X_train, y_train)
predictions = rfr.predict(X_test)
print(predictions)
# [20.01270808 24.65041667 11.97722067 20.10345 26.48676667 42.19045952
# 19.821 31.51163333 8.34169603 18.94511667 20.21901915 16.01440705
# [18.39205325 21.41698333 14.29509221 35.34981667 27.64378333 20.98569135
# 21.15996673 14.0288093 9.44657947 29.99185 19.3774 11.88189465
# ...
# 18.37752952 19.34765 20.13355 21.19648333 18.91611667 15.58964837
# 31.4223 ]
# 11.08502822 36.80993636 18.29633154 12.90448354 20.94311667 11.45154934
# 41.44466667]
# enable quantile regression on instantiation
rfr = RangerForestRegressor(quantiles=True)
rfr.fit(X_train, y_train)
quantile_lower = rfr.predict_quantiles(X_test, quantiles=[0.1])
print(quantile_lower)
# [12.9 17. 8. 28. 22. 10.9 7. 8. 5. 20.8 16.9 7. 8. 18.
# 22. 19. 29. 21. 19. 19. 22. 10.9 20. 16. 14. 20. 9.8 22.9
# ...
# 16. 17. 12. 20. 13. 26. 19. 21.9 7. 14.9 13. 8. 17.9 7.9
# 29. ]
quantile_upper = rfr.predict_quantiles(X_test, quantiles=[0.9])
print(quantile_upper)
# [23. 27. 21. 44. 32.1 50. 50. 18.2 12. 43. 22. 17. 17. 24.
# 31.1 25. 37. 28. 23. 24. 28. 18. 28. 23. 23. 26. 17.1 43.
# ...
# 22. 24. 20. 28. 18. 44.2 24. 33.4 15.1 50. 21. 17. 25. 13.
# 50. ]
RangerForestSurvival
Expand Down
6 changes: 5 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ chardet==3.0.4
click==7.1.2
colorama==0.4.3; python_version >= "3.3" and sys_platform == "win32" or sys_platform == "win32"
coverage==5.2
cycler==0.10.0
cython==3.0a5
decorator==4.4.2
defusedxml==0.6.0
Expand All @@ -33,7 +34,9 @@ jupyter==1.0.0
jupyter-client==6.1.5
jupyter-console==6.1.0
jupyter-core==4.6.3
kiwisolver==1.2.0
markupsafe==1.1.1
matplotlib==3.3.2
mistune==0.8.4
more-itertools==8.4.0
nbconvert==5.6.1
Expand All @@ -47,10 +50,11 @@ parso==0.7.0
pathspec==0.8.0
pexpect==4.8.0; python_version >= "3.3" and sys_platform != "win32" or sys_platform != "win32"
pickleshare==0.7.5
pillow==8.0.1
pluggy==0.13.1
prometheus-client==0.8.0
prompt-toolkit==3.0.5
ptyprocess==0.6.0; sys_platform != "win32" or os_name != "nt" or python_version >= "3.3" and sys_platform != "win32"
ptyprocess==0.6.0; python_version >= "3.3" and sys_platform != "win32" or sys_platform != "win32" or os_name != "nt" or python_version >= "3.3" and sys_platform != "win32" and (python_version >= "3.3" and sys_platform != "win32" or sys_platform != "win32")
py==1.9.0
pygments==2.6.1
pyparsing==2.4.7
Expand Down
Loading

0 comments on commit 26ffe86

Please sign in to comment.