diff --git a/.tools/create_algo_selection_code.py b/.tools/create_algo_selection_code.py index 6d28fde6d..29c0b4983 100644 --- a/.tools/create_algo_selection_code.py +++ b/.tools/create_algo_selection_code.py @@ -389,11 +389,19 @@ def _available(self) -> list[Type[Algorithm]]: ] @property - def All(self) -> list[str]: + def All(self) -> list[Type[Algorithm]]: + return self._all() + + @property + def Available(self) -> list[Type[Algorithm]]: + return self._available() + + @property + def AllNames(self) -> list[str]: return [a.__algo_info__.name for a in self._all()] # type: ignore @property - def Available(self) -> list[str]: + def AvailableNames(self) -> list[str]: return [a.__algo_info__.name for a in self._available()] # type: ignore @property diff --git a/src/optimagic/algorithms.py b/src/optimagic/algorithms.py index 748189786..4341744b4 100644 --- a/src/optimagic/algorithms.py +++ b/src/optimagic/algorithms.py @@ -94,11 +94,19 @@ def _available(self) -> list[Type[Algorithm]]: ] @property - def All(self) -> list[str]: + def All(self) -> list[Type[Algorithm]]: + return self._all() + + @property + def Available(self) -> list[Type[Algorithm]]: + return self._available() + + @property + def AllNames(self) -> list[str]: return [a.__algo_info__.name for a in self._all()] # type: ignore @property - def Available(self) -> list[str]: + def AvailableNames(self) -> list[str]: return [a.__algo_info__.name for a in self._available()] # type: ignore @property diff --git a/src/optimagic/optimization/algorithm.py b/src/optimagic/optimization/algorithm.py index 3bef1d09f..7f776cf90 100644 --- a/src/optimagic/optimization/algorithm.py +++ b/src/optimagic/optimization/algorithm.py @@ -1,6 +1,6 @@ import typing import warnings -from abc import ABC, abstractmethod +from abc import ABC, ABCMeta, abstractmethod from dataclasses import dataclass, replace from typing import Any @@ -143,8 +143,38 @@ def __post_init__(self) -> None: raise TypeError(msg) +class AlgorithmMeta(ABCMeta): + """Metaclass to get repr, algo_info and name for classes, not just instances.""" + + def __repr__(self) -> str: + if hasattr(self, "__algo_info__") and self.__algo_info__ is not None: + out = f"om.algos.{self.__algo_info__.name}" + else: + out = self.__class__.__name__ + return out + + @property + def name(self) -> str: + if hasattr(self, "__algo_info__") and self.__algo_info__ is not None: + out = self.__algo_info__.name + else: + out = self.__class__.__name__ + return out + + @property + def algo_info(self) -> AlgoInfo: + if not hasattr(self, "__algo_info__") or self.__algo_info__ is None: + msg = ( + f"The algorithm {self.name} does not have have the __algo_info__ " + "attribute. Use the `mark.minimizer` decorator to add this attribute." + ) + raise AttributeError(msg) + + return self.__algo_info__ + + @dataclass(frozen=True) -class Algorithm(ABC): +class Algorithm(ABC, metaclass=AlgorithmMeta): @abstractmethod def _solve_internal_problem( self, problem: InternalOptimizationProblem, x0: NDArray[np.float64] diff --git a/src/optimagic/visualization/history_plots.py b/src/optimagic/visualization/history_plots.py index 68a538638..1d2247e71 100644 --- a/src/optimagic/visualization/history_plots.py +++ b/src/optimagic/visualization/history_plots.py @@ -1,5 +1,7 @@ +import inspect import itertools from pathlib import Path +from typing import Any import numpy as np import plotly.graph_objects as go @@ -7,6 +9,7 @@ from optimagic.config import PLOTLY_PALETTE, PLOTLY_TEMPLATE from optimagic.logging.logger import LogReader, SQLiteLogOptions +from optimagic.optimization.algorithm import Algorithm from optimagic.optimization.history_tools import get_history_arrays from optimagic.optimization.optimize_result import OptimizeResult from optimagic.parameters.tree_registry import get_registry @@ -50,23 +53,7 @@ def criterion_plot( # Process inputs # ================================================================================== - if not isinstance(names, list) and names is not None: - names = [names] - - if not isinstance(results, dict): - if isinstance(results, list): - names = range(len(results)) if names is None else names - if len(names) != len(results): - raise ValueError("len(results) needs to be equal to len(names).") - results = dict(zip(names, results, strict=False)) - else: - name = 0 if names is None else names - if isinstance(name, list): - if len(name) > 1: - raise ValueError("len(results) needs to be equal to len(names).") - else: - name = name[0] - results = {name: results} + results = _harmonize_inputs_to_dict(results, names) if not isinstance(palette, list): palette = [palette] @@ -180,6 +167,46 @@ def criterion_plot( return fig +def _harmonize_inputs_to_dict(results, names): + """Convert all valid inputs for results and names to dict[str, OptimizeResult].""" + # convert scalar case to list case + if not isinstance(names, list) and names is not None: + names = [names] + + if isinstance(results, OptimizeResult): + results = [results] + + if names is not None and len(names) != len(results): + raise ValueError("len(results) needs to be equal to len(names).") + + # handle dict case + if isinstance(results, dict): + if names is not None: + results_dict = dict(zip(names, results, strict=False)) + else: + results_dict = results + + # unlabeled iterable of results + else: + names = range(len(results)) if names is None else names + results_dict = dict(zip(names, results, strict=False)) + + # convert keys to strings + results_dict = {_convert_key_to_str(k): v for k, v in results_dict.items()} + + return results_dict + + +def _convert_key_to_str(key: Any) -> str: + if inspect.isclass(key) and issubclass(key, Algorithm): + out = key.__algo_info__.name # type: ignore + elif isinstance(key, Algorithm): + out = key.__algo_info__.name # type: ignore + else: + out = str(key) + return out + + def params_plot( result, selector=None,