Skip to content

Commit

Permalink
Merge pull request #458 from DHI/feature/multiple_model_support_for_r…
Browse files Browse the repository at this point in the history
…esidual_hist_on_comparer

Feature/multiple model support for residual hist on comparer
  • Loading branch information
ryan-kipawa authored Oct 25, 2024
2 parents f752b8c + b03265f commit 22f0fca
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 6 deletions.
55 changes: 51 additions & 4 deletions modelskill/comparison/_comparer_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def taylor(

def residual_hist(
self, bins=100, title=None, color=None, figsize=None, ax=None, **kwargs
) -> matplotlib.axes.Axes:
) -> matplotlib.axes.Axes | list[matplotlib.axes.Axes]:
"""plot histogram of residual values
Parameters
Expand All @@ -776,20 +776,67 @@ def residual_hist(
residual color, by default "#8B8D8E"
figsize : tuple, optional
figure size, by default None
ax : matplotlib.axes.Axes, optional
ax : matplotlib.axes.Axes | list[matplotlib.axes.Axes], optional
axes to plot on, by default None
**kwargs
other keyword arguments to plt.hist()
Returns
-------
matplotlib.axes.Axes
matplotlib.axes.Axes | list[matplotlib.axes.Axes]
"""
cmp = self.comparer

if cmp.n_models == 1:
return self._residual_hist_one_model(
bins=bins,
title=title,
color=color,
figsize=figsize,
ax=ax,
mod_name=cmp.mod_names[0],
**kwargs,
)

if ax is not None and len(ax) != len(cmp.mod_names):
raise ValueError("Number of axes must match number of models")

axs = ax if ax is not None else [None] * len(cmp.mod_names)

for i, mod_name in enumerate(cmp.mod_names):
cmp_model = cmp.sel(model=mod_name)
ax_mod = cmp_model.plot.residual_hist(
bins=bins,
title=title,
color=color,
figsize=figsize,
ax=axs[i],
**kwargs,
)
axs[i] = ax_mod

return axs

def _residual_hist_one_model(
self,
bins=100,
title=None,
color=None,
figsize=None,
ax=None,
mod_name=None,
**kwargs,
) -> matplotlib.axes.Axes:
"""Residual histogram for one model only"""
_, ax = _get_fig_ax(ax, figsize)

default_color = "#8B8D8E"
color = default_color if color is None else color
title = f"Residuals, {self.comparer.name}" if title is None else title
title = (
f"Residuals, Observation: {self.comparer.name}, Model: {mod_name}"
if title is None
else title
)
ax.hist(self.comparer._residual, bins=bins, color=color, **kwargs)
ax.set_title(title)
ax.set_xlabel(f"Residuals of {self.comparer._unit_text}")
Expand Down
14 changes: 12 additions & 2 deletions tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,8 @@ def test_to_dataframe_tc(tc):

# ======================== plotting ========================

PLOT_FUNCS_RETURNING_MANY_AX = ["scatter", "hist", "residual_hist"]


@pytest.fixture(
params=[
Expand All @@ -727,11 +729,20 @@ def test_to_dataframe_tc(tc):
def pc_plot_function(pc, request):
func = getattr(pc.plot, request.param)
# special cases requiring a model to be selected
if request.param in ["scatter", "hist", "residual_hist"]:
if request.param in PLOT_FUNCS_RETURNING_MANY_AX:
func = getattr(pc.sel(model=0).plot, request.param)
return func


@pytest.mark.parametrize("kind", PLOT_FUNCS_RETURNING_MANY_AX)
def test_plots_returning_multiple_axes(pc, kind):
n_models = 2
func = getattr(pc.plot, kind)
ax = func()
assert len(ax) == n_models
assert all(isinstance(a, plt.Axes) for a in ax)


def test_plot_returns_an_object(pc_plot_function):
obj = pc_plot_function()
assert obj is not None
Expand Down Expand Up @@ -824,7 +835,6 @@ def test_plots_directional(pt_df):


def test_from_matched_track_data():

df = pd.DataFrame(
{
"lat": [55.0, 55.1],
Expand Down

0 comments on commit 22f0fca

Please sign in to comment.