diff --git a/ggml-quants.c b/ggml-quants.c index 2b5ae8c660e89c..b8ac16177f5e65 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -943,21 +943,15 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) { // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { - const __m128i tmp = __lsx_vld((const __m128i *)rsi, 0); - __m128i tmp2 = __lsx_vsrli_h(tmp, 4); - __m128i lowMask = __lsx_vreplgr2vr_b(0xf); - __m128i tmpl = __lsx_vand_v(tmp, lowMask); - __m128i tmph = __lsx_vand_v(tmp2, lowMask); - return MM256_SET_M128I(tmph, tmpl); + const __m128i lo = __lsx_vld((const __m128i *)rsi, 0); + __m128i hi = __lsx_vsrli_h(lo, 4); + return __lasx_xvandi_b(MM256_SET_M128I(hi, lo), 0xf); } // add int16_t pairwise and return as float vector static inline __m256 sum_i16_pairs_float(const __m256i x) { - const __m256i ones = __lasx_xvreplgr2vr_h(1); - - __m256i zero256 = __lasx_xvldi(0); - const __m256i tmp1 = __lasx_xvmaddwev_w_h(zero256, ones, x); - const __m256i summed_pairs = __lasx_xvmaddwod_w_h(tmp1, ones, x); + __m256i v = __lasx_xvpackod_h(x, x); + __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v); return __lasx_xvffint_s_w(summed_pairs); }