From b8b48d5f47aee9bc746cabdc0879a889e2b38235 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Tue, 5 Mar 2024 20:31:40 -0800 Subject: [PATCH] Add unhex Spark function (#8289) Summary: Spark's `unhex` function allows input arg with an odd length, and not throw for non-ascii char. Spark's implementation details: https://github.com/apache/spark/blob/28da1d853477b306774798d8aa738901221fb804/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L1057-L1089 Pull Request resolved: https://github.com/facebookincubator/velox/pull/8289 Reviewed By: Yuhta Differential Revision: D54537678 Pulled By: mbasmanova fbshipit-source-id: 393250a277b24825d9629595d391ce723a08a574 --- velox/docs/functions/spark/math.rst | 15 +++++ velox/functions/sparksql/Arithmetic.h | 55 +++++++++++++++++++ .../functions/sparksql/RegisterArithmetic.cpp | 1 + .../sparksql/tests/ArithmeticTest.cpp | 21 +++++++ 4 files changed, 92 insertions(+) diff --git a/velox/docs/functions/spark/math.rst b/velox/docs/functions/spark/math.rst index af05926a75e9..5d1ffde64c82 100644 --- a/velox/docs/functions/spark/math.rst +++ b/velox/docs/functions/spark/math.rst @@ -233,3 +233,18 @@ Mathematical Functions .. spark:function:: unaryminus(x) -> [same as x] Returns the negative of `x`. Corresponds to Spark's operator ``-``. + +.. spark:function:: unhex(x) -> varbinary + + Converts hexadecimal varchar ``x`` to varbinary. + ``x`` is considered case insensitive and expected to contain only hexadecimal characters 0-9 and A-F. + If ``x`` contains non-hexadecimal character, the function returns NULL. + When ``x`` contains even number of characters, each pair is converted to a single byte. The number of bytes in the result is half the number of bytes in the input. + When ``x`` contains a single character, the result contains a single byte whose value matches the hexadecimal character. + When ``x`` contains an odd number characters greater than 2, the first character is ignored, the remaining pairs of characters are converted to bytes, then zero byte is added at the end of the output. :: + + SELECT unhex("23"); -- # + SELECT unhex("f"); -- \x0F + SELECT unhex("b2323"); -- ##\0 + SELECT unhex("G"); -- NULL + SELECT unhex("G23"); -- NULL \ No newline at end of file diff --git a/velox/functions/sparksql/Arithmetic.h b/velox/functions/sparksql/Arithmetic.h index 7877389260a0..aedc3d44f21b 100644 --- a/velox/functions/sparksql/Arithmetic.h +++ b/velox/functions/sparksql/Arithmetic.h @@ -350,4 +350,59 @@ struct ToHexBigintFunction { ToHexUtil::toHex(input, result); } }; + +namespace detail { +FOLLY_ALWAYS_INLINE static int8_t fromHex(char c) { + if (c >= '0' && c <= '9') { + return c - '0'; + } + + if (c >= 'A' && c <= 'F') { + return 10 + c - 'A'; + } + + if (c >= 'a' && c <= 'f') { + return 10 + c - 'a'; + } + return -1; +} +} // namespace detail + +template +struct UnHexFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE bool call( + out_type& result, + const arg_type& input) { + const auto resultSize = (input.size() + 1) >> 1; + result.resize(resultSize); + const char* inputBuffer = input.data(); + char* resultBuffer = result.data(); + + int32_t i = 0; + if ((input.size() & 0x01) != 0) { + const auto v = detail::fromHex(inputBuffer[0]); + if (v == -1) { + return false; + } + // out_type resize does not guarantee all chars initialized + // with 0, filling last char with 0 to align with Spark. + resultBuffer[resultSize - 1] = 0; + resultBuffer[0] = v; + i += 1; + } + + while (i < input.size()) { + const auto first = detail::fromHex(inputBuffer[i]); + const auto second = detail::fromHex(inputBuffer[i + 1]); + if (first == -1 || second == -1) { + return false; + } + resultBuffer[i / 2] = (first << 4) | second; + i += 2; + } + return true; + } +}; } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/RegisterArithmetic.cpp b/velox/functions/sparksql/RegisterArithmetic.cpp index 5aa85c778878..8fd92861495f 100644 --- a/velox/functions/sparksql/RegisterArithmetic.cpp +++ b/velox/functions/sparksql/RegisterArithmetic.cpp @@ -84,6 +84,7 @@ void registerArithmeticFunctions(const std::string& prefix) { {prefix + "round"}); registerFunction({prefix + "round"}); registerFunction({prefix + "round"}); + registerFunction({prefix + "unhex"}); // In Spark only long, double, and decimal have ceil/floor registerFunction({prefix + "ceil"}); registerFunction({prefix + "ceil"}); diff --git a/velox/functions/sparksql/tests/ArithmeticTest.cpp b/velox/functions/sparksql/tests/ArithmeticTest.cpp index 0ed06df554ce..f44560a4fe25 100644 --- a/velox/functions/sparksql/tests/ArithmeticTest.cpp +++ b/velox/functions/sparksql/tests/ArithmeticTest.cpp @@ -268,6 +268,27 @@ TEST_F(ArithmeticTest, cosh) { EXPECT_TRUE(std::isnan(cosh(kNan).value_or(0))); } +TEST_F(ArithmeticTest, unhex) { + const auto unhex = [&](std::optional a) { + return evaluateOnce("unhex(c0)", a); + }; + + EXPECT_EQ(unhex("737472696E67"), "string"); + EXPECT_EQ(unhex(""), ""); + EXPECT_EQ(unhex("23"), "#"); + std::string b("#\0", 2); + EXPECT_EQ(unhex("123"), b); + EXPECT_EQ(unhex("b23"), b); + b = std::string("##\0", 3); + EXPECT_EQ(unhex("b2323"), b); + EXPECT_EQ(unhex("F"), "\x0F"); + EXPECT_EQ(unhex("ff"), "\xFF"); + EXPECT_EQ(unhex("G"), std::nullopt); + EXPECT_EQ(unhex("GG"), std::nullopt); + EXPECT_EQ(unhex("G23"), std::nullopt); + EXPECT_EQ(unhex("E4B889E9878DE79A84"), "\u4E09\u91CD\u7684"); +} + class CeilFloorTest : public SparkFunctionBaseTest { protected: template