Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_variables_names() in class ModelStatsmodels does not return all variables which causes errors #91

Closed
RoelVerbelen opened this issue Feb 24, 2024 · 3 comments · Fixed by #93

Comments

@RoelVerbelen
Copy link
Contributor

As far as I'm aware, there's no easy way to extract the names of the orginal columns used in a patsy formula, see these open tickets here and here. So you have to rely on regular expressions for now.

However the current code does not capture all complex scenarios which can occur in formulas, leading to errors for marginaleffects.

I try to illustrate that in the below code and suggest a potential alternative (which I'm currently relying on): detecting whether any of the data columns, surrounded by word boudaries, occurs in the model formula. It's still not perfect, as it can capture non model terms (such as Treatment, Good, minimum, df, constraints, center for the example below, if these exists as columns in the data), but at least it won't miss any of the predictors.

import re

import numpy as np
import pandas as pd
import polars as pl
import statsmodels.formula.api as smf
from marginaleffects import predictions
from marginaleffects.sanitize_model import sanitize_model

diamonds = pd.read_csv("https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/csv/ggplot2/diamonds.csv")

# Complex formula with interaction term only, categorical with custom reference level, and spline
model = smf.ols("price ~ depth:color + C(cut, Treatment('Good')) + cr(np.minimum(carat, 0.8), df=5, constraints='center')", data = diamonds).fit()

# Fails: ValueError: There is no valid column name in `variables`.
predictions(model, newdata=diamonds, by ="cut")

# Create ModelStatsmodels object
self = sanitize_model(model)

# Variable list shows up empty
self.get_variables_names()

# Current code: Lines 53-56 in model_statsmodels.py
variables = self.model.model.exog_names
variables = [re.sub("\[.*\]", "", x) for x in variables]
variables = [x for x in variables if x in self.modeldata.columns]
variables = pl.Series(variables).unique().to_list()
# []

# Proposed code
formula = self.formula
columns = self.modeldata.columns
variables = list({var for var in columns if re.search(rf"\b{re.escape(var)}\b", formula)})
# ['price', 'carat', 'cut', 'color', 'depth']
@vincentarelbundock
Copy link
Owner

I like this a lot! Thanks for the suggestion.

@vincentarelbundock
Copy link
Owner

Thanks again for the report. Fixed and on pypi as 0.0.9

@RoelVerbelen
Copy link
Contributor Author

Thank you for incorporating this, @vincentarelbundock and @LamAdr !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants