-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor benchmark architecture (#477)
This PR refines the benchmarking classes and fixes a few other things along the way: * Most importantly, the convergence-test-specific `optimal_function_inputs` and `best_possible_result` attributes are moved to a separate `ConvergenceBenchmark` class, with corresponding `ConvergenceBenchmarkSettings`. * The remaining benchmark attributes / properties are cleaned up.
- Loading branch information
Showing
12 changed files
with
188 additions
and
150 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,17 @@ | ||
"""Benchmarking module for performance tracking.""" | ||
|
||
from benchmarks.definition import Benchmark | ||
from benchmarks.definition import ( | ||
Benchmark, | ||
BenchmarkSettings, | ||
ConvergenceBenchmark, | ||
ConvergenceBenchmarkSettings, | ||
) | ||
from benchmarks.result import Result | ||
|
||
__all__ = [ | ||
"Result", | ||
"Benchmark", | ||
"BenchmarkSettings", | ||
"ConvergenceBenchmark", | ||
"ConvergenceBenchmarkSettings", | ||
"Result", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,17 @@ | ||
"""Benchmark task definitions.""" | ||
"""Benchmark definitions.""" | ||
|
||
from benchmarks.definition.config import ( | ||
from benchmarks.definition.base import ( | ||
Benchmark, | ||
BenchmarkSettings, | ||
ConvergenceExperimentSettings, | ||
) | ||
from benchmarks.definition.convergence import ( | ||
ConvergenceBenchmark, | ||
ConvergenceBenchmarkSettings, | ||
) | ||
|
||
__all__ = [ | ||
"ConvergenceExperimentSettings", | ||
"Benchmark", | ||
"BenchmarkSettings", | ||
"ConvergenceBenchmark", | ||
"ConvergenceBenchmarkSettings", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
"""Basic benchmark configuration.""" | ||
|
||
import time | ||
from abc import ABC | ||
from collections.abc import Callable | ||
from datetime import datetime, timedelta, timezone | ||
from typing import Generic, TypeVar | ||
|
||
from attrs import define, field | ||
from attrs.validators import instance_of | ||
from cattrs import override | ||
from cattrs.gen import make_dict_unstructure_fn | ||
from pandas import DataFrame | ||
|
||
from baybe.utils.random import temporary_seed | ||
from benchmarks.result import Result, ResultMetadata | ||
from benchmarks.serialization import BenchmarkSerialization, converter | ||
|
||
|
||
@define(frozen=True, kw_only=True) | ||
class BenchmarkSettings(ABC, BenchmarkSerialization): | ||
"""The basic benchmark configuration.""" | ||
|
||
random_seed: int = field(validator=instance_of(int), default=1337) | ||
"""The used random seed.""" | ||
|
||
|
||
BenchmarkSettingsType = TypeVar("BenchmarkSettingsType", bound=BenchmarkSettings) | ||
|
||
|
||
@define(frozen=True) | ||
class Benchmark(Generic[BenchmarkSettingsType], BenchmarkSerialization): | ||
"""The base class for all benchmark definitions.""" | ||
|
||
function: Callable[[BenchmarkSettingsType], DataFrame] = field() | ||
"""The callable containing the benchmarking logic.""" | ||
|
||
settings: BenchmarkSettingsType = field() | ||
"""The benchmark configuration.""" | ||
|
||
@function.validator | ||
def _validate_function(self, _, function) -> None: | ||
if function.__doc__ is None: | ||
raise ValueError("The benchmark function must have a docstring.") | ||
|
||
@property | ||
def name(self) -> str: | ||
"""The name of the benchmark function.""" | ||
return self.function.__name__ | ||
|
||
@property | ||
def description(self) -> str: | ||
"""The description of the benchmark function.""" | ||
assert self.function.__doc__ is not None | ||
return self.function.__doc__ | ||
|
||
def __call__(self) -> Result: | ||
"""Execute the benchmark and return the result.""" | ||
start_datetime = datetime.now(timezone.utc) | ||
|
||
with temporary_seed(self.settings.random_seed): | ||
start_sec = time.perf_counter() | ||
result = self.function(self.settings) | ||
stop_sec = time.perf_counter() | ||
|
||
duration = timedelta(seconds=stop_sec - start_sec) | ||
|
||
metadata = ResultMetadata( | ||
start_datetime=start_datetime, | ||
duration=duration, | ||
) | ||
|
||
return Result(self.name, result, metadata) | ||
|
||
|
||
@converter.register_unstructure_hook | ||
def unstructure_benchmark(benchmark: Benchmark) -> dict: | ||
"""Unstructure a benchmark instance.""" | ||
fn = make_dict_unstructure_fn( | ||
type(benchmark), converter, function=override(omit=True) | ||
) | ||
return { | ||
"name": benchmark.name, | ||
"description": benchmark.description, | ||
**fn(benchmark), | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""Convergence benchmark configuration.""" | ||
|
||
from typing import Any | ||
|
||
from attrs import define, field | ||
from attrs.validators import deep_mapping, instance_of, optional | ||
|
||
from benchmarks.definition.base import Benchmark, BenchmarkSettings | ||
|
||
|
||
@define(frozen=True, kw_only=True) | ||
class ConvergenceBenchmarkSettings(BenchmarkSettings): | ||
"""Benchmark configuration for recommender convergence analyses.""" | ||
|
||
batch_size: int = field(validator=instance_of(int)) | ||
"""The recommendation batch size.""" | ||
|
||
n_doe_iterations: int = field(validator=instance_of(int)) | ||
"""The number of Design of Experiment iterations.""" | ||
|
||
n_mc_iterations: int = field(validator=instance_of(int)) | ||
"""The number of Monte Carlo iterations.""" | ||
|
||
|
||
@define(frozen=True) | ||
class ConvergenceBenchmark(Benchmark[ConvergenceBenchmarkSettings]): | ||
"""A class for defining convergence benchmarks.""" | ||
|
||
optimal_target_values: dict[str, Any] | None = field( | ||
default=None, | ||
validator=optional( | ||
deep_mapping( | ||
key_validator=instance_of(str), | ||
mapping_validator=instance_of(dict), | ||
value_validator=lambda *_: None, | ||
) | ||
), | ||
) | ||
"""The optimal values that can be achieved for the targets **individually**.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.