From 4fd1e56c3f56c1f1a2ff9382e1a2f92e7827b00c Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 5 Sep 2024 21:05:00 +0000 Subject: [PATCH] Chore: Clean-up `substract_bf16x32_genoa` Relates to #160 --- include/simsimd/spatial.h | 51 +++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/include/simsimd/spatial.h b/include/simsimd/spatial.h index 2de78530..6ec48b3b 100644 --- a/include/simsimd/spatial.h +++ b/include/simsimd/spatial.h @@ -1128,30 +1128,33 @@ SIMSIMD_INTERNAL __m512i simsimd_substract_bf16x32_genoa(__m512i a_i16, __m512i b.ivec = b_i16; // Let's perform the subtraction with single-precision, while the dot-product with half-precision. - // For that we need to perform a couple of casts - each is a bitshift. To convert `bf16` to `f32`, - // expand it to 32-bit integers, then shift the bits by 16 to the left. Then subtract as floats, - // and shift back. During expansion, we will double the space, and should use separate registers - // for top and bottom halves. - a_f32_bot.fvec = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(a_i16)), 16)); - b_f32_bot.fvec = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(b_i16)), 16)); - - // Some compilers don't have `_mm512_extracti32x8_epi32`, so we need to use `_mm512_extracti64x4_epi64` - a_f32_top.fvec = - _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(a_i16, 1)), 16)); - b_f32_top.fvec = - _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(b_i16, 1)), 16)); - - // Subtract in single precision - d_f32_top.fvec = _mm512_sub_ps(a_f32_top.fvec, b_f32_top.fvec); - d_f32_bot.fvec = _mm512_sub_ps(a_f32_bot.fvec, b_f32_bot.fvec); - - // Now, let's populate one ZMM register with the top 16 bits of every 32-bit float, - // in the "top", followed by the top parts of the "bottom" floats. Instead of using multple - // shifts and blends, we can achieve that with cheap `_mm512_mask_shuffle_epi8`, or a more - // expensive `_mm512_permutex2var_epi16`. - d.ivec = _mm512_castsi256_si512(_mm512_cvtepi32_epi16(_mm512_srli_epi32(_mm512_castps_si512(d_f32_bot.fvec), 16))); - d.ivec = _mm512_inserti64x4(d.ivec, - _mm512_cvtepi32_epi16(_mm512_srli_epi32(_mm512_castps_si512(d_f32_top.fvec), 16)), 1); + // + // There are several approaches to achieve this this. The first one is: + // + // Perform a couple of casts - each is a bitshift. To convert `bf16` to `f32`, + // expand it to 32-bit integers, then shift the bits by 16 to the left. + // Then subtract as floats, and shift back. During expansion, we will double the space, + // and should use separate registers for top and bottom halves. + // Some compilers don't have `_mm512_extracti32x8_epi32`, so we use `_mm512_extracti64x4_epi64`: + // + // a_f32_bot.fvec = _mm512_castsi512_ps(_mm512_slli_epi32( + // _mm512_cvtepu16_epi32(_mm512_castsi512_si256(a_i16)), 16)); + // b_f32_bot.fvec = _mm512_castsi512_ps(_mm512_slli_epi32( + // _mm512_cvtepu16_epi32(_mm512_castsi512_si256(b_i16)), 16)); + // a_f32_top.fvec =_mm512_castsi512_ps( + // _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(a_i16, 1)), 16)); + // b_f32_top.fvec =_mm512_castsi512_ps( + // _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(b_i16, 1)), 16)); + // d_f32_top.fvec = _mm512_sub_ps(a_f32_top.fvec, b_f32_top.fvec); + // d_f32_bot.fvec = _mm512_sub_ps(a_f32_bot.fvec, b_f32_bot.fvec); + // d.ivec = _mm512_castsi256_si512(_mm512_cvtepi32_epi16( + // _mm512_srli_epi32(_mm512_castps_si512(d_f32_bot.fvec), 16))); + // d.ivec = _mm512_inserti64x4(d.ivec, _mm512_cvtepi32_epi16( + // _mm512_srli_epi32(_mm512_castps_si512(d_f32_top.fvec), 16)), 1); + // + // Instead of using multple shifts and an insertion, we can achieve similar result with fewer expensive + // calls to `_mm512_permutex2var_epi16`, or a cheap `_mm512_mask_shuffle_epi8` and blend: + return d.ivec; }