From 19a868bea65018ac35b1ef1cd839b02cea7bd61b Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 14 Nov 2024 18:27:01 +0100 Subject: [PATCH 1/2] Add unit test to investigate issue 1046 Signed-off-by: xadupre --- tests/test_issues_2024.py | 132 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/tests/test_issues_2024.py b/tests/test_issues_2024.py index 15c01469d..3f8bfb145 100644 --- a/tests/test_issues_2024.py +++ b/tests/test_issues_2024.py @@ -271,6 +271,138 @@ def Classifier(features: list[str]) -> base.BaseEstimator: ) assert modelengine is not None + @ignore_warnings(category=(ConvergenceWarning, FutureWarning)) + def test_issue_1046(self): + import pandas as pd + import numpy as np + + from xgboost import XGBClassifier + + from skl2onnx.common.data_types import ( + FloatTensorType, + StringTensorType, + Int64TensorType, + DoubleTensorType, + ) + from skl2onnx import convert_sklearn, to_onnx, update_registered_converter + from skl2onnx.common.shape_calculator import ( + calculate_linear_classifier_output_shapes, + ) + + from onnxmltools.convert.xgboost.operator_converters.XGBoost import ( + convert_xgboost, + ) + + from sklearn.impute import SimpleImputer + from sklearn.preprocessing import StandardScaler, OneHotEncoder + from sklearn.compose import ColumnTransformer + from sklearn.pipeline import Pipeline + + def convert_dataframe_schema(df, drop=None): + inputs = [] + for k, v in zip(df.columns, df.dtypes): + if drop is not None and k in drop: + continue + if v == "int64": + t = Int64TensorType([None, 1]) + elif v == "float32": + t = FloatTensorType([None, 1]) + elif v == "float64": + t = DoubleTensorType([None, 1]) + else: + t = StringTensorType([None, 1]) + inputs.append((k, t)) + return inputs + + def get_categorical_features(df: pd.DataFrame): + dtype_ser = df.dtypes + categorical_features = dtype_ser[dtype_ser == "object"].index.tolist() + return categorical_features + + numeric_transformer = Pipeline( + steps=[ + ("imputer", SimpleImputer(strategy="median")), + ("scaler", StandardScaler()), + ] + ) + + categorical_transformer = Pipeline( + steps=[("onehot", OneHotEncoder(handle_unknown="ignore"))] + ) + + X_train = pd.DataFrame( + [ + {"x1": 0.5, "c1": "A", "c2": "B"}, + {"x1": 0.6, "c1": "B", "c2": "B"}, + {"x1": 0.7, "c1": "A", "c2": "C"}, + {"x1": 0.8, "c1": "A", "c2": "B"}, + {"x1": 0.9, "c1": "B", "c2": "B"}, + {"x1": 1.1, "c1": "A", "c2": "C"}, + {"x1": 1.2, "c1": "B", "c2": "B"}, + {"x1": 1.3, "c1": "A", "c2": "B"}, + ] + ) + y_train = (np.arange(X_train.shape[0]) % 3) % 2 + X_train["x1"] = X_train["x1"].astype(np.float32) + + categorical_features = get_categorical_features(X_train) + numeric_features = [f for f in X_train.columns if not f in categorical_features] + + preprocessor = ColumnTransformer( + transformers=[ + ("num", numeric_transformer, numeric_features), + ("cat", categorical_transformer, categorical_features), + ] + ) + + est_val = Pipeline( + steps=[ + ("preprocessor", preprocessor), + ( + "classifier", + XGBClassifier( + enable_categorical=False, + random_state=42, + ), + ), + ] + ) + + est_val.fit(X_train, y_train) + + # guess model input schema from training data + schema = convert_dataframe_schema(X_train) + + # register XGBoost model converter + update_registered_converter( + XGBClassifier, + "XGBoostXGBClassifier", + calculate_linear_classifier_output_shapes, + convert_xgboost, + options={"nocl": [True, False], "zipmap": [True, False, "columns"]}, + ) + + # convert sklearn pipeline to ONNX + model_onnx = convert_sklearn( + est_val, + "xgb_pipeline", + schema, + target_opset={"": 18, "ai.onnx.ml": 3}, + options={"zipmap": False}, + ) + + import onnxruntime + + sess = onnxruntime.InferenceSession( + model_onnx.SerializeToString(), + providers=["CPUExecutionProvider"], + ) + + expected = est_val.predict_proba(X_train) + feeds = {c: X_train[c].values.reshape((-1, 1)) for c in X_train.columns} + got = sess.run(None, feeds) + np.testing.assert_allclose(expected, got[1], atol=1e-5) + if __name__ == "__main__": unittest.main(verbosity=2) From 2164398ad9b51057557977b27730a2b92e5d7975 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 14 Nov 2024 18:32:21 +0100 Subject: [PATCH 2/2] black --- tests/test_issues_2024.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_issues_2024.py b/tests/test_issues_2024.py index 3f8bfb145..32e9ff257 100644 --- a/tests/test_issues_2024.py +++ b/tests/test_issues_2024.py @@ -284,7 +284,7 @@ def test_issue_1046(self): Int64TensorType, DoubleTensorType, ) - from skl2onnx import convert_sklearn, to_onnx, update_registered_converter + from skl2onnx import convert_sklearn, update_registered_converter from skl2onnx.common.shape_calculator import ( calculate_linear_classifier_output_shapes, ) @@ -346,7 +346,7 @@ def get_categorical_features(df: pd.DataFrame): X_train["x1"] = X_train["x1"].astype(np.float32) categorical_features = get_categorical_features(X_train) - numeric_features = [f for f in X_train.columns if not f in categorical_features] + numeric_features = [f for f in X_train.columns if f not in categorical_features] preprocessor = ColumnTransformer( transformers=[