From 665f84f7449443bde4bcc594d4539f3b0c104a8b Mon Sep 17 00:00:00 2001 From: Tobias Boltz Date: Wed, 31 Jan 2024 15:36:45 -0800 Subject: [PATCH] Fix indexing error for single objective with show_acqusition set to false --- xopt/generators/bayesian/visualize.py | 31 ++++++++++++++++----------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/xopt/generators/bayesian/visualize.py b/xopt/generators/bayesian/visualize.py index 90f5e638..71a86d9d 100644 --- a/xopt/generators/bayesian/visualize.py +++ b/xopt/generators/bayesian/visualize.py @@ -142,7 +142,10 @@ def visualize_generator_model( figsize = (6, 2 * nrows) else: sharex, sharey = True, True - figsize = (4 * ncols, 3.2 * nrows) + if nrows == 1: + figsize = (4 * ncols, 3.7 * nrows) + else: + figsize = (4 * ncols, 3.2 * nrows) # lazy import from matplotlib import pyplot as plt from matplotlib.legend_handler import HandlerTuple @@ -262,36 +265,38 @@ def visualize_generator_model( else: for i in range(nrows): for j in range(ncols): + ax_ij = ax[i, j] if nrows > 1 else ax[j] if i == nrows - 1: - ax[i, j].set_xlabel(variable_names[0]) + ax_ij.set_xlabel(variable_names[0]) if j == 0: - ax[i, j].set_ylabel(variable_names[1]) - ax[i, j].locator_params(axis="both", nbins=5) + ax_ij.set_ylabel(variable_names[1]) + ax_ij.locator_params(axis="both", nbins=5) for i, output_name in enumerate(output_names): for j in range(ncols): + ax_ij = ax[i, j] if nrows > 1 else ax[j] # model predictions - pcm = ax[i, j].pcolormesh( + pcm = ax_ij.pcolormesh( x_mesh[0].numpy(), x_mesh[1].numpy(), predictions[output_name][j].reshape(n_grid, n_grid), ) - divider = make_axes_locatable(ax[i, j]) + divider = make_axes_locatable(ax_ij) cax = divider.append_axes("right", size="5%", pad=0.1) cbar = fig.colorbar(pcm, cax=cax) if j == 0: - ax[i, j].set_title(f"Posterior Mean [{output_name}]") + ax_ij.set_title(f"Posterior Mean [{output_name}]") cbar.set_label(output_name) elif j == 1: - ax[i, j].set_title(f"Posterior SD [{output_name}]") + ax_ij.set_title(f"Posterior SD [{output_name}]") cbar.set_label(r"$\sigma\,$[{}]".format(output_name)) else: - ax[i, j].set_title(f"Prior Mean [{output_name}]") + ax_ij.set_title(f"Prior Mean [{output_name}]") cbar.set_label(output_name) # data samples if show_samples: if not samples[output_name][1].empty: x1_feasible, x2_feasible = samples[output_name][1].to_numpy().T - ax[i, j].scatter( + ax_ij.scatter( x1_feasible, x2_feasible, marker="o", @@ -304,7 +309,7 @@ def visualize_generator_model( x1_infeasible, x2_infeasible = ( samples[output_name][2].to_numpy().T ) - ax[i, j].scatter( + ax_ij.scatter( x1_infeasible, x2_infeasible, marker="o", @@ -314,7 +319,7 @@ def visualize_generator_model( label="Infeasible Samples", ) if i == j == 0: - handles, labels = ax[i, j].get_legend_handles_labels() + handles, labels = ax_ij.get_legend_handles_labels() if all( [ ele in labels @@ -324,7 +329,7 @@ def visualize_generator_model( labels[-2] = "In-/Feasible Samples" handles[-2] = [handles[-1], handles[-2]] del labels[-1], handles[-1] - ax[i, j].legend( + ax_ij.legend( labels=labels, handles=handles, handler_map={list: HandlerTuple(ndivide=None)},