Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-17966: [C++] Adjust to new format for Substrait optional arguments #14415

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,11 @@ Id NormalizeFunctionName(Id id) {

} // namespace

Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx,
SubstraitCall* call, const ExtensionSet& ext_set,
Status DecodeArg(const substrait::FunctionArgument& arg, int idx, SubstraitCall* call,
const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
if (arg.has_enum_()) {
const substrait::FunctionArgument::Enum& enum_val = arg.enum_();
switch (enum_val.enum_kind_case()) {
case substrait::FunctionArgument::Enum::EnumKindCase::kSpecified:
call->SetEnumArg(idx, enum_val.specified());
break;
case substrait::FunctionArgument::Enum::EnumKindCase::kUnspecified:
call->SetEnumArg(idx, std::nullopt);
break;
default:
return Status::Invalid("Unrecognized enum kind case: ",
enum_val.enum_kind_case());
}
call->SetEnumArg(idx, arg.enum_());
} else if (arg.has_value()) {
ARROW_ASSIGN_OR_RAISE(compute::Expression expr,
FromProto(arg.value(), ext_set, conversion_options));
Expand All @@ -80,15 +69,31 @@ Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx,
return Status::OK();
}

Status DecodeOption(const substrait::FunctionOption& opt, SubstraitCall* call) {
std::vector<std::string_view> prefs;
if (opt.preference_size() == 0) {
return Status::Invalid("Invalid Substrait plan. The option ", opt.name(),
" is specified but does not list any choices");
}
for (const auto& preference : opt.preference()) {
prefs.push_back(preference);
}
call->SetOption(opt.name(), prefs);
westonpace marked this conversation as resolved.
Show resolved Hide resolved
return Status::OK();
}

Result<SubstraitCall> DecodeScalarFunction(
Id id, const substrait::Expression::ScalarFunction& scalar_fn,
const ExtensionSet& ext_set, const ConversionOptions& conversion_options) {
ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable,
FromProto(scalar_fn.output_type(), ext_set, conversion_options));
SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second);
for (int i = 0; i < scalar_fn.arguments_size(); i++) {
ARROW_RETURN_NOT_OK(DecodeArg(scalar_fn.arguments(i), static_cast<uint32_t>(i), &call,
ext_set, conversion_options));
ARROW_RETURN_NOT_OK(
DecodeArg(scalar_fn.arguments(i), i, &call, ext_set, conversion_options));
}
for (const auto& opt : scalar_fn.options()) {
ARROW_RETURN_NOT_OK(DecodeOption(opt, &call));
}
return std::move(call);
}
Expand Down Expand Up @@ -927,17 +932,11 @@ Result<std::unique_ptr<substrait::Expression::ScalarFunction>> EncodeSubstraitCa
ToProto(*call.output_type(), call.output_nullable(), ext_set, conversion_options));
scalar_fn->set_allocated_output_type(output_type.release());

for (uint32_t i = 0; i < call.size(); i++) {
for (int i = 0; i < call.size(); i++) {
substrait::FunctionArgument* arg = scalar_fn->add_arguments();
if (call.HasEnumArg(i)) {
auto enum_val = std::make_unique<substrait::FunctionArgument::Enum>();
ARROW_ASSIGN_OR_RAISE(std::optional<std::string_view> enum_arg, call.GetEnumArg(i));
if (enum_arg) {
enum_val->set_specified(std::string(*enum_arg));
} else {
enum_val->set_allocated_unspecified(new google::protobuf::Empty());
}
arg->set_allocated_enum_(enum_val.release());
ARROW_ASSIGN_OR_RAISE(std::string_view enum_val, call.GetEnumArg(i));
arg->set_enum_(std::string(enum_val));
} else if (call.HasValueArg(i)) {
ARROW_ASSIGN_OR_RAISE(compute::Expression value_arg, call.GetValueArg(i));
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Expression> value_expr,
Expand Down
145 changes: 102 additions & 43 deletions cpp/src/arrow/engine/substrait/extension_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "arrow/engine/substrait/expression_internal.h"
#include "arrow/util/hash_util.h"
#include "arrow/util/hashing.h"
#include "arrow/util/string.h"

namespace arrow {
namespace engine {
Expand Down Expand Up @@ -121,7 +122,7 @@ class IdStorageImpl : public IdStorage {

std::unique_ptr<IdStorage> IdStorage::Make() { return std::make_unique<IdStorageImpl>(); }

Result<std::optional<std::string_view>> SubstraitCall::GetEnumArg(uint32_t index) const {
Result<std::string_view> SubstraitCall::GetEnumArg(int index) const {
if (index >= size_) {
return Status::Invalid("Expected Substrait call to have an enum argument at index ",
index, " but it did not have enough arguments");
Expand All @@ -134,16 +135,16 @@ Result<std::optional<std::string_view>> SubstraitCall::GetEnumArg(uint32_t index
return enum_arg_it->second;
}

bool SubstraitCall::HasEnumArg(uint32_t index) const {
bool SubstraitCall::HasEnumArg(int index) const {
return enum_args_.find(index) != enum_args_.end();
}

void SubstraitCall::SetEnumArg(uint32_t index, std::optional<std::string> enum_arg) {
void SubstraitCall::SetEnumArg(int index, std::string enum_arg) {
size_ = std::max(size_, index + 1);
enum_args_[index] = std::move(enum_arg);
}

Result<compute::Expression> SubstraitCall::GetValueArg(uint32_t index) const {
Result<compute::Expression> SubstraitCall::GetValueArg(int index) const {
if (index >= size_) {
return Status::Invalid("Expected Substrait call to have a value argument at index ",
index, " but it did not have enough arguments");
Expand All @@ -156,15 +157,32 @@ Result<compute::Expression> SubstraitCall::GetValueArg(uint32_t index) const {
return value_arg_it->second;
}

bool SubstraitCall::HasValueArg(uint32_t index) const {
bool SubstraitCall::HasValueArg(int index) const {
return value_args_.find(index) != value_args_.end();
}

void SubstraitCall::SetValueArg(uint32_t index, compute::Expression value_arg) {
void SubstraitCall::SetValueArg(int index, compute::Expression value_arg) {
size_ = std::max(size_, index + 1);
value_args_[index] = std::move(value_arg);
}

std::optional<std::vector<std::string> const*> SubstraitCall::GetOption(
std::string_view option_name) const {
auto opt = options_.find(std::string(option_name));
if (opt == options_.end()) {
return std::nullopt;
Comment on lines +169 to +173
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this could be simpler

Suggested change
std::optional<std::vector<std::string> const*> SubstraitCall::GetOption(
std::string_view option_name) const {
auto opt = options_.find(std::string(option_name));
if (opt == options_.end()) {
return std::nullopt;
const std::vector<std::string>* SubstraitCall::GetOption(
std::string_view option_name) const {
auto opt = options_.find(std::string(option_name));
if (opt == options_.end()) {
return nullptr;

}
return &opt->second;
}

void SubstraitCall::SetOption(std::string_view option_name,
const std::vector<std::string_view>& option_preferences) {
bkietz marked this conversation as resolved.
Show resolved Hide resolved
auto& prefs = options_[std::string(option_name)];
for (std::string_view pref : option_preferences) {
prefs.emplace_back(pref);
}
}

// A builder used when creating a Substrait plan from an Arrow execution plan. In
// that situation we do not have a set of anchor values already defined so we keep
// a map of what Ids we have seen.
Expand Down Expand Up @@ -645,50 +663,91 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
};

template <typename Enum>
using EnumParser = std::function<Result<Enum>(std::optional<std::string_view>)>;

template <typename Enum>
EnumParser<Enum> GetEnumParser(const std::vector<std::string>& options) {
std::unordered_map<std::string, Enum> parse_map;
for (std::size_t i = 0; i < options.size(); i++) {
parse_map[options[i]] = static_cast<Enum>(i + 1);
class EnumParser {
public:
explicit EnumParser(const std::vector<std::string>& options) {
for (std::size_t i = 0; i < options.size(); i++) {
parse_map_[options[i]] = static_cast<Enum>(i + 1);
reverse_map_[static_cast<Enum>(i + 1)] = options[i];
}
}
return [parse_map](std::optional<std::string_view> enum_val) -> Result<Enum> {
if (!enum_val) {
// Assumes 0 is always kUnspecified in Enum
return static_cast<Enum>(0);

Result<Enum> Parse(std::string_view enum_val) const {
auto it = parse_map_.find(std::string(enum_val));
if (it == parse_map_.end()) {
return Status::NotImplemented("The value ", enum_val,
" is not an expected enum value");
}
auto maybe_parsed = parse_map.find(std::string(*enum_val));
if (maybe_parsed == parse_map.end()) {
return Status::Invalid("The value ", *enum_val, " is not an expected enum value");
return it->second;
}

std::string ImplementedOptionsAsString(
const std::vector<Enum>& implemented_opts) const {
std::vector<std::string_view> opt_strs;
for (const Enum& implemented_opt : implemented_opts) {
auto it = reverse_map_.find(implemented_opt);
if (it == reverse_map_.end()) {
opt_strs.emplace_back("Unknown");
} else {
opt_strs.emplace_back(it->second);
}
}
return maybe_parsed->second;
};
}
return arrow::internal::JoinStrings(opt_strs, ", ");
}

private:
std::unordered_map<std::string, Enum> parse_map_;
std::unordered_map<Enum, std::string> reverse_map_;
};

enum class TemporalComponent { kUnspecified = 0, kYear, kMonth, kDay, kSecond };
static std::vector<std::string> kTemporalComponentOptions = {"YEAR", "MONTH", "DAY",
"SECOND"};
static EnumParser<TemporalComponent> kTemporalComponentParser =
GetEnumParser<TemporalComponent>(kTemporalComponentOptions);
static EnumParser<TemporalComponent> kTemporalComponentParser(kTemporalComponentOptions);

enum class OverflowBehavior { kUnspecified = 0, kSilent, kSaturate, kError };
static std::vector<std::string> kOverflowOptions = {"SILENT", "SATURATE", "ERROR"};
static EnumParser<OverflowBehavior> kOverflowParser =
GetEnumParser<OverflowBehavior>(kOverflowOptions);
static EnumParser<OverflowBehavior> kOverflowParser(kOverflowOptions);

template <typename Enum>
Result<Enum> ParseEnumArg(const SubstraitCall& call, uint32_t arg_index,
Result<Enum> ParseOptionOrElse(const SubstraitCall& call, std::string_view option_name,
const EnumParser<Enum>& parser,
const std::vector<Enum>& implemented_options,
Enum fallback) {
std::optional<std::vector<std::string> const*> enum_arg = call.GetOption(option_name);
if (!enum_arg.has_value()) {
return fallback;
}
std::vector<std::string> const* prefs = *enum_arg;
for (const std::string& pref : *prefs) {
ARROW_ASSIGN_OR_RAISE(Enum parsed, parser.Parse(pref));
for (Enum implemented_opt : implemented_options) {
if (implemented_opt == parsed) {
return parsed;
}
}
}

// Prepare error message
return Status::NotImplemented(
"During a call to a function with id ", call.id().uri, "#", call.id().name,
" the plan requested the option ", option_name, " to be one of [",
arrow::internal::JoinStrings(*prefs, ", "),
"] but the only supported options are [",
parser.ImplementedOptionsAsString(implemented_options), "]");
}

template <typename Enum>
Result<Enum> ParseEnumArg(const SubstraitCall& call, int arg_index,
const EnumParser<Enum>& parser) {
ARROW_ASSIGN_OR_RAISE(std::optional<std::string_view> enum_arg,
call.GetEnumArg(arg_index));
return parser(enum_arg);
ARROW_ASSIGN_OR_RAISE(std::string_view enum_val, call.GetEnumArg(arg_index));
return parser.Parse(enum_val);
}

Result<std::vector<compute::Expression>> GetValueArgs(const SubstraitCall& call,
int start_index) {
std::vector<compute::Expression> expressions;
for (uint32_t index = start_index; index < call.size(); index++) {
for (int index = start_index; index < call.size(); index++) {
ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(index));
expressions.push_back(arg);
}
Expand All @@ -698,13 +757,13 @@ Result<std::vector<compute::Expression>> GetValueArgs(const SubstraitCall& call,
ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessOverflowableArithmetic(
const std::string& function_name) {
return [function_name](const SubstraitCall& call) -> Result<compute::Expression> {
ARROW_ASSIGN_OR_RAISE(OverflowBehavior overflow_behavior,
ParseEnumArg(call, 0, kOverflowParser));
ARROW_ASSIGN_OR_RAISE(
OverflowBehavior overflow_behavior,
ParseOptionOrElse(call, "overflow", kOverflowParser,
{OverflowBehavior::kSilent, OverflowBehavior::kError},
OverflowBehavior::kSilent));
ARROW_ASSIGN_OR_RAISE(std::vector<compute::Expression> value_args,
GetValueArgs(call, 1));
if (overflow_behavior == OverflowBehavior::kUnspecified) {
overflow_behavior = OverflowBehavior::kSilent;
}
GetValueArgs(call, 0));
if (overflow_behavior == OverflowBehavior::kSilent) {
return arrow::compute::call(function_name, std::move(value_args));
} else if (overflow_behavior == OverflowBehavior::kError) {
Expand Down Expand Up @@ -736,12 +795,12 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic
SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(),
/*nullable=*/true);
if (kChecked) {
substrait_call.SetEnumArg(0, "ERROR");
substrait_call.SetOption("overflow", {"ERROR"});
} else {
substrait_call.SetEnumArg(0, "SILENT");
substrait_call.SetOption("overflow", {"SILENT"});
}
for (std::size_t i = 0; i < call.arguments.size(); i++) {
substrait_call.SetValueArg(static_cast<uint32_t>(i + 1), call.arguments[i]);
substrait_call.SetValueArg(static_cast<int>(i), call.arguments[i]);
}
return std::move(substrait_call);
};
Expand All @@ -755,14 +814,14 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrai
SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(),
/*nullable=*/true);
for (std::size_t i = 0; i < call.arguments.size(); i++) {
substrait_call.SetValueArg(static_cast<uint32_t>(i), call.arguments[i]);
substrait_call.SetValueArg(static_cast<int>(i), call.arguments[i]);
}
return std::move(substrait_call);
};
}

ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping(
const std::string& function_name, uint32_t max_args) {
const std::string& function_name, int max_args) {
return [function_name,
max_args](const SubstraitCall& call) -> Result<compute::Expression> {
if (call.size() > max_args) {
Expand Down
25 changes: 15 additions & 10 deletions cpp/src/arrow/engine/substrait/extension_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,17 @@ class SubstraitCall {
bool output_nullable() const { return output_nullable_; }
bool is_hash() const { return is_hash_; }

bool HasEnumArg(uint32_t index) const;
Result<std::optional<std::string_view>> GetEnumArg(uint32_t index) const;
void SetEnumArg(uint32_t index, std::optional<std::string> enum_arg);
Result<compute::Expression> GetValueArg(uint32_t index) const;
bool HasValueArg(uint32_t index) const;
void SetValueArg(uint32_t index, compute::Expression value_arg);
uint32_t size() const { return size_; }
bool HasEnumArg(int index) const;
Result<std::string_view> GetEnumArg(int index) const;
void SetEnumArg(int index, std::string enum_arg);
Result<compute::Expression> GetValueArg(int index) const;
bool HasValueArg(int index) const;
void SetValueArg(int index, compute::Expression value_arg);
std::optional<std::vector<std::string> const*> GetOption(
std::string_view option_name) const;
void SetOption(std::string_view option_name,
const std::vector<std::string_view>& option_preferences);
int size() const { return size_; }

private:
Id id_;
Expand All @@ -140,9 +144,10 @@ class SubstraitCall {
// Only needed when converting from Substrait -> Arrow aggregates. The
// Arrow function name depends on whether or not there are any groups
bool is_hash_;
std::unordered_map<uint32_t, std::optional<std::string>> enum_args_;
std::unordered_map<uint32_t, compute::Expression> value_args_;
uint32_t size_ = 0;
std::unordered_map<int, std::string> enum_args_;
std::unordered_map<int, compute::Expression> value_args_;
std::unordered_map<std::string, std::vector<std::string>> options_;
int size_ = 0;
};

/// Substrait identifies functions and custom data types using a (uri, name) pair.
Expand Down
Loading