Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix number of neighbours in KNNImputer #1167

Merged
merged 4 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

## 1.19.0

* Refactors KNNImputer converter with local function to match
scikit-learn's implementation, the code was partially
automatically generated from an equivalent implementation
in pytorch and exported into ONNX
[#1167](https://github.com/onnx/sklearn-onnx/issues/1167),
[#1165](https://github.com/onnx/sklearn-onnx/issues/1165)
* Add support to sklearn TargetEncoder
[#1137](https://github.com/onnx/sklearn-onnx/issues/1137)
* Fixes missing WhiteKernel with return_std=True #1163
Expand Down
14 changes: 8 additions & 6 deletions skl2onnx/operator_converters/nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,7 @@ def make_knn_imputer_column_nan_found(
dist_chunk = gr.make_tensor_input("dist_chunk")
_fit_x = gr.make_tensor_input("_fit_x")
row_missing_idx = gr.make_tensor_input("row_missing_idx")
n_neighours = gr.make_tensor_input("n_neighbors")

op = gr.get_op_builder(scope)

Expand All @@ -1336,10 +1337,6 @@ def make_knn_imputer_column_nan_found(
value=from_array(np.array(1, dtype=np.int64), name="value"),
outputs=["init7_s_1"],
)
c_lifted_tensor_1 = op.Constant(
value=from_array(np.array(3, dtype=np.int64), name="value"),
outputs=["c_lifted_tensor_1"],
)
c_lifted_tensor_2 = op.Constant(
value=from_array(np.array([1], dtype=np.int64), name="value"),
outputs=["c_lifted_tensor_2"],
Expand Down Expand Up @@ -1434,8 +1431,8 @@ def make_knn_imputer_column_nan_found(
outputs=["select_scatter", "dist_subset_updated", "receivers_idx_updated"],
)

lt = op.Less(c_lifted_tensor_1, sym_size_int_23, outputs=["lt"])
where_1 = op.Where(lt, c_lifted_tensor_1, sym_size_int_23, outputs=["where_1"])
lt = op.Less(n_neighours, sym_size_int_23, outputs=["lt"])
where_1 = op.Where(lt, n_neighours, sym_size_int_23, outputs=["where_1"])
le = op.LessOrEqual(where_1, init7_s_0, outputs=["le"])
# c_lifted_tensor_2 -> init7_s_1 to have a zero time, onnxruntime crashes wher shapes are
# (), (1,), ()
Expand Down Expand Up @@ -1502,6 +1499,7 @@ def make_knn_imputer_column(g: ModelComponentContainer, scope: Scope, itype: int
mask = gr.make_tensor_input("mask")
row_missing_idx = gr.make_tensor_input("row_missing_idx")
_fit_x = gr.make_tensor_input("_fit_x")
n_neighbors = gr.make_tensor_input("n_neighbors")

op = gr.get_op_builder(scope)
zero32 = op.Constant(
Expand Down Expand Up @@ -1549,6 +1547,7 @@ def make_knn_imputer_column(g: ModelComponentContainer, scope: Scope, itype: int
dist_chunk,
_fit_x,
row_missing_idx,
n_neighbors,
],
["X"],
domain="local_domain",
Expand All @@ -1572,6 +1571,7 @@ def _knn_imputer_builder(
_mask_fit_x: "BOOL[s0, 2]", # noqa: F821
_valid_mask: "BOOL[2]", # noqa: F821
_fit_x: "DOUBLE[s1, 2]", # noqa: F821
n_neighbors: "INT", # noqa: F821
x: "FLOAT[s2, 2]", # noqa: F821
itype: int,
):
Expand Down Expand Up @@ -1614,6 +1614,7 @@ def _knn_imputer_builder(
isnan,
nonzero_numpy__0,
_fit_x,
n_neighbors,
domain="local_domain",
)
return op.Compress(x, _valid_mask, axis=1, outputs=["output_0"])
Expand Down Expand Up @@ -1656,6 +1657,7 @@ def convert_knn_imputer(
knn_op._mask_fit_X,
knn_op._valid_mask,
training_data,
np.array(knn_op.n_neighbors, dtype=np.int64),
operator.inputs[0].full_name,
itype=proto_type,
)
Expand Down
Binary file removed tests/bug.onnx
Binary file not shown.
55 changes: 54 additions & 1 deletion tests/test_sklearn_nearest_neighbour_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,10 @@
)

def _get_torch_knn_imputer(self):
import torch
try:
import torch
except ImportError:
return None, None

def _get_weights(dist, weights):
"""Get the weights from an array of distances and a parameter ``weights``.
Expand Down Expand Up @@ -1730,6 +1733,56 @@
backend="onnxruntime",
)

@unittest.skipIf(KNNImputer is None, reason="new in 0.22")
@unittest.skipIf(
pv.Version(ort_version) <= pv.Version("1.16.0"),
reason="onnxruntime not recent enough",
)
@ignore_warnings(category=DeprecationWarning)
@unittest.skipIf(
pv.Version(ort_version) < pv.Version("1.20.0"),
reason="onnxruntime not recent enough",
)
@unittest.skipIf(
sys.platform != "linux" and pv.Version(skl_version) < pv.Version("1.6.0"),
"investigate why topk returns different results",
)
def test_knn_imputer_one_nan(self):
import numpy as np
import onnxruntime as rt

np.random.seed(42)
data = np.random.randn(8, 2).astype(np.float32)
data[0, -1] = np.nan
imputer = KNNImputer(n_neighbors=5)
imputer.fit(data)
dataft = imputer.transform(data)

tmodel_cls, _ = self._get_torch_knn_imputer()
if tmodel_cls is not None:
import torch

tmodel = tmodel_cls(imputer)
ty = tmodel.transform(
torch.from_numpy(imputer._mask_fit_X),
torch.from_numpy(imputer._valid_mask),
torch.from_numpy(imputer._fit_X.astype(numpy.float32)),
torch.from_numpy(data),
)
assert_almost_equal(dataft[:2], ty[:2].numpy())

input_data = data.astype(np.float32)
initial_type = [("float_input", FloatTensorType([None, data.shape[1]]))]
onnx_model = convert_sklearn(imputer, initial_types=initial_type)
# with open("test_knn_imputer_one_nan.onnx", "wb") as f:
# f.write(onnx_model.SerializeToString())
Comment on lines +1777 to +1778

Check notice

Code scanning / CodeQL

Commented-out code Note test

This comment appears to contain commented-out code.

sess = rt.InferenceSession(onnx_model.SerializeToString())
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
res = sess.run([output_name], {input_name: input_data})
assert_almost_equal(dataft[:2], res[0][:2])


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading