Skip to content

Commit

Permalink
abe cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
Abe Burton authored and Abe Burton committed Nov 27, 2023
1 parent 676a61f commit ae6bb32
Showing 1 changed file with 5 additions and 146 deletions.
151 changes: 5 additions & 146 deletions supervised_ml.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,7 @@
"id": "d23b5405-6491-43ed-90fb-cf7aa2056553",
"metadata": {},
"source": [
"we had to group by pickup area and dropoff area seperately- daily counts of number of trips to that particular community area when it was either a pickup or dropoff area"
"We had to group by pickup area and dropoff area seperately then sum to create daily counts of number of trips to that particular community area when it was either a pickup or dropoff area"
]
},
{
Expand Down Expand Up @@ -1384,11 +1384,9 @@
"# ML Model\n",
"\n",
"1. Create Datasets that are for data pre-program (Oct 2021) and for data between Oct 2021 up to not including july 2023.\n",
"2. Get Cross Validated Model Running\n",
"3. Train model on first dataset. predict for october, november, december 2021\n",
"4. plot predictions (dotted line for predictions, solid line for actual)\n",
"5. train new model on second dataset\n",
"6. plot for july, \n"
"2. Train model on first dataset. predict for october, november, december 2021\n",
"3. train new model on second dataset, make more future predictions\n",
"4. Plot graphs of all predictions versus actual ride counts over the years \n"
]
},
{
Expand Down Expand Up @@ -1425,20 +1423,6 @@
"df_3 = merged_df.filter(((merged_df.year == 2023) & (merged_df.month >= 7)))"
]
},
{
"cell_type": "markdown",
"id": "49227805-612c-4ff6-b714-74f2de65f556",
"metadata": {},
"source": [
"## I'm trying to organize my thoughts about prediction here, hopefully this makes some sense.\n",
"\n",
"The way we've been thinking about this model is that it is predicting daily counts, so our predictions should be daily as well.\n",
"So I think I should take the rows from df2 that are in 10,11,12 and make predictions for each one. put those in a dataframe. Group by month and summ predicted rides.\n",
"Then take the actual sums from df2, group by month and sum. Plot those against each other\n",
"\n",
"I'm going to work on doing this without cross validation first, and then move to cross validation depending on time."
]
},
{
"cell_type": "code",
"execution_count": 11,
Expand Down Expand Up @@ -1505,7 +1489,7 @@
"id": "370b8f4e-f6bc-4245-889d-e1803afed90c",
"metadata": {},
"source": [
"Using df1 (all pre-program data) to predict what would happen to count of rides if the program had not happened at all."
"Using df1 (all pre-program data) to predict what would happen to count of rides if the program had not happened at all. Do this by creating the model using only pre-program data. Predict outcomes for future dates and compare to the actual ride counts on those dates."
]
},
{
Expand Down Expand Up @@ -3726,131 +3710,6 @@
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "49a4a9b4-3ff9-4bdb-81d4-ae8587546ad8",
"metadata": {},
"source": [
"### Notes for Abe on work done:\n",
"\n",
"- Added predicted pre-program trends based on pre-program data.\n",
"- Added predictions for after program expansion based on pre-program data.\n",
"- Made all the plots nicer using Seaborn\n",
"- Created a final plot for data from 2021-2023. "
]
},
{
"cell_type": "markdown",
"id": "54413c6b-49f7-418a-a9b9-97862a4976a9",
"metadata": {},
"source": [
"## Abe Notes for harsh\n",
"\n",
"Ok so I ran these models without cross validation and saved them to the models folder in GCS. We can do cross validation if you have time to debug and run. If you have to rerun any of my code, I would consider repartitioning because its still going through 600 partitions for the tiny datset we're working with. The plots could be improved by putting in the pre-trends and giving them better titles and labels as well I just ran out of time to do that. It shouldn't take long to do that. Let me know if you're confused about anything I did. I think its possible this is good enough for a final product once the plots have pre-trends added and things are cleaned up. That could also give you more time to check on our groupmates if necessary haha. We're getting close."
]
},
{
"cell_type": "markdown",
"id": "4a8f0db1-4c39-4a3f-a340-303747b3ae29",
"metadata": {},
"source": [
"# pseudocode if we want to do cross validation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a2a0a967-aa2f-4a97-bb0d-811bf528eb0a",
"metadata": {},
"outputs": [],
"source": [
"#taken from ashish's notebook- to check if we need to convert anything to labels\n",
"def labelForResults(s):\n",
" if s == 'Fail':\n",
" return 0.0\n",
" elif s == 'Pass w/ Conditions' or s == 'Pass':\n",
" return 1.0\n",
" else:\n",
" return -1.0\n",
" \n",
"#add labels to the original dataset \n",
"label = UserDefinedFunction(labelForResults, DoubleType())\n",
"labeledData = df.select(df.Violations,label(df.Results).alias('label')).where('label >= 0')\n",
"labeledData.show(10, truncate=False)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "115c0a40-ec37-419a-ae05-4115adad625f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"23/11/25 03:43:39 WARN org.apache.spark.sql.execution.CacheManager: Asked to cache already cached data.\n",
"23/11/25 03:43:39 WARN org.apache.spark.sql.execution.CacheManager: Asked to cache already cached data.\n"
]
},
{
"ename": "IllegalArgumentException",
"evalue": "label does not exist. Available: month, year, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 40, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, area_sums, temp, precip, snow, snowdepth, sunset, features, CrossValidator_82ced15f4dc7_rand, prediction",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIllegalArgumentException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[45], line 14\u001b[0m\n\u001b[1;32m 8\u001b[0m crossval \u001b[38;5;241m=\u001b[39m CrossValidator(estimator\u001b[38;5;241m=\u001b[39mlr,\n\u001b[1;32m 9\u001b[0m estimatorParamMaps\u001b[38;5;241m=\u001b[39mparamGrid,\n\u001b[1;32m 10\u001b[0m evaluator\u001b[38;5;241m=\u001b[39mRegressionEvaluator(),\n\u001b[1;32m 11\u001b[0m numFolds\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m) \n\u001b[1;32m 13\u001b[0m \u001b[38;5;66;03m# Run cross-validation, and choose the best set of parameters.\u001b[39;00m\n\u001b[0;32m---> 14\u001b[0m cvModel \u001b[38;5;241m=\u001b[39m \u001b[43mcrossval\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_df\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/usr/lib/spark/python/pyspark/ml/base.py:161\u001b[0m, in \u001b[0;36mEstimator.fit\u001b[0;34m(self, dataset, params)\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcopy(params)\u001b[38;5;241m.\u001b[39m_fit(dataset)\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 161\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 163\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mParams must be either a param map or a list/tuple of param maps, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbut got \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mtype\u001b[39m(params))\n",
"File \u001b[0;32m/usr/lib/spark/python/pyspark/ml/tuning.py:687\u001b[0m, in \u001b[0;36mCrossValidator._fit\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 684\u001b[0m train \u001b[38;5;241m=\u001b[39m datasets[i][\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mcache()\n\u001b[1;32m 686\u001b[0m tasks \u001b[38;5;241m=\u001b[39m _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)\n\u001b[0;32m--> 687\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m j, metric, subModel \u001b[38;5;129;01min\u001b[39;00m pool\u001b[38;5;241m.\u001b[39mimap_unordered(\u001b[38;5;28;01mlambda\u001b[39;00m f: f(), tasks):\n\u001b[1;32m 688\u001b[0m metrics[j] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (metric \u001b[38;5;241m/\u001b[39m nFolds)\n\u001b[1;32m 689\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m collectSubModelsParam:\n",
"File \u001b[0;32m/opt/conda/miniconda3/lib/python3.8/multiprocessing/pool.py:868\u001b[0m, in \u001b[0;36mIMapIterator.next\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 866\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[1;32m 867\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value\n\u001b[0;32m--> 868\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m value\n",
"File \u001b[0;32m/opt/conda/miniconda3/lib/python3.8/multiprocessing/pool.py:125\u001b[0m, in \u001b[0;36mworker\u001b[0;34m(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception)\u001b[0m\n\u001b[1;32m 123\u001b[0m job, i, func, args, kwds \u001b[38;5;241m=\u001b[39m task\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 125\u001b[0m result \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 126\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m wrap_exception \u001b[38;5;129;01mand\u001b[39;00m func \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _helper_reraises_exception:\n",
"File \u001b[0;32m/usr/lib/spark/python/pyspark/ml/tuning.py:687\u001b[0m, in \u001b[0;36mCrossValidator._fit.<locals>.<lambda>\u001b[0;34m(f)\u001b[0m\n\u001b[1;32m 684\u001b[0m train \u001b[38;5;241m=\u001b[39m datasets[i][\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mcache()\n\u001b[1;32m 686\u001b[0m tasks \u001b[38;5;241m=\u001b[39m _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)\n\u001b[0;32m--> 687\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m j, metric, subModel \u001b[38;5;129;01min\u001b[39;00m pool\u001b[38;5;241m.\u001b[39mimap_unordered(\u001b[38;5;28;01mlambda\u001b[39;00m f: \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m, tasks):\n\u001b[1;32m 688\u001b[0m metrics[j] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (metric \u001b[38;5;241m/\u001b[39m nFolds)\n\u001b[1;32m 689\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m collectSubModelsParam:\n",
"File \u001b[0;32m/usr/lib/spark/python/pyspark/ml/tuning.py:74\u001b[0m, in \u001b[0;36m_parallelFitTasks.<locals>.singleTask\u001b[0;34m()\u001b[0m\n\u001b[1;32m 69\u001b[0m index, model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(modelIter)\n\u001b[1;32m 70\u001b[0m \u001b[38;5;66;03m# TODO: duplicate evaluator to take extra params from input\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \u001b[38;5;66;03m# Note: Supporting tuning params in evaluator need update method\u001b[39;00m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;66;03m# `MetaAlgorithmReadWrite.getAllNestedStages`, make it return\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;66;03m# all nested stages and evaluators\u001b[39;00m\n\u001b[0;32m---> 74\u001b[0m metric \u001b[38;5;241m=\u001b[39m \u001b[43meva\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalidation\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepm\u001b[49m\u001b[43m[\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m index, metric, model \u001b[38;5;28;01mif\u001b[39;00m collectSubModel \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m/usr/lib/spark/python/pyspark/ml/evaluation.py:84\u001b[0m, in \u001b[0;36mEvaluator.evaluate\u001b[0;34m(self, dataset, params)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcopy(params)\u001b[38;5;241m.\u001b[39m_evaluate(dataset)\n\u001b[1;32m 83\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 84\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mParams must be a param map but got \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mtype\u001b[39m(params))\n",
"File \u001b[0;32m/usr/lib/spark/python/pyspark/ml/evaluation.py:120\u001b[0m, in \u001b[0;36mJavaEvaluator._evaluate\u001b[0;34m(self, dataset)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03mEvaluates the output.\u001b[39;00m\n\u001b[1;32m 108\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[38;5;124;03m evaluation metric\u001b[39;00m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_transfer_params_to_java()\n\u001b[0;32m--> 120\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_java_obj\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_jdf\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/usr/lib/spark/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py:1304\u001b[0m, in \u001b[0;36mJavaMember.__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 1298\u001b[0m command \u001b[38;5;241m=\u001b[39m proto\u001b[38;5;241m.\u001b[39mCALL_COMMAND_NAME \u001b[38;5;241m+\u001b[39m\\\n\u001b[1;32m 1299\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcommand_header \u001b[38;5;241m+\u001b[39m\\\n\u001b[1;32m 1300\u001b[0m args_command \u001b[38;5;241m+\u001b[39m\\\n\u001b[1;32m 1301\u001b[0m proto\u001b[38;5;241m.\u001b[39mEND_COMMAND_PART\n\u001b[1;32m 1303\u001b[0m answer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgateway_client\u001b[38;5;241m.\u001b[39msend_command(command)\n\u001b[0;32m-> 1304\u001b[0m return_value \u001b[38;5;241m=\u001b[39m \u001b[43mget_return_value\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1305\u001b[0m \u001b[43m \u001b[49m\u001b[43manswer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgateway_client\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtarget_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1307\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m temp_arg \u001b[38;5;129;01min\u001b[39;00m temp_args:\n\u001b[1;32m 1308\u001b[0m temp_arg\u001b[38;5;241m.\u001b[39m_detach()\n",
"File \u001b[0;32m/usr/lib/spark/python/pyspark/sql/utils.py:117\u001b[0m, in \u001b[0;36mcapture_sql_exception.<locals>.deco\u001b[0;34m(*a, **kw)\u001b[0m\n\u001b[1;32m 113\u001b[0m converted \u001b[38;5;241m=\u001b[39m convert_exception(e\u001b[38;5;241m.\u001b[39mjava_exception)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(converted, UnknownException):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;66;03m# Hide where the exception came from that shows a non-Pythonic\u001b[39;00m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;66;03m# JVM exception message.\u001b[39;00m\n\u001b[0;32m--> 117\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m converted \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n",
"\u001b[0;31mIllegalArgumentException\u001b[0m: label does not exist. Available: month, year, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 40, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, area_sums, temp, precip, snow, snowdepth, sunset, features, CrossValidator_82ced15f4dc7_rand, prediction"
]
}
],
"source": [
"# an open question for me is that in Ashish's code and online examples, the model is fit and tested on the same full dataset. which is confusing to me\n",
"# I would think you would fit on training and transform the test set. Something to look into.\n",
"\n",
"# look to adjust the parameters regparam and elastic net\n",
"\n",
"paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.1, 0.01]).addGrid(\n",
" lr.elasticNetParam, [0.2, 0.4,0.6, 0.8,1.0]).build() \n",
"crossval = CrossValidator(estimator=lr,\n",
" estimatorParamMaps=paramGrid,\n",
" evaluator=RegressionEvaluator(),\n",
" numFolds=10) \n",
"\n",
"# Run cross-validation, and choose the best set of parameters.\n",
"cvModel = crossval.fit(train_df)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7296e8c3-ddff-4124-bf9d-01755471bda1",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Save the model in a location so we don't have to rerun this\n",
"model_path = \"gs://msca-bdp-student-gcs/bdp-rideshare-project/models/supervised_model/\"\n",
"cvModel.write().save(model_path)\n",
"\n",
"# Read the model back in\n",
"cvModelRead = CrossValidatorModel.read().load(model_path)\n",
"\n",
"# Make predictions on test documents. cvModel uses the best model found (lrModel).\n",
"predictions = cvModel.transform(test_df)"
]
}
],
"metadata": {
Expand Down

0 comments on commit ae6bb32

Please sign in to comment.