From 733ca85b7395a83c6b671936db94be98d600e3ae Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Wed, 15 Feb 2023 12:51:53 -0800 Subject: [PATCH] Cfu fp16 (#14538) ### Description FP16 GEMM, including hardware agnostic driver code, a slow C++ kernel, and ARM64 NEON kernel. ### Motivation and Context First step in creating native support of fp16 model inferencing on ARM64 and AMD64 platforms. --------- Co-authored-by: Chen Fu --- cmake/onnxruntime_mlas.cmake | 7 + onnxruntime/core/common/cpuid_info.cc | 38 ++ onnxruntime/core/common/cpuid_info.h | 6 +- onnxruntime/core/mlas/inc/mlas.h | 178 +++++- onnxruntime/core/mlas/inc/mlas_float16.h | 115 ++++ .../mlas/lib/aarch64/HalfGemmKernelNeon.S | 550 +++++++++++++++++ .../mlas/lib/arm64/HalfGemmKernelNeon.asm | 552 ++++++++++++++++++ onnxruntime/core/mlas/lib/halfgemm.cpp | 334 +++++++++++ onnxruntime/core/mlas/lib/halfgemm.h | 515 ++++++++++++++++ .../core/mlas/lib/halfgemm_kernel_neon.cpp | 187 ++++++ onnxruntime/core/mlas/lib/mlasi.h | 55 +- onnxruntime/core/mlas/lib/platform.cpp | 11 +- .../test/mlas/unittest/test_halfgemm.cpp | 206 +++++++ .../test/mlas/unittest/test_halfgemm.h | 319 ++++++++++ onnxruntime/test/mlas/unittest/test_util.h | 45 +- 15 files changed, 3091 insertions(+), 27 deletions(-) create mode 100644 onnxruntime/core/mlas/inc/mlas_float16.h create mode 100644 onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S create mode 100644 onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm create mode 100644 onnxruntime/core/mlas/lib/halfgemm.cpp create mode 100644 onnxruntime/core/mlas/lib/halfgemm.h create mode 100644 onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_halfgemm.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_halfgemm.h diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 1f9b7129943e6..80a65c6787eb9 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -19,6 +19,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/platform.cpp ${MLAS_SRC_DIR}/threading.cpp ${MLAS_SRC_DIR}/sgemm.cpp + ${MLAS_SRC_DIR}/halfgemm.cpp ${MLAS_SRC_DIR}/qgemm.cpp ${MLAS_SRC_DIR}/qdwconv.cpp ${MLAS_SRC_DIR}/convolve.cpp @@ -59,6 +60,7 @@ function(setup_mlas_source_for_windows) if(onnxruntime_target_platform STREQUAL "ARM64") target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp @@ -73,6 +75,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm + ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm @@ -305,6 +308,7 @@ else() ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S + ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S @@ -314,10 +318,13 @@ else() ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp ) + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + if(ONNXRUNTIME_MLAS_MULTI_ARCH) onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index b950e4e734fa5..03460c9def5bd 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -143,6 +143,7 @@ void CPUIDInfo::ArmLinuxInit() { if (pytorch_cpuinfo_init_) { is_hybrid_ = cpuinfo_get_uarchs_count() > 1; has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); + has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); is_armv8_narrow_ld_.resize(core_cnt, false); @@ -165,6 +166,7 @@ void CPUIDInfo::ArmLinuxInit() { } } else { has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); + has_fp16_ |= has_arm_neon_dot_; } } @@ -220,9 +222,45 @@ void CPUIDInfo::ArmWindowsInit() { lastUarch = uarch; } } + + switch (lastUarch) { + case cpuinfo_uarch_cortex_a55: + case cpuinfo_uarch_cortex_a55r0: + case cpuinfo_uarch_cortex_a76: + case cpuinfo_uarch_neoverse_n1: + case cpuinfo_uarch_cortex_a77: + case cpuinfo_uarch_exynos_m4: + case cpuinfo_uarch_exynos_m5: + has_fp16_ = true; + break; + default: + break; + } + if (!has_fp16_) { + /* + * Detecting fp16 support. Different cores should have the same instruction set. + * So we just check the first ID_AA64PFR0_EL1 + * Op0(0b11), Op1(0b000), CRn(0b0000), CRm(0b0100), Op2(0b000), + */ + uint64_t ID_AA64PFR0_EL1; + unsigned long valsize = sizeof(uint64_t); + auto retCode = ::RegGetValueA( + HKEY_LOCAL_MACHINE, + "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", + "CP 4020", RRF_RT_REG_QWORD, nullptr, + &ID_AA64PFR0_EL1, &valsize); + if (retCode == ERROR_SUCCESS) { + // AdvSIMD, bits [23:20] + auto advSimd = ID_AA64PFR0_EL1 >> 20; + if ((advSimd & 0xfULL) == 1) { + has_fp16_ = true; + } + } + } #endif /* Application Family or OneCore Family */ has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); + has_fp16_ |= has_arm_neon_dot_; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 858f8595b8220..c413e0ca7ed5f 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -21,7 +21,7 @@ class CPUIDInfo { bool HasAVX512f() const { return has_avx512f_; } bool HasAVX512_BF16() const {return has_avx512_bf16_;} bool HasAVX512Skylake() const { return has_avx512_skylake_; } - bool HasF16C() const { return has_f16c_; } + bool HasF16C() const { return has_f16c_; } /*fp16 conversion inst*/ bool HasSSE3() const { return has_sse3_; } bool HasSSE4_1() const { return has_sse4_1_; } bool IsHybrid() const { return is_hybrid_; } @@ -85,6 +85,9 @@ class CPUIDInfo { return is_armv8_narrow_ld_[coreIdx]; } + bool HasFp16VectorAcceleration() const { + return has_fp16_; + } private: CPUIDInfo() { @@ -118,6 +121,7 @@ class CPUIDInfo { std::vector is_armv8_narrow_ld_; bool has_arm_neon_dot_{false}; + bool has_fp16_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index c1a4d16fd44fb..07757917de59d 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -90,7 +90,7 @@ typedef enum { CblasLeft=141, CblasRight=142} CBLAS_SIDE; #endif // -// Forward declare the thread pool implementation class. +// Forward declare the thread pool implementation class and half precision floating point. // // N.B. Avoid including ONNX Runtime headers here to keep the dependencies for // standalone MLAS test executables smaller. @@ -100,10 +100,12 @@ namespace onnxruntime { namespace concurrency { class ThreadPool; }; -}; + struct MLFloat16; +}; // namespace onnxruntime using MLAS_THREADPOOL = onnxruntime::concurrency::ThreadPool; + // // Platform routines. // @@ -613,7 +615,7 @@ MlasGemm( // Currently only supported in ARM64 // #if defined(MLAS_TARGET_ARM64) -constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 15; +constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 30; #else constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 0; #endif @@ -1367,3 +1369,173 @@ MlasQLinearMul( size_t N, bool IsScalarB ); + +// +// Half precision routines +// + +// Any type with size=2 should work +using MLAS_FP16 = onnxruntime::MLFloat16; + +constexpr size_t FP16_SIZE = sizeof(uint16_t); + + +bool MLASCALL +MlasFp16AccelerationSupported(); + +/** + * @brief Interface for half gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from half precision to single precision, etc. + * + * Half GEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. +*/ +class MLAS_HALF_GEMM_POSTPROCESSOR { +public: + virtual + void + Process( + MLAS_FP16*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_HALF_GEMM_POSTPROCESSOR() {} +}; + +/** + * @brief Convert half gemm result matrix to single precision float matrix +*/ +class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR { +public: + MLAS_HALF_GEMM_2FLOAT_PROCESSOR( + float* Output, /**< address of the output matrix, row major */ + size_t RowStride /**< row stride of the output matrix */ + ) : + Output_(Output), + RowStride_(RowStride) + {} + + void + Process( + MLAS_FP16* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const override; + +private: + float* Output_; + size_t RowStride_; +}; + + +/** + * @brief Data parameters for half precision GEMM routine + * All except C are [in] parameters +*/ +struct MLAS_HALF_GEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ + const MLAS_FP16* Bias = nullptr; /**< address of Bias, vector size N */ + MLAS_FP16* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_HALF_GEMM_POSTPROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be casted into fp16*/ +}; + +/** + * @brief Half precision Batched GEMM: C = A * B + Bias + * Either A or B can be fp32 or fp16 + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return +*/ +void +MLASCALL +MlasHalfGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool = nullptr + ); + +/** + * @brief For half precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] float2half Whether the input is float that + * needs to be converted to half precision + * @return size of the packing buffer, + * 0 if operation not supported +*/ +size_t +MLASCALL +MlasHalfGemmPackBSize( + size_t N, + size_t K, + bool float2half + ); + +/** + * @brief For half precision GEMM, pack the right hand + * side matrix B + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix +*/ +void +MLASCALL +MlasHalfGemmPackB( + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB + ); + +/** + * @brief For half precision GEMM, convert the float matrix B + * to half precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix +*/ +void +MLASCALL +MlasHalfGemmConvertPackB( + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ); diff --git a/onnxruntime/core/mlas/inc/mlas_float16.h b/onnxruntime/core/mlas/inc/mlas_float16.h new file mode 100644 index 0000000000000..33227ea90d6be --- /dev/null +++ b/onnxruntime/core/mlas/inc/mlas_float16.h @@ -0,0 +1,115 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + mlas_float16.h + +Abstract: + + Utilities for half precision floating type conversions. Used internally + by MLAS on platforms without half precision support. Provided here as + convenience for tests or other client libraries/apps. + +--*/ + +#pragma once + +#include +#include +#include + + +using _mlas_fp16_ = uint16_t; + +union fp32_bits { + uint32_t u; + float f; +}; + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) + +/*PreFast told us to convert them to constexpr but the compiler says we can't.*/ +#pragma warning(disable : 26497) + +/*Added whole bunch of casts, still can't get rid of these overflow warnings.*/ +#pragma warning(disable : 26450) +#pragma warning(disable : 26451) +#endif + +inline +_mlas_fp16_ +MLAS_Float2Half(float ff) +{ + constexpr fp32_bits f32infty = {255 << 23}; + constexpr fp32_bits f16max = {(127 + 16) << 23}; + constexpr fp32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; + constexpr uint32_t sign_mask = 0x80000000u; + + auto val = static_cast(0x0u); + fp32_bits f; + f.f = ff; + + uint32_t sign = f.u & sign_mask; + f.u ^= sign; + + if (f.u >= f16max.u) { + // Inf or NaN (all exponent bits set) + val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + } else { + if (f.u < (113 << 23)) { + // Subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.f += denorm_magic.f; + + // and one integer subtract of the bias later, we have our final float! + val = static_cast(f.u - denorm_magic.u); + } else { + uint32_t mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + f.u += ((uint32_t)(15 - 127) << 23) + 0xfff; + // rounding bias part 2 + f.u += mant_odd; + // take the bits! + val = static_cast(f.u >> 13); + } + } + + val |= static_cast(sign >> 16); + return val; +} + +inline +float +MLAS_Half2Float(_mlas_fp16_ val) +{ + constexpr fp32_bits magic = {113 << 23}; + constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift + fp32_bits o; + + o.u = (val & 0x7fff) << 13; // exponent/mantissa bits + uint32_t exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // renormalize + } + + o.u |= (val & 0x8000) << 16; // sign bit + return o.f; +} + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S new file mode 100644 index 0000000000000..036928d21b8ca --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/HalfGemmKernelNeon.S @@ -0,0 +1,550 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + HalfGemmKernelNeon.s + +Abstract: + + This module implements the kernels for the half precision matrix/matrix + multiply operation (HALF GEMM). + +--*/ + +#include "asmmacro.h" + +// +// Stack frame layout for the half gemm kernel. +// Callee save registers: d8-d15, x19-x30. x18 is reserved by the OS. +// + .equ .LHGemmKernelFrame_SavedRegs, (2 * 8) + .equ .LHGemmKernelFrame_B, 0 + .LHGemmKernelFrame_SavedRegs + .equ .LHGemmKernelFrame_ldb, 8 + .LHGemmKernelFrame_SavedRegs + .equ .LHGemmKernelFrame_ZeroMode, 16 + .LHGemmKernelFrame_SavedRegs + + .text + +/*++ + +Routine Description: + + This routine is an inner kernel to compute 6 rows of GEMM + +Arguments: + + CountM - (x0) the number of rows for matrix A and matrix C. + only process 6 rows + + CountN - (x1) the number of columns from matrix B and matrix C + + CountK - (x2/x0) the number of columns from matrix A and the + number of rows from matrix B. + + C - (x3) the address of matrix C. + + ldc - (x4) - the first dimension of matrix C. + + Bias - (x5) - the address of the Bias vector (optional) + + A - (x6) - the address of matrix A + + lda - (x7) - the first dimension of matrix A + + B - the address of matrix B + + ldb - the first dimension of matrix B + + ZeroMode - true if the output matrix must be zero initialized, else + if the output matrix is accumulated into + +--*/ + + FUNCTION_ENTRY MlasHalfGemmKernelNeon + + str x19,[sp,#-.LHGemmKernelFrame_SavedRegs]! + ldr x9,[sp,#.LHGemmKernelFrame_ldb] + lsl x2,x2,#1 // k *= sizeof(fp16) + cmp x0,2 + add x14,x6,x7,lsl #1 // a1 = a0 + lda + add x10,x3,x4,lsl #1 // c1 = c0 + ldc + ldr x8,[sp,#.LHGemmKernelFrame_B] + csel x14,x6,x14,LO // M < 2 ? a1 = a0 + csel x10,x3,x10,LO // c1 = c0 + add x15,x14,x7,lsl #1 // a2 = a1 + lda + add x11,x10,x4,lsl #1 // c2 = c1 + ldc + csel x15,x14,x15,LS // M <= 2 ? a2 = a1 + csel x11,x10,x11,LS // c2 = c1 + cmp x0,4 + add x16,x15,x7,lsl #1 // a3 = a2 + lda + add x12,x11,x4,lsl #1 // c3 = c2 + ldc + csel x16,x15,x16,LO // M < 4 ? a3 = a2 + csel x12,x11,x12,LO // c3 = c2 + add x17,x16,x7,lsl #1 // a4 = a3 + lda + add x13,x12,x4,lsl #1 // c4 = c3 + ldc + csel x17,x16,x17,LS // M <= 4 ? a4 = a3 + csel x13,x12,x13,LS // c4 = c3 + cmp x0,6 + add x7,x17,x7,lsl #1 // a5 = a4 + lda + add x4,x13,x4,lsl #1 // c5 = c4 + ldc + csel x7,x17,x7,LO // M < 6 ? a5 = a4 + csel x4,x13,x4,LO // c5 = c4 + lsl x9,x9,#1 // ldb *= sizeof(fp16) + ldrb w19,[sp,#.LHGemmKernelFrame_ZeroMode] + sub x9,x9,16 // ldb -= 16 + +/**** +Main loop processes 6x16 tile, depth 4. + B 4x16 + --------------------------------------- + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + A 6x4 --------------------------------------- + ------------------ --------------------------------------- +x6 |v0.h[0]..v0.h[3]| |v20.h[0]..v20.h[7] v21.h[0]..v21.h[7]| x3 +x14 |v1.h[0]..v1.h[3]| |v22.h[0]..v22.h[7] v23.h[0]..v23.h[7]| x10 +x15 |v2.h[0]..v2.h[3]| |v24.h[0]..v24.h[7] v25.h[0]..v25.h[7]| x11 +x16 |v3.h[0]..v3.h[3]| |v26.h[0]..v26.h[7] v27.h[0]..v27.h[7]| x12 +x17 |v4.h[0]..v4.h[3]| |v28.h[0]..v28.h[7] v29.h[0]..v29.h[7]| x13 +x7 |v5.h[0]..v5.h[3]| |v30.h[0]..v30.h[7] v31.h[0]..v31.h[7]| x4 + ------------------ --------------------------------------- +****/ + +.LM6N16OutterLoopN: + cbz x5,.LM6N16SkipBias + ldp q20,q21,[x5],32 // Load 16 Bias values + b .LM6N16PopulateAccumulators + +.LM6N16SkipBias: + eor v20.16b,v20.16b,v20.16b // No bias, reset regs + eor v21.16b,v21.16b,v21.16b + +.LM6N16PopulateAccumulators: + mov v22.16b,v20.16b + mov v23.16b,v21.16b + mov v24.16b,v20.16b + mov v25.16b,v21.16b + mov v26.16b,v20.16b + mov v27.16b,v21.16b + mov v28.16b,v20.16b + subs x0,x2,8 // k -= 4 (8 bytes) + mov v29.16b,v21.16b + mov v30.16b,v20.16b + mov v31.16b,v21.16b + b.LO .LM6N16RemainderK123 // remaining k 1~3 + + ldr d0,[x6],8 // A0 + ldr q16,[x8],16 // B0.l + ld1 {v17.16b},[x8],x9 // B0.high x8 <- next row + subs x0,x0,8 // over decement k -= 4 (8 bytes) + ldr d1,[x14],8 // A1 + ldr d2,[x15],8 // A2 + ldr d3,[x16],8 // A3 + b.LO .LM6N16LoopK_Epilogue // need k>=8 for main loop + +.LM6N16InnerLoopK: + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] + ldr q18,[x8],16 // B1.low + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + subs x0,x0,8 // k -= 4 + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] + ldr q16,[x8],16 // B2.low + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] + ldr q18,[x8],16 // B3.low + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + ldr q16,[x8],16 // Load B0.low for next iter + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + ld1 {v17.16b},[x8],x9 // Load B0.high for next iter + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + ldr d0,[x6],8 // Load A0 for next iter + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + ldr d1,[x14],8 // Load A1 for next iter + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + ldr d2,[x15],8 // Load A2 for next iter + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + ldr d3,[x16],8 // Load A3 for next iter + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.hs .LM6N16InnerLoopK // k >= 8 for main loop + +.LM6N16LoopK_Epilogue: + // last block of k >= 4, no pre-load for next iter + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] + ldr q18,[x8],16 // B1.low + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + adds x0,x0,8 // revert k over-decrement + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] + ldr q16,[x8],16 // B2.low + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] + ldr q18,[x8],16 // B3.low + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.NE .LM6N16RemainderK123 // remaining k 1~3 + +.LM6N16OutterLoopNTail: + subs x1,x1,16 // N -= 16 + ldr x8,[sp,#.LHGemmKernelFrame_B] + b.LO .LM6StoreRemainderN // remaining N < 16 + + cbnz x19,.LM6N16SkipAccumulateOutput + ldp q0,q1,[x3] + ldp q2,q3,[x10] + ldp q4,q5,[x11] + ldp q6,q7,[x12] + ldp q16,q17,[x13] + ldp q18,q19,[x4] + fadd v20.8h,v20.8h,v0.8h // !ZeroMode + fadd v21.8h,v21.8h,v1.8h // accumulate into C + fadd v22.8h,v22.8h,v2.8h + fadd v23.8h,v23.8h,v3.8h + fadd v24.8h,v24.8h,v4.8h + fadd v25.8h,v25.8h,v5.8h + fadd v26.8h,v26.8h,v6.8h + fadd v27.8h,v27.8h,v7.8h + fadd v28.8h,v28.8h,v16.8h + fadd v29.8h,v29.8h,v17.8h + fadd v30.8h,v30.8h,v18.8h + fadd v31.8h,v31.8h,v19.8h + +.LM6N16SkipAccumulateOutput: + st1 {v20.16b,v21.16b},[x3],32 + sub x6,x6,x2 // restore a0 + st1 {v22.16b,v23.16b},[x10],32 + sub x14,x14,x2 // restore a1 + st1 {v24.16b,v25.16b},[x11],32 + sub x15,x15,x2 // restore a2 + st1 {v26.16b,v27.16b},[x12],32 + sub x16,x16,x2 // restore a3 + st1 {v28.16b,v29.16b},[x13],32 + sub x17,x17,x2 // restore a4 + add x8,x8,32 // B <- next 16 columns + st1 {v30.16b,v31.16b},[x4],32 + sub x7,x7,x2 // restore a5 + str x8,[sp,#.LHGemmKernelFrame_B] + b.HI .LM6N16OutterLoopN + +.LExitKernel: + ldr x19,[sp],#.LHGemmKernelFrame_SavedRegs + ret + +.LM6N16RemainderK123: + tbz x0,2,.LM6N16RemainderK1 + ldr s0,[x6],4 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr s1,[x14],4 // A1 + ldr s2,[x15],4 // A2 + ldr s3,[x16],4 // A3 + ldr s4,[x17],4 // A4 + ldr s5,[x7],4 // A5 + ldr q18,[x8],16 // B1.low + ld1 {v19.16b},[x8],x9 // B2.high + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + + fmla v20.8h,v18.8h,v0.h[1] + fmla v22.8h,v18.8h,v1.h[1] + fmla v24.8h,v18.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v21.8h,v19.8h,v0.h[1] + fmla v23.8h,v19.8h,v1.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v31.8h,v19.8h,v5.h[1] + tbz x0,1,.LM6N16OutterLoopNTail + +.LM6N16RemainderK1: + ldr h0,[x6],2 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr h1,[x14],2 // A1 + ldr h2,[x15],2 // A2 + ldr h3,[x16],2 // A3 + ldr h4,[x17],2 // A4 + ldr h5,[x7],2 // A5 + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + b .LM6N16OutterLoopNTail + +.LM6StoreRemainderN: + cbnz x19,.LM6StoreRemainderNZeroMode + tbz x1,3,.LM6StoreRemainderN4 + ldr q0,[x3] + ldr q1,[x10] + ldr q2,[x11] + ldr q3,[x12] + ldr q4,[x13] + ldr q5,[x4] + fadd v20.8h,v20.8h,v0.8h + fadd v22.8h,v22.8h,v1.8h + fadd v24.8h,v24.8h,v2.8h + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b + fadd v26.8h,v26.8h,v3.8h + fadd v28.8h,v28.8h,v4.8h + fadd v30.8h,v30.8h,v5.8h + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b + +.LM6StoreRemainderN4: + tbz x1,2,.LM6StoreRemainderN2 + ldr d0,[x3] + ldr d1,[x10] + ldr d2,[x11] + ldr d3,[x12] + ldr d4,[x13] + ldr d5,[x4] + fadd v21.4h,v20.4h,v0.4h + dup d20,v20.d[1] + fadd v23.4h,v22.4h,v1.4h + dup d22,v22.d[1] + fadd v25.4h,v24.4h,v2.4h + dup d24,v24.d[1] + fadd v27.4h,v26.4h,v3.4h + dup d26,v26.d[1] + fadd v29.4h,v28.4h,v4.4h + dup d28,v28.d[1] + fadd v31.4h,v30.4h,v5.4h + dup d30,v30.d[1] + str d21,[x3],8 + str d23,[x10],8 + str d25,[x11],8 + str d27,[x12],8 + str d29,[x13],8 + str d31,[x4],8 + +.LM6StoreRemainderN2: + tbz x1,1,.LM6StoreRemainderN1 + ldr s0,[x3] + ldr s1,[x10] + ldr s2,[x11] + ldr s3,[x12] + ldr s4,[x13] + ldr s5,[x4] + fadd v21.4h,v20.4h,v0.4h + fadd v23.4h,v22.4h,v1.4h + fadd v25.4h,v24.4h,v2.4h + fadd v27.4h,v26.4h,v3.4h + fadd v29.4h,v28.4h,v4.4h + fadd v31.4h,v30.4h,v5.4h + str s21,[x3],4 + str s23,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s25,[x11],4 + str s27,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s29,[x13],4 + str s31,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] + +.LM6StoreRemainderN1: + tbz x1,0,.LExitKernel + ldr h0,[x3] + ldr h1,[x10] + ldr h2,[x11] + ldr h3,[x12] + ldr h4,[x13] + ldr h5,[x4] + fadd v20.4h,v20.4h,v0.4h + fadd v22.4h,v22.4h,v1.4h + fadd v24.4h,v24.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v28.4h,v28.4h,v4.4h + fadd v30.4h,v30.4h,v5.4h + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] + b .LExitKernel + +.LM6StoreRemainderNZeroMode: + tbz x1,3,.LM6StoreRemainderN4ZeroMode + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b + +.LM6StoreRemainderN4ZeroMode: + tbz x1,2,.LM6StoreRemainderN2ZeroMode + str d20,[x3],8 + str d22,[x10],8 + dup d20,v20.d[1] + dup d22,v22.d[1] + str d24,[x11],8 + str d26,[x12],8 + dup d24,v24.d[1] + dup d26,v26.d[1] + str d28,[x13],8 + str d30,[x4],8 + dup d28,v28.d[1] + dup d30,v30.d[1] + +.LM6StoreRemainderN2ZeroMode: + tbz x1,1,.LM6StoreRemainderN1ZeroMode + str s20,[x3],4 + str s22,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s24,[x11],4 + str s26,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s28,[x13],4 + str s30,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] + +.LM6StoreRemainderN1ZeroMode: + tbz x1,0,.LExitKernel + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] + b .LExitKernel + + .end diff --git a/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm new file mode 100644 index 0000000000000..d7b626327780c --- /dev/null +++ b/onnxruntime/core/mlas/lib/arm64/HalfGemmKernelNeon.asm @@ -0,0 +1,552 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + HalfGemmKernelNeon.asm + +Abstract: + + This module implements the kernels for the half precision matrix/matrix + multiply operation (HALF GEMM). + +--*/ + +#include "kxarm64.h" + +// +// Stack frame layout for the half gemm kernel. +// Callee save registers: d8-d15, x19-x30. x18 is reserved by the OS. +// + +#define HGemmKernelFrame_SavedRegs (2 * 8) +#define HGemmKernelFrame_B 0 + HGemmKernelFrame_SavedRegs +#define HGemmKernelFrame_ldb 8 + HGemmKernelFrame_SavedRegs +#define HGemmKernelFrame_ZeroMode 16 + HGemmKernelFrame_SavedRegs + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute 6 rows of GEMM + +Arguments: + + CountM - (x0) the number of rows for matrix A and matrix C. + only process 6 rows + + CountN - (x1) the number of columns from matrix B and matrix C + + CountK - (x2/x0) the number of columns from matrix A and the + number of rows from matrix B. + + C - (x3) the address of matrix C. + + ldc - (x4) - the first dimension of matrix C. + + Bias - (x5) - the address of the Bias vector (optional) + + A - (x6) - the address of matrix A + + lda - (x7) - the first dimension of matrix A + + B - the address of matrix B + + ldb - the first dimension of matrix B + + ZeroMode - true if the output matrix must be zero initialized, else + if the output matrix is accumulated into + +--*/ + + LEAF_ENTRY MlasHalfGemmKernelNeon + + PROLOG_SAVE_REG x19,#-HGemmKernelFrame_SavedRegs! + ldr x9,[sp,#HGemmKernelFrame_ldb] + lsl x2,x2,#1 // k *= sizeof(fp16) + cmp x0,2 + add x14,x6,x7,lsl #1 // a1 = a0 + lda + add x10,x3,x4,lsl #1 // c1 = c0 + ldc + ldr x8,[sp,#HGemmKernelFrame_B] + csel x14,x6,x14,LO // M < 2 ? a1 = a0 + csel x10,x3,x10,LO // c1 = c0 + add x15,x14,x7,lsl #1 // a2 = a1 + lda + add x11,x10,x4,lsl #1 // c2 = c1 + ldc + csel x15,x14,x15,LS // M <= 2 ? a2 = a1 + csel x11,x10,x11,LS // c2 = c1 + cmp x0,4 + add x16,x15,x7,lsl #1 // a3 = a2 + lda + add x12,x11,x4,lsl #1 // c3 = c2 + ldc + csel x16,x15,x16,LO // M < 4 ? a3 = a2 + csel x12,x11,x12,LO // c3 = c2 + add x17,x16,x7,lsl #1 // a4 = a3 + lda + add x13,x12,x4,lsl #1 // c4 = c3 + ldc + csel x17,x16,x17,LS // M <= 4 ? a4 = a3 + csel x13,x12,x13,LS // c4 = c3 + cmp x0,6 + add x7,x17,x7,lsl #1 // a5 = a4 + lda + add x4,x13,x4,lsl #1 // c5 = c4 + ldc + csel x7,x17,x7,LO // M < 6 ? a5 = a4 + csel x4,x13,x4,LO // c5 = c4 + lsl x9,x9,#1 // ldb *= sizeof(fp16) + ldrb w19,[sp,#HGemmKernelFrame_ZeroMode] + sub x9,x9,16 // ldb -= 16 + +/**** +Main loop processes 6x16 tile, depth 4. + B 4x16 + --------------------------------------- + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + |v16.h[0]..v16.h[7] v17.h[0]..v17.h[7]| x8 + |v18.h[0]..v18.h[7] v19.h[0]..v19.h[7]| x8 + A 6x4 --------------------------------------- + ------------------ --------------------------------------- +x6 |v0.h[0]..v0.h[3]| |v20.h[0]..v20.h[7] v21.h[0]..v21.h[7]| x3 +x14 |v1.h[0]..v1.h[3]| |v22.h[0]..v22.h[7] v23.h[0]..v23.h[7]| x10 +x15 |v2.h[0]..v2.h[3]| |v24.h[0]..v24.h[7] v25.h[0]..v25.h[7]| x11 +x16 |v3.h[0]..v3.h[3]| |v26.h[0]..v26.h[7] v27.h[0]..v27.h[7]| x12 +x17 |v4.h[0]..v4.h[3]| |v28.h[0]..v28.h[7] v29.h[0]..v29.h[7]| x13 +x7 |v5.h[0]..v5.h[3]| |v30.h[0]..v30.h[7] v31.h[0]..v31.h[7]| x4 + ------------------ --------------------------------------- +****/ + +M6N16OutterLoopN + cbz x5,M6N16SkipBias + ldp q20,q21,[x5],32 // Load 16 Bias values + b M6N16PopulateAccumulators + +M6N16SkipBias + eor q20.16b,q20.16b,q20.16b // No bias, reset regs + eor q21.16b,q21.16b,q21.16b + +M6N16PopulateAccumulators + mov v22.16b,v20.16b + mov v23.16b,v21.16b + mov v24.16b,v20.16b + mov v25.16b,v21.16b + mov v26.16b,v20.16b + mov v27.16b,v21.16b + mov v28.16b,v20.16b + subs x0,x2,8 // k -= 4 (8 bytes) + mov v29.16b,v21.16b + mov v30.16b,v20.16b + mov v31.16b,v21.16b + b.LO M6N16RemainderK123 // remaining k 1~3 + + ldr d0,[x6],8 // A0 + ldr q16,[x8],16 // B0.l + ld1 {v17.16b},[x8],x9 // B0.high x8 <- next row + subs x0,x0,8 // over decement k -= 4 (8 bytes) + ldr d1,[x14],8 // A1 + ldr d2,[x15],8 // A2 + ldr d3,[x16],8 // A3 + b.LO M6N16LoopK_Epilogue // need k>=8 for main loop + +M6N16InnerLoopK + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] + ldr q18,[x8],16 // B1.low + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + subs x0,x0,8 // k -= 4 + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] + ldr q16,[x8],16 // B2.low + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] + ldr q18,[x8],16 // B3.low + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + ldr q16,[x8],16 // Load B0.low for next iter + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + ld1 {v17.16b},[x8],x9 // Load B0.high for next iter + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + ldr d0,[x6],8 // Load A0 for next iter + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + ldr d1,[x14],8 // Load A1 for next iter + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + ldr d2,[x15],8 // Load A2 for next iter + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + ldr d3,[x16],8 // Load A3 for next iter + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.hs M6N16InnerLoopK // k >= 8 for main loop + +M6N16LoopK_Epilogue + // last block of k >= 4, no pre-load for next iter + fmla v20.8h,v16.8h,v0.h[0] + fmla v21.8h,v17.8h,v0.h[0] + ldr d4,[x17],8 // A4 + fmla v22.8h,v16.8h,v1.h[0] + fmla v23.8h,v17.8h,v1.h[0] + ldr d5,[x7],8 // A5 + fmla v24.8h,v16.8h,v2.h[0] + fmla v25.8h,v17.8h,v2.h[0] + ldr q18,[x8],16 // B1.low + fmla v26.8h,v16.8h,v3.h[0] + fmla v27.8h,v17.8h,v3.h[0] + ld1 {v19.16b},[x8],x9 // B1.high x8 <- next row + fmla v28.8h,v16.8h,v4.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v31.8h,v17.8h,v5.h[0] + adds x0,x0,8 // revert k over-decrement + + fmla v20.8h,v18.8h,v0.h[1] + fmla v21.8h,v19.8h,v0.h[1] + ldr q16,[x8],16 // B2.low + fmla v22.8h,v18.8h,v1.h[1] + fmla v23.8h,v19.8h,v1.h[1] + ld1 {v17.16b},[x8],x9 // B2.high x8 <- next row + fmla v24.8h,v18.8h,v2.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v31.8h,v19.8h,v5.h[1] + + fmla v20.8h,v16.8h,v0.h[2] + fmla v21.8h,v17.8h,v0.h[2] + ldr q18,[x8],16 // B3.low + fmla v22.8h,v16.8h,v1.h[2] + fmla v23.8h,v17.8h,v1.h[2] + ld1 {v19.16b},[x8],x9 // B3.high x8 <- next row + fmla v24.8h,v16.8h,v2.h[2] + fmla v25.8h,v17.8h,v2.h[2] + fmla v26.8h,v16.8h,v3.h[2] + fmla v27.8h,v17.8h,v3.h[2] + fmla v28.8h,v16.8h,v4.h[2] + fmla v29.8h,v17.8h,v4.h[2] + fmla v30.8h,v16.8h,v5.h[2] + fmla v31.8h,v17.8h,v5.h[2] + + fmla v20.8h,v18.8h,v0.h[3] + fmla v21.8h,v19.8h,v0.h[3] + fmla v22.8h,v18.8h,v1.h[3] + fmla v23.8h,v19.8h,v1.h[3] + fmla v24.8h,v18.8h,v2.h[3] + fmla v25.8h,v19.8h,v2.h[3] + fmla v26.8h,v18.8h,v3.h[3] + fmla v27.8h,v19.8h,v3.h[3] + fmla v28.8h,v18.8h,v4.h[3] + fmla v29.8h,v19.8h,v4.h[3] + fmla v30.8h,v18.8h,v5.h[3] + fmla v31.8h,v19.8h,v5.h[3] + b.NE M6N16RemainderK123 // remaining k 1~3 + +M6N16OutterLoopNTail + subs x1,x1,16 // N -= 16 + ldr x8,[sp,#HGemmKernelFrame_B] + b.LO M6StoreRemainderN // remaining N < 16 + + cbnz x19,M6N16SkipAccumulateOutput + ldp q0,q1,[x3] + ldp q2,q3,[x10] + ldp q4,q5,[x11] + ldp q6,q7,[x12] + ldp q16,q17,[x13] + ldp q18,q19,[x4] + fadd v20.8h,v20.8h,v0.8h // !ZeroMode + fadd v21.8h,v21.8h,v1.8h // accumulate into C + fadd v22.8h,v22.8h,v2.8h + fadd v23.8h,v23.8h,v3.8h + fadd v24.8h,v24.8h,v4.8h + fadd v25.8h,v25.8h,v5.8h + fadd v26.8h,v26.8h,v6.8h + fadd v27.8h,v27.8h,v7.8h + fadd v28.8h,v28.8h,v16.8h + fadd v29.8h,v29.8h,v17.8h + fadd v30.8h,v30.8h,v18.8h + fadd v31.8h,v31.8h,v19.8h + +M6N16SkipAccumulateOutput + st1 {v20.16b,v21.16b},[x3],32 + sub x6,x6,x2 // restore a0 + st1 {v22.16b,v23.16b},[x10],32 + sub x14,x14,x2 // restore a1 + st1 {v24.16b,v25.16b},[x11],32 + sub x15,x15,x2 // restore a2 + st1 {v26.16b,v27.16b},[x12],32 + sub x16,x16,x2 // restore a3 + st1 {v28.16b,v29.16b},[x13],32 + sub x17,x17,x2 // restore a4 + add x8,x8,32 // B <- next 16 columns + st1 {v30.16b,v31.16b},[x4],32 + sub x7,x7,x2 // restore a5 + str x8,[sp,#HGemmKernelFrame_B] + b.HI M6N16OutterLoopN + +ExitKernel + EPILOG_RESTORE_REG x19,#HGemmKernelFrame_SavedRegs! + EPILOG_RETURN + +M6N16RemainderK123 + tbz x0,2,M6N16RemainderK1 + ldr s0,[x6],4 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr s1,[x14],4 // A1 + ldr s2,[x15],4 // A2 + ldr s3,[x16],4 // A3 + ldr s4,[x17],4 // A4 + ldr s5,[x7],4 // A5 + ldr q18,[x8],16 // B1.low + ld1 {v19.16b},[x8],x9 // B2.high + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + + fmla v20.8h,v18.8h,v0.h[1] + fmla v22.8h,v18.8h,v1.h[1] + fmla v24.8h,v18.8h,v2.h[1] + fmla v26.8h,v18.8h,v3.h[1] + fmla v28.8h,v18.8h,v4.h[1] + fmla v30.8h,v18.8h,v5.h[1] + fmla v21.8h,v19.8h,v0.h[1] + fmla v23.8h,v19.8h,v1.h[1] + fmla v25.8h,v19.8h,v2.h[1] + fmla v27.8h,v19.8h,v3.h[1] + fmla v29.8h,v19.8h,v4.h[1] + fmla v31.8h,v19.8h,v5.h[1] + tbz x0,1,M6N16OutterLoopNTail + +M6N16RemainderK1 + ldr h0,[x6],2 // A0 + ldr q16,[x8],16 // B0.low + ld1 {v17.16b},[x8],x9 // B0.high + ldr h1,[x14],2 // A1 + ldr h2,[x15],2 // A2 + ldr h3,[x16],2 // A3 + ldr h4,[x17],2 // A4 + ldr h5,[x7],2 // A5 + fmla v20.8h,v16.8h,v0.h[0] + fmla v22.8h,v16.8h,v1.h[0] + fmla v24.8h,v16.8h,v2.h[0] + fmla v26.8h,v16.8h,v3.h[0] + fmla v28.8h,v16.8h,v4.h[0] + fmla v30.8h,v16.8h,v5.h[0] + fmla v21.8h,v17.8h,v0.h[0] + fmla v23.8h,v17.8h,v1.h[0] + fmla v25.8h,v17.8h,v2.h[0] + fmla v27.8h,v17.8h,v3.h[0] + fmla v29.8h,v17.8h,v4.h[0] + fmla v31.8h,v17.8h,v5.h[0] + b M6N16OutterLoopNTail + +M6StoreRemainderN + cbnz x19,M6StoreRemainderNZeroMode + tbz x1,3,M6StoreRemainderN4 + ldr q0,[x3] + ldr q1,[x10] + ldr q2,[x11] + ldr q3,[x12] + ldr q4,[x13] + ldr q5,[x4] + fadd v20.8h,v20.8h,v0.8h + fadd v22.8h,v22.8h,v1.8h + fadd v24.8h,v24.8h,v2.8h + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b + fadd v26.8h,v26.8h,v3.8h + fadd v28.8h,v28.8h,v4.8h + fadd v30.8h,v30.8h,v5.8h + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b + +M6StoreRemainderN4 + tbz x1,2,M6StoreRemainderN2 + ldr d0,[x3] + ldr d1,[x10] + ldr d2,[x11] + ldr d3,[x12] + ldr d4,[x13] + ldr d5,[x4] + fadd v21.4h,v20.4h,v0.4h + dup d20,v20.d[1] + fadd v23.4h,v22.4h,v1.4h + dup d22,v22.d[1] + fadd v25.4h,v24.4h,v2.4h + dup d24,v24.d[1] + fadd v27.4h,v26.4h,v3.4h + dup d26,v26.d[1] + fadd v29.4h,v28.4h,v4.4h + dup d28,v28.d[1] + fadd v31.4h,v30.4h,v5.4h + dup d30,v30.d[1] + str d21,[x3],8 + str d23,[x10],8 + str d25,[x11],8 + str d27,[x12],8 + str d29,[x13],8 + str d31,[x4],8 + +M6StoreRemainderN2 + tbz x1,1,M6StoreRemainderN1 + ldr s0,[x3] + ldr s1,[x10] + ldr s2,[x11] + ldr s3,[x12] + ldr s4,[x13] + ldr s5,[x4] + fadd v21.4h,v20.4h,v0.4h + fadd v23.4h,v22.4h,v1.4h + fadd v25.4h,v24.4h,v2.4h + fadd v27.4h,v26.4h,v3.4h + fadd v29.4h,v28.4h,v4.4h + fadd v31.4h,v30.4h,v5.4h + str s21,[x3],4 + str s23,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s25,[x11],4 + str s27,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s29,[x13],4 + str s31,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] + +M6StoreRemainderN1 + tbz x1,0,ExitKernel + ldr h0,[x3] + ldr h1,[x10] + ldr h2,[x11] + ldr h3,[x12] + ldr h4,[x13] + ldr h5,[x4] + fadd v20.4h,v20.4h,v0.4h + fadd v22.4h,v22.4h,v1.4h + fadd v24.4h,v24.4h,v2.4h + fadd v26.4h,v26.4h,v3.4h + fadd v28.4h,v28.4h,v4.4h + fadd v30.4h,v30.4h,v5.4h + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] + b ExitKernel + +M6StoreRemainderNZeroMode + tbz x1,3,M6StoreRemainderN4ZeroMode + str q20,[x3],16 + mov v20.16b,v21.16b + str q22,[x10],16 + mov v22.16b,v23.16b + str q24,[x11],16 + mov v24.16b,v25.16b + str q26,[x12],16 + mov v26.16b,v27.16b + str q28,[x13],16 + mov v28.16b,v29.16b + str q30,[x4],16 + mov v30.16b,v31.16b + +M6StoreRemainderN4ZeroMode + tbz x1,2,M6StoreRemainderN2ZeroMode + str d20,[x3],8 + str d22,[x10],8 + dup d20,v20.d[1] + dup d22,v22.d[1] + str d24,[x11],8 + str d26,[x12],8 + dup d24,v24.d[1] + dup d26,v26.d[1] + str d28,[x13],8 + str d30,[x4],8 + dup d28,v28.d[1] + dup d30,v30.d[1] + +M6StoreRemainderN2ZeroMode + tbz x1,1,M6StoreRemainderN1ZeroMode + str s20,[x3],4 + str s22,[x10],4 + dup s20,v20.s[1] + dup s22,v22.s[1] + str s24,[x11],4 + str s26,[x12],4 + dup s24,v24.s[1] + dup s26,v26.s[1] + str s28,[x13],4 + str s30,[x4],4 + dup s28,v28.s[1] + dup s30,v30.s[1] + +M6StoreRemainderN1ZeroMode + tbz x1,0,ExitKernel + str h20,[x3] + str h22,[x10] + str h24,[x11] + str h26,[x12] + str h28,[x13] + str h30,[x4] + b ExitKernel + + LEAF_END MlasHalfGemmKernelNeon + + END diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp new file mode 100644 index 0000000000000..778db2003d6c6 --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -0,0 +1,334 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + half gemm.cpp + +Abstract: + + This module implements the half precision (fp16) matrix/matrix multiply + operation (QGEMM). + +--*/ + +#include "mlasi.h" +#include "mlas_float16.h" + +#include "halfgemm.h" + +#include + +bool MLASCALL +MlasFp16AccelerationSupported() +{ + return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration(); +} + + +void +MLASCALL +MlasHalfGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_HALF_GEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool + ) +{ + const MLAS_HALFGEMM_DISPATCH* dispatch = MlasHalfGemmGetDispatch(); + MLAS_HALFGEMM_OPERATION* operation = dispatch->Operation; + + if (ThreadPool == nullptr) { + for (size_t gemm_i = 0; gemm_i < BatchN; gemm_i++) { + auto Data = &DataParams[gemm_i]; + operation(N, K, Data, 0, M, 0, N); + } + return; + } + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchN; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + const size_t StrideM = dispatch->StrideM; + + size_t nc = N; + if ((size_t)MlasGetMaximumThreadCount(ThreadPool) > BatchN) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min(nc, MlasDivRoundup(nc, max_nc * MLAS_QGEMM_STRIDEN_THREAD_ALIGN) * + MLAS_QGEMM_STRIDEN_THREAD_ALIGN); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + auto Data = &DataParams[gemm_i]; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + operation(N, K, Data, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} + + +size_t +MLASCALL +MlasHalfGemmPackBSize( + size_t N, + size_t K, + bool float2half + ) +{ + const auto* dispatch = MlasHalfGemmGetDispatch(); + const auto padding = dispatch->BufOverRead; + const auto PackedK = dispatch->PackededK; + if (!float2half && dispatch->CopyPackBRoutine == nullptr) { + // No packing routine provided + return 0; + } + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t BytesRequired = N * AlignedK * FP16_SIZE + padding; + const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); + const size_t AlignedBytesRequired = + (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + return AlignedBytesRequired; +} + +void +MLASCALL +MlasHalfGemmPackB( + size_t N, + size_t K, + const MLAS_FP16* B, + size_t ldb, + void* PackedB + ) +{ + const auto* dispatch = MlasHalfGemmGetDispatch(); + dispatch->CopyPackBRoutine((_mlas_fp16_*)PackedB, (const _mlas_fp16_*)B, ldb, N, K); +} + +void +MLASCALL +MlasHalfGemmConvertPackB( + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ) +{ + const auto* dispatch = MlasHalfGemmGetDispatch(); + dispatch->ConvertPackBRoutine((_mlas_fp16_*)PackedB, B, ldb, N, K); +} + + +// +// Post Processor Implementations +// + +MLAS_FORCEINLINE +void +CvtHalf2Float( + float* dest, + const _mlas_fp16_* src, + size_t len +) +{ +#ifdef MLAS_TARGET_ARM64 + while (len >= 4) { + const auto* srcPtr = reinterpret_cast(src); + auto* dstPtr = reinterpret_cast(dest); + *dstPtr = vcvt_f32_f16(*srcPtr); + src += 4; + dest += 4; + len -= 4; + } + + if (0 == len) { + return; + } + + float16x4_t buf; + std::memcpy(&buf, src, len * sizeof(_mlas_fp16_)); + float32x4_t res = vcvt_f32_f16(buf); + + if ((len & 2) != 0) { + auto wide = vreinterpretq_f64_f32(res); + vst1q_lane_f64((float64_t*)dest, wide, 0); + res = vreinterpretq_f32_f64(vdupq_laneq_f64(wide, 1)); + dest += 2; + } + if ((len & 1) != 0) { + vst1q_lane_f32(dest, res, 0); + } +#else + for (size_t i = 0; i < len; i++) { + *dest++ = MLAS_Half2Float(*src++); + } +#endif // MLAS_TARGET_ARM64 +} + +void +MLAS_HALF_GEMM_2FLOAT_PROCESSOR::Process( + MLAS_FP16* C, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN, + size_t ldc + ) const +{ + // + // TODO!! use templates to add activations in this impl + // + float* Output = Output_; + const auto* CRow = reinterpret_cast(C); + CRow += StartM * ldc + StartN; + Output += StartM * RowStride_ + StartN; + + while (CountM-- > 0) { + CvtHalf2Float(Output, CRow, CountN); + + CRow += ldc; + Output += RowStride_; + } +} + + +// +// Dummy C++ implementation that runs very slowly +// + +struct MLAS_HALF_GEMM_KERNEL_DEFAULT { + + static constexpr bool PackNeeded = false; + static constexpr size_t KernelMaxM = 128; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 1; + + static constexpr MLAS_HALF_GEMM_STRIDES Strides{8, 16, 32}; +}; + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackA( + _mlas_fp16_* D, + const float* A, + size_t lda, + size_t CountM, + size_t CountK +) +{ + for (size_t m = 0; m < CountM; m++) { + for (size_t k = 0; k < CountK; k++) { + *D++ = MLAS_Float2Half(*(A + m * lda + k)); + } + } +} + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackB( + _mlas_fp16_* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK +) +{ + for (size_t k = 0; k < CountK; k++) { + for (size_t n = 0; n < CountN; n++) { + *D++ = MLAS_Float2Half(*(B + k * ldb + n)); + } + } +} + + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmKernel( + size_t CountM, + size_t CountN, + size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + size_t lda, + const _mlas_fp16_* B, + size_t ldb, + const bool ZeroMode) +{ + for (size_t m = 0; m < CountM; m++) { + for (size_t n = 0; n < CountN; n++) { + const auto* a = A + (m * lda); + const auto* b = B + n; + auto* c = C + (m * ldc) + n; + + float sum = Bias == nullptr ? 0.0f : MLAS_Half2Float(Bias[n]); + if (!ZeroMode) { + sum += MLAS_Half2Float(*c); + } + + for (size_t k = 0; k < CountK; k++) { + auto down = MLAS_Float2Half(MLAS_Half2Float(*a) * MLAS_Half2Float(*b) + sum); + sum = MLAS_Half2Float(down); + b += ldb; + a += 1; + } + + *c = MLAS_Float2Half(sum); + } + } +} + + +const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { + MlasHalfGemmOperation, + nullptr, + MlasHalfGemmConvertPackB, + MLAS_HALF_GEMM_KERNEL_DEFAULT::PackedK, + MLAS_HALF_GEMM_KERNEL_DEFAULT::KernelMaxM, + 0 +}; diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h new file mode 100644 index 0000000000000..9e781207571a4 --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -0,0 +1,515 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm.h + +Abstract: + + This module defines the set of template functions to implement half + precision matrix/matrix multiply operation (QGEMM). + + To implement a new kernel, template functions below need to be specialized: + MlasHalfGemmCopyPackB + MlasHalfGemmConvertPackA + MlasHalfGemmConvertPackB + MlasHalfGemmPackedBOffset + MlasHalfGemmPackedBLeadingDim + MlasHalfGemmKernel + + MlasHalfGemmOperation is the shared kernel driver. + + A kernel type should define the following constants: + bool PackNeeded; Whether fp16 B needs to be packed + size_t KernelMaxM; Max # rows the vectorized kernel can process + size_t PackedK; Packed alignment on the K dim (power of 2) + MLAS_HALF_GEMM_STRIDES Strides{128, 128, 128}; +--*/ + +#pragma once + +#include +#include +#include + +#include "mlasi.h" +#include "mlas_float16.h" + + +/** + * @brief Define the default striding parameters for + * the half precision gemm operation + */ +struct MLAS_HALF_GEMM_STRIDES { + size_t M; + size_t N; + size_t K; +}; + +/** + * @brief Packing function for fp16 B matrix + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack +*/ +template +MLAS_FORCEINLINE +void +MlasHalfGemmCopyPackB( + _mlas_fp16_* D, + const _mlas_fp16_* B, + size_t ldb, + size_t CountN, + size_t CountK +) +{ + MLAS_UNREFERENCED_PARAMETER(D); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(ldb); + MLAS_UNREFERENCED_PARAMETER(CountN); + MLAS_UNREFERENCED_PARAMETER(CountK); + // No packing needed by default +} + +/** + * @brief Convert fp32 matrix A to fp16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of the packing buffer + * @param[in] A Address of fp32 matrix A + * @param[in] lda leading dimension of A + * @param[in] CountM # of rows to pack + * @param[in] CountK # of columns to pack +*/ +template +void +MlasHalfGemmConvertPackA( + _mlas_fp16_* D, + const float* A, + size_t lda, + size_t CountM, + size_t CountK +); + +/** + * @brief Convert fp32 matrix B to fp16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B in fp32 + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack + */ +template +void +MlasHalfGemmConvertPackB( + _mlas_fp16_* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK +); + +/** + * @brief Find the location of PackedB[StartK, StartN] + * + * @tparam KernelType + * @param PackedB + * @param DimN Total columns of the packing buffer + * @param DimK Total rows of the packing buffer + * @param StartN + * @param StartK + * @return Address of PackedB[StartK, StartN] +*/ +template +MLAS_FORCEINLINE +const _mlas_fp16_* +MlasHalfGemmPackedBOffset( + const _mlas_fp16_* PackedB, + size_t DimN, + size_t DimK, + size_t StartN, + size_t StartK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return PackedB + StartK * DimN + StartN; +} + +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(push) +/*No it can NOT be constexpr!.*/ +#pragma warning(disable : 26497) +#endif + +/** + * @brief leading dimension of the packed B buffer + * Related to how B is packed + * @tparam KernelType + * @param DimN + * @param DimK + * @return leading dimension of the packed B buffer +*/ +template +MLAS_FORCEINLINE +size_t +MlasHalfGemmPackedBLeadingDim( + size_t DimN, + size_t DimK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return DimN; +} +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif + +template +void +MlasHalfGemmKernel( + const size_t CountM, + const size_t CountN, + const size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + const size_t lda, + const _mlas_fp16_* B, + const size_t ldb, + const bool ZeroMode +); + + +template +MLAS_FORCEINLINE +void +MlasHalfGemmNoPackOperation( + const size_t N, + const size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ) +{ + // + // Optimize for the special case where no packing is needed. + // Simpler tiling as we are not restricted by packing panel size + // + + const size_t lda = Data->lda; + size_t ldb = Data->ldb; // 0 if prepacked + const size_t ldc = Data->ldc; + + const auto* pa = reinterpret_cast(Data->A) + + RangeStartM * lda; + const _mlas_fp16_* pb; + if (ldb == 0) { + pb = MlasHalfGemmPackedBOffset( + reinterpret_cast(Data->B), + N, + K, + RangeStartN, + 0); + ldb = MlasHalfGemmPackedBLeadingDim(N, K); + } else { + pb = reinterpret_cast(Data->B) + RangeStartN; + } + + const _mlas_fp16_* Bias = (nullptr == Data->Bias) + ? nullptr + : reinterpret_cast(Data->Bias) + RangeStartN; + _mlas_fp16_* c = reinterpret_cast<_mlas_fp16_*>(Data->C) + + RangeStartM * ldc + RangeStartN; + + size_t RowsRemaining = RangeCountM; + while (RowsRemaining > 0) { + MlasHalfGemmKernel( + RowsRemaining, + RangeCountN, + K, + c, + ldc, + Bias, + pa, + lda, + pb, + ldb, + true); + + size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); + + if (Data->OutputProcessor != nullptr) { + Data->OutputProcessor->Process( + Data->C, + RangeStartM + RangeCountM - RowsRemaining, + RangeStartN, + RowsHandled, + RangeCountN, + Data->ldc); + } + + c += ldc * RowsHandled; + pa += lda * RowsHandled; + RowsRemaining -= RowsHandled; + } +} + + +template +void +MlasHalfGemmOperation( + const size_t N, + const size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ) +{ + const size_t lda = Data->lda; + const size_t ldb = Data->ldb; + const size_t ldc = Data->ldc; + + if (!Data->AIsfp32 && (ldb == 0 || (!KernelType::PackNeeded && !Data->BIsfp32))) { + // !Data->AIsfp32 => A is fp16, no packing on the left hand side + // ldb == 0 => B is already packed, no packing on the right hand side + // !KernelType::PackNeeded && !Data->BIsfp32 => B is fp16 and the kernel + // does not require packing + // + // So no packing needed on either A or B, use a simpler driver instead + + MlasHalfGemmNoPackOperation( + N, + K, + Data, + RangeStartM, + RangeCountM, + RangeStartN, + RangeCountN); + return; + } + + const auto* Bias = reinterpret_cast(Data->Bias); + _mlas_fp16_* C = reinterpret_cast<_mlas_fp16_*>(Data->C) + + RangeStartM * ldc + RangeStartN; + + // + // Three dimensional tiling due to limited packing panel size + // + constexpr MLAS_HALF_GEMM_STRIDES Strides = KernelType::Strides; + constexpr size_t packASize = UpAlignSize(Strides.M * Strides.K * FP16_SIZE); + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * FP16_SIZE); + MlasThreadedBufAlloc(packASize + packBSize); + + uint8_t* p = ThreadedBufHolder.get(); + auto* PanelA = reinterpret_cast<_mlas_fp16_*>(p); + p += packASize; + auto* PanelB = reinterpret_cast<_mlas_fp16_*>(p); + + // + // Step through each slice of matrix B along the K dimension. + // + + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, Strides.K); + const size_t PackedCountK = (CountK + KernelType::PackedK - 1) / KernelType::PackedK; + + // + // Step through each slice of matrix B along the N dimension. + // + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, Strides.N); + + // + // Copy a panel of matrix B to a local packed buffer. + // + size_t ld_pb; + const _mlas_fp16_* pb; + if (ldb == 0) { + // Already packed + pb = MlasHalfGemmPackedBOffset( + reinterpret_cast(Data->B), + N, + K, + RangeStartN + n, + k); + ld_pb = MlasHalfGemmPackedBLeadingDim(N, K); + } else if (Data->BIsfp32) { + // fp32, need conversion and packing + MlasHalfGemmConvertPackB( + PanelB, + reinterpret_cast(Data->B) + ldb * k + RangeStartN + n, + ldb, + CountN, + CountK); + pb = PanelB; + ld_pb = MlasHalfGemmPackedBLeadingDim(CountN, CountK); + } else if (KernelType::PackNeeded) { + // fp16, need packing + MlasHalfGemmCopyPackB( + PanelB, + reinterpret_cast(Data->B) + ldb * k + RangeStartN + n, + ldb, + CountN, + CountK); + pb = PanelB; + ld_pb = MlasHalfGemmPackedBLeadingDim(CountN, CountK); + } else { + // fp16, and no packing needed + pb = reinterpret_cast(Data->B) + ldb * k + RangeStartN + n; + ld_pb = ldb; + } + + // + // Step through each slice of matrix A along the M dimension. + // + + auto* c = C + n; + const auto* pbias = (nullptr == Bias) ? nullptr : Bias + RangeStartN + n; + size_t CountM; + for (size_t m = 0; m < RangeCountM; m += CountM) { + CountM = std::min(RangeCountM - m, Strides.M); + + // + // Copy a panel of matrix A to a local packed buffer. + // + const _mlas_fp16_* pa; + size_t ld_pa; + if (Data->AIsfp32) { + MlasHalfGemmConvertPackA( + PanelA, + reinterpret_cast(Data->A) + (RangeStartM + m) * lda + k, + lda, + CountM, + CountK); + pa = PanelA; + ld_pa = KernelType::PackedK * PackedCountK; + } else { + pa = reinterpret_cast(Data->A) + (RangeStartM + m) * lda + k; + ld_pa = lda; + } + + size_t RowsRemaining = CountM; + bool ZeroMode = (k == 0); + bool PostProcess = (k + CountK == K); + + while (RowsRemaining > 0) { + MlasHalfGemmKernel( + RowsRemaining, + CountN, + CountK, + c, + ldc, + ZeroMode ? pbias : nullptr, + pa, + ld_pa, + pb, + ld_pb, + ZeroMode); + + size_t RowsHandled = std::min(RowsRemaining, KernelType::KernelMaxM); + + if (PostProcess && Data->OutputProcessor != nullptr) { + Data->OutputProcessor->Process( + Data->C, + RangeStartM + m + CountM - RowsRemaining, + RangeStartN + n, + RowsHandled, + CountN, + Data->ldc); + } + + c += ldc * RowsHandled; + pa += ld_pa * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + } + } +} + + +// +// dispatch structure. +// + +typedef +void +(MLAS_HALFGEMM_OPERATION)( + const size_t N, + const size_t K, + const MLAS_HALF_GEMM_DATA_PARAMS* Data, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN + ); + + +typedef +void +(MLAS_HALFGEMM_COPYPACKB_ROUTINE)( + _mlas_fp16_* D, + const _mlas_fp16_* B, + size_t ldb, + size_t CountN, + size_t CountK + ); + +typedef +void +(MLAS_HALFGEMM_CONVERTPACKB_ROUTINE)( + _mlas_fp16_* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK + ); + +/** + * @brief Hardware dependent dispatch for half precision GEMM +*/ +struct MLAS_HALFGEMM_DISPATCH { + MLAS_HALFGEMM_OPERATION* Operation; /**< HalfGemm driver */ + MLAS_HALFGEMM_COPYPACKB_ROUTINE* CopyPackBRoutine; /**< Pack function for B */ + MLAS_HALFGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ + size_t PackededK; + size_t StrideM; + size_t BufOverRead; +}; + +extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault; + +#if defined(MLAS_TARGET_ARM64) +extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon; +#endif + +MLAS_FORCEINLINE +const MLAS_HALFGEMM_DISPATCH* +MlasHalfGemmGetDispatch() +{ +#if defined(MLAS_TARGET_ARM64) + return &MlasHalfGemmDispatchNeon; +#else + return &MlasHalfGemmDispatchDefault; +#endif +} diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..d7f5a90b00589 --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon.cpp @@ -0,0 +1,187 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm_kernel_neon.cpp + +Abstract: + + This module implements half precision GEMM kernel for neon. + +--*/ + +#include "mlasi.h" +#include "halfgemm.h" + +#include "arm_neon.h" + +// +// Define the prototypes of the NEON routines written in assembly. +// +// N.B. The kernel has not been ported to build with the Windows ARM32 toolset. +// + +extern "C" { + + size_t + MLASCALL + MlasHalfGemmKernelNeon( + const size_t CountM, + const size_t CountN, + const size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + const size_t lda, + const _mlas_fp16_* B, + const size_t ldb, + const bool ZeroMode + ); + +} + + +struct MLAS_HALF_GEMM_KERNEL_NEON { + static constexpr bool PackNeeded = false; + static constexpr size_t KernelMaxM = 6; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 1; + + static constexpr MLAS_HALF_GEMM_STRIDES Strides{24, 128, 512}; +}; + + +MLAS_FORCEINLINE +void +CvtFloat2Half( + _mlas_fp16_* dest, + const float* src, + size_t len +) +{ + while (len >= 4) { + const auto* srcPtr = reinterpret_cast(src); + auto* dstPtr = reinterpret_cast(dest); + *dstPtr = vcvt_f16_f32(*srcPtr); + src += 4; + dest += 4; + len -= 4; + } + + if (0 == len) { + return; + } + + float32x4_t buf; + std::memcpy(&buf, src, len * sizeof(float)); + float16x4_t res = vcvt_f16_f32(buf); + + if ((len & 2) != 0) { + auto wide = vreinterpret_f32_f16(res); + vst1_lane_f32((float32_t*)dest, wide, 0); + res = vreinterpret_f16_f32(vdup_lane_f32(wide, 1)); + dest += 2; + } + if ((len & 1) != 0) { + vst1_lane_u16(dest, vreinterpret_u16_f16(res), 0); + } +} + +/** + * @brief Convert a 2D matrix from float to fp16 +*/ +MLAS_FORCEINLINE +void +CvtFloat2Half2D( + _mlas_fp16_* dest, + const float* src, + size_t stride, + size_t CntRow, + size_t CntCol + ) +{ + if (stride == CntCol) { + const size_t len = CntRow * CntCol; + CvtFloat2Half(dest, src, len); + return; + } + while (CntRow > 0) { + CvtFloat2Half(dest, src, CntCol); + src += stride; + dest += CntCol; + CntRow--; + } +} + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackA( + _mlas_fp16_* D, + const float* A, + size_t lda, + size_t CountM, + size_t CountK +) +{ + CvtFloat2Half2D(D, A, lda, CountM, CountK); +} + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmConvertPackB( + _mlas_fp16_* D, + const float* B, + size_t ldb, + size_t CountN, + size_t CountK +) +{ + CvtFloat2Half2D(D, B, ldb, CountK, CountN); +} + + +template<> +MLAS_FORCEINLINE +void +MlasHalfGemmKernel( + size_t CountM, + size_t CountN, + size_t CountK, + _mlas_fp16_* C, + size_t ldc, + const _mlas_fp16_* Bias, + const _mlas_fp16_* A, + size_t lda, + const _mlas_fp16_* B, + size_t ldb, + const bool ZeroMode) +{ + MlasHalfGemmKernelNeon( + CountM, + CountN, + CountK, + C, + ldc, + Bias, + A, + lda, + B, + ldb, + ZeroMode); +} + + +const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon = { + MlasHalfGemmOperation, + nullptr, + MlasHalfGemmConvertPackB, + MLAS_HALF_GEMM_KERNEL_NEON::PackedK, + MLAS_HALF_GEMM_KERNEL_NEON::KernelMaxM, + 32 // kernel may read beyond buffer end by 32 bytes +}; diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 31999f3294999..21949535cf63b 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -107,6 +107,8 @@ Module Name: #include "core/common/cpuid_info.h" using MLAS_CPUIDINFO = onnxruntime::CPUIDInfo; +#include "core/framework/float16.h" + #else // BUILD_MLAS_NO_ONNXRUNTIME class MLASCPUIDInfo @@ -121,6 +123,8 @@ class MLASCPUIDInfo // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } + bool HasFp16VectorAcceleration() const { return has_fp16_; } + uint32_t GetCurrentCoreIdx() const { return 0xFFFFFFFF; } int32_t GetCurrentUarch() const { return -1; } @@ -135,6 +139,7 @@ class MLASCPUIDInfo MLASCPUIDInfo(); bool has_arm_neon_dot_{false}; + bool has_fp16_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -179,7 +184,49 @@ enum MlasUArch { #endif // MLAS_TARGET_ARM64 -#endif // BUILD_MLAS_NO_ONNXRUNTIME +// +// Define MLAS_FP16 +// +#include "mlas_float16.h" + +namespace onnxruntime +{ +struct MLFloat16 { + uint16_t val{0}; + + MLFloat16() = default; + explicit constexpr MLFloat16(uint16_t x) : val(x) {} + explicit MLFloat16(float ff) : val(MLAS_Float2Half(ff)) {} + + float ToFloat() const { return MLAS_Half2Float(val); } + + operator float() const { return ToFloat(); } + + MLFloat16& operator=(float ff) + { + val = MLAS_Float2Half(ff); + return *this; + } +}; + +inline bool +operator==(const MLFloat16& left, const MLFloat16& right) +{ + return left.val == right.val; +} + +inline bool +operator!=(const MLFloat16& left, const MLFloat16& right) +{ + return left.val != right.val; +} + +} + +#endif // BUILD_MLAS_NO_ONNXRUNTIME + +static_assert(sizeof(MLAS_FP16) == FP16_SIZE); + // // Define the maximum number of threads supported by this implementation. @@ -700,9 +747,9 @@ extern "C" { // thread to perform additional work. // -#define MLAS_SGEMM_THREAD_COMPLEXITY (64 * 1024) -#define MLAS_DGEMM_THREAD_COMPLEXITY (64 * 1024) -#define MLAS_QGEMM_THREAD_COMPLEXITY (64 * 1024) +#define MLAS_SGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#define MLAS_QGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) // // Single-threaded single precision matrix/matrix multiply operation. diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 7d8624a32a218..c52d4f3b0b8c4 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -36,6 +36,9 @@ Module Name: MLASCPUIDInfo::MLASCPUIDInfo() { has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); + + // raw hack! Need CPUIDInfo implementation for more precise detection + has_fp16_ = has_arm_neon_dot_; } #endif @@ -50,7 +53,13 @@ MLASCPUIDInfo::MLASCPUIDInfo() #endif #if defined(BUILD_MLAS_NO_ONNXRUNTIME) -MLASCPUIDInfo::MLASCPUIDInfo() { has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); } +MLASCPUIDInfo::MLASCPUIDInfo() +{ + has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); + + // raw hack! Need CPUIDInfo implementation for more precise detection + has_fp16_ = has_arm_neon_dot_; +} #endif #else diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.cpp b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp new file mode 100644 index 0000000000000..652645448d9dd --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.cpp @@ -0,0 +1,206 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_halfgemm.cpp + +Abstract: + + Tests for MLAS half precision GEMM. + +--*/ + + +#include "test_halfgemm.h" + +// +// Short Execute() test helper to register each test seperately by all parameters. +// +template +class HalfGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit HalfGemmShortExecuteTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) + : M_(M), N_(N), K_(K), Batch_(Batch), hasBias_(hasBias) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test(M_, N_, K_, Batch_, hasBias_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) { + std::stringstream ss; + ss << "Batch" << Batch << "/M" << M << "xN" << N << "xK" << K << "/" + << "hasBias" << hasBias; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasHalfGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new HalfGemmShortExecuteTest( + M, N, K, Batch, hasBias); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t test_registered = 0; + + for (size_t b = 1; b < 16; b++) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 16; b <= 256; b <<= 1) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 256; b < 320; b += 32) { + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 1; b < 96; b++) { + test_registered += RegisterSingleTest(1, b, 32, 1, false); + test_registered += RegisterSingleTest(1, 32, b, 1, true); + test_registered += RegisterSingleTest(1, b, b, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(1, b, 32, 3, true); + test_registered += RegisterSingleTest(1, 32, b, 5, false); + } + } + test_registered += RegisterSingleTest(43, 500, 401, 1, true); + test_registered += RegisterSingleTest(1001, 1027, 1031, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(43, 500, 401, 5, true); + test_registered += RegisterSingleTest(1000, 1029, 1030, 3, false); + } + + return test_registered; + } + + private: + size_t M_, N_, K_, Batch_; + bool hasBias_; +}; + + +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); + +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); + +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); + +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); +template <> +MlasHalfGemmTest* MlasTestFixture>::mlas_tester(nullptr); + +static size_t HalfGemmRegistLongExecute() { + size_t count = 0; + + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasHalfGemmPackBSize(128, 128, false) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + if (MlasHalfGemmPackBSize(128, 128, true) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + + if (GetMlasThreadPool() != nullptr) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasHalfGemmPackBSize(128, 128, false) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + if (MlasHalfGemmPackBSize(128, 128, true) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + } + + return count; +} + +static size_t HalfGemmRegistShortExecute() { + size_t count = 0; + + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasHalfGemmPackBSize(128, 128, false) > 0) { + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + } + if (MlasHalfGemmPackBSize(128, 128, true) > 0) { + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + } + + if (GetMlasThreadPool() != nullptr) { + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasHalfGemmPackBSize(128, 128, false) > 0) { + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + } + if (MlasHalfGemmPackBSize(128, 128, true) > 0) { + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + count += HalfGemmShortExecuteTest::RegisterShortExecuteTests(); + } + } + + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + if (!MlasFp16AccelerationSupported()) { + return false; + } + if (is_short_execute) { + return HalfGemmRegistShortExecute() > 0; + } + return HalfGemmRegistLongExecute() > 0; +}); diff --git a/onnxruntime/test/mlas/unittest/test_halfgemm.h b/onnxruntime/test/mlas/unittest/test_halfgemm.h new file mode 100644 index 0000000000000..6b1118d27124e --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_halfgemm.h @@ -0,0 +1,319 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_halfgemm.h + +Abstract: + + Tests for MLAS half precision GEMM. + +--*/ + +#pragma once + +#include "test_util.h" +#include "mlas_float16.h" + + +// +// Define our own fp16 type to avoid dragging in big dependencies +// +struct MLFp16 { + uint16_t val{0}; + + MLFp16() = default; + explicit constexpr MLFp16(uint16_t x) : val(x) {} + explicit constexpr MLFp16(int32_t x) : val((uint16_t)x) {} + explicit MLFp16(float ff) : val(MLAS_Float2Half(ff)) {} + + float ToFloat() const { + return MLAS_Half2Float(val); + } + + operator float() const { return ToFloat(); } + + MLFp16& operator=(float ff) { + val = MLAS_Float2Half(ff); + return *this; + } +}; + +inline bool +operator==(const MLFp16& left, const MLFp16& right) { + return left.val == right.val; +} + +inline bool +operator!=(const MLFp16& left, const MLFp16& right) { + return left.val != right.val; +} + +template +void SmallFloatFill(T* start, size_t size) { + constexpr float MinimumFillValue = -11.0f; + auto* FillAddress = start; + size_t offset = size % 23; + + for (size_t i = 0; i < size; i++) { + offset = (offset + 21) % 23; + *FillAddress++ = T((MinimumFillValue + offset) / 16.0f); + } +} + + +/** + * @brief Test class for half precision GEMM + * @tparam AType Data type of A matrix, can be either float or MLFp16 + * @tparam BType Data type of b matrix, can be either float or MLFp16 +*/ +template +class MlasHalfGemmTest : public MlasTestBase { + +private: + MatrixGuardBuffer BufferBPacked; + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + MatrixGuardBuffer BufferFloatC; + MLAS_THREADPOOL* threadpool_; + + void* PackB(size_t N, size_t K, const BType* B, size_t ldb) { + size_t PackedBSize = MlasHalfGemmPackBSize(N, K, std::is_same::value); + if (PackedBSize == 0) { + return nullptr; + } + void* PackedB = BufferBPacked.GetBuffer(PackedBSize); + if (std::is_same::value) { + MlasHalfGemmConvertPackB(N, K, (const float*)B, ldb, PackedB); + } else { + MlasHalfGemmPackB(N, K, (const MLAS_FP16*)B, ldb, PackedB); + } + return PackedB; + } + + void CallGemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const AType* A, + size_t lda, + const BType* B, + size_t ldb, + const MLFp16* Bias, + MLFp16* C, + size_t ldc, + float* Cfloat) { + std::vector Converters; + Converters.reserve(BatchSize); + + std::vector GemmParameters(BatchSize); + + for (size_t i = 0; i < GemmParameters.size(); i++) { + auto& params = GemmParameters[i]; + params.A = A + (M * lda * i); + params.lda = lda; + if (nullptr != Bias) { + params.Bias = reinterpret_cast(Bias + N * i); + } else { + params.Bias = nullptr; + } + params.C = reinterpret_cast(C + (M * ldc * i)); + params.ldc = ldc; + + if (Packed) { + ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; + params.B = PackB(N, K, B, ldb); + params.ldb = 0; + } else { + params.B = B + (K * N * i); + params.ldb = ldb; + } + params.AIsfp32 = std::is_same::value; + params.BIsfp32 = std::is_same::value; + Converters.emplace_back(Cfloat + (M * N * i), N); + params.OutputProcessor = &(Converters[i]); + } + + MlasHalfGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); + } + + void ReferenceQgemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const AType* A, + const BType* B, + const MLFp16* Bias, + float* C) { + // TODO!! deal with half precision accumulation error + // Most CPUs does not support mixed precision accumulation, + // only mul & add fuse. As a result, different striding + // on the K dimension may lead to rounding error. + // Accumulation of these rounding error maybe significant. + // + // An ugly hack now is to change the K stride of the kernel + // under test to be 16, pass this test and then change it + // back :-(. + // + constexpr size_t KStride = 16; + + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const AType* a = A + M * K * batch + m * K; + const BType* b = B + K * N * batch + n; + float* c = C + (M * N * batch) + (m * N) + n; + + for (size_t k = 0; k < K; k+=KStride) { + float sum = 0.0f; + if (k == 0 && Bias != nullptr) { + sum = float(Bias[n]); + } + for (size_t kk = 0; kk < std::min(KStride, K - k); kk++) { + MLFp16 down(float(*b) * float(*a) + sum); + sum = float(down); + b += N; + a += 1; + } + if (k == 0) { + *c = sum; + } else { + MLFp16 d(sum + *c); + *c = float(d); + } + } + } + } + if (Bias) { + Bias += N; + } + } + } + +public: + MlasHalfGemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} + + void Test(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { + const AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + AType Atail[16]; + std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); + + const BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); + BType Btail[16]; + std::memcpy(Btail, B + N * K * BatchSize, 16 * sizeof(BType)); + + MLFp16 BiasTail[16]; + const MLFp16* Bias = nullptr; + if (withBias) { + Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); + std::memcpy(BiasTail, Bias + N * BatchSize, 16 * sizeof(MLFp16)); + } + + MLFp16* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); + float* Cfloat = BufferFloatC.GetBuffer(N * M * BatchSize, true); + float* CReference = BufferCReference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + + this->CallGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N, Cfloat); + ReferenceQgemm(M, N, K, BatchSize, A, B, Bias, CReference); + + for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + ASSERT_EQ(float(C[f]), CReference[f]) << "@[" << batch << "x" << m << "x" << n << "], " + << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K; + ASSERT_EQ(Cfloat[f], CReference[f]) << "Converted@[" << batch << "x" << m << "x" << n << "], " + << "Batch=" << BatchSize << "M=" << M << ", N=" << N << ", K=" << K; + + } + } + } + ASSERT_EQ(std::memcmp(Atail, A + K * M * BatchSize, 16 * sizeof(AType)), 0) << "Matrix A buffer overwritten!"; + ASSERT_EQ(std::memcmp(Btail, B + N * K * BatchSize, 16 * sizeof(BType)), 0) << "Matrix B buffer overwritten!"; + if (withBias){ + ASSERT_EQ(std::memcmp(BiasTail, Bias + N * BatchSize, 16 * sizeof(MLFp16)), 0) << "Bias buffer overwritten!"; + } + } + + private: + + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("HalfGemmFP") + + (std::is_same::value ? "32" : "16") + + (std::is_same::value ? "32" : "16") + + (Packed ? "_Packed" : "_NoPack") + + (Threaded ? "_Threaded" : "_SingleThread"); + return suite_name.c_str(); + } + + void ExecuteLong(void) override { + for (size_t M = 16; M < 160; M += 32) { + for (size_t N = 16; N < 160; N += 32) { + static const size_t ks[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320}; + for (size_t k = 0; k < _countof(ks); k++) { + size_t K = ks[k]; + + Test(M, N, K, 1, false); + Test(M, N, K, 1, true); + Test(M + 1, N, K, 1, false); + Test(M, N + 1, K, 1, true); + Test(M + 1, N + 1, K, 1, false); + Test(M + 3, N + 2, K, 1, true); + Test(M + 4, N, K, 1, false); + Test(M, N + 4, K, 1, true); + Test(M + 4, N + 4, K, 1, false); + Test(M + 3, N + 7, K, 1, true); + Test(M + 8, N, K, 1, false); + Test(M, N + 8, K, 1, true); + Test(M + 12, N + 12, K, 1, false); + Test(M + 13, N, K, 1, true); + Test(M, N + 15, K, 1, false); + Test(M + 15, N + 15, K, 1, false); + if (!Packed) { + Test(M, N, K, 7, false); + Test(M + 3, N, K, 8, true); + Test(M, N + 1, K, 9, false); + Test(M + 12, N, K, 10, true); + Test(M, N + 15, K, 11, false); + Test(M + 15, N + 15, K, 12, true); + } + } + } + printf("M %zd\n", M); + } + + for (size_t M = 1; M < 160; M++) { + for (size_t N = 1; N < 160; N++) { + for (size_t K = 1; K < 160; K++) { + Test(M, N, K, 1, true); + } + } + printf("M %zd\n", M); + } + + for (size_t M = 160; M < 320; M += 24) { + for (size_t N = 112; N < 320; N += 24) { + for (size_t K = 1; K < 16; K++) { + Test(M, N, K, 1, true); + } + for (size_t K = 16; K < 160; K += 32) { + Test(M, N, K, 1, false); + } + } + printf("M %zd\n", M); + } + } + + +}; diff --git a/onnxruntime/test/mlas/unittest/test_util.h b/onnxruntime/test/mlas/unittest/test_util.h index c3d97eb3bb402..d3fc735260742 100644 --- a/onnxruntime/test/mlas/unittest/test_util.h +++ b/onnxruntime/test/mlas/unittest/test_util.h @@ -50,7 +50,7 @@ class MatrixGuardBuffer { ReleaseBuffer(); } - T* GetBuffer(size_t Elements, bool ZeroFill = false) { + T* GetFilledBuffer(size_t Elements, std::function const & fillFunc) { // // Check if the internal buffer needs to be reallocated. // @@ -105,29 +105,38 @@ class MatrixGuardBuffer { T* GuardAddress = _GuardAddress; T* buffer = GuardAddress - Elements; + fillFunc(buffer, Elements); - if (ZeroFill) { - std::fill_n(buffer, Elements, T(0)); - - } else { - constexpr int MinimumFillValue = -23; - constexpr int MaximumFillValue = 23; + return buffer; + } - int FillValue = MinimumFillValue; - T* FillAddress = buffer; + T* GetBuffer(size_t Elements, bool ZeroFill = false) { + if (ZeroFill) { + return GetFilledBuffer( + Elements, + [](T* start, size_t size) { + std::fill_n(start, size, T(0)); + }); + } - while (FillAddress < GuardAddress) { - *FillAddress++ = (T)FillValue; + return GetFilledBuffer( + Elements, + [](T* start, size_t size) { + constexpr int MinimumFillValue = -23; + constexpr int MaximumFillValue = 23; - FillValue++; + int FillValue = MinimumFillValue; + T* FillAddress = start; + for (size_t i = 0; i < size; i++) { + *FillAddress++ = (T)FillValue; - if (FillValue > MaximumFillValue) { - FillValue = MinimumFillValue; - } - } - } + FillValue++; - return buffer; + if (FillValue > MaximumFillValue) { + FillValue = MinimumFillValue; + } + } + }); } void ReleaseBuffer(void) {