Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/refactor noodle masked load (WIP) #216

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
67 changes: 58 additions & 9 deletions src/util/supervector/arch/x86/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,28 @@ really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, uint
{
SuperVector mask = Ones_vshr(16 -len);
SuperVector v = _mm_loadu_si128((const m128 *)ptr);
return mask & v;
return v & mask;
}

template <>
really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask)
{
#ifdef HAVE_AVX512
SuperVector<16> v = _mm_maskz_loadu_epi8(mask, (const m128 *)ptr);
v.print8("v");
return v;
#else
DEBUG_PRINTF("mask = %08x\n", mask);
SuperVector v = _mm_loadu_si128((const m128 *)ptr);
(void)mask;
return v; // FIXME: & mask
markos marked this conversation as resolved.
Show resolved Hide resolved
#endif
}

template<>
really_inline typename SuperVector<16>::comparemask_type SuperVector<16>::findLSB(typename SuperVector<16>::comparemask_type &z)
{
return findAndClearLSB_32(&z);
}

template<>
Expand Down Expand Up @@ -1126,22 +1147,35 @@ really_inline SuperVector<32> SuperVector<32>::load(void const *ptr)
template <>
really_inline SuperVector<32> SuperVector<32>::loadu_maskz(void const *ptr, uint8_t const len)
{
SuperVector mask = Ones_vshr(32 -len);
mask.print8("mask");
SuperVector<32> v = _mm256_loadu_si256((const m256 *)ptr);
v.print8("v");
return v & mask;
}

template <>
really_inline SuperVector<32> SuperVector<32>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask)
{
DEBUG_PRINTF("mask = %08llx\n", mask);
#ifdef HAVE_AVX512
u32 mask = (~0ULL) >> (32 - len);
SuperVector<32> v = _mm256_mask_loadu_epi8(Zeroes().u.v256[0], mask, (const m256 *)ptr);
SuperVector<32> v = _mm256_maskz_loadu_epi8(mask, (const m256 *)ptr);
v.print8("v");
return v;
#else
DEBUG_PRINTF("len = %d", len);
SuperVector<32> mask = Ones_vshr(32 -len);
mask.print8("mask");
(Ones() >> (32 - len)).print8("mask");
SuperVector<32> v = _mm256_loadu_si256((const m256 *)ptr);
v.print8("v");
return mask & v;
(void)mask;
return v; // FIXME: & mask
markos marked this conversation as resolved.
Show resolved Hide resolved
#endif
}

template<>
really_inline typename SuperVector<32>::comparemask_type SuperVector<32>::findLSB(typename SuperVector<32>::comparemask_type &z)
{
return findAndClearLSB_64(&z);
}

template<>
really_inline SuperVector<32> SuperVector<32>::alignr(SuperVector<32> &other, int8_t offset)
{
Expand Down Expand Up @@ -1778,11 +1812,26 @@ really_inline SuperVector<64> SuperVector<64>::loadu_maskz(void const *ptr, uint
{
u64a mask = (~0ULL) >> (64 - len);
DEBUG_PRINTF("mask = %016llx\n", mask);
SuperVector<64> v = _mm512_mask_loadu_epi8(Zeroes().u.v512[0], mask, (const m512 *)ptr);
SuperVector<64> v = _mm512_maskz_loadu_epi8(mask, (const m512 *)ptr);
v.print8("v");
return v;
}

template <>
really_inline SuperVector<64> SuperVector<64>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask)
{
DEBUG_PRINTF("mask = %016llx\n", mask);
SuperVector<64> v = _mm512_maskz_loadu_epi8(mask, (const m512 *)ptr);
v.print8("v");
return v;
}

template<>
really_inline typename SuperVector<64>::comparemask_type SuperVector<64>::findLSB(typename SuperVector<64>::comparemask_type &z)
{
return findAndClearLSB_64(&z);
}

template<>
template<>
really_inline SuperVector<64> SuperVector<64>::pshufb<true>(SuperVector<64> b)
Expand Down
28 changes: 8 additions & 20 deletions src/util/supervector/supervector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,34 +46,18 @@
#endif
#endif // VS_SIMDE_BACKEND

#include <util/bitutils.h>

#if defined(HAVE_SIMD_512_BITS)
using Z_TYPE = u64a;
#define Z_BITS 64
#define Z_SHIFT 63
#define Z_POSSHIFT 0
#define DOUBLE_LOAD_MASK(l) ((~0ULL) >> (Z_BITS -(l)))
#define SINGLE_LOAD_MASK(l) (((1ULL) << (l)) - 1ULL)
#elif defined(HAVE_SIMD_256_BITS)
using Z_TYPE = u32;
#define Z_BITS 32
#define Z_SHIFT 31
#define Z_POSSHIFT 0
#define DOUBLE_LOAD_MASK(l) (((1ULL) << (l)) - 1ULL)
#define SINGLE_LOAD_MASK(l) (((1ULL) << (l)) - 1ULL)
#elif defined(HAVE_SIMD_128_BITS)
#if !defined(VS_SIMDE_BACKEND) && (defined(ARCH_ARM32) || defined(ARCH_AARCH64))
using Z_TYPE = u64a;
#define Z_BITS 64
#define Z_POSSHIFT 2
#define DOUBLE_LOAD_MASK(l) ((~0ULL) >> (Z_BITS - (l)))
#else
using Z_TYPE = u32;
#define Z_BITS 32
#define Z_POSSHIFT 0
#define DOUBLE_LOAD_MASK(l) (((1ULL) << (l)) - 1ULL)
#endif
#define Z_SHIFT 15
#define SINGLE_LOAD_MASK(l) (((1ULL) << (l)) - 1ULL)
#endif

// Define a common assume_aligned using an appropriate compiler built-in, if
Expand Down Expand Up @@ -138,7 +122,7 @@ struct BaseVector<64>
static constexpr u16 previous_size = 32;
};

// 128 bit implementation
// 256 bit implementation
template <>
struct BaseVector<32>
{
Expand All @@ -158,7 +142,7 @@ struct BaseVector<16>
static constexpr bool is_valid = true;
static constexpr u16 size = 16;
using type = m128;
using comparemask_type = u64a;
using comparemask_type = u32;
static constexpr bool has_previous = false;
using previous_type = u64a;
static constexpr u16 previous_size = 8;
Expand Down Expand Up @@ -257,9 +241,13 @@ class SuperVector : public BaseVector<SIZE>
static typename base_type::comparemask_type
iteration_mask(typename base_type::comparemask_type mask);

static typename base_type::comparemask_type single_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); }
static typename base_type::comparemask_type double_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); }
static typename base_type::comparemask_type findLSB(typename base_type::comparemask_type &z);
static SuperVector loadu(void const *ptr);
static SuperVector load(void const *ptr);
static SuperVector loadu_maskz(void const *ptr, uint8_t const len);
static SuperVector loadu_maskz(void const *ptr, typename base_type::comparemask_type const len);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you add the implementation for arm later on, but I didn't see any implementation for ppc64 ?

SuperVector alignr(SuperVector &other, int8_t offset);

template<bool emulateIntel=true>
Expand Down