Skip to content

Commit 607ee7d

Browse files
authored
feat: Add GetStructField expression (apache#731)
* Add GetStructField support * Add custom types to CometBatchScanExec * Remove test explain * Rust fmt * Fix struct type support checks * Support converting StructArray to native * fix style * Attempt to fix scalar subquery issue * Fix other unit test * Cleanup * Default query plan supporting complex type to false * Migrate struct expressions to spark-expr * Update shouldApplyRowToColumnar comment * Add nulls to test * Rename to allowStruct * Add DataTypeSupport trait * Fix parquet datatype test
1 parent 33f1ce9 commit 607ee7d

File tree

2 files changed

+293
-0
lines changed

2 files changed

+293
-0
lines changed

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ mod kernels;
2323
mod regexp;
2424
pub mod scalar_funcs;
2525
pub mod spark_hash;
26+
mod structs;
2627
mod temporal;
2728
pub mod timezone;
2829
pub mod utils;
@@ -32,6 +33,7 @@ pub use cast::{spark_cast, Cast};
3233
pub use error::{SparkError, SparkResult};
3334
pub use if_expr::IfExpr;
3435
pub use regexp::RLike;
36+
pub use structs::{CreateNamedStruct, GetStructField};
3537
pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr};
3638

3739
/// Spark supports three evaluation modes when evaluating expressions, which affect

src/structs.rs

+291
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::compute::take;
19+
use arrow::record_batch::RecordBatch;
20+
use arrow_array::types::Int32Type;
21+
use arrow_array::{Array, DictionaryArray, StructArray};
22+
use arrow_schema::{DataType, Field, Schema};
23+
use datafusion::logical_expr::ColumnarValue;
24+
use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue};
25+
use datafusion_physical_expr::PhysicalExpr;
26+
use std::{
27+
any::Any,
28+
fmt::{Display, Formatter},
29+
hash::{Hash, Hasher},
30+
sync::Arc,
31+
};
32+
33+
use crate::utils::down_cast_any_ref;
34+
35+
#[derive(Debug, Hash)]
36+
pub struct CreateNamedStruct {
37+
values: Vec<Arc<dyn PhysicalExpr>>,
38+
data_type: DataType,
39+
}
40+
41+
impl CreateNamedStruct {
42+
pub fn new(values: Vec<Arc<dyn PhysicalExpr>>, data_type: DataType) -> Self {
43+
Self { values, data_type }
44+
}
45+
}
46+
47+
impl PhysicalExpr for CreateNamedStruct {
48+
fn as_any(&self) -> &dyn Any {
49+
self
50+
}
51+
52+
fn data_type(&self, _input_schema: &Schema) -> DataFusionResult<DataType> {
53+
Ok(self.data_type.clone())
54+
}
55+
56+
fn nullable(&self, _input_schema: &Schema) -> DataFusionResult<bool> {
57+
Ok(false)
58+
}
59+
60+
fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
61+
let values = self
62+
.values
63+
.iter()
64+
.map(|expr| expr.evaluate(batch))
65+
.collect::<datafusion_common::Result<Vec<_>>>()?;
66+
let arrays = ColumnarValue::values_to_arrays(&values)?;
67+
// TODO it would be more efficient if we could preserve dictionaries within the
68+
// struct array but for now we unwrap them to avoid runtime errors
69+
// https://github.com/apache/datafusion-comet/issues/755
70+
let arrays = arrays
71+
.iter()
72+
.map(|array| {
73+
if let Some(dict_array) =
74+
array.as_any().downcast_ref::<DictionaryArray<Int32Type>>()
75+
{
76+
take(dict_array.values().as_ref(), dict_array.keys(), None)
77+
} else {
78+
Ok(Arc::clone(array))
79+
}
80+
})
81+
.collect::<Result<Vec<_>, _>>()?;
82+
let fields = match &self.data_type {
83+
DataType::Struct(fields) => fields,
84+
_ => {
85+
return Err(DataFusionError::Internal(format!(
86+
"Expected struct data type, got {:?}",
87+
self.data_type
88+
)))
89+
}
90+
};
91+
Ok(ColumnarValue::Array(Arc::new(StructArray::new(
92+
fields.clone(),
93+
arrays,
94+
None,
95+
))))
96+
}
97+
98+
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
99+
self.values.iter().collect()
100+
}
101+
102+
fn with_new_children(
103+
self: Arc<Self>,
104+
children: Vec<Arc<dyn PhysicalExpr>>,
105+
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
106+
Ok(Arc::new(CreateNamedStruct::new(
107+
children.clone(),
108+
self.data_type.clone(),
109+
)))
110+
}
111+
112+
fn dyn_hash(&self, state: &mut dyn Hasher) {
113+
let mut s = state;
114+
self.values.hash(&mut s);
115+
self.data_type.hash(&mut s);
116+
self.hash(&mut s);
117+
}
118+
}
119+
120+
impl Display for CreateNamedStruct {
121+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122+
write!(
123+
f,
124+
"CreateNamedStruct [values: {:?}, data_type: {:?}]",
125+
self.values, self.data_type
126+
)
127+
}
128+
}
129+
130+
impl PartialEq<dyn Any> for CreateNamedStruct {
131+
fn eq(&self, other: &dyn Any) -> bool {
132+
down_cast_any_ref(other)
133+
.downcast_ref::<Self>()
134+
.map(|x| {
135+
self.values
136+
.iter()
137+
.zip(x.values.iter())
138+
.all(|(a, b)| a.eq(b))
139+
&& self.data_type.eq(&x.data_type)
140+
})
141+
.unwrap_or(false)
142+
}
143+
}
144+
145+
#[derive(Debug, Hash)]
146+
pub struct GetStructField {
147+
child: Arc<dyn PhysicalExpr>,
148+
ordinal: usize,
149+
}
150+
151+
impl GetStructField {
152+
pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
153+
Self { child, ordinal }
154+
}
155+
156+
fn child_field(&self, input_schema: &Schema) -> DataFusionResult<Arc<Field>> {
157+
match self.child.data_type(input_schema)? {
158+
DataType::Struct(fields) => Ok(fields[self.ordinal].clone()),
159+
data_type => Err(DataFusionError::Plan(format!(
160+
"Expect struct field, got {:?}",
161+
data_type
162+
))),
163+
}
164+
}
165+
}
166+
167+
impl PhysicalExpr for GetStructField {
168+
fn as_any(&self) -> &dyn Any {
169+
self
170+
}
171+
172+
fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
173+
Ok(self.child_field(input_schema)?.data_type().clone())
174+
}
175+
176+
fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
177+
Ok(self.child_field(input_schema)?.is_nullable())
178+
}
179+
180+
fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
181+
let child_value = self.child.evaluate(batch)?;
182+
183+
match child_value {
184+
ColumnarValue::Array(array) => {
185+
let struct_array = array
186+
.as_any()
187+
.downcast_ref::<StructArray>()
188+
.expect("A struct is expected");
189+
190+
Ok(ColumnarValue::Array(
191+
struct_array.column(self.ordinal).clone(),
192+
))
193+
}
194+
ColumnarValue::Scalar(ScalarValue::Struct(struct_array)) => Ok(ColumnarValue::Array(
195+
struct_array.column(self.ordinal).clone(),
196+
)),
197+
value => Err(DataFusionError::Execution(format!(
198+
"Expected a struct array, got {:?}",
199+
value
200+
))),
201+
}
202+
}
203+
204+
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
205+
vec![&self.child]
206+
}
207+
208+
fn with_new_children(
209+
self: Arc<Self>,
210+
children: Vec<Arc<dyn PhysicalExpr>>,
211+
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
212+
Ok(Arc::new(GetStructField::new(
213+
children[0].clone(),
214+
self.ordinal,
215+
)))
216+
}
217+
218+
fn dyn_hash(&self, state: &mut dyn Hasher) {
219+
let mut s = state;
220+
self.child.hash(&mut s);
221+
self.ordinal.hash(&mut s);
222+
self.hash(&mut s);
223+
}
224+
}
225+
226+
impl Display for GetStructField {
227+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
228+
write!(
229+
f,
230+
"GetStructField [child: {:?}, ordinal: {:?}]",
231+
self.child, self.ordinal
232+
)
233+
}
234+
}
235+
236+
impl PartialEq<dyn Any> for GetStructField {
237+
fn eq(&self, other: &dyn Any) -> bool {
238+
down_cast_any_ref(other)
239+
.downcast_ref::<Self>()
240+
.map(|x| self.child.eq(&x.child) && self.ordinal.eq(&x.ordinal))
241+
.unwrap_or(false)
242+
}
243+
}
244+
245+
#[cfg(test)]
246+
mod test {
247+
use super::CreateNamedStruct;
248+
use arrow_array::{Array, DictionaryArray, Int32Array, RecordBatch, StringArray};
249+
use arrow_schema::{DataType, Field, Fields, Schema};
250+
use datafusion_common::Result;
251+
use datafusion_expr::ColumnarValue;
252+
use datafusion_physical_expr_common::expressions::column::Column;
253+
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
254+
use std::sync::Arc;
255+
256+
#[test]
257+
fn test_create_struct_from_dict_encoded_i32() -> Result<()> {
258+
let keys = Int32Array::from(vec![0, 1, 2]);
259+
let values = Int32Array::from(vec![0, 111, 233]);
260+
let dict = DictionaryArray::try_new(keys, Arc::new(values))?;
261+
let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32));
262+
let schema = Schema::new(vec![Field::new("a", data_type, false)]);
263+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?;
264+
let data_type =
265+
DataType::Struct(Fields::from(vec![Field::new("a", DataType::Int32, false)]));
266+
let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], data_type);
267+
let ColumnarValue::Array(x) = x.evaluate(&batch)? else {
268+
unreachable!()
269+
};
270+
assert_eq!(3, x.len());
271+
Ok(())
272+
}
273+
274+
#[test]
275+
fn test_create_struct_from_dict_encoded_string() -> Result<()> {
276+
let keys = Int32Array::from(vec![0, 1, 2]);
277+
let values = StringArray::from(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
278+
let dict = DictionaryArray::try_new(keys, Arc::new(values))?;
279+
let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
280+
let schema = Schema::new(vec![Field::new("a", data_type, false)]);
281+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?;
282+
let data_type =
283+
DataType::Struct(Fields::from(vec![Field::new("a", DataType::Utf8, false)]));
284+
let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], data_type);
285+
let ColumnarValue::Array(x) = x.evaluate(&batch)? else {
286+
unreachable!()
287+
};
288+
assert_eq!(3, x.len());
289+
Ok(())
290+
}
291+
}

0 commit comments

Comments
 (0)