From 20c372848c093fc3ab6ddce5b697ab0ef7870ce5 Mon Sep 17 00:00:00 2001 From: xadupre Date: Sun, 22 Dec 2024 23:52:11 +0100 Subject: [PATCH] fix comparison Signed-off-by: xadupre --- .../utils/utils_backend_onnxruntime.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/onnxmltools/utils/utils_backend_onnxruntime.py b/onnxmltools/utils/utils_backend_onnxruntime.py index edb6fea7..9076fb83 100644 --- a/onnxmltools/utils/utils_backend_onnxruntime.py +++ b/onnxmltools/utils/utils_backend_onnxruntime.py @@ -297,6 +297,30 @@ def _compare_expected( (len(expected), len(output.ravel()) // len(expected)) ) if len(expected) != len(output): + if ( + len(output) == 2 + and len(expected) == 1 + and output[0].dtype in (numpy.int64, numpy.int32) + ): + # a classifier + if len(expected[0].shape) == 1: + expected = [ + numpy.hstack( + [ + 1 - expected[0].reshape((-1, 1)), + expected[0].reshape((-1, 1)), + ] + ) + ] + return _compare_expected( + expected, + output[1:], + sess, + onnx, + decimal=5, + onnx_shape=None, + **kwargs + ) raise OnnxRuntimeAssertionError( "Unexpected number of outputs '{0}', expected={1}, got={2}".format( onnx, len(expected), len(output)