Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/refactor noodle masked load (WIP) #216

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
17 changes: 13 additions & 4 deletions src/hwlm/noodle_engine_simd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,21 @@ hwlm_error_t scanSingleMain(const struct noodTable *n, const u8 *buf,
DEBUG_PRINTF("d - d0: %ld \n", d - d0);
#if defined(HAVE_MASKED_LOADS)
uint8_t l = d - d0;
typename SuperVector<S>::comparemask_type mask = ~SuperVector<S>::single_load_mask(l);
typename SuperVector<S>::comparemask_type mask = ~SuperVector<S>::load_mask(l);
SuperVector<S> chars = SuperVector<S>::loadu_maskz(d0, mask) & caseMask;
typename SuperVector<S>::comparemask_type z = mask1.eqmask(chars);
DEBUG_PRINTF("mask: %08llx\n", mask);
hwlm_error_t rv = single_zscan<S>(n, d0, buf, z, len, cbi);
#else
uint8_t l = d0 + S - d;
DEBUG_PRINTF("l: %d \n", l);
SuperVector<S> chars = SuperVector<S>::loadu_maskz(d, l) & caseMask;
chars.print8("chars");

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug print?

typename SuperVector<S>::comparemask_type z = mask1.eqmask(chars);
DEBUG_PRINTF("z: %08llx\n", (u64a) z);
z = SuperVector<S>::iteration_mask(z);
DEBUG_PRINTF("z: %08llx\n", (u64a) z);

hwlm_error_t rv = single_zscan<S>(n, d, buf, z, len, cbi);
#endif
chars.print32("chars");
markos marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -125,6 +131,8 @@ hwlm_error_t scanSingleMain(const struct noodTable *n, const u8 *buf,
uint8_t l = buf_end - d;
SuperVector<S> chars = SuperVector<S>::loadu_maskz(d, l) & caseMask;
typename SuperVector<S>::comparemask_type z = mask1.eqmask(chars);
z = SuperVector<S>::iteration_mask(z);

hwlm_error_t rv = single_zscan<S>(n, d, buf, z, len, cbi);
RETURN_IF_TERMINATED(rv);
}
Expand Down Expand Up @@ -160,12 +168,12 @@ hwlm_error_t scanDoubleMain(const struct noodTable *n, const u8 *buf,
const u8 *d0 = ROUNDDOWN_PTR(d, S);
#if defined(HAVE_MASKED_LOADS)
uint8_t l = d - d0;
typename SuperVector<S>::comparemask_type mask = ~SuperVector<S>::double_load_mask(l);
typename SuperVector<S>::comparemask_type mask = ~SuperVector<S>::load_mask(l);
SuperVector<S> chars = SuperVector<S>::loadu_maskz(d0, mask) & caseMask;
typename SuperVector<S>::comparemask_type z1 = mask1.eqmask(chars);
typename SuperVector<S>::comparemask_type z2 = mask2.eqmask(chars);
typename SuperVector<S>::comparemask_type z = (z1 << SuperVector<S>::mask_width()) & z2;
DEBUG_PRINTF("z: %0llx\n", z);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why deleting the debug print here? when you added more of the likes for the scanSingle function

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

z has a different type in each architecture, this DEBUG_PRINTF fails to compile on some architectures, so I need to make it work and compile on all architectures.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What confuse me is that in the previous function, you added DEBUG_PRINTF("z: %08llx\n", (u64a) z);, so I believe you could have modified this print to work the same way by casting z?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's what I did locally after I realized I could just cast it :)
unfortunately I had some other things to fix before the holidays and this was left unfinished -along with other fixes that I have locally. I will be commiting more fixes over the next days.

z = SuperVector<S>::iteration_mask(z);
lastz1 = z1 >> (S - 1);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this assume SuperVector<S>::mask_width() == 1 which is not always the case (for arm/neon it's 4)


DEBUG_PRINTF("mask: %08llx\n", mask);
Expand All @@ -176,8 +184,9 @@ hwlm_error_t scanDoubleMain(const struct noodTable *n, const u8 *buf,
chars.print8("chars");
markos marked this conversation as resolved.
Show resolved Hide resolved
typename SuperVector<S>::comparemask_type z1 = mask1.eqmask(chars);
typename SuperVector<S>::comparemask_type z2 = mask2.eqmask(chars);

typename SuperVector<S>::comparemask_type z = (z1 << SuperVector<S>::mask_width()) & z2;
z = SuperVector<S>::iteration_mask(z);

hwlm_error_t rv = double_zscan<S>(n, d, buf, z, len, cbi);
lastz1 = z1 >> (l - 1);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same issue with mask_width()

#endif
Expand Down
19 changes: 17 additions & 2 deletions src/util/supervector/arch/arm/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,11 +525,26 @@ really_inline SuperVector<16> SuperVector<16>::load(void const *ptr)
template <>
really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, uint8_t const len)
{
SuperVector mask = Ones_vshr(16 -len);
SuperVector<16> v = loadu(ptr);
SuperVector mask = Ones_vshr(16 - len);
SuperVector v = loadu(ptr);
return mask & v;
}

template <>
really_inline SuperVector<16> SuperVector<16>::loadu_maskz(void const *ptr, typename base_type::comparemask_type const mask)
{
DEBUG_PRINTF("mask = %08llx\n", mask);
SuperVector v = loadu(ptr);
(void)mask;
return v; // FIXME: & mask

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FIXME

}

template<>
really_inline typename SuperVector<16>::comparemask_type SuperVector<16>::findLSB(typename SuperVector<16>::comparemask_type &z)
{
return findAndClearLSB_64(&z) >> 2;
}

template<>
really_inline SuperVector<16> SuperVector<16>::alignr(SuperVector<16> &other, int8_t offset)
{
Expand Down
7 changes: 5 additions & 2 deletions src/util/supervector/supervector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ struct BaseVector<16>
static constexpr bool is_valid = true;
static constexpr u16 size = 16;
using type = m128;
#if defined(ARCH_ARM32) || defined(ARCH_AARCH64)
using comparemask_type = u64a;
#else
using comparemask_type = u32;
#endif
static constexpr bool has_previous = false;
using previous_type = u64a;
static constexpr u16 previous_size = 8;
Expand Down Expand Up @@ -229,8 +233,7 @@ class SuperVector : public BaseVector<SIZE>
static typename base_type::comparemask_type
iteration_mask(typename base_type::comparemask_type mask);

static typename base_type::comparemask_type single_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); }
static typename base_type::comparemask_type double_load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); }
static typename base_type::comparemask_type load_mask(uint8_t const len) { return (((1ULL) << (len)) - 1ULL); }
static typename base_type::comparemask_type findLSB(typename base_type::comparemask_type &z);
static SuperVector loadu(void const *ptr);
static SuperVector load(void const *ptr);
Expand Down