From f434e5221a38ef17921fcc7c4c5268d65599cc37 Mon Sep 17 00:00:00 2001 From: Sunita Nadampalli Date: Mon, 22 Jan 2024 22:12:42 +0000 Subject: [PATCH] ggml: aarch64: implement smmla kernel for q4_1_q8_1 quantized gemm armv8.2-a and above supports MMLA instructions that have better throughput than DOT. this commit adds support for mmla kernel for q4_1_q8_1 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. --- ggml-quants.c | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++- ggml.c | 3 ++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/ggml-quants.c b/ggml-quants.c index f8b56aa04e6cd2..499412d10e0a03 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3979,14 +3979,87 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * vx[res const int nb = n / qk; assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert(nrc <= MAX_NUM_ROWS_COLS_ARM_MATMUL_INT8); +#else assert(nrc == MAX_NUM_ROWS_COLS_DOT_PRODUCT); - UNUSED(nrc); +#endif const block_q4_1 * restrict x = vx[0]; const block_q8_1 * restrict y = vy[0]; +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_1 * restrict vx0 = vx[0]; + const block_q4_1 * restrict vx1 = vx[1]; + const block_q8_1 * restrict vy0 = vy[0]; + const block_q8_1 * restrict vy1 = vy[1]; + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t summs0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q4_1 * restrict b_x0 = &vx0[i]; + const block_q4_1 * restrict b_x1 = &vx1[i]; + const block_q8_1 * restrict b_y0 = &vy0[i]; + const block_q8_1 * restrict b_y1 = &vy1[i]; + + float32x4_t summs_t = {GGML_FP16_TO_FP32(b_x0->m) * b_y0->s, + GGML_FP16_TO_FP32(b_x1->m) * b_y0->s, + GGML_FP16_TO_FP32(b_x0->m) * b_y1->s, + GGML_FP16_TO_FP32(b_x1->m) * b_y1->s}; + summs0 += summs_t; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); + + // 4-bit -> 8-bit + const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + // mmla into int32x4_t + float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + + float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + sumv2 = sumv2 + summs0; + + vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s + 16, vget_high_f32(sumv2)); + } else +#endif // TODO: add WASM SIMD #if defined(__ARM_NEON) + { float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -4028,6 +4101,7 @@ void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * vx[res } *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; + } #elif defined(__AVX2__) || defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); diff --git a/ggml.c b/ggml.c index 32cc8d3ce93829..d0b653510af189 100644 --- a/ggml.c +++ b/ggml.c @@ -456,6 +456,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, .vec_dot = ggml_vec_dot_q4_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .hw_matmul = true, +#endif }, [4] = { // GGML_TYPE_Q4_2 .type_name = "DEPRECATED",