Skip to content

Commit

Permalink
Merge pull request #567 from robertknight/simplify-optimize-x64-zip
Browse files Browse the repository at this point in the history
Simplify and optimize i8/i16 interleaving SIMD ops for x64
  • Loading branch information
robertknight authored Feb 1, 2025
2 parents 387277d + 0e3fbb5 commit 36d7986
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 80 deletions.
121 changes: 51 additions & 70 deletions rten-simd/src/arch/x86_64.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use std::arch::x86_64::{
__m128i, __m256, __m256i, _mm256_add_epi32, _mm256_add_ps, _mm256_and_si256, _mm256_andnot_ps,
_mm256_blendv_epi8, _mm256_blendv_ps, _mm256_castps256_ps128, _mm256_castsi128_si256,
_mm256_castsi256_ps, _mm256_castsi256_si128, _mm256_cmp_ps, _mm256_cmpeq_epi32,
_mm256_cmpgt_epi32, _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_div_ps,
_mm256_extractf128_ps, _mm256_extractf128_si256, _mm256_fmadd_ps, _mm256_insertf128_si256,
_mm256_loadu_ps, _mm256_loadu_si256, _mm256_max_epi32, _mm256_max_ps, _mm256_min_epi32,
_mm256_min_ps, _mm256_mul_ps, _mm256_mullo_epi32, _mm256_or_si256, _mm256_set1_epi32,
_mm256_set1_ps, _mm256_setr_epi32, _mm256_slli_epi32, _mm256_storeu_ps, _mm256_storeu_si256,
_mm256_sub_epi32, _mm256_sub_ps, _mm256_unpackhi_epi16, _mm256_unpackhi_epi8,
_mm256_unpacklo_epi16, _mm256_unpacklo_epi8, _mm256_xor_si256, _mm_add_ps, _mm_cvtss_f32,
_mm_loadl_epi64, _mm_movehl_ps, _mm_prefetch, _mm_shuffle_ps, _CMP_GE_OQ, _CMP_LE_OQ,
_CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0,
_mm256_blendv_epi8, _mm256_blendv_ps, _mm256_castps256_ps128, _mm256_castsi256_ps,
_mm256_castsi256_si128, _mm256_cmp_ps, _mm256_cmpeq_epi32, _mm256_cmpgt_epi32,
_mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_div_ps, _mm256_extractf128_ps, _mm256_fmadd_ps,
_mm256_insertf128_si256, _mm256_loadu_ps, _mm256_loadu_si256, _mm256_max_epi32, _mm256_max_ps,
_mm256_min_epi32, _mm256_min_ps, _mm256_mul_ps, _mm256_mullo_epi32, _mm256_or_si256,
_mm256_permute2x128_si256, _mm256_set1_epi32, _mm256_set1_ps, _mm256_setr_epi32,
_mm256_slli_epi32, _mm256_storeu_ps, _mm256_storeu_si256, _mm256_sub_epi32, _mm256_sub_ps,
_mm256_unpackhi_epi16, _mm256_unpackhi_epi8, _mm256_unpacklo_epi16, _mm256_unpacklo_epi8,
_mm256_xor_si256, _mm_add_ps, _mm_cvtss_f32, _mm_loadl_epi64, _mm_movehl_ps, _mm_prefetch,
_mm_shuffle_ps, _CMP_GE_OQ, _CMP_LE_OQ, _CMP_LT_OQ, _MM_HINT_ET0, _MM_HINT_T0,
};
use std::mem::{transmute, MaybeUninit};

Expand Down Expand Up @@ -211,38 +210,34 @@ impl SimdInt for __m256i {

#[inline]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
// Interleave from low half of each 128-bit block.
let lo = _mm256_unpacklo_epi8(self, rhs);
// Interleave from high half of each 128-bit block.
let hi = _mm256_unpackhi_epi8(self, rhs);
// Combine elements from low and high half of first 128-bit block in
// `self` and `rhs`.
_mm256_insertf128_si256(lo, _mm256_castsi256_si128(hi), 1)
// AB{N} = Interleaved Nth 64-bit block.
let lo = _mm256_unpacklo_epi8(self, rhs); // AB0 AB2
let hi = _mm256_unpackhi_epi8(self, rhs); // AB1 AB3
_mm256_insertf128_si256(lo, _mm256_castsi256_si128(hi), 1) // AB0 AB1
}

#[inline]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
let lo = _mm256_unpacklo_epi8(self, rhs);
let hi = _mm256_unpackhi_epi8(self, rhs);
let lo_hi = _mm256_castsi128_si256(_mm256_extractf128_si256(lo, 1));
let hi_hi = _mm256_extractf128_si256(hi, 1);
_mm256_insertf128_si256(lo_hi, hi_hi, 1)
// AB{N} = Interleaved Nth 64-bit block.
let lo = _mm256_unpacklo_epi8(self, rhs); // AB0 AB2
let hi = _mm256_unpackhi_epi8(self, rhs); // AB1 AB3
_mm256_permute2x128_si256(lo, hi, 0x31) // AB2 AB3
}

#[inline]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
let lo = _mm256_unpacklo_epi16(self, rhs);
let hi = _mm256_unpackhi_epi16(self, rhs);
_mm256_insertf128_si256(lo, _mm256_castsi256_si128(hi), 1)
// AB{N} = Interleaved Nth 64-bit block.
let lo = _mm256_unpacklo_epi16(self, rhs); // AB0 AB2
let hi = _mm256_unpackhi_epi16(self, rhs); // AB1 AB3
_mm256_insertf128_si256(lo, _mm256_castsi256_si128(hi), 1) // AB0 AB1
}

#[inline]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
let lo = _mm256_unpacklo_epi16(self, rhs);
let hi = _mm256_unpackhi_epi16(self, rhs);
let lo_hi = _mm256_castsi128_si256(_mm256_extractf128_si256(lo, 1));
let hi_hi = _mm256_extractf128_si256(hi, 1);
_mm256_insertf128_si256(lo_hi, hi_hi, 1)
// AB{N} = Interleaved Nth 64-bit block.
let lo = _mm256_unpacklo_epi16(self, rhs); // AB0 AB2
let hi = _mm256_unpackhi_epi16(self, rhs); // AB1 AB3
_mm256_permute2x128_si256(lo, hi, 0x31) // AB2 AB3
}
}

Expand Down Expand Up @@ -414,9 +409,11 @@ use std::arch::x86_64::{
_mm512_cvttps_epi32, _mm512_div_ps, _mm512_fmadd_ps, _mm512_loadu_ps, _mm512_loadu_si512,
_mm512_mask_blend_epi32, _mm512_mask_blend_ps, _mm512_mask_i32gather_ps, _mm512_max_epi32,
_mm512_max_ps, _mm512_min_epi32, _mm512_min_ps, _mm512_mul_ps, _mm512_mullo_epi32,
_mm512_reduce_add_ps, _mm512_set1_epi32, _mm512_set1_ps, _mm512_setzero_si512,
_mm512_sllv_epi32, _mm512_storeu_epi32, _mm512_storeu_ps, _mm512_sub_epi32, _mm512_sub_ps,
_mm512_xor_si512, _MM_CMPINT_EQ, _MM_CMPINT_LE, _MM_CMPINT_LT,
_mm512_permutex2var_epi32, _mm512_reduce_add_ps, _mm512_set1_epi32, _mm512_set1_ps,
_mm512_setr_epi32, _mm512_setzero_si512, _mm512_sllv_epi32, _mm512_storeu_epi32,
_mm512_storeu_ps, _mm512_sub_epi32, _mm512_sub_ps, _mm512_unpackhi_epi16, _mm512_unpackhi_epi8,
_mm512_unpacklo_epi16, _mm512_unpacklo_epi8, _mm512_xor_si512, _MM_CMPINT_EQ, _MM_CMPINT_LE,
_MM_CMPINT_LT,
};

#[cfg(feature = "avx512")]
Expand Down Expand Up @@ -593,57 +590,41 @@ impl SimdInt for __m512i {
#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_lo_i8(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_castsi512_si256, _mm512_inserti64x4,
};
let lo_self = _mm512_castsi512_si256(self);
let lo_rhs = _mm512_castsi512_si256(rhs);
let lo = lo_self.zip_lo_i8(lo_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = lo_self.zip_hi_i8(lo_rhs);
_mm512_inserti64x4(lo, hi, 1)
// AB{N} = Interleaved Nth 64-bit block.
let lo = _mm512_unpacklo_epi8(self, rhs); // AB0 AB2 AB4 AB6
let hi = _mm512_unpackhi_epi8(self, rhs); // AB1 AB3 AB5 AB7
let idx = _mm512_setr_epi32(0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23);
_mm512_permutex2var_epi32(lo, idx, hi) // AB0 AB1 AB2 AB3
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_hi_i8(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_extracti64x4_epi64, _mm512_inserti64x4,
};
let hi_self = _mm512_extracti64x4_epi64(self, 1);
let hi_rhs = _mm512_extracti64x4_epi64(rhs, 1);
let lo = hi_self.zip_lo_i8(hi_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = hi_self.zip_hi_i8(hi_rhs);
_mm512_inserti64x4(lo, hi, 1)
// AB{N} = Interleaved Nth 64-bit block.
let lo = _mm512_unpacklo_epi8(self, rhs); // AB0 AB2 AB4 AB6
let hi = _mm512_unpackhi_epi8(self, rhs); // AB1 AB3 AB5 AB7
let idx = _mm512_setr_epi32(8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31);
_mm512_permutex2var_epi32(lo, idx, hi) // AB4 AB5 AB6 AB7
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_lo_i16(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_castsi512_si256, _mm512_inserti64x4,
};
let lo_self = _mm512_castsi512_si256(self);
let lo_rhs = _mm512_castsi512_si256(rhs);
let lo = lo_self.zip_lo_i16(lo_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = lo_self.zip_hi_i16(lo_rhs);
_mm512_inserti64x4(lo, hi, 1)
// AB{N} = Interleaved Nth 64-bit block.
let lo = _mm512_unpacklo_epi16(self, rhs); // AB0 AB2 AB4 AB6
let hi = _mm512_unpackhi_epi16(self, rhs); // AB1 AB3 AB5 AB7
let idx = _mm512_setr_epi32(0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23);
_mm512_permutex2var_epi32(lo, idx, hi) // AB0 AB1 AB2 AB3
}

#[inline]
#[target_feature(enable = "avx512f")]
unsafe fn zip_hi_i16(self, rhs: Self) -> Self {
use core::arch::x86_64::{
_mm512_castsi256_si512, _mm512_extracti64x4_epi64, _mm512_inserti64x4,
};
let hi_self = _mm512_extracti64x4_epi64(self, 1);
let hi_rhs = _mm512_extracti64x4_epi64(rhs, 1);
let lo = hi_self.zip_lo_i16(hi_rhs);
let lo = _mm512_castsi256_si512(lo);
let hi = hi_self.zip_hi_i16(hi_rhs);
_mm512_inserti64x4(lo, hi, 1)
// AB{N} = Interleaved Nth 64-bit block.
let lo = _mm512_unpacklo_epi16(self, rhs); // AB0 AB2 AB4 AB6
let hi = _mm512_unpackhi_epi16(self, rhs); // AB1 AB3 AB5 AB7
let idx = _mm512_setr_epi32(8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31);
_mm512_permutex2var_epi32(lo, idx, hi) // AB4 AB5 AB6 AB7
}
}

Expand Down
45 changes: 35 additions & 10 deletions rten-simd/src/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ pub mod tests {
}

#[test]
fn test_zip_lo_hi_i8() {
fn test_zip_lo_i8() {
let a_start = 0i8;
// `bstart` is not i8 to avoid overflow when `LEN` is 64.
let b_start = LEN * 4;
Expand All @@ -402,20 +402,33 @@ pub mod tests {
let b_vec = unsafe { SimdVec::load(b.as_ptr() as *const i32) };

let i8_lo = unsafe { a_vec.zip_lo_i8(b_vec) };
let i8_hi = unsafe { a_vec.zip_hi_i8(b_vec) };

let mut actual_i8_lo = [0i8; LEN * 4];
unsafe { i8_lo.store(actual_i8_lo.as_mut_ptr() as *mut i32) }

let mut actual_i8_hi = [0i8; LEN * 4];
unsafe { i8_hi.store(actual_i8_hi.as_mut_ptr() as *mut i32) }

let expected_i8_lo: Vec<_> = (a_start..)
.zip(b_start..)
.flat_map(|(a, b)| [a, b as i8])
.take(LEN * 4)
.collect();
assert_eq!(actual_i8_lo, expected_i8_lo.as_slice());
}

#[test]
fn test_zip_hi_i8() {
let a_start = 0i8;
// `bstart` is not i8 to avoid overflow when `LEN` is 64.
let b_start = LEN * 4;
let a: Vec<_> = (a_start..).take(LEN * 4).collect();
let b: Vec<_> = (b_start..).map(|x| x as i8).take(LEN * 4).collect();

let a_vec = unsafe { SimdVec::load(a.as_ptr() as *const i32) };
let b_vec = unsafe { SimdVec::load(b.as_ptr() as *const i32) };

let i8_hi = unsafe { a_vec.zip_hi_i8(b_vec) };

let mut actual_i8_hi = [0i8; LEN * 4];
unsafe { i8_hi.store(actual_i8_hi.as_mut_ptr() as *mut i32) }

let expected_i8_hi: Vec<_> = (a_start + LEN as i8 * 2..)
.zip(b_start + LEN * 2..)
Expand All @@ -426,7 +439,7 @@ pub mod tests {
}

#[test]
fn test_zip_lo_hi_i16() {
fn test_zip_lo_i16() {
let a_start = 0i16;
let b_start = LEN as i16 * 2;
let a: Vec<_> = (a_start..).take(LEN * 2).collect();
Expand All @@ -436,20 +449,32 @@ pub mod tests {
let b_vec = unsafe { SimdVec::load(b.as_ptr() as *const i32) };

let i16_lo = unsafe { a_vec.zip_lo_i16(b_vec) };
let i16_hi = unsafe { a_vec.zip_hi_i16(b_vec) };

let mut actual_i16_lo = [0i16; LEN * 2];
unsafe { i16_lo.store(actual_i16_lo.as_mut_ptr() as *mut i32) }

let mut actual_i16_hi = [0i16; LEN * 2];
unsafe { i16_hi.store(actual_i16_hi.as_mut_ptr() as *mut i32) }

let expected_i16_lo: Vec<_> = (a_start..)
.zip(b_start..)
.flat_map(|(a, b)| [a, b])
.take(LEN * 2)
.collect();
assert_eq!(actual_i16_lo, expected_i16_lo.as_slice());
}

#[test]
fn test_zip_hi_i16() {
let a_start = 0i16;
let b_start = LEN as i16 * 2;
let a: Vec<_> = (a_start..).take(LEN * 2).collect();
let b: Vec<_> = (b_start..).take(LEN * 2).collect();

let a_vec = unsafe { SimdVec::load(a.as_ptr() as *const i32) };
let b_vec = unsafe { SimdVec::load(b.as_ptr() as *const i32) };

let i16_hi = unsafe { a_vec.zip_hi_i16(b_vec) };

let mut actual_i16_hi = [0i16; LEN * 2];
unsafe { i16_hi.store(actual_i16_hi.as_mut_ptr() as *mut i32) }

let expected_i16_hi: Vec<_> = (a_start + LEN as i16..)
.zip(b_start + LEN as i16..)
Expand Down

0 comments on commit 36d7986

Please sign in to comment.