Skip to content

Commit

Permalink
add legend option, improve points defaults.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 4, 2023
1 parent c1e4b67 commit 2d62131
Showing 1 changed file with 75 additions and 48 deletions.
123 changes: 75 additions & 48 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# under the Affero General Public License v3, see <https://www.gnu.org/licenses/>.

import collections
from logging import warn
from typing import Any, Dict, List, Optional, Tuple, Union

import matplotlib as mpl
Expand Down Expand Up @@ -211,7 +212,10 @@ def diag_func(row, **kwargs):
for n, v in enumerate(samples):
if opts["diag"][n] == "hist":
plt.hist(
v[:, row], color=opts["samples_colors"][n], **opts["hist_diag"]
v[:, row],
color=opts["samples_colors"][n],
label=opts["samples_labels"][n],
**opts["hist_diag"],
)
elif opts["diag"][n] == "kde":
density = gaussian_kde(
Expand All @@ -226,7 +230,7 @@ def diag_func(row, **kwargs):
ys,
color=opts["samples_colors"][n],
)
elif "upper" in opts.keys() and opts["upper"][n] == "scatter":
elif "offdiag" in opts.keys() and opts["offdiag"][n] == "scatter":
for single_sample in v:
plt.axvline(
single_sample[row],
Expand Down Expand Up @@ -280,12 +284,12 @@ def pairplot(
] = None,
limits: Optional[Union[List, torch.Tensor]] = None,
subset: Optional[List[int]] = None,
upper: Optional[str] = "hist",
offdiag: Optional[str] = "hist",
diag: Optional[str] = "hist",
figsize: Tuple = (10, 10),
labels: Optional[List[str]] = None,
ticks: Union[List, torch.Tensor] = [],
points_colors: List[str] = plt.rcParams["axes.prop_cycle"].by_key()["color"],
upper: Optional[str] = None,
fig=None,
axes=None,
**kwargs,
Expand All @@ -305,16 +309,18 @@ def pairplot(
subset: List containing the dimensions to plot. E.g. subset=[1,3] will plot
plot only the 1st and 3rd dimension but will discard the 0th and 2nd (and,
if they exist, the 4th, 5th and so on).
upper: Plotting style for upper diagonal, {hist, scatter, contour, cond, None}.
offdiag: Plotting style for upper diagonal, {hist, scatter, contour, cond,
None}.
upper: deprecated, use offdiag instead.
diag: Plotting style for diagonal, {hist, cond, None}.
figsize: Size of the entire figure.
labels: List of strings specifying the names of the parameters.
ticks: Position of the ticks.
points_colors: Colors of the `points`.
fig: matplotlib figure to plot on.
axes: matplotlib axes corresponding to fig.
**kwargs: Additional arguments to adjust the plot, see the source code in
`_get_default_opts()` in `sbi.utils.plot` for more details.
**kwargs: Additional arguments to adjust the plot, e.g., `samples_colors`,
`points_colors` and many more, see the source code in `_get_default_opts()`
in `sbi.analysis.plot` for details.
Returns: figure and axis of posterior distribution plot
"""
Expand All @@ -332,21 +338,30 @@ def pairplot(

samples, dim, limits = prepare_for_plot(samples, limits)

# checks.
if opts["legend"]:
assert len(opts["samples_labels"]) >= len(
samples
), "Provide at least as many labels as samples."
if opts["upper"] is not None:
warn("upper is deprecated, use offdiag instead.")
opts["offdiag"] = opts["upper"]

# Prepare diag/upper/lower
if type(opts["diag"]) is not list:
opts["diag"] = [opts["diag"] for _ in range(len(samples))]
if type(opts["upper"]) is not list:
opts["upper"] = [opts["upper"] for _ in range(len(samples))]
if type(opts["offdiag"]) is not list:
opts["offdiag"] = [opts["offdiag"] for _ in range(len(samples))]
# if type(opts['lower']) is not list:
# opts['lower'] = [opts['lower'] for _ in range(len(samples))]
opts["lower"] = None

diag_func = get_diag_func(samples, limits, opts, **kwargs)

def upper_func(row, col, limits, **kwargs):
def offdiag_func(row, col, limits, **kwargs):
if len(samples) > 0:
for n, v in enumerate(samples):
if opts["upper"][n] == "hist" or opts["upper"][n] == "hist2d":
if opts["offdiag"][n] == "hist" or opts["offdiag"][n] == "hist2d":
hist, xedges, yedges = np.histogram2d(
v[:, col],
v[:, row],
Expand All @@ -368,7 +383,7 @@ def upper_func(row, col, limits, **kwargs):
aspect="auto",
)

elif opts["upper"][n] in [
elif opts["offdiag"][n] in [
"kde",
"kde2d",
"contour",
Expand All @@ -393,7 +408,7 @@ def upper_func(row, col, limits, **kwargs):
positions = np.vstack([X.ravel(), Y.ravel()])
Z = np.reshape(density(positions).T, X.shape)

if opts["upper"][n] == "kde" or opts["upper"][n] == "kde2d":
if opts["offdiag"][n] == "kde" or opts["offdiag"][n] == "kde2d":
plt.imshow(
Z,
extent=(
Expand All @@ -405,7 +420,7 @@ def upper_func(row, col, limits, **kwargs):
origin="lower",
aspect="auto",
)
elif opts["upper"][n] == "contour":
elif opts["offdiag"][n] == "contour":
if opts["contour_offdiag"]["percentile"]:
Z = probs2contours(Z, opts["contour_offdiag"]["levels"])
else:
Expand All @@ -426,14 +441,14 @@ def upper_func(row, col, limits, **kwargs):
)
else:
pass
elif opts["upper"][n] == "scatter":
elif opts["offdiag"][n] == "scatter":
plt.scatter(
v[:, col],
v[:, row],
color=opts["samples_colors"][n],
**opts["scatter_offdiag"],
)
elif opts["upper"][n] == "plot":
elif opts["offdiag"][n] == "plot":
plt.plot(
v[:, col],
v[:, row],
Expand All @@ -444,7 +459,7 @@ def upper_func(row, col, limits, **kwargs):
pass

return _arrange_plots(
diag_func, upper_func, dim, limits, points, opts, fig=fig, axes=axes
diag_func, offdiag_func, dim, limits, points, opts, fig=fig, axes=axes
)


Expand All @@ -459,7 +474,6 @@ def marginal_plot(
figsize: Tuple = (10, 10),
labels: Optional[List[str]] = None,
ticks: Union[List, torch.Tensor] = [],
points_colors: List[str] = plt.rcParams["axes.prop_cycle"].by_key()["color"],
fig=None,
axes=None,
**kwargs,
Expand All @@ -485,8 +499,9 @@ def marginal_plot(
points_colors: Colors of the `points`.
fig: matplotlib figure to plot on.
axes: matplotlib axes corresponding to fig.
**kwargs: Additional arguments to adjust the plot, see the source code in
`_get_default_opts()` in `sbi.utils.plot` for more details.
**kwargs: Additional arguments to adjust the plot, e.g., `samples_colors`,
`points_colors` and many more, see the source code in `_get_default_opts()`
in `sbi.analysis.plot` for details.
Returns: figure and axis of posterior distribution plot
"""
Expand Down Expand Up @@ -523,7 +538,6 @@ def conditional_marginal_plot(
figsize: Tuple = (10, 10),
labels: Optional[List[str]] = None,
ticks: Union[List, torch.Tensor] = [],
points_colors: List[str] = plt.rcParams["axes.prop_cycle"].by_key()["color"],
fig=None,
axes=None,
**kwargs,
Expand Down Expand Up @@ -559,8 +573,9 @@ def conditional_marginal_plot(
fig: matplotlib figure to plot on.
axes: matplotlib axes corresponding to fig.
**kwargs: Additional arguments to adjust the plot, see the source code in
`_get_default_opts()` in `sbi.utils.plot` for more details.
**kwargs: Additional arguments to adjust the plot, e.g., `samples_colors`,
`points_colors` and many more, see the source code in `_get_default_opts()`
in `sbi.analysis.plot` for details.
Returns: figure and axis of posterior distribution plot
"""
Expand Down Expand Up @@ -596,7 +611,6 @@ def conditional_pairplot(
figsize: Tuple = (10, 10),
labels: Optional[List[str]] = None,
ticks: Union[List, torch.Tensor] = [],
points_colors: List[str] = plt.rcParams["axes.prop_cycle"].by_key()["color"],
fig=None,
axes=None,
**kwargs,
Expand Down Expand Up @@ -634,8 +648,9 @@ def conditional_pairplot(
fig: matplotlib figure to plot on.
axes: matplotlib axes corresponding to fig.
**kwargs: Additional arguments to adjust the plot, see the source code in
`_get_default_opts()` in `sbi.utils.plot` for more details.
**kwargs: Additional arguments to adjust the plot, e.g., `samples_colors`,
`points_colors` and many more, see the source code in `_get_default_opts()`
in `sbi.analysis.plot` for details.
Returns: figure and axis of posterior distribution plot
"""
Expand All @@ -644,7 +659,7 @@ def conditional_pairplot(
# Setting these is required because _pairplot_scaffold will check if opts['diag'] is
# `None`. This would break if opts has no key 'diag'. Same for 'upper'.
diag = "cond"
upper = "cond"
offdiag = "cond"

opts = _get_default_opts()
# update the defaults dictionary by the current values of the variables (passed by
Expand All @@ -656,7 +671,7 @@ def conditional_pairplot(
dim, limits, eps_margins = prepare_for_conditional_plot(condition, opts)
diag_func = get_conditional_diag_func(opts, limits, eps_margins, resolution)

def upper_func(row, col, **kwargs):
def offdiag_func(row, col, **kwargs):
p_image = (
eval_conditional_density(
opts["density"],
Expand All @@ -675,21 +690,21 @@ def upper_func(row, col, **kwargs):
p_image.T,
origin="lower",
extent=(
limits[col, 0],
limits[col, 1],
limits[row, 0],
limits[row, 1],
limits[col, 0].item(),
limits[col, 1].item(),
limits[row, 0].item(),
limits[row, 1].item(),
),
aspect="auto",
)

return _arrange_plots(
diag_func, upper_func, dim, limits, points, opts, fig=fig, axes=axes
diag_func, offdiag_func, dim, limits, points, opts, fig=fig, axes=axes
)


def _arrange_plots(
diag_func, upper_func, dim, limits, points, opts, fig=None, axes=None
diag_func, offdiag_func, dim, limits, points, opts, fig=None, axes=None
):
"""
Arranges the plots for any function that plots parameters either in a row of 1D
Expand All @@ -700,7 +715,7 @@ def _arrange_plots(
the plot (or the columns of a row of 1D marginals). It will be passed the
current `row` (i.e. which parameter that is to be plotted) and the `limits`
for all dimensions.
upper_func: Plotting function that will be executed for the upper-diagonal
offdiag_func: Plotting function that will be executed for the upper-diagonal
elements of the plot. It will be passed the current `row` and `col` (i.e.
which parameters are to be plotted and the `limits` for all dimensions. None
if we are in a 1D setting.
Expand Down Expand Up @@ -755,7 +770,7 @@ def _arrange_plots(
else:
raise NotImplementedError
rows = cols = len(subset)
flat = upper_func is None
flat = offdiag_func is None
if flat:
rows = 1
opts["lower"] = None
Expand Down Expand Up @@ -798,7 +813,7 @@ def _arrange_plots(
elif row == col:
current = "diag"
elif row < col:
current = "upper"
current = "offdiag"
else:
current = "lower"

Expand Down Expand Up @@ -876,11 +891,14 @@ def _arrange_plots(
extent,
color=opts["points_colors"][n],
**opts["points_diag"],
label=opts["points_labels"][n],
)
if opts["legend"] and col == 0:
plt.legend(**opts["legend_kwargs"])

# Off-diagonals
else:
upper_func(
offdiag_func(
row=row,
col=col,
limits=limits,
Expand Down Expand Up @@ -923,22 +941,28 @@ def _arrange_plots(

def _get_default_opts():
"""Return default values for plotting specs."""

return {
# 'lower': None, # hist/scatter/None # TODO: implement
# title and legend
"title": None,
"legend": False,
"legend_kwargs": {},
# labels
"labels_points": [], # for points
"labels_samples": [], # for samples
# colors
"samples_colors": plt.rcParams["axes.prop_cycle"].by_key()["color"],
"points_labels": [f"points_{idx}" for idx in range(10)], # for points
"samples_labels": [f"samples_{idx}" for idx in range(10)], # for samples
# colors: take even colors for samples, odd colors for points
"samples_colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][0::2],
"points_colors": plt.rcParams["axes.prop_cycle"].by_key()["color"][1::2],
# ticks
"tickformatter": mpl.ticker.FormatStrFormatter("%g"),
"tick_labels": None,
# options for hist
"hist_diag": {"alpha": 1.0, "bins": 50, "density": False, "histtype": "step"},
"hist_diag": {
"alpha": 1.0,
"bins": 50,
"density": False,
"histtype": "step",
},
"hist_offdiag": {
# 'edgecolor': 'none',
# 'linewidth': 0.0,
Expand All @@ -962,10 +986,10 @@ def _get_default_opts():
"points_diag": {},
"points_offdiag": {
"marker": ".",
"markersize": 20,
"markersize": 10,
},
# other options
"fig_bg_colors": {"upper": None, "diag": None, "lower": None},
"fig_bg_colors": {"offdiag": None, "diag": None, "lower": None},
"fig_subplots_adjust": {
"top": 0.9,
},
Expand Down Expand Up @@ -1124,7 +1148,10 @@ def _sbc_rank_plot(
if params_in_subplots:
if fig is None or ax is None:
fig, ax = plt.subplots(
num_rows, min(num_parameters, num_cols), figsize=figsize, sharey=sharey
num_rows,
min(num_parameters, num_cols),
figsize=figsize,
sharey=sharey,
)
ax = np.atleast_1d(ax) # type: ignore
else:
Expand Down

0 comments on commit 2d62131

Please sign in to comment.