From 63f013cdb36d05f6f96a145aff3c6232470f2d02 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Thu, 24 Nov 2022 22:12:31 -0800 Subject: [PATCH] ARROW-17966: [C++] Adjust to new format for Substrait optional arguments (#14415) Lead-authored-by: Weston Pace Co-authored-by: Benjamin Kietzman Signed-off-by: Weston Pace --- .../engine/substrait/expression_internal.cc | 49 +++-- .../arrow/engine/substrait/extension_set.cc | 145 +++++++++---- .../arrow/engine/substrait/extension_set.h | 25 ++- .../arrow/engine/substrait/function_test.cc | 190 ++++++++++++++---- .../arrow/engine/substrait/plan_internal.cc | 20 ++ cpp/src/arrow/engine/substrait/serde.cc | 10 + cpp/src/arrow/engine/substrait/serde_test.cc | 93 ++++++--- .../engine/substrait/test_plan_builder.cc | 30 ++- .../engine/substrait/test_plan_builder.h | 2 + cpp/thirdparty/versions.txt | 4 +- python/pyarrow/tests/test_substrait.py | 5 + 11 files changed, 419 insertions(+), 154 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 4cabd3647131b..8b33b6729a302 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -54,22 +54,11 @@ Id NormalizeFunctionName(Id id) { } // namespace -Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx, - SubstraitCall* call, const ExtensionSet& ext_set, +Status DecodeArg(const substrait::FunctionArgument& arg, int idx, SubstraitCall* call, + const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { if (arg.has_enum_()) { - const substrait::FunctionArgument::Enum& enum_val = arg.enum_(); - switch (enum_val.enum_kind_case()) { - case substrait::FunctionArgument::Enum::EnumKindCase::kSpecified: - call->SetEnumArg(idx, enum_val.specified()); - break; - case substrait::FunctionArgument::Enum::EnumKindCase::kUnspecified: - call->SetEnumArg(idx, std::nullopt); - break; - default: - return Status::Invalid("Unrecognized enum kind case: ", - enum_val.enum_kind_case()); - } + call->SetEnumArg(idx, arg.enum_()); } else if (arg.has_value()) { ARROW_ASSIGN_OR_RAISE(compute::Expression expr, FromProto(arg.value(), ext_set, conversion_options)); @@ -82,6 +71,19 @@ Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx, return Status::OK(); } +Status DecodeOption(const substrait::FunctionOption& opt, SubstraitCall* call) { + std::vector prefs; + if (opt.preference_size() == 0) { + return Status::Invalid("Invalid Substrait plan. The option ", opt.name(), + " is specified but does not list any choices"); + } + for (const auto& preference : opt.preference()) { + prefs.push_back(preference); + } + call->SetOption(opt.name(), prefs); + return Status::OK(); +} + Result DecodeScalarFunction( Id id, const substrait::Expression::ScalarFunction& scalar_fn, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { @@ -89,8 +91,11 @@ Result DecodeScalarFunction( FromProto(scalar_fn.output_type(), ext_set, conversion_options)); SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second); for (int i = 0; i < scalar_fn.arguments_size(); i++) { - ARROW_RETURN_NOT_OK(DecodeArg(scalar_fn.arguments(i), static_cast(i), &call, - ext_set, conversion_options)); + ARROW_RETURN_NOT_OK( + DecodeArg(scalar_fn.arguments(i), i, &call, ext_set, conversion_options)); + } + for (const auto& opt : scalar_fn.options()) { + ARROW_RETURN_NOT_OK(DecodeOption(opt, &call)); } return std::move(call); } @@ -929,17 +934,11 @@ Result> EncodeSubstraitCa ToProto(*call.output_type(), call.output_nullable(), ext_set, conversion_options)); scalar_fn->set_allocated_output_type(output_type.release()); - for (uint32_t i = 0; i < call.size(); i++) { + for (int i = 0; i < call.size(); i++) { substrait::FunctionArgument* arg = scalar_fn->add_arguments(); if (call.HasEnumArg(i)) { - auto enum_val = std::make_unique(); - ARROW_ASSIGN_OR_RAISE(std::optional enum_arg, call.GetEnumArg(i)); - if (enum_arg) { - enum_val->set_specified(std::string(*enum_arg)); - } else { - enum_val->set_allocated_unspecified(new google::protobuf::Empty()); - } - arg->set_allocated_enum_(enum_val.release()); + ARROW_ASSIGN_OR_RAISE(std::string_view enum_val, call.GetEnumArg(i)); + arg->set_enum_(std::string(enum_val)); } else if (call.HasValueArg(i)) { ARROW_ASSIGN_OR_RAISE(compute::Expression value_arg, call.GetValueArg(i)); ARROW_ASSIGN_OR_RAISE(std::unique_ptr value_expr, diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index e7b61b5bc75eb..1c0a92715614c 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -25,6 +25,7 @@ #include "arrow/engine/substrait/expression_internal.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" +#include "arrow/util/string.h" namespace arrow { namespace engine { @@ -121,7 +122,7 @@ class IdStorageImpl : public IdStorage { std::unique_ptr IdStorage::Make() { return std::make_unique(); } -Result> SubstraitCall::GetEnumArg(uint32_t index) const { +Result SubstraitCall::GetEnumArg(int index) const { if (index >= size_) { return Status::Invalid("Expected Substrait call to have an enum argument at index ", index, " but it did not have enough arguments"); @@ -134,16 +135,16 @@ Result> SubstraitCall::GetEnumArg(uint32_t index return enum_arg_it->second; } -bool SubstraitCall::HasEnumArg(uint32_t index) const { +bool SubstraitCall::HasEnumArg(int index) const { return enum_args_.find(index) != enum_args_.end(); } -void SubstraitCall::SetEnumArg(uint32_t index, std::optional enum_arg) { +void SubstraitCall::SetEnumArg(int index, std::string enum_arg) { size_ = std::max(size_, index + 1); enum_args_[index] = std::move(enum_arg); } -Result SubstraitCall::GetValueArg(uint32_t index) const { +Result SubstraitCall::GetValueArg(int index) const { if (index >= size_) { return Status::Invalid("Expected Substrait call to have a value argument at index ", index, " but it did not have enough arguments"); @@ -156,15 +157,32 @@ Result SubstraitCall::GetValueArg(uint32_t index) const { return value_arg_it->second; } -bool SubstraitCall::HasValueArg(uint32_t index) const { +bool SubstraitCall::HasValueArg(int index) const { return value_args_.find(index) != value_args_.end(); } -void SubstraitCall::SetValueArg(uint32_t index, compute::Expression value_arg) { +void SubstraitCall::SetValueArg(int index, compute::Expression value_arg) { size_ = std::max(size_, index + 1); value_args_[index] = std::move(value_arg); } +std::optional const*> SubstraitCall::GetOption( + std::string_view option_name) const { + auto opt = options_.find(std::string(option_name)); + if (opt == options_.end()) { + return std::nullopt; + } + return &opt->second; +} + +void SubstraitCall::SetOption(std::string_view option_name, + const std::vector& option_preferences) { + auto& prefs = options_[std::string(option_name)]; + for (std::string_view pref : option_preferences) { + prefs.emplace_back(pref); + } +} + // A builder used when creating a Substrait plan from an Arrow execution plan. In // that situation we do not have a set of anchor values already defined so we keep // a map of what Ids we have seen. @@ -645,50 +663,91 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { }; template -using EnumParser = std::function(std::optional)>; - -template -EnumParser GetEnumParser(const std::vector& options) { - std::unordered_map parse_map; - for (std::size_t i = 0; i < options.size(); i++) { - parse_map[options[i]] = static_cast(i + 1); +class EnumParser { + public: + explicit EnumParser(const std::vector& options) { + for (std::size_t i = 0; i < options.size(); i++) { + parse_map_[options[i]] = static_cast(i + 1); + reverse_map_[static_cast(i + 1)] = options[i]; + } } - return [parse_map](std::optional enum_val) -> Result { - if (!enum_val) { - // Assumes 0 is always kUnspecified in Enum - return static_cast(0); + + Result Parse(std::string_view enum_val) const { + auto it = parse_map_.find(std::string(enum_val)); + if (it == parse_map_.end()) { + return Status::NotImplemented("The value ", enum_val, + " is not an expected enum value"); } - auto maybe_parsed = parse_map.find(std::string(*enum_val)); - if (maybe_parsed == parse_map.end()) { - return Status::Invalid("The value ", *enum_val, " is not an expected enum value"); + return it->second; + } + + std::string ImplementedOptionsAsString( + const std::vector& implemented_opts) const { + std::vector opt_strs; + for (const Enum& implemented_opt : implemented_opts) { + auto it = reverse_map_.find(implemented_opt); + if (it == reverse_map_.end()) { + opt_strs.emplace_back("Unknown"); + } else { + opt_strs.emplace_back(it->second); + } } - return maybe_parsed->second; - }; -} + return arrow::internal::JoinStrings(opt_strs, ", "); + } + + private: + std::unordered_map parse_map_; + std::unordered_map reverse_map_; +}; enum class TemporalComponent { kUnspecified = 0, kYear, kMonth, kDay, kSecond }; static std::vector kTemporalComponentOptions = {"YEAR", "MONTH", "DAY", "SECOND"}; -static EnumParser kTemporalComponentParser = - GetEnumParser(kTemporalComponentOptions); +static EnumParser kTemporalComponentParser(kTemporalComponentOptions); enum class OverflowBehavior { kUnspecified = 0, kSilent, kSaturate, kError }; static std::vector kOverflowOptions = {"SILENT", "SATURATE", "ERROR"}; -static EnumParser kOverflowParser = - GetEnumParser(kOverflowOptions); +static EnumParser kOverflowParser(kOverflowOptions); template -Result ParseEnumArg(const SubstraitCall& call, uint32_t arg_index, +Result ParseOptionOrElse(const SubstraitCall& call, std::string_view option_name, + const EnumParser& parser, + const std::vector& implemented_options, + Enum fallback) { + std::optional const*> enum_arg = call.GetOption(option_name); + if (!enum_arg.has_value()) { + return fallback; + } + std::vector const* prefs = *enum_arg; + for (const std::string& pref : *prefs) { + ARROW_ASSIGN_OR_RAISE(Enum parsed, parser.Parse(pref)); + for (Enum implemented_opt : implemented_options) { + if (implemented_opt == parsed) { + return parsed; + } + } + } + + // Prepare error message + return Status::NotImplemented( + "During a call to a function with id ", call.id().uri, "#", call.id().name, + " the plan requested the option ", option_name, " to be one of [", + arrow::internal::JoinStrings(*prefs, ", "), + "] but the only supported options are [", + parser.ImplementedOptionsAsString(implemented_options), "]"); +} + +template +Result ParseEnumArg(const SubstraitCall& call, int arg_index, const EnumParser& parser) { - ARROW_ASSIGN_OR_RAISE(std::optional enum_arg, - call.GetEnumArg(arg_index)); - return parser(enum_arg); + ARROW_ASSIGN_OR_RAISE(std::string_view enum_val, call.GetEnumArg(arg_index)); + return parser.Parse(enum_val); } Result> GetValueArgs(const SubstraitCall& call, int start_index) { std::vector expressions; - for (uint32_t index = start_index; index < call.size(); index++) { + for (int index = start_index; index < call.size(); index++) { ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(index)); expressions.push_back(arg); } @@ -698,13 +757,13 @@ Result> GetValueArgs(const SubstraitCall& call, ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessOverflowableArithmetic( const std::string& function_name) { return [function_name](const SubstraitCall& call) -> Result { - ARROW_ASSIGN_OR_RAISE(OverflowBehavior overflow_behavior, - ParseEnumArg(call, 0, kOverflowParser)); + ARROW_ASSIGN_OR_RAISE( + OverflowBehavior overflow_behavior, + ParseOptionOrElse(call, "overflow", kOverflowParser, + {OverflowBehavior::kSilent, OverflowBehavior::kError}, + OverflowBehavior::kSilent)); ARROW_ASSIGN_OR_RAISE(std::vector value_args, - GetValueArgs(call, 1)); - if (overflow_behavior == OverflowBehavior::kUnspecified) { - overflow_behavior = OverflowBehavior::kSilent; - } + GetValueArgs(call, 0)); if (overflow_behavior == OverflowBehavior::kSilent) { return arrow::compute::call(function_name, std::move(value_args)); } else if (overflow_behavior == OverflowBehavior::kError) { @@ -736,12 +795,12 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(), /*nullable=*/true); if (kChecked) { - substrait_call.SetEnumArg(0, "ERROR"); + substrait_call.SetOption("overflow", {"ERROR"}); } else { - substrait_call.SetEnumArg(0, "SILENT"); + substrait_call.SetOption("overflow", {"SILENT"}); } for (std::size_t i = 0; i < call.arguments.size(); i++) { - substrait_call.SetValueArg(static_cast(i + 1), call.arguments[i]); + substrait_call.SetValueArg(static_cast(i), call.arguments[i]); } return std::move(substrait_call); }; @@ -755,14 +814,14 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrai SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(), /*nullable=*/true); for (std::size_t i = 0; i < call.arguments.size(); i++) { - substrait_call.SetValueArg(static_cast(i), call.arguments[i]); + substrait_call.SetValueArg(static_cast(i), call.arguments[i]); } return std::move(substrait_call); }; } ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping( - const std::string& function_name, uint32_t max_args) { + const std::string& function_name, int max_args) { return [function_name, max_args](const SubstraitCall& call) -> Result { if (call.size() > max_args) { diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index fd49f182f2749..12aa40115b15b 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -125,13 +125,17 @@ class SubstraitCall { bool output_nullable() const { return output_nullable_; } bool is_hash() const { return is_hash_; } - bool HasEnumArg(uint32_t index) const; - Result> GetEnumArg(uint32_t index) const; - void SetEnumArg(uint32_t index, std::optional enum_arg); - Result GetValueArg(uint32_t index) const; - bool HasValueArg(uint32_t index) const; - void SetValueArg(uint32_t index, compute::Expression value_arg); - uint32_t size() const { return size_; } + bool HasEnumArg(int index) const; + Result GetEnumArg(int index) const; + void SetEnumArg(int index, std::string enum_arg); + Result GetValueArg(int index) const; + bool HasValueArg(int index) const; + void SetValueArg(int index, compute::Expression value_arg); + std::optional const*> GetOption( + std::string_view option_name) const; + void SetOption(std::string_view option_name, + const std::vector& option_preferences); + int size() const { return size_; } private: Id id_; @@ -140,9 +144,10 @@ class SubstraitCall { // Only needed when converting from Substrait -> Arrow aggregates. The // Arrow function name depends on whether or not there are any groups bool is_hash_; - std::unordered_map> enum_args_; - std::unordered_map value_args_; - uint32_t size_ = 0; + std::unordered_map enum_args_; + std::unordered_map value_args_; + std::unordered_map> options_; + int size_ = 0; }; /// Substrait identifies functions and custom data types using a (uri, name) pair. diff --git a/cpp/src/arrow/engine/substrait/function_test.cc b/cpp/src/arrow/engine/substrait/function_test.cc index f4c59aad7642f..2c00e70ff7ba4 100644 --- a/cpp/src/arrow/engine/substrait/function_test.cc +++ b/cpp/src/arrow/engine/substrait/function_test.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include "arrow/array.h" @@ -43,6 +44,7 @@ namespace engine { struct FunctionTestCase { Id function_id; std::vector arguments; + std::unordered_map> options; std::vector> data_types; // For a test case that should fail just use the empty string std::string expected_output; @@ -98,10 +100,11 @@ Result> PlanFromTestCase( const FunctionTestCase& test_case, std::shared_ptr* output_table) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr
input_table, GetInputTable(test_case.arguments, test_case.data_types)); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr substrait, - internal::CreateScanProjectSubstrait( - test_case.function_id, input_table, test_case.arguments, - test_case.data_types, *test_case.expected_output_type)); + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr substrait, + internal::CreateScanProjectSubstrait( + test_case.function_id, input_table, test_case.arguments, test_case.options, + test_case.data_types, *test_case.expected_output_type)); std::shared_ptr consumer = std::make_shared(output_table, default_memory_pool()); @@ -144,6 +147,8 @@ void CheckValidTestCases(const std::vector& valid_cases) { void CheckErrorTestCases(const std::vector& error_cases) { for (const FunctionTestCase& test_case : error_cases) { + ARROW_SCOPED_TRACE("func=", test_case.function_id.uri, "#", + test_case.function_id.name); std::shared_ptr
output_table; ASSERT_OK_AND_ASSIGN(std::shared_ptr plan, PlanFromTestCase(test_case, &output_table)); @@ -151,232 +156,347 @@ void CheckErrorTestCases(const std::vector& error_cases) { } } +template +void CheckNonYetImplementedTestCase(const FunctionTestCase& test_case, + ErrorMatcher error_matcher) { + ARROW_SCOPED_TRACE("func=", test_case.function_id.uri, "#", test_case.function_id.name); + std::shared_ptr
output_table; + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, error_matcher, + PlanFromTestCase(test_case, &output_table)); +} + +static const std::unordered_map> kNoOptions; + // These are not meant to be an exhaustive test of Substrait // conformance. Instead, we should test just enough to ensure // we are mapping to the correct function TEST(FunctionMapping, ValidCases) { - const std::vector valid_test_cases = { + const std::initializer_list valid_test_cases = { {{kSubstraitArithmeticFunctionsUri, "add"}, - {"SILENT", "127", "10"}, - {nullptr, int8(), int8()}, + {"127", "10"}, + {{"overflow", {"SILENT", "ERROR"}}}, + {int8(), int8()}, "-119", int8()}, {{kSubstraitArithmeticFunctionsUri, "subtract"}, - {"SILENT", "-119", "10"}, - {nullptr, int8(), int8()}, + {"-119", "10"}, + {{"overflow", {"SILENT", "ERROR"}}}, + {int8(), int8()}, "127", int8()}, {{kSubstraitArithmeticFunctionsUri, "multiply"}, - {"SILENT", "10", "13"}, - {nullptr, int8(), int8()}, + {"10", "13"}, + {{"overflow", {"SILENT", "ERROR"}}}, + {int8(), int8()}, "-126", int8()}, {{kSubstraitArithmeticFunctionsUri, "divide"}, - {"SILENT", "-128", "-1"}, - {nullptr, int8(), int8()}, + {"-128", "-1"}, + {{"overflow", {"SILENT", "ERROR"}}}, + {int8(), int8()}, "0", int8()}, - {{kSubstraitArithmeticFunctionsUri, "sign"}, {"-1"}, {int8()}, "-1", int8()}, + {{kSubstraitArithmeticFunctionsUri, "sign"}, + {"-1"}, + kNoOptions, + {int8()}, + "-1", + int8()}, {{kSubstraitArithmeticFunctionsUri, "power"}, - {"SILENT", "2", "2"}, - {nullptr, int8(), int8()}, + {"2", "2"}, + {{"overflow", {"SILENT", "ERROR"}}}, + {int8(), int8()}, "4", int8()}, {{kSubstraitArithmeticFunctionsUri, "sqrt"}, - {"SILENT", "4"}, - {nullptr, int8()}, + {"4"}, + {{"overflow", {"SILENT", "ERROR"}}}, + {int8()}, "2", float64()}, {{kSubstraitArithmeticFunctionsUri, "exp"}, {"1"}, + kNoOptions, {float64()}, "2.718281828459045", float64()}, {{kSubstraitArithmeticFunctionsUri, "abs"}, - {"SILENT", "-1"}, - {nullptr, int8()}, + {"-1"}, + {{"overflow", {"SILENT", "ERROR"}}}, + {int8()}, "1", int8()}, {{kSubstraitBooleanFunctionsUri, "or"}, {"1", ""}, + kNoOptions, {boolean(), boolean()}, "1", boolean()}, {{kSubstraitBooleanFunctionsUri, "and"}, {"1", ""}, + kNoOptions, {boolean(), boolean()}, "", boolean()}, {{kSubstraitBooleanFunctionsUri, "xor"}, {"1", "1"}, + kNoOptions, {boolean(), boolean()}, "0", boolean()}, - {{kSubstraitBooleanFunctionsUri, "not"}, {"1"}, {boolean()}, "0", boolean()}, + {{kSubstraitBooleanFunctionsUri, "not"}, + {"1"}, + kNoOptions, + {boolean()}, + "0", + boolean()}, {{kSubstraitComparisonFunctionsUri, "equal"}, {"57", "57"}, + kNoOptions, {int8(), int8()}, "1", boolean()}, - {{kSubstraitComparisonFunctionsUri, "is_null"}, {"abc"}, {utf8()}, "0", boolean()}, + {{kSubstraitComparisonFunctionsUri, "is_null"}, + {"abc"}, + kNoOptions, + {utf8()}, + "0", + boolean()}, {{kSubstraitComparisonFunctionsUri, "is_not_null"}, {"57"}, + kNoOptions, {int8()}, "1", boolean()}, {{kSubstraitComparisonFunctionsUri, "not_equal"}, {"57", "57"}, + kNoOptions, {int8(), int8()}, "0", boolean()}, {{kSubstraitComparisonFunctionsUri, "lt"}, {"57", "80"}, + kNoOptions, {int8(), int8()}, "1", boolean()}, {{kSubstraitComparisonFunctionsUri, "lt"}, {"57", "57"}, + kNoOptions, {int8(), int8()}, "0", boolean()}, {{kSubstraitComparisonFunctionsUri, "gt"}, {"57", "30"}, + kNoOptions, {int8(), int8()}, "1", boolean()}, {{kSubstraitComparisonFunctionsUri, "gt"}, {"57", "57"}, + kNoOptions, {int8(), int8()}, "0", boolean()}, {{kSubstraitComparisonFunctionsUri, "lte"}, {"57", "57"}, + kNoOptions, {int8(), int8()}, "1", boolean()}, {{kSubstraitComparisonFunctionsUri, "lte"}, {"50", "57"}, + kNoOptions, {int8(), int8()}, "1", boolean()}, {{kSubstraitComparisonFunctionsUri, "gte"}, {"57", "57"}, + kNoOptions, {int8(), int8()}, "1", boolean()}, {{kSubstraitComparisonFunctionsUri, "gte"}, {"60", "57"}, + kNoOptions, {int8(), int8()}, "1", boolean()}, {{kSubstraitDatetimeFunctionsUri, "extract"}, {"YEAR", "2022-07-15T14:33:14"}, + kNoOptions, {nullptr, timestamp(TimeUnit::MICRO)}, "2022", int64()}, {{kSubstraitDatetimeFunctionsUri, "extract"}, {"MONTH", "2022-07-15T14:33:14"}, + kNoOptions, {nullptr, timestamp(TimeUnit::MICRO)}, "7", int64()}, {{kSubstraitDatetimeFunctionsUri, "extract"}, {"DAY", "2022-07-15T14:33:14"}, + kNoOptions, {nullptr, timestamp(TimeUnit::MICRO)}, "15", int64()}, {{kSubstraitDatetimeFunctionsUri, "extract"}, {"SECOND", "2022-07-15T14:33:14"}, + kNoOptions, {nullptr, timestamp(TimeUnit::MICRO)}, "14", int64()}, {{kSubstraitDatetimeFunctionsUri, "extract"}, {"YEAR", "2022-07-15T14:33:14Z"}, + kNoOptions, {nullptr, timestamp(TimeUnit::MICRO, "UTC")}, "2022", int64()}, {{kSubstraitDatetimeFunctionsUri, "extract"}, {"MONTH", "2022-07-15T14:33:14Z"}, + kNoOptions, {nullptr, timestamp(TimeUnit::MICRO, "UTC")}, "7", int64()}, {{kSubstraitDatetimeFunctionsUri, "extract"}, {"DAY", "2022-07-15T14:33:14Z"}, + kNoOptions, {nullptr, timestamp(TimeUnit::MICRO, "UTC")}, "15", int64()}, {{kSubstraitDatetimeFunctionsUri, "extract"}, {"SECOND", "2022-07-15T14:33:14Z"}, + kNoOptions, {nullptr, timestamp(TimeUnit::MICRO, "UTC")}, "14", int64()}, {{kSubstraitDatetimeFunctionsUri, "lt"}, {"2022-07-15T14:33:14", "2022-07-15T14:33:20"}, + kNoOptions, {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}, "1", boolean()}, {{kSubstraitDatetimeFunctionsUri, "lte"}, {"2022-07-15T14:33:14", "2022-07-15T14:33:14"}, + kNoOptions, {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}, "1", boolean()}, {{kSubstraitDatetimeFunctionsUri, "gt"}, {"2022-07-15T14:33:30", "2022-07-15T14:33:14"}, + kNoOptions, {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}, "1", boolean()}, {{kSubstraitDatetimeFunctionsUri, "gte"}, {"2022-07-15T14:33:14", "2022-07-15T14:33:14"}, + kNoOptions, {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}, "1", boolean()}, {{kSubstraitStringFunctionsUri, "concat"}, {"abc", "def"}, + kNoOptions, {utf8(), utf8()}, "abcdef", utf8()}, - {{kSubstraitLogarithmicFunctionsUri, "ln"}, {"1"}, {int8()}, "0", float64()}, - {{kSubstraitLogarithmicFunctionsUri, "log10"}, {"10"}, {int8()}, "1", float64()}, - {{kSubstraitLogarithmicFunctionsUri, "log2"}, {"2"}, {int8()}, "1", float64()}, + {{kSubstraitLogarithmicFunctionsUri, "ln"}, + {"1"}, + kNoOptions, + {int8()}, + "0", + float64()}, + {{kSubstraitLogarithmicFunctionsUri, "log10"}, + {"10"}, + kNoOptions, + {int8()}, + "1", + float64()}, + {{kSubstraitLogarithmicFunctionsUri, "log2"}, + {"2"}, + kNoOptions, + {int8()}, + "1", + float64()}, {{kSubstraitLogarithmicFunctionsUri, "log1p"}, {"1"}, + kNoOptions, {int8()}, "0.6931471805599453", float64()}, {{kSubstraitLogarithmicFunctionsUri, "logb"}, {"10", "10"}, + kNoOptions, {int8(), int8()}, "1", float64()}, - {{kSubstraitRoundingFunctionsUri, "floor"}, {"3.1"}, {float64()}, "3", float64()}, - {{kSubstraitRoundingFunctionsUri, "ceil"}, {"3.1"}, {float64()}, "4", float64()}}; + {{kSubstraitRoundingFunctionsUri, "floor"}, + {"3.1"}, + kNoOptions, + {float64()}, + "3", + float64()}, + {{kSubstraitRoundingFunctionsUri, "ceil"}, + {"3.1"}, + kNoOptions, + {float64()}, + "4", + float64()}}; CheckValidTestCases(valid_test_cases); } TEST(FunctionMapping, ErrorCases) { const std::vector error_test_cases = { {{kSubstraitArithmeticFunctionsUri, "add"}, - {"ERROR", "127", "10"}, - {nullptr, int8(), int8()}, + {"127", "10"}, + {{"overflow", {"ERROR", "SILENT"}}}, + {int8(), int8()}, "", int8()}, {{kSubstraitArithmeticFunctionsUri, "subtract"}, - {"ERROR", "-119", "10"}, - {nullptr, int8(), int8()}, + {"-119", "10"}, + {{"overflow", {"ERROR", "SILENT"}}}, + {int8(), int8()}, "", int8()}, {{kSubstraitArithmeticFunctionsUri, "multiply"}, - {"ERROR", "10", "13"}, - {nullptr, int8(), int8()}, + {"10", "13"}, + {{"overflow", {"ERROR", "SILENT"}}}, + {int8(), int8()}, "", int8()}, {{kSubstraitArithmeticFunctionsUri, "divide"}, - {"ERROR", "-128", "-1"}, - {nullptr, int8(), int8()}, + {"-128", "-1"}, + {{"overflow", {"ERROR", "SILENT"}}}, + {int8(), int8()}, "", int8()}}; CheckErrorTestCases(error_test_cases); } +TEST(FunctionMapping, UnrecognizedOptions) { + CheckNonYetImplementedTestCase( + {{kSubstraitArithmeticFunctionsUri, "add"}, + {"-119", "10"}, + {{"overflow", {"NEW_OVERFLOW_TYPE", "SILENT"}}}, + {int8(), int8()}, + "", + int8()}, + ::testing::HasSubstr("The value NEW_OVERFLOW_TYPE is not an expected enum value")); + CheckNonYetImplementedTestCase( + {{kSubstraitArithmeticFunctionsUri, "add"}, + {"-119", "10"}, + {{"overflow", {"SATURATE"}}}, + {int8(), int8()}, + "", + int8()}, + ::testing::HasSubstr( + "During a call to a function with id " + + std::string(kSubstraitArithmeticFunctionsUri) + + "#add the plan requested the option overflow to be one of [SATURATE] but the " + "only supported options are [SILENT, ERROR]")); +} + // For each aggregate test case we take in three values. We compute the // aggregate both on the entire set (all three values) and on groups. The // first two rows will be in the first group and the last row will be in the diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 18915868ee0ef..e675a2c2ab7f6 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -17,6 +17,7 @@ #include "arrow/engine/substrait/plan_internal.h" +#include "arrow/config.h" #include "arrow/dataset/plan.h" #include "arrow/engine/substrait/relation_internal.h" #include "arrow/result.h" @@ -132,10 +133,29 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, conversion_options, registry); } +namespace { + +// FIXME Is there some way to get these from the cmake files? +constexpr uint32_t kSubstraitMajorVersion = 0; +constexpr uint32_t kSubstraitMinorVersion = 20; +constexpr uint32_t kSubstraitPatchVersion = 0; + +std::unique_ptr CreateVersion() { + auto version = std::make_unique(); + version->set_major_number(kSubstraitMajorVersion); + version->set_minor_number(kSubstraitMinorVersion); + version->set_patch_number(kSubstraitPatchVersion); + version->set_producer("Acero " + GetBuildInfo().version_string); + return version; +} + +} // namespace + Result> PlanToProto( const compute::Declaration& declr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { auto subs_plan = std::make_unique(); + subs_plan->set_allocated_version(CreateVersion().release()); auto plan_rel = std::make_unique(); auto rel_root = std::make_unique(); ARROW_ASSIGN_OR_RAISE(auto rel, ToProto(declr, ext_set, conversion_options)); diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index f8c846c5a2389..b90bc98aab919 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -139,12 +139,22 @@ DeclarationFactory MakeWriteDeclarationFactory( }; } +// FIXME - Replace with actual version that includes the change +constexpr uint32_t kMinimumMajorVersion = 0; +constexpr uint32_t kMinimumMinorVersion = 19; + Result> DeserializePlans( const Buffer& buf, DeclarationFactory declaration_factory, const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out, const ConversionOptions& conversion_options) { ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer(buf)); + if (plan.version().major_number() < kMinimumMajorVersion && + plan.version().minor_number() < kMinimumMinorVersion) { + return Status::Invalid("Can only parse plans with a version >= ", + kMinimumMajorVersion, ".", kMinimumMinorVersion); + } + ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, conversion_options, registry)); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 4c2c860b07c94..210edcf75e71f 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -849,6 +849,7 @@ TEST(Substrait, RelWithHint) { TEST(Substrait, ExtensionSetFromPlan) { std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [ {"rel": { "read": { @@ -998,6 +999,7 @@ TEST(Substrait, ExtensionSetFromPlanExhaustedFactory) { TEST(Substrait, ExtensionSetFromPlanRegisterFunc) { std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [], "extension_uris": [ { @@ -1048,6 +1050,7 @@ Result GetSubstraitJSON() { auto file_path = file_name->ToString(); std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [ {"rel": { "read": { @@ -1179,6 +1182,7 @@ TEST(Substrait, GetRecordBatchReader) { TEST(Substrait, InvalidPlan) { std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [ ] })"; @@ -1189,8 +1193,34 @@ TEST(Substrait, InvalidPlan) { }); } +TEST(Substrait, InvalidMinimumVersion) { + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "version": { "major_number": 0, "minor_number": 18, "patch_number": 0 }, + "relations": [{ + "rel": { + "read": { + "base_schema": { + "names": ["A"], + "struct": { + "types": [{ + "i32": {} + }] + } + }, + "named_table": { "names": ["x"] } + } + } + }], + "extensionUris": [], + "extensions": [], + })")); + + ASSERT_RAISES(Invalid, DeserializePlans(*buf, [] { return kNullConsumer; })); +} + TEST(Substrait, JoinPlanBasic) { std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "join": { @@ -1336,6 +1366,7 @@ TEST(Substrait, JoinPlanBasic) { TEST(Substrait, JoinPlanInvalidKeyCmp) { std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "join": { @@ -1454,6 +1485,7 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) { TEST(Substrait, JoinPlanInvalidExpression) { ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "join": { @@ -1523,6 +1555,7 @@ TEST(Substrait, JoinPlanInvalidExpression) { TEST(Substrait, JoinPlanInvalidKeys) { ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "join": { @@ -1597,6 +1630,7 @@ TEST(Substrait, JoinPlanInvalidKeys) { TEST(Substrait, AggregateBasic) { ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "aggregate": { @@ -1692,6 +1726,7 @@ TEST(Substrait, AggregateBasic) { TEST(Substrait, AggregateInvalidRel) { ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "aggregate": { @@ -1718,6 +1753,7 @@ TEST(Substrait, AggregateInvalidRel) { TEST(Substrait, AggregateInvalidFunction) { ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "aggregate": { @@ -1781,6 +1817,7 @@ TEST(Substrait, AggregateInvalidFunction) { TEST(Substrait, AggregateInvalidAggFuncArgs) { ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "aggregate": { @@ -1822,7 +1859,7 @@ TEST(Substrait, AggregateInvalidAggFuncArgs) { "measures": [{ "measure": { "functionReference": 0, - "args": [], + "arguments": [], "sorts": [], "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", "invocation": "AGGREGATION_INVOCATION_ALL", @@ -1854,6 +1891,7 @@ TEST(Substrait, AggregateInvalidAggFuncArgs) { TEST(Substrait, AggregateWithFilter) { ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "aggregate": { @@ -1895,7 +1933,7 @@ TEST(Substrait, AggregateWithFilter) { "measures": [{ "measure": { "functionReference": 0, - "args": [], + "arguments": [], "sorts": [], "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", "invocation": "AGGREGATION_INVOCATION_ALL", @@ -1927,6 +1965,7 @@ TEST(Substrait, AggregateWithFilter) { TEST(Substrait, AggregateBadPhase) { ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "aggregate": { @@ -1968,7 +2007,7 @@ TEST(Substrait, AggregateBadPhase) { "measures": [{ "measure": { "functionReference": 0, - "args": [], + "arguments": [], "sorts": [], "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", "invocation": "AGGREGATION_INVOCATION_DISTINCT", @@ -1997,7 +2036,7 @@ TEST(Substrait, AggregateBadPhase) { ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; })); } -TEST(Substrait, BasicPlanRoundTripping) { +TEST(SubstraitRoundTrip, BasicPlan) { compute::ExecContext exec_context; arrow::dataset::internal::Initialize(); @@ -2108,7 +2147,7 @@ TEST(Substrait, BasicPlanRoundTripping) { } } -TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { +TEST(SubstraitRoundTrip, BasicPlanEndToEnd) { compute::ExecContext exec_context; arrow::dataset::internal::Initialize(); @@ -2221,7 +2260,7 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); } -TEST(Substrait, ProjectRel) { +TEST(SubstraitRoundTrip, ProjectRel) { compute::ExecContext exec_context; auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32())}); @@ -2237,6 +2276,7 @@ TEST(Substrait, ProjectRel) { ])"}); std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "project": { @@ -2335,7 +2375,7 @@ TEST(Substrait, ProjectRel) { buf, {}, conversion_options); } -TEST(Substrait, ProjectRelOnFunctionWithEmit) { +TEST(SubstraitRoundTrip, ProjectRelOnFunctionWithEmit) { compute::ExecContext exec_context; auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32())}); @@ -2351,6 +2391,7 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { ])"}); std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "project": { @@ -2453,7 +2494,7 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { buf, {}, conversion_options); } -TEST(Substrait, ReadRelWithEmit) { +TEST(SubstraitRoundTrip, ReadRelWithEmit) { compute::ExecContext exec_context; auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32())}); @@ -2465,6 +2506,7 @@ TEST(Substrait, ReadRelWithEmit) { ])"}); std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "read": { @@ -2511,7 +2553,7 @@ TEST(Substrait, ReadRelWithEmit) { buf, {}, conversion_options); } -TEST(Substrait, FilterRelWithEmit) { +TEST(SubstraitRoundTrip, FilterRelWithEmit) { compute::ExecContext exec_context; auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32()), field("D", int32())}); @@ -2528,6 +2570,7 @@ TEST(Substrait, FilterRelWithEmit) { ])"}); std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "filter": { @@ -2628,7 +2671,7 @@ TEST(Substrait, FilterRelWithEmit) { buf, {}, conversion_options); } -TEST(Substrait, JoinRelEndToEnd) { +TEST(SubstraitRoundTrip, JoinRel) { compute::ExecContext exec_context; auto left_schema = schema({field("A", int32()), field("B", int32())}); @@ -2648,6 +2691,7 @@ TEST(Substrait, JoinRelEndToEnd) { ])"}); std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "join": { @@ -2777,7 +2821,7 @@ TEST(Substrait, JoinRelEndToEnd) { buf, {}, conversion_options); } -TEST(Substrait, JoinRelWithEmit) { +TEST(SubstraitRoundTrip, JoinRelWithEmit) { compute::ExecContext exec_context; auto left_schema = schema({field("A", int32()), field("B", int32())}); @@ -2797,6 +2841,7 @@ TEST(Substrait, JoinRelWithEmit) { ])"}); std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "join": { @@ -2928,7 +2973,7 @@ TEST(Substrait, JoinRelWithEmit) { buf, {}, conversion_options); } -TEST(Substrait, AggregateRel) { +TEST(SubstraitRoundTrip, AggregateRel) { compute::ExecContext exec_context; auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32())}); @@ -2945,6 +2990,7 @@ TEST(Substrait, AggregateRel) { ])"}); std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "aggregate": { @@ -3036,7 +3082,7 @@ TEST(Substrait, AggregateRel) { buf, {}, conversion_options); } -TEST(Substrait, AggregateRelEmit) { +TEST(SubstraitRoundTrip, AggregateRelEmit) { compute::ExecContext exec_context; auto dummy_schema = schema({field("A", int32()), field("B", int32()), field("C", int32())}); @@ -3054,6 +3100,7 @@ TEST(Substrait, AggregateRelEmit) { // TODO: fixme https://issues.apache.org/jira/browse/ARROW-17484 std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "relations": [{ "rel": { "aggregate": { @@ -3153,10 +3200,8 @@ TEST(Substrait, AggregateRelEmit) { TEST(Substrait, IsthmusPlan) { // This is a plan generated from Isthmus // isthmus -c "CREATE TABLE T1(foo int)" "SELECT foo + 1 FROM T1" - // - // The plan had to be modified slightly to introduce the missing enum - // argument that isthmus did not put there. std::string substrait_json = R"({ + "version": { "major_number": 9999, "minor_number": 9999, "patch_number": 9999 }, "extensionUris": [{ "extensionUriAnchor": 1, "uri": "/functions_arithmetic.yaml" @@ -3165,7 +3210,7 @@ TEST(Substrait, IsthmusPlan) { "extensionFunction": { "extensionUriReference": 1, "functionAnchor": 0, - "name": "add:opt_i32_i32" + "name": "add:i32_i32" } }], "relations": [{ @@ -3204,7 +3249,6 @@ TEST(Substrait, IsthmusPlan) { "expressions": [{ "scalarFunction": { "functionReference": 0, - "args": [], "outputType": { "i32": { "typeVariationReference": 0, @@ -3212,10 +3256,6 @@ TEST(Substrait, IsthmusPlan) { } }, "arguments": [{ - "enum": { - "unspecified": {} - } - }, { "value": { "selection": { "directReference": { @@ -3282,7 +3322,7 @@ TEST(Substrait, ProjectWithMultiFieldExpressions) { "extensionFunction": { "extensionUriReference": 1, "functionAnchor": 0, - "name": "add:opt_i32_i32" + "name": "add:i32_i32" } }], "relations": [{ @@ -3361,7 +3401,6 @@ TEST(Substrait, ProjectWithMultiFieldExpressions) { },{ "scalarFunction": { "functionReference": 0, - "args": [], "outputType": { "i32": { "typeVariationReference": 0, @@ -3369,10 +3408,6 @@ TEST(Substrait, ProjectWithMultiFieldExpressions) { } }, "arguments": [{ - "enum": { - "unspecified": {} - } - }, { "value": { "selection": { "directReference": { @@ -3472,7 +3507,6 @@ TEST(Substrait, NestedProjectWithMultiFieldExpressions) { "functionReference": 2, "outputType": {"i32": {}}, "arguments": [ - {"enum": {"unspecified": {}}}, {"value": {"selection": {"directReference": {"structField": {"field": 0}}}}}, {"value": {"literal": {"fp64": 10}}} ] @@ -3561,7 +3595,6 @@ TEST(Substrait, NestedEmitProjectWithMultiFieldExpressions) { "functionReference": 2, "outputType": {"i32": {}}, "arguments": [ - {"enum": {"unspecified": {}}}, {"value": {"selection": {"directReference": {"structField": {"field": 0}}}}}, {"value": {"literal": {"fp64": 10}}} ] diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.cc b/cpp/src/arrow/engine/substrait/test_plan_builder.cc index d175006c63bce..2643b5c46986d 100644 --- a/cpp/src/arrow/engine/substrait/test_plan_builder.cc +++ b/cpp/src/arrow/engine/substrait/test_plan_builder.cc @@ -67,6 +67,7 @@ void CreateDirectReference(int32_t index, substrait::Expression* expr) { Result> CreateProject( Id function_id, const std::vector& arguments, + const std::unordered_map> options, const std::vector>& arg_types, const DataType& output_type, ExtensionSet* ext_set) { auto project = std::make_unique(); @@ -87,17 +88,17 @@ Result> CreateProject( } else { // If it doesn't have a type then it's an enum const std::string& enum_value = arguments[arg_index]; - auto enum_ = std::make_unique(); - if (enum_value.size() > 0) { - enum_->set_specified(enum_value); - } else { - auto unspecified = std::make_unique(); - enum_->set_allocated_unspecified(unspecified.release()); - } - argument->set_allocated_enum_(enum_.release()); + argument->set_enum_(enum_value); } arg_index++; } + for (const auto& opt : options) { + substrait::FunctionOption* option = call->add_options(); + option->set_name(opt.first); + for (const std::string& pref : opt.second) { + option->add_preference(pref); + } + } ARROW_ASSIGN_OR_RAISE( std::unique_ptr output_type_substrait, @@ -150,9 +151,19 @@ Result> CreateAgg(Id function_id, return agg; } +std::unique_ptr CreateTestVersion() { + auto version = std::make_unique(); + version->set_major_number(std::numeric_limits::max()); + version->set_minor_number(std::numeric_limits::max()); + version->set_patch_number(std::numeric_limits::max()); + version->set_producer("Arrow unit test"); + return version; +} + Result> CreatePlan(std::unique_ptr root, ExtensionSet* ext_set) { auto plan = std::make_unique(); + plan->set_allocated_version(CreateTestVersion().release()); substrait::PlanRel* plan_rel = plan->add_relations(); auto rel_root = std::make_unique(); @@ -166,6 +177,7 @@ Result> CreatePlan(std::unique_ptr> CreateScanProjectSubstrait( Id function_id, const std::shared_ptr
& input_table, const std::vector& arguments, + const std::unordered_map>& options, const std::vector>& data_types, const DataType& output_type) { ExtensionSet ext_set; @@ -173,7 +185,7 @@ Result> CreateScanProjectSubstrait( CreateRead(*input_table, &ext_set)); ARROW_ASSIGN_OR_RAISE( std::unique_ptr project, - CreateProject(function_id, arguments, data_types, output_type, &ext_set)); + CreateProject(function_id, arguments, options, data_types, output_type, &ext_set)); auto read_rel = std::make_unique(); read_rel->set_allocated_read(read.release()); diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.h b/cpp/src/arrow/engine/substrait/test_plan_builder.h index 9d2d97a8cc9cc..5f6629e9054ce 100644 --- a/cpp/src/arrow/engine/substrait/test_plan_builder.h +++ b/cpp/src/arrow/engine/substrait/test_plan_builder.h @@ -27,6 +27,7 @@ #include #include +#include #include #include "arrow/buffer.h" @@ -55,6 +56,7 @@ namespace internal { ARROW_ENGINE_EXPORT Result> CreateScanProjectSubstrait( Id function_id, const std::shared_ptr
& input_table, const std::vector& arguments, + const std::unordered_map>& options, const std::vector>& data_types, const DataType& output_type); diff --git a/cpp/thirdparty/versions.txt b/cpp/thirdparty/versions.txt index 0cc496e93856c..dc9a3ad1d660c 100644 --- a/cpp/thirdparty/versions.txt +++ b/cpp/thirdparty/versions.txt @@ -81,8 +81,8 @@ ARROW_RE2_BUILD_SHA256_CHECKSUM=f89c61410a072e5cbcf8c27e3a778da7d6fd2f2b5b1445cd # 1.1.9 is patched to implement https://github.com/google/snappy/pull/148 if this is bumped, remove the patch ARROW_SNAPPY_BUILD_VERSION=1.1.9 ARROW_SNAPPY_BUILD_SHA256_CHECKSUM=75c1fbb3d618dd3a0483bff0e26d0a92b495bbe5059c8b4f1c962b478b6e06e7 -ARROW_SUBSTRAIT_BUILD_VERSION=v0.6.0 -ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=7b8583b9684477e9027f417bbfb4febb8acfeb01923dcaa7cf0fd3f921d69c88 +ARROW_SUBSTRAIT_BUILD_VERSION=v0.20.0 +ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=5ceaa559ccef29a7825b5e5d4b5e7eed384830294f08bec913feecdd903a94cf ARROW_THRIFT_BUILD_VERSION=0.16.0 ARROW_THRIFT_BUILD_SHA256_CHECKSUM=f460b5c1ca30d8918ff95ea3eb6291b3951cf518553566088f3f2be8981f6209 ARROW_UCX_BUILD_VERSION=1.12.1 diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index 030e4aad8203f..e6358666f44ad 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -43,6 +43,7 @@ def _write_dummy_data_to_disk(tmpdir, file_name, table): def test_run_serialized_query(tmpdir): substrait_query = """ { + "version": { "major": 9999 }, "relations": [ {"rel": { "read": { @@ -116,6 +117,7 @@ def test_invalid_plan(): def test_binary_conversion_with_json_options(tmpdir): substrait_query = """ { + "version": { "major": 9999 }, "relations": [ {"rel": { "read": { @@ -195,6 +197,7 @@ def table_provider(names): substrait_query = """ { + "version": { "major": 9999 }, "relations": [ {"rel": { "read": { @@ -236,6 +239,7 @@ def table_provider(names): substrait_query = """ { + "version": { "major": 9999 }, "relations": [ {"rel": { "read": { @@ -277,6 +281,7 @@ def table_provider(names): substrait_query = """ { + "version": { "major": 9999 }, "relations": [ {"rel": { "read": {