diff --git a/cpp/tests/utility/vector_equality.hpp b/cpp/tests/utility/vector_equality.hpp index b964a02ea..4082300ff 100644 --- a/cpp/tests/utility/vector_equality.hpp +++ b/cpp/tests/utility/vector_equality.hpp @@ -32,8 +32,14 @@ namespace cuspatial { namespace test { +/** + * @brief Compare two floats are close within N ULPs + * + * N is predefined by GoogleTest + * https://google.github.io/googletest/reference/assertions.html#EXPECT_FLOAT_EQ + */ template -auto floating_eq(T val) +auto floating_eq_by_ulp(T val) { if constexpr (std::is_same_v) { return ::testing::FloatEq(val); @@ -42,14 +48,43 @@ auto floating_eq(T val) } } +/** + * @brief Compare two floats are close within `abs_error` + */ +template +auto floating_eq_by_abs_error(T val, T abs_error) +{ + if constexpr (std::is_same_v) { + return ::testing::FloatNear(val, abs_error); + } else { + return ::testing::FloatNear(val, abs_error); + } +} + MATCHER(vec_2d_matcher, std::string(negation ? "are not" : "are") + " approximately equal vec_2d structs") { auto lhs = std::get<0>(arg); auto rhs = std::get<1>(arg); - if (::testing::Matches(floating_eq(rhs.x))(lhs.x) && - ::testing::Matches(floating_eq(rhs.y))(lhs.y)) + if (::testing::Matches(floating_eq_by_ulp(rhs.x))(lhs.x) && + ::testing::Matches(floating_eq_by_ulp(rhs.y))(lhs.y)) + return true; + + *result_listener << lhs << " != " << rhs; + + return false; +} + +MATCHER_P(vec_2d_near_matcher, + abs_error, + std::string(negation ? "are not" : "are") + " approximately equal vec_2d structs") +{ + auto lhs = std::get<0>(arg); + auto rhs = std::get<1>(arg); + + if (::testing::Matches(floating_eq_by_abs_error(rhs.x, abs_error))(lhs.x) && + ::testing::Matches(floating_eq_by_abs_error(rhs.y, abs_error))(lhs.y)) return true; *result_listener << lhs << " != " << rhs; @@ -62,7 +97,21 @@ MATCHER(float_matcher, std::string(negation ? "are not" : "are") + " approximate auto lhs = std::get<0>(arg); auto rhs = std::get<1>(arg); - if (::testing::Matches(floating_eq(rhs))(lhs)) return true; + if (::testing::Matches(floating_eq_by_ulp(rhs))(lhs)) return true; + + *result_listener << std::setprecision(18) << lhs << " != " << rhs; + + return false; +} + +MATCHER_P(float_near_matcher, + abs_error, + std::string(negation ? "are not" : "are") + " approximately equal floats") +{ + auto lhs = std::get<0>(arg); + auto rhs = std::get<1>(arg); + + if (::testing::Matches(floating_eq_by_abs_error(rhs, abs_error))(lhs)) return true; *result_listener << std::setprecision(18) << lhs << " != " << rhs; @@ -103,5 +152,28 @@ inline void expect_vector_equivalent(Vector1 const& lhs, Vector2 const& rhs) } } +template +inline void expect_vector_equivalent(Vector1 const& lhs, Vector2 const& rhs, T abs_error) +{ + static_assert(std::is_same_v, "Value type mismatch."); + static_assert(!std::is_integral_v, "Integral types cannot be compared with an error."); + + if constexpr (cuspatial::is_vec_2d()) { + EXPECT_THAT(to_host(lhs), + ::testing::Pointwise(vec_2d_near_matcher(abs_error), to_host(rhs))); + } else if constexpr (std::is_floating_point_v) { + EXPECT_THAT(to_host(lhs), + ::testing::Pointwise(float_near_matcher(abs_error), to_host(rhs))); + } else { + EXPECT_EQ(lhs, rhs); + } +} + +#define CUSPATIAL_EXPECT_VECTORS_EQUIVALENT(lhs, rhs, ...) \ + do { \ + SCOPED_TRACE(" <-- line of failure\n"); \ + expect_vector_equivalent(lhs, rhs, ##__VA_ARGS__); \ + } while (0) + } // namespace test } // namespace cuspatial