diff --git a/src/layer/x86/lstm_int8.h b/src/layer/x86/lstm_int8.h index 752d3b16802..0bc9cda343a 100644 --- a/src/layer/x86/lstm_int8.h +++ b/src/layer/x86/lstm_int8.h @@ -1676,11 +1676,14 @@ static void lstm_dynamic_quantize_scale2int8(const float* ptr, int size, float s __m128 _p = _mm_loadu_ps(ptr); _p = _mm_mul_ps(_p, _scale); *(int32_t*)outptr = float2int8_sse(_p); +#ifndef _MSC_VER + // vs2019 crash on 128bit vnni :L --- nihui #if __AVXVNNI__ || __AVX512VNNI__ outptr[0] += 127; outptr[1] += 127; outptr[2] += 127; outptr[3] += 127; +#endif #endif ptr += 4; outptr += 4; @@ -1781,10 +1784,17 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d } for (; i + 3 < num_output; i += 4) { +#ifdef _MSC_VER + hs[0] = 0; + hs[1] = 0; + hs[2] = 0; + hs[3] = 0; +#else hs[0] = 127; hs[1] = 127; hs[2] = 127; hs[3] = 127; +#endif hs += 4; } for (; i < num_output; i++) @@ -1899,6 +1909,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _xi = _mm512_set1_epi32(((const int*)(x + i))[0]); __m512i _w = _mm512_loadu_si512((const __m512i*)kptr); +#ifdef _MSC_VER + _xi = _mm512_add_epi32(_xi, _mm512_set1_epi8(127)); +#endif _lstm_IFOGx0 = _mm512_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); kptr += 64; @@ -2083,6 +2096,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _h_cont = _mm512_set1_epi32(((const int*)(hs + i))[0]); __m512i _w = _mm512_loadu_si512((const __m512i*)kptr); +#ifdef _MSC_VER + _h_cont = _mm512_add_epi32(_h_cont, _mm512_set1_epi8(127)); +#endif _lstm_IFOGh0 = _mm512_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); kptr += 64; @@ -2295,6 +2311,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _xi = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); +#ifdef _MSC_VER + _xi = _mm256_add_epi32(_xi, _mm256_set1_epi8(127)); +#endif _lstm_IFOGx0 = _mm256_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); kptr += 32; @@ -2453,6 +2472,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _h_cont = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); +#ifdef _MSC_VER + _h_cont = _mm256_add_epi32(_h_cont, _mm256_set1_epi8(127)); +#endif _lstm_IFOGh0 = _mm256_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); kptr += 32; @@ -2652,6 +2674,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); __m128i _w = _mm_loadu_si128((const __m128i*)kptr); +#ifdef _MSC_VER + _xi = _mm_add_epi32(_xi, _mm_set1_epi8(127)); +#endif _lstm_IFOGx0 = _mm_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); kptr += 16; @@ -2858,6 +2883,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); __m128i _w = _mm_loadu_si128((const __m128i*)kptr); +#ifdef _MSC_VER + _h_cont = _mm_add_epi32(_h_cont, _mm_set1_epi8(127)); +#endif _lstm_IFOGh0 = _mm_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); kptr += 16;