Skip to content

Commit

Permalink
Add GroupNorm forward operation (#2623)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyeonghwanryu authored Jan 30, 2024
1 parent b9f6ef2 commit c7fa6f7
Show file tree
Hide file tree
Showing 36 changed files with 1,996 additions and 103 deletions.
1 change: 1 addition & 0 deletions docs/apireference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ API Reference
layernorm
sum
argmax
groupnorm

20 changes: 20 additions & 0 deletions docs/groupnorm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

GroupNorm Layer(experimental)
=============================

The groupnorm types and functions.
It splits input channels into num_group groups and do normalize for each group.

To enable this, define MIOPEN_BETA_API before including miopen.h.


miopenNormMode_t
-----------------------

.. doxygenenum:: miopenNormMode_t

miopenGroupNormForward
----------------------------------

.. doxygenfunction:: miopenGroupNormForward

4 changes: 2 additions & 2 deletions docs/layernorm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ The layernorm types and functions.
To enable this, define MIOPEN_BETA_API before including miopen.h.


miopenLayerNormMode_t
miopenNormMode_t
-----------------------

.. doxygenenum:: miopenLayerNormMode_t
.. doxygenenum:: miopenNormMode_t

miopenLayerNormForward
----------------------------------
Expand Down
3 changes: 2 additions & 1 deletion driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
"pool[fp16], lrn[fp16], "
"activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], "
"tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16], "
"argmax[bfp16|fp16]\n");
"argmax[bfp16|fp16], groupnorm[bfp16|fp16]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}

Expand All @@ -175,6 +175,7 @@ inline std::string ParseBaseArg(int argc, char* argv[])
arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "layernorm" &&
arg != "layernormfp16" && arg != "layernormbfp16" && arg != "sum" && arg != "sumfp16" &&
arg != "sumbfp16" && arg != "argmax" && arg != "argmaxfp16" && arg != "argmaxbfp16" &&
arg != "groupnorm" && arg != "groupnormfp16" && arg != "groupnormbfp16" &&
arg != "--version")
{
printf("FAILED: Invalid Base Input Argument\n");
Expand Down
Loading

0 comments on commit c7fa6f7

Please sign in to comment.