From 13348c572a315806c7745280073aacdf384dcea9 Mon Sep 17 00:00:00 2001 From: Jing Fang <126209182+fajin-corp@users.noreply.github.com> Date: Fri, 24 Jan 2025 15:25:24 -0800 Subject: [PATCH 1/2] [ARM CPU] hgemm optimized for gqa (#23107) ### Description Add fp16 kernels for GQA matmul on ARM CPU. The kernels are mlas hgemm for C = alpha * A x B' + beta * C ### Motivation and Context Add fp16 support for GQA, speed up the operator and reduce memory usage. __Token Generation__ | | HGEMM Runtime (ns) | SGEMM Runtime (ns) | Speed-up (%) | |---------------------------------|--------------------|--------------------|--------------| | M:1/N:4096/K:4096 | 251551 | 1775905 | 85.84 | | M:1/N:11008/K:4096 | 892507 | 4649145 | 80.80 | | M:1/N:4096/K:11008 | 866860 | 3240015 | 73.25 | | M:1/N:11008/K:11008 | 2631615 |8783877 | 70.04 | __Prompting__ | | HGEMM Runtime (ns) | SGEMM Runtime (ns) | Speed-up (%) | |---------------------------------|--------------------|--------------------|--------------| | M:1024/N:4096/K:4096 | 90508701 | 111283029 | 18.67 | | M:2048/N:4096/K:4096 | 181307522 | 240211107 | 24.52 | | M:1024/N:11008/K:4096 | 241120234 | 307707933 | 21.64 | | M:2048/N:11008/K:4096 | 481091232 | 648921367 | 25.86 | | M:1024/N:4096/K:11008 | 241736343 | 310129880 | 22.05 | | M:2048/N:4096/K:11008 | 480456703 | 644814999 | 25.49 | | M:1024/N:11008/K:11008 | 642121440 | 847925766 | 24.27 | | M:2048/N:11008/K:11008 | 1276097154 | 1731314509 | 26.29 --- cmake/onnxruntime_mlas.cmake | 5 + .../contrib_ops/cpu/bert/gqa_attention_base.h | 6 + onnxruntime/core/mlas/inc/mlas.h | 102 +- onnxruntime/core/mlas/lib/fp16_common.h | 99 ++ onnxruntime/core/mlas/lib/halfgemm.cpp | 170 ++ onnxruntime/core/mlas/lib/halfgemm.h | 122 ++ .../mlas/lib/halfgemm_kernel_neon_fp16.cpp | 1572 +++++++++++++++++ .../core/mlas/lib/hgemm_kernel_neon.cpp | 28 + .../mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp | 33 - onnxruntime/core/mlas/lib/mlasi.h | 11 + onnxruntime/core/mlas/lib/platform.cpp | 1 + onnxruntime/test/mlas/bench/bench_hgemm.cpp | 86 + .../test/mlas/unittest/test_hgemm_neon.cpp | 393 +++++ 13 files changed, 2594 insertions(+), 34 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/halfgemm_kernel_neon_fp16.cpp create mode 100644 onnxruntime/core/mlas/lib/hgemm_kernel_neon.cpp create mode 100644 onnxruntime/test/mlas/bench/bench_hgemm.cpp create mode 100644 onnxruntime/test/mlas/unittest/test_hgemm_neon.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 5124262ec0004..ed3ad89247975 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -95,6 +95,8 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp ) set(mlas_platform_preprocess_srcs @@ -374,6 +376,7 @@ else() ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp + ${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") @@ -394,6 +397,7 @@ else() ${MLAS_SRC_DIR}/cast_kernel_neon.cpp ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") @@ -406,6 +410,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index ccaeb6654e286..abb24e20a6178 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -75,6 +75,7 @@ class GQAAttentionBase { int seqlen_present_kv_cache = static_cast(present_key->Shape().GetDims()[2]); // Compute the attention score. + // TODO(fajin): type depends on kernel supportability size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float); auto attention_probs = allocator->Alloc(bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); @@ -198,6 +199,11 @@ class GQAAttentionBase { math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, static_cast(present_buffer_sequence_length), nullptr); + // TODO(fajin): update later + // } else if (MlasHGemmSupported(CblasNoTrans, CblasTrans)) { + // MlasGemm(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, + // q, static_cast(head_size), k, static_cast(head_size), output, + // static_cast(present_buffer_sequence_length), alpha, 0.0f /*beta*/, nullptr); } else { size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float); auto q_k_fp32 = allocator->Alloc(bytes); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 207c058d899b4..7e0335cc66ef0 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1458,7 +1458,107 @@ MlasRotaryEmbedOneRow( T* output ); - /** +/** + * @brief Supply matrices data information to half precision gemm functions + */ +struct MLAS_HGEMM_DATA_PARAMS { + const MLAS_FP16* A; /**< Supplies the address of matrix A */ + size_t lda; /**< Supplies the first dimension of matrix A. */ + const MLAS_FP16* B; /**< Supplies the address of matrix B */ + size_t ldb; /**< Supplies the first dimension of matrix B. */ + MLAS_FP16* C; /**< Supplies the address of matrix C */ + size_t ldc; /**< Supplies the first dimension of matrix C. */ + uint16_t alpha; /**< Supplies the scalar alpha multiplier (see GEMM definition). FP16 encoding. */ + uint16_t beta; /**< Supplies the scalar beta multiplier (see GEMM definition). FP16 encoding. */ +}; + +/** + * @brief Check whether current CPU supports half precision gemm. + */ +bool +MLASCALL +MlasHGemmSupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB + ); + +/** + * @brief Batched half precision matrix/matrix multiply operation (HGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number of rows of matrix B. + * @param Data A array of matrices data parameters + * @param BatchSize Supplies number of multiplications in this batch + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +void +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_HGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + +/** + * @brief half precision matrix/matrix multiply operation (HGEMM) + * C = alpha * op(A) * op(B) + beta * C + * + * @param TransA Supplies the transpose operation for matrix A. Currently only support CblasNoTrans. + * @param TransB Supplies the transpose operation for matrix B. Currently only support CblasTrans. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number of rows of matrix B. + * @param A Supplies the address of matrix A + * @param lda Supplies the first dimension of matrix A. + * @param B Supplies the address of matrix B + * @param ldb Supplies the first dimension of matrix B. + * @param C Supplies the address of matrix C + * @param ldc Supplies the first dimension of matrix C. + * @param alpha Supplies the scalar alpha multiplier (see GEMM definition) + * @param beta Supplies the scalar beta multiplier (see GEMM definition) + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the base library threading support + * should be used. + */ +inline +void +MlasGemm( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_FP16* A, + size_t lda, + const MLAS_FP16* B, + size_t ldb, + MLAS_FP16* C, + size_t ldc, + uint16_t alpha, + uint16_t beta, + MLAS_THREADPOOL* ThreadPool +) { + MLAS_HGEMM_DATA_PARAMS Data; + Data.A = A; + Data.lda = lda; + Data.B = B; + Data.ldb = ldb; + Data.C = C; + Data.ldc = ldc; + Data.alpha = alpha; + Data.beta = beta; + MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); +} + +/** * @brief Whether current CPU supports FP16 acceleration. */ bool MLASCALL diff --git a/onnxruntime/core/mlas/lib/fp16_common.h b/onnxruntime/core/mlas/lib/fp16_common.h index f4c49905ebbd7..acee567162b9d 100644 --- a/onnxruntime/core/mlas/lib/fp16_common.h +++ b/onnxruntime/core/mlas/lib/fp16_common.h @@ -349,4 +349,103 @@ MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT return vbsl_f16(select, ones, zeros); } +MLAS_FORCEINLINE +void +Transpose8x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3, + MLAS_FLOAT16X8& v4, MLAS_FLOAT16X8& v5, MLAS_FLOAT16X8& v6, MLAS_FLOAT16X8& v7) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // |v40|v41|v42|v43|v44|v45|v46|v47| + // |v50|v51|v52|v53|v54|v55|v56|v57| + // |v60|v61|v62|v63|v64|v65|v66|v67| + // |v70|v71|v72|v73|v74|v75|v76|v77| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + float16x8x2_t t45 = vtrnq_f16(v4, v5); + float16x8x2_t t67 = vtrnq_f16(v6, v7); + // |v00|v10|v02|v12|v04|v14|v06|v16| + // |v01|v11|v03|v13|v05|v15|v07|v17| + // |v20|v30|v22|v32|v24|v34|v26|v36| + // |v21|v31|v23|v33|v25|v35|v27|v37| + // |v40|v50|v42|v52|v44|v54|v46|v56| + // |v41|v51|v43|v53|v45|v55|v47|v57| + // |v60|v70|v62|v72|v64|v74|v66|v76| + // |v61|v71|v63|v73|v65|v75|v67|v77| + float32x4x2_t t02 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])); + float32x4x2_t t13 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])); + float32x4x2_t t46 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[0]), vreinterpretq_f32_f16(t67.val[0])); + float32x4x2_t t57 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[1]), vreinterpretq_f32_f16(t67.val[1])); + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + // |v40|v50|v60|v70|v44|v54|v64|v74| + // |v41|v51|v61|v71|v45|v55|v65|v75| + // |v42|v52|v62|v72|v46|v56|v66|v76| + // |v43|v53|v63|v73|v47|v57|v67|v77| + v0 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0]))); + v4 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0]))); + v2 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1]))); + v6 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1]))); + v1 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0]))); + v5 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0]))); + v3 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1]))); + v7 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1]))); + // |v00|v10|v20|v30|v40|v50|v60|v70| + // |v01|v11|v21|v31|v41|v51|v61|v71| + // |v02|v12|v22|v32|v42|v52|v62|v72| + // |v03|v13|v23|v33|v43|v53|v63|v73| + // |v04|v14|v24|v34|v44|v54|v64|v74| + // |v05|v15|v25|v35|v45|v55|v65|v75| + // |v06|v16|v26|v36|v46|v56|v66|v76| + // |v07|v17|v27|v37|v47|v57|v67|v77| +} + +MLAS_FORCEINLINE +void +Transpose4x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // => + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + + v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); + v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); +} + +MLAS_FORCEINLINE +void +Transpose4x4(MLAS_FLOAT16X4& v0, MLAS_FLOAT16X4& v1, MLAS_FLOAT16X4& v2, MLAS_FLOAT16X4& v3) +{ + // |v00|v01|v02|v03| + // |v10|v11|v12|v13| + // |v20|v21|v22|v23| + // |v30|v31|v32|v33| + // => + // |v00|v10|v20|v30| + // |v01|v11|v21|v31| + // |v02|v12|v22|v32| + // |v03|v13|v23|v33| + float16x4x2_t t01 = vtrn_f16(v0, v1); + float16x4x2_t t23 = vtrn_f16(v2, v3); + + v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); + v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); +} + #endif // fp16 vector intrinsic supported diff --git a/onnxruntime/core/mlas/lib/halfgemm.cpp b/onnxruntime/core/mlas/lib/halfgemm.cpp index 49387d2fc998f..65ab0e9ce4630 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.cpp +++ b/onnxruntime/core/mlas/lib/halfgemm.cpp @@ -324,6 +324,176 @@ MlasHalfGemmKernel( } } +bool +MLASCALL +MlasHGemmSupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB +) { + auto* dispatch = GetMlasPlatform().HGemmDispatch; + if (TransA == CblasNoTrans && TransB == CblasTrans) { + return dispatch && + dispatch->HGemmKernel_TransposedB && + dispatch->HPackBKernel_TransposedB && + dispatch->HGemmKernel_TransposedPackedB; + } + + return false; +} + +void +HGemmOperation( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t K, // full K slice + const MLAS_HGEMM_DATA_PARAMS* DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) { + const size_t lda = DataParams->lda; + const size_t ldb = DataParams->ldb; + const size_t ldc = DataParams->ldc; + const _mlas_fp16_ alpha = DataParams->alpha; + const _mlas_fp16_ beta = DataParams->beta; + auto* dispatch = GetMlasPlatform().HGemmDispatch; + constexpr size_t StrideM = 2; + const auto beta_add = MLAS_FP16(1.0f); + constexpr size_t buffer_size = MLAS_HGEMM_STRIDEN * MLAS_HGEMM_STRIDEK; + MLAS_DECLSPEC_ALIGN(MLAS_FP16 PackedB[buffer_size], 16 * sizeof(_mlas_fp16_)); + + if (TransA == CblasNoTrans && TransB == CblasTrans) { + const auto* A = DataParams->A + RangeStartM * lda; + const auto* B = DataParams->B + RangeStartN * ldb; + auto* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + if (RangeCountM <= StrideM) { + if (!dispatch || !dispatch->HGemmKernel_TransposedB) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x Transposed(B) kernels"); + } + // When M is small, B is visited once. The overhead of Pack(B') exceeds the benefits + // from A x Pack(B'). Therefore directly calculate A x B'. + // Without PackB, to utilize memory locality, iterate full K. + constexpr size_t StrideN = 16; + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + dispatch->HGemmKernel_TransposedB(A, B, C, RangeCountM, countN, K, lda, ldb, ldc, alpha, beta); + B += countN * ldb; + C += countN; + } + } else { + if (!dispatch || !dispatch->HPackBKernel_TransposedB || !dispatch->HGemmKernel_TransposedPackedB) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x Transposed(B) kernels"); + } + // 16N is the smallest pack unit. + const size_t StrideK = std::min(K, size_t(MLAS_HGEMM_STRIDEK)); + const size_t StrideN = buffer_size/StrideK & (~15); // >= MLAS_HGEMM_STRIDEN + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + const MLAS_FP16* a = A; + const MLAS_FP16* b = B; + MLAS_FP16* c = C; + for (size_t k = 0, countK; k < K; k += countK) { + countK = std::min(StrideK, K - k); + dispatch->HPackBKernel_TransposedB(b, PackedB, countN, countK, ldb); + const MLAS_FP16* aa = a; + MLAS_FP16* cc = c; + for (size_t m = 0, countM; m < RangeCountM; m += countM) { + countM = std::min(StrideM, RangeCountM - m); + // First K iteration, beta is applied to the whole C. In rest K iterations, use add mode. + dispatch->HGemmKernel_TransposedPackedB( + aa, PackedB, cc, countM, countN, countK, lda, ldc, alpha, k == 0 ? beta : beta_add.val); + aa += countM * lda; + cc += countM * ldc; + } + a += countK; + b += countK; + } + B += countN * ldb; + C += countN; + } + } + } else { + MLAS_THROW_EX(std::runtime_error, "hgemm currently only support A x Transpoe(B)"); + } +} + +void +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_HGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) { + if (!ThreadPool) { + for (size_t gemm_i = 0; gemm_i < BatchSize; gemm_i++) { + HGemmOperation(TransA, TransB, K, &Data[gemm_i], 0, M, 0, N); + } + return; + } + + const double Complexity = double(M) * double(N) * double(K) * double(BatchSize); + ptrdiff_t TargetThreadCount; + + if (Complexity < double(MLAS_HGEMM_THREAD_COMPLEXITY) * GetMlasPlatform().MaximumThreadCount) { + TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_HGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = GetMlasPlatform().MaximumThreadCount; + } + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // Segment the operation across multiple threads. + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchSize; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min( + nc, MlasDivRoundup(max_nc, MLAS_HGEMM_STRIDEN_THREAD_ALIGN) * MLAS_HGEMM_STRIDEN_THREAD_ALIGN); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * static_cast(BatchSize), [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + HGemmOperation(TransA, TransB, K, &Data[gemm_i], RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { MlasHalfGemmOperation, diff --git a/onnxruntime/core/mlas/lib/halfgemm.h b/onnxruntime/core/mlas/lib/halfgemm.h index 61e2fbb0afc6a..e280e6d40973f 100644 --- a/onnxruntime/core/mlas/lib/halfgemm.h +++ b/onnxruntime/core/mlas/lib/halfgemm.h @@ -513,3 +513,125 @@ MlasHalfGemmGetDispatch() return &MlasHalfGemmDispatchDefault; #endif } + +namespace hgemm_neon { + +void HPackB_TransposedB_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +); + +void HGemm_TransposedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +); + +void HGemm_TransposedPackedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +); + +} // namespace hgemm_neon + +struct MLAS_HGEMM_DISPATCH { + /** + * @brief Pack the B matrix segment. B is column-major. Elements from CountK rows x N columns are packed + * continuously in row-major. + * First pack CountK rows x 16 columns, then pack CountK rows x 8 columns. + * If there are < 8 columns left, pad the columns with 0. + * @param B the first element of the B matrix segment. Column major. + * @param[out] PackedB the first element of the packed B matrix segment. + * @param CountN the number of columns of B chunk. + * @param CountK the number of rows of B chunk. + */ + typedef void(HPackBKernel_TransposedB_Fn) ( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb + ); + + HPackBKernel_TransposedB_Fn* HPackBKernel_TransposedB = nullptr; + + /** + * @brief C = alpha * A * Transpose(B) + beta * C. CountM <= 2. B is not packed. Used when M is small. + * + * @param A first row of the A matrix segment. Row major. + * @param B first column of the B matrix segment. Column major. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param CountK the number of columns of A chunk and the number of rows of B chunk. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + * @param alpha the alpha scalar value. + * @param beta the beta scalar value. + */ + typedef void(HGemmKernel_TransposedB_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta + ); + + HGemmKernel_TransposedB_Fn* HGemmKernel_TransposedB = nullptr; + + /** + * @brief C = alpha * A * Transpose(B) + beta * C. CountM <= 2. B has been packed using HPackBKernel_TransposedB_Fn. + * Use when M is large. + * + * @param A first row of the A matrix segment. Row major. + * @param PackedB first element of the packed B buffer. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param CountK the number of columns of A chunk and the number of rows of B chunk. + * @param lda the leading dimension of A. + * @param ldc the leading dimension of C. + * @param alpha the alpha scalar value. + * @param beta the beta scalar value. + */ + typedef void(HGemmKernel_TransposedPackedB_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta + ); + + HGemmKernel_TransposedPackedB_Fn* HGemmKernel_TransposedPackedB = nullptr; +}; diff --git a/onnxruntime/core/mlas/lib/halfgemm_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon_fp16.cpp new file mode 100644 index 0000000000000..02ce38fcb21d6 --- /dev/null +++ b/onnxruntime/core/mlas/lib/halfgemm_kernel_neon_fp16.cpp @@ -0,0 +1,1572 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm_kernel_neon_fp16.cpp + +Abstract: + + This module implements half precision GEMM kernel for neon. + +--*/ + +#include + +#include "halfgemm.h" +#include "fp16_common.h" + +namespace hgemm_neon { + +void HPackB_TransposedB_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +) { + const _mlas_fp16_* B_data = reinterpret_cast(B); + _mlas_fp16_* PackedB_data = reinterpret_cast<_mlas_fp16_*>(PackedB); + + for (; CountN >= 16; CountN -= 16, B_data += 16 * ldb) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 16; // pack 8 * 16 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v0 = MlasLoadFloat16x8(b); + float16x8_t v1 = MlasLoadFloat16x8(b + ldb); + float16x8_t v2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t v3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t v4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t v5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t v6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t v7 = MlasLoadFloat16x8(b + 7 * ldb); + float16x8_t v8 = MlasLoadFloat16x8(b + 8 * ldb); + float16x8_t v9 = MlasLoadFloat16x8(b + 9 * ldb); + float16x8_t vA = MlasLoadFloat16x8(b + 10 * ldb); + float16x8_t vB = MlasLoadFloat16x8(b + 11 * ldb); + float16x8_t vC = MlasLoadFloat16x8(b + 12 * ldb); + float16x8_t vD = MlasLoadFloat16x8(b + 13 * ldb); + float16x8_t vE = MlasLoadFloat16x8(b + 14 * ldb); + float16x8_t vF = MlasLoadFloat16x8(b + 15 * ldb); + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + Transpose8x8(v8, v9, vA, vB, vC, vD, vE, vF); + + MlasStoreFloat16x8(PackedB_data, v0); + MlasStoreFloat16x8(PackedB_data + 8, v8); + MlasStoreFloat16x8(PackedB_data + 16, v1); + MlasStoreFloat16x8(PackedB_data + 24, v9); + MlasStoreFloat16x8(PackedB_data + 32, v2); + MlasStoreFloat16x8(PackedB_data + 40, vA); + MlasStoreFloat16x8(PackedB_data + 48, v3); + MlasStoreFloat16x8(PackedB_data + 56, vB); + MlasStoreFloat16x8(PackedB_data + 64, v4); + MlasStoreFloat16x8(PackedB_data + 72, vC); + MlasStoreFloat16x8(PackedB_data + 80, v5); + MlasStoreFloat16x8(PackedB_data + 88, vD); + MlasStoreFloat16x8(PackedB_data + 96, v6); + MlasStoreFloat16x8(PackedB_data + 104, vE); + MlasStoreFloat16x8(PackedB_data + 112, v7); + MlasStoreFloat16x8(PackedB_data + 120, vF); + } + + if (k & 4) { + float16x4_t v0 = MlasLoadFloat16x4(b); + float16x4_t v1 = MlasLoadFloat16x4(b + ldb); + float16x4_t v2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t v3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t v4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t v5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t v6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t v7 = MlasLoadFloat16x4(b + 7 * ldb); + float16x4_t v8 = MlasLoadFloat16x4(b + 8 * ldb); + float16x4_t v9 = MlasLoadFloat16x4(b + 9 * ldb); + float16x4_t vA = MlasLoadFloat16x4(b + 10 * ldb); + float16x4_t vB = MlasLoadFloat16x4(b + 11 * ldb); + float16x4_t vC = MlasLoadFloat16x4(b + 12 * ldb); + float16x4_t vD = MlasLoadFloat16x4(b + 13 * ldb); + float16x4_t vE = MlasLoadFloat16x4(b + 14 * ldb); + float16x4_t vF = MlasLoadFloat16x4(b + 15 * ldb); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + Transpose4x4(v8, v9, vA, vB); + Transpose4x4(vC, vD, vE, vF); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v8); + MlasStoreFloat16x4(PackedB_data + 12, vC); + MlasStoreFloat16x4(PackedB_data + 16, v1); + MlasStoreFloat16x4(PackedB_data + 20, v5); + MlasStoreFloat16x4(PackedB_data + 24, v9); + MlasStoreFloat16x4(PackedB_data + 28, vD); + MlasStoreFloat16x4(PackedB_data + 32, v2); + MlasStoreFloat16x4(PackedB_data + 36, v6); + MlasStoreFloat16x4(PackedB_data + 40, vA); + MlasStoreFloat16x4(PackedB_data + 44, vE); + MlasStoreFloat16x4(PackedB_data + 48, v3); + MlasStoreFloat16x4(PackedB_data + 52, v7); + MlasStoreFloat16x4(PackedB_data + 56, vB); + MlasStoreFloat16x4(PackedB_data + 60, vF); + + k -= 4, b += 4, PackedB_data += 4 * 16; + } + + if (k > 0) { + float16x4_t v0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t v1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t v2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t v3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t v4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t v5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t v6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t v7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + float16x4_t v8 = MlasLoadPartialFloat16x4(b + 8 * ldb, k); + float16x4_t v9 = MlasLoadPartialFloat16x4(b + 9 * ldb, k); + float16x4_t vA = MlasLoadPartialFloat16x4(b + 10 * ldb, k); + float16x4_t vB = MlasLoadPartialFloat16x4(b + 11 * ldb, k); + float16x4_t vC = MlasLoadPartialFloat16x4(b + 12 * ldb, k); + float16x4_t vD = MlasLoadPartialFloat16x4(b + 13 * ldb, k); + float16x4_t vE = MlasLoadPartialFloat16x4(b + 14 * ldb, k); + float16x4_t vF = MlasLoadPartialFloat16x4(b + 15 * ldb, k); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + Transpose4x4(v8, v9, vA, vB); + Transpose4x4(vC, vD, vE, vF); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v8); + MlasStoreFloat16x4(PackedB_data + 12, vC); + if (k > 1) { + MlasStoreFloat16x4(PackedB_data + 16, v1); + MlasStoreFloat16x4(PackedB_data + 20, v5); + MlasStoreFloat16x4(PackedB_data + 24, v9); + MlasStoreFloat16x4(PackedB_data + 28, vD); + } + if (k > 2) { + MlasStoreFloat16x4(PackedB_data + 32, v2); + MlasStoreFloat16x4(PackedB_data + 36, v6); + MlasStoreFloat16x4(PackedB_data + 40, vA); + MlasStoreFloat16x4(PackedB_data + 44, vE); + } + + PackedB_data += k * 16; + } + } + + if (CountN & 8) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 8; // pack 8 * 8 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v0 = MlasLoadFloat16x8(b); + float16x8_t v1 = MlasLoadFloat16x8(b + ldb); + float16x8_t v2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t v3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t v4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t v5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t v6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t v7 = MlasLoadFloat16x8(b + 7 * ldb); + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + + MlasStoreFloat16x8(PackedB_data, v0); + MlasStoreFloat16x8(PackedB_data + 8, v1); + MlasStoreFloat16x8(PackedB_data + 16, v2); + MlasStoreFloat16x8(PackedB_data + 24, v3); + MlasStoreFloat16x8(PackedB_data + 32, v4); + MlasStoreFloat16x8(PackedB_data + 40, v5); + MlasStoreFloat16x8(PackedB_data + 48, v6); + MlasStoreFloat16x8(PackedB_data + 56, v7); + } + + if (k & 4) { + float16x4_t v0 = MlasLoadFloat16x4(b); + float16x4_t v1 = MlasLoadFloat16x4(b + ldb); + float16x4_t v2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t v3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t v4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t v5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t v6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t v7 = MlasLoadFloat16x4(b + 7 * ldb); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v1); + MlasStoreFloat16x4(PackedB_data + 12, v5); + MlasStoreFloat16x4(PackedB_data + 16, v2); + MlasStoreFloat16x4(PackedB_data + 20, v6); + MlasStoreFloat16x4(PackedB_data + 24, v3); + MlasStoreFloat16x4(PackedB_data + 28, v7); + k -= 4, b += 4, PackedB_data += 4 * 8; + } + + if (k > 0) { + float16x4_t v0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t v1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t v2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t v3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t v4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t v5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t v6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t v7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + if (k > 1) { + MlasStoreFloat16x4(PackedB_data + 8, v1); + MlasStoreFloat16x4(PackedB_data + 12, v5); + } + if (k > 2) { + MlasStoreFloat16x4(PackedB_data + 16, v2); + MlasStoreFloat16x4(PackedB_data + 20, v6); + } + + PackedB_data += k * 8; + } + + B_data += 8 * ldb; + CountN -= 8; + } + + if (CountN > 0) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 8; // pack extended 8 * 8 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadFloat16x8(b + i * ldb); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x8(); + } + Transpose8x8(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]); + MlasStoreFloat16x8(PackedB_data, v[0]); + MlasStoreFloat16x8(PackedB_data + 8, v[1]); + MlasStoreFloat16x8(PackedB_data + 16, v[2]); + MlasStoreFloat16x8(PackedB_data + 24, v[3]); + MlasStoreFloat16x8(PackedB_data + 32, v[4]); + MlasStoreFloat16x8(PackedB_data + 40, v[5]); + MlasStoreFloat16x8(PackedB_data + 48, v[6]); + MlasStoreFloat16x8(PackedB_data + 56, v[7]); + } + + if (k & 4) { + float16x4_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadFloat16x4(b + i * ldb); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x4(); + } + Transpose4x4(v[0], v[1], v[2], v[3]); + Transpose4x4(v[4], v[5], v[6], v[7]); + MlasStoreFloat16x4(PackedB_data, v[0]); + MlasStoreFloat16x4(PackedB_data + 4, v[4]); + MlasStoreFloat16x4(PackedB_data + 8, v[1]); + MlasStoreFloat16x4(PackedB_data + 12, v[5]); + MlasStoreFloat16x4(PackedB_data + 16, v[2]); + MlasStoreFloat16x4(PackedB_data + 20, v[6]); + MlasStoreFloat16x4(PackedB_data + 24, v[3]); + MlasStoreFloat16x4(PackedB_data + 28, v[7]); + k -= 4, b += 4, PackedB_data += 4 * 8; + } + + if (k > 0) { + float16x4_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadPartialFloat16x4(b + i * ldb, k); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x4(); + } + Transpose4x4(v[0], v[1], v[2], v[3]); + Transpose4x4(v[4], v[5], v[6], v[7]); + MlasStoreFloat16x4(PackedB_data, v[0]); + MlasStoreFloat16x4(PackedB_data + 4, v[4]); + if (k > 1) { + MlasStoreFloat16x4(PackedB_data + 8, v[1]); + MlasStoreFloat16x4(PackedB_data + 12, v[5]); + } + if (k > 2) { + MlasStoreFloat16x4(PackedB_data + 16, v[2]); + MlasStoreFloat16x4(PackedB_data + 20, v[6]); + } + } + } +} + +MLAS_FORCEINLINE +float16x8_t addq_f16x4(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3) { + v0 = vaddq_f16(v0, v1); + v2 = vaddq_f16(v2, v3); + v0 = vaddq_f16(v0, v2); + return v0; +} + +MLAS_FORCEINLINE +float16x8_t addq_f16x8(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x8_t v4, float16x8_t v5, float16x8_t v6, float16x8_t v7) { + return vaddq_f16(addq_f16x4(v0, v1, v2, v3), addq_f16x4(v4, v5, v6, v7)); +} + +MLAS_FORCEINLINE +float16x8_t maq_lane_f16_accu(float16x8_t accu0, float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x4_t a0) { + accu0 = vfmaq_lane_f16(accu0, v0, a0, 0); + accu0 = vfmaq_lane_f16(accu0, v1, a0, 1); + accu0 = vfmaq_lane_f16(accu0, v2, a0, 2); + accu0 = vfmaq_lane_f16(accu0, v3, a0, 3); + return accu0; +} + +MLAS_FORCEINLINE +float16x8_t maq_laneq_f16_accu(float16x8_t accu0, float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x8_t v4, float16x8_t v5, float16x8_t v6, float16x8_t v7, float16x8_t a0) { + accu0 = vfmaq_laneq_f16(accu0, v0, a0, 0); + accu0 = vfmaq_laneq_f16(accu0, v1, a0, 1); + accu0 = vfmaq_laneq_f16(accu0, v2, a0, 2); + accu0 = vfmaq_laneq_f16(accu0, v3, a0, 3); + accu0 = vfmaq_laneq_f16(accu0, v4, a0, 4); + accu0 = vfmaq_laneq_f16(accu0, v5, a0, 5); + accu0 = vfmaq_laneq_f16(accu0, v6, a0, 6); + accu0 = vfmaq_laneq_f16(accu0, v7, a0, 7); + return accu0; +} + +MLAS_FORCEINLINE +float16x4_t ma_lane_f16_accu(float16x4_t accu, float16x4_t v0, float16x4_t v1, float16x4_t v2, float16x4_t v3, + float16x4_t a0) { + accu = vfma_lane_f16(accu, v0, a0, 0); + accu = vfma_lane_f16(accu, v1, a0, 1); + accu = vfma_lane_f16(accu, v2, a0, 2); + accu = vfma_lane_f16(accu, v3, a0, 3); + return accu; +} + +template // 0: beta == 0.0f16, 1: beta == 1.0f16, 2: beta != 0.0f16 && beta != 1.0f16 +void HGemm_TransposedB_Kernel_M1( + const _mlas_fp16_* A_data, + const _mlas_fp16_* B_data, + _mlas_fp16_* C_data, + size_t CountN, + size_t CountK, + size_t ldb, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + for (; CountN >= 8; CountN -= 8, B_data += 8 * ldb, C_data += 8) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + float16x8_t accu1 = MlasZeroFloat16x8(); + float16x8_t accu2 = MlasZeroFloat16x8(); + float16x8_t accu3 = MlasZeroFloat16x8(); + float16x8_t accu4 = MlasZeroFloat16x8(); + float16x8_t accu5 = MlasZeroFloat16x8(); + float16x8_t accu6 = MlasZeroFloat16x8(); + float16x8_t accu7 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t b5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t b6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t b7 = MlasLoadFloat16x8(b + 7 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + accu0 = vfmaq_f16(accu0, b0, a0); + accu1 = vfmaq_f16(accu1, b1, a0); + accu2 = vfmaq_f16(accu2, b2, a0); + accu3 = vfmaq_f16(accu3, b3, a0); + accu4 = vfmaq_f16(accu4, b4, a0); + accu5 = vfmaq_f16(accu5, b5, a0); + accu6 = vfmaq_f16(accu6, b6, a0); + accu7 = vfmaq_f16(accu7, b7, a0); + } + Transpose8x8(accu0, accu1, accu2, accu3, accu4, accu5, accu6, accu7); + accu0 = addq_f16x8(accu0, accu1, accu2, accu3, accu4, accu5, accu6, accu7); // accumulator of 8 columns + + if (k & 4) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t b4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t b5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t b6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t b7 = MlasLoadFloat16x4(b + 7 * ldb); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4); + float16x8_t v1 = vcombine_f16(b1, b5); + float16x8_t v2 = vcombine_f16(b2, b6); + float16x8_t v3 = vcombine_f16(b3, b7); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu0 = maq_lane_f16_accu(accu0, v0, v1, v2, v3, a0); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t b4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t b5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t b6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t b7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4), v1, v2; + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + accu0 = vfmaq_lane_f16(accu0, v0, a0, 0); + if (k > 1) { + v1 = vcombine_f16(b1, b5); + accu0 = vfmaq_lane_f16(accu0, v1, a0, 1); + } + if (k > 2) { + v2 = vcombine_f16(b2, b6); + accu0 = vfmaq_lane_f16(accu0, v2, a0, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x8_t c = MlasLoadFloat16x8(C_data); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vfmaq_f16(c, accu0, alpha_v); + MlasStoreFloat16x8(C_data, accu0); + } else if constexpr (beta_behavior == 2) { + float16x8_t c = MlasLoadFloat16x8(C_data); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu0 = vfmaq_f16(vmulq_f16(c, beta_v), accu0, alpha_v); + MlasStoreFloat16x8(C_data, accu0); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vmulq_f16(accu0, alpha_v); + MlasStoreFloat16x8(C_data, accu0); + } + } + + if (CountN & 4) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + float16x8_t accu1 = MlasZeroFloat16x8(); + float16x8_t accu2 = MlasZeroFloat16x8(); + float16x8_t accu3 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + accu0 = vfmaq_f16(accu0, b0, a0); + accu1 = vfmaq_f16(accu1, b1, a0); + accu2 = vfmaq_f16(accu2, b2, a0); + accu3 = vfmaq_f16(accu3, b3, a0); + } + Transpose4x8(accu0, accu1, accu2, accu3); + accu0 = addq_f16x4(accu0, accu1, accu2, accu3); // accumulator of 4 columns + float16x4_t accu = vadd_f16(vget_low_f16(accu0), vget_high_f16(accu0)); + + if (k & 4) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu = ma_lane_f16_accu(accu, b0, b1, b2, b3, a0); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + accu = vfma_lane_f16(accu, b0, a0, 0); + if (k > 1) { + accu = vfma_lane_f16(accu, b1, a0, 1); + } + if (k > 2) { + accu = vfma_lane_f16(accu, b2, a0, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t c = MlasLoadFloat16x4(C_data); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu = vfma_f16(c, accu, alpha_v); + MlasStoreFloat16x4(C_data, accu); + } else if constexpr (beta_behavior == 2) { + float16x4_t c = MlasLoadFloat16x4(C_data); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + accu = vfma_f16(vmul_f16(c, beta_v), accu, alpha_v); + MlasStoreFloat16x4(C_data, accu); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu = vmul_f16(accu, alpha_v); + MlasStoreFloat16x4(C_data, accu); + } + + CountN -= 4, B_data += 4 * ldb, C_data += 4; + } + + if (CountN > 0) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accus[4]; + size_t i = 0; + for (i = 0; i < 4; ++i) { + accus[i] = MlasZeroFloat16x8(); + } + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t a0 = MlasLoadFloat16x8(a); + for (i = 0; i < CountN; ++i) { + accus[i] = vfmaq_f16(accus[i], MlasLoadFloat16x8(b + i * ldb), a0); + } + } + Transpose4x8(accus[0], accus[1], accus[2], accus[3]); + float16x8_t accu0 = addq_f16x4(accus[0], accus[1], accus[2], accus[3]); // accumulator of 4 columns + float16x4_t accu = vadd_f16(vget_low_f16(accu0), vget_high_f16(accu0)); + + if (k & 4) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadFloat16x4(b + i * ldb); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu = ma_lane_f16_accu(accu, bs[0], bs[1], bs[2], bs[3], a0); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadPartialFloat16x4(b + i * ldb, k); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + accu = vfma_lane_f16(accu, bs[0], a0, 0); + if (k > 1) { + accu = vfma_lane_f16(accu, bs[1], a0, 1); + } + if (k > 2) { + accu = vfma_lane_f16(accu, bs[2], a0, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t c = MlasLoadPartialFloat16x4(C_data, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu = vfma_f16(c, accu, alpha_v); + MlasStorePartialFloat16x4(C_data, accu, CountN); + } else if constexpr (beta_behavior == 2) { + float16x4_t c = MlasLoadPartialFloat16x4(C_data, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + accu = vfma_f16(vmul_f16(c, beta_v), accu, alpha_v); + MlasStorePartialFloat16x4(C_data, accu, CountN); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu = vmul_f16(accu, alpha_v); + MlasStorePartialFloat16x4(C_data, accu, CountN); + } + } +} + +template // 0: beta == 0.0f16, 1: beta == 1.0f16, 2: beta != 0.0f16 && beta != 1.0f16 +void HGemm_TransposedB_Kernel_M2( + const _mlas_fp16_* A_data, + const _mlas_fp16_* B_data, + _mlas_fp16_* C_data, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + for (; CountN >= 8; CountN -= 8, B_data += 8 * ldb, C_data += 8) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu04 = MlasZeroFloat16x8(); + float16x8_t accu05 = MlasZeroFloat16x8(); + float16x8_t accu06 = MlasZeroFloat16x8(); + float16x8_t accu07 = MlasZeroFloat16x8(); + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + float16x8_t accu12 = MlasZeroFloat16x8(); + float16x8_t accu13 = MlasZeroFloat16x8(); + float16x8_t accu14 = MlasZeroFloat16x8(); + float16x8_t accu15 = MlasZeroFloat16x8(); + float16x8_t accu16 = MlasZeroFloat16x8(); + float16x8_t accu17 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t b5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t b6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t b7 = MlasLoadFloat16x8(b + 7 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + accu04 = vfmaq_f16(accu04, b4, a0); + accu05 = vfmaq_f16(accu05, b5, a0); + accu06 = vfmaq_f16(accu06, b6, a0); + accu07 = vfmaq_f16(accu07, b7, a0); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + accu14 = vfmaq_f16(accu14, b4, a1); + accu15 = vfmaq_f16(accu15, b5, a1); + accu16 = vfmaq_f16(accu16, b6, a1); + accu17 = vfmaq_f16(accu17, b7, a1); + } + Transpose8x8(accu00, accu01, accu02, accu03, accu04, accu05, accu06, accu07); + Transpose8x8(accu10, accu11, accu12, accu13, accu14, accu15, accu16, accu17); + accu00 = addq_f16x8(accu00, accu01, accu02, accu03, accu04, accu05, accu06, accu07); + accu10 = addq_f16x8(accu10, accu11, accu12, accu13, accu14, accu15, accu16, accu17); + + if (k & 4) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t b4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t b5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t b6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t b7 = MlasLoadFloat16x4(b + 7 * ldb); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4); + float16x8_t v1 = vcombine_f16(b1, b5); + float16x8_t v2 = vcombine_f16(b2, b6); + float16x8_t v3 = vcombine_f16(b3, b7); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu00 = maq_lane_f16_accu(accu00, v0, v1, v2, v3, a0); + accu10 = maq_lane_f16_accu(accu10, v0, v1, v2, v3, a1); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t b4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t b5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t b6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t b7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu00 = vfmaq_lane_f16(accu00, v0, a0, 0); + accu10 = vfmaq_lane_f16(accu10, v0, a1, 0); + if (k > 1) { + float16x8_t v1 = vcombine_f16(b1, b5); + accu00 = vfmaq_lane_f16(accu00, v1, a0, 1); + accu10 = vfmaq_lane_f16(accu10, v1, a1, 1); + } + if (k > 2) { + float16x8_t v2 = vcombine_f16(b2, b6); + accu00 = vfmaq_lane_f16(accu00, v2, a0, 2); + accu10 = vfmaq_lane_f16(accu10, v2, a1, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x8_t c0 = MlasLoadFloat16x8(C_data); + float16x8_t c1 = MlasLoadFloat16x8(C_data + ldc); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vfmaq_f16(c0, accu00, alpha_v); + accu10 = vfmaq_f16(c1, accu10, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + MlasStoreFloat16x8(C_data + ldc, accu10); + } else if constexpr (beta_behavior == 2) { + float16x8_t c0 = MlasLoadFloat16x8(C_data); + float16x8_t c1 = MlasLoadFloat16x8(C_data + ldc); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu00 = vfmaq_f16(vmulq_f16(c0, beta_v), accu00, alpha_v); + accu10 = vfmaq_f16(vmulq_f16(c1, beta_v), accu10, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + MlasStoreFloat16x8(C_data + ldc, accu10); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vmulq_f16(accu00, alpha_v); + accu10 = vmulq_f16(accu10, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + MlasStoreFloat16x8(C_data + ldc, accu10); + } + } + + if (CountN & 4) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + float16x8_t accu12 = MlasZeroFloat16x8(); + float16x8_t accu13 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + } + Transpose4x8(accu00, accu01, accu02, accu03); + Transpose4x8(accu10, accu11, accu12, accu13); + accu00 = addq_f16x4(accu00, accu01, accu02, accu03); + accu10 = addq_f16x4(accu10, accu11, accu12, accu13); + float16x4_t accu0 = vadd_f16(vget_low_f16(accu00), vget_high_f16(accu00)); + float16x4_t accu1 = vadd_f16(vget_low_f16(accu10), vget_high_f16(accu10)); + + if (k & 4) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu0 = ma_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + accu1 = ma_lane_f16_accu(accu1, b0, b1, b2, b3, a1); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu0 = vfma_lane_f16(accu0, b0, a0, 0); + accu1 = vfma_lane_f16(accu1, b0, a1, 0); + if (k > 1) { + accu0 = vfma_lane_f16(accu0, b1, a0, 1); + accu1 = vfma_lane_f16(accu1, b1, a1, 1); + } + if (k > 2) { + accu0 = vfma_lane_f16(accu0, b2, a0, 2); + accu1 = vfma_lane_f16(accu1, b2, a1, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadFloat16x4(C_data); + float16x4_t c1 = MlasLoadFloat16x4(C_data + ldc); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu0 = vfma_f16(c0, accu0, alpha_v); + accu1 = vfma_f16(c1, accu1, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + MlasStoreFloat16x4(C_data + ldc, accu1); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadFloat16x4(C_data); + float16x4_t c1 = MlasLoadFloat16x4(C_data + ldc); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + accu0 = vfma_f16(vmul_f16(c0, beta_v), accu0, alpha_v); + accu1 = vfma_f16(vmul_f16(c1, beta_v), accu1, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + MlasStoreFloat16x4(C_data + ldc, accu1); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu0 = vmul_f16(accu0, alpha_v); + accu1 = vmul_f16(accu1, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + MlasStoreFloat16x4(C_data + ldc, accu1); + } + + CountN -= 4, B_data += 4 * ldb, C_data += 4; + } + + if (CountN > 0) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu0[4]; + float16x8_t accu1[4]; + size_t i = 0; + for (i = 0; i < 4; ++i) { + accu0[i] = MlasZeroFloat16x8(); + accu1[i] = MlasZeroFloat16x8(); + } + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + for (i = 0; i < CountN; ++i) { + float16x8_t bi = MlasLoadFloat16x8(b + i * ldb); + accu0[i] = vfmaq_f16(accu0[i], bi, a0); + accu1[i] = vfmaq_f16(accu1[i], bi, a1); + } + } + Transpose4x8(accu0[0], accu0[1], accu0[2], accu0[3]); + Transpose4x8(accu1[0], accu1[1], accu1[2], accu1[3]); + float16x8_t accu00 = addq_f16x4(accu0[0], accu0[1], accu0[2], accu0[3]); + float16x4_t accu_0 = vadd_f16(vget_low_f16(accu00), vget_high_f16(accu00)); + float16x8_t accu10 = addq_f16x4(accu1[0], accu1[1], accu1[2], accu1[3]); + float16x4_t accu_1 = vadd_f16(vget_low_f16(accu10), vget_high_f16(accu10)); + + if (k & 4) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadFloat16x4(b + i * ldb); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu_0 = ma_lane_f16_accu(accu_0, bs[0], bs[1], bs[2], bs[3], a0); + accu_1 = ma_lane_f16_accu(accu_1, bs[0], bs[1], bs[2], bs[3], a1); + k -= 4, a += 4, b += 4; + } + + if (k > 0) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadPartialFloat16x4(b + i * ldb, k); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu_0 = vfma_lane_f16(accu_0, bs[0], a0, 0); + accu_1 = vfma_lane_f16(accu_1, bs[0], a1, 0); + if (k > 1) { + accu_0 = vfma_lane_f16(accu_0, bs[1], a0, 1); + accu_1 = vfma_lane_f16(accu_1, bs[1], a1, 1); + } + if (k > 2) { + accu_0 = vfma_lane_f16(accu_0, bs[2], a0, 2); + accu_1 = vfma_lane_f16(accu_1, bs[2], a1, 2); + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C_data, CountN); + float16x4_t c1 = MlasLoadPartialFloat16x4(C_data + ldc, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu_0 = vfma_f16(c0, accu_0, alpha_v); + accu_1 = vfma_f16(c1, accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C_data, CountN); + float16x4_t c1 = MlasLoadPartialFloat16x4(C_data + ldc, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + accu_0 = vfma_f16(vmul_f16(c0, beta_v), accu_0, alpha_v); + accu_1 = vfma_f16(vmul_f16(c1, beta_v), accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu_0 = vmul_f16(accu_0, alpha_v); + accu_1 = vmul_f16(accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } + } +} + +// Full K. Directly save to C. +void HGemm_TransposedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + if (CountM > 2) { + MLAS_THROW_EX(std::runtime_error, "HGemm_TransposedB_Kernel only support <= 2 rows"); + } + const auto* A_data = reinterpret_cast(A); + const auto* B_data = reinterpret_cast(B); + auto* C_data = reinterpret_cast<_mlas_fp16_*>(C); + const auto f16_0 = MLAS_FP16(0.0f); + const auto f16_1 = MLAS_FP16(1.0f); + if (CountM == 1) { + if (beta == f16_0.val) { + HGemm_TransposedB_Kernel_M1<0>(A_data, B_data, C_data, CountN, CountK, ldb, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedB_Kernel_M1<1>(A_data, B_data, C_data, CountN, CountK, ldb, alpha, beta); + } else { + HGemm_TransposedB_Kernel_M1<2>(A_data, B_data, C_data, CountN, CountK, ldb, alpha, beta); + } + } else { + if (beta == f16_0.val) { + HGemm_TransposedB_Kernel_M2<0>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedB_Kernel_M2<1>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else { + HGemm_TransposedB_Kernel_M2<2>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } + } +} + +template // 0: beta == 0, 1: beta == 1, 2: beta != 0 && beta != 1 +void HGemm_TransposedPackedB_Kernel_M1( + const _mlas_fp16_* A, + const _mlas_fp16_* PackedB, + _mlas_fp16_* C, + size_t CountN, + size_t CountK, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + for (; CountN >= 16; CountN -= 16, C += 16) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + float16x8_t accu1 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 16) { + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t b40 = MlasLoadFloat16x8(PackedB + 64); + float16x8_t b41 = MlasLoadFloat16x8(PackedB + 72); + float16x8_t b50 = MlasLoadFloat16x8(PackedB + 80); + float16x8_t b51 = MlasLoadFloat16x8(PackedB + 88); + float16x8_t b60 = MlasLoadFloat16x8(PackedB + 96); + float16x8_t b61 = MlasLoadFloat16x8(PackedB + 104); + float16x8_t b70 = MlasLoadFloat16x8(PackedB + 112); + float16x8_t b71 = MlasLoadFloat16x8(PackedB + 120); + float16x8_t a0 = MlasLoadFloat16x8(a); + accu0 = maq_laneq_f16_accu(accu0, b00, b10, b20, b30, b40, b50, b60, b70, a0); + accu1 = maq_laneq_f16_accu(accu1, b01, b11, b21, b31, b41, b51, b61, b71, a0); + } + + if (k & 4) { + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 56); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu0 = maq_lane_f16_accu(accu0, b00, b10, b20, b30, a0); + accu1 = maq_lane_f16_accu(accu1, b01, b11, b21, b31, a0); + k -= 4, a += 4, PackedB += 4 * 16; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + accu0 = vfmaq_lane_f16(accu0, b00, a0, 0); + accu1 = vfmaq_lane_f16(accu1, b01, a0, 0); + if (k > 1) { + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + accu0 = vfmaq_lane_f16(accu0, b10, a0, 1); + accu1 = vfmaq_lane_f16(accu1, b11, a0, 1); + } + if (k > 2) { + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + accu0 = vfmaq_lane_f16(accu0, b20, a0, 2); + accu1 = vfmaq_lane_f16(accu1, b21, a0, 2); + } + + PackedB += k * 16; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t c1 = MlasLoadFloat16x8(C + 8); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vfmaq_f16(c0, accu0, alpha_v); + accu1 = vfmaq_f16(c1, accu1, alpha_v); + MlasStoreFloat16x8(C, accu0); + MlasStoreFloat16x8(C + 8, accu1); + } else if constexpr (beta_behavior == 2) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t c1 = MlasLoadFloat16x8(C + 8); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu0 = vfmaq_f16(vmulq_f16(c0, beta_v), accu0, alpha_v); + accu1 = vfmaq_f16(vmulq_f16(c1, beta_v), accu1, alpha_v); + MlasStoreFloat16x8(C, accu0); + MlasStoreFloat16x8(C + 8, accu1); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vmulq_f16(accu0, alpha_v); + accu1 = vmulq_f16(accu1, alpha_v); + MlasStoreFloat16x8(C, accu0); + MlasStoreFloat16x8(C + 8, accu1); + } + } + + if (CountN & 8) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 8) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b4 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b5 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b6 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b7 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t a0 = MlasLoadFloat16x8(a); + accu0 = maq_laneq_f16_accu(accu0, b0, b1, b2, b3, b4, b5, b6, b7, a0); + } + + if (k & 4) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu0 = maq_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu0 = vfmaq_lane_f16(accu0, b0, a0, 0); + if (k > 1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu0 = vfmaq_lane_f16(accu0, b1, a0, 1); + } + if (k > 2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu0 = vfmaq_lane_f16(accu0, b2, a0, 2); + } + PackedB += k * 8; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vfmaq_f16(c0, accu0, alpha_v); + MlasStoreFloat16x8(C, accu0); + } else if constexpr (beta_behavior == 2) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu0 = vfmaq_f16(vmulq_f16(c0, beta_v), accu0, alpha_v); + MlasStoreFloat16x8(C, accu0); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu0 = vmulq_f16(accu0, alpha_v); + MlasStoreFloat16x8(C, accu0); + } + + CountN -= 8, C += 8; + } + + if (CountN > 0) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 8) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b4 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b5 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b6 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b7 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t a0 = MlasLoadFloat16x8(a); + accu0 = maq_laneq_f16_accu(accu0, b0, b1, b2, b3, b4, b5, b6, b7, a0); + } + + if (k & 4) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu0 = maq_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu0 = vfmaq_lane_f16(accu0, b0, a0, 0); + if (k > 1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu0 = vfmaq_lane_f16(accu0, b1, a0, 1); + } + if (k > 2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu0 = vfmaq_lane_f16(accu0, b2, a0, 2); + } + PackedB += k * 8; + } + + float16x4_t accu_low = vget_low_f16(accu0); + float16x4_t accu_high = vget_high_f16(accu0); + + if (CountN & 4) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadFloat16x4(C); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStoreFloat16x4(C, vfma_f16(c0, accu_low, alpha_v)); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadFloat16x4(C); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + MlasStoreFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v), accu_low, alpha_v)); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStoreFloat16x4(C, vmul_f16(accu_low, alpha_v)); + } + + CountN -= 4, C += 4; + accu_low = accu_high; + } + + if (CountN) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStorePartialFloat16x4(C, vfma_f16(c0, accu_low, alpha_v), CountN); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + MlasStorePartialFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v), accu_low, alpha_v), CountN); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStorePartialFloat16x4(C, vmul_f16(accu_low, alpha_v), CountN); + } + } + } +} + +template // 0: beta == 0, 1: beta == 1, 2: beta != 0 && beta != 1 +void HGemm_TransposedPackedB_Kernel_M2( + const _mlas_fp16_* A, + const _mlas_fp16_* PackedB, + _mlas_fp16_* C, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + for (; CountN >= 16; CountN -= 16, C += 16) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 16) { + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t b40 = MlasLoadFloat16x8(PackedB + 64); + float16x8_t b41 = MlasLoadFloat16x8(PackedB + 72); + float16x8_t b50 = MlasLoadFloat16x8(PackedB + 80); + float16x8_t b51 = MlasLoadFloat16x8(PackedB + 88); + float16x8_t b60 = MlasLoadFloat16x8(PackedB + 96); + float16x8_t b61 = MlasLoadFloat16x8(PackedB + 104); + float16x8_t b70 = MlasLoadFloat16x8(PackedB + 112); + float16x8_t b71 = MlasLoadFloat16x8(PackedB + 120); + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu00 = maq_laneq_f16_accu(accu00, b00, b10, b20, b30, b40, b50, b60, b70, a0); + accu01 = maq_laneq_f16_accu(accu01, b01, b11, b21, b31, b41, b51, b61, b71, a0); + accu10 = maq_laneq_f16_accu(accu10, b00, b10, b20, b30, b40, b50, b60, b70, a1); + accu11 = maq_laneq_f16_accu(accu11, b01, b11, b21, b31, b41, b51, b61, b71, a1); + } + + if (k & 4) { + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 56); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + k -= 4, a += 4, PackedB += 4 * 16; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + if (k > 1) { + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + } + if (k > 2) { + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + } + PackedB += k * 16; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vfmaq_f16(c00, accu00, alpha_v); + accu01 = vfmaq_f16(c01, accu01, alpha_v); + accu10 = vfmaq_f16(c10, accu10, alpha_v); + accu11 = vfmaq_f16(c11, accu11, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } else if constexpr (beta_behavior == 2) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu00 = vfmaq_f16(vmulq_f16(c00, beta_v), accu00, alpha_v); + accu01 = vfmaq_f16(vmulq_f16(c01, beta_v), accu01, alpha_v); + accu10 = vfmaq_f16(vmulq_f16(c10, beta_v), accu10, alpha_v); + accu11 = vfmaq_f16(vmulq_f16(c11, beta_v), accu11, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vmulq_f16(accu00, alpha_v); + accu01 = vmulq_f16(accu01, alpha_v); + accu10 = vmulq_f16(accu10, alpha_v); + accu11 = vmulq_f16(accu11, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } + + if (CountN & 8) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu10 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 8) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b4 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b5 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b6 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b7 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu00 = maq_laneq_f16_accu(accu00, b0, b1, b2, b3, b4, b5, b6, b7, a0); + accu10 = maq_laneq_f16_accu(accu10, b0, b1, b2, b3, b4, b5, b6, b7, a1); + } + + if (k & 4) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu00 = maq_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + accu10 = maq_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu00 = vfmaq_lane_f16(accu00, b0, a0, 0); + accu10 = vfmaq_lane_f16(accu10, b0, a1, 0); + if (k > 1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu00 = vfmaq_lane_f16(accu00, b1, a0, 1); + accu10 = vfmaq_lane_f16(accu10, b1, a1, 1); + } + if (k > 2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu00 = vfmaq_lane_f16(accu00, b2, a0, 2); + accu10 = vfmaq_lane_f16(accu10, b2, a1, 2); + } + PackedB += k * 8; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t c1 = MlasLoadFloat16x8(C + ldc); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vfmaq_f16(c0, accu00, alpha_v); + accu10 = vfmaq_f16(c1, accu10, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + ldc, accu10); + } else if constexpr (beta_behavior == 2) { + float16x8_t c0 = MlasLoadFloat16x8(C); + float16x8_t c1 = MlasLoadFloat16x8(C + ldc); + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + accu00 = vfmaq_f16(vmulq_f16(c0, beta_v), accu00, alpha_v); + accu10 = vfmaq_f16(vmulq_f16(c1, beta_v), accu10, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + ldc, accu10); + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vmulq_f16(accu00, alpha_v); + accu10 = vmulq_f16(accu10, alpha_v); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + ldc, accu10); + } + + CountN -= 8, C += 8; + } + + if (CountN > 0) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(); + float16x8_t accu1 = MlasZeroFloat16x8(); + for (; k >= 8; k -= 8, a += 8, PackedB += 8 * 8) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b4 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b5 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b6 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b7 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t a0 = MlasLoadFloat16x8(a); + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu0 = maq_laneq_f16_accu(accu0, b0, b1, b2, b3, b4, b5, b6, b7, a0); + accu1 = maq_laneq_f16_accu(accu1, b0, b1, b2, b3, b4, b5, b6, b7, a1); + } + + if (k & 4) { + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + float16x4_t a0 = MlasLoadFloat16x4(a); + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu0 = maq_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + accu1 = maq_lane_f16_accu(accu1, b0, b1, b2, b3, a1); + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (k > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k); + float16x4_t a1 = MlasLoadPartialFloat16x4(a + lda, k); + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu0 = vfmaq_lane_f16(accu0, b0, a0, 0); + accu1 = vfmaq_lane_f16(accu1, b0, a1, 0); + if (k > 1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu0 = vfmaq_lane_f16(accu0, b1, a0, 1); + accu1 = vfmaq_lane_f16(accu1, b1, a1, 1); + } + if (k > 2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu0 = vfmaq_lane_f16(accu0, b2, a0, 2); + accu1 = vfmaq_lane_f16(accu1, b2, a1, 2); + } + PackedB += k * 8; + } + + float16x4_t accu0_low = vget_low_f16(accu0); + float16x4_t accu0_high = vget_high_f16(accu0); + float16x4_t accu1_low = vget_low_f16(accu1); + float16x4_t accu1_high = vget_high_f16(accu1); + + if (CountN & 4) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadFloat16x4(C); + float16x4_t c1 = MlasLoadFloat16x4(C + ldc); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStoreFloat16x4(C, vfma_f16(c0, accu0_low, alpha_v)); + MlasStoreFloat16x4(C + ldc, vfma_f16(c1, accu1_low, alpha_v)); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadFloat16x4(C); + float16x4_t c1 = MlasLoadFloat16x4(C + ldc); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + MlasStoreFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v), accu0_low, alpha_v)); + MlasStoreFloat16x4(C + ldc, vfma_f16(vmul_f16(c1, beta_v), accu1_low, alpha_v)); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStoreFloat16x4(C, vmul_f16(accu0_low, alpha_v)); + MlasStoreFloat16x4(C + ldc, vmul_f16(accu1_low, alpha_v)); + } + CountN -= 4, C += 4; + accu0_low = accu0_high; + accu1_low = accu1_high; + } + + if (CountN) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + float16x4_t c1 = MlasLoadPartialFloat16x4(C + ldc, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStorePartialFloat16x4(C, vfma_f16(c0, accu0_low, alpha_v), CountN); + MlasStorePartialFloat16x4(C + ldc, vfma_f16(c1, accu1_low, alpha_v), CountN); + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + float16x4_t c1 = MlasLoadPartialFloat16x4(C + ldc, CountN); + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + MlasStorePartialFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v), accu0_low, alpha_v), CountN); + MlasStorePartialFloat16x4(C + ldc, vfma_f16(vmul_f16(c1, beta_v), accu1_low, alpha_v), CountN); + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + MlasStorePartialFloat16x4(C, vmul_f16(accu0_low, alpha_v), CountN); + MlasStorePartialFloat16x4(C + ldc, vmul_f16(accu1_low, alpha_v), CountN); + } + } + } +} + +void HGemm_TransposedPackedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + if (CountM > 2) { + MLAS_THROW_EX(std::runtime_error, "HGemm_TransposedPackedB_Kernel only support <= 2 rows"); + } + + const auto* A_data = reinterpret_cast(A); + const auto* PackedB_data = reinterpret_cast(PackedB); + auto* C_data = reinterpret_cast<_mlas_fp16_*>(C); + const auto f16_0 = MLAS_FP16(0.0f); + const auto f16_1 = MLAS_FP16(1.0f); + if (CountM == 1) { + if (beta == f16_0.val) { + HGemm_TransposedPackedB_Kernel_M1<0>(A_data, PackedB_data, C_data, CountN, CountK, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedPackedB_Kernel_M1<1>(A_data, PackedB_data, C_data, CountN, CountK, alpha, beta); + } else { + HGemm_TransposedPackedB_Kernel_M1<2>(A_data, PackedB_data, C_data, CountN, CountK, alpha, beta); + } + } else { + if (beta == f16_0.val) { + HGemm_TransposedPackedB_Kernel_M2<0>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedPackedB_Kernel_M2<1>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else { + HGemm_TransposedPackedB_Kernel_M2<2>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } + } +} + +} // namespace hgemm_neon diff --git a/onnxruntime/core/mlas/lib/hgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/hgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..5b131a8e41f21 --- /dev/null +++ b/onnxruntime/core/mlas/lib/hgemm_kernel_neon.cpp @@ -0,0 +1,28 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + hgemm_kernel_neon.cpp + +Abstract: + + This module implements half precision GEMM kernel for neon. + +--*/ + +#include "mlasi.h" +#include "halfgemm.h" + +const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon = [](){ + MLAS_HGEMM_DISPATCH d; +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + d.HPackBKernel_TransposedB = hgemm_neon::HPackB_TransposedB_Kernel; + d.HGemmKernel_TransposedB = hgemm_neon::HGemm_TransposedB_Kernel; + d.HGemmKernel_TransposedPackedB = hgemm_neon::HGemm_TransposedPackedB_Kernel; +#endif + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp index 69e37d2b916d1..5b1f9d7d4a2dc 100644 --- a/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp @@ -93,39 +93,6 @@ Transpose8x8(uint8x8_t& v0, uint8x8_t& v1, uint8x8_t& v2, uint8x8_t& v3, v7 = vreinterpret_u8_u32(c3.val[1]); } -MLAS_FORCEINLINE void -Transpose4x8(float16x8_t& v0, float16x8_t& v1, float16x8_t& v2, float16x8_t& v3) -{ - // |v00|v01|v02|v03|v04|v05|v06|v07| - // |v10|v11|v12|v13|v14|v15|v16|v17| - // |v20|v21|v22|v23|v24|v25|v26|v27| - // |v30|v31|v32|v33|v34|v35|v36|v37| - // => - // |v00|v10|v20|v30|v04|v14|v24|v34| - // |v01|v11|v21|v31|v05|v15|v25|v35| - // |v02|v12|v22|v32|v06|v16|v26|v36| - // |v03|v13|v23|v33|v07|v17|v27|v37| - float16x8x2_t t01 = vtrnq_f16(v0, v1); - float16x8x2_t t23 = vtrnq_f16(v2, v3); - - v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); - v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); - v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); - v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); -} - -MLAS_FORCEINLINE void -Transpose4x4(float16x4_t& v0, float16x4_t& v1, float16x4_t& v2, float16x4_t& v3) -{ - float16x4x2_t t01 = vtrn_f16(v0, v1); - float16x4x2_t t23 = vtrn_f16(v2, v3); - - v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); - v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); - v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); - v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); -} - void HQ4BitGemmPackQuantBData_CompFp16( size_t N, diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 100d7d47751aa..56fad6bb3412a 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -301,6 +301,8 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // Define the default strides to step through slices of the input matrices. // +#define MLAS_HGEMM_STRIDEN 32 +#define MLAS_HGEMM_STRIDEK 512 #define MLAS_SGEMM_STRIDEN 128 #define MLAS_SGEMM_STRIDEK 128 #define MLAS_SGEMM_PACKED_STRIDEN 128 @@ -317,6 +319,7 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // the effort at this time. // +#define MLAS_HGEMM_STRIDEN_THREAD_ALIGN 16 #define MLAS_SGEMM_STRIDEN_THREAD_ALIGN 16 #define MLAS_DGEMM_STRIDEN_THREAD_ALIGN 8 #define MLAS_QGEMM_STRIDEN_THREAD_ALIGN 16 @@ -944,6 +947,7 @@ extern "C" { #define MLAS_SGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#define MLAS_HGEMM_THREAD_COMPLEXITY 65536 #if defined(__aarch64__) && defined(__linux__) #define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) @@ -1055,6 +1059,12 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; struct MLAS_ROPE_DISPATCH; extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; +// +// half gemm dispatch structure +// +struct MLAS_HGEMM_DISPATCH; +extern const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon; + // // Quantized depthwise convolution kernels. @@ -1217,6 +1227,7 @@ struct MLAS_PLATFORM { MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; + const MLAS_HGEMM_DISPATCH* HGemmDispatch{nullptr}; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index ec572a4150292..026a954bbc6c2 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -544,6 +544,7 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; this->RopeDispatch = &MlasRopeDispatchNeon; + this->HGemmDispatch = &MlasHGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. diff --git a/onnxruntime/test/mlas/bench/bench_hgemm.cpp b/onnxruntime/test/mlas/bench/bench_hgemm.cpp new file mode 100644 index 0000000000000..1e8b0eb7c34d6 --- /dev/null +++ b/onnxruntime/test/mlas/bench/bench_hgemm.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas.h" +#include "bench_util.h" +#include "core/util/thread_utils.h" + +#include +#include + +static const std::vector hgemm_bench_arg_names = {"M", "N", "K"}; + +void HGEMM(benchmark::State& state, bool transA, bool transB) { + if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); + const size_t M = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); + + auto A = RandomVectorUniform(static_cast(M * K), MLAS_FP16(-1.0f), MLAS_FP16(1.0f)); + auto B = RandomVectorUniform(static_cast(N * K), MLAS_FP16(-1.0f), MLAS_FP16(1.0f)); + std::vector C(static_cast(M * N)); + + MLAS_FP16 alpha = MLAS_FP16(1.0f); + MLAS_FP16 beta = MLAS_FP16(0.0f); + OrtThreadPoolParams tpo; + tpo.thread_pool_size = 8; + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + MlasGemm( + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + static_cast(M), + static_cast(N), + static_cast(K), + A.data(), + transA ? M : K, + B.data(), + transB ? K : N, + C.data(), + N, + alpha.val, + beta.val, + tp.get()); + + for (auto _ : state) { + MlasGemm( + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + static_cast(M), + static_cast(N), + static_cast(K), + A.data(), + transA ? M : K, + B.data(), + transB ? K : N, + C.data(), + N, + alpha.val, + beta.val, + tp.get()); + } +} + +static void GemmSizeWithOne(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{1}, {63, 255, 1023}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {1}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {1}}); +} +BENCHMARK_CAPTURE(HGEMM, GEMV_TransB, false, true)->Apply(GemmSizeWithOne)->UseRealTime(); + +static void GemmSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}}); +} +BENCHMARK_CAPTURE(HGEMM, NORMAL_TransB, false, true)->Apply(GemmSizeProducts)->UseRealTime(); + +static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); +} +BENCHMARK_CAPTURE(HGEMM, LLM, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_hgemm_neon.cpp b/onnxruntime/test/mlas/unittest/test_hgemm_neon.cpp new file mode 100644 index 0000000000000..4f3d690b432bf --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_hgemm_neon.cpp @@ -0,0 +1,393 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_hgemm_neon.cpp + +Abstract: + + Tests for MLAS fp16 GEMM on ARM CPU. + +--*/ + +#include +#include + +#include "test/mlas/unittest/test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/halfgemm.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonHGemmPackBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer input_, ref_, packed_; + + template + MLAS_FORCEINLINE void PackB(const MLAS_FP16* src, MLAS_FP16* dst) { + size_t i = 0; + for (; i + 16 <= N; i += 16) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 16; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + } + } + if (i + 8 <= N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 8; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + } + i += 8; + } + if (i < N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < N - i; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + dst += 8 - (N - i); + } + } + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* packed, const MLAS_FP16* ref) { + size_t n = ((N + 7) & ~7) * K; + for (size_t i = 0; i < n; ++i) { + ASSERT_EQ(packed[i].val, ref[i].val) << " seed " << seed_ << " i " << i; + } + } + + template + void TestPackB() { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(N * K, InitializeBuffer); + auto* packed = packed_.GetBuffer(K * ((N + 7) & ~7), true); + auto* ref = ref_.GetBuffer(K * ((N + 7) & ~7), true); + hgemm_neon::HPackB_TransposedB_Kernel(input, packed, N, K, K); + PackB(input, ref); + Check(packed, ref); + } + + public: + MlasNeonHGemmPackBTest() + : seed_(rd_()), gen_(seed_), distrib_(-100.f, 100.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmPackB"; + } + + void ExecuteShort(void) override { + TestPackB<1, 1>(); + TestPackB<1, 15>(); + TestPackB<1, 31>(); + TestPackB<8, 1>(); + TestPackB<8, 16>(); + TestPackB<9, 31>(); + TestPackB<9, 33>(); + TestPackB<15, 33>(); + TestPackB<17, 67>(); + TestPackB<17, 96>(); + TestPackB<265, 263>(); + } +}; + +class MlasNeonHGemmTransposedBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm(const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * K + k].ToFloat()) * (B[n * K + k].ToFloat()); + } + C[m * N + n] = MLAS_FP16(accu * alphaf + C[m * N + n].ToFloat() * betaf); + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref) { + size_t n = M * N; + for (size_t i = 0; i < n; ++i) { + ASSERT_TRUE(FloatEqual(C[i], ref[i], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i + << " M " << M << " N " << N << " K " << K + << " v0 " << C[i] << " v1 " << ref[i]; + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* A = A_.GetFilledBuffer(M * K, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(K * N, InitializeBuffer); + auto* C = C_.GetBuffer(M * N, true); + auto* ref = ref_.GetBuffer(M * N, true); + hgemm_neon::HGemm_TransposedB_Kernel(A, B, C, M, N, K, K, K, N, alpha.val, beta.val); + HGemm(A, B, ref, alpha, beta); + Check(C, ref); + } + + public: + MlasNeonHGemmTransposedBTest() + : seed_(1928375), gen_(seed_), distrib_(-1.f, 1.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmTransposedB"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 78, 263>(MLAS_FP16(0.5f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.5f), MLAS_FP16(1.0f)); + } +}; + +class MlasNeonHGemmTransposedPackedBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm(const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + size_t n = 0; + for (; n + 16 <= N; n += 16) { + for (size_t i = 0; i < 16; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * K + k].ToFloat()) * (B[n * K + k * 16 + i].ToFloat()); + } + C[m * N + n + i] = MLAS_FP16(accu * alphaf + C[m * N + n + i].ToFloat() * betaf); + } + } + } + if (n + 8 <= N) { + for (size_t i = 0; i < 8; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * K + k].ToFloat()) * (B[n * K + k * 8 + i].ToFloat()); + } + C[m * N + n + i] = MLAS_FP16(accu * alphaf + C[m * N + n + i].ToFloat() * betaf); + } + } + n += 8; + } + if (n < N) { + for (size_t i = 0; i < N - n; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * K + k].ToFloat()) * (B[n * K + k * 8 + i].ToFloat()); + } + C[m * N + n + i] = MLAS_FP16(accu * alphaf + C[m * N + n + i].ToFloat() * betaf); + } + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref) { + size_t n = M * N; + for (size_t i = 0; i < n; ++i) { + ASSERT_TRUE(FloatEqual(C[i], ref[i], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i + << " M " << M << " K " << K << " N " << N + << " v0 " << C[i] << " v1 " << ref[i]; + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* A = A_.GetFilledBuffer(M * K, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(K * ((N + 7) & ~7), InitializeBuffer); + auto* C = C_.GetBuffer(M * N, true); + auto* ref = ref_.GetBuffer(M * N, true); + hgemm_neon::HGemm_TransposedPackedB_Kernel(A, B, C, M, N, K, K, N, alpha.val, beta.val); + HGemm(A, B, ref, alpha, beta); + Check(C, ref); + } + + public: + MlasNeonHGemmTransposedPackedBTest() + : seed_(1928372), gen_(seed_), distrib_(-1.f, 1.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmTransposedPackedB"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 78, 263>(MLAS_FP16(0.5f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.5f), MLAS_FP16(1.0f)); + } +}; + +class MlasNeonHGemmTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm(const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[i * K + k].ToFloat()) * (B[j * K + k].ToFloat()); + } + C[i * N + j] = MLAS_FP16(accu * alphaf + C[i * N + j].ToFloat() * betaf); + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + ASSERT_TRUE(FloatEqual(C[i * N + j], ref[i * N + j], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i << " j " << j + << " M " << M << " K " << K << " N " << N + << " v0 " << C[i * N + j] << " v1 " << ref[i * N + j]; + } + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* A = A_.GetFilledBuffer(M * K, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(K * N, InitializeBuffer); + auto* C = C_.GetBuffer(M * N, true); + auto* ref = ref_.GetBuffer(M * N, true); + MlasGemm(CblasNoTrans, CblasTrans, M, N, K, A, K, B, K, C, N, alpha.val, beta.val, nullptr); + HGemm(A, B, ref, alpha, beta); + Check(C, ref); + } + + public: + MlasNeonHGemmTest() + : seed_(192837), gen_(seed_), distrib_(-0.25f, 0.25f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemm"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 128, 512>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 128, 513>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 128, 511>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 129, 512>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 127, 512>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 513, 1023>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 511, 1025>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<127, 513, 1023>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<129, 511, 1025>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) From 1fc9c4823d7c2e8f0d07a09315a0755dd7c58ef8 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 24 Jan 2025 18:18:37 -0800 Subject: [PATCH 2/2] Enable coremltools for Linux build (#23481) ### Description Enable coremltools for Linux build. In order to do this, I did: 1. Add uuid-devel to the Linux images and regenerate them. 2. Patch the coremltools code a little bit to add some missing header files. ### Motivation and Context To make the code simpler. Later on I will create another PR to remove the COREML_ENABLE_MLPROGRAM C/C++ macro. Also, after this PR I will bring more changes to onnxruntime_provider_coreml.cmake to make it work with vcpkg. --- cmake/onnxruntime_providers_coreml.cmake | 85 +++++++++---------- .../coremltools/crossplatformbuild.patch | 81 ++++++++++++------ .../azure-pipelines/bigmodels-ci-pipeline.yml | 2 +- .../azure-pipelines/linux-gpu-ci-pipeline.yml | 4 +- .../linux-gpu-tensorrt-ci-pipeline.yml | 4 +- ...-gpu-tensorrt-cuda-minimal-ci-pipeline.yml | 4 +- .../py-cuda-alt-package-test-pipeline.yml | 2 +- .../py-cuda-package-test-pipeline.yml | 2 +- .../stages/java-cuda-packaging-stage.yml | 4 +- .../jobs/py-linux-cuda-package-test-job.yml | 4 +- .../stages/py-gpu-packaging-stage.yml | 4 +- .../linux/docker/Dockerfile.manylinux2_28_cpu | 2 +- .../inference/aarch64/default/cpu/Dockerfile | 2 +- .../inference/aarch64/python/cpu/Dockerfile | 2 +- .../inference/x86_64/default/cpu/Dockerfile | 2 +- .../x86_64/default/cuda11/Dockerfile | 2 +- .../x86_64/default/cuda12/Dockerfile | 2 +- .../inference/x86_64/python/cpu/Dockerfile | 2 +- 18 files changed, 116 insertions(+), 94 deletions(-) diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index ec7bc7a98969e..18048c8cdce2f 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -8,25 +8,18 @@ endif() add_compile_definitions(USE_COREML=1) # Check if we can build the coremltools code for creating an mlpackage with an mlprogram. -# The coremltools source requires std::filesystem::path which is only available from iOS 13 on. -set(_enable_ML_PROGRAM ON) -if (IOS AND CMAKE_OSX_DEPLOYMENT_TARGET VERSION_LESS 13.0) - message(WARNING "CoreML ML Program is not supported on iOS < 13.0. Excluding ML Program support from build.") - set(_enable_ML_PROGRAM OFF) -elseif(LINUX) - # uuid-dev is required. we don't bother installing on CIs as it's really for manual developer testing. +if(LINUX) find_library(LibUUID_LIBRARY NAMES uuid) find_path(LibUUID_INCLUDE_DIR NAMES uuid/uuid.h) if (NOT LibUUID_INCLUDE_DIR) - message(STATUS "uuid/uuid.h was not found as is required for ML Program support. " + message(FATAL "uuid/uuid.h was not found as is required for ML Program support. " "Run `sudo apt install uuid-dev` if you need to test ML Program related CoreML EP code. ") - set(_enable_ML_PROGRAM OFF) endif() endif() -if (_enable_ML_PROGRAM) - add_compile_definitions(COREML_ENABLE_MLPROGRAM=1) -endif() + +add_compile_definitions(COREML_ENABLE_MLPROGRAM=1) + # Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml_proto set(COREML_PROTO_ROOT ${coremltools_SOURCE_DIR}/mlmodel/format) @@ -93,10 +86,10 @@ file(GLOB_RECURSE "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" ) -if(_enable_ML_PROGRAM) + # Add helpers to create mlpackage weights. limit to just the files we need to minimize the changes to make them # build on Windows and Linux. - file(GLOB +file(GLOB onnxruntime_providers_coreml_milblob_cc_srcs CONFIGURE_DEPENDS "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.hpp" "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.cpp" @@ -105,22 +98,22 @@ if(_enable_ML_PROGRAM) "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageFormat.hpp" "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/FileWriter.?pp" "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageWriter.?pp" - ) +) - # Add helpers to create mlpackage - file(GLOB +# Add helpers to create mlpackage +file(GLOB onnxruntime_providers_coreml_modelpackage_cc_srcs CONFIGURE_DEPENDS "${coremltools_SOURCE_DIR}/modelpackage/src/ModelPackage.?pp" "${coremltools_SOURCE_DIR}/modelpackage/src/utils/JsonMap.?pp" - ) +) - set(coremltools_srcs +set(coremltools_srcs ${onnxruntime_providers_coreml_milblob_cc_srcs} ${onnxruntime_providers_coreml_modelpackage_cc_srcs} - ) +) + +source_group(TREE ${coremltools_SOURCE_DIR} PREFIX coremltools FILES ${coremltools_srcs}) - source_group(TREE ${coremltools_SOURCE_DIR} PREFIX coremltools FILES ${coremltools_srcs}) -endif() # Add CoreML objective c++ source code if (APPLE) @@ -174,34 +167,34 @@ if (APPLE) target_compile_definitions(onnxruntime_providers_coreml PRIVATE __APPLE__) endif() -if (_enable_ML_PROGRAM) - # Setup coremltools fp16 and json dependencies for creating an mlpackage. - # - # fp16 depends on psimd - FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) - onnxruntime_fetchcontent_makeavailable(psimd) - set(PSIMD_SOURCE_DIR ${psimd_SOURCE_DIR}) - FetchContent_Declare(fp16 URL ${DEP_URL_fp16} URL_HASH SHA1=${DEP_SHA1_fp16}) - set(FP16_BUILD_TESTS OFF CACHE INTERNAL "") - set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") - onnxruntime_fetchcontent_makeavailable(fp16) - - # need to tweak the include paths to match what the coreml source code expects - target_include_directories(onnxruntime_providers_coreml PRIVATE - ${fp16_SOURCE_DIR}/include - ${nlohmann_json_SOURCE_DIR}/single_include/nlohmann - ${coremltools_SOURCE_DIR} - ${coremltools_SOURCE_DIR}/mlmodel/src/ - ${coremltools_SOURCE_DIR}/modelpackage/src/ - ) - add_dependencies(onnxruntime_providers_coreml nlohmann_json::nlohmann_json fp16) +# Setup coremltools fp16 and json dependencies for creating an mlpackage. +# +# fp16 depends on psimd +FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) +onnxruntime_fetchcontent_makeavailable(psimd) +set(PSIMD_SOURCE_DIR ${psimd_SOURCE_DIR}) +FetchContent_Declare(fp16 URL ${DEP_URL_fp16} URL_HASH SHA1=${DEP_SHA1_fp16}) +set(FP16_BUILD_TESTS OFF CACHE INTERNAL "") +set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") +onnxruntime_fetchcontent_makeavailable(fp16) + +# need to tweak the include paths to match what the coreml source code expects +target_include_directories(onnxruntime_providers_coreml PRIVATE + ${fp16_SOURCE_DIR}/include + ${nlohmann_json_SOURCE_DIR}/single_include/nlohmann + ${coremltools_SOURCE_DIR} + ${coremltools_SOURCE_DIR}/mlmodel/src/ + ${coremltools_SOURCE_DIR}/modelpackage/src/ +) - if (LINUX) - target_link_libraries(onnxruntime_providers_coreml PRIVATE uuid) - endif() +add_dependencies(onnxruntime_providers_coreml nlohmann_json::nlohmann_json fp16) + +if (LINUX) + target_link_libraries(onnxruntime_providers_coreml PRIVATE uuid) endif() + if (APPLE) target_link_libraries(onnxruntime_providers_coreml PRIVATE "-framework Foundation" "-framework CoreML") endif() diff --git a/cmake/patches/coremltools/crossplatformbuild.patch b/cmake/patches/coremltools/crossplatformbuild.patch index 7f2268f50c82e..832191b366d4d 100644 --- a/cmake/patches/coremltools/crossplatformbuild.patch +++ b/cmake/patches/coremltools/crossplatformbuild.patch @@ -3,7 +3,7 @@ index adc7bfcf..7b2bf9cc 100644 --- a/mlmodel/src/MILBlob/Blob/FileWriter.cpp +++ b/mlmodel/src/MILBlob/Blob/FileWriter.cpp @@ -8,8 +8,12 @@ - + #include #include + @@ -12,17 +12,31 @@ index adc7bfcf..7b2bf9cc 100644 #include #include +#endif - + using namespace MILBlob; using namespace MILBlob::Blob; +diff --git a/mlmodel/src/MILBlob/Blob/FileWriter.hpp b/mlmodel/src/MILBlob/Blob/FileWriter.hpp +index 2bc99403..49239513 100644 +--- a/mlmodel/src/MILBlob/Blob/FileWriter.hpp ++++ b/mlmodel/src/MILBlob/Blob/FileWriter.hpp +@@ -6,7 +6,8 @@ + #pragma once + + #include "MILBlob/Util/Span.hpp" +- ++// ORT_EDIT: add missing header ++#include + #include + #include + #include diff --git a/mlmodel/src/MILBlob/Fp16.cpp b/mlmodel/src/MILBlob/Fp16.cpp index ae1e71a1..77a7161f 100644 --- a/mlmodel/src/MILBlob/Fp16.cpp +++ b/mlmodel/src/MILBlob/Fp16.cpp @@ -5,6 +5,8 @@ - + #include "MILBlob/Fp16.hpp" - + +// ORT_EDIT: Exclude clang specific pragmas from other builds +#if defined(__clang__) // fp16 lib code has some conversion warnings we don't want to globally ignore @@ -35,11 +49,11 @@ index ae1e71a1..77a7161f 100644 +#else +#include "fp16/fp16.h" +#endif - + using namespace MILBlob; - + diff --git a/modelpackage/src/ModelPackage.cpp b/modelpackage/src/ModelPackage.cpp -index 8fee56b9..99e0d8d6 100644 +index 8fee56b9..5508e316 100644 --- a/modelpackage/src/ModelPackage.cpp +++ b/modelpackage/src/ModelPackage.cpp @@ -26,7 +26,14 @@ namespace std { @@ -55,22 +69,22 @@ index 8fee56b9..99e0d8d6 100644 #include +#endif #include - + #if defined(__cplusplus) @@ -187,7 +194,10 @@ public: ModelPackageItemInfo createFile(const std::string& name, const std::string& author, const std::string& description); }; - + +// ORT_EDIT: pragma only available on APPLE platforms +#if defined(__APPLE__) #pragma mark ModelPackageImpl +#endif - + ModelPackageImpl::ModelPackageImpl(const std::filesystem::path& path, bool createIfNecessary, bool readOnly) : m_packagePath(path), @@ -372,6 +382,20 @@ std::filesystem::path ModelPackageImpl::getItemPath(const std::string& name, con } - + std::string ModelPackageImpl::generateIdentifier() const { +// ORT_EDIT: Use built-in UUID generation on Windows +#if defined(_WIN32) @@ -87,20 +101,20 @@ index 8fee56b9..99e0d8d6 100644 + return uuidStrCpp; +#else uuid_t uuid; - + // uuid_unparse generates a 36-character null-terminated string (37 bytes). @@ -383,6 +407,7 @@ std::string ModelPackageImpl::generateIdentifier() const { uuid_unparse(uuid, buf); - + return std::string(buf); +#endif } - + ModelPackageItemInfo ModelPackageImpl::createFile(const std::string& name, const std::string& author, const std::string& description) { -@@ -468,7 +493,13 @@ std::shared_ptr ModelPackageImpl::findItem(const std::stri +@@ -468,7 +493,14 @@ std::shared_ptr ModelPackageImpl::findItem(const std::stri auto author = itemInfoEntry->getString(kModelPackageItemInfoAuthorKey); auto description = itemInfoEntry->getString(kModelPackageItemInfoDescriptionKey); - + +// ORT_EDIT: need to use path.string() on Windows +#if defined(_WIN32) + return std::make_shared(std::make_shared(identifier, path.string(), name, author, description)); @@ -108,12 +122,13 @@ index 8fee56b9..99e0d8d6 100644 +#else return std::make_shared(std::make_shared(identifier, path, name, author, description)); +#endif ++ } - + std::shared_ptr ModelPackageImpl::findItem(const std::string& name, const std::string& author) const -@@ -514,7 +545,9 @@ void ModelPackageImpl::removeItem(const std::string& identifier) +@@ -514,7 +546,9 @@ void ModelPackageImpl::removeItem(const std::string& identifier) } - + auto path = m_packageDataDirPath / itemInfoEntry->getString(kModelPackageItemInfoPathKey); - if (0 != std::remove(path.c_str())) { + // ORT_EDIT: std::remove doesn't work on Windows. Use std::filesystem::remove instead. @@ -121,8 +136,8 @@ index 8fee56b9..99e0d8d6 100644 + if (!std::filesystem::remove(path)) { throw std::runtime_error("Failed to remove file at path: " + path.string()); } - -@@ -525,13 +558,16 @@ bool ModelPackageImpl::isValid(const std::filesystem::path& path) + +@@ -525,13 +559,16 @@ bool ModelPackageImpl::isValid(const std::filesystem::path& path) { try { ModelPackageImpl(path, false, true); @@ -132,16 +147,16 @@ index 8fee56b9..99e0d8d6 100644 } return true; } - + +// ORT_EDIT: pragma only available on APPLE platforms +#if defined(__APPLE__) #pragma mark ModelPackage +#endif - + ModelPackage::ModelPackage(const std::string& packagePath, bool createIfNecessary, bool readOnly) : m_modelPackageImpl(std::make_shared(packagePath, createIfNecessary, readOnly)) -@@ -544,7 +580,12 @@ ModelPackage::~ModelPackage() - +@@ -544,7 +581,12 @@ ModelPackage::~ModelPackage() + std::string ModelPackage::path() const { +// ORT_EDIT: Windows doesn't automatically convert to std::string as the native format could be char or wchar. @@ -151,5 +166,19 @@ index 8fee56b9..99e0d8d6 100644 return m_modelPackageImpl->path(); +#endif } - + std::string ModelPackage::setRootModel(const std::string& path, const std::string& name, const std::string& author, const std::string& description) +diff --git a/modelpackage/src/utils/JsonMap.hpp b/modelpackage/src/utils/JsonMap.hpp +index 0d7dc3f4..b700cfd5 100644 +--- a/modelpackage/src/utils/JsonMap.hpp ++++ b/modelpackage/src/utils/JsonMap.hpp +@@ -10,7 +10,8 @@ + #include + #include + #include +- ++// ORT_EDIT: add missing header ++#include + class JsonMapImpl; + + class JsonMap { diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index 59deb0d4975fe..0eaaea562ca36 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -41,7 +41,7 @@ parameters: variables: - name: docker_base_image - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250109.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 - name: linux_trt_version value: 10.3.0.26-1.cuda11.8 - name: Repository diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 518aec8c2f92a..71f7ab6e49b70 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -49,9 +49,9 @@ parameters: variables: - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250109.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250109.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 - name: Repository ${{ if eq(parameters.CudaVersion, '11.8') }}: diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index 9025f084d5982..c08eaaaa1308d 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -40,9 +40,9 @@ variables: - template: templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250109.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250109.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: value: ${{ variables.linux_trt_version_cuda11 }} diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml index 8d42e7201411b..4a86da167ff1f 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml @@ -40,9 +40,9 @@ variables: - template: templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250109.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250109.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: value: ${{ variables.linux_trt_version_cuda11 }} diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml index 4b94ffc7e302e..960b59f93bee0 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-alt-package-test-pipeline.yml @@ -18,7 +18,7 @@ stages: machine_pool: 'Onnxruntime-Linux-GPU' python_wheel_suffix: '_gpu' timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250109.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 cuda_version: '11.8' - stage: Republish_Wheels diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml index 48d1e6b1ac7a7..021f7c5ece140 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml @@ -18,7 +18,7 @@ stages: machine_pool: 'Onnxruntime-Linux-GPU' python_wheel_suffix: '_gpu' timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250109.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 cuda_version: '12.2' - stage: Republish_Wheels diff --git a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml index 0517fec3bad04..b081b39ad9bcc 100644 --- a/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/java-cuda-packaging-stage.yml @@ -142,9 +142,9 @@ stages: value: false - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250109.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250109.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 timeoutInMinutes: 60 steps: diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index 4adf41d3db4e5..85366ffc28b3a 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -45,9 +45,9 @@ jobs: - template: ../../templates/common-variables.yml - name: docker_base_image ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250109.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 ${{ if eq(parameters.CudaVersion, '12.2') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250109.1 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 - name: linux_trt_version ${{ if eq(parameters.CudaVersion, '11.8') }}: value: ${{ variables.linux_trt_version_cuda11 }} diff --git a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml index a3c804055d8fb..f48573abd3dba 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml @@ -68,9 +68,9 @@ stages: cmake_build_type: ${{ parameters.cmake_build_type }} cuda_version: ${{ parameters.cuda_version }} ${{ if eq(parameters.cuda_version, '11.8') }}: - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250109.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20250124.1 ${{ if eq(parameters.cuda_version, '12.2') }}: - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250109.1 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 - ${{ if eq(parameters.enable_windows_dml, true) }}: - ${{ each python_version in parameters.PythonVersions }}: diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index 72912acce885e..02938f015ec57 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14:20250109.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14:20250124.1 ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile index 9569aa2fcda63..f9d84e3b0e130 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc14_dotnet:20250109.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc14_dotnet:20250124.1 ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile index 589bd869ba89f..20b9a6c224120 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc14:20250109.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_ubi8_gcc14:20250124.1 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile index 1c1f716d81e95..d94e7562f19d4 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14_dotnet:20250109.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14_dotnet:20250124.1 ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile index 6caf21c475545..24287fd34d3ea 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda11/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11_dotnet:20250109.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11_dotnet:20250124.1 ARG TRT_VERSION #Install TensorRT only if TRT_VERSION is not empty diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile index a5dda5904de49..764a79135d7a3 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12_dotnet:20250109.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12_dotnet:20250124.1 ARG TRT_VERSION #Install TensorRT only if TRT_VERSION is not empty diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile index 04c6398e061b7..7590d5dd18347 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14:20250109.1 +FROM onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_ubi8_gcc14:20250124.1 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && rm -rf /tmp/scripts