Skip to content

Commit d2853ef

Browse files
authored
feat: Optimze CreateNamedStruct preserve dictionaries (apache#789)
* feat: Optimze CreateNamedStruct preserve dictionaries Instead of serializing the return data_type we just serialize the field names. The original implmentation was done as it lead to slightly simpler implementation, but it clear from apache#750 that this was the wrong choice and leads to issues with the physical data_type. * Support dictionary data_types in StructVector and MapVector * Add length checks
1 parent e3709ea commit d2853ef

File tree

1 file changed

+33
-45
lines changed

1 file changed

+33
-45
lines changed

src/structs.rs

+33-45
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::compute::take;
1918
use arrow::record_batch::RecordBatch;
20-
use arrow_array::types::Int32Type;
21-
use arrow_array::{Array, DictionaryArray, StructArray};
19+
use arrow_array::{Array, StructArray};
2220
use arrow_schema::{DataType, Field, Schema};
2321
use datafusion::logical_expr::ColumnarValue;
2422
use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue};
@@ -35,12 +33,24 @@ use crate::utils::down_cast_any_ref;
3533
#[derive(Debug, Hash)]
3634
pub struct CreateNamedStruct {
3735
values: Vec<Arc<dyn PhysicalExpr>>,
38-
data_type: DataType,
36+
names: Vec<String>,
3937
}
4038

4139
impl CreateNamedStruct {
42-
pub fn new(values: Vec<Arc<dyn PhysicalExpr>>, data_type: DataType) -> Self {
43-
Self { values, data_type }
40+
pub fn new(values: Vec<Arc<dyn PhysicalExpr>>, names: Vec<String>) -> Self {
41+
Self { values, names }
42+
}
43+
44+
fn fields(&self, schema: &Schema) -> DataFusionResult<Vec<Field>> {
45+
self.values
46+
.iter()
47+
.zip(&self.names)
48+
.map(|(expr, name)| {
49+
let data_type = expr.data_type(schema)?;
50+
let nullable = expr.nullable(schema)?;
51+
Ok(Field::new(name, data_type, nullable))
52+
})
53+
.collect()
4454
}
4555
}
4656

@@ -49,8 +59,9 @@ impl PhysicalExpr for CreateNamedStruct {
4959
self
5060
}
5161

52-
fn data_type(&self, _input_schema: &Schema) -> DataFusionResult<DataType> {
53-
Ok(self.data_type.clone())
62+
fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
63+
let fields = self.fields(input_schema)?;
64+
Ok(DataType::Struct(fields.into()))
5465
}
5566

5667
fn nullable(&self, _input_schema: &Schema) -> DataFusionResult<bool> {
@@ -64,32 +75,9 @@ impl PhysicalExpr for CreateNamedStruct {
6475
.map(|expr| expr.evaluate(batch))
6576
.collect::<datafusion_common::Result<Vec<_>>>()?;
6677
let arrays = ColumnarValue::values_to_arrays(&values)?;
67-
// TODO it would be more efficient if we could preserve dictionaries within the
68-
// struct array but for now we unwrap them to avoid runtime errors
69-
// https://github.com/apache/datafusion-comet/issues/755
70-
let arrays = arrays
71-
.iter()
72-
.map(|array| {
73-
if let Some(dict_array) =
74-
array.as_any().downcast_ref::<DictionaryArray<Int32Type>>()
75-
{
76-
take(dict_array.values().as_ref(), dict_array.keys(), None)
77-
} else {
78-
Ok(Arc::clone(array))
79-
}
80-
})
81-
.collect::<Result<Vec<_>, _>>()?;
82-
let fields = match &self.data_type {
83-
DataType::Struct(fields) => fields,
84-
_ => {
85-
return Err(DataFusionError::Internal(format!(
86-
"Expected struct data type, got {:?}",
87-
self.data_type
88-
)))
89-
}
90-
};
78+
let fields = self.fields(&batch.schema())?;
9179
Ok(ColumnarValue::Array(Arc::new(StructArray::new(
92-
fields.clone(),
80+
fields.into(),
9381
arrays,
9482
None,
9583
))))
@@ -105,14 +93,14 @@ impl PhysicalExpr for CreateNamedStruct {
10593
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
10694
Ok(Arc::new(CreateNamedStruct::new(
10795
children.clone(),
108-
self.data_type.clone(),
96+
self.names.clone(),
10997
)))
11098
}
11199

112100
fn dyn_hash(&self, state: &mut dyn Hasher) {
113101
let mut s = state;
114102
self.values.hash(&mut s);
115-
self.data_type.hash(&mut s);
103+
self.names.hash(&mut s);
116104
self.hash(&mut s);
117105
}
118106
}
@@ -121,8 +109,8 @@ impl Display for CreateNamedStruct {
121109
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122110
write!(
123111
f,
124-
"CreateNamedStruct [values: {:?}, data_type: {:?}]",
125-
self.values, self.data_type
112+
"CreateNamedStruct [values: {:?}, names: {:?}]",
113+
self.values, self.names
126114
)
127115
}
128116
}
@@ -136,7 +124,9 @@ impl PartialEq<dyn Any> for CreateNamedStruct {
136124
.iter()
137125
.zip(x.values.iter())
138126
.all(|(a, b)| a.eq(b))
139-
&& self.data_type.eq(&x.data_type)
127+
&& self.values.len() == x.values.len()
128+
&& self.names.iter().zip(x.names.iter()).all(|(a, b)| a.eq(b))
129+
&& self.names.len() == x.names.len()
140130
})
141131
.unwrap_or(false)
142132
}
@@ -246,7 +236,7 @@ impl PartialEq<dyn Any> for GetStructField {
246236
mod test {
247237
use super::CreateNamedStruct;
248238
use arrow_array::{Array, DictionaryArray, Int32Array, RecordBatch, StringArray};
249-
use arrow_schema::{DataType, Field, Fields, Schema};
239+
use arrow_schema::{DataType, Field, Schema};
250240
use datafusion_common::Result;
251241
use datafusion_expr::ColumnarValue;
252242
use datafusion_physical_expr_common::expressions::column::Column;
@@ -261,9 +251,8 @@ mod test {
261251
let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32));
262252
let schema = Schema::new(vec![Field::new("a", data_type, false)]);
263253
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?;
264-
let data_type =
265-
DataType::Struct(Fields::from(vec![Field::new("a", DataType::Int32, false)]));
266-
let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], data_type);
254+
let field_names = vec!["a".to_string()];
255+
let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names);
267256
let ColumnarValue::Array(x) = x.evaluate(&batch)? else {
268257
unreachable!()
269258
};
@@ -279,9 +268,8 @@ mod test {
279268
let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
280269
let schema = Schema::new(vec![Field::new("a", data_type, false)]);
281270
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?;
282-
let data_type =
283-
DataType::Struct(Fields::from(vec![Field::new("a", DataType::Utf8, false)]));
284-
let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], data_type);
271+
let field_names = vec!["a".to_string()];
272+
let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names);
285273
let ColumnarValue::Array(x) = x.evaluate(&batch)? else {
286274
unreachable!()
287275
};

0 commit comments

Comments
 (0)