From c3bb0275cfc59770730114eada575599f8cd150e Mon Sep 17 00:00:00 2001 From: ZENOTME <43447882+ZENOTME@users.noreply.github.com> Date: Mon, 6 Feb 2023 14:16:39 +0800 Subject: [PATCH] feat(frontend): seperate plan_fragmenter into two phase (#7581) To solve #7439, we need to do async operation in plan_fragmentor. To do this, I seperate the plan_fragmentor into two phase **so that we can do async operation in phase 2**: phase 1 : BatchPlanFragmenter.split(batch_node) -> PreStageGraph phase 2 : PreStageGraph.complete() -> StageGraph The difference between PreStageGraph and StageGraph is that StageGraph contains the exchange_info and parallism. These information will be filled in phase 2. Approved-By: liurenjie1024 --- Cargo.lock | 5 +- src/frontend/Cargo.toml | 1 + src/frontend/planner_test/src/lib.rs | 3 +- src/frontend/src/handler/explain.rs | 232 +++++++------ src/frontend/src/handler/mod.rs | 2 +- src/frontend/src/handler/query.rs | 9 +- .../src/optimizer/property/distribution.rs | 27 +- .../src/scheduler/distributed/query.rs | 11 +- .../src/scheduler/distributed/stage.rs | 8 +- src/frontend/src/scheduler/local.rs | 6 +- src/frontend/src/scheduler/plan_fragmenter.rs | 320 +++++++++++++++--- 11 files changed, 431 insertions(+), 193 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 21e5ac50e7ed8..1e27f6f2c6555 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -336,9 +336,9 @@ dependencies = [ [[package]] name = "async-recursion" -version = "1.0.0" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cda8f4bcc10624c4e85bc66b3f452cca98cfa5ca002dc83a16aad2367641bea" +checksum = "3b015a331cc64ebd1774ba119538573603427eaace0a1950c423ab971f903796" dependencies = [ "proc-macro2", "quote", @@ -6161,6 +6161,7 @@ dependencies = [ "arc-swap", "assert-impl", "assert_matches", + "async-recursion", "async-trait", "bk-tree", "byteorder", diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index c2b4246867baa..0de7e2df10e4e 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -12,6 +12,7 @@ repository = { workspace = true } anyhow = "1" arc-swap = "1" assert-impl = "0.1" +async-recursion = "1.0.2" async-trait = "0.1" bk-tree = "0.4.0" byteorder = "1.4" diff --git a/src/frontend/planner_test/src/lib.rs b/src/frontend/planner_test/src/lib.rs index 63ff1569bd862..10efc505c9bcf 100644 --- a/src/frontend/planner_test/src/lib.rs +++ b/src/frontend/planner_test/src/lib.rs @@ -441,7 +441,8 @@ impl TestCase { if result.is_some() { panic!("two queries in one test case"); } - let rsp = explain::handle_explain(handler_args, *statement, options, analyze)?; + let rsp = + explain::handle_explain(handler_args, *statement, options, analyze).await?; let explain_output = get_explain_output(rsp).await; let ret = TestCaseResult { diff --git a/src/frontend/src/handler/explain.rs b/src/frontend/src/handler/explain.rs index f0fab9e1d1754..111bc81ccfcf1 100644 --- a/src/frontend/src/handler/explain.rs +++ b/src/frontend/src/handler/explain.rs @@ -34,7 +34,7 @@ use crate::scheduler::BatchPlanFragmenter; use crate::stream_fragmenter::build_graph; use crate::utils::explain_stream_graph; -pub fn handle_explain( +pub async fn handle_explain( handler_args: HandlerArgs, stmt: Statement, options: ExplainOptions, @@ -48,132 +48,140 @@ pub fn handle_explain( let session = context.session_ctx().clone(); - let plan = match stmt { - Statement::CreateView { - or_replace: false, - materialized: true, - query, - name, - columns, - .. - } => gen_create_mv_plan(&session, context.into(), *query, name, columns)?.0, + let mut plan_fragmenter = None; + let mut rows = { + let plan = match stmt { + Statement::CreateView { + or_replace: false, + materialized: true, + query, + name, + columns, + .. + } => gen_create_mv_plan(&session, context.into(), *query, name, columns)?.0, - Statement::CreateSink { stmt } => gen_sink_plan(&session, context.into(), stmt)?.0, + Statement::CreateSink { stmt } => gen_sink_plan(&session, context.into(), stmt)?.0, - Statement::CreateTable { - name, - columns, - constraints, - source_schema, - .. - } => match check_create_table_with_source(&handler_args.with_options, source_schema)? { - Some(_) => { - return Err(ErrorCode::NotImplemented( - "explain create table with a connector".to_string(), - None.into(), - ) - .into()) - } - None => { - gen_create_table_plan( - context, - name, - columns, - constraints, - ColumnIdGenerator::new_initial(), - )? - .0 - } - }, + Statement::CreateTable { + name, + columns, + constraints, + source_schema, + .. + } => match check_create_table_with_source(&handler_args.with_options, source_schema)? { + Some(_) => { + return Err(ErrorCode::NotImplemented( + "explain create table with a connector".to_string(), + None.into(), + ) + .into()) + } + None => { + gen_create_table_plan( + context, + name, + columns, + constraints, + ColumnIdGenerator::new_initial(), + )? + .0 + } + }, - Statement::CreateIndex { - name, - table_name, - columns, - include, - distributed_by, - .. - } => { - gen_create_index_plan( - &session, - context.into(), + Statement::CreateIndex { name, table_name, columns, include, distributed_by, - )? - .0 - } + .. + } => { + gen_create_index_plan( + &session, + context.into(), + name, + table_name, + columns, + include, + distributed_by, + )? + .0 + } - stmt => gen_batch_query_plan(&session, context.into(), stmt)?.0, - }; + stmt => gen_batch_query_plan(&session, context.into(), stmt)?.0, + }; - let ctx = plan.plan_base().ctx.clone(); - let explain_trace = ctx.is_explain_trace(); - let explain_verbose = ctx.is_explain_verbose(); + let ctx = plan.plan_base().ctx.clone(); + let explain_trace = ctx.is_explain_trace(); + let explain_verbose = ctx.is_explain_verbose(); - let mut rows = if explain_trace { - let trace = ctx.take_trace(); - trace - .iter() - .flat_map(|s| s.lines()) - .map(|s| Row::new(vec![Some(s.to_string().into())])) - .collect::>() - } else { - vec![] - }; + let mut rows = if explain_trace { + let trace = ctx.take_trace(); + trace + .iter() + .flat_map(|s| s.lines()) + .map(|s| Row::new(vec![Some(s.to_string().into())])) + .collect::>() + } else { + vec![] + }; - match options.explain_type { - ExplainType::DistSql => match plan.convention() { - Convention::Logical => unreachable!(), - Convention::Batch => { - let plan_fragmenter = BatchPlanFragmenter::new( - session.env().worker_node_manager_ref(), - session.env().catalog_reader().clone(), - ); - let query = plan_fragmenter.split(plan)?; - let stage_graph_json = serde_json::to_string_pretty(&query.stage_graph).unwrap(); - rows.extend( - vec![stage_graph_json] - .iter() - .flat_map(|s| s.lines()) - .map(|s| Row::new(vec![Some(s.to_string().into())])), - ); + match options.explain_type { + ExplainType::DistSql => match plan.convention() { + Convention::Logical => unreachable!(), + Convention::Batch => { + plan_fragmenter = Some(BatchPlanFragmenter::new( + session.env().worker_node_manager_ref(), + session.env().catalog_reader().clone(), + plan, + )?); + } + Convention::Stream => { + let graph = build_graph(plan); + rows.extend( + explain_stream_graph(&graph, explain_verbose)? + .lines() + .map(|s| Row::new(vec![Some(s.to_string().into())])), + ); + } + }, + ExplainType::Physical => { + // if explain trace is open, the plan has been in the rows + if !explain_trace { + let output = plan.explain_to_string()?; + rows.extend( + output + .lines() + .map(|s| Row::new(vec![Some(s.to_string().into())])), + ); + } } - Convention::Stream => { - let graph = build_graph(plan); - rows.extend( - explain_stream_graph(&graph, explain_verbose)? - .lines() - .map(|s| Row::new(vec![Some(s.to_string().into())])), - ); - } - }, - ExplainType::Physical => { - // if explain trace is open, the plan has been in the rows - if !explain_trace { - let output = plan.explain_to_string()?; - rows.extend( - output - .lines() - .map(|s| Row::new(vec![Some(s.to_string().into())])), - ); - } - } - ExplainType::Logical => { - // if explain trace is open, the plan has been in the rows - if !explain_trace { - let output = plan.ctx().take_logical().ok_or_else(|| { - ErrorCode::InternalError("Logical plan not found for query".into()) - })?; - rows.extend( - output - .lines() - .map(|s| Row::new(vec![Some(s.to_string().into())])), - ); + ExplainType::Logical => { + // if explain trace is open, the plan has been in the rows + if !explain_trace { + let output = plan.ctx().take_logical().ok_or_else(|| { + ErrorCode::InternalError("Logical plan not found for query".into()) + })?; + rows.extend( + output + .lines() + .map(|s| Row::new(vec![Some(s.to_string().into())])), + ); + } } } + rows + }; + + if let Some(plan_fragmenter) = plan_fragmenter { + let query = plan_fragmenter.generate_complete_query().await?; + let stage_graph_json = serde_json::to_string_pretty(&query.stage_graph).unwrap(); + rows.extend( + vec![stage_graph_json] + .iter() + .flat_map(|s| s.lines()) + .map(|s| Row::new(vec![Some(s.to_string().into())])), + ); } Ok(PgResponse::new_for_stream( diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index 7644abe6e0089..4740d911e63ca 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -160,7 +160,7 @@ pub async fn handle( statement, analyze, options, - } => explain::handle_explain(handler_args, *statement, options, analyze), + } => explain::handle_explain(handler_args, *statement, options, analyze).await, Statement::CreateSource { stmt } => { create_source::handle_create_source(handler_args, stmt).await } diff --git a/src/frontend/src/handler/query.rs b/src/frontend/src/handler/query.rs index 6eafe7ec3289e..361ec520859c0 100644 --- a/src/frontend/src/handler/query.rs +++ b/src/frontend/src/handler/query.rs @@ -103,7 +103,7 @@ pub async fn handle_query( let mut notice = String::new(); // Subblock to make sure PlanRef (an Rc) is dropped before `await` below. - let (query, query_mode, output_schema) = { + let (plan_fragmenter, query_mode, output_schema) = { let context = OptimizerContext::from_handler_args(handler_args); let (plan, query_mode, schema) = gen_batch_query_plan(&session, context.into(), stmt)?; @@ -116,11 +116,12 @@ pub async fn handle_query( let plan_fragmenter = BatchPlanFragmenter::new( session.env().worker_node_manager_ref(), session.env().catalog_reader().clone(), - ); - let query = plan_fragmenter.split(plan)?; + plan, + )?; context.append_notice(&mut notice); - (query, query_mode, schema) + (plan_fragmenter, query_mode, schema) }; + let query = plan_fragmenter.generate_complete_query().await?; tracing::trace!("Generated query after plan fragmenter: {:?}", &query); let pg_descs = output_schema diff --git a/src/frontend/src/optimizer/property/distribution.rs b/src/frontend/src/optimizer/property/distribution.rs index 3857134d713b0..cdc45c4d12328 100644 --- a/src/frontend/src/optimizer/property/distribution.rs +++ b/src/frontend/src/optimizer/property/distribution.rs @@ -57,10 +57,11 @@ use risingwave_pb::batch_plan::exchange_info::{ use risingwave_pb::batch_plan::ExchangeInfo; use super::super::plan_node::*; +use crate::catalog::catalog_service::CatalogReader; use crate::optimizer::plan_node::stream::StreamPlanRef; use crate::optimizer::property::Order; use crate::optimizer::PlanRef; -use crate::scheduler::BatchPlanFragmenter; +use crate::scheduler::worker_node_manager::WorkerNodeManagerRef; /// the distribution property provided by a operator. #[derive(Debug, Clone, PartialEq)] @@ -108,7 +109,12 @@ pub enum RequiredDist { } impl Distribution { - pub fn to_prost(&self, output_count: u32, fragmenter: &BatchPlanFragmenter) -> ExchangeInfo { + pub fn to_prost( + &self, + output_count: u32, + catalog_reader: &CatalogReader, + worker_node_manager: &WorkerNodeManagerRef, + ) -> ExchangeInfo { ExchangeInfo { mode: match self { Distribution::Single => DistributionMode::Single, @@ -139,8 +145,9 @@ impl Distribution { "hash key should not be empty, use `Single` instead" ); - let vnode_mapping = Self::get_vnode_mapping(fragmenter, table_id) - .expect("vnode_mapping of UpstreamHashShard should not be none"); + let vnode_mapping = + Self::get_vnode_mapping(catalog_reader, worker_node_manager, table_id) + .expect("vnode_mapping of UpstreamHashShard should not be none"); let pu2id_map: HashMap = vnode_mapping .iter_unique() @@ -194,18 +201,14 @@ impl Distribution { #[inline(always)] fn get_vnode_mapping( - fragmenter: &BatchPlanFragmenter, + catalog_reader: &CatalogReader, + worker_node_manager: &WorkerNodeManagerRef, table_id: &TableId, ) -> Option { - fragmenter - .catalog_reader() + catalog_reader .read_guard() .get_table_by_id(table_id) - .map(|table| { - fragmenter - .worker_node_manager() - .get_fragment_mapping(&table.fragment_id) - }) + .map(|table| worker_node_manager.get_fragment_mapping(&table.fragment_id)) .ok() .flatten() } diff --git a/src/frontend/src/scheduler/distributed/query.rs b/src/frontend/src/scheduler/distributed/query.rs index b91327db9d4c5..de3f6c000a4b6 100644 --- a/src/frontend/src/scheduler/distributed/query.rs +++ b/src/frontend/src/scheduler/distributed/query.rs @@ -546,7 +546,7 @@ pub(crate) mod tests { ), ) .into(); - let batch_exchange_node3: PlanRef = BatchExchange::new( + let batch_exchange_node: PlanRef = BatchExchange::new( hash_join_node.clone(), Order::default(), Distribution::Single, @@ -590,8 +590,13 @@ pub(crate) mod tests { catalog.write().insert_table_id_mapping(table_id, 0); let catalog_reader = CatalogReader::new(catalog); // Break the plan node into fragments. - let fragmenter = BatchPlanFragmenter::new(worker_node_manager, catalog_reader); - fragmenter.split(batch_exchange_node3.clone()).unwrap() + let fragmenter = BatchPlanFragmenter::new( + worker_node_manager, + catalog_reader, + batch_exchange_node.clone(), + ) + .unwrap(); + fragmenter.generate_complete_query().await.unwrap() } fn generate_parallel_units(start_id: u32, node_id: u32) -> Vec { diff --git a/src/frontend/src/scheduler/distributed/stage.rs b/src/frontend/src/scheduler/distributed/stage.rs index 7826aca1f1de7..cd7709fdfb448 100644 --- a/src/frontend/src/scheduler/distributed/stage.rs +++ b/src/frontend/src/scheduler/distributed/stage.rs @@ -167,7 +167,7 @@ impl StageExecution { catalog_reader: CatalogReader, ctx: ExecutionContextRef, ) -> Self { - let tasks = (0..stage.parallelism) + let tasks = (0..stage.parallelism.unwrap()) .map(|task_id| (task_id, TaskStatusHolder::new(task_id))) .collect(); Self { @@ -328,7 +328,7 @@ impl StageRunner { futures.push(self.schedule_task(task_id, plan_fragment, Some(worker))); } } else if let Some(source_info) = self.stage.source_info.as_ref() { - for (id, split) in source_info.split_info().iter().enumerate() { + for (id, split) in source_info.split_info().unwrap().iter().enumerate() { let task_id = TaskIdProst { query_id: self.stage.query_id.id.clone(), stage_id: self.stage.id, @@ -340,7 +340,7 @@ impl StageRunner { } } else { - for id in 0..self.stage.parallelism { + for id in 0..self.stage.parallelism.unwrap() { let task_id = TaskIdProst { query_id: self.stage.query_id.id.clone(), stage_id: self.stage.id, @@ -720,7 +720,7 @@ impl StageRunner { let plan_node_prost = self.convert_plan_node(&self.stage.root, task_id, partition, identity_id); - let exchange_info = self.stage.exchange_info.clone(); + let exchange_info = self.stage.exchange_info.clone().unwrap(); PlanFragment { root: Some(plan_node_prost), diff --git a/src/frontend/src/scheduler/local.rs b/src/frontend/src/scheduler/local.rs index b5b710288b090..170574316ca05 100644 --- a/src/frontend/src/scheduler/local.rs +++ b/src/frontend/src/scheduler/local.rs @@ -155,7 +155,7 @@ impl LocalQueryExecution { fn create_plan_fragment(&self) -> SchedulerResult { let root_stage_id = self.query.root_stage_id(); let root_stage = self.query.stage_graph.stages.get(&root_stage_id).unwrap(); - assert_eq!(root_stage.parallelism, 1); + assert_eq!(root_stage.parallelism.unwrap(), 1); let second_stage_id = self.query.stage_graph.get_child_stages(&root_stage_id); let plan_node_prost = match second_stage_id { None => { @@ -269,7 +269,7 @@ impl LocalQueryExecution { sources.push(exchange_source); } } else if let Some(source_info) = &second_stage.source_info { - for (id,split) in source_info.split_info().iter().enumerate() { + for (id,split) in source_info.split_info().unwrap().iter().enumerate() { let second_stage_plan_node = self.convert_plan_node( &second_stage.root, &mut None, @@ -319,7 +319,7 @@ impl LocalQueryExecution { epoch: Some(self.snapshot.get_batch_query_epoch()), }; - let workers = if second_stage.parallelism == 1 { + let workers = if second_stage.parallelism.unwrap() == 1 { vec![self.front_env.worker_node_manager().next_random()?] } else { self.front_env.worker_node_manager().list_worker_nodes() diff --git a/src/frontend/src/scheduler/plan_fragmenter.rs b/src/frontend/src/scheduler/plan_fragmenter.rs index 32eb2c31e1752..265a88ae558cb 100644 --- a/src/frontend/src/scheduler/plan_fragmenter.rs +++ b/src/frontend/src/scheduler/plan_fragmenter.rs @@ -17,8 +17,8 @@ use std::fmt::{Debug, Formatter}; use std::sync::Arc; use anyhow::anyhow; +use async_recursion::async_recursion; use enum_as_inner::EnumAsInner; -use futures::executor::block_on; use itertools::Itertools; use risingwave_common::buffer::{Bitmap, BitmapBuilder}; use risingwave_common::catalog::TableDesc; @@ -116,10 +116,12 @@ impl ExecutionPlanNode { /// `BatchPlanFragmenter` splits a query plan into fragments. pub struct BatchPlanFragmenter { query_id: QueryId, - stage_graph_builder: StageGraphBuilder, next_stage_id: StageId, worker_node_manager: WorkerNodeManagerRef, catalog_reader: CatalogReader, + + stage_graph_builder: Option, + stage_graph: Option, } impl Default for QueryId { @@ -131,14 +133,36 @@ impl Default for QueryId { } impl BatchPlanFragmenter { - pub fn new(worker_node_manager: WorkerNodeManagerRef, catalog_reader: CatalogReader) -> Self { - Self { + pub fn new( + worker_node_manager: WorkerNodeManagerRef, + catalog_reader: CatalogReader, + batch_node: PlanRef, + ) -> SchedulerResult { + let mut plan_fragmenter = Self { query_id: Default::default(), - stage_graph_builder: StageGraphBuilder::new(), + stage_graph_builder: Some(StageGraphBuilder::new()), next_stage_id: 0, worker_node_manager, catalog_reader, - } + stage_graph: None, + }; + plan_fragmenter.split_into_stage(batch_node)?; + Ok(plan_fragmenter) + } + + /// Split the plan node into each stages, based on exchange node. + fn split_into_stage(&mut self, batch_node: PlanRef) -> SchedulerResult<()> { + let root_stage = self.new_stage( + batch_node, + Some(Distribution::Single.to_prost(1, &self.catalog_reader, &self.worker_node_manager)), + )?; + self.stage_graph = Some( + self.stage_graph_builder + .take() + .unwrap() + .build(root_stage.id), + ); + Ok(()) } } @@ -200,19 +224,57 @@ impl Query { } } +#[derive(Debug, Clone)] +pub struct SourceFetchInfo { + pub connector: ConnectorProperties, + pub timebound: (Option, Option), +} + #[derive(Clone, Debug)] -pub struct SourceScanInfo { +pub enum SourceScanInfo { /// Split Info - split_info: Vec, + Incomplete(SourceFetchInfo), + Complete(Vec), } impl SourceScanInfo { - pub fn new(split_info: Vec) -> Self { - Self { split_info } + pub fn new(fetch_info: SourceFetchInfo) -> Self { + Self::Incomplete(fetch_info) + } + + pub async fn complete(self) -> SchedulerResult { + let fetch_info = match self { + SourceScanInfo::Incomplete(fetch_info) => fetch_info, + SourceScanInfo::Complete(_) => { + unreachable!("Never call complete when SourceScanInfo is already complete") + } + }; + let mut enumerator = SplitEnumeratorImpl::create(fetch_info.connector).await?; + let kafka_enumerator = match enumerator { + SplitEnumeratorImpl::Kafka(ref mut kafka_enumerator) => kafka_enumerator, + _ => { + return Err(SchedulerError::Internal(anyhow!( + "Unsupported to query directly from this source" + ))) + } + }; + let split_info = kafka_enumerator + .list_splits_batch(fetch_info.timebound.0, fetch_info.timebound.1) + .await? + .into_iter() + .map(SplitImpl::Kafka) + .collect_vec(); + + Ok(SourceScanInfo::Complete(split_info)) } - pub fn split_info(&self) -> &Vec { - &self.split_info + pub fn split_info(&self) -> SchedulerResult<&Vec> { + match self { + Self::Incomplete(_) => Err(SchedulerError::Internal(anyhow!( + "Should not get split info from incomplete source scan info" + ))), + Self::Complete(split_info) => Ok(split_info), + } } } @@ -270,16 +332,20 @@ pub enum PartitionInfo { } /// Fragment part of `Query`. +#[derive(Clone)] pub struct QueryStage { pub query_id: QueryId, pub id: StageId, pub root: Arc, - pub exchange_info: ExchangeInfo, - pub parallelism: u32, + pub exchange_info: Option, + pub parallelism: Option, /// Indicates whether this stage contains a table scan node and the table's information if so. pub table_scan_info: Option, pub source_info: Option, pub has_lookup_join: bool, + + /// Used to generage exchange information when complete source scan information. + children_exhchange_distribution: Option>, } impl QueryStage { @@ -295,6 +361,48 @@ impl QueryStage { pub fn has_lookup_join(&self) -> bool { self.has_lookup_join } + + pub fn clone_with_exchange_info(&self, exchange_info: Option) -> Self { + if let Some(exchange_info) = exchange_info { + return Self { + query_id: self.query_id.clone(), + id: self.id, + root: self.root.clone(), + exchange_info: Some(exchange_info), + parallelism: self.parallelism, + table_scan_info: self.table_scan_info.clone(), + source_info: self.source_info.clone(), + has_lookup_join: self.has_lookup_join, + children_exhchange_distribution: self.children_exhchange_distribution.clone(), + }; + } + self.clone() + } + + pub fn clone_with_exchange_info_and_complete_source_info( + &self, + exchange_info: Option, + source_info: SourceScanInfo, + ) -> Self { + assert!(matches!(source_info, SourceScanInfo::Complete(_))); + let exchange_info = if let Some(exchange_info) = exchange_info { + Some(exchange_info) + } else { + self.exchange_info.clone() + }; + + Self { + query_id: self.query_id.clone(), + id: self.id, + root: self.root.clone(), + exchange_info, + parallelism: Some(source_info.split_info().unwrap().len() as u32), + table_scan_info: self.table_scan_info.clone(), + source_info: Some(source_info), + has_lookup_join: self.has_lookup_join, + children_exhchange_distribution: None, + } + } } impl Debug for QueryStage { @@ -327,22 +435,24 @@ struct QueryStageBuilder { query_id: QueryId, id: StageId, root: Option>, - parallelism: u32, - exchange_info: ExchangeInfo, + parallelism: Option, + exchange_info: Option, children_stages: Vec, /// See also [`QueryStage::table_scan_info`]. table_scan_info: Option, source_info: Option, has_lookup_join: bool, + + children_exhchange_distribution: HashMap, } impl QueryStageBuilder { fn new( id: StageId, query_id: QueryId, - parallelism: u32, - exchange_info: ExchangeInfo, + parallelism: Option, + exchange_info: Option, table_scan_info: Option, source_info: Option, has_lookup_join: bool, @@ -357,10 +467,16 @@ impl QueryStageBuilder { table_scan_info, source_info, has_lookup_join, + children_exhchange_distribution: HashMap::new(), } } fn finish(self, stage_graph_builder: &mut StageGraphBuilder) -> QueryStageRef { + let children_exhchange_distribution = if self.parallelism.is_none() { + Some(self.children_exhchange_distribution) + } else { + None + }; let stage = Arc::new(QueryStage { query_id: self.query_id, id: self.id, @@ -370,6 +486,7 @@ impl QueryStageBuilder { table_scan_info: self.table_scan_info, source_info: self.source_info, has_lookup_join: self.has_lookup_join, + children_exhchange_distribution, }); stage_graph_builder.add_node(stage.clone()); @@ -418,6 +535,95 @@ impl StageGraph { ret.into_iter().rev() } + + async fn complete( + self, + catalog_reader: &CatalogReader, + worker_node_manager: &WorkerNodeManagerRef, + ) -> SchedulerResult { + let mut complete_stages = HashMap::new(); + self.complete_stage( + self.stages.get(&self.root_stage_id).unwrap().clone(), + None, + &mut complete_stages, + catalog_reader, + worker_node_manager, + ) + .await?; + Ok(StageGraph { + root_stage_id: self.root_stage_id, + stages: complete_stages, + child_edges: self.child_edges, + parent_edges: self.parent_edges, + }) + } + + #[async_recursion] + async fn complete_stage( + &self, + stage: QueryStageRef, + exchange_info: Option, + complete_stages: &mut HashMap, + catalog_reader: &CatalogReader, + worker_node_manager: &WorkerNodeManagerRef, + ) -> SchedulerResult<()> { + let parallelism = if stage.parallelism.is_some() { + // If the stage has parallelism, it means it's a complete stage. + complete_stages.insert( + stage.id, + Arc::new(stage.clone_with_exchange_info(exchange_info)), + ); + None + } else { + assert!(matches!( + stage.source_info, + Some(SourceScanInfo::Incomplete(_)) + )); + let complete_source_info = stage + .source_info + .as_ref() + .unwrap() + .clone() + .complete() + .await?; + + let complete_stage = Arc::new(stage.clone_with_exchange_info_and_complete_source_info( + exchange_info, + complete_source_info, + )); + let parallelism = complete_stage.parallelism; + complete_stages.insert(stage.id, complete_stage); + parallelism + }; + + for child_stage_id in self.child_edges.get(&stage.id).unwrap_or(&HashSet::new()) { + let exchange_info = if let Some(parallelism) = parallelism { + let exchange_distribution = stage + .children_exhchange_distribution + .as_ref() + .unwrap() + .get(child_stage_id) + .expect("Exchange distribution is not consistent with the stage graph"); + Some(exchange_distribution.to_prost( + parallelism, + catalog_reader, + worker_node_manager, + )) + } else { + None + }; + self.complete_stage( + self.stages.get(child_stage_id).unwrap().clone(), + exchange_info, + complete_stages, + catalog_reader, + worker_node_manager, + ) + .await?; + } + + Ok(()) + } } struct StageGraphBuilder { @@ -466,20 +672,26 @@ impl StageGraphBuilder { } impl BatchPlanFragmenter { - /// Split the plan node into each stages, based on exchange node. - pub fn split(mut self, batch_node: PlanRef) -> SchedulerResult { - let root_stage = self.new_stage(batch_node, Distribution::Single.to_prost(1, &self))?; - let stage_graph = self.stage_graph_builder.build(root_stage.id); + /// After split, the `stage_graph` in the framenter may has the stage with incomplete source + /// info, we need to fetch the source info to complete the stage in this function. + /// Why separate this two step(`split()` and `generate_complete_query()`)? + /// The step of fetching source info is a async operation so that we can't do it in the split + /// step. + pub async fn generate_complete_query(self) -> SchedulerResult { + let stage_graph = self.stage_graph.unwrap(); + let new_stage_graph = stage_graph + .complete(&self.catalog_reader, &self.worker_node_manager) + .await?; Ok(Query { - stage_graph, query_id: self.query_id, + stage_graph: new_stage_graph, }) } fn new_stage( &mut self, root: PlanRef, - exchange_info: ExchangeInfo, + exchange_info: Option, ) -> SchedulerResult { let next_stage_id = self.next_stage_id; self.next_stage_id += 1; @@ -518,8 +730,10 @@ impl BatchPlanFragmenter { } else { // System table } - } else { - // No table scan + } else if source_info.is_some() { + return Err(SchedulerError::Internal(anyhow!( + "The stage has single distribution, but contains a source operator" + ))); } 1 } @@ -535,18 +749,23 @@ impl BatchPlanFragmenter { { has_lookup_join = true; lookup_join_parallelism - } else if let Some(source_info) = &source_info { - source_info.split_info().len() + } else if source_info.is_some() { + 0 } else { self.worker_node_manager.worker_node_count() } } }; + let parallelism = if parallelism == 0 { + None + } else { + Some(parallelism as u32) + }; let mut builder = QueryStageBuilder::new( next_stage_id, self.query_id.clone(), - parallelism as u32, + parallelism, exchange_info, table_scan_info, source_info, @@ -555,7 +774,7 @@ impl BatchPlanFragmenter { self.visit_node(root, &mut builder, None)?; - Ok(builder.finish(&mut self.stage_graph_builder)) + Ok(builder.finish(self.stage_graph_builder.as_mut().unwrap())) } fn visit_node( @@ -592,9 +811,22 @@ impl BatchPlanFragmenter { parent_exec_node: Option<&mut ExecutionPlanNode>, ) -> SchedulerResult<()> { let mut execution_plan_node = ExecutionPlanNode::from(node.clone()); - let child_exchange_info = node.distribution().to_prost(builder.parallelism, self); + let child_exchange_info = if let Some(parallelism) = builder.parallelism { + Some(node.distribution().to_prost( + parallelism, + &self.catalog_reader, + &self.worker_node_manager, + )) + } else { + None + }; let child_stage = self.new_stage(node.inputs()[0].clone(), child_exchange_info)?; execution_plan_node.source_stage_id = Some(child_stage.id); + if builder.parallelism.is_none() { + builder + .children_exhchange_distribution + .insert(child_stage.id, node.distribution().clone()); + } if let Some(parent) = parent_exec_node { parent.children.push(Arc::new(execution_plan_node)); @@ -620,25 +852,11 @@ impl BatchPlanFragmenter { let source_catalog = source_node.logical().source_catalog(); if let Some(source_catalog) = source_catalog { let property = ConnectorProperties::extract(source_catalog.properties.clone())?; - let mut enumerator = block_on(SplitEnumeratorImpl::create(property))?; - let kafka_enumerator = match enumerator { - SplitEnumeratorImpl::Kafka(ref mut kafka_enumerator) => kafka_enumerator, - _ => { - return Err(SchedulerError::Internal(anyhow!( - "Unsupported to query directly from this source" - ))) - } - }; let timestamp_bound = source_node.logical().kafka_timestamp_range_value(); - // println!("Timestamp bound: {:?}", timestamp_bound); - let split_info = block_on( - kafka_enumerator.list_splits_batch(timestamp_bound.0, timestamp_bound.1), - )? - .into_iter() - .map(SplitImpl::Kafka) - .collect_vec(); - // println!("Split info: {:?}", split_info); - return Ok(Some(SourceScanInfo::new(split_info))); + return Ok(Some(SourceScanInfo::new(SourceFetchInfo { + connector: property, + timebound: timestamp_bound, + }))); } } @@ -861,12 +1079,12 @@ mod tests { assert_eq!(root_exchange.root.node_type(), PlanNodeType::BatchExchange); assert_eq!(root_exchange.root.source_stage_id, Some(1)); assert!(matches!(root_exchange.root.node, NodeBody::Exchange(_))); - assert_eq!(root_exchange.parallelism, 1); + assert_eq!(root_exchange.parallelism, Some(1)); assert!(!root_exchange.has_table_scan()); let join_node = query.stage_graph.stages.get(&1).unwrap(); assert_eq!(join_node.root.node_type(), PlanNodeType::BatchHashJoin); - assert_eq!(join_node.parallelism, 3); + assert_eq!(join_node.parallelism, Some(3)); assert!(matches!(join_node.root.node, NodeBody::HashJoin(_))); assert_eq!(join_node.root.source_stage_id, None);