Skip to content

Commit

Permalink
Vectorize find_first_of for 8 and 16 bit elements with SSE4.2 `pcmp…
Browse files Browse the repository at this point in the history
…estri` (#4466)

Co-authored-by: Stephan T. Lavavej <[email protected]>
  • Loading branch information
AlexGuteniev and StephanTLavavej authored Mar 21, 2024
1 parent 1e7d7f8 commit 9d761bd
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 0 deletions.
1 change: 1 addition & 0 deletions benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ endfunction()

add_benchmark(bitset_to_string src/bitset_to_string.cpp)
add_benchmark(find_and_count src/find_and_count.cpp)
add_benchmark(find_first_of src/find_first_of.cpp)
add_benchmark(locale_classic src/locale_classic.cpp)
add_benchmark(minmax_element src/minmax_element.cpp)
add_benchmark(path_lexically_normal src/path_lexically_normal.cpp)
Expand Down
46 changes: 46 additions & 0 deletions benchmarks/src/find_first_of.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <benchmark/benchmark.h>
#include <cstddef>
#include <cstdint>
#include <numeric>
#include <vector>

using namespace std;

template <class T, size_t Pos, size_t NSize, size_t HSize = Pos * 2, size_t Which = 0>
void bm(benchmark::State& state) {
vector<T> h(HSize, T{'.'});
vector<T> n(NSize);
iota(n.begin(), n.end(), T{'a'});

static_assert(Pos < HSize);
static_assert(Which < NSize);
h[Pos] = n[Which];

for (auto _ : state) {
benchmark::DoNotOptimize(find_first_of(h.begin(), h.end(), n.begin(), n.end()));
}
}

BENCHMARK(bm<uint8_t, 2, 3>);
BENCHMARK(bm<uint16_t, 2, 3>);

BENCHMARK(bm<uint8_t, 7, 4>);
BENCHMARK(bm<uint16_t, 7, 4>);

BENCHMARK(bm<uint8_t, 9, 3>);
BENCHMARK(bm<uint16_t, 9, 3>);

BENCHMARK(bm<uint8_t, 22, 5>);
BENCHMARK(bm<uint16_t, 22, 5>);

BENCHMARK(bm<uint8_t, 3056, 7>);
BENCHMARK(bm<uint16_t, 3056, 7>);

BENCHMARK(bm<uint8_t, 1011, 11>);
BENCHMARK(bm<uint16_t, 1011, 11>);

BENCHMARK_MAIN();
69 changes: 69 additions & 0 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ const void* __stdcall __std_find_last_trivial_2(const void* _First, const void*
const void* __stdcall __std_find_last_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;

const void* __stdcall __std_find_first_of_trivial_1(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_2(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;

__declspec(noalias) _Min_max_1i __stdcall __std_minmax_1i(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_1u __stdcall __std_minmax_1u(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_2i __stdcall __std_minmax_2i(const void* _First, const void* _Last) noexcept;
Expand Down Expand Up @@ -160,6 +165,29 @@ _Ty* __std_find_last_trivial(_Ty* const _First, _Ty* const _Last, const _TVal _V
static_assert(_Always_false<_Ty>, "Unexpected size");
}
}

template <class _Ty1, class _Ty2>
_Ty1* __std_find_first_of_trivial(
_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, _Ty2* const _Last2) noexcept {
if constexpr (sizeof(_Ty1) == 1) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_1(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(
static_cast<const _Ty1*>(::__std_find_first_of_trivial_2(_First1, _Last1, _First2, _Last2)));
} else {
static_assert(_Always_false<_Ty1>, "Unexpected size");
}
}

// find_first_of vectorization is likely to be a win after this size (in elements)
_INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16;

// Can we activate the vector algorithms for find_first_of?
template <class _It1, class _It2, class _Pr>
_INLINE_VAR constexpr bool _Vector_alg_in_find_first_of_is_safe =
_Equal_memcmp_is_safe<_It1, _It2, _Pr> // can replace value comparison with bitwise comparison
&& sizeof(_Iter_value_t<_It1>) <= 2; // pcmpestri compatible size
_STD_END
#endif // _USE_STD_VECTOR_ALGORITHMS

Expand Down Expand Up @@ -3321,6 +3349,24 @@ _NODISCARD _CONSTEXPR20 _FwdIt1 find_first_of(
const auto _ULast1 = _STD _Get_unwrapped(_Last1);
const auto _UFirst2 = _STD _Get_unwrapped(_First2);
const auto _ULast2 = _STD _Get_unwrapped(_Last2);
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_find_first_of_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
if (!_STD _Is_constant_evaluated() && _ULast1 - _UFirst1 >= _Threshold_find_first_of) {
const auto _First1_ptr = _STD _To_address(_UFirst1);
const auto _Result = _STD __std_find_first_of_trivial(
_First1_ptr, _STD _To_address(_ULast1), _STD _To_address(_UFirst2), _STD _To_address(_ULast2));

if constexpr (is_pointer_v<decltype(_UFirst1)>) {
_UFirst1 = _Result;
} else {
_UFirst1 += _Result - _First1_ptr;
}
_STD _Seek_wrapped(_First1, _UFirst1);
return _First1;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (; _UFirst1 != _ULast1; ++_UFirst1) {
for (auto _UMid2 = _UFirst2; _UMid2 != _ULast2; ++_UMid2) {
if (_Pred(*_UFirst1, *_UMid2)) {
Expand Down Expand Up @@ -3398,6 +3444,29 @@ namespace ranges {
_STL_INTERNAL_STATIC_ASSERT(sentinel_for<_Se2, _It2>);
_STL_INTERNAL_STATIC_ASSERT(indirectly_comparable<_It1, _It2, _Pr, _Pj1, _Pj2>);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_find_first_of_is_safe<_It1, _It2, _Pr> && sized_sentinel_for<_Se1, _It1>
&& sized_sentinel_for<_Se2, _It2> && is_same_v<_Pj1, identity> && is_same_v<_Pj2, identity>) {
if (!_STD is_constant_evaluated() && _Last1 - _First1 >= _Threshold_find_first_of) {
const auto _Count1 = _Last1 - _First1;
const auto _First1_ptr = _STD _To_address(_First1);
const auto _Last1_ptr = _First1_ptr + _Count1;

const auto _Count2 = _Last2 - _First2;
const auto _First2_ptr = _STD _To_address(_First2);
const auto _Last2_ptr = _First2_ptr + _Count2;

const auto _Result =
_STD __std_find_first_of_trivial(_First1_ptr, _Last1_ptr, _First2_ptr, _Last2_ptr);

if constexpr (is_pointer_v<_It1>) {
return _Result;
} else {
return _First1 + (_Result - _First1_ptr);
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
for (; _First1 != _Last1; ++_First1) {
for (auto _Mid2 = _First2; _Mid2 != _Last2; ++_Mid2) {
if (_STD invoke(_Pred, _STD invoke(_Proj1, *_First1), _STD invoke(_Proj2, *_Mid2))) {
Expand Down
78 changes: 78 additions & 0 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2009,6 +2009,74 @@ namespace {
}
return _Result;
}

template <class _Ty>
const void* __stdcall __std_find_first_of_trivial_impl(
const void* _First1, const void* const _Last1, const void* const _First2, const void* const _Last2) noexcept {
#ifndef _M_ARM64EC
const size_t _Needle_length = _Byte_length(_First2, _Last2);

if (_Use_sse42() && _Needle_length <= 16) {
constexpr int _Op =
(sizeof(_Ty) == 1 ? _SIDD_UBYTE_OPS : _SIDD_UWORD_OPS) | _SIDD_CMP_EQUAL_ANY | _SIDD_LEAST_SIGNIFICANT;
constexpr int _Part_size_el = sizeof(_Ty) == 1 ? 16 : 8;

const int _Needle_length_el = static_cast<int>(_Needle_length / sizeof(_Ty));

alignas(16) uint8_t _Tmp1[16];
memcpy(_Tmp1, _First2, _Needle_length);
const __m128i _Needle = _mm_load_si128(reinterpret_cast<const __m128i*>(_Tmp1));

const size_t _Haystack_length = _Byte_length(_First1, _Last1);
const void* _Stop_at = _First1;
_Advance_bytes(_Stop_at, _Haystack_length & ~size_t{0xF});

while (_First1 != _Stop_at) {
const __m128i _Haystack_part = _mm_loadu_si128(static_cast<const __m128i*>(_First1));

if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op)) {
const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op);
_Advance_bytes(_First1, _Pos * sizeof(_Ty));
return _First1;
}

_Advance_bytes(_First1, 16);
}

const size_t _Last_part_size = _Haystack_length & 0xF;
const int _Last_part_size_el = static_cast<int>(_Last_part_size / sizeof(_Ty));

alignas(16) uint8_t _Tmp2[16];
memcpy(_Tmp2, _First1, _Last_part_size);
const __m128i _Haystack_last_part = _mm_load_si128(reinterpret_cast<const __m128i*>(_Tmp2));

if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_last_part, _Last_part_size_el, _Op)) {
const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_last_part, _Last_part_size_el, _Op);
_Advance_bytes(_First1, _Pos * sizeof(_Ty));
return _First1;
}

_Advance_bytes(_First1, _Last_part_size);
return _First1;
}
#endif // !_M_ARM64EC

auto _Ptr_haystack = static_cast<const _Ty*>(_First1);
const auto _Ptr_haystack_end = static_cast<const _Ty*>(_Last1);
const auto _Ptr_needle = static_cast<const _Ty*>(_First2);
const auto _Ptr_needle_end = static_cast<const _Ty*>(_Last2);

for (; _Ptr_haystack != _Ptr_haystack_end; ++_Ptr_haystack) {
for (auto _Ptr = _Ptr_needle; _Ptr != _Ptr_needle_end; ++_Ptr) {
if (*_Ptr_haystack == *_Ptr) {
return _Ptr_haystack;
}
}
}

return _Ptr_haystack;
}

} // unnamed namespace

extern "C" {
Expand Down Expand Up @@ -2094,6 +2162,16 @@ __declspec(noalias) size_t
return __std_count_trivial_impl<_Find_traits_8>(_First, _Last, _Val);
}

const void* __stdcall __std_find_first_of_trivial_1(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept {
return __std_find_first_of_trivial_impl<uint8_t>(_First1, _Last1, _First2, _Last2);
}

const void* __stdcall __std_find_first_of_trivial_2(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept {
return __std_find_first_of_trivial_impl<uint16_t>(_First1, _Last1, _First2, _Last2);
}

} // extern "C"

#ifndef _M_ARM64EC
Expand Down
81 changes: 81 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ auto last_known_good_find_last(FwdIt first, FwdIt last, T v) {
}
}

template <class FwdItH, class FwdItN>
auto last_known_good_find_first_of(FwdItH h_first, FwdItH h_last, FwdItN n_first, FwdItN n_last) {
for (; h_first != h_last; ++h_first) {
for (FwdItN n = n_first; n != n_last; ++n) {
if (*h_first == *n) {
return h_first;
}
}
}
return h_first;
}

template <class T>
void test_case_find(const vector<T>& input, T v) {
auto expected = last_known_good_find(input.begin(), input.end(), v);
Expand Down Expand Up @@ -211,6 +223,57 @@ void test_find_last(mt19937_64& gen) {
}
#endif // _HAS_CXX23

template <class T>
void test_case_find_first_of(const vector<T>& input_haystack, const vector<T>& input_needle) {
auto expected = last_known_good_find_first_of(
input_haystack.begin(), input_haystack.end(), input_needle.begin(), input_needle.end());
auto actual = find_first_of(input_haystack.begin(), input_haystack.end(), input_needle.begin(), input_needle.end());
assert(expected == actual);
#if _HAS_CXX20
auto ranges_actual = ranges::find_first_of(input_haystack, input_needle);
assert(expected == ranges_actual);
#endif // _HAS_CXX20
}

template <class T>
void test_find_first_of(mt19937_64& gen) {
constexpr size_t needleDataCount = 30;
using TD = conditional_t<sizeof(T) == 1, int, T>;
uniform_int_distribution<TD> dis('a', 'z');
vector<T> input_haystack;
vector<T> input_needle;
input_haystack.reserve(dataCount);
input_needle.reserve(needleDataCount);

for (;;) {
input_needle.clear();

test_case_find_first_of(input_haystack, input_needle);
for (size_t attempts = 0; attempts < needleDataCount; ++attempts) {
input_needle.push_back(static_cast<T>(dis(gen)));
test_case_find_first_of(input_haystack, input_needle);
}

if (input_haystack.size() == dataCount) {
break;
}

input_haystack.push_back(static_cast<T>(dis(gen)));
}
}

template <class C1, class C2>
void test_find_first_of_containers() {
C1 haystack{'m', 'e', 'o', 'w', 'C', 'A', 'T', 'S'};
C2 needle{'R', 'S', 'T'};
const auto result = find_first_of(haystack.begin(), haystack.end(), needle.begin(), needle.end());
assert(result == haystack.begin() + 6);
#if _HAS_CXX20
const auto ranges_result = ranges::find_first_of(haystack, needle);
assert(ranges_result == haystack.begin() + 6);
#endif // _HAS_CXX20
}

template <class T>
void test_min_max_element(mt19937_64& gen) {
using Limits = numeric_limits<T>;
Expand Down Expand Up @@ -437,6 +500,24 @@ void test_vector_algorithms(mt19937_64& gen) {
test_find_last<unsigned long long>(gen);
#endif // _HAS_CXX23

test_find_first_of<char>(gen);
test_find_first_of<signed char>(gen);
test_find_first_of<unsigned char>(gen);
test_find_first_of<short>(gen);
test_find_first_of<unsigned short>(gen);
test_find_first_of<int>(gen);
test_find_first_of<unsigned int>(gen);
test_find_first_of<long long>(gen);
test_find_first_of<unsigned long long>(gen);

test_find_first_of_containers<vector<char>, vector<signed char>>();
test_find_first_of_containers<vector<char>, vector<unsigned char>>();
test_find_first_of_containers<vector<wchar_t>, vector<char>>();
test_find_first_of_containers<const vector<char>, const vector<char>>();
test_find_first_of_containers<vector<char>, const vector<char>>();
test_find_first_of_containers<const vector<wchar_t>, vector<wchar_t>>();
test_find_first_of_containers<vector<char>, vector<int>>();

test_min_max_element<char>(gen);
test_min_max_element<signed char>(gen);
test_min_max_element<unsigned char>(gen);
Expand Down

0 comments on commit 9d761bd

Please sign in to comment.