From 09f471969c7e545471a01f776e617a074fe9f043 Mon Sep 17 00:00:00 2001 From: Janos Gabler Date: Tue, 12 Nov 2024 19:51:52 +0100 Subject: [PATCH] Add tests. --- src/optimagic/visualization/history_plots.py | 2 +- tests/optimagic/test_algo_selection.py | 6 +++ .../visualization/test_history_plots.py | 46 ++++++++++++++++++- 3 files changed, 52 insertions(+), 2 deletions(-) diff --git a/src/optimagic/visualization/history_plots.py b/src/optimagic/visualization/history_plots.py index 50b3f09e9..4c4797b53 100644 --- a/src/optimagic/visualization/history_plots.py +++ b/src/optimagic/visualization/history_plots.py @@ -182,7 +182,7 @@ def _harmonize_inputs_to_dict(results, names): # handle dict case if isinstance(results, dict): if names is not None: - results_dict = dict(zip(names, results, strict=False)) + results_dict = dict(zip(names, list(results.values()), strict=False)) else: results_dict = results diff --git a/tests/optimagic/test_algo_selection.py b/tests/optimagic/test_algo_selection.py index c42cdae87..31906c7c5 100644 --- a/tests/optimagic/test_algo_selection.py +++ b/tests/optimagic/test_algo_selection.py @@ -24,3 +24,9 @@ def test_scipy_cobyla_is_present(): assert hasattr(algos.NonlinearConstrained.GradientFree.Local, "scipy_cobyla") assert hasattr(algos.NonlinearConstrained.Local.GradientFree, "scipy_cobyla") assert hasattr(algos.Local.NonlinearConstrained.GradientFree, "scipy_cobyla") + + +def test_algorithm_lists(): + assert len(algos.All) >= len(algos.Available) + assert len(algos.AllNames) == len(algos.All) + assert len(algos.AvailableNames) == len(algos.Available) diff --git a/tests/optimagic/visualization/test_history_plots.py b/tests/optimagic/visualization/test_history_plots.py index 70078b137..32bab8358 100644 --- a/tests/optimagic/visualization/test_history_plots.py +++ b/tests/optimagic/visualization/test_history_plots.py @@ -7,7 +7,11 @@ from optimagic.logging import SQLiteLogOptions from optimagic.optimization.optimize import minimize from optimagic.parameters.bounds import Bounds -from optimagic.visualization.history_plots import criterion_plot, params_plot +from optimagic.visualization.history_plots import ( + _harmonize_inputs_to_dict, + criterion_plot, + params_plot, +) @pytest.fixture() @@ -130,3 +134,43 @@ def test_criterion_plot_wrong_inputs(): with pytest.raises(ValueError): criterion_plot(["bla", "bla"], names="blub") + + +def test_harmonize_inputs_to_dict_single_result(): + res = minimize(fun=lambda x: x @ x, params=np.arange(5), algorithm="scipy_lbfgsb") + assert _harmonize_inputs_to_dict(results=res, names=None) == {"0": res} + + +def test_harmonize_inputs_to_dict_single_result_with_name(): + res = minimize(fun=lambda x: x @ x, params=np.arange(5), algorithm="scipy_lbfgsb") + assert _harmonize_inputs_to_dict(results=res, names="bla") == {"bla": res} + + +def test_harmonize_inputs_to_dict_list_results(): + res = minimize(fun=lambda x: x @ x, params=np.arange(5), algorithm="scipy_lbfgsb") + results = [res, res] + assert _harmonize_inputs_to_dict(results=results, names=None) == { + "0": res, + "1": res, + } + + +def test_harmonize_inputs_to_dict_dict_input(): + res = minimize(fun=lambda x: x @ x, params=np.arange(5), algorithm="scipy_lbfgsb") + results = {"bla": res, "blub": res} + assert _harmonize_inputs_to_dict(results=results, names=None) == results + + +def test_harmonize_inputs_to_dict_dict_input_with_names(): + res = minimize(fun=lambda x: x @ x, params=np.arange(5), algorithm="scipy_lbfgsb") + results = {"bla": res, "blub": res} + got = _harmonize_inputs_to_dict(results=results, names=["a", "b"]) + expected = {"a": res, "b": res} + assert got == expected + + +def test_harmonize_inputs_to_dict_invalid_names(): + results = [None] + names = ["a", "b"] + with pytest.raises(ValueError): + _harmonize_inputs_to_dict(results=results, names=names)