From 4504276bde1fdd5efa6b84e57cdc475177bd7b9e Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Sun, 15 Aug 2021 15:20:21 -0700 Subject: [PATCH] add cross join support to ballista --- ballista/rust/core/proto/ballista.proto | 14 +++++++++++++- .../rust/core/src/serde/logical_plan/from_proto.rs | 11 +++++++++-- .../rust/core/src/serde/logical_plan/to_proto.rs | 13 ++++++++++++- .../core/src/serde/physical_plan/from_proto.rs | 7 +++++++ .../rust/core/src/serde/physical_plan/to_proto.rs | 12 ++++++++++++ datafusion/src/physical_plan/planner.rs | 5 +++-- 6 files changed, 56 insertions(+), 6 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index a1608c652dba..12a27ba878e4 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -249,9 +249,10 @@ message LogicalPlanNode { RepartitionNode repartition = 9; EmptyRelationNode empty_relation = 10; CreateExternalTableNode create_external_table = 11; - AnalyzeNode analyze = 14; ExplainNode explain = 12; WindowNode window = 13; + AnalyzeNode analyze = 14; + CrossJoinNode cross_join = 15; } } @@ -399,6 +400,11 @@ message JoinNode { repeated Column right_join_column = 6; } +message CrossJoinNode { + LogicalPlanNode left = 1; + LogicalPlanNode right = 2; +} + message LimitNode { LogicalPlanNode input = 1; uint32 limit = 2; @@ -432,6 +438,7 @@ message PhysicalPlanNode { RepartitionExecNode repartition = 16; WindowAggExecNode window = 17; ShuffleWriterExecNode shuffle_writer = 18; + CrossJoinExecNode cross_join = 19; } } @@ -593,6 +600,11 @@ message HashJoinExecNode { PartitionMode partition_mode = 6; } +message CrossJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; +} + message PhysicalColumn { string name = 1; uint32 index = 2; diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index f9761a201541..ade2cb40adb7 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -41,8 +41,6 @@ use std::{ unimplemented, }; -// use uuid::Uuid; - impl TryInto for &protobuf::LogicalPlanNode { type Error = BallistaError; @@ -290,6 +288,15 @@ impl TryInto for &protobuf::LogicalPlanNode { builder.build().map_err(|e| e.into()) } + LogicalPlanType::CrossJoin(crossjoin) => { + let left = convert_box_required!(crossjoin.left)?; + let right = convert_box_required!(crossjoin.right)?; + + LogicalPlanBuilder::from(left) + .cross_join(&right)? + .build() + .map_err(|e| e.into()) + } } } } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index e1c7f53cf9cf..5877ced5f561 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -955,7 +955,18 @@ impl TryInto for &LogicalPlan { } LogicalPlan::Extension { .. } => unimplemented!(), LogicalPlan::Union { .. } => unimplemented!(), - LogicalPlan::CrossJoin { .. } => unimplemented!(), + LogicalPlan::CrossJoin { left, right, .. } => { + let left: protobuf::LogicalPlanNode = left.as_ref().try_into()?; + let right: protobuf::LogicalPlanNode = right.as_ref().try_into()?; + Ok(protobuf::LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CrossJoin(Box::new( + protobuf::CrossJoinNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + }, + ))), + }) + } } } } diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 678bcde8fa73..46815db056a1 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -51,6 +51,7 @@ use datafusion::physical_plan::window_functions::{ use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; use datafusion::physical_plan::{ coalesce_batches::CoalesceBatchesExec, + cross_join::CrossJoinExec, csv::CsvExec, empty::EmptyExec, expressions::{ @@ -372,6 +373,12 @@ impl TryInto> for &protobuf::PhysicalPlanNode { partition_mode, )?)) } + PhysicalPlanType::CrossJoin(crossjoin) => { + let left: Arc = convert_box_required!(crossjoin.left)?; + let right: Arc = + convert_box_required!(crossjoin.right)?; + Ok(Arc::new(CrossJoinExec::try_new(left, right)?)) + } PhysicalPlanType::ShuffleWriter(shuffle_writer) => { let input: Arc = convert_box_required!(shuffle_writer.input)?; diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 48b21345525b..8d8f917461a9 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -28,6 +28,7 @@ use std::{ use datafusion::logical_plan::JoinType; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion::physical_plan::cross_join::CrossJoinExec; use datafusion::physical_plan::csv::CsvExec; use datafusion::physical_plan::expressions::{ CaseExpr, InListExpr, IsNotNullExpr, IsNullExpr, NegativeExpr, NotExpr, @@ -155,6 +156,17 @@ impl TryInto for Arc { }, ))), }) + } else if let Some(exec) = plan.downcast_ref::() { + let left: protobuf::PhysicalPlanNode = exec.left().to_owned().try_into()?; + let right: protobuf::PhysicalPlanNode = exec.right().to_owned().try_into()?; + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::CrossJoin(Box::new( + protobuf::CrossJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + }, + ))), + }) } else if let Some(exec) = plan.downcast_ref::() { let groups = exec .group_expr() diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 256a43b205e5..02ab15d1a652 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -19,8 +19,8 @@ use super::analyze::AnalyzeExec; use super::{ - aggregates, cross_join::CrossJoinExec, empty::EmptyExec, expressions::binary, - functions, hash_join::PartitionMode, udaf, union::UnionExec, windows, + aggregates, empty::EmptyExec, expressions::binary, functions, + hash_join::PartitionMode, udaf, union::UnionExec, windows, }; use crate::execution::context::ExecutionContextState; use crate::logical_plan::{ @@ -29,6 +29,7 @@ use crate::logical_plan::{ UserDefinedLogicalNode, }; use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; +use crate::physical_plan::cross_join::CrossJoinExec; use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions; use crate::physical_plan::expressions::{CaseExpr, Column, Literal, PhysicalSortExpr};