diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 807bb13aea85c..d5e558d10d2a7 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/platform/env_var_utils.h" #include "gtest/gtest.h" #include "test/common/tensor_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" @@ -1718,6 +1719,13 @@ TEST(AttentionTest, AttentionMaskIndexOutOfRange) { #if !defined(__wasm__) // TODO: fix in web assembly TEST(AttentionTest, AttentionPastState_dynamic) { + // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test. + // Do not run this test unless TF32 is disabled explicitly. + if (HasCudaEnvironment(800) && ParseEnvironmentVariableWithDefault("NVIDIA_TF32_OVERRIDE", 1) != 0) { + GTEST_SKIP() << "Skipping AttentionPastState_dynamic in A100 since TF32 is enabled"; + return; + } + // create rand inputs RandomValueGenerator random{}; @@ -1865,6 +1873,13 @@ static void RunModelWithRandomInput( std::vector& mask_index_data, std::string& onnx_model, bool is_float16) { + // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test. + // Do not run this test unless TF32 is disabled explicitly. + if (HasCudaEnvironment(800) && ParseEnvironmentVariableWithDefault("NVIDIA_TF32_OVERRIDE", 1) != 0) { + GTEST_SKIP() << "Skipping RunModelWithRandomInput in A100 since TF32 is enabled"; + return; + } + RandomValueGenerator random{234}; constexpr int hidden_size = 768;