diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index a30e7b323aba..12ea8f863be3 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -806,11 +806,15 @@ pub fn parse_expr( .ok_or_else(|| Error::unknown("BuiltInWindowFunction", *i))? .into(); + let args = parse_optional_expr(&expr.expr, registry)? + .map(|e| vec![e]) + .unwrap_or_else(Vec::new); + Ok(Expr::WindowFunction { fun: datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction( built_in_function, ), - args: vec![parse_required_expr(&expr.expr, registry, "expr")?], + args, partition_by, order_by, window_frame, @@ -1234,16 +1238,14 @@ impl TryFrom for WindowFrameBound { })?; match bound_type { protobuf::WindowFrameBoundType::CurrentRow => Ok(Self::CurrentRow), - protobuf::WindowFrameBoundType::Preceding => { - // FIXME implement bound value parsing - // https://github.com/apache/arrow-datafusion/issues/361 - Ok(Self::Preceding(ScalarValue::UInt64(Some(1)))) - } - protobuf::WindowFrameBoundType::Following => { - // FIXME implement bound value parsing - // https://github.com/apache/arrow-datafusion/issues/361 - Ok(Self::Following(ScalarValue::UInt64(Some(1)))) - } + protobuf::WindowFrameBoundType::Preceding => match bound.bound_value { + Some(x) => Ok(Self::Preceding(ScalarValue::try_from(&x)?)), + None => Ok(Self::Preceding(ScalarValue::UInt64(None))), + }, + protobuf::WindowFrameBoundType::Following => match bound.bound_value { + Some(x) => Ok(Self::Following(ScalarValue::try_from(&x)?)), + None => Ok(Self::Following(ScalarValue::UInt64(None))), + }, } } } diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index bf4b777ffab3..12c2a5e784a6 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -70,7 +70,6 @@ mod roundtrip_tests { }; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; - use datafusion_expr::create_udaf; use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like}; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode}; use datafusion_expr::{ @@ -78,6 +77,9 @@ mod roundtrip_tests { BuiltinScalarFunction::{Sqrt, Substr}, Expr, LogicalPlan, Operator, Volatility, }; + use datafusion_expr::{ + create_udaf, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + }; use prost::Message; use std::any::Any; use std::collections::HashMap; @@ -1331,4 +1333,67 @@ mod roundtrip_tests { roundtrip_expr_test(test_expr, ctx.clone()); roundtrip_expr_test(test_expr_with_count, ctx); } + #[test] + fn roundtrip_window() { + let ctx = SessionContext::new(); + + // 1. without window_frame + let test_expr1 = Expr::WindowFunction { + fun: WindowFunction::BuiltInWindowFunction( + datafusion_expr::window_function::BuiltInWindowFunction::Rank, + ), + args: vec![], + partition_by: vec![col("col1")], + order_by: vec![col("col2")], + window_frame: None, + }; + + // 2. with default window_frame + let test_expr2 = Expr::WindowFunction { + fun: WindowFunction::BuiltInWindowFunction( + datafusion_expr::window_function::BuiltInWindowFunction::Rank, + ), + args: vec![], + partition_by: vec![col("col1")], + order_by: vec![col("col2")], + window_frame: Some(WindowFrame::default()), + }; + + // 3. with window_frame with row numbers + let range_number_frame = WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), + }; + + let test_expr3 = Expr::WindowFunction { + fun: WindowFunction::BuiltInWindowFunction( + datafusion_expr::window_function::BuiltInWindowFunction::Rank, + ), + args: vec![], + partition_by: vec![col("col1")], + order_by: vec![col("col2")], + window_frame: Some(range_number_frame), + }; + + // 4. test with AggregateFunction + let row_number_frame = WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), + }; + + let test_expr4 = Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Max), + args: vec![col("col1")], + partition_by: vec![col("col1")], + order_by: vec![col("col2")], + window_frame: Some(row_number_frame), + }; + + roundtrip_expr_test(test_expr1, ctx.clone()); + roundtrip_expr_test(test_expr2, ctx.clone()); + roundtrip_expr_test(test_expr3, ctx.clone()); + roundtrip_expr_test(test_expr4, ctx); + } }