Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorRT 7.1.3 / Cuda11 / CuDNN 8 - supported yet? #4841

Closed
kzuiderveld opened this issue Aug 18, 2020 · 36 comments
Closed

TensorRT 7.1.3 / Cuda11 / CuDNN 8 - supported yet? #4841

kzuiderveld opened this issue Aug 18, 2020 · 36 comments
Assignees
Labels
ep:TensorRT issues related to TensorRT execution provider

Comments

@kzuiderveld
Copy link

Hi,

I pulled the latest Onnx runtime version and compiled with the latest NVIDIA libraries (see title). The tests passed except for the Python test (which is expected).

Unfortunately, the TensorRT path gives significantly different results compared to the CPU path (I can tell differences by visual inspection of the activation output).

Are there currently known compatibility issues with the latest TensorRT/CuDNN/Cuda libraries? If so, is there an ETA for a fix? If not, is there somebody I can work with to troubleshoot this further?

Karel

@snnn snnn added the ep:TensorRT issues related to TensorRT execution provider label Aug 18, 2020
@snnn
Copy link
Member

snnn commented Aug 18, 2020

Could you please provide more information? We plan to release ORT 1.5 with TensorRT 7.1, it should work.

@kzuiderveld
Copy link
Author

kzuiderveld commented Aug 18, 2020

Onnx runtime version: a3c953
TensorRT 7.1.3, CuDNN 8.0, Cuda 11.0
Using simple U-net model (can share under NDA if needed, but I imagine other might have similar issues).

tile4608_8704_channel4
tile4608_8704_channel4

Attached one of the output channels. Visual inspection shows they're similar, but quite different in places yielding different results. Only change in the code is switching from default to TensorRT provider, so differences originate inside Onnx Runtime or TensorRT. TensorRT fp16 and fp32 gave virtually identical results (different from CPU version).

Any suggestions on how to troubleshoot this most effectively?

@snnn
Copy link
Member

snnn commented Aug 18, 2020

@stevenlix Do you know why?

@jywu-msft
Copy link
Member

we're unaware of a general systemic issue with ORT + TensorRT EP.
if it's possible, sharing a minimal repro (rather than full model) would be very helpful to debug the issue further from our end.

@kzuiderveld
Copy link
Author

kzuiderveld commented Aug 31, 2020

Hello,

I re-ported TensorRT 7.1.3.4 - I don't believe it is working correctly. Lots of errors like
F:\onnxruntime\onnxruntime\test\providers\provider_test_utils.cc(155): error: The difference between expected[i] and output[i] is 0.41475188732147217, which exceeds threshold

that suggest that tests are failing. Comments?

I uploaded the output of onnxruntime_test_all.exe here: https://1drv.ms/u/s!AtdjyUUyTjd8_RzBG2s9B4RMYlt9?e=S7EVyA

Here are my installation notes:

  1. Downloaded and installed Quadro NVIDIA driver 452.06 (388,202 KB) - my system has a Quadro RTX4000.

  2. Downloaded and installed cuda_11.0.3_451.82_win10.
    Installed:

  3. Built the CUDA samples. 174 succeeded, 2 failed (MPI and Vulkan not installed on my system)
    Tried various examples, all ran successfully ("Turing" with compute capability 7.5)

  4. Downloaded cudnn-11.0-windows-x64-v8.0.2.39. Unzipped into f:\cudnn8. Header file identifies version 8.0.2

  5. Downloaded TensorRT-7.1.3.4.Windows10.x86_64.cuda-11.0.cudnn8.0. Unzipped into f:\TenosrRT-7.1.3.4.

  6. Performed a "git pull" on OnnxRunTime 10:02 MDT 8/31/2020. Last checkin Optimize MatmulGradient when B is 2D weight #4977, 98f7fdd.

  7. Attempted to build OnnxRunTime with TensorRT support:
    .\build.bat --use_tensorrt --tensorrt_home "f:\TensorRT-7.1.3.4" --cuda_home "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0" --cudnn_home "f:\cudnn-8.0.2\cuda" --parallel --build_shared_lib --config RelWithDebInfo --cmake_generator "Visual Studio 16 2019"

    Fails: 2020-08-31 10:08:52,352 Build [ERROR] - No version file found in CUDA install directory. Looked for C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.0\version.txt

    -> That's an Onnx Runtime issue, cmake supports needs to be modified for 11.0

  8. Manually created c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\V11.0\version.txt with the string
    CUDA Version 11.0.194

  9. Compilation now proceeds. However, the tests give various error messages, including many like
    1: F:\onnxruntime\onnxruntime\test\providers\provider_test_utils.cc(155): error: The difference between expected[i] and output[i] is 0.41475188732147217, which exceeds threshold, where
    1: expected[i] evaluates to 0.63340294361114502,
    1: output[i] evaluates to 0.21865105628967285, and
    1: threshold evaluates to 0.004999999888241291.
    1: i:35730, provider_type: TensorrtExecutionProvider
    This tells me something is broken.

    Last screen after completion of compilation:
    6: [----------] 1 test from TestSessionOptions (0 ms total)
    6:
    6: [----------] Global test environment tear-down
    6: [==========] 1 test from 1 test suite ran. (0 ms total)
    6: [ PASSED ] 1 test.
    6/6 Test Add doxygen generated website for the project #6: onnxruntime_api_tests_without_env ...... Passed 0.11 sec

83% tests passed, 1 tests failed out of 6

Total Test time (real) = 1322.91 sec

The following tests FAILED:
1 - onnxruntime_test_all (Failed)
Errors while running CTest
Traceback (most recent call last):
File "F:\onnxruntime\tools\ci_build\build.py", line 1824, in
sys.exit(main())
File "F:\onnxruntime\tools\ci_build\build.py", line 1766, in main
run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs)
File "F:\onnxruntime\tools\ci_build\build.py", line 1221, in run_onnxruntime_tests
run_subprocess(ctest_cmd, cwd=cwd, dll_path=dll_path)
File "F:\onnxruntime\tools\ci_build\build.py", line 428, in run_subprocess
env=my_env, shell=shell)
File "C:\Program Files\WindowsApps\PythonSoftwareFoundation.Python.3.7_3.7.2544.0_x64__qbz5n2kfra8p0\lib\subprocess.py", line 512, in run
output=stdout, stderr=stderr)
subprocess.CalledProcessError: Command '['C:\Program Files\CMake\bin\ctest.EXE', '--build-config', 'RelWithDebInfo', '--verbose']' returned non-zero exit status 8.

F:\onnxruntime>

@jywu-msft
Copy link
Member

Thanks for the details.
re: #7, @snnn , do we need to expand on #4706 , seems like the version.txt file may not exist on windows
re: the test failures, I glanced over your log output, and there are a lot of test failures. (203 tests)
Our internal CI test runs shows all the tests passing - currently I don't have a good explanation for why it's failing in your environment/hardware. we haven't tested on a Quadro RTX4000 on Windows but i wouldn't have anticipated issues with it.
if we can procure the hardware we will try to do more testing on different windows hardware to see if we can repro these issues.

https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=211587&view=logs&j=72b59a5c-f9fa-5437-adfc-3f855724a45f&t=706f6d43-7b36-5d52-2781-608600615bec

1: [----------] Global test environment tear-down
1: [==========] 2186 tests from 187 test suites ran. (477588 ms total)
1: [ PASSED ] 2168 tests.
1: [ SKIPPED ] 18 tests, listed below:
1: [ SKIPPED ] InferenceSessionTests.TestLenientShapeInferencing
1: [ SKIPPED ] DequantizeLinearOpTest.Per_Channel_Axis_Default
1: [ SKIPPED ] DequantizeLinearOpTest.Per_Channel_Axis_0
1: [ SKIPPED ] DequantizeLinearOpTest.Per_Channel_Axis_1
1: [ SKIPPED ] DequantizeLinearOpTest.Per_Channel_Neg_2
1: [ SKIPPED ] MathOpTest.Sign_BFloat16
1: [ SKIPPED ] LogSoftmaxOperator.InvalidAxis
1: [ SKIPPED ] SoftmaxOperator.InvalidAxis
1: [ SKIPPED ] ConvTest.Conv1D_Invalid_Input_Shape
1: [ SKIPPED ] ConvTest.Conv2D_Invalid_Input_Shape
1: [ SKIPPED ] TfIdfVectorizerTest.Int32_TF_onlyBigrams_Skip0_Empty_Dim1Fail
1: [ SKIPPED ] TfIdfVectorizerTest.Int32_TF_onlyBigrams_Skip0_Empty_Dim2
1: [ SKIPPED ] TfIdfVectorizerTest.Int32_TF_onlyBigrams_Skip01_Empty_Dim2
1: [ SKIPPED ] QuantizeLinearOpTest.Per_Channel_Axis_Default
1: [ SKIPPED ] QuantizeLinearOpTest.Per_Channel_Axis_0
1: [ SKIPPED ] QuantizeLinearOpTest.Per_Channel_Axis_neg
1: [ SKIPPED ] TensorOpTest.Unsqueeze_Duplicate
1: [ SKIPPED ] TensorOpTest.Unsqueeze_OutOfRange

@snnn
Copy link
Member

snnn commented Sep 2, 2020

@jywu-msft Yes, I think we should.

@xkszltl
Copy link
Contributor

xkszltl commented Sep 2, 2020

FYI I also have colleague reported this kind of issue (i.e. mismatch way over fp error) with TensorRT provider.

@jywu-msft
Copy link
Member

FYI I also have colleague reported this kind of issue (i.e. mismatch way over fp error) with TensorRT provider.

for a specific model? or is it the same case here, where the unit tests fail for certain environments?
we need help getting some repro to investigate both cases.

@jywu-msft
Copy link
Member

can you re-test with the latest master?

@kzuiderveld
Copy link
Author

@jywu-msft
TensorRT 7.1.3, CuDNN 8.0, Cuda 11.0 is still broken on Windows with the latest master.

@jywu-msft
Copy link
Member

@jywu-msft
TensorRT 7.1.3, CuDNN 8.0, Cuda 11.0 is still broken on Windows with the latest master.

can you please be more specific.
it's running fine in our CI/test environments.
it's difficult to debug without more information/repro.

@kzuiderveld
Copy link
Author

@jywu-msft

I was specific in the bug report I originally filed. Here is the output of a test run using the latest branch: https://1drv.ms/u/s!AtdjyUUyTjd8gP44fraVGsr0N4_tUQ?e=d3fSjl

Same issue remains: TensorRT does not give accurate results and is unusable. The tests confirm this.

@jywu-msft
Copy link
Member

I examined your logfile. Yes, it does seem like a lot of tests are failing with incorrect results.
Would you happen to have environment variable ORT_TENSORRT_ENGINE_CACHE_ENABLE set to 1 ?

@kzuiderveld
Copy link
Author

@jywu-msft
Ding, ding, ding! Yes, that environment variable was enabled and set to 1. Setting it to zero fixed the issue.

@jywu-msft
Copy link
Member

@jywu-msft
Ding, ding, ding! Yes, that environment variable was enabled and set to 1. Setting it to zero fixed the issue.

thanks for confirming.
yes, that feature is experimental and still has some kinks. we will fix it.
@stevenlix , @chilo-ms FYI

@kzuiderveld
Copy link
Author

kzuiderveld commented Sep 23, 2020 via email

@jywu-msft
Copy link
Member

While the tests are successful now, the TensorRT path gives results that are significantly different compared to the CUDA and CPU brethens. We’re not there yet. Karel From: George Wu [email protected] Sent: Wednesday, September 23, 2020 16:12 To: microsoft/onnxruntime [email protected] Cc: kzuiderveld [email protected]; State change [email protected] Subject: Re: [microsoft/onnxruntime] TensorRT 7.1.3 / Cuda11 / CuDNN 8 - supported yet? (#4841) @jywu-msft https://github.com/jywu-msft Ding, ding, ding! Yes, that environment variable was enabled and set to 1. Setting it to zero fixed the issue. thanks for confirming. yes, that feature is experimental and still has some kinks. we will fix it. @stevenlix https://github.com/stevenlix , @chilo-ms https://github.com/chilo-ms FYI — You are receiving this because you modified the open/close state. Reply to this email directly, view it on GitHub <#4841 (comment)> , or unsubscribe https://github.com/notifications/unsubscribe-auth/AGWGFN3UK5AIKF5GK7SWXETSHJXBRANCNFSM4QD5RP3A . https://github.com/notifications/beacon/AGWGFN6KONOFQLNGRUVZGIDSHJXBRA5CNFSM4QD5RP3KYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOFGNJ3VQ.gif

is this running a specific model?
ideally, can you share a stripped down test case which repros the issue?

@kzuiderveld
Copy link
Author

I'm reopening this issue as the TensorRT results are still incorrect.

@kzuiderveld kzuiderveld reopened this Sep 24, 2020
@jywu-msft
Copy link
Member

how do we gain access to the model or a stripped down repro test case?
maybe the latter would be easier if you don't want to share your model.

@kzuiderveld
Copy link
Author

kzuiderveld commented Sep 28, 2020 via email

@stevenlix
Copy link
Contributor

stevenlix commented Oct 1, 2020

It will be great if you can share your model for us to debug the issue. In the meanwhile, I am wondering if you could try ORT rel-1.4.0 (integrated with TRT7.0) to see if the accuracy issue is also there.

@jywu-msft
Copy link
Member

It will be great if you can share your model for us to debug the issue. In the meanwhile, I am wondering if you could try ORT rel-1.4.0 (integrated with TRT7.0) to see if the accuracy issue is also there.

we started a thread off-line about model sharing.

@stevenlix
Copy link
Contributor

I tried the model in ORT-TRT7.0 and there is no accuracy issue. It seems a regression in TRT7.1 itself. Will work with Nvidia for further investigation.

@stevenlix
Copy link
Contributor

Update: we found the issue may be related to opset11 support in TRT7.1. When model is converted to opset10, the accuracy issue seems gone. Could you convert your model back to opset10 and try again?

@kzuiderveld
Copy link
Author

Thanks for the update, wonderful to hear that something was broken indeed. I'll try again ASAP.

@kzuiderveld
Copy link
Author

Our scientists say that converting to Opset11 is not possible without changing the Unet model (hardcoding the padding function call input values). That would be our plan B.
Is there an ETA for a fix in TRT?

@jywu-msft
Copy link
Member

jywu-msft commented Oct 6, 2020

The issue is still under investigation by Nvidia.
Note, this accuracy issue doesn't seem to be a widespread issue. So far we've only found 3 models which exhibited this accuracy issue. 1 is yours, 1 is an Microsoft internal model, and the 3rd is the yolov4 model from the onnx model zoo.
All 3 models were opset11 models.
When we downgraded the internal model to opset10, the problem went away so that is why we think it's an issue related to TensorRT's opset11 handling. we haven't yet pinpointed which ops specifically.
we don't currently have an ETA. will update once Nvidia confirms and root causes the issue.

@stevenlix
Copy link
Contributor

The issue has been confirmed by Nvidia, which is caused by opset11 Resize operator. There is a PR (#5442) in ORT to include the fix. Please try this PR or pull master after it's merged. I've tested the model using random data and accuracy issue seems gone.

@kzuiderveld
Copy link
Author

That's awesome news. I'll wait until #5442 is merged into master and then test whether the issue has been resolved.

@kzuiderveld
Copy link
Author

The good news: the issue I reported has been fixed. TensorRT yields similar results compared to CUDA and CPU.

The bad news: the TensorRT implementation is about 2.6x slower compared to the CUDA implementation (Quadro RTX 4000 using fp16). When I tried TensorRT earlier this year, I had a significant performance improvement over CUDA that now completely vanished.

What would the root cause be for this performance degradation? Does the fix disable part of the TensorRT path and falls back to a CPU provider (or CUDA perhaps)?

@kzuiderveld
Copy link
Author

@stevenlix Any update on a fix from NVIDIA? I noticed that TRT 7.2.1 arrived...

@RamanHacks
Copy link

Does ONNX compile with TRT 7.2.1?

@kzuiderveld
Copy link
Author

@stevenlix and others
I compiled the latest OnnxRuntime with cuda 11.0, cudnn 8.0.5 and TensorRT-7.2.1.6. The results with our model are now correct and TensorRT (as expected) is now running at full speed.

This issue now has been resolved.

@jywu-msft
Copy link
Member

jywu-msft commented Nov 18, 2020

@stevenlix and others
I compiled the latest OnnxRuntime with cuda 11.0, cudnn 8.0.5 and TensorRT-7.2.1.6. The results with our model are now correct and TensorRT (as expected) is now running at full speed.

This issue now has been resolved.

great. thanks for the update!
btw, yes, it does build and run with TRT 7.2.1 , but we have not officially integrated the onnx-tensorrt parser from trt 7.2 branch yet.

@ghost
Copy link

ghost commented Jan 19, 2024

@stevenlix and others
I compiled the latest OnnxRuntime with cuda 11.0, cudnn 8.0.5 and TensorRT-7.2.1.6. The results with our model are now correct and TensorRT (as expected) is now running at full speed.
This issue now has been resolved.

great. thanks for the update! btw, yes, it does build and run with TRT 7.2.1 , but we have not officially integrated the onnx-tensorrt parser from trt 7.2 branch yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:TensorRT issues related to TensorRT execution provider
Projects
None yet
Development

No branches or pull requests

6 participants