diff --git a/src/estimagic/visualization/profile_plot.py b/src/estimagic/visualization/profile_plot.py index 992b16e3f..be3fb87a3 100644 --- a/src/estimagic/visualization/profile_plot.py +++ b/src/estimagic/visualization/profile_plot.py @@ -160,13 +160,12 @@ def create_solution_times(df, runtime_measure, converged_info, return_tidy=True) problem, algorithm and runtime_measure. The values are either the number of evaluations or the walltime each algorithm needed to achieve the desired precision. If the desired precision was not achieved the value is - set to np.inf (for n_evaluations) or 7000 days (for walltime since there - no infinite value is allowed). + set to np.inf. """ solution_times = df.groupby(["problem", "algorithm"])[runtime_measure].max() - solution_times = solution_times.unstack() - solution_times[~converged_info] = np.inf + solution_times = solution_times.unstack().astype(float) + solution_times = solution_times.where(converged_info, other=np.inf) if not return_tidy: solution_times = solution_times.stack().reset_index() diff --git a/tests/visualization/test_profile_plot.py b/tests/visualization/test_profile_plot.py index 2d3a7fabb..ff1cc393a 100644 --- a/tests/visualization/test_profile_plot.py +++ b/tests/visualization/test_profile_plot.py @@ -57,8 +57,8 @@ def test_create_solution_times_n_evaluations(): ) expected = pd.DataFrame( { - "algo1": [1, 5], - "algo2": [3, np.inf], + "algo1": [1.0, 5], + "algo2": [3.0, np.inf], }, index=pd.Index(["prob1", "prob2"], name="problem"), ) @@ -95,8 +95,8 @@ def test_create_solution_times_n_batches(): ) expected = pd.DataFrame( { - "algo1": [1, 1], - "algo2": [2, np.inf], + "algo1": [1.0, 1], + "algo2": [2.0, np.inf], }, index=pd.Index(["prob1", "prob2"], name="problem"), ) @@ -131,8 +131,8 @@ def test_create_solution_times_walltime(): ) expected = pd.DataFrame( { - "algo1": [1, 5], - "algo2": [3, np.inf], + "algo1": [1.0, 5], + "algo2": [3.0, np.inf], }, index=pd.Index(["prob1", "prob2"], name="problem"), )