Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix and rename _Within_limits, add tests #3247

Merged
merged 3 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ namespace ranges {
if constexpr (is_same_v<_Pj, identity> && _Vector_alg_in_find_is_safe<_It, _Ty>
&& sized_sentinel_for<_Se, _It>) {
if (!_STD is_constant_evaluated()) {
if (!_Within_limits(_First, _Val)) {
if (!_STD _Could_compare_equal_to_value_type<_It>(_Val)) {
return 0;
}

Expand Down
59 changes: 35 additions & 24 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -5544,17 +5544,33 @@ _NODISCARD constexpr auto lexicographical_compare_three_way(
}
#endif // __cpp_lib_concepts

template <class _Iter, class _Ty, class _Elem = _Iter_value_t<_Iter>>
_INLINE_VAR constexpr bool _Vector_alg_in_find_is_safe = // Can we activate the vector algorithms for find/count?
_Iterator_is_contiguous<_Iter> // The iterator must be contiguous so we can get raw pointers.
&& !_Iterator_is_volatile<_Iter> // The iterator must not be volatile.
&& disjunction_v< // And one of the following conditions must be met:
#ifdef __cpp_lib_byte
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>, // We're finding a std::byte in a range of std::byte.
#endif // __cpp_lib_byte
conjunction<is_integral<_Ty>, is_integral<_Elem>>, // We're finding an integer in a range of integers.
// The integer types can be different, which requires careful handling.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>>; // We're finding a U* in a range of U* (identical types).

template <class _InIt, class _Ty>
_NODISCARD constexpr bool _Within_limits(const _InIt&, const _Ty& _Val) {
_NODISCARD constexpr bool _Could_compare_equal_to_value_type(const _Ty& _Val) {
// check whether _Val is within the limits of _Elem
_STL_INTERNAL_STATIC_ASSERT(_Vector_alg_in_find_is_safe<_InIt, _Ty>);
strega-nil-ms marked this conversation as resolved.
Show resolved Hide resolved

if constexpr (disjunction_v<
#ifdef __cpp_lib_byte
is_same<_Ty, byte>,
#endif // __cpp_lib_byte
is_same<_Ty, bool>, is_pointer<_Ty>, is_null_pointer<_Ty>>) {
is_same<_Ty, bool>, is_pointer<_Ty>>) {
return true;
} else {
using _Elem = _Iter_value_t<_InIt>;
_STL_INTERNAL_STATIC_ASSERT(is_integral_v<_Elem> && is_integral_v<_Ty>);

if constexpr (is_same_v<_Elem, bool>) {
return _Val == true || _Val == false;
} else if constexpr (is_signed_v<_Elem>) {
Expand All @@ -5566,40 +5582,35 @@ _NODISCARD constexpr bool _Within_limits(const _InIt&, const _Ty& _Val) {
// signed _Elem, signed _Ty
return _Min <= _Val && _Val <= _Max;
} else {
if constexpr (-1 == static_cast<_Ty>(-1)) {
// signed _Elem, unsigned _Ty, -1 == static_cast<_Ty>(-1)
// signed _Elem, unsigned _Ty
if constexpr (_Elem{-1} == static_cast<_Ty>(-1)) {
// negative values of _Elem can compare equal to values of _Ty
return _Val <= _Max || static_cast<_Ty>(_Min) <= _Val;
} else {
// signed _Elem, unsigned _Ty, -1 != static_cast<_Ty>(-1)
// negative values of _Elem cannot compare equal to values of _Ty
return _Val <= _Max;
}
}
} else {
constexpr _Elem _Max = static_cast<_Elem>(~_Elem{0});

if constexpr (is_signed_v<_Ty>) {
// unsigned _Elem, signed _Ty
return 0 <= _Val && static_cast<make_unsigned_t<_Ty>>(_Val) <= _Max;
} else {
if constexpr (is_unsigned_v<_Ty>) {
// unsigned _Elem, unsigned _Ty
return _Val <= _Max;
} else {
// unsigned _Elem, signed _Ty
if constexpr (_Ty{-1} == static_cast<_Elem>(-1)) {
// negative values of _Ty can compare equal to values of _Elem
return _Val <= _Max;
} else {
// negative values of _Ty cannot compare equal to values of _Elem
return 0 <= _Val && _Val <= _Max;
}
}
}
}
}

template <class _Iter, class _Ty, class _Elem = _Iter_value_t<_Iter>>
_INLINE_VAR constexpr bool _Vector_alg_in_find_is_safe = // Can we activate the vector algorithms for find/count?
_Iterator_is_contiguous<_Iter> // The iterator must be contiguous so we can get raw pointers.
&& !_Iterator_is_volatile<_Iter> // The iterator must not be volatile.
&& disjunction_v< // And one of the following conditions must be met:
#ifdef __cpp_lib_byte
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>, // We're finding a std::byte in a range of std::byte.
#endif // __cpp_lib_byte
conjunction<is_integral<_Ty>, is_integral<_Elem>>, // We're finding an integer in a range of integers.
// The integer types can be different, which requires careful handling.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>>; // We're finding a U* in a range of U* (identical types).

template <class _InIt, class _Ty>
_NODISCARD _CONSTEXPR20 _InIt _Find_unchecked(_InIt _First, const _InIt _Last, const _Ty& _Val) {
// find first matching _Val; choose optimization
Expand All @@ -5609,7 +5620,7 @@ _NODISCARD _CONSTEXPR20 _InIt _Find_unchecked(_InIt _First, const _InIt _Last, c
if (!_STD is_constant_evaluated())
#endif // _HAS_CXX20
{
if (!_Within_limits(_First, _Val)) {
if (!_STD _Could_compare_equal_to_value_type<_InIt>(_Val)) {
return _Last;
}
#if _USE_STD_VECTOR_ALGORITHMS
Expand Down Expand Up @@ -5678,7 +5689,7 @@ namespace ranges {
if constexpr (_Vector_alg_in_find_is_safe<_It, _Ty> && _Sized_or_unreachable_sentinel_for<_Se, _It>
&& same_as<_Pj, identity>) {
if (!_STD is_constant_evaluated()) {
if (!_Within_limits(_First, _Val)) {
if (!_STD _Could_compare_equal_to_value_type<_It>(_Val)) {
if constexpr (_Is_sized) {
return _RANGES next(_STD move(_First), _Last);
} else {
Expand Down Expand Up @@ -5792,7 +5803,7 @@ _NODISCARD _CONSTEXPR20 _Iter_diff_t<_InIt> count(const _InIt _First, const _InI
if (!_STD is_constant_evaluated())
#endif // _HAS_CXX20
{
if (!_Within_limits(_UFirst, _Val)) {
if (!_STD _Could_compare_equal_to_value_type<decltype(_UFirst)>(_Val)) {
return 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ using namespace std;

constexpr auto long_min = numeric_limits<long>::min();
constexpr auto long_max = numeric_limits<long>::max();
constexpr auto uint_max = numeric_limits<unsigned int>::max();

#define STATIC_ASSERT(...) static_assert(__VA_ARGS__, #__VA_ARGS__)

Expand Down Expand Up @@ -62,6 +63,22 @@ void test_limit_check_elements_impl() {
assert(find(begin(sc), end(sc), ValueType{-1}) == begin(sc) + 4);

assert(count(begin(sc), end(sc), ValueType{-1}) == 2);
} else {
constexpr auto max_vt = numeric_limits<ValueType>::max();
if constexpr (ElementType{-1} == max_vt) {
// ugly conversions :(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-inclusive conversion shaming! (No change requested.)

assert(find(begin(sc), end(sc), max_vt) == begin(sc) + 4);
assert(find(begin(sc), end(sc), max_vt - 1) == begin(sc) + 3);

assert(count(begin(sc), end(sc), max_vt) == 2);
assert(count(begin(sc), end(sc), max_vt - 1) == 1);
} else {
assert(find(begin(sc), end(sc), max_vt) == end(sc));
assert(find(begin(sc), end(sc), max_vt - 1) == end(sc));

assert(count(begin(sc), end(sc), max_vt) == 0);
assert(count(begin(sc), end(sc), max_vt - 1) == 0);
}
}

assert(count(begin(sc), end(sc), ValueType{0}) == 1);
Expand Down Expand Up @@ -91,6 +108,23 @@ void test_limit_check_elements_impl() {
assert(find(begin(uc), end(uc), ValueType{2}) == begin(uc) + 3);
assert(find(begin(uc), end(uc), ValueType{6}) == end(uc));

if constexpr (is_signed_v<ValueType>) {
strega-nil-ms marked this conversation as resolved.
Show resolved Hide resolved
if constexpr (ValueType{-1} == max_val) {
// ugly conversions :(
assert(find(begin(uc), end(uc), ValueType{-1}) == begin(uc) + 6);
assert(find(begin(uc), end(uc), ValueType{-2}) == begin(uc) + 5);

assert(count(begin(uc), end(uc), ValueType{-1}) == 1);
assert(count(begin(uc), end(uc), ValueType{-2}) == 1);
} else {
assert(find(begin(uc), end(uc), ValueType{-1}) == end(uc));
assert(find(begin(uc), end(uc), ValueType{-2}) == end(uc));

assert(count(begin(uc), end(uc), ValueType{-1}) == 0);
assert(count(begin(uc), end(uc), ValueType{-2}) == 0);
}
}

if constexpr (max_val <= max_vt) {
assert(find(begin(uc), end(uc), ValueType{max_val - 3}) == end(uc));
assert(find(begin(uc), end(uc), ValueType{max_val - 2}) == begin(uc) + 4);
Expand Down Expand Up @@ -133,6 +167,7 @@ int main() {
#ifdef __cpp_lib_concepts
static_assert(_Vector_alg_in_find_is_safe<decltype(v.begin()), decltype(33)>, "should optimize");
#endif // __cpp_lib_concepts
static_assert(_Could_compare_equal_to_value_type<signed char*>(33), "should be within limits");

assert(find(v.begin(), v.end(), 33) - v.begin() == 1);
assert(find(v.begin(), v.end(), -1) - v.begin() == 2);
Expand Down Expand Up @@ -413,6 +448,16 @@ int main() {
assert(find(begin(sl), end(sl), 0xFFFFFFFF00000000ULL) == end(sl));
}

{ // unsigned int == int, weird conversions yay! (GH-3244)
const unsigned int ui[] = {0, 1, 2, uint_max - 2, uint_max - 1, uint_max};

assert(find(begin(ui), end(ui), 0) == begin(ui));
assert(find(begin(ui), end(ui), 2) == begin(ui) + 2);
assert(find(begin(ui), end(ui), 3) == end(ui));
assert(find(begin(ui), end(ui), -2) == begin(ui) + 4);
assert(find(begin(ui), end(ui), -1) == begin(ui) + 5);
}

{ // Test bools
const bool arr[]{true, true, true, false, true, false};

Expand All @@ -429,16 +474,21 @@ int main() {

{ // Test pointers
const char* s = "xxxyyy";
const char* arr[]{s, s + 1, s + 1, s + 5, s, s + 4};
const char* arr[]{s, s + 1, s + 1, s + 5, s, s + 4, nullptr};

static_assert(_Vector_alg_in_find_is_safe<decltype(begin(arr)), decltype(s + 1)>, "should optimize");
static_assert(!_Vector_alg_in_find_is_safe<decltype(begin(arr)), nullptr_t>, "should not optimize");

assert(find(begin(arr), end(arr), s) == begin(arr));
assert(find(begin(arr), end(arr), s + 1) == begin(arr) + 1);
assert(find(begin(arr), end(arr), s + 3) == end(arr));
assert(find(begin(arr), end(arr), static_cast<const char*>(nullptr)) == begin(arr) + 6);
assert(find(begin(arr), end(arr), nullptr) == begin(arr) + 6);

assert(count(begin(arr), end(arr), s + 1) == 2);
assert(count(begin(arr), end(arr), s + 5) == 1);
assert(count(begin(arr), end(arr), s + 3) == 0);
assert(count(begin(arr), end(arr), static_cast<const char*>(nullptr)) == 1);
assert(count(begin(arr), end(arr), nullptr) == 1);
}
}
strega-nil-ms marked this conversation as resolved.
Show resolved Hide resolved