Skip to content

Commit

Permalink
Added k-fold cross validation
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerldixon committed Mar 10, 2018
1 parent a5fe950 commit ba99668
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 124 deletions.
119 changes: 57 additions & 62 deletions .ipynb_checkpoints/1. Classifying Partial Permits-checkpoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
},
{
"cell_type": "code",
"execution_count": 382,
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -31,15 +31,15 @@
"from keras.utils import np_utils\n",
"from sklearn.model_selection import cross_val_score\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.model_selection import StratifiedKFold\n",
"from sklearn.model_selection import KFold\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.preprocessing import MinMaxScaler\n",
"from sklearn.pipeline import Pipeline"
]
},
{
"cell_type": "code",
"execution_count": 383,
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": 384,
"execution_count": 41,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1520,7 +1520,7 @@
"[75891 rows x 16 columns]"
]
},
"execution_count": 384,
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -1545,7 +1545,7 @@
},
{
"cell_type": "code",
"execution_count": 385,
"execution_count": 42,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -2884,7 +2884,7 @@
"[75891 rows x 14 columns]"
]
},
"execution_count": 385,
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -2905,7 +2905,7 @@
},
{
"cell_type": "code",
"execution_count": 386,
"execution_count": 43,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -4244,7 +4244,7 @@
"[75891 rows x 14 columns]"
]
},
"execution_count": 386,
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -4265,7 +4265,7 @@
},
{
"cell_type": "code",
"execution_count": 387,
"execution_count": 44,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -4301,7 +4301,7 @@
},
{
"cell_type": "code",
"execution_count": 388,
"execution_count": 45,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -4371,7 +4371,7 @@
"Name: Purpose, Length: 75891, dtype: object"
]
},
"execution_count": 388,
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -4391,7 +4391,7 @@
},
{
"cell_type": "code",
"execution_count": 389,
"execution_count": 46,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -5417,7 +5417,7 @@
"[75891 rows x 12 columns]"
]
},
"execution_count": 389,
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -5437,7 +5437,7 @@
},
{
"cell_type": "code",
"execution_count": 390,
"execution_count": 47,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -7336,7 +7336,7 @@
"[75891 rows x 9400 columns]"
]
},
"execution_count": 390,
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -7356,7 +7356,7 @@
},
{
"cell_type": "code",
"execution_count": 391,
"execution_count": 48,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -9003,7 +9003,7 @@
"[75891 rows x 9400 columns]"
]
},
"execution_count": 391,
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -9027,7 +9027,7 @@
},
{
"cell_type": "code",
"execution_count": 392,
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -9036,7 +9036,7 @@
},
{
"cell_type": "code",
"execution_count": 393,
"execution_count": 50,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -9068,7 +9068,7 @@
},
{
"cell_type": "code",
"execution_count": 394,
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -9091,7 +9091,7 @@
},
{
"cell_type": "code",
"execution_count": 395,
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -9108,32 +9108,32 @@
},
{
"cell_type": "code",
"execution_count": 396,
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"60712/60712 [==============================] - 148s 2ms/step - loss: 0.7000 - acc: 0.7340\n",
"60712/60712 [==============================] - 150s 2ms/step - loss: 0.7000 - acc: 0.7341\n",
"Epoch 2/5\n",
"60712/60712 [==============================] - 146s 2ms/step - loss: 0.3237 - acc: 0.8088\n",
"60712/60712 [==============================] - 147s 2ms/step - loss: 0.3236 - acc: 0.8089\n",
"Epoch 3/5\n",
"60712/60712 [==============================] - 152s 3ms/step - loss: 0.2564 - acc: 0.8270\n",
"60712/60712 [==============================] - 143s 2ms/step - loss: 0.2564 - acc: 0.8270\n",
"Epoch 4/5\n",
"60712/60712 [==============================] - 145s 2ms/step - loss: 0.2190 - acc: 0.8382\n",
"60712/60712 [==============================] - 143s 2ms/step - loss: 0.2190 - acc: 0.8382\n",
"Epoch 5/5\n",
"60712/60712 [==============================] - 143s 2ms/step - loss: 0.1948 - acc: 0.8446\n"
"60712/60712 [==============================] - 148s 2ms/step - loss: 0.1948 - acc: 0.8447\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x1acb952b0>"
"<keras.callbacks.History at 0x118098940>"
]
},
"execution_count": 396,
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -9145,15 +9145,15 @@
},
{
"cell_type": "code",
"execution_count": 397,
"execution_count": 54,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"15179/15179 [==============================] - 31s 2ms/step\n",
"acc: 81.85%\n"
"15179/15179 [==============================] - 33s 2ms/step\n",
"acc: 81.84%\n"
]
}
],
Expand All @@ -9174,7 +9174,7 @@
},
{
"cell_type": "code",
"execution_count": 543,
"execution_count": 55,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -9261,7 +9261,7 @@
"[1 rows x 9400 columns]"
]
},
"execution_count": 543,
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -9275,7 +9275,7 @@
},
{
"cell_type": "code",
"execution_count": 544,
"execution_count": 56,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -9338,7 +9338,7 @@
"1 0 0 0 0 0 0 0 0 1 0 0 0"
]
},
"execution_count": 544,
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -9350,7 +9350,7 @@
},
{
"cell_type": "code",
"execution_count": 545,
"execution_count": 57,
"metadata": {},
"outputs": [
{
Expand All @@ -9376,42 +9376,37 @@
"source": [
"### Evaluating our model with K-Fold Cross Validation\n",
"\n",
"We'll use k-fold validation to get a better representation of how our model did..."
"We'll use k-fold validation to get a better representation of how our model did.\n",
"We'll first build an estimator using `KerasClassifier` which is a wrapper to make our model work nicely with sci-kit learn's validators..."
]
},
{
"cell_type": "code",
"execution_count": 398,
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"estimator = KerasClassifier(build_fn=build_model, epochs=epochs, batch_size=batch_size, verbose=0)\n",
"k_fold = KFold(n_splits=10, shuffle=True, random_state=seed)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'[ 9400 9401 9402 ... 75888 75889 75890] not in index'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-398-5c8de40a14e6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mkfold\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbuild_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mscores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"%s: %.2f%%\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics_names\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscores\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 2131\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mSeries\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIndex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2132\u001b[0m \u001b[0;31m# either boolean or fancy integer index\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2133\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2134\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mDataFrame\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2135\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getitem_frame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m_getitem_array\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 2175\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_take\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconvert\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2176\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2177\u001b[0;31m \u001b[0mindexer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_convert_to_indexer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2178\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_take\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconvert\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2179\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/site-packages/pandas/core/indexing.py\u001b[0m in \u001b[0;36m_convert_to_indexer\u001b[0;34m(self, obj, axis, is_setter)\u001b[0m\n\u001b[1;32m 1267\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1268\u001b[0m raise KeyError('{mask} not in index'\n\u001b[0;32m-> 1269\u001b[0;31m .format(mask=objarr[mask]))\n\u001b[0m\u001b[1;32m 1270\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1271\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0m_values_from_object\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindexer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyError\u001b[0m: '[ 9400 9401 9402 ... 75888 75889 75890] not in index'"
"name": "stdout",
"output_type": "stream",
"text": [
"Baseline: 82.23% (0.43%)\n"
]
}
],
"source": [
"k_fold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)\n",
"cv_scores = []\n",
"results = cross_val_score(estimator, data, labels, cv=k_fold)\n",
"\n",
"for train, test in kfold.split(data, labels):\n",
" model = build_model()\n",
" model.fit(data[train], labels[train], epochs=epochs, batch_size=10, verbose=0)\n",
" scores = model.evaluate(data[test], labels[test], verbose=0)\n",
" print(\"%s: %.2f%%\" % (model.metrics_names[1], scores[1]*100))\n",
" \n",
" cv_scores.append(scores[1] * 100)\n",
" \n",
"print(\"%.2f%% (+/- %.2f%%)\" % (numpy.mean(cv_scores), numpy.std(cv_scores)))"
"print(\"Baseline: %.2f%% (%.2f%%)\" % (results.mean()*100, results.std()*100))"
]
},
{
Expand Down
Loading

0 comments on commit ba99668

Please sign in to comment.