Skip to content

Commit

Permalink
fix comparison
Browse files Browse the repository at this point in the history
Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre committed Dec 22, 2024
1 parent 7069b3d commit 20c3728
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions onnxmltools/utils/utils_backend_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,30 @@ def _compare_expected(
(len(expected), len(output.ravel()) // len(expected))
)
if len(expected) != len(output):
if (
len(output) == 2
and len(expected) == 1
and output[0].dtype in (numpy.int64, numpy.int32)
):
# a classifier
if len(expected[0].shape) == 1:
expected = [
numpy.hstack(
[
1 - expected[0].reshape((-1, 1)),
expected[0].reshape((-1, 1)),
]
)
]
return _compare_expected(
expected,
output[1:],
sess,
onnx,
decimal=5,
onnx_shape=None,
**kwargs
)
raise OnnxRuntimeAssertionError(
"Unexpected number of outputs '{0}', expected={1}, got={2}".format(
onnx, len(expected), len(output)
Expand Down

0 comments on commit 20c3728

Please sign in to comment.