From 5cef7605da3bec952000e331bd342e39c94c44cc Mon Sep 17 00:00:00 2001 From: codesigner Date: Thu, 17 Nov 2022 11:04:38 +0800 Subject: [PATCH] fix Nan & Infinity comparison And Value's operator/ & operator% --- src/common/datatypes/Value.cpp | 46 ++------ src/common/datatypes/Value.h | 4 + src/common/datatypes/test/ValueTest.cpp | 106 +++++++++++++++++- .../expression/RelationalExpression.cpp | 4 +- .../test/ArithmeticExpressionTest.cpp | 12 +- src/common/expression/test/TestBase.h | 7 +- src/common/function/FunctionManager.cpp | 6 - .../function/test/FunctionManagerTest.cpp | 2 + tests/common/comparator.py | 37 +++--- .../features/bugfix/NaNInfinityFloat.feature | 43 +++++++ tests/tck/features/parser/nebula.feature | 3 +- tests/tck/utils/nbv.py | 20 ++-- 12 files changed, 211 insertions(+), 79 deletions(-) create mode 100644 tests/tck/features/bugfix/NaNInfinityFloat.feature diff --git a/src/common/datatypes/Value.cpp b/src/common/datatypes/Value.cpp index 916f7847c91..7de3ece5103 100644 --- a/src/common/datatypes/Value.cpp +++ b/src/common/datatypes/Value.cpp @@ -2487,12 +2487,7 @@ Value operator/(const Value& lhs, const Value& rhs) { return lVal / denom; } case Value::Type::FLOAT: { - double denom = rhs.getFloat(); - if (std::abs(denom) > kEpsilon) { - return lhs.getInt() / denom; - } else { - return Value::kNullDivByZero; - } + return lhs.getInt() / rhs.getFloat(); } default: { return Value::kNullBadType; @@ -2502,20 +2497,10 @@ Value operator/(const Value& lhs, const Value& rhs) { case Value::Type::FLOAT: { switch (rhs.type()) { case Value::Type::INT: { - int64_t denom = rhs.getInt(); - if (denom != 0) { - return lhs.getFloat() / denom; - } else { - return Value::kNullDivByZero; - } + return lhs.getFloat() / rhs.getInt(); } case Value::Type::FLOAT: { - double denom = rhs.getFloat(); - if (std::abs(denom) > kEpsilon) { - return lhs.getFloat() / denom; - } else { - return Value::kNullDivByZero; - } + return lhs.getFloat() / rhs.getFloat(); } default: { return Value::kNullBadType; @@ -2548,12 +2533,7 @@ Value operator%(const Value& lhs, const Value& rhs) { } } case Value::Type::FLOAT: { - double denom = rhs.getFloat(); - if (std::abs(denom) > kEpsilon) { - return std::fmod(lhs.getInt(), denom); - } else { - return Value::kNullDivByZero; - } + return std::fmod(lhs.getInt(), rhs.getFloat()); } default: { return Value::kNullBadType; @@ -2563,20 +2543,10 @@ Value operator%(const Value& lhs, const Value& rhs) { case Value::Type::FLOAT: { switch (rhs.type()) { case Value::Type::INT: { - int64_t denom = rhs.getInt(); - if (denom != 0) { - return std::fmod(lhs.getFloat(), denom); - } else { - return Value::kNullDivByZero; - } + return std::fmod(lhs.getFloat(), rhs.getInt()); } case Value::Type::FLOAT: { - double denom = rhs.getFloat(); - if (std::abs(denom) > kEpsilon) { - return std::fmod(lhs.getFloat(), denom); - } else { - return Value::kNullDivByZero; - } + return std::fmod(lhs.getFloat(), rhs.getFloat()); } default: { return Value::kNullBadType; @@ -2877,11 +2847,11 @@ bool operator>(const Value& lhs, const Value& rhs) { } bool operator<=(const Value& lhs, const Value& rhs) { - return !(rhs < lhs); + return lhs < rhs || lhs == rhs; } bool operator>=(const Value& lhs, const Value& rhs) { - return !(lhs < rhs); + return lhs > rhs || lhs == rhs; } Value operator&&(const Value& lhs, const Value& rhs) { diff --git a/src/common/datatypes/Value.h b/src/common/datatypes/Value.h index 4f1e1e9076e..06f4e8b2241 100644 --- a/src/common/datatypes/Value.h +++ b/src/common/datatypes/Value.h @@ -367,8 +367,12 @@ struct Value { Value toInt() const; Value toSet() const; + // Expr use this function instead of operator<, because a Value compare to a Null + // return null instead of true or false Value lessThan(const Value& v) const; + // Expr use this function instead of operator==, because a Value compare to a Null + // return null instead of true or false Value equal(const Value& v) const; // Whether the value can be converted to bool implicitly diff --git a/src/common/datatypes/test/ValueTest.cpp b/src/common/datatypes/test/ValueTest.cpp index e04ae18564a..73abce84640 100644 --- a/src/common/datatypes/test/ValueTest.cpp +++ b/src/common/datatypes/test/ValueTest.cpp @@ -6,6 +6,8 @@ #include #include +#include + #include "common/base/Base.h" #include "common/datatypes/CommonCpp2Ops.h" #include "common/datatypes/DataSet.h" @@ -598,7 +600,8 @@ TEST(Value, Arithmetics) { EXPECT_EQ((vFloat1.getFloat() / vFloat2.getFloat()), v.getFloat()); v = vFloat1 / vZero; - EXPECT_EQ(Value::Type::NULLVALUE, v.type()); + EXPECT_EQ(Value::Type::FLOAT, v.type()); + EXPECT_EQ(std::numeric_limits::infinity(), v.getFloat()); v = vInt1 / vZero; EXPECT_EQ(Value::Type::NULLVALUE, v.type()); } @@ -629,7 +632,8 @@ TEST(Value, Arithmetics) { EXPECT_EQ(std::fmod(vFloat1.getFloat(), vFloat2.getFloat()), v.getFloat()); v = vFloat1 % vZero; - EXPECT_EQ(Value::Type::NULLVALUE, v.type()); + EXPECT_EQ(Value::Type::FLOAT, v.type()); + EXPECT_TRUE(std::isnan(v.getFloat())); v = vInt1 % vZero; EXPECT_EQ(Value::Type::NULLVALUE, v.type()); } @@ -682,6 +686,11 @@ TEST(Value, Comparison) { Value vInt2(2); Value vFloat1(3.14); Value vFloat2(2.67); + Value vFloat3(-2.67); + Value vFloatNaN(0 / 0.0); + Value vFloatPositiveInfinity(1 / 0.0); + Value vFloatNegativeInfinity(-1 / 0.0); + Value vStr1("Hello "); Value vStr2("World"); Value vBool1(false); @@ -811,6 +820,99 @@ TEST(Value, Comparison) { v = vFloat1 <= vFloat2; EXPECT_EQ(Value::Type::BOOL, v.type()); EXPECT_EQ(false, v.getBool()); + + // NaN comparison + // https://en.wikipedia.org/wiki/NaN#Comparison_with_NaN + // Comparison between NaN and any floating-point value x (including NaN and ±Inf) + // Comparison NaN ≥ x NaN ≤ x NaN > x NaN < x NaN = x NaN ≠ x + // Result False False False False False True + v = vFloatNaN >= vFloat1; + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatNaN > vFloat1; + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatNaN < vFloat1; + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatNaN.lessThan(vFloat1); + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatNaN <= vFloat1; + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatNaN >= vFloat3; + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatNaN > vFloat3; + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatNaN < vFloat3; + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatNaN <= vFloat3; + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + // NaN != any Value + v = vFloatNaN != vFloat3; + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(true, v.getBool()); + v = vFloatNaN == vFloat3; + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + v = vFloatNaN.equal(vFloat3); + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatNaN.equal(Value(0 / 0.0)); + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + } + + { + // -Inf + Value v = vFloatPositiveInfinity.lessThan(vFloatNegativeInfinity); + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatNegativeInfinity.lessThan(vFloatPositiveInfinity); + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(true, v.getBool()); + + v = vFloatNegativeInfinity.lessThan(vInt1); + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(true, v.getBool()); + + v = vFloatNegativeInfinity.lessThan(vFloat1); + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(true, v.getBool()); + + // +Inf + v = vFloatPositiveInfinity.lessThan(vInt1); + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatPositiveInfinity.lessThan(vFloat1); + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + // NaN comparison always false + v = vFloatNegativeInfinity.lessThan(vFloatNaN); + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); + + v = vFloatPositiveInfinity.lessThan(vFloatNaN); + EXPECT_EQ(Value::Type::BOOL, v.type()); + EXPECT_EQ(false, v.getBool()); } // int and float diff --git a/src/common/expression/RelationalExpression.cpp b/src/common/expression/RelationalExpression.cpp index 50cdac11079..2baedf5ca25 100644 --- a/src/common/expression/RelationalExpression.cpp +++ b/src/common/expression/RelationalExpression.cpp @@ -29,10 +29,10 @@ const Value& RelationalExpression::eval(ExpressionContext& ctx) { result_ = lhs.lessThan(rhs) || lhs.equal(rhs); break; case Kind::kRelGT: - result_ = !lhs.lessThan(rhs) && !lhs.equal(rhs); + result_ = rhs.lessThan(lhs); break; case Kind::kRelGE: - result_ = !lhs.lessThan(rhs) || lhs.equal(rhs); + result_ = rhs.lessThan(lhs) || lhs.equal(rhs); break; case Kind::kRelREG: { if (lhs.isBadNull() || rhs.isBadNull()) { diff --git a/src/common/expression/test/ArithmeticExpressionTest.cpp b/src/common/expression/test/ArithmeticExpressionTest.cpp index 689f6de5392..5b6e859cf6c 100644 --- a/src/common/expression/test/ArithmeticExpressionTest.cpp +++ b/src/common/expression/test/ArithmeticExpressionTest.cpp @@ -22,13 +22,13 @@ TEST_F(ArithmeticExpressionTest, TestArithmeticExpression) { TEST_EXPR(11 * 2, 22); TEST_EXPR(11 * 2.2, 24.2); TEST_EXPR(100.4 / 4, 25.1); - TEST_EXPR(10.4 % 0, NullType::DIV_BY_ZERO); - TEST_EXPR(10 % 0.0, NullType::DIV_BY_ZERO); - TEST_EXPR(10.4 % 0.0, NullType::DIV_BY_ZERO); + TEST_EXPR(10.4 % 0, Value(std::numeric_limits::quiet_NaN())); + TEST_EXPR(10 % 0.0, Value(std::numeric_limits::quiet_NaN())); + TEST_EXPR(10.4 % 0.0, Value(std::numeric_limits::quiet_NaN())); TEST_EXPR(10 / 0, NullType::DIV_BY_ZERO); - TEST_EXPR(12 / 0.0, NullType::DIV_BY_ZERO); - TEST_EXPR(187. / 0.0, NullType::DIV_BY_ZERO); - TEST_EXPR(17. / 0, NullType::DIV_BY_ZERO); + TEST_EXPR(12 / 0.0, std::numeric_limits::infinity()); + TEST_EXPR(187. / 0.0, std::numeric_limits::infinity()); + TEST_EXPR(17. / 0, std::numeric_limits::infinity()); } { TEST_EXPR(1 + 2 + 3.2, 6.2); diff --git a/src/common/expression/test/TestBase.h b/src/common/expression/test/TestBase.h index 7969f5a70ee..42ff3ca8bfb 100644 --- a/src/common/expression/test/TestBase.h +++ b/src/common/expression/test/TestBase.h @@ -77,7 +77,12 @@ class ExpressionTest : public ::testing::Test { Expression *ep = yieldSentence->yield()->yields()->back()->expr(); auto eval = Expression::eval(ep, gExpCtxt); EXPECT_EQ(eval.type(), expected.type()) << "type check failed: " << ep->toString(); - EXPECT_EQ(eval, expected) << "check failed: " << ep->toString(); + // NaN is not equals to NaN, check equals should use std::isnan() + if (expected.type() == Value::Type::FLOAT && std::isnan(expected.getFloat())) { + EXPECT_TRUE(std::isnan(eval.getFloat())) << "check failed: " << ep->toString(); + } else { + EXPECT_EQ(eval, expected) << "check failed: " << ep->toString(); + } } void testToString(const std::string &exprSymbol, const char *expected) { diff --git a/src/common/function/FunctionManager.cpp b/src/common/function/FunctionManager.cpp index 425a87410eb..ee800b6512b 100644 --- a/src/common/function/FunctionManager.cpp +++ b/src/common/function/FunctionManager.cpp @@ -592,16 +592,10 @@ FunctionManager::FunctionManager() { } case Value::Type::INT: { auto val = args[0].get().getInt(); - if (val < 0) { - return Value::kNullValue; - } return std::sqrt(val); } case Value::Type::FLOAT: { auto val = args[0].get().getFloat(); - if (val < 0) { - return Value::kNullValue; - } return std::sqrt(val); } default: { diff --git a/src/common/function/test/FunctionManagerTest.cpp b/src/common/function/test/FunctionManagerTest.cpp index 05e36f119f9..0c3b025a319 100644 --- a/src/common/function/test/FunctionManagerTest.cpp +++ b/src/common/function/test/FunctionManagerTest.cpp @@ -255,6 +255,8 @@ TEST_F(FunctionManagerTest, functionCall) { { TEST_FUNCTION(sqrt, args_["int"], 2.0); TEST_FUNCTION(sqrt, args_["float"], std::sqrt(1.1)); + TEST_FUNCTION(sqrt, {Value(-1)}, std::sqrt(-1)); + TEST_FUNCTION(sqrt, {Value(0)}, std::sqrt(0)); } { TEST_FUNCTION(cbrt, args_["int"], std::cbrt(4)); diff --git a/tests/common/comparator.py b/tests/common/comparator.py index d8e6790d054..d496a38dcfa 100644 --- a/tests/common/comparator.py +++ b/tests/common/comparator.py @@ -25,13 +25,13 @@ class DataSetComparator: def __init__( - self, - strict=True, - order=False, - contains=CmpType.EQUAL, - first_n_records=-1, - decode_type='utf-8', - vid_fn=None, + self, + strict=True, + order=False, + contains=CmpType.EQUAL, + first_n_records=-1, + decode_type='utf-8', + vid_fn=None, ): self._strict = strict self._order = order @@ -51,9 +51,9 @@ def s(self, b: bytes) -> str: def _whether_return(self, cmp: bool) -> bool: return ( - (self._contains == CmpType.EQUAL and not cmp) - or (self._contains == CmpType.CONTAINS and not cmp) - or (self._contains == CmpType.NOT_CONTAINS and cmp) + (self._contains == CmpType.EQUAL and not cmp) + or (self._contains == CmpType.CONTAINS and not cmp) + or (self._contains == CmpType.NOT_CONTAINS and cmp) ) def compare(self, resp: DataSet, expect: DataSet): @@ -73,7 +73,7 @@ def compare(self, resp: DataSet, expect: DataSet): if self._order and self._contains == CmpType.EQUAL: if self._first_n_records > 0: # just compare the first n records - resp_rows = resp.rows[0 : self._first_n_records] + resp_rows = resp.rows[0: self._first_n_records] else: resp_rows = resp.rows @@ -107,6 +107,11 @@ def compare_value(self, lhs: Value, rhs: Union[Value, Pattern]) -> bool: if lhs.getType() == Value.FVAL: if not rhs.getType() == Value.FVAL: return False + # handle nan & inf + if math.isnan(lhs.get_fVal()): + return math.isnan(rhs.get_fVal()) + if math.isinf(lhs.get_fVal()): + return math.isinf(rhs.get_fVal()) return math.fabs(lhs.get_fVal() - rhs.get_fVal()) < 1.0e-8 if lhs.getType() == Value.SVAL: if not rhs.getType() == Value.SVAL: @@ -236,7 +241,7 @@ def compare_edge(self, lhs: Edge, rhs: Edge): return False rsrc, rdst = self.eid(rhs, lhs.type) if not ( - self.compare_vid(lhs.src, rsrc) and self.compare_vid(lhs.dst, rdst) + self.compare_vid(lhs.src, rsrc) and self.compare_vid(lhs.dst, rdst) ): return False if rhs.props is None or len(lhs.props) != len(rhs.props): @@ -245,7 +250,7 @@ def compare_edge(self, lhs: Edge, rhs: Edge): if rhs.src is not None and rhs.dst is not None: rsrc, rdst = self.eid(rhs, lhs.type) if not ( - self.compare_vid(lhs.src, rsrc) and self.compare_vid(lhs.dst, rdst) + self.compare_vid(lhs.src, rsrc) and self.compare_vid(lhs.dst, rdst) ): return False if rhs.ranking is not None: @@ -262,9 +267,9 @@ def bstr(self, vid) -> bytes: return self.b(vid) if type(vid) == str else vid def _compare_vid( - self, - lid: Union[int, bytes], - rid: Union[int, bytes, str], + self, + lid: Union[int, bytes], + rid: Union[int, bytes, str], ) -> bool: if type(lid) is bytes: return type(rid) in [str, bytes] and lid == self.bstr(rid) diff --git a/tests/tck/features/bugfix/NaNInfinityFloat.feature b/tests/tck/features/bugfix/NaNInfinityFloat.feature new file mode 100644 index 00000000000..173ca600afd --- /dev/null +++ b/tests/tck/features/bugfix/NaNInfinityFloat.feature @@ -0,0 +1,43 @@ +# Copyright (c) 2022 vesoft inc. All rights reserved. +# +# This source code is licensed under Apache 2.0 License. +Feature: NaN/Infinity result test + + # issue https://github.com/vesoft-inc/nebula/issues/3473 + Scenario: NaN/Infinity result test + Given an empty graph + And create a space with following options: + | partition_num | 1 | + | replica_factor | 1 | + | vid_type | FIXED_STRING(30) | + | charset | utf8 | + | collate | utf8_bin | + When executing query: + """ + Yield 0/0.0 + """ + Then the result should be, in any order, with relax comparison: + | (0/0) | + | nan | + When executing query: + """ + Yield 1/0.0 + """ + Then the result should be, in any order, with relax comparison: + | (1/0) | + | inf | + When executing query: + """ + Yield -1/0.0 + """ + Then the result should be, in any order, with relax comparison: + | (-(1)/0) | + | -inf | + When executing query: + """ + Yield sqrt(-1.0) + """ + Then the result should be, in any order, with relax comparison: + | sqrt(-(1)) | + | nan | + Then drop the used space diff --git a/tests/tck/features/parser/nebula.feature b/tests/tck/features/parser/nebula.feature index e410d1ce956..7e802ce4155 100644 --- a/tests/tck/features/parser/nebula.feature +++ b/tests/tck/features/parser/nebula.feature @@ -12,7 +12,8 @@ Feature: Value parsing | format | _type | | EMPTY | EMPTY | | NULL | NULL | - | NaN | NaN | + | NaN | fVal | + | Inf | fVal | | BAD_DATA | BAD_DATA | | BAD_TYPE | BAD_TYPE | | OVERFLOW | ERR_OVERFLOW | diff --git a/tests/tck/utils/nbv.py b/tests/tck/utils/nbv.py index ed694a0ee1f..6bc4e3a1714 100644 --- a/tests/tck/utils/nbv.py +++ b/tests/tck/utils/nbv.py @@ -37,7 +37,8 @@ tokens = ( 'EMPTY', 'NULL', - 'NaN', + 'nan', + 'inf', 'BAD_DATA', 'BAD_TYPE', 'OVERFLOW', @@ -72,18 +73,21 @@ def t_NULL(t): return t -def t_NaN(t): - r'NaN' - t.value = Value(nVal=NullType.NaN) +def t_nan(t): + r'nan' + t.value = Value(fVal=float('nan')) return t +def t_inf(t): + r'inf' + t.value = Value(fVal=float('inf')) + return t def t_BAD_DATA(t): r'BAD_DATA' t.value = Value(nVal=NullType.BAD_DATA) return t - def t_BAD_TYPE(t): r'BAD_TYPE' t.value = Value(nVal=NullType.BAD_TYPE) @@ -230,7 +234,8 @@ def p_expr(p): ''' expr : EMPTY | NULL - | NaN + | nan + | inf | BAD_DATA | BAD_TYPE | OVERFLOW @@ -577,7 +582,8 @@ def parse_row(row): expected = {} expected['EMPTY'] = Value() expected['NULL'] = Value(nVal=NullType.__NULL__) - expected['NaN'] = Value(nVal=NullType.NaN) + expected['nan'] = Value(fVal=float('nan')) + expected['inf'] = Value(fVal=float('inf')) expected['BAD_DATA'] = Value(nVal=NullType.BAD_DATA) expected['BAD_TYPE'] = Value(nVal=NullType.BAD_TYPE) expected['OVERFLOW'] = Value(nVal=NullType.ERR_OVERFLOW)