Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using ColumnTransformer with OneHotEncoder creates lots of disprepency between the raw predictions and onnx inference #1046

Open
AchilleSoulie opened this issue Nov 14, 2023 · 4 comments

Comments

@AchilleSoulie
Copy link

AchilleSoulie commented Nov 14, 2023

I am trying to do a simple sklearn pipeline with a StandardScaler for the numerical values and a OneHotEncoder for the categorical.

As shown here, I use the CastTransformer to reduce discrepencies between the raw prediction and the onnx inference induces by the StandardScaler and it works fine.

The problem arise when I use the OneHotEncoder for the categorical features, which creates a lot of discrepencies.

Here the complete code :

import numpy as np
from xgboost import XGBClassifier
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import (
    FloatTensorType,
    StringTensorType,
    BooleanTensorType,
)
from skl2onnx.sklapi import CastTransformer
from skl2onnx.common.data_types import Int64TensorType
import onnxruntime as rt

from skl2onnx import update_registered_converter
from skl2onnx.common.shape_calculator import calculate_linear_classifier_output_shapes
from onnxmltools.convert.xgboost.operator_converters.XGBoost import convert_xgboost


numeric_transformer = Pipeline(
    steps=[
        ("cast64", CastTransformer(dtype=np.float64)),
        ("scaler", StandardScaler()),
        ("cast", CastTransformer()),
    ]
)

categorical_transformer = Pipeline(
    steps=[
        ("onehot", OneHotEncoder(handle_unknown="ignore", dtype=np.float32, sparse_output=False)),
    ]
)

trans = [
    ("num", numeric_transformer, numeric_features),
    ("cat", categorical_transformer, categorical_features)
]
preprocessor = ColumnTransformer(transformers=trans, remainder="drop")

clf = Pipeline(
    steps=[
        ("preprocessor", preprocessor),
        ("classifier", XGBClassifier()),
    ]
)

clf.fit(X_train, y_train.astype(int))

def convert_dataframe_schema(df, drop=None):
    inputs = []
    for k, v in zip(df.columns, df.dtypes):
        if drop is not None and k in drop:
            continue
        if v == "int64":
            t = Int64TensorType([None, 1])
        elif (v == "float64") or (v == "float32"):
            t = FloatTensorType([None, 1])
        elif v == "bool":
            t = BooleanTensorType([None, 1])
        else:
            t = StringTensorType([None, 1])
        inputs.append((k, t))
    return inputs


initial_inputs = convert_dataframe_schema(X_train)
model_onnx = convert_sklearn(
    clf,
    "pipeline_titanic",
    initial_inputs,
    target_opset=12
)

raw_prediction = clf.predict_proba(X_test)[:, 1]

inputs = {c: X_test[c].values for c in X_test.columns}
for c in fl.training_features['numerical']:
    v = X_test[c].dtype
    if v == "float64":
       inputs[c] = inputs[c].astype(np.float32)
       print(c)
for k in inputs:
    inputs[k] = inputs[k].reshape((inputs[k].shape[0], 1))

sess = rt.InferenceSession(model_onnx.SerializeToString(), providers=["CPUExecutionProvider"])
pred_onx = sess.run(None, inputs) 
onnx_prediction = np.array([p[1] for p in pred_onx[1]])

def diff(p1, p2):
    p1 = p1.ravel()
    p2 = p2.ravel()
    d = np.abs(p2 - p1)
    return d.max(), (d / np.abs(p1)).max()

print(diff(raw_prediction, onnx_prediction))

Without integrating the categorical columns (by commenting out the ("cat", categorical_transformer, categorical_features) line), I get the folowing differences : (1.7136335372924805e-07, 0.09922722248309723).

But integrating the categorical columns (and so OneHotEncoder), I get the folowing differences : (0.4294445514678955, 11.320544924924578)

I was expecting some differences in disprepency but not such a large increase.

Is this the expected behavior ? Is there no way to reduce it ?

@AchilleSoulie AchilleSoulie changed the title Using ColumnTransformer with OneHotEncoder creates lots of difference between the true predictions and onnx inference Using ColumnTransformer with OneHotEncoder creates lots of disprepency between the raw predictions and onnx inference Nov 14, 2023
@xadupre
Copy link
Collaborator

xadupre commented Nov 23, 2023

Is is possible to know if the predictions are all wrong or only a couple of them? You can use options={"zipmap": False} when calling convert_sklearn to get a matrix of proabilities and not a list of dictionaries. Did you check the discrepancies of your pipeline without the classifier? I assume there is none. In that case, the converter for xgboost is probably the one introducing the errors. Could you try with xgboost<2? It seems their API changes and the converter is no longer able to extract the information it needs.

@github-project-automation github-project-automation bot moved this to Can Fix but Waiting for an Answer in Can Fix Aug 29, 2024
@codingcyclist
Copy link

@xadupre I can reproduce the same issue. I'm trying to convert an sklearn pipleine with an XGBClassifier that boasts OneHotEncoder for categoricals and SimpleImputer / StandardScaler for numerical features (see full pipeline code below).

I'm seeing vast differences in predicted probabilities between the "native" sklearn and the onnx converted model (see chart on the left). Only if I comment-out the categorical transformer, do the predicted probabilities align (chart on the left).

image

Here's my environment:

pandas==1.5.3
numpy==1.26.4
xgboost==1.7.6
libxgboost==1.7.6
py-xgboost==1.7.6
skl2onnx==1.17.0
onnxmltools==1.12.0
onnxruntime==1.19.2
scikit-learn==1.5.2

Here's the pipeline + ONNX conversion code

import pandas as pd
import numpy as np

from xgboost import XGBClassifier

from skl2onnx.common.data_types import FloatTensorType, StringTensorType, Int64TensorType, DoubleTensorType
from skl2onnx import convert_sklearn, to_onnx, update_registered_converter
from skl2onnx.common.shape_calculator import calculate_linear_classifier_output_shapes

from onnxmltools.convert.xgboost.operator_converters.XGBoost import convert_xgboost

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

def convert_dataframe_schema(df, drop=None):
    inputs = []
    for k, v in zip(df.columns, df.dtypes):
        if drop is not None and k in drop:
            continue
        if v == "int64":
            t = Int64TensorType([None, 1])
        elif v == "float64":
            t = FloatTensorType([None, 1])
        else:
            t = StringTensorType([None, 1])
        inputs.append((k, t))
    return inputs

def get_categorical_features(df: pd.DataFrame):
    dtype_ser = df.dtypes
    categorical_features = dtype_ser[dtype_ser == "object"].index.tolist()
    return categorical_features

numeric_transformer = Pipeline(
    steps=[
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler())
    ]
)

categorical_transformer = Pipeline(
    steps=[
        ("onehot", OneHotEncoder(handle_unknown="ignore"))
    ]
)

categorical_features = get_categorical_features(X_train)
numeric_features = [f for f in X_train.columns if not f in categorical_features]

preprocessor = ColumnTransformer(
    transformers=[
        ("num", numeric_transformer, numeric_features),
        #("cat", categorical_transformer, categorical_features),
    ]
)

est_val = Pipeline(
    steps=[
        ("preprocessor", preprocessor),
        ("classifier", XGBClassifier(
            enable_categorical=False,
            random_state=42,
        ))
    ]
)


est_val.fit(
    X_train,
    y_train
)

# guess model input schema from training data
schema = convert_dataframe_schema(X_train)

# register XGBoost model converter
update_registered_converter(
    XGBClassifier,
    "XGBoostXGBClassifier",
    calculate_linear_classifier_output_shapes,
    convert_xgboost,
    options={"nocl": [True, False], "zipmap": [True, False, "columns"]},
)

# convert sklearn pipeline to ONNX
model_onnx = convert_sklearn(
    est_val,
    "xgb_pipeline",
    schema,
    target_opset={"": 12, "ai.onnx.ml": 3},
)

# save onnx model file
with open("xgb_pipeline.onnx", "wb") as f:
    f.write(model_onnx.SerializeToString())

@codingcyclist
Copy link

Quick update here, after some deeper digging, I apparently found the root cause: it's the sparsity of the OneHotEncoder output. After setting OneHotEncoder(sparse_output=False), the predicted probabilities are perfectly aligned 🙌

@xadupre
Copy link
Collaborator

xadupre commented Nov 14, 2024

If you see any bug, could you complete the PR #1140 with some dummy data which fail.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Can Fix but Waiting for an Answer
Development

No branches or pull requests

3 participants