diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index 4c0ab4244be3..4e5cc1a7a76b 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -24,21 +24,27 @@ use std::{collections::HashMap, convert::TryInto}; use std::{fs, time::Duration}; use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient; +use ballista_core::serde::protobuf::PartitionLocation; use ballista_core::serde::protobuf::{ execute_query_params::Query, job_status, ExecuteQueryParams, GetJobStatusParams, GetJobStatusResult, }; use ballista_core::{ - client::BallistaClient, datasource::DfTableAdapter, memory_stream::MemoryStream, - utils::create_datafusion_context, + client::BallistaClient, datasource::DfTableAdapter, utils::create_datafusion_context, }; use datafusion::arrow::datatypes::Schema; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::arrow::error::Result as ArrowResult; +use datafusion::arrow::record_batch::RecordBatch; use datafusion::catalog::TableReference; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_plan::LogicalPlan; use datafusion::physical_plan::csv::CsvReadOptions; use datafusion::{dataframe::DataFrame, physical_plan::RecordBatchStream}; +use futures::future; +use futures::Stream; +use futures::StreamExt; use log::{error, info}; #[allow(dead_code)] @@ -68,6 +74,32 @@ impl BallistaContextState { } } +struct WrappedStream { + stream: Pin> + Send + Sync>>, + schema: SchemaRef, +} + +impl RecordBatchStream for WrappedStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for WrappedStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.stream.poll_next_unpin(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} + #[allow(dead_code)] pub struct BallistaContext { @@ -155,6 +187,29 @@ impl BallistaContext { ctx.sql(sql) } + async fn fetch_partition( + location: PartitionLocation, + ) -> Result>> { + let metadata = location.executor_meta.ok_or_else(|| { + DataFusionError::Internal("Received empty executor metadata".to_owned()) + })?; + let partition_id = location.partition_id.ok_or_else(|| { + DataFusionError::Internal("Received empty partition id".to_owned()) + })?; + let mut ballista_client = + BallistaClient::try_new(metadata.host.as_str(), metadata.port as u16) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + Ok(ballista_client + .fetch_partition( + &partition_id.job_id, + partition_id.stage_id as usize, + partition_id.partition_id as usize, + ) + .await + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?) + } + pub async fn collect( &self, plan: &LogicalPlan, @@ -222,45 +277,21 @@ impl BallistaContext { break Err(DataFusionError::Execution(msg)); } job_status::Status::Completed(completed) => { - // TODO: use streaming. Probably need to change the signature of fetch_partition to achieve that - let mut result = vec![]; - for location in completed.partition_location { - let metadata = location.executor_meta.ok_or_else(|| { - DataFusionError::Internal( - "Received empty executor metadata".to_owned(), - ) - })?; - let partition_id = location.partition_id.ok_or_else(|| { - DataFusionError::Internal( - "Received empty partition id".to_owned(), - ) - })?; - let mut ballista_client = BallistaClient::try_new( - metadata.host.as_str(), - metadata.port as u16, - ) - .await - .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; - let stream = ballista_client - .fetch_partition( - &partition_id.job_id, - partition_id.stage_id as usize, - partition_id.partition_id as usize, - ) - .await - .map_err(|e| { - DataFusionError::Execution(format!("{:?}", e)) - })?; - result.append( - &mut datafusion::physical_plan::common::collect(stream) - .await?, - ); - } - break Ok(Box::pin(MemoryStream::try_new( - result, - Arc::new(schema), - None, - )?)); + let result = future::join_all( + completed + .partition_location + .into_iter() + .map(BallistaContext::fetch_partition), + ) + .await + .into_iter() + .collect::>>()?; + + let result = WrappedStream { + stream: Box::pin(futures::stream::iter(result).flatten()), + schema: Arc::new(schema), + }; + break Ok(Box::pin(result)); } }; }