Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add loongarch lsx and lasx optimize code #6454

Merged
merged 8 commits into from
May 20, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
opt bytes_from_nibbles_32 and sum_i16_pairs_float
  • Loading branch information
MQ-mengqing authored and junchao-loongson committed May 18, 2024
commit e8ed67052adae06d22aac3ea6cdf9fc05a33f49f
16 changes: 5 additions & 11 deletions ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -653,21 +653,15 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
{
const __m128i tmp = __lsx_vld((const __m128i *)rsi, 0);
__m128i tmp2 = __lsx_vsrli_h(tmp, 4);
__m128i lowMask = __lsx_vreplgr2vr_b(0xf);
__m128i tmpl = __lsx_vand_v(tmp, lowMask);
__m128i tmph = __lsx_vand_v(tmp2, lowMask);
return MM256_SET_M128I(tmph, tmpl);
const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
__m128i hi = __lsx_vsrli_h(lo, 4);
return __lasx_xvandi_b(MM256_SET_M128I(hi, lo), 0xf);
}

// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m256i x) {
const __m256i ones = __lasx_xvreplgr2vr_h(1);

__m256i zero256 = __lasx_xvldi(0);
const __m256i tmp1 = __lasx_xvmaddwev_w_h(zero256, ones, x);
const __m256i summed_pairs = __lasx_xvmaddwod_w_h(tmp1, ones, x);
__m256i v = __lasx_xvpackod_h(x, x);
__m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
return __lasx_xvffint_s_w(summed_pairs);
}

Expand Down
Loading