Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit test to investigate issue 1046 #1140

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions tests/test_issues_2024.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,138 @@ def Classifier(features: list[str]) -> base.BaseEstimator:
)
assert modelengine is not None

@ignore_warnings(category=(ConvergenceWarning, FutureWarning))
def test_issue_1046(self):
import pandas as pd
import numpy as np

from xgboost import XGBClassifier

from skl2onnx.common.data_types import (
FloatTensorType,
StringTensorType,
Int64TensorType,
DoubleTensorType,
)
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.common.shape_calculator import (
calculate_linear_classifier_output_shapes,
)

from onnxmltools.convert.xgboost.operator_converters.XGBoost import (
convert_xgboost,
)

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

def convert_dataframe_schema(df, drop=None):
inputs = []
for k, v in zip(df.columns, df.dtypes):
if drop is not None and k in drop:
continue
if v == "int64":
t = Int64TensorType([None, 1])
elif v == "float32":
t = FloatTensorType([None, 1])
elif v == "float64":
t = DoubleTensorType([None, 1])
else:
t = StringTensorType([None, 1])
inputs.append((k, t))
return inputs

def get_categorical_features(df: pd.DataFrame):
dtype_ser = df.dtypes
categorical_features = dtype_ser[dtype_ser == "object"].index.tolist()
return categorical_features

numeric_transformer = Pipeline(
steps=[
("imputer", SimpleImputer(strategy="median")),
("scaler", StandardScaler()),
]
)

categorical_transformer = Pipeline(
steps=[("onehot", OneHotEncoder(handle_unknown="ignore"))]
)

X_train = pd.DataFrame(
[
{"x1": 0.5, "c1": "A", "c2": "B"},
{"x1": 0.6, "c1": "B", "c2": "B"},
{"x1": 0.7, "c1": "A", "c2": "C"},
{"x1": 0.8, "c1": "A", "c2": "B"},
{"x1": 0.9, "c1": "B", "c2": "B"},
{"x1": 1.1, "c1": "A", "c2": "C"},
{"x1": 1.2, "c1": "B", "c2": "B"},
{"x1": 1.3, "c1": "A", "c2": "B"},
]
)
y_train = (np.arange(X_train.shape[0]) % 3) % 2
X_train["x1"] = X_train["x1"].astype(np.float32)

categorical_features = get_categorical_features(X_train)
numeric_features = [f for f in X_train.columns if f not in categorical_features]

preprocessor = ColumnTransformer(
transformers=[
("num", numeric_transformer, numeric_features),
("cat", categorical_transformer, categorical_features),
]
)

est_val = Pipeline(
steps=[
("preprocessor", preprocessor),
(
"classifier",
XGBClassifier(
enable_categorical=False,
random_state=42,
),
),
]
)

est_val.fit(X_train, y_train)

# guess model input schema from training data
schema = convert_dataframe_schema(X_train)

# register XGBoost model converter
update_registered_converter(
XGBClassifier,
"XGBoostXGBClassifier",
calculate_linear_classifier_output_shapes,
convert_xgboost,
options={"nocl": [True, False], "zipmap": [True, False, "columns"]},
)

# convert sklearn pipeline to ONNX
model_onnx = convert_sklearn(
est_val,
"xgb_pipeline",
schema,
target_opset={"": 18, "ai.onnx.ml": 3},
options={"zipmap": False},
)

import onnxruntime

sess = onnxruntime.InferenceSession(
model_onnx.SerializeToString(),
providers=["CPUExecutionProvider"],
)

expected = est_val.predict_proba(X_train)
feeds = {c: X_train[c].values.reshape((-1, 1)) for c in X_train.columns}
got = sess.run(None, feeds)
np.testing.assert_allclose(expected, got[1], atol=1e-5)


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading