Skip to content

Commit

Permalink
Add unhex Spark function (facebookincubator#8289)
Browse files Browse the repository at this point in the history
Summary:
Spark's `unhex` function allows input arg with an odd length, and not throw for non-ascii char.
Spark's implementation details: https://github.com/apache/spark/blob/28da1d853477b306774798d8aa738901221fb804/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L1057-L1089

Pull Request resolved: facebookincubator#8289

Reviewed By: Yuhta

Differential Revision: D54537678

Pulled By: mbasmanova

fbshipit-source-id: 393250a277b24825d9629595d391ce723a08a574
  • Loading branch information
Yohahaha authored and facebook-github-bot committed Mar 6, 2024
1 parent 07a1518 commit b8b48d5
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 0 deletions.
15 changes: 15 additions & 0 deletions velox/docs/functions/spark/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,18 @@ Mathematical Functions
.. spark:function:: unaryminus(x) -> [same as x]
Returns the negative of `x`. Corresponds to Spark's operator ``-``.

.. spark:function:: unhex(x) -> varbinary
Converts hexadecimal varchar ``x`` to varbinary.
``x`` is considered case insensitive and expected to contain only hexadecimal characters 0-9 and A-F.
If ``x`` contains non-hexadecimal character, the function returns NULL.
When ``x`` contains even number of characters, each pair is converted to a single byte. The number of bytes in the result is half the number of bytes in the input.
When ``x`` contains a single character, the result contains a single byte whose value matches the hexadecimal character.
When ``x`` contains an odd number characters greater than 2, the first character is ignored, the remaining pairs of characters are converted to bytes, then zero byte is added at the end of the output. ::

SELECT unhex("23"); -- #
SELECT unhex("f"); -- \x0F
SELECT unhex("b2323"); -- ##\0
SELECT unhex("G"); -- NULL
SELECT unhex("G23"); -- NULL
55 changes: 55 additions & 0 deletions velox/functions/sparksql/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,4 +350,59 @@ struct ToHexBigintFunction {
ToHexUtil::toHex(input, result);
}
};

namespace detail {
FOLLY_ALWAYS_INLINE static int8_t fromHex(char c) {
if (c >= '0' && c <= '9') {
return c - '0';
}

if (c >= 'A' && c <= 'F') {
return 10 + c - 'A';
}

if (c >= 'a' && c <= 'f') {
return 10 + c - 'a';
}
return -1;
}
} // namespace detail

template <typename T>
struct UnHexFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE bool call(
out_type<Varbinary>& result,
const arg_type<Varchar>& input) {
const auto resultSize = (input.size() + 1) >> 1;
result.resize(resultSize);
const char* inputBuffer = input.data();
char* resultBuffer = result.data();

int32_t i = 0;
if ((input.size() & 0x01) != 0) {
const auto v = detail::fromHex(inputBuffer[0]);
if (v == -1) {
return false;
}
// out_type<Varbinary> resize does not guarantee all chars initialized
// with 0, filling last char with 0 to align with Spark.
resultBuffer[resultSize - 1] = 0;
resultBuffer[0] = v;
i += 1;
}

while (i < input.size()) {
const auto first = detail::fromHex(inputBuffer[i]);
const auto second = detail::fromHex(inputBuffer[i + 1]);
if (first == -1 || second == -1) {
return false;
}
resultBuffer[i / 2] = (first << 4) | second;
i += 2;
}
return true;
}
};
} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/RegisterArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ void registerArithmeticFunctions(const std::string& prefix) {
{prefix + "round"});
registerFunction<RoundFunction, double, double, int32_t>({prefix + "round"});
registerFunction<RoundFunction, float, float, int32_t>({prefix + "round"});
registerFunction<UnHexFunction, Varbinary, Varchar>({prefix + "unhex"});
// In Spark only long, double, and decimal have ceil/floor
registerFunction<sparksql::CeilFunction, int64_t, int64_t>({prefix + "ceil"});
registerFunction<sparksql::CeilFunction, int64_t, double>({prefix + "ceil"});
Expand Down
21 changes: 21 additions & 0 deletions velox/functions/sparksql/tests/ArithmeticTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,27 @@ TEST_F(ArithmeticTest, cosh) {
EXPECT_TRUE(std::isnan(cosh(kNan).value_or(0)));
}

TEST_F(ArithmeticTest, unhex) {
const auto unhex = [&](std::optional<std::string> a) {
return evaluateOnce<std::string>("unhex(c0)", a);
};

EXPECT_EQ(unhex("737472696E67"), "string");
EXPECT_EQ(unhex(""), "");
EXPECT_EQ(unhex("23"), "#");
std::string b("#\0", 2);
EXPECT_EQ(unhex("123"), b);
EXPECT_EQ(unhex("b23"), b);
b = std::string("##\0", 3);
EXPECT_EQ(unhex("b2323"), b);
EXPECT_EQ(unhex("F"), "\x0F");
EXPECT_EQ(unhex("ff"), "\xFF");
EXPECT_EQ(unhex("G"), std::nullopt);
EXPECT_EQ(unhex("GG"), std::nullopt);
EXPECT_EQ(unhex("G23"), std::nullopt);
EXPECT_EQ(unhex("E4B889E9878DE79A84"), "\u4E09\u91CD\u7684");
}

class CeilFloorTest : public SparkFunctionBaseTest {
protected:
template <typename T>
Expand Down

0 comments on commit b8b48d5

Please sign in to comment.