From 9790b502e61a4e9110bf2bee8bc7b7a13c0ec064 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 17 Sep 2024 08:35:28 +0300 Subject: [PATCH 1/4] Playing with horizontal sums --- ggml/src/iqk/iqk_mul_mat.cpp | 104 +++++++++++++++++++++++++++++------ 1 file changed, 87 insertions(+), 17 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 7543d8957..5d02773ac 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -260,6 +260,19 @@ inline float hmax_float_8(__m256 x) { max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4)); return _mm_cvtss_f32(max4); } +IQK_ALWAYS_INLINE __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} +IQK_ALWAYS_INLINE void store_8(int ix, __m256 * accm, const DataInfo& info) { + union { __m256 vec; float val[8]; } h; + h.vec = hsum_float_8x8(accm); + for (int iy = 0; iy < 8; ++iy) info.store(ix, iy, h.val[iy]); +} #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) @@ -1128,9 +1141,17 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); - info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + if constexpr (nrc_y == 8) { + for (int iy = 0; iy < nrc_y; ++iy) { + accm[iy] = _mm256_add_ps(accm[iy], _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1))); + } + store_8(ix, accm, info); + } + else { + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + } } } @@ -1230,9 +1251,18 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); - info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + if constexpr (nrc_y == 8) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + accm[iy] = _mm256_add_ps(accm[iy], sum256); + } + store_8(ix, accm, info); + } + else { + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256))); + } } } @@ -1833,8 +1863,12 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); + if constexpr (nrc_y == 8) { + store_8(ix, accd, info); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } } } @@ -1877,10 +1911,13 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); + if constexpr (nrc_y == 8) { + store_8(ix, accd, info); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } } - } } @@ -1926,8 +1963,12 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); + if constexpr (nrc_y == 8) { + store_8(ix, accd, info); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } } } @@ -2094,8 +2135,12 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data } } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); + if constexpr (nrc_y == 8) { + store_8(ix, accd, info); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } } } } @@ -2999,10 +3044,17 @@ struct ScaleHelperQ_1 { } }; -struct MinusType0 { +template struct MinusType0 { inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); } inline float compute(float d, int) const { return d; } inline float result(__m256 acc, int) const { return hsum_float_8(acc); } + //inline void store(int ix, __m256 * acc, const DataInfo& info) { + // if constexpr (nrc_y == 8) { + // store_8(ix, acc, info); + // } else { + // for (int iy = 0; iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_8(acc[iy])); + // } + //} }; template struct MinusType1 { @@ -3022,6 +3074,23 @@ template struct MinusType1 { const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); return hsum_float_4(_mm_add_ps(sum, accm[iy])); } + //inline void store(int ix, const __m256 * acc, const DataInfo& info) { + // for (int iy = 0; iy < nrc_y; ++iy) { + // accm[iy] = _mm_add_ps(accm[iy], _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1))); + // } + // if constexpr (nrc_y >= 4) { + // union { __m128 vec; float val[4]; } h; + // for (int i = 0; i < nrc_y/4; ++i) { + // accm[4*i+0] = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+0], accm[4*i+2]), _mm_unpackhi_ps(accm[4*i+0], accm[4*i+2])); + // accm[4*i+1] = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+1], accm[4*i+3]), _mm_unpackhi_ps(accm[4*i+1], accm[4*i+3])); + // h.vec = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+0], accm[4*i+1]), _mm_unpackhi_ps(accm[4*i+0], accm[4*i+1])); + // for (int j = 0; j < 4; ++j) info.store(ix, 4*i+j, h.val[j]); + // } + // for (int iy = 4*(nrc_y/4); iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_4(accm[iy])); + // } else { + // for (int iy = 0; iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_4(accm[iy])); + // } + //} }; template struct AccumT { @@ -3054,6 +3123,7 @@ template struct AccumT { } } } + //accm.store(ix, acc, info); for (int iy = 0; iy < nrc_y; ++iy) { info.store(ix, iy, accm.result(acc[iy], iy)); } @@ -3061,7 +3131,7 @@ template struct AccumT { }; template -using AccumType0 = AccumT; +using AccumType0 = AccumT, nrc_y, is_multiple_of_4>; template using AccumType1 = AccumT, nrc_y, is_multiple_of_4>; From 94cdadd5599c8c43300a3d0e5eb5a4c53442781d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 17 Sep 2024 09:14:19 +0300 Subject: [PATCH 2/4] Playing with horizontal sums - matrix times vector --- ggml/src/iqk/iqk_mul_mat.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5d02773ac..66f12a582 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1286,6 +1286,9 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const __m512i scales[2*k_nx]; + __m256 sums[8]; + + int ks = 0; for (int ix = 0; ix < nrc_x; ++ix) { auto accd = _mm512_setzero_ps(); @@ -1319,12 +1322,21 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const } if constexpr (std::is_same_v) { - info.store(ix, 0, _mm512_reduce_add_ps(accd)); + sums[ks++] = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); + //info.store(ix, 0, _mm512_reduce_add_ps(accd)); } else { auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1)); - info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); + sums[ks++] = _mm256_add_ps(accm, sum256); + //info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256))); + } + if (ks == 8) { + _mm256_storeu_ps(info.dst_row(0) + ix - 7, hsum_float_8x8(sums)); + ks = 0; } } + if (ks > 0) { + for (int ix = 0; ix < ks; ++ix) info.store(ix, 0, hsum_float_8(sums[ix])); + } } #else From 07b5d73837d98770584697960a8a3be8437fcea4 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 17 Sep 2024 09:46:21 +0300 Subject: [PATCH 3/4] Also apply to iq2_tn --- ggml/src/iqk/iqk_mul_mat.cpp | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 66f12a582..d76fd70e3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1198,8 +1198,8 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q); // The scale is supposed to be per per tensor, so we can use the same scale auto vd = _mm512_set1_ps(d*q8.scale(iy, i)); - accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]); - accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]); + accd[iy+ 0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[iy+ 0]); + accd[iy+nrc_y] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[iy+nrc_y]); // Leaving this here just in case ternary models start using per row scales //accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]); //accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]); @@ -1207,9 +1207,21 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0])); - info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1])); + if constexpr (nrc_y == 8) { + __m256 sums[8]; + for (int iy = 0; iy < nrc_y; ++iy) { + sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)); + } + store_8(ix+0, sums, info); + for (int iy = 0; iy < nrc_y; ++iy) { + sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy+nrc_y]), _mm512_extractf32x8_ps(accd[iy+nrc_y], 1)); + } + store_8(ix+1, sums, info); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix+0, iy, _mm512_reduce_add_ps(accd[iy+ 0])); + info.store(ix+1, iy, _mm512_reduce_add_ps(accd[iy+nrc_y])); + } } } From 5065dcd4a0b80573c14d4e865435d3c70ee52361 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 17 Sep 2024 10:52:23 +0300 Subject: [PATCH 4/4] Playing with hsums --- ggml/src/iqk/iqk_mul_mat.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index d76fd70e3..bb3075598 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -268,11 +268,33 @@ IQK_ALWAYS_INLINE __m256 hsum_float_8x8(__m256 * accm) { for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); } +#ifdef HAVE_FANCY_SIMD IQK_ALWAYS_INLINE void store_8(int ix, __m256 * accm, const DataInfo& info) { union { __m256 vec; float val[8]; } h; h.vec = hsum_float_8x8(accm); for (int iy = 0; iy < 8; ++iy) info.store(ix, iy, h.val[iy]); } +#else +// Somehow on the AVX2 system that I have available (Ryzen-5975WX), the store_8 version above +// and the commented out store_8 version below are slower than this. +IQK_ALWAYS_INLINE void store_8(int ix, __m256 * accm, const DataInfo& info) { + for (int iy = 0; iy < 8; ++iy) info.store(ix, iy, hsum_float_8(accm[iy])); +} +//IQK_ALWAYS_INLINE __m128 hsum_float_4x4(__m128 * a) { +// for (int i = 0; i < 2; ++i) a[i] = _mm_add_ps(_mm_unpacklo_ps(a[i], a[i+2]), _mm_unpackhi_ps(a[i], a[i+2])); +// return _mm_add_ps(_mm_unpacklo_ps(a[0], a[1]), _mm_unpackhi_ps(a[0], a[1])); +//} +//IQK_ALWAYS_INLINE void store_8(int ix, __m256 * accm, const DataInfo& info) { +// union { __m128 vec; float val[4]; } h; +// __m128 a[4]; +// for (int i = 0; i < 4; ++i) a[i] = _mm_add_ps(_mm256_castps256_ps128(accm[i]), _mm256_extractf128_ps(accm[i], 1)); +// h.vec = hsum_float_4x4(a); +// for (int iy = 0; iy < 4; ++iy) info.store(ix, iy, h.val[iy]); +// for (int i = 0; i < 4; ++i) a[i] = _mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)); +// h.vec = hsum_float_4x4(a); +// for (int iy = 0; iy < 4; ++iy) info.store(ix, iy+4, h.val[iy]); +#endif + #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)