Skip to content

Commit

Permalink
Move the arg list parsing to dedicated functions
Browse files Browse the repository at this point in the history
  • Loading branch information
niermann999 committed Jan 9, 2025
1 parent 3a6224b commit 9e2e844
Show file tree
Hide file tree
Showing 8 changed files with 260 additions and 215 deletions.
7 changes: 3 additions & 4 deletions tests/tools/python/impl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .plot_benchmark_results import (
read_benchmark_data,
add_track_multiplicity_column,
prepare_data,
plot_benchmark,
generate_plot_series,
prepare_benchmark_data,
plot_benchmark_case,
plot_benchmark_data,
)
from .plot_navigation_validation import (
read_scan_data,
Expand Down
27 changes: 13 additions & 14 deletions tests/tools/python/impl/plot_benchmark_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def add_track_multiplicity_column(df):
""" Read the benchmark data and prepare it for plotting """


def prepare_data(logging, input_dir, file):
def prepare_benchmark_data(logging, input_dir, file):

# Convert benchmark timings to 'ms'
unit_conversion = {"ns": 10**-6, "um": 10**-3, "ms": 1, "s": 10**3}
Expand Down Expand Up @@ -109,10 +109,13 @@ def prepare_data(logging, input_dir, file):
return context, data


""" Plot the benchmark latency for different hardware and algebra plugins """
"""
Plot the benchmark latency and throughout for different hardware backends and
algebra plugins
"""


def plot_benchmark(
def plot_benchmark_case(
context,
df,
plot_factory,
Expand All @@ -139,15 +142,11 @@ def plot_benchmark(

if plot is None:
# Create new plot
box_anchor_x = 1.0
box_anchor_y = 1.02

lgd_ops = plotting.legend_options(
loc=ldg_loc, horiz_anchor=box_anchor_x, vert_anchor=box_anchor_y
loc=ldg_loc, horiz_anchor=1.0, vert_anchor=1.02
)

labels = label_dict[data_type]

x_axis_opts = plotting.axis_options(
label=labels.x_axis, log_scale=True, tick_positions=n_tracks
)
Expand Down Expand Up @@ -186,7 +185,7 @@ def plot_benchmark(
""" Plot the data of all benchmark files given in 'data_files' """


def generate_plot_series(
def plot_benchmark_data(
logging,
input_dir,
det_name,
Expand All @@ -211,14 +210,14 @@ def generate_plot_series(
# Go through all benchmark data files for this hardware backend type
for i, file in enumerate(file_list):
# Benchmark results for the next algebra plugin
context, data = prepare_data(logging, input_dir, file)
context, data = prepare_benchmark_data(logging, input_dir, file)
marker = next(marker_style_cycle)

# Initialize plots
if i == 0:

# Plot the data against the number of tracks
latency_plot = plot_benchmark(
latency_plot = plot_benchmark_case(
context=context,
df=data,
plot_factory=plot_factory,
Expand All @@ -228,7 +227,7 @@ def generate_plot_series(
title=title,
)

throughput_plot = plot_benchmark(
throughput_plot = plot_benchmark_case(
context=context,
df=data,
plot_factory=plot_factory,
Expand All @@ -243,7 +242,7 @@ def generate_plot_series(

# Add new data to plots
else:
plot_benchmark(
plot_benchmark_case(
context=context,
df=data,
plot_factory=plot_factory,
Expand All @@ -253,7 +252,7 @@ def generate_plot_series(
plot=plots.latency,
)

plot_benchmark(
plot_benchmark_case(
context=context,
df=data,
plot_factory=plot_factory,
Expand Down
39 changes: 11 additions & 28 deletions tests/tools/python/material_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
parse_plotting_options,
)
from utils import read_detector_name
from utils import add_track_generator_args, add_propagation_args, add_detector_io_args

# python includes
import argparse
Expand Down Expand Up @@ -93,38 +94,20 @@ def __main__():

# Pass on the options for the validation tools
args_list = [
"--geometry_file",
args.geometry_file,
"--material_file",
args.material_file,
"--phi_steps",
str(args.phi_steps),
"--eta_steps",
str(args.eta_steps),
"--eta_range",
str(args.eta_range[0]),
str(args.eta_range[1]),
"--tol",
str(args.tolerance),
"--min_mask_tolerance",
str(args.min_mask_tol),
"--max_mask_tolerance",
str(args.max_mask_tol),
"--overstep_tolerance",
str(args.overstep_tol),
"--path_tolerance",
str(args.path_tol),
"--rk-tolerance",
str(args.rk_error_tol),
"--path_limit",
str(args.path_limit),
"--search_window",
str(args.search_window[0]),
str(args.search_window[1]),
]

if args.grid_file:
args_list = args_list + ["--grid_file", args.grid_file]
# Add parsed options to argument list
add_detector_io_args(args_list, args)
add_track_generator_args(args_list, args)
add_propagation_args(args_list, args)

if "--material_file" not in args_list:
logging.error(
"Detector material is required! Please add it using the '--material_file' option"
)
sys.exit(1)

# Run the host validation and produce the truth data
logging.debug("Running CPU material validation")
Expand Down
44 changes: 6 additions & 38 deletions tests/tools/python/navigation_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from plotting import pyplot_factory as plt_factory
from utils import read_detector_name
from utils import add_track_generator_args, add_propagation_args, add_detector_io_args

# python imports
import argparse
Expand Down Expand Up @@ -124,45 +125,12 @@ def __main__():
# -----------------------------------------------------------------------run

# Pass on the options for the validation tools
args_list = [
"--data_dir",
datadir,
"--geometry_file",
args.geometry_file,
"--n_tracks",
str(args.n_tracks),
"--randomize_charge",
str(args.randomize_charge),
"--pT_range",
str(args.transv_momentum_range[0]),
str(args.transv_momentum_range[1]),
"--eta_range",
str(args.eta_range[0]),
str(args.eta_range[1]),
"--min_mask_tolerance",
str(args.min_mask_tol),
"--max_mask_tolerance",
str(args.max_mask_tol),
"--mask_tolerance_scalor",
str(args.mask_tol_scalor),
"--overstep_tolerance",
str(args.overstep_tol),
"--path_tolerance",
str(args.path_tol),
"--rk-tolerance",
str(args.rk_error_tol),
"--path_limit",
str(args.path_limit),
"--search_window",
str(args.search_window[0]),
str(args.search_window[1]),
]

if args.grid_file:
args_list = args_list + ["--grid_file", args.grid_file]
args_list = ["--data_dir", datadir]

if args.material_file:
args_list = args_list + ["--material_file", args.material_file]
# Add parsed options to argument list
add_detector_io_args(args_list, args)
add_track_generator_args(args_list, args)
add_propagation_args(args_list, args)

# Run the host validation and produce the truth data
logging.debug("Running CPU validation")
Expand Down
11 changes: 2 additions & 9 deletions tests/tools/python/options/propagation_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,6 @@ def propagation_options():
default=0.0001,
type=float,
)
parser.add_argument(
"--max_n_steps",
"-n_step",
help=("Max. Runge-Kutta step updates"),
default=10000,
type=int,
)
parser.add_argument(
"--path_limit",
"-plim",
Expand All @@ -99,14 +92,14 @@ def propagation_options():
"-bethe",
help=("Use Bethe energy loss"),
action="store_true",
default=True,
default=False,
)
parser.add_argument(
"--covariance_transport",
"-cov_trnsp",
help=("Do covaraiance transport"),
action="store_true",
default=True,
default=False,
)
parser.add_argument(
"--energy_loss_grad",
Expand Down
Loading

0 comments on commit 9e2e844

Please sign in to comment.