From 4f892ec13641ec423a5b74da6ad3b3cf38d392ad Mon Sep 17 00:00:00 2001 From: Chen Fu Date: Wed, 11 Jan 2023 13:54:53 -0800 Subject: [PATCH 01/19] fp16 gemm interface --- cmake/onnxruntime_mlas.cmake | 1 + onnxruntime/core/mlas/inc/mlas.h | 102 ++++++- onnxruntime/core/mlas/lib/halfgemm.cpp | 213 ++++++++++++++ onnxruntime/core/mlas/lib/halfgemm.h | 371 +++++++++++++++++++++++++ onnxruntime/core/mlas/lib/mlasi.h | 96 ++++++- 5 files changed, 780 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/halfgemm.cpp create mode 100644 onnxruntime/core/mlas/lib/halfgemm.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 1f9b7129943e6..267d85b855f7c 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -19,6 +19,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/platform.cpp ${MLAS_SRC_DIR}/threading.cpp ${MLAS_SRC_DIR}/sgemm.cpp + ${MLAS_SRC_DIR}/halfgemm.cpp ${MLAS_SRC_DIR}/qgemm.cpp ${MLAS_SRC_DIR}/qdwconv.cpp ${MLAS_SRC_DIR}/convolve.cpp diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 5b6756e4fb90b..0c67948d7b0be 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -90,7 +90,7 @@ typedef enum { CblasLeft=141, CblasRight=142} CBLAS_SIDE; #endif // -// Forward declare the thread pool implementation class. +// Forward declare the thread pool implementation class and half precision floating point. // // N.B. Avoid including ONNX Runtime headers here to keep the dependencies for // standalone MLAS test executables smaller. @@ -100,9 +100,12 @@ namespace onnxruntime { namespace concurrency { class ThreadPool; }; -}; + struct MLFloat16; +}; // namespace onnxruntime using MLAS_THREADPOOL = onnxruntime::concurrency::ThreadPool; +using MLAS_FP16 = struct onnxruntime::MLFloat16; + // // Platform routines. @@ -1366,3 +1369,98 @@ MlasQLinearMul( size_t N, bool IsScalarB ); + +// +// Half precision routines +// + +class MLAS_HALF_GEMM_OUTPUT_PROCESSOR { +public: + virtual + void + Process( + const MLAS_FP16*, // Supplies the address of matrix to process + size_t, // Supplies the start row index of matrix + size_t, // Supplies the start col index of matrix + size_t, // Supplies the element count per row to process + size_t, // Supplies the element count per col to process + size_t // Supplies the leading dimension of matrix + ) const = 0; + + virtual ~MLAS_HALF_GEMM_OUTPUT_PROCESSOR() {} +}; + + +/** + * @brief Data parameters for half precision GEMM routine + * All except C are [in] parameters +*/ +struct MLAS_HALF_GEMM_DATA_PARAMS { + const MLAS_FP16* A = nullptr; /**< address of A */ + size_t lda = 0; /**< leading dimension of A */ + const MLAS_FP16* B = nullptr; /**< address of B */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is packed*/ + const MLAS_FP16* Bias = nullptr; /**< address of Bias, vector size N */ + MLAS_FP16* C = nullptr; /**< address of result matrix */ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_HALF_GEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr; +}; + +/** + * @brief Half precision Batched GEMM: C = A * B + Bias + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return +*/ +void +MLASCALL +MlasHalfGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool = nullptr + ); + +/** + * @brief For half precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @return size of the packing buffer, + * 0 if operation not supported +*/ +size_t +MLASCALL +MlasHalfGemmPackBSize( + size_t N, + size_t K + ); + +/** + * @brief For half precision GEMM, pack the right hand + * side matrix B + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix +*/ +void +MLASCALL +MlasHalfGemmPackB( + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB + ); diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp new file mode 100644 index 0000000000000..e049c4e0b719a --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -0,0 +1,213 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + half gemm.cpp + +Abstract: + + This module implements the half precision (fp16) matrix/matrix multiply + operation (QGEMM). + +--*/ + +#include "mlasi.h" +#include "halfgemm.h" + +#include + + +void +MLASCALL +MlasHalfGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool + ) +{ + const MLAS_HALF_GEMM_DISPATCH* dispatch = MlasHalfGemmGetDispatch(); + MLAS_HALF_GEMM_OPERATION* operation; + if (DataParams->ldb == 0) { + // B is packed + operation = dispatch->PackedOperation; + } + else { + operation = dispatch->Operation; + } + + if (ThreadPool == nullptr) { + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + auto Data = &DataParams[gemm_i]; + operation(K, Data, 0, M, 0, N); + } + return; + } + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + const size_t StrideM = dispatch->StrideM; + + size_t nc = N; + if ((size_t)MlasGetMaximumThreadCount(ThreadPool) > BatchN) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min(nc, MlasDivRoundup(nc, max_nc * MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + auto Data = &DataParams[gemm_i]; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} + + + +size_t +MLASCALL +MlasHalfGemmPackBSize( + size_t N, + size_t K + ) +{ + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + + return 0; +} + + +void +MLASCALL +MlasHalfGemmPackB( + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB + ) +{ + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(ldb); + MLAS_UNREFERENCED_PARAMETER(PackedB); + + throw std::exception("HalfGemmPacking should not be called, when MlasHalfGemmPackBSize returns 0!"); +} + + +// +// C++ implementation that runs very slowly +// + +struct MLAS_HALF_GEMM_KERNEL_DEFAULT { + + static constexpr bool PackNeeded = false; + static constexpr size_t KernelMaxM = 128; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 1; + + static constexpr MLAS_HALF_GEMM_STRIDES Strides{128, 128, 128}; + static constexpr MLAS_HALF_GEMM_STRIDES PackedStrides{0, 0, 0}; +}; + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmCopyPackB( + MLAS_FP16* D, + const MLAS_FP16* B, + size_t ldb, + size_t CountN, + size_t CountK + ) +{ + MLAS_UNREFERENCED_PARAMETER(D); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(ldb); + MLAS_UNREFERENCED_PARAMETER(CountN); + MLAS_UNREFERENCED_PARAMETER(CountK); +} + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmKernel( + const size_t CountM, + const size_t CountN, + const size_t CountK, + const MLAS_FP16* A, + const size_t lda, + const MLAS_FP16* B, + const size_t ldb, + MLAS_FP16* C, + size_t ldc, + const MLAS_FP16* Bias, + const bool ZeroMode) +{ + MLAS_UNREFERENCED_PARAMETER(CountM); + MLAS_UNREFERENCED_PARAMETER(CountN); + MLAS_UNREFERENCED_PARAMETER(CountK); + MLAS_UNREFERENCED_PARAMETER(A); + MLAS_UNREFERENCED_PARAMETER(lda); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(ldb); + MLAS_UNREFERENCED_PARAMETER(C); + MLAS_UNREFERENCED_PARAMETER(ldc); + MLAS_UNREFERENCED_PARAMETER(Bias); + MLAS_UNREFERENCED_PARAMETER(ZeroMode); +} + + +const MLAS_HALF_GEMM_DISPATCH MlasHalfGemmDispatchDefault = { + MlasHalfGemmOperation, + nullptr, + nullptr, + 128 +}; diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h new file mode 100644 index 0000000000000..71e19bc6b918c --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -0,0 +1,371 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm.h + +Abstract: + + This module defines the set of template functions to implement a kernel of + half precision matrix/matrix multiply operation (QGEMM). + + To implement a new kernel, template functions below need to be specialized: + MlasHalfGemmCopyPackB + MlasHalfGemmKernel + Specialization of MlasHalfGemmTryGemvKernel is optional. + + MlasHalfGemmOperation and MlasHalfGemmPackedOperation are shared kernel drivers. + + +--*/ + +#pragma once + +#include +#include + +#include "mlasi.h" + + +/** + * @brief Define the default striding parameters for + * the half precision gemm operation + */ +struct MLAS_HALF_GEMM_STRIDES { + size_t M; + size_t N; + size_t K; +}; + +template +MLAS_FORCEINLINE +bool +MlasHalfGemmTryGemvKernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + size_t ldb, + MLAS_FP16* C, + size_t CountK, + size_t CountN +) +{ + MLAS_UNREFERENCED_PARAMETER(A); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(ldb); + MLAS_UNREFERENCED_PARAMETER(C); + MLAS_UNREFERENCED_PARAMETER(CountK); + MLAS_UNREFERENCED_PARAMETER(CountN); + + return false; +} + +template +void +MlasHalfGemmCopyPackB( + MLAS_FP16* D, + const MLAS_FP16* B, + size_t ldb, + size_t CountN, + size_t CountK +); + +template +void +MlasHalfGemmKernel( + const size_t CountM, + const size_t CountN, + const size_t CountK, + const MLAS_FP16* A, + const size_t lda, + const MLAS_FP16* B, + const size_t ldb, + MLAS_FP16* C, + size_t ldc, + const MLAS_FP16* Bias, + const bool ZeroMode +); + + +template +MLAS_FORCEINLINE +void +MlasHalfGemmThreadInit() +{ + if (!KernelType::PackNeeded) { + return; + } + constexpr MLAS_HALF_GEMM_STRIDES Strides = KernelType::Strides; + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(MLAS_FP16)); + + MlasThreadedBufAlloc(packBSize); +} + + +template +void +MlasHalfGemmOperation( + const size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ) +{ + MlasHalfGemmThreadInit(); + + const size_t lda = Data->lda; + const size_t ldb = Data->ldb; + const size_t ldc = Data->ldc; + + const MLAS_FP16* A = Data->A + RangeStartM * lda; + const MLAS_FP16* B = Data->B + RangeStartN; + const MLAS_FP16* Bias = Data->Bias + RangeStartN; + MLAS_FP16* C = Data->C + RangeStartM * ldc + RangeStartN; + + // + // Try to use a GEMV kernel if supported by this kernel type. + // + + if ((RangeCountM == 1) && (Data->OutputProcessor == nullptr)) { + if (MlasHalfGemmTryGemvKernel(A, B, ldb, C, K, RangeCountN)) { + return; + } + } + + if (!KernelType::PackNeeded) { + // We are not restricted by packing panel size, so simpler tiling + + auto pa = A; + auto c = C; + size_t RowsRemaining = RangeCountM; + + while (RowsRemaining > 0) { + MlasHalfGemmKernel( + RowsRemaining, + RangeCountN, + K, + pa, + lda, + B, + ldb, + c, + ldc, + Bias, + true); + + size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); + + if (Data->OutputProcessor != nullptr) { + Data->OutputProcessor->Process( + Data->C, + RangeStartM + RangeCountM - RowsRemaining, + RangeStartN, + RowsHandled, + RangeCountN, + Data->ldc); + } + + c += ldc * RowsHandled; + pa += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + + return; + } + + // + // Three dimensional tiling due to limited packing panel size + // + + constexpr MLAS_HALF_GEMM_STRIDES Strides = KernelType::Strides; + MLAS_FP16* PanelB = reinterpret_cast(ThreadedBufHolder.get()); + + // + // Step through each slice of matrix B along the K dimension. + // + + size_t CountK; + + for (size_t k = 0; k < K; k += CountK) { + + CountK = std::min(K - k, Strides.K); + + // + // Step through each slice of matrix B along the N dimension. + // + + size_t CountN; + + for (size_t n = 0; n < RangeCountN; n += CountN) { + + CountN = std::min(RangeCountN - n, Strides.N); + + // + // Copy a panel of matrix B to a local packed buffer. + // + + MlasHalfGemmCopyPackB( + PanelB, + B + n, + ldb, + CountN, + CountK); + + // + // Step through each slice of matrix A along the M dimension. + // + + MLAS_FP16* c = C + n; + size_t CountM; + + for (size_t m = 0; m < RangeCountM; m += CountM) { + + CountM = std::min(RangeCountM - m, Strides.M); + + const MLAS_FP16* pa = A + m * lda; + size_t RowsRemaining = CountM; + + bool ZeroMode = (k == 0); + bool PostProcess = (k + CountK == K); + + while (RowsRemaining > 0) { + MlasHalfGemmKernel( + RowsRemaining, + CountN, + CountK, + pa, + lda, + PanelB, + 0, // ldb not needed for packed B + c, + ldc, + Bias, + ZeroMode); + + size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); + + if (PostProcess && Data->OutputProcessor != nullptr) { + Data->OutputProcessor->Process( + Data->C, + RangeStartM + m + CountM - RowsRemaining, + RangeStartN + n, + RowsHandled, + CountN, + Data->ldc); + } + + c += ldc * RowsHandled; + pa += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + } + + A += CountK; + B += CountK * ldb; + } +} + + +template +void +MlasHalfGemmPackedOperation( + const size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ) +{ + const size_t lda = Data->lda; + const size_t ldc = Data->ldc; + + auto pa = (Data->A) + RangeStartM * lda; + const size_t PackedCountK = (K + KernelType::PackedK - 1) / KernelType::PackedK; + const MLAS_FP16* b = Data->B + RangeStartN * KernelType::PackedK * PackedCountK; + const MLAS_FP16* Bias = Data->Bias + RangeStartN; + auto* c = C; + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + MlasHalfGemmKernel( + RowsRemaining, + RangeCountN, + K, + pa, + lda, + b, + 0, // packed B ldb not needed + c, + ldc, + Bias, + true); + + size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); + + if (Data->OutputProcessor != nullptr) { + Data->OutputProcessor->Process( + Data->C, + RangeStartM + RangeCountM - RowsRemaining, + RangeStartN + n, + RowsHandled, + RangeCountN, + Data->ldc); + } + + c += ldc * RowsHandled; + pa += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } +} + + + + +// +// dispatch structure. +// + +typedef +void +(MLAS_HALF_GEMM_OPERATION)( + const size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ); + + +typedef +void +(MLAS_HALF_GEMM_COPY_PACKB_ROUTINE)( + MLAS_FP16* D, + const MLAS_FP16* B, + size_t ldb, + size_t CountN, + size_t CountK + ); + +struct MLAS_HALF_GEMM_DISPATCH { + MLAS_HALF_GEMM_OPERATION* Operation; + MLAS_HALF_GEMM_OPERATION* PackedOperation; + MLAS_HALF_GEMM_COPY_PACKB_ROUTINE* CopyPackBRoutine; + size_t StrideM; +}; + +extern const MLAS_HALF_GEMM_DISPATCH MlasHalfGemmDispatchDefault; + +MLAS_FORCEINLINE +const MLAS_HALF_GEMM_DISPATCH* +MlasHalfGemmGetDispatch() +{ + return &MlasHalfGemmDispatchDefault; +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 31999f3294999..32eb29b107b38 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -107,6 +107,8 @@ Module Name: #include "core/common/cpuid_info.h" using MLAS_CPUIDINFO = onnxruntime::CPUIDInfo; +#include "core/framework/float16.h" + #else // BUILD_MLAS_NO_ONNXRUNTIME class MLASCPUIDInfo @@ -179,7 +181,99 @@ enum MlasUArch { #endif // MLAS_TARGET_ARM64 -#endif // BUILD_MLAS_NO_ONNXRUNTIME +union fp32_bits { + uint32_t u; + float f; +}; + +namespace onnxruntime +{ +// MLFloat16 +struct MLFloat16 { + uint16_t val{0}; + + MLFloat16() = default; + explicit constexpr MLFloat16(uint16_t x) : val(x) {} + explicit MLFloat16(float ff) { + constexpr fp32_bits f32infty = {255 << 23}; + constexpr fp32_bits f16max = {(127 + 16) << 23}; + constexpr fp32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; + constexpr uint32_t sign_mask = 0x80000000u; + + val = static_cast(0x0u); + fp32_bits f; f.f = ff; + + uint32_t sign = f.u & sign_mask; + f.u ^= sign; + + if (f.u >= f16max.u) { + // Inf or NaN (all exponent bits set) + val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + } else { + if (f.u < (113 << 23)) { + // Subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.f += denorm_magic.f; + + // and one integer subtract of the bias later, we have our final float! + val = static_cast(f.u - denorm_magic.u); + } else { + uint32_t mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + f.u += ((uint32_t)(15 - 127) << 23) + 0xfff; + // rounding bias part 2 + f.u += mant_odd; + // take the bits! + val = static_cast(f.u >> 13); + } + } + + val |= static_cast(sign >> 16); + } + + float ToFloat() const { + constexpr fp32_bits magic = {113 << 23}; + constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift + fp32_bits o; + + o.u = (val & 0x7fff) << 13; // exponent/mantissa bits + uint32_t exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // renormalize + } + + o.u |= (val & 0x8000) << 16; // sign bit + return o.f; + } + + operator float() const { return ToFloat(); } +}; + +inline bool +operator==(const MLFloat16& left, const MLFloat16& right) +{ + return left.val == right.val; +} + +inline bool +operator!=(const MLFloat16& left, const MLFloat16& right) +{ + return left.val != right.val; +} + +} // namespace onnxruntime + +#endif // BUILD_MLAS_NO_ONNXRUNTIME + // // Define the maximum number of threads supported by this implementation. From d83bf3cbec3959a2dacdb3382fd436dc01aef560 Mon Sep 17 00:00:00 2001 From: Chen Fu Date: Wed, 11 Jan 2023 20:49:34 -0800 Subject: [PATCH 02/19] add fp32/fp16 --- onnxruntime/core/mlas/inc/mlas.h | 39 ++- onnxruntime/core/mlas/lib/halfgemm.cpp | 154 ++++++--- onnxruntime/core/mlas/lib/halfgemm.h | 439 +++++++++++++++---------- 3 files changed, 418 insertions(+), 214 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 0c67948d7b0be..fce9c17ba9465 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1396,18 +1396,21 @@ class MLAS_HALF_GEMM_OUTPUT_PROCESSOR { * All except C are [in] parameters */ struct MLAS_HALF_GEMM_DATA_PARAMS { - const MLAS_FP16* A = nullptr; /**< address of A */ - size_t lda = 0; /**< leading dimension of A */ - const MLAS_FP16* B = nullptr; /**< address of B */ - size_t ldb = 0; /**< leading dimension of B, 0 when B is packed*/ + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ const MLAS_FP16* Bias = nullptr; /**< address of Bias, vector size N */ MLAS_FP16* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is packed*/ size_t ldc = 0; /**< leading dimension of C*/ const MLAS_HALF_GEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be casted into fp16*/ }; /** * @brief Half precision Batched GEMM: C = A * B + Bias + * Either A or B can be fp32 or fp16 * * Note: We only support uniform batching, so shapes and types of the * input must be same across all parameter blocks. @@ -1433,9 +1436,11 @@ MlasHalfGemmBatch( /** * @brief For half precision GEMM, returns size of the - * packing buffer needed for right hand side + * packing buffer needed for right hand side * @param[in] N Number of columns * @param[in] K Number of rows + * @param[in] float2half Whether the input is float that + * needs to be converted to half precision * @return size of the packing buffer, * 0 if operation not supported */ @@ -1443,12 +1448,14 @@ size_t MLASCALL MlasHalfGemmPackBSize( size_t N, - size_t K + size_t K, + bool float2half ); /** * @brief For half precision GEMM, pack the right hand * side matrix B + * * @param[in] N Number of columns * @param[in] K Number of rows * @param[in] B Address of matrix B @@ -1464,3 +1471,23 @@ MlasHalfGemmPackB( size_t ldb, void* PackedB ); + +/** + * @brief For half precision GEMM, convert the float matrix B + * to half precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix +*/ +void +MLASCALL +MlasHalfGemmConvertPackB( + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ); diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index e049c4e0b719a..1029371b1c67a 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -33,19 +33,12 @@ MlasHalfGemmBatch( ) { const MLAS_HALF_GEMM_DISPATCH* dispatch = MlasHalfGemmGetDispatch(); - MLAS_HALF_GEMM_OPERATION* operation; - if (DataParams->ldb == 0) { - // B is packed - operation = dispatch->PackedOperation; - } - else { - operation = dispatch->Operation; - } + MLAS_HALF_GEMM_OPERATION* operation = dispatch->Operation; if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { auto Data = &DataParams[gemm_i]; - operation(K, Data, 0, M, 0, N); + operation(N, K, Data, 0, M, 0, N); } return; } @@ -103,26 +96,33 @@ MlasHalfGemmBatch( const size_t RangeStartN = ThreadIdN * StrideN; const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); - operation(K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + operation(N, K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); }); } - size_t MLASCALL MlasHalfGemmPackBSize( size_t N, - size_t K + size_t K, + bool float2half ) { - MLAS_UNREFERENCED_PARAMETER(N); - MLAS_UNREFERENCED_PARAMETER(K); - - return 0; + const auto* dispatch = MlasHalfGemmGetDispatch(); + const auto PackedK = dispatch->PackededK; + if (!float2half && dispatch->CopyPackBRoutine == nullptr) { + // No packing routine provided + return 0; + } + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t BytesRequired = N * AlignedK * sizeof(MLAS_FP16); + const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); + const size_t AlignedBytesRequired = + (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + return AlignedBytesRequired; } - void MLASCALL MlasHalfGemmPackB( @@ -133,13 +133,22 @@ MlasHalfGemmPackB( void* PackedB ) { - MLAS_UNREFERENCED_PARAMETER(N); - MLAS_UNREFERENCED_PARAMETER(K); - MLAS_UNREFERENCED_PARAMETER(B); - MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(PackedB); + const auto* dispatch = MlasHalfGemmGetDispatch(); + dispatch->CopyPackBRoutine((MLAS_FP16*)PackedB, B, ldb, N, K); +} - throw std::exception("HalfGemmPacking should not be called, when MlasHalfGemmPackBSize returns 0!"); +void +MLASCALL +MlasHalfGemmConvertPackB( + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ) +{ + const auto* dispatch = MlasHalfGemmGetDispatch(); + dispatch->ConvertPackBRoutine((MLAS_FP16*)PackedB, B, ldb, N, K); } @@ -154,7 +163,6 @@ struct MLAS_HALF_GEMM_KERNEL_DEFAULT { static constexpr size_t PackedK = 1; static constexpr MLAS_HALF_GEMM_STRIDES Strides{128, 128, 128}; - static constexpr MLAS_HALF_GEMM_STRIDES PackedStrides{0, 0, 0}; }; template<> @@ -173,41 +181,101 @@ MlasHalfGemmCopyPackB( MLAS_UNREFERENCED_PARAMETER(ldb); MLAS_UNREFERENCED_PARAMETER(CountN); MLAS_UNREFERENCED_PARAMETER(CountK); + // No packing for fp16 B. leave it alone } +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackA( + MLAS_FP16* D, + const float* A, + size_t lda, + size_t CountM, + size_t CountK +) +{ + for (size_t m = 0; m < CountM; m++) { + for (size_t k = 0; k < CountK; k++) { + new (D) MLAS_FP16(*(A + m * lda + k)); + D++; + } + } +} + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackB( + MLAS_FP16* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK +) +{ + for (size_t k = 0; k < CountK; k++) { + for (size_t n = 0; n < CountN; n++) { + new (D) MLAS_FP16(*(B + k * ldb + n)); + D++; + } + } +} + + template<> MLAS_FORCEINLINE void MlasHalfGemmKernel( - const size_t CountM, - const size_t CountN, - const size_t CountK, + size_t CountM, + size_t CountN, + size_t CountK, const MLAS_FP16* A, - const size_t lda, + size_t lda, const MLAS_FP16* B, - const size_t ldb, + size_t ldb, MLAS_FP16* C, size_t ldc, const MLAS_FP16* Bias, const bool ZeroMode) { - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(CountN); - MLAS_UNREFERENCED_PARAMETER(CountK); - MLAS_UNREFERENCED_PARAMETER(A); - MLAS_UNREFERENCED_PARAMETER(lda); - MLAS_UNREFERENCED_PARAMETER(B); - MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(C); - MLAS_UNREFERENCED_PARAMETER(ldc); - MLAS_UNREFERENCED_PARAMETER(Bias); - MLAS_UNREFERENCED_PARAMETER(ZeroMode); + CountM = std::min(CountM, MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM); + while (CountM-- > 0) { + // + // Process a single column of matrix B in a loop. + // + const MLAS_FP16* bias = Bias; + const auto* b_col = B; + auto* c = C; + while (CountN-- > 0) { + const auto* a = A; + const auto* b = b_col; + + float Accumulator = bias->ToFloat(); + bias++; + for (size_t k = 0; k < CountK; k++) { + Accumulator += a->ToFloat() * b->ToFloat(); + a++; + b += ldb; + } + if (!ZeroMode) { + Accumulator += c->ToFloat(); + } + new (c) MLAS_FP16(Accumulator); + + c++; + b_col++; + } + A += lda; + C += ldc; + } } const MLAS_HALF_GEMM_DISPATCH MlasHalfGemmDispatchDefault = { MlasHalfGemmOperation, - nullptr, - nullptr, - 128 + nullptr, + MlasHalfGemmConvertPackA, + MLAS_HALF_GEMM_KERNEL_DEFAULT::PackedK, + MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM }; diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 71e19bc6b918c..8cab5711e8c19 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -14,18 +14,26 @@ Module Name: half precision matrix/matrix multiply operation (QGEMM). To implement a new kernel, template functions below need to be specialized: - MlasHalfGemmCopyPackB - MlasHalfGemmKernel - Specialization of MlasHalfGemmTryGemvKernel is optional. - - MlasHalfGemmOperation and MlasHalfGemmPackedOperation are shared kernel drivers. - - + MlasHalfGemmCopyPackB + MlasHalfGemmConvertPackA + MlasHalfGemmConvertPackB + MlasHalfGemmPackedBOffset + MlasHalfGemmPackedBLeadingDim + MlasHalfGemmKernel + + MlasHalfGemmOperation is the shared kernel driver. + + A kernel type should define the following constants: + bool PackNeeded; Whether fp16 B needs to be packed + size_t KernelMaxM; Max # rows the vectorized kernel can process + size_t PackedK; Packed alignment on the K dim (power of 2) + MLAS_HALF_GEMM_STRIDES Strides{128, 128, 128}; --*/ #pragma once #include +#include #include #include "mlasi.h" @@ -41,38 +49,124 @@ struct MLAS_HALF_GEMM_STRIDES { size_t K; }; +/** + * @brief Packing function for fp16 B matrix + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack +*/ template -MLAS_FORCEINLINE -bool -MlasHalfGemmTryGemvKernel( - const MLAS_FP16* A, +MLAS_FORCEINLINE +void +MlasHalfGemmCopyPackB( + MLAS_FP16* D, const MLAS_FP16* B, size_t ldb, - MLAS_FP16* C, - size_t CountK, - size_t CountN + size_t CountN, + size_t CountK ) { - MLAS_UNREFERENCED_PARAMETER(A); + MLAS_UNREFERENCED_PARAMETER(D); MLAS_UNREFERENCED_PARAMETER(B); MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(C); - MLAS_UNREFERENCED_PARAMETER(CountK); MLAS_UNREFERENCED_PARAMETER(CountN); - - return false; + MLAS_UNREFERENCED_PARAMETER(CountK); + // No packing needed by default } +/** + * @brief Convert fp32 matrix A to fp16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of the packing buffer + * @param[in] A Address of fp32 matrix A + * @param[in] lda leading dimension of A + * @param[in] CountM # of rows to pack + * @param[in] CountK # of columns to pack +*/ template void -MlasHalfGemmCopyPackB( +MlasHalfGemmConvertPackA( MLAS_FP16* D, - const MLAS_FP16* B, + const float* A, + size_t lda, + size_t CountM, + size_t CountK +); + +/** + * @brief Convert fp32 matrix B to fp16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B in fp32 + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack + */ +template +void +MlasHalfGemmConvertPackB( + MLAS_FP16* D, + const float* B, size_t ldb, size_t CountN, size_t CountK ); +/** + * @brief Find the location of [StartK, StartN] in packed B buffer + * + * @tparam KernelType + * @param PackedB + * @param DimN + * @param DimK + * @param StartN + * @param StartK + * @return Address of PackedB[StartK, StartN] +*/ +template +MLAS_FORCEINLINE +const MLAS_FP16* +MlasHalfGemmPackedBOffset( + const MLAS_FP16* PackedB, + size_t DimN, + size_t DimK, + size_t StartN, + size_t StartK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return PackedB + StartK * DimN + StartN; +} + +/** + * @brief leading dimension of the packed B buffer + * Related to how B is packed + * @tparam KernelType + * @param DimN + * @param DimK + * @return leading dimension of the packed B buffer +*/ +template +MLAS_FORCEINLINE +size_t +MlasHalfGemmPackedBLeadingDim( + size_t DimN, + size_t DimK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return DimN; +} + + template void MlasHalfGemmKernel( @@ -92,22 +186,9 @@ MlasHalfGemmKernel( template MLAS_FORCEINLINE -void -MlasHalfGemmThreadInit() -{ - if (!KernelType::PackNeeded) { - return; - } - constexpr MLAS_HALF_GEMM_STRIDES Strides = KernelType::Strides; - constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(MLAS_FP16)); - - MlasThreadedBufAlloc(packBSize); -} - - -template void -MlasHalfGemmOperation( +MlasHalfGemmNoPackOperation( + const size_t N, const size_t K, const MLAS_HALF_GEMM_DATA_PARAMS* Data, const size_t RangeStartM, @@ -116,105 +197,165 @@ MlasHalfGemmOperation( const size_t RangeCountN ) { - MlasHalfGemmThreadInit(); + // + // Optimize for the special case where no packing is needed. + // Simpler tiling as we are not restricted by packing panel size + // const size_t lda = Data->lda; - const size_t ldb = Data->ldb; + size_t ldb = Data->ldb; // 0 if prepacked const size_t ldc = Data->ldc; - const MLAS_FP16* A = Data->A + RangeStartM * lda; - const MLAS_FP16* B = Data->B + RangeStartN; + const MLAS_FP16* pa = reinterpret_cast(Data->A) + RangeStartM * lda; + const MLAS_FP16* B; + if (ldb == 0) { + B = MlasHalfGemmPackedBOffset( + reinterpret_cast(Data->B), + N, + K, + RangeStartN, + 0); + ldb = MlasHalfGemmPackedBLeadingDim(N, K); + } else { + B = reinterpret_cast(Data->B) + RangeStartN; + } + const MLAS_FP16* Bias = Data->Bias + RangeStartN; - MLAS_FP16* C = Data->C + RangeStartM * ldc + RangeStartN; + MLAS_FP16* c = Data->C + RangeStartM * ldc + RangeStartN; - // - // Try to use a GEMV kernel if supported by this kernel type. - // + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + MlasHalfGemmKernel( + RowsRemaining, + RangeCountN, + K, + pa, + lda, + B, + ldb, + c, + ldc, + Bias, + true); - if ((RangeCountM == 1) && (Data->OutputProcessor == nullptr)) { - if (MlasHalfGemmTryGemvKernel(A, B, ldb, C, K, RangeCountN)) { - return; - } - } + size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); - if (!KernelType::PackNeeded) { - // We are not restricted by packing panel size, so simpler tiling + if (Data->OutputProcessor != nullptr) { + Data->OutputProcessor->Process( + Data->C, + RangeStartM + RangeCountM - RowsRemaining, + RangeStartN, + RowsHandled, + RangeCountN, + Data->ldc); + } - auto pa = A; - auto c = C; - size_t RowsRemaining = RangeCountM; + c += ldc * RowsHandled; + pa += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } +} - while (RowsRemaining > 0) { - MlasHalfGemmKernel( - RowsRemaining, - RangeCountN, - K, - pa, - lda, - B, - ldb, - c, - ldc, - Bias, - true); - - size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); - - if (Data->OutputProcessor != nullptr) { - Data->OutputProcessor->Process( - Data->C, - RangeStartM + RangeCountM - RowsRemaining, - RangeStartN, - RowsHandled, - RangeCountN, - Data->ldc); - } - c += ldc * RowsHandled; - pa += lda * RowsHandled; - RowsRemaining -= RowsHandled; - } +template +void +MlasHalfGemmOperation( + const size_t N, + const size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ) +{ + const size_t lda = Data->lda; + const size_t ldb = Data->ldb; + const size_t ldc = Data->ldc; + if (!Data->AIsfp32 && (ldb == 0 || !KernelType::PackNeeded && !Data->BIsfp32)) { + // No packing needed, use a simpler driver instead + MlasHalfGemmNoPackOperation( + N, + K, + Data, + RangeStartM, + RangeCountM, + RangeStartN, + RangeCountN); return; - } + } + + const MLAS_FP16* Bias = Data->Bias + RangeStartN; + MLAS_FP16* C = Data->C + RangeStartM * ldc + RangeStartN; // // Three dimensional tiling due to limited packing panel size // - constexpr MLAS_HALF_GEMM_STRIDES Strides = KernelType::Strides; - MLAS_FP16* PanelB = reinterpret_cast(ThreadedBufHolder.get()); + constexpr size_t packASize = UpAlignSize(Strides.M * Strides.K * sizeof(MLAS_FP16)); + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(MLAS_FP16)); + MlasThreadedBufAlloc(packASize + packBSize); + + uint8_t* p = ThreadedBufHolder.get(); + MLAS_FP16* PanelA = reinterpret_cast(p); + p += packASize; + MLAS_FP16* PanelB = reinterpret_cast(p); // // Step through each slice of matrix B along the K dimension. // size_t CountK; - for (size_t k = 0; k < K; k += CountK) { - CountK = std::min(K - k, Strides.K); + const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / KernelType::PackedK; // // Step through each slice of matrix B along the N dimension. // size_t CountN; - for (size_t n = 0; n < RangeCountN; n += CountN) { - CountN = std::min(RangeCountN - n, Strides.N); // // Copy a panel of matrix B to a local packed buffer. // - - MlasHalfGemmCopyPackB( - PanelB, - B + n, - ldb, - CountN, - CountK); + size_t ld_pb; + const MLAS_FP16* pb; + if (ldb == 0) { + // Already packed + pb = MlasHalfGemmPackedBOffset( + reinterpret_cast(Data->B), + N, + K, + RangeStartN + n, + k); + ld_pb = MlasHalfGemmPackedBLeadingDim(N, K); + } else if (Data->BIsfp32) { + MlasHalfGemmConvertPackB( + PanelB, + reinterpret_cast(Data->B) + ldb * k + RangeStartN + n, + ldb, + CountN, + CountK); + pb = PanelB; + ld_pb = MlasHalfGemmPackedBLeadingDim(CountN, CountK); + } else if (KernelType::PackNeeded) { + MlasHalfGemmCopyPackB( + PanelB, + reinterpret_cast(Data->B) + ldb * k + RangeStartN + n, + ldb, + CountN, + CountK); + pb = PanelB; + ld_pb = MlasHalfGemmPackedBLeadingDim(CountN, CountK); + } else { + // fp16, and no packing needed + pb = reinterpret_cast(Data->B) + ldb * k + RangeStartN + n; + ld_pb = ldb; + } // // Step through each slice of matrix A along the M dimension. @@ -222,14 +363,29 @@ MlasHalfGemmOperation( MLAS_FP16* c = C + n; size_t CountM; - for (size_t m = 0; m < RangeCountM; m += CountM) { - CountM = std::min(RangeCountM - m, Strides.M); - const MLAS_FP16* pa = A + m * lda; - size_t RowsRemaining = CountM; + // + // Copy a panel of matrix A to a local packed buffer. + // + const MLAS_FP16* pa; + size_t ld_pa; + if (Data->AIsfp32) { + MlasHalfGemmConvertPackA( + PanelA, + reinterpret_cast(Data->A) + (RangeStartM + m) * lda + k, + lda, + CountM, + CountK); + pa = PanelA; + ld_pa = KernelType::PackedK * PackedCountK; + } else { + pa = reinterpret_cast(Data->A) + (RangeStartM + m) * lda + k; + ld_pa = lda; + } + size_t RowsRemaining = CountM; bool ZeroMode = (k == 0); bool PostProcess = (k + CountK == K); @@ -239,9 +395,9 @@ MlasHalfGemmOperation( CountN, CountK, pa, - lda, - PanelB, - 0, // ldb not needed for packed B + ld_pa, + pb, + ld_pb, c, ldc, Bias, @@ -260,74 +416,15 @@ MlasHalfGemmOperation( } c += ldc * RowsHandled; - pa += lda * RowsHandled; + pa += ld_pa * RowsHandled; RowsRemaining -= RowsHandled; } } } - - A += CountK; - B += CountK * ldb; - } -} - - -template -void -MlasHalfGemmPackedOperation( - const size_t K, - const MLAS_HALF_GEMM_DATA_PARAMS* Data, - const size_t RangeStartM, - const size_t RangeCountM, - const size_t RangeStartN, - const size_t RangeCountN - ) -{ - const size_t lda = Data->lda; - const size_t ldc = Data->ldc; - - auto pa = (Data->A) + RangeStartM * lda; - const size_t PackedCountK = (K + KernelType::PackedK - 1) / KernelType::PackedK; - const MLAS_FP16* b = Data->B + RangeStartN * KernelType::PackedK * PackedCountK; - const MLAS_FP16* Bias = Data->Bias + RangeStartN; - auto* c = C; - - size_t RowsRemaining = RangeCountM; - while (RowsRemaining > 0) { - MlasHalfGemmKernel( - RowsRemaining, - RangeCountN, - K, - pa, - lda, - b, - 0, // packed B ldb not needed - c, - ldc, - Bias, - true); - - size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); - - if (Data->OutputProcessor != nullptr) { - Data->OutputProcessor->Process( - Data->C, - RangeStartM + RangeCountM - RowsRemaining, - RangeStartN + n, - RowsHandled, - RangeCountN, - Data->ldc); - } - - c += ldc * RowsHandled; - pa += lda * RowsHandled; - RowsRemaining -= RowsHandled; } } - - // // dispatch structure. // @@ -335,6 +432,7 @@ MlasHalfGemmPackedOperation( typedef void (MLAS_HALF_GEMM_OPERATION)( + const size_t N, const size_t K, const MLAS_HALF_GEMM_DATA_PARAMS* Data, const size_t RangeStartM, @@ -354,10 +452,21 @@ void size_t CountK ); +typedef +void +(MLAS_HALF_GEMM_CONVERT_PACKB_ROUTINE)( + MLAS_FP16* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK + ); + struct MLAS_HALF_GEMM_DISPATCH { MLAS_HALF_GEMM_OPERATION* Operation; - MLAS_HALF_GEMM_OPERATION* PackedOperation; MLAS_HALF_GEMM_COPY_PACKB_ROUTINE* CopyPackBRoutine; + MLAS_HALF_GEMM_CONVERT_PACKB_ROUTINE* ConvertPackBRoutine; + size_t PackededK; size_t StrideM; }; From 6dfaef32a249b06b259f26466a47a66637dde88a Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Wed, 18 Jan 2023 10:45:57 -0800 Subject: [PATCH 03/19] Adding tests --- onnxruntime/core/mlas/inc/mlas.h | 9 +- onnxruntime/core/mlas/inc/mlas_float16.h | 100 ++++++ onnxruntime/core/mlas/lib/halfgemm.cpp | 90 ++--- onnxruntime/core/mlas/lib/halfgemm.h | 96 ++--- onnxruntime/core/mlas/lib/mlasi.h | 78 +--- .../test/mlas/unittest/test_halfgemm.cpp | 200 +++++++++++ .../test/mlas/unittest/test_halfgemm.h | 339 ++++++++++++++++++ 7 files changed, 746 insertions(+), 166 deletions(-) create mode 100644 onnxruntime/core/mlas/inc/mlas_float16.h create mode 100644 onnxruntime/test/mlas/unittest/test_halfgemm.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_halfgemm.h diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index fce9c17ba9465..b3e79bbb297fe 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -104,7 +104,6 @@ namespace onnxruntime { }; // namespace onnxruntime using MLAS_THREADPOOL = onnxruntime::concurrency::ThreadPool; -using MLAS_FP16 = struct onnxruntime::MLFloat16; // @@ -1374,6 +1373,12 @@ MlasQLinearMul( // Half precision routines // +// Any type with size=2 should work +using MLAS_FP16 = onnxruntime::MLFloat16; + +constexpr size_t FP16_SIZE = sizeof(uint16_t); + + class MLAS_HALF_GEMM_OUTPUT_PROCESSOR { public: virtual @@ -1401,7 +1406,7 @@ struct MLAS_HALF_GEMM_DATA_PARAMS { const MLAS_FP16* Bias = nullptr; /**< address of Bias, vector size N */ MLAS_FP16* C = nullptr; /**< address of result matrix */ size_t lda = 0; /**< leading dimension of A */ - size_t ldb = 0; /**< leading dimension of B, 0 when B is packed*/ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ size_t ldc = 0; /**< leading dimension of C*/ const MLAS_HALF_GEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr; bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/ diff --git a/onnxruntime/core/mlas/inc/mlas_float16.h b/onnxruntime/core/mlas/inc/mlas_float16.h new file mode 100644 index 0000000000000..a8d566677a126 --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_float16.h @@ -0,0 +1,100 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_float16.h + +Abstract: + + Utilities for half precision floating type conversions. Used internally + by MLAS on platforms without half precision support. Provided here as + convenience for tests or other client libraries/apps. + +--*/ + +#pragma once + +#include +#include +#include + + +using _mlas_fp16_ = uint16_t; + +union fp32_bits { + uint32_t u; + float f; +}; + +inline +_mlas_fp16_ +MLAS_Float2Half(float ff) +{ + constexpr fp32_bits f32infty = {255 << 23}; + constexpr fp32_bits f16max = {(127 + 16) << 23}; + constexpr fp32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; + constexpr uint32_t sign_mask = 0x80000000u; + + auto val = static_cast(0x0u); + fp32_bits f; + f.f = ff; + + uint32_t sign = f.u & sign_mask; + f.u ^= sign; + + if (f.u >= f16max.u) { + // Inf or NaN (all exponent bits set) + val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + } else { + if (f.u < (113 << 23)) { + // Subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.f += denorm_magic.f; + + // and one integer subtract of the bias later, we have our final float! + val = static_cast(f.u - denorm_magic.u); + } else { + uint32_t mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + f.u += ((uint32_t)(15 - 127) << 23) + 0xfff; + // rounding bias part 2 + f.u += mant_odd; + // take the bits! + val = static_cast(f.u >> 13); + } + } + + val |= static_cast(sign >> 16); + return val; +} + +inline +float +MLAS_Half2Float(_mlas_fp16_ val) +{ + constexpr fp32_bits magic = {113 << 23}; + constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift + fp32_bits o; + + o.u = (val & 0x7fff) << 13; // exponent/mantissa bits + uint32_t exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // renormalize + } + + o.u |= (val & 0x8000) << 16; // sign bit + return o.f; +} diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 1029371b1c67a..71124f26fb12f 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -16,6 +16,8 @@ Module Name: --*/ #include "mlasi.h" +#include "mlas_float16.h" + #include "halfgemm.h" #include @@ -116,7 +118,7 @@ MlasHalfGemmPackBSize( return 0; } const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); - const size_t BytesRequired = N * AlignedK * sizeof(MLAS_FP16); + const size_t BytesRequired = N * AlignedK * FP16_SIZE; const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); @@ -134,7 +136,7 @@ MlasHalfGemmPackB( ) { const auto* dispatch = MlasHalfGemmGetDispatch(); - dispatch->CopyPackBRoutine((MLAS_FP16*)PackedB, B, ldb, N, K); + dispatch->CopyPackBRoutine((_mlas_fp16_*)PackedB, (const _mlas_fp16_*)B, ldb, N, K); } void @@ -148,12 +150,12 @@ MlasHalfGemmConvertPackB( ) { const auto* dispatch = MlasHalfGemmGetDispatch(); - dispatch->ConvertPackBRoutine((MLAS_FP16*)PackedB, B, ldb, N, K); + dispatch->ConvertPackBRoutine((_mlas_fp16_*)PackedB, B, ldb, N, K); } // -// C++ implementation that runs very slowly +// Dummy C++ implementation that runs very slowly // struct MLAS_HALF_GEMM_KERNEL_DEFAULT { @@ -162,33 +164,14 @@ struct MLAS_HALF_GEMM_KERNEL_DEFAULT { static constexpr size_t KernelMaxM = 128; // max # rows the vectorized kernel can process static constexpr size_t PackedK = 1; - static constexpr MLAS_HALF_GEMM_STRIDES Strides{128, 128, 128}; + static constexpr MLAS_HALF_GEMM_STRIDES Strides{8, 16, 32}; }; -template<> -MLAS_FORCEINLINE -void -MlasHalfGemmCopyPackB( - MLAS_FP16* D, - const MLAS_FP16* B, - size_t ldb, - size_t CountN, - size_t CountK - ) -{ - MLAS_UNREFERENCED_PARAMETER(D); - MLAS_UNREFERENCED_PARAMETER(B); - MLAS_UNREFERENCED_PARAMETER(ldb); - MLAS_UNREFERENCED_PARAMETER(CountN); - MLAS_UNREFERENCED_PARAMETER(CountK); - // No packing for fp16 B. leave it alone -} - template<> MLAS_FORCEINLINE void MlasHalfGemmConvertPackA( - MLAS_FP16* D, + _mlas_fp16_* D, const float* A, size_t lda, size_t CountM, @@ -197,8 +180,7 @@ MlasHalfGemmConvertPackA( { for (size_t m = 0; m < CountM; m++) { for (size_t k = 0; k < CountK; k++) { - new (D) MLAS_FP16(*(A + m * lda + k)); - D++; + *D++ = MLAS_Float2Half(*(A + m * lda + k)); } } } @@ -207,7 +189,7 @@ template<> MLAS_FORCEINLINE void MlasHalfGemmConvertPackB( - MLAS_FP16* D, + _mlas_fp16_* D, const float* B, size_t ldb, size_t CountN, @@ -216,8 +198,7 @@ MlasHalfGemmConvertPackB( { for (size_t k = 0; k < CountK; k++) { for (size_t n = 0; n < CountN; n++) { - new (D) MLAS_FP16(*(B + k * ldb + n)); - D++; + *D++ = MLAS_Float2Half(*(B + k * ldb + n)); } } } @@ -230,44 +211,35 @@ MlasHalfGemmKernel( size_t CountM, size_t CountN, size_t CountK, - const MLAS_FP16* A, + const _mlas_fp16_* A, size_t lda, - const MLAS_FP16* B, + const _mlas_fp16_* B, size_t ldb, - MLAS_FP16* C, + _mlas_fp16_* C, size_t ldc, - const MLAS_FP16* Bias, + const _mlas_fp16_* Bias, const bool ZeroMode) { - CountM = std::min(CountM, MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM); - while (CountM-- > 0) { - // - // Process a single column of matrix B in a loop. - // - const MLAS_FP16* bias = Bias; - const auto* b_col = B; - auto* c = C; - while (CountN-- > 0) { - const auto* a = A; - const auto* b = b_col; - - float Accumulator = bias->ToFloat(); - bias++; + for (size_t m = 0; m < CountM; m++) { + for (size_t n = 0; n < CountN; n++) { + const auto* a = A + (m * lda); + const auto* b = B + n; + auto* c = C + (m * ldc) + n; + + float sum = Bias == nullptr ? 0.0f : MLAS_Half2Float(Bias[n]); + if (!ZeroMode) { + sum += MLAS_Half2Float(*c); + } + for (size_t k = 0; k < CountK; k++) { - Accumulator += a->ToFloat() * b->ToFloat(); - a++; + auto down = MLAS_Float2Half(MLAS_Half2Float(*a) * MLAS_Half2Float(*b) + sum); + sum = MLAS_Half2Float(down); b += ldb; + a += 1; } - if (!ZeroMode) { - Accumulator += c->ToFloat(); - } - new (c) MLAS_FP16(Accumulator); - c++; - b_col++; + *c = MLAS_Float2Half(sum); } - A += lda; - C += ldc; } } @@ -275,7 +247,7 @@ MlasHalfGemmKernel( const MLAS_HALF_GEMM_DISPATCH MlasHalfGemmDispatchDefault = { MlasHalfGemmOperation, nullptr, - MlasHalfGemmConvertPackA, + MlasHalfGemmConvertPackB, MLAS_HALF_GEMM_KERNEL_DEFAULT::PackedK, MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM }; diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 8cab5711e8c19..4b44f7c156058 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -10,8 +10,8 @@ Module Name: Abstract: - This module defines the set of template functions to implement a kernel of - half precision matrix/matrix multiply operation (QGEMM). + This module defines the set of template functions to implement half + precision matrix/matrix multiply operation (QGEMM). To implement a new kernel, template functions below need to be specialized: MlasHalfGemmCopyPackB @@ -63,8 +63,8 @@ template MLAS_FORCEINLINE void MlasHalfGemmCopyPackB( - MLAS_FP16* D, - const MLAS_FP16* B, + _mlas_fp16_* D, + const _mlas_fp16_* B, size_t ldb, size_t CountN, size_t CountK @@ -91,7 +91,7 @@ MlasHalfGemmCopyPackB( template void MlasHalfGemmConvertPackA( - MLAS_FP16* D, + _mlas_fp16_* D, const float* A, size_t lda, size_t CountM, @@ -111,7 +111,7 @@ MlasHalfGemmConvertPackA( template void MlasHalfGemmConvertPackB( - MLAS_FP16* D, + _mlas_fp16_* D, const float* B, size_t ldb, size_t CountN, @@ -119,21 +119,21 @@ MlasHalfGemmConvertPackB( ); /** - * @brief Find the location of [StartK, StartN] in packed B buffer + * @brief Find the location of PackedB[StartK, StartN] * * @tparam KernelType * @param PackedB - * @param DimN - * @param DimK + * @param DimN Total columns of the packing buffer + * @param DimK Total rows of the packing buffer * @param StartN * @param StartK * @return Address of PackedB[StartK, StartN] */ template MLAS_FORCEINLINE -const MLAS_FP16* +const _mlas_fp16_* MlasHalfGemmPackedBOffset( - const MLAS_FP16* PackedB, + const _mlas_fp16_* PackedB, size_t DimN, size_t DimK, size_t StartN, @@ -173,13 +173,13 @@ MlasHalfGemmKernel( const size_t CountM, const size_t CountN, const size_t CountK, - const MLAS_FP16* A, + const _mlas_fp16_* A, const size_t lda, - const MLAS_FP16* B, + const _mlas_fp16_* B, const size_t ldb, - MLAS_FP16* C, + _mlas_fp16_* C, size_t ldc, - const MLAS_FP16* Bias, + const _mlas_fp16_* Bias, const bool ZeroMode ); @@ -206,22 +206,26 @@ MlasHalfGemmNoPackOperation( size_t ldb = Data->ldb; // 0 if prepacked const size_t ldc = Data->ldc; - const MLAS_FP16* pa = reinterpret_cast(Data->A) + RangeStartM * lda; - const MLAS_FP16* B; + const auto* pa = reinterpret_cast(Data->A) + + RangeStartM * lda; + const _mlas_fp16_* pb; if (ldb == 0) { - B = MlasHalfGemmPackedBOffset( - reinterpret_cast(Data->B), + pb = MlasHalfGemmPackedBOffset( + reinterpret_cast(Data->B), N, K, RangeStartN, 0); ldb = MlasHalfGemmPackedBLeadingDim(N, K); } else { - B = reinterpret_cast(Data->B) + RangeStartN; + pb = reinterpret_cast(Data->B) + RangeStartN; } - const MLAS_FP16* Bias = Data->Bias + RangeStartN; - MLAS_FP16* c = Data->C + RangeStartM * ldc + RangeStartN; + const _mlas_fp16_* Bias = (nullptr == Data->Bias) + ? nullptr + : reinterpret_cast(Data->Bias) + RangeStartN; + _mlas_fp16_* c = reinterpret_cast<_mlas_fp16_*>(Data->C) + + RangeStartM * ldc + RangeStartN; size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { @@ -231,7 +235,7 @@ MlasHalfGemmNoPackOperation( K, pa, lda, - B, + pb, ldb, c, ldc, @@ -274,7 +278,13 @@ MlasHalfGemmOperation( const size_t ldc = Data->ldc; if (!Data->AIsfp32 && (ldb == 0 || !KernelType::PackNeeded && !Data->BIsfp32)) { - // No packing needed, use a simpler driver instead + // !Data->AIsfp32 => A is fp16, no packing on the left hand side + // ldb == 0 => B is already packed, no packing on the right hand side + // !KernelType::PackNeeded && !Data->BIsfp32 => B is fp16 and the kernel + // does not require packing + // + // So no packing needed on either A or B, use a simpler driver instead + MlasHalfGemmNoPackOperation( N, K, @@ -286,21 +296,22 @@ MlasHalfGemmOperation( return; } - const MLAS_FP16* Bias = Data->Bias + RangeStartN; - MLAS_FP16* C = Data->C + RangeStartM * ldc + RangeStartN; + const auto* Bias = reinterpret_cast(Data->Bias); + _mlas_fp16_* C = reinterpret_cast<_mlas_fp16_*>(Data->C) + + RangeStartM * ldc + RangeStartN; // // Three dimensional tiling due to limited packing panel size // constexpr MLAS_HALF_GEMM_STRIDES Strides = KernelType::Strides; - constexpr size_t packASize = UpAlignSize(Strides.M * Strides.K * sizeof(MLAS_FP16)); - constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(MLAS_FP16)); + constexpr size_t packASize = UpAlignSize(Strides.M * Strides.K * FP16_SIZE); + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * FP16_SIZE); MlasThreadedBufAlloc(packASize + packBSize); uint8_t* p = ThreadedBufHolder.get(); - MLAS_FP16* PanelA = reinterpret_cast(p); + auto* PanelA = reinterpret_cast<_mlas_fp16_*>(p); p += packASize; - MLAS_FP16* PanelB = reinterpret_cast(p); + auto* PanelB = reinterpret_cast<_mlas_fp16_*>(p); // // Step through each slice of matrix B along the K dimension. @@ -323,17 +334,18 @@ MlasHalfGemmOperation( // Copy a panel of matrix B to a local packed buffer. // size_t ld_pb; - const MLAS_FP16* pb; + const _mlas_fp16_* pb; if (ldb == 0) { // Already packed pb = MlasHalfGemmPackedBOffset( - reinterpret_cast(Data->B), + reinterpret_cast(Data->B), N, K, RangeStartN + n, k); ld_pb = MlasHalfGemmPackedBLeadingDim(N, K); } else if (Data->BIsfp32) { + // fp32, need conversion and packing MlasHalfGemmConvertPackB( PanelB, reinterpret_cast(Data->B) + ldb * k + RangeStartN + n, @@ -343,9 +355,10 @@ MlasHalfGemmOperation( pb = PanelB; ld_pb = MlasHalfGemmPackedBLeadingDim(CountN, CountK); } else if (KernelType::PackNeeded) { + // fp16, need packing MlasHalfGemmCopyPackB( PanelB, - reinterpret_cast(Data->B) + ldb * k + RangeStartN + n, + reinterpret_cast(Data->B) + ldb * k + RangeStartN + n, ldb, CountN, CountK); @@ -353,7 +366,7 @@ MlasHalfGemmOperation( ld_pb = MlasHalfGemmPackedBLeadingDim(CountN, CountK); } else { // fp16, and no packing needed - pb = reinterpret_cast(Data->B) + ldb * k + RangeStartN + n; + pb = reinterpret_cast(Data->B) + ldb * k + RangeStartN + n; ld_pb = ldb; } @@ -361,7 +374,8 @@ MlasHalfGemmOperation( // Step through each slice of matrix A along the M dimension. // - MLAS_FP16* c = C + n; + auto* c = C + n; + const auto* pbias = (nullptr == Bias) ? nullptr : Bias + RangeStartN + n; size_t CountM; for (size_t m = 0; m < RangeCountM; m += CountM) { CountM = std::min(RangeCountM - m, Strides.M); @@ -369,7 +383,7 @@ MlasHalfGemmOperation( // // Copy a panel of matrix A to a local packed buffer. // - const MLAS_FP16* pa; + const _mlas_fp16_* pa; size_t ld_pa; if (Data->AIsfp32) { MlasHalfGemmConvertPackA( @@ -381,7 +395,7 @@ MlasHalfGemmOperation( pa = PanelA; ld_pa = KernelType::PackedK * PackedCountK; } else { - pa = reinterpret_cast(Data->A) + (RangeStartM + m) * lda + k; + pa = reinterpret_cast(Data->A) + (RangeStartM + m) * lda + k; ld_pa = lda; } @@ -400,7 +414,7 @@ MlasHalfGemmOperation( ld_pb, c, ldc, - Bias, + ZeroMode ? pbias : nullptr, ZeroMode); size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); @@ -445,8 +459,8 @@ void typedef void (MLAS_HALF_GEMM_COPY_PACKB_ROUTINE)( - MLAS_FP16* D, - const MLAS_FP16* B, + _mlas_fp16_* D, + const _mlas_fp16_* B, size_t ldb, size_t CountN, size_t CountK @@ -455,7 +469,7 @@ void typedef void (MLAS_HALF_GEMM_CONVERT_PACKB_ROUTINE)( - MLAS_FP16* D, + _mlas_fp16_* D, const float* B, size_t ldb, size_t CountN, diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 32eb29b107b38..051e8c0352b6a 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -181,81 +181,29 @@ enum MlasUArch { #endif // MLAS_TARGET_ARM64 -union fp32_bits { - uint32_t u; - float f; -}; +// +// Define MLAS_FP16 +// +#include "mlas_float16.h" namespace onnxruntime { -// MLFloat16 struct MLFloat16 { uint16_t val{0}; MLFloat16() = default; explicit constexpr MLFloat16(uint16_t x) : val(x) {} - explicit MLFloat16(float ff) { - constexpr fp32_bits f32infty = {255 << 23}; - constexpr fp32_bits f16max = {(127 + 16) << 23}; - constexpr fp32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; - constexpr uint32_t sign_mask = 0x80000000u; - - val = static_cast(0x0u); - fp32_bits f; f.f = ff; - - uint32_t sign = f.u & sign_mask; - f.u ^= sign; - - if (f.u >= f16max.u) { - // Inf or NaN (all exponent bits set) - val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf - } else { - if (f.u < (113 << 23)) { - // Subnormal or zero - // use a magic value to align our 10 mantissa bits at the bottom of - // the float. as long as FP addition is round-to-nearest-even this - // just works. - f.f += denorm_magic.f; - - // and one integer subtract of the bias later, we have our final float! - val = static_cast(f.u - denorm_magic.u); - } else { - uint32_t mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd - - // update exponent, rounding bias part 1 - f.u += ((uint32_t)(15 - 127) << 23) + 0xfff; - // rounding bias part 2 - f.u += mant_odd; - // take the bits! - val = static_cast(f.u >> 13); - } - } + explicit MLFloat16(float ff) : val(MLAS_Float2Half(ff)) {} - val |= static_cast(sign >> 16); - } + float ToFloat() const { return MLAS_Half2Float(val); } - float ToFloat() const { - constexpr fp32_bits magic = {113 << 23}; - constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift - fp32_bits o; - - o.u = (val & 0x7fff) << 13; // exponent/mantissa bits - uint32_t exp = shifted_exp & o.u; // just the exponent - o.u += (127 - 15) << 23; // exponent adjust - - // handle exponent special cases - if (exp == shifted_exp) { // Inf/NaN? - o.u += (128 - 16) << 23; // extra exp adjust - } else if (exp == 0) { // Zero/Denormal? - o.u += 1 << 23; // extra exp adjust - o.f -= magic.f; // renormalize - } + operator float() const { return ToFloat(); } - o.u |= (val & 0x8000) << 16; // sign bit - return o.f; + MLFloat16& operator=(float ff) + { + val = MLAS_Float2Half(ff); + return *this; } - - operator float() const { return ToFloat(); } }; inline bool @@ -270,10 +218,12 @@ operator!=(const MLFloat16& left, const MLFloat16& right) return left.val != right.val; } -} // namespace onnxruntime +} #endif // BUILD_MLAS_NO_ONNXRUNTIME +static_assert(sizeof(MLAS_FP16) == FP16_SIZE); + // // Define the maximum number of threads supported by this implementation. diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp new file mode 100644 index 0000000000000..a275dabe7df4d --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp @@ -0,0 +1,200 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_halfgemm.cpp + +Abstract: + + Tests for MLAS half precision GEMM. + +--*/ + + +#include "test_halfgemm.h" + +// +// Short Execute() test helper to register each test seperately by all parameters. +// +template +class HalfGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit HalfGemmShortExecuteTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) + : M_(M), N_(N), K_(K), Batch_(Batch), hasBias_(hasBias) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test(M_, N_, K_, Batch_, hasBias_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) { + std::stringstream ss; + ss << "Batch" << Batch << "/M" << M << "xN" << N << "xK" << K << "/" + << "hasBias" << hasBias; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasHalfGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new HalfGemmShortExecuteTest( + M, N, K, Batch, hasBias); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t test_registered = 0; + + for (size_t b = 1; b < 16; b++) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 16; b <= 256; b <<= 1) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 256; b < 320; b += 32) { + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 1; b < 96; b++) { + test_registered += RegisterSingleTest(1, b, 32, 1, false); + test_registered += RegisterSingleTest(1, 32, b, 1, true); + test_registered += RegisterSingleTest(1, b, b, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(1, b, 32, 3, true); + test_registered += RegisterSingleTest(1, 32, b, 5, false); + } + } + test_registered += RegisterSingleTest(43, 500, 401, 1, true); + test_registered += RegisterSingleTest(1001, 1027, 1031, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(43, 500, 401, 5, true); + test_registered += RegisterSingleTest(1000, 1029, 1030, 3, false); + } + + return test_registered; + } + + private: + size_t M_, N_, K_, Batch_; + bool hasBias_; +}; + + +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); + +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); + +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); + +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); + +static size_t HalfGemmRegistLongExecute() { + size_t count = 0; + + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasHalfGemmPackBSize(128, 128, false) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + if (MlasHalfGemmPackBSize(128, 128, true) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + + if (GetMlasThreadPool() != nullptr) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasHalfGemmPackBSize(128, 128, false) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + if (MlasHalfGemmPackBSize(128, 128, true) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + } + + return count; +} + +static size_t HalfGemmRegistShortExecute() { + size_t count = 0; + + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasHalfGemmPackBSize(128, 128, false) > 0) { + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + } + if (MlasHalfGemmPackBSize(128, 128, true) > 0) { + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + } + + if (GetMlasThreadPool() != nullptr) { + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasHalfGemmPackBSize(128, 128, false) > 0) { + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + } + if (MlasHalfGemmPackBSize(128, 128, true) > 0) { + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + } + } + + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + return is_short_execute ? HalfGemmRegistShortExecute() : HalfGemmRegistLongExecute(); +}); diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h new file mode 100644 index 0000000000000..17435e06498a6 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -0,0 +1,339 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_halfgemm.h + +Abstract: + + Tests for MLAS half precision GEMM. + +--*/ + +#pragma once + +#include "test_util.h" +#include "mlas_float16.h" + + +// +// Define our own fp16 type to avoid dragging in big dependencies +// +struct MLFp16 { + uint16_t val{0}; + + MLFp16() = default; + explicit constexpr MLFp16(uint16_t x) : val(x) {} + explicit MLFp16(float ff) : val(MLAS_Float2Half(ff)) {} + + float ToFloat() const { + return MLAS_Half2Float(val); + } + + operator float() const { return ToFloat(); } + + MLFp16& operator=(float ff) { + val = MLAS_Float2Half(ff); + return *this; + } +}; + +inline bool +operator==(const MLFp16& left, const MLFp16& right) { + return left.val == right.val; +} + +inline bool +operator!=(const MLFp16& left, const MLFp16& right) { + return left.val != right.val; +} + +// +// Customize buffer fill for half precision buffer +// +template <> +MLFp16* +MatrixGuardBuffer::GetBuffer(size_t Elements, bool ZeroFill) { + // + // Check if the internal buffer needs to be reallocated. + // + + if (Elements > _ElementsAllocated) { + ReleaseBuffer(); + + // + // Reserve a virtual address range for the allocation plus an unmapped + // guard region. + // + + constexpr size_t BufferAlignment = 64 * 1024; + constexpr size_t GuardPadding = 256 * 1024; + + size_t BytesToAllocate = ((Elements * FP16_SIZE) + BufferAlignment - 1) & ~(BufferAlignment - 1); + + _BaseBufferSize = BytesToAllocate + GuardPadding; + +#if defined(_WIN32) + _BaseBuffer = VirtualAlloc(NULL, _BaseBufferSize, MEM_RESERVE, PAGE_NOACCESS); +#else + _BaseBuffer = mmap(0, _BaseBufferSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); +#endif + + if (_BaseBuffer == nullptr) { + abort(); + } + + // + // Commit the number of bytes for the allocation leaving the upper + // guard region as unmapped. + // + +#if defined(_WIN32) + if (VirtualAlloc(_BaseBuffer, BytesToAllocate, MEM_COMMIT, PAGE_READWRITE) == nullptr) { + ORT_THROW_EX(std::bad_alloc); + } +#else + if (mprotect(_BaseBuffer, BytesToAllocate, PROT_READ | PROT_WRITE) != 0) { + abort(); + } +#endif + + _ElementsAllocated = BytesToAllocate / FP16_SIZE; + _GuardAddress = (MLFp16*)((unsigned char*)_BaseBuffer + BytesToAllocate); + } + + + auto* GuardAddress = _GuardAddress; + auto* buffer = GuardAddress - Elements; + + if (ZeroFill) { + std::fill_n(buffer, Elements, MLFp16()); + } else { + constexpr float MinimumFillValue = -11.0f; + constexpr float MaximumFillValue = 11.0f; + + float FillValue = MinimumFillValue; + auto* FillAddress = buffer; + + while (FillAddress < GuardAddress) { + *FillAddress++ = FillValue/16.0f; + + FillValue+=1.0f; + + if (FillValue > MaximumFillValue) { + FillValue = MinimumFillValue; + } + } + } + + return buffer; +} + + +/** + * @brief Test class for half precision GEMM + * @tparam AType Data type of A matrix, can be either float or MLFp16 + * @tparam BType Data type of b matrix, can be either float or MLFp16 +*/ +template +class MlasHalfGemmTest : public MlasTestBase { + +private: + MatrixGuardBuffer BufferBPacked; + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + MLAS_THREADPOOL* threadpool_; + + void* PackB(size_t N, size_t K, const BType* B, size_t ldb) { + size_t PackedBSize = MlasHalfGemmPackBSize(N, K, std::is_same::value); + if (PackedBSize == 0) { + return nullptr; + } + void* PackedB = BufferBPacked.GetBuffer(PackedBSize); + if (std::is_same::value) { + MlasHalfGemmConvertPackB(N, K, (const float*)B, ldb, PackedB); + } else { + MlasHalfGemmPackB(N, K, (const MLAS_FP16*)B, ldb, PackedB); + } + return PackedB; + } + + void CallGemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const AType* A, + size_t lda, + const BType* B, + size_t ldb, + const MLFp16* Bias, + MLFp16* C, + size_t ldc) { + + std::vector GemmParameters(BatchSize); + + for (size_t i = 0; i < GemmParameters.size(); i++) { + auto& params = GemmParameters[i]; + params.A = A + (M * lda * i); + params.lda = lda; + if (nullptr != Bias) { + params.Bias = reinterpret_cast(Bias + N * i); + } else { + params.Bias = nullptr; + } + params.C = reinterpret_cast(C + (M * ldc * i)); + params.ldc = ldc; + + if (Packed) { + ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; + params.B = PackB(N, K, B, ldb); + params.ldb = 0; + } else { + params.B = B + (K * N * i); + params.ldb = ldb; + } + params.AIsfp32 = std::is_same::value; + params.BIsfp32 = std::is_same::value; + } + + MlasHalfGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); + } + + void ReferenceQgemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const AType* A, + const BType* B, + const MLFp16* Bias, + float* C) { + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const AType* a = A + M * K * batch + m * K; + const BType* b = B + K * N * batch + n; + float* c = C + (M * N * batch) + (m * N) + n; + float sum = Bias == nullptr ? 0.0f : float(Bias[n]); + + for (size_t k = 0; k < K; k++) { + MLFp16 down(float(*b) * float(*a) + sum); + sum = float(down); + b += N; + a += 1; + } + + *c = sum; + } + } + if (Bias) { + Bias += N; + } + } + } + +public: + MlasHalfGemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} + + void Test(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { + const AType* A = BufferA.GetBuffer(K * M * BatchSize); + const BType* B = BufferB.GetBuffer(N * K * BatchSize); + const MLFp16* Bias = withBias ? BufferBias.GetBuffer(N * BatchSize) : nullptr; + MLFp16* C = BufferC.GetBuffer(N * M * BatchSize); + float* CReference = BufferCReference.GetBuffer(N * M * BatchSize); + + std::fill_n(CReference, M * N * BatchSize, float(-1.0)); + + this->CallGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N); + ReferenceQgemm(M, N, K, BatchSize, A, B, Bias, CReference); + + for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + ASSERT_EQ(float(C[f]), CReference[f]) << "@[" << batch << "x" << m << "x" << n << "], " + << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K; + } + } + } + } + + private: + + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("HalfGemmFP") + + (std::is_same::value ? "32" : "16") + + (std::is_same::value ? "32" : "16") + + (Packed ? "_Packed" : "_NoPack") + + (Threaded ? "_Threaded" : "_SingleThread"); + return suite_name.c_str(); + } + + void ExecuteLong(void) override { + for (size_t M = 16; M < 160; M += 32) { + for (size_t N = 16; N < 160; N += 32) { + static const size_t ks[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320}; + for (size_t k = 0; k < _countof(ks); k++) { + size_t K = ks[k]; + + Test(M, N, K, 1, false); + Test(M, N, K, 1, true); + Test(M + 1, N, K, 1, false); + Test(M, N + 1, K, 1, true); + Test(M + 1, N + 1, K, 1, false); + Test(M + 3, N + 2, K, 1, true); + Test(M + 4, N, K, 1, false); + Test(M, N + 4, K, 1, true); + Test(M + 4, N + 4, K, 1, false); + Test(M + 3, N + 7, K, 1, true); + Test(M + 8, N, K, 1, false); + Test(M, N + 8, K, 1, true); + Test(M + 12, N + 12, K, 1, false); + Test(M + 13, N, K, 1, true); + Test(M, N + 15, K, 1, false); + Test(M + 15, N + 15, K, 1, false); + if (!Packed) { + Test(M, N, K, 7, false); + Test(M + 3, N, K, 8, true); + Test(M, N + 1, K, 9, false); + Test(M + 12, N, K, 10, true); + Test(M, N + 15, K, 11, false); + Test(M + 15, N + 15, K, 12, true); + } + } + } + printf("M %zd\n", M); + } + + for (size_t M = 1; M < 160; M++) { + for (size_t N = 1; N < 160; N++) { + for (size_t K = 1; K < 160; K++) { + Test(M, N, K, 1, true); + } + } + printf("M %zd\n", M); + } + + for (size_t M = 160; M < 320; M += 24) { + for (size_t N = 112; N < 320; N += 24) { + for (size_t K = 1; K < 16; K++) { + Test(M, N, K, 1, true); + } + for (size_t K = 16; K < 160; K += 32) { + Test(M, N, K, 1, false); + } + } + printf("M %zd\n", M); + } + } + + +}; + + From 9aca265e9610b1d84d8f8440c305a40a54a79c43 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Tue, 24 Jan 2023 15:04:02 -0800 Subject: [PATCH 04/19] neon kernel msvc --- cmake/onnxruntime_mlas.cmake | 4 + .../mlas/lib/arm64/HalfGemmKernelNeon.asm | 426 ++++++++++++++++++ onnxruntime/core/mlas/lib/halfgemm.cpp | 6 +- onnxruntime/core/mlas/lib/halfgemm.h | 27 +- .../core/mlas/lib/halfgemm_kernel_neon.cpp | 130 ++++++ 5 files changed, 581 insertions(+), 12 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm create mode 100644 onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 267d85b855f7c..720694938afdf 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -60,6 +60,7 @@ function(setup_mlas_source_for_windows) if(onnxruntime_target_platform STREQUAL "ARM64") target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp @@ -74,6 +75,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm + ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm @@ -306,6 +308,7 @@ else() ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S + ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S @@ -315,6 +318,7 @@ else() ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp diff --git a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm new file mode 100644 index 0000000000000..bebf862b2e9f0 --- /dev/null +++ b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm @@ -0,0 +1,426 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + HalfGemmKernelNeon.asm + +Abstract: + + This module implements the kernels for the half precision matrix/matrix + multiply operation (HALF GEMM). + +--*/ + +#include "kxarm64.h" + +// +// Stack frame layout for the half gemm kernel. +// Callee save registers: d8-d15, x19-x30. x18 is reserved by the OS. +// + +#define HGemmKernelFrame_SavedRegs (2 * 8) +#define HGemmKernelFrame_B 0 + HGemmKernelFrame_SavedRegs +#define HGemmKernelFrame_ldb 8 + HGemmKernelFrame_SavedRegs +#define HGemmKernelFrame_ZeroMode 16 + HGemmKernelFrame_SavedRegs + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute 6 rows of GEMM + +Arguments: + + CountM - (x0) the number of rows for matrix A and matrix C. + only process 6 rows + + CountN - (x1) the number of columns from matrix B and matrix C + + CountK - (x2/x0) the number of columns from matrix A and the + number of rows from matrix B. + + C - (x3) the address of matrix C. + + ldc - (x4) - the first dimension of matrix C. + + Bias - (x5) - the address of the Bias vector (optional) + + A - (x6) - the address of matrix A + + lda - (x7) - the first dimension of matrix A + + B - the address of matrix B + + ldb - the first dimension of matrix B + + ZeroMode - true if the output matrix must be zero initialized, else + if the output matrix is accumulated into + +--*/ + + LEAF_ENTRY MlasHalfGemmKernelNeon + + PROLOG_SAVE_REG x19,#-HGemmKernelFrame_SavedRegs! + lsl x2,x2,#1 // k *= sizeof(fp16) + ldr x8,[sp,#HGemmKernelFrame_B] + ldr x15,[sp,#HGemmKernelFrame_ldb] + CMP x0, 2 // if M < 2 + ADD x9, x6, x7 // a1 = a0 + lda + ADD x16, x3, x4 // c1 = c0 + ldc + CSEL x9, x6, x9, LO // a1 = a0 + CSEL x16, x3, x16, LO // c1 = c0 + ADD x10, x9, x7 // a2 = a1 + lda + ADD x17, x16, x4 // c2 = c1 + ldc + CSEL x10, x9, x10, LS // if M <= 2 a2 = a1 + CSEL x17, x16, x17, LS // c2 = c1 + CMP x0, 4 // if M < 4 + ADD x11, x10, x7 // a3 = a2 + lda + ADD x14, x17, x4 // c3 = c2 + ldc + CSEL x11, x10, x11, LO // a3 = a2 + CSEL x14, x17, x14, LO // c3 = c2 + ADD x12, x11, x7 // a4 = a3 + lda + ADD x13, x14, x4 // c4 = c3 + ldc + CSEL x12, x11, x12, LS // if M <= 4 a4 = a3 + CSEL x13, x14, x13, LS // c4 = c3 + CMP x0, 6 // if M < 6 + ADD x7, x12, x7 // a5 = a4 + lda + ADD x4, x13, x4 // c5 = c4 + ldc + CSEL x7, x12, x7, LO // a5 = a4 + CSEL x4, x13, x4, LO // c5 = c4 + sub x15,x15,16 + ldrb w19,[sp,#HGemmKernelFrame_ZeroMode] + +/**** +Main loop processes 6x16 tile, depth 4. + B 4x16 + --------------------------------------- + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + A 6x4 --------------------------------------- + ------------------ --------------------------------------- +x6 |v0.h[0]..v0.h[3]| |v20.h[0]..v20.h[7] v21.h[0]..v21.h[7]| x3 +x9 |v1.h[0]..v1.h[3]| |v22.h[0]..v22.h[7] v23.h[0]..v23.h[7]| x16 +x10 |v2.h[0]..v2.h[3]| |v24.h[0]..v24.h[7] v25.h[0]..v25.h[7]| x17 +x11 |v3.h[0]..v3.h[3]| |v26.h[0]..v26.h[7] v27.h[0]..v27.h[7]| x14 +x12 |v4.h[0]..v4.h[3]| |v28.h[0]..v28.h[7] v29.h[0]..v29.h[7]| x13 +x7 |v5.h[0]..v5.h[3]| |v30.h[0]..v30.h[7] v31.h[0]..v31.h[7]| x4 + ------------------ --------------------------------------- +****/ + +M6N16OutterLoopN + cbz x5, M6N16SkipBias + ldp q20,q21,[x5],32 // Load 16 Bias values + b M6N16PopulateAccumulators + +M6N16SkipBias + eor q20.16b,q20.16b,q20.16b // No bias, reset regs + eor q21.16b,q21.16b,q21.16b + +M6N16PopulateAccumulators + MOV v22.16b, v20.16b + MOV v23.16b, v21.16b + MOV v24.16b, v20.16b + MOV v25.16b, v21.16b + MOV v26.16b, v20.16b + MOV v27.16b, v21.16b + MOV v28.16b, v20.16b + subs x0,x2,8 // k -= 4 (8 bytes) + MOV v29.16b, v21.16b + MOV v30.16b, v20.16b + MOV v31.16b, v21.16b + b.lo M6N16RemainderK123 // remaining k 1~3 + + ldr d0,[x6],8 // A0 + ldr q16,[x8],16 // B0.l + ld1 {v17.16b},[x8],x15 // B0.high x8 <- next row + subs x0,x0,8 // over decement k -= 4 (8 bytes) + ldr d1,[x9],8 // A1 + ldr d2,[x10],8 // A2 + ldr d3,[x11],8 // A3 + b.lo M6N16LoopK_Epilogue // need k>=8 for main loop + +M6N16InnerLoopK + FMLA v20.8h, v16.8h, v0.h[0] + FMLA v21.8h, v17.8h, v0.h[0] + LDR d4, [x12], 8 // A4 + FMLA v22.8h, v16.8h, v1.h[0] + FMLA v23.8h, v17.8h, v1.h[0] + LDR d5, [x7], 8 // A5 + FMLA v24.8h, v16.8h, v2.h[0] + FMLA v25.8h, v17.8h, v2.h[0] + ldr q18,[x8],16 // B1.low + FMLA v26.8h, v16.8h, v3.h[0] + FMLA v27.8h, v17.8h, v3.h[0] + ld1 {v19.16b},[x8],x15 // B1.high x8 <- next row + FMLA v28.8h, v16.8h, v4.h[0] + FMLA v29.8h, v17.8h, v4.h[0] + FMLA v30.8h, v16.8h, v5.h[0] + FMLA v31.8h, v17.8h, v5.h[0] + subs x0,x0,8 // k -= 4 + + FMLA v20.8h, v18.8h, v0.h[1] + FMLA v21.8h, v19.8h, v0.h[1] + ldr q16,[x8],16 // B2.low + FMLA v22.8h, v18.8h, v1.h[1] + FMLA v23.8h, v19.8h, v1.h[1] + ld1 {v17.16b},[x8],x15 // B2.high x8 <- next row + FMLA v24.8h, v18.8h, v2.h[1] + FMLA v25.8h, v19.8h, v2.h[1] + FMLA v26.8h, v18.8h, v3.h[1] + FMLA v27.8h, v19.8h, v3.h[1] + FMLA v28.8h, v18.8h, v4.h[1] + FMLA v29.8h, v19.8h, v4.h[1] + FMLA v30.8h, v18.8h, v5.h[1] + FMLA v31.8h, v19.8h, v5.h[1] + + FMLA v20.8h, v16.8h, v0.h[2] + FMLA v21.8h, v17.8h, v0.h[2] + ldr q18,[x8],16 // B3.low + FMLA v22.8h, v16.8h, v1.h[2] + FMLA v23.8h, v17.8h, v1.h[2] + ld1 {v19.16b},[x8],x15 // B3.high x8 <- next row + FMLA v24.8h, v16.8h, v2.h[2] + FMLA v25.8h, v17.8h, v2.h[2] + FMLA v26.8h, v16.8h, v3.h[2] + FMLA v27.8h, v17.8h, v3.h[2] + FMLA v28.8h, v16.8h, v4.h[2] + FMLA v29.8h, v17.8h, v4.h[2] + FMLA v30.8h, v16.8h, v5.h[2] + FMLA v31.8h, v17.8h, v5.h[2] + + ldr q16,[x8],16 // B0.low next iter + FMLA v20.8h, v18.8h, v0.h[3] + FMLA v21.8h, v19.8h, v0.h[3] + ld1 {v17.16b},[x8],x15 // B0.high x8 <- next row + FMLA v22.8h, v18.8h, v1.h[3] + FMLA v23.8h, v19.8h, v1.h[3] + LDR d0, [x6], 8 // A0 + FMLA v24.8h, v18.8h, v2.h[3] + FMLA v25.8h, v19.8h, v2.h[3] + LDR d1, [x9], 8 // A1 + FMLA v26.8h, v18.8h, v3.h[3] + FMLA v27.8h, v19.8h, v3.h[3] + LDR d2, [x10], 8 // A2 + FMLA v28.8h, v18.8h, v4.h[3] + FMLA v29.8h, v19.8h, v4.h[3] + LDR d3, [x11], 8 // A3 + FMLA v30.8h, v18.8h, v5.h[3] + FMLA v31.8h, v19.8h, v5.h[3] + b.hs M6N16InnerLoopK // k >= 8 for main loop + +M6N16LoopK_Epilogue + // last block of k >= 4, no pre-load for next iter + FMLA v20.8h, v16.8h, v0.h[0] + FMLA v21.8h, v17.8h, v0.h[0] + LDR d4, [x12], 8 // A4 + FMLA v22.8h, v16.8h, v1.h[0] + FMLA v23.8h, v17.8h, v1.h[0] + LDR d5, [x7], 8 // A5 + FMLA v24.8h, v16.8h, v2.h[0] + FMLA v25.8h, v17.8h, v2.h[0] + ldr q18,[x8],16 // B1.low + FMLA v26.8h, v16.8h, v3.h[0] + FMLA v27.8h, v17.8h, v3.h[0] + ld1 {v19.16b},[x8],x15 // B1.high x8 <- next row + FMLA v28.8h, v16.8h, v4.h[0] + FMLA v29.8h, v17.8h, v4.h[0] + FMLA v30.8h, v16.8h, v5.h[0] + FMLA v31.8h, v17.8h, v5.h[0] + adds x0,x0,8 // revert k over-decrement + + FMLA v20.8h, v18.8h, v0.h[1] + FMLA v21.8h, v19.8h, v0.h[1] + ldr q16,[x8],16 // B2.low + FMLA v22.8h, v18.8h, v1.h[1] + FMLA v23.8h, v19.8h, v1.h[1] + ld1 {v17.16b},[x8],x15 // B2.high x8 <- next row + FMLA v24.8h, v18.8h, v2.h[1] + FMLA v25.8h, v19.8h, v2.h[1] + FMLA v26.8h, v18.8h, v3.h[1] + FMLA v27.8h, v19.8h, v3.h[1] + FMLA v28.8h, v18.8h, v4.h[1] + FMLA v29.8h, v19.8h, v4.h[1] + FMLA v30.8h, v18.8h, v5.h[1] + FMLA v31.8h, v19.8h, v5.h[1] + + FMLA v20.8h, v16.8h, v0.h[2] + FMLA v21.8h, v17.8h, v0.h[2] + ldr q18,[x8],16 // B3.low + FMLA v22.8h, v16.8h, v1.h[2] + FMLA v23.8h, v17.8h, v1.h[2] + ld1 {v19.16b},[x8],x15 // B3.high x8 <- next row + FMLA v24.8h, v16.8h, v2.h[2] + FMLA v25.8h, v17.8h, v2.h[2] + FMLA v26.8h, v16.8h, v3.h[2] + FMLA v27.8h, v17.8h, v3.h[2] + FMLA v28.8h, v16.8h, v4.h[2] + FMLA v29.8h, v17.8h, v4.h[2] + FMLA v30.8h, v16.8h, v5.h[2] + FMLA v31.8h, v17.8h, v5.h[2] + + FMLA v20.8h, v18.8h, v0.h[3] + FMLA v21.8h, v19.8h, v0.h[3] + FMLA v22.8h, v18.8h, v1.h[3] + FMLA v23.8h, v19.8h, v1.h[3] + FMLA v24.8h, v18.8h, v2.h[3] + FMLA v25.8h, v19.8h, v2.h[3] + FMLA v26.8h, v18.8h, v3.h[3] + FMLA v27.8h, v19.8h, v3.h[3] + FMLA v28.8h, v18.8h, v4.h[3] + FMLA v29.8h, v19.8h, v4.h[3] + FMLA v30.8h, v18.8h, v5.h[3] + FMLA v31.8h, v19.8h, v5.h[3] + B.NE M6N16RemainderK123 // remaining k 1~3 + +M6N16NextIterN + SUBS x1, x1, 16 + B.LO M6StoreRemainderN + + ldr x8,[sp,#HGemmKernelFrame_B] + add x8,x8,16 // B <- next 16 columns + str x8,[sp,#HGemmKernelFrame_B] + ST1 {v20.16b, v21.16b}, [x3], 32 + SUB x6, x6, x2 // a0 -= k + ST1 {v22.16b, v23.16b}, [x16], 32 + SUB x9, x9, x2 // a1 -= k + ST1 {v24.16b, v25.16b}, [x17], 32 + SUB x10, x10, x2 // a2 -= k + ST1 {v26.16b, v27.16b}, [x14], 32 + SUB x11, x11, x2 // a3 -= k + ST1 {v28.16b, v29.16b}, [x13], 32 + SUB x12, x12, x2 // a4 -= k + ST1 {v30.16b, v31.16b}, [x4], 32 + SUB x7, x7, x2 // a5 -= k + B.HI M6N16OutterLoopN + +ExitKernel + EPILOG_RESTORE_REG x19,#HGemmKernelFrame_SavedRegs! + EPILOG_RETURN + +M6N16RemainderK123 + TBZ x0, 2, M6N16RemainderK1 + LDR s0, [x6], 4 + LDR q16, [x8], 16 + ld1 {v17.16b},[x8],x15 + LDR s1, [x9], 4 + LDR s2, [x10], 4 + LDR s3, [x11], 4 + LDR s4, [x12], 4 + LDR s5, [x7], 4 + LDR q18, [x8], 16 + ld1 {v19.16b},[x8],x15 + FMLA v20.8h, v16.8h, v0.h[0] + FMLA v22.8h, v16.8h, v1.h[0] + FMLA v24.8h, v16.8h, v2.h[0] + FMLA v26.8h, v16.8h, v3.h[0] + FMLA v28.8h, v16.8h, v4.h[0] + FMLA v30.8h, v16.8h, v5.h[0] + FMLA v21.8h, v17.8h, v0.h[0] + FMLA v23.8h, v17.8h, v1.h[0] + FMLA v25.8h, v17.8h, v2.h[0] + FMLA v27.8h, v17.8h, v3.h[0] + FMLA v29.8h, v17.8h, v4.h[0] + FMLA v31.8h, v17.8h, v5.h[0] + + FMLA v20.8h, v18.8h, v0.h[1] + FMLA v22.8h, v18.8h, v1.h[1] + FMLA v24.8h, v18.8h, v2.h[1] + FMLA v26.8h, v18.8h, v3.h[1] + FMLA v28.8h, v18.8h, v4.h[1] + FMLA v30.8h, v18.8h, v5.h[1] + FMLA v21.8h, v19.8h, v0.h[1] + FMLA v23.8h, v19.8h, v1.h[1] + FMLA v25.8h, v19.8h, v2.h[1] + FMLA v27.8h, v19.8h, v3.h[1] + FMLA v29.8h, v19.8h, v4.h[1] + FMLA v31.8h, v19.8h, v5.h[1] + TBZ x0, 1, M6N16NextIterN + +M6N16RemainderK1 + LDR h0, [x6], 2 + LDR q16, [x8], 16 + ld1 {v17.16b},[x8],x15 + LDR h1, [x9], 2 + LDR h2, [x10], 2 + LDR h3, [x11], 2 + LDR h4, [x12], 2 + LDR h5, [x7], 2 + FMLA v20.8h, v16.8h, v0.h[0] + FMLA v22.8h, v16.8h, v1.h[0] + FMLA v24.8h, v16.8h, v2.h[0] + FMLA v26.8h, v16.8h, v3.h[0] + FMLA v28.8h, v16.8h, v4.h[0] + FMLA v30.8h, v16.8h, v5.h[0] + FMLA v21.8h, v17.8h, v0.h[0] + FMLA v23.8h, v17.8h, v1.h[0] + FMLA v25.8h, v17.8h, v2.h[0] + FMLA v27.8h, v17.8h, v3.h[0] + FMLA v29.8h, v17.8h, v4.h[0] + FMLA v31.8h, v17.8h, v5.h[0] + B M6N16NextIterN + +M6StoreRemainderN + TBZ x1, 3, M6StoreRemainderN + STR q20, [x3], 16 + MOV v20.16b, v21.16b + STR q22, [x16], 16 + MOV v22.16b, v23.16b + STR q24, [x17], 16 + MOV v24.16b, v25.16b + STR q26, [x14], 16 + MOV v26.16b, v27.16b + STR q28, [x13], 16 + MOV v28.16b, v29.16b + STR q30, [x4], 16 + MOV v30.16b, v31.16b + +M6StoreRemainderN4 + TBZ x1, 2, M6StoreRemainderN2 + STR d20, [x3], 8 + STR d22, [x16], 8 + DUP d20, v20.d[1] + DUP d22, v22.d[1] + STR d24, [x17], 8 + STR d26, [x14], 8 + DUP d24, v24.d[1] + DUP d26, v26.d[1] + STR d28, [x13], 8 + STR d30, [x4], 8 + DUP d28, v28.d[1] + DUP d30, v30.d[1] + +M6StoreRemainderN2 + TBZ x1, 1, M6StoreRemainderN1 + STR s20, [x3], 4 + STR s22, [x16], 4 + DUP s20, v20.s[1] + DUP s22, v22.s[1] + STR s24, [x17], 4 + STR s26, [x14], 4 + DUP s24, v24.s[1] + DUP s26, v26.s[1] + STR s28, [x13], 4 + STR s30, [x4], 4 + DUP s28, v28.s[1] + DUP s30, v30.s[1] + +M6StoreRemainderN1 + TBZ x1, 0, ExitKernel + STR h20, [x3] + STR h22, [x16] + STR h24, [x17] + STR h26, [x14] + STR h28, [x13] + STR h30, [x4] + b ExitKernel + + LEAF_END MlasHalfGemmKernelNeon + + END diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 71124f26fb12f..c82b2c8b298d1 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -211,13 +211,13 @@ MlasHalfGemmKernel( size_t CountM, size_t CountN, size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, const _mlas_fp16_* A, size_t lda, const _mlas_fp16_* B, size_t ldb, - _mlas_fp16_* C, - size_t ldc, - const _mlas_fp16_* Bias, const bool ZeroMode) { for (size_t m = 0; m < CountM; m++) { diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 4b44f7c156058..ce1f4ede49171 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -37,6 +37,7 @@ Module Name: #include #include "mlasi.h" +#include "mlas_float16.h" /** @@ -173,13 +174,13 @@ MlasHalfGemmKernel( const size_t CountM, const size_t CountN, const size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, const _mlas_fp16_* A, const size_t lda, const _mlas_fp16_* B, const size_t ldb, - _mlas_fp16_* C, - size_t ldc, - const _mlas_fp16_* Bias, const bool ZeroMode ); @@ -233,13 +234,13 @@ MlasHalfGemmNoPackOperation( RowsRemaining, RangeCountN, K, + c, + ldc, + Bias, pa, lda, pb, ldb, - c, - ldc, - Bias, true); size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); @@ -408,13 +409,13 @@ MlasHalfGemmOperation( RowsRemaining, CountN, CountK, + c, + ldc, + ZeroMode ? pbias : nullptr, pa, ld_pa, pb, ld_pb, - c, - ldc, - ZeroMode ? pbias : nullptr, ZeroMode); size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); @@ -486,9 +487,17 @@ struct MLAS_HALF_GEMM_DISPATCH { extern const MLAS_HALF_GEMM_DISPATCH MlasHalfGemmDispatchDefault; +#if defined(MLAS_TARGET_ARM64) +extern const MLAS_HALF_GEMM_DISPATCH MlasHalfGemmDispatchNeon; +#endif + MLAS_FORCEINLINE const MLAS_HALF_GEMM_DISPATCH* MlasHalfGemmGetDispatch() { +#if defined(MLAS_TARGET_ARM64) + return &MlasHalfGemmDispatchNeon; +#else return &MlasHalfGemmDispatchDefault; +#endif } diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..2c0684e64c56b --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp @@ -0,0 +1,130 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm_kernel_neon.cpp + +Abstract: + + This module implements half precision GEMM kernel for neon. + +--*/ + +#include "mlasi.h" +#include "halfgemm.h" + +// +// Define the prototypes of the NEON routines written in assembly. +// +// N.B. The kernel has not been ported to build with the Windows ARM32 toolset. +// + +extern "C" { + + size_t + MLASCALL + MlasHalfGemmKernelNeon( + const size_t CountM, + const size_t CountN, + const size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + const size_t lda, + const _mlas_fp16_* B, + const size_t ldb, + const bool ZeroMode + ); + +} + + +struct MLAS_HALF_GEMM_KERNEL_NEON { + static constexpr bool PackNeeded = false; + static constexpr size_t KernelMaxM = 6; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 1; + + static constexpr MLAS_HALF_GEMM_STRIDES Strides{24, 128, 256}; +}; + + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackA( + _mlas_fp16_* D, + const float* A, + size_t lda, + size_t CountM, + size_t CountK +) +{ + for (size_t m = 0; m < CountM; m++) { + for (size_t k = 0; k < CountK; k++) { + *D++ = MLAS_Float2Half(*(A + m * lda + k)); + } + } +} + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackB( + _mlas_fp16_* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK +) +{ + for (size_t k = 0; k < CountK; k++) { + for (size_t n = 0; n < CountN; n++) { + *D++ = MLAS_Float2Half(*(B + k * ldb + n)); + } + } +} + + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmKernel( + size_t CountM, + size_t CountN, + size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + size_t lda, + const _mlas_fp16_* B, + size_t ldb, + const bool ZeroMode) +{ + MlasHalfGemmKernelNeon( + CountM, + CountN, + CountK, + C, + ldc, + Bias, + A, + lda, + B, + ldb, + ZeroMode); +} + + +const MLAS_HALF_GEMM_DISPATCH MlasHalfGemmDispatchNeon = { + MlasHalfGemmOperation, + nullptr, + MlasHalfGemmConvertPackB, + MLAS_HALF_GEMM_KERNEL_NEON::PackedK, + MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM +}; From e905bb64bb238249dc63174ebf0616e1ea165406 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Tue, 24 Jan 2023 15:47:28 -0800 Subject: [PATCH 05/19] accumulate C --- .../mlas/lib/arm64/HalfGemmKernelNeon.asm | 107 ++++++++++++++++-- .../test/mlas/unittest/test_halfgemm.cpp | 20 ++++ 2 files changed, 115 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm index bebf862b2e9f0..938be26df3c48 100644 --- a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm +++ b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm @@ -70,29 +70,30 @@ Arguments: ldr x8,[sp,#HGemmKernelFrame_B] ldr x15,[sp,#HGemmKernelFrame_ldb] CMP x0, 2 // if M < 2 - ADD x9, x6, x7 // a1 = a0 + lda - ADD x16, x3, x4 // c1 = c0 + ldc + add x9,x6,x7,lsl #1 // a1 = a0 + lda + add x16,x3,x4,lsl #1 // c1 = c0 + ldc CSEL x9, x6, x9, LO // a1 = a0 CSEL x16, x3, x16, LO // c1 = c0 - ADD x10, x9, x7 // a2 = a1 + lda - ADD x17, x16, x4 // c2 = c1 + ldc + add x10,x9,x7,lsl #1 // a2 = a1 + lda + add x17,x16,x4,lsl #1 // c2 = c1 + ldc CSEL x10, x9, x10, LS // if M <= 2 a2 = a1 CSEL x17, x16, x17, LS // c2 = c1 CMP x0, 4 // if M < 4 - ADD x11, x10, x7 // a3 = a2 + lda - ADD x14, x17, x4 // c3 = c2 + ldc + add x11,x10,x7,lsl #1 // a3 = a2 + lda + add x14,x17,x4,lsl #1 // c3 = c2 + ldc CSEL x11, x10, x11, LO // a3 = a2 CSEL x14, x17, x14, LO // c3 = c2 - ADD x12, x11, x7 // a4 = a3 + lda - ADD x13, x14, x4 // c4 = c3 + ldc + add x12,x11,x7,lsl #1 // a4 = a3 + lda + add x13,x14,x4,lsl #1 // c4 = c3 + ldc CSEL x12, x11, x12, LS // if M <= 4 a4 = a3 CSEL x13, x14, x13, LS // c4 = c3 CMP x0, 6 // if M < 6 - ADD x7, x12, x7 // a5 = a4 + lda - ADD x4, x13, x4 // c5 = c4 + ldc + add x7,x12,x7,lsl #1 // a5 = a4 + lda + add x4,x13,x4,lsl #1 // c5 = c4 + ldc CSEL x7, x12, x7, LO // a5 = a4 CSEL x4, x13, x4, LO // c5 = c4 - sub x15,x15,16 + lsl x15,x15,#1 // ldb *= sizeof(fp16) + sub x15,x15,16 // ldb -= 16 ldrb w19,[sp,#HGemmKernelFrame_ZeroMode] /**** @@ -284,8 +285,30 @@ M6N16NextIterN B.LO M6StoreRemainderN ldr x8,[sp,#HGemmKernelFrame_B] - add x8,x8,16 // B <- next 16 columns + add x8,x8,32 // B <- next 16 columns str x8,[sp,#HGemmKernelFrame_B] + + cbnz x19,M6N16SkipAccumulateOutput + ldp q0,q1,[x3] + ldp q2,q3,[x16] + ldp q4,q5,[x17] + ldp q6,q7,[x14] + ldp q16,q17,[x13] + ldp q18,q19,[x4] + fadd v20.8h,v20.8h,v0.8h + fadd v21.8h,v21.8h,v1.8h + fadd v22.8h,v22.8h,v2.8h + fadd v23.8h,v23.8h,v3.8h + fadd v24.8h,v24.8h,v4.8h + fadd v25.8h,v25.8h,v5.8h + fadd v26.8h,v26.8h,v6.8h + fadd v27.8h,v27.8h,v7.8h + fadd v28.8h,v28.8h,v16.8h + fadd v29.8h,v29.8h,v17.8h + fadd v30.8h,v30.8h,v18.8h + fadd v31.8h,v31.8h,v19.8h + +M6N16SkipAccumulateOutput ST1 {v20.16b, v21.16b}, [x3], 32 SUB x6, x6, x2 // a0 -= k ST1 {v22.16b, v23.16b}, [x16], 32 @@ -368,6 +391,21 @@ M6N16RemainderK1 M6StoreRemainderN TBZ x1, 3, M6StoreRemainderN + cbnz x19,M6N8SkipAccumulateOutput + ldr q0,[x3] + ldr q1,[x16] + ldr q2,[x17] + ldr q3,[x14] + ldr q4,[x13] + ldr q5,[x4] + fadd v20.8h,v20.8h,v0.8h + fadd v22.8h,v22.8h,v1.8h + fadd v24.8h,v24.8h,v2.8h + fadd v26.8h,v26.8h,v3.8h + fadd v28.8h,v28.8h,v4.8h + fadd v30.8h,v30.8h,v5.8h + +M6N8SkipAccumulateOutput STR q20, [x3], 16 MOV v20.16b, v21.16b STR q22, [x16], 16 @@ -383,6 +421,21 @@ M6StoreRemainderN M6StoreRemainderN4 TBZ x1, 2, M6StoreRemainderN2 + cbnz x19,M6N4SkipAccumulateOutput + ldr d0,[x3] + ldr d1,[x16] + ldr d2,[x17] + ldr d3,[x14] + ldr d4,[x13] + ldr d5,[x4] + fadd v20.4h,v20.4h,v0.4h + fadd v22.4h,v22.4h,v1.4h + fadd v24.4h,v24.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v28.4h,v28.4h,v4.4h + fadd v30.4h,v30.4h,v5.4h + +M6N4SkipAccumulateOutput STR d20, [x3], 8 STR d22, [x16], 8 DUP d20, v20.d[1] @@ -398,6 +451,21 @@ M6StoreRemainderN4 M6StoreRemainderN2 TBZ x1, 1, M6StoreRemainderN1 + cbnz x19,M6N2SkipAccumulateOutput + ldr s0,[x3] + ldr s1,[x16] + ldr s2,[x17] + ldr s3,[x14] + ldr s4,[x13] + ldr s5,[x4] + fadd v20.4h,v20.4h,v0.4h + fadd v22.4h,v22.4h,v1.4h + fadd v24.4h,v24.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v28.4h,v28.4h,v4.4h + fadd v30.4h,v30.4h,v5.4h + +M6N2SkipAccumulateOutput STR s20, [x3], 4 STR s22, [x16], 4 DUP s20, v20.s[1] @@ -413,6 +481,21 @@ M6StoreRemainderN2 M6StoreRemainderN1 TBZ x1, 0, ExitKernel + cbnz x19,M6N1SkipAccumulateOutput + ldr h0,[x3] + ldr h1,[x16] + ldr h2,[x17] + ldr h3,[x14] + ldr h4,[x13] + ldr h5,[x4] + fadd v20.4h,v20.4h,v0.4h + fadd v22.4h,v22.4h,v1.4h + fadd v24.4h,v24.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v28.4h,v28.4h,v4.4h + fadd v30.4h,v30.4h,v5.4h + +M6N1SkipAccumulateOutput STR h20, [x3] STR h22, [x16] STR h24, [x17] diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp index a275dabe7df4d..57d82d3f7c3c8 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp @@ -55,6 +55,26 @@ class HalfGemmShortExecuteTest : public MlasTestFixture Date: Wed, 25 Jan 2023 12:36:23 -0800 Subject: [PATCH 06/19] test buffer fill --- .../mlas/lib/arm64/HalfGemmKernelNeon.asm | 135 ++++++++++++------ onnxruntime/core/mlas/lib/halfgemm.cpp | 12 +- onnxruntime/core/mlas/lib/halfgemm.h | 24 ++-- .../core/mlas/lib/halfgemm_kernel_neon.cpp | 7 +- .../test/mlas/unittest/test_halfgemm.h | 102 +++---------- onnxruntime/test/mlas/unittest/test_util.h | 45 +++--- 6 files changed, 159 insertions(+), 166 deletions(-) diff --git a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm index 938be26df3c48..4680ff976bb09 100644 --- a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm +++ b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm @@ -66,12 +66,12 @@ Arguments: LEAF_ENTRY MlasHalfGemmKernelNeon PROLOG_SAVE_REG x19,#-HGemmKernelFrame_SavedRegs! - lsl x2,x2,#1 // k *= sizeof(fp16) - ldr x8,[sp,#HGemmKernelFrame_B] ldr x15,[sp,#HGemmKernelFrame_ldb] + lsl x2,x2,#1 // k *= sizeof(fp16) CMP x0, 2 // if M < 2 add x9,x6,x7,lsl #1 // a1 = a0 + lda add x16,x3,x4,lsl #1 // c1 = c0 + ldc + ldr x8,[sp,#HGemmKernelFrame_B] CSEL x9, x6, x9, LO // a1 = a0 CSEL x16, x3, x16, LO // c1 = c0 add x10,x9,x7,lsl #1 // a2 = a1 + lda @@ -93,8 +93,8 @@ Arguments: CSEL x7, x12, x7, LO // a5 = a4 CSEL x4, x13, x4, LO // c5 = c4 lsl x15,x15,#1 // ldb *= sizeof(fp16) - sub x15,x15,16 // ldb -= 16 ldrb w19,[sp,#HGemmKernelFrame_ZeroMode] + sub x15,x15,16 // ldb -= 16 /**** Main loop processes 6x16 tile, depth 4. @@ -282,11 +282,8 @@ M6N16LoopK_Epilogue M6N16NextIterN SUBS x1, x1, 16 - B.LO M6StoreRemainderN - ldr x8,[sp,#HGemmKernelFrame_B] - add x8,x8,32 // B <- next 16 columns - str x8,[sp,#HGemmKernelFrame_B] + B.LO M6StoreRemainderN cbnz x19,M6N16SkipAccumulateOutput ldp q0,q1,[x3] @@ -319,8 +316,10 @@ M6N16SkipAccumulateOutput SUB x11, x11, x2 // a3 -= k ST1 {v28.16b, v29.16b}, [x13], 32 SUB x12, x12, x2 // a4 -= k + add x8,x8,32 // B <- next 16 columns ST1 {v30.16b, v31.16b}, [x4], 32 SUB x7, x7, x2 // a5 -= k + str x8,[sp,#HGemmKernelFrame_B] B.HI M6N16OutterLoopN ExitKernel @@ -390,8 +389,8 @@ M6N16RemainderK1 B M6N16NextIterN M6StoreRemainderN - TBZ x1, 3, M6StoreRemainderN - cbnz x19,M6N8SkipAccumulateOutput + cbnz x19,M6StoreRemainderNZeroMode + TBZ x1, 3, M6StoreRemainderN4 ldr q0,[x3] ldr q1,[x16] ldr q2,[x17] @@ -401,17 +400,15 @@ M6StoreRemainderN fadd v20.8h,v20.8h,v0.8h fadd v22.8h,v22.8h,v1.8h fadd v24.8h,v24.8h,v2.8h - fadd v26.8h,v26.8h,v3.8h - fadd v28.8h,v28.8h,v4.8h - fadd v30.8h,v30.8h,v5.8h - -M6N8SkipAccumulateOutput STR q20, [x3], 16 MOV v20.16b, v21.16b STR q22, [x16], 16 MOV v22.16b, v23.16b STR q24, [x17], 16 MOV v24.16b, v25.16b + fadd v26.8h,v26.8h,v3.8h + fadd v28.8h,v28.8h,v4.8h + fadd v30.8h,v30.8h,v5.8h STR q26, [x14], 16 MOV v26.16b, v27.16b STR q28, [x13], 16 @@ -421,67 +418,60 @@ M6N8SkipAccumulateOutput M6StoreRemainderN4 TBZ x1, 2, M6StoreRemainderN2 - cbnz x19,M6N4SkipAccumulateOutput ldr d0,[x3] ldr d1,[x16] ldr d2,[x17] ldr d3,[x14] ldr d4,[x13] ldr d5,[x4] - fadd v20.4h,v20.4h,v0.4h - fadd v22.4h,v22.4h,v1.4h - fadd v24.4h,v24.4h,v2.4h - fadd v26.4h,v26.4h,v3.4h - fadd v28.4h,v28.4h,v4.4h - fadd v30.4h,v30.4h,v5.4h - -M6N4SkipAccumulateOutput - STR d20, [x3], 8 - STR d22, [x16], 8 + fadd v21.4h,v20.4h,v0.4h DUP d20, v20.d[1] + fadd v23.4h,v22.4h,v1.4h DUP d22, v22.d[1] - STR d24, [x17], 8 - STR d26, [x14], 8 + fadd v25.4h,v24.4h,v2.4h DUP d24, v24.d[1] + fadd v27.4h,v26.4h,v3.4h DUP d26, v26.d[1] - STR d28, [x13], 8 - STR d30, [x4], 8 + fadd v29.4h,v28.4h,v4.4h DUP d28, v28.d[1] + fadd v31.4h,v30.4h,v5.4h DUP d30, v30.d[1] + STR d21, [x3], 8 + STR d23, [x16], 8 + STR d25, [x17], 8 + STR d27, [x14], 8 + STR d29, [x13], 8 + STR d31, [x4], 8 M6StoreRemainderN2 TBZ x1, 1, M6StoreRemainderN1 - cbnz x19,M6N2SkipAccumulateOutput ldr s0,[x3] ldr s1,[x16] ldr s2,[x17] ldr s3,[x14] ldr s4,[x13] ldr s5,[x4] - fadd v20.4h,v20.4h,v0.4h - fadd v22.4h,v22.4h,v1.4h - fadd v24.4h,v24.4h,v2.4h - fadd v26.4h,v26.4h,v3.4h - fadd v28.4h,v28.4h,v4.4h - fadd v30.4h,v30.4h,v5.4h - -M6N2SkipAccumulateOutput - STR s20, [x3], 4 - STR s22, [x16], 4 + fadd v21.4h,v20.4h,v0.4h + fadd v23.4h,v22.4h,v1.4h + fadd v25.4h,v24.4h,v2.4h + fadd v27.4h,v26.4h,v3.4h + fadd v29.4h,v28.4h,v4.4h + fadd v31.4h,v30.4h,v5.4h + STR s21, [x3], 4 + STR s23, [x16], 4 DUP s20, v20.s[1] DUP s22, v22.s[1] - STR s24, [x17], 4 - STR s26, [x14], 4 + STR s25, [x17], 4 + STR s27, [x14], 4 DUP s24, v24.s[1] DUP s26, v26.s[1] - STR s28, [x13], 4 - STR s30, [x4], 4 + STR s29, [x13], 4 + STR s31, [x4], 4 DUP s28, v28.s[1] DUP s30, v30.s[1] M6StoreRemainderN1 TBZ x1, 0, ExitKernel - cbnz x19,M6N1SkipAccumulateOutput ldr h0,[x3] ldr h1,[x16] ldr h2,[x17] @@ -494,8 +484,61 @@ M6StoreRemainderN1 fadd v26.4h,v26.4h,v3.4h fadd v28.4h,v28.4h,v4.4h fadd v30.4h,v30.4h,v5.4h + STR h20, [x3] + STR h22, [x16] + STR h24, [x17] + STR h26, [x14] + STR h28, [x13] + STR h30, [x4] + b ExitKernel + +M6StoreRemainderNZeroMode + TBZ x1, 3, M6StoreRemainderN4ZeroMode + STR q20, [x3], 16 + MOV v20.16b, v21.16b + STR q22, [x16], 16 + MOV v22.16b, v23.16b + STR q24, [x17], 16 + MOV v24.16b, v25.16b + STR q26, [x14], 16 + MOV v26.16b, v27.16b + STR q28, [x13], 16 + MOV v28.16b, v29.16b + STR q30, [x4], 16 + MOV v30.16b, v31.16b + +M6StoreRemainderN4ZeroMode + TBZ x1, 2, M6StoreRemainderN2ZeroMode + STR d20, [x3], 8 + STR d22, [x16], 8 + DUP d20, v20.d[1] + DUP d22, v22.d[1] + STR d24, [x17], 8 + STR d26, [x14], 8 + DUP d24, v24.d[1] + DUP d26, v26.d[1] + STR d28, [x13], 8 + STR d30, [x4], 8 + DUP d28, v28.d[1] + DUP d30, v30.d[1] -M6N1SkipAccumulateOutput +M6StoreRemainderN2ZeroMode + TBZ x1, 1, M6StoreRemainderN1ZeroMode + STR s20, [x3], 4 + STR s22, [x16], 4 + DUP s20, v20.s[1] + DUP s22, v22.s[1] + STR s24, [x17], 4 + STR s26, [x14], 4 + DUP s24, v24.s[1] + DUP s26, v26.s[1] + STR s28, [x13], 4 + STR s30, [x4], 4 + DUP s28, v28.s[1] + DUP s30, v30.s[1] + +M6StoreRemainderN1ZeroMode + TBZ x1, 0, ExitKernel STR h20, [x3] STR h22, [x16] STR h24, [x17] diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index c82b2c8b298d1..174c9131eb1f1 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -34,8 +34,8 @@ MlasHalfGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - const MLAS_HALF_GEMM_DISPATCH* dispatch = MlasHalfGemmGetDispatch(); - MLAS_HALF_GEMM_OPERATION* operation = dispatch->Operation; + const MLAS_HALFGEMM_DISPATCH* dispatch = MlasHalfGemmGetDispatch(); + MLAS_HALFGEMM_OPERATION* operation = dispatch->Operation; if (ThreadPool == nullptr) { for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { @@ -112,13 +112,14 @@ MlasHalfGemmPackBSize( ) { const auto* dispatch = MlasHalfGemmGetDispatch(); + const auto padding = dispatch->BufOverRead; const auto PackedK = dispatch->PackededK; if (!float2half && dispatch->CopyPackBRoutine == nullptr) { // No packing routine provided return 0; } const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); - const size_t BytesRequired = N * AlignedK * FP16_SIZE; + const size_t BytesRequired = N * AlignedK * FP16_SIZE + padding; const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); @@ -244,10 +245,11 @@ MlasHalfGemmKernel( } -const MLAS_HALF_GEMM_DISPATCH MlasHalfGemmDispatchDefault = { +const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { MlasHalfGemmOperation, nullptr, MlasHalfGemmConvertPackB, MLAS_HALF_GEMM_KERNEL_DEFAULT::PackedK, - MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM + MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM, + 0 }; diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index ce1f4ede49171..65399fe044202 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -446,7 +446,7 @@ MlasHalfGemmOperation( typedef void -(MLAS_HALF_GEMM_OPERATION)( +(MLAS_HALFGEMM_OPERATION)( const size_t N, const size_t K, const MLAS_HALF_GEMM_DATA_PARAMS* Data, @@ -459,7 +459,7 @@ void typedef void -(MLAS_HALF_GEMM_COPY_PACKB_ROUTINE)( +(MLAS_HALFGEMM_COPYPACKB_ROUTINE)( _mlas_fp16_* D, const _mlas_fp16_* B, size_t ldb, @@ -469,7 +469,7 @@ void typedef void -(MLAS_HALF_GEMM_CONVERT_PACKB_ROUTINE)( +(MLAS_HALFGEMM_CONVERTPACKB_ROUTINE)( _mlas_fp16_* D, const float* B, size_t ldb, @@ -477,22 +477,26 @@ void size_t CountK ); -struct MLAS_HALF_GEMM_DISPATCH { - MLAS_HALF_GEMM_OPERATION* Operation; - MLAS_HALF_GEMM_COPY_PACKB_ROUTINE* CopyPackBRoutine; - MLAS_HALF_GEMM_CONVERT_PACKB_ROUTINE* ConvertPackBRoutine; +/** + * @brief Hardware dependent dispatch for half precision GEMM +*/ +struct MLAS_HALFGEMM_DISPATCH { + MLAS_HALFGEMM_OPERATION* Operation; /**< HalfGemm driver */ + MLAS_HALFGEMM_COPYPACKB_ROUTINE* CopyPackBRoutine; /**< Pack function for B */ + MLAS_HALFGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ size_t PackededK; size_t StrideM; + size_t BufOverRead; }; -extern const MLAS_HALF_GEMM_DISPATCH MlasHalfGemmDispatchDefault; +extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault; #if defined(MLAS_TARGET_ARM64) -extern const MLAS_HALF_GEMM_DISPATCH MlasHalfGemmDispatchNeon; +extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon; #endif MLAS_FORCEINLINE -const MLAS_HALF_GEMM_DISPATCH* +const MLAS_HALFGEMM_DISPATCH* MlasHalfGemmGetDispatch() { #if defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp index 2c0684e64c56b..def3f7732fb41 100644 --- a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp @@ -49,7 +49,7 @@ struct MLAS_HALF_GEMM_KERNEL_NEON { static constexpr size_t KernelMaxM = 6; // max # rows the vectorized kernel can process static constexpr size_t PackedK = 1; - static constexpr MLAS_HALF_GEMM_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_HALF_GEMM_STRIDES Strides{24, 128, 16}; }; @@ -121,10 +121,11 @@ MlasHalfGemmKernel( } -const MLAS_HALF_GEMM_DISPATCH MlasHalfGemmDispatchNeon = { +const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon = { MlasHalfGemmOperation, nullptr, MlasHalfGemmConvertPackB, MLAS_HALF_GEMM_KERNEL_NEON::PackedK, - MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM + MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM, + 32 // kernel may read beyond buffer end by 32 bytes }; diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h index 17435e06498a6..860c56b8674e9 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.h +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -28,6 +28,7 @@ struct MLFp16 { MLFp16() = default; explicit constexpr MLFp16(uint16_t x) : val(x) {} + explicit constexpr MLFp16(int32_t x) : val((uint16_t)x) {} explicit MLFp16(float ff) : val(MLAS_Float2Half(ff)) {} float ToFloat() const { @@ -52,85 +53,16 @@ operator!=(const MLFp16& left, const MLFp16& right) { return left.val != right.val; } -// -// Customize buffer fill for half precision buffer -// -template <> -MLFp16* -MatrixGuardBuffer::GetBuffer(size_t Elements, bool ZeroFill) { - // - // Check if the internal buffer needs to be reallocated. - // - - if (Elements > _ElementsAllocated) { - ReleaseBuffer(); - - // - // Reserve a virtual address range for the allocation plus an unmapped - // guard region. - // - - constexpr size_t BufferAlignment = 64 * 1024; - constexpr size_t GuardPadding = 256 * 1024; - - size_t BytesToAllocate = ((Elements * FP16_SIZE) + BufferAlignment - 1) & ~(BufferAlignment - 1); - - _BaseBufferSize = BytesToAllocate + GuardPadding; - -#if defined(_WIN32) - _BaseBuffer = VirtualAlloc(NULL, _BaseBufferSize, MEM_RESERVE, PAGE_NOACCESS); -#else - _BaseBuffer = mmap(0, _BaseBufferSize, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); -#endif - - if (_BaseBuffer == nullptr) { - abort(); - } - - // - // Commit the number of bytes for the allocation leaving the upper - // guard region as unmapped. - // - -#if defined(_WIN32) - if (VirtualAlloc(_BaseBuffer, BytesToAllocate, MEM_COMMIT, PAGE_READWRITE) == nullptr) { - ORT_THROW_EX(std::bad_alloc); - } -#else - if (mprotect(_BaseBuffer, BytesToAllocate, PROT_READ | PROT_WRITE) != 0) { - abort(); - } -#endif +template +void SmallFloatFill(T* start, size_t size) { + constexpr float MinimumFillValue = -11.0f; + auto* FillAddress = start; + size_t offset = size % 23; - _ElementsAllocated = BytesToAllocate / FP16_SIZE; - _GuardAddress = (MLFp16*)((unsigned char*)_BaseBuffer + BytesToAllocate); + for (size_t i = 0; i < size; i++) { + offset = (offset + 21) % 23; + *FillAddress++ = T((MinimumFillValue + offset) / 16.0f); } - - - auto* GuardAddress = _GuardAddress; - auto* buffer = GuardAddress - Elements; - - if (ZeroFill) { - std::fill_n(buffer, Elements, MLFp16()); - } else { - constexpr float MinimumFillValue = -11.0f; - constexpr float MaximumFillValue = 11.0f; - - float FillValue = MinimumFillValue; - auto* FillAddress = buffer; - - while (FillAddress < GuardAddress) { - *FillAddress++ = FillValue/16.0f; - - FillValue+=1.0f; - - if (FillValue > MaximumFillValue) { - FillValue = MinimumFillValue; - } - } - } - - return buffer; } @@ -242,13 +174,15 @@ class MlasHalfGemmTest : public MlasTestBase { MlasHalfGemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} void Test(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { - const AType* A = BufferA.GetBuffer(K * M * BatchSize); - const BType* B = BufferB.GetBuffer(N * K * BatchSize); - const MLFp16* Bias = withBias ? BufferBias.GetBuffer(N * BatchSize) : nullptr; - MLFp16* C = BufferC.GetBuffer(N * M * BatchSize); - float* CReference = BufferCReference.GetBuffer(N * M * BatchSize); - - std::fill_n(CReference, M * N * BatchSize, float(-1.0)); + const AType* A = BufferA.GetFilledBuffer(K * M * BatchSize, SmallFloatFill); + const BType* B = BufferB.GetFilledBuffer(N * K * BatchSize, SmallFloatFill); + const MLFp16* Bias = withBias ? BufferBias.GetFilledBuffer(N * BatchSize, SmallFloatFill) : nullptr; + MLFp16* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); + float* CReference = BufferCReference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); this->CallGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N); ReferenceQgemm(M, N, K, BatchSize, A, B, Bias, CReference); diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h index c3d97eb3bb402..d3fc735260742 100644 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ b/onnxruntime/test/mlas/unittest/test_util.h @@ -50,7 +50,7 @@ class MatrixGuardBuffer { ReleaseBuffer(); } - T* GetBuffer(size_t Elements, bool ZeroFill = false) { + T* GetFilledBuffer(size_t Elements, std::function const & fillFunc) { // // Check if the internal buffer needs to be reallocated. // @@ -105,29 +105,38 @@ class MatrixGuardBuffer { T* GuardAddress = _GuardAddress; T* buffer = GuardAddress - Elements; + fillFunc(buffer, Elements); - if (ZeroFill) { - std::fill_n(buffer, Elements, T(0)); - - } else { - constexpr int MinimumFillValue = -23; - constexpr int MaximumFillValue = 23; + return buffer; + } - int FillValue = MinimumFillValue; - T* FillAddress = buffer; + T* GetBuffer(size_t Elements, bool ZeroFill = false) { + if (ZeroFill) { + return GetFilledBuffer( + Elements, + [](T* start, size_t size) { + std::fill_n(start, size, T(0)); + }); + } - while (FillAddress < GuardAddress) { - *FillAddress++ = (T)FillValue; + return GetFilledBuffer( + Elements, + [](T* start, size_t size) { + constexpr int MinimumFillValue = -23; + constexpr int MaximumFillValue = 23; - FillValue++; + int FillValue = MinimumFillValue; + T* FillAddress = start; + for (size_t i = 0; i < size; i++) { + *FillAddress++ = (T)FillValue; - if (FillValue > MaximumFillValue) { - FillValue = MinimumFillValue; - } - } - } + FillValue++; - return buffer; + if (FillValue > MaximumFillValue) { + FillValue = MinimumFillValue; + } + } + }); } void ReleaseBuffer(void) { From 300c1d925714c03371561f878057563e124f5500 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 26 Jan 2023 14:40:43 -0800 Subject: [PATCH 07/19] half gemm kernel adjust --- .../mlas/lib/arm64/HalfGemmKernelNeon.asm | 230 +++++++++--------- 1 file changed, 115 insertions(+), 115 deletions(-) diff --git a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm index 4680ff976bb09..a0e6276eecfe4 100644 --- a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm +++ b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm @@ -66,35 +66,35 @@ Arguments: LEAF_ENTRY MlasHalfGemmKernelNeon PROLOG_SAVE_REG x19,#-HGemmKernelFrame_SavedRegs! - ldr x15,[sp,#HGemmKernelFrame_ldb] + ldr x9,[sp,#HGemmKernelFrame_ldb] lsl x2,x2,#1 // k *= sizeof(fp16) - CMP x0, 2 // if M < 2 - add x9,x6,x7,lsl #1 // a1 = a0 + lda - add x16,x3,x4,lsl #1 // c1 = c0 + ldc + cmp x0,2 + add x14,x6,x7,lsl #1 // a1 = a0 + lda + add x10,x3,x4,lsl #1 // c1 = c0 + ldc ldr x8,[sp,#HGemmKernelFrame_B] - CSEL x9, x6, x9, LO // a1 = a0 - CSEL x16, x3, x16, LO // c1 = c0 - add x10,x9,x7,lsl #1 // a2 = a1 + lda - add x17,x16,x4,lsl #1 // c2 = c1 + ldc - CSEL x10, x9, x10, LS // if M <= 2 a2 = a1 - CSEL x17, x16, x17, LS // c2 = c1 - CMP x0, 4 // if M < 4 - add x11,x10,x7,lsl #1 // a3 = a2 + lda - add x14,x17,x4,lsl #1 // c3 = c2 + ldc - CSEL x11, x10, x11, LO // a3 = a2 - CSEL x14, x17, x14, LO // c3 = c2 - add x12,x11,x7,lsl #1 // a4 = a3 + lda - add x13,x14,x4,lsl #1 // c4 = c3 + ldc - CSEL x12, x11, x12, LS // if M <= 4 a4 = a3 - CSEL x13, x14, x13, LS // c4 = c3 - CMP x0, 6 // if M < 6 - add x7,x12,x7,lsl #1 // a5 = a4 + lda + csel x14,x6,x14,LO // M < 2 ? a1 = a0 + csel x10,x3,x10,LO // c1 = c0 + add x15,x14,x7,lsl #1 // a2 = a1 + lda + add x11,x10,x4,lsl #1 // c2 = c1 + ldc + csel x15,x14,x15,LS // M <= 2 ? a2 = a1 + csel x11,x10,x11,LS // c2 = c1 + cmp x0,4 + add x16,x15,x7,lsl #1 // a3 = a2 + lda + add x12,x11,x4,lsl #1 // c3 = c2 + ldc + csel x16,x15,x16,LO // M < 4 ? a3 = a2 + csel x12,x11,x12,LO // c3 = c2 + add x17,x16,x7,lsl #1 // a4 = a3 + lda + add x13,x12,x4,lsl #1 // c4 = c3 + ldc + csel x17,x16,x17,LS // M <= 4 ? a4 = a3 + csel x13,x12,x13,LS // c4 = c3 + cmp x0,6 + add x7,x17,x7,lsl #1 // a5 = a4 + lda add x4,x13,x4,lsl #1 // c5 = c4 + ldc - CSEL x7, x12, x7, LO // a5 = a4 - CSEL x4, x13, x4, LO // c5 = c4 - lsl x15,x15,#1 // ldb *= sizeof(fp16) + csel x7,x17,x7,LO // M < 6 ? a5 = a4 + csel x4,x13,x4,LO // c5 = c4 + lsl x9,x9,#1 // ldb *= sizeof(fp16) ldrb w19,[sp,#HGemmKernelFrame_ZeroMode] - sub x15,x15,16 // ldb -= 16 + sub x9,x9,16 // ldb -= 16 /**** Main loop processes 6x16 tile, depth 4. @@ -107,16 +107,16 @@ Main loop processes 6x16 tile, depth 4. A 6x4 --------------------------------------- ------------------ --------------------------------------- x6 |v0.h[0]..v0.h[3]| |v20.h[0]..v20.h[7] v21.h[0]..v21.h[7]| x3 -x9 |v1.h[0]..v1.h[3]| |v22.h[0]..v22.h[7] v23.h[0]..v23.h[7]| x16 -x10 |v2.h[0]..v2.h[3]| |v24.h[0]..v24.h[7] v25.h[0]..v25.h[7]| x17 -x11 |v3.h[0]..v3.h[3]| |v26.h[0]..v26.h[7] v27.h[0]..v27.h[7]| x14 -x12 |v4.h[0]..v4.h[3]| |v28.h[0]..v28.h[7] v29.h[0]..v29.h[7]| x13 +x14 |v1.h[0]..v1.h[3]| |v22.h[0]..v22.h[7] v23.h[0]..v23.h[7]| x10 +x15 |v2.h[0]..v2.h[3]| |v24.h[0]..v24.h[7] v25.h[0]..v25.h[7]| x11 +x16 |v3.h[0]..v3.h[3]| |v26.h[0]..v26.h[7] v27.h[0]..v27.h[7]| x12 +x17 |v4.h[0]..v4.h[3]| |v28.h[0]..v28.h[7] v29.h[0]..v29.h[7]| x13 x7 |v5.h[0]..v5.h[3]| |v30.h[0]..v30.h[7] v31.h[0]..v31.h[7]| x4 ------------------ --------------------------------------- ****/ M6N16OutterLoopN - cbz x5, M6N16SkipBias + cbz x5,M6N16SkipBias ldp q20,q21,[x5],32 // Load 16 Bias values b M6N16PopulateAccumulators @@ -125,32 +125,32 @@ M6N16SkipBias eor q21.16b,q21.16b,q21.16b M6N16PopulateAccumulators - MOV v22.16b, v20.16b - MOV v23.16b, v21.16b - MOV v24.16b, v20.16b - MOV v25.16b, v21.16b - MOV v26.16b, v20.16b - MOV v27.16b, v21.16b - MOV v28.16b, v20.16b + mov v22.16b,v20.16b + mov v23.16b,v21.16b + mov v24.16b,v20.16b + mov v25.16b,v21.16b + mov v26.16b,v20.16b + mov v27.16b,v21.16b + mov v28.16b,v20.16b subs x0,x2,8 // k -= 4 (8 bytes) - MOV v29.16b, v21.16b - MOV v30.16b, v20.16b - MOV v31.16b, v21.16b - b.lo M6N16RemainderK123 // remaining k 1~3 + mov v29.16b,v21.16b + mov v30.16b,v20.16b + mov v31.16b,v21.16b + b.LO M6N16RemainderK123 // remaining k 1~3 ldr d0,[x6],8 // A0 ldr q16,[x8],16 // B0.l - ld1 {v17.16b},[x8],x15 // B0.high x8 <- next row + ld1 {v17.16b},[x8],x9 // B0.high x8 <- next row subs x0,x0,8 // over decement k -= 4 (8 bytes) - ldr d1,[x9],8 // A1 - ldr d2,[x10],8 // A2 - ldr d3,[x11],8 // A3 - b.lo M6N16LoopK_Epilogue // need k>=8 for main loop + ldr d1,[x14],8 // A1 + ldr d2,[x15],8 // A2 + ldr d3,[x16],8 // A3 + b.LO M6N16LoopK_Epilogue // need k>=8 for main loop M6N16InnerLoopK FMLA v20.8h, v16.8h, v0.h[0] FMLA v21.8h, v17.8h, v0.h[0] - LDR d4, [x12], 8 // A4 + LDR d4, [x17], 8 // A4 FMLA v22.8h, v16.8h, v1.h[0] FMLA v23.8h, v17.8h, v1.h[0] LDR d5, [x7], 8 // A5 @@ -159,7 +159,7 @@ M6N16InnerLoopK ldr q18,[x8],16 // B1.low FMLA v26.8h, v16.8h, v3.h[0] FMLA v27.8h, v17.8h, v3.h[0] - ld1 {v19.16b},[x8],x15 // B1.high x8 <- next row + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row FMLA v28.8h, v16.8h, v4.h[0] FMLA v29.8h, v17.8h, v4.h[0] FMLA v30.8h, v16.8h, v5.h[0] @@ -171,7 +171,7 @@ M6N16InnerLoopK ldr q16,[x8],16 // B2.low FMLA v22.8h, v18.8h, v1.h[1] FMLA v23.8h, v19.8h, v1.h[1] - ld1 {v17.16b},[x8],x15 // B2.high x8 <- next row + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row FMLA v24.8h, v18.8h, v2.h[1] FMLA v25.8h, v19.8h, v2.h[1] FMLA v26.8h, v18.8h, v3.h[1] @@ -186,7 +186,7 @@ M6N16InnerLoopK ldr q18,[x8],16 // B3.low FMLA v22.8h, v16.8h, v1.h[2] FMLA v23.8h, v17.8h, v1.h[2] - ld1 {v19.16b},[x8],x15 // B3.high x8 <- next row + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row FMLA v24.8h, v16.8h, v2.h[2] FMLA v25.8h, v17.8h, v2.h[2] FMLA v26.8h, v16.8h, v3.h[2] @@ -199,19 +199,19 @@ M6N16InnerLoopK ldr q16,[x8],16 // B0.low next iter FMLA v20.8h, v18.8h, v0.h[3] FMLA v21.8h, v19.8h, v0.h[3] - ld1 {v17.16b},[x8],x15 // B0.high x8 <- next row + ld1 {v17.16b},[x8],x9 // B0.high x8 <- next row FMLA v22.8h, v18.8h, v1.h[3] FMLA v23.8h, v19.8h, v1.h[3] LDR d0, [x6], 8 // A0 FMLA v24.8h, v18.8h, v2.h[3] FMLA v25.8h, v19.8h, v2.h[3] - LDR d1, [x9], 8 // A1 + LDR d1, [x14], 8 // A1 FMLA v26.8h, v18.8h, v3.h[3] FMLA v27.8h, v19.8h, v3.h[3] - LDR d2, [x10], 8 // A2 + LDR d2, [x15], 8 // A2 FMLA v28.8h, v18.8h, v4.h[3] FMLA v29.8h, v19.8h, v4.h[3] - LDR d3, [x11], 8 // A3 + LDR d3, [x16], 8 // A3 FMLA v30.8h, v18.8h, v5.h[3] FMLA v31.8h, v19.8h, v5.h[3] b.hs M6N16InnerLoopK // k >= 8 for main loop @@ -220,7 +220,7 @@ M6N16LoopK_Epilogue // last block of k >= 4, no pre-load for next iter FMLA v20.8h, v16.8h, v0.h[0] FMLA v21.8h, v17.8h, v0.h[0] - LDR d4, [x12], 8 // A4 + LDR d4, [x17], 8 // A4 FMLA v22.8h, v16.8h, v1.h[0] FMLA v23.8h, v17.8h, v1.h[0] LDR d5, [x7], 8 // A5 @@ -229,7 +229,7 @@ M6N16LoopK_Epilogue ldr q18,[x8],16 // B1.low FMLA v26.8h, v16.8h, v3.h[0] FMLA v27.8h, v17.8h, v3.h[0] - ld1 {v19.16b},[x8],x15 // B1.high x8 <- next row + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row FMLA v28.8h, v16.8h, v4.h[0] FMLA v29.8h, v17.8h, v4.h[0] FMLA v30.8h, v16.8h, v5.h[0] @@ -241,7 +241,7 @@ M6N16LoopK_Epilogue ldr q16,[x8],16 // B2.low FMLA v22.8h, v18.8h, v1.h[1] FMLA v23.8h, v19.8h, v1.h[1] - ld1 {v17.16b},[x8],x15 // B2.high x8 <- next row + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row FMLA v24.8h, v18.8h, v2.h[1] FMLA v25.8h, v19.8h, v2.h[1] FMLA v26.8h, v18.8h, v3.h[1] @@ -256,7 +256,7 @@ M6N16LoopK_Epilogue ldr q18,[x8],16 // B3.low FMLA v22.8h, v16.8h, v1.h[2] FMLA v23.8h, v17.8h, v1.h[2] - ld1 {v19.16b},[x8],x15 // B3.high x8 <- next row + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row FMLA v24.8h, v16.8h, v2.h[2] FMLA v25.8h, v17.8h, v2.h[2] FMLA v26.8h, v16.8h, v3.h[2] @@ -287,9 +287,9 @@ M6N16NextIterN cbnz x19,M6N16SkipAccumulateOutput ldp q0,q1,[x3] - ldp q2,q3,[x16] - ldp q4,q5,[x17] - ldp q6,q7,[x14] + ldp q2,q3,[x10] + ldp q4,q5,[x11] + ldp q6,q7,[x12] ldp q16,q17,[x13] ldp q18,q19,[x4] fadd v20.8h,v20.8h,v0.8h @@ -308,14 +308,14 @@ M6N16NextIterN M6N16SkipAccumulateOutput ST1 {v20.16b, v21.16b}, [x3], 32 SUB x6, x6, x2 // a0 -= k - ST1 {v22.16b, v23.16b}, [x16], 32 - SUB x9, x9, x2 // a1 -= k - ST1 {v24.16b, v25.16b}, [x17], 32 - SUB x10, x10, x2 // a2 -= k - ST1 {v26.16b, v27.16b}, [x14], 32 - SUB x11, x11, x2 // a3 -= k + ST1 {v22.16b, v23.16b}, [x10], 32 + SUB x14, x14, x2 // a1 -= k + ST1 {v24.16b, v25.16b}, [x11], 32 + SUB x15, x15, x2 // a2 -= k + ST1 {v26.16b, v27.16b}, [x12], 32 + SUB x16, x16, x2 // a3 -= k ST1 {v28.16b, v29.16b}, [x13], 32 - SUB x12, x12, x2 // a4 -= k + SUB x17, x17, x2 // a4 -= k add x8,x8,32 // B <- next 16 columns ST1 {v30.16b, v31.16b}, [x4], 32 SUB x7, x7, x2 // a5 -= k @@ -330,14 +330,14 @@ M6N16RemainderK123 TBZ x0, 2, M6N16RemainderK1 LDR s0, [x6], 4 LDR q16, [x8], 16 - ld1 {v17.16b},[x8],x15 - LDR s1, [x9], 4 - LDR s2, [x10], 4 - LDR s3, [x11], 4 - LDR s4, [x12], 4 + ld1 {v17.16b},[x8],x9 + LDR s1, [x14], 4 + LDR s2, [x15], 4 + LDR s3, [x16], 4 + LDR s4, [x17], 4 LDR s5, [x7], 4 LDR q18, [x8], 16 - ld1 {v19.16b},[x8],x15 + ld1 {v19.16b},[x8],x9 FMLA v20.8h, v16.8h, v0.h[0] FMLA v22.8h, v16.8h, v1.h[0] FMLA v24.8h, v16.8h, v2.h[0] @@ -368,11 +368,11 @@ M6N16RemainderK123 M6N16RemainderK1 LDR h0, [x6], 2 LDR q16, [x8], 16 - ld1 {v17.16b},[x8],x15 - LDR h1, [x9], 2 - LDR h2, [x10], 2 - LDR h3, [x11], 2 - LDR h4, [x12], 2 + ld1 {v17.16b},[x8],x9 + LDR h1, [x14], 2 + LDR h2, [x15], 2 + LDR h3, [x16], 2 + LDR h4, [x17], 2 LDR h5, [x7], 2 FMLA v20.8h, v16.8h, v0.h[0] FMLA v22.8h, v16.8h, v1.h[0] @@ -392,9 +392,9 @@ M6StoreRemainderN cbnz x19,M6StoreRemainderNZeroMode TBZ x1, 3, M6StoreRemainderN4 ldr q0,[x3] - ldr q1,[x16] - ldr q2,[x17] - ldr q3,[x14] + ldr q1,[x10] + ldr q2,[x11] + ldr q3,[x12] ldr q4,[x13] ldr q5,[x4] fadd v20.8h,v20.8h,v0.8h @@ -402,14 +402,14 @@ M6StoreRemainderN fadd v24.8h,v24.8h,v2.8h STR q20, [x3], 16 MOV v20.16b, v21.16b - STR q22, [x16], 16 + STR q22, [x10], 16 MOV v22.16b, v23.16b - STR q24, [x17], 16 + STR q24, [x11], 16 MOV v24.16b, v25.16b fadd v26.8h,v26.8h,v3.8h fadd v28.8h,v28.8h,v4.8h fadd v30.8h,v30.8h,v5.8h - STR q26, [x14], 16 + STR q26, [x12], 16 MOV v26.16b, v27.16b STR q28, [x13], 16 MOV v28.16b, v29.16b @@ -419,9 +419,9 @@ M6StoreRemainderN M6StoreRemainderN4 TBZ x1, 2, M6StoreRemainderN2 ldr d0,[x3] - ldr d1,[x16] - ldr d2,[x17] - ldr d3,[x14] + ldr d1,[x10] + ldr d2,[x11] + ldr d3,[x12] ldr d4,[x13] ldr d5,[x4] fadd v21.4h,v20.4h,v0.4h @@ -437,18 +437,18 @@ M6StoreRemainderN4 fadd v31.4h,v30.4h,v5.4h DUP d30, v30.d[1] STR d21, [x3], 8 - STR d23, [x16], 8 - STR d25, [x17], 8 - STR d27, [x14], 8 + STR d23, [x10], 8 + STR d25, [x11], 8 + STR d27, [x12], 8 STR d29, [x13], 8 STR d31, [x4], 8 M6StoreRemainderN2 TBZ x1, 1, M6StoreRemainderN1 ldr s0,[x3] - ldr s1,[x16] - ldr s2,[x17] - ldr s3,[x14] + ldr s1,[x10] + ldr s2,[x11] + ldr s3,[x12] ldr s4,[x13] ldr s5,[x4] fadd v21.4h,v20.4h,v0.4h @@ -458,11 +458,11 @@ M6StoreRemainderN2 fadd v29.4h,v28.4h,v4.4h fadd v31.4h,v30.4h,v5.4h STR s21, [x3], 4 - STR s23, [x16], 4 + STR s23, [x10], 4 DUP s20, v20.s[1] DUP s22, v22.s[1] - STR s25, [x17], 4 - STR s27, [x14], 4 + STR s25, [x11], 4 + STR s27, [x12], 4 DUP s24, v24.s[1] DUP s26, v26.s[1] STR s29, [x13], 4 @@ -473,9 +473,9 @@ M6StoreRemainderN2 M6StoreRemainderN1 TBZ x1, 0, ExitKernel ldr h0,[x3] - ldr h1,[x16] - ldr h2,[x17] - ldr h3,[x14] + ldr h1,[x10] + ldr h2,[x11] + ldr h3,[x12] ldr h4,[x13] ldr h5,[x4] fadd v20.4h,v20.4h,v0.4h @@ -485,9 +485,9 @@ M6StoreRemainderN1 fadd v28.4h,v28.4h,v4.4h fadd v30.4h,v30.4h,v5.4h STR h20, [x3] - STR h22, [x16] - STR h24, [x17] - STR h26, [x14] + STR h22, [x10] + STR h24, [x11] + STR h26, [x12] STR h28, [x13] STR h30, [x4] b ExitKernel @@ -496,11 +496,11 @@ M6StoreRemainderNZeroMode TBZ x1, 3, M6StoreRemainderN4ZeroMode STR q20, [x3], 16 MOV v20.16b, v21.16b - STR q22, [x16], 16 + STR q22, [x10], 16 MOV v22.16b, v23.16b - STR q24, [x17], 16 + STR q24, [x11], 16 MOV v24.16b, v25.16b - STR q26, [x14], 16 + STR q26, [x12], 16 MOV v26.16b, v27.16b STR q28, [x13], 16 MOV v28.16b, v29.16b @@ -510,11 +510,11 @@ M6StoreRemainderNZeroMode M6StoreRemainderN4ZeroMode TBZ x1, 2, M6StoreRemainderN2ZeroMode STR d20, [x3], 8 - STR d22, [x16], 8 + STR d22, [x10], 8 DUP d20, v20.d[1] DUP d22, v22.d[1] - STR d24, [x17], 8 - STR d26, [x14], 8 + STR d24, [x11], 8 + STR d26, [x12], 8 DUP d24, v24.d[1] DUP d26, v26.d[1] STR d28, [x13], 8 @@ -525,11 +525,11 @@ M6StoreRemainderN4ZeroMode M6StoreRemainderN2ZeroMode TBZ x1, 1, M6StoreRemainderN1ZeroMode STR s20, [x3], 4 - STR s22, [x16], 4 + STR s22, [x10], 4 DUP s20, v20.s[1] DUP s22, v22.s[1] - STR s24, [x17], 4 - STR s26, [x14], 4 + STR s24, [x11], 4 + STR s26, [x12], 4 DUP s24, v24.s[1] DUP s26, v26.s[1] STR s28, [x13], 4 @@ -540,9 +540,9 @@ M6StoreRemainderN2ZeroMode M6StoreRemainderN1ZeroMode TBZ x1, 0, ExitKernel STR h20, [x3] - STR h22, [x16] - STR h24, [x17] - STR h26, [x14] + STR h22, [x10] + STR h24, [x11] + STR h26, [x12] STR h28, [x13] STR h30, [x4] b ExitKernel From dd8c90f2dcac44e8b3d4ca691fee1062899f02f4 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Fri, 27 Jan 2023 08:58:11 -0800 Subject: [PATCH 08/19] kernel adjust 1 --- .../mlas/lib/arm64/HalfGemmKernelNeon.asm | 582 +++++++++--------- 1 file changed, 291 insertions(+), 291 deletions(-) diff --git a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm index a0e6276eecfe4..f4f26da711097 100644 --- a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm +++ b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm @@ -148,142 +148,142 @@ M6N16PopulateAccumulators b.LO M6N16LoopK_Epilogue // need k>=8 for main loop M6N16InnerLoopK - FMLA v20.8h, v16.8h, v0.h[0] - FMLA v21.8h, v17.8h, v0.h[0] - LDR d4, [x17], 8 // A4 - FMLA v22.8h, v16.8h, v1.h[0] - FMLA v23.8h, v17.8h, v1.h[0] - LDR d5, [x7], 8 // A5 - FMLA v24.8h, v16.8h, v2.h[0] - FMLA v25.8h, v17.8h, v2.h[0] + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] ldr q18,[x8],16 // B1.low - FMLA v26.8h, v16.8h, v3.h[0] - FMLA v27.8h, v17.8h, v3.h[0] - ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row - FMLA v28.8h, v16.8h, v4.h[0] - FMLA v29.8h, v17.8h, v4.h[0] - FMLA v30.8h, v16.8h, v5.h[0] - FMLA v31.8h, v17.8h, v5.h[0] - subs x0,x0,8 // k -= 4 - - FMLA v20.8h, v18.8h, v0.h[1] - FMLA v21.8h, v19.8h, v0.h[1] + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + subs x0,x0,8 // k -= 4 + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] ldr q16,[x8],16 // B2.low - FMLA v22.8h, v18.8h, v1.h[1] - FMLA v23.8h, v19.8h, v1.h[1] - ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row - FMLA v24.8h, v18.8h, v2.h[1] - FMLA v25.8h, v19.8h, v2.h[1] - FMLA v26.8h, v18.8h, v3.h[1] - FMLA v27.8h, v19.8h, v3.h[1] - FMLA v28.8h, v18.8h, v4.h[1] - FMLA v29.8h, v19.8h, v4.h[1] - FMLA v30.8h, v18.8h, v5.h[1] - FMLA v31.8h, v19.8h, v5.h[1] - - FMLA v20.8h, v16.8h, v0.h[2] - FMLA v21.8h, v17.8h, v0.h[2] + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] ldr q18,[x8],16 // B3.low - FMLA v22.8h, v16.8h, v1.h[2] - FMLA v23.8h, v17.8h, v1.h[2] - ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row - FMLA v24.8h, v16.8h, v2.h[2] - FMLA v25.8h, v17.8h, v2.h[2] - FMLA v26.8h, v16.8h, v3.h[2] - FMLA v27.8h, v17.8h, v3.h[2] - FMLA v28.8h, v16.8h, v4.h[2] - FMLA v29.8h, v17.8h, v4.h[2] - FMLA v30.8h, v16.8h, v5.h[2] - FMLA v31.8h, v17.8h, v5.h[2] - - ldr q16,[x8],16 // B0.low next iter - FMLA v20.8h, v18.8h, v0.h[3] - FMLA v21.8h, v19.8h, v0.h[3] - ld1 {v17.16b},[x8],x9 // B0.high x8 <- next row - FMLA v22.8h, v18.8h, v1.h[3] - FMLA v23.8h, v19.8h, v1.h[3] - LDR d0, [x6], 8 // A0 - FMLA v24.8h, v18.8h, v2.h[3] - FMLA v25.8h, v19.8h, v2.h[3] - LDR d1, [x14], 8 // A1 - FMLA v26.8h, v18.8h, v3.h[3] - FMLA v27.8h, v19.8h, v3.h[3] - LDR d2, [x15], 8 // A2 - FMLA v28.8h, v18.8h, v4.h[3] - FMLA v29.8h, v19.8h, v4.h[3] - LDR d3, [x16], 8 // A3 - FMLA v30.8h, v18.8h, v5.h[3] - FMLA v31.8h, v19.8h, v5.h[3] - b.hs M6N16InnerLoopK // k >= 8 for main loop + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + ldr q16,[x8],16 // Load B0.low for next iter + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + ld1 {v17.16b},[x8],x9 // Load B0.high for next iter + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + ldr d0,[x6],8 // Load A0 for next iter + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + ldr d1,[x14],8 // Load A1 for next iter + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + ldr d2,[x15],8 // Load A2 for next iter + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + ldr d3,[x16],8 // Load A3 for next iter + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.hs M6N16InnerLoopK // k >= 8 for main loop M6N16LoopK_Epilogue // last block of k >= 4, no pre-load for next iter - FMLA v20.8h, v16.8h, v0.h[0] - FMLA v21.8h, v17.8h, v0.h[0] - LDR d4, [x17], 8 // A4 - FMLA v22.8h, v16.8h, v1.h[0] - FMLA v23.8h, v17.8h, v1.h[0] - LDR d5, [x7], 8 // A5 - FMLA v24.8h, v16.8h, v2.h[0] - FMLA v25.8h, v17.8h, v2.h[0] + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] ldr q18,[x8],16 // B1.low - FMLA v26.8h, v16.8h, v3.h[0] - FMLA v27.8h, v17.8h, v3.h[0] - ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row - FMLA v28.8h, v16.8h, v4.h[0] - FMLA v29.8h, v17.8h, v4.h[0] - FMLA v30.8h, v16.8h, v5.h[0] - FMLA v31.8h, v17.8h, v5.h[0] - adds x0,x0,8 // revert k over-decrement - - FMLA v20.8h, v18.8h, v0.h[1] - FMLA v21.8h, v19.8h, v0.h[1] + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + adds x0,x0,8 // revert k over-decrement + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] ldr q16,[x8],16 // B2.low - FMLA v22.8h, v18.8h, v1.h[1] - FMLA v23.8h, v19.8h, v1.h[1] - ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row - FMLA v24.8h, v18.8h, v2.h[1] - FMLA v25.8h, v19.8h, v2.h[1] - FMLA v26.8h, v18.8h, v3.h[1] - FMLA v27.8h, v19.8h, v3.h[1] - FMLA v28.8h, v18.8h, v4.h[1] - FMLA v29.8h, v19.8h, v4.h[1] - FMLA v30.8h, v18.8h, v5.h[1] - FMLA v31.8h, v19.8h, v5.h[1] - - FMLA v20.8h, v16.8h, v0.h[2] - FMLA v21.8h, v17.8h, v0.h[2] + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] ldr q18,[x8],16 // B3.low - FMLA v22.8h, v16.8h, v1.h[2] - FMLA v23.8h, v17.8h, v1.h[2] - ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row - FMLA v24.8h, v16.8h, v2.h[2] - FMLA v25.8h, v17.8h, v2.h[2] - FMLA v26.8h, v16.8h, v3.h[2] - FMLA v27.8h, v17.8h, v3.h[2] - FMLA v28.8h, v16.8h, v4.h[2] - FMLA v29.8h, v17.8h, v4.h[2] - FMLA v30.8h, v16.8h, v5.h[2] - FMLA v31.8h, v17.8h, v5.h[2] - - FMLA v20.8h, v18.8h, v0.h[3] - FMLA v21.8h, v19.8h, v0.h[3] - FMLA v22.8h, v18.8h, v1.h[3] - FMLA v23.8h, v19.8h, v1.h[3] - FMLA v24.8h, v18.8h, v2.h[3] - FMLA v25.8h, v19.8h, v2.h[3] - FMLA v26.8h, v18.8h, v3.h[3] - FMLA v27.8h, v19.8h, v3.h[3] - FMLA v28.8h, v18.8h, v4.h[3] - FMLA v29.8h, v19.8h, v4.h[3] - FMLA v30.8h, v18.8h, v5.h[3] - FMLA v31.8h, v19.8h, v5.h[3] - B.NE M6N16RemainderK123 // remaining k 1~3 - -M6N16NextIterN - SUBS x1, x1, 16 + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.NE M6N16RemainderK123 // remaining k 1~3 + +M6N16OutterLoopNTail + subs x1,x1,16 // N -= 16 ldr x8,[sp,#HGemmKernelFrame_B] - B.LO M6StoreRemainderN + b.LO M6StoreRemainderN // remaining k < 16 cbnz x19,M6N16SkipAccumulateOutput ldp q0,q1,[x3] @@ -292,8 +292,8 @@ M6N16NextIterN ldp q6,q7,[x12] ldp q16,q17,[x13] ldp q18,q19,[x4] - fadd v20.8h,v20.8h,v0.8h - fadd v21.8h,v21.8h,v1.8h + fadd v20.8h,v20.8h,v0.8h // !ZeroMode + fadd v21.8h,v21.8h,v1.8h // accumulate into C fadd v22.8h,v22.8h,v2.8h fadd v23.8h,v23.8h,v3.8h fadd v24.8h,v24.8h,v4.8h @@ -306,91 +306,91 @@ M6N16NextIterN fadd v31.8h,v31.8h,v19.8h M6N16SkipAccumulateOutput - ST1 {v20.16b, v21.16b}, [x3], 32 - SUB x6, x6, x2 // a0 -= k - ST1 {v22.16b, v23.16b}, [x10], 32 - SUB x14, x14, x2 // a1 -= k - ST1 {v24.16b, v25.16b}, [x11], 32 - SUB x15, x15, x2 // a2 -= k - ST1 {v26.16b, v27.16b}, [x12], 32 - SUB x16, x16, x2 // a3 -= k - ST1 {v28.16b, v29.16b}, [x13], 32 - SUB x17, x17, x2 // a4 -= k + st1 {v20.16b,v21.16b},[x3],32 + sub x6,x6,x2 // restore a0 + st1 {v22.16b,v23.16b},[x10],32 + sub x14,x14,x2 // restore a1 + st1 {v24.16b,v25.16b},[x11],32 + sub x15,x15,x2 // restore a2 + st1 {v26.16b,v27.16b},[x12],32 + sub x16,x16,x2 // restore a3 + st1 {v28.16b,v29.16b},[x13],32 + sub x17,x17,x2 // restore a4 add x8,x8,32 // B <- next 16 columns - ST1 {v30.16b, v31.16b}, [x4], 32 - SUB x7, x7, x2 // a5 -= k + st1 {v30.16b,v31.16b},[x4],32 + sub x7,x7,x2 // restore a5 str x8,[sp,#HGemmKernelFrame_B] - B.HI M6N16OutterLoopN + b.HI M6N16OutterLoopN ExitKernel EPILOG_RESTORE_REG x19,#HGemmKernelFrame_SavedRegs! EPILOG_RETURN M6N16RemainderK123 - TBZ x0, 2, M6N16RemainderK1 - LDR s0, [x6], 4 - LDR q16, [x8], 16 - ld1 {v17.16b},[x8],x9 - LDR s1, [x14], 4 - LDR s2, [x15], 4 - LDR s3, [x16], 4 - LDR s4, [x17], 4 - LDR s5, [x7], 4 - LDR q18, [x8], 16 - ld1 {v19.16b},[x8],x9 - FMLA v20.8h, v16.8h, v0.h[0] - FMLA v22.8h, v16.8h, v1.h[0] - FMLA v24.8h, v16.8h, v2.h[0] - FMLA v26.8h, v16.8h, v3.h[0] - FMLA v28.8h, v16.8h, v4.h[0] - FMLA v30.8h, v16.8h, v5.h[0] - FMLA v21.8h, v17.8h, v0.h[0] - FMLA v23.8h, v17.8h, v1.h[0] - FMLA v25.8h, v17.8h, v2.h[0] - FMLA v27.8h, v17.8h, v3.h[0] - FMLA v29.8h, v17.8h, v4.h[0] - FMLA v31.8h, v17.8h, v5.h[0] - - FMLA v20.8h, v18.8h, v0.h[1] - FMLA v22.8h, v18.8h, v1.h[1] - FMLA v24.8h, v18.8h, v2.h[1] - FMLA v26.8h, v18.8h, v3.h[1] - FMLA v28.8h, v18.8h, v4.h[1] - FMLA v30.8h, v18.8h, v5.h[1] - FMLA v21.8h, v19.8h, v0.h[1] - FMLA v23.8h, v19.8h, v1.h[1] - FMLA v25.8h, v19.8h, v2.h[1] - FMLA v27.8h, v19.8h, v3.h[1] - FMLA v29.8h, v19.8h, v4.h[1] - FMLA v31.8h, v19.8h, v5.h[1] - TBZ x0, 1, M6N16NextIterN + tbz x0,2,M6N16RemainderK1 + ldr s0,[x6],4 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr s1,[x14],4 // A1 + ldr s2,[x15],4 // A2 + ldr s3,[x16],4 // A3 + ldr s4,[x17],4 // A4 + ldr s5,[x7],4 // A5 + ldr q18,[x8],16 // B1.low + ld1 {v19.16b},[x8],x9 // B2.high + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + + fmla v20.8h,v18.8h,v0.h[1] + fmla v22.8h,v18.8h,v1.h[1] + fmla v24.8h,v18.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v21.8h,v19.8h,v0.h[1] + fmla v23.8h,v19.8h,v1.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v31.8h,v19.8h,v5.h[1] + tbz x0,1,M6N16OutterLoopNTail M6N16RemainderK1 - LDR h0, [x6], 2 - LDR q16, [x8], 16 - ld1 {v17.16b},[x8],x9 - LDR h1, [x14], 2 - LDR h2, [x15], 2 - LDR h3, [x16], 2 - LDR h4, [x17], 2 - LDR h5, [x7], 2 - FMLA v20.8h, v16.8h, v0.h[0] - FMLA v22.8h, v16.8h, v1.h[0] - FMLA v24.8h, v16.8h, v2.h[0] - FMLA v26.8h, v16.8h, v3.h[0] - FMLA v28.8h, v16.8h, v4.h[0] - FMLA v30.8h, v16.8h, v5.h[0] - FMLA v21.8h, v17.8h, v0.h[0] - FMLA v23.8h, v17.8h, v1.h[0] - FMLA v25.8h, v17.8h, v2.h[0] - FMLA v27.8h, v17.8h, v3.h[0] - FMLA v29.8h, v17.8h, v4.h[0] - FMLA v31.8h, v17.8h, v5.h[0] - B M6N16NextIterN + ldr h0,[x6],2 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr h1,[x14],2 // A1 + ldr h2,[x15],2 // A2 + ldr h3,[x16],2 // A3 + ldr h4,[x17],2 // A4 + ldr h5,[x7],2 // A5 + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + b M6N16OutterLoopNTail M6StoreRemainderN cbnz x19,M6StoreRemainderNZeroMode - TBZ x1, 3, M6StoreRemainderN4 + tbz x1,3,M6StoreRemainderN4 ldr q0,[x3] ldr q1,[x10] ldr q2,[x11] @@ -400,24 +400,24 @@ M6StoreRemainderN fadd v20.8h,v20.8h,v0.8h fadd v22.8h,v22.8h,v1.8h fadd v24.8h,v24.8h,v2.8h - STR q20, [x3], 16 - MOV v20.16b, v21.16b - STR q22, [x10], 16 - MOV v22.16b, v23.16b - STR q24, [x11], 16 - MOV v24.16b, v25.16b + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b fadd v26.8h,v26.8h,v3.8h fadd v28.8h,v28.8h,v4.8h fadd v30.8h,v30.8h,v5.8h - STR q26, [x12], 16 - MOV v26.16b, v27.16b - STR q28, [x13], 16 - MOV v28.16b, v29.16b - STR q30, [x4], 16 - MOV v30.16b, v31.16b + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b M6StoreRemainderN4 - TBZ x1, 2, M6StoreRemainderN2 + tbz x1,2,M6StoreRemainderN2 ldr d0,[x3] ldr d1,[x10] ldr d2,[x11] @@ -425,26 +425,26 @@ M6StoreRemainderN4 ldr d4,[x13] ldr d5,[x4] fadd v21.4h,v20.4h,v0.4h - DUP d20, v20.d[1] + dup d20,v20.d[1] fadd v23.4h,v22.4h,v1.4h - DUP d22, v22.d[1] + dup d22,v22.d[1] fadd v25.4h,v24.4h,v2.4h - DUP d24, v24.d[1] + dup d24,v24.d[1] fadd v27.4h,v26.4h,v3.4h - DUP d26, v26.d[1] + dup d26,v26.d[1] fadd v29.4h,v28.4h,v4.4h - DUP d28, v28.d[1] + dup d28,v28.d[1] fadd v31.4h,v30.4h,v5.4h - DUP d30, v30.d[1] - STR d21, [x3], 8 - STR d23, [x10], 8 - STR d25, [x11], 8 - STR d27, [x12], 8 - STR d29, [x13], 8 - STR d31, [x4], 8 + dup d30,v30.d[1] + str d21,[x3],8 + str d23,[x10],8 + str d25,[x11],8 + str d27,[x12],8 + str d29,[x13],8 + str d31,[x4],8 M6StoreRemainderN2 - TBZ x1, 1, M6StoreRemainderN1 + tbz x1,1,M6StoreRemainderN1 ldr s0,[x3] ldr s1,[x10] ldr s2,[x11] @@ -457,21 +457,21 @@ M6StoreRemainderN2 fadd v27.4h,v26.4h,v3.4h fadd v29.4h,v28.4h,v4.4h fadd v31.4h,v30.4h,v5.4h - STR s21, [x3], 4 - STR s23, [x10], 4 - DUP s20, v20.s[1] - DUP s22, v22.s[1] - STR s25, [x11], 4 - STR s27, [x12], 4 - DUP s24, v24.s[1] - DUP s26, v26.s[1] - STR s29, [x13], 4 - STR s31, [x4], 4 - DUP s28, v28.s[1] - DUP s30, v30.s[1] + str s21,[x3],4 + str s23,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s25,[x11],4 + str s27,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s29,[x13],4 + str s31,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] M6StoreRemainderN1 - TBZ x1, 0, ExitKernel + tbz x1,0,ExitKernel ldr h0,[x3] ldr h1,[x10] ldr h2,[x11] @@ -484,67 +484,67 @@ M6StoreRemainderN1 fadd v26.4h,v26.4h,v3.4h fadd v28.4h,v28.4h,v4.4h fadd v30.4h,v30.4h,v5.4h - STR h20, [x3] - STR h22, [x10] - STR h24, [x11] - STR h26, [x12] - STR h28, [x13] - STR h30, [x4] + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] b ExitKernel M6StoreRemainderNZeroMode - TBZ x1, 3, M6StoreRemainderN4ZeroMode - STR q20, [x3], 16 - MOV v20.16b, v21.16b - STR q22, [x10], 16 - MOV v22.16b, v23.16b - STR q24, [x11], 16 - MOV v24.16b, v25.16b - STR q26, [x12], 16 - MOV v26.16b, v27.16b - STR q28, [x13], 16 - MOV v28.16b, v29.16b - STR q30, [x4], 16 - MOV v30.16b, v31.16b + tbz x1,3,M6StoreRemainderN4ZeroMode + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b M6StoreRemainderN4ZeroMode - TBZ x1, 2, M6StoreRemainderN2ZeroMode - STR d20, [x3], 8 - STR d22, [x10], 8 - DUP d20, v20.d[1] - DUP d22, v22.d[1] - STR d24, [x11], 8 - STR d26, [x12], 8 - DUP d24, v24.d[1] - DUP d26, v26.d[1] - STR d28, [x13], 8 - STR d30, [x4], 8 - DUP d28, v28.d[1] - DUP d30, v30.d[1] + tbz x1,2,M6StoreRemainderN2ZeroMode + str d20,[x3],8 + str d22,[x10],8 + dup d20,v20.d[1] + dup d22,v22.d[1] + str d24,[x11],8 + str d26,[x12],8 + dup d24,v24.d[1] + dup d26,v26.d[1] + str d28,[x13],8 + str d30,[x4],8 + dup d28,v28.d[1] + dup d30,v30.d[1] M6StoreRemainderN2ZeroMode - TBZ x1, 1, M6StoreRemainderN1ZeroMode - STR s20, [x3], 4 - STR s22, [x10], 4 - DUP s20, v20.s[1] - DUP s22, v22.s[1] - STR s24, [x11], 4 - STR s26, [x12], 4 - DUP s24, v24.s[1] - DUP s26, v26.s[1] - STR s28, [x13], 4 - STR s30, [x4], 4 - DUP s28, v28.s[1] - DUP s30, v30.s[1] + tbz x1,1,M6StoreRemainderN1ZeroMode + str s20,[x3],4 + str s22,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s24,[x11],4 + str s26,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s28,[x13],4 + str s30,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] M6StoreRemainderN1ZeroMode - TBZ x1, 0, ExitKernel - STR h20, [x3] - STR h22, [x10] - STR h24, [x11] - STR h26, [x12] - STR h28, [x13] - STR h30, [x4] + tbz x1,0,ExitKernel + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] b ExitKernel LEAF_END MlasHalfGemmKernelNeon From 4e634d8346e92fe0628eda63cef715c793eaec8c Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Fri, 27 Jan 2023 11:20:52 -0800 Subject: [PATCH 09/19] reorg conversion code --- .../core/mlas/lib/halfgemm_kernel_neon.cpp | 41 +++++++++---- .../test/mlas/unittest/test_halfgemm.cpp | 20 ------- .../test/mlas/unittest/test_halfgemm.h | 59 +++++++++++++++---- 3 files changed, 79 insertions(+), 41 deletions(-) diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp index def3f7732fb41..f3dea5ffe5f7d 100644 --- a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp @@ -52,6 +52,35 @@ struct MLAS_HALF_GEMM_KERNEL_NEON { static constexpr MLAS_HALF_GEMM_STRIDES Strides{24, 128, 16}; }; +/** + * @brief Convert a 2D matrix from float to fp16 +*/ +MLAS_FORCEINLINE +void +CvtFloat2Half2D( + _mlas_fp16_* dest, + const float* src, + size_t stride, + size_t CntRow, + size_t CntCol + ) +{ + int64_t stride_gap = size_t(int64_t(stride) - int64_t(CntCol)); + if (0 == stride_gap) { + const size_t len = CntRow * CntCol; + for (size_t i = 0; i < len; i++) { + *dest++ = MLAS_Float2Half(*(src++)); + } + return; + } + while (CntRow > 0) { + for (size_t k = 0; k < CntCol; k++) { + *dest++ = MLAS_Float2Half(*(src++)); + } + src += stride_gap; + CntRow--; + } +} template<> MLAS_FORCEINLINE @@ -64,11 +93,7 @@ MlasHalfGemmConvertPackA( size_t CountK ) { - for (size_t m = 0; m < CountM; m++) { - for (size_t k = 0; k < CountK; k++) { - *D++ = MLAS_Float2Half(*(A + m * lda + k)); - } - } + CvtFloat2Half2D(D, A, lda, CountM, CountK); } template<> @@ -82,11 +107,7 @@ MlasHalfGemmConvertPackB( size_t CountK ) { - for (size_t k = 0; k < CountK; k++) { - for (size_t n = 0; n < CountN; n++) { - *D++ = MLAS_Float2Half(*(B + k * ldb + n)); - } - } + CvtFloat2Half2D(D, B, ldb, CountK, CountN); } diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp index 57d82d3f7c3c8..a275dabe7df4d 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp @@ -55,26 +55,6 @@ class HalfGemmShortExecuteTest : public MlasTestFixture); - const BType* B = BufferB.GetFilledBuffer(N * K * BatchSize, SmallFloatFill); - const MLFp16* Bias = withBias ? BufferBias.GetFilledBuffer(N * BatchSize, SmallFloatFill) : nullptr; + const AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + AType Atail[16]; + std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); + + const BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); + BType Btail[16]; + std::memcpy(Btail, B + N * K * BatchSize, 16 * sizeof(BType)); + + MLFp16 BiasTail[16]; + const MLFp16* Bias = nullptr; + if (withBias) { + Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); + std::memcpy(BiasTail, Bias + N * BatchSize, 16 * sizeof(MLFp16)); + } + MLFp16* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); float* CReference = BufferCReference.GetFilledBuffer( N * M * BatchSize, @@ -195,6 +228,10 @@ class MlasHalfGemmTest : public MlasTestBase { } } } + ASSERT_EQ(std::memcmp(Atail, A + K * M * BatchSize, 16 * sizeof(AType)), 0) << "Matrix A buffer overwritten!"; + ASSERT_EQ(std::memcmp(Btail, B + N * K * BatchSize, 16 * sizeof(BType)), 0) << "Matrix B buffer overwritten!"; + if (withBias) + ASSERT_EQ(std::memcmp(BiasTail, Bias + N * BatchSize, 16 * sizeof(MLFp16)), 0) << "Bias buffer overwritten!"; } private: From db0958d49c513551f96fb80d588b0495f0f54cb4 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Fri, 27 Jan 2023 15:53:55 -0800 Subject: [PATCH 10/19] vector fp16 conversion --- .../core/mlas/lib/halfgemm_kernel_neon.cpp | 88 +++++++++++++++++-- 1 file changed, 79 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp index f3dea5ffe5f7d..a8f33a72bd70e 100644 --- a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp @@ -17,6 +17,8 @@ Module Name: #include "mlasi.h" #include "halfgemm.h" +#include "arm64_neon.h" + // // Define the prototypes of the NEON routines written in assembly. // @@ -52,6 +54,78 @@ struct MLAS_HALF_GEMM_KERNEL_NEON { static constexpr MLAS_HALF_GEMM_STRIDES Strides{24, 128, 16}; }; + +MLAS_FORCEINLINE +void +CvtHalf2Float( + float* dest, + const _mlas_fp16_* src, + size_t len +) +{ + while (len >= 4) { + const auto* srcPtr = reinterpret_cast(src); + auto* dstPtr = reinterpret_cast(dest); + *dstPtr = vcvt_f32_f16(*srcPtr); + src += 4; + dest += 4; + len -= 4; + } + + if (0 == len) { + return; + } + + float16x4_t buf; + std::memcpy(&buf, src, len * sizeof(_mlas_fp16_)); + float32x4_t res = vcvt_f32_f16(buf); + + if ((len & 2) != 0) { + vst1q_lane_f64(dest, res, 0); + res = vdupq_laneq_f64(res, 1); + dest += 2; + } + if ((len & 1) != 0) { + vst1q_lane_f32(dest, res, 0); + } +} + + +MLAS_FORCEINLINE +void +CvtFloat2Half( + _mlas_fp16_* dest, + const float* src, + size_t len +) +{ + while (len >= 4) { + const auto* srcPtr = reinterpret_cast(src); + auto* dstPtr = reinterpret_cast(dest); + *dstPtr = vcvt_f16_f32(*srcPtr); + src += 4; + dest += 4; + len -= 4; + } + + if (0 == len) { + return; + } + + float32x4_t buf; + std::memcpy(&buf, src, len * sizeof(float)); + float16x4_t res = vcvt_f16_f32(buf); + + if ((len & 2) != 0) { + vst1_lane_f32(dest, res, 0); + res = vdup_lane_f32(res, 1); + dest += 2; + } + if ((len & 1) != 0) { + vst1_lane_f16(dest, res, 0); + } +} + /** * @brief Convert a 2D matrix from float to fp16 */ @@ -65,19 +139,15 @@ CvtFloat2Half2D( size_t CntCol ) { - int64_t stride_gap = size_t(int64_t(stride) - int64_t(CntCol)); - if (0 == stride_gap) { + if (stride == CntCol) { const size_t len = CntRow * CntCol; - for (size_t i = 0; i < len; i++) { - *dest++ = MLAS_Float2Half(*(src++)); - } + CvtFloat2Half(dest, src, len); return; } while (CntRow > 0) { - for (size_t k = 0; k < CntCol; k++) { - *dest++ = MLAS_Float2Half(*(src++)); - } - src += stride_gap; + CvtFloat2Half(dest, src, CntCol); + src += stride; + dest += CntCol; CntRow--; } } From 9cfcd87f8ae9c7388acb70df9e34db8d9ba493fe Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Mon, 30 Jan 2023 11:11:29 -0800 Subject: [PATCH 11/19] halfgemm post processor --- onnxruntime/core/mlas/inc/mlas.h | 69 ++++++++++++-- onnxruntime/core/mlas/lib/halfgemm.cpp | 90 +++++++++++++++++++ .../core/mlas/lib/halfgemm_kernel_neon.cpp | 36 -------- .../test/mlas/unittest/test_halfgemm.h | 14 ++- 4 files changed, 162 insertions(+), 47 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index b3e79bbb297fe..dc59634064102 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1379,20 +1379,71 @@ using MLAS_FP16 = onnxruntime::MLFloat16; constexpr size_t FP16_SIZE = sizeof(uint16_t); -class MLAS_HALF_GEMM_OUTPUT_PROCESSOR { +/** + * @brief Interface for half gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from half precision to single precision, etc. + * + * Half GEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. +*/ +class MLAS_HALF_GEMM_POSTPROCESSOR { public: virtual void Process( - const MLAS_FP16*, // Supplies the address of matrix to process - size_t, // Supplies the start row index of matrix - size_t, // Supplies the start col index of matrix - size_t, // Supplies the element count per row to process - size_t, // Supplies the element count per col to process - size_t // Supplies the leading dimension of matrix + const MLAS_FP16*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ ) const = 0; - virtual ~MLAS_HALF_GEMM_OUTPUT_PROCESSOR() {} + virtual ~MLAS_HALF_GEMM_POSTPROCESSOR() {} +}; + +/** + * @brief Convert half gemm result matrix to single precision float matrix +*/ +class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR { +public: + MLAS_HALF_GEMM_2FLOAT_PROCESSOR( + float* Output, /**< address of the output matrix, row major */ + size_t RowStride /**< row stride of the output matrix */ + ) : + Output_(Output), + RowStride_(RowStride) + {} + + void + Process( + const MLAS_FP16* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const override; + +private: + inline + void + ProcessImpl( + const MLAS_FP16* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const; + +private: + float* Output_; + size_t RowStride_; }; @@ -1408,7 +1459,7 @@ struct MLAS_HALF_GEMM_DATA_PARAMS { size_t lda = 0; /**< leading dimension of A */ size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ size_t ldc = 0; /**< leading dimension of C*/ - const MLAS_HALF_GEMM_OUTPUT_PROCESSOR* OutputProcessor = nullptr; + const MLAS_HALF_GEMM_POSTPROCESSOR* OutputProcessor = nullptr; bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/ bool BIsfp32 = false; /**< matrix B is fp32, needs to be casted into fp16*/ }; diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 174c9131eb1f1..2a8ae1730461b 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -155,6 +155,96 @@ MlasHalfGemmConvertPackB( } +// +// Post Processor Implementations +// + +void +MLAS_HALF_GEMM_2FLOAT_PROCESSOR::Process( + const MLAS_FP16* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const +{ + ProcessImpl( + C, + StartM, + StartN, + CountM, + CountN, + ldc); +} + +MLAS_FORCEINLINE +void +CvtHalf2Float( + float* dest, + const _mlas_fp16_* src, + size_t len +) +{ +#ifdef MLAS_TARGET_ARM64 + while (len >= 4) { + const auto* srcPtr = reinterpret_cast(src); + auto* dstPtr = reinterpret_cast(dest); + *dstPtr = vcvt_f32_f16(*srcPtr); + src += 4; + dest += 4; + len -= 4; + } + + if (0 == len) { + return; + } + + float16x4_t buf; + std::memcpy(&buf, src, len * sizeof(_mlas_fp16_)); + float32x4_t res = vcvt_f32_f16(buf); + + if ((len & 2) != 0) { + vst1q_lane_f64(dest, res, 0); + res = vdupq_laneq_f64(res, 1); + dest += 2; + } + if ((len & 1) != 0) { + vst1q_lane_f32(dest, res, 0); + } +#else + throw std::invalid_argument("FP16 acceleration not supported in this platform!"); +#endif // MLAS_TARGET_ARM64 + +} + +MLAS_FORCEINLINE +void +MLAS_HALF_GEMM_2FLOAT_PROCESSOR::ProcessImpl( + const MLAS_FP16* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc) const +{ + // + // TODO!! use templates to add activations in this impl + // + float* Output = Output_; + const auto* CRow = reinterpret_cast(C); + CRow += StartM * ldc + StartN; + Output += StartM * RowStride_ + StartN; + + while (CountM-- > 0) { + CvtHalf2Float(Output, CRow, CountN); + + CRow += ldc; + Output += RowStride_; + } +} + + // // Dummy C++ implementation that runs very slowly // diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp index a8f33a72bd70e..4479876c4e346 100644 --- a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp @@ -55,42 +55,6 @@ struct MLAS_HALF_GEMM_KERNEL_NEON { }; -MLAS_FORCEINLINE -void -CvtHalf2Float( - float* dest, - const _mlas_fp16_* src, - size_t len -) -{ - while (len >= 4) { - const auto* srcPtr = reinterpret_cast(src); - auto* dstPtr = reinterpret_cast(dest); - *dstPtr = vcvt_f32_f16(*srcPtr); - src += 4; - dest += 4; - len -= 4; - } - - if (0 == len) { - return; - } - - float16x4_t buf; - std::memcpy(&buf, src, len * sizeof(_mlas_fp16_)); - float32x4_t res = vcvt_f32_f16(buf); - - if ((len & 2) != 0) { - vst1q_lane_f64(dest, res, 0); - res = vdupq_laneq_f64(res, 1); - dest += 2; - } - if ((len & 1) != 0) { - vst1q_lane_f32(dest, res, 0); - } -} - - MLAS_FORCEINLINE void CvtFloat2Half( diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h index d0d61cf427ca7..e8d0302390f43 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.h +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -81,6 +81,7 @@ class MlasHalfGemmTest : public MlasTestBase { MatrixGuardBuffer BufferBias; MatrixGuardBuffer BufferC; MatrixGuardBuffer BufferCReference; + MatrixGuardBuffer BufferFloatC; MLAS_THREADPOOL* threadpool_; void* PackB(size_t N, size_t K, const BType* B, size_t ldb) { @@ -107,7 +108,10 @@ class MlasHalfGemmTest : public MlasTestBase { size_t ldb, const MLFp16* Bias, MLFp16* C, - size_t ldc) { + size_t ldc, + float* Cfloat) { + std::vector Converters; + Converters.reserve(BatchSize); std::vector GemmParameters(BatchSize); @@ -133,6 +137,8 @@ class MlasHalfGemmTest : public MlasTestBase { } params.AIsfp32 = std::is_same::value; params.BIsfp32 = std::is_same::value; + Converters.emplace_back(Cfloat + (M * N * i), N); + params.OutputProcessor = &(Converters[i]); } MlasHalfGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); @@ -211,13 +217,14 @@ class MlasHalfGemmTest : public MlasTestBase { } MLFp16* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); + float* Cfloat = BufferFloatC.GetBuffer(N * M * BatchSize, true); float* CReference = BufferCReference.GetFilledBuffer( N * M * BatchSize, [](float* start, size_t size) { std::fill_n(start, size, -1.0f); }); - this->CallGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N); + this->CallGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N, Cfloat); ReferenceQgemm(M, N, K, BatchSize, A, B, Bias, CReference); for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { @@ -225,6 +232,9 @@ class MlasHalfGemmTest : public MlasTestBase { for (size_t n = 0; n < N; n++, f++) { ASSERT_EQ(float(C[f]), CReference[f]) << "@[" << batch << "x" << m << "x" << n << "], " << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K; + ASSERT_EQ(Cfloat[f], CReference[f]) << "Converted@[" << batch << "x" << m << "x" << n << "], " + << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K; + } } } From 39604690629781c2072755dd09f32ab7d6d6e5e2 Mon Sep 17 00:00:00 2001 From: Chen Fu Date: Thu, 2 Feb 2023 04:21:57 +0000 Subject: [PATCH 12/19] Linux compile and test --- cmake/onnxruntime_mlas.cmake | 2 + .../mlas/lib/aarch64/HalfGemmKernelNeon.S | 550 ++++++++++++++++++ onnxruntime/core/mlas/lib/halfgemm.cpp | 7 +- onnxruntime/core/mlas/lib/halfgemm.h | 38 +- .../core/mlas/lib/halfgemm_kernel_neon.cpp | 15 +- .../test/mlas/unittest/test_halfgemm.h | 5 +- 6 files changed, 585 insertions(+), 32 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 720694938afdf..80a65c6787eb9 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -323,6 +323,8 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp ) + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") diff --git a/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S new file mode 100644 index 0000000000000..8622929a9a0fe --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S @@ -0,0 +1,550 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + HalfGemmKernelNeon.s + +Abstract: + + This module implements the kernels for the half precision matrix/matrix + multiply operation (HALF GEMM). + +--*/ + +#include "asmmacro.h" + +// +// Stack frame layout for the half gemm kernel. +// Callee save registers: d8-d15, x19-x30. x18 is reserved by the OS. +// + .equ .LHGemmKernelFrame_SavedRegs, (2 * 8) + .equ .LHGemmKernelFrame_B, 0 + .LHGemmKernelFrame_SavedRegs + .equ .LHGemmKernelFrame_ldb, 8 + .LHGemmKernelFrame_SavedRegs + .equ .LHGemmKernelFrame_ZeroMode, 16 + .LHGemmKernelFrame_SavedRegs + + .text + +/*++ + +Routine Description: + + This routine is an inner kernel to compute 6 rows of GEMM + +Arguments: + + CountM - (x0) the number of rows for matrix A and matrix C. + only process 6 rows + + CountN - (x1) the number of columns from matrix B and matrix C + + CountK - (x2/x0) the number of columns from matrix A and the + number of rows from matrix B. + + C - (x3) the address of matrix C. + + ldc - (x4) - the first dimension of matrix C. + + Bias - (x5) - the address of the Bias vector (optional) + + A - (x6) - the address of matrix A + + lda - (x7) - the first dimension of matrix A + + B - the address of matrix B + + ldb - the first dimension of matrix B + + ZeroMode - true if the output matrix must be zero initialized, else + if the output matrix is accumulated into + +--*/ + + FUNCTION_ENTRY MlasHalfGemmKernelNeon + + str x19,[sp,#-.LHGemmKernelFrame_SavedRegs]! + ldr x9,[sp,#.LHGemmKernelFrame_ldb] + lsl x2,x2,#1 // k *= sizeof(fp16) + cmp x0,2 + add x14,x6,x7,lsl #1 // a1 = a0 + lda + add x10,x3,x4,lsl #1 // c1 = c0 + ldc + ldr x8,[sp,#.LHGemmKernelFrame_B] + csel x14,x6,x14,LO // M < 2 ? a1 = a0 + csel x10,x3,x10,LO // c1 = c0 + add x15,x14,x7,lsl #1 // a2 = a1 + lda + add x11,x10,x4,lsl #1 // c2 = c1 + ldc + csel x15,x14,x15,LS // M <= 2 ? a2 = a1 + csel x11,x10,x11,LS // c2 = c1 + cmp x0,4 + add x16,x15,x7,lsl #1 // a3 = a2 + lda + add x12,x11,x4,lsl #1 // c3 = c2 + ldc + csel x16,x15,x16,LO // M < 4 ? a3 = a2 + csel x12,x11,x12,LO // c3 = c2 + add x17,x16,x7,lsl #1 // a4 = a3 + lda + add x13,x12,x4,lsl #1 // c4 = c3 + ldc + csel x17,x16,x17,LS // M <= 4 ? a4 = a3 + csel x13,x12,x13,LS // c4 = c3 + cmp x0,6 + add x7,x17,x7,lsl #1 // a5 = a4 + lda + add x4,x13,x4,lsl #1 // c5 = c4 + ldc + csel x7,x17,x7,LO // M < 6 ? a5 = a4 + csel x4,x13,x4,LO // c5 = c4 + lsl x9,x9,#1 // ldb *= sizeof(fp16) + ldrb w19,[sp,#.LHGemmKernelFrame_ZeroMode] + sub x9,x9,16 // ldb -= 16 + +/**** +Main loop processes 6x16 tile, depth 4. + B 4x16 + --------------------------------------- + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + A 6x4 --------------------------------------- + ------------------ --------------------------------------- +x6 |v0.h[0]..v0.h[3]| |v20.h[0]..v20.h[7] v21.h[0]..v21.h[7]| x3 +x14 |v1.h[0]..v1.h[3]| |v22.h[0]..v22.h[7] v23.h[0]..v23.h[7]| x10 +x15 |v2.h[0]..v2.h[3]| |v24.h[0]..v24.h[7] v25.h[0]..v25.h[7]| x11 +x16 |v3.h[0]..v3.h[3]| |v26.h[0]..v26.h[7] v27.h[0]..v27.h[7]| x12 +x17 |v4.h[0]..v4.h[3]| |v28.h[0]..v28.h[7] v29.h[0]..v29.h[7]| x13 +x7 |v5.h[0]..v5.h[3]| |v30.h[0]..v30.h[7] v31.h[0]..v31.h[7]| x4 + ------------------ --------------------------------------- +****/ + +.LM6N16OutterLoopN: + cbz x5,.LM6N16SkipBias + ldp q20,q21,[x5],32 // Load 16 Bias values + b .LM6N16PopulateAccumulators + +.LM6N16SkipBias: + eor v20.16b,v20.16b,v20.16b // No bias, reset regs + eor v21.16b,v21.16b,v21.16b + +.LM6N16PopulateAccumulators: + mov v22.16b,v20.16b + mov v23.16b,v21.16b + mov v24.16b,v20.16b + mov v25.16b,v21.16b + mov v26.16b,v20.16b + mov v27.16b,v21.16b + mov v28.16b,v20.16b + subs x0,x2,8 // k -= 4 (8 bytes) + mov v29.16b,v21.16b + mov v30.16b,v20.16b + mov v31.16b,v21.16b + b.LO .LM6N16RemainderK123 // remaining k 1~3 + + ldr d0,[x6],8 // A0 + ldr q16,[x8],16 // B0.l + ld1 {v17.16b},[x8],x9 // B0.high x8 <- next row + subs x0,x0,8 // over decement k -= 4 (8 bytes) + ldr d1,[x14],8 // A1 + ldr d2,[x15],8 // A2 + ldr d3,[x16],8 // A3 + b.LO .LM6N16LoopK_Epilogue // need k>=8 for main loop + +.LM6N16InnerLoopK: + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] + ldr q18,[x8],16 // B1.low + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + subs x0,x0,8 // k -= 4 + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] + ldr q16,[x8],16 // B2.low + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] + ldr q18,[x8],16 // B3.low + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + ldr q16,[x8],16 // Load B0.low for next iter + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + ld1 {v17.16b},[x8],x9 // Load B0.high for next iter + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + ldr d0,[x6],8 // Load A0 for next iter + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + ldr d1,[x14],8 // Load A1 for next iter + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + ldr d2,[x15],8 // Load A2 for next iter + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + ldr d3,[x16],8 // Load A3 for next iter + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.hs .LM6N16InnerLoopK // k >= 8 for main loop + +.LM6N16LoopK_Epilogue: + // last block of k >= 4, no pre-load for next iter + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] + ldr q18,[x8],16 // B1.low + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + adds x0,x0,8 // revert k over-decrement + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] + ldr q16,[x8],16 // B2.low + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] + ldr q18,[x8],16 // B3.low + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.NE .LM6N16RemainderK123 // remaining k 1~3 + +.LM6N16OutterLoopNTail: + subs x1,x1,16 // N -= 16 + ldr x8,[sp,#.LHGemmKernelFrame_B] + b.LO .LM6StoreRemainderN // remaining k < 16 + + cbnz x19,.LM6N16SkipAccumulateOutput + ldp q0,q1,[x3] + ldp q2,q3,[x10] + ldp q4,q5,[x11] + ldp q6,q7,[x12] + ldp q16,q17,[x13] + ldp q18,q19,[x4] + fadd v20.8h,v20.8h,v0.8h // !ZeroMode + fadd v21.8h,v21.8h,v1.8h // accumulate into C + fadd v22.8h,v22.8h,v2.8h + fadd v23.8h,v23.8h,v3.8h + fadd v24.8h,v24.8h,v4.8h + fadd v25.8h,v25.8h,v5.8h + fadd v26.8h,v26.8h,v6.8h + fadd v27.8h,v27.8h,v7.8h + fadd v28.8h,v28.8h,v16.8h + fadd v29.8h,v29.8h,v17.8h + fadd v30.8h,v30.8h,v18.8h + fadd v31.8h,v31.8h,v19.8h + +.LM6N16SkipAccumulateOutput: + st1 {v20.16b,v21.16b},[x3],32 + sub x6,x6,x2 // restore a0 + st1 {v22.16b,v23.16b},[x10],32 + sub x14,x14,x2 // restore a1 + st1 {v24.16b,v25.16b},[x11],32 + sub x15,x15,x2 // restore a2 + st1 {v26.16b,v27.16b},[x12],32 + sub x16,x16,x2 // restore a3 + st1 {v28.16b,v29.16b},[x13],32 + sub x17,x17,x2 // restore a4 + add x8,x8,32 // B <- next 16 columns + st1 {v30.16b,v31.16b},[x4],32 + sub x7,x7,x2 // restore a5 + str x8,[sp,#.LHGemmKernelFrame_B] + b.HI .LM6N16OutterLoopN + +.LExitKernel: + ldr x19,[sp],#.LHGemmKernelFrame_SavedRegs + ret + +.LM6N16RemainderK123: + tbz x0,2,.LM6N16RemainderK1 + ldr s0,[x6],4 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr s1,[x14],4 // A1 + ldr s2,[x15],4 // A2 + ldr s3,[x16],4 // A3 + ldr s4,[x17],4 // A4 + ldr s5,[x7],4 // A5 + ldr q18,[x8],16 // B1.low + ld1 {v19.16b},[x8],x9 // B2.high + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + + fmla v20.8h,v18.8h,v0.h[1] + fmla v22.8h,v18.8h,v1.h[1] + fmla v24.8h,v18.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v21.8h,v19.8h,v0.h[1] + fmla v23.8h,v19.8h,v1.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v31.8h,v19.8h,v5.h[1] + tbz x0,1,.LM6N16OutterLoopNTail + +.LM6N16RemainderK1: + ldr h0,[x6],2 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr h1,[x14],2 // A1 + ldr h2,[x15],2 // A2 + ldr h3,[x16],2 // A3 + ldr h4,[x17],2 // A4 + ldr h5,[x7],2 // A5 + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + b .LM6N16OutterLoopNTail + +.LM6StoreRemainderN: + cbnz x19,.LM6StoreRemainderNZeroMode + tbz x1,3,.LM6StoreRemainderN4 + ldr q0,[x3] + ldr q1,[x10] + ldr q2,[x11] + ldr q3,[x12] + ldr q4,[x13] + ldr q5,[x4] + fadd v20.8h,v20.8h,v0.8h + fadd v22.8h,v22.8h,v1.8h + fadd v24.8h,v24.8h,v2.8h + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b + fadd v26.8h,v26.8h,v3.8h + fadd v28.8h,v28.8h,v4.8h + fadd v30.8h,v30.8h,v5.8h + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b + +.LM6StoreRemainderN4: + tbz x1,2,.LM6StoreRemainderN2 + ldr d0,[x3] + ldr d1,[x10] + ldr d2,[x11] + ldr d3,[x12] + ldr d4,[x13] + ldr d5,[x4] + fadd v21.4h,v20.4h,v0.4h + dup d20,v20.d[1] + fadd v23.4h,v22.4h,v1.4h + dup d22,v22.d[1] + fadd v25.4h,v24.4h,v2.4h + dup d24,v24.d[1] + fadd v27.4h,v26.4h,v3.4h + dup d26,v26.d[1] + fadd v29.4h,v28.4h,v4.4h + dup d28,v28.d[1] + fadd v31.4h,v30.4h,v5.4h + dup d30,v30.d[1] + str d21,[x3],8 + str d23,[x10],8 + str d25,[x11],8 + str d27,[x12],8 + str d29,[x13],8 + str d31,[x4],8 + +.LM6StoreRemainderN2: + tbz x1,1,.LM6StoreRemainderN1 + ldr s0,[x3] + ldr s1,[x10] + ldr s2,[x11] + ldr s3,[x12] + ldr s4,[x13] + ldr s5,[x4] + fadd v21.4h,v20.4h,v0.4h + fadd v23.4h,v22.4h,v1.4h + fadd v25.4h,v24.4h,v2.4h + fadd v27.4h,v26.4h,v3.4h + fadd v29.4h,v28.4h,v4.4h + fadd v31.4h,v30.4h,v5.4h + str s21,[x3],4 + str s23,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s25,[x11],4 + str s27,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s29,[x13],4 + str s31,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] + +.LM6StoreRemainderN1: + tbz x1,0,.LExitKernel + ldr h0,[x3] + ldr h1,[x10] + ldr h2,[x11] + ldr h3,[x12] + ldr h4,[x13] + ldr h5,[x4] + fadd v20.4h,v20.4h,v0.4h + fadd v22.4h,v22.4h,v1.4h + fadd v24.4h,v24.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v28.4h,v28.4h,v4.4h + fadd v30.4h,v30.4h,v5.4h + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] + b .LExitKernel + +.LM6StoreRemainderNZeroMode: + tbz x1,3,.LM6StoreRemainderN4ZeroMode + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b + +.LM6StoreRemainderN4ZeroMode: + tbz x1,2,.LM6StoreRemainderN2ZeroMode + str d20,[x3],8 + str d22,[x10],8 + dup d20,v20.d[1] + dup d22,v22.d[1] + str d24,[x11],8 + str d26,[x12],8 + dup d24,v24.d[1] + dup d26,v26.d[1] + str d28,[x13],8 + str d30,[x4],8 + dup d28,v28.d[1] + dup d30,v30.d[1] + +.LM6StoreRemainderN2ZeroMode: + tbz x1,1,.LM6StoreRemainderN1ZeroMode + str s20,[x3],4 + str s22,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s24,[x11],4 + str s26,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s28,[x13],4 + str s30,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] + +.LM6StoreRemainderN1ZeroMode: + tbz x1,0,.LExitKernel + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] + b .LExitKernel + + .end diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 2a8ae1730461b..0f32f63e0d294 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -205,8 +205,9 @@ CvtHalf2Float( float32x4_t res = vcvt_f32_f16(buf); if ((len & 2) != 0) { - vst1q_lane_f64(dest, res, 0); - res = vdupq_laneq_f64(res, 1); + auto wide = vreinterpretq_f64_f32(res); + vst1q_lane_f64((float64_t*)dest, wide, 0); + res = vreinterpretq_f32_f64(vdupq_laneq_f64(wide, 1)); dest += 2; } if ((len & 1) != 0) { @@ -337,7 +338,7 @@ MlasHalfGemmKernel( const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { MlasHalfGemmOperation, - nullptr, + nullptr, MlasHalfGemmConvertPackB, MLAS_HALF_GEMM_KERNEL_DEFAULT::PackedK, MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM, diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 65399fe044202..21284ae6fd0af 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -52,16 +52,16 @@ struct MLAS_HALF_GEMM_STRIDES { /** * @brief Packing function for fp16 B matrix - * - * @tparam KernelType + * + * @tparam KernelType * @param[out] D Address of packing buffer * @param[in] B Address of source matrix B - * @param[in] ldb Leading dimension of B - * @param[in] CountN # of column to pack + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack * @param[in] CountK # of rows to pack */ template -MLAS_FORCEINLINE +MLAS_FORCEINLINE void MlasHalfGemmCopyPackB( _mlas_fp16_* D, @@ -81,12 +81,12 @@ MlasHalfGemmCopyPackB( /** * @brief Convert fp32 matrix A to fp16 and pack the data - * - * @tparam KernelType + * + * @tparam KernelType * @param[out] D Address of the packing buffer * @param[in] A Address of fp32 matrix A * @param[in] lda leading dimension of A - * @param[in] CountM # of rows to pack + * @param[in] CountM # of rows to pack * @param[in] CountK # of columns to pack */ template @@ -121,13 +121,13 @@ MlasHalfGemmConvertPackB( /** * @brief Find the location of PackedB[StartK, StartN] - * - * @tparam KernelType - * @param PackedB + * + * @tparam KernelType + * @param PackedB * @param DimN Total columns of the packing buffer * @param DimK Total rows of the packing buffer - * @param StartN - * @param StartK + * @param StartN + * @param StartK * @return Address of PackedB[StartK, StartN] */ template @@ -149,9 +149,9 @@ MlasHalfGemmPackedBOffset( /** * @brief leading dimension of the packed B buffer * Related to how B is packed - * @tparam KernelType - * @param DimN - * @param DimK + * @tparam KernelType + * @param DimN + * @param DimK * @return leading dimension of the packed B buffer */ template @@ -223,7 +223,7 @@ MlasHalfGemmNoPackOperation( } const _mlas_fp16_* Bias = (nullptr == Data->Bias) - ? nullptr + ? nullptr : reinterpret_cast(Data->Bias) + RangeStartN; _mlas_fp16_* c = reinterpret_cast<_mlas_fp16_*>(Data->C) + RangeStartM * ldc + RangeStartN; @@ -278,12 +278,12 @@ MlasHalfGemmOperation( const size_t ldb = Data->ldb; const size_t ldc = Data->ldc; - if (!Data->AIsfp32 && (ldb == 0 || !KernelType::PackNeeded && !Data->BIsfp32)) { + if (!Data->AIsfp32 && (ldb == 0 || (!KernelType::PackNeeded && !Data->BIsfp32))) { // !Data->AIsfp32 => A is fp16, no packing on the left hand side // ldb == 0 => B is already packed, no packing on the right hand side // !KernelType::PackNeeded && !Data->BIsfp32 => B is fp16 and the kernel // does not require packing - // + // // So no packing needed on either A or B, use a simpler driver instead MlasHalfGemmNoPackOperation( diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp index 4479876c4e346..d7f5a90b00589 100644 --- a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp @@ -17,7 +17,7 @@ Module Name: #include "mlasi.h" #include "halfgemm.h" -#include "arm64_neon.h" +#include "arm_neon.h" // // Define the prototypes of the NEON routines written in assembly. @@ -51,7 +51,7 @@ struct MLAS_HALF_GEMM_KERNEL_NEON { static constexpr size_t KernelMaxM = 6; // max # rows the vectorized kernel can process static constexpr size_t PackedK = 1; - static constexpr MLAS_HALF_GEMM_STRIDES Strides{24, 128, 16}; + static constexpr MLAS_HALF_GEMM_STRIDES Strides{24, 128, 512}; }; @@ -81,12 +81,13 @@ CvtFloat2Half( float16x4_t res = vcvt_f16_f32(buf); if ((len & 2) != 0) { - vst1_lane_f32(dest, res, 0); - res = vdup_lane_f32(res, 1); + auto wide = vreinterpret_f32_f16(res); + vst1_lane_f32((float32_t*)dest, wide, 0); + res = vreinterpret_f16_f32(vdup_lane_f32(wide, 1)); dest += 2; } if ((len & 1) != 0) { - vst1_lane_f16(dest, res, 0); + vst1_lane_u16(dest, vreinterpret_u16_f16(res), 0); } } @@ -141,7 +142,7 @@ MlasHalfGemmConvertPackB( size_t CountK ) { - CvtFloat2Half2D(D, B, ldb, CountK, CountN); + CvtFloat2Half2D(D, B, ldb, CountK, CountN); } @@ -178,7 +179,7 @@ MlasHalfGemmKernel( const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon = { MlasHalfGemmOperation, - nullptr, + nullptr, MlasHalfGemmConvertPackB, MLAS_HALF_GEMM_KERNEL_NEON::PackedK, MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM, diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h index e8d0302390f43..6b1118d27124e 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.h +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -240,8 +240,9 @@ class MlasHalfGemmTest : public MlasTestBase { } ASSERT_EQ(std::memcmp(Atail, A + K * M * BatchSize, 16 * sizeof(AType)), 0) << "Matrix A buffer overwritten!"; ASSERT_EQ(std::memcmp(Btail, B + N * K * BatchSize, 16 * sizeof(BType)), 0) << "Matrix B buffer overwritten!"; - if (withBias) + if (withBias){ ASSERT_EQ(std::memcmp(BiasTail, Bias + N * BatchSize, 16 * sizeof(MLFp16)), 0) << "Bias buffer overwritten!"; + } } private: @@ -316,5 +317,3 @@ class MlasHalfGemmTest : public MlasTestBase { }; - - From f9e2563bfe94f44cd5f7a6371f6a5c7d803ef6b1 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 2 Feb 2023 09:07:33 -0800 Subject: [PATCH 13/19] fix compilation warning --- onnxruntime/core/mlas/lib/halfgemm.cpp | 29 ++++++-------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 0f32f63e0d294..b288867d94763 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -159,25 +159,6 @@ MlasHalfGemmConvertPackB( // Post Processor Implementations // -void -MLAS_HALF_GEMM_2FLOAT_PROCESSOR::Process( - const MLAS_FP16* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) const -{ - ProcessImpl( - C, - StartM, - StartN, - CountM, - CountN, - ldc); -} - MLAS_FORCEINLINE void CvtHalf2Float( @@ -214,20 +195,22 @@ CvtHalf2Float( vst1q_lane_f32(dest, res, 0); } #else + MLAS_UNREFERENCED_PARAMETER(dest); + MLAS_UNREFERENCED_PARAMETER(src); + MLAS_UNREFERENCED_PARAMETER(len); throw std::invalid_argument("FP16 acceleration not supported in this platform!"); #endif // MLAS_TARGET_ARM64 - } -MLAS_FORCEINLINE void -MLAS_HALF_GEMM_2FLOAT_PROCESSOR::ProcessImpl( +MLAS_HALF_GEMM_2FLOAT_PROCESSOR::Process( const MLAS_FP16* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, - size_t ldc) const + size_t ldc + ) const { // // TODO!! use templates to add activations in this impl From 5c0416899a5210ad9b86b0d22f64183eaa37689c Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 2 Feb 2023 10:14:04 -0800 Subject: [PATCH 14/19] detect fp16 hardware acceleration support --- onnxruntime/core/mlas/inc/mlas.h | 29 +++++++------------ onnxruntime/core/mlas/lib/halfgemm.cpp | 15 +++++++++- .../test/mlas/unittest/test_halfgemm.cpp | 8 ++++- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index dc59634064102..4aac20fe08f24 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1379,6 +1379,9 @@ using MLAS_FP16 = onnxruntime::MLFloat16; constexpr size_t FP16_SIZE = sizeof(uint16_t); +bool MLASCALL +MlasFp16AccelerationSupported(); + /** * @brief Interface for half gemm post processors. * @@ -1395,12 +1398,12 @@ class MLAS_HALF_GEMM_POSTPROCESSOR { virtual void Process( - const MLAS_FP16*, /**< the address of matrix to process */ - size_t, /**< the start row index of matrix */ - size_t, /**< the start col index of matrix */ - size_t, /**< the element count per row to process */ - size_t, /**< the element count per col to process */ - size_t /**< the leading dimension of matrix */ + MLAS_FP16*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ ) const = 0; virtual ~MLAS_HALF_GEMM_POSTPROCESSOR() {} @@ -1421,7 +1424,7 @@ class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR { void Process( - const MLAS_FP16* C, + MLAS_FP16* C, size_t StartM, size_t StartN, size_t CountM, @@ -1429,18 +1432,6 @@ class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR { size_t ldc ) const override; -private: - inline - void - ProcessImpl( - const MLAS_FP16* C, - size_t StartM, - size_t StartN, - size_t CountM, - size_t CountN, - size_t ldc - ) const; - private: float* Output_; size_t RowStride_; diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index b288867d94763..f50e03980c655 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -22,6 +22,19 @@ Module Name: #include +bool MLASCALL +MlasFp16AccelerationSupported() +{ +#ifdef MLAS_NEON64_INTRINSICS + // TODO!! Only support for ARMv8.2 + // TODO!! how to detect ARMv8.0 ??? + return true; +#else + return false; +#endif +} + + void MLASCALL @@ -204,7 +217,7 @@ CvtHalf2Float( void MLAS_HALF_GEMM_2FLOAT_PROCESSOR::Process( - const MLAS_FP16* C, + MLAS_FP16* C, size_t StartM, size_t StartN, size_t CountM, diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp index a275dabe7df4d..652645448d9dd 100644 --- a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp @@ -196,5 +196,11 @@ static size_t HalfGemmRegistShortExecute() { } static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - return is_short_execute ? HalfGemmRegistShortExecute() : HalfGemmRegistLongExecute(); + if (!MlasFp16AccelerationSupported()) { + return false; + } + if (is_short_execute) { + return HalfGemmRegistShortExecute() > 0; + } + return HalfGemmRegistLongExecute() > 0; }); From 038f59c62abfe444eb64d49739beb594aa2f4b44 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 2 Feb 2023 10:35:40 -0800 Subject: [PATCH 15/19] fix dummy path compiler warning --- onnxruntime/core/mlas/lib/halfgemm.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index f50e03980c655..837b8f1bbd2b8 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -208,10 +208,9 @@ CvtHalf2Float( vst1q_lane_f32(dest, res, 0); } #else - MLAS_UNREFERENCED_PARAMETER(dest); - MLAS_UNREFERENCED_PARAMETER(src); - MLAS_UNREFERENCED_PARAMETER(len); - throw std::invalid_argument("FP16 acceleration not supported in this platform!"); + for (size_t i = 0; i < len; i++) { + *dest++ = MLAS_Half2Float(*src++); + } #endif // MLAS_TARGET_ARM64 } From 21ac3be69939a13e0ab7595ef44c66d93a2701e0 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 2 Feb 2023 13:29:35 -0800 Subject: [PATCH 16/19] suppress prefast warnings --- onnxruntime/core/mlas/inc/mlas_float16.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/onnxruntime/core/mlas/inc/mlas_float16.h b/onnxruntime/core/mlas/inc/mlas_float16.h index a8d566677a126..33227ea90d6be 100644 --- a/onnxruntime/core/mlas/inc/mlas_float16.h +++ b/onnxruntime/core/mlas/inc/mlas_float16.h @@ -30,6 +30,17 @@ union fp32_bits { float f; }; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) + +/*PreFast told us to convert them to constexpr but the compiler says we can't.*/ +#pragma warning(disable : 26497) + +/*Added whole bunch of casts, still can't get rid of these overflow warnings.*/ +#pragma warning(disable : 26450) +#pragma warning(disable : 26451) +#endif + inline _mlas_fp16_ MLAS_Float2Half(float ff) @@ -98,3 +109,7 @@ MLAS_Half2Float(_mlas_fp16_ val) o.u |= (val & 0x8000) << 16; // sign bit return o.f; } + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif \ No newline at end of file From ed874c1e4c85c8275707e4f2a55267301503af60 Mon Sep 17 00:00:00 2001 From: Chen Fu Date: Fri, 3 Feb 2023 00:35:34 +0000 Subject: [PATCH 17/19] fix prefast warning 2 --- onnxruntime/core/mlas/lib/halfgemm.h | 10 +++++++++- onnxruntime/core/mlas/lib/mlasi.h | 6 +++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 21284ae6fd0af..9e781207571a4 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -146,6 +146,12 @@ MlasHalfGemmPackedBOffset( return PackedB + StartK * DimN + StartN; } +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +/*No it can NOT be constexpr!.*/ +#pragma warning(disable : 26497) +#endif + /** * @brief leading dimension of the packed B buffer * Related to how B is packed @@ -166,7 +172,9 @@ MlasHalfGemmPackedBLeadingDim( MLAS_UNREFERENCED_PARAMETER(DimK); return DimN; } - +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif template void diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 051e8c0352b6a..495c2f06ea757 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -744,9 +744,9 @@ extern "C" { // thread to perform additional work. // -#define MLAS_SGEMM_THREAD_COMPLEXITY (64 * 1024) -#define MLAS_DGEMM_THREAD_COMPLEXITY (64 * 1024) -#define MLAS_QGEMM_THREAD_COMPLEXITY (64 * 1024) +#define MLAS_SGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#define MLAS_QGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) // // Single-threaded single precision matrix/matrix multiply operation. From a451ce9caafe0a2c8188c5d1dac217f9feacc135 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Mon, 6 Feb 2023 11:28:14 -0800 Subject: [PATCH 18/19] fp16 detect --- onnxruntime/core/common/cpuid_info.cc | 38 ++++++++++++++++++++++++++ onnxruntime/core/common/cpuid_info.h | 6 +++- onnxruntime/core/mlas/lib/halfgemm.cpp | 9 +----- onnxruntime/core/mlas/lib/mlasi.h | 3 ++ onnxruntime/core/mlas/lib/platform.cpp | 3 ++ 5 files changed, 50 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index b950e4e734fa5..03460c9def5bd 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -143,6 +143,7 @@ void CPUIDInfo::ArmLinuxInit() { if (pytorch_cpuinfo_init_) { is_hybrid_ = cpuinfo_get_uarchs_count() > 1; has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); + has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); is_armv8_narrow_ld_.resize(core_cnt, false); @@ -165,6 +166,7 @@ void CPUIDInfo::ArmLinuxInit() { } } else { has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); + has_fp16_ |= has_arm_neon_dot_; } } @@ -220,9 +222,45 @@ void CPUIDInfo::ArmWindowsInit() { lastUarch = uarch; } } + + switch (lastUarch) { + case cpuinfo_uarch_cortex_a55: + case cpuinfo_uarch_cortex_a55r0: + case cpuinfo_uarch_cortex_a76: + case cpuinfo_uarch_neoverse_n1: + case cpuinfo_uarch_cortex_a77: + case cpuinfo_uarch_exynos_m4: + case cpuinfo_uarch_exynos_m5: + has_fp16_ = true; + break; + default: + break; + } + if (!has_fp16_) { + /* + * Detecting fp16 support. Different cores should have the same instruction set. + * So we just check the first ID_AA64PFR0_EL1 + * Op0(0b11), Op1(0b000), CRn(0b0000), CRm(0b0100), Op2(0b000), + */ + uint64_t ID_AA64PFR0_EL1; + unsigned long valsize = sizeof(uint64_t); + auto retCode = ::RegGetValueA( + HKEY_LOCAL_MACHINE, + "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", + "CP 4020", RRF_RT_REG_QWORD, nullptr, + &ID_AA64PFR0_EL1, &valsize); + if (retCode == ERROR_SUCCESS) { + // AdvSIMD, bits [23:20] + auto advSimd = ID_AA64PFR0_EL1 >> 20; + if ((advSimd & 0xfULL) == 1) { + has_fp16_ = true; + } + } + } #endif /* Application Family or OneCore Family */ has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); + has_fp16_ |= has_arm_neon_dot_; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 858f8595b8220..c413e0ca7ed5f 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -21,7 +21,7 @@ class CPUIDInfo { bool HasAVX512f() const { return has_avx512f_; } bool HasAVX512_BF16() const {return has_avx512_bf16_;} bool HasAVX512Skylake() const { return has_avx512_skylake_; } - bool HasF16C() const { return has_f16c_; } + bool HasF16C() const { return has_f16c_; } /*fp16 conversion inst*/ bool HasSSE3() const { return has_sse3_; } bool HasSSE4_1() const { return has_sse4_1_; } bool IsHybrid() const { return is_hybrid_; } @@ -85,6 +85,9 @@ class CPUIDInfo { return is_armv8_narrow_ld_[coreIdx]; } + bool HasFp16VectorAcceleration() const { + return has_fp16_; + } private: CPUIDInfo() { @@ -118,6 +121,7 @@ class CPUIDInfo { std::vector is_armv8_narrow_ld_; bool has_arm_neon_dot_{false}; + bool has_fp16_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 837b8f1bbd2b8..778db2003d6c6 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -25,17 +25,10 @@ Module Name: bool MLASCALL MlasFp16AccelerationSupported() { -#ifdef MLAS_NEON64_INTRINSICS - // TODO!! Only support for ARMv8.2 - // TODO!! how to detect ARMv8.0 ??? - return true; -#else - return false; -#endif + return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration(); } - void MLASCALL MlasHalfGemmBatch( diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 495c2f06ea757..21949535cf63b 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -123,6 +123,8 @@ class MLASCPUIDInfo // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } + bool HasFp16VectorAcceleration() const { return has_fp16_; } + uint32_t GetCurrentCoreIdx() const { return 0xFFFFFFFF; } int32_t GetCurrentUarch() const { return -1; } @@ -137,6 +139,7 @@ class MLASCPUIDInfo MLASCPUIDInfo(); bool has_arm_neon_dot_{false}; + bool has_fp16_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 7d8624a32a218..32ada2ee4cfa9 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -36,6 +36,9 @@ Module Name: MLASCPUIDInfo::MLASCPUIDInfo() { has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); + + // raw hack! Need CPUIDInfo implementation for more precise detection + has_fp16_ = has_arm_neon_dot_; } #endif From 7925c91bab6200c50c6bfd6eaa795fcf676e35a1 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Mon, 13 Feb 2023 11:35:25 -0800 Subject: [PATCH 19/19] comments --- onnxruntime/core/mlas/inc/mlas.h | 2 +- onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S | 2 +- onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm | 2 +- onnxruntime/core/mlas/lib/platform.cpp | 8 +++++++- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 4aac20fe08f24..84285d7bcacf0 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -615,7 +615,7 @@ MlasGemm( // Currently only supported in ARM64 // #if defined(MLAS_TARGET_ARM64) -constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 15; +constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 30; #else constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 0; #endif diff --git a/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S index 8622929a9a0fe..036928d21b8ca 100644 --- a/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S @@ -283,7 +283,7 @@ x7 |v5.h[0]..v5.h[3]| |v30.h[0]..v30.h[7] v31.h[0]..v31.h[7]| x4 .LM6N16OutterLoopNTail: subs x1,x1,16 // N -= 16 ldr x8,[sp,#.LHGemmKernelFrame_B] - b.LO .LM6StoreRemainderN // remaining k < 16 + b.LO .LM6StoreRemainderN // remaining N < 16 cbnz x19,.LM6N16SkipAccumulateOutput ldp q0,q1,[x3] diff --git a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm index f4f26da711097..d7b626327780c 100644 --- a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm +++ b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm @@ -283,7 +283,7 @@ M6N16LoopK_Epilogue M6N16OutterLoopNTail subs x1,x1,16 // N -= 16 ldr x8,[sp,#HGemmKernelFrame_B] - b.LO M6StoreRemainderN // remaining k < 16 + b.LO M6StoreRemainderN // remaining N < 16 cbnz x19,M6N16SkipAccumulateOutput ldp q0,q1,[x3] diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 32ada2ee4cfa9..c52d4f3b0b8c4 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -53,7 +53,13 @@ MLASCPUIDInfo::MLASCPUIDInfo() #endif #if defined(BUILD_MLAS_NO_ONNXRUNTIME) -MLASCPUIDInfo::MLASCPUIDInfo() { has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); } +MLASCPUIDInfo::MLASCPUIDInfo() +{ + has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); + + // raw hack! Need CPUIDInfo implementation for more precise detection + has_fp16_ = has_arm_neon_dot_; +} #endif #else