diff --git a/example-notebook/example_classification_model.ipynb b/example-notebook/example_classification_model.ipynb index cc5d67f..213e915 100644 --- a/example-notebook/example_classification_model.ipynb +++ b/example-notebook/example_classification_model.ipynb @@ -12,8 +12,7 @@ "import inspect\n", "\n", "import numpy as np\n", - "import pandas as pd\n", - "import plotly.graph_objects as go" + "import pandas as pd" ] }, { @@ -37,10 +36,10 @@ { "data": { "text/html": [ - "
RandomForestClassifier(max_depth=6, oob_score=True)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" + "
RandomForestClassifier(max_depth=6, oob_score=True, random_state=123)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ - "RandomForestClassifier(max_depth=6, oob_score=True)" + "RandomForestClassifier(max_depth=6, oob_score=True, random_state=123)" ] }, "execution_count": 2, @@ -59,10 +58,10 @@ " random_state=12, shuffle=False, weights = [0.8, 0.2])\n", "\n", "# Train - test split\n", - "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, stratify = y, random_state=0)\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, stratify = y, random_state=123)\n", "\n", "# Train a RF classifier\n", - "cls = RandomForestClassifier(max_depth=6, oob_score=True)\n", + "cls = RandomForestClassifier(max_depth=6, oob_score=True, random_state=123)\n", "cls.fit(X_train, y_train)" ] }, @@ -271,76 +270,76 @@ { "customdata": [ [ - 1.8829064005331113 + 1.8292313717410853 ], [ - 0.8829064005331113 + 0.8292313717410853 ], [ - 0.4338412959831805 + 0.42701050390153383 ], [ - 0.39782772199062094 + 0.4163757258146578 ], [ - 0.3757061042924499 + 0.3740970508252879 ], [ - 0.37243543040526134 + 0.3197570322138689 ], [ - 0.3525129725783062 + 0.31895755890522137 ], [ - 0.3301962740255774 + 0.3180382144381978 ], [ - 0.3280321800602266 + 0.316526854126403 ], [ - 0.300888575200243 + 0.2593442209225198 ], [ - 0.2968368395264936 + 0.24975360339458408 ], [ - 0.2683259862999592 + 0.2302493601716992 ], [ - 0.2669456333638167 + 0.22852341628131667 ], [ - 0.2641184486339815 + 0.22678447137367905 ], [ - 0.25929236881564155 + 0.22459451038542572 ], [ - 0.25884609489023336 + 0.22030280143233244 ], [ - 0.24918148149241456 + 0.21882829151595085 ], [ - 0.24889171205469662 + 0.19869987552466334 ], [ - 0.2451121675467701 + 0.1980313758868512 ], [ - 0.20271656460844986 + 0.1923312854892404 ], [ - 0.1999642147144849 + 0.18919332379255235 ], [ - 0.19745487816192991 + 0.0750746554091418 ], [ - 0.1957958401674481 + 0.0730143388762927 ], [ - 0.025389097878735077 + 0.018586278192252906 ] ], "hovertemplate": "Threshold: %{customdata:.4f}
False Positive Rate: %{x:.4f}
True Positive Rate: %{y:.4f}", @@ -354,7 +353,7 @@ "symbol": "circle" }, "mode": "lines", - "name": "ROC Curve (AUC=0.977)", + "name": "ROC Curve (AUC=0.955)", "orientation": "v", "showlegend": true, "textposition": "top center", @@ -365,44 +364,44 @@ 0, 0.006289308176100629, 0.006289308176100629, - 0.012578616352201259, - 0.012578616352201259, - 0.025157232704402517, - 0.025157232704402517, + 0.031446540880503145, + 0.031446540880503145, 0.03773584905660377, 0.03773584905660377, - 0.0880503144654088, - 0.0880503144654088, - 0.09433962264150944, - 0.09433962264150944, 0.10062893081761007, 0.10062893081761007, - 0.1069182389937107, - 0.1069182389937107, - 0.16352201257861634, - 0.16352201257861634, - 0.16981132075471697, - 0.16981132075471697, + 0.1320754716981132, + 0.1320754716981132, + 0.13836477987421383, + 0.13836477987421383, + 0.1509433962264151, + 0.1509433962264151, + 0.1949685534591195, + 0.1949685534591195, + 0.2138364779874214, + 0.2138364779874214, + 0.6666666666666666, + 0.6666666666666666, 1 ], "xaxis": "x", "y": [ 0, 0.024390243902439025, - 0.6585365853658537, - 0.6585365853658537, - 0.7073170731707317, - 0.7073170731707317, + 0.5853658536585366, + 0.5853658536585366, 0.7560975609756098, 0.7560975609756098, 0.7804878048780488, 0.7804878048780488, + 0.8048780487804879, + 0.8048780487804879, 0.8292682926829268, 0.8292682926829268, - 0.8536585365853658, - 0.8536585365853658, 0.8780487804878049, 0.8780487804878049, + 0.9024390243902439, + 0.9024390243902439, 0.926829268292683, 0.926829268292683, 0.9512195121951219, @@ -1297,9 +1296,9 @@ } }, "text/html": [ - "