Skip to content

Commit 8db88ba

Browse files
authored
Allow different types of query variables (@@var) rather than just string (#1943)
* Upstream implementation of typed query variables * fix ballista * test that we can have variables that are not just Utf8 * clippy lint
1 parent e160520 commit 8db88ba

File tree

15 files changed

+91
-28
lines changed

15 files changed

+91
-28
lines changed

datafusion-expr/src/expr.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ pub enum Expr {
8686
/// A named reference to a qualified filed in a schema.
8787
Column(Column),
8888
/// A named reference to a variable in a registry.
89-
ScalarVariable(Vec<String>),
89+
ScalarVariable(DataType, Vec<String>),
9090
/// A constant value.
9191
Literal(ScalarValue),
9292
/// A binary expression such as "age > 21"
@@ -399,7 +399,7 @@ impl fmt::Debug for Expr {
399399
match self {
400400
Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias),
401401
Expr::Column(c) => write!(f, "{}", c),
402-
Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")),
402+
Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")),
403403
Expr::Literal(v) => write!(f, "{:?}", v),
404404
Expr::Case {
405405
expr,
@@ -562,7 +562,7 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
562562
match e {
563563
Expr::Alias(_, name) => Ok(name.clone()),
564564
Expr::Column(c) => Ok(c.flat_name()),
565-
Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")),
565+
Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")),
566566
Expr::Literal(value) => Ok(format!("{:?}", value)),
567567
Expr::BinaryExpr { left, op, right } => {
568568
let left = create_name(left, input_schema)?;

datafusion-proto/src/to_proto.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
508508
expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),
509509
}
510510
}
511-
Expr::ScalarVariable(_) => unimplemented!(),
511+
Expr::ScalarVariable(_, _) => unimplemented!(),
512512
Expr::ScalarFunction { ref fun, ref args } => {
513513
let fun: protobuf::ScalarFunction = fun.try_into()?;
514514
let args: Vec<Self> = args

datafusion/src/datasource/listing/helpers.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> {
8282
}
8383
Expr::Literal(_)
8484
| Expr::Alias(_, _)
85-
| Expr::ScalarVariable(_)
85+
| Expr::ScalarVariable(_, _)
8686
| Expr::Not(_)
8787
| Expr::IsNotNull(_)
8888
| Expr::IsNull(_)

datafusion/src/execution/context.rs

+25-7
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use std::path::PathBuf;
4545
use std::string::String;
4646
use std::sync::Arc;
4747

48-
use arrow::datatypes::SchemaRef;
48+
use arrow::datatypes::{DataType, SchemaRef};
4949

5050
use crate::catalog::{
5151
catalog::{CatalogProvider, MemoryCatalogProvider},
@@ -1190,6 +1190,23 @@ impl ContextProvider for ExecutionContextState {
11901190
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
11911191
self.aggregate_functions.get(name).cloned()
11921192
}
1193+
1194+
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
1195+
if variable_names.is_empty() {
1196+
return None;
1197+
}
1198+
1199+
let provider_type = if &variable_names[0][0..2] == "@@" {
1200+
VarType::System
1201+
} else {
1202+
VarType::UserDefined
1203+
};
1204+
1205+
self.execution_props
1206+
.var_providers
1207+
.as_ref()
1208+
.and_then(|provider| provider.get(&provider_type)?.get_type(variable_names))
1209+
}
11931210
}
11941211

11951212
impl FunctionRegistry for ExecutionContextState {
@@ -1300,14 +1317,15 @@ mod tests {
13001317
ctx.register_table("dual", provider)?;
13011318

13021319
let results =
1303-
plan_and_collect(&mut ctx, "SELECT @@version, @name FROM dual").await?;
1320+
plan_and_collect(&mut ctx, "SELECT @@version, @name, @integer + 1 FROM dual")
1321+
.await?;
13041322

13051323
let expected = vec![
1306-
"+----------------------+------------------------+",
1307-
"| @@version | @name |",
1308-
"+----------------------+------------------------+",
1309-
"| system-var-@@version | user-defined-var-@name |",
1310-
"+----------------------+------------------------+",
1324+
"+----------------------+------------------------+------------------------+",
1325+
"| @@version | @name | @integer Plus Int64(1) |",
1326+
"+----------------------+------------------------+------------------------+",
1327+
"| system-var-@@version | user-defined-var-@name | 42 |",
1328+
"+----------------------+------------------------+------------------------+",
13111329
];
13121330
assert_batches_eq!(expected, &results);
13131331

datafusion/src/logical_plan/expr_rewriter.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ impl ExprRewritable for Expr {
111111
let expr = match self {
112112
Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name),
113113
Expr::Column(_) => self.clone(),
114-
Expr::ScalarVariable(names) => Expr::ScalarVariable(names),
114+
Expr::ScalarVariable(ty, names) => Expr::ScalarVariable(ty, names),
115115
Expr::Literal(value) => Expr::Literal(value),
116116
Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr {
117117
left: rewrite_boxed(left, rewriter)?,

datafusion/src/logical_plan/expr_schema.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ impl ExprSchemable for Expr {
5858
expr.get_type(schema)
5959
}
6060
Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
61-
Expr::ScalarVariable(_) => Ok(DataType::Utf8),
61+
Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
6262
Expr::Literal(l) => Ok(l.get_datatype()),
6363
Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
6464
Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => {
@@ -162,7 +162,7 @@ impl ExprSchemable for Expr {
162162
}
163163
}
164164
Expr::Cast { expr, .. } => expr.nullable(input_schema),
165-
Expr::ScalarVariable(_)
165+
Expr::ScalarVariable(_, _)
166166
| Expr::TryCast { .. }
167167
| Expr::ScalarFunction { .. }
168168
| Expr::ScalarUDF { .. }

datafusion/src/logical_plan/expr_visitor.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ impl ExprVisitable for Expr {
104104
| Expr::Sort { expr, .. }
105105
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
106106
Expr::Column(_)
107-
| Expr::ScalarVariable(_)
107+
| Expr::ScalarVariable(_, _)
108108
| Expr::Literal(_)
109109
| Expr::Wildcard => Ok(visitor),
110110
Expr::BinaryExpr { left, right, .. } => {

datafusion/src/optimizer/common_subexpr_eliminate.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ impl ExprIdentifierVisitor<'_> {
379379
desc.push_str("Column-");
380380
desc.push_str(&column.flat_name());
381381
}
382-
Expr::ScalarVariable(var_names) => {
382+
Expr::ScalarVariable(_, var_names) => {
383383
desc.push_str("ScalarVariable-");
384384
desc.push_str(&var_names.join("."));
385385
}

datafusion/src/optimizer/simplify_expressions.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ impl<'a> ConstEvaluator<'a> {
373373
Expr::Alias(..)
374374
| Expr::AggregateFunction { .. }
375375
| Expr::AggregateUDF { .. }
376-
| Expr::ScalarVariable(_)
376+
| Expr::ScalarVariable(_, _)
377377
| Expr::Column(_)
378378
| Expr::WindowFunction { .. }
379379
| Expr::Sort { .. }

datafusion/src/optimizer/utils.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
6262
Expr::Column(qc) => {
6363
self.accum.insert(qc.clone());
6464
}
65-
Expr::ScalarVariable(var_names) => {
65+
Expr::ScalarVariable(_, var_names) => {
6666
self.accum.insert(Column::from_name(var_names.join(".")));
6767
}
6868
Expr::Alias(_, _)
@@ -331,7 +331,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
331331
}
332332
Ok(expr_list)
333333
}
334-
Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_) => Ok(vec![]),
334+
Expr::Column(_) | Expr::Literal(_) | Expr::ScalarVariable(_, _) => Ok(vec![]),
335335
Expr::Between {
336336
expr, low, high, ..
337337
} => Ok(vec![
@@ -476,7 +476,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
476476
Expr::Column(_)
477477
| Expr::Literal(_)
478478
| Expr::InList { .. }
479-
| Expr::ScalarVariable(_) => Ok(expr.clone()),
479+
| Expr::ScalarVariable(_, _) => Ok(expr.clone()),
480480
Expr::Sort {
481481
asc, nulls_first, ..
482482
} => Ok(Expr::Sort {

datafusion/src/physical_plan/planner.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
9797
}
9898
}
9999
Expr::Alias(_, name) => Ok(name.clone()),
100-
Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")),
100+
Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")),
101101
Expr::Literal(value) => Ok(format!("{:?}", value)),
102102
Expr::BinaryExpr { left, op, right } => {
103103
let left = create_physical_name(left, false)?;
@@ -883,7 +883,7 @@ pub fn create_physical_expr(
883883
Ok(Arc::new(Column::new(&c.name, idx)))
884884
}
885885
Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))),
886-
Expr::ScalarVariable(variable_names) => {
886+
Expr::ScalarVariable(_, variable_names) => {
887887
if &variable_names[0][0..2] == "@@" {
888888
match execution_props.get_var_provider(VarType::System) {
889889
Some(provider) => {

datafusion/src/sql/planner.rs

+26-2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ pub trait ContextProvider {
7878
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
7979
/// Getter for a UDAF description
8080
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>>;
81+
/// Getter for system/user-defined variable type
82+
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType>;
8183
}
8284

8385
/// SQL query planner
@@ -1412,7 +1414,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
14121414
if id.value.starts_with('@') {
14131415
// TODO: figure out if ScalarVariables should be insensitive.
14141416
let var_names = vec![id.value.clone()];
1415-
Ok(Expr::ScalarVariable(var_names))
1417+
let ty = self
1418+
.schema_provider
1419+
.get_variable_type(&var_names)
1420+
.ok_or_else(|| {
1421+
DataFusionError::Execution(format!(
1422+
"variable {:?} has no type information",
1423+
var_names
1424+
))
1425+
})?;
1426+
Ok(Expr::ScalarVariable(ty, var_names))
14161427
} else {
14171428
// Don't use `col()` here because it will try to
14181429
// interpret names with '.' as if they were
@@ -1440,7 +1451,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
14401451
let mut var_names: Vec<_> = ids.iter().map(normalize_ident).collect();
14411452

14421453
if &var_names[0][0..1] == "@" {
1443-
Ok(Expr::ScalarVariable(var_names))
1454+
let ty = self
1455+
.schema_provider
1456+
.get_variable_type(&var_names)
1457+
.ok_or_else(|| {
1458+
DataFusionError::Execution(format!(
1459+
"variable {:?} has no type information",
1460+
var_names
1461+
))
1462+
})?;
1463+
Ok(Expr::ScalarVariable(ty, var_names))
14441464
} else {
14451465
match (var_names.pop(), var_names.pop()) {
14461466
(Some(name), Some(relation)) if var_names.is_empty() => {
@@ -3938,6 +3958,10 @@ mod tests {
39383958
fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
39393959
unimplemented!()
39403960
}
3961+
3962+
fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
3963+
unimplemented!()
3964+
}
39413965
}
39423966

39433967
#[test]

datafusion/src/sql/utils.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ where
368368
asc: *asc,
369369
nulls_first: *nulls_first,
370370
}),
371-
Expr::Column { .. } | Expr::Literal(_) | Expr::ScalarVariable(_) => {
371+
Expr::Column { .. } | Expr::Literal(_) | Expr::ScalarVariable(_, _) => {
372372
Ok(expr.clone())
373373
}
374374
Expr::Wildcard => Ok(Expr::Wildcard),

datafusion/src/test/variable.rs

+19-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
use crate::error::Result;
2121
use crate::scalar::ScalarValue;
2222
use crate::variable::VarProvider;
23+
use arrow::datatypes::DataType;
2324

2425
/// System variable
2526
#[derive(Default)]
@@ -38,6 +39,10 @@ impl VarProvider for SystemVar {
3839
let s = format!("{}-{}", "system-var", var_names.concat());
3940
Ok(ScalarValue::Utf8(Some(s)))
4041
}
42+
43+
fn get_type(&self, _: &[String]) -> Option<DataType> {
44+
Some(DataType::Utf8)
45+
}
4146
}
4247

4348
/// user defined variable
@@ -54,7 +59,19 @@ impl UserDefinedVar {
5459
impl VarProvider for UserDefinedVar {
5560
/// Get user defined variable value
5661
fn get_value(&self, var_names: Vec<String>) -> Result<ScalarValue> {
57-
let s = format!("{}-{}", "user-defined-var", var_names.concat());
58-
Ok(ScalarValue::Utf8(Some(s)))
62+
if var_names[0] != "@integer" {
63+
let s = format!("{}-{}", "user-defined-var", var_names.concat());
64+
Ok(ScalarValue::Utf8(Some(s)))
65+
} else {
66+
Ok(ScalarValue::Int32(Some(41)))
67+
}
68+
}
69+
70+
fn get_type(&self, var_names: &[String]) -> Option<DataType> {
71+
if var_names[0] != "@integer" {
72+
Some(DataType::Utf8)
73+
} else {
74+
Some(DataType::Int32)
75+
}
5976
}
6077
}

datafusion/src/variable/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
2020
use crate::error::Result;
2121
use crate::scalar::ScalarValue;
22+
use arrow::datatypes::DataType;
2223

2324
/// Variable type, system/user defined
2425
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -33,4 +34,7 @@ pub enum VarType {
3334
pub trait VarProvider {
3435
/// Get variable value
3536
fn get_value(&self, var_names: Vec<String>) -> Result<ScalarValue>;
37+
38+
/// Return the type of the given variable
39+
fn get_type(&self, var_names: &[String]) -> Option<DataType>;
3640
}

0 commit comments

Comments
 (0)