diff --git a/CHANGELOGS.md b/CHANGELOGS.md index 4c2b1cf67..9d14b8376 100644 --- a/CHANGELOGS.md +++ b/CHANGELOGS.md @@ -2,6 +2,12 @@ ## 1.19.0 +* Refactors KNNImputer converter with local function to match + scikit-learn's implementation, the code was partially + automatically generated from an equivalent implementation + in pytorch and exported into ONNX + [#1167](https://github.com/onnx/sklearn-onnx/issues/1167), + [#1165](https://github.com/onnx/sklearn-onnx/issues/1165) * Add support to sklearn TargetEncoder [#1137](https://github.com/onnx/sklearn-onnx/issues/1137) * Fixes missing WhiteKernel with return_std=True #1163 diff --git a/skl2onnx/operator_converters/nearest_neighbours.py b/skl2onnx/operator_converters/nearest_neighbours.py index 1d09e1fa3..f5f69b094 100644 --- a/skl2onnx/operator_converters/nearest_neighbours.py +++ b/skl2onnx/operator_converters/nearest_neighbours.py @@ -1320,6 +1320,7 @@ def make_knn_imputer_column_nan_found( dist_chunk = gr.make_tensor_input("dist_chunk") _fit_x = gr.make_tensor_input("_fit_x") row_missing_idx = gr.make_tensor_input("row_missing_idx") + n_neighours = gr.make_tensor_input("n_neighbors") op = gr.get_op_builder(scope) @@ -1336,10 +1337,6 @@ def make_knn_imputer_column_nan_found( value=from_array(np.array(1, dtype=np.int64), name="value"), outputs=["init7_s_1"], ) - c_lifted_tensor_1 = op.Constant( - value=from_array(np.array(3, dtype=np.int64), name="value"), - outputs=["c_lifted_tensor_1"], - ) c_lifted_tensor_2 = op.Constant( value=from_array(np.array([1], dtype=np.int64), name="value"), outputs=["c_lifted_tensor_2"], @@ -1434,8 +1431,8 @@ def make_knn_imputer_column_nan_found( outputs=["select_scatter", "dist_subset_updated", "receivers_idx_updated"], ) - lt = op.Less(c_lifted_tensor_1, sym_size_int_23, outputs=["lt"]) - where_1 = op.Where(lt, c_lifted_tensor_1, sym_size_int_23, outputs=["where_1"]) + lt = op.Less(n_neighours, sym_size_int_23, outputs=["lt"]) + where_1 = op.Where(lt, n_neighours, sym_size_int_23, outputs=["where_1"]) le = op.LessOrEqual(where_1, init7_s_0, outputs=["le"]) # c_lifted_tensor_2 -> init7_s_1 to have a zero time, onnxruntime crashes wher shapes are # (), (1,), () @@ -1502,6 +1499,7 @@ def make_knn_imputer_column(g: ModelComponentContainer, scope: Scope, itype: int mask = gr.make_tensor_input("mask") row_missing_idx = gr.make_tensor_input("row_missing_idx") _fit_x = gr.make_tensor_input("_fit_x") + n_neighbors = gr.make_tensor_input("n_neighbors") op = gr.get_op_builder(scope) zero32 = op.Constant( @@ -1549,6 +1547,7 @@ def make_knn_imputer_column(g: ModelComponentContainer, scope: Scope, itype: int dist_chunk, _fit_x, row_missing_idx, + n_neighbors, ], ["X"], domain="local_domain", @@ -1572,6 +1571,7 @@ def _knn_imputer_builder( _mask_fit_x: "BOOL[s0, 2]", # noqa: F821 _valid_mask: "BOOL[2]", # noqa: F821 _fit_x: "DOUBLE[s1, 2]", # noqa: F821 + n_neighbors: "INT", # noqa: F821 x: "FLOAT[s2, 2]", # noqa: F821 itype: int, ): @@ -1614,6 +1614,7 @@ def _knn_imputer_builder( isnan, nonzero_numpy__0, _fit_x, + n_neighbors, domain="local_domain", ) return op.Compress(x, _valid_mask, axis=1, outputs=["output_0"]) @@ -1656,6 +1657,7 @@ def convert_knn_imputer( knn_op._mask_fit_X, knn_op._valid_mask, training_data, + np.array(knn_op.n_neighbors, dtype=np.int64), operator.inputs[0].full_name, itype=proto_type, ) diff --git a/tests/bug.onnx b/tests/bug.onnx deleted file mode 100644 index 1fdc58190..000000000 Binary files a/tests/bug.onnx and /dev/null differ diff --git a/tests/test_sklearn_nearest_neighbour_converter.py b/tests/test_sklearn_nearest_neighbour_converter.py index 27589eafc..c0555ffa6 100644 --- a/tests/test_sklearn_nearest_neighbour_converter.py +++ b/tests/test_sklearn_nearest_neighbour_converter.py @@ -919,7 +919,10 @@ def test_sklearn_k_neighbours_transformer_connectivity(self): ) def _get_torch_knn_imputer(self): - import torch + try: + import torch + except ImportError: + return None, None def _get_weights(dist, weights): """Get the weights from an array of distances and a parameter ``weights``. @@ -1730,6 +1733,56 @@ def test_sklearn_knn_imputer_issue_2025(self): backend="onnxruntime", ) + @unittest.skipIf(KNNImputer is None, reason="new in 0.22") + @unittest.skipIf( + pv.Version(ort_version) <= pv.Version("1.16.0"), + reason="onnxruntime not recent enough", + ) + @ignore_warnings(category=DeprecationWarning) + @unittest.skipIf( + pv.Version(ort_version) < pv.Version("1.20.0"), + reason="onnxruntime not recent enough", + ) + @unittest.skipIf( + sys.platform != "linux" and pv.Version(skl_version) < pv.Version("1.6.0"), + "investigate why topk returns different results", + ) + def test_knn_imputer_one_nan(self): + import numpy as np + import onnxruntime as rt + + np.random.seed(42) + data = np.random.randn(8, 2).astype(np.float32) + data[0, -1] = np.nan + imputer = KNNImputer(n_neighbors=5) + imputer.fit(data) + dataft = imputer.transform(data) + + tmodel_cls, _ = self._get_torch_knn_imputer() + if tmodel_cls is not None: + import torch + + tmodel = tmodel_cls(imputer) + ty = tmodel.transform( + torch.from_numpy(imputer._mask_fit_X), + torch.from_numpy(imputer._valid_mask), + torch.from_numpy(imputer._fit_X.astype(numpy.float32)), + torch.from_numpy(data), + ) + assert_almost_equal(dataft[:2], ty[:2].numpy()) + + input_data = data.astype(np.float32) + initial_type = [("float_input", FloatTensorType([None, data.shape[1]]))] + onnx_model = convert_sklearn(imputer, initial_types=initial_type) + # with open("test_knn_imputer_one_nan.onnx", "wb") as f: + # f.write(onnx_model.SerializeToString()) + + sess = rt.InferenceSession(onnx_model.SerializeToString()) + input_name = sess.get_inputs()[0].name + output_name = sess.get_outputs()[0].name + res = sess.run([output_name], {input_name: input_data}) + assert_almost_equal(dataft[:2], res[0][:2]) + if __name__ == "__main__": unittest.main(verbosity=2)