Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

some model convert error question when tensorrt backend #9853

Open
ggslayer opened this issue Nov 24, 2021 · 5 comments
Open

some model convert error question when tensorrt backend #9853

ggslayer opened this issue Nov 24, 2021 · 5 comments
Labels
ep:TensorRT issues related to TensorRT execution provider stale issues that have not been addressed in a while; categorized by a bot

Comments

@ggslayer
Copy link

Hi, I am testing onnx-tensorrt backend, block at the first step:
when I ran my test onnx model in tensorrt backend, I got the warns and errors as below:

warn: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.

error: [shapeContext.cpp::volumeOfShapeTensor::497] Error Code 2: Internal Error (Assertion hasAllConstantValues(t.extent) && "shape tensor must have build-time extent" failed.)

warn: Output type must be INT32 for shape outputs
loglevel:1 tag:ONNXRuntime msg:Exception during initialization:

error: SubGraphCollection_t onnxruntime::TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t, int, int, const onnxruntime::GraphViewer&, bool*) const [ONNXRuntimeError] : 1 : FAIL : TensorRT input: 252 has no shape specified. Please run shape inference on the onnx model first. Details can be found in https://www.onnxruntime.ai/docs/reference/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs

env:
ubuntu 20.04 & onnxruntime 1.9.1 & tensorrt:8.0.3.4

I'm newer for onnxruntime & tensorrt, please give me some explain about the warn&error and some advices to run succeed,
thank you very much!!

attention: my model is running ok under cuda backend.

@pranavsharma pranavsharma added the ep:TensorRT issues related to TensorRT execution provider label Nov 24, 2021
@stevenlix
Copy link
Contributor

Could you run shape inference on the model first as indicated by the message in #4? TensorRT requires input shape info to be available before inference. #1 and #3 are not errors. Not sure what's the reason for #2. Can you share your model if possible?

@ggslayer
Copy link
Author

Hi, sevenlix, thank you for your reply, I tried the onnxruntime/python/tools/symbolic_shape_infer.py on my onnx model, but got a error:

symbolic_shape_infer.py --input=./fs.onnx
input model: ./fs.onnx
Doing symbolic shape inference...
Traceback (most recent call last):
File "/data/projects/onnxruntime_1.9.1/onnxruntime/python/tools/symbolic_shape_infer.py", line 2035, in
out_mp = SymbolicShapeInference.infer_shapes(onnx.load(args.input), args.int_max, args.auto_merge,
File "/data/projects/onnxruntime_1.9.1/onnxruntime/python/tools/symbolic_shape_infer.py", line 1999, in infer_shapes
all_shapes_inferred = symbolic_shape_inference._infer_impl()
File "/data/projects/onnxruntime_1.9.1/onnxruntime/python/tools/symbolic_shape_infer.py", line 1833, in infer_impl
self.dispatcher
node.op_type
File "/data/projects/onnxruntime_1.9.1/onnxruntime/python/tools/symbolic_shape_infer.py", line 1627, in _infer_Unsqueeze
output_rank = len(input_shape) + len(axes)
TypeError: object of type 'NoneType' has no len()

please give me some advice for this error, thank you!

@askhade
Copy link
Contributor

askhade commented Nov 29, 2021

Please run onnx checker to make sure the model is valid.

@weixsong
Copy link

weixsong commented Dec 19, 2021

The model is correct because the ORT could load the model and do inference as expected.

We want to convert the ONNX model to TensorRT, it seems the error comes from the mask code:

def get_mask_from_lengths(lengths, max_len=None):
    batch_size = lengths.shape[0]
    if max_len is None:
        max_len = torch.max(lengths)

    ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
    mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)

    return mask

We also tried using torch2trt to export the TRT model directly, but encountered the below error:

Traceback (most recent call last):
  File "torch2trt_tutorial.py", line 207, in <module>
    model_trt = torch2trt(model, [speakers, texts, src_lens, duration_control])
  File "/export/anaconda3/lib/python3.7/site-packages/torch2trt-0.3.0-py3.7.egg/torch2trt/torch2trt.py", line 553, in torch2trt
    outputs = module(*inputs)
  File "/export/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/export/users/songwei/workspace/FastSpeech/exp02_test/fastspeech/model/fastspeech2.py", line 35, in forward
    src_masks = get_mask_from_lengths(src_lens, max_src_len)
  File "/export/users/songwei/workspace/FastSpeech/exp02_test/fastspeech/utils/tools.py", line 64, in get_mask_from_lengths
    max_len = torch.max(lengths)
  File "/export/anaconda3/lib/python3.7/site-packages/torch2trt-0.3.0-py3.7.egg/torch2trt/torch2trt.py", line 300, in wrapper
    converter["converter"](ctx)
  File "/export/anaconda3/lib/python3.7/site-packages/torch2trt-0.3.0-py3.7.egg/torch2trt/converters/max.py", line 33, in convert_max
    __convert_max_reduce(ctx)
  File "/export/anaconda3/lib/python3.7/site-packages/torch2trt-0.3.0-py3.7.egg/torch2trt/converters/max.py", line 21, in __convert_max_reduce
    output_val = ctx.method_return[0]
  File "/export/anaconda3/lib/python3.7/site-packages/torch2trt-0.3.0-py3.7.egg/torch2trt/torch2trt.py", line 291, in wrapper
    outputs = method(*args, **kwargs)
IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number

This error is also from the mask code, it seems the ONNX to TRT error and the torch2trt error are from the same part.

This dynamically generated mask is a must for our model because of the dynamic predicted decoder length.
How could we solve this issue?

It seems TRT does not support the torch.arange function?

@stale
Copy link

stale bot commented Apr 17, 2022

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

@stale stale bot added the stale issues that have not been addressed in a while; categorized by a bot label Apr 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:TensorRT issues related to TensorRT execution provider stale issues that have not been addressed in a while; categorized by a bot
Projects
None yet
Development

No branches or pull requests

5 participants