-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add experiments with dvc metrics to latex
- Loading branch information
Showing
1 changed file
with
336 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,336 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import dvc.api\n", | ||
"import pandas as pd\n", | ||
"\n", | ||
"pd.set_option(\"display.max_columns\", 500)\n", | ||
"\n", | ||
"metrics = dvc.api.metrics_show(\"../output/energy_metrics.json\")\n", | ||
"metrics" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"exps = dvc.api.exp_show(param_deps=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"exps" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# tags = dvc.api.scm.all_tags()\n", | ||
"# exps = dvc.api.exp_show(revs=tags)\n", | ||
"df = pd.DataFrame(exps)\n", | ||
"df.head()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"energy_metrics = [col for col in df.columns if \"energy_metrics.json\" in col]\n", | ||
"# bikes_metrics = [col for col in df.columns if \"bikes_metrics.json:\" in col]\n", | ||
"params_col = [col for col in df.columns if \"energy.selected\" in col]\n", | ||
"df = pd.DataFrame(exps, columns=energy_metrics + params_col)\n", | ||
"df = df.dropna()\n", | ||
"# strip the prefix in the col name until : (but not for Created col)\n", | ||
"df.columns = [col.split(\":\")[1] if \":\" in col else col for col in df.columns]\n", | ||
"df = df.iloc[2:].sort_values(\"train.energy.selected\")\n", | ||
"df.head()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df = df.drop(columns=[\"avg_fit_time\", \"avg_pred_time\"])\n", | ||
"df.head()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create a function to format metrics as mean (±std)\n", | ||
"def format_table_with_metrics():\n", | ||
" # Define display name mapping\n", | ||
" display_names = {\n", | ||
" \"pinball_loss\": \"Pinball Loss\",\n", | ||
" \"interval_score_50\": \"50\\% Interval Score\",\n", | ||
" \"interval_score_95\": \"95\\% Interval Score\",\n", | ||
" \"coverage_50\": \"50\\% Coverage\",\n", | ||
" \"coverage_95\": \"95\\% Coverage\",\n", | ||
" }\n", | ||
"\n", | ||
" # Create new formatted DataFrame\n", | ||
" formatted_table = pd.DataFrame()\n", | ||
" formatted_table[\"Model\"] = df[\"train.energy.selected\"]\n", | ||
"\n", | ||
" # Format each metric with mean and std\n", | ||
" for metric in metrics:\n", | ||
" means = df[f\"{metric}.mean\"]\n", | ||
" stds = df[f\"{metric}.std\"]\n", | ||
" # Use the display name mapping for column names\n", | ||
" formatted_table[display_names[metric]] = (\n", | ||
" means.map(\"{:.2f}\".format) + \" (± \" + stds.map(\"{:.2f}\".format) + \")\"\n", | ||
" )\n", | ||
"\n", | ||
" return formatted_table.to_latex(index=False)\n", | ||
"\n", | ||
"\n", | ||
"print(format_table_with_metrics())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# adjust names of models\n", | ||
"df[\"train.energy.selected\"] = df[\"train.energy.selected\"].replace(\n", | ||
" {\n", | ||
" \"lgbm\": \"LightGBM\",\n", | ||
" \"xgb-custom\": \"XGBoost\",\n", | ||
" \"catboost\": \"CatBoost\",\n", | ||
" \"quantreg\": \"Quantile Regression\",\n", | ||
" \"benchmark\": \"Benchmark\",\n", | ||
" }\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create a function to format metrics as mean (±std) with highlighting\n", | ||
"def format_table_with_metrics_and_highlighting(df, target=\"energy\"):\n", | ||
" # Define display name mapping\n", | ||
" display_names = {\n", | ||
" \"pinball_loss\": \"Pinball Loss\",\n", | ||
" \"interval_score_50\": r\"50\\% Interval Score\",\n", | ||
" \"interval_score_95\": r\"95\\% Interval Score\",\n", | ||
" \"coverage_50\": r\"50\\% PI Coverage\",\n", | ||
" \"coverage_95\": r\"95\\% PI Coverage\",\n", | ||
" }\n", | ||
"\n", | ||
" # Create new formatted DataFrame\n", | ||
" formatted_table = pd.DataFrame()\n", | ||
" formatted_table[\"Model\"] = df[f\"train.{target}.selected\"]\n", | ||
"\n", | ||
" # Format metrics with highlighting\n", | ||
" for metric in [\"pinball_loss\", \"interval_score_50\", \"interval_score_95\"]:\n", | ||
" means = df[f\"{metric}.mean\"]\n", | ||
" stds = df[f\"{metric}.std\"]\n", | ||
" min_idx = means.idxmin()\n", | ||
"\n", | ||
" values = []\n", | ||
" for idx in means.index:\n", | ||
" if idx == min_idx:\n", | ||
" values.append(f\"\\\\textbf{{{means[idx]:.2f} (± {stds[idx]:.2f})}}\")\n", | ||
" else:\n", | ||
" values.append(f\"{means[idx]:.2f} (± {stds[idx]:.2f})\")\n", | ||
" formatted_table[display_names[metric]] = values\n", | ||
"\n", | ||
" # Handle coverage metrics\n", | ||
" for metric, target in [(\"coverage_50\", 0.5), (\"coverage_95\", 0.95)]:\n", | ||
" means = df[f\"{metric}.mean\"]\n", | ||
" stds = df[f\"{metric}.std\"]\n", | ||
" # Find index closest to target\n", | ||
" closest_idx = (means - target).abs().idxmin()\n", | ||
"\n", | ||
" values = []\n", | ||
" for idx in means.index:\n", | ||
" if idx == closest_idx:\n", | ||
" values.append(f\"\\\\textbf{{{means[idx]:.2f} (± {stds[idx]:.2f})}}\")\n", | ||
" else:\n", | ||
" values.append(f\"{means[idx]:.2f} (± {stds[idx]:.2f})\")\n", | ||
" formatted_table[display_names[metric]] = values\n", | ||
"\n", | ||
" return formatted_table\n", | ||
"\n", | ||
"\n", | ||
"print(format_table_with_metrics_and_highlighting(df))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"formatted_table = format_table_with_metrics_and_highlighting(df).drop(\n", | ||
" columns=[r\"50\\% Interval Score\", r\"95\\% Interval Score\"]\n", | ||
")\n", | ||
"print(\n", | ||
" formatted_table.to_latex(\n", | ||
" index=False,\n", | ||
" escape=False,\n", | ||
" caption=\"Results of Timeseries Cross-Validation on the Hourly Eletricity Demand Dataset. Best values are highlighted in bold.\",\n", | ||
" label=\"tab:energy_results\",\n", | ||
" position=\"htp\",\n", | ||
" )\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def format_dvc_experiments(\n", | ||
" exps: dict,\n", | ||
" target=\"energy\",\n", | ||
" allow_duplicated_models=False,\n", | ||
" caption: str | None = None,\n", | ||
" label: str | None = None,\n", | ||
") -> tuple[pd.DataFrame, pd.DataFrame, str]:\n", | ||
" df = pd.DataFrame(exps)\n", | ||
" metrics = [col for col in df.columns if f\"{target}_metrics.json\" in col]\n", | ||
" params_col = [f\"train.{target}.selected\"]\n", | ||
"\n", | ||
" df = pd.DataFrame(exps, columns=metrics + params_col)\n", | ||
" df = df.dropna()\n", | ||
" if not allow_duplicated_models:\n", | ||
" df = df.drop_duplicates(subset=params_col)\n", | ||
"\n", | ||
" # strip the prefix in the col name until :\n", | ||
" df.columns = [col.split(\":\")[1] if \":\" in col else col for col in df.columns]\n", | ||
" df = df.drop(columns=[\"avg_fit_time\", \"avg_pred_time\"])\n", | ||
" df = df.sort_values(\"pinball_loss.mean\", ascending=False, ignore_index=True)\n", | ||
"\n", | ||
" # adjust names of models\n", | ||
" df[f\"train.{target}.selected\"] = df[f\"train.{target}.selected\"].replace(\n", | ||
" {\n", | ||
" \"lgbm\": \"LightGBM\",\n", | ||
" \"xgb-custom\": \"XGBoost\",\n", | ||
" \"catboost\": \"CatBoost\",\n", | ||
" \"quantreg\": \"Quantile Regression\",\n", | ||
" \"benchmark\": \"Benchmark\",\n", | ||
" }\n", | ||
" )\n", | ||
"\n", | ||
" formatted_table = format_table_with_metrics_and_highlighting(df, target=target)\n", | ||
" formatted_table = formatted_table.drop(\n", | ||
" columns=[r\"50\\% Interval Score\", r\"95\\% Interval Score\"]\n", | ||
" )\n", | ||
"\n", | ||
" latex_code = formatted_table.to_latex(\n", | ||
" index=False, escape=False, caption=caption, label=label, position=\"htp\"\n", | ||
" )\n", | ||
" return df, formatted_table, latex_code" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"_, _, bikes_table = format_dvc_experiments(\n", | ||
" exps,\n", | ||
" target=\"bikes\",\n", | ||
" caption=\"Results of Timeseries Cross-Validation on the Daily Bike Count Dataset. Best values are highlighted in bold.\",\n", | ||
" label=\"tab:bikes_results\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"_, _, energy_table = format_dvc_experiments(\n", | ||
" exps,\n", | ||
" target=\"energy\",\n", | ||
" caption=\"Results of Timeseries Cross-Validation on the Hourly Electricity Demand Dataset. Best values are highlighted in bold.\",\n", | ||
" label=\"tab:energy_results\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"print(bikes_table)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"print(energy_table)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |