diff --git a/ggml-quants.c b/ggml-quants.c index eec564dcc71d57..527ff926c953c6 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3406,6 +3406,82 @@ static inline __m128i get_scale_shuffle(int i) { } #endif + +#if defined(__ARM_FEATURE_MATMUL_INT8) +void ggml_vec_mmla_q4_0_q8_0(const int n, float * restrict s0, float * restrict s1, const void * restrict lhs0, + const void * restrict lhs1, const void * restrict rhs0, const void * restrict rhs1) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q4_0 * restrict vx0 = lhs0; + const block_q4_0 * restrict vx1 = lhs1; + + const block_q8_0 * restrict vy0 = rhs0; + const block_q8_0 * restrict vy1 = rhs1; + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q4_0 * restrict b_x0 = &vx0[i]; + const block_q4_0 * restrict b_x1 = &vx1[i]; + const block_q8_0 * restrict b_y0 = &vy0[i]; + const block_q8_0 * restrict b_y1 = &vy1[i]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t x0_l = vsubq_s8(v0_0l, s8b); + const int8x16_t x0_h = vsubq_s8(v0_0h, s8b); + const int8x16_t x1_l = vsubq_s8(v0_1l, s8b); + const int8x16_t x1_h = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + float32x4_t scale = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + + float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + + vst1_f32(s0, vget_low_f32(sumv2)); + vst1_f32(s1, vget_high_f32(sumv2)); +} +#endif + void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml-quants.h b/ggml-quants.h index 4272edcfec5beb..ab02b88cd8f3cb 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -247,6 +247,8 @@ void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict // mmla void ggml_vec_mmla_q8_0_q8_0(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1); +void ggml_vec_mmla_q4_0_q8_0(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1); #endif // diff --git a/ggml.c b/ggml.c index 12df5b4c582e71..c204e10146e857 100644 --- a/ggml.c +++ b/ggml.c @@ -448,6 +448,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, .vec_dot = ggml_vec_dot_q4_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, +#if defined(__ARM_FEATURE_MATMUL_INT8) + .vec_mmla = ggml_vec_mmla_q4_0_q8_0, +#endif }, [GGML_TYPE_Q4_1] = { .type_name = "q4_1",