Skip to content

Commit 767eeb0

Browse files
authored
closing up type checks (#506)
1 parent ee2b9ef commit 767eeb0

File tree

12 files changed

+512
-62
lines changed

12 files changed

+512
-62
lines changed

ballista/rust/core/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ futures = "0.3"
3535
log = "0.4"
3636
prost = "0.7"
3737
serde = {version = "1", features = ["derive"]}
38-
sqlparser = "0.8"
38+
sqlparser = "0.9.0"
3939
tokio = "1.0"
4040
tonic = "0.4"
4141
uuid = { version = "0.8", features = ["v4"] }

ballista/rust/core/proto/ballista.proto

+3-3
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,9 @@ message WindowExprNode {
177177
// repeated LogicalExprNode partition_by = 5;
178178
repeated LogicalExprNode order_by = 6;
179179
// repeated LogicalExprNode filter = 7;
180-
// oneof window_frame {
181-
// WindowFrame frame = 8;
182-
// }
180+
oneof window_frame {
181+
WindowFrame frame = 8;
182+
}
183183
}
184184

185185
message BetweenNode {

ballista/rust/core/src/serde/logical_plan/from_proto.rs

+27-22
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,24 @@
2020
use crate::error::BallistaError;
2121
use crate::serde::{proto_error, protobuf};
2222
use crate::{convert_box_required, convert_required};
23-
use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};
24-
use std::{
25-
convert::{From, TryInto},
26-
unimplemented,
27-
};
28-
2923
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
3024
use datafusion::logical_plan::{
3125
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin,
3226
sqrt, tan, trunc, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator,
3327
};
3428
use datafusion::physical_plan::aggregates::AggregateFunction;
3529
use datafusion::physical_plan::csv::CsvReadOptions;
30+
use datafusion::physical_plan::window_frames::{
31+
WindowFrame, WindowFrameBound, WindowFrameUnits,
32+
};
3633
use datafusion::physical_plan::window_functions::BuiltInWindowFunction;
3734
use datafusion::scalar::ScalarValue;
3835
use protobuf::logical_plan_node::LogicalPlanType;
3936
use protobuf::{logical_expr_node::ExprType, scalar_type};
37+
use std::{
38+
convert::{From, TryInto},
39+
unimplemented,
40+
};
4041

4142
// use uuid::Uuid;
4243

@@ -83,20 +84,6 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
8384
.iter()
8485
.map(|expr| expr.try_into())
8586
.collect::<Result<Vec<_>, _>>()?;
86-
87-
// let partition_by_expr = window
88-
// .partition_by_expr
89-
// .iter()
90-
// .map(|expr| expr.try_into())
91-
// .collect::<Result<Vec<_>, _>>()?;
92-
// let order_by_expr = window
93-
// .order_by_expr
94-
// .iter()
95-
// .map(|expr| expr.try_into())
96-
// .collect::<Result<Vec<_>, _>>()?;
97-
// // FIXME: add filter by expr
98-
// // FIXME: parse the window_frame data
99-
// let window_frame = None;
10087
LogicalPlanBuilder::from(&input)
10188
.window(window_expr)?
10289
.build()
@@ -929,6 +916,15 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
929916
.map(|e| e.try_into())
930917
.into_iter()
931918
.collect::<Result<Vec<_>, _>>()?;
919+
let window_frame = expr
920+
.window_frame
921+
.as_ref()
922+
.map::<Result<WindowFrame, _>, _>(|e| match e {
923+
window_expr_node::WindowFrame::Frame(frame) => {
924+
frame.clone().try_into()
925+
}
926+
})
927+
.transpose()?;
932928
match window_function {
933929
window_expr_node::WindowFunction::AggrFunction(i) => {
934930
let aggr_function = protobuf::AggregateFunction::from_i32(*i)
@@ -945,6 +941,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
945941
),
946942
args: vec![parse_required_expr(&expr.expr)?],
947943
order_by,
944+
window_frame,
948945
})
949946
}
950947
window_expr_node::WindowFunction::BuiltInFunction(i) => {
@@ -964,6 +961,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
964961
),
965962
args: vec![parse_required_expr(&expr.expr)?],
966963
order_by,
964+
window_frame,
967965
})
968966
}
969967
}
@@ -1333,8 +1331,15 @@ impl TryFrom<protobuf::WindowFrame> for WindowFrame {
13331331
)
13341332
})?
13351333
.try_into()?;
1336-
// FIXME parse end bound
1337-
let end_bound = None;
1334+
let end_bound = window
1335+
.end_bound
1336+
.map(|end_bound| match end_bound {
1337+
protobuf::window_frame::EndBound::Bound(end_bound) => {
1338+
end_bound.try_into()
1339+
}
1340+
})
1341+
.transpose()?
1342+
.unwrap_or(WindowFrameBound::CurrentRow);
13381343
Ok(WindowFrame {
13391344
units,
13401345
start_bound,

ballista/rust/core/src/serde/logical_plan/to_proto.rs

+38-18
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,17 @@ use std::{
2424
convert::{TryFrom, TryInto},
2525
};
2626

27+
use super::super::proto_error;
2728
use crate::datasource::DfTableAdapter;
2829
use crate::serde::{protobuf, BallistaError};
2930
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
3031
use datafusion::datasource::CsvFile;
3132
use datafusion::logical_plan::{Expr, JoinType, LogicalPlan};
3233
use datafusion::physical_plan::aggregates::AggregateFunction;
34+
use datafusion::physical_plan::functions::BuiltinScalarFunction;
35+
use datafusion::physical_plan::window_frames::{
36+
WindowFrame, WindowFrameBound, WindowFrameUnits,
37+
};
3338
use datafusion::physical_plan::window_functions::{
3439
BuiltInWindowFunction, WindowFunction,
3540
};
@@ -38,10 +43,6 @@ use protobuf::{
3843
arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, PrimitiveScalarType,
3944
ScalarListValue, ScalarType,
4045
};
41-
use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};
42-
43-
use super::super::proto_error;
44-
use datafusion::physical_plan::functions::BuiltinScalarFunction;
4546

4647
impl protobuf::IntervalUnit {
4748
pub fn from_arrow_interval_unit(interval_unit: &IntervalUnit) -> Self {
@@ -1007,6 +1008,7 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
10071008
ref fun,
10081009
ref args,
10091010
ref order_by,
1011+
ref window_frame,
10101012
..
10111013
} => {
10121014
let window_function = match fun {
@@ -1026,10 +1028,16 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
10261028
.iter()
10271029
.map(|e| e.try_into())
10281030
.collect::<Result<Vec<_>, _>>()?;
1031+
let window_frame = window_frame.map(|window_frame| {
1032+
protobuf::window_expr_node::WindowFrame::Frame(
1033+
window_frame.clone().into(),
1034+
)
1035+
});
10291036
let window_expr = Box::new(protobuf::WindowExprNode {
10301037
expr: Some(Box::new(arg.try_into()?)),
10311038
window_function: Some(window_function),
10321039
order_by,
1040+
window_frame,
10331041
});
10341042
Ok(protobuf::LogicalExprNode {
10351043
expr_type: Some(ExprType::WindowExpr(window_expr)),
@@ -1256,23 +1264,35 @@ impl From<WindowFrameUnits> for protobuf::WindowFrameUnits {
12561264
}
12571265
}
12581266

1259-
impl TryFrom<WindowFrameBound> for protobuf::WindowFrameBound {
1260-
type Error = BallistaError;
1261-
1262-
fn try_from(_bound: WindowFrameBound) -> Result<Self, Self::Error> {
1263-
Err(BallistaError::NotImplemented(
1264-
"WindowFrameBound => protobuf::WindowFrameBound".to_owned(),
1265-
))
1267+
impl From<WindowFrameBound> for protobuf::WindowFrameBound {
1268+
fn from(bound: WindowFrameBound) -> Self {
1269+
match bound {
1270+
WindowFrameBound::CurrentRow => protobuf::WindowFrameBound {
1271+
window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow
1272+
.into(),
1273+
bound_value: None,
1274+
},
1275+
WindowFrameBound::Preceding(v) => protobuf::WindowFrameBound {
1276+
window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding.into(),
1277+
bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value),
1278+
},
1279+
WindowFrameBound::Following(v) => protobuf::WindowFrameBound {
1280+
window_frame_bound_type: protobuf::WindowFrameBoundType::Following.into(),
1281+
bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value),
1282+
},
1283+
}
12661284
}
12671285
}
12681286

1269-
impl TryFrom<WindowFrame> for protobuf::WindowFrame {
1270-
type Error = BallistaError;
1271-
1272-
fn try_from(_window: WindowFrame) -> Result<Self, Self::Error> {
1273-
Err(BallistaError::NotImplemented(
1274-
"WindowFrame => protobuf::WindowFrame".to_owned(),
1275-
))
1287+
impl From<WindowFrame> for protobuf::WindowFrame {
1288+
fn from(window: WindowFrame) -> Self {
1289+
protobuf::WindowFrame {
1290+
window_frame_units: protobuf::WindowFrameUnits::from(window.units).into(),
1291+
start_bound: Some(window.start_bound.into()),
1292+
end_bound: Some(protobuf::window_frame::EndBound::Bound(
1293+
window.end_bound.into(),
1294+
)),
1295+
}
12761296
}
12771297
}
12781298

ballista/rust/core/src/serde/physical_plan/from_proto.rs

+1
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
237237
fun,
238238
args,
239239
order_by,
240+
..
240241
} => {
241242
let arg = df_planner
242243
.create_physical_expr(

datafusion/src/logical_plan/expr.rs

+38-12
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,19 @@
1919
//! such as `col = 5` or `SUM(col)`. See examples on the [`Expr`] struct.
2020
2121
pub use super::Operator;
22-
23-
use std::fmt;
24-
use std::sync::Arc;
25-
26-
use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction};
27-
use arrow::{compute::can_cast_types, datatypes::DataType};
28-
2922
use crate::error::{DataFusionError, Result};
3023
use crate::logical_plan::{DFField, DFSchema};
3124
use crate::physical_plan::{
3225
aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF,
33-
window_functions,
26+
window_frames, window_functions,
3427
};
3528
use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue};
29+
use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction};
30+
use arrow::{compute::can_cast_types, datatypes::DataType};
3631
use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature};
3732
use std::collections::HashSet;
33+
use std::fmt;
34+
use std::sync::Arc;
3835

3936
/// `Expr` is a central struct of DataFusion's query API, and
4037
/// represent logical expressions such as `A + 1`, or `CAST(c1 AS
@@ -199,6 +196,8 @@ pub enum Expr {
199196
args: Vec<Expr>,
200197
/// List of order by expressions
201198
order_by: Vec<Expr>,
199+
/// Window frame
200+
window_frame: Option<window_frames::WindowFrame>,
202201
},
203202
/// aggregate function
204203
AggregateUDF {
@@ -735,10 +734,12 @@ impl Expr {
735734
args,
736735
fun,
737736
order_by,
737+
window_frame,
738738
} => Expr::WindowFunction {
739739
args: rewrite_vec(args, rewriter)?,
740740
fun,
741741
order_by: rewrite_vec(order_by, rewriter)?,
742+
window_frame,
742743
},
743744
Expr::AggregateFunction {
744745
args,
@@ -1283,8 +1284,23 @@ impl fmt::Debug for Expr {
12831284
Expr::ScalarUDF { fun, ref args, .. } => {
12841285
fmt_function(f, &fun.name, false, args)
12851286
}
1286-
Expr::WindowFunction { fun, ref args, .. } => {
1287-
fmt_function(f, &fun.to_string(), false, args)
1287+
Expr::WindowFunction {
1288+
fun,
1289+
ref args,
1290+
window_frame,
1291+
..
1292+
} => {
1293+
fmt_function(f, &fun.to_string(), false, args)?;
1294+
if let Some(window_frame) = window_frame {
1295+
write!(
1296+
f,
1297+
" {} BETWEEN {} AND {}",
1298+
window_frame.units,
1299+
window_frame.start_bound,
1300+
window_frame.end_bound
1301+
)?;
1302+
}
1303+
Ok(())
12881304
}
12891305
Expr::AggregateFunction {
12901306
fun,
@@ -1401,8 +1417,18 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
14011417
Expr::ScalarUDF { fun, args, .. } => {
14021418
create_function_name(&fun.name, false, args, input_schema)
14031419
}
1404-
Expr::WindowFunction { fun, args, .. } => {
1405-
create_function_name(&fun.to_string(), false, args, input_schema)
1420+
Expr::WindowFunction {
1421+
fun,
1422+
args,
1423+
window_frame,
1424+
..
1425+
} => {
1426+
let fun_name =
1427+
create_function_name(&fun.to_string(), false, args, input_schema)?;
1428+
Ok(match window_frame {
1429+
Some(window_frame) => format!("{} {}", fun_name, window_frame),
1430+
None => fun_name,
1431+
})
14061432
}
14071433
Expr::AggregateFunction {
14081434
fun,

datafusion/src/optimizer/utils.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,9 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
337337
fun: fun.clone(),
338338
args: expressions.to_vec(),
339339
}),
340-
Expr::WindowFunction { fun, .. } => {
340+
Expr::WindowFunction {
341+
fun, window_frame, ..
342+
} => {
341343
let index = expressions
342344
.iter()
343345
.position(|expr| {
@@ -353,6 +355,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
353355
fun: fun.clone(),
354356
args: expressions[..index].to_vec(),
355357
order_by: expressions[index + 1..].to_vec(),
358+
window_frame: *window_frame,
356359
})
357360
}
358361
Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction {

datafusion/src/physical_plan/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -617,5 +617,6 @@ pub mod udf;
617617
#[cfg(feature = "unicode_expressions")]
618618
pub mod unicode_expressions;
619619
pub mod union;
620+
pub mod window_frames;
620621
pub mod window_functions;
621622
pub mod windows;

datafusion/src/physical_plan/planner.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
//! Physical query planner
1919
20-
use std::sync::Arc;
21-
2220
use super::{
2321
aggregates, cross_join::CrossJoinExec, empty::EmptyExec, expressions::binary,
2422
functions, hash_join::PartitionMode, udaf, union::UnionExec, windows,
@@ -56,6 +54,7 @@ use arrow::datatypes::{Schema, SchemaRef};
5654
use arrow::{compute::can_cast_types, datatypes::DataType};
5755
use expressions::col;
5856
use log::debug;
57+
use std::sync::Arc;
5958

6059
/// This trait exposes the ability to plan an [`ExecutionPlan`] out of a [`LogicalPlan`].
6160
pub trait ExtensionPlanner {

0 commit comments

Comments
 (0)