Skip to content

Commit

Permalink
Fix number of neighbours in KNNImputer (#1167)
Browse files Browse the repository at this point in the history
* Fix number of neighbours in KNNImputer

Signed-off-by: xadupre <[email protected]>

* fix

* fix issue

Signed-off-by: xadupre <[email protected]>

* fix changelogs

Signed-off-by: xadupre <[email protected]>

---------

Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre authored Feb 27, 2025
1 parent 07c5950 commit acfa6a4
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 7 deletions.
6 changes: 6 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

## 1.19.0

* Refactors KNNImputer converter with local function to match
scikit-learn's implementation, the code was partially
automatically generated from an equivalent implementation
in pytorch and exported into ONNX
[#1167](https://github.com/onnx/sklearn-onnx/issues/1167),
[#1165](https://github.com/onnx/sklearn-onnx/issues/1165)
* Add support to sklearn TargetEncoder
[#1137](https://github.com/onnx/sklearn-onnx/issues/1137)
* Fixes missing WhiteKernel with return_std=True #1163
Expand Down
14 changes: 8 additions & 6 deletions skl2onnx/operator_converters/nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,7 @@ def make_knn_imputer_column_nan_found(
dist_chunk = gr.make_tensor_input("dist_chunk")
_fit_x = gr.make_tensor_input("_fit_x")
row_missing_idx = gr.make_tensor_input("row_missing_idx")
n_neighours = gr.make_tensor_input("n_neighbors")

op = gr.get_op_builder(scope)

Expand All @@ -1336,10 +1337,6 @@ def make_knn_imputer_column_nan_found(
value=from_array(np.array(1, dtype=np.int64), name="value"),
outputs=["init7_s_1"],
)
c_lifted_tensor_1 = op.Constant(
value=from_array(np.array(3, dtype=np.int64), name="value"),
outputs=["c_lifted_tensor_1"],
)
c_lifted_tensor_2 = op.Constant(
value=from_array(np.array([1], dtype=np.int64), name="value"),
outputs=["c_lifted_tensor_2"],
Expand Down Expand Up @@ -1434,8 +1431,8 @@ def make_knn_imputer_column_nan_found(
outputs=["select_scatter", "dist_subset_updated", "receivers_idx_updated"],
)

lt = op.Less(c_lifted_tensor_1, sym_size_int_23, outputs=["lt"])
where_1 = op.Where(lt, c_lifted_tensor_1, sym_size_int_23, outputs=["where_1"])
lt = op.Less(n_neighours, sym_size_int_23, outputs=["lt"])
where_1 = op.Where(lt, n_neighours, sym_size_int_23, outputs=["where_1"])
le = op.LessOrEqual(where_1, init7_s_0, outputs=["le"])
# c_lifted_tensor_2 -> init7_s_1 to have a zero time, onnxruntime crashes wher shapes are
# (), (1,), ()
Expand Down Expand Up @@ -1502,6 +1499,7 @@ def make_knn_imputer_column(g: ModelComponentContainer, scope: Scope, itype: int
mask = gr.make_tensor_input("mask")
row_missing_idx = gr.make_tensor_input("row_missing_idx")
_fit_x = gr.make_tensor_input("_fit_x")
n_neighbors = gr.make_tensor_input("n_neighbors")

op = gr.get_op_builder(scope)
zero32 = op.Constant(
Expand Down Expand Up @@ -1549,6 +1547,7 @@ def make_knn_imputer_column(g: ModelComponentContainer, scope: Scope, itype: int
dist_chunk,
_fit_x,
row_missing_idx,
n_neighbors,
],
["X"],
domain="local_domain",
Expand All @@ -1572,6 +1571,7 @@ def _knn_imputer_builder(
_mask_fit_x: "BOOL[s0, 2]", # noqa: F821
_valid_mask: "BOOL[2]", # noqa: F821
_fit_x: "DOUBLE[s1, 2]", # noqa: F821
n_neighbors: "INT", # noqa: F821
x: "FLOAT[s2, 2]", # noqa: F821
itype: int,
):
Expand Down Expand Up @@ -1614,6 +1614,7 @@ def _knn_imputer_builder(
isnan,
nonzero_numpy__0,
_fit_x,
n_neighbors,
domain="local_domain",
)
return op.Compress(x, _valid_mask, axis=1, outputs=["output_0"])
Expand Down Expand Up @@ -1656,6 +1657,7 @@ def convert_knn_imputer(
knn_op._mask_fit_X,
knn_op._valid_mask,
training_data,
np.array(knn_op.n_neighbors, dtype=np.int64),
operator.inputs[0].full_name,
itype=proto_type,
)
Expand Down
Binary file removed tests/bug.onnx
Binary file not shown.
55 changes: 54 additions & 1 deletion tests/test_sklearn_nearest_neighbour_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,10 @@ def test_sklearn_k_neighbours_transformer_connectivity(self):
)

def _get_torch_knn_imputer(self):
import torch
try:
import torch
except ImportError:
return None, None

def _get_weights(dist, weights):
"""Get the weights from an array of distances and a parameter ``weights``.
Expand Down Expand Up @@ -1730,6 +1733,56 @@ def test_sklearn_knn_imputer_issue_2025(self):
backend="onnxruntime",
)

@unittest.skipIf(KNNImputer is None, reason="new in 0.22")
@unittest.skipIf(
pv.Version(ort_version) <= pv.Version("1.16.0"),
reason="onnxruntime not recent enough",
)
@ignore_warnings(category=DeprecationWarning)
@unittest.skipIf(
pv.Version(ort_version) < pv.Version("1.20.0"),
reason="onnxruntime not recent enough",
)
@unittest.skipIf(
sys.platform != "linux" and pv.Version(skl_version) < pv.Version("1.6.0"),
"investigate why topk returns different results",
)
def test_knn_imputer_one_nan(self):
import numpy as np
import onnxruntime as rt

np.random.seed(42)
data = np.random.randn(8, 2).astype(np.float32)
data[0, -1] = np.nan
imputer = KNNImputer(n_neighbors=5)
imputer.fit(data)
dataft = imputer.transform(data)

tmodel_cls, _ = self._get_torch_knn_imputer()
if tmodel_cls is not None:
import torch

tmodel = tmodel_cls(imputer)
ty = tmodel.transform(
torch.from_numpy(imputer._mask_fit_X),
torch.from_numpy(imputer._valid_mask),
torch.from_numpy(imputer._fit_X.astype(numpy.float32)),
torch.from_numpy(data),
)
assert_almost_equal(dataft[:2], ty[:2].numpy())

input_data = data.astype(np.float32)
initial_type = [("float_input", FloatTensorType([None, data.shape[1]]))]
onnx_model = convert_sklearn(imputer, initial_types=initial_type)
# with open("test_knn_imputer_one_nan.onnx", "wb") as f:
# f.write(onnx_model.SerializeToString())

sess = rt.InferenceSession(onnx_model.SerializeToString())
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
res = sess.run([output_name], {input_name: input_data})
assert_almost_equal(dataft[:2], res[0][:2])


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit acfa6a4

Please sign in to comment.