Skip to content

Commit

Permalink
fix and rename _Within_limits, add tests (#3247)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicole Mazzuca <[email protected]>
  • Loading branch information
strega-nil-ms and strega-nil authored Dec 6, 2022
1 parent ac11067 commit e5b008c
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 26 deletions.
2 changes: 1 addition & 1 deletion stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ namespace ranges {
if constexpr (is_same_v<_Pj, identity> && _Vector_alg_in_find_is_safe<_It, _Ty>
&& sized_sentinel_for<_Se, _It>) {
if (!_STD is_constant_evaluated()) {
if (!_Within_limits(_First, _Val)) {
if (!_STD _Could_compare_equal_to_value_type<_It>(_Val)) {
return 0;
}

Expand Down
59 changes: 35 additions & 24 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -5544,17 +5544,33 @@ _NODISCARD constexpr auto lexicographical_compare_three_way(
}
#endif // __cpp_lib_concepts

template <class _Iter, class _Ty, class _Elem = _Iter_value_t<_Iter>>
_INLINE_VAR constexpr bool _Vector_alg_in_find_is_safe = // Can we activate the vector algorithms for find/count?
_Iterator_is_contiguous<_Iter> // The iterator must be contiguous so we can get raw pointers.
&& !_Iterator_is_volatile<_Iter> // The iterator must not be volatile.
&& disjunction_v< // And one of the following conditions must be met:
#ifdef __cpp_lib_byte
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>, // We're finding a std::byte in a range of std::byte.
#endif // __cpp_lib_byte
conjunction<is_integral<_Ty>, is_integral<_Elem>>, // We're finding an integer in a range of integers.
// The integer types can be different, which requires careful handling.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>>; // We're finding a U* in a range of U* (identical types).

template <class _InIt, class _Ty>
_NODISCARD constexpr bool _Within_limits(const _InIt&, const _Ty& _Val) {
_NODISCARD constexpr bool _Could_compare_equal_to_value_type(const _Ty& _Val) {
// check whether _Val is within the limits of _Elem
_STL_INTERNAL_STATIC_ASSERT(_Vector_alg_in_find_is_safe<_InIt, _Ty>);

if constexpr (disjunction_v<
#ifdef __cpp_lib_byte
is_same<_Ty, byte>,
#endif // __cpp_lib_byte
is_same<_Ty, bool>, is_pointer<_Ty>, is_null_pointer<_Ty>>) {
is_same<_Ty, bool>, is_pointer<_Ty>>) {
return true;
} else {
using _Elem = _Iter_value_t<_InIt>;
_STL_INTERNAL_STATIC_ASSERT(is_integral_v<_Elem> && is_integral_v<_Ty>);

if constexpr (is_same_v<_Elem, bool>) {
return _Val == true || _Val == false;
} else if constexpr (is_signed_v<_Elem>) {
Expand All @@ -5566,40 +5582,35 @@ _NODISCARD constexpr bool _Within_limits(const _InIt&, const _Ty& _Val) {
// signed _Elem, signed _Ty
return _Min <= _Val && _Val <= _Max;
} else {
if constexpr (-1 == static_cast<_Ty>(-1)) {
// signed _Elem, unsigned _Ty, -1 == static_cast<_Ty>(-1)
// signed _Elem, unsigned _Ty
if constexpr (_Elem{-1} == static_cast<_Ty>(-1)) {
// negative values of _Elem can compare equal to values of _Ty
return _Val <= _Max || static_cast<_Ty>(_Min) <= _Val;
} else {
// signed _Elem, unsigned _Ty, -1 != static_cast<_Ty>(-1)
// negative values of _Elem cannot compare equal to values of _Ty
return _Val <= _Max;
}
}
} else {
constexpr _Elem _Max = static_cast<_Elem>(~_Elem{0});

if constexpr (is_signed_v<_Ty>) {
// unsigned _Elem, signed _Ty
return 0 <= _Val && static_cast<make_unsigned_t<_Ty>>(_Val) <= _Max;
} else {
if constexpr (is_unsigned_v<_Ty>) {
// unsigned _Elem, unsigned _Ty
return _Val <= _Max;
} else {
// unsigned _Elem, signed _Ty
if constexpr (_Ty{-1} == static_cast<_Elem>(-1)) {
// negative values of _Ty can compare equal to values of _Elem
return _Val <= _Max;
} else {
// negative values of _Ty cannot compare equal to values of _Elem
return 0 <= _Val && _Val <= _Max;
}
}
}
}
}

template <class _Iter, class _Ty, class _Elem = _Iter_value_t<_Iter>>
_INLINE_VAR constexpr bool _Vector_alg_in_find_is_safe = // Can we activate the vector algorithms for find/count?
_Iterator_is_contiguous<_Iter> // The iterator must be contiguous so we can get raw pointers.
&& !_Iterator_is_volatile<_Iter> // The iterator must not be volatile.
&& disjunction_v< // And one of the following conditions must be met:
#ifdef __cpp_lib_byte
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>, // We're finding a std::byte in a range of std::byte.
#endif // __cpp_lib_byte
conjunction<is_integral<_Ty>, is_integral<_Elem>>, // We're finding an integer in a range of integers.
// The integer types can be different, which requires careful handling.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>>; // We're finding a U* in a range of U* (identical types).

template <class _InIt, class _Ty>
_NODISCARD _CONSTEXPR20 _InIt _Find_unchecked(_InIt _First, const _InIt _Last, const _Ty& _Val) {
// find first matching _Val; choose optimization
Expand All @@ -5609,7 +5620,7 @@ _NODISCARD _CONSTEXPR20 _InIt _Find_unchecked(_InIt _First, const _InIt _Last, c
if (!_STD is_constant_evaluated())
#endif // _HAS_CXX20
{
if (!_Within_limits(_First, _Val)) {
if (!_STD _Could_compare_equal_to_value_type<_InIt>(_Val)) {
return _Last;
}
#if _USE_STD_VECTOR_ALGORITHMS
Expand Down Expand Up @@ -5678,7 +5689,7 @@ namespace ranges {
if constexpr (_Vector_alg_in_find_is_safe<_It, _Ty> && _Sized_or_unreachable_sentinel_for<_Se, _It>
&& same_as<_Pj, identity>) {
if (!_STD is_constant_evaluated()) {
if (!_Within_limits(_First, _Val)) {
if (!_STD _Could_compare_equal_to_value_type<_It>(_Val)) {
if constexpr (_Is_sized) {
return _RANGES next(_STD move(_First), _Last);
} else {
Expand Down Expand Up @@ -5792,7 +5803,7 @@ _NODISCARD _CONSTEXPR20 _Iter_diff_t<_InIt> count(const _InIt _First, const _InI
if (!_STD is_constant_evaluated())
#endif // _HAS_CXX20
{
if (!_Within_limits(_UFirst, _Val)) {
if (!_STD _Could_compare_equal_to_value_type<decltype(_UFirst)>(_Val)) {
return 0;
}

Expand Down
52 changes: 51 additions & 1 deletion tests/std/tests/Dev11_0316853_find_memchr_optimization/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ using namespace std;

constexpr auto long_min = numeric_limits<long>::min();
constexpr auto long_max = numeric_limits<long>::max();
constexpr auto uint_max = numeric_limits<unsigned int>::max();

#define STATIC_ASSERT(...) static_assert(__VA_ARGS__, #__VA_ARGS__)

Expand Down Expand Up @@ -62,6 +63,22 @@ void test_limit_check_elements_impl() {
assert(find(begin(sc), end(sc), ValueType{-1}) == begin(sc) + 4);

assert(count(begin(sc), end(sc), ValueType{-1}) == 2);
} else {
constexpr auto max_vt = numeric_limits<ValueType>::max();
if constexpr (ElementType{-1} == max_vt) {
// ugly conversions :(
assert(find(begin(sc), end(sc), max_vt) == begin(sc) + 4);
assert(find(begin(sc), end(sc), max_vt - 1) == begin(sc) + 3);

assert(count(begin(sc), end(sc), max_vt) == 2);
assert(count(begin(sc), end(sc), max_vt - 1) == 1);
} else {
assert(find(begin(sc), end(sc), max_vt) == end(sc));
assert(find(begin(sc), end(sc), max_vt - 1) == end(sc));

assert(count(begin(sc), end(sc), max_vt) == 0);
assert(count(begin(sc), end(sc), max_vt - 1) == 0);
}
}

assert(count(begin(sc), end(sc), ValueType{0}) == 1);
Expand Down Expand Up @@ -91,6 +108,23 @@ void test_limit_check_elements_impl() {
assert(find(begin(uc), end(uc), ValueType{2}) == begin(uc) + 3);
assert(find(begin(uc), end(uc), ValueType{6}) == end(uc));

if constexpr (is_signed_v<ValueType>) {
if constexpr (ValueType{-1} == max_val) {
// ugly conversions :(
assert(find(begin(uc), end(uc), ValueType{-1}) == begin(uc) + 6);
assert(find(begin(uc), end(uc), ValueType{-2}) == begin(uc) + 5);

assert(count(begin(uc), end(uc), ValueType{-1}) == 1);
assert(count(begin(uc), end(uc), ValueType{-2}) == 1);
} else {
assert(find(begin(uc), end(uc), ValueType{-1}) == end(uc));
assert(find(begin(uc), end(uc), ValueType{-2}) == end(uc));

assert(count(begin(uc), end(uc), ValueType{-1}) == 0);
assert(count(begin(uc), end(uc), ValueType{-2}) == 0);
}
}

if constexpr (max_val <= max_vt) {
assert(find(begin(uc), end(uc), ValueType{max_val - 3}) == end(uc));
assert(find(begin(uc), end(uc), ValueType{max_val - 2}) == begin(uc) + 4);
Expand Down Expand Up @@ -133,6 +167,7 @@ int main() {
#ifdef __cpp_lib_concepts
static_assert(_Vector_alg_in_find_is_safe<decltype(v.begin()), decltype(33)>, "should optimize");
#endif // __cpp_lib_concepts
static_assert(_Could_compare_equal_to_value_type<signed char*>(33), "should be within limits");

assert(find(v.begin(), v.end(), 33) - v.begin() == 1);
assert(find(v.begin(), v.end(), -1) - v.begin() == 2);
Expand Down Expand Up @@ -413,6 +448,16 @@ int main() {
assert(find(begin(sl), end(sl), 0xFFFFFFFF00000000ULL) == end(sl));
}

{ // unsigned int == int, weird conversions yay! (GH-3244)
const unsigned int ui[] = {0, 1, 2, uint_max - 2, uint_max - 1, uint_max};

assert(find(begin(ui), end(ui), 0) == begin(ui));
assert(find(begin(ui), end(ui), 2) == begin(ui) + 2);
assert(find(begin(ui), end(ui), 3) == end(ui));
assert(find(begin(ui), end(ui), -2) == begin(ui) + 4);
assert(find(begin(ui), end(ui), -1) == begin(ui) + 5);
}

{ // Test bools
const bool arr[]{true, true, true, false, true, false};

Expand All @@ -429,16 +474,21 @@ int main() {

{ // Test pointers
const char* s = "xxxyyy";
const char* arr[]{s, s + 1, s + 1, s + 5, s, s + 4};
const char* arr[]{s, s + 1, s + 1, s + 5, s, s + 4, nullptr};

static_assert(_Vector_alg_in_find_is_safe<decltype(begin(arr)), decltype(s + 1)>, "should optimize");
static_assert(!_Vector_alg_in_find_is_safe<decltype(begin(arr)), nullptr_t>, "should not optimize");

assert(find(begin(arr), end(arr), s) == begin(arr));
assert(find(begin(arr), end(arr), s + 1) == begin(arr) + 1);
assert(find(begin(arr), end(arr), s + 3) == end(arr));
assert(find(begin(arr), end(arr), static_cast<const char*>(nullptr)) == begin(arr) + 6);
assert(find(begin(arr), end(arr), nullptr) == begin(arr) + 6);

assert(count(begin(arr), end(arr), s + 1) == 2);
assert(count(begin(arr), end(arr), s + 5) == 1);
assert(count(begin(arr), end(arr), s + 3) == 0);
assert(count(begin(arr), end(arr), static_cast<const char*>(nullptr)) == 1);
assert(count(begin(arr), end(arr), nullptr) == 1);
}
}

0 comments on commit e5b008c

Please sign in to comment.