Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support to sklearn TargetEncoder #1137

Merged
merged 14 commits into from
Feb 20, 2025
2 changes: 2 additions & 0 deletions skl2onnx/_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@
LabelEncoder,
Normalizer,
OneHotEncoder,
TargetEncoder,
)

try:
Expand Down Expand Up @@ -511,6 +512,7 @@ def build_sklearn_operator_name_map():
RidgeClassifierCV: "SklearnLinearClassifier",
SGDRegressor: "SklearnLinearRegressor",
StandardScaler: "SklearnScaler",
TargetEncoder: "SklearnTargetEncoder",
TheilSenRegressor: "SklearnLinearRegressor",
}
)
Expand Down
2 changes: 2 additions & 0 deletions skl2onnx/operator_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from . import sgd_oneclass_svm
from . import stacking
from . import support_vector_machines
from . import target_encoder
from . import text_vectoriser
from . import tfidf_transformer
from . import tfidf_vectoriser
Expand Down Expand Up @@ -128,6 +129,7 @@
sgd_oneclass_svm,
stacking,
support_vector_machines,
target_encoder,
text_vectoriser,
tfidf_transformer,
tfidf_vectoriser,
Expand Down
99 changes: 99 additions & 0 deletions skl2onnx/operator_converters/target_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
import numpy as np

from ..common._apply_operation import apply_cast, apply_concat, apply_reshape
from ..common._container import ModelComponentContainer
from ..common.data_types import (
FloatTensorType,
Int64TensorType,
)
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..proto import onnx_proto


def convert_sklearn_target_encoder(
scope: Scope, operator: Operator, container: ModelComponentContainer
):
op = operator.raw_operator
result = []
input_idx = 0
dimension_idx = 0

# NotImplementedError( # TODO: assert that we have binary output
if (op.target_type_ == "multiclass") or (
isinstance(op.classes_.dtype, np.int64) and (len(op.classes_) > 2)
):
raise NotImplementedError("multiclass TargetEncoder is not supported")
for categories, encodings in zip(op.categories_, op.encodings_):
if len(categories) == 0:
continue

current_input = operator.inputs[input_idx]
if current_input.get_second_dimension() == 1:
feature_column = current_input
input_idx += 1
else:
index_name = scope.get_unique_variable_name("index")
container.add_initializer(
index_name, onnx_proto.TensorProto.INT64, [], [dimension_idx]
)

feature_column = scope.declare_local_variable(
"feature_column",
current_input.type.__class__([current_input.get_first_dimension(), 1]),
)

container.add_node(
"ArrayFeatureExtractor",
[current_input.onnx_name, index_name],
feature_column.onnx_name,
op_domain="ai.onnx.ml",
name=scope.get_unique_operator_name("ArrayFeatureExtractor"),
)

dimension_idx += 1
if dimension_idx == current_input.get_second_dimension():
dimension_idx = 0
input_idx += 1

attrs = {"name": scope.get_unique_operator_name("LabelEncoder")}
if isinstance(feature_column.type, FloatTensorType):
attrs["keys_floats"] = np.array([float(s) for s in categories], dtype=np.float32)
elif isinstance(feature_column.type, Int64TensorType):
attrs["keys_int64s"] = np.array([int(s) for s in categories], dtype=np.int64)
else:
attrs["keys_strings"] = np.array([str(s).encode("utf-8") for s in categories])
attrs["values_floats"] = encodings
attrs["default_float"] = op.target_mean_

result.append(scope.get_unique_variable_name("ordinal_output"))
label_encoder_output = scope.get_unique_variable_name("label_encoder")

container.add_node(
"LabelEncoder",
feature_column.onnx_name,
label_encoder_output,
op_domain="ai.onnx.ml",
**attrs,
)
apply_reshape(
scope,
label_encoder_output,
result[-1],
container,
desired_shape=(-1, 1),
)

concat_result_name = scope.get_unique_variable_name("concat_result")
apply_concat(scope, result, concat_result_name, container, axis=1)
apply_cast(
scope,
concat_result_name,
operator.output_full_names,
container,
to=onnx_proto.TensorProto.FLOAT,
)


register_converter("SklearnTargetEncoder", convert_sklearn_target_encoder)
2 changes: 2 additions & 0 deletions skl2onnx/shape_calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from . import sgd_oneclass_svm
from . import svd
from . import support_vector_machines
from . import target_encoder
from . import text_vectorizer
from . import tuned_threshold_classifier
from . import tfidf_transformer
Expand Down Expand Up @@ -99,6 +100,7 @@
sgd_oneclass_svm,
svd,
support_vector_machines,
target_encoder,
text_vectorizer,
tfidf_transformer,
tuned_threshold_classifier,
Expand Down
29 changes: 29 additions & 0 deletions skl2onnx/shape_calculators/target_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0


from ..common._registration import register_shape_calculator
from ..common.data_types import FloatTensorType
from ..common.data_types import Int64TensorType, StringTensorType
from ..common.utils import check_input_and_output_numbers
from ..common.utils import check_input_and_output_types


def calculate_sklearn_target_encoder_output_shapes(operator):
"""
This function just copy the input shape to the output because target
encoder only alters input features' values, not their shape.
"""
check_input_and_output_numbers(operator, output_count_range=1)
check_input_and_output_types(
operator, good_input_types=[FloatTensorType, Int64TensorType, StringTensorType]
)

N = operator.inputs[0].get_first_dimension()
shape = [N, len(operator.raw_operator.categories_)]

operator.outputs[0].type = FloatTensorType(shape=shape)


register_shape_calculator(
"SklearnTargetEncoder", calculate_sklearn_target_encoder_output_shapes
)
Loading
Loading