Skip to content

Commit

Permalink
better or not?
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Feb 25, 2025
1 parent f150bf7 commit 9c9b10f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
9 changes: 5 additions & 4 deletions skl2onnx/operator_converters/nearest_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,13 +970,14 @@ def make_calc_impute_donors(g: ModelComponentContainer, scope: Scope, itype: int
sym_size_int_4 = op.SqueezeAnyOpset(
_shape_dist_pot_donors0, outputs=["sym_size_int_4"]
)
unused_topk_values, output_0 = op.TopK(
dist_pot_donors,
unused_topk_values, neg_output_0 = op.TopK(
op.Neg(dist_pot_donors),
op.Reshape(n_neighbors, np.array([1], dtype=np.int64)),
largest=0,
largest=1,
sorted=1,
outputs=["unused_topk_values", "output_0"],
outputs=["unused_topk_values", "neg_output_0"],
)
output_0 = op.Neg(neg_output_0, outputs=["output_0"])
arange = op.Range(init7_s_0, sym_size_int_4, init7_s_1, outputs=["arange"])
unsqueeze = op.UnsqueezeAnyOpset(arange, init7_s1_1, outputs=["unsqueeze"])
_onx_gathernd_dist_pot_donors0 = op.GatherND(
Expand Down
9 changes: 9 additions & 0 deletions tests/test_sklearn_nearest_neighbour_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,8 @@ def test_model_knn_iris_classifier_multi_reg3_weight(self):
@unittest.skipIf(TARGET_OPSET < 18, reason="not available")
@ignore_warnings(category=DeprecationWarning)
def test_sklearn_knn_imputer_issue_2025(self):
# This test is about having nan or the fact TopK
# does not handle largest=1 in opset < 11.
from onnxruntime import InferenceSession

data = (numpy.arange(14) + 100).reshape((-1, 2)).astype(float)
Expand All @@ -1691,6 +1693,13 @@ def test_sklearn_knn_imputer_issue_2025(self):
None, {"float_input": input_data}
)[0]
assert_almost_equal(expected, got)

# in case onnruntime fails
# from experimental_experiment.reference import OrtEval

# got = OrtEval(onnx_model, verbose=10).run(None, {"float_input": input_data})[0]
# assert_almost_equal(expected, got)

got = InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
).run(None, {"float_input": input_data})[0]
Expand Down

0 comments on commit 9c9b10f

Please sign in to comment.