15
15
// specific language governing permissions and limitations
16
16
// under the License.
17
17
18
- use arrow:: compute:: take;
19
18
use arrow:: record_batch:: RecordBatch ;
20
- use arrow_array:: types:: Int32Type ;
21
- use arrow_array:: { Array , DictionaryArray , StructArray } ;
19
+ use arrow_array:: { Array , StructArray } ;
22
20
use arrow_schema:: { DataType , Field , Schema } ;
23
21
use datafusion:: logical_expr:: ColumnarValue ;
24
22
use datafusion_common:: { DataFusionError , Result as DataFusionResult , ScalarValue } ;
@@ -35,12 +33,24 @@ use crate::utils::down_cast_any_ref;
35
33
#[ derive( Debug , Hash ) ]
36
34
pub struct CreateNamedStruct {
37
35
values : Vec < Arc < dyn PhysicalExpr > > ,
38
- data_type : DataType ,
36
+ names : Vec < String > ,
39
37
}
40
38
41
39
impl CreateNamedStruct {
42
- pub fn new ( values : Vec < Arc < dyn PhysicalExpr > > , data_type : DataType ) -> Self {
43
- Self { values, data_type }
40
+ pub fn new ( values : Vec < Arc < dyn PhysicalExpr > > , names : Vec < String > ) -> Self {
41
+ Self { values, names }
42
+ }
43
+
44
+ fn fields ( & self , schema : & Schema ) -> DataFusionResult < Vec < Field > > {
45
+ self . values
46
+ . iter ( )
47
+ . zip ( & self . names )
48
+ . map ( |( expr, name) | {
49
+ let data_type = expr. data_type ( schema) ?;
50
+ let nullable = expr. nullable ( schema) ?;
51
+ Ok ( Field :: new ( name, data_type, nullable) )
52
+ } )
53
+ . collect ( )
44
54
}
45
55
}
46
56
@@ -49,8 +59,9 @@ impl PhysicalExpr for CreateNamedStruct {
49
59
self
50
60
}
51
61
52
- fn data_type ( & self , _input_schema : & Schema ) -> DataFusionResult < DataType > {
53
- Ok ( self . data_type . clone ( ) )
62
+ fn data_type ( & self , input_schema : & Schema ) -> DataFusionResult < DataType > {
63
+ let fields = self . fields ( input_schema) ?;
64
+ Ok ( DataType :: Struct ( fields. into ( ) ) )
54
65
}
55
66
56
67
fn nullable ( & self , _input_schema : & Schema ) -> DataFusionResult < bool > {
@@ -64,32 +75,9 @@ impl PhysicalExpr for CreateNamedStruct {
64
75
. map ( |expr| expr. evaluate ( batch) )
65
76
. collect :: < datafusion_common:: Result < Vec < _ > > > ( ) ?;
66
77
let arrays = ColumnarValue :: values_to_arrays ( & values) ?;
67
- // TODO it would be more efficient if we could preserve dictionaries within the
68
- // struct array but for now we unwrap them to avoid runtime errors
69
- // https://github.com/apache/datafusion-comet/issues/755
70
- let arrays = arrays
71
- . iter ( )
72
- . map ( |array| {
73
- if let Some ( dict_array) =
74
- array. as_any ( ) . downcast_ref :: < DictionaryArray < Int32Type > > ( )
75
- {
76
- take ( dict_array. values ( ) . as_ref ( ) , dict_array. keys ( ) , None )
77
- } else {
78
- Ok ( Arc :: clone ( array) )
79
- }
80
- } )
81
- . collect :: < Result < Vec < _ > , _ > > ( ) ?;
82
- let fields = match & self . data_type {
83
- DataType :: Struct ( fields) => fields,
84
- _ => {
85
- return Err ( DataFusionError :: Internal ( format ! (
86
- "Expected struct data type, got {:?}" ,
87
- self . data_type
88
- ) ) )
89
- }
90
- } ;
78
+ let fields = self . fields ( & batch. schema ( ) ) ?;
91
79
Ok ( ColumnarValue :: Array ( Arc :: new ( StructArray :: new (
92
- fields. clone ( ) ,
80
+ fields. into ( ) ,
93
81
arrays,
94
82
None ,
95
83
) ) ) )
@@ -105,14 +93,14 @@ impl PhysicalExpr for CreateNamedStruct {
105
93
) -> datafusion_common:: Result < Arc < dyn PhysicalExpr > > {
106
94
Ok ( Arc :: new ( CreateNamedStruct :: new (
107
95
children. clone ( ) ,
108
- self . data_type . clone ( ) ,
96
+ self . names . clone ( ) ,
109
97
) ) )
110
98
}
111
99
112
100
fn dyn_hash ( & self , state : & mut dyn Hasher ) {
113
101
let mut s = state;
114
102
self . values . hash ( & mut s) ;
115
- self . data_type . hash ( & mut s) ;
103
+ self . names . hash ( & mut s) ;
116
104
self . hash ( & mut s) ;
117
105
}
118
106
}
@@ -121,8 +109,8 @@ impl Display for CreateNamedStruct {
121
109
fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
122
110
write ! (
123
111
f,
124
- "CreateNamedStruct [values: {:?}, data_type : {:?}]" ,
125
- self . values, self . data_type
112
+ "CreateNamedStruct [values: {:?}, names : {:?}]" ,
113
+ self . values, self . names
126
114
)
127
115
}
128
116
}
@@ -136,7 +124,9 @@ impl PartialEq<dyn Any> for CreateNamedStruct {
136
124
. iter ( )
137
125
. zip ( x. values . iter ( ) )
138
126
. all ( |( a, b) | a. eq ( b) )
139
- && self . data_type . eq ( & x. data_type )
127
+ && self . values . len ( ) == x. values . len ( )
128
+ && self . names . iter ( ) . zip ( x. names . iter ( ) ) . all ( |( a, b) | a. eq ( b) )
129
+ && self . names . len ( ) == x. names . len ( )
140
130
} )
141
131
. unwrap_or ( false )
142
132
}
@@ -246,7 +236,7 @@ impl PartialEq<dyn Any> for GetStructField {
246
236
mod test {
247
237
use super :: CreateNamedStruct ;
248
238
use arrow_array:: { Array , DictionaryArray , Int32Array , RecordBatch , StringArray } ;
249
- use arrow_schema:: { DataType , Field , Fields , Schema } ;
239
+ use arrow_schema:: { DataType , Field , Schema } ;
250
240
use datafusion_common:: Result ;
251
241
use datafusion_expr:: ColumnarValue ;
252
242
use datafusion_physical_expr_common:: expressions:: column:: Column ;
@@ -261,9 +251,8 @@ mod test {
261
251
let data_type = DataType :: Dictionary ( Box :: new ( DataType :: Int32 ) , Box :: new ( DataType :: Int32 ) ) ;
262
252
let schema = Schema :: new ( vec ! [ Field :: new( "a" , data_type, false ) ] ) ;
263
253
let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( dict) ] ) ?;
264
- let data_type =
265
- DataType :: Struct ( Fields :: from ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ) ;
266
- let x = CreateNamedStruct :: new ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , data_type) ;
254
+ let field_names = vec ! [ "a" . to_string( ) ] ;
255
+ let x = CreateNamedStruct :: new ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , field_names) ;
267
256
let ColumnarValue :: Array ( x) = x. evaluate ( & batch) ? else {
268
257
unreachable ! ( )
269
258
} ;
@@ -279,9 +268,8 @@ mod test {
279
268
let data_type = DataType :: Dictionary ( Box :: new ( DataType :: Int32 ) , Box :: new ( DataType :: Utf8 ) ) ;
280
269
let schema = Schema :: new ( vec ! [ Field :: new( "a" , data_type, false ) ] ) ;
281
270
let batch = RecordBatch :: try_new ( Arc :: new ( schema) , vec ! [ Arc :: new( dict) ] ) ?;
282
- let data_type =
283
- DataType :: Struct ( Fields :: from ( vec ! [ Field :: new( "a" , DataType :: Utf8 , false ) ] ) ) ;
284
- let x = CreateNamedStruct :: new ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , data_type) ;
271
+ let field_names = vec ! [ "a" . to_string( ) ] ;
272
+ let x = CreateNamedStruct :: new ( vec ! [ Arc :: new( Column :: new( "a" , 0 ) ) ] , field_names) ;
285
273
let ColumnarValue :: Array ( x) = x. evaluate ( & batch) ? else {
286
274
unreachable ! ( )
287
275
} ;
0 commit comments