From b2a7a8288f4d5375912279128b0dcf9e70864e34 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Sun, 22 Aug 2021 18:53:21 +0200 Subject: [PATCH 01/15] enable GetIndexedField for Array and Dictionary --- datafusion/src/logical_plan/expr.rs | 28 +++++ datafusion/src/optimizer/utils.rs | 6 + .../expressions/get_indexed_field.rs | 118 ++++++++++++++++++ .../src/physical_plan/expressions/mod.rs | 4 +- datafusion/src/physical_plan/planner.rs | 8 ++ datafusion/src/sql/planner.rs | 14 +++ datafusion/src/sql/utils.rs | 4 + datafusion/src/utils.rs | 44 +++++++ 8 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 datafusion/src/physical_plan/expressions/get_indexed_field.rs create mode 100644 datafusion/src/utils.rs diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 011068d0e18b..a5d27e7e5121 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -26,6 +26,7 @@ use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, window_functions, }; +use crate::utils::get_indexed_field; use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::{compute::can_cast_types, datatypes::DataType}; @@ -245,6 +246,13 @@ pub enum Expr { IsNull(Box), /// arithmetic negation of an expression, the operand must be of a signed numeric data type Negative(Box), + /// Returns the field of a [`ListArray`] or ['DictionaryArray'] by name + GetIndexedField { + /// the expression to take the field from + expr: Box, + /// The name of the field to take + key: String, + }, /// Whether an expression is between a given range. Between { /// The value to compare @@ -433,6 +441,10 @@ impl Expr { Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), + Expr::GetIndexedField { ref expr, key } => { + let data_type = expr.get_type(schema)?; + get_indexed_field(&data_type, key).map(|x| x.data_type().clone()) + } } } @@ -488,6 +500,10 @@ impl Expr { Expr::Wildcard => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), + Expr::GetIndexedField { ref expr, key } => { + let data_type = expr.get_type(input_schema)?; + get_indexed_field(&data_type, key).map(|x| x.is_nullable()) + } } } @@ -763,6 +779,7 @@ impl Expr { .try_fold(visitor, |visitor, arg| arg.accept(visitor)) } Expr::Wildcard => Ok(visitor), + Expr::GetIndexedField { ref expr, .. } => expr.accept(visitor), }?; visitor.post_visit(self) @@ -923,6 +940,10 @@ impl Expr { negated, }, Expr::Wildcard => Expr::Wildcard, + Expr::GetIndexedField { expr, key } => Expr::GetIndexedField { + expr: rewrite_boxed(expr, rewriter)?, + key, + }, }; // now rewrite this expression itself @@ -1799,6 +1820,9 @@ impl fmt::Debug for Expr { } } Expr::Wildcard => write!(f, "*"), + Expr::GetIndexedField { ref expr, key } => { + write!(f, "({:?})[{}]", expr, key) + } } } } @@ -1879,6 +1903,10 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { let expr = create_name(expr, input_schema)?; Ok(format!("{} IS NOT NULL", expr)) } + Expr::GetIndexedField { expr, key } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{}[{}]", expr, key)) + } Expr::ScalarFunction { fun, args, .. } => { create_function_name(&fun.to_string(), false, args, input_schema) } diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 00ea31e2a358..39b68a980133 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -85,6 +85,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { Expr::AggregateUDF { .. } => {} Expr::InList { .. } => {} Expr::Wildcard => {} + Expr::GetIndexedField { .. } => {} } Ok(Recursion::Continue(self)) } @@ -337,6 +338,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::Wildcard { .. } => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), + Expr::GetIndexedField { expr, .. } => Ok(vec![expr.as_ref().to_owned()]), } } @@ -496,6 +498,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { Expr::Wildcard { .. } => Err(DataFusionError::Internal( "Wildcard expressions are not valid in a logical query plan".to_owned(), )), + Expr::GetIndexedField { expr: _, key } => Ok(Expr::GetIndexedField { + expr: Box::new(expressions[0].clone()), + key: key.clone(), + }), } } diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs new file mode 100644 index 000000000000..fdb18792109d --- /dev/null +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! get field of a struct array + +use std::{any::Any, sync::Arc}; + +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; + +use crate::{ + error::DataFusionError, + error::Result, + physical_plan::{ColumnarValue, PhysicalExpr}, + utils::get_indexed_field as get_data_type_field, +}; +use arrow::array::ListArray; + +/// expression to get a field of a struct array. +#[derive(Debug)] +pub struct GetIndexedFieldExpr { + arg: Arc, + key: String, +} + +impl GetIndexedFieldExpr { + /// Create new get field expression + pub fn new(arg: Arc, key: String) -> Self { + Self { arg, key } + } + + /// Get the input expression + pub fn arg(&self) -> &Arc { + &self.arg + } +} + +impl std::fmt::Display for GetIndexedFieldExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "({}).[{}]", self.arg, self.key) + } +} + +impl PhysicalExpr for GetIndexedFieldExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + let data_type = self.arg.data_type(input_schema)?; + get_data_type_field(&data_type, &self.key).map(|f| f.data_type().clone()) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + let data_type = self.arg.data_type(input_schema)?; + get_data_type_field(&data_type, &self.key).map(|f| f.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arg = self.arg.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + if let Some(la) = array.as_any().downcast_ref::() { + Ok(ColumnarValue::Array( + la.value(self.key.parse::().unwrap()), + )) + } + /*else if let Some(da) = + array.as_any().downcast_ref::>() + { + if let Some(index) = da.lookup_key::(&self.key) { + Ok(ColumnarValue::Array(Arc::new( + Int32Array::builder(0).finish(), + ))) + } else { + Err(DataFusionError::NotImplemented(format!( + "key not found in dictionnary : {}", + self.key + ))) + } + }*/ + else { + Err(DataFusionError::NotImplemented( + "get indexed field is only possible on dictionnary and list" + .to_string(), + )) + } + } + ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented( + "field is not yet implemented for scalar values".to_string(), + )), + } + } +} + +/// Create a `.[field]` expression +pub fn get_indexed_field( + arg: Arc, + key: String, +) -> Result> { + Ok(Arc::new(GetIndexedFieldExpr::new(arg, key))) +} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 4ca00367e7fe..bcd326f43993 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -33,8 +33,9 @@ mod case; mod cast; mod coercion; mod column; -mod count; +mod count mod cume_dist; +mod get_indexed_field; mod in_list; mod is_not_null; mod is_null; @@ -66,6 +67,7 @@ pub use cast::{ pub use column::{col, Column}; pub use count::Count; pub use cume_dist::cume_dist; +pub use get_indexed_field::{get_indexed_field, GetIndexedFieldExpr}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 8cfb907350b5..f58627d56cf4 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -141,6 +141,10 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let expr = create_physical_name(expr, false)?; Ok(format!("{} IS NOT NULL", expr)) } + Expr::GetIndexedField { expr, key } => { + let expr = create_physical_name(expr, false)?; + Ok(format!("{}[{}]", expr, key)) + } Expr::ScalarFunction { fun, args, .. } => { create_function_physical_name(&fun.to_string(), false, args) } @@ -989,6 +993,10 @@ impl DefaultPhysicalPlanner { Expr::IsNotNull(expr) => expressions::is_not_null( self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, ), + Expr::GetIndexedField { expr, key } => expressions::get_indexed_field( + self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, + key.clone(), + ), Expr::ScalarFunction { fun, args } => { let physical_args = args .iter() diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 7bdb7b82234b..bb8647e40197 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -1197,6 +1197,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } + SQLExpr::MapAccess { ref column, key } => { + if let SQLExpr::Identifier(ref id) = column.as_ref() { + Ok(Expr::GetIndexedField { + expr: Box::new(col(&id.value)), + key: key.to_string(), + }) + } else { + Err(DataFusionError::NotImplemented(format!( + "map access requires an identifier, found column {} instead", + column + ))) + } + } + SQLExpr::CompoundIdentifier(ids) => { let mut var_names = vec![]; for id in ids { diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 41bcd205800d..45e78cb67c4f 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -368,6 +368,10 @@ where Ok(expr.clone()) } Expr::Wildcard => Ok(Expr::Wildcard), + Expr::GetIndexedField { expr, key } => Ok(Expr::GetIndexedField { + expr: Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?), + key: key.clone(), + }), }, } } diff --git a/datafusion/src/utils.rs b/datafusion/src/utils.rs new file mode 100644 index 000000000000..7cf3ab0b9c1d --- /dev/null +++ b/datafusion/src/utils.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field}; + +use crate::error::{DataFusionError, Result}; + +/// Returns the a field access indexed by `name` from a [`DataType::List`] or [`DataType::Dictionnary`]. +/// # Error +/// Errors if +/// * the `data_type` is not a Struct or, +/// * there is no field key is not of the required index type +pub fn get_indexed_field<'a>(data_type: &'a DataType, key: &str) -> Result { + match data_type { + DataType::Dictionary(ref kt, ref vt) => { + match kt.as_ref() { + DataType::Utf8 => Ok(Field::new(key, *vt.clone(), true)), + _ => Err(DataFusionError::Plan(format!("The key for a dictionary has to be an utf8 string, was : \"{}\"", key))), + } + }, + DataType::List(lt) => match key.parse::() { + Ok(_) => Ok(Field::new(key, lt.data_type().clone(), false)), + Err(_) => Err(DataFusionError::Plan(format!("The key for a list has to be an integer, was : \"{}\"", key))), + }, + _ => Err(DataFusionError::Plan( + "The expression to get an indexed field is only valid for `List` or 'Dictionary'" + .to_string(), + )), + } +} From 98cb0a8cd4a13e466dd8f7c82a49a8149472f322 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Tue, 24 Aug 2021 14:05:56 +0200 Subject: [PATCH 02/15] fix GetIndexedField which should index slices not values --- datafusion/src/lib.rs | 1 + datafusion/src/logical_plan/expr.rs | 3 +- .../expressions/get_indexed_field.rs | 55 +++++++++++-------- datafusion/src/sql/planner.rs | 33 +++++++++-- datafusion/src/utils.rs | 31 ++++++++--- 5 files changed, 86 insertions(+), 37 deletions(-) diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index a4a5a88d16b5..75e8b82f2772 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -234,6 +234,7 @@ pub use parquet; #[cfg(test)] pub mod test; pub mod test_util; +pub mod utils; #[macro_use] #[cfg(feature = "regex_expressions")] diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index a5d27e7e5121..dccfcb01cfc8 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -251,7 +251,7 @@ pub enum Expr { /// the expression to take the field from expr: Box, /// The name of the field to take - key: String, + key: ScalarValue, }, /// Whether an expression is between a given range. Between { @@ -443,6 +443,7 @@ impl Expr { )), Expr::GetIndexedField { ref expr, key } => { let data_type = expr.get_type(schema)?; + get_indexed_field(&data_type, key).map(|x| x.data_type().clone()) } } diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index fdb18792109d..77d700b1207e 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -24,24 +24,29 @@ use arrow::{ record_batch::RecordBatch, }; +use crate::arrow::array::Array; +use crate::arrow::compute::concat; +use crate::scalar::ScalarValue; use crate::{ error::DataFusionError, error::Result, physical_plan::{ColumnarValue, PhysicalExpr}, utils::get_indexed_field as get_data_type_field, }; -use arrow::array::ListArray; +use arrow::array::{DictionaryArray, ListArray}; +use arrow::datatypes::Int8Type; +use std::fmt::Debug; /// expression to get a field of a struct array. #[derive(Debug)] pub struct GetIndexedFieldExpr { arg: Arc, - key: String, + key: ScalarValue, } impl GetIndexedFieldExpr { /// Create new get field expression - pub fn new(arg: Arc, key: String) -> Self { + pub fn new(arg: Arc, key: ScalarValue) -> Self { Self { arg, key } } @@ -75,33 +80,37 @@ impl PhysicalExpr for GetIndexedFieldExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let arg = self.arg.evaluate(batch)?; match arg { - ColumnarValue::Array(array) => { - if let Some(la) = array.as_any().downcast_ref::() { - Ok(ColumnarValue::Array( - la.value(self.key.parse::().unwrap()), - )) + ColumnarValue::Array(array) => match (array.data_type(), &self.key) { + (DataType::List(_), ScalarValue::Int64(Some(i))) => { + let as_list_array = + array.as_any().downcast_ref::().unwrap(); + let x: Vec> = as_list_array + .iter() + .filter_map(|o| o.map(|list| list.slice(*i as usize, 1).clone())) + .collect(); + let vec = x.iter().map(|a| a.as_ref()).collect::>(); + let iter = concat(vec.as_slice()).unwrap(); + Ok(ColumnarValue::Array(iter)) } - /*else if let Some(da) = - array.as_any().downcast_ref::>() - { - if let Some(index) = da.lookup_key::(&self.key) { - Ok(ColumnarValue::Array(Arc::new( - Int32Array::builder(0).finish(), - ))) + (DataType::Dictionary(_, _), ScalarValue::Utf8(Some(s))) => { + let as_dict_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if let Some(index) = as_dict_array.lookup_key(s) { + Ok(ColumnarValue::Array(as_dict_array.slice(index as usize, 1))) } else { Err(DataFusionError::NotImplemented(format!( "key not found in dictionnary : {}", self.key ))) } - }*/ - else { - Err(DataFusionError::NotImplemented( - "get indexed field is only possible on dictionnary and list" - .to_string(), - )) } - } + _ => Err(DataFusionError::NotImplemented( + "get indexed field is only possible on dictionary and list" + .to_string(), + )), + }, ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented( "field is not yet implemented for scalar values".to_string(), )), @@ -112,7 +121,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { /// Create a `.[field]` expression pub fn get_indexed_field( arg: Arc, - key: String, + key: ScalarValue, ) -> Result> { Ok(Arc::new(GetIndexedFieldExpr::new(arg, key))) } diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index bb8647e40197..7de534448c70 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -81,6 +81,32 @@ pub struct SqlToRel<'a, S: ContextProvider> { schema_provider: &'a S, } +fn plan_key(key: Value) -> ScalarValue { + match key { + Value::Number(s, _) => ScalarValue::Int64(Some(s.parse().unwrap())), + Value::SingleQuotedString(s) => ScalarValue::Utf8(Some(s)), + _ => unreachable!(), + } +} + +fn plan_indexed(expr: Expr, mut keys: Vec) -> Expr { + if keys.len() == 1 { + let key = keys.pop().unwrap(); + Expr::GetIndexedField { + expr: Box::new(expr), + key: plan_key(key), + } + } else { + // "table.column[key]..." + let key = keys.pop().unwrap(); + let expr = Box::new(plan_indexed(expr, keys)); + Expr::GetIndexedField { + expr, + key: plan_key(key), + } + } +} + impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Create a new query planner pub fn new(schema_provider: &'a S) -> Self { @@ -1197,12 +1223,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } - SQLExpr::MapAccess { ref column, key } => { + SQLExpr::MapAccess { ref column, keys } => { if let SQLExpr::Identifier(ref id) = column.as_ref() { - Ok(Expr::GetIndexedField { - expr: Box::new(col(&id.value)), - key: key.to_string(), - }) + Ok(plan_indexed(col(&id.value), keys.clone())) } else { Err(DataFusionError::NotImplemented(format!( "map access requires an identifier, found column {} instead", diff --git a/datafusion/src/utils.rs b/datafusion/src/utils.rs index 7cf3ab0b9c1d..aab043c5e60e 100644 --- a/datafusion/src/utils.rs +++ b/datafusion/src/utils.rs @@ -18,24 +18,39 @@ use arrow::datatypes::{DataType, Field}; use crate::error::{DataFusionError, Result}; +use crate::scalar::ScalarValue; /// Returns the a field access indexed by `name` from a [`DataType::List`] or [`DataType::Dictionnary`]. /// # Error /// Errors if /// * the `data_type` is not a Struct or, /// * there is no field key is not of the required index type -pub fn get_indexed_field<'a>(data_type: &'a DataType, key: &str) -> Result { - match data_type { - DataType::Dictionary(ref kt, ref vt) => { +pub fn get_indexed_field<'a>( + data_type: &'a DataType, + key: &ScalarValue, +) -> Result { + match (data_type, key) { + (DataType::Dictionary(ref kt, ref vt), ScalarValue::Utf8(Some(k))) => { match kt.as_ref() { - DataType::Utf8 => Ok(Field::new(key, *vt.clone(), true)), + DataType::Utf8 => Ok(Field::new(&k, *vt.clone(), true)), _ => Err(DataFusionError::Plan(format!("The key for a dictionary has to be an utf8 string, was : \"{}\"", key))), } }, - DataType::List(lt) => match key.parse::() { - Ok(_) => Ok(Field::new(key, lt.data_type().clone(), false)), - Err(_) => Err(DataFusionError::Plan(format!("The key for a list has to be an integer, was : \"{}\"", key))), - }, + (DataType::Dictionary(_, _), _) => { + Err(DataFusionError::Plan( + "Only uf8 is valid as an indexed field in a dictionary" + .to_string(), + )) + } + (DataType::List(lt), ScalarValue::Int64(Some(i))) => { + Ok(Field::new(&i.to_string(), lt.data_type().clone(), false)) + } + (DataType::List(_), _) => { + Err(DataFusionError::Plan( + "Only ints are valid as an indexed field in a list" + .to_string(), + )) + } _ => Err(DataFusionError::Plan( "The expression to get an indexed field is only valid for `List` or 'Dictionary'" .to_string(), From 7b5aec612ed8c84f3f7dfef906447c21a7a92e02 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Wed, 15 Sep 2021 21:25:54 +0200 Subject: [PATCH 03/15] Compat with latest sqlparser --- datafusion/src/utils.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/datafusion/src/utils.rs b/datafusion/src/utils.rs index aab043c5e60e..20b160ecb497 100644 --- a/datafusion/src/utils.rs +++ b/datafusion/src/utils.rs @@ -20,15 +20,12 @@ use arrow::datatypes::{DataType, Field}; use crate::error::{DataFusionError, Result}; use crate::scalar::ScalarValue; -/// Returns the a field access indexed by `name` from a [`DataType::List`] or [`DataType::Dictionnary`]. +/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Dictionnary`]. /// # Error /// Errors if /// * the `data_type` is not a Struct or, /// * there is no field key is not of the required index type -pub fn get_indexed_field<'a>( - data_type: &'a DataType, - key: &ScalarValue, -) -> Result { +pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { match (data_type, key) { (DataType::Dictionary(ref kt, ref vt), ScalarValue::Utf8(Some(k))) => { match kt.as_ref() { From c0b0920f93f64ea7814bd0050937ce9543cc82ee Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Wed, 15 Sep 2021 21:57:57 +0200 Subject: [PATCH 04/15] Add two tests for indexed_field access, level one and level two nesting --- datafusion/tests/sql.rs | 85 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index f3dba3fc2ad1..b888c88445dd 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -5305,3 +5305,88 @@ async fn case_with_bool_type_result() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[tokio::test] +async fn query_get_indexed_field() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let schema = Arc::new(Schema::new(vec![Field::new( + "some_list", + DataType::List(Box::new(Field::new("item", DataType::Int64, true))), + false, + )])); + let builder = PrimitiveBuilder::::new(3); + let mut lb = ListBuilder::new(builder); + for int_vec in vec![ + vec![0 as i64, 1, 2], + vec![4 as i64, 5, 6], + vec![7 as i64, 8, 9], + ] { + let builder = lb.values(); + for int in int_vec { + builder.append_value(int); + } + lb.append(true); + } + + let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(lb.finish())])?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + let table_a = Arc::new(table); + + ctx.register_table("ints", table_a)?; + + // Original column is micros, convert to millis and check timestamp + let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0"], vec!["4"], vec!["7"]]; + assert_eq!(expected, actual); + Ok(()) +} + +#[tokio::test] +async fn query_nested_get_indexed_field() -> Result<()> { + let mut ctx = ExecutionContext::new(); + let schema = Arc::new(Schema::new(vec![Field::new( + "some_list", + DataType::List(Box::new(Field::new( + "item", + DataType::List(Box::new(Field::new("item", DataType::Int64, true))), + true, + ))), + false, + )])); + let builder = PrimitiveBuilder::::new(3); + let nested_lb = ListBuilder::new(builder); + let mut lb = ListBuilder::new(nested_lb); + for int_vec_vec in vec![ + vec![vec![0 as i64, 1], vec![2, 3], vec![3, 4]], + vec![vec![5 as i64, 6], vec![7, 8], vec![9, 10]], + vec![vec![11 as i64, 12], vec![13, 14], vec![15, 16]], + ] { + let nested_builder = lb.values(); + for int_vec in int_vec_vec { + let mut builder = nested_builder.values(); + for int in int_vec { + builder.append_value(int); + } + nested_builder.append(true); + } + lb.append(true); + } + + let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(lb.finish())])?; + let table = MemTable::try_new(schema, vec![vec![data]])?; + let table_a = Arc::new(table); + + ctx.register_table("ints", table_a)?; + + // Original column is micros, convert to millis and check timestamp + let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["[0, 1]"], vec!["[5, 6]"], vec!["[11, 12]"]]; + assert_eq!(expected, actual); + let sql = "SELECT some_list[0][0] as i0 FROM ints LIMIT 3"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0"], vec!["5"], vec!["11"]]; + assert_eq!(expected, actual); + Ok(()) +} From 60ba85849f26af826ae6a34b2090da4043d35a0e Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Mon, 20 Sep 2021 14:23:07 +0200 Subject: [PATCH 05/15] fix compilation issues --- datafusion/src/optimizer/common_subexpr_eliminate.rs | 4 ++++ datafusion/src/physical_plan/expressions/get_indexed_field.rs | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/src/optimizer/common_subexpr_eliminate.rs b/datafusion/src/optimizer/common_subexpr_eliminate.rs index 8d87b2218304..ea60286b902f 100644 --- a/datafusion/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/src/optimizer/common_subexpr_eliminate.rs @@ -442,6 +442,10 @@ impl ExprIdentifierVisitor<'_> { Expr::Wildcard => { desc.push_str("Wildcard-"); } + Expr::GetIndexedField { key, .. } => { + desc.push_str("GetIndexedField-"); + desc.push_str(&key.to_string()); + } } desc diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index 77d700b1207e..d3cc872adf87 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -86,7 +86,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { array.as_any().downcast_ref::().unwrap(); let x: Vec> = as_list_array .iter() - .filter_map(|o| o.map(|list| list.slice(*i as usize, 1).clone())) + .filter_map(|o| o.map(|list| list.slice(*i as usize, 1))) .collect(); let vec = x.iter().map(|a| a.as_ref()).collect::>(); let iter = concat(vec.as_slice()).unwrap(); From c96cc6b2a5e9f58f5ecd03606581355deb2ab598 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Fri, 24 Sep 2021 14:07:57 +0200 Subject: [PATCH 06/15] try fixing dictionary lookup --- datafusion/Cargo.toml | 4 +- .../expressions/get_indexed_field.rs | 59 ++++++++++++++----- datafusion/src/utils.rs | 11 ++-- datafusion/tests/sql.rs | 44 ++++++++++---- 4 files changed, 87 insertions(+), 31 deletions(-) diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index e05bc0702bc0..17d5124b8bca 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -46,7 +46,7 @@ unicode_expressions = ["unicode-segmentation"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] # Used to enable the avro format -avro = ["avro-rs", "num-traits"] +avro = ["avro-rs"] [dependencies] ahash = "0.7" @@ -74,7 +74,7 @@ lazy_static = { version = "^1.4.0", optional = true } smallvec = { version = "1.6", features = ["union"] } rand = "0.8" avro-rs = { version = "0.13", features = ["snappy"], optional = true } -num-traits = { version = "0.2", optional = true } +num-traits = { version = "0.2" } [dev-dependencies] criterion = "0.3" diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index d3cc872adf87..f68cfb68eab1 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -26,6 +26,10 @@ use arrow::{ use crate::arrow::array::Array; use crate::arrow::compute::concat; +use crate::arrow::datatypes::{ + ArrowNativeType, Int16Type, Int32Type, Int64Type, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, +}; use crate::scalar::ScalarValue; use crate::{ error::DataFusionError, @@ -33,8 +37,9 @@ use crate::{ physical_plan::{ColumnarValue, PhysicalExpr}, utils::get_indexed_field as get_data_type_field, }; -use arrow::array::{DictionaryArray, ListArray}; -use arrow::datatypes::Int8Type; +use arrow::array::{ArrayRef, DictionaryArray, ListArray}; +use arrow::datatypes::{ArrowPrimitiveType, Int8Type}; +use num_traits::ToPrimitive; use std::fmt::Debug; /// expression to get a field of a struct array. @@ -92,18 +97,22 @@ impl PhysicalExpr for GetIndexedFieldExpr { let iter = concat(vec.as_slice()).unwrap(); Ok(ColumnarValue::Array(iter)) } - (DataType::Dictionary(_, _), ScalarValue::Utf8(Some(s))) => { - let as_dict_array = array - .as_any() - .downcast_ref::>() - .unwrap(); - if let Some(index) = as_dict_array.lookup_key(s) { - Ok(ColumnarValue::Array(as_dict_array.slice(index as usize, 1))) - } else { - Err(DataFusionError::NotImplemented(format!( - "key not found in dictionnary : {}", - self.key - ))) + (DataType::Dictionary(ref kt, _), ScalarValue::Utf8(Some(s))) => { + match **kt { + DataType::Int8 => dict_lookup::(array, s), + DataType::Int16 => dict_lookup::(array, s), + DataType::Int32 => dict_lookup::(array, s), + DataType::Int64 => dict_lookup::(array, s), + DataType::UInt8 => dict_lookup::(array, s), + DataType::UInt16 => dict_lookup::(array, s), + DataType::UInt32 => dict_lookup::(array, s), + DataType::UInt64 => dict_lookup::(array, s), + _ => { + return Err(DataFusionError::NotImplemented( + "dictionary lookup only available for numeric keys" + .to_string(), + )) + } } } _ => Err(DataFusionError::NotImplemented( @@ -125,3 +134,25 @@ pub fn get_indexed_field( ) -> Result> { Ok(Arc::new(GetIndexedFieldExpr::new(arg, key))) } + +fn dict_lookup( + array: ArrayRef, + lookup: &str, +) -> Result +where + T::Native: num_traits::cast::ToPrimitive, +{ + let as_dict_array = array.as_any().downcast_ref::>().unwrap(); + if let Some(index) = as_dict_array.lookup_key(lookup) { + Ok(ColumnarValue::Array( + as_dict_array + .keys() + .slice(ToPrimitive::to_usize(&index).unwrap(), 1), + )) + } else { + Err(DataFusionError::NotImplemented(format!( + "key not found in dictionary for : {}", + lookup + ))) + } +} diff --git a/datafusion/src/utils.rs b/datafusion/src/utils.rs index 20b160ecb497..54cc5375ead7 100644 --- a/datafusion/src/utils.rs +++ b/datafusion/src/utils.rs @@ -20,7 +20,7 @@ use arrow::datatypes::{DataType, Field}; use crate::error::{DataFusionError, Result}; use crate::scalar::ScalarValue; -/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Dictionnary`]. +/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Dictionary`]. /// # Error /// Errors if /// * the `data_type` is not a Struct or, @@ -29,13 +29,16 @@ pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { match kt.as_ref() { - DataType::Utf8 => Ok(Field::new(&k, *vt.clone(), true)), - _ => Err(DataFusionError::Plan(format!("The key for a dictionary has to be an utf8 string, was : \"{}\"", key))), + DataType::Int8 | DataType::Int16 |DataType::Int32 |DataType::Int64 |DataType::UInt8 + | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + Ok(Field::new(&k, *kt.clone(), true)) + }, + _ => Err(DataFusionError::Plan(format!("The key for a dictionary has to be a primitive type, was : \"{}\"", key))), } }, (DataType::Dictionary(_, _), _) => { Err(DataFusionError::Plan( - "Only uf8 is valid as an indexed field in a dictionary" + "Only utf8 types are valid for dictionary lookup" .to_string(), )) } diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index b888c88445dd..1e7c0d4e65a3 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -5345,15 +5345,20 @@ async fn query_get_indexed_field() -> Result<()> { #[tokio::test] async fn query_nested_get_indexed_field() -> Result<()> { let mut ctx = ExecutionContext::new(); - let schema = Arc::new(Schema::new(vec![Field::new( - "some_list", - DataType::List(Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int64, true))), - true, - ))), - false, - )])); + let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); + // Nested schema of { "some_list": [[i64]] } + let schema = Arc::new(Schema::new(vec![ + Field::new( + "some_list", + DataType::List(Box::new(Field::new("item", nested_dt.clone(), true))), + false, + ), + Field::new( + "some_dict", + DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8)), + false, + ), + ])); let builder = PrimitiveBuilder::::new(3); let nested_lb = ListBuilder::new(builder); let mut lb = ListBuilder::new(nested_lb); @@ -5372,8 +5377,21 @@ async fn query_nested_get_indexed_field() -> Result<()> { } lb.append(true); } - - let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(lb.finish())])?; + let dictionary_values = StringArray::from(vec![Some("a"), Some("b"), Some("c")]); + let mut sb = StringDictionaryBuilder::new_with_dictionary( + PrimitiveBuilder::::new(3), + &dictionary_values, + ) + .unwrap(); + for s in vec!["b", "a", "c"] { + sb.append(s); + } + let array = sb.finish(); + eprintln!("array.keys() = {:?}", array.keys()); + let data = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(lb.finish()), Arc::new(array)], + )?; let table = MemTable::try_new(schema, vec![vec![data]])?; let table_a = Arc::new(table); @@ -5388,5 +5406,9 @@ async fn query_nested_get_indexed_field() -> Result<()> { let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["0"], vec!["5"], vec!["11"]]; assert_eq!(expected, actual); + let sql = r#"SELECT some_dict["b"] as i0 FROM ints LIMIT 3"#; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0"], vec!["1"], vec!["0"]]; + assert_eq!(expected, actual); Ok(()) } From 16d190d13da37b32d5a86efc85ee237870dbfa56 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Fri, 24 Sep 2021 14:18:53 +0200 Subject: [PATCH 07/15] address clippy warnings --- datafusion/src/{utils.rs => field_util.rs} | 6 ++- datafusion/src/lib.rs | 2 +- datafusion/src/logical_plan/expr.rs | 2 +- .../expressions/get_indexed_field.rs | 15 +++----- datafusion/src/sql/planner.rs | 2 +- datafusion/tests/sql.rs | 38 +++++++++---------- 6 files changed, 30 insertions(+), 35 deletions(-) rename datafusion/src/{utils.rs => field_util.rs} (92%) diff --git a/datafusion/src/utils.rs b/datafusion/src/field_util.rs similarity index 92% rename from datafusion/src/utils.rs rename to datafusion/src/field_util.rs index 54cc5375ead7..795a5c92ec84 100644 --- a/datafusion/src/utils.rs +++ b/datafusion/src/field_util.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! Utility functions for complex field access + use arrow::datatypes::{DataType, Field}; use crate::error::{DataFusionError, Result}; @@ -27,11 +29,11 @@ use crate::scalar::ScalarValue; /// * there is no field key is not of the required index type pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { match (data_type, key) { - (DataType::Dictionary(ref kt, ref vt), ScalarValue::Utf8(Some(k))) => { + (DataType::Dictionary(ref kt, ref _vt), ScalarValue::Utf8(Some(k))) => { match kt.as_ref() { DataType::Int8 | DataType::Int16 |DataType::Int32 |DataType::Int64 |DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - Ok(Field::new(&k, *kt.clone(), true)) + Ok(Field::new(k, *kt.clone(), true)) }, _ => Err(DataFusionError::Plan(format!("The key for a dictionary has to be a primitive type, was : \"{}\"", key))), } diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index 75e8b82f2772..b91aef9077f3 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -231,10 +231,10 @@ pub mod variable; pub use arrow; pub use parquet; +pub mod field_util; #[cfg(test)] pub mod test; pub mod test_util; -pub mod utils; #[macro_use] #[cfg(feature = "regex_expressions")] diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index dccfcb01cfc8..fe91ba0c3ec8 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -20,13 +20,13 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; +use crate::field_util::get_indexed_field; use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan}; use crate::physical_plan::functions::Volatility; use crate::physical_plan::{ aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF, window_functions, }; -use crate::utils::get_indexed_field; use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::{compute::can_cast_types, datatypes::DataType}; diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index f68cfb68eab1..4f80b428f45f 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -27,15 +27,14 @@ use arrow::{ use crate::arrow::array::Array; use crate::arrow::compute::concat; use crate::arrow::datatypes::{ - ArrowNativeType, Int16Type, Int32Type, Int64Type, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, + Int16Type, Int32Type, Int64Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use crate::scalar::ScalarValue; use crate::{ error::DataFusionError, error::Result, + field_util::get_indexed_field as get_data_type_field, physical_plan::{ColumnarValue, PhysicalExpr}, - utils::get_indexed_field as get_data_type_field, }; use arrow::array::{ArrayRef, DictionaryArray, ListArray}; use arrow::datatypes::{ArrowPrimitiveType, Int8Type}; @@ -107,12 +106,10 @@ impl PhysicalExpr for GetIndexedFieldExpr { DataType::UInt16 => dict_lookup::(array, s), DataType::UInt32 => dict_lookup::(array, s), DataType::UInt64 => dict_lookup::(array, s), - _ => { - return Err(DataFusionError::NotImplemented( - "dictionary lookup only available for numeric keys" - .to_string(), - )) - } + _ => Err(DataFusionError::NotImplemented( + "dictionary lookup only available for numeric keys" + .to_string(), + )), } } _ => Err(DataFusionError::NotImplemented( diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 7de534448c70..60d2da8be2c7 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -89,6 +89,7 @@ fn plan_key(key: Value) -> ScalarValue { } } +#[allow(clippy::branches_sharing_code)] fn plan_indexed(expr: Expr, mut keys: Vec) -> Expr { if keys.len() == 1 { let key = keys.pop().unwrap(); @@ -97,7 +98,6 @@ fn plan_indexed(expr: Expr, mut keys: Vec) -> Expr { key: plan_key(key), } } else { - // "table.column[key]..." let key = keys.pop().unwrap(); let expr = Box::new(plan_indexed(expr, keys)); Expr::GetIndexedField { diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 1e7c0d4e65a3..c60ca5e5def1 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -5316,16 +5316,12 @@ async fn query_get_indexed_field() -> Result<()> { )])); let builder = PrimitiveBuilder::::new(3); let mut lb = ListBuilder::new(builder); - for int_vec in vec![ - vec![0 as i64, 1, 2], - vec![4 as i64, 5, 6], - vec![7 as i64, 8, 9], - ] { + for int_vec in vec![vec![0, 1, 2], vec![4, 5, 6], vec![7, 8, 9]] { let builder = lb.values(); for int in int_vec { - builder.append_value(int); + builder.append_value(int).unwrap(); } - lb.append(true); + lb.append(true).unwrap(); } let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(lb.finish())])?; @@ -5363,34 +5359,34 @@ async fn query_nested_get_indexed_field() -> Result<()> { let nested_lb = ListBuilder::new(builder); let mut lb = ListBuilder::new(nested_lb); for int_vec_vec in vec![ - vec![vec![0 as i64, 1], vec![2, 3], vec![3, 4]], - vec![vec![5 as i64, 6], vec![7, 8], vec![9, 10]], - vec![vec![11 as i64, 12], vec![13, 14], vec![15, 16]], + vec![vec![0, 1], vec![2, 3], vec![3, 4]], + vec![vec![5, 6], vec![7, 8], vec![9, 10]], + vec![vec![11, 12], vec![13, 14], vec![15, 16]], ] { let nested_builder = lb.values(); for int_vec in int_vec_vec { - let mut builder = nested_builder.values(); + let builder = nested_builder.values(); for int in int_vec { - builder.append_value(int); + builder.append_value(int).unwrap(); } - nested_builder.append(true); + nested_builder.append(true).unwrap(); } - lb.append(true); + lb.append(true).unwrap(); } + let dictionary_values = StringArray::from(vec![Some("a"), Some("b"), Some("c")]); let mut sb = StringDictionaryBuilder::new_with_dictionary( PrimitiveBuilder::::new(3), &dictionary_values, ) .unwrap(); - for s in vec!["b", "a", "c"] { - sb.append(s); + for s in &["b", "a", "c"] { + sb.append(s).unwrap(); } - let array = sb.finish(); - eprintln!("array.keys() = {:?}", array.keys()); + let data = RecordBatch::try_new( schema.clone(), - vec![Arc::new(lb.finish()), Arc::new(array)], + vec![Arc::new(lb.finish()), Arc::new(sb.finish())], )?; let table = MemTable::try_new(schema, vec![vec![data]])?; let table_a = Arc::new(table); @@ -5406,9 +5402,9 @@ async fn query_nested_get_indexed_field() -> Result<()> { let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["0"], vec!["5"], vec!["11"]]; assert_eq!(expected, actual); - let sql = r#"SELECT some_dict["b"] as i0 FROM ints LIMIT 3"#; + let sql = r#"SELECT some_dict["b"], some_dict["a"] FROM ints LIMIT 3"#; let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["0"], vec!["1"], vec!["0"]]; + let expected = vec![vec!["0"], vec!["1"]]; assert_eq!(expected, actual); Ok(()) } From f04507fa959f899cc88e1a42f54ce9070c12e7b5 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Fri, 24 Sep 2021 15:05:32 +0200 Subject: [PATCH 08/15] fix test --- datafusion/src/physical_plan/expressions/mod.rs | 2 +- datafusion/tests/sql.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index bcd326f43993..0ba545312446 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -33,7 +33,7 @@ mod case; mod cast; mod coercion; mod column; -mod count +mod count; mod cume_dist; mod get_indexed_field; mod in_list; diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index c60ca5e5def1..247cfdaba63e 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -5404,7 +5404,7 @@ async fn query_nested_get_indexed_field() -> Result<()> { assert_eq!(expected, actual); let sql = r#"SELECT some_dict["b"], some_dict["a"] FROM ints LIMIT 3"#; let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["0"], vec!["1"]]; + let expected = vec![vec!["0", "1"]]; assert_eq!(expected, actual); Ok(()) } From fb018c02f18ee406a08a260d754a86ac4118e160 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Tue, 12 Oct 2021 10:40:45 +0200 Subject: [PATCH 09/15] Revert dictionary lookup for indexed fields --- datafusion/Cargo.toml | 4 +- datafusion/src/field_util.rs | 15 ------ .../expressions/get_indexed_field.rs | 48 +------------------ datafusion/tests/sql.rs | 37 +++----------- 4 files changed, 11 insertions(+), 93 deletions(-) diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 17d5124b8bca..e05bc0702bc0 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -46,7 +46,7 @@ unicode_expressions = ["unicode-segmentation"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] # Used to enable the avro format -avro = ["avro-rs"] +avro = ["avro-rs", "num-traits"] [dependencies] ahash = "0.7" @@ -74,7 +74,7 @@ lazy_static = { version = "^1.4.0", optional = true } smallvec = { version = "1.6", features = ["union"] } rand = "0.8" avro-rs = { version = "0.13", features = ["snappy"], optional = true } -num-traits = { version = "0.2" } +num-traits = { version = "0.2", optional = true } [dev-dependencies] criterion = "0.3" diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs index 795a5c92ec84..f125f18c722b 100644 --- a/datafusion/src/field_util.rs +++ b/datafusion/src/field_util.rs @@ -29,21 +29,6 @@ use crate::scalar::ScalarValue; /// * there is no field key is not of the required index type pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { match (data_type, key) { - (DataType::Dictionary(ref kt, ref _vt), ScalarValue::Utf8(Some(k))) => { - match kt.as_ref() { - DataType::Int8 | DataType::Int16 |DataType::Int32 |DataType::Int64 |DataType::UInt8 - | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - Ok(Field::new(k, *kt.clone(), true)) - }, - _ => Err(DataFusionError::Plan(format!("The key for a dictionary has to be a primitive type, was : \"{}\"", key))), - } - }, - (DataType::Dictionary(_, _), _) => { - Err(DataFusionError::Plan( - "Only utf8 types are valid for dictionary lookup" - .to_string(), - )) - } (DataType::List(lt), ScalarValue::Int64(Some(i))) => { Ok(Field::new(&i.to_string(), lt.data_type().clone(), false)) } diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index 4f80b428f45f..d5270b2cb152 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -26,9 +26,6 @@ use arrow::{ use crate::arrow::array::Array; use crate::arrow::compute::concat; -use crate::arrow::datatypes::{ - Int16Type, Int32Type, Int64Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; use crate::scalar::ScalarValue; use crate::{ error::DataFusionError, @@ -36,9 +33,7 @@ use crate::{ field_util::get_indexed_field as get_data_type_field, physical_plan::{ColumnarValue, PhysicalExpr}, }; -use arrow::array::{ArrayRef, DictionaryArray, ListArray}; -use arrow::datatypes::{ArrowPrimitiveType, Int8Type}; -use num_traits::ToPrimitive; +use arrow::array::ListArray; use std::fmt::Debug; /// expression to get a field of a struct array. @@ -96,25 +91,8 @@ impl PhysicalExpr for GetIndexedFieldExpr { let iter = concat(vec.as_slice()).unwrap(); Ok(ColumnarValue::Array(iter)) } - (DataType::Dictionary(ref kt, _), ScalarValue::Utf8(Some(s))) => { - match **kt { - DataType::Int8 => dict_lookup::(array, s), - DataType::Int16 => dict_lookup::(array, s), - DataType::Int32 => dict_lookup::(array, s), - DataType::Int64 => dict_lookup::(array, s), - DataType::UInt8 => dict_lookup::(array, s), - DataType::UInt16 => dict_lookup::(array, s), - DataType::UInt32 => dict_lookup::(array, s), - DataType::UInt64 => dict_lookup::(array, s), - _ => Err(DataFusionError::NotImplemented( - "dictionary lookup only available for numeric keys" - .to_string(), - )), - } - } _ => Err(DataFusionError::NotImplemented( - "get indexed field is only possible on dictionary and list" - .to_string(), + "get indexed field is only possible on lists".to_string(), )), }, ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented( @@ -131,25 +109,3 @@ pub fn get_indexed_field( ) -> Result> { Ok(Arc::new(GetIndexedFieldExpr::new(arg, key))) } - -fn dict_lookup( - array: ArrayRef, - lookup: &str, -) -> Result -where - T::Native: num_traits::cast::ToPrimitive, -{ - let as_dict_array = array.as_any().downcast_ref::>().unwrap(); - if let Some(index) = as_dict_array.lookup_key(lookup) { - Ok(ColumnarValue::Array( - as_dict_array - .keys() - .slice(ToPrimitive::to_usize(&index).unwrap(), 1), - )) - } else { - Err(DataFusionError::NotImplemented(format!( - "key not found in dictionary for : {}", - lookup - ))) - } -} diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 247cfdaba63e..f1e988814add 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -5343,18 +5343,12 @@ async fn query_nested_get_indexed_field() -> Result<()> { let mut ctx = ExecutionContext::new(); let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); // Nested schema of { "some_list": [[i64]] } - let schema = Arc::new(Schema::new(vec![ - Field::new( - "some_list", - DataType::List(Box::new(Field::new("item", nested_dt.clone(), true))), - false, - ), - Field::new( - "some_dict", - DataType::Dictionary(Box::new(DataType::Int64), Box::new(DataType::Utf8)), - false, - ), - ])); + let schema = Arc::new(Schema::new(vec![Field::new( + "some_list", + DataType::List(Box::new(Field::new("item", nested_dt.clone(), true))), + false, + )])); + let builder = PrimitiveBuilder::::new(3); let nested_lb = ListBuilder::new(builder); let mut lb = ListBuilder::new(nested_lb); @@ -5374,20 +5368,7 @@ async fn query_nested_get_indexed_field() -> Result<()> { lb.append(true).unwrap(); } - let dictionary_values = StringArray::from(vec![Some("a"), Some("b"), Some("c")]); - let mut sb = StringDictionaryBuilder::new_with_dictionary( - PrimitiveBuilder::::new(3), - &dictionary_values, - ) - .unwrap(); - for s in &["b", "a", "c"] { - sb.append(s).unwrap(); - } - - let data = RecordBatch::try_new( - schema.clone(), - vec![Arc::new(lb.finish()), Arc::new(sb.finish())], - )?; + let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(lb.finish())])?; let table = MemTable::try_new(schema, vec![vec![data]])?; let table_a = Arc::new(table); @@ -5402,9 +5383,5 @@ async fn query_nested_get_indexed_field() -> Result<()> { let actual = execute(&mut ctx, sql).await; let expected = vec![vec!["0"], vec!["5"], vec!["11"]]; assert_eq!(expected, actual); - let sql = r#"SELECT some_dict["b"], some_dict["a"] FROM ints LIMIT 3"#; - let actual = execute(&mut ctx, sql).await; - let expected = vec![vec!["0", "1"]]; - assert_eq!(expected, actual); Ok(()) } From a630bcd05e3f18dac62a3dd7a71e4ae835ccfcc4 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Tue, 26 Oct 2021 08:15:13 +0200 Subject: [PATCH 10/15] Reject negative ints when accessing list values in get indexed field --- datafusion/src/field_util.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs index f125f18c722b..e95561263c91 100644 --- a/datafusion/src/field_util.rs +++ b/datafusion/src/field_util.rs @@ -30,7 +30,13 @@ use crate::scalar::ScalarValue; pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { match (data_type, key) { (DataType::List(lt), ScalarValue::Int64(Some(i))) => { - Ok(Field::new(&i.to_string(), lt.data_type().clone(), false)) + if *i < 0 { + Err(DataFusionError::Plan( + format!("List based indexed access requires a positive int, was {0}", i), + )) + } else { + Ok(Field::new(&i.to_string(), lt.data_type().clone(), false)) + } } (DataType::List(_), _) => { Err(DataFusionError::Plan( From f0dfe946ae6f597975931b239911a88bae10a9e5 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Tue, 26 Oct 2021 08:15:37 +0200 Subject: [PATCH 11/15] Fix doc in get_indexed_field --- datafusion/src/field_util.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs index e95561263c91..c34233972c0e 100644 --- a/datafusion/src/field_util.rs +++ b/datafusion/src/field_util.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field}; use crate::error::{DataFusionError, Result}; use crate::scalar::ScalarValue; -/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Dictionary`]. +/// Returns the field access indexed by `key` from a [`DataType::List`] /// # Error /// Errors if /// * the `data_type` is not a Struct or, From 53d361bfb7a09021e68df0c41e23c549073aa1e6 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Tue, 26 Oct 2021 08:21:10 +0200 Subject: [PATCH 12/15] use GetIndexedFieldExpr directly --- .../expressions/get_indexed_field.rs | 8 -------- .../src/physical_plan/expressions/mod.rs | 2 +- datafusion/src/physical_plan/planner.rs | 20 ++++++++++++++----- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index d5270b2cb152..fe2a930a7946 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -101,11 +101,3 @@ impl PhysicalExpr for GetIndexedFieldExpr { } } } - -/// Create a `.[field]` expression -pub fn get_indexed_field( - arg: Arc, - key: ScalarValue, -) -> Result> { - Ok(Arc::new(GetIndexedFieldExpr::new(arg, key))) -} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 0ba545312446..24a61d2b8483 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -67,7 +67,7 @@ pub use cast::{ pub use column::{col, Column}; pub use count::Count; pub use cume_dist::cume_dist; -pub use get_indexed_field::{get_indexed_field, GetIndexedFieldExpr}; +pub use get_indexed_field::GetIndexedFieldExpr; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index f58627d56cf4..fd0421b5e0ba 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -32,7 +32,9 @@ use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; use crate::physical_plan::cross_join::CrossJoinExec; use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions; -use crate::physical_plan::expressions::{CaseExpr, Column, Literal, PhysicalSortExpr}; +use crate::physical_plan::expressions::{ + CaseExpr, Column, GetIndexedFieldExpr, Literal, PhysicalSortExpr, +}; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use crate::physical_plan::hash_join::HashJoinExec; @@ -993,10 +995,18 @@ impl DefaultPhysicalPlanner { Expr::IsNotNull(expr) => expressions::is_not_null( self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, ), - Expr::GetIndexedField { expr, key } => expressions::get_indexed_field( - self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?, - key.clone(), - ), + Expr::GetIndexedField { expr, key } => { + Ok(Arc::new(GetIndexedFieldExpr::new( + self.create_physical_expr( + expr, + input_dfschema, + input_schema, + ctx_state, + )?, + key.clone(), + ))) + } + Expr::ScalarFunction { fun, args } => { let physical_args = args .iter() From b2012b6b456842fda30132f0725a82de739b1db3 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Fri, 29 Oct 2021 10:46:55 +0200 Subject: [PATCH 13/15] return the data type in unavailable field indexation error message --- .../src/physical_plan/expressions/get_indexed_field.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index fe2a930a7946..25ec41cfbd79 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -91,9 +91,10 @@ impl PhysicalExpr for GetIndexedFieldExpr { let iter = concat(vec.as_slice()).unwrap(); Ok(ColumnarValue::Array(iter)) } - _ => Err(DataFusionError::NotImplemented( - "get indexed field is only possible on lists".to_string(), - )), + (dt, _) => Err(DataFusionError::NotImplemented(format!( + "get indexed field is not implemented for {}", + dt + ))), }, ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented( "field is not yet implemented for scalar values".to_string(), From 24bac8ba3ecb6865c51d97688e385eba1b70c59f Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Fri, 29 Oct 2021 13:41:10 +0200 Subject: [PATCH 14/15] Add unit tests for the physical plan of get_indexed_field --- datafusion/src/field_util.rs | 18 +-- datafusion/src/lib.rs | 2 +- datafusion/src/logical_plan/expr.rs | 2 +- .../expressions/get_indexed_field.rs | 137 +++++++++++++++++- 4 files changed, 139 insertions(+), 20 deletions(-) diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs index c34233972c0e..9d5facebc0c1 100644 --- a/datafusion/src/field_util.rs +++ b/datafusion/src/field_util.rs @@ -31,21 +31,19 @@ pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { if *i < 0 { - Err(DataFusionError::Plan( - format!("List based indexed access requires a positive int, was {0}", i), - )) + Err(DataFusionError::Plan(format!( + "List based indexed access requires a positive int, was {0}", + i + ))) } else { Ok(Field::new(&i.to_string(), lt.data_type().clone(), false)) } } - (DataType::List(_), _) => { - Err(DataFusionError::Plan( - "Only ints are valid as an indexed field in a list" - .to_string(), - )) - } + (DataType::List(_), _) => Err(DataFusionError::Plan( + "Only ints are valid as an indexed field in a list".to_string(), + )), _ => Err(DataFusionError::Plan( - "The expression to get an indexed field is only valid for `List` or 'Dictionary'" + "The expression to get an indexed field is only valid for `List` types" .to_string(), )), } diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index b91aef9077f3..fa140cee38f9 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -231,7 +231,7 @@ pub mod variable; pub use arrow; pub use parquet; -pub mod field_util; +pub(crate) mod field_util; #[cfg(test)] pub mod test; pub mod test_util; diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index fe91ba0c3ec8..499a8c720dba 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -246,7 +246,7 @@ pub enum Expr { IsNull(Box), /// arithmetic negation of an expression, the operand must be of a signed numeric data type Negative(Box), - /// Returns the field of a [`ListArray`] or ['DictionaryArray'] by name + /// Returns the field of a [`ListArray`] by key GetIndexedField { /// the expression to take the field from expr: Box, diff --git a/datafusion/src/physical_plan/expressions/get_indexed_field.rs b/datafusion/src/physical_plan/expressions/get_indexed_field.rs index 25ec41cfbd79..8a9191e9c346 100644 --- a/datafusion/src/physical_plan/expressions/get_indexed_field.rs +++ b/datafusion/src/physical_plan/expressions/get_indexed_field.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! get field of a struct array +//! get field of a `ListArray` +use std::convert::TryInto; use std::{any::Any, sync::Arc}; use arrow::{ @@ -80,25 +81,145 @@ impl PhysicalExpr for GetIndexedFieldExpr { let arg = self.arg.evaluate(batch)?; match arg { ColumnarValue::Array(array) => match (array.data_type(), &self.key) { + (DataType::List(_), _) if self.key.is_null() => { + let scalar_null: ScalarValue = array.data_type().try_into()?; + Ok(ColumnarValue::Scalar(scalar_null)) + } (DataType::List(_), ScalarValue::Int64(Some(i))) => { let as_list_array = array.as_any().downcast_ref::().unwrap(); - let x: Vec> = as_list_array + if as_list_array.is_empty() { + let scalar_null: ScalarValue = array.data_type().try_into()?; + return Ok(ColumnarValue::Scalar(scalar_null)) + } + let sliced_array: Vec> = as_list_array .iter() .filter_map(|o| o.map(|list| list.slice(*i as usize, 1))) .collect(); - let vec = x.iter().map(|a| a.as_ref()).collect::>(); + let vec = sliced_array.iter().map(|a| a.as_ref()).collect::>(); let iter = concat(vec.as_slice()).unwrap(); Ok(ColumnarValue::Array(iter)) } - (dt, _) => Err(DataFusionError::NotImplemented(format!( - "get indexed field is not implemented for {}", - dt - ))), + (dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {} with {} index", dt, key))), }, ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented( - "field is not yet implemented for scalar values".to_string(), + "field access is not yet implemented for scalar values".to_string(), )), } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::physical_plan::expressions::{col, lit}; + use arrow::array::{ListBuilder, StringBuilder}; + use arrow::{array::StringArray, datatypes::Field}; + + fn get_indexed_field_test( + list_of_lists: Vec>>, + index: i64, + expected: Vec>, + ) -> Result<()> { + let schema = list_schema("l"); + let builder = StringBuilder::new(3); + let mut lb = ListBuilder::new(builder); + for values in list_of_lists { + let builder = lb.values(); + for value in values { + match value { + None => builder.append_null(), + Some(v) => builder.append_value(v), + } + .unwrap() + } + lb.append(true).unwrap(); + } + + let expr = col("l", &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; + + let key = ScalarValue::Int64(Some(index)); + let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + let result = result + .as_any() + .downcast_ref::() + .expect("failed to downcast to StringArray"); + let expected = &StringArray::from(expected); + assert_eq!(expected, result); + Ok(()) + } + + fn list_schema(col: &str) -> Schema { + Schema::new(vec![Field::new( + col, + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))), + true, + )]) + } + + #[test] + fn get_indexed_field_list() -> Result<()> { + let list_of_lists = vec![ + vec![Some("a"), Some("b"), None], + vec![None, Some("c"), Some("d")], + vec![Some("e"), None, Some("f")], + ]; + let expected_list = vec![ + vec![Some("a"), None, Some("e")], + vec![Some("b"), Some("c"), None], + vec![None, Some("d"), Some("f")], + ]; + + for (i, expected) in expected_list.into_iter().enumerate() { + get_indexed_field_test(list_of_lists.clone(), i as i64, expected)?; + } + Ok(()) + } + + #[test] + fn get_indexed_field_empty_list() -> Result<()> { + let schema = list_schema("l"); + let builder = StringBuilder::new(0); + let mut lb = ListBuilder::new(builder); + let expr = col("l", &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; + let key = ScalarValue::Int64(Some(0)); + let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + assert!(result.is_empty()); + Ok(()) + } + + fn get_indexed_field_test_failure( + schema: Schema, + expr: Arc, + key: ScalarValue, + expected: &str, + ) -> Result<()> { + let builder = StringBuilder::new(3); + let mut lb = ListBuilder::new(builder); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; + let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); + let r = expr.evaluate(&batch).map(|_| ()); + assert!(r.is_err()); + assert_eq!(format!("{}", r.unwrap_err()), expected); + Ok(()) + } + + #[test] + fn get_indexed_field_invalid_scalar() -> Result<()> { + let schema = list_schema("l"); + let expr = lit(ScalarValue::Utf8(Some("a".to_string()))); + get_indexed_field_test_failure(schema, expr, ScalarValue::Int64(Some(0)), "This feature is not implemented: field access is not yet implemented for scalar values") + } + + #[test] + fn get_indexed_field_invalid_list_index() -> Result<()> { + let schema = list_schema("l"); + let expr = col("l", &schema).unwrap(); + get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index") + } +} From 5d61ce58eb29913c88d9c2065267862591120d46 Mon Sep 17 00:00:00 2001 From: Guillaume Balaine Date: Fri, 29 Oct 2021 19:16:23 +0200 Subject: [PATCH 15/15] Fix missing clause for const evaluator --- datafusion/src/optimizer/utils.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 39b68a980133..f36330eac05e 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -656,6 +656,7 @@ impl ConstEvaluator { Expr::Cast { .. } => true, Expr::TryCast { .. } => true, Expr::InList { .. } => true, + Expr::GetIndexedField { .. } => true, } }