Skip to content

Commit

Permalink
ggml: aarch64: implement mmla kernel for q8_0_q8_0 quantized gemm
Browse files Browse the repository at this point in the history
armv8.2-a and above supports MMLA instructions that have better
throughput then DOT. this commit adds support for mmla kernel for
q8_0_q8_0 gemm. This is disabled by default. Please build llama.cpp
with the following option to enable it.
"LLAMA_ARM_MMLA=ON"
on AWS Graviton3 processors this kernel resulted up to 1.5x
improvement for prompt evaluation throughput compared to the
default sdot kernel.
  • Loading branch information
snadampal committed Jan 16, 2024
1 parent 3e5ca79 commit d924089
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 28 deletions.
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STA
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_SERVER "llama: build server example" ON)

# aarmv8.2-a+ extensions
option(LLAMA_ARM_MMLA "llama: enable aarch64 mmla kernels" OFF)

# Required for relocatable CMake package
include(${CMAKE_CURRENT_SOURCE_DIR}/scripts/build-info.cmake)

Expand Down Expand Up @@ -626,6 +629,10 @@ if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATC
# Raspberry Pi 3, 4, Zero 2 (32-bit)
add_compile_options(-mno-unaligned-access)
endif()
if (LLAMA_ARM_MMLA)
add_compile_options(-march=armv8.2-a+i8mm)
add_compile_definitions(__ARM_FEATURE_MMLA)
endif()
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "^(x86_64|i686|amd64|x64)$" )
message(STATUS "x86 detected")
Expand Down
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,13 @@ ggml-mpi.o: ggml-mpi.c ggml-mpi.h
$(CC) $(CFLAGS) -c $< -o $@
endif # LLAMA_MPI

ifdef LLAMA_ARM_MMLA
MK_CPPFLAGS += -D__ARM_FEATURE_MMLA
MK_CFLAGS += -D__ARM_FEATURE_MMLA
MK_CXXFLAGS += -D__ARM_FEATURE_MMLA
endif # LLAMA_ARM_MMLA


GF_CC := $(CC)
include scripts/get-flags.mk

Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
fprintf(stream, "cpu_has_neon_mmla: %s\n", ggml_cpu_has_neon_mmla() ? "true" : "false");
fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false");
fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false");
fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false");
Expand Down
62 changes: 62 additions & 0 deletions ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -4421,6 +4421,68 @@ void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restri
#endif
}

#if defined(__ARM_FEATURE_MMLA)
void ggml_vec_mmla_q8_0_q8_0(const int n, float * restrict s0, float * restrict s1, const void * restrict lhs0,
const void * restrict lhs1, const void * restrict rhs0, const void * restrict rhs1) {
const int qk = QK8_0;
const int nb = n / qk;

assert(n % qk == 0);

const block_q8_0 * restrict vx0 = lhs0;
const block_q8_0 * restrict vx1 = lhs1;
const block_q8_0 * restrict vy0 = rhs0;
const block_q8_0 * restrict vy1 = rhs1;

float32x4_t sumv0 = vdupq_n_f32(0.0f);

for (int i = 0; i < nb; i++) {
const block_q8_0 * restrict b_x0 = &vx0[i];
const block_q8_0 * restrict b_y0 = &vy0[i];

const block_q8_0 * restrict b_x1 = &vx1[i];
const block_q8_0 * restrict b_y1 = &vy1[i];

const int8x16_t x0_l = vld1q_s8(b_x0->qs);
const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
const int8x16_t x1_l = vld1q_s8(b_x1->qs);
const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);

// load y
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);

float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};

int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));

int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));

int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));

int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));

sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), l1, r1)), l2, r2)), l3, r3))), scale);
}

float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);

vst1_f32(s0, vget_low_f32(sumv2));
vst1_f32(s1, vget_high_f32(sumv2));
}
#endif


void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int qk = QK8_0;
const int nb = n / qk;
Expand Down
6 changes: 6 additions & 0 deletions ggml-quants.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,12 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx,
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);

#if defined(__ARM_FEATURE_MMLA)
// mmla
void ggml_vec_mmla_q8_0_q8_0(int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1);
#endif

//
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
//
Expand Down
134 changes: 106 additions & 28 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
.vec_dot = ggml_vec_dot_q8_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
#if defined(__ARM_FEATURE_MMLA)
.vec_mmla = ggml_vec_mmla_q8_0_q8_0,
#endif
},
[GGML_TYPE_Q8_1] = {
.type_name = "q8_1",
Expand Down Expand Up @@ -9801,6 +9804,9 @@ static void ggml_compute_forward_mul_mat(
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
#if defined(__ARM_FEATURE_MMLA)
ggml_vec_mmla_t const vec_mmla = type_traits[type].vec_mmla;
#endif

GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
Expand Down Expand Up @@ -9952,43 +9958,107 @@ static void ggml_compute_forward_mul_mat(

// attempt to reduce false-sharing (does not seem to make a difference)
float tmp[16];
#if defined(__ARM_FEATURE_MMLA)
float tmp1[16];
float tmp2[16];

for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t i13 = (ir1/(ne12*ne1));
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
if ((vec_mmla != NULL) && (nr0 % 2 == 0) && (nr1 %2 == 0)) {
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += 2) {
const int64_t i13 = (ir1/(ne12*ne11));
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);

// broadcast src0 into src1
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;
// broadcast src0 into src1
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;

const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;

const int64_t i13_ = ((ir1+1)/(ne12*ne11));
const int64_t i12_ = ((ir1+1) - (i13_)*ne12*ne11)/ne11;
const int64_t i11_ = ((ir1+1) - (i13_)*ne12*ne11 - (i12_)*ne11);

// broadcast src0 into src1
const int64_t i03_ = i13_/r3;
const int64_t i02_ = i12_/r2;

const int64_t i1_ = i11_;
const int64_t i2_ = i12_;
const int64_t i3_ = i13_;

const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);

// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));

const char * src1_col_ = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? ((i11_) + (i12_)*ne11 + (i13_)*ne12*ne11)*row_size
: ((i11_)*nb11 + (i12_)*nb12 + (i13_)*nb13));

float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
float * dst_col_ = (float *) ((char *) dst->data + ((i1_)*nb1 + (i2_)*nb2 + (i3_)*nb3));

for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += 2) {
vec_mmla(ne00, &tmp1[ir0 - iir0], &tmp2[ir0 - iir0], src0_row + ir0*nb01,
src0_row + (ir0+1)*nb01, src1_col, src1_col_);
}

memcpy(&dst_col[iir0], tmp1, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
memcpy(&dst_col_[iir0], tmp2, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
}
}
} else
#endif
{
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t i13 = (ir1/(ne12*ne1));
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);

const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;
// broadcast src0 into src1
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;

const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;

// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));
const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);

float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));

//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));

for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}

for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
}
}
Expand Down Expand Up @@ -19965,6 +20035,14 @@ int ggml_cpu_has_arm_fma(void) {
#endif
}

int ggml_cpu_has_neon_mmla(void) {
#if defined(__ARM_FEATURE_MMLA)
return 1;
#else
return 0;
#endif
}

int ggml_cpu_has_metal(void) {
#if defined(GGML_USE_METAL)
return 1;
Expand Down
8 changes: 8 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2217,6 +2217,7 @@ extern "C" {
GGML_API int ggml_cpu_has_fma (void);
GGML_API int ggml_cpu_has_neon (void);
GGML_API int ggml_cpu_has_arm_fma (void);
GGML_API int ggml_cpu_has_neon_mmla (void);
GGML_API int ggml_cpu_has_metal (void);
GGML_API int ggml_cpu_has_f16c (void);
GGML_API int ggml_cpu_has_fp16_va (void);
Expand All @@ -2242,6 +2243,10 @@ extern "C" {
typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
#if defined(__ARM_FEATURE_MMLA)
typedef void (*ggml_vec_mmla_t) (const int n, float * GGML_RESTRICT s0, float * GGML_RESTRICT s1, const void * GGML_RESTRICT x0,
const void * GGML_RESTRICT x1, const void * GGML_RESTRICT y0, const void * GGML_RESTRICT y1);
#endif

typedef struct {
const char * type_name;
Expand All @@ -2253,6 +2258,9 @@ extern "C" {
ggml_from_float_t from_float_reference;
ggml_vec_dot_t vec_dot;
enum ggml_type vec_dot_type;
#if defined(__ARM_FEATURE_MMLA)
ggml_vec_mmla_t vec_mmla;
#endif
} ggml_type_traits_t;

GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
Expand Down
1 change: 1 addition & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10541,6 +10541,7 @@ const char * llama_print_system_info(void) {
s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "NEON_MMLA = " + std::to_string(ggml_cpu_has_neon_mmla()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
Expand Down

0 comments on commit d924089

Please sign in to comment.