diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index 4a5fe6d30bfe..8db9a0c05821 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -561,4 +561,64 @@ mod tests { let df = context.sql(sql).await.unwrap(); assert!(!df.collect().await.unwrap().is_empty()); } + + #[tokio::test] + #[cfg(feature = "standalone")] + async fn test_union_and_union_all() { + use super::*; + use ballista_core::config::{ + BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA, + }; + use datafusion::arrow::util::pretty::pretty_format_batches; + use datafusion::assert_batches_eq; + let config = BallistaConfigBuilder::default() + .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true") + .build() + .unwrap(); + let context = BallistaContext::standalone(&config, 1).await.unwrap(); + + let df = context + .sql("SELECT 1 as NUMBER union SELECT 1 as NUMBER;") + .await + .unwrap(); + let res1 = df.collect().await.unwrap(); + let expected1 = vec![ + "+--------+", + "| number |", + "+--------+", + "| 1 |", + "+--------+", + ]; + assert_eq!( + expected1, + pretty_format_batches(&*res1) + .unwrap() + .to_string() + .trim() + .lines() + .collect::>() + ); + let expected2 = vec![ + "+--------+", + "| number |", + "+--------+", + "| 1 |", + "| 1 |", + "+--------+", + ]; + let df = context + .sql("SELECT 1 as NUMBER union all SELECT 1 as NUMBER;") + .await + .unwrap(); + let res2 = df.collect().await.unwrap(); + assert_eq!( + expected2, + pretty_format_batches(&*res2) + .unwrap() + .to_string() + .trim() + .lines() + .collect::>() + ); + } } diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 5bb12890ccc8..4b493105d933 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -50,6 +50,7 @@ message LogicalPlanNode { ValuesNode values = 16; LogicalExtensionNode extension = 17; CreateCatalogSchemaNode create_catalog_schema = 18; + UnionNode union = 19; } } @@ -212,6 +213,10 @@ message JoinNode { bool null_equals_null = 7; } +message UnionNode { + repeated LogicalPlanNode inputs = 1; +} + message CrossJoinNode { LogicalPlanNode left = 1; LogicalPlanNode right = 2; @@ -253,6 +258,7 @@ message PhysicalPlanNode { CrossJoinExecNode cross_join = 19; AvroScanExecNode avro_scan = 20; PhysicalExtensionNode extension = 21; + UnionExecNode union = 22; } } @@ -433,6 +439,10 @@ message HashJoinExecNode { bool null_equals_null = 7; } +message UnionExecNode { + repeated PhysicalPlanNode inputs = 1; +} + message CrossJoinExecNode { PhysicalPlanNode left = 1; PhysicalPlanNode right = 2; diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 82e47874fae8..4198f452909a 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -406,6 +406,25 @@ impl AsLogicalPlan for LogicalPlanNode { builder.build().map_err(|e| e.into()) } + LogicalPlanType::Union(union) => { + let mut input_plans: Vec = union + .inputs + .iter() + .map(|i| i.try_into_logical_plan(ctx, extension_codec)) + .collect::>()?; + + if input_plans.len() < 2 { + return Err( BallistaError::General(String::from( + "Protobuf deserialization error, Union was require at least two input.", + ))); + } + + let mut builder = LogicalPlanBuilder::from(input_plans.pop().unwrap()); + for plan in input_plans { + builder = builder.union(plan)?; + } + builder.build().map_err(|e| e.into()) + } LogicalPlanType::CrossJoin(crossjoin) => { let left = into_logical_plan!(crossjoin.left, &ctx, extension_codec)?; let right = into_logical_plan!(crossjoin.right, &ctx, extension_codec)?; @@ -815,7 +834,23 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Union(_) => unimplemented!(), + LogicalPlan::Union(union) => { + let inputs: Vec = union + .inputs + .iter() + .map(|i| { + protobuf::LogicalPlanNode::try_from_logical_plan( + i, + extension_codec, + ) + }) + .collect::>()?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Union( + protobuf::UnionNode { inputs }, + )), + }) + } LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { let left = protobuf::LogicalPlanNode::try_from_logical_plan( left.as_ref(), diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index e7d803d54dd9..2f9ecda7a517 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -52,6 +52,7 @@ use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::union::UnionExec; use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, @@ -382,6 +383,13 @@ impl AsExecutionPlan for PhysicalPlanNode { &hashjoin.null_equals_null, )?)) } + PhysicalPlanType::Union(union) => { + let mut inputs: Vec> = vec![]; + for input in &union.inputs { + inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + } + Ok(Arc::new(UnionExec::new(inputs))) + } PhysicalPlanType::CrossJoin(crossjoin) => { let left: Arc = into_physical_plan!(crossjoin.left, ctx, extension_codec)?; @@ -866,6 +874,19 @@ impl AsExecutionPlan for PhysicalPlanNode { }, )), }) + } else if let Some(union) = plan.downcast_ref::() { + let mut inputs: Vec = vec![]; + for input in union.inputs() { + inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( + input.to_owned(), + extension_codec, + )?); + } + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Union( + protobuf::UnionExecNode { inputs }, + )), + }) } else { let mut buf: Vec = vec![]; extension_codec.try_encode(plan_clone.clone(), &mut buf)?; diff --git a/datafusion/src/physical_plan/union.rs b/datafusion/src/physical_plan/union.rs index fb25cf30e8dc..bf6fb7ce6623 100644 --- a/datafusion/src/physical_plan/union.rs +++ b/datafusion/src/physical_plan/union.rs @@ -56,6 +56,11 @@ impl UnionExec { metrics: ExecutionPlanMetricsSet::new(), } } + + /// Get inputs of the execution plan + pub fn inputs(&self) -> &Vec> { + &self.inputs + } } #[async_trait]