Skip to content

Commit

Permalink
Chore: Clean-up substract_bf16x32_genoa
Browse files Browse the repository at this point in the history
Relates to #160
  • Loading branch information
ashvardanian committed Sep 5, 2024
1 parent 61806d8 commit 4fd1e56
Showing 1 changed file with 27 additions and 24 deletions.
51 changes: 27 additions & 24 deletions include/simsimd/spatial.h
Original file line number Diff line number Diff line change
Expand Up @@ -1128,30 +1128,33 @@ SIMSIMD_INTERNAL __m512i simsimd_substract_bf16x32_genoa(__m512i a_i16, __m512i
b.ivec = b_i16;

// Let's perform the subtraction with single-precision, while the dot-product with half-precision.
// For that we need to perform a couple of casts - each is a bitshift. To convert `bf16` to `f32`,
// expand it to 32-bit integers, then shift the bits by 16 to the left. Then subtract as floats,
// and shift back. During expansion, we will double the space, and should use separate registers
// for top and bottom halves.
a_f32_bot.fvec = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(a_i16)), 16));
b_f32_bot.fvec = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(b_i16)), 16));

// Some compilers don't have `_mm512_extracti32x8_epi32`, so we need to use `_mm512_extracti64x4_epi64`
a_f32_top.fvec =
_mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(a_i16, 1)), 16));
b_f32_top.fvec =
_mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(b_i16, 1)), 16));

// Subtract in single precision
d_f32_top.fvec = _mm512_sub_ps(a_f32_top.fvec, b_f32_top.fvec);
d_f32_bot.fvec = _mm512_sub_ps(a_f32_bot.fvec, b_f32_bot.fvec);

// Now, let's populate one ZMM register with the top 16 bits of every 32-bit float,
// in the "top", followed by the top parts of the "bottom" floats. Instead of using multple
// shifts and blends, we can achieve that with cheap `_mm512_mask_shuffle_epi8`, or a more
// expensive `_mm512_permutex2var_epi16`.
d.ivec = _mm512_castsi256_si512(_mm512_cvtepi32_epi16(_mm512_srli_epi32(_mm512_castps_si512(d_f32_bot.fvec), 16)));
d.ivec = _mm512_inserti64x4(d.ivec,
_mm512_cvtepi32_epi16(_mm512_srli_epi32(_mm512_castps_si512(d_f32_top.fvec), 16)), 1);
//
// There are several approaches to achieve this this. The first one is:
//
// Perform a couple of casts - each is a bitshift. To convert `bf16` to `f32`,
// expand it to 32-bit integers, then shift the bits by 16 to the left.
// Then subtract as floats, and shift back. During expansion, we will double the space,
// and should use separate registers for top and bottom halves.
// Some compilers don't have `_mm512_extracti32x8_epi32`, so we use `_mm512_extracti64x4_epi64`:
//
// a_f32_bot.fvec = _mm512_castsi512_ps(_mm512_slli_epi32(
// _mm512_cvtepu16_epi32(_mm512_castsi512_si256(a_i16)), 16));
// b_f32_bot.fvec = _mm512_castsi512_ps(_mm512_slli_epi32(
// _mm512_cvtepu16_epi32(_mm512_castsi512_si256(b_i16)), 16));
// a_f32_top.fvec =_mm512_castsi512_ps(
// _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(a_i16, 1)), 16));
// b_f32_top.fvec =_mm512_castsi512_ps(
// _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(b_i16, 1)), 16));
// d_f32_top.fvec = _mm512_sub_ps(a_f32_top.fvec, b_f32_top.fvec);
// d_f32_bot.fvec = _mm512_sub_ps(a_f32_bot.fvec, b_f32_bot.fvec);
// d.ivec = _mm512_castsi256_si512(_mm512_cvtepi32_epi16(
// _mm512_srli_epi32(_mm512_castps_si512(d_f32_bot.fvec), 16)));
// d.ivec = _mm512_inserti64x4(d.ivec, _mm512_cvtepi32_epi16(
// _mm512_srli_epi32(_mm512_castps_si512(d_f32_top.fvec), 16)), 1);
//
// Instead of using multple shifts and an insertion, we can achieve similar result with fewer expensive
// calls to `_mm512_permutex2var_epi16`, or a cheap `_mm512_mask_shuffle_epi8` and blend:

return d.ivec;
}

Expand Down

0 comments on commit 4fd1e56

Please sign in to comment.