Skip to content

Commit

Permalink
Remove physical expr of NamedStructField, convert to get_field func…
Browse files Browse the repository at this point in the history
…tion call (apache#9563)

* feat: replace namedstruct with ScalarUDF

* fix typo

* delete indexed_field file

* fix cargo check

* fix cargo check

* cargo update in CLI

* feat: add getfield func

* fix struct fun

* stage commit

* fix test

* refresh CI

* resolve strange bug

* fix clippy

* use values_to_arrays

* delete for merge

* use function_rewrite feature
  • Loading branch information
yyy1000 authored Mar 13, 2024
1 parent 4c9e787 commit 3b61004
Show file tree
Hide file tree
Showing 16 changed files with 183 additions and 630 deletions.
31 changes: 16 additions & 15 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 8 additions & 7 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
let expr = create_physical_name(expr, false)?;
Ok(format!("{expr} IS NOT UNKNOWN"))
}
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
let expr = create_physical_name(expr, false)?;
let name = match field {
GetFieldAccess::NamedStructField { name } => format!("{expr}[{name}]"),
Expr::GetIndexedField(GetIndexedField { expr: _, field }) => {
match field {
GetFieldAccess::NamedStructField { name: _ } => {
unreachable!(
"NamedStructField should have been rewritten in OperatorToFunction"
)
}
GetFieldAccess::ListIndex { key: _ } => {
unreachable!(
"ListIndex should have been rewritten in OperatorToFunction"
Expand All @@ -222,12 +225,10 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
stride: _,
} => {
unreachable!(
"ListIndex should have been rewritten in OperatorToFunction"
"ListRange should have been rewritten in OperatorToFunction"
)
}
};

Ok(name)
}
Expr::ScalarFunction(fun) => {
// function should be resolved during `AnalyzerRule`s
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ arrow-schema = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-functions = { workspace = true }
itertools = { version = "0.12", features = ["use_std"] }
log = { workspace = true }
paste = "1.0.14"
10 changes: 10 additions & 0 deletions datafusion/functions-array/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::{
BinaryExpr, BuiltinScalarFunction, Expr, GetFieldAccess, GetIndexedField, Operator,
};
use datafusion_functions::expr_fn::get_field;

/// Rewrites expressions into function calls to array functions
pub(crate) struct ArrayFunctionRewriter {}
Expand Down Expand Up @@ -147,6 +148,15 @@ impl FunctionRewrite for ArrayFunctionRewriter {
Transformed::yes(array_prepend(*left, *right))
}

Expr::GetIndexedField(GetIndexedField {
expr,
field: GetFieldAccess::NamedStructField { name },
}) => {
let expr = *expr.clone();
let name = Expr::Literal(name);
Transformed::yes(get_field(expr, name.clone()))
}

// expr[idx] ==> array_element(expr, idx)
Expr::GetIndexedField(GetIndexedField {
expr,
Expand Down
129 changes: 129 additions & 0 deletions datafusion/functions/src/core/getfield.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::DataType;
use arrow_array::{Scalar, StringArray};
use datafusion_common::cast::{as_map_array, as_struct_array};
use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue};
use datafusion_expr::field_util::GetFieldAccessSchema;
use datafusion_expr::{ColumnarValue, Expr, ExprSchemable};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;

#[derive(Debug)]
pub(super) struct GetFieldFunc {
signature: Signature,
}

impl GetFieldFunc {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}

// get_field(struct_array, field_name)
impl ScalarUDFImpl for GetFieldFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"get_field"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _: &[DataType]) -> Result<DataType> {
todo!()
}

fn return_type_from_exprs(
&self,
args: &[Expr],
schema: &dyn ExprSchema,
_arg_types: &[DataType],
) -> Result<DataType> {
if args.len() != 2 {
return exec_err!(
"get_field function requires 2 arguments, got {}",
args.len()
);
}

let name = match &args[1] {
Expr::Literal(name) => name,
_ => {
return exec_err!(
"get_field function requires the argument field_name to be a string"
);
}
};
let access_schema = GetFieldAccessSchema::NamedStructField { name: name.clone() };
let arg_dt = args[0].get_type(schema)?;
access_schema
.get_accessed_field(&arg_dt)
.map(|f| f.data_type().clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return exec_err!(
"get_field function requires 2 arguments, got {}",
args.len()
);
}

let arrays = ColumnarValue::values_to_arrays(args)?;
let array = arrays[0].clone();

let name = match &args[1] {
ColumnarValue::Scalar(name) => name,
_ => {
return exec_err!(
"get_field function requires the argument field_name to be a string"
);
}
};
match (array.data_type(), name) {
(DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => {
let map_array = as_map_array(array.as_ref())?;
let key_scalar = Scalar::new(StringArray::from(vec![k.clone()]));
let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?;
let entries = arrow::compute::filter(map_array.entries(), &keys)?;
let entries_struct_array = as_struct_array(entries.as_ref())?;
Ok(ColumnarValue::Array(entries_struct_array.column(1).clone()))
}
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
let as_struct_array = as_struct_array(&array)?;
match as_struct_array.column_by_name(k) {
None => exec_err!(
"get indexed field {k} not found in struct"),
Some(col) => Ok(ColumnarValue::Array(col.clone()))
}
}
(DataType::Struct(_), name) => exec_err!(
"get indexed field is only possible on struct with utf8 indexes. \
Tried with {name:?} index"),
(dt, name) => exec_err!(
"get indexed field is only possible on lists with int64 indexes or struct \
with utf8 indexes. Tried {dt:?} with {name:?} index"),
}
}
}
7 changes: 5 additions & 2 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@
//! "core" DataFusion functions
mod arrowtypeof;
mod getfield;
mod nullif;
mod nvl;
mod nvl2;
pub mod r#struct;
mod r#struct;

// create UDFs
make_udf_function!(nullif::NullIfFunc, NULLIF, nullif);
make_udf_function!(nvl::NVLFunc, NVL, nvl);
make_udf_function!(nvl2::NVL2Func, NVL2, nvl2);
make_udf_function!(arrowtypeof::ArrowTypeOfFunc, ARROWTYPEOF, arrow_typeof);
make_udf_function!(r#struct::StructFunc, STRUCT, r#struct);
make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field);

// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
(nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."),
(nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"),
(nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."),
(arrow_typeof, arg_1, "Returns the Arrow type of the input expression."),
(r#struct, args, "Returns a struct with the given arguments")
(r#struct, args, "Returns a struct with the given arguments"),
(get_field, arg_1 arg_2, "Returns the value of the field with the given name from the struct")
);
8 changes: 1 addition & 7 deletions datafusion/functions/src/core/struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn struct_expr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?))
}
#[derive(Debug)]
pub struct StructFunc {
pub(super) struct StructFunc {
signature: Signature,
}

Expand All @@ -73,12 +73,6 @@ impl StructFunc {
}
}

impl Default for StructFunc {
fn default() -> Self {
Self::new()
}
}

impl ScalarUDFImpl for StructFunc {
fn as_any(&self) -> &dyn Any {
self
Expand Down
Loading

0 comments on commit 3b61004

Please sign in to comment.