From 26ffe863b45a28e547a54967ee23adf548c1b369 Mon Sep 17 00:00:00 2001 From: Flynn Date: Wed, 28 Oct 2020 17:21:48 -0400 Subject: [PATCH] enable quantile regression on RangerForestRegressor (#47) * add matplotlib to dev * add quantile regression * update docs * add default to docs --- README.rst | 33 ++++- docs/index.rst | 33 ++++- docs/requirements.txt | 6 +- poetry.lock | 120 +++++++++++++++++- pyproject.toml | 1 + skranger/ensemble/ranger_forest_regressor.py | 102 ++++++++++++++- .../ensemble/test_ranger_forest_regressor.py | 18 +++ 7 files changed, 292 insertions(+), 21 deletions(-) diff --git a/README.rst b/README.rst index 6dd069e..fefb688 100644 --- a/README.rst +++ b/README.rst @@ -50,7 +50,7 @@ The ``RangerForestClassifier`` predictor uses ``ranger``'s ForestProbability cla from sklearn.model_selection import train_test_split from skranger.ensemble import RangerForestClassifier - X, y = load_iris(True) + X, y = load_iris(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y) rfc = RangerForestClassifier() @@ -72,7 +72,7 @@ The ``RangerForestClassifier`` predictor uses ``ranger``'s ForestProbability cla RangerForestRegressor ~~~~~~~~~~~~~~~~~~~~~ -The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class. +The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class. It also supports quantile regression using the ``predict_quantiles`` method. .. code-block:: python @@ -80,7 +80,7 @@ The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class from sklearn.model_selection import train_test_split from skranger.ensemble import RangerForestRegressor - X, y = load_boston(True) + X, y = load_boston(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y) rfr = RangerForestRegressor() @@ -88,11 +88,30 @@ The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class predictions = rfr.predict(X_test) print(predictions) - # [20.01270808 24.65041667 11.97722067 20.10345 26.48676667 42.19045952 - # 19.821 31.51163333 8.34169603 18.94511667 20.21901915 16.01440705 + # [18.39205325 21.41698333 14.29509221 35.34981667 27.64378333 20.98569135 + # 21.15996673 14.0288093 9.44657947 29.99185 19.3774 11.88189465 # ... - # 18.37752952 19.34765 20.13355 21.19648333 18.91611667 15.58964837 - # 31.4223 ] + # 11.08502822 36.80993636 18.29633154 12.90448354 20.94311667 11.45154934 + # 41.44466667] + + # enable quantile regression on instantiation + rfr = RangerForestRegressor(quantiles=True) + rfr.fit(X_train, y_train) + + quantile_lower = rfr.predict_quantiles(X_test, quantiles=[0.1]) + print(quantile_lower) + # [12.9 17. 8. 28. 22. 10.9 7. 8. 5. 20.8 16.9 7. 8. 18. + # 22. 19. 29. 21. 19. 19. 22. 10.9 20. 16. 14. 20. 9.8 22.9 + # ... + # 16. 17. 12. 20. 13. 26. 19. 21.9 7. 14.9 13. 8. 17.9 7.9 + # 29. ] + quantile_upper = rfr.predict_quantiles(X_test, quantiles=[0.9]) + print(quantile_upper) + # [23. 27. 21. 44. 32.1 50. 50. 18.2 12. 43. 22. 17. 17. 24. + # 31.1 25. 37. 28. 23. 24. 28. 18. 28. 23. 23. 26. 17.1 43. + # ... + # 22. 24. 20. 28. 18. 44.2 24. 33.4 15.1 50. 21. 17. 25. 13. + # 50. ] RangerForestSurvival diff --git a/docs/index.rst b/docs/index.rst index 6654ad6..747c9b5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -62,7 +62,7 @@ The ``RangerForestClassifier`` predictor uses ``ranger``'s ForestProbability cla from sklearn.model_selection import train_test_split from skranger.ensemble import RangerForestClassifier - X, y = load_iris(True) + X, y = load_iris(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y) rfc = RangerForestClassifier() @@ -84,7 +84,7 @@ The ``RangerForestClassifier`` predictor uses ``ranger``'s ForestProbability cla RangerForestRegressor ~~~~~~~~~~~~~~~~~~~~~ -The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class. +The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class. It also supports quantile regression using the ``predict_quantiles`` method. .. code-block:: python @@ -92,7 +92,7 @@ The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class from sklearn.model_selection import train_test_split from skranger.ensemble import RangerForestRegressor - X, y = load_boston(True) + X, y = load_boston(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y) rfr = RangerForestRegressor() @@ -100,11 +100,30 @@ The ``RangerForestRegressor`` predictor uses ``ranger``'s ForestRegression class predictions = rfr.predict(X_test) print(predictions) - # [20.01270808 24.65041667 11.97722067 20.10345 26.48676667 42.19045952 - # 19.821 31.51163333 8.34169603 18.94511667 20.21901915 16.01440705 + # [18.39205325 21.41698333 14.29509221 35.34981667 27.64378333 20.98569135 + # 21.15996673 14.0288093 9.44657947 29.99185 19.3774 11.88189465 # ... - # 18.37752952 19.34765 20.13355 21.19648333 18.91611667 15.58964837 - # 31.4223 ] + # 11.08502822 36.80993636 18.29633154 12.90448354 20.94311667 11.45154934 + # 41.44466667] + + # enable quantile regression on instantiation + rfr = RangerForestRegressor(quantiles=True) + rfr.fit(X_train, y_train) + + quantile_lower = rfr.predict_quantiles(X_test, quantiles=[0.1]) + print(quantile_lower) + # [12.9 17. 8. 28. 22. 10.9 7. 8. 5. 20.8 16.9 7. 8. 18. + # 22. 19. 29. 21. 19. 19. 22. 10.9 20. 16. 14. 20. 9.8 22.9 + # ... + # 16. 17. 12. 20. 13. 26. 19. 21.9 7. 14.9 13. 8. 17.9 7.9 + # 29. ] + quantile_upper = rfr.predict_quantiles(X_test, quantiles=[0.9]) + print(quantile_upper) + # [23. 27. 21. 44. 32.1 50. 50. 18.2 12. 43. 22. 17. 17. 24. + # 31.1 25. 37. 28. 23. 24. 28. 18. 28. 23. 23. 26. 17.1 43. + # ... + # 22. 24. 20. 28. 18. 44.2 24. 33.4 15.1 50. 21. 17. 25. 13. + # 50. ] RangerForestSurvival diff --git a/docs/requirements.txt b/docs/requirements.txt index 76b86f8..f796b2f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -12,6 +12,7 @@ chardet==3.0.4 click==7.1.2 colorama==0.4.3; python_version >= "3.3" and sys_platform == "win32" or sys_platform == "win32" coverage==5.2 +cycler==0.10.0 cython==3.0a5 decorator==4.4.2 defusedxml==0.6.0 @@ -33,7 +34,9 @@ jupyter==1.0.0 jupyter-client==6.1.5 jupyter-console==6.1.0 jupyter-core==4.6.3 +kiwisolver==1.2.0 markupsafe==1.1.1 +matplotlib==3.3.2 mistune==0.8.4 more-itertools==8.4.0 nbconvert==5.6.1 @@ -47,10 +50,11 @@ parso==0.7.0 pathspec==0.8.0 pexpect==4.8.0; python_version >= "3.3" and sys_platform != "win32" or sys_platform != "win32" pickleshare==0.7.5 +pillow==8.0.1 pluggy==0.13.1 prometheus-client==0.8.0 prompt-toolkit==3.0.5 -ptyprocess==0.6.0; sys_platform != "win32" or os_name != "nt" or python_version >= "3.3" and sys_platform != "win32" +ptyprocess==0.6.0; python_version >= "3.3" and sys_platform != "win32" or sys_platform != "win32" or os_name != "nt" or python_version >= "3.3" and sys_platform != "win32" and (python_version >= "3.3" and sys_platform != "win32" or sys_platform != "win32") py==1.9.0 pygments==2.6.1 pyparsing==2.4.7 diff --git a/poetry.lock b/poetry.lock index 579f4f6..a6d2f1c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -142,6 +142,17 @@ version = "5.2" [package.extras] toml = ["toml"] +[[package]] +category = "dev" +description = "Composable style cycles" +name = "cycler" +optional = false +python-versions = "*" +version = "0.10.0" + +[package.dependencies] +six = "*" + [[package]] category = "dev" description = "The Cython compiler for writing C extensions for the Python language." @@ -430,6 +441,14 @@ version = "4.6.3" pywin32 = ">=1.0" traitlets = "*" +[[package]] +category = "dev" +description = "A fast implementation of the Cassowary constraint solver" +name = "kiwisolver" +optional = false +python-versions = ">=3.6" +version = "1.2.0" + [[package]] category = "dev" description = "Safely add untrusted strings to HTML/XML markup." @@ -438,6 +457,23 @@ optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*" version = "1.1.1" +[[package]] +category = "dev" +description = "Python plotting package" +name = "matplotlib" +optional = false +python-versions = ">=3.6" +version = "3.3.2" + +[package.dependencies] +certifi = ">=2020.06.20" +cycler = ">=0.10" +kiwisolver = ">=1.0.1" +numpy = ">=1.15" +pillow = ">=6.2.0" +pyparsing = ">=2.0.3,<2.0.4 || >2.0.4,<2.1.2 || >2.1.2,<2.1.6 || >2.1.6" +python-dateutil = ">=2.1" + [[package]] category = "dev" description = "The fastest markdown parser in pure Python" @@ -608,6 +644,14 @@ optional = false python-versions = "*" version = "0.7.5" +[[package]] +category = "dev" +description = "Python Imaging Library (Fork)" +name = "pillow" +optional = false +python-versions = ">=3.6" +version = "8.0.1" + [[package]] category = "dev" description = "plugin and hook calling mechanisms for python" @@ -649,7 +693,7 @@ wcwidth = "*" [[package]] category = "dev" description = "Run a subprocess in a pseudo terminal" -marker = "sys_platform != \"win32\" or os_name != \"nt\" or python_version >= \"3.3\" and sys_platform != \"win32\"" +marker = "python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\" or os_name != \"nt\" or python_version >= \"3.3\" and sys_platform != \"win32\" and (python_version >= \"3.3\" and sys_platform != \"win32\" or sys_platform != \"win32\")" name = "ptyprocess" optional = false python-versions = "*" @@ -1132,7 +1176,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "303424c09288db72ab8423dc3bf973d7d2a3738e7a49d4062c02d1f5dda237dd" +content-hash = "3844ef91dc055c860b490bbc53d0bfba20a11e9c7596aca12908780cbc5c2401" python-versions = "^3.6.1" [metadata.files] @@ -1224,6 +1268,10 @@ coverage = [ {file = "coverage-5.2-cp39-cp39-win_amd64.whl", hash = "sha256:10f2a618a6e75adf64329f828a6a5b40244c1c50f5ef4ce4109e904e69c71bd2"}, {file = "coverage-5.2.tar.gz", hash = "sha256:1874bdc943654ba46d28f179c1846f5710eda3aeb265ff029e0ac2b52daae404"}, ] +cycler = [ + {file = "cycler-0.10.0-py2.py3-none-any.whl", hash = "sha256:1d8a5ae1ff6c5cf9b93e8811e581232ad8920aeec647c37316ceac982b08cb2d"}, + {file = "cycler-0.10.0.tar.gz", hash = "sha256:cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8"}, +] cython = [ {file = "Cython-3.0a5-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:c06b76bb6e0c90129f1428049db0efa7f2bdf9d18861415fa514e66f12c182a0"}, {file = "Cython-3.0a5-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:10272614a25735edd5f8c29c44d9cb7f9683183b451dc84ba487590bdcc82bc2"}, @@ -1335,6 +1383,24 @@ jupyter-core = [ {file = "jupyter_core-4.6.3-py2.py3-none-any.whl", hash = "sha256:a4ee613c060fe5697d913416fc9d553599c05e4492d58fac1192c9a6844abb21"}, {file = "jupyter_core-4.6.3.tar.gz", hash = "sha256:394fd5dd787e7c8861741880bdf8a00ce39f95de5d18e579c74b882522219e7e"}, ] +kiwisolver = [ + {file = "kiwisolver-1.2.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:443c2320520eda0a5b930b2725b26f6175ca4453c61f739fef7a5847bd262f74"}, + {file = "kiwisolver-1.2.0-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:efcf3397ae1e3c3a4a0a0636542bcad5adad3b1dd3e8e629d0b6e201347176c8"}, + {file = "kiwisolver-1.2.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:fccefc0d36a38c57b7bd233a9b485e2f1eb71903ca7ad7adacad6c28a56d62d2"}, + {file = "kiwisolver-1.2.0-cp36-none-win32.whl", hash = "sha256:60a78858580761fe611d22127868f3dc9f98871e6fdf0a15cc4203ed9ba6179b"}, + {file = "kiwisolver-1.2.0-cp36-none-win_amd64.whl", hash = "sha256:556da0a5f60f6486ec4969abbc1dd83cf9b5c2deadc8288508e55c0f5f87d29c"}, + {file = "kiwisolver-1.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7cc095a4661bdd8a5742aaf7c10ea9fac142d76ff1770a0f84394038126d8fc7"}, + {file = "kiwisolver-1.2.0-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:c955791d80e464da3b471ab41eb65cf5a40c15ce9b001fdc5bbc241170de58ec"}, + {file = "kiwisolver-1.2.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:603162139684ee56bcd57acc74035fceed7dd8d732f38c0959c8bd157f913fec"}, + {file = "kiwisolver-1.2.0-cp37-none-win32.whl", hash = "sha256:03662cbd3e6729f341a97dd2690b271e51a67a68322affab12a5b011344b973c"}, + {file = "kiwisolver-1.2.0-cp37-none-win_amd64.whl", hash = "sha256:4eadb361baf3069f278b055e3bb53fa189cea2fd02cb2c353b7a99ebb4477ef1"}, + {file = "kiwisolver-1.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c31bc3c8e903d60a1ea31a754c72559398d91b5929fcb329b1c3a3d3f6e72113"}, + {file = "kiwisolver-1.2.0-cp38-cp38-manylinux1_i686.whl", hash = "sha256:d52b989dc23cdaa92582ceb4af8d5bcc94d74b2c3e64cd6785558ec6a879793e"}, + {file = "kiwisolver-1.2.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:e586b28354d7b6584d8973656a7954b1c69c93f708c0c07b77884f91640b7657"}, + {file = "kiwisolver-1.2.0-cp38-none-win32.whl", hash = "sha256:d069ef4b20b1e6b19f790d00097a5d5d2c50871b66d10075dab78938dc2ee2cf"}, + {file = "kiwisolver-1.2.0-cp38-none-win_amd64.whl", hash = "sha256:18d749f3e56c0480dccd1714230da0f328e6e4accf188dd4e6884bdd06bf02dd"}, + {file = "kiwisolver-1.2.0.tar.gz", hash = "sha256:247800260cd38160c362d211dcaf4ed0f7816afb5efe56544748b21d6ad6d17f"}, +] markupsafe = [ {file = "MarkupSafe-1.1.1-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:09027a7803a62ca78792ad89403b1b7a73a01c8cb65909cd876f7fcebd79b161"}, {file = "MarkupSafe-1.1.1-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:e249096428b3ae81b08327a63a485ad0878de3fb939049038579ac0ef61e17e7"}, @@ -1370,6 +1436,26 @@ markupsafe = [ {file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, ] +matplotlib = [ + {file = "matplotlib-3.3.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:27f9de4784ae6fb97679556c5542cf36c0751dccb4d6407f7c62517fa2078868"}, + {file = "matplotlib-3.3.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:06866c138d81a593b535d037b2727bec9b0818cadfe6a81f6ec5715b8dd38a89"}, + {file = "matplotlib-3.3.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:5ccecb5f78b51b885f0028b646786889f49c54883e554fca41a2a05998063f23"}, + {file = "matplotlib-3.3.2-cp36-cp36m-win32.whl", hash = "sha256:69cf76d673682140f46c6cb5e073332c1f1b2853c748dc1cb04f7d00023567f7"}, + {file = "matplotlib-3.3.2-cp36-cp36m-win_amd64.whl", hash = "sha256:371518c769d84af8ec9b7dcb871ac44f7a67ef126dd3a15c88c25458e6b6d205"}, + {file = "matplotlib-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:793e061054662aa27acaff9201cdd510a698541c6e8659eeceb31d66c16facc6"}, + {file = "matplotlib-3.3.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:16b241c3d17be786966495229714de37de04472da472277869b8d5b456a8df00"}, + {file = "matplotlib-3.3.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:3fb0409754b26f48045bacd6818e44e38ca9338089f8ba689e2f9344ff2847c7"}, + {file = "matplotlib-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:548cfe81476dbac44db96e9c0b074b6fb333b4d1f12b1ae68dbed47e45166384"}, + {file = "matplotlib-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:f0268613073df055bcc6a490de733012f2cf4fe191c1adb74e41cec8add1a165"}, + {file = "matplotlib-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:57be9e21073fc367237b03ecac0d9e4b8ddbe38e86ec4a316857d8d93ac9286c"}, + {file = "matplotlib-3.3.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:be2f0ec62e0939a9dcfd3638c140c5a74fc929ee3fd1f31408ab8633db6e1523"}, + {file = "matplotlib-3.3.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:c5d0c2ae3e3ed4e9f46b7c03b40d443601012ffe8eb8dfbb2bd6b2d00509f797"}, + {file = "matplotlib-3.3.2-cp38-cp38-win32.whl", hash = "sha256:a522de31e07ed7d6f954cda3fbd5ca4b8edbfc592a821a7b00291be6f843292e"}, + {file = "matplotlib-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:8bc1d3284dee001f41ec98f59675f4d723683e1cc082830b440b5f081d8e0ade"}, + {file = "matplotlib-3.3.2-pp36-pypy36_pp73-macosx_10_9_x86_64.whl", hash = "sha256:799c421bc245a0749c1515b6dea6dc02db0a8c1f42446a0f03b3b82a60a900dc"}, + {file = "matplotlib-3.3.2-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:2f5eefc17dc2a71318d5a3496313be5c351c0731e8c4c6182c9ac3782cfc4076"}, + {file = "matplotlib-3.3.2.tar.gz", hash = "sha256:3d2edbf59367f03cd9daf42939ca06383a7d7803e3993eb5ff1bee8e8a3fbb6b"}, +] mistune = [ {file = "mistune-0.8.4-py2.py3-none-any.whl", hash = "sha256:88a1051873018da288eee8538d476dffe1262495144b33ecb586c4ab266bb8d4"}, {file = "mistune-0.8.4.tar.gz", hash = "sha256:59a3429db53c50b5c6bcc8a07f8848cb00d7dc8bdb431a4ab41920d201d4756e"}, @@ -1459,6 +1545,36 @@ pickleshare = [ {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"}, {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"}, ] +pillow = [ + {file = "Pillow-8.0.1-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:b63d4ff734263ae4ce6593798bcfee6dbfb00523c82753a3a03cbc05555a9cc3"}, + {file = "Pillow-8.0.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:5f9403af9c790cc18411ea398a6950ee2def2a830ad0cfe6dc9122e6d528b302"}, + {file = "Pillow-8.0.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:6b4a8fd632b4ebee28282a9fef4c341835a1aa8671e2770b6f89adc8e8c2703c"}, + {file = "Pillow-8.0.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:cc3ea6b23954da84dbee8025c616040d9aa5eaf34ea6895a0a762ee9d3e12e11"}, + {file = "Pillow-8.0.1-cp36-cp36m-win32.whl", hash = "sha256:d8a96747df78cda35980905bf26e72960cba6d355ace4780d4bdde3b217cdf1e"}, + {file = "Pillow-8.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:7ba0ba61252ab23052e642abdb17fd08fdcfdbbf3b74c969a30c58ac1ade7cd3"}, + {file = "Pillow-8.0.1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:795e91a60f291e75de2e20e6bdd67770f793c8605b553cb6e4387ce0cb302e09"}, + {file = "Pillow-8.0.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:0a2e8d03787ec7ad71dc18aec9367c946ef8ef50e1e78c71f743bc3a770f9fae"}, + {file = "Pillow-8.0.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:006de60d7580d81f4a1a7e9f0173dc90a932e3905cc4d47ea909bc946302311a"}, + {file = "Pillow-8.0.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:bd7bf289e05470b1bc74889d1466d9ad4a56d201f24397557b6f65c24a6844b8"}, + {file = "Pillow-8.0.1-cp37-cp37m-win32.whl", hash = "sha256:95edb1ed513e68bddc2aee3de66ceaf743590bf16c023fb9977adc4be15bd3f0"}, + {file = "Pillow-8.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:e38d58d9138ef972fceb7aeec4be02e3f01d383723965bfcef14d174c8ccd039"}, + {file = "Pillow-8.0.1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:d3d07c86d4efa1facdf32aa878bd508c0dc4f87c48125cc16b937baa4e5b5e11"}, + {file = "Pillow-8.0.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:fbd922f702582cb0d71ef94442bfca57624352622d75e3be7a1e7e9360b07e72"}, + {file = "Pillow-8.0.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:92c882b70a40c79de9f5294dc99390671e07fc0b0113d472cbea3fde15db1792"}, + {file = "Pillow-8.0.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:7c9401e68730d6c4245b8e361d3d13e1035cbc94db86b49dc7da8bec235d0015"}, + {file = "Pillow-8.0.1-cp38-cp38-win32.whl", hash = "sha256:6c1aca8231625115104a06e4389fcd9ec88f0c9befbabd80dc206c35561be271"}, + {file = "Pillow-8.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:cc9ec588c6ef3a1325fa032ec14d97b7309db493782ea8c304666fb10c3bd9a7"}, + {file = "Pillow-8.0.1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:eb472586374dc66b31e36e14720747595c2b265ae962987261f044e5cce644b5"}, + {file = "Pillow-8.0.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:0eeeae397e5a79dc088d8297a4c2c6f901f8fb30db47795113a4a605d0f1e5ce"}, + {file = "Pillow-8.0.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:81f812d8f5e8a09b246515fac141e9d10113229bc33ea073fec11403b016bcf3"}, + {file = "Pillow-8.0.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:895d54c0ddc78a478c80f9c438579ac15f3e27bf442c2a9aa74d41d0e4d12544"}, + {file = "Pillow-8.0.1-cp39-cp39-win32.whl", hash = "sha256:2fb113757a369a6cdb189f8df3226e995acfed0a8919a72416626af1a0a71140"}, + {file = "Pillow-8.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:59e903ca800c8cfd1ebe482349ec7c35687b95e98cefae213e271c8c7fffa021"}, + {file = "Pillow-8.0.1-pp36-pypy36_pp73-macosx_10_10_x86_64.whl", hash = "sha256:5abd653a23c35d980b332bc0431d39663b1709d64142e3652890df4c9b6970f6"}, + {file = "Pillow-8.0.1-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:4b0ef2470c4979e345e4e0cc1bbac65fda11d0d7b789dbac035e4c6ce3f98adb"}, + {file = "Pillow-8.0.1-pp37-pypy37_pp73-win32.whl", hash = "sha256:8de332053707c80963b589b22f8e0229f1be1f3ca862a932c1bcd48dafb18dd8"}, + {file = "Pillow-8.0.1.tar.gz", hash = "sha256:11c5c6e9b02c9dac08af04f093eb5a2f84857df70a7d4a6a6ad461aca803fb9e"}, +] pluggy = [ {file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"}, {file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"}, diff --git a/pyproject.toml b/pyproject.toml index 63c03db..bf2ecd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ sphinx-rtd-theme = "^0.4.3" scikit-survival = "^0.12.0" pandas = "^1.0.3" pytest-cov = "^2.9.0" +matplotlib = "^3.3.2" [build-system] requires = ["poetry>=0.12", "cython", "numpy"] diff --git a/skranger/ensemble/ranger_forest_regressor.py b/skranger/ensemble/ranger_forest_regressor.py index b12a80a..b3935c2 100644 --- a/skranger/ensemble/ranger_forest_regressor.py +++ b/skranger/ensemble/ranger_forest_regressor.py @@ -1,7 +1,4 @@ -"""Scikit-learn wrapper for ranger regression. - -TODO quantreg -""" +"""Scikit-learn wrapper for ranger regression.""" import numpy as np from sklearn.base import BaseEstimator from sklearn.base import RegressorMixin @@ -65,6 +62,8 @@ class specific values. :param bool regularization_usedepth: Whether to consider depth in regularization. :param bool holdout: Hold-out all samples with case weight 0 and use these for feature importance and prediction error. + :param bool quantiles: Enable quantile regression after fitting. This must be + set to ``True`` in order to call ``predict_quantiles`` after fitting. :param bool oob_error: Whether to calculate out-of-bag prediction error. :param int n_jobs: The number of threads. Default is number of CPU cores. :param bool save_memory: Save memory at the cost of speed growing trees. @@ -86,6 +85,8 @@ class specific values. regularization factor input parameter. :ivar int importance_mode\_: The importance mode integer corresponding to ranger enum ``ImportanceMode``. + :ivar 2darray random_node_values\_: Random training target values based on + trained forest terminal nodes for the purpose of quantile regression. """ def __init__( @@ -113,6 +114,7 @@ def __init__( regularization_factor=None, regularization_usedepth=False, holdout=False, + quantiles=False, oob_error=False, n_jobs=-1, save_memory=False, @@ -141,6 +143,7 @@ def __init__( self.regularization_factor = regularization_factor self.regularization_usedepth = regularization_usedepth self.holdout = holdout + self.quantiles = quantiles self.oob_error = oob_error self.n_jobs = n_jobs self.save_memory = save_memory @@ -220,8 +223,99 @@ def fit(self, X, y, sample_weight=None): False, # use_regularization_factor self.regularization_usedepth, ) + + if self.quantiles: + forest = self._get_terminal_node_forest(X) + terminal_nodes = np.array(forest["predictions"]).astype(int) + self.random_node_values_ = np.empty((np.max(terminal_nodes) + 1, self.n_estimators)) + self.random_node_values_[:] = np.nan + for tree in range(self.n_estimators): + idx = np.arange(X.shape[0]) + np.random.shuffle(idx) + self.random_node_values_[terminal_nodes[idx, tree], tree] = y[idx] + return self + def _get_terminal_node_forest(self, X): + """Get a terminal node forest for X. + + :param array2d X: prediction input features + """ + # many fields defaulted here which are unused + forest = ranger.ranger( + self.tree_type_, + np.asfortranarray(X.astype("float64")), + np.asfortranarray([[]]), + self.feature_names_, # variable_names + 0, # m_try + self.n_estimators, # num_trees + self.verbose, + self.seed, + self.n_jobs_, # num_threads + False, # write_forest + 0, # importance_mode + 0, # min_node_size + [], # split_select_weights + False, # use_split_select_weights + [], # always_split_feature_names + False, # use_always_split_feature_names + True, # prediction_mode + self.ranger_forest_["forest"], # loaded_forest + np.asfortranarray([[]]), # snp_data + True, # sample_with_replacement + False, # probability + [], # unordered_feature_names + False, # use_unordered_features + False, # save_memory + 1, # split_rule + [], # case_weights + False, # use_case_weights + [], # class_weights + False, # predict_all + self.keep_inbag, + self.sample_fraction_, + 0, # alpha + 0, # minprop + self.holdout, + 2, # prediction_type (terminal nodes) + 1, # num_random_splits + False, # use_sparse_data + False, # order_snps_ + False, # oob_error + 0, # max_depth + [], # inbag + False, # use_inbag + [], # regularization_factor_ + False, # use_regularization_factor_ + False, # regularization_usedepth + ) + return forest + + def predict_quantiles(self, X, quantiles=None): + """Predict quantile regression target for X. + + :param array2d X: prediction input features + :param list(float) quantiles: a list of quantiles on which to predict. + If the list contains a single quantile, the result will be a 1darray. + If there are multiple quantiles, the result will be a 2darray with + columns corresponding to respective quantiles. Default is ``[0.1, 0.5, 0.9]``. + """ + if not hasattr(self, "random_node_values_"): + raise ValueError("Must set quantiles = True for quantile predictions.") + quantiles = quantiles or [0.1, 0.5, 0.9] + check_is_fitted(self) + X = check_array(X) + + forest = self._get_terminal_node_forest(X) + terminal_nodes = np.array(forest["predictions"]).astype(int) + node_values = 0 * terminal_nodes + for tree in range(self.n_estimators): + node_values[:, tree] = self.random_node_values_[terminal_nodes[:, tree], tree] + quantile_predictions = np.nanquantile(node_values, quantiles, axis=1) + if len(quantiles) == 1: + return np.squeeze(quantile_predictions) + return quantile_predictions + def predict(self, X): """Predict regression target for X. diff --git a/tests/ensemble/test_ranger_forest_regressor.py b/tests/ensemble/test_ranger_forest_regressor.py index dbdf4e4..26e7f0c 100644 --- a/tests/ensemble/test_ranger_forest_regressor.py +++ b/tests/ensemble/test_ranger_forest_regressor.py @@ -6,6 +6,7 @@ import pytest from sklearn.base import clone from sklearn.exceptions import NotFittedError +from sklearn.model_selection import train_test_split from sklearn.utils.validation import check_is_fitted from skranger.ensemble import RangerForestRegressor @@ -220,3 +221,20 @@ def test_always_split_features(self, boston_X, boston_y): # feature 0 is in every tree split for tree in rfc.ranger_forest_["forest"]["split_var_ids"]: assert 0 in tree + + def test_quantile_regression(self, boston_X, boston_y): + X_train, X_test, y_train, y_test = train_test_split(boston_X, boston_y) + rfr = RangerForestRegressor(quantiles=False) + rfr.fit(X_train, y_train) + assert not hasattr(rfr, "random_node_values_") + with pytest.raises(ValueError): + rfr.predict_quantiles(X_test) + rfr = RangerForestRegressor(quantiles=True) + rfr.fit(X_train, y_train) + assert hasattr(rfr, "random_node_values_") + quantiles_lower = rfr.predict_quantiles(X_test, quantiles=[0.1]) + quantiles_upper = rfr.predict_quantiles(X_test, quantiles=[0.9]) + assert np.less(quantiles_lower, quantiles_upper).all() + assert quantiles_upper.ndim == 1 + quantiles = rfr.predict_quantiles(X_test, quantiles=[0.1, 0.9]) + assert quantiles.ndim == 2