Skip to content

Commit

Permalink
heatmap
Browse files Browse the repository at this point in the history
  • Loading branch information
robinholzi committed Sep 27, 2024
1 parent 7529234 commit 7050320
Show file tree
Hide file tree
Showing 9 changed files with 738 additions and 26 deletions.
58 changes: 52 additions & 6 deletions analytics/plotting/common/heatmap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal

import matplotlib.patches as patches
import pandas as pd
Expand All @@ -14,6 +14,31 @@
from analytics.plotting.common.font import setup_font


def get_fractional_index(dates: pd.Series, query_date: pd.Timestamp, fractional: bool = True) -> float:
"""Given a list of Period objects (dates) and a query_date as a Period,
return the interpolated fractional index between two period indices if the
query_date lies between them."""
# Ensure query_date is within the bounds of the period range
if query_date < dates[0].start_time:
return -1 # -1 before first index

if query_date > dates[-1].start_time:
return len(dates) # +1 after last index

# Find the two periods where the query_date falls in between
for i in range(len(dates) - 1):
if dates[i].start_time <= query_date <= dates[i + 1].start_time:
# Perform linear interpolation, assuming equal length periods
return i + (
((query_date - dates[i].start_time) / (dates[i + 1].start_time - dates[i].start_time))
if fractional
else 0
)

# If query_date is exactly one of the dates
return dates.get_loc(query_date)


def build_heatmap(
heatmap_data: pd.DataFrame,
y_ticks: list[int] | list[str] | None = None,
Expand All @@ -39,7 +64,8 @@ def build_heatmap(
grid_alpha: float = 0.0,
disable_horizontal_grid: bool = False,
df_logs_models: pd.DataFrame | None = None,
triggers: dict[int, list[pd.Timestamp]] = {},
triggers: dict[int, pd.DataFrame] = {},
x_axis: Literal["int", "period"] = "year",
) -> Figure | Axes:
init_plot()
setup_font(small_label=True, small_title=True)
Expand Down Expand Up @@ -92,7 +118,7 @@ def build_heatmap(
ax.set_xlabel(x_label)
if not x_ticks and not x_custom_ticks:
ax.set_xticks(
ticks=[x + 0.5 for x in range(0, 2010 - 1930 + 1, 20)],
ticks=[x + 0.5 for x in range(0, 2010 - 1930 + 1, 20)], # TODO: check 0.5
labels=[x for x in range(1930, 2010 + 1, 20)],
rotation=0,
# ha='right'
Expand Down Expand Up @@ -184,11 +210,31 @@ def build_heatmap(
if df_logs_models is not None:
for type_, dashed in [("train", False), ("usage", False), ("train", True)]:
for active_ in df_logs_models.iterrows():
x_start = active_[1][f"{type_}_start"].year - 1930
x_end = active_[1][f"{type_}_end"].year - 1930
if x_axis == "year":
x_start = active_[1][f"{type_}_start"].year - 1930
x_end = active_[1][f"{type_}_end"].year - 1930
else:
# start_idx = get_fractional_index(heatmap_data.columns, start_date)
# end_idx = get_fractional_index(heatmap_data.columns, end_date)
# x_start = heatmap_data.columns.get_loc(active_[1][f"{type_}_start"])
# x_end = heatmap_data.columns.get_loc(active_[1][f"{type_}_end"])
x_start = get_fractional_index(
heatmap_data.columns,
active_[1][f"{type_}_start"],
fractional=False,
)
x_end = get_fractional_index(
heatmap_data.columns,
active_[1][f"{type_}_end"],
fractional=False,
)

y = active_[1]["model_idx"]
rect = plt.Rectangle(
(x_start, y - 1), # y: 0 based index, model_idx: 1 based index
(
x_start,
y - 1,
), # y: 0 based index, model_idx: 1 based index
x_end - x_start,
1,
edgecolor="White" if type_ == "train" else "Black",
Expand Down
16 changes: 15 additions & 1 deletion analytics/plotting/rh_thesis/TODO.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
drift:

- plot arxiv / huffpost
- plot arxiv

performance:

- 1 cost plot
- 1 single pipeline heatmap
- 1 multi pipeline heatmap for every dataset (including best of every subtype)

cost:

- 1 dummy plot

discussion:

- tradeoff plot: 1 per dataset
286 changes: 286 additions & 0 deletions analytics/plotting/rh_thesis/drift/arxiv_heatmap_single.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"import pandas as pd\n",
"\n",
"from analytics.app.data.load import list_pipelines\n",
"from analytics.app.data.transform import dfs_models_and_evals, logs_dataframe\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pipelines_dir = Path(\n",
" \"/Users/robinholzinger/robin/dev/eth/modyn-robinholzi-data/data/triggering/arxiv/21_datadrift_dynamic\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pipelines = list_pipelines(pipelines_dir)\n",
"max_pipeline_id = max(pipelines.keys())\n",
"pipelines"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from analytics.app.data.load import load_pipeline_logs\n",
"\n",
"pipeline_logs = {p_id: load_pipeline_logs(p_id, pipelines_dir) for (p_id, (_, p_path)) in pipelines.items()}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# mode:\n",
"pipeline_id = 771 # hp drifttrigger_mmd-rollavg-2.0-20_int1500_win1y\n",
"\n",
"# doesn't do anything unless include_composite_model = True\n",
"composite_model_variant = \"currently_active_model\"\n",
"\n",
"patch_yearbook = True\n",
"dataset_id = \"huffpost_kaggle_test\"\n",
"eval_handler = \"periodic-current\"\n",
"metric = \"Accuracy\"\n",
"include_composite_model = False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Wrangle data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pipeline_log = pipeline_logs[pipeline_id]\n",
"pipeline_ref = f\"{pipeline_id}\".zfill(len(str(max_pipeline_id))) + f\" - {pipelines[pipeline_id][0]}\"\n",
"\n",
"df_all = logs_dataframe(pipeline_log, pipeline_ref)\n",
"\n",
"df_logs_models, _, df_eval_single = dfs_models_and_evals(\n",
" # subtracting would interfere with yearbook patching\n",
" pipeline_log,\n",
" df_all[\"sample_time\"].max(),\n",
" pipeline_ref,\n",
")\n",
"\n",
"df_adjusted = df_eval_single\n",
"\n",
"\n",
"df_adjusted = df_adjusted[\n",
" (df_adjusted[\"dataset_id\"] == dataset_id)\n",
" & (df_adjusted[\"eval_handler\"] == eval_handler)\n",
" & (df_adjusted[\"metric\"] == metric)\n",
"]\n",
"\n",
"# in percent (0-100)\n",
"df_adjusted[\"value\"] = df_adjusted[\"value\"] * 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_adjusted = df_adjusted.sort_values(by=[\"interval_center\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Add composite model\n",
"\n",
"assert df_adjusted[\"pipeline_ref\"].nunique() <= 1\n",
"# add the pipeline time series which is the performance of different models stitched together dep.\n",
"# w.r.t which model was active\n",
"pipeline_composite_model = df_adjusted[df_adjusted[composite_model_variant]]\n",
"pipeline_composite_model[\"model_idx\"] = 0\n",
"pipeline_composite_model[\"id_model\"] = 0\n",
"\n",
"label_map = {k: f\"{k}\" for k, v in df_adjusted[[\"model_idx\", \"id_model\"]].values}\n",
"label_map[0] = \"Pipeline composite model\"\n",
"\n",
"if include_composite_model:\n",
" df_adjusted = pd.concat([pipeline_composite_model, df_adjusted])\n",
"else:\n",
" df_adjusted[\"model_idx\"] = df_adjusted[\"model_idx\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create Plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_adjusted = df_adjusted.sort_values(by=[\"interval_center\"])\n",
"df_adjusted[\"interval_center\"] = df_adjusted[\"interval_center\"].dt.to_period(\"M\")\n",
"df_adjusted"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_train_end_years_per_model = df_logs_models[[\"model_idx\", \"real_train_end\"]]\n",
"df_train_end_years_per_model[\"real_train_end\"] = df_train_end_years_per_model[\"real_train_end\"].dt.to_period(\"M\")\n",
"df_train_end_years_per_model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_merged = df_adjusted.merge(df_train_end_years_per_model, on=\"model_idx\", how=\"left\")\n",
"df_merged"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_merged.groupby([\"real_train_end\", \"interval_center\"]).size()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# build heatmap matrix dataframe:\n",
"df_merged[\"real_train_end\"] = df_merged[\"real_train_end\"].apply(lambda x: pd.Period(x, freq=\"M\"))\n",
"heatmap_data = df_merged.pivot(index=[\"real_train_end\"], columns=\"interval_center\", values=\"value\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"heatmap_data.index.min(), heatmap_data.index.max()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"heatmap_data.index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from analytics.plotting.common.heatmap import build_heatmap\n",
"from analytics.plotting.common.save import save_plot\n",
"\n",
"fig = build_heatmap(\n",
" heatmap_data,\n",
" reverse_col=True,\n",
" x_custom_ticks=[\n",
" (i, f\"{period.to_timestamp().strftime('%b %Y')}\".replace(\" \", \"\\n\"))\n",
" for i, period in list(enumerate(heatmap_data.columns))[::1]\n",
" if period in [pd.Period(\"Apr 2014\"), pd.Period(\"Jul 2018\"), pd.Period(\"Jan 2022\")]\n",
" ],\n",
" y_custom_ticks=[\n",
" (i + 0.5, f\"{period.to_timestamp().strftime('%b %Y')}\")\n",
" for i, period in list(enumerate(heatmap_data.index))[::1]\n",
" ],\n",
" y_label=\"Trained up to\",\n",
" x_label=\"Evaluation Year\",\n",
" title_label=\"HuffPost Dynamic Threshold\\nRolling Average: Δ +200%\",\n",
" color_label=\"Accuracy %\",\n",
" width_factor=0.6,\n",
" height_factor=0.61,\n",
" # grid_alpha=0.4,\n",
" grid_alpha=0.0,\n",
" # disable_horizontal_grid=True,\n",
" # cbar=False,\n",
" df_logs_models=df_logs_models,\n",
" x_axis=\"period\",\n",
")\n",
"save_plot(fig, \"arxiv_trigger_heatmap_drift_single_dynamic\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 7050320

Please sign in to comment.