Skip to content

Commit 73ea6e1

Browse files
authored
[Ballista] Support Union in ballista. (#2098)
* add union in ballista.proto * add ballista plan to proto * fix clippy * add ut * fix clippy * fix fmt * fix comment
1 parent ff110d6 commit 73ea6e1

File tree

5 files changed

+132
-1
lines changed

5 files changed

+132
-1
lines changed

ballista/rust/client/src/context.rs

+60
Original file line numberDiff line numberDiff line change
@@ -561,4 +561,64 @@ mod tests {
561561
let df = context.sql(sql).await.unwrap();
562562
assert!(!df.collect().await.unwrap().is_empty());
563563
}
564+
565+
#[tokio::test]
566+
#[cfg(feature = "standalone")]
567+
async fn test_union_and_union_all() {
568+
use super::*;
569+
use ballista_core::config::{
570+
BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA,
571+
};
572+
use datafusion::arrow::util::pretty::pretty_format_batches;
573+
use datafusion::assert_batches_eq;
574+
let config = BallistaConfigBuilder::default()
575+
.set(BALLISTA_WITH_INFORMATION_SCHEMA, "true")
576+
.build()
577+
.unwrap();
578+
let context = BallistaContext::standalone(&config, 1).await.unwrap();
579+
580+
let df = context
581+
.sql("SELECT 1 as NUMBER union SELECT 1 as NUMBER;")
582+
.await
583+
.unwrap();
584+
let res1 = df.collect().await.unwrap();
585+
let expected1 = vec![
586+
"+--------+",
587+
"| number |",
588+
"+--------+",
589+
"| 1 |",
590+
"+--------+",
591+
];
592+
assert_eq!(
593+
expected1,
594+
pretty_format_batches(&*res1)
595+
.unwrap()
596+
.to_string()
597+
.trim()
598+
.lines()
599+
.collect::<Vec<&str>>()
600+
);
601+
let expected2 = vec![
602+
"+--------+",
603+
"| number |",
604+
"+--------+",
605+
"| 1 |",
606+
"| 1 |",
607+
"+--------+",
608+
];
609+
let df = context
610+
.sql("SELECT 1 as NUMBER union all SELECT 1 as NUMBER;")
611+
.await
612+
.unwrap();
613+
let res2 = df.collect().await.unwrap();
614+
assert_eq!(
615+
expected2,
616+
pretty_format_batches(&*res2)
617+
.unwrap()
618+
.to_string()
619+
.trim()
620+
.lines()
621+
.collect::<Vec<&str>>()
622+
);
623+
}
564624
}

ballista/rust/core/proto/ballista.proto

+10
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ message LogicalPlanNode {
5050
ValuesNode values = 16;
5151
LogicalExtensionNode extension = 17;
5252
CreateCatalogSchemaNode create_catalog_schema = 18;
53+
UnionNode union = 19;
5354
}
5455
}
5556

@@ -212,6 +213,10 @@ message JoinNode {
212213
bool null_equals_null = 7;
213214
}
214215

216+
message UnionNode {
217+
repeated LogicalPlanNode inputs = 1;
218+
}
219+
215220
message CrossJoinNode {
216221
LogicalPlanNode left = 1;
217222
LogicalPlanNode right = 2;
@@ -253,6 +258,7 @@ message PhysicalPlanNode {
253258
CrossJoinExecNode cross_join = 19;
254259
AvroScanExecNode avro_scan = 20;
255260
PhysicalExtensionNode extension = 21;
261+
UnionExecNode union = 22;
256262
}
257263
}
258264

@@ -433,6 +439,10 @@ message HashJoinExecNode {
433439
bool null_equals_null = 7;
434440
}
435441

442+
message UnionExecNode {
443+
repeated PhysicalPlanNode inputs = 1;
444+
}
445+
436446
message CrossJoinExecNode {
437447
PhysicalPlanNode left = 1;
438448
PhysicalPlanNode right = 2;

ballista/rust/core/src/serde/logical_plan/mod.rs

+36-1
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,25 @@ impl AsLogicalPlan for LogicalPlanNode {
406406

407407
builder.build().map_err(|e| e.into())
408408
}
409+
LogicalPlanType::Union(union) => {
410+
let mut input_plans: Vec<LogicalPlan> = union
411+
.inputs
412+
.iter()
413+
.map(|i| i.try_into_logical_plan(ctx, extension_codec))
414+
.collect::<Result<_, BallistaError>>()?;
415+
416+
if input_plans.len() < 2 {
417+
return Err( BallistaError::General(String::from(
418+
"Protobuf deserialization error, Union was require at least two input.",
419+
)));
420+
}
421+
422+
let mut builder = LogicalPlanBuilder::from(input_plans.pop().unwrap());
423+
for plan in input_plans {
424+
builder = builder.union(plan)?;
425+
}
426+
builder.build().map_err(|e| e.into())
427+
}
409428
LogicalPlanType::CrossJoin(crossjoin) => {
410429
let left = into_logical_plan!(crossjoin.left, &ctx, extension_codec)?;
411430
let right = into_logical_plan!(crossjoin.right, &ctx, extension_codec)?;
@@ -815,7 +834,23 @@ impl AsLogicalPlan for LogicalPlanNode {
815834
))),
816835
})
817836
}
818-
LogicalPlan::Union(_) => unimplemented!(),
837+
LogicalPlan::Union(union) => {
838+
let inputs: Vec<LogicalPlanNode> = union
839+
.inputs
840+
.iter()
841+
.map(|i| {
842+
protobuf::LogicalPlanNode::try_from_logical_plan(
843+
i,
844+
extension_codec,
845+
)
846+
})
847+
.collect::<Result<_, BallistaError>>()?;
848+
Ok(protobuf::LogicalPlanNode {
849+
logical_plan_type: Some(LogicalPlanType::Union(
850+
protobuf::UnionNode { inputs },
851+
)),
852+
})
853+
}
819854
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
820855
let left = protobuf::LogicalPlanNode::try_from_logical_plan(
821856
left.as_ref(),

ballista/rust/core/src/serde/physical_plan/mod.rs

+21
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
5252
use datafusion::physical_plan::projection::ProjectionExec;
5353
use datafusion::physical_plan::repartition::RepartitionExec;
5454
use datafusion::physical_plan::sorts::sort::SortExec;
55+
use datafusion::physical_plan::union::UnionExec;
5556
use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec};
5657
use datafusion::physical_plan::{
5758
AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr,
@@ -382,6 +383,13 @@ impl AsExecutionPlan for PhysicalPlanNode {
382383
&hashjoin.null_equals_null,
383384
)?))
384385
}
386+
PhysicalPlanType::Union(union) => {
387+
let mut inputs: Vec<Arc<dyn ExecutionPlan>> = vec![];
388+
for input in &union.inputs {
389+
inputs.push(input.try_into_physical_plan(ctx, extension_codec)?);
390+
}
391+
Ok(Arc::new(UnionExec::new(inputs)))
392+
}
385393
PhysicalPlanType::CrossJoin(crossjoin) => {
386394
let left: Arc<dyn ExecutionPlan> =
387395
into_physical_plan!(crossjoin.left, ctx, extension_codec)?;
@@ -866,6 +874,19 @@ impl AsExecutionPlan for PhysicalPlanNode {
866874
},
867875
)),
868876
})
877+
} else if let Some(union) = plan.downcast_ref::<UnionExec>() {
878+
let mut inputs: Vec<PhysicalPlanNode> = vec![];
879+
for input in union.inputs() {
880+
inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan(
881+
input.to_owned(),
882+
extension_codec,
883+
)?);
884+
}
885+
Ok(protobuf::PhysicalPlanNode {
886+
physical_plan_type: Some(PhysicalPlanType::Union(
887+
protobuf::UnionExecNode { inputs },
888+
)),
889+
})
869890
} else {
870891
let mut buf: Vec<u8> = vec![];
871892
extension_codec.try_encode(plan_clone.clone(), &mut buf)?;

datafusion/src/physical_plan/union.rs

+5
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ impl UnionExec {
5656
metrics: ExecutionPlanMetricsSet::new(),
5757
}
5858
}
59+
60+
/// Get inputs of the execution plan
61+
pub fn inputs(&self) -> &Vec<Arc<dyn ExecutionPlan>> {
62+
&self.inputs
63+
}
5964
}
6065

6166
#[async_trait]

0 commit comments

Comments
 (0)