Skip to content

Commit

Permalink
Update BiasGelu fusion and related ops (#23518)
Browse files Browse the repository at this point in the history
### Description
(1) Update BiasGelu fusion to support onnx Gelu-20

Since onnx Gelu-20 supports float/double/bf16/fp16, here we update
related ops to support these data types in CUDA and ROCm execution
providers:
(2) Add double support for Gelu/FastGelu op in CUDA/ROCm execution
provider
(3) Add BFloat16 support for Gelu ops in CUDA execution provider

(4) Add unit tests
(5) Update operator documents

### Motivation and Context
#23491
  • Loading branch information
tianleiwu authored Jan 31, 2025
1 parent 4dde74a commit 0bb4ea6
Show file tree
Hide file tree
Showing 18 changed files with 193 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dt><tt>T</tt> : tensor(float), tensor(double), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float or half tensors.</dd>
</dl>

Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -912,11 +912,11 @@ Do not modify directly.*
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|DynamicTimeWarping|*in* input:**F**<br> *out* output:**I**|1+|**F** = tensor(float)<br/> **I** = tensor(int32)|
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GemmFloat8|*in* A:**TA**<br> *in* B:**TB**<br> *in* C:**TC**<br> *in* scaleA:**TS**<br> *in* scaleB:**TS**<br> *in* scaleY:**TS**<br> *out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TS** = tensor(float)|
|GemmaRotaryEmbedding|*in* emb:**U**<br> *in* q:**T**<br> *in* q_rot:**T**<br> *in* k:**T**<br> *in* k_rot:**T**<br> *out* output1:**T**<br> *out* output2:**T**|1+|**T** = tensor(float16)<br/> **U** = tensor(float)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
REGISTER_KERNEL_TYPED(double)

using namespace ONNX_NAMESPACE;

Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample);

class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu);
class CUDA_MS_OP_CLASS_NAME(1, BiasGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu);
Expand Down Expand Up @@ -154,7 +158,6 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, DequantizeLinear);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_int8_t, QAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, QAttention);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedConv);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul); // backward compatibility
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul);
class CUDA_MS_OP_CLASS_NAME(1, QOrderedMatMul);
Expand Down Expand Up @@ -234,10 +237,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasSplitGelu)>,
Expand Down Expand Up @@ -362,7 +368,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, UnfoldTensor)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, DynamicTimeWarping)>,
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, Trilu)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul)>,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@ class ElementwiseTunableOp : public TunableOp<ElementwiseParams<T>> {
}

ELEMENTWISE_FWD_DECL(FastGeLU, float);
ELEMENTWISE_FWD_DECL(FastGeLU, double);
ELEMENTWISE_FWD_DECL(FastGeLU, half);
ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16);

ELEMENTWISE_FWD_DECL(GeLU, float);
ELEMENTWISE_FWD_DECL(GeLU, double);
ELEMENTWISE_FWD_DECL(GeLU, half);
ELEMENTWISE_FWD_DECL(GeLU, BFloat16);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"

ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float);
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double);
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half);
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16);
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"

ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half);
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16);
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ namespace contrib {
namespace rocm {
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu);
Expand Down Expand Up @@ -126,7 +129,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul);
Expand Down Expand Up @@ -173,10 +175,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu)>,
Expand Down Expand Up @@ -287,7 +292,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
// TransposedMatMul is still here for backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1490,7 +1490,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Input(0, "X", "input tensor", "T")
.Input(1, "bias", "bias tensor", "T", OpSchema::Optional)
.Output(0, "Y", "output tensor", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
.TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
// fastgelu(x) =
Expand Down
13 changes: 11 additions & 2 deletions onnxruntime/core/optimizer/bias_gelu_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}

const Node& next_node = (*next_node_itr);
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||

bool is_onnx_gelu = graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {20}, kOnnxDomain);
if (!(is_onnx_gelu ||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "FastGelu", {1}, kMSDomain)) ||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
continue;
Expand All @@ -72,14 +75,20 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}

bool is_approximate = is_fast_gelu;
if (is_onnx_gelu) {
const ONNX_NAMESPACE::AttributeProto* attribute = graph_utils::GetNodeAttribute(next_node, "approximate");
is_approximate = (attribute != nullptr) && utils::HasString(*attribute) && (attribute->s() == "tanh");
}

if (graph.NodeProducesGraphOutput(node)) {
continue;
}

Node& add_node = node;
Node& gelu_node = const_cast<Node&>(next_node);
std::string op_type = "BiasGelu";
if (is_fast_gelu) op_type = "FastGelu";
if (is_approximate) op_type = "FastGelu";

Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type),
op_type,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace cuda {

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
REGISTER_KERNEL_TYPED(double)

template <typename T>
Expand Down Expand Up @@ -80,6 +81,7 @@ namespace contrib::cuda {

REGISTER_CONTRIB_KERNEL_TYPED(float)
REGISTER_CONTRIB_KERNEL_TYPED(MLFloat16)
REGISTER_CONTRIB_KERNEL_TYPED(BFloat16)
REGISTER_CONTRIB_KERNEL_TYPED(double)

#undef REGISTER_CONTRIB_KERNEL_TYPED
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/tensor/gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Status LaunchGeluKernel(

SPECIALIZED_GELU_IMPL(float);
SPECIALIZED_GELU_IMPL(half);
SPECIALIZED_GELU_IMPL(BFloat16);
SPECIALIZED_GELU_IMPL(double);

#undef SPECIALIZED_GELU_IMPL
Expand Down
40 changes: 39 additions & 1 deletion onnxruntime/test/contrib_ops/fastgelu_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) {
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
int min_cuda_architecture = 800;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
return;
Expand Down Expand Up @@ -440,5 +440,43 @@ TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
}
#endif

// CUDA and ROCm only for double type.
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FastGeluTest, FastGeluWithBias_Double) {
OpTester tester("FastGelu", 1, onnxruntime::kMSDomain);

int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;

std::vector<double> X = {
0.8, -0.5, 0.0, 1.0,
0.5, 0.2, 0.3, -0.6};

std::vector<double> B = {
-0.5, 0.6, 1.2, 2.1};

std::vector<double> Y = {
0.185371, 0.053983, 1.061703, 3.097373,
0.000000, 0.630432, 1.399572, 1.399572};

std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
std::vector<int64_t> bias_dims = {hidden_size};
std::vector<int64_t> output_dims = input_dims;

tester.AddInput<double>("X", input_dims, X);
tester.AddInput<double>("bias", bias_dims, B);
tester.AddOutput<double>("Y", output_dims, Y);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif

} // namespace test
} // namespace onnxruntime
40 changes: 40 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4781,6 +4781,46 @@ TEST_F(GraphTransformationTests, BiasGeluTest) {
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1);
}

TEST_F(GraphTransformationTests, BiasOnnxGeluTest) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_gelu_fusion.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1);
}

TEST_F(GraphTransformationTests, BiasOnnxFastGeluTest) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_fast_gelu_fusion.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 0);
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 0);
}

// BiasGelu allows input switching based on input dimensions.
// This test validates the input edges are plugged correct in the optimized graph.
TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) {
Expand Down
Binary file modified onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion.onnx
Binary file not shown.
Loading

0 comments on commit 0bb4ea6

Please sign in to comment.