Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] conv2d #228

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
64 changes: 54 additions & 10 deletions benchmark/benchmarks_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
import os
import sys
from argparse import ArgumentParser
from dataclasses import dataclass

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

DATA_PATH = "data/all_benchmark_data.csv"
VISUALIZATIONS_PATH = "visualizations/"
DATA_PATH = os.path.join(os.path.dirname(__file__), "data/all_benchmark_data.csv")
VISUALIZATIONS_PATH = os.path.join(os.path.dirname(__file__), "visualizations/")


@dataclass
Expand All @@ -32,27 +33,42 @@ class VisualizationsConfig:
overwrite: bool = False


def get_available_options():
csv_path = os.path.join(os.path.dirname(__file__), DATA_PATH)
df = pd.read_csv(csv_path)
return {
"kernel_name": df["kernel_name"].unique().tolist(),
"metric_name": df["metric_name"].unique().tolist(),
"kernel_operation_mode": df["kernel_operation_mode"].unique().tolist(),
}


def parse_args() -> VisualizationsConfig:
"""Parse command line arguments into a configuration object.

Returns:
VisualizationsConfig: Configuration object for the visualizations script.
"""
parser = ArgumentParser()
available_options = get_available_options()

parser = ArgumentParser(description="Visualize benchmark data", add_help=False)
parser.add_argument(
"--kernel-name", type=str, required=True, help="Kernel name to benchmark"
"-h", "--help", action="store_true", help="Show this help message and exit"
)
parser.add_argument(
"--kernel-name",
type=str,
help=f"Kernel name to benchmark. Options: {', '.join(available_options['kernel_name'])}",
)
parser.add_argument(
"--metric-name",
type=str,
required=True,
help="Metric name to visualize (speed/memory)",
help=f"Metric name to visualize. Options: {', '.join(available_options['metric_name'])}",
)
parser.add_argument(
"--kernel-operation-mode",
type=str,
required=True,
help="Kernel operation mode to visualize (forward/backward/full)",
help=f"Kernel operation mode to visualize. Options: {', '.join(available_options['kernel_operation_mode'])}",
)
parser.add_argument(
"--display", action="store_true", help="Display the visualization"
Expand All @@ -65,7 +81,35 @@ def parse_args() -> VisualizationsConfig:

args = parser.parse_args()

return VisualizationsConfig(**dict(args._get_kwargs()))
if args.help or len(sys.argv) == 1:
parser.print_help()
print("\nAvailable options:")
for arg, options in available_options.items():
print(f" {arg}: {', '.join(options)}")
sys.exit(0)

if not all([args.kernel_name, args.metric_name, args.kernel_operation_mode]):
parser.error(
"--kernel-name, --metric-name, and --kernel-operation-mode are required arguments"
)

if args.kernel_name not in available_options["kernel_name"]:
parser.error(
f"Invalid kernel name. Choose from: {', '.join(available_options['kernel_name'])}"
)
if args.metric_name not in available_options["metric_name"]:
parser.error(
f"Invalid metric name. Choose from: {', '.join(available_options['metric_name'])}"
)
if args.kernel_operation_mode not in available_options["kernel_operation_mode"]:
parser.error(
f"Invalid kernel operation mode. Choose from: {', '.join(available_options['kernel_operation_mode'])}"
)

args_dict = vars(args)
args_dict.pop("help", None)

return VisualizationsConfig(**args_dict)


def load_data(config: VisualizationsConfig) -> pd.DataFrame:
Expand Down Expand Up @@ -119,7 +163,7 @@ def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
hue="kernel_provider",
marker="o",
palette="tab10",
errorbar=("ci", None),
errorbar=None,
)

# Seaborn can't plot pre-computed error bars, so we need to do it manually
Expand Down
Loading
Loading