Skip to content

Commit

Permalink
ggml: aarch64: implement mmla kernel for q4_1_q8_1 quantized gemm
Browse files Browse the repository at this point in the history
armv8.2-a and above supports MMLA instructions that have better
throughput than DOT. this commit adds support for mmla kernel for
q4_1_q8_1 gemm. This feature is disabled by default. Please build llama.cpp
with the following option to enable it.
"LLAMA_ARM_MMLA=ON"
on AWS Graviton3 processors this kernel resulted up to 1.5x
improvement for prompt evaluation throughput compared to the
default sdot kernel.
  • Loading branch information
snadampal committed Jan 16, 2024
1 parent 4f7b61f commit 72cad33
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 0 deletions.
76 changes: 76 additions & 0 deletions ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -3772,6 +3772,82 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx,
#endif
}

#if defined(__ARM_FEATURE_MMLA)
void ggml_vec_mmla_q4_1_q8_1(const int n, float * restrict s0, float * restrict s1, const void * restrict lhs0,
const void * restrict lhs1, const void * restrict rhs0, const void * restrict rhs1) {
const int qk = QK8_1;
const int nb = n / qk;

assert(n % qk == 0);

const block_q4_1 * restrict vx0 = lhs0;
const block_q4_1 * restrict vx1 = lhs1;
const block_q8_1 * restrict vy0 = rhs0;
const block_q8_1 * restrict vy1 = rhs1;

float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t summs0 = vdupq_n_f32(0.0f);

for (int i = 0; i < nb; i++) {
const block_q4_1 * restrict b_x0 = &vx0[i];
const block_q4_1 * restrict b_x1 = &vx1[i];
const block_q8_1 * restrict b_y0 = &vy0[i];
const block_q8_1 * restrict b_y1 = &vy1[i];

float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * b_y0->s,
GGML_FP16_TO_FP32(b_x1->m) * b_y0->s,
GGML_FP16_TO_FP32(b_x0->m) * b_y1->s,
GGML_FP16_TO_FP32(b_x1->m) * b_y1->s};
summs0 += summs_t;

const uint8x16_t m4b = vdupq_n_u8(0x0F);

const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);

// 4-bit -> 8-bit
const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));

// load y
const int8x16_t y0_l = vld1q_s8(b_y0->qs);
const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);

// mmla into int32x4_t
float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};

int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));

int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));

int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));

int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));

sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
l1, r1)), l2, r2)), l3, r3))), scale);
}

float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
sumv2 = sumv2 + summs0;

vst1_f32(s0, vget_low_f32(sumv2));
vst1_f32(s1, vget_high_f32(sumv2));
}
#endif

void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int qk = QK8_1;
const int nb = n / qk;
Expand Down
2 changes: 2 additions & 0 deletions ggml-quants.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ void ggml_vec_mmla_q8_0_q8_0(const int n, float * restrict s0, float * restrict
const void * restrict vy0, const void * restrict vy1);
void ggml_vec_mmla_q4_0_q8_0(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1);
void ggml_vec_mmla_q4_1_q8_1(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1,
const void * restrict vy0, const void * restrict vy1);
#endif

//
Expand Down
3 changes: 3 additions & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
.vec_dot = ggml_vec_dot_q4_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1,
#if defined(__ARM_FEATURE_MMLA)
.vec_mmla = ggml_vec_mmla_q4_1_q8_1,
#endif
},
[4] = { // GGML_TYPE_Q4_2
.type_name = "DEPRECATED",
Expand Down

0 comments on commit 72cad33

Please sign in to comment.