Skip to content

Commit

Permalink
Fix KNNImputer (#1165)
Browse files Browse the repository at this point in the history
* Add failing unit test for KNNImputer

Signed-off-by: xadupre <[email protected]>

* refactoring

Signed-off-by: xadupre <[email protected]>

* fix a few things again

* fix issues

Signed-off-by: xadupre <[email protected]>

* changes

* fix knn

Signed-off-by: xadupre <[email protected]>

* changes

Signed-off-by: xadupre <[email protected]>

* fix a couple of issues

* test

* fix conversion

* fix knn imputer

Signed-off-by: xadupre <[email protected]>

* fix converter

Signed-off-by: xadupre <[email protected]>

* fix one issue

Signed-off-by: xadupre <[email protected]>

* ifx one case

Signed-off-by: xadupre <[email protected]>

* better or not?

* verbose

Signed-off-by: xadupre <[email protected]>

* fix neg

Signed-off-by: xadupre <[email protected]>

* ort

Signed-off-by: xadupre <[email protected]>

* disable

Signed-off-by: xadupre <[email protected]>

* update

Signed-off-by: xadupre <[email protected]>

* dar

Signed-off-by: xadupre <[email protected]>

* issue

* fix ut

Signed-off-by: xadupre <[email protected]>

* disable

Signed-off-by: xadupre <[email protected]>

---------

Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre authored Feb 26, 2025
1 parent 11ee203 commit 07c5950
Show file tree
Hide file tree
Showing 10 changed files with 1,983 additions and 165 deletions.
52 changes: 47 additions & 5 deletions skl2onnx/common/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@
from scipy.sparse import coo_matrix
from onnx import SparseTensorProto, ValueInfoProto
from onnx.defs import onnx_opset_version, get_all_schemas_with_history
from onnx.helper import make_node, make_tensor, make_attribute, make_sparse_tensor
from onnx.helper import (
make_node,
make_tensor,
make_attribute,
make_sparse_tensor,
make_tensor_value_info,
)
import onnx.onnx_cpp2py_export.defs as C
from onnxconverter_common.onnx_ops import __dict__ as dict_apply_operation
from ..proto import TensorProto
from .utils import get_domain
from .graph_builder_opset import Opset


logger = getLogger("skl2onnx")
Expand Down Expand Up @@ -222,6 +229,7 @@ def __init__(
white_op=None,
black_op=None,
verbose=0,
as_function=False,
):
"""
:param target_opset: number, for example, 7 for *ONNX 1.2*, and
Expand All @@ -233,6 +241,7 @@ def __init__(
:param black_op: black list of ONNX nodes allowed
while converting a pipeline, if empty, none are blacklisted
:param verbose: display information while converting
:param as_function: to export as a local function
"""
_WhiteBlackContainer.__init__(
self, white_op=white_op, black_op=black_op, verbose=verbose
Expand All @@ -251,6 +260,7 @@ def __init__(
# ONNX nodes (type: NodeProto) used to define computation
# structure
self.nodes = []
self.as_function = as_function
# ONNX operators' domain-version pair set. They will be added
# into opset_import field in the final ONNX model.
self.node_domain_version_pair_sets = set()
Expand All @@ -266,6 +276,14 @@ def __init__(
self.options = options
# All registered models.
self.registered_models = registered_models
self.local_functions = {}

@property
def main_opset(self):
return self.target_opset

def make_local_function(self, domain: str, name: str, container, optimize=False):
self.local_functions[domain, name] = container

def swap_names(self, old_name, new_name):
"""
Expand Down Expand Up @@ -389,6 +407,14 @@ def add_input(self, variable):
"""
self.inputs.append(self._make_value_info(variable))

def make_tensor_input(self, name):
self.inputs.append(make_tensor_value_info(name, 0, None))
return name

def make_tensor_output(self, name):
self.outputs.append(make_tensor_value_info(name, 0, None))
return name

def add_output(self, variable):
"""
Adds our *Variable* object defined *_parser.py* into the the
Expand Down Expand Up @@ -621,6 +647,7 @@ def add_node(
attributes' names and attributes' values,
respectively.
"""
assert op_type != "knn_imputer_column" or op_domain
if "axes" in attrs and (
attrs["axes"] is None or not isinstance(attrs["axes"], (list, np.ndarray))
):
Expand Down Expand Up @@ -650,6 +677,12 @@ def add_node(
",".join(outputs),
name,
)
if not hasattr(self, "_added_names_"):
self._added_names_ = set()
assert all(
n not in self._added_names_ for n in outputs
), f"One output node in {outputs} was already added, added={self._added_names_}"
self._added_names_ |= set(outputs)
try:
common = set(inputs) & set(outputs)
except TypeError as e:
Expand Down Expand Up @@ -763,7 +796,7 @@ def _get_op_version(self, domain, op_type):
key = domain, op_type
vers = self._op_versions.get(key, None)
if vers is None:
if domain == "com.microsoft":
if domain in ("com.microsoft", "local_domain"):
# avoid a not necessarily necessary warning
vers = 1
else:
Expand Down Expand Up @@ -875,6 +908,7 @@ def ensure_topological_order(self):
n_iter = 0
missing_ops = []
cont = True
parent = {}
while cont and n_iter < len(self.nodes) * 2:
n_iter += 1
missing_names = set()
Expand All @@ -900,11 +934,16 @@ def ensure_topological_order(self):
order[key] = maxi
maxi += 1
for name in node.output:
if name not in parent:
parent[name] = []
parent[name].append(node)
if name in order:
raise RuntimeError(
"Unable to sort a node (cycle). An output was "
"already ordered with name %r (iteration=%r)."
"" % (name, n_iter)
f"Unable to sort a node (cycle). An output was "
f"already ordered with name {name!r} (iteration={n_iter})\n"
f"--\n{[f'{id(n)}-{n.op_type}' for n in parent[name]]}\n"
f"--\n{[n.op_type for n in self.nodes]}\n"
f"--\n{pprint.pformat(order)}"
)
order[name] = maxi
if len(missing_names) == 0:
Expand Down Expand Up @@ -1030,3 +1069,6 @@ def nstr(name):
)
map_nodes = {str(id(node)): node for node in self.nodes}
self.nodes = [map_nodes[_[-1]] for _ in topo]

def get_op_builder(self, scope):
return Opset(self, scope)
Loading

0 comments on commit 07c5950

Please sign in to comment.