-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Example of use of MLFlow with XGBoost and LogisticRegression
- Loading branch information
Remi Tschupp
committed
Jun 25, 2024
1 parent
f4856c5
commit b79d883
Showing
1 changed file
with
278 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,278 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 41, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import mlflow\n", | ||
"from mlflow.models import infer_signature\n", | ||
"\n", | ||
"\n", | ||
"from sklearn import datasets\n", | ||
"from sklearn.model_selection import train_test_split\n", | ||
"from sklearn.linear_model import LogisticRegression\n", | ||
"from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix,roc_auc_score\n", | ||
"\n", | ||
"from xgboost import XGBClassifier\n", | ||
"\n", | ||
"import seaborn as sns\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"import pandas as pd\n", | ||
"import os" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 42, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"gen_dirname = os.path.dirname(os.path.abspath(''))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Load the dataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 43, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"type_of_dataset = \"gentle\"\n", | ||
"\n", | ||
"labeled_data = pd.read_csv(os.path.join(gen_dirname,f\"data\\{type_of_dataset}\\labelled.csv\"))\n", | ||
"\n", | ||
"labels = labeled_data[\"Survived\"]\n", | ||
"inputs = labeled_data.drop(\"Survived\",axis=\"columns\")\n", | ||
"\n", | ||
"X_train, X_test, y_train, y_test = train_test_split(inputs,labels,test_size=0.3,random_state=42) # We are fixing the split so every run is comparable " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Models" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 44, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"list_models = []" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### XGBoost" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 45, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Define the model hyperparameters\n", | ||
"params_xgb = {\n", | ||
" \"n_estimators\":20,\n", | ||
" \"max_depth\":100,\n", | ||
" \"learning_rate\": 0.3,\n", | ||
" \"objective\": \"binary:logistic\",\n", | ||
"}\n", | ||
"\n", | ||
"# Create model instance\n", | ||
"bst = XGBClassifier(**params_xgb)\n", | ||
"\n", | ||
"# Fit the model\n", | ||
"bst.fit(X_train, y_train)\n", | ||
"\n", | ||
"# # Infer the model signature\n", | ||
"# signature = infer_signature(X_train, bst.predict(X_train))\n", | ||
"\n", | ||
"# # Log the model\n", | ||
"# model_info = mlflow.xgboost.autolog()\n", | ||
"\n", | ||
"# Register in list \n", | ||
"list_models.append([\"XGBoost\",params_xgb,bst,mlflow.xgboost.log_model])#,model_info])\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### LogisticRegression" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 46, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"c:\\Users\\RT277831\\Documents\\Projets\\Dauphine\\ML_OPS\\venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py:1247: FutureWarning: 'multi_class' was deprecated in version 1.5 and will be removed in 1.7. From then on, it will always use 'multinomial'. Leave it to its default value to avoid this warning.\n", | ||
" warnings.warn(\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Define the model hyperparameters\n", | ||
"params_lr = {\n", | ||
" \"solver\": \"lbfgs\",\n", | ||
" \"max_iter\": 1000,\n", | ||
" \"multi_class\": \"auto\",\n", | ||
" \"random_state\": 8888,\n", | ||
"}\n", | ||
"\n", | ||
"# Create model instance\n", | ||
"lr = LogisticRegression(**params_lr)\n", | ||
"\n", | ||
"# Fit the model\n", | ||
"lr.fit(X_train, y_train)\n", | ||
"\n", | ||
"# # Log the model\n", | ||
"# model_info = mlflow.sklearn.autolog()\n", | ||
"\n", | ||
"# Register in list \n", | ||
"list_models.append([\"LogisticRegression\",params_lr,lr,mlflow.sklearn.log_model])#,model_info])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Evaluation metrics" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"We are gonna set our metrics that will help compare our different models, because it is a classification task we are gonna focus on AUC, accuracy, recall, confusion matrix." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 47, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def eval_metrics(actual, preds):\n", | ||
" # Calculate metrics\n", | ||
"\n", | ||
" # Accuracy\n", | ||
" accuracy = accuracy_score(actual, preds)\n", | ||
"\n", | ||
" # recall\n", | ||
" recall = recall_score(actual, preds)\n", | ||
"\n", | ||
" # AUC\n", | ||
" auc = roc_auc_score(actual, preds)\n", | ||
"\n", | ||
" # Confusion matrix\n", | ||
" cnf_matr = confusion_matrix(actual,preds)\n", | ||
"\n", | ||
" return accuracy, recall, auc, cnf_matr" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## MLFlow part" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 48, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Set our tracking server uri for logging\n", | ||
"mlflow.set_tracking_uri(uri=\"http://127.0.0.1:5000\")\n", | ||
"\n", | ||
"# Create a new MLflow Experiment\n", | ||
"mlflow.set_experiment(\"Titanic MLFlow demo\")\n", | ||
"\n", | ||
"\n", | ||
"for name,params,model,log_model in list_models:\n", | ||
"\n", | ||
" # Start an MLflow run\n", | ||
" with mlflow.start_run():\n", | ||
" # Log the hyperparameters\n", | ||
" mlflow.log_params(params)\n", | ||
" \n", | ||
" preds = model.predict(X_test)\n", | ||
"\n", | ||
" # Log the metric\n", | ||
" accuracy, recall, auc, cnf_matr = eval_metrics(y_test,preds)\n", | ||
" mlflow.log_metric(\"accuracy\", accuracy)\n", | ||
" mlflow.log_metric(\"recall\", recall)\n", | ||
" mlflow.log_metric(\"auc\", auc)\n", | ||
"\n", | ||
" fig, ax = plt.subplots()\n", | ||
"\n", | ||
" sns.heatmap(cnf_matr, annot=True)\n", | ||
" ax.set_title(\"Feature confusion Matrix\", fontsize=14)\n", | ||
" plt.tight_layout()\n", | ||
" plt.close(fig)\n", | ||
"\n", | ||
" mlflow.log_figure(fig, \"confusion_matrix.png\")\n", | ||
"\n", | ||
" # Set a tag that we can use to remind ourselves what this run was for\n", | ||
" mlflow.set_tag(\"Training Info\", f\"{name} model training for {type_of_dataset} titanic dataset\")\n", | ||
"\n", | ||
" mlflow.set_tag(\"mlflow.runName\", f\"{name}\")\n", | ||
"\n", | ||
"\n", | ||
" # model_info = log_model()\n", | ||
" # # Infer the model signature\n", | ||
" # signature = infer_signature(X_train, model.predict(X_train))\n", | ||
"\n", | ||
" # model_info =log_model(\n", | ||
" # artifact_path=f\"{type_of_dataset}_{name}\",\n", | ||
" # signature=signature,\n", | ||
" # input_example=X_train,\n", | ||
" # registered_model_name=f\" {name}\",\n", | ||
" # )" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |