Skip to content

Commit 24bac8b

Browse files
committed
Add unit tests for the physical plan of get_indexed_field
1 parent b2012b6 commit 24bac8b

File tree

4 files changed

+139
-20
lines changed

4 files changed

+139
-20
lines changed

datafusion/src/field_util.rs

+8-10
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,19 @@ pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result<Fiel
3131
match (data_type, key) {
3232
(DataType::List(lt), ScalarValue::Int64(Some(i))) => {
3333
if *i < 0 {
34-
Err(DataFusionError::Plan(
35-
format!("List based indexed access requires a positive int, was {0}", i),
36-
))
34+
Err(DataFusionError::Plan(format!(
35+
"List based indexed access requires a positive int, was {0}",
36+
i
37+
)))
3738
} else {
3839
Ok(Field::new(&i.to_string(), lt.data_type().clone(), false))
3940
}
4041
}
41-
(DataType::List(_), _) => {
42-
Err(DataFusionError::Plan(
43-
"Only ints are valid as an indexed field in a list"
44-
.to_string(),
45-
))
46-
}
42+
(DataType::List(_), _) => Err(DataFusionError::Plan(
43+
"Only ints are valid as an indexed field in a list".to_string(),
44+
)),
4745
_ => Err(DataFusionError::Plan(
48-
"The expression to get an indexed field is only valid for `List` or 'Dictionary'"
46+
"The expression to get an indexed field is only valid for `List` types"
4947
.to_string(),
5048
)),
5149
}

datafusion/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ pub mod variable;
231231
pub use arrow;
232232
pub use parquet;
233233

234-
pub mod field_util;
234+
pub(crate) mod field_util;
235235
#[cfg(test)]
236236
pub mod test;
237237
pub mod test_util;

datafusion/src/logical_plan/expr.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ pub enum Expr {
246246
IsNull(Box<Expr>),
247247
/// arithmetic negation of an expression, the operand must be of a signed numeric data type
248248
Negative(Box<Expr>),
249-
/// Returns the field of a [`ListArray`] or ['DictionaryArray'] by name
249+
/// Returns the field of a [`ListArray`] by key
250250
GetIndexedField {
251251
/// the expression to take the field from
252252
expr: Box<Expr>,

datafusion/src/physical_plan/expressions/get_indexed_field.rs

+129-8
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
//! get field of a struct array
18+
//! get field of a `ListArray`
1919
20+
use std::convert::TryInto;
2021
use std::{any::Any, sync::Arc};
2122

2223
use arrow::{
@@ -80,25 +81,145 @@ impl PhysicalExpr for GetIndexedFieldExpr {
8081
let arg = self.arg.evaluate(batch)?;
8182
match arg {
8283
ColumnarValue::Array(array) => match (array.data_type(), &self.key) {
84+
(DataType::List(_), _) if self.key.is_null() => {
85+
let scalar_null: ScalarValue = array.data_type().try_into()?;
86+
Ok(ColumnarValue::Scalar(scalar_null))
87+
}
8388
(DataType::List(_), ScalarValue::Int64(Some(i))) => {
8489
let as_list_array =
8590
array.as_any().downcast_ref::<ListArray>().unwrap();
86-
let x: Vec<Arc<dyn Array>> = as_list_array
91+
if as_list_array.is_empty() {
92+
let scalar_null: ScalarValue = array.data_type().try_into()?;
93+
return Ok(ColumnarValue::Scalar(scalar_null))
94+
}
95+
let sliced_array: Vec<Arc<dyn Array>> = as_list_array
8796
.iter()
8897
.filter_map(|o| o.map(|list| list.slice(*i as usize, 1)))
8998
.collect();
90-
let vec = x.iter().map(|a| a.as_ref()).collect::<Vec<&dyn Array>>();
99+
let vec = sliced_array.iter().map(|a| a.as_ref()).collect::<Vec<&dyn Array>>();
91100
let iter = concat(vec.as_slice()).unwrap();
92101
Ok(ColumnarValue::Array(iter))
93102
}
94-
(dt, _) => Err(DataFusionError::NotImplemented(format!(
95-
"get indexed field is not implemented for {}",
96-
dt
97-
))),
103+
(dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {} with {} index", dt, key))),
98104
},
99105
ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented(
100-
"field is not yet implemented for scalar values".to_string(),
106+
"field access is not yet implemented for scalar values".to_string(),
101107
)),
102108
}
103109
}
104110
}
111+
112+
#[cfg(test)]
113+
mod tests {
114+
use super::*;
115+
use crate::error::Result;
116+
use crate::physical_plan::expressions::{col, lit};
117+
use arrow::array::{ListBuilder, StringBuilder};
118+
use arrow::{array::StringArray, datatypes::Field};
119+
120+
fn get_indexed_field_test(
121+
list_of_lists: Vec<Vec<Option<&str>>>,
122+
index: i64,
123+
expected: Vec<Option<&str>>,
124+
) -> Result<()> {
125+
let schema = list_schema("l");
126+
let builder = StringBuilder::new(3);
127+
let mut lb = ListBuilder::new(builder);
128+
for values in list_of_lists {
129+
let builder = lb.values();
130+
for value in values {
131+
match value {
132+
None => builder.append_null(),
133+
Some(v) => builder.append_value(v),
134+
}
135+
.unwrap()
136+
}
137+
lb.append(true).unwrap();
138+
}
139+
140+
let expr = col("l", &schema).unwrap();
141+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?;
142+
143+
let key = ScalarValue::Int64(Some(index));
144+
let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
145+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
146+
let result = result
147+
.as_any()
148+
.downcast_ref::<StringArray>()
149+
.expect("failed to downcast to StringArray");
150+
let expected = &StringArray::from(expected);
151+
assert_eq!(expected, result);
152+
Ok(())
153+
}
154+
155+
fn list_schema(col: &str) -> Schema {
156+
Schema::new(vec![Field::new(
157+
col,
158+
DataType::List(Box::new(Field::new("item", DataType::Utf8, true))),
159+
true,
160+
)])
161+
}
162+
163+
#[test]
164+
fn get_indexed_field_list() -> Result<()> {
165+
let list_of_lists = vec![
166+
vec![Some("a"), Some("b"), None],
167+
vec![None, Some("c"), Some("d")],
168+
vec![Some("e"), None, Some("f")],
169+
];
170+
let expected_list = vec![
171+
vec![Some("a"), None, Some("e")],
172+
vec![Some("b"), Some("c"), None],
173+
vec![None, Some("d"), Some("f")],
174+
];
175+
176+
for (i, expected) in expected_list.into_iter().enumerate() {
177+
get_indexed_field_test(list_of_lists.clone(), i as i64, expected)?;
178+
}
179+
Ok(())
180+
}
181+
182+
#[test]
183+
fn get_indexed_field_empty_list() -> Result<()> {
184+
let schema = list_schema("l");
185+
let builder = StringBuilder::new(0);
186+
let mut lb = ListBuilder::new(builder);
187+
let expr = col("l", &schema).unwrap();
188+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?;
189+
let key = ScalarValue::Int64(Some(0));
190+
let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
191+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
192+
assert!(result.is_empty());
193+
Ok(())
194+
}
195+
196+
fn get_indexed_field_test_failure(
197+
schema: Schema,
198+
expr: Arc<dyn PhysicalExpr>,
199+
key: ScalarValue,
200+
expected: &str,
201+
) -> Result<()> {
202+
let builder = StringBuilder::new(3);
203+
let mut lb = ListBuilder::new(builder);
204+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?;
205+
let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
206+
let r = expr.evaluate(&batch).map(|_| ());
207+
assert!(r.is_err());
208+
assert_eq!(format!("{}", r.unwrap_err()), expected);
209+
Ok(())
210+
}
211+
212+
#[test]
213+
fn get_indexed_field_invalid_scalar() -> Result<()> {
214+
let schema = list_schema("l");
215+
let expr = lit(ScalarValue::Utf8(Some("a".to_string())));
216+
get_indexed_field_test_failure(schema, expr, ScalarValue::Int64(Some(0)), "This feature is not implemented: field access is not yet implemented for scalar values")
217+
}
218+
219+
#[test]
220+
fn get_indexed_field_invalid_list_index() -> Result<()> {
221+
let schema = list_schema("l");
222+
let expr = col("l", &schema).unwrap();
223+
get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index")
224+
}
225+
}

0 commit comments

Comments
 (0)