Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into snnn/vcpkg
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jan 25, 2025
2 parents 4a643dd + 1fc9c48 commit d897c46
Show file tree
Hide file tree
Showing 14 changed files with 2,624 additions and 44 deletions.
5 changes: 5 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -374,6 +376,7 @@ else()
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
Expand All @@ -394,6 +397,7 @@ else()
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
Expand All @@ -406,6 +410,7 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
40 changes: 30 additions & 10 deletions cmake/onnxruntime_providers_coreml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ endif()
add_compile_definitions(USE_COREML=1)
add_compile_definitions(COREML_ENABLE_MLPROGRAM=1)


# Check if we can build the coremltools code for creating an mlpackage with an mlprogram.
if(LINUX)
find_library(LibUUID_LIBRARY NAMES uuid)
find_path(LibUUID_INCLUDE_DIR NAMES uuid/uuid.h)
if (NOT LibUUID_INCLUDE_DIR)
message(FATAL "uuid/uuid.h was not found as is required for ML Program support. "
"Run `sudo apt install uuid-dev` if you need to test ML Program related CoreML EP code. ")
endif()
endif()

# Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml_proto
set(COREML_PROTO_ROOT ${coremltools_SOURCE_DIR}/mlmodel/format)
file(GLOB coreml_proto_srcs "${COREML_PROTO_ROOT}/*.proto")
Expand Down Expand Up @@ -75,7 +86,7 @@ file(GLOB_RECURSE

# Add helpers to create mlpackage weights. limit to just the files we need to minimize the changes to make them
# build on Windows and Linux.
file(GLOB
file(GLOB
onnxruntime_providers_coreml_milblob_cc_srcs CONFIGURE_DEPENDS
"${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.hpp"
"${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.cpp"
Expand All @@ -84,22 +95,21 @@ file(GLOB_RECURSE
"${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageFormat.hpp"
"${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/FileWriter.?pp"
"${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageWriter.?pp"
)
)

# Add helpers to create mlpackage
file(GLOB
# Add helpers to create mlpackage
file(GLOB
onnxruntime_providers_coreml_modelpackage_cc_srcs CONFIGURE_DEPENDS
"${coremltools_SOURCE_DIR}/modelpackage/src/ModelPackage.?pp"
"${coremltools_SOURCE_DIR}/modelpackage/src/utils/JsonMap.?pp"
)
)

set(coremltools_srcs
set(coremltools_srcs
${onnxruntime_providers_coreml_milblob_cc_srcs}
${onnxruntime_providers_coreml_modelpackage_cc_srcs}
)

source_group(TREE ${coremltools_SOURCE_DIR} PREFIX coremltools FILES ${coremltools_srcs})
)

source_group(TREE ${coremltools_SOURCE_DIR} PREFIX coremltools FILES ${coremltools_srcs})

# Add CoreML objective c++ source code
if (APPLE)
Expand Down Expand Up @@ -173,20 +183,30 @@ if (APPLE)
endif()



# need to tweak the include paths to match what the coreml source code expects
target_include_directories(onnxruntime_providers_coreml PRIVATE
${coremltools_SOURCE_DIR}
${coremltools_SOURCE_DIR}/mlmodel/src/
${coremltools_SOURCE_DIR}/modelpackage/src/
)

add_dependencies(onnxruntime_providers_coreml nlohmann_json::nlohmann_json fp16)
# need to tweak the include paths to match what the coreml source code expects
target_include_directories(onnxruntime_providers_coreml PRIVATE
${fp16_SOURCE_DIR}/include
${nlohmann_json_SOURCE_DIR}/single_include/nlohmann
${coremltools_SOURCE_DIR}
${coremltools_SOURCE_DIR}/mlmodel/src/
${coremltools_SOURCE_DIR}/modelpackage/src/
)


if (LINUX)
target_link_libraries(onnxruntime_providers_coreml PRIVATE uuid)
endif()



if (APPLE)
target_link_libraries(onnxruntime_providers_coreml PRIVATE "-framework Foundation" "-framework CoreML")
endif()
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class GQAAttentionBase {
int seqlen_present_kv_cache = static_cast<int>(present_key->Shape().GetDims()[2]);

// Compute the attention score.
// TODO(fajin): type depends on kernel supportability
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float);
auto attention_probs = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
Expand Down Expand Up @@ -198,6 +199,11 @@ class GQAAttentionBase {
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/,
output, static_cast<int>(present_buffer_sequence_length), nullptr);
// TODO(fajin): update later
// } else if (MlasHGemmSupported(CblasNoTrans, CblasTrans)) {
// MlasGemm(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size,
// q, static_cast<int>(head_size), k, static_cast<int>(head_size), output,
// static_cast<int>(present_buffer_sequence_length), alpha, 0.0f /*beta*/, nullptr);
} else {
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
auto q_k_fp32 = allocator->Alloc(bytes);
Expand Down
102 changes: 101 additions & 1 deletion onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,107 @@ MlasRotaryEmbedOneRow(
T* output
);

/**
/**
* @brief Supply matrices data information to half precision gemm functions
*/
struct MLAS_HGEMM_DATA_PARAMS {
const MLAS_FP16* A; /**< Supplies the address of matrix A */
size_t lda; /**< Supplies the first dimension of matrix A. */
const MLAS_FP16* B; /**< Supplies the address of matrix B */
size_t ldb; /**< Supplies the first dimension of matrix B. */
MLAS_FP16* C; /**< Supplies the address of matrix C */
size_t ldc; /**< Supplies the first dimension of matrix C. */
uint16_t alpha; /**< Supplies the scalar alpha multiplier (see GEMM definition). FP16 encoding. */
uint16_t beta; /**< Supplies the scalar beta multiplier (see GEMM definition). FP16 encoding. */
};

/**
* @brief Check whether current CPU supports half precision gemm.
*/
bool
MLASCALL
MlasHGemmSupported(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB
);

/**
* @brief Batched half precision matrix/matrix multiply operation (HGEMM)
*
* @param TransA Supplies the transpose operation for matrix A.
* @param TransB Supplies the transpose operation for matrix B.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number of rows of matrix B.
* @param Data A array of matrices data parameters
* @param BatchSize Supplies number of multiplications in this batch
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
*/
void
MLASCALL
MlasGemmBatch(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_HGEMM_DATA_PARAMS* Data,
size_t BatchSize,
MLAS_THREADPOOL* ThreadPool
);

/**
* @brief half precision matrix/matrix multiply operation (HGEMM)
* C = alpha * op(A) * op(B) + beta * C
*
* @param TransA Supplies the transpose operation for matrix A. Currently only support CblasNoTrans.
* @param TransB Supplies the transpose operation for matrix B. Currently only support CblasTrans.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number of rows of matrix B.
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
* @param C Supplies the address of matrix C
* @param ldc Supplies the first dimension of matrix C.
* @param alpha Supplies the scalar alpha multiplier (see GEMM definition)
* @param beta Supplies the scalar beta multiplier (see GEMM definition)
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the base library threading support
* should be used.
*/
inline
void
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_FP16* A,
size_t lda,
const MLAS_FP16* B,
size_t ldb,
MLAS_FP16* C,
size_t ldc,
uint16_t alpha,
uint16_t beta,
MLAS_THREADPOOL* ThreadPool
) {
MLAS_HGEMM_DATA_PARAMS Data;
Data.A = A;
Data.lda = lda;
Data.B = B;
Data.ldb = ldb;
Data.C = C;
Data.ldc = ldc;
Data.alpha = alpha;
Data.beta = beta;
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
}

/**
* @brief Whether current CPU supports FP16 acceleration.
*/
bool MLASCALL
Expand Down
99 changes: 99 additions & 0 deletions onnxruntime/core/mlas/lib/fp16_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,103 @@ MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT
return vbsl_f16(select, ones, zeros);
}

MLAS_FORCEINLINE
void
Transpose8x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3,
MLAS_FLOAT16X8& v4, MLAS_FLOAT16X8& v5, MLAS_FLOAT16X8& v6, MLAS_FLOAT16X8& v7)
{
// |v00|v01|v02|v03|v04|v05|v06|v07|
// |v10|v11|v12|v13|v14|v15|v16|v17|
// |v20|v21|v22|v23|v24|v25|v26|v27|
// |v30|v31|v32|v33|v34|v35|v36|v37|
// |v40|v41|v42|v43|v44|v45|v46|v47|
// |v50|v51|v52|v53|v54|v55|v56|v57|
// |v60|v61|v62|v63|v64|v65|v66|v67|
// |v70|v71|v72|v73|v74|v75|v76|v77|
float16x8x2_t t01 = vtrnq_f16(v0, v1);
float16x8x2_t t23 = vtrnq_f16(v2, v3);
float16x8x2_t t45 = vtrnq_f16(v4, v5);
float16x8x2_t t67 = vtrnq_f16(v6, v7);
// |v00|v10|v02|v12|v04|v14|v06|v16|
// |v01|v11|v03|v13|v05|v15|v07|v17|
// |v20|v30|v22|v32|v24|v34|v26|v36|
// |v21|v31|v23|v33|v25|v35|v27|v37|
// |v40|v50|v42|v52|v44|v54|v46|v56|
// |v41|v51|v43|v53|v45|v55|v47|v57|
// |v60|v70|v62|v72|v64|v74|v66|v76|
// |v61|v71|v63|v73|v65|v75|v67|v77|
float32x4x2_t t02 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]));
float32x4x2_t t13 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]));
float32x4x2_t t46 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[0]), vreinterpretq_f32_f16(t67.val[0]));
float32x4x2_t t57 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[1]), vreinterpretq_f32_f16(t67.val[1]));
// |v00|v10|v20|v30|v04|v14|v24|v34|
// |v01|v11|v21|v31|v05|v15|v25|v35|
// |v02|v12|v22|v32|v06|v16|v26|v36|
// |v03|v13|v23|v33|v07|v17|v27|v37|
// |v40|v50|v60|v70|v44|v54|v64|v74|
// |v41|v51|v61|v71|v45|v55|v65|v75|
// |v42|v52|v62|v72|v46|v56|v66|v76|
// |v43|v53|v63|v73|v47|v57|v67|v77|
v0 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0])));
v4 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0])));
v2 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1])));
v6 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1])));
v1 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0])));
v5 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0])));
v3 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1])));
v7 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1])));
// |v00|v10|v20|v30|v40|v50|v60|v70|
// |v01|v11|v21|v31|v41|v51|v61|v71|
// |v02|v12|v22|v32|v42|v52|v62|v72|
// |v03|v13|v23|v33|v43|v53|v63|v73|
// |v04|v14|v24|v34|v44|v54|v64|v74|
// |v05|v15|v25|v35|v45|v55|v65|v75|
// |v06|v16|v26|v36|v46|v56|v66|v76|
// |v07|v17|v27|v37|v47|v57|v67|v77|
}

MLAS_FORCEINLINE
void
Transpose4x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3)
{
// |v00|v01|v02|v03|v04|v05|v06|v07|
// |v10|v11|v12|v13|v14|v15|v16|v17|
// |v20|v21|v22|v23|v24|v25|v26|v27|
// |v30|v31|v32|v33|v34|v35|v36|v37|
// =>
// |v00|v10|v20|v30|v04|v14|v24|v34|
// |v01|v11|v21|v31|v05|v15|v25|v35|
// |v02|v12|v22|v32|v06|v16|v26|v36|
// |v03|v13|v23|v33|v07|v17|v27|v37|
float16x8x2_t t01 = vtrnq_f16(v0, v1);
float16x8x2_t t23 = vtrnq_f16(v2, v3);

v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])));
v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])));
v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])));
v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])));
}

MLAS_FORCEINLINE
void
Transpose4x4(MLAS_FLOAT16X4& v0, MLAS_FLOAT16X4& v1, MLAS_FLOAT16X4& v2, MLAS_FLOAT16X4& v3)
{
// |v00|v01|v02|v03|
// |v10|v11|v12|v13|
// |v20|v21|v22|v23|
// |v30|v31|v32|v33|
// =>
// |v00|v10|v20|v30|
// |v01|v11|v21|v31|
// |v02|v12|v22|v32|
// |v03|v13|v23|v33|
float16x4x2_t t01 = vtrn_f16(v0, v1);
float16x4x2_t t23 = vtrn_f16(v2, v3);

v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0])));
v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1])));
v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0])));
v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1])));
}

#endif // fp16 vector intrinsic supported
Loading

0 comments on commit d897c46

Please sign in to comment.