Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cfu fp16 #14538

Merged
merged 19 commits into from
Feb 15, 2023
7 changes: 7 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/platform.cpp
${MLAS_SRC_DIR}/threading.cpp
${MLAS_SRC_DIR}/sgemm.cpp
${MLAS_SRC_DIR}/halfgemm.cpp
${MLAS_SRC_DIR}/qgemm.cpp
${MLAS_SRC_DIR}/qdwconv.cpp
${MLAS_SRC_DIR}/convolve.cpp
Expand Down Expand Up @@ -59,6 +60,7 @@ function(setup_mlas_source_for_windows)

if(onnxruntime_target_platform STREQUAL "ARM64")
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
Expand All @@ -73,6 +75,7 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm
${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm
${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm
${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm
${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm
${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm
${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm
Expand Down Expand Up @@ -305,6 +308,7 @@ else()
${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S
Expand All @@ -314,10 +318,13 @@ else()
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs})
set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64")
Expand Down
38 changes: 38 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ void CPUIDInfo::ArmLinuxInit() {
if (pytorch_cpuinfo_init_) {
is_hybrid_ = cpuinfo_get_uarchs_count() > 1;
has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot();
has_fp16_ = cpuinfo_has_arm_neon_fp16_arith();
const uint32_t core_cnt = cpuinfo_get_cores_count();
core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown);
is_armv8_narrow_ld_.resize(core_cnt, false);
Expand All @@ -165,6 +166,7 @@ void CPUIDInfo::ArmLinuxInit() {
}
} else {
has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0);
has_fp16_ |= has_arm_neon_dot_;
}
}

Expand Down Expand Up @@ -220,9 +222,45 @@ void CPUIDInfo::ArmWindowsInit() {
lastUarch = uarch;
}
}

switch (lastUarch) {
case cpuinfo_uarch_cortex_a55:
case cpuinfo_uarch_cortex_a55r0:
case cpuinfo_uarch_cortex_a76:
case cpuinfo_uarch_neoverse_n1:
case cpuinfo_uarch_cortex_a77:
case cpuinfo_uarch_exynos_m4:
case cpuinfo_uarch_exynos_m5:
has_fp16_ = true;
break;
default:
break;
}
if (!has_fp16_) {
/*
* Detecting fp16 support. Different cores should have the same instruction set.
* So we just check the first ID_AA64PFR0_EL1
* Op0(0b11), Op1(0b000), CRn(0b0000), CRm(0b0100), Op2(0b000),
*/
uint64_t ID_AA64PFR0_EL1;
unsigned long valsize = sizeof(uint64_t);
auto retCode = ::RegGetValueA(
HKEY_LOCAL_MACHINE,
"HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0",
"CP 4020", RRF_RT_REG_QWORD, nullptr,
&ID_AA64PFR0_EL1, &valsize);
if (retCode == ERROR_SUCCESS) {
// AdvSIMD, bits [23:20]
auto advSimd = ID_AA64PFR0_EL1 >> 20;
if ((advSimd & 0xfULL) == 1) {
has_fp16_ = true;
}
}
}
#endif /* Application Family or OneCore Family */

has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0);
has_fp16_ |= has_arm_neon_dot_;
}

#endif /* (arm or arm64) and windows */
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CPUIDInfo {
bool HasAVX512f() const { return has_avx512f_; }
bool HasAVX512_BF16() const {return has_avx512_bf16_;}
bool HasAVX512Skylake() const { return has_avx512_skylake_; }
bool HasF16C() const { return has_f16c_; }
bool HasF16C() const { return has_f16c_; } /*fp16 conversion inst*/
bool HasSSE3() const { return has_sse3_; }
bool HasSSE4_1() const { return has_sse4_1_; }
bool IsHybrid() const { return is_hybrid_; }
Expand Down Expand Up @@ -85,6 +85,9 @@ class CPUIDInfo {
return is_armv8_narrow_ld_[coreIdx];
}

bool HasFp16VectorAcceleration() const {
return has_fp16_;
}

private:
CPUIDInfo() {
Expand Down Expand Up @@ -118,6 +121,7 @@ class CPUIDInfo {
std::vector<bool> is_armv8_narrow_ld_;

bool has_arm_neon_dot_{false};
bool has_fp16_{false};

#ifdef CPUIDINFO_ARCH_X86

Expand Down
178 changes: 175 additions & 3 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ typedef enum { CblasLeft=141, CblasRight=142} CBLAS_SIDE;
#endif

//
// Forward declare the thread pool implementation class.
// Forward declare the thread pool implementation class and half precision floating point.
//
// N.B. Avoid including ONNX Runtime headers here to keep the dependencies for
// standalone MLAS test executables smaller.
Expand All @@ -100,10 +100,12 @@ namespace onnxruntime {
namespace concurrency {
class ThreadPool;
};
};
struct MLFloat16;
}; // namespace onnxruntime

using MLAS_THREADPOOL = onnxruntime::concurrency::ThreadPool;


//
// Platform routines.
//
Expand Down Expand Up @@ -613,7 +615,7 @@ MlasGemm(
// Currently only supported in ARM64
//
#if defined(MLAS_TARGET_ARM64)
constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 15;
constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 30;
#else
constexpr size_t MLAS_SYMM_QGEMM_BUF_OVERRUN = 0;
#endif
Expand Down Expand Up @@ -1366,3 +1368,173 @@ MlasQLinearMul(
size_t N,
bool IsScalarB
);

//
// Half precision routines
//

// Any type with size=2 should work
using MLAS_FP16 = onnxruntime::MLFloat16;

constexpr size_t FP16_SIZE = sizeof(uint16_t);


bool MLASCALL
MlasFp16AccelerationSupported();

/**
* @brief Interface for half gemm post processors.
*
* Example implementation of this interface includes activations,
* conversion from half precision to single precision, etc.
*
* Half GEMM is computed tile by tile. When a tile of result matrix
* is produced, the method Process() is called to process this tile.
* Parameters of this method describe the location and shape of the
* tile.
*/
class MLAS_HALF_GEMM_POSTPROCESSOR {
public:
virtual
void
Process(
MLAS_FP16*, /**< the address of matrix to process */
size_t, /**< the start row index of matrix */
size_t, /**< the start col index of matrix */
size_t, /**< the element count per row to process */
size_t, /**< the element count per col to process */
size_t /**< the leading dimension of matrix */
) const = 0;

virtual ~MLAS_HALF_GEMM_POSTPROCESSOR() {}
};

/**
* @brief Convert half gemm result matrix to single precision float matrix
*/
class MLAS_HALF_GEMM_2FLOAT_PROCESSOR : public MLAS_HALF_GEMM_POSTPROCESSOR {
public:
MLAS_HALF_GEMM_2FLOAT_PROCESSOR(
float* Output, /**< address of the output matrix, row major */
size_t RowStride /**< row stride of the output matrix */
) :
Output_(Output),
RowStride_(RowStride)
{}

void
Process(
MLAS_FP16* C,
size_t StartM,
size_t StartN,
size_t CountM,
size_t CountN,
size_t ldc
) const override;

private:
float* Output_;
size_t RowStride_;
};


/**
* @brief Data parameters for half precision GEMM routine
* All except C are [in] parameters
*/
struct MLAS_HALF_GEMM_DATA_PARAMS {
const void* A = nullptr; /**< address of A */
const void* B = nullptr; /**< address of B */
const MLAS_FP16* Bias = nullptr; /**< address of Bias, vector size N */
MLAS_FP16* C = nullptr; /**< address of result matrix */
size_t lda = 0; /**< leading dimension of A */
size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/
size_t ldc = 0; /**< leading dimension of C*/
const MLAS_HALF_GEMM_POSTPROCESSOR* OutputProcessor = nullptr;
bool AIsfp32 = false; /**< matrix A is fp32, needs to be casted into fp16*/
bool BIsfp32 = false; /**< matrix B is fp32, needs to be casted into fp16*/
};

/**
* @brief Half precision Batched GEMM: C = A * B + Bias
* Either A or B can be fp32 or fp16
*
* Note: We only support uniform batching, so shapes and types of the
* input must be same across all parameter blocks.
*
* @param[in] M row size of matrix A and C
* @param[in] N column size of matrix B and C
* @param[in] K column size of matrix A and row size of matrix B
* @param[in] BatchN number of batches
* @param[inout] DataParams An array (size BatchN) of parameter blocks
* @param[in] ThreadPool
* @return
*/
void
MLASCALL
MlasHalfGemmBatch(
const size_t M,
const size_t N,
const size_t K,
const size_t BatchN,
const MLAS_HALF_GEMM_DATA_PARAMS* DataParams,
MLAS_THREADPOOL* ThreadPool = nullptr
);

/**
* @brief For half precision GEMM, returns size of the
* packing buffer needed for right hand side
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] float2half Whether the input is float that
* needs to be converted to half precision
* @return size of the packing buffer,
* 0 if operation not supported
*/
size_t
MLASCALL
MlasHalfGemmPackBSize(
size_t N,
size_t K,
bool float2half
);

/**
* @brief For half precision GEMM, pack the right hand
* side matrix B
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void
MLASCALL
MlasHalfGemmPackB(
size_t N,
size_t K,
const MLAS_FP16* B,
size_t ldb,
void* PackedB
);

/**
* @brief For half precision GEMM, convert the float matrix B
* to half precision and pack it into a packing buffer
*
* @param[in] N Number of columns
* @param[in] K Number of rows
* @param[in] B Address of matrix B
* @param[in] ldb leading dimension of input matrix B
* @param[out] PackedB Address of the packed matrix
*/
void
MLASCALL
MlasHalfGemmConvertPackB(
size_t N,
size_t K,
const float* B,
size_t ldb,
void* PackedB
);
Loading