|
15 | 15 | // specific language governing permissions and limitations
|
16 | 16 | // under the License.
|
17 | 17 |
|
18 |
| -//! get field of a struct array |
| 18 | +//! get field of a `ListArray` |
19 | 19 |
|
| 20 | +use std::convert::TryInto; |
20 | 21 | use std::{any::Any, sync::Arc};
|
21 | 22 |
|
22 | 23 | use arrow::{
|
@@ -80,25 +81,145 @@ impl PhysicalExpr for GetIndexedFieldExpr {
|
80 | 81 | let arg = self.arg.evaluate(batch)?;
|
81 | 82 | match arg {
|
82 | 83 | ColumnarValue::Array(array) => match (array.data_type(), &self.key) {
|
| 84 | + (DataType::List(_), _) if self.key.is_null() => { |
| 85 | + let scalar_null: ScalarValue = array.data_type().try_into()?; |
| 86 | + Ok(ColumnarValue::Scalar(scalar_null)) |
| 87 | + } |
83 | 88 | (DataType::List(_), ScalarValue::Int64(Some(i))) => {
|
84 | 89 | let as_list_array =
|
85 | 90 | array.as_any().downcast_ref::<ListArray>().unwrap();
|
86 |
| - let x: Vec<Arc<dyn Array>> = as_list_array |
| 91 | + if as_list_array.is_empty() { |
| 92 | + let scalar_null: ScalarValue = array.data_type().try_into()?; |
| 93 | + return Ok(ColumnarValue::Scalar(scalar_null)) |
| 94 | + } |
| 95 | + let sliced_array: Vec<Arc<dyn Array>> = as_list_array |
87 | 96 | .iter()
|
88 | 97 | .filter_map(|o| o.map(|list| list.slice(*i as usize, 1)))
|
89 | 98 | .collect();
|
90 |
| - let vec = x.iter().map(|a| a.as_ref()).collect::<Vec<&dyn Array>>(); |
| 99 | + let vec = sliced_array.iter().map(|a| a.as_ref()).collect::<Vec<&dyn Array>>(); |
91 | 100 | let iter = concat(vec.as_slice()).unwrap();
|
92 | 101 | Ok(ColumnarValue::Array(iter))
|
93 | 102 | }
|
94 |
| - (dt, _) => Err(DataFusionError::NotImplemented(format!( |
95 |
| - "get indexed field is not implemented for {}", |
96 |
| - dt |
97 |
| - ))), |
| 103 | + (dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {} with {} index", dt, key))), |
98 | 104 | },
|
99 | 105 | ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented(
|
100 |
| - "field is not yet implemented for scalar values".to_string(), |
| 106 | + "field access is not yet implemented for scalar values".to_string(), |
101 | 107 | )),
|
102 | 108 | }
|
103 | 109 | }
|
104 | 110 | }
|
| 111 | + |
| 112 | +#[cfg(test)] |
| 113 | +mod tests { |
| 114 | + use super::*; |
| 115 | + use crate::error::Result; |
| 116 | + use crate::physical_plan::expressions::{col, lit}; |
| 117 | + use arrow::array::{ListBuilder, StringBuilder}; |
| 118 | + use arrow::{array::StringArray, datatypes::Field}; |
| 119 | + |
| 120 | + fn get_indexed_field_test( |
| 121 | + list_of_lists: Vec<Vec<Option<&str>>>, |
| 122 | + index: i64, |
| 123 | + expected: Vec<Option<&str>>, |
| 124 | + ) -> Result<()> { |
| 125 | + let schema = list_schema("l"); |
| 126 | + let builder = StringBuilder::new(3); |
| 127 | + let mut lb = ListBuilder::new(builder); |
| 128 | + for values in list_of_lists { |
| 129 | + let builder = lb.values(); |
| 130 | + for value in values { |
| 131 | + match value { |
| 132 | + None => builder.append_null(), |
| 133 | + Some(v) => builder.append_value(v), |
| 134 | + } |
| 135 | + .unwrap() |
| 136 | + } |
| 137 | + lb.append(true).unwrap(); |
| 138 | + } |
| 139 | + |
| 140 | + let expr = col("l", &schema).unwrap(); |
| 141 | + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; |
| 142 | + |
| 143 | + let key = ScalarValue::Int64(Some(index)); |
| 144 | + let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); |
| 145 | + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); |
| 146 | + let result = result |
| 147 | + .as_any() |
| 148 | + .downcast_ref::<StringArray>() |
| 149 | + .expect("failed to downcast to StringArray"); |
| 150 | + let expected = &StringArray::from(expected); |
| 151 | + assert_eq!(expected, result); |
| 152 | + Ok(()) |
| 153 | + } |
| 154 | + |
| 155 | + fn list_schema(col: &str) -> Schema { |
| 156 | + Schema::new(vec![Field::new( |
| 157 | + col, |
| 158 | + DataType::List(Box::new(Field::new("item", DataType::Utf8, true))), |
| 159 | + true, |
| 160 | + )]) |
| 161 | + } |
| 162 | + |
| 163 | + #[test] |
| 164 | + fn get_indexed_field_list() -> Result<()> { |
| 165 | + let list_of_lists = vec![ |
| 166 | + vec![Some("a"), Some("b"), None], |
| 167 | + vec![None, Some("c"), Some("d")], |
| 168 | + vec![Some("e"), None, Some("f")], |
| 169 | + ]; |
| 170 | + let expected_list = vec![ |
| 171 | + vec![Some("a"), None, Some("e")], |
| 172 | + vec![Some("b"), Some("c"), None], |
| 173 | + vec![None, Some("d"), Some("f")], |
| 174 | + ]; |
| 175 | + |
| 176 | + for (i, expected) in expected_list.into_iter().enumerate() { |
| 177 | + get_indexed_field_test(list_of_lists.clone(), i as i64, expected)?; |
| 178 | + } |
| 179 | + Ok(()) |
| 180 | + } |
| 181 | + |
| 182 | + #[test] |
| 183 | + fn get_indexed_field_empty_list() -> Result<()> { |
| 184 | + let schema = list_schema("l"); |
| 185 | + let builder = StringBuilder::new(0); |
| 186 | + let mut lb = ListBuilder::new(builder); |
| 187 | + let expr = col("l", &schema).unwrap(); |
| 188 | + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; |
| 189 | + let key = ScalarValue::Int64(Some(0)); |
| 190 | + let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); |
| 191 | + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); |
| 192 | + assert!(result.is_empty()); |
| 193 | + Ok(()) |
| 194 | + } |
| 195 | + |
| 196 | + fn get_indexed_field_test_failure( |
| 197 | + schema: Schema, |
| 198 | + expr: Arc<dyn PhysicalExpr>, |
| 199 | + key: ScalarValue, |
| 200 | + expected: &str, |
| 201 | + ) -> Result<()> { |
| 202 | + let builder = StringBuilder::new(3); |
| 203 | + let mut lb = ListBuilder::new(builder); |
| 204 | + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; |
| 205 | + let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); |
| 206 | + let r = expr.evaluate(&batch).map(|_| ()); |
| 207 | + assert!(r.is_err()); |
| 208 | + assert_eq!(format!("{}", r.unwrap_err()), expected); |
| 209 | + Ok(()) |
| 210 | + } |
| 211 | + |
| 212 | + #[test] |
| 213 | + fn get_indexed_field_invalid_scalar() -> Result<()> { |
| 214 | + let schema = list_schema("l"); |
| 215 | + let expr = lit(ScalarValue::Utf8(Some("a".to_string()))); |
| 216 | + get_indexed_field_test_failure(schema, expr, ScalarValue::Int64(Some(0)), "This feature is not implemented: field access is not yet implemented for scalar values") |
| 217 | + } |
| 218 | + |
| 219 | + #[test] |
| 220 | + fn get_indexed_field_invalid_list_index() -> Result<()> { |
| 221 | + let schema = list_schema("l"); |
| 222 | + let expr = col("l", &schema).unwrap(); |
| 223 | + get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index") |
| 224 | + } |
| 225 | +} |
0 commit comments