Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bambi.interpret_plot_predictions() fails when we condition and color by the same categorical variable #870

Open
tomicapretto opened this issue Jan 19, 2025 · 3 comments
Assignees
Labels
bug good first issue If you want to contribute but are not sure where to get started, this issue is for you!

Comments

@tomicapretto
Copy link
Collaborator

See this example

import bambi as bmb
import numpy as np
import pandas as pd

rng = np.random.default_rng(1234)
levels = list("ABC")
df = pd.DataFrame({"y": rng.normal(size=100), "factor": rng.choice(levels, size=100)})

model = bmb.Model("y ~ factor", data=df)
idata = model.fit()

bmb.interpret.plot_predictions(
    model=model,
    idata=idata,
    conditional="factor",
    subplot_kwargs={"group": "factor"}
);
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[1], line 12
      9 model = bmb.Model("y ~ factor", data=df)
     10 idata = model.fit()
---> 12 bmb.interpret.plot_predictions(
     13     model=model,
     14     idata=idata,
     15     conditional="factor",
     16     subplot_kwargs={"group": "factor"}
     17 );

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/bambi/interpret/plotting.py:230, in plot_predictions(model, idata, conditional, average_by, target, sample_new_groups, pps, use_hdi, prob, transforms, legend, ax, fig_kwargs, subplot_kwargs)
    226     axes = plot_numeric(covariates, cap_data, transforms, legend, axes)
    227 elif is_categorical_dtype(cap_data[covariates.main]) or is_string_dtype(
    228     cap_data[covariates.main]
    229 ):
--> 230     axes = plot_categoric(covariates, cap_data, legend, axes)
    231 else:
    232     raise ValueError("Main covariate must be numeric or categoric.")

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/bambi/interpret/plot_types.py:223, in plot_categoric(covariates, plot_data, legend, axes)
    221         idx = (plot_data[color] == clr).to_numpy()
    222         idxs = idxs_main + colors_offset[i]
--> 223         ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
    224         ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
    225 elif not "group" in covariates and "panel" in covariates:

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/matplotlib/__init__.py:1473, in _preprocess_data.<locals>.inner(ax, data, *args, **kwargs)
   1470 @functools.wraps(func)
   1471 def inner(ax, *args, data=None, **kwargs):
   1472     if data is None:
-> 1473         return func(
   1474             ax,
   1475             *map(sanitize_sequence, args),
   1476             **{k: sanitize_sequence(v) for k, v in kwargs.items()})
   1478     bound = new_sig.bind(ax, *args, **kwargs)
   1479     auto_label = (bound.arguments.get(label_namer)
   1480                   or bound.kwargs.get(label_namer))

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/matplotlib/axes/_axes.py:4787, in Axes.scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, edgecolors, plotnonfinite, **kwargs)
   4785 y = np.ma.ravel(y)
   4786 if x.size != y.size:
-> 4787     raise ValueError("x and y must be the same size")
   4789 if s is None:
   4790     s = (20 if mpl.rcParams['_internal.classic_mode'] else
   4791          mpl.rcParams['lines.markersize'] ** 2.0)

ValueError: x and y must be the same size

These are the problematic lines

idx = (plot_data[color] == clr).to_numpy()
idxs = idxs_main + colors_offset[i]
ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")

Length of idxs is greater than sum(idx). In this case we need to also slice idxs with idx. But I'm not sure if that would cause issues with other scenarios.

@GStechschulte
Copy link
Collaborator

Thanks for the issue and example @tomicapretto . I will look into this.

@GStechschulte GStechschulte self-assigned this Jan 20, 2025
@GStechschulte GStechschulte added the good first issue If you want to contribute but are not sure where to get started, this issue is for you! label Jan 22, 2025
@jgyasu
Copy link

jgyasu commented Jan 22, 2025

Hi @GStechschulte! Will you work on this? If not then can I be assigned? Asking since you added the good first issue label :)

@GStechschulte
Copy link
Collaborator

Hey @jgyasu that would be great if you could give me a hand here. Feel free to open a draft PR and I will give it a review. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug good first issue If you want to contribute but are not sure where to get started, this issue is for you!
Projects
None yet
Development

No branches or pull requests

3 participants