Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add converter for QuantileTransformer #705

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions skl2onnx/_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@
MaxAbsScaler,
MinMaxScaler,
PolynomialFeatures,
QuantileTransformer,
RobustScaler,
StandardScaler,
)
Expand Down Expand Up @@ -436,6 +437,7 @@ def build_sklearn_operator_name_map():
PolynomialFeatures,
PowerTransformer,
QuadraticDiscriminantAnalysis,
QuantileTransformer,
RadiusNeighborsClassifier,
RadiusNeighborsRegressor,
RandomForestClassifier,
Expand Down
2 changes: 2 additions & 0 deletions skl2onnx/operator_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from . import polynomial_features
from . import power_transformer
from . import quadratic_discriminant_analysis
from . import quantile_transformer
from . import random_forest
from . import random_projection
from . import random_trees_embedding
Expand Down Expand Up @@ -116,6 +117,7 @@
polynomial_features,
power_transformer,
quadratic_discriminant_analysis,
quantile_transformer,
random_forest,
random_projection,
random_trees_embedding,
Expand Down
57 changes: 57 additions & 0 deletions skl2onnx/operator_converters/quantile_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0

import numpy as np
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..common.data_types import guess_numpy_type


def convert_quantile_transformer(
scope: Scope, operator: Operator, container: ModelComponentContainer
):
"""Converter for QuantileTransformer"""
# op_in = operator.inputs[0]
# op_out = operator.outputs[0].full_name
op = operator.raw_operator
# opv = container.target_opset
dtype = guess_numpy_type(operator.inputs[0].type)
if dtype != np.float64:
dtype = np.float32
Fixed Show fixed Hide fixed

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable dtype is not used.
if op.output_distribution != "uniform":
raise RuntimeError(
"Conversion of QuantileTransformer with output_distribution=%r "
"is not supported." % op.output_distribution
)

# ref = op.references_
# quantiles = op.quantiles_

# Code of QuantileTransformer.transform
# lower_bound_x = quantiles[0]
# upper_bound_x = quantiles[-1]
# lower_bound_y = 0
# upper_bound_y = 1
# lower_bounds_idx = (X_col == lower_bound_x)
# upper_bounds_idx = (X_col == upper_bound_x)

# isfinite_mask = ~np.isnan(X_col)
# xcolf = X_col[isfinite_mask]
# X_col[isfinite_mask] = .5 * (
# np.interp(xcolf, quantiles, self.references_)
# - np.interp(-xcolf, -quantiles[::-1], -self.references_[::-1]))
# X_col[upper_bounds_idx] = upper_bound_y
# X_col[lower_bounds_idx] = lower_bound_y

# Strategy
# implement interpolation in Onnx
# * use 2 trees to determine the quantile x (qx, dx)
# * use 2 trees to determine the quantile y (qy, dy)
# do : (x - q) * dx * dy + qy

# y.set_onnx_name_prefix('quantile')
# y.add_to(scope, container)
raise NotImplementedError()


register_converter("SklearnQuantileTransformer", convert_quantile_transformer)
2 changes: 2 additions & 0 deletions skl2onnx/shape_calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from . import polynomial_features
from . import power_transformer
from . import quadratic_discriminant_analysis
from . import quantile_transformer
from . import random_projection
from . import random_trees_embedding
from . import replace_op
Expand Down Expand Up @@ -90,6 +91,7 @@
polynomial_features,
power_transformer,
quadratic_discriminant_analysis,
quantile_transformer,
random_projection,
random_trees_embedding,
replace_op,
Expand Down
24 changes: 24 additions & 0 deletions skl2onnx/shape_calculators/quantile_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SPDX-License-Identifier: Apache-2.0

import copy
from ..common._registration import register_shape_calculator
from ..common.utils import check_input_and_output_numbers, check_input_and_output_types
from ..common.data_types import FloatTensorType, Int64TensorType, DoubleTensorType


def quantile_transformer_shape_calculator(operator):
"""Shape calculator for QuantileTransformer"""
check_input_and_output_numbers(operator, output_count_range=1)
check_input_and_output_types(
operator, good_input_types=[FloatTensorType, Int64TensorType, DoubleTensorType]
)

N = operator.inputs[0].get_first_dimension()
model = operator.raw_operator
operator.outputs[0].type = copy.deepcopy(operator.inputs[0].type)
operator.outputs[0].type.shape = [N, model.quantiles_.shape[1]]


register_shape_calculator(
"SklearnQuantileTransformer", quantile_transformer_shape_calculator
)
39 changes: 39 additions & 0 deletions tests/test_sklearn_quantile_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0

"""
Tests scikit-learn's polynomial features converter.
"""
import unittest
from distutils.version import StrictVersion
import numpy as np
import onnx
from sklearn.preprocessing import QuantileTransformer
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from test_utils import dump_data_and_model


class TestSklearnQuantileTransformer(unittest.TestCase):
@unittest.skipIf(
StrictVersion(onnx.__version__) < StrictVersion("1.4.0"),
reason="ConstantOfShape not available",
)
def test_quantile_transformer(self):
X = np.empty((100, 2), dtype=np.float32)
X[:, 0] = np.arange(X.shape[0])
X[:, 1] = np.arange(X.shape[0]) * 2
model = QuantileTransformer(n_quantiles=6).fit(X)
model_onnx = convert_sklearn(
model, "test", [("input", FloatTensorType([None, X.shape[1]]))]
)
self.assertTrue(model_onnx is not None)
Fixed Show fixed Hide fixed

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a is not b) cannot provide an informative message. Using assertIsNot(a, b) instead will give more informative messages.
dump_data_and_model(
X.astype(np.float32),
model,
model_onnx,
basename="SklearnQuantileTransformer",
)


if __name__ == "__main__":
unittest.main()
Loading