Skip to content

Commit

Permalink
Merge pull request xopt-org#193 from t-bz/fix_visualization
Browse files Browse the repository at this point in the history
Fix indexing error for model visualization
  • Loading branch information
roussel-ryan authored Feb 1, 2024
2 parents b2f0be4 + 665f84f commit 6cb0704
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions xopt/generators/bayesian/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ def visualize_generator_model(
figsize = (6, 2 * nrows)
else:
sharex, sharey = True, True
figsize = (4 * ncols, 3.2 * nrows)
if nrows == 1:
figsize = (4 * ncols, 3.7 * nrows)
else:
figsize = (4 * ncols, 3.2 * nrows)
# lazy import
from matplotlib import pyplot as plt
from matplotlib.legend_handler import HandlerTuple
Expand Down Expand Up @@ -262,36 +265,38 @@ def visualize_generator_model(
else:
for i in range(nrows):
for j in range(ncols):
ax_ij = ax[i, j] if nrows > 1 else ax[j]
if i == nrows - 1:
ax[i, j].set_xlabel(variable_names[0])
ax_ij.set_xlabel(variable_names[0])
if j == 0:
ax[i, j].set_ylabel(variable_names[1])
ax[i, j].locator_params(axis="both", nbins=5)
ax_ij.set_ylabel(variable_names[1])
ax_ij.locator_params(axis="both", nbins=5)
for i, output_name in enumerate(output_names):
for j in range(ncols):
ax_ij = ax[i, j] if nrows > 1 else ax[j]
# model predictions
pcm = ax[i, j].pcolormesh(
pcm = ax_ij.pcolormesh(
x_mesh[0].numpy(),
x_mesh[1].numpy(),
predictions[output_name][j].reshape(n_grid, n_grid),
)
divider = make_axes_locatable(ax[i, j])
divider = make_axes_locatable(ax_ij)
cax = divider.append_axes("right", size="5%", pad=0.1)
cbar = fig.colorbar(pcm, cax=cax)
if j == 0:
ax[i, j].set_title(f"Posterior Mean [{output_name}]")
ax_ij.set_title(f"Posterior Mean [{output_name}]")
cbar.set_label(output_name)
elif j == 1:
ax[i, j].set_title(f"Posterior SD [{output_name}]")
ax_ij.set_title(f"Posterior SD [{output_name}]")
cbar.set_label(r"$\sigma\,$[{}]".format(output_name))
else:
ax[i, j].set_title(f"Prior Mean [{output_name}]")
ax_ij.set_title(f"Prior Mean [{output_name}]")
cbar.set_label(output_name)
# data samples
if show_samples:
if not samples[output_name][1].empty:
x1_feasible, x2_feasible = samples[output_name][1].to_numpy().T
ax[i, j].scatter(
ax_ij.scatter(
x1_feasible,
x2_feasible,
marker="o",
Expand All @@ -304,7 +309,7 @@ def visualize_generator_model(
x1_infeasible, x2_infeasible = (
samples[output_name][2].to_numpy().T
)
ax[i, j].scatter(
ax_ij.scatter(
x1_infeasible,
x2_infeasible,
marker="o",
Expand All @@ -314,7 +319,7 @@ def visualize_generator_model(
label="Infeasible Samples",
)
if i == j == 0:
handles, labels = ax[i, j].get_legend_handles_labels()
handles, labels = ax_ij.get_legend_handles_labels()
if all(
[
ele in labels
Expand All @@ -324,7 +329,7 @@ def visualize_generator_model(
labels[-2] = "In-/Feasible Samples"
handles[-2] = [handles[-1], handles[-2]]
del labels[-1], handles[-1]
ax[i, j].legend(
ax_ij.legend(
labels=labels,
handles=handles,
handler_map={list: HandlerTuple(ndivide=None)},
Expand Down

0 comments on commit 6cb0704

Please sign in to comment.