Skip to content

Commit

Permalink
Better string representation for algorithm classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
janosg committed Nov 12, 2024
1 parent 058c79e commit d16c36f
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 23 deletions.
12 changes: 10 additions & 2 deletions .tools/create_algo_selection_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,19 @@ def _available(self) -> list[Type[Algorithm]]:
]
@property
def All(self) -> list[str]:
def All(self) -> list[Type[Algorithm]]:
return self._all()
@property
def Available(self) -> list[Type[Algorithm]]:
return self._available()
@property
def AllNames(self) -> list[str]:
return [a.__algo_info__.name for a in self._all()] # type: ignore
@property
def Available(self) -> list[str]:
def AvailableNames(self) -> list[str]:
return [a.__algo_info__.name for a in self._available()] # type: ignore
@property
Expand Down
12 changes: 10 additions & 2 deletions src/optimagic/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,19 @@ def _available(self) -> list[Type[Algorithm]]:
]

@property
def All(self) -> list[str]:
def All(self) -> list[Type[Algorithm]]:
return self._all()

Check warning on line 98 in src/optimagic/algorithms.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/algorithms.py#L98

Added line #L98 was not covered by tests

@property
def Available(self) -> list[Type[Algorithm]]:
return self._available()

Check warning on line 102 in src/optimagic/algorithms.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/algorithms.py#L102

Added line #L102 was not covered by tests

@property
def AllNames(self) -> list[str]:
return [a.__algo_info__.name for a in self._all()] # type: ignore

@property
def Available(self) -> list[str]:
def AvailableNames(self) -> list[str]:
return [a.__algo_info__.name for a in self._available()] # type: ignore

@property
Expand Down
34 changes: 32 additions & 2 deletions src/optimagic/optimization/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing
import warnings
from abc import ABC, abstractmethod
from abc import ABC, ABCMeta, abstractmethod
from dataclasses import dataclass, replace
from typing import Any

Expand Down Expand Up @@ -143,8 +143,38 @@ def __post_init__(self) -> None:
raise TypeError(msg)


class AlgorithmMeta(ABCMeta):
"""Metaclass to get repr, algo_info and name for classes, not just instances."""

def __repr__(self) -> str:
if hasattr(self, "__algo_info__") and self.__algo_info__ is not None:
out = f"om.algos.{self.__algo_info__.name}"
else:
out = self.__class__.__name__

Check warning on line 153 in src/optimagic/optimization/algorithm.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/optimization/algorithm.py#L153

Added line #L153 was not covered by tests
return out

@property
def name(self) -> str:
if hasattr(self, "__algo_info__") and self.__algo_info__ is not None:
out = self.__algo_info__.name

Check warning on line 159 in src/optimagic/optimization/algorithm.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/optimization/algorithm.py#L158-L159

Added lines #L158 - L159 were not covered by tests
else:
out = self.__class__.__name__
return out

Check warning on line 162 in src/optimagic/optimization/algorithm.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/optimization/algorithm.py#L161-L162

Added lines #L161 - L162 were not covered by tests

@property
def algo_info(self) -> AlgoInfo:
if not hasattr(self, "__algo_info__") or self.__algo_info__ is None:
msg = (

Check warning on line 167 in src/optimagic/optimization/algorithm.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/optimization/algorithm.py#L166-L167

Added lines #L166 - L167 were not covered by tests
f"The algorithm {self.name} does not have have the __algo_info__ "
"attribute. Use the `mark.minimizer` decorator to add this attribute."
)
raise AttributeError(msg)

Check warning on line 171 in src/optimagic/optimization/algorithm.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/optimization/algorithm.py#L171

Added line #L171 was not covered by tests

return self.__algo_info__

Check warning on line 173 in src/optimagic/optimization/algorithm.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/optimization/algorithm.py#L173

Added line #L173 was not covered by tests


@dataclass(frozen=True)
class Algorithm(ABC):
class Algorithm(ABC, metaclass=AlgorithmMeta):
@abstractmethod
def _solve_internal_problem(
self, problem: InternalOptimizationProblem, x0: NDArray[np.float64]
Expand Down
61 changes: 44 additions & 17 deletions src/optimagic/visualization/history_plots.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import inspect
import itertools
from pathlib import Path
from typing import Any

import numpy as np
import plotly.graph_objects as go
from pybaum import leaf_names, tree_flatten, tree_just_flatten, tree_unflatten

from optimagic.config import PLOTLY_PALETTE, PLOTLY_TEMPLATE
from optimagic.logging.logger import LogReader, SQLiteLogOptions
from optimagic.optimization.algorithm import Algorithm
from optimagic.optimization.history_tools import get_history_arrays
from optimagic.optimization.optimize_result import OptimizeResult
from optimagic.parameters.tree_registry import get_registry
Expand Down Expand Up @@ -50,23 +53,7 @@ def criterion_plot(
# Process inputs
# ==================================================================================

if not isinstance(names, list) and names is not None:
names = [names]

if not isinstance(results, dict):
if isinstance(results, list):
names = range(len(results)) if names is None else names
if len(names) != len(results):
raise ValueError("len(results) needs to be equal to len(names).")
results = dict(zip(names, results, strict=False))
else:
name = 0 if names is None else names
if isinstance(name, list):
if len(name) > 1:
raise ValueError("len(results) needs to be equal to len(names).")
else:
name = name[0]
results = {name: results}
results = _harmonize_inputs_to_dict(results, names)

if not isinstance(palette, list):
palette = [palette]
Expand Down Expand Up @@ -180,6 +167,46 @@ def criterion_plot(
return fig


def _harmonize_inputs_to_dict(results, names):
"""Convert all valid inputs for results and names to dict[str, OptimizeResult]."""
# convert scalar case to list case
if not isinstance(names, list) and names is not None:
names = [names]

if isinstance(results, OptimizeResult):
results = [results]

if names is not None and len(names) != len(results):
raise ValueError("len(results) needs to be equal to len(names).")

# handle dict case
if isinstance(results, dict):
if names is not None:
results_dict = dict(zip(names, results, strict=False))

Check warning on line 185 in src/optimagic/visualization/history_plots.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/history_plots.py#L184-L185

Added lines #L184 - L185 were not covered by tests
else:
results_dict = results

Check warning on line 187 in src/optimagic/visualization/history_plots.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/history_plots.py#L187

Added line #L187 was not covered by tests

# unlabeled iterable of results
else:
names = range(len(results)) if names is None else names
results_dict = dict(zip(names, results, strict=False))

# convert keys to strings
results_dict = {_convert_key_to_str(k): v for k, v in results_dict.items()}

return results_dict


def _convert_key_to_str(key: Any) -> str:
if inspect.isclass(key) and issubclass(key, Algorithm):
out = key.__algo_info__.name # type: ignore

Check warning on line 202 in src/optimagic/visualization/history_plots.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/history_plots.py#L202

Added line #L202 was not covered by tests
elif isinstance(key, Algorithm):
out = key.__algo_info__.name # type: ignore

Check warning on line 204 in src/optimagic/visualization/history_plots.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/history_plots.py#L204

Added line #L204 was not covered by tests
else:
out = str(key)
return out


def params_plot(
result,
selector=None,
Expand Down

0 comments on commit d16c36f

Please sign in to comment.