Skip to content

Commit

Permalink
f
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Dec 10, 2024
1 parent 8ba8efe commit ca25494
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions src/layer/x86/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -2797,7 +2797,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i

__m128i _pp = float2int8_avx(_p0, _p1);

__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15);
__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15);
_pp = _mm_shuffle_epi8(_pp, _si);

#if __AVX2__
Expand Down Expand Up @@ -2889,7 +2889,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i

__m128i _pp = float2int8_avx(_t0, _t1);

__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15);
__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15);
_pp = _mm_shuffle_epi8(_pp, _si);

#if __AVX2__
Expand Down Expand Up @@ -3002,7 +3002,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i

__m128i _pp = float2int8_avx(_p0, _p1);

__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15);
__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15);
_pp = _mm_shuffle_epi8(_pp, _si);

#if __AVX2__
Expand Down Expand Up @@ -4778,33 +4778,29 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int
#if __AVX2__
__m256i _t0 = combine4x2_epi32(_pp0, _pp2);
__m256i _t1 = combine4x2_epi32(_pp1, _pp3);

__m256i _t2 = _mm256_unpacklo_epi16(_t0, _t1);
__m256i _t3 = _mm256_unpackhi_epi16(_t0, _t1);
_t0 = _mm256_unpacklo_epi32(_t2, _t3);
_t1 = _mm256_unpackhi_epi32(_t2, _t3);

_t0 = _mm256_permute4x64_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0));
_t1 = _mm256_permute4x64_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0));

_mm256_storeu_si256((__m256i*)pp, _t0);
_mm256_storeu_si256((__m256i*)(pp + 32), _t1);
pp += 64;
#else
__m128i _t0 = _mm_unpacklo_epi16(_pp0, _pp1);
__m128i _t1 = _mm_unpackhi_epi16(_pp0, _pp1);
__m128i _t2 = _mm_unpacklo_epi16(_pp2, _pp3);
__m128i _t3 = _mm_unpackhi_epi16(_pp2, _pp3);
_pp0 = _mm_unpacklo_epi16(_t0, _t1);
_pp1 = _mm_unpackhi_epi16(_t0, _t1);
_pp2 = _mm_unpacklo_epi16(_t2, _t3);
_pp3 = _mm_unpackhi_epi16(_t2, _t3);

__m256i _t4 = combine4x2_epi32(_pp0, _pp1);
__m256i _t5 = combine4x2_epi32(_pp2, _pp3);

_mm256_storeu_si256((__m256i*)pp, _t4);
_mm256_storeu_si256((__m256i*)pp1, _t5);
__m128i _tt0 = _mm_unpacklo_epi16(_pp0, _pp1);
__m128i _tt1 = _mm_unpackhi_epi16(_pp0, _pp1);
__m128i _tt2 = _mm_unpacklo_epi16(_pp2, _pp3);
__m128i _tt3 = _mm_unpackhi_epi16(_pp2, _pp3);
_pp0 = _mm_unpacklo_epi32(_tt0, _tt1);
_pp1 = _mm_unpackhi_epi32(_tt0, _tt1);
_pp2 = _mm_unpacklo_epi32(_tt2, _tt3);
_pp3 = _mm_unpackhi_epi32(_tt2, _tt3);
__m256i _t0 = combine4x2_epi32(_pp0, _pp1);
__m256i _t1 = combine4x2_epi32(_pp2, _pp3);
_mm256_storeu_si256((__m256i*)pp, _t0);
_mm256_storeu_si256((__m256i*)pp1, _t1);
pp += 32;
pp1 += 32;
#endif
Expand Down Expand Up @@ -6482,7 +6478,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i

__m128i _pp = float2int8_avx(_p0, _p1);

__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15);
__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15);
_pp = _mm_shuffle_epi8(_pp, _si);

_mm_storeu_si128((__m128i*)pp, _pp);
Expand Down Expand Up @@ -6548,7 +6544,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i

__m128i _si = _mm_setr_epi8(0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15);
_pp = _mm_shuffle_epi8(_pp, _si);
#else // __AVX__
#else // __AVX__
__m128 _p0 = _mm_loadu_ps(p0);
__m128 _p1 = _mm_loadu_ps(p0 + 4);
__m128 _p2 = _mm_loadu_ps(p0 + B_hstep * 4);
Expand Down Expand Up @@ -6581,7 +6577,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i
_p = _mm256_mul_ps(_p, _scale);

int64_t v = float2int8_avx(_p);
#else // __AVX__
#else // __AVX__
_p0 = _mm_mul_ps(_p0, _scale);
_p1 = _mm_mul_ps(_p1, _scale);

Expand Down Expand Up @@ -6651,10 +6647,10 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i
__m128i _pp = float2int8_avx(_p0, _p1);

#if __AVX2__
__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15);
__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15);
_pp = _mm_shuffle_epi8(_pp, _si);
#endif
#else // __AVX__
#else // __AVX__
__m128 _p0 = _mm_setr_ps(p0[0], p0[1], p0[B_hstep], p0[B_hstep + 1]);
__m128 _p1 = _mm_setr_ps(p0[B_hstep * 2], p0[B_hstep * 2 + 1], p0[B_hstep * 3], p0[B_hstep * 3 + 1]);
__m128 _p2 = _mm_setr_ps(p0[B_hstep * 4], p0[B_hstep * 4 + 1], p0[B_hstep * 5], p0[B_hstep * 5 + 1]);
Expand Down Expand Up @@ -6687,7 +6683,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i
_p = _mm256_mul_ps(_p, _scale);

int64_t v = float2int8_avx(_p);
#else // __AVX__
#else // __AVX__
__m128 _p0 = _mm_setr_ps(p0[0], p0[B_hstep], p0[B_hstep * 2], p0[B_hstep * 3]);
__m128 _p1 = _mm_setr_ps(p0[B_hstep * 4], p0[B_hstep * 5], p0[B_hstep * 6], p0[B_hstep * 7]);

Expand Down Expand Up @@ -6864,7 +6860,7 @@ static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, i

__m128i _pp = float2int8_avx(_p0, _p1);

__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15);
__m128i _si = _mm_setr_epi8(0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7 ,15);
_pp = _mm_shuffle_epi8(_pp, _si);

_mm_storel_pd((double*)pp, _mm_castsi128_pd(_pp));
Expand Down Expand Up @@ -7715,12 +7711,16 @@ static void transpose_pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int
__m128i _tt1 = _mm_unpackhi_epi16(_pp0, _pp1);
__m128i _tt2 = _mm_unpacklo_epi16(_pp2, _pp3);
__m128i _tt3 = _mm_unpackhi_epi16(_pp2, _pp3);
_pp0 = _mm_unpacklo_epi16(_tt0, _tt1);
_pp1 = _mm_unpackhi_epi16(_tt0, _tt1);
_pp2 = _mm_unpacklo_epi16(_tt2, _tt3);
_pp3 = _mm_unpackhi_epi16(_tt2, _tt3);
__m256i _t0 = combine4x2_epi32(_pp0, _pp1);
__m256i _t1 = combine4x2_epi32(_pp2, _pp3);
_pp0 = _mm_unpacklo_epi32(_tt0, _tt1);
_pp1 = _mm_unpackhi_epi32(_tt0, _tt1);
_pp2 = _mm_unpacklo_epi32(_tt2, _tt3);
_pp3 = _mm_unpackhi_epi32(_tt2, _tt3);
_tt0 = _mm_unpacklo_epi64(_pp0, _pp2);
_tt1 = _mm_unpackhi_epi64(_pp0, _pp2);
_tt2 = _mm_unpacklo_epi64(_pp1, _pp3);
_tt3 = _mm_unpackhi_epi64(_pp1, _pp3);
__m256i _t0 = combine4x2_epi32(_tt0, _tt1);
__m256i _t1 = combine4x2_epi32(_tt2, _tt3);
#endif
_mm256_storeu_si256((__m256i*)pp, _t0);
_mm256_storeu_si256((__m256i*)(pp + 32), _t1);
Expand Down

0 comments on commit ca25494

Please sign in to comment.