Skip to content

Commit

Permalink
MLAS: add U8U8 MatMul operation (#1644)
Browse files Browse the repository at this point in the history
Implement the first round of changes for quantization inside MLAS. This adds a MatMul operation for U8xU8=S32 for x86/x64 processors.
  • Loading branch information
tracysh authored Aug 19, 2019
1 parent fbd790f commit bc72c2d
Show file tree
Hide file tree
Showing 20 changed files with 4,958 additions and 59 deletions.
47 changes: 23 additions & 24 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
set(mlas_common_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/platform.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/threading.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/qgemm.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/sgemm.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/convolve.cpp
${ONNXRUNTIME_ROOT}/core/mlas/lib/pooling.cpp
Expand All @@ -16,9 +17,7 @@ set(mlas_common_srcs
)

if(MSVC)

if(CMAKE_GENERATOR_PLATFORM STREQUAL "ARM64")

set(asm_filename ${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/sgemma.asm)
set(pre_filename ${CMAKE_CURRENT_BINARY_DIR}/sgemma.i)
set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/sgemma.obj)
Expand All @@ -36,20 +35,18 @@ if(MSVC)
COMMAND
armasm64.exe ${ARMASM_FLAGS} ${pre_filename} ${obj_filename}
)

set(mlas_platform_srcs ${obj_filename})

elseif(CMAKE_GENERATOR_PLATFORM STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM")

set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/arm/sgemmc.cpp
)

elseif(CMAKE_GENERATOR_PLATFORM STREQUAL "x64" OR CMAKE_GENERATOR MATCHES "Win64")

enable_language(ASM_MASM)

set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx2.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512BW.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/QgemmU8U8KernelAvx512Vnni.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/SgemmKernelSse2.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/SgemmKernelAvx.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/SgemmKernelFma3.asm
Expand All @@ -67,17 +64,14 @@ if(MSVC)
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/TanhKernelFma3.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/ErfKernelFma3.asm
)

else()

enable_language(ASM_MASM)

set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh")

set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/i386/sgemma.asm
)

endif()
else()
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
Expand All @@ -95,8 +89,7 @@ else()
COMMAND ${CMAKE_C_COMPILER} -dumpmachine
OUTPUT_VARIABLE dumpmachine_output
ERROR_QUIET
)

)
if(dumpmachine_output MATCHES "^arm.*")
set(ARM TRUE)
elseif(dumpmachine_output MATCHES "^aarch64.*")
Expand All @@ -108,39 +101,39 @@ else()
endif()
endif()

if (ARM)
if(ARM)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon")

set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/arm/sgemmc.cpp
)
elseif (ARM64)
)
elseif(ARM64)
enable_language(ASM)

set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/aarch64/sgemma.s
)
elseif (X86)
)
elseif(X86)
enable_language(ASM)

set(mlas_platform_srcs_sse2
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86/SgemmKernelSse2.S
)
)
set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2")

set(mlas_platform_srcs_avx
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86/SgemmKernelAvx.S
)
)
set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx")

set(mlas_platform_srcs
${mlas_platform_srcs_sse2}
${mlas_platform_srcs_avx}
)
elseif (X86_64)
)
elseif(X86_64)
enable_language(ASM)

# The LLVM assmebler does not support the .arch directive to enable instruction
# The LLVM assembler does not support the .arch directive to enable instruction
# set extensions and also doesn't support AVX-512F instructions without
# turning on support via command-line option. Group the sources by the
# instruction set extension and explicitly set the compiler flag as appropriate.
Expand All @@ -164,6 +157,7 @@ else()
set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx")

set(mlas_platform_srcs_avx2
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx2.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SgemmKernelFma3.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/SconvKernelFma3.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/LogisticKernelFma3.S
Expand All @@ -179,15 +173,20 @@ else()
)
set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f")

set(mlas_platform_srcs_avx512bw
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512BW.S
${ONNXRUNTIME_ROOT}/core/mlas/lib/x86_64/QgemmU8U8KernelAvx512Vnni.S
)
set_source_files_properties(${mlas_platform_srcs_avx512bw} PROPERTIES COMPILE_FLAGS "-mavx512bw")

set(mlas_platform_srcs
${mlas_platform_srcs_sse2}
${mlas_platform_srcs_avx}
${mlas_platform_srcs_avx2}
${mlas_platform_srcs_avx512f}
${mlas_platform_srcs_avx512bw}
)

endif()

endif()

add_library(onnxruntime_mlas STATIC ${mlas_common_srcs} ${mlas_platform_srcs})
Expand Down
21 changes: 21 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,27 @@ MlasSgemm(
MLAS_THREADPOOL* ThreadPool
);

//
// Quantized integer matrix/matrix multiply routine.
//

void
MLASCALL
MlasQgemm(
size_t M,
size_t N,
size_t K,
const uint8_t* A,
size_t lda,
uint8_t offa,
const uint8_t* B,
size_t ldb,
uint8_t offb,
int32_t* C,
size_t ldc,
MLAS_THREADPOOL* ThreadPool
);

//
// Convolution routines.
//
Expand Down
Loading

0 comments on commit bc72c2d

Please sign in to comment.