Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KNNImputer #1165

Merged
merged 25 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skl2onnx/_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def _get_sklearn_operator_name(model_type):
:return: A string which stands for the type of the input model in
our conversion framework
"""
if model_type not in sklearn_operator_name_map: # noqa: SIM401
if model_type not in sklearn_operator_name_map:
# No proper operator name found, it means a local operator.
alias = None
else:
Expand Down
34 changes: 31 additions & 3 deletions skl2onnx/common/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@
from scipy.sparse import coo_matrix
from onnx import SparseTensorProto, ValueInfoProto
from onnx.defs import onnx_opset_version, get_all_schemas_with_history
from onnx.helper import make_node, make_tensor, make_attribute, make_sparse_tensor
from onnx.helper import (
make_node,
make_tensor,
make_attribute,
make_sparse_tensor,
make_tensor_value_info,
)
import onnx.onnx_cpp2py_export.defs as C
from onnxconverter_common.onnx_ops import __dict__ as dict_apply_operation
from ..proto import TensorProto
from .utils import get_domain
from .graph_builder_opset import Opset


logger = getLogger("skl2onnx")
Expand Down Expand Up @@ -222,6 +229,7 @@ def __init__(
white_op=None,
black_op=None,
verbose=0,
as_function=False,
):
"""
:param target_opset: number, for example, 7 for *ONNX 1.2*, and
Expand All @@ -233,6 +241,7 @@ def __init__(
:param black_op: black list of ONNX nodes allowed
while converting a pipeline, if empty, none are blacklisted
:param verbose: display information while converting
:param as_function: to export as a local function
"""
_WhiteBlackContainer.__init__(
self, white_op=white_op, black_op=black_op, verbose=verbose
Expand All @@ -251,6 +260,7 @@ def __init__(
# ONNX nodes (type: NodeProto) used to define computation
# structure
self.nodes = []
self.as_function = as_function
# ONNX operators' domain-version pair set. They will be added
# into opset_import field in the final ONNX model.
self.node_domain_version_pair_sets = set()
Expand All @@ -266,6 +276,14 @@ def __init__(
self.options = options
# All registered models.
self.registered_models = registered_models
self.local_functions = {}

@property
def main_opset(self):
return self.target_opset

def make_local_function(self, domain: str, name: str, container, optimize=False):
self.local_functions[domain, name] = container

def swap_names(self, old_name, new_name):
"""
Expand Down Expand Up @@ -389,6 +407,12 @@ def add_input(self, variable):
"""
self.inputs.append(self._make_value_info(variable))

def make_tensor_input(self, name):
self.inputs.append(make_tensor_value_info(name, 0, None))

def make_tensor_output(self, name):
self.outputs.append(make_tensor_value_info(name, 0, None))

def add_output(self, variable):
"""
Adds our *Variable* object defined *_parser.py* into the the
Expand Down Expand Up @@ -621,6 +645,7 @@ def add_node(
attributes' names and attributes' values,
respectively.
"""
assert op_type != "knn_imputer_column" or op_domain
if "axes" in attrs and (
attrs["axes"] is None or not isinstance(attrs["axes"], (list, np.ndarray))
):
Expand Down Expand Up @@ -732,7 +757,7 @@ def add_node(
def target_opset_any_domain(self, domain):
target_opset = self.target_opset_all
if isinstance(target_opset, dict):
if domain in target_opset: # noqa: SIM401
if domain in target_opset:
to = target_opset[domain]
else:
to = None
Expand Down Expand Up @@ -763,7 +788,7 @@ def _get_op_version(self, domain, op_type):
key = domain, op_type
vers = self._op_versions.get(key, None)
if vers is None:
if domain == "com.microsoft":
if domain in ("com.microsoft", "local_domain"):
# avoid a not necessarily necessary warning
vers = 1
else:
Expand Down Expand Up @@ -1030,3 +1055,6 @@ def nstr(name):
)
map_nodes = {str(id(node)): node for node in self.nodes}
self.nodes = [map_nodes[_[-1]] for _ in topo]

def get_op_builder(self, scope):
return Opset(self, scope)
272 changes: 272 additions & 0 deletions skl2onnx/common/graph_builder_opset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
# SPDX-License-Identifier: Apache-2.0

from functools import partial
from typing import List, Optional, Union
import numpy as np
from onnx.helper import np_dtype_to_tensor_dtype


class Opset:
"""
Makes it easier to write onnx graph.
The method name is the node type.

:param graph_builder: the builder or container
:param allow_unknown: allows unknown operators, otherwise,
fails this class does not the expected number of outputs
"""

# defined for opset >= 18
# name: number of expected outputs
_implemented = {
"Abs": 1,
"Add": 1,
"And": 1,
"ArgMax": 1,
"ArgMin": 1,
"Cast": 1,
"CastLike": 1,
"Celu": 1,
"Concat": 1,
"Constant": 1,
"ConstantOfShape": 1,
"Cos": 1,
"Cosh": 1,
"Div": 1,
"Dropout": 2,
"Elu": 1,
"Equal": 1,
"Exp": 1,
"Expand": 1,
"Flatten": 1,
"Gather": 1,
"GatherElements": 1,
"GatherND": 1,
"Gemm": 1,
"Greater": 1,
"GreaterOrEqual": 1,
"Identity": 1,
"MatMul": 1,
"MaxPool": 2,
"Mul": 1,
"Less": 1,
"LessOrEqual": 1,
"Log": 1,
"LogSoftmax": 1,
"Neg": 1,
"Not": 1,
"Or": 1,
"Pow": 1,
"Range": 1,
"Reciprocal": 1,
"ReduceMax": 1,
"ReduceMean": 1,
"ReduceMin": 1,
"ReduceSum": 1,
"Relu": 1,
"Reshape": 1,
"ScatterElements": 1,
"ScatterND": 1,
"Shape": 1,
"Sigmoid": 1,
"Sin": 1,
"Sinh": 1,
"Slice": 1,
"Softmax": 1,
"Sqrt": 1,
"Squeeze": 1,
"Sub": 1,
"Tile": 1,
"TopK": 2,
"Transpose": 1,
"Trilu": 1,
"Unsqueeze": 1,
"Where": 1,
}

def __init__(
self,
container,
scope,
allow_unknown: bool = False,
):
self.container = container
self.scope = scope
self.allow_unknown = allow_unknown

def __getattr__(self, name):
if name in self._implemented:
return partial(self.make_node, name)
if name in self.__dict__:
return self.__dict__[name]
return partial(self._make_node, name)

def _make_node(self, op_type, *args, outputs=None, **kwargs):
if outputs is None:
if op_type in self._implemented:
outputs = self._implemented[op_type]
elif op_type == "Split" and kwargs.get("domain", "") == "":
assert "num_outputs" in kwargs, (
"Number of outputs is not implemented yet for operator "
f"{op_type!r} and kwargs={kwargs}"
)
outputs = kwargs["num_outputs"]
else:
# We assume there is only one outputs.
outputs = 1
return self.make_node(
op_type, *args, outputs=outputs, allow_empty_shape=True, **kwargs
)

def make_node(
self,
op_type: str,
*inputs: Optional[Union[str, List[str]]],
outputs: Optional[Union[int, List[str], str]] = None,
domain: str = "",
name: Optional[str] = None,
allow_empty_shape: bool = False,
**kwargs,
):
assert (
op_type != "Split" or outputs != 1
), f"Operator Split is useless with one output, inputs={inputs}, outputs={outputs}"
if outputs is None:
num_outputs = self._implemented[op_type]
outputs = [
self.scope.get_unique_variable_name(f"_onx_{op_type.lower()}")
for _ in range(num_outputs)
]
elif isinstance(outputs, int):
outputs = [
self.scope.get_unique_variable_name(f"_onx_{op_type.lower()}")
for _ in range(outputs)
]
if inputs is None:
inputs = []
assert (
op_type != "Reshape"
or len(inputs) != 2
or not isinstance(inputs[1], np.ndarray)
or inputs[1].dtype == np.int64
), f"Suspicious shape {inputs[1]!r} for a Reshape"
new_inputs = []
for i in inputs:
assert not isinstance(
i, (list, tuple)
), f"Wrong inputs for operator {op_type!r}: {inputs!r}"
if isinstance(i, str):
new_inputs.append(i)
elif hasattr(i, "name") and not hasattr(i, "detach"):
# torch.fx.Node
assert i.name is not None, f"Unexpected name for type {type(i)}"
new_inputs.append(i.name)
elif i is None:
# Optional input
new_inputs.append("")
elif isinstance(i, np.ndarray):
assert allow_empty_shape or 0 not in i.shape, (
f"Not implemented for type(i)={type(i)}, i={i}, "
f"inputs={inputs!r}, op_type={op_type!r}, i.shape={i.shape}"
f""
)
cst_name = self.scope.get_unique_variable_name("cst")
self.container.add_initializer(
cst_name,
np_dtype_to_tensor_dtype(i.dtype),
i.shape,
list(i.ravel()),
)
new_inputs.append(cst_name)
else:
raise AssertionError(
f"Not implemented for type(i)={type(i)}, i={i}, "
f"inputs={inputs!r}, op_type={op_type!r}"
)

assert None not in new_inputs
if self.allow_unknown and not self.container.get_opset(domain):
self.container.add_domain(domain)
self.container.add_node(
op_type,
new_inputs,
outputs,
op_domain=domain,
name=name or f"{self.__class__.__name__}",
**kwargs,
)
if len(outputs) == 1:
return outputs[0]
return outputs

@staticmethod
def _iaxes(op_type, axes) -> int:
if isinstance(axes, np.ndarray):
iaxes = axes.tolist()
elif isinstance(axes, int):
iaxes = [axes]
else:
raise RuntimeError(
f"Unable to call {op_type} on a dynamic input axis={axes}"
)
return iaxes

def ReduceMaxAnyOpset(self, *args, name: str = "ReduceMaxAnyOpset", **kwargs):
if len(args) == 1:
return self.ReduceMax(*args, name=name, **kwargs)
assert len(args) == 2, f"ReduceMaxAnyOpset expects 2 arguments not {len(args)}"
if self.container.main_opset >= 18:
return self.ReduceMax(*args, name=name, **kwargs)
return self.ReduceMax(
args[0], axes=self._iaxes("ReduceMax", args[1]), name=name, **kwargs
)

def ReduceMinAnyOpset(self, *args, name: str = "ReduceMinAnyOpset", **kwargs):
if len(args) == 1:
return self.ReduceMin(*args, name=name, **kwargs)
assert len(args) == 2, f"ReduceMaxAnyOpset expects 2 arguments not {len(args)}"
if self.container.main_opset >= 18:
return self.ReduceMin(*args, name=name, **kwargs)
return self.ReduceMin(
args[0], axes=self._iaxes("ReduceMin", args[1]), name=name, **kwargs
)

def ReduceMeanAnyOpset(self, *args, name: str = "ReduceMeanAnyOpset", **kwargs):
if len(args) == 1:
return self.ReduceMean(*args, name=name, **kwargs)
assert len(args) == 2, f"ReduceMeanAnyOpset expects 2 arguments not {len(args)}"
if self.container.main_opset >= 18:
return self.ReduceMean(*args, name=name, **kwargs)
return self.ReduceMean(
args[0], axes=self._iaxes("ReduceMean", args[1]), name=name, **kwargs
)

def ReduceSumAnyOpset(self, *args, name: str = "ReduceSumAnyOpset", **kwargs):
if len(args) == 1:
return self.ReduceSum(*args, name=name, **kwargs)
assert len(args) == 2, f"ReduceSumAnyOpset expects 2 arguments not {len(args)}"
if self.container.main_opset >= 13:
return self.ReduceSum(*args, name=name, **kwargs)
return self.ReduceSum(
args[0], axes=self._iaxes("ReduceSum", args[1]), name=name, **kwargs
)

def SqueezeAnyOpset(self, *args, name: str = "SqueezeAnyOpset", **kwargs):
if len(args) == 1 and len(kwargs) == 0:
return self.Squeeze(*args, name=name)
assert len(args) == 2, f"SqueezeAnyOpset expects 2 arguments not {len(args)}"
if self.container.main_opset >= 13:
return self.Squeeze(*args, name=name, **kwargs)
return self.Squeeze(
args[0], axes=self._iaxes("Squeeze", args[1]), name=name, **kwargs
)

def UnsqueezeAnyOpset(self, *args, name: str = "UnsqueezeAnyOpset", **kwargs):
if len(args) == 1 and len(kwargs) == 0:
return self.Unsqueeze(*args, name=name)
assert len(args) == 2, f"UnsqueezeAnyOpset expects 2 arguments not {len(args)}"
if self.container.main_opset >= 13:
return self.Unsqueeze(*args, name=name, **kwargs)
return self.Unsqueeze(
args[0], axes=self._iaxes("Unsqueeze", args[1]), name=name, **kwargs
)
Loading
Loading