From 156e21651c281bd3c0d900fe5e9937995b2c3684 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 <xzhangxian@foxmail.com> Date: Wed, 28 Dec 2022 16:24:17 +0800 Subject: [PATCH] Support regexp_replace function (#6370) close pingcap/tiflash#6115 --- dbms/src/Common/OptimizedRegularExpression.h | 7 + .../Common/OptimizedRegularExpression.inl.h | 171 +++++- dbms/src/Flash/Coprocessor/DAGUtils.cpp | 2 +- dbms/src/Functions/FunctionsRegexp.cpp | 271 +--------- dbms/src/Functions/FunctionsRegexp.h | 485 ++++++++++++++++-- dbms/src/Functions/tests/gtest_regexp.cpp | 308 +++++++---- tests/fullstack-test/expr/regexp.test | 26 +- 7 files changed, 847 insertions(+), 423 deletions(-) diff --git a/dbms/src/Common/OptimizedRegularExpression.h b/dbms/src/Common/OptimizedRegularExpression.h index 661b1233cfb..1cb3daab368 100644 --- a/dbms/src/Common/OptimizedRegularExpression.h +++ b/dbms/src/Common/OptimizedRegularExpression.h @@ -14,6 +14,7 @@ #pragma once +#include <Columns/ColumnString.h> #include <Common/config.h> #include <common/StringRef.h> #include <common/types.h> @@ -117,6 +118,7 @@ class OptimizedRegularExpressionImpl Int64 instr(const char * subject, size_t subject_size, Int64 pos, Int64 occur, Int64 ret_op); std::optional<StringRef> substr(const char * subject, size_t subject_size, Int64 pos, Int64 occur); + void replace(const char * subject, size_t subject_size, DB::ColumnString::Chars_t & res_data, DB::ColumnString::Offset & res_offset, const StringRef & repl, Int64 pos, Int64 occur); private: Int64 processInstrEmptyStringExpr(const char * expr, size_t expr_size, size_t byte_pos, Int64 occur); @@ -125,6 +127,11 @@ class OptimizedRegularExpressionImpl std::optional<StringRef> processSubstrEmptyStringExpr(const char * expr, size_t expr_size, size_t byte_pos, Int64 occur); std::optional<StringRef> substrImpl(const char * subject, size_t subject_size, Int64 byte_pos, Int64 occur); + void processReplaceEmptyStringExpr(const char * subject, size_t subject_size, DB::ColumnString::Chars_t & res_data, DB::ColumnString::Offset & res_offset, const StringRef & repl, Int64 byte_pos, Int64 occur); + void replaceImpl(const char * subject, size_t subject_size, DB::ColumnString::Chars_t & res_data, DB::ColumnString::Offset & res_offset, const StringRef & repl, Int64 byte_pos, Int64 occur); + void replaceOneImpl(const char * subject, size_t subject_size, DB::ColumnString::Chars_t & res_data, DB::ColumnString::Offset & res_offset, const StringRef & repl, Int64 byte_pos, Int64 occur); + void replaceAllImpl(const char * subject, size_t subject_size, DB::ColumnString::Chars_t & res_data, DB::ColumnString::Offset & res_offset, const StringRef & repl, Int64 byte_pos); + bool is_trivial; bool required_substring_is_prefix; bool is_case_insensitive; diff --git a/dbms/src/Common/OptimizedRegularExpression.inl.h b/dbms/src/Common/OptimizedRegularExpression.inl.h index a73e78562b5..21e5d635ef9 100644 --- a/dbms/src/Common/OptimizedRegularExpression.inl.h +++ b/dbms/src/Common/OptimizedRegularExpression.inl.h @@ -21,6 +21,7 @@ #include <common/defines.h> #include <common/types.h> +#include <cstring> #include <iostream> #include <optional> @@ -499,22 +500,67 @@ std::optional<StringRef> OptimizedRegularExpressionImpl<thread_safe>::processSub return std::optional<StringRef>(StringRef(matched_str.data(), matched_str.size())); } -static inline void checkInstrArgs(Int64 utf8_total_len, size_t subject_size, Int64 pos, Int64 ret_op) +template <bool thread_safe> +void OptimizedRegularExpressionImpl<thread_safe>::processReplaceEmptyStringExpr(const char * subject, size_t subject_size, DB::ColumnString::Chars_t & res_data, DB::ColumnString::Offset & res_offset, const StringRef & repl, Int64 byte_pos, Int64 occur) { - RUNTIME_CHECK_MSG(!(ret_op != 0 && ret_op != 1), "Incorrect argument to regexp function: return_option must be 1 or 0"); - RUNTIME_CHECK_MSG(!(pos <= 0 || (pos > utf8_total_len && subject_size != 0)), "Index out of bounds in regular function."); + if (occur > 1 || byte_pos != 1) + { + res_data.resize(res_data.size() + 1); + res_data[res_offset++] = '\0'; + return; + } + + StringPieceType expr_sp(subject, subject_size); + StringPieceType matched_str; + bool success = RegexType::FindAndConsume(&expr_sp, *re2, &matched_str); + if (!success) + { + res_data.resize(res_data.size() + 1); + } + else + { + res_data.resize(res_data.size() + repl.size + 1); + memcpy(&res_data[res_offset], repl.data, repl.size); + res_offset += repl.size; + } + + res_data[res_offset++] = '\0'; } -static inline void checkSubstrArgs(Int64 utf8_total_len, size_t subject_size, Int64 pos) +namespace FunctionsRegexp +{ +inline void checkArgPos(Int64 utf8_total_len, size_t subject_size, Int64 pos) { RUNTIME_CHECK_MSG(!(pos <= 0 || (pos > utf8_total_len && subject_size != 0)), "Index out of bounds in regular function."); } -static inline void makeOccurValid(Int64 & occur) +inline void checkArgsInstr(Int64 utf8_total_len, size_t subject_size, Int64 pos, Int64 ret_op) +{ + RUNTIME_CHECK_MSG(!(ret_op != 0 && ret_op != 1), "Incorrect argument to regexp function: return_option must be 1 or 0"); + checkArgPos(utf8_total_len, subject_size, pos); +} + +inline void checkArgsSubstr(Int64 utf8_total_len, size_t subject_size, Int64 pos) +{ + checkArgPos(utf8_total_len, subject_size, pos); +} + +inline void checkArgsReplace(Int64 utf8_total_len, size_t subject_size, Int64 pos) +{ + checkArgPos(utf8_total_len, subject_size, pos); +} + +inline void makeOccurValid(Int64 & occur) { occur = occur < 1 ? 1 : occur; } +inline void makeReplaceOccurValid(Int64 & occur) +{ + occur = occur < 0 ? 1 : occur; +} +} // namespace FunctionsRegexp + template <bool thread_safe> Int64 OptimizedRegularExpressionImpl<thread_safe>::instrImpl(const char * subject, size_t subject_size, Int64 byte_pos, Int64 occur, Int64 ret_op) { @@ -557,13 +603,95 @@ std::optional<StringRef> OptimizedRegularExpressionImpl<thread_safe>::substrImpl return std::optional<StringRef>(StringRef(matched_str.data(), matched_str.size())); } +template <bool thread_safe> +void OptimizedRegularExpressionImpl<thread_safe>::replaceAllImpl(const char * subject, size_t subject_size, DB::ColumnString::Chars_t & res_data, DB::ColumnString::Offset & res_offset, const StringRef & repl, Int64 byte_pos) +{ + size_t byte_offset = byte_pos - 1; // This is a offset for bytes, not utf8 + StringPieceType expr_sp(subject + byte_offset, subject_size - byte_offset); + StringPieceType matched_str; + size_t prior_offset = 0; + + while (true) + { + bool success = RegexType::FindAndConsume(&expr_sp, *re2, &matched_str); + if (!success) + break; + + auto skipped_byte_size = reinterpret_cast<Int64>(matched_str.data() - (subject + prior_offset)); + res_data.resize(res_data.size() + skipped_byte_size); + memcpy(&res_data[res_offset], subject + prior_offset, skipped_byte_size); // copy the skipped bytes + res_offset += skipped_byte_size; + + res_data.resize(res_data.size() + repl.size); + memcpy(&res_data[res_offset], repl.data, repl.size); // replace the matched string + res_offset += repl.size; + + prior_offset = expr_sp.data() - subject; + } + + size_t suffix_byte_size = subject_size - prior_offset; + res_data.resize(res_data.size() + suffix_byte_size + 1); + memcpy(&res_data[res_offset], subject + prior_offset, suffix_byte_size); // Copy suffix string + res_offset += suffix_byte_size; + res_data[res_offset++] = 0; +} + +template <bool thread_safe> +void OptimizedRegularExpressionImpl<thread_safe>::replaceOneImpl(const char * subject, size_t subject_size, DB::ColumnString::Chars_t & res_data, DB::ColumnString::Offset & res_offset, const StringRef & repl, Int64 byte_pos, Int64 occur) +{ + size_t byte_offset = byte_pos - 1; // This is a offset for bytes, not utf8 + StringPieceType expr_sp(subject + byte_offset, subject_size - byte_offset); + StringPieceType matched_str; + + while (occur > 0) + { + bool success = RegexType::FindAndConsume(&expr_sp, *re2, &matched_str); + if (!success) + { + res_data.resize(res_data.size() + subject_size + 1); + memcpy(&res_data[res_offset], subject, subject_size); + res_offset += subject_size; + res_data[res_offset++] = 0; + return; + } + + --occur; + } + + auto prefix_byte_size = reinterpret_cast<Int64>(matched_str.data() - subject); + res_data.resize(res_data.size() + prefix_byte_size); + memcpy(&res_data[res_offset], subject, prefix_byte_size); // Copy prefix string + res_offset += prefix_byte_size; + + res_data.resize(res_data.size() + repl.size); + memcpy(&res_data[res_offset], repl.data, repl.size); // Replace the matched string + res_offset += repl.size; + + const char * suffix_str = subject + prefix_byte_size + matched_str.size(); + size_t suffix_byte_size = subject_size - prefix_byte_size - matched_str.size(); + res_data.resize(res_data.size() + suffix_byte_size + 1); + memcpy(&res_data[res_offset], suffix_str, suffix_byte_size); // Copy suffix string + res_offset += suffix_byte_size; + + res_data[res_offset++] = 0; +} + +template <bool thread_safe> +void OptimizedRegularExpressionImpl<thread_safe>::replaceImpl(const char * subject, size_t subject_size, DB::ColumnString::Chars_t & res_data, DB::ColumnString::Offset & res_offset, const StringRef & repl, Int64 byte_pos, Int64 occur) +{ + if (occur == 0) + return replaceAllImpl(subject, subject_size, res_data, res_offset, repl, byte_pos); + else + return replaceOneImpl(subject, subject_size, res_data, res_offset, repl, byte_pos, occur); +} + template <bool thread_safe> Int64 OptimizedRegularExpressionImpl<thread_safe>::instr(const char * subject, size_t subject_size, Int64 pos, Int64 occur, Int64 ret_op) { Int64 utf8_total_len = DB::UTF8::countCodePoints(reinterpret_cast<const UInt8 *>(subject), subject_size); ; - checkInstrArgs(utf8_total_len, subject_size, pos, ret_op); - makeOccurValid(occur); + FunctionsRegexp::checkArgsInstr(utf8_total_len, subject_size, pos, ret_op); + FunctionsRegexp::makeOccurValid(occur); if (unlikely(subject_size == 0)) return processInstrEmptyStringExpr(subject, subject_size, pos, occur); @@ -576,8 +704,8 @@ template <bool thread_safe> std::optional<StringRef> OptimizedRegularExpressionImpl<thread_safe>::substr(const char * subject, size_t subject_size, Int64 pos, Int64 occur) { Int64 utf8_total_len = DB::UTF8::countCodePoints(reinterpret_cast<const UInt8 *>(subject), subject_size); - checkSubstrArgs(utf8_total_len, subject_size, pos); - makeOccurValid(occur); + FunctionsRegexp::checkArgsSubstr(utf8_total_len, subject_size, pos); + FunctionsRegexp::makeOccurValid(occur); if (unlikely(subject_size == 0)) return processSubstrEmptyStringExpr(subject, subject_size, pos, occur); @@ -586,5 +714,30 @@ std::optional<StringRef> OptimizedRegularExpressionImpl<thread_safe>::substr(con return substrImpl(subject, subject_size, byte_pos, occur); } +template <bool thread_safe> +void OptimizedRegularExpressionImpl<thread_safe>::replace( + const char * subject, + size_t subject_size, + DB::ColumnString::Chars_t & res_data, + DB::ColumnString::Offset & res_offset, + const StringRef & repl, + Int64 pos, + Int64 occur) +{ + Int64 utf8_total_len = DB::UTF8::countCodePoints(reinterpret_cast<const UInt8 *>(subject), subject_size); + ; + FunctionsRegexp::checkArgsReplace(utf8_total_len, subject_size, pos); + FunctionsRegexp::makeReplaceOccurValid(occur); + + if (unlikely(subject_size == 0)) + { + processReplaceEmptyStringExpr(subject, subject_size, res_data, res_offset, repl, pos, occur); + return; + } + + size_t byte_pos = DB::UTF8::utf8Pos2bytePos(reinterpret_cast<const UInt8 *>(subject), pos); + replaceImpl(subject, subject_size, res_data, res_offset, repl, byte_pos, occur); +} + #undef MIN_LENGTH_FOR_STRSTR #undef MAX_SUBPATTERNS diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 3fe11851304..02445972ecd 100755 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -435,7 +435,7 @@ const std::unordered_map<tipb::ScalarFuncSig, String> scalar_func_map({ {tipb::ScalarFuncSig::RegexpUTF8Sig, "regexp"}, {tipb::ScalarFuncSig::RegexpLikeSig, "regexp_like"}, {tipb::ScalarFuncSig::RegexpInStrSig, "regexp_instr"}, - // {tipb::ScalarFuncSig::RegexpReplaceSig, "regexp_replace"}, + {tipb::ScalarFuncSig::RegexpReplaceSig, "regexp_replace"}, {tipb::ScalarFuncSig::RegexpSubstrSig, "regexp_substr"}, {tipb::ScalarFuncSig::JsonExtractSig, "json_extract"}, diff --git a/dbms/src/Functions/FunctionsRegexp.cpp b/dbms/src/Functions/FunctionsRegexp.cpp index bc25b91569b..a812744489f 100644 --- a/dbms/src/Functions/FunctionsRegexp.cpp +++ b/dbms/src/Functions/FunctionsRegexp.cpp @@ -20,285 +20,18 @@ namespace DB { -/** Replace all matches of regexp 'needle' to string 'replacement'. 'needle' and 'replacement' are constants. - * 'replacement' could contain substitutions, for example: '\2-\3-\1' - */ -template <bool replace_one = false> -struct ReplaceRegexpImpl -{ - static constexpr bool support_non_const_needle = false; - static constexpr bool support_non_const_replacement = false; - /// need customized escape char when do the string search - static const bool need_customized_escape_char = false; - /// support match type when do the string search, used in regexp - static const bool support_match_type = true; - - /// Sequence of instructions, describing how to get resulting string. - /// Each element is either: - /// - substitution (in that case first element of pair is their number and second element is empty) - /// - string that need to be inserted (in that case, first element of pair is -1 and second element is that string) - using Instructions = std::vector<std::pair<int, std::string>>; - - static const size_t max_captures = 10; - - static Instructions createInstructions(const std::string & s, int num_captures) - { - Instructions instructions; - - String now; - for (size_t i = 0; i < s.size(); ++i) - { - if (s[i] == '\\' && i + 1 < s.size()) - { - if (isNumericASCII(s[i + 1])) /// Substitution - { - if (!now.empty()) - { - instructions.emplace_back(-1, now); - now = ""; - } - instructions.emplace_back(s[i + 1] - '0', String()); - } - else - now += s[i + 1]; /// Escaping - ++i; - } - else - now += s[i]; /// Plain character - } - - if (!now.empty()) - { - instructions.emplace_back(-1, now); - now = ""; - } - - for (const auto & it : instructions) - if (it.first >= num_captures) - throw Exception("Invalid replace instruction in replacement string. Id: " + toString(it.first) + ", but regexp has only " - + toString(num_captures - 1) - + " subpatterns", - ErrorCodes::BAD_ARGUMENTS); - - return instructions; - } - - - static void processString(const re2_st::StringPiece & input, - ColumnString::Chars_t & res_data, - ColumnString::Offset & res_offset, - const Int64 & pos, - const Int64 & occ, - re2_st::RE2 & searcher, - int num_captures, - const Instructions & instructions) - { - re2_st::StringPiece matches[max_captures]; - - size_t start_pos = pos <= 0 ? 0 : pos - 1; - Int64 match_occ = 0; - size_t prefix_length = std::min(start_pos, static_cast<size_t>(input.length())); - if (prefix_length > 0) - { - /// Copy prefix - res_data.resize(res_data.size() + prefix_length); - memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], input.data(), prefix_length); - res_offset += prefix_length; - } - while (start_pos < static_cast<size_t>(input.length())) - { - /// If no more replacements possible for current string - bool can_finish_current_string = false; - - if (searcher.Match(input, start_pos, input.length(), re2_st::RE2::Anchor::UNANCHORED, matches, num_captures)) - { - match_occ++; - /// if occ == 0, it will replace all the match expr, otherwise it only replace the occ-th match - if (occ == 0 || match_occ == occ) - { - const auto & match = matches[0]; - size_t bytes_to_copy = (match.data() - input.data()) - start_pos; - - /// Copy prefix before matched regexp without modification - res_data.resize(res_data.size() + bytes_to_copy); - memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], input.data() + start_pos, bytes_to_copy); - res_offset += bytes_to_copy; - start_pos += bytes_to_copy + match.length(); - - /// Do substitution instructions - for (const auto & it : instructions) - { - if (it.first >= 0) - { - res_data.resize(res_data.size() + matches[it.first].length()); - memcpy(&res_data[res_offset], matches[it.first].data(), matches[it.first].length()); - res_offset += matches[it.first].length(); - } - else - { - res_data.resize(res_data.size() + it.second.size()); - memcpy(&res_data[res_offset], it.second.data(), it.second.size()); - res_offset += it.second.size(); - } - } - - /// when occ > 0, just replace the occ-th match even if replace_one is false - if (replace_one || match.length() == 0) /// Stop after match of zero length, to avoid infinite loop. - can_finish_current_string = true; - } - else - { - const auto & match = matches[0]; - size_t bytes_to_copy = (match.data() - input.data()) - start_pos + match.length(); - - /// Copy the matched string without modification - res_data.resize(res_data.size() + bytes_to_copy); - memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], input.data() + start_pos, bytes_to_copy); - res_offset += bytes_to_copy; - start_pos += bytes_to_copy; - if (match.length() == 0) - can_finish_current_string = true; - } - } - else - can_finish_current_string = true; - - /// If ready, append suffix after match to end of string. - if (can_finish_current_string) - { - res_data.resize(res_data.size() + input.length() - start_pos); - memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], input.data() + start_pos, input.length() - start_pos); - res_offset += input.length() - start_pos; - start_pos = input.length(); - } - } - - res_data.resize(res_data.size() + 1); - res_data[res_offset] = 0; - ++res_offset; - } - - - static void vector(const ColumnString::Chars_t & data, - const ColumnString::Offsets & offsets, - const std::string & needle, - const std::string & replacement, - const Int64 & pos, - const Int64 & occ, - const std::string & match_type, - TiDB::TiDBCollatorPtr collator, - ColumnString::Chars_t & res_data, - ColumnString::Offsets & res_offsets) - { - ColumnString::Offset res_offset = 0; - res_data.reserve(data.size()); - size_t size = offsets.size(); - res_offsets.resize(size); - - if (needle.empty()) - { - /// Copy all the data without changing. - res_data.resize(data.size()); - const UInt8 * begin = &data[0]; - memcpy(&res_data[0], begin, data.size()); - memcpy(&res_offsets[0], &offsets[0], size * sizeof(UInt64)); - return; - } - - String updated_needle = needle; - if (!match_type.empty() || collator != nullptr) - { - String mode_modifiers = re2Util::getRE2ModeModifiers(match_type, collator); - if (!mode_modifiers.empty()) - updated_needle = mode_modifiers + updated_needle; - } - re2_st::RE2 searcher(updated_needle); - int num_captures = std::min(searcher.NumberOfCapturingGroups() + 1, static_cast<int>(max_captures)); - - Instructions instructions = createInstructions(replacement, num_captures); - - /// Cannot perform search for whole block. Will process each string separately. - for (size_t i = 0; i < size; ++i) - { - int from = i > 0 ? offsets[i - 1] : 0; - re2_st::StringPiece input(reinterpret_cast<const char *>(&data[0] + from), offsets[i] - from - 1); - - processString(input, res_data, res_offset, pos, occ, searcher, num_captures, instructions); - res_offsets[i] = res_offset; - } - } - - static void vectorFixed(const ColumnString::Chars_t & data, - size_t n, - const std::string & needle, - const std::string & replacement, - const Int64 & pos, - const Int64 & occ, - const std::string & match_type, - TiDB::TiDBCollatorPtr collator, - ColumnString::Chars_t & res_data, - ColumnString::Offsets & res_offsets) - { - ColumnString::Offset res_offset = 0; - size_t size = data.size() / n; - res_data.reserve(data.size()); - res_offsets.resize(size); - - if (needle.empty()) - { - /// TODO: copy all the data without changing - throw Exception("Length of the second argument of function replace must be greater than 0.", ErrorCodes::ARGUMENT_OUT_OF_BOUND); - } - - String updated_needle = needle; - if (!match_type.empty() || collator != nullptr) - { - String mode_modifiers = re2Util::getRE2ModeModifiers(match_type, collator); - if (!mode_modifiers.empty()) - updated_needle = mode_modifiers + updated_needle; - } - re2_st::RE2 searcher(updated_needle); - int num_captures = std::min(searcher.NumberOfCapturingGroups() + 1, static_cast<int>(max_captures)); - - Instructions instructions = createInstructions(replacement, num_captures); - - for (size_t i = 0; i < size; ++i) - { - int from = i * n; - re2_st::StringPiece input(reinterpret_cast<const char *>(&data[0] + from), n); - - processString(input, res_data, res_offset, pos, occ, searcher, num_captures, instructions); - res_offsets[i] = res_offset; - } - } - static void constant(const String & input, const String & needle, const String & replacement, const Int64 & pos, const Int64 & occ, const String & match_type, TiDB::TiDBCollatorPtr collator, String & output) - { - ColumnString::Chars_t input_data; - input_data.insert(input_data.end(), input.begin(), input.end()); - ColumnString::Offsets input_offsets; - input_offsets.push_back(input_data.size() + 1); - ColumnString::Chars_t output_data; - ColumnString::Offsets output_offsets; - vector(input_data, input_offsets, needle, replacement, pos, occ, match_type, collator, output_data, output_offsets); - output = String(reinterpret_cast<const char *>(&output_data[0]), output_offsets[0] - 1); - } -}; - using FunctionTiDBRegexp = FunctionStringRegexp<NameTiDBRegexp>; using FunctionRegexpLike = FunctionStringRegexp<NameRegexpLike>; using FunctionRegexpInstr = FunctionStringRegexpInstr<NameRegexpInstr>; using FunctionRegexpSubstr = FunctionStringRegexpSubstr<NameRegexpSubstr>; -using FunctionReplaceRegexpOne = FunctionStringReplace<ReplaceRegexpImpl<true>, NameReplaceRegexpOne>; -using FunctionReplaceRegexpAll = FunctionStringReplace<ReplaceRegexpImpl<false>, NameReplaceRegexpAll>; +using FunctionRegexpReplace = FunctionStringRegexpReplace<NameRegexpReplace>; void registerFunctionsRegexp(FunctionFactory & factory) { - factory.registerFunction<FunctionReplaceRegexpOne>(); - factory.registerFunction<FunctionReplaceRegexpAll>(); factory.registerFunction<FunctionTiDBRegexp>(); factory.registerFunction<FunctionRegexpLike>(); factory.registerFunction<FunctionRegexpInstr>(); factory.registerFunction<FunctionRegexpSubstr>(); + factory.registerFunction<FunctionRegexpReplace>(); } - } // namespace DB diff --git a/dbms/src/Functions/FunctionsRegexp.h b/dbms/src/Functions/FunctionsRegexp.h index b1b167a4f15..4458677a5dc 100644 --- a/dbms/src/Functions/FunctionsRegexp.h +++ b/dbms/src/Functions/FunctionsRegexp.h @@ -80,17 +80,29 @@ struct NameRegexpSubstr { static constexpr auto name = "regexp_substr"; }; -struct NameReplaceRegexpOne +struct NameRegexpReplace { - static constexpr auto name = "replaceRegexpOne"; + static constexpr auto name = "regexp_replace"; }; -struct NameReplaceRegexpAll + +static constexpr std::string_view regexp_like_name(NameRegexpLike::name); + +enum class IntType { - static constexpr auto name = "replaceRegexpAll"; + UInt8 = 0, + UInt16, + UInt32, + UInt64, + Int8, + Int16, + Int32, + Int64 }; -static constexpr std::string_view regexp_like_name(NameRegexpLike::name); +using GetIntFuncPointerType = Int64 (*)(const void *, size_t); +namespace FunctionsRegexp +{ inline int getDefaultFlags() { int flags = 0; @@ -120,7 +132,7 @@ inline constexpr bool check_int_type() return std::is_same_v<T, UInt8> || std::is_same_v<T, UInt16> || std::is_same_v<T, UInt32> || std::is_same_v<T, UInt64> || std::is_same_v<T, Int8> || std::is_same_v<T, Int16> || std::is_same_v<T, Int32> || std::is_same_v<T, Int64>; } -Int64 getIntFromField(Field & field) +inline Int64 getIntFromField(Field & field) { switch (field.getType()) { @@ -133,28 +145,14 @@ Int64 getIntFromField(Field & field) } } -enum class IntType -{ - UInt8 = 0, - UInt16, - UInt32, - UInt64, - Int8, - Int16, - Int32, - Int64 -}; - template <typename T> -Int64 getInt(const void * container, size_t idx) +inline Int64 getInt(const void * container, size_t idx) { const auto * tmp = reinterpret_cast<const typename ColumnVector<T>::Container *>(container); return static_cast<Int64>((*tmp)[idx]); } -using GetIntFuncPointerType = Int64 (*)(const void *, size_t); - -GetIntFuncPointerType getGetIntFuncPointer(IntType int_type) +inline GetIntFuncPointerType getGetIntFuncPointer(IntType int_type) { switch (int_type) { @@ -194,6 +192,7 @@ inline void fillColumnStringWhenAllNull(decltype(ColumnString::create()) & col_r col_res_offsets[i] = offset; } } +} // namespace FunctionsRegexp template <bool is_const> class ParamString @@ -349,7 +348,7 @@ class ParamInt return const_int_val; else { - const auto * tmp = reinterpret_cast<const typename ColumnVector<std::enable_if_t<check_int_type<T>(), T>>::Container *>(int_container); + const auto * tmp = reinterpret_cast<const typename ColumnVector<std::enable_if_t<FunctionsRegexp::check_int_type<T>(), T>>::Container *>(int_container); return static_cast<Int64>((*tmp)[idx]); } } @@ -481,7 +480,7 @@ class Param private: // When this is a nullable param, we should ensure the null_map is not nullptr - inline void checkNullableLogic() + void checkNullableLogic() { if (is_nullable && (null_map == nullptr)) throw Exception("Nullable Param with nullptr null_map"); @@ -656,7 +655,7 @@ class ParamVariant { Field field; col_const->get(0, field); - auto data_int64 = field.isNull() ? -1 : getIntFromField(field); + auto data_int64 = field.isNull() ? -1 : FunctionsRegexp::getIntFromField(field); const auto & col_const_data = col_const->getDataColumnPtr(); if (col_const_data->isColumnNullable()) { @@ -757,6 +756,7 @@ class ParamVariant #define OCCUR_PV_VAR_NAME occur_pv #define RET_OP_PV_VAR_NAME return_option_pv #define MATCH_TYPE_PV_VAR_NAME match_type_pv +#define REPL_PV_VAR_NAME repl_pv #define EXPR_PARAM_PTR_VAR_NAME expr_param #define PAT_PARAM_PTR_VAR_NAME pat_param @@ -764,6 +764,7 @@ class ParamVariant #define OCCUR_PARAM_PTR_VAR_NAME occur_param #define RET_OP_PARAM_PTR_VAR_NAME return_option_param #define MATCH_TYPE_PARAM_PTR_VAR_NAME match_type_param +#define REPL_PARAM_PTR_VAR_NAME repl_param #define RES_ARG_VAR_NAME res_arg @@ -802,6 +803,7 @@ class FunctionStringRegexpBase { public: static constexpr size_t REGEXP_MIN_PARAM_NUM = 2; + static constexpr size_t REGEXP_REPLACE_MIN_PARAM_NUM = 3; // Max parameter number the regexp_xxx function could receive static constexpr size_t REGEXP_MAX_PARAM_NUM = 2; @@ -827,9 +829,9 @@ class FunctionStringRegexpBase final_pattern = fmt::format("({})", final_pattern); String match_type = match_type_param.getString(0); - final_pattern = addMatchTypeForPattern(final_pattern, match_type, collator); + final_pattern = FunctionsRegexp::addMatchTypeForPattern(final_pattern, match_type, collator); - int flags = getDefaultFlags(); + int flags = FunctionsRegexp::getDefaultFlags(); return std::make_unique<Regexps::Regexp>(final_pattern, flags); } @@ -998,7 +1000,7 @@ class FunctionStringRegexp : public FunctionStringRegexpBase return; } - int flags = getDefaultFlags(); + int flags = FunctionsRegexp::getDefaultFlags(); String expr = expr_param.getString(0); String pat = pat_param.getString(0); if (unlikely(pat.empty())) @@ -1006,7 +1008,7 @@ class FunctionStringRegexp : public FunctionStringRegexpBase String match_type = match_type_param.getString(0); - Regexps::Regexp regexp(addMatchTypeForPattern(pat, match_type, collator), flags); + Regexps::Regexp regexp(FunctionsRegexp::addMatchTypeForPattern(pat, match_type, collator), flags); ResultType res{regexp.match(expr)}; res_arg.column = res_arg.type->createColumnConst(col_size, toField(res)); return; @@ -1015,7 +1017,8 @@ class FunctionStringRegexp : public FunctionStringRegexpBase // Initialize result column auto col_res = ColumnVector<ResultType>::create(); typename ColumnVector<ResultType>::Container & vec_res = col_res->getData(); - vec_res.resize(col_size, 0); + ResultType default_val = 0; + vec_res.assign(col_size, default_val); constexpr bool has_nullable_col = ExprT::isNullableCol() || PatT::isNullableCol() || MatchTypeT::isNullableCol(); @@ -1030,7 +1033,8 @@ class FunctionStringRegexp : public FunctionStringRegexpBase { auto nullmap_col = ColumnUInt8::create(); typename ColumnUInt8::Container & nullmap = nullmap_col->getData(); - nullmap.resize(col_size, 1); + UInt8 default_val = 1; + nullmap.assign(col_size, default_val); res_arg.column = ColumnNullable::create(std::move(col_res), std::move(nullmap_col)); return; } @@ -1101,7 +1105,7 @@ class FunctionStringRegexp : public FunctionStringRegexpBase if (unlikely(pat.empty())) throw Exception(EMPTY_PAT_ERR_MSG); - auto regexp = createRegexpWithMatchType(pat, match_type, collator); + auto regexp = FunctionsRegexp::createRegexpWithMatchType(pat, match_type, collator); vec_res[i] = regexp.match(expr_ref.data, expr_ref.size); // match } @@ -1121,7 +1125,7 @@ class FunctionStringRegexp : public FunctionStringRegexpBase if (unlikely(pat.empty())) throw Exception(EMPTY_PAT_ERR_MSG); - auto regexp = createRegexpWithMatchType(pat, match_type, collator); + auto regexp = FunctionsRegexp::createRegexpWithMatchType(pat, match_type, collator); vec_res[i] = regexp.match(expr_ref.data, expr_ref.size); // match } @@ -1277,9 +1281,9 @@ class FunctionStringRegexpInstr : public FunctionStringRegexpBase size_t col_size = expr_param.getDataNum(); // Get function pointers to process the specific int type - GetIntFuncPointerType get_pos_func = getGetIntFuncPointer(pos_param.getIntType()); - GetIntFuncPointerType get_occur_func = getGetIntFuncPointer(occur_param.getIntType()); - GetIntFuncPointerType get_ret_op_func = getGetIntFuncPointer(ret_op_param.getIntType()); + GetIntFuncPointerType get_pos_func = FunctionsRegexp::getGetIntFuncPointer(pos_param.getIntType()); + GetIntFuncPointerType get_occur_func = FunctionsRegexp::getGetIntFuncPointer(occur_param.getIntType()); + GetIntFuncPointerType get_ret_op_func = FunctionsRegexp::getGetIntFuncPointer(ret_op_param.getIntType()); // Container will not be used when parm is const const void * pos_container = pos_param.getContainer(); @@ -1300,7 +1304,7 @@ class FunctionStringRegexpInstr : public FunctionStringRegexpBase return; } - int flags = getDefaultFlags(); + int flags = FunctionsRegexp::getDefaultFlags(); String expr = expr_param.getString(0); String match_type = match_type_param.getString(0); String pat = pat_param.getString(0); @@ -1308,7 +1312,7 @@ class FunctionStringRegexpInstr : public FunctionStringRegexpBase throw Exception(EMPTY_PAT_ERR_MSG); pat = fmt::format("({})", pat); - Regexps::Regexp regexp(addMatchTypeForPattern(pat, match_type, collator), flags); + Regexps::Regexp regexp(FunctionsRegexp::addMatchTypeForPattern(pat, match_type, collator), flags); ResultType res = regexp.instr(expr.c_str(), expr.size(), pos_const_val, occur_const_val, ret_op_const_val); res_arg.column = res_arg.type->createColumnConst(col_size, toField(res)); return; @@ -1317,7 +1321,8 @@ class FunctionStringRegexpInstr : public FunctionStringRegexpBase // Initialize result column auto col_res = ColumnVector<ResultType>::create(); typename ColumnVector<ResultType>::Container & vec_res = col_res->getData(); - vec_res.resize(col_size, 0); + ResultType default_val = 0; + vec_res.assign(col_size, default_val); constexpr bool has_nullable_col = ExprT::isNullableCol() || PatT::isNullableCol() || PosT::isNullableCol() || OccurT::isNullableCol() || RetOpT::isNullableCol() || MatchTypeT::isNullableCol(); @@ -1366,7 +1371,8 @@ class FunctionStringRegexpInstr : public FunctionStringRegexpBase { auto nullmap_col = ColumnUInt8::create(); typename ColumnUInt8::Container & nullmap = nullmap_col->getData(); - nullmap.resize(col_size, 1); + UInt8 default_val = 1; + nullmap.assign(col_size, default_val); res_arg.column = ColumnNullable::create(std::move(col_res), std::move(nullmap_col)); return; } @@ -1439,7 +1445,7 @@ class FunctionStringRegexpInstr : public FunctionStringRegexpBase GET_OCCUR_VALUE(i) GET_RET_OP_VALUE(i) match_type = match_type_param.getString(i); - auto regexp = createRegexpWithMatchType(pat, match_type, collator); + auto regexp = FunctionsRegexp::createRegexpWithMatchType(pat, match_type, collator); vec_res[i] = regexp.instr(expr_ref.data, expr_ref.size, pos, occur, ret_op); } @@ -1459,7 +1465,7 @@ class FunctionStringRegexpInstr : public FunctionStringRegexpBase GET_OCCUR_VALUE(i) GET_RET_OP_VALUE(i) match_type = match_type_param.getString(i); - auto regexp = createRegexpWithMatchType(pat, match_type, collator); + auto regexp = FunctionsRegexp::createRegexpWithMatchType(pat, match_type, collator); vec_res[i] = regexp.instr(expr_ref.data, expr_ref.size, pos, occur, ret_op); } @@ -1628,8 +1634,8 @@ class FunctionStringRegexpSubstr : public FunctionStringRegexpBase size_t col_size = expr_param.getDataNum(); // Get function pointers to process the specific int type - GetIntFuncPointerType get_pos_func = getGetIntFuncPointer(pos_param.getIntType()); - GetIntFuncPointerType get_occur_func = getGetIntFuncPointer(occur_param.getIntType()); + GetIntFuncPointerType get_pos_func = FunctionsRegexp::getGetIntFuncPointer(pos_param.getIntType()); + GetIntFuncPointerType get_occur_func = FunctionsRegexp::getGetIntFuncPointer(occur_param.getIntType()); // Container will not be used when parm is const const void * pos_container = pos_param.getContainer(); @@ -1652,12 +1658,12 @@ class FunctionStringRegexpSubstr : public FunctionStringRegexpBase if (unlikely(pat.empty())) throw Exception(EMPTY_PAT_ERR_MSG); - int flags = getDefaultFlags(); + int flags = FunctionsRegexp::getDefaultFlags(); String expr = expr_param.getString(0); String match_type = match_type_param.getString(0); pat = fmt::format("({})", pat); - Regexps::Regexp regexp(addMatchTypeForPattern(pat, match_type, collator), flags); + Regexps::Regexp regexp(FunctionsRegexp::addMatchTypeForPattern(pat, match_type, collator), flags); auto res = regexp.substr(expr.c_str(), expr.size(), pos_const_val, occur_const_val); if (res) res_arg.column = res_arg.type->createColumnConst(col_size, toField(res.value().toString())); @@ -1698,7 +1704,8 @@ class FunctionStringRegexpSubstr : public FunctionStringRegexpBase auto null_map_col = ColumnUInt8::create(); typename ColumnUInt8::Container & null_map = null_map_col->getData(); - null_map.resize(col_size, 1); + UInt8 default_val = 1; + null_map.assign(col_size, default_val); // Start to execute substr if (canMemorize<PatT, MatchTypeT>()) @@ -1709,7 +1716,7 @@ class FunctionStringRegexpSubstr : public FunctionStringRegexpBase regexp = memorize<true>(pat_param, match_type_param, collator); if (regexp == nullptr) { - fillColumnStringWhenAllNull(col_res, col_size); + FunctionsRegexp::fillColumnStringWhenAllNull(col_res, col_size); res_arg.column = ColumnNullable::create(std::move(col_res), std::move(null_map_col)); return; } @@ -1768,7 +1775,7 @@ class FunctionStringRegexpSubstr : public FunctionStringRegexpBase match_type = match_type_param.getString(i); pat = fmt::format("({})", pat); - auto regexp = createRegexpWithMatchType(pat, match_type, collator); + auto regexp = FunctionsRegexp::createRegexpWithMatchType(pat, match_type, collator); executeAndSetResult(regexp, col_res, null_map, i, expr_ref.data, expr_ref.size, pos, occur); } } @@ -1786,7 +1793,7 @@ class FunctionStringRegexpSubstr : public FunctionStringRegexpBase match_type = match_type_param.getString(i); pat = fmt::format("({})", pat); - auto regexp = createRegexpWithMatchType(pat, match_type, collator); + auto regexp = FunctionsRegexp::createRegexpWithMatchType(pat, match_type, collator); executeAndSetResult(regexp, col_res, null_map, i, expr_ref.data, expr_ref.size, pos, occur); } } @@ -1858,8 +1865,8 @@ class FunctionStringRegexpSubstr : public FunctionStringRegexpBase } else { + // 1 has been assigned when null_map is resized col_res->insertData("", 0); - null_map[idx] = 1; } } @@ -1874,11 +1881,382 @@ class FunctionStringRegexpSubstr : public FunctionStringRegexpBase #undef GET_MATCH_TYPE_ACTUAL_PARAM #undef EXECUTE_REGEXP_SUBSTR +#define EXECUTE_REGEXP_REPLACE() \ + do \ + { \ + REGEXP_CLASS_MEM_FUNC_IMPL_NAME(RES_ARG_VAR_NAME, *(EXPR_PARAM_PTR_VAR_NAME), *(PAT_PARAM_PTR_VAR_NAME), *(REPL_PARAM_PTR_VAR_NAME), *(POS_PARAM_PTR_VAR_NAME), *(OCCUR_PARAM_PTR_VAR_NAME), *(MATCH_TYPE_PARAM_PTR_VAR_NAME)); \ + } while (0); + +// Method to get actual match type param +#define GET_MATCH_TYPE_ACTUAL_PARAM() \ + do \ + { \ + GET_ACTUAL_STRING_PARAM(MATCH_TYPE_PV_VAR_NAME, MATCH_TYPE_PARAM_PTR_VAR_NAME, ({EXECUTE_REGEXP_REPLACE()})) \ + } while (0); + +// Method to get actual occur param +#define GET_OCCUR_ACTUAL_PARAM() \ + do \ + { \ + GET_ACTUAL_INT_PARAM(OCCUR_PV_VAR_NAME, OCCUR_PARAM_PTR_VAR_NAME, ({GET_MATCH_TYPE_ACTUAL_PARAM()})) \ + } while (0); + +// Method to get actual position param +#define GET_POS_ACTUAL_PARAM() \ + do \ + { \ + GET_ACTUAL_INT_PARAM(POS_PV_VAR_NAME, POS_PARAM_PTR_VAR_NAME, ({GET_OCCUR_ACTUAL_PARAM()})) \ + } while (0); + +// Method to get actual repl param +#define GET_REPL_ACTUAL_PARAM() \ + do \ + { \ + GET_ACTUAL_STRING_PARAM(REPL_PV_VAR_NAME, REPL_PARAM_PTR_VAR_NAME, ({GET_POS_ACTUAL_PARAM()})) \ + } while (0); + +// Method to get actual pattern param +#define GET_PAT_ACTUAL_PARAM() \ + do \ + { \ + GET_ACTUAL_STRING_PARAM(PAT_PV_VAR_NAME, PAT_PARAM_PTR_VAR_NAME, ({GET_REPL_ACTUAL_PARAM()})) \ + } while (0); + +// Method to get actual expression param +#define GET_EXPR_ACTUAL_PARAM() \ + do \ + { \ + GET_ACTUAL_STRING_PARAM(EXPR_PV_VAR_NAME, EXPR_PARAM_PTR_VAR_NAME, ({GET_PAT_ACTUAL_PARAM()})) \ + } while (0); + +// The entry to get actual params and execute regexp functions +#define GET_ACTUAL_PARAMS_AND_EXECUTE() \ + do \ + { \ + GET_EXPR_ACTUAL_PARAM() \ + } while (0); + +// Implementation of regexp_replace function +template <typename Name> +class FunctionStringRegexpReplace : public FunctionStringRegexpBase + , public IFunction +{ +public: + using ResultType = String; + static constexpr auto name = Name::name; + + static FunctionPtr create(const Context &) { return std::make_shared<FunctionStringRegexpReplace>(); } + String getName() const override { return name; } + bool isVariadic() const override { return true; } + void setCollator(const TiDB::TiDBCollatorPtr & collator_) override { collator = collator_; } + bool useDefaultImplementationForNulls() const override { return false; } + size_t getNumberOfArguments() const override { return 0; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + size_t arg_num = arguments.size(); + if (arg_num < REGEXP_REPLACE_MIN_PARAM_NUM) + throw Exception("Too few arguments", ErrorCodes::TOO_LESS_ARGUMENTS_FOR_FUNCTION); + else if (arg_num > REGEXP_REPLACE_MAX_PARAM_NUM) + throw Exception("Too many arguments", ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION); + + bool has_nullable_col = false; + bool has_data_type_nothing = false; + bool is_str_arg; + + // Check type of arguments + for (size_t i = 0; i < arg_num; ++i) + { + // Index at 0, 1 and 4 arguments should be string type, otherwise int type. + is_str_arg = (i <= 2 || i == 5); + checkInputArg(arguments[i], is_str_arg, &has_nullable_col, &has_data_type_nothing); + } + + if (has_data_type_nothing) + return std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>()); + + if (has_nullable_col) + return std::make_shared<DataTypeNullable>(std::make_shared<DataTypeString>()); + else + return std::make_shared<DataTypeString>(); + } + + template <typename ExprT, typename PatT, typename ReplT, typename PosT, typename OccurT, typename MatchTypeT> + void REGEXP_CLASS_MEM_FUNC_IMPL_NAME(ColumnWithTypeAndName & res_arg, const ExprT & expr_param, const PatT & pat_param, const ReplT & repl_param, const PosT & pos_param, const OccurT & occur_param, const MatchTypeT & match_type_param) const + { + size_t col_size = expr_param.getDataNum(); + + // Get function pointers to process the specific int type + GetIntFuncPointerType get_pos_func = FunctionsRegexp::getGetIntFuncPointer(pos_param.getIntType()); + GetIntFuncPointerType get_occur_func = FunctionsRegexp::getGetIntFuncPointer(occur_param.getIntType()); + + // Container will not be used when parm is const + const void * pos_container = pos_param.getContainer(); + const void * occur_container = occur_param.getContainer(); + + // Const value will not be used when param is not const + Int64 pos_const_val = PosT::isConst() ? pos_param.template getInt<Int64>(0) : -1; + Int64 occur_const_val = OccurT::isConst() ? occur_param.template getInt<Int64>(0) : -1; + + // Check if args are all const columns + if constexpr (ExprT::isConst() && PatT::isConst() && ReplT::isConst() && PosT::isConst() && OccurT::isConst() && MatchTypeT::isConst()) + { + if (expr_param.isNullAt(0) || pat_param.isNullAt(0) || repl_param.isNullAt(0) || pos_param.isNullAt(0) || occur_param.isNullAt(0) || match_type_param.isNullAt(0)) + { + res_arg.column = res_arg.type->createColumnConst(col_size, Null()); + return; + } + + String pat = pat_param.getString(0); + if (unlikely(pat.empty())) + throw Exception(EMPTY_PAT_ERR_MSG); + + pat = fmt::format("({})", pat); + StringRef expr_ref; + StringRef repl_ref; + expr_param.getStringRef(0, expr_ref); + repl_param.getStringRef(0, repl_ref); + String match_type = match_type_param.getString(0); + + ColumnString::Chars_t res_data; + IColumn::Offset offset = 0; + + Regexps::Regexp regexp(FunctionsRegexp::addMatchTypeForPattern(pat, match_type, collator), FunctionsRegexp::getDefaultFlags()); + regexp.replace(expr_ref.data, expr_ref.size, res_data, offset, repl_ref, pos_const_val, occur_const_val); + res_arg.column = res_arg.type->createColumnConst(col_size, toField(String(reinterpret_cast<const char *>(&res_data[0]), offset - 1))); + return; + } + + // Initialize result column + auto col_res = ColumnString::create(); + col_res->reserve(col_size); + + auto & res_data = col_res->getChars(); + auto & res_offsets = col_res->getOffsets(); + res_offsets.resize(col_size); + ColumnString::Offset res_offset = 0; + + constexpr bool has_nullable_col = ExprT::isNullableCol() || PatT::isNullableCol() || ReplT::isNullableCol() || PosT::isNullableCol() || OccurT::isNullableCol() || MatchTypeT::isNullableCol(); + +#define GET_POS_VALUE(idx) \ + do \ + { \ + if constexpr (PosT::isConst()) \ + pos = pos_const_val; \ + else \ + pos = get_pos_func(pos_container, idx); \ + } while (0); + +#define GET_OCCUR_VALUE(idx) \ + do \ + { \ + if constexpr (OccurT::isConst()) \ + occur = occur_const_val; \ + else \ + occur = get_occur_func(occur_container, idx); \ + } while (0); + + // Start to execute replace + if (canMemorize<PatT, MatchTypeT>()) + { + std::unique_ptr<Regexps::Regexp> regexp; + if (col_size > 0) + { + regexp = memorize<true>(pat_param, match_type_param, collator); + if (regexp == nullptr) + { + auto null_map_col = ColumnUInt8::create(); + typename ColumnUInt8::Container & null_map = null_map_col->getData(); + UInt8 default_val = 1; + null_map.assign(col_size, default_val); + FunctionsRegexp::fillColumnStringWhenAllNull(col_res, col_size); + res_arg.column = ColumnNullable::create(std::move(col_res), std::move(null_map_col)); + return; + } + } + + StringRef expr_ref; + StringRef repl_ref; + String pat; + Int64 pos; + Int64 occur; + String match_type; + + if constexpr (has_nullable_col) + { + auto null_map_col = ColumnUInt8::create(); + typename ColumnUInt8::Container & null_map = null_map_col->getData(); + null_map.resize(col_size); + + for (size_t i = 0; i < col_size; ++i) + { + if (expr_param.isNullAt(i) || repl_param.isNullAt(i) || pos_param.isNullAt(i) || occur_param.isNullAt(i)) + { + null_map[i] = 1; + res_data.resize(res_data.size() + 1); + res_data[res_offset++] = 0; + res_offsets[i] = res_offset; + continue; + } + + null_map[i] = 0; + expr_param.getStringRef(i, expr_ref); + repl_param.getStringRef(i, repl_ref); + GET_POS_VALUE(i) + GET_OCCUR_VALUE(i) + + regexp->replace(expr_ref.data, expr_ref.size, res_data, res_offset, repl_ref, pos, occur); + res_offsets[i] = res_offset; + } + res_arg.column = ColumnNullable::create(std::move(col_res), std::move(null_map_col)); + } + else + { + for (size_t i = 0; i < col_size; ++i) + { + expr_param.getStringRef(i, expr_ref); + repl_param.getStringRef(i, repl_ref); + GET_POS_VALUE(i) + GET_OCCUR_VALUE(i) + + regexp->replace(expr_ref.data, expr_ref.size, res_data, res_offset, repl_ref, pos, occur); + res_offsets[i] = res_offset; + } + res_arg.column = std::move(col_res); + } + } + else + { + StringRef expr_ref; + StringRef repl_ref; + String pat; + Int64 pos; + Int64 occur; + String match_type; + + if constexpr (has_nullable_col) + { + auto null_map_col = ColumnUInt8::create(); + typename ColumnUInt8::Container & null_map = null_map_col->getData(); + null_map.resize(col_size); + + for (size_t i = 0; i < col_size; ++i) + { + if (expr_param.isNullAt(i) || pat_param.isNullAt(i) || repl_param.isNullAt(i) || pos_param.isNullAt(i) || occur_param.isNullAt(i) || match_type_param.isNullAt(i)) + { + null_map[i] = 1; + res_data.resize(res_data.size() + 1); + res_data[res_offset++] = 0; + res_offsets[i] = res_offset; + continue; + } + + pat = pat_param.getString(i); + if (unlikely(pat.empty())) + throw Exception(EMPTY_PAT_ERR_MSG); + + null_map[i] = 0; + pat = fmt::format("({})", pat); + expr_param.getStringRef(i, expr_ref); + repl_param.getStringRef(i, repl_ref); + GET_POS_VALUE(i) + GET_OCCUR_VALUE(i) + match_type = match_type_param.getString(i); + + auto regexp = FunctionsRegexp::createRegexpWithMatchType(pat, match_type, collator); + regexp.replace(expr_ref.data, expr_ref.size, res_data, res_offset, repl_ref, pos, occur); + res_offsets[i] = res_offset; + } + res_arg.column = ColumnNullable::create(std::move(col_res), std::move(null_map_col)); + } + else + { + for (size_t i = 0; i < col_size; ++i) + { + pat = pat_param.getString(i); + if (unlikely(pat.empty())) + throw Exception(EMPTY_PAT_ERR_MSG); + + pat = fmt::format("({})", pat); + expr_param.getStringRef(i, expr_ref); + repl_param.getStringRef(i, repl_ref); + GET_POS_VALUE(i) + GET_OCCUR_VALUE(i) + match_type = match_type_param.getString(i); + + auto regexp = FunctionsRegexp::createRegexpWithMatchType(pat, match_type, collator); + regexp.replace(expr_ref.data, expr_ref.size, res_data, res_offset, repl_ref, pos, occur); + res_offsets[i] = res_offset; + } + res_arg.column = std::move(col_res); + } + } +#undef GET_OCCUR_VALUE +#undef GET_POS_VALUE + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) const override + { + // Do something related with nullable columns + NullPresence null_presence = getNullPresense(block, arguments); + + if (null_presence.has_null_constant) + { + block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(block.rows(), Null()); + return; + } + + const ColumnPtr & col_expr = block.getByPosition(arguments[0]).column; + const ColumnPtr & col_pat = block.getByPosition(arguments[1]).column; + const ColumnPtr & col_rep = block.getByPosition(arguments[2]).column; + + size_t arg_num = arguments.size(); + auto & RES_ARG_VAR_NAME = block.getByPosition(result); + + ColumnPtr col_pos; + ColumnPtr col_occur; + ColumnPtr col_match_type; + + // Go through cases to get arguments + switch (arg_num) + { + case REGEXP_REPLACE_MAX_PARAM_NUM: + col_match_type = block.getByPosition(arguments[5]).column; + case REGEXP_REPLACE_MAX_PARAM_NUM - 1: + col_occur = block.getByPosition(arguments[4]).column; + case REGEXP_REPLACE_MAX_PARAM_NUM - 2: + col_pos = block.getByPosition(arguments[3]).column; + }; + + size_t col_size = col_expr->size(); + + ParamVariant EXPR_PV_VAR_NAME(col_expr, col_size, StringRef("", 0)); + ParamVariant PAT_PV_VAR_NAME(col_pat, col_size, StringRef("", 0)); + ParamVariant REPL_PV_VAR_NAME(col_rep, col_size, StringRef("", 0)); + ParamVariant POS_PV_VAR_NAME(col_pos, col_size, 1); + ParamVariant OCCUR_PV_VAR_NAME(col_occur, col_size, 0); + ParamVariant MATCH_TYPE_PV_VAR_NAME(col_match_type, col_size, StringRef("", 0)); + + GET_ACTUAL_PARAMS_AND_EXECUTE() + } + +private: + TiDB::TiDBCollatorPtr collator = nullptr; +}; + +#undef GET_ACTUAL_PARAMS_AND_EXECUTE +#undef GET_EXPR_ACTUAL_PARAM +#undef GET_PAT_ACTUAL_PARAM +#undef GET_REPL_ACTUAL_PARAM +#undef GET_POS_ACTUAL_PARAM +#undef GET_OCCUR_ACTUAL_PARAM +#undef GET_MATCH_TYPE_ACTUAL_PARAM +#undef EXECUTE_REGEXP_REPLACE + #undef GET_ACTUAL_INT_PARAM #undef GET_ACTUAL_STRING_PARAM #undef REGEXP_CLASS_MEM_FUNC_IMPL_NAME -#undef RES_ARG_VAR_NAME - +#undef REPL_PARAM_PTR_VAR_NAME #undef MATCH_TYPE_PARAM_PTR_VAR_NAME #undef RET_OP_PARAM_PTR_VAR_NAME #undef OCCUR_PARAM_PTR_VAR_NAME @@ -1886,6 +2264,7 @@ class FunctionStringRegexpSubstr : public FunctionStringRegexpBase #undef PAT_PARAM_PTR_VAR_NAME #undef EXPR_PARAM_PTR_VAR_NAME +#undef REPL_PV_VAR_NAME #undef MATCH_TYPE_PV_VAR_NAME #undef RET_OP_PV_VAR_NAME #undef OCCUR_PV_VAR_NAME diff --git a/dbms/src/Functions/tests/gtest_regexp.cpp b/dbms/src/Functions/tests/gtest_regexp.cpp index bdbd61f92a5..d49cca814c5 100644 --- a/dbms/src/Functions/tests/gtest_regexp.cpp +++ b/dbms/src/Functions/tests/gtest_regexp.cpp @@ -2243,6 +2243,17 @@ std::vector<String> getPatVec(const std::vector<T> & test_cases) return vecs; } +template <typename T> +std::vector<String> getReplVec(const std::vector<T> & test_cases) +{ + std::vector<String> vecs; + vecs.reserve(test_cases.size()); + for (const auto & elem : test_cases) + vecs.push_back(elem.replacement); + + return vecs; +} + template <typename T> std::vector<Int64> getPosVec(const std::vector<T> & test_cases) { @@ -2749,7 +2760,15 @@ struct RegexpSubstrCase , match_type(mt) {} - static void setVecsWithoutNullMap(int param_num, const std::vector<RegexpSubstrCase> test_cases, std::vector<String> & results, std::vector<String> & exprs, std::vector<String> & pats, std::vector<Int64> & positions, std::vector<Int64> & occurs, std::vector<String> & match_types) + static void setVecsWithoutNullMap( + int param_num, + const std::vector<RegexpSubstrCase> test_cases, + std::vector<String> & results, + std::vector<String> & exprs, + std::vector<String> & pats, + std::vector<Int64> & positions, + std::vector<Int64> & occurs, + std::vector<String> & match_types) { results = getResultVec<String>(test_cases); switch (param_num) @@ -2769,7 +2788,16 @@ struct RegexpSubstrCase } } - static void setVecsWithNullMap(int param_num, const std::vector<RegexpSubstrCase> test_cases, std::vector<String> & results, std::vector<std::vector<UInt8>> & null_map, std::vector<String> & exprs, std::vector<String> & pats, std::vector<Int64> & positions, std::vector<Int64> & occurs, std::vector<String> & match_types) + static void setVecsWithNullMap( + int param_num, + const std::vector<RegexpSubstrCase> test_cases, + std::vector<String> & results, + std::vector<std::vector<UInt8>> & null_map, + std::vector<String> & exprs, + std::vector<String> & pats, + std::vector<Int64> & positions, + std::vector<Int64> & occurs, + std::vector<String> & match_types) { null_map.clear(); null_map.resize(REGEXP_SUBSTR_MAX_PARAM_NUM); @@ -3096,74 +3124,105 @@ TEST_F(Regexp, RegexpSubstr) } } -TEST_F(Regexp, testRegexpReplaceMatchType) +struct RegexpReplaceCase { - String res; - const auto * binary_collator = TiDB::ITiDBCollator::getCollator(TiDB::ITiDBCollator::BINARY); - const auto * ci_collator = TiDB::ITiDBCollator::getCollator(TiDB::ITiDBCollator::UTF8MB4_GENERAL_CI); - DB::ReplaceRegexpImpl<false>::constant("a\nB\nc", "(?m)(?i)^b", "xxx", 1, 0, "", nullptr, res); - ASSERT_TRUE(res == "a\nxxx\nc"); - DB::ReplaceRegexpImpl<false>::constant("a\nB\nc", "^b", "xxx", 1, 0, "mi", nullptr, res); - ASSERT_TRUE(res == "a\nxxx\nc"); - DB::ReplaceRegexpImpl<false>::constant("a\nB\nc", "^b", "xxx", 1, 0, "m", ci_collator, res); - ASSERT_TRUE(res == "a\nxxx\nc"); - DB::ReplaceRegexpImpl<false>::constant("a\nB\nc", "^b", "xxx", 1, 0, "mi", binary_collator, res); - ASSERT_TRUE(res == "a\nB\nc"); - DB::ReplaceRegexpImpl<false>::constant("a\nB\nc", "^b", "xxx", 1, 0, "i", nullptr, res); - ASSERT_TRUE(res == "a\nB\nc"); - DB::ReplaceRegexpImpl<false>::constant("a\nB\nc", "^b", "xxx", 1, 0, "m", nullptr, res); - ASSERT_TRUE(res == "a\nB\nc"); - DB::ReplaceRegexpImpl<false>::constant("a\nB\n", "^a.*b", "xxx", 1, 0, "", nullptr, res); - ASSERT_TRUE(res == "a\nB\n"); - DB::ReplaceRegexpImpl<false>::constant("a\nB\n", "^a.*B", "xxx", 1, 0, "s", nullptr, res); - ASSERT_TRUE(res == "xxx\n"); - DB::ReplaceRegexpImpl<false>::constant("a\nB\n", "^a.*b", "xxx", 1, 0, "is", nullptr, res); - ASSERT_TRUE(res == "xxx\n"); -} + RegexpReplaceCase(const String & res, const String & expr, const String & pat, const String & repl, Int64 pos = 1, Int64 occur = 1, const String & mt = "") + : result(res) + , expression(expr) + , pattern(pat) + , replacement(repl) + , position(pos) + , occurrence(occur) + , match_type(mt) + {} -TEST_F(Regexp, testRegexpReplaceMySQLCases) -{ - // Test based on https://github.com/mysql/mysql-server/blob/mysql-cluster-8.0.17/mysql-test/t/regular_expressions_utf-8.test - String res; - DB::ReplaceRegexpImpl<false>::constant("aaa", "a", "X", 1, 0, "", nullptr, res); - ASSERT_TRUE(res == "XXX"); - DB::ReplaceRegexpImpl<false>::constant("abc", "b", "X", 1, 0, "", nullptr, res); - ASSERT_TRUE(res == "aXc"); - DB::ReplaceRegexpImpl<false>::constant("aaabbccbbddaa", "b+", "X", 1, 1, "", nullptr, res); - ASSERT_TRUE(res == "aaaXccbbddaa"); - DB::ReplaceRegexpImpl<false>::constant("aaabbccbbddaa", "b+", "X", 1, 2, "", nullptr, res); - ASSERT_TRUE(res == "aaabbccXddaa"); - DB::ReplaceRegexpImpl<false>::constant("aaabbccbbddaa", "(b+)", "<\\1>", 1, 2, "", nullptr, res); - ASSERT_TRUE(res == "aaabbcc<bb>ddaa"); - DB::ReplaceRegexpImpl<false>::constant("aaabbccbbddaa", "x+", "x", 1, 0, "", nullptr, res); - ASSERT_TRUE(res == "aaabbccbbddaa"); - DB::ReplaceRegexpImpl<false>::constant("aaabbccbbddaa", "b+", "x", 1, 0, "", nullptr, res); - ASSERT_TRUE(res == "aaaxccxddaa"); - DB::ReplaceRegexpImpl<false>::constant("aaab", "b", "x", 1, 2, "", nullptr, res); - ASSERT_TRUE(res == "aaab"); - DB::ReplaceRegexpImpl<false>::constant("aaabccc", "b", "x", 1, 2, "", nullptr, res); - ASSERT_TRUE(res == "aaabccc"); - DB::ReplaceRegexpImpl<false>::constant("abcbdb", "b", "X", 1, 0, "", nullptr, res); - ASSERT_TRUE(res == "aXcXdX"); - DB::ReplaceRegexpImpl<false>::constant("aaabcbdb", "b", "X", 1, 0, "", nullptr, res); - ASSERT_TRUE(res == "aaaXcXdX"); - DB::ReplaceRegexpImpl<false>::constant("aaabcbdb", "b", "X", 2, 0, "", nullptr, res); - ASSERT_TRUE(res == "aaaXcXdX"); - DB::ReplaceRegexpImpl<false>::constant("aaabcbdb", "b", "X", 3, 0, "", nullptr, res); - ASSERT_TRUE(res == "aaaXcXdX"); - DB::ReplaceRegexpImpl<false>::constant("aaa", "a", "X", 2, 0, "", nullptr, res); - ASSERT_TRUE(res == "aXX"); - DB::ReplaceRegexpImpl<false>::constant("aaa", "a", "XX", 2, 0, "", nullptr, res); - ASSERT_TRUE(res == "aXXXX"); - DB::ReplaceRegexpImpl<false>::constant("c b b", "^([[:alpha:]]+)[[:space:]].*$", "\\1", 1, 0, "", nullptr, res); - ASSERT_TRUE(res == "c"); - DB::ReplaceRegexpImpl<false>::constant("\U0001F450\U0001F450\U0001F450", ".", "a", 2, 0, "", nullptr, res); - ASSERT_TRUE(res == "\U0001F450aa"); - DB::ReplaceRegexpImpl<false>::constant("\U0001F450\U0001F450\U0001F450", ".", "a", 2, 2, "", nullptr, res); - ASSERT_TRUE(res == "\U0001F450\U0001F450a"); -} + RegexpReplaceCase(const String & res, const std::vector<UInt8> & null_map_, const String & expr, const String & pat, const String & repl, Int64 pos = 1, Int64 occur = 1, const String & mt = "") + : result(res) + , null_map(null_map_) + , expression(expr) + , pattern(pat) + , replacement(repl) + , position(pos) + , occurrence(occur) + , match_type(mt) + {} + + static void setVecsWithoutNullMap( + int param_num, + const std::vector<RegexpReplaceCase> test_cases, + std::vector<String> & results, + std::vector<String> & exprs, + std::vector<String> & pats, + std::vector<String> & repls, + std::vector<Int64> & positions, + std::vector<Int64> & occurs, + std::vector<String> & match_types) + { + results = getResultVec<String>(test_cases); + switch (param_num) + { + case 6: + match_types = getMatchTypeVec(test_cases); + case 5: + occurs = getOccurVec(test_cases); + case 4: + positions = getPosVec(test_cases); + case 3: + repls = getReplVec(test_cases); + pats = getPatVec(test_cases); + exprs = getExprVec(test_cases); + break; + default: + throw DB::Exception("Invalid param_num"); + } + } -TEST_F(Regexp, testRegexpReplace) + static void setVecsWithNullMap( + int param_num, + const std::vector<RegexpReplaceCase> test_cases, + std::vector<String> & results, + std::vector<std::vector<UInt8>> & null_map, + std::vector<String> & exprs, + std::vector<String> & pats, + std::vector<String> repls, + std::vector<Int64> & positions, + std::vector<Int64> & occurs, + std::vector<String> & match_types) + { + null_map.clear(); + null_map.resize(REGEXP_REPLACE_MAX_PARAM_NUM); + for (const auto & elem : test_cases) + { + null_map[EXPR_NULL_MAP_IDX].push_back(elem.null_map[EXPR_NULL_MAP_IDX]); + null_map[PAT_NULL_MAP_IDX].push_back(elem.null_map[PAT_NULL_MAP_IDX]); + null_map[POS_NULL_MAP_IDX].push_back(elem.null_map[POS_NULL_MAP_IDX]); + null_map[REPLACE_NULL_MAP_IDX].push_back(elem.null_map[REPLACE_NULL_MAP_IDX]); + null_map[OCCUR_NULL_MAP_IDX].push_back(elem.null_map[OCCUR_NULL_MAP_IDX]); + null_map[MATCH_TYPE_NULL_MAP_IDX].push_back(elem.null_map[MATCH_TYPE_NULL_MAP_IDX]); + } + + setVecsWithoutNullMap(param_num, test_cases, results, exprs, pats, repls, positions, occurs, match_types); + } + + const static UInt8 REGEXP_REPLACE_MAX_PARAM_NUM = 6; + const static UInt8 EXPR_NULL_MAP_IDX = 0; + const static UInt8 PAT_NULL_MAP_IDX = 1; + const static UInt8 REPLACE_NULL_MAP_IDX = 2; + const static UInt8 POS_NULL_MAP_IDX = 3; + const static UInt8 OCCUR_NULL_MAP_IDX = 4; + const static UInt8 MATCH_TYPE_NULL_MAP_IDX = 5; + + String result; + std::vector<UInt8> null_map; + String expression; + String pattern; + String replacement; + Int64 position; + Int64 occurrence; + String match_type; +}; + +TEST_F(Regexp, RegexpReplace) { const auto * binary_collator = TiDB::ITiDBCollator::getCollator(TiDB::ITiDBCollator::BINARY); auto string_type = std::make_shared<DataTypeString>(); @@ -3211,23 +3270,23 @@ TEST_F(Regexp, testRegexpReplace) { /// test regexp_replace(str, pattern, replacement) ASSERT_COLUMN_EQ(createConstColumn<String>(row_size, results[i]), - executeFunction("replaceRegexpAll", {createConstColumn<String>(row_size, input_strings[i]), createConstColumn<String>(row_size, patterns[i]), createConstColumn<String>(row_size, replacements[i])}, nullptr, true)); + executeFunction("regexp_replace", {createConstColumn<String>(row_size, input_strings[i]), createConstColumn<String>(row_size, patterns[i]), createConstColumn<String>(row_size, replacements[i])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos) ASSERT_COLUMN_EQ(createConstColumn<String>(row_size, results_with_pos[i]), - executeFunction("replaceRegexpAll", {createConstColumn<String>(row_size, input_strings[i]), createConstColumn<String>(row_size, patterns[i]), createConstColumn<String>(row_size, replacements[i]), createConstColumn<Int64>(row_size, pos[i])}, nullptr, true)); + executeFunction("regexp_replace", {createConstColumn<String>(row_size, input_strings[i]), createConstColumn<String>(row_size, patterns[i]), createConstColumn<String>(row_size, replacements[i]), createConstColumn<Int64>(row_size, pos[i])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ) ASSERT_COLUMN_EQ(createConstColumn<String>(row_size, results_with_pos_occ[i]), - executeFunction("replaceRegexpAll", {createConstColumn<String>(row_size, input_strings[i]), createConstColumn<String>(row_size, patterns[i]), createConstColumn<String>(row_size, replacements[i]), createConstColumn<Int64>(row_size, pos[i]), createConstColumn<Int64>(row_size, occ[i])}, nullptr, true)); + executeFunction("regexp_replace", {createConstColumn<String>(row_size, input_strings[i]), createConstColumn<String>(row_size, patterns[i]), createConstColumn<String>(row_size, replacements[i]), createConstColumn<Int64>(row_size, pos[i]), createConstColumn<Int64>(row_size, occ[i])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ, match_type) ASSERT_COLUMN_EQ(createConstColumn<String>(row_size, results_with_pos_occ_match_type[i]), - executeFunction("replaceRegexpAll", {createConstColumn<String>(row_size, input_strings[i]), createConstColumn<String>(row_size, patterns[i]), createConstColumn<String>(row_size, replacements[i]), createConstColumn<Int64>(row_size, pos[i]), createConstColumn<Int64>(row_size, occ[i]), createConstColumn<String>(row_size, match_types[i])}, nullptr, true)); + executeFunction("regexp_replace", {createConstColumn<String>(row_size, input_strings[i]), createConstColumn<String>(row_size, patterns[i]), createConstColumn<String>(row_size, replacements[i]), createConstColumn<Int64>(row_size, pos[i]), createConstColumn<Int64>(row_size, occ[i]), createConstColumn<String>(row_size, match_types[i])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ, match_type) with binary collator ASSERT_COLUMN_EQ(createConstColumn<String>(row_size, results_with_pos_occ_match_type_binary[i]), - executeFunction("replaceRegexpAll", {createConstColumn<String>(row_size, input_strings[i]), createConstColumn<String>(row_size, patterns[i]), createConstColumn<String>(row_size, replacements[i]), createConstColumn<Int64>(row_size, pos[i]), createConstColumn<Int64>(row_size, occ[i]), createConstColumn<String>(row_size, match_types[i])}, binary_collator, true)); + executeFunction("regexp_replace", {createConstColumn<String>(row_size, input_strings[i]), createConstColumn<String>(row_size, patterns[i]), createConstColumn<String>(row_size, replacements[i]), createConstColumn<Int64>(row_size, pos[i]), createConstColumn<Int64>(row_size, occ[i]), createConstColumn<String>(row_size, match_types[i])}, binary_collator, true)); } /// case 2. regexp_replace(const, const, const [, const, const ,const]) with null value @@ -3236,74 +3295,147 @@ TEST_F(Regexp, testRegexpReplace) /// test regexp_replace(str, pattern, replacement) bool null_result = input_string_nulls[i] || pattern_nulls[i] || replacement_nulls[i]; ASSERT_COLUMN_EQ(null_result ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, results[i]), - executeFunction("replaceRegexpAll", {input_string_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, input_strings[i]), pattern_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, patterns[i]), replacement_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, replacements[i])}, nullptr, true)); + executeFunction("regexp_replace", {input_string_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, input_strings[i]), pattern_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, patterns[i]), replacement_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, replacements[i])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos) null_result = null_result || pos_nulls[i]; ASSERT_COLUMN_EQ(null_result ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, results_with_pos[i]), - executeFunction("replaceRegexpAll", {input_string_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, input_strings[i]), pattern_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, patterns[i]), replacement_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, replacements[i]), pos_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, pos[i])}, nullptr, true)); + executeFunction("regexp_replace", {input_string_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, input_strings[i]), pattern_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, patterns[i]), replacement_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, replacements[i]), pos_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, pos[i])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ) null_result = null_result || occ_nulls[i]; ASSERT_COLUMN_EQ(null_result ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, results_with_pos_occ[i]), - executeFunction("replaceRegexpAll", {input_string_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, input_strings[i]), pattern_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, patterns[i]), replacement_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, replacements[i]), pos_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, pos[i]), occ_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, occ[i])}, nullptr, true)); + executeFunction("regexp_replace", {input_string_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, input_strings[i]), pattern_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, patterns[i]), replacement_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, replacements[i]), pos_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, pos[i]), occ_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, occ[i])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ, match_type) null_result = null_result || match_type_nulls[i]; ASSERT_COLUMN_EQ(null_result ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, results_with_pos_occ_match_type[i]), - executeFunction("replaceRegexpAll", {input_string_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, input_strings[i]), pattern_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, patterns[i]), replacement_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, replacements[i]), pos_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, pos[i]), occ_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, occ[i]), match_type_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, match_types[i])}, nullptr, true)); + executeFunction("regexp_replace", {input_string_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, input_strings[i]), pattern_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, patterns[i]), replacement_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, replacements[i]), pos_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, pos[i]), occ_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, occ[i]), match_type_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, match_types[i])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ, match_type) with binary collator ASSERT_COLUMN_EQ(null_result ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, results_with_pos_occ_match_type_binary[i]), - executeFunction("replaceRegexpAll", {input_string_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, input_strings[i]), pattern_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, patterns[i]), replacement_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, replacements[i]), pos_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, pos[i]), occ_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, occ[i]), match_type_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, match_types[i])}, binary_collator, true)); + executeFunction("regexp_replace", {input_string_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, input_strings[i]), pattern_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, patterns[i]), replacement_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, replacements[i]), pos_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, pos[i]), occ_nulls[i] ? const_int64_null_column : createConstColumn<Nullable<Int64>>(row_size, occ[i]), match_type_nulls[i] ? const_string_null_column : createConstColumn<Nullable<String>>(row_size, match_types[i])}, binary_collator, true)); } /// case 3 regexp_replace(vector, const, const[, const, const, const]) { /// test regexp_replace(str, pattern, replacement) ASSERT_COLUMN_EQ(createColumn<String>(vec_results), - executeFunction("replaceRegexpAll", {createColumn<String>(input_strings), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0])}, nullptr, true)); + executeFunction("regexp_replace", {createColumn<String>(input_strings), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos) ASSERT_COLUMN_EQ(createColumn<String>(vec_results_with_pos), - executeFunction("replaceRegexpAll", {createColumn<String>(input_strings), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0])}, nullptr, true)); + executeFunction("regexp_replace", {createColumn<String>(input_strings), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ) ASSERT_COLUMN_EQ(createColumn<String>(vec_results_with_pos_occ), - executeFunction("replaceRegexpAll", {createColumn<String>(input_strings), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0])}, nullptr, true)); + executeFunction("regexp_replace", {createColumn<String>(input_strings), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ, match_type) ASSERT_COLUMN_EQ(createColumn<String>(vec_results_with_pos_occ_match_type), - executeFunction("replaceRegexpAll", {createColumn<String>(input_strings), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0]), createConstColumn<String>(row_size, match_types[0])}, nullptr, true)); + executeFunction("regexp_replace", {createColumn<String>(input_strings), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0]), createConstColumn<String>(row_size, match_types[0])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ, match_type) with binary collator ASSERT_COLUMN_EQ(createColumn<String>(vec_results_with_pos_occ_match_type_binary), - executeFunction("replaceRegexpAll", {createColumn<String>(input_strings), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0]), createConstColumn<String>(row_size, match_types[0])}, binary_collator, true)); + executeFunction("regexp_replace", {createColumn<String>(input_strings), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0]), createConstColumn<String>(row_size, match_types[0])}, binary_collator, true)); } /// case 4 regexp_replace(vector, const, const[, const, const, const]) with null value { /// test regexp_replace(str, pattern, replacement) ASSERT_COLUMN_EQ(createNullableVectorColumn<String>(vec_results, input_string_nulls), - executeFunction("replaceRegexpAll", {createNullableVectorColumn<String>(input_strings, input_string_nulls), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0])}, nullptr, true)); + executeFunction("regexp_replace", {createNullableVectorColumn<String>(input_strings, input_string_nulls), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos) ASSERT_COLUMN_EQ(createNullableVectorColumn<String>(vec_results_with_pos, input_string_nulls), - executeFunction("replaceRegexpAll", {createNullableVectorColumn<String>(input_strings, input_string_nulls), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0])}, nullptr, true)); + executeFunction("regexp_replace", {createNullableVectorColumn<String>(input_strings, input_string_nulls), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ) ASSERT_COLUMN_EQ(createNullableVectorColumn<String>(vec_results_with_pos_occ, input_string_nulls), - executeFunction("replaceRegexpAll", {createNullableVectorColumn<String>(input_strings, input_string_nulls), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0])}, nullptr, true)); + executeFunction("regexp_replace", {createNullableVectorColumn<String>(input_strings, input_string_nulls), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ, match_type) ASSERT_COLUMN_EQ(createNullableVectorColumn<String>(vec_results_with_pos_occ_match_type, input_string_nulls), - executeFunction("replaceRegexpAll", {createNullableVectorColumn<String>(input_strings, input_string_nulls), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0]), createConstColumn<String>(row_size, match_types[0])}, nullptr, true)); - + executeFunction("regexp_replace", {createNullableVectorColumn<String>(input_strings, input_string_nulls), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0]), createConstColumn<String>(row_size, match_types[0])}, nullptr, true)); /// test regexp_replace(str, pattern, replacement, pos, occ, match_type) with binary collator ASSERT_COLUMN_EQ(createNullableVectorColumn<String>(vec_results_with_pos_occ_match_type_binary, input_string_nulls), - executeFunction("replaceRegexpAll", {createNullableVectorColumn<String>(input_strings, input_string_nulls), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0]), createConstColumn<String>(row_size, match_types[0])}, binary_collator, true)); + executeFunction("regexp_replace", {createNullableVectorColumn<String>(input_strings, input_string_nulls), createConstColumn<String>(row_size, patterns[0]), createConstColumn<String>(row_size, replacements[0]), createConstColumn<Int64>(row_size, pos[0]), createConstColumn<Int64>(row_size, occ[0]), createConstColumn<String>(row_size, match_types[0])}, binary_collator, true)); + } + + std::vector<RegexpReplaceCase> test_cases; + std::vector<std::vector<UInt8>> null_maps; + std::vector<String> exprs; + std::vector<String> repls; + std::vector<Int64> positions; + std::vector<Int64> occurs; + + /// case 5 regexp_replace(vector, vector, vector[, vector, vector, vector]) + { + test_cases = {{"taa", "ttifl", "tifl", "aa", 1, 0, ""}, + {"aaaaaa", "121212", "1.", "aa", 1, 0, ""}, + {"aa1212", "121212", "1.", "aa", 1, 1, ""}, + {"12aa12", "121212", "1.", "aa", 1, 2, ""}, + {"1212aa", "121212", "1.", "aa", 1, 3, ""}, + {"121212", "121212", "1.", "aa", 1, 4, ""}, + {"啊ah好a哈哈", "啊a哈a哈哈", "哈", "h好", 1, 1, ""}, + {"啊a哈ah好哈", "啊a哈a哈哈", "哈", "h好", 4, 1, ""}, + {"啊a哈a哈哈", "啊a哈a哈哈", "哈", "h好", 4, 5, ""}, + {"aa", "\n", ".", "aa", 1, 0, "s"}, + {"12aa34", "12\n34", ".", "aa", 3, 1, "s"}}; + RegexpReplaceCase::setVecsWithoutNullMap(6, test_cases, results, exprs, patterns, repls, positions, occurs, match_types); + results = getResultVec<String>(test_cases); + ASSERT_COLUMN_EQ(createColumn<String>(results), + executeFunction( + "regexp_replace", + createColumn<String>(exprs), + createColumn<String>(patterns), + createColumn<String>(repls), + createColumn<Int32>(positions), + createColumn<Int32>(occurs), + createColumn<String>(match_types))); + } + + /// case 6 regexp_replace(vector, vector, const[, const, const, vector]) with null value + { + test_cases = {{"taa", {0, 0, 1, 0, 0, 0}, "ttifl", "tifl", "aa", 1, 0, ""}, + {"aaaaaa", {0, 0, 0, 0, 0, 0}, "121212", "1.", "aa", 1, 0, ""}, + {"aa1212", {0, 1, 0, 0, 0, 0}, "121212", "1.", "aa", 1, 1, ""}, + {"12aa12", {0, 0, 0, 0, 0, 0}, "121212", "1.", "aa", 1, 2, ""}, + {"1212aa", {0, 1, 0, 0, 0, 0}, "121212", "1.", "aa", 1, 3, ""}, + {"121212", {0, 0, 0, 0, 0, 0}, "121212", "1.", "aa", 1, 4, ""}, + {"啊ah好a哈哈", {0, 1, 0, 0, 0, 0}, "啊a哈a哈哈", "哈", "h好", 1, 1, ""}, + {"啊a哈ah好哈", {0, 0, 0, 0, 0, 0}, "啊a哈a哈哈", "哈", "h好", 4, 1, ""}, + {"啊a哈a哈哈", {0, 1, 0, 0, 0, 0}, "啊a哈a哈哈", "哈", "h好", 4, 5, ""}, + {"aa", {0, 1, 0, 0, 0, 0}, "\n", ".", "aa", 1, 0, "s"}, + {"12aa34", {0, 0, 0, 0, 0, 0}, "12\n34", ".", "aa", 3, 1, "s"}}; + RegexpReplaceCase::setVecsWithNullMap(6, test_cases, results, null_maps, exprs, patterns, repls, positions, occurs, match_types); + results = getResultVec<String>(test_cases); + ASSERT_COLUMN_EQ(createNullableVectorColumn<String>(results, null_maps[RegexpReplaceCase::PAT_NULL_MAP_IDX]), + executeFunction( + "regexp_replace", + createColumn<String>(exprs), + createNullableVectorColumn<String>(patterns, null_maps[RegexpReplaceCase::PAT_NULL_MAP_IDX]), + createColumn<String>(repls), + createColumn<Int32>(positions), + createColumn<Int32>(occurs), + createColumn<String>(match_types))); + } + + /// case 7: test some special cases + { + // test empty expr + ASSERT_COLUMN_EQ(createColumn<String>({"aa", "aa", "aa", "", ""}), + executeFunction( + "regexp_replace", + {createColumn<String>({"", "", "", "", ""}), + createColumn<String>({"^$", "^$", "^$", "^$", "12"}), + createColumn<String>({"aa", "aa", "aa", "aa", "aa"}), + createColumn<Int64>({1, 1, 1, 1, 1}), + createColumn<Int64>({-1, 0, 1, 2, 3})}, + nullptr, + true)); } } } // namespace tests diff --git a/tests/fullstack-test/expr/regexp.test b/tests/fullstack-test/expr/regexp.test index e4f8e002e15..5f09b489fec 100644 --- a/tests/fullstack-test/expr/regexp.test +++ b/tests/fullstack-test/expr/regexp.test @@ -14,7 +14,7 @@ # test regexp and regexp_like mysql> drop table if exists test.t -mysql> create table test.t (data varchar(30), data_not_null varchar(30) not null, pattern varchar(30), pattern_not_null varchar(30) not null); +mysql> create table test.t (data varchar(30), data_not_null varchar(30) not null, pattern varchar(30), pattern_not_null varchar(30) not null); mysql> insert into test.t values ('aaaa', 'AAAA', '^a.*', '^A.*'), ('abcd', 'abcd', null, '^a..d$'), (null, 'bbb', 'bb$', 'bb$'),('中文测试','中文测试','中文','^....$'),('中English混合','中English混合','^中English','^..nglish..$'); mysql> alter table test.t set tiflash replica 1 func> wait_table test t @@ -71,7 +71,7 @@ mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp=1; se +---------------------+------------------------------+------------------------------+---------------------------------------+ mysql> drop table if exists test.t -mysql> create table test.t (data varchar(30), pattern varchar(30)); +mysql> create table test.t (data varchar(30), pattern varchar(30)); mysql> insert into test.t values ('abcd', 'abcd'); mysql> alter table test.t set tiflash replica 1 func> wait_table test t @@ -83,7 +83,7 @@ mysql> set @@tidb_isolation_read_engines='tiflash'; set @@tidb_enforce_mpp=1; se +---------------------+-------------------+ mysql> drop table if exists test.t; -mysql> create table test.t (data varchar(30), pattern varchar(30), match_type varchar(30)); +mysql> create table test.t (data varchar(30), pattern varchar(30), match_type varchar(30)); mysql> insert into test.t values ('a', 'A', 'i'), ('\n', '.', 's'), ('ab\nabc', '^abc$', 'm'); mysql> alter table test.t set tiflash replica 1; func> wait_table test t @@ -135,3 +135,23 @@ mysql> set tidb_enforce_mpp=1; select regexp_substr(expr, pattern, 1, 1, match_t | Bb | | abc | +------+ + +# test regexp_replace +mysql> drop table if exists test.t; +mysql> create table test.t (expr varchar(30), pattern varchar(30), repl varchar(30), pos int, occur int, match_type varchar(30)); +mysql> alter table test.t set tiflash replica 1; +func> wait_table test t +mysql> set tidb_enforce_mpp=1; select regexp_replace(_utf8mb4'1', _utf8mb4'1', repl, pos, occur, match_type) as res from test.t; +mysql> set tidb_enforce_mpp=1; select regexp_replace(_utf8mb4'1', _utf8mb4'', repl, pos, occur, match_type) as res from test.t; + +mysql> insert into test.t values (_utf8mb4'123', _utf8mb4'12.', _utf8mb4'233', 1, 1, _utf8mb4''), (_utf8mb4'aBb', _utf8mb4'bb', _utf8mb4'xzx', 1, 1, _utf8mb4'i'), (_utf8mb4'ababc', _utf8mb4'^abc$', _utf8mb4'123', 1, 1, _utf8mb4'c'); +mysql> alter table test.t set tiflash replica 1; +func> wait_table test t +mysql> set tidb_enforce_mpp=1; select regexp_replace(expr, pattern, repl, 1, 1, match_type) as res from test.t; ++-------+ +| res | ++-------+ +| 233 | +| axzx | +| ababc | ++-------+