@@ -78,6 +78,8 @@ pub trait ContextProvider {
78
78
fn get_function_meta ( & self , name : & str ) -> Option < Arc < ScalarUDF > > ;
79
79
/// Getter for a UDAF description
80
80
fn get_aggregate_meta ( & self , name : & str ) -> Option < Arc < AggregateUDF > > ;
81
+ /// Getter for system/user-defined variable type
82
+ fn get_variable_type ( & self , variable_names : & [ String ] ) -> Option < DataType > ;
81
83
}
82
84
83
85
/// SQL query planner
@@ -1412,7 +1414,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
1412
1414
if id. value . starts_with ( '@' ) {
1413
1415
// TODO: figure out if ScalarVariables should be insensitive.
1414
1416
let var_names = vec ! [ id. value. clone( ) ] ;
1415
- Ok ( Expr :: ScalarVariable ( var_names) )
1417
+ let ty = self
1418
+ . schema_provider
1419
+ . get_variable_type ( & var_names)
1420
+ . ok_or_else ( || {
1421
+ DataFusionError :: Execution ( format ! (
1422
+ "variable {:?} has no type information" ,
1423
+ var_names
1424
+ ) )
1425
+ } ) ?;
1426
+ Ok ( Expr :: ScalarVariable ( ty, var_names) )
1416
1427
} else {
1417
1428
// Don't use `col()` here because it will try to
1418
1429
// interpret names with '.' as if they were
@@ -1440,7 +1451,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
1440
1451
let mut var_names: Vec < _ > = ids. iter ( ) . map ( normalize_ident) . collect ( ) ;
1441
1452
1442
1453
if & var_names[ 0 ] [ 0 ..1 ] == "@" {
1443
- Ok ( Expr :: ScalarVariable ( var_names) )
1454
+ let ty = self
1455
+ . schema_provider
1456
+ . get_variable_type ( & var_names)
1457
+ . ok_or_else ( || {
1458
+ DataFusionError :: Execution ( format ! (
1459
+ "variable {:?} has no type information" ,
1460
+ var_names
1461
+ ) )
1462
+ } ) ?;
1463
+ Ok ( Expr :: ScalarVariable ( ty, var_names) )
1444
1464
} else {
1445
1465
match ( var_names. pop ( ) , var_names. pop ( ) ) {
1446
1466
( Some ( name) , Some ( relation) ) if var_names. is_empty ( ) => {
@@ -3938,6 +3958,10 @@ mod tests {
3938
3958
fn get_aggregate_meta ( & self , _name : & str ) -> Option < Arc < AggregateUDF > > {
3939
3959
unimplemented ! ( )
3940
3960
}
3961
+
3962
+ fn get_variable_type ( & self , _: & [ String ] ) -> Option < DataType > {
3963
+ unimplemented ! ( )
3964
+ }
3941
3965
}
3942
3966
3943
3967
#[ test]
0 commit comments