From 07b5d73837d98770584697960a8a3be8437fcea4 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 17 Sep 2024 09:46:21 +0300 Subject: [PATCH] Also apply to iq2_tn --- ggml/src/iqk/iqk_mul_mat.cpp | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 66f12a582..d76fd70e3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1198,8 +1198,8 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q); // The scale is supposed to be per per tensor, so we can use the same scale auto vd = _mm512_set1_ps(d*q8.scale(iy, i)); - accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]); - accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]); + accd[iy+ 0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[iy+ 0]); + accd[iy+nrc_y] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[iy+nrc_y]); // Leaving this here just in case ternary models start using per row scales //accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]); //accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]); @@ -1207,9 +1207,21 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0])); - info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1])); + if constexpr (nrc_y == 8) { + __m256 sums[8]; + for (int iy = 0; iy < nrc_y; ++iy) { + sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + } + store_8(ix+0, sums, info); + for (int iy = 0; iy < nrc_y; ++iy) { + sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy+nrc_y]), _mm512_extractf32x8_ps(accd[iy+nrc_y], 1)); + } + store_8(ix+1, sums, info); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix+0, iy, _mm512_reduce_add_ps(accd[iy+ 0])); + info.store(ix+1, iy, _mm512_reduce_add_ps(accd[iy+nrc_y])); + } } }