Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regression: Torch exported Onnx doesn't run after Onnxruntime>=1.17 update - [ShapeInferenceError] #20808

Closed
MengLinMaker opened this issue May 24, 2024 · 6 comments

Comments

@MengLinMaker
Copy link

MengLinMaker commented May 24, 2024

Describe the issue (Issue solved, see closing comment)

Previously - exporting Torch==2.3 model to Onnx - the model would run on Onnxruntime==1.16

Currently - exporting Torch==2.3 model to Onnx - the model doesn't run on Onnxruntime==1.17 nor Onnxruntime==1.18

The error originates from [ShapeInferenceError] First input does not have rank 2

However, I could not track down the location of the error using Netron

It's also possible that the problem is caused by "Torch.onnx.export", which can be found in this tutorial

Log truncated - for context

  File "/Users/menglinmaker/Documents/2-Engineering/Personal/Musidi/transcribe/src/transcribe/inference.py", line 19, in inference
    model = InferenceSession(f'layer/model_onnx/{onnx_path}', providers=['CPUExecutionProvider'])
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/menglinmaker/Documents/2-Engineering/Personal/Musidi/transcribe/.venv/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/Users/menglinmaker/Documents/2-Engineering/Personal/Musidi/transcribe/.venv/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 483, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)

Log last line - I believe this is the cause of the error

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (MatMulBnFusion_Gemm) Op (Gemm) [ShapeInferenceError] First input does not have rank 2

To reproduce

I'm using a custom model.
If necessary, I could create a google colab if that helps.

Urgency

I can no longer update Onnxruntime without breaking the application.

Platform

Mac

OS Version

macOS Sonoma 14.5

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.17.0 and above

ONNX Runtime API

Python

Architecture

ARM64

Execution Provider

Default CPU

Execution Provider Library Version

No response

@MengLinMaker MengLinMaker changed the title Regression: Torch exported Onnx doesn't run after OnnxRT update Regression: Torch exported Onnx doesn't run after Onnxruntime>=1.17 update May 24, 2024
@MengLinMaker
Copy link
Author

Here's the DEBUG (0) log from Onnxruntime - I could't find any helpful info:

2024-05-24 23:25:21.261904 [I:onnxruntime:, inference_session.cc:533 TraceSessionOptions] Session Options {  execution_mode:0 execution_order:DEFAULT enable_profiling:0 optimized_model_filepath: enable_mem_pattern:1 enable_mem_reuse:1 enable_cpu_mem_arena:1 profile_file_prefix:onnxruntime_profile_ session_logid: session_log_severity_level:0 session_log_verbosity_level:0 max_num_graph_transformation_steps:10 graph_optimization_level:3 intra_op_param:OrtThreadPoolParams { thread_pool_size: 0 auto_set_affinity: 0 allow_spinning: 1 dynamic_block_base_: 0 stack_size: 0 affinity_str:  set_denormal_as_zero: 0 } inter_op_param:OrtThreadPoolParams { thread_pool_size: 0 auto_set_affinity: 0 allow_spinning: 1 dynamic_block_base_: 0 stack_size: 0 affinity_str:  set_denormal_as_zero: 0 } use_per_session_threads:1 thread_pool_allow_spinning:1 use_deterministic_compute:0 config_options: {  } }
2024-05-24 23:25:21.262142 [I:onnxruntime:, inference_session.cc:433 operator()] Flush-to-zero and denormal-as-zero are off
2024-05-24 23:25:21.262152 [I:onnxruntime:, inference_session.cc:441 ConstructorCommon] Creating and using per session threadpools since use_per_session_threads_ is true
2024-05-24 23:25:21.262158 [I:onnxruntime:, inference_session.cc:459 ConstructorCommon] Dynamic block base set to 0
2024-05-24 23:25:21.289917 [I:onnxruntime:, inference_session.cc:1602 Initialize] Initializing session.
2024-05-24 23:25:21.304691 [I:onnxruntime:, graph_partitioner.cc:900 InlineFunctionsAOT] This model does not have any local functions defined. AOT Inlining is not performed
2024-05-24 23:25:21.305127 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer EnsureUniqueDQForNodeUnit modified: 0 with status: OK
2024-05-24 23:25:21.322771 [I:onnxruntime:, graph_transformer.cc:15 Apply] GraphTransformer Level1_RuleBasedTransformer modified: 1 with status: OK
multiprocessing.pool.RemoteTraceback:

@MengLinMaker MengLinMaker changed the title Regression: Torch exported Onnx doesn't run after Onnxruntime>=1.17 update Regression: Torch exported Onnx doesn't run after Onnxruntime>=1.17 update - [ShapeInferenceError] May 24, 2024
@MengLinMaker
Copy link
Author

MengLinMaker commented May 25, 2024

Solved the issue!

Cause:

Onnx opset version was not compatible with onnxruntime.

Note: This is not an issue with ONNXRuntime


Fix

  1. Examine which onnx opset and onnxruntime version is required. Eg: onnxruntime==1.18 requires onnx=1.16 and opset 21.

  2. Upgrade onnx opset:

import onnx

oldModel = onnx.load(modelPath)
upgradedModel = onnx.version_converter.convert_version(oldModel, 21)
onnx.save(upgradedModel, modelPath)

@MengLinMaker
Copy link
Author

This issue will be updated in a few months from 31st of May 2024:
pytorch/pytorch#127167

For general best practice, I recommend explicitly stating the ONNX opset version

@csukuangfj
Copy link
Contributor

Onnx opset version was not compatible with onnxruntime.

@MengLinMaker Could you list any reference about it?


The page
https://onnxruntime.ai/docs/reference/compatibility.html#onnx-opset-support
says

ONNX Runtime supports all opsets from the latest released version of the ONNX spec. All versions of ONNX Runtime support ONNX opsets from ONNX v1.2.1+ (opset version 7 and higher).

which means the latest onnxruntime would support all opsets >= 7.

@MengLinMaker
Copy link
Author

MengLinMaker commented Jul 10, 2024

Onnx opset version was not compatible with onnxruntime.

@MengLinMaker Could you list any reference about it?


The page
https://onnxruntime.ai/docs/reference/compatibility.html#onnx-opset-support
says

ONNX Runtime supports all opsets from the latest released version of the ONNX spec. All versions of ONNX Runtime support ONNX opsets from ONNX v1.2.1+ (opset version 7 and higher).

which means the latest onnxruntime would support all opsets >= 7.

Yep, there's no issues with ONNX Runtime that I could find, except that statement you referenced is inaccurate. A specific opset version was required to solve this issue.

Pytorch hard coded the default opset for ONNX conversion.

@MengLinMaker
Copy link
Author

Update:
This issue should be solved in pytorch/pytorch#134571
Pytorch will be able to generate opset 21 ONNX for ONNXRuntime>=1.17 once the PR is merged and new package is released.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants