diff --git a/docs/.doxygen/Doxyfile b/docs/.doxygen/Doxyfile index d8b8d767f0..79eec5c32e 100644 --- a/docs/.doxygen/Doxyfile +++ b/docs/.doxygen/Doxyfile @@ -2065,7 +2065,8 @@ INCLUDE_FILE_PATTERNS = PREDEFINED = __attribute__(x)= \ __inline= \ - MIOPEN_EXPORT + MIOPEN_EXPORT \ + MIOPEN_BETA_API # If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then this # tag can be used to specify a list of macro names that should be expanded. The diff --git a/docs/apireference.rst b/docs/apireference.rst index aa27afa459..ac56c2ee04 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -21,4 +21,5 @@ API Reference loss dropout reduction + layernorm diff --git a/docs/layernorm.rst b/docs/layernorm.rst new file mode 100644 index 0000000000..1d480df7d6 --- /dev/null +++ b/docs/layernorm.rst @@ -0,0 +1,17 @@ + +Layernorm Layer +=================== + +The layernorm types and functions. + + +miopenLayerNormMode_t +----------------------- + +.. doxygenenum:: miopenLayerNormMode_t + +miopenLayerNormForward +---------------------------------- + +.. doxygenfunction:: miopenLayerNormForward + diff --git a/driver/driver.hpp b/driver/driver.hpp index fd5774ae17..d8e5352255 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -150,11 +150,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) printf("Supported Base Arguments: conv[fp16|int8|bfp16|fp8|bfp8], CBAInfer[fp16], " "pool[fp16], lrn[fp16], " "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm, ctc, dropout[fp16], " - "tensorop[fp16], reduce[fp16,fp64]" -#ifdef MIOPEN_BETA_API - ", layernorm[bf16, fp16, fp32]" -#endif - "\n"); + "tensorop[fp16], reduce[fp16,fp64], layernorm[bfp16, fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -175,11 +171,8 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "bnormfp16" && arg != "rnn" && arg != "rnnfp16" && arg != "rnn_seq" && arg != "rnn_seqfp16" && arg != "gemm" /*&& arg != "gemmfp16"*/ && arg != "ctc" && arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" && - arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && -#ifdef MIOPEN_BETA_API - arg != "layernorm" && arg != "layernormfp16" && arg != "layernormbfp16" && -#endif - arg != "--version") + arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "layernorm" && + arg != "layernormfp16" && arg != "layernormbfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); Usage(); diff --git a/driver/layernorm_driver.hpp b/driver/layernorm_driver.hpp index e43f3f2f37..dbdda790b5 100644 --- a/driver/layernorm_driver.hpp +++ b/driver/layernorm_driver.hpp @@ -24,7 +24,6 @@ * *******************************************************************************/ #include -#ifdef MIOPEN_BETA_API #ifndef GUARD_MIOPEN_LAYERNORM_DRIVER_HPP #define GUARD_MIOPEN_LAYERNORM_DRIVER_HPP @@ -112,8 +111,8 @@ class LayerNormDriver : public Driver std::vector weight; std::vector bias; std::vector out; - std::vector mean; - std::vector rstd; + std::vector mean; + std::vector rstd; std::vector outhost; std::vector meanhost; std::vector rstdhost; @@ -164,7 +163,7 @@ int LayerNormDriver::GetandSetData() eps = static_cast(inflags.GetValueDouble("eps")); mode = miopenLayerNormMode_t(inflags.GetValueInt("mode")); - return (0); + return 0; } template @@ -200,24 +199,31 @@ std::vector LayerNormDriver::GetInputTensorLengthsFromCmdLine() int in_h = inflags.GetValueInt("in_h"); int in_d = inflags.GetValueInt("in_d"); - if(in_h != 0) + if((in_n != 0) && (in_c != 0) && (in_d != 0) && (in_h != 0) && (in_w != 0)) { - if(in_d != 0) - { - dim_size = 5; - return std::vector({in_n, in_c, in_d, in_h, in_w}); - } - else - { - dim_size = 4; - return std::vector({in_n, in_c, in_h, in_w}); - } + dim_size = 5; + return std::vector({in_n, in_c, in_d, in_h, in_w}); } - else + else if((in_n != 0) && (in_c != 0) && (in_h != 0) && (in_w != 0)) + { + dim_size = 4; + return std::vector({in_n, in_c, in_h, in_w}); + } + else if((in_n != 0) && (in_c != 0) && (in_w != 0)) { dim_size = 3; return std::vector({in_n, in_c, in_w}); } + else if((in_n != 0) && (in_w != 0)) + { + dim_size = 2; + return std::vector({in_n, in_w}); + } + else + { + std::cout << "Error Input Tensor Lengths\n" << std::endl; + return std::vector({0}); + } } template @@ -230,27 +236,25 @@ int LayerNormDriver::AllocateBuffersAndCopy() size_t mean_sz = GetTensorSize(meanDesc); size_t rstd_sz = GetTensorSize(rstdDesc); - // MIOPEN_BACKEND_HIP uint32_t ctx = 0; in_dev = std::unique_ptr(new GPUMem(ctx, in_sz, sizeof(Tgpu))); weight_dev = std::unique_ptr(new GPUMem(ctx, weight_sz, sizeof(Tgpu))); bias_dev = std::unique_ptr(new GPUMem(ctx, bias_sz, sizeof(Tgpu))); out_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(Tgpu))); - mean_dev = std::unique_ptr(new GPUMem(ctx, mean_sz, sizeof(Tgpu))); - rstd_dev = std::unique_ptr(new GPUMem(ctx, rstd_sz, sizeof(Tgpu))); + mean_dev = std::unique_ptr(new GPUMem(ctx, mean_sz, sizeof(Tref))); + rstd_dev = std::unique_ptr(new GPUMem(ctx, rstd_sz, sizeof(Tref))); in = std::vector(in_sz, static_cast(0)); weight = std::vector(weight_sz, static_cast(0)); bias = std::vector(bias_sz, static_cast(0)); out = std::vector(out_sz, static_cast(0)); - mean = std::vector(mean_sz, static_cast(0)); - rstd = std::vector(rstd_sz, static_cast(0)); + mean = std::vector(mean_sz, static_cast(0)); + rstd = std::vector(rstd_sz, static_cast(0)); outhost = std::vector(out_sz, static_cast(0)); meanhost = std::vector(mean_sz, static_cast(0)); rstdhost = std::vector(rstd_sz, static_cast(0)); - // MIOPEN_BACKEND_HIP int status; for(int i = 0; i < in_sz; i++) @@ -261,22 +265,28 @@ int LayerNormDriver::AllocateBuffersAndCopy() for(int i = 0; i < weight_sz; i++) { - weight[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + if(mode == MIOPEN_ELEMENTWISE_AFFINE) + weight[i] = static_cast(1); + else + weight[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); } - status = weight_dev->ToGPU(q, weight.data()); + status |= weight_dev->ToGPU(q, weight.data()); for(int i = 0; i < bias_sz; i++) { - bias[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + if(mode == MIOPEN_ELEMENTWISE_AFFINE) + bias[i] = static_cast(0); + else + bias[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); } - status = bias_dev->ToGPU(q, bias.data()); + status |= bias_dev->ToGPU(q, bias.data()); status |= out_dev->ToGPU(q, out.data()); status |= mean_dev->ToGPU(q, mean.data()); status |= rstd_dev->ToGPU(q, rstd.data()); - if(status != CL_SUCCESS) - printf("Error copying data to GPU\n"); + if(status != 0) + std::cout << "Error copying data to GPU\n" << std::endl; return miopenStatusSuccess; } @@ -390,6 +400,7 @@ int LayerNormDriver::VerifyForward() if(!std::isfinite(error) || error > tolerance) { std::cout << "Forward LayerNorm FAILED: " << error << std::endl; + return EC_VerifyFwd; } else { @@ -400,6 +411,7 @@ int LayerNormDriver::VerifyForward() if(!std::isfinite(meanerror) || meanerror > tolerance) { std::cout << "Forward LayerNorm mean FAILED: " << meanerror << std::endl; + return EC_VerifyFwd; } else { @@ -410,6 +422,7 @@ int LayerNormDriver::VerifyForward() if(!std::isfinite(rstderror) || rstderror > tolerance) { std::cout << "Forward LayerNorm rstd FAILED: " << rstderror << std::endl; + return EC_VerifyFwd; } else { @@ -425,5 +438,4 @@ int LayerNormDriver::VerifyBackward() return miopenStatusSuccess; } -#endif // GUARD_MIOPEN_SOFTMAX_DRIVER_HPP -#endif +#endif // GUARD_MIOPEN_LAYERNORM_DRIVER_HPP diff --git a/driver/main.cpp b/driver/main.cpp index 79e52e5e38..c4aa25c7e8 100644 --- a/driver/main.cpp +++ b/driver/main.cpp @@ -43,9 +43,7 @@ #include "reduce_driver.hpp" #include #include -#ifdef MIOPEN_BETA_API #include "layernorm_driver.hpp" -#endif int main(int argc, char* argv[]) { @@ -199,7 +197,6 @@ int main(int argc, char* argv[]) { drv = new ReduceDriver(); } -#ifdef MIOPEN_BETA_API else if(base_arg == "layernorm") { drv = new LayerNormDriver(); @@ -212,7 +209,6 @@ int main(int argc, char* argv[]) { drv = new LayerNormDriver(); } -#endif else { printf("Incorrect BaseArg\n"); diff --git a/driver/mloLayerNormHost.hpp b/driver/mloLayerNormHost.hpp index 5c504f8068..2350487144 100644 --- a/driver/mloLayerNormHost.hpp +++ b/driver/mloLayerNormHost.hpp @@ -23,7 +23,6 @@ * SOFTWARE. * *******************************************************************************/ -#ifdef MIOPEN_BETA_API #ifndef MLO_LAYERNORMHOST_H_ #define MLO_LAYERNORMHOST_H_ @@ -46,15 +45,13 @@ int32_t mloLayerNormForwardRunHost(miopenTensorDescriptor_t inputDesc, auto dims = miopen::deref(inputDesc).GetLengths(); size_t outer_size = 1; size_t inner_size = 1; - size_t i = 0; - for(; i < normalized_dim; i++) - { - outer_size *= dims[i]; - } - for(; i < dims.size(); i++) + for(size_t i = 0ULL; i < dims.size(); ++i) { - inner_size *= dims[i]; + if(i < normalized_dim) + outer_size *= dims[i]; + else + inner_size *= dims[i]; } int32_t ret = 0; @@ -63,7 +60,7 @@ int32_t mloLayerNormForwardRunHost(miopenTensorDescriptor_t inputDesc, { Tcheck pmean = 0.0f; Tcheck pvar = 0.0f; - for(i = 0; i < inner_size; i++) + for(int32_t i = 0; i < inner_size; i++) { Tcheck tmp = static_cast(input[o * inner_size + i]); pmean += tmp; @@ -77,10 +74,10 @@ int32_t mloLayerNormForwardRunHost(miopenTensorDescriptor_t inputDesc, meanhost[o] = pmean; rstdhost[o] = prstd; - for(i = 0; i < inner_size; i++) + for(int32_t i = 0; i < inner_size; i++) { - Tcheck pweight = mode ? 1 : static_cast(weight[i]); - Tcheck pbias = mode ? 0 : static_cast(bias[i]); + Tcheck pweight = mode ? static_cast(weight[i]) : 1; + Tcheck pbias = mode ? static_cast(bias[i]) : 0; outputhost[o * inner_size + i] = (static_cast(input[o * inner_size + i]) - pmean) * prstd * pweight + pbias; } @@ -88,4 +85,3 @@ int32_t mloLayerNormForwardRunHost(miopenTensorDescriptor_t inputDesc, return ret; } #endif -#endif diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 9c47d62e81..9f091c835d 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -464,8 +464,8 @@ typedef enum } miopenLRNMode_t; #ifdef MIOPEN_BETA_API /*! @ingroup layernorm - * @enum miopenLayerNormAlgorithm_t - * LayerNorm implementation algorithms + * @enum miopenLayerNormMode_t + * LayerNorm mode */ typedef enum { @@ -2490,8 +2490,6 @@ MIOPEN_EXPORT miopenStatus_t miopenDestroyLRNDescriptor(miopenLRNDescriptor_t lr * @{ */ /*! @brief Execute a layernorm forward layer - * - * This API only implements the LAYERNORM_MODE_CHANNEL in LAYERNORM_ACCURATE path. * * @param handle MIOpen handle (input) * @param mode LayerNorm mode (input) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 77f835c279..41f134e239 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -120,17 +120,18 @@ set( MIOpen_Source find_db.cpp fused_api.cpp fusion.cpp - fusion/problem_description.cpp + fusion/problem_description.cpp generic_search.cpp handle_api.cpp invoker_cache.cpp kernel_build_params.cpp kernel_warnings.cpp + layernorm_api.cpp load_file.cpp lock_file.cpp logger.cpp - layernorm_api.cpp lrn_api.cpp + norm/problem_description.cpp op_args.cpp operator.cpp performance_config.cpp @@ -238,6 +239,9 @@ set( MIOpen_Source solver/gemm_bwd.cpp solver/gemm_common.cpp solver/gemm_wrw.cpp + solver/norm/forward_layernorm.cpp + solver/norm/forward_layernorm2d_ck.cpp + solver/norm/forward_layernorm4d_ck.cpp solver/pooling/forward2d.cpp solver/pooling/forwardNaive.cpp solver/pooling/forwardNd.cpp @@ -537,6 +541,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN list(APPEND MIOpen_Source activ.cpp kernel_cache.cpp + layer_norm.cpp lrn.cpp mlo_dir_conv.cpp exec_utils.cpp @@ -557,7 +562,6 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN hip/hip_build_utils.cpp hip/batched_transpose_sol.cpp hip/general_tensor_reorder_sol.cpp - layer_norm.cpp pooling.cpp ocl/fusionopconvocl.cpp ocl/fusionopbiasbnactivocl.cpp diff --git a/src/include/miopen/kernel_build_params.hpp b/src/include/miopen/kernel_build_params.hpp index 987345717e..946dcac597 100644 --- a/src/include/miopen/kernel_build_params.hpp +++ b/src/include/miopen/kernel_build_params.hpp @@ -142,6 +142,11 @@ struct GcnAsm { static std::string Generate(const std::vector& options); }; + +struct HIP +{ + static std::string Generate(const std::vector& options); +}; } // namespace kbp } // namespace miopen diff --git a/src/include/miopen/layernorm.hpp b/src/include/miopen/layernorm.hpp index f897e79eea..64a3ea8339 100644 --- a/src/include/miopen/layernorm.hpp +++ b/src/include/miopen/layernorm.hpp @@ -24,7 +24,6 @@ * *******************************************************************************/ #include -#ifdef MIOPEN_BETA_API #ifndef MIOPEN_LAYERNORM_HPP_ #define MIOPEN_LAYERNORM_HPP_ @@ -35,7 +34,7 @@ namespace miopen { struct Handle; struct TensorDescriptor; -miopenStatus_t LayerNormForward(const Handle& handle, +miopenStatus_t LayerNormForward(Handle& handle, const TensorDescriptor& xDesc, ConstData_t x, const TensorDescriptor& weightDesc, @@ -54,4 +53,3 @@ miopenStatus_t LayerNormForward(const Handle& handle, } // namespace miopen #endif // _MIOPEN_LAYERNORM_HPP_ -#endif diff --git a/src/include/miopen/norm/invoke_params.hpp b/src/include/miopen/norm/invoke_params.hpp new file mode 100644 index 0000000000..de6abd8c7a --- /dev/null +++ b/src/include/miopen/norm/invoke_params.hpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include + +namespace miopen { +namespace norm { + +struct InvokeParams : public miopen::InvokeParams +{ + InvokeParams() = default; + + const TensorDescriptor* xDesc = nullptr; + + ConstData_t x = nullptr; + ConstData_t weight = nullptr; + ConstData_t bias = nullptr; + Data_t y = nullptr; + Data_t mean = nullptr; + Data_t rstd = nullptr; + float epsilon = 0; + int32_t normalized_dim = 0; + miopenLayerNormMode_t mode = MIOPEN_ELEMENTWISE_AFFINE; + + std::size_t GetWorkspaceSize() const { return 0; } + Data_t GetWorkspace() const { return nullptr; } +}; + +} // namespace norm + +} // namespace miopen diff --git a/src/include/miopen/norm/problem_description.hpp b/src/include/miopen/norm/problem_description.hpp new file mode 100644 index 0000000000..ebc0f657c8 --- /dev/null +++ b/src/include/miopen/norm/problem_description.hpp @@ -0,0 +1,148 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace norm { + +struct ProblemDescription : ProblemDescriptionBase +{ + ProblemDescription(miopenLayerNormMode_t mode_, + const TensorDescriptor& xDesc_, + const TensorDescriptor& weightDesc_, + const TensorDescriptor& biasDesc_, + const TensorDescriptor& yDesc_, + const TensorDescriptor& meanDesc_, + const TensorDescriptor& rstdDesc_, + float epsilon_, + int32_t normalized_dim_) + : mode(mode_), + xDesc(xDesc_), + weightDesc(weightDesc_), + biasDesc(biasDesc_), + yDesc(yDesc_), + meanDesc(meanDesc_), + rstdDesc(rstdDesc_), + epsilon(epsilon_), + normalized_dim(normalized_dim_) + { + } + + miopenLayerNormMode_t GetMode() const { return mode; } + const TensorDescriptor& GetXDesc() const { return xDesc; } + const TensorDescriptor& GetWeightDesc() const { return weightDesc; } + const TensorDescriptor& GetBiasDesc() const { return biasDesc; } + const TensorDescriptor& GetYDesc() const { return yDesc; } + const TensorDescriptor& GetMeanDesc() const { return meanDesc; } + const TensorDescriptor& GetRstdDesc() const { return rstdDesc; } + float GetEpsilon() const { return epsilon; } + int32_t GetNormalizedDim() const { return normalized_dim; } + + bool IsSameType() const + { + if(xDesc.GetType() != yDesc.GetType()) + { + MIOPEN_THROW(miopenStatusBadParm, "LayerNormForward: Tensor types do not match."); + } + return true; + } + + bool IsSameLength() const + { + if(xDesc.GetLengths() != yDesc.GetLengths()) + { + MIOPEN_THROW(miopenStatusBadParm, + "LayerNormForward: Tensor dimension lengths do not match."); + } + return true; + } + + bool IsRightNormDim() const + { + if((normalized_dim < 0) || (normalized_dim > xDesc.GetLengths().size())) + { + MIOPEN_THROW( + miopenStatusBadParm, + "LayerNormForward: normalized dim is greater than 0 and less than or equal " + "Tensor dimension length."); + } + return true; + } + + bool IsAllPacked() const + { + if(!(xDesc.IsPacked() && weightDesc.IsPacked() && biasDesc.IsPacked() && yDesc.IsPacked() && + meanDesc.IsPacked() && rstdDesc.IsPacked())) + { + MIOPEN_THROW(miopenStatusBadParm, "LayerNormForward: Unpacked tensors not supported."); + } + return true; + } + + bool IsLargeSize() const + { + auto dims = xDesc.GetLengths(); + + size_t outer_size = 1; + for(size_t i = 0; i < normalized_dim; i++) + { + outer_size *= dims[i]; + } + + return (outer_size > 32); + } + + NetworkConfig MakeNetworkConfig() const override; + +private: + miopenLayerNormMode_t mode; + TensorDescriptor xDesc; + TensorDescriptor weightDesc; + TensorDescriptor biasDesc; + TensorDescriptor yDesc; + TensorDescriptor meanDesc; + TensorDescriptor rstdDesc; + + float epsilon; + int32_t normalized_dim; + + NetworkConfig MakeForwardNetworkConfig() const; +}; + +} // namespace norm + +} // namespace miopen diff --git a/src/include/miopen/norm/solvers.hpp b/src/include/miopen/norm/solvers.hpp new file mode 100644 index 0000000000..cc2131aedb --- /dev/null +++ b/src/include/miopen/norm/solvers.hpp @@ -0,0 +1,77 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace miopen { + +namespace solver { + +namespace norm { + +using NormalizationSolver = + NonTunableSolverBase; + +struct LayernormForward final : NormalizationSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::norm::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::norm::ProblemDescription& problem) const override; +}; + +struct Layernorm2DCKForward final : NormalizationSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::norm::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::norm::ProblemDescription& problem) const override; +}; + +struct Layernorm4DCKForward final : NormalizationSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::norm::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::norm::ProblemDescription& problem) const override; +}; + +} // namespace norm + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index f0ce465a11..477053ed20 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -52,6 +52,7 @@ enum class Primitive Bias, Fusion, Pooling, + Normalization }; struct Id diff --git a/src/kernel_build_params.cpp b/src/kernel_build_params.cpp index f0f36361ed..10db70380e 100644 --- a/src/kernel_build_params.cpp +++ b/src/kernel_build_params.cpp @@ -76,4 +76,9 @@ std::string kbp::GcnAsm::Generate(const std::vector& optio return GenerateDefines(options, "Wa,-defsym,"); } +std::string kbp::HIP::Generate(const std::vector& options) +{ + return GenerateDefines(options, "D"); +} + } // namespace miopen diff --git a/src/kernels/MIOpenLayerNorm.cpp b/src/kernels/MIOpenLayerNorm.cpp index 520ca01da9..98e1b56d56 100644 --- a/src/kernels/MIOpenLayerNorm.cpp +++ b/src/kernels/MIOpenLayerNorm.cpp @@ -23,8 +23,6 @@ * SOFTWARE. * *******************************************************************************/ -#ifdef MIOPEN_BETA_API - #ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include #include @@ -32,17 +30,20 @@ #include "float_types.h" -//#if MIOPEN_USE_BFP16 == 1 -//#undef FLOAT -//#define FLOAT hip_bfloat16 -//#endif +#if MIOPEN_USE_BFP16 == 1 +#define CVT_FLOAT2ACCUM(x) (bfloat16_to_float(x)) +#define CVT_ACCUM2FLOAT(x) (float_to_bfloat16(x)) +#define CVT_INTEGRAL2ACCUM(x) ((_FLOAT_ACCUM)(x)) +#define CVT_FP32_2FLOAT(x) (CVT_ACCUM2FLOAT(x)) +#define CVT_FP32_2ACCUM(x) (x) +#endif extern "C" __global__ void LayernormFwdContiguous(const FLOAT* __restrict__ x, FLOAT* __restrict__ y, const FLOAT* __restrict__ weight, const FLOAT* __restrict__ bias, - FLOAT* __restrict__ mean, - FLOAT* __restrict__ rstd, + FLOAT_ACCUM* __restrict__ mean, + FLOAT_ACCUM* __restrict__ rstd, float eps, uint64_t inner_size, bool mode) @@ -67,15 +68,15 @@ extern "C" __global__ void LayernormFwdContiguous(const FLOAT* __restrict__ x, const uint64_t gid = blockIdx.x; const uint64_t lid = threadIdx.x; - FLOAT_ACCUM pmean = CVT_FLOAT2ACCUM(0); - FLOAT_ACCUM pvar = CVT_FLOAT2ACCUM(0); + FLOAT_ACCUM pmean = static_cast(0); + FLOAT_ACCUM pvar = static_cast(0); __shared__ FLOAT_ACCUM ltmp1[LOCAL_SIZE]; __shared__ FLOAT_ACCUM ltmp2[LOCAL_SIZE]; // reduce sum for mean and var for(uint64_t i = lid; i < inner_size; i += LOCAL_SIZE) { - uint64_t x_idx = gid * inner_size + i; + size_t x_idx = gid * inner_size + i; FLOAT_ACCUM tmp = CVT_FLOAT2ACCUM(x[x_idx]); pmean += tmp; @@ -85,7 +86,7 @@ extern "C" __global__ void LayernormFwdContiguous(const FLOAT* __restrict__ x, ltmp1[lid] = pmean; ltmp2[lid] = pvar; __syncthreads(); - for(uint64_t i = LOCAL_SIZE >> 1; i > 0; i >>= 1) + for(uint32_t i = LOCAL_SIZE >> 1; i > 0; i >>= 1) { if(lid < i) { @@ -101,24 +102,23 @@ extern "C" __global__ void LayernormFwdContiguous(const FLOAT* __restrict__ x, if(lid == 0) { if(mean) - mean[gid] = CVT_ACCUM2FLOAT(pmean); + mean[gid] = pmean; if(rstd) - rstd[gid] = CVT_ACCUM2FLOAT(prstd); + rstd[gid] = prstd; } // forward calculation for(uint64_t i = lid; i < inner_size; i += LOCAL_SIZE) { - uint64_t idx = gid * inner_size + i; + size_t idx = gid * inner_size + i; FLOAT_ACCUM pweight; FLOAT_ACCUM pbias; - pweight = mode ? CVT_FLOAT2ACCUM(1) : CVT_FLOAT2ACCUM(weight[i]); - pbias = mode ? CVT_FLOAT2ACCUM(0) : CVT_FLOAT2ACCUM(bias[i]); + pweight = mode ? CVT_FLOAT2ACCUM(weight[i]) : CVT_FP32_2ACCUM(1.0f); + pbias = mode ? CVT_FLOAT2ACCUM(bias[i]) : static_cast(0); FLOAT_ACCUM val = (CVT_FLOAT2ACCUM(x[idx]) - pmean) * prstd * pweight + pbias; y[idx] = CVT_ACCUM2FLOAT(val); } } -#endif diff --git a/src/layer_norm.cpp b/src/layer_norm.cpp index 33030887ee..ae10f62eb0 100644 --- a/src/layer_norm.cpp +++ b/src/layer_norm.cpp @@ -24,17 +24,16 @@ * *******************************************************************************/ #include -#ifdef MIOPEN_BETA_API #include #include -#include #include - -#define LOCAL_SIZE 256 +#include +#include +#include namespace miopen { -miopenStatus_t LayerNormForward(const Handle& handle, +miopenStatus_t LayerNormForward(Handle& handle, const TensorDescriptor& xDesc, ConstData_t x, const TensorDescriptor& weightDesc, @@ -51,85 +50,33 @@ miopenStatus_t LayerNormForward(const Handle& handle, float epsilon, int32_t normalized_dim) { - if(x == nullptr || y == nullptr) - { - MIOPEN_THROW(miopenStatusBadParm, "Null pointer for tensor."); - } - - if(xDesc.GetType() != yDesc.GetType()) - { - MIOPEN_THROW(miopenStatusBadParm, "Tensor types do not match."); - } - - if(xDesc.GetLengths() != yDesc.GetLengths()) - { - MIOPEN_THROW(miopenStatusBadParm, "Tensor dimension lengths do not match."); - } - - bool is_all_packed = xDesc.IsPacked() && weightDesc.IsPacked() && biasDesc.IsPacked() && - yDesc.IsPacked() && meanDesc.IsPacked() && rstdDesc.IsPacked(); - - if(!is_all_packed) - { - MIOPEN_THROW(miopenStatusBadParm, "All tensor is not packed."); - } - - auto dims = xDesc.GetLengths(); - size_t grid_size = 1; - size_t outer_size = 1; - size_t inner_size = 1; - size_t i = 0; - for(; i < normalized_dim; i++) - { - outer_size *= dims[i]; - grid_size *= dims[i]; - } - - for(; i < dims.size(); i++) - { - inner_size *= dims[i]; - grid_size *= dims[i]; - } - - auto dtype = xDesc.GetType(); - - const std::vector vld{LOCAL_SIZE, 1, 1}; - const std::vector vgd{outer_size * vld[0], 1, 1}; - - std::string algo_name = "LayerNormForward"; - std::string network_config = - "lnfwd-dtype" + std::to_string(static_cast(dtype)) + "g" + std::to_string(vgd[0]) + - "l" + std::to_string(vld[0]) + "normalized_dim" + std::to_string(normalized_dim) + "grid" + - std::to_string(grid_size) + "outer_size" + std::to_string(outer_size) + "inner_size" + - std::to_string(inner_size) + "mode" + std::to_string(static_cast(mode)) + "eps" + - std::to_string(static_cast(epsilon)); - - std::string program_name = "MIOpenLayerNorm.cpp"; - std::string kernel_name = "LayernormFwdContiguous"; - - // compile parameters - std::string parms = - " -DMIOPEN_USE_FP16=" + std::to_string(static_cast(dtype == miopenHalf)) + - " -DMIOPEN_USE_FP32=" + std::to_string(static_cast(dtype == miopenFloat)) + - " -DMIOPEN_USE_FP64=" + std::to_string(static_cast(dtype == miopenDouble)) + - " -DMIOPEN_USE_BFP16=" + std::to_string(static_cast(dtype == miopenBFloat16)); - - parms += " -DMIOPEN_BETA_API=1"; - parms += " -DLOCAL_SIZE=" + std::to_string(LOCAL_SIZE); - - auto&& kernels = handle.GetKernels(algo_name, network_config); - if(!kernels.empty()) - { - kernels.front()(x, y, weight, bias, mean, rstd, epsilon, inner_size, mode); - } - else - { - handle.AddKernel(algo_name, network_config, program_name, kernel_name, vld, vgd, parms)( - x, y, weight, bias, mean, rstd, epsilon, inner_size, mode); - } + const auto problem = norm::ProblemDescription{ + mode, xDesc, weightDesc, biasDesc, yDesc, meanDesc, rstdDesc, epsilon, normalized_dim}; + + const auto invoke_params = [&]() { + auto tmp = norm::InvokeParams{}; + tmp.type = InvokeType::Run; + tmp.xDesc = &xDesc; + tmp.x = x; + tmp.weight = weight; + tmp.bias = bias; + tmp.y = y; + tmp.mean = mean; + tmp.rstd = rstd; + tmp.epsilon = epsilon; + tmp.normalized_dim = normalized_dim; + tmp.mode = mode; + return tmp; + }(); + + const auto algo = AlgorithmName{"LayerNormForward"}; + const auto solvers = solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); return miopenStatusSuccess; } } // namespace miopen -#endif diff --git a/src/layernorm_api.cpp b/src/layernorm_api.cpp index 1c8f8d0cca..dc3dcb4a53 100644 --- a/src/layernorm_api.cpp +++ b/src/layernorm_api.cpp @@ -24,7 +24,6 @@ * *******************************************************************************/ #include -#ifdef MIOPEN_BETA_API #include #include #include @@ -134,4 +133,3 @@ extern "C" miopenStatus_t miopenLayerNormForward(miopenHandle_t handle, normalized_dim); }); } -#endif diff --git a/src/norm/problem_description.cpp b/src/norm/problem_description.cpp new file mode 100644 index 0000000000..3e99557187 --- /dev/null +++ b/src/norm/problem_description.cpp @@ -0,0 +1,64 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include + +#include + +namespace miopen { + +namespace norm { + +NetworkConfig ProblemDescription::MakeNetworkConfig() const +{ + auto dims = xDesc.GetLengths(); + size_t outer_size = 1; + size_t inner_size = 1; + + for(size_t i = 0ULL; i < dims.size(); ++i) + { + if(i < normalized_dim) + outer_size *= dims[i]; + else + inner_size *= dims[i]; + } + + auto dtype = xDesc.GetType(); + + std::ostringstream ss; + + ss << "dtype" << dtype; + ss << "normalized_dim" << normalized_dim; + ss << "outer_size" << outer_size; + ss << "inner_size" << inner_size; + + return NetworkConfig{ss.str()}; +} + +} // namespace norm + +} // namespace miopen diff --git a/src/solver.cpp b/src/solver.cpp index 01835dcb1c..a15e82db7c 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -606,6 +607,9 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdInference{}.SolverDbId()); Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKBwdBackward{}.SolverDbId()); Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdTraining{}.SolverDbId()); + Register(registry, ++id, Primitive::Normalization, norm::Layernorm2DCKForward{}.SolverDbId()); + Register(registry, ++id, Primitive::Normalization, norm::Layernorm4DCKForward{}.SolverDbId()); + Register(registry, ++id, Primitive::Normalization, norm::LayernormForward{}.SolverDbId()); // IMPORTANT: New solvers should be added to the end of the function! } diff --git a/src/solver/norm/forward_layernorm.cpp b/src/solver/norm/forward_layernorm.cpp new file mode 100644 index 0000000000..be258a5d22 --- /dev/null +++ b/src/solver/norm/forward_layernorm.cpp @@ -0,0 +1,144 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace norm { + +std::size_t sizeof_kernel_FLOAT(const miopen::norm::ProblemDescription& problem) +{ + const auto datatype = problem.GetXDesc().GetType(); + return get_data_size(datatype); +} + +std::size_t sizeof_local_memory(const miopen::norm::ProblemDescription& problem) +{ + std::size_t rv = 0; + rv += LOCAL_SIZE * sizeof_kernel_FLOAT(problem) * 2; + return rv; +} + +bool LayernormForward::IsApplicable(const ExecutionContext&, + const miopen::norm::ProblemDescription& problem) const +{ + return (sizeof_local_memory(problem) <= TargetProperties::GetMaxLocalMemorySize()); +} + +ConvSolution LayernormForward::GetSolution(const ExecutionContext& context, + const miopen::norm::ProblemDescription& problem) const +{ + std::ignore = context; + + auto result = ConvSolution{miopenStatusSuccess}; + + { + auto dtype = problem.GetXDesc().GetType(); + auto dims = problem.GetXDesc().GetLengths(); + + size_t outer_size = 1; + for(size_t i = 0; i < problem.GetNormalizedDim(); i++) + { + outer_size *= dims[i]; + } + + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = outer_size * xlocalsize; + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + + kernel.kernel_file = "MIOpenLayerNorm.cpp"; + kernel.kernel_name = "LayernormFwdContiguous"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"LOCAL_SIZE", LOCAL_SIZE}, + }; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto dims = params.xDesc->GetLengths(); + size_t inner_size = 1; + + for(size_t i = params.normalized_dim; i < dims.size(); i++) + { + inner_size *= dims[i]; + } + + kernel(params.x, + params.y, + params.weight, + params.bias, + params.mean, + params.rstd, + params.epsilon, + inner_size, + static_cast(params.mode)); + }; + }; + + return result; +} + +} // namespace norm + +} // namespace solver + +} // namespace miopen diff --git a/src/solver/norm/forward_layernorm2d_ck.cpp b/src/solver/norm/forward_layernorm2d_ck.cpp new file mode 100644 index 0000000000..7b14c77429 --- /dev/null +++ b/src/solver/norm/forward_layernorm2d_ck.cpp @@ -0,0 +1,283 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#if MIOPEN_USE_COMPOSABLEKERNEL +#include +#include +#endif +MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_LAYERNORM2DCKFORWARD_CONV_CK_LN) + +namespace miopen { +namespace solver { +namespace norm { +#if MIOPEN_USE_COMPOSABLEKERNEL + +using F16 = ck::half_t; +using F32 = float; +using F64 = double; +using BF16 = ushort; + +template +using DeviceOp = ck::tensor_operation::device::DeviceNormalization< + XDataType, + GammaDataType, + BetaDataType, + YDataType, + SaveMeanInvStdDataType, + ck::tensor_operation::element_wise::PassThrough, + 2, + 1>; +template +using DeviceOpLnFwdPtrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>; + +namespace { +struct CKArgs +{ + CKArgs(const miopen::norm::ProblemDescription& problem) + { + auto length = problem.GetXDesc().GetLengths(); + + N = length[0]; + W = length[1]; + + N_stride = W; + W_stride = 1; + + xyLengths = {N, W}; + xyStrides = {N_stride, W_stride}; + gammaStrides = {0, W_stride}; + betaStrides = {0, W_stride}; + meanStrides = {1}; + rstdStrides = {1}; + epsilon = problem.GetEpsilon(); + } + + CKArgs(const CKArgs&) = default; + CKArgs(CKArgs&&) = default; + CKArgs& operator=(const CKArgs&) = default; + + template + auto MakeArgPtr(const LNPtr& ln_ptr, + ConstData_t x, + ConstData_t weight, + ConstData_t bias, + Data_t y, + Data_t mean, + Data_t rstd) const + { + return ln_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + gammaStrides, + betaStrides, + xyStrides, + meanStrides, + rstdStrides, + {1}, + epsilon, + x, + weight, + bias, + y, + mean, + rstd, + ck::tensor_operation::element_wise::PassThrough{}); + } + + template + bool IsSupportedBy(const LNPtr& ln_ptr) const + { + auto arg_ptr = MakeArgPtr(ln_ptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); + return ln_ptr->IsSupportedArgument(arg_ptr.get()); + } + + int32_t N; + int32_t W; + int32_t N_stride; + int32_t W_stride; + std::vector xyLengths; + std::vector xyStrides; + std::vector gammaStrides; + std::vector betaStrides; + std::vector meanStrides; + std::vector rstdStrides; + float epsilon; +}; +} // namespace + +template +bool CheckCKApplicability(const miopen::norm::ProblemDescription& problem) +{ + const auto ln_args = CKArgs{problem}; + const auto ln_ptrs = DeviceOpType::GetInstances(); + + return std::any_of(ln_ptrs.begin(), ln_ptrs.end(), [&ln_args](auto& ln_ptrs) { + return ln_args.IsSupportedBy(ln_ptrs); + }); +} + +template +typename LnPtrsType::iterator FindLnPtr(LnPtrsType& ln_ptrs, + const miopen::norm::ProblemDescription& problem) +{ + const auto ln_args = CKArgs{problem}; + return std::find_if(ln_ptrs.begin(), ln_ptrs.end(), [&ln_args](auto& ln_ptrs) { + return ln_args.IsSupportedBy(ln_ptrs); + }); +} + +template +ConvSolution MakeInvokerFactory([[maybe_unused]] const ExecutionContext& context, + const miopen::norm::ProblemDescription& problem) +{ + auto ln_ptr = DeviceOpType::GetInstances(); + auto ln_ptr_iter = FindLnPtr(ln_ptr, problem); + + if(ln_ptr_iter == ln_ptr.end()) + { + MIOPEN_LOG_E("Layernorm kernel does not exist."); + return {miopenStatusInvalidValue}; + } + + ConvSolution result; + result.invoker_factory = + [ck_args = CKArgsType{problem}, + sh_ln_ptr = std::shared_ptr{std::move(*ln_ptr_iter)}](const std::vector&) mutable { + return [ck_args = std::move(ck_args), sh_ln_ptr = std::move(sh_ln_ptr)]( + const Handle& handle, const AnyInvokeParams& primitive_parameters) { + const auto& data_ctx = primitive_parameters.CastTo(); + auto argument_ptr = ck_args.MakeArgPtr(sh_ln_ptr, + data_ctx.x, + data_ctx.weight, + data_ctx.bias, + data_ctx.y, + data_ctx.mean, + data_ctx.rstd); + auto invoker_ptr = sh_ln_ptr->MakeInvokerPointer(); + + const auto enable_profiling = handle.IsProfilingEnabled(); + float elapsed_time = + invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), enable_profiling}); + if(enable_profiling) + { + handle.ResetKernelTime(); + handle.AccumKernelTime(elapsed_time); + } + }; + }; + return result; +} +#endif + +bool IsRank2Dim1(const miopen::norm::ProblemDescription& problem) +{ + return (problem.GetXDesc().GetLengths().size() == 2) && (problem.GetNormalizedDim() == 1); +} + +bool Layernorm2DCKForward::IsApplicable( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::norm::ProblemDescription& problem) const +{ +#if MIOPEN_USE_COMPOSABLEKERNEL + if(miopen::IsDisabled(MIOPEN_DEBUG_LAYERNORM2DCKFORWARD_CONV_CK_LN{})) + return false; + if(!problem.IsSameType()) + return false; + if(!problem.IsSameLength()) + return false; + if(!problem.IsAllPacked()) + return false; + if(!IsRank2Dim1(problem)) + return false; + if(!problem.IsLargeSize()) + return false; + if(!ck_utility::is_ck_supported_hardware(context.GetStream())) + return false; + + switch(problem.GetXDesc().GetType()) + { + case miopenHalf: + return CheckCKApplicability>(problem); + case miopenFloat: + return CheckCKApplicability>(problem); + case miopenBFloat16: return false; + case miopenDouble: + case miopenInt32: + case miopenInt8: + case miopenFloat8: + case miopenBFloat8: + default: MIOPEN_THROW("Unsupported datatype"); + } +#endif + return false; +} + +ConvSolution Layernorm2DCKForward::GetSolution( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::norm::ProblemDescription& problem) const +{ +#if MIOPEN_USE_COMPOSABLEKERNEL + switch(problem.GetXDesc().GetType()) + { + case miopenHalf: + return MakeInvokerFactory, + CKArgs, + miopen::norm::InvokeParams>(context, problem); + case miopenFloat: + return MakeInvokerFactory, + CKArgs, + miopen::norm::InvokeParams>(context, problem); + case miopenDouble: + case miopenBFloat16: + case miopenInt8: + case miopenInt32: + case miopenFloat8: + case miopenBFloat8: + default: + MIOPEN_THROW(miopenStatusInternalError, + "ConvHipImplicitGemmFwdXdlops operation not implemented for this data type"); + } +#endif + return {}; +} + +} // namespace norm +} // namespace solver +} // namespace miopen diff --git a/src/solver/norm/forward_layernorm4d_ck.cpp b/src/solver/norm/forward_layernorm4d_ck.cpp new file mode 100644 index 0000000000..29d706cd2c --- /dev/null +++ b/src/solver/norm/forward_layernorm4d_ck.cpp @@ -0,0 +1,291 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#if MIOPEN_USE_COMPOSABLEKERNEL +#include +#include +#endif +MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_LAYERNORM4DCKFORWARD_CONV_CK_LN) + +namespace miopen { +namespace solver { +namespace norm { +#if MIOPEN_USE_COMPOSABLEKERNEL + +using F16 = ck::half_t; +using F32 = float; +using F64 = double; +using BF16 = ushort; + +template +using DeviceOp = ck::tensor_operation::device::DeviceNormalization< + XDataType, + GammaDataType, + BetaDataType, + YDataType, + SaveMeanInvStdDataType, + ck::tensor_operation::element_wise::PassThrough, + 4, + 3>; +template +using DeviceOpLnFwdPtrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>; + +namespace { +struct CKArgs +{ + CKArgs(const miopen::norm::ProblemDescription& problem) + { + auto length = problem.GetXDesc().GetLengths(); + + N = length[0]; + H = length[1]; + W = length[2]; + C = length[3]; + + N_stride = H * W * C; + H_stride = W * C; + W_stride = C; + C_stride = 1; + + xyLengths = {N, H, W, C}; + xyStrides = {N_stride, H_stride, W_stride, C_stride}; + gammaStrides = {0, H_stride, W_stride, C_stride}; + betaStrides = {0, H_stride, W_stride, C_stride}; + meanStrides = {1}; + rstdStrides = {1}; + epsilon = problem.GetEpsilon(); + } + + CKArgs(const CKArgs&) = default; + CKArgs(CKArgs&&) = default; + CKArgs& operator=(const CKArgs&) = default; + + template + auto MakeArgPtr(const LNPtr& ln_ptr, + ConstData_t x, + ConstData_t weight, + ConstData_t bias, + Data_t y, + Data_t mean, + Data_t rstd) const + { + return ln_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + gammaStrides, + betaStrides, + xyStrides, + meanStrides, + rstdStrides, + {1, 2, 3}, + epsilon, + x, + weight, + bias, + y, + mean, + rstd, + ck::tensor_operation::element_wise::PassThrough{}); + } + + template + bool IsSupportedBy(const LNPtr& ln_ptr) const + { + auto arg_ptr = MakeArgPtr(ln_ptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); + return ln_ptr->IsSupportedArgument(arg_ptr.get()); + } + + int32_t N; + int32_t C; + int32_t H; + int32_t W; + int32_t N_stride; + int32_t C_stride; + int32_t H_stride; + int32_t W_stride; + std::vector xyLengths; + std::vector xyStrides; + std::vector gammaStrides; + std::vector betaStrides; + std::vector meanStrides; + std::vector rstdStrides; + float epsilon; +}; +} // namespace + +template +bool CheckCKApplicability(const miopen::norm::ProblemDescription& problem) +{ + const auto ln_args = CKArgs{problem}; + const auto ln_ptrs = DeviceOpType::GetInstances(); + + return std::any_of(ln_ptrs.begin(), ln_ptrs.end(), [&ln_args](auto& ln_ptrs) { + return ln_args.IsSupportedBy(ln_ptrs); + }); +} + +template +typename LnPtrsType::iterator FindLnPtr(LnPtrsType& ln_ptrs, + const miopen::norm::ProblemDescription& problem) +{ + const auto ln_args = CKArgs{problem}; + return std::find_if(ln_ptrs.begin(), ln_ptrs.end(), [&ln_args](auto& ln_ptrs) { + return ln_args.IsSupportedBy(ln_ptrs); + }); +} + +template +ConvSolution MakeInvokerFactory([[maybe_unused]] const ExecutionContext& context, + const miopen::norm::ProblemDescription& problem) +{ + auto ln_ptr = DeviceOpType::GetInstances(); + auto ln_ptr_iter = FindLnPtr(ln_ptr, problem); + + if(ln_ptr_iter == ln_ptr.end()) + { + MIOPEN_LOG_E("Layernorm kernel does not exist."); + return {miopenStatusInvalidValue}; + } + + ConvSolution result; + result.invoker_factory = + [ck_args = CKArgsType{problem}, + sh_ln_ptr = std::shared_ptr{std::move(*ln_ptr_iter)}](const std::vector&) mutable { + return [ck_args = std::move(ck_args), sh_ln_ptr = std::move(sh_ln_ptr)]( + const Handle& handle, const AnyInvokeParams& primitive_parameters) { + const auto& data_ctx = primitive_parameters.CastTo(); + auto argument_ptr = ck_args.MakeArgPtr(sh_ln_ptr, + data_ctx.x, + data_ctx.weight, + data_ctx.bias, + data_ctx.y, + data_ctx.mean, + data_ctx.rstd); + auto invoker_ptr = sh_ln_ptr->MakeInvokerPointer(); + + const auto enable_profiling = handle.IsProfilingEnabled(); + float elapsed_time = + invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), enable_profiling}); + if(enable_profiling) + { + handle.ResetKernelTime(); + handle.AccumKernelTime(elapsed_time); + } + }; + }; + return result; +} +#endif + +bool IsRank4Dim1(const miopen::norm::ProblemDescription& problem) +{ + return (problem.GetXDesc().GetLengths().size() == 4) && (problem.GetNormalizedDim() == 1); +} + +bool Layernorm4DCKForward::IsApplicable( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::norm::ProblemDescription& problem) const +{ +#if MIOPEN_USE_COMPOSABLEKERNEL + if(miopen::IsDisabled(MIOPEN_DEBUG_LAYERNORM4DCKFORWARD_CONV_CK_LN{})) + return false; + if(!problem.IsSameType()) + return false; + if(!problem.IsSameLength()) + return false; + if(!problem.IsAllPacked()) + return false; + if(!IsRank4Dim1(problem)) + return false; + if(!problem.IsLargeSize()) + return false; + if(!ck_utility::is_ck_supported_hardware(context.GetStream())) + return false; + + switch(problem.GetXDesc().GetType()) + { + case miopenHalf: + return CheckCKApplicability>(problem); + case miopenFloat: + return CheckCKApplicability>(problem); + case miopenBFloat16: return false; + case miopenDouble: + case miopenInt32: + case miopenInt8: + case miopenFloat8: + case miopenBFloat8: + default: MIOPEN_THROW("Unsupported datatype"); + } +#endif + return false; +} + +ConvSolution Layernorm4DCKForward::GetSolution( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::norm::ProblemDescription& problem) const +{ +#if MIOPEN_USE_COMPOSABLEKERNEL + switch(problem.GetXDesc().GetType()) + { + case miopenHalf: + return MakeInvokerFactory, + CKArgs, + miopen::norm::InvokeParams>(context, problem); + case miopenFloat: + return MakeInvokerFactory, + CKArgs, + miopen::norm::InvokeParams>(context, problem); + case miopenDouble: + case miopenBFloat16: + case miopenInt8: + case miopenInt32: + case miopenFloat8: + case miopenBFloat8: + default: + MIOPEN_THROW(miopenStatusInternalError, + "ConvHipImplicitGemmFwdXdlops operation not implemented for this data type"); + } +#endif + return {}; +} + +} // namespace norm +} // namespace solver +} // namespace miopen diff --git a/test/cpu_layernorm.hpp b/test/cpu_layernorm.hpp index 08cf44368e..9f89249a1b 100644 --- a/test/cpu_layernorm.hpp +++ b/test/cpu_layernorm.hpp @@ -23,7 +23,6 @@ * SOFTWARE. * *******************************************************************************/ -#ifdef MIOPEN_BETA_API #ifndef GUARD_CPU_LAYERNORM_HPP #define GUARD_CPU_LAYERNORM_HPP @@ -72,12 +71,11 @@ void cpu_layernorm_forward(tensor input, ref_rstd[o] = rstd_v; ford(inner_size)([&](int32_t i) { - T weight_v = mode ? 1 : weight[i]; - T bias_v = mode ? 0 : bias[i]; + T weight_v = mode ? weight[i] : 1; + T bias_v = mode ? bias[i] : 0; ref_output[o * inner_size + i] = (input[o * inner_size + i] - mean_v) * rstd_v * weight_v + bias_v; }); }); } #endif -#endif diff --git a/test/gtest/layernorm.cpp b/test/gtest/layernorm.cpp index dbb38c2a02..56db7be0a6 100644 --- a/test/gtest/layernorm.cpp +++ b/test/gtest/layernorm.cpp @@ -24,15 +24,29 @@ * *******************************************************************************/ #include "layernorm.hpp" -#ifdef MIOPEN_BETA_API -struct LayerNormSolverTestFloat : LayerNormSolverTest +std::string GetFloatArg() +{ + static const auto tmp = miopen::GetEnv("MIOPEN_TEST_FLOAT_ARG"); + if(tmp.empty()) + { + return ""; + } + return tmp.front(); +} + +struct LayerNormTestFloat : LayerNormTest { }; -TEST_P(LayerNormSolverTestFloat, LayerNormTestFw){}; +TEST_P(LayerNormTestFloat, LayerNormTestFw) +{ + if(!(miopen::IsEnvvarValueEnabled("MIOPEN_TEST_ALL")) && (GetFloatArg() != "--float")) + { + GTEST_SKIP(); + } +}; INSTANTIATE_TEST_SUITE_P(LayerNormTestSet, - LayerNormSolverTestFloat, + LayerNormTestFloat, testing::ValuesIn(LayerNormTestConfigs())); -#endif diff --git a/test/gtest/layernorm.hpp b/test/gtest/layernorm.hpp index 740108a887..f3491b7b47 100644 --- a/test/gtest/layernorm.hpp +++ b/test/gtest/layernorm.hpp @@ -23,8 +23,8 @@ * SOFTWARE. * *******************************************************************************/ +#define MIOPEN_BETA_API 1 #include -#ifdef MIOPEN_BETA_API #include #include @@ -52,42 +52,102 @@ struct LayerNormTestCase << " LayerNorm_mode:" << tc.ln_mode; } - std::vector GetInput() { return {N, C, D, H, W}; } + std::vector GetInput() + { + if((N != 0) && (C != 0) && (D != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, D, H, W}); + } + else if((N != 0) && (C != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, H, W}); + } + else if((N != 0) && (C != 0) && (W != 0)) + { + return std::vector({N, C, W}); + } + else if((N != 0) && (W != 0)) + { + return std::vector({N, W}); + } + else + { + std::cout << "Error Input Tensor Lengths\n" << std::endl; + return std::vector({0}); + } + } }; std::vector LayerNormTestConfigs() -{ // n c h d w nomalized_dim eps ln_mode +{ // n c d h w nomalized_dim eps ln_mode // clang-format off return { - { 32, 1, 32, 32, 32 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 32x32x32 based on VoxNet arch + { 32, 1, 32, 32, 32 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 32x32x32 based on VoxNet arch { 32, 1, 14, 14, 14 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 32, 32, 14, 14, 14 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 32, 32, 12, 12, 12 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 32, 32, 6, 6, 6 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, - { 256, 1, 32, 32, 32 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 32x32x32 based on VoxNet arch + { 256, 1, 32, 32, 32 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 32x32x32 based on VoxNet arch { 256, 32, 14, 14, 14 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 256, 32, 12, 12, 12 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 256, 32, 6, 6, 6 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, - { 512, 1, 32, 32, 32 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 32x32x32 based on VoxNet arch + { 512, 1, 32, 32, 32 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 32x32x32 based on VoxNet arch { 512, 32, 14, 14, 14 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 512, 32, 12, 12, 12 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 512, 32, 6, 6, 6 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, - { 32, 2, 32, 57, 125 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // Hand-gesture recognition CVPR 2015 paper High Res Net Path + { 32, 2, 32, 57, 125 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // Hand-gesture recognition CVPR 2015 paper High Res Net Path { 32, 32, 14, 25, 59 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 32, 32, 6, 10, 27 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 32, 32, 4, 6, 11 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 32, 32, 2, 2, 3 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, - { 32, 32, 32, 28, 62 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // Hand-gesture recognition CVPR 2015 paper Low Res Net Path + { 32, 32, 32, 28, 62 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // Hand-gesture recognition CVPR 2015 paper Low Res Net Path { 32, 32, 14, 12, 29 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 32, 32, 6, 4, 12 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, { 32, 32, 4, 2, 2 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, - { 16, 32, 6, 50, 50 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // Multi-view 3D convnet - { 1, 3, 8, 240, 320 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 3D convet on video - { 1, 3, 16, 240, 320 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 3D convet on video - { 1, 3, 8, 128, 171 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 3D convet on video - { 1, 3, 16, 128, 171 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 3D convet on video - { 1, 3, 8, 112, 112 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 3D convet on video - { 1, 3, 16, 112, 112 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE} // 3D convet on video + { 16, 32, 6, 50, 50 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // Multi-view 3D convnet + { 1, 3, 8, 240, 320 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 3D convet on video + { 1, 3, 16, 240, 320 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 3D convet on video + { 1, 3, 8, 128, 171 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 3D convet on video + { 1, 3, 16, 128, 171 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 3D convet on video + { 1, 3, 8, 112, 112 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 3D convet on video + { 1, 3, 16, 112, 112 ,4 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, // 3D convet on video + { 32, 1, 32, 32, 32 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // 32x32x32 based on VoxNet arch + { 32, 1, 14, 14, 14 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 32, 32, 14, 14, 14 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 32, 32, 12, 12, 12 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 32, 32, 6, 6, 6 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 256, 1, 32, 32, 32 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // 32x32x32 based on VoxNet arch + { 256, 32, 14, 14, 14 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 256, 32, 12, 12, 12 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 256, 32, 6, 6, 6 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 512, 1, 32, 32, 32 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // 32x32x32 based on VoxNet arch + { 512, 32, 14, 14, 14 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 512, 32, 12, 12, 12 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 512, 32, 6, 6, 6 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 32, 2, 32, 57, 125 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // Hand-gesture recognition CVPR 2015 paper High Res Net Path + { 32, 32, 14, 25, 59 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 32, 32, 6, 10, 27 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 32, 32, 4, 6, 11 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 32, 32, 2, 2, 3 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 32, 32, 32, 28, 62 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // Hand-gesture recognition CVPR 2015 paper Low Res Net Path + { 32, 32, 14, 12, 29 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 32, 32, 6, 4, 12 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 32, 32, 4, 2, 2 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, + { 16, 32, 6, 50, 50 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // Multi-view 3D convnet + { 1, 3, 8, 240, 320 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // 3D convet on video + { 1, 3, 16, 240, 320 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // 3D convet on video + { 1, 3, 8, 128, 171 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // 3D convet on video + { 1, 3, 16, 128, 171 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // 3D convet on video + { 1, 3, 8, 112, 112 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // 3D convet on video + { 1, 3, 16, 112, 112 ,4 , 1e-5, MIOPEN_WEIGHT_BIAS}, // 3D convet on video + {32, 4, 0, 4, 256 ,1 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {64, 4, 0, 4, 256 ,1 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 4, 0, 4, 256 ,1 , 1e-5, MIOPEN_WEIGHT_BIAS}, + {64, 4, 0, 4, 256 ,1 , 1e-5, MIOPEN_WEIGHT_BIAS}, + {32, 0, 0, 0, 256 ,1 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {64, 0, 0, 0, 256 ,1 , 1e-5, MIOPEN_ELEMENTWISE_AFFINE}, + {32, 0, 0, 0, 256 ,1 , 1e-5, MIOPEN_WEIGHT_BIAS}, + {64, 0, 0, 0, 256 ,1 , 1e-5, MIOPEN_WEIGHT_BIAS} }; // clang-format on } @@ -102,13 +162,12 @@ inline int32_t SetTensorLayout(miopen::TensorDescriptor& desc) } template -struct LayerNormSolverTest : public ::testing::TestWithParam +struct LayerNormTest : public ::testing::TestWithParam { protected: void SetUp() override { auto&& handle = get_handle(); - test_skipped = false; layernorm_config = GetParam(); std::mt19937 gen(0); std::uniform_real_distribution<> d{-3, 3}; @@ -122,13 +181,23 @@ struct LayerNormSolverTest : public ::testing::TestWithParam input = tensor{in_dim}.generate(gen_value); + std::vector inner_dim; + if(nomalized_dim == in_dim.size()) + inner_dim = {1}; + else + inner_dim = {in_dim.begin() + nomalized_dim, in_dim.end()}; + if(ln_mode == MIOPEN_ELEMENTWISE_AFFINE) { - std::vector inner_dim; - if(nomalized_dim == in_dim.size()) - inner_dim = {1}; - else - inner_dim = {in_dim.begin() + nomalized_dim, in_dim.end()}; + auto gen_one = [&](auto...) { return 1; }; + auto gen_zero = [&](auto...) { return 0; }; + weight = tensor{inner_dim}.generate(gen_one); + bias = tensor{inner_dim}.generate(gen_zero); + SetTensorLayout(weight.desc); + SetTensorLayout(bias.desc); + } + else + { weight = tensor{inner_dim}.generate(gen_value); bias = tensor{inner_dim}.generate(gen_value); SetTensorLayout(weight.desc); @@ -169,9 +238,6 @@ struct LayerNormSolverTest : public ::testing::TestWithParam } void TearDown() override { - if(test_skipped) - return; - auto&& handle = get_handle(); cpu_layernorm_forward( @@ -210,7 +276,7 @@ struct LayerNormSolverTest : public ::testing::TestWithParam error = miopen::rms_range(ref_mean, mean); EXPECT_TRUE(miopen::range_distance(ref_mean) == miopen::range_distance(mean)); - EXPECT_TRUE(error < threshold) + EXPECT_TRUE(error < threshold * 20) << "Error mean beyond tolerance Error:" << error << ", Threshold: " << threshold; error = miopen::rms_range(ref_rstd, rstd); @@ -241,7 +307,4 @@ struct LayerNormSolverTest : public ::testing::TestWithParam size_t nomalized_dim; float eps; miopenLayerNormMode_t ln_mode; - - bool test_skipped = false; }; -#endif