Skip to content

Commit

Permalink
Refactor benchmark architecture (#477)
Browse files Browse the repository at this point in the history
This PR refines the benchmarking classes and fixes a few other things
along the way:
* Most importantly, the convergence-test-specific
`optimal_function_inputs` and `best_possible_result` attributes are
moved to a separate `ConvergenceBenchmark` class, with corresponding
`ConvergenceBenchmarkSettings`.
* The remaining benchmark attributes / properties are cleaned up.
  • Loading branch information
AdrianSosic authored Feb 7, 2025
2 parents ac8c518 + 2322b2a commit ae64ca1
Show file tree
Hide file tree
Showing 12 changed files with 188 additions and 150 deletions.
38 changes: 19 additions & 19 deletions .lockfiles/py310-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ anyio==4.4.0
# via
# httpx
# jupyter-server
appnope==0.1.4 ; platform_system == 'Darwin'
appnope==0.1.4 ; sys_platform == 'darwin'
# via ipykernel
argon2-cffi==23.1.0
# via jupyter-server
Expand Down Expand Up @@ -61,7 +61,7 @@ cachetools==5.4.0
# via
# streamlit
# tox
cattrs==23.2.3
cattrs==24.1.2
# via baybe (pyproject.toml)
certifi==2024.7.4
# via
Expand Down Expand Up @@ -240,7 +240,7 @@ importlib-metadata==7.1.0
# opentelemetry-api
iniconfig==2.0.0
# via pytest
intel-openmp==2021.4.0 ; platform_system == 'Windows'
intel-openmp==2021.4.0 ; sys_platform == 'win32'
# via mkl
interface-meta==1.3.0
# via formulaic
Expand Down Expand Up @@ -393,7 +393,7 @@ mdurl==0.1.2
# via markdown-it-py
mistune==3.0.2
# via nbconvert
mkl==2021.4.0 ; platform_system == 'Windows'
mkl==2021.4.0 ; sys_platform == 'win32'
# via torch
mmh3==5.0.1
# via e3fp
Expand Down Expand Up @@ -487,36 +487,36 @@ numpy==1.26.4
# types-seaborn
# xarray
# xyzpy
nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
onnx==1.16.1
# via
Expand Down Expand Up @@ -922,7 +922,7 @@ sympy==1.13.1
# via
# onnxruntime
# torch
tbb==2021.13.0 ; platform_system == 'Windows'
tbb==2021.13.0 ; sys_platform == 'win32'
# via mkl
tenacity==8.5.0
# via
Expand Down Expand Up @@ -1007,7 +1007,7 @@ traitlets==5.14.3
# nbclient
# nbconvert
# nbformat
triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and platform_system == 'Linux'
triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
typeguard==2.13.3
# via
Expand Down Expand Up @@ -1050,7 +1050,7 @@ virtualenv==20.26.3
# via
# pre-commit
# tox
watchdog==4.0.1 ; platform_system != 'Darwin'
watchdog==4.0.1 ; sys_platform != 'darwin'
# via streamlit
wcwidth==0.2.13
# via prompt-toolkit
Expand Down
12 changes: 10 additions & 2 deletions benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Benchmarking module for performance tracking."""

from benchmarks.definition import Benchmark
from benchmarks.definition import (
Benchmark,
BenchmarkSettings,
ConvergenceBenchmark,
ConvergenceBenchmarkSettings,
)
from benchmarks.result import Result

__all__ = [
"Result",
"Benchmark",
"BenchmarkSettings",
"ConvergenceBenchmark",
"ConvergenceBenchmarkSettings",
"Result",
]
12 changes: 8 additions & 4 deletions benchmarks/definition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Benchmark task definitions."""
"""Benchmark definitions."""

from benchmarks.definition.config import (
from benchmarks.definition.base import (
Benchmark,
BenchmarkSettings,
ConvergenceExperimentSettings,
)
from benchmarks.definition.convergence import (
ConvergenceBenchmark,
ConvergenceBenchmarkSettings,
)

__all__ = [
"ConvergenceExperimentSettings",
"Benchmark",
"BenchmarkSettings",
"ConvergenceBenchmark",
"ConvergenceBenchmarkSettings",
]
86 changes: 86 additions & 0 deletions benchmarks/definition/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Basic benchmark configuration."""

import time
from abc import ABC
from collections.abc import Callable
from datetime import datetime, timedelta, timezone
from typing import Generic, TypeVar

from attrs import define, field
from attrs.validators import instance_of
from cattrs import override
from cattrs.gen import make_dict_unstructure_fn
from pandas import DataFrame

from baybe.utils.random import temporary_seed
from benchmarks.result import Result, ResultMetadata
from benchmarks.serialization import BenchmarkSerialization, converter


@define(frozen=True, kw_only=True)
class BenchmarkSettings(ABC, BenchmarkSerialization):
"""The basic benchmark configuration."""

random_seed: int = field(validator=instance_of(int), default=1337)
"""The used random seed."""


BenchmarkSettingsType = TypeVar("BenchmarkSettingsType", bound=BenchmarkSettings)


@define(frozen=True)
class Benchmark(Generic[BenchmarkSettingsType], BenchmarkSerialization):
"""The base class for all benchmark definitions."""

function: Callable[[BenchmarkSettingsType], DataFrame] = field()
"""The callable containing the benchmarking logic."""

settings: BenchmarkSettingsType = field()
"""The benchmark configuration."""

@function.validator
def _validate_function(self, _, function) -> None:
if function.__doc__ is None:
raise ValueError("The benchmark function must have a docstring.")

@property
def name(self) -> str:
"""The name of the benchmark function."""
return self.function.__name__

@property
def description(self) -> str:
"""The description of the benchmark function."""
assert self.function.__doc__ is not None
return self.function.__doc__

def __call__(self) -> Result:
"""Execute the benchmark and return the result."""
start_datetime = datetime.now(timezone.utc)

with temporary_seed(self.settings.random_seed):
start_sec = time.perf_counter()
result = self.function(self.settings)
stop_sec = time.perf_counter()

duration = timedelta(seconds=stop_sec - start_sec)

metadata = ResultMetadata(
start_datetime=start_datetime,
duration=duration,
)

return Result(self.name, result, metadata)


@converter.register_unstructure_hook
def unstructure_benchmark(benchmark: Benchmark) -> dict:
"""Unstructure a benchmark instance."""
fn = make_dict_unstructure_fn(
type(benchmark), converter, function=override(omit=True)
)
return {
"name": benchmark.name,
"description": benchmark.description,
**fn(benchmark),
}
103 changes: 0 additions & 103 deletions benchmarks/definition/config.py

This file was deleted.

39 changes: 39 additions & 0 deletions benchmarks/definition/convergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Convergence benchmark configuration."""

from typing import Any

from attrs import define, field
from attrs.validators import deep_mapping, instance_of, optional

from benchmarks.definition.base import Benchmark, BenchmarkSettings


@define(frozen=True, kw_only=True)
class ConvergenceBenchmarkSettings(BenchmarkSettings):
"""Benchmark configuration for recommender convergence analyses."""

batch_size: int = field(validator=instance_of(int))
"""The recommendation batch size."""

n_doe_iterations: int = field(validator=instance_of(int))
"""The number of Design of Experiment iterations."""

n_mc_iterations: int = field(validator=instance_of(int))
"""The number of Monte Carlo iterations."""


@define(frozen=True)
class ConvergenceBenchmark(Benchmark[ConvergenceBenchmarkSettings]):
"""A class for defining convergence benchmarks."""

optimal_target_values: dict[str, Any] | None = field(
default=None,
validator=optional(
deep_mapping(
key_validator=instance_of(str),
mapping_validator=instance_of(dict),
value_validator=lambda *_: None,
)
),
)
"""The optimal values that can be achieved for the targets **individually**."""
2 changes: 1 addition & 1 deletion benchmarks/domains/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Benchmark domains."""

from benchmarks.definition.config import Benchmark
from benchmarks.definition.base import Benchmark
from benchmarks.domains.synthetic_2C1D_1C import synthetic_2C1D_1C_benchmark

BENCHMARKS: list[Benchmark] = [
Expand Down
Loading

0 comments on commit ae64ca1

Please sign in to comment.