Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KNNImputer #1165

Merged
merged 25 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions skl2onnx/common/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,18 @@
from scipy.sparse import coo_matrix
from onnx import SparseTensorProto, ValueInfoProto
from onnx.defs import onnx_opset_version, get_all_schemas_with_history
from onnx.helper import make_node, make_tensor, make_attribute, make_sparse_tensor
from onnx.helper import (
make_node,
make_tensor,
make_attribute,
make_sparse_tensor,
make_tensor_value_info,
)
import onnx.onnx_cpp2py_export.defs as C
from onnxconverter_common.onnx_ops import __dict__ as dict_apply_operation
from ..proto import TensorProto
from .utils import get_domain
from .graph_builder_opset import Opset


logger = getLogger("skl2onnx")
Expand Down Expand Up @@ -222,6 +229,7 @@ def __init__(
white_op=None,
black_op=None,
verbose=0,
as_function=False,
):
"""
:param target_opset: number, for example, 7 for *ONNX 1.2*, and
Expand All @@ -233,6 +241,7 @@ def __init__(
:param black_op: black list of ONNX nodes allowed
while converting a pipeline, if empty, none are blacklisted
:param verbose: display information while converting
:param as_function: to export as a local function
"""
_WhiteBlackContainer.__init__(
self, white_op=white_op, black_op=black_op, verbose=verbose
Expand All @@ -251,6 +260,7 @@ def __init__(
# ONNX nodes (type: NodeProto) used to define computation
# structure
self.nodes = []
self.as_function = as_function
# ONNX operators' domain-version pair set. They will be added
# into opset_import field in the final ONNX model.
self.node_domain_version_pair_sets = set()
Expand All @@ -266,6 +276,14 @@ def __init__(
self.options = options
# All registered models.
self.registered_models = registered_models
self.local_functions = {}

@property
def main_opset(self):
return self.target_opset

def make_local_function(self, domain: str, name: str, container, optimize=False):
self.local_functions[domain, name] = container

def swap_names(self, old_name, new_name):
"""
Expand Down Expand Up @@ -389,6 +407,14 @@ def add_input(self, variable):
"""
self.inputs.append(self._make_value_info(variable))

def make_tensor_input(self, name):
self.inputs.append(make_tensor_value_info(name, 0, None))
return name

def make_tensor_output(self, name):
self.outputs.append(make_tensor_value_info(name, 0, None))
return name

def add_output(self, variable):
"""
Adds our *Variable* object defined *_parser.py* into the the
Expand Down Expand Up @@ -621,6 +647,7 @@ def add_node(
attributes' names and attributes' values,
respectively.
"""
assert op_type != "knn_imputer_column" or op_domain
if "axes" in attrs and (
attrs["axes"] is None or not isinstance(attrs["axes"], (list, np.ndarray))
):
Expand Down Expand Up @@ -650,6 +677,12 @@ def add_node(
",".join(outputs),
name,
)
if not hasattr(self, "_added_names_"):
self._added_names_ = set()
assert all(
n not in self._added_names_ for n in outputs
), f"One output node in {outputs} was already added, added={self._added_names_}"
self._added_names_ |= set(outputs)
try:
common = set(inputs) & set(outputs)
except TypeError as e:
Expand Down Expand Up @@ -763,7 +796,7 @@ def _get_op_version(self, domain, op_type):
key = domain, op_type
vers = self._op_versions.get(key, None)
if vers is None:
if domain == "com.microsoft":
if domain in ("com.microsoft", "local_domain"):
# avoid a not necessarily necessary warning
vers = 1
else:
Expand Down Expand Up @@ -875,6 +908,7 @@ def ensure_topological_order(self):
n_iter = 0
missing_ops = []
cont = True
parent = {}
while cont and n_iter < len(self.nodes) * 2:
n_iter += 1
missing_names = set()
Expand All @@ -900,11 +934,16 @@ def ensure_topological_order(self):
order[key] = maxi
maxi += 1
for name in node.output:
if name not in parent:
parent[name] = []
parent[name].append(node)
if name in order:
raise RuntimeError(
"Unable to sort a node (cycle). An output was "
"already ordered with name %r (iteration=%r)."
"" % (name, n_iter)
f"Unable to sort a node (cycle). An output was "
f"already ordered with name {name!r} (iteration={n_iter})\n"
f"--\n{[f'{id(n)}-{n.op_type}' for n in parent[name]]}\n"
f"--\n{[n.op_type for n in self.nodes]}\n"
f"--\n{pprint.pformat(order)}"
)
order[name] = maxi
if len(missing_names) == 0:
Expand Down Expand Up @@ -1030,3 +1069,6 @@ def nstr(name):
)
map_nodes = {str(id(node)): node for node in self.nodes}
self.nodes = [map_nodes[_[-1]] for _ in topo]

def get_op_builder(self, scope):
return Opset(self, scope)
Loading
Loading