Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling NaN inputs #213

Open
ciberger opened this issue Oct 28, 2024 · 0 comments
Open

Handling NaN inputs #213

ciberger opened this issue Oct 28, 2024 · 0 comments

Comments

@ciberger
Copy link

Hi! Thanks for this fantastic package. I'm struggling to find a solution that would handle NaNs. I'm passing a catboost model as an estimator to HSTreeClassifier, which correctly handles missing values.

from catboost import CatBoostClassifier
from imodels import HSTreeClassifier

clf = CatBoostClassifier()
model = HSTreeClassifier(estimator_=clf)
model = model.fit(X_train, y_train)

Error message

ValueError: Input contains NaN
File <command-4108227421121860>, line 6
      4 clf = CatBoostClassifier()
      5 model = HSTreeClassifier(estimator_=clf)
----> 6 model = model.fit(X_train, y_train, feature_names=_FEATURES)
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-23c4ea58-433a-40f0-89be-b3d953b89efe/lib/python3.11/site-packages/imodels/tree/hierarchical_shrinkage.py:82, in HSTree.fit(self, X, y, sample_weight, *args, **kwargs)
     78 def fit(self, X, y, sample_weight=None, *args, **kwargs):
     79     # remove feature_names if it exists (note: only works as keyword-arg)
     80     # None returned if not passed
     81     feature_names = kwargs.pop("feature_names", None)
---> 82     X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
     83     if feature_names is not None:
     84         self.feature_names = feature_names
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-23c4ea58-433a-40f0-89be-b3d953b89efe/lib/python3.11/site-packages/imodels/util/arguments.py:26, in check_fit_arguments(model, X, y, feature_names)
     24 if scipy.sparse.issparse(X):
     25     X = X.toarray()
---> 26 X, y = check_X_y(X, y)
     27 _, model.n_features_in_ = X.shape
     28 assert len(model.feature_names_) == model.n_features_in_, 'feature_names should be same size as X.shape[1]'
File /databricks/python/lib/python3.11/site-packages/sklearn/utils/validation.py:1147, in check_X_y(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, estimator)
   1142         estimator_name = _check_estimator_name(estimator)
   1143     raise ValueError(
   1144         f"{estimator_name} requires y to be passed, but the target y is None"
   1145     )
-> 1147 X = check_array(
   1148     X,
   1149     accept_sparse=accept_sparse,
   1150     accept_large_sparse=accept_large_sparse,
   1151     dtype=dtype,
   1152     order=order,
   1153     copy=copy,
   1154     force_all_finite=force_all_finite,
   1155     ensure_2d=ensure_2d,
   1156     allow_nd=allow_nd,
   1157     ensure_min_samples=ensure_min_samples,
   1158     ensure_min_features=ensure_min_features,
   1159     estimator=estimator,
   1160     input_name="X",
   1161 )
   1163 y = _check_y(y, multi_output=multi_output, y_numeric=y_numeric, estimator=estimator)
   1165 check_consistent_length(X, y)
File /databricks/python/lib/python3.11/site-packages/sklearn/utils/validation.py:959, in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)
    953         raise ValueError(
    954             "Found array with dim %d. %s expected <= 2."
    955             % (array.ndim, estimator_name)
    956         )
    958     if force_all_finite:
--> 959         _assert_all_finite(
    960             array,
    961             input_name=input_name,
    962             estimator_name=estimator_name,
    963             allow_nan=force_all_finite == "allow-nan",
    964         )
    966 if ensure_min_samples > 0:
    967     n_samples = _num_samples(array)
File /databricks/python/lib/python3.11/site-packages/sklearn/utils/validation.py:109, in _assert_all_finite(X, allow_nan, msg_dtype, estimator_name, input_name)
    107 if X.dtype == np.dtype("object") and not allow_nan:
    108     if _object_dtype_isnan(X).any():
--> 109         raise ValueError("Input contains NaN")
    111 # We need only consider float arrays, hence can early return for all else.
    112 if not xp.isdtype(X.dtype, ("real floating", "complex floating")):
@ciberger ciberger changed the title Handle NaN inputs Handling NaN inputs Oct 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant