Skip to content

Commit

Permalink
More plots
Browse files Browse the repository at this point in the history
  • Loading branch information
robinholzi committed Sep 24, 2024
1 parent 53d3b24 commit 17058cf
Show file tree
Hide file tree
Showing 6 changed files with 778 additions and 90 deletions.
50 changes: 32 additions & 18 deletions analytics/plotting/common/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

def build_heatmap(
heatmap_data: pd.DataFrame,
y_ticks: list[int] | None = None,
y_ticks: list[int] | list[str] | None = None,
y_ticks_bins: int | None = None,
x_ticks: list[int] | None = None,
x_custom_ticks: list[tuple[int, str]] | None = None, # (position, label)
y_custom_ticks: list[tuple[int, str]] | None = None, # (position, label)
reverse_col: bool = False,
y_label: str = "Reference Year",
x_label: str = "Current Year",
Expand All @@ -34,6 +36,7 @@ def build_heatmap(
policy: list[tuple[int, int, int]] = [],
cmap: Any | None = None,
linewidth: int = 2,
grid_alpha: float = 0.0,
) -> Figure | Axes:
init_plot()
setup_font(small_label=True, small_title=True)
Expand All @@ -53,7 +56,7 @@ def build_heatmap(
heatmap_data,
cmap=("RdBu" + ("_r" if reverse_col else "")) if not cmap else cmap,
linewidths=0.0,
linecolor="black",
linecolor="white",
# color bar from 0 to 1
cbar_kws={
"label": color_label,
Expand Down Expand Up @@ -84,22 +87,34 @@ def build_heatmap(

# Adjust x-axis tick labels
ax.set_xlabel(x_label)
if not x_ticks:
if not x_ticks and not x_custom_ticks:
ax.set_xticks(
ticks=[x + 0.5 for x in range(0, 2010 - 1930 + 1, 20)],
labels=[x for x in range(1930, 2010 + 1, 20)],
rotation=0,
# ha='right'
)
else:
ax.set_xticks(
ticks=[x - 1930 + 0.5 for x in x_ticks],
labels=[x for x in x_ticks],
rotation=0,
# ha='right'
)
if x_custom_ticks:
ax.set_xticks(
ticks=[x[0] for x in x_custom_ticks],
labels=[x[1] for x in x_custom_ticks],
rotation=0,
# ha='right'
)
else:
assert x_ticks is not None
ax.set_xticks(
ticks=[x - 1930 + 0.5 for x in x_ticks],
labels=[x for x in x_ticks],
rotation=0,
# ha='right'
)
ax.invert_yaxis()

ax.grid(axis="y", linestyle="--", alpha=grid_alpha, color="white")
ax.grid(axis="x", linestyle="--", alpha=grid_alpha, color="white")

if y_ticks is not None:
ax.set_yticks(
ticks=[y + 0.5 - 1930 for y in y_ticks],
Expand All @@ -109,21 +124,20 @@ def build_heatmap(
elif y_ticks_bins is not None:
ax.yaxis.set_major_locator(MaxNLocator(nbins=y_ticks_bins))
ax.set_yticklabels([int(i) + min(heatmap_data.index) for i in ax.get_yticks()], rotation=0)
else:
if y_custom_ticks:
ax.set_yticks(
ticks=[y[0] for y in y_custom_ticks],
labels=[y[1] for y in y_custom_ticks],
rotation=0,
# ha='right'
)

ax.set_ylabel(y_label)

if title_label:
ax.set_title(title_label)

# drift_pipeline = []

# TODO visualize policy
# Draft training boxes
# if drift_pipeline:
# x_start = active_[1][f"_start"].year - 1930
# x_end = active_[1][f"{type_}_end"].year - 1930
# y = active_[1]["model_idx"]

previous_y = 0
for x_start, x_end, y in policy:
# main box
Expand Down
100 changes: 100 additions & 0 deletions analytics/plotting/common/linear_regression_scatterplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Any

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from analytics.plotting.common.color import main_color
from analytics.plotting.common.common import init_plot
from analytics.plotting.common.font import setup_font

# Create the heatmap


def scatter_linear_regression(
data: pd.DataFrame,
x: str,
y: str,
hue: str,
y_ticks: list[int] | list[str] | None = None,
x_ticks: list[int] | None = None,
y_label: str = "Reference Year",
x_label: str = "Current Year",
height_factor: float = 1.0,
width_factor: float = 1.0,
legend_label: str = "Number Samples",
title_label: str = "",
target_ax: Axes | None = None,
palette: Any = None,
) -> Figure | tuple[Axes, Axes]:
sns.set_style("whitegrid")

init_plot()
setup_font(small_label=True, small_title=True)

DOUBLE_FIG_WIDTH = 10
DOUBLE_FIG_HEIGHT = 3.5

if not target_ax:
fig = plt.figure(
edgecolor="black",
frameon=True,
figsize=(
DOUBLE_FIG_WIDTH * width_factor,
2 * DOUBLE_FIG_HEIGHT * height_factor,
),
dpi=300,
)

ax1 = sns.regplot(
data,
x=x,
y=y, # duration
color=main_color(0),
)

ax2 = sns.scatterplot(
data,
x=x,
y=y, # duration
hue=hue,
palette=palette,
s=200,
legend=True,
marker="X",
)

ax2.legend(title=legend_label, ncol=2, handletextpad=0, columnspacing=0.5, fontsize="x-small")
# ax2.legend().set_title(legend_label)

# Adjust x-axis tick labels
ax2.set_xlabel(x_label)
if x_ticks is not None:
ax2.set_xticks(
ticks=x_ticks,
labels=x_ticks,
rotation=0,
# ha='right'
)

if y_ticks is not None:
ax2.set_yticks(
ticks=y_ticks,
labels=y_ticks,
rotation=0,
)

ax2.set_ylabel(y_label)

if title_label:
ax2.set_title(title_label)

print("Number of plotted items", data.shape[0])

# Display the plot
plt.tight_layout()
# plt.show()

return fig if not target_ax else (ax1, ax2)
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
"source": [
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"\n",
"from analytics.app.data.load import list_pipelines\n",
"from analytics.plotting.common.common import init_plot\n",
"from analytics.plotting.common.font import setup_font\n",
"from analytics.plotting.common.color import discrete_colors\n",
"from analytics.plotting.common.linear_regression_scatterplot import scatter_linear_regression\n",
"from modyn.supervisor.internal.grpc.enums import PipelineStage\n",
"from modyn.supervisor.internal.pipeline_executor.models import StageLog\n",
"\n",
Expand All @@ -30,11 +28,13 @@
"source": [
"# INPUTS\n",
"\n",
"# pipelines_dir = Path(\"/Users/robinholzinger/robin/dev/eth/modyn-robinholzi-data/data/triggering/yearbook/11_baselines_amount\")\n",
"pipelines_dir = Path(\n",
" \"/Users/robinholzinger/robin/dev/eth/modyn-robinholzi-data/data/triggering/arxiv/11_baselines_amount\"\n",
" \"/Users/robinholzinger/robin/dev/eth/modyn-robinholzi-data/data/triggering/huffpost/11_baselines_amount\"\n",
")\n",
"# pipelines_dir = Path(\"/Users/robinholzinger/robin/dev/eth/modyn-robinholzi-data/data/triggering/huffpost/11_baselines_amount\")\n",
"# pipelines_dir = Path(\"/Users/robinholzinger/robin/dev/eth/modyn-robinholzi-data/data/triggering/yearbook/11_baselines_amount\")\n",
"# pipelines_dir = Path(\n",
"# \"/Users/robinholzinger/robin/dev/eth/modyn-robinholzi-data/data/triggering/arxiv/11_baselines_amount\"\n",
"# )\n",
"output_dir = Path(\"/Users/robinholzinger/robin/dev/eth/modyn-2/.analytics.log/.data/_plots\")\n",
"assert pipelines_dir.exists()\n",
"assert output_dir.exists()"
Expand Down Expand Up @@ -62,6 +62,25 @@
"pipeline_logs = {p_id: load_pipeline_logs(p_id, pipelines_dir) for (p_id, (_, p_path)) in pipelines.items()}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# extract number of epochs\n",
"num_epochs: int | None = None\n",
"\n",
"for p_id, logs in pipeline_logs.items():\n",
" for log in logs:\n",
" if num_epochs is None:\n",
" num_epochs = logs.config.pipeline.training.epochs_per_trigger\n",
" else:\n",
" assert num_epochs == logs.config.pipeline.training.epochs_per_trigger\n",
"\n",
"assert num_epochs"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -107,7 +126,7 @@
"\n",
"\n",
"def pipeline_name_cleaner(name: str):\n",
" return re.sub(r\".*_dataamount_(\\d+)\", \"trigger every \\\\1 samples\", name)\n",
" return re.sub(r\".*dataamount_(\\d+)\", r\"\\1\", name)\n",
"\n",
"\n",
"df_train[\"pipeline_id\"] = df_train[\"pipeline_id\"].apply(pipeline_name_cleaner)\n",
Expand All @@ -125,6 +144,18 @@
"# df_train[\"duration\"] = df_train[\"duration\"].dt.total_seconds()\n",
"# df_train[\"train_time_at_trainer\"] = df_train[\"train_time_at_trainer\"] / 1_000 # millis to seconds\n",
"df_train[\"train_time_at_trainer\"] = df_train[\"train_time_at_trainer\"] / 1_000 / 60 # millis to minutes\n",
"\n",
"# vs. number of passed sample: num_samples\n",
"df_train[\"num_input_samples\"] = df_train[\"num_samples\"] / num_epochs\n",
"\n",
"\n",
"dataset = pipelines_dir.parent.name\n",
"\n",
"if dataset != \"yearbook\":\n",
" df_train[\"num_input_samples\"] = df_train[\"num_input_samples\"] / 1_000\n",
" df_train[\"pipeline_id\"] = (df_train[\"pipeline_id\"].astype(int) // 1_000).astype(str) + \"k\"\n",
"\n",
"\n",
"df_train"
]
},
Expand All @@ -151,61 +182,36 @@
"metadata": {},
"outputs": [],
"source": [
"from analytics.plotting.common.color import discrete_colors, main_color\n",
"from analytics.plotting.common.save import save_plot\n",
"\n",
"sns.set_style(\"whitegrid\")\n",
"\n",
"init_plot()\n",
"setup_font(small_label=True, small_title=True)\n",
"\n",
"\n",
"FONTSIZE = 20\n",
"DOUBLE_FIG_WIDTH = 10\n",
"DOUBLE_FIG_HEIGHT = 3.5\n",
"DOUBLE_FIG_SIZE = (DOUBLE_FIG_WIDTH, 1.5 * DOUBLE_FIG_HEIGHT)\n",
"\n",
"width_factor = 0.5\n",
"height_factor = 0.5\n",
"\n",
"fig = plt.figure(\n",
" edgecolor=\"black\",\n",
" frameon=True,\n",
" figsize=(\n",
" DOUBLE_FIG_WIDTH * width_factor,\n",
" 2 * DOUBLE_FIG_HEIGHT * height_factor,\n",
" ),\n",
" dpi=300,\n",
")\n",
"\n",
"ax1 = sns.regplot(\n",
"fig = scatter_linear_regression(\n",
" df_train,\n",
" x=\"num_samples\",\n",
" y=\"train_time_at_trainer\", # duration\n",
" color=main_color(0),\n",
")\n",
"\n",
"ax2 = sns.scatterplot(\n",
" df_train,\n",
" x=\"num_samples\",\n",
" y=\"train_time_at_trainer\", # duration\n",
" x=\"num_input_samples\",\n",
" y=\"train_time_at_trainer\", # duration is broken due to bug in grpc interface\n",
" hue=\"pipeline_id\",\n",
" palette=(\n",
" discrete_colors(14)[0:5] + discrete_colors(14)[9:14]\n",
" discrete_colors(14)[0:4] + discrete_colors(14)[10:14]\n",
" if \"yearbook\" in str(pipelines_dir)\n",
" else (\n",
" discrete_colors(8)[0:3] + discrete_colors(8)[6:8]\n",
" discrete_colors(12)[0:4] + discrete_colors(12)[9:12]\n",
" if \"huffpost\" in str(pipelines_dir)\n",
" else discrete_colors(8)[0:3] + discrete_colors(8)[6:8]\n",
" )\n",
" ),\n",
" s=200,\n",
" legend=True,\n",
" marker=\"X\",\n",
" title_label=\"Training Size (Samples) vs. Cost (Time)\",\n",
" x_label=\"#Trained Samples (k) / #Epochs\",\n",
" y_label=\"Duration (min)\",\n",
" legend_label=\"Trigger every\",\n",
" height_factor=0.5,\n",
" width_factor=0.575,\n",
" # x_ticks=[],\n",
" # y_ticks=[],\n",
")\n",
"\n",
"# Display the plot\n",
"plt.tight_layout()\n",
"plt.show()"
"save_plot(\n",
" fig=fig,\n",
" name=dataset + \"_training_size_vs_cost\",\n",
")"
]
},
{
Expand All @@ -218,6 +224,13 @@
"# TODO: plot / add number of datapoints to thesis so that the signicance of regression line is clear\n",
"# State in thesis that there are no outliers to be expected!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 17058cf

Please sign in to comment.