Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DataFrame.collect for Ballista DataFrames #785

Merged
merged 15 commits into from
Jul 28, 2021
Merged
Prev Previous commit
Next Next commit
DataFrame.collect now works with Ballista DataFrames
  • Loading branch information
andygrove committed Jul 27, 2021
commit 79de6d5da2c0add12ef6f7ae1d199eb849d1235f
13 changes: 1 addition & 12 deletions ballista-examples/src/bin/ballista-dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,7 @@ async fn main() -> Result<()> {
.select_columns(&["id", "bool_col", "timestamp_col"])?
.filter(col("id").gt(lit(1)))?;

// execute the query - note that calling collect on the DataFrame
// trait will execute the query with DataFusion so we have to call
// collect on the BallistaContext instead and pass it the DataFusion
// logical plan
let mut stream = ctx.collect(&df.to_logical_plan()).await?;

// print the results
let mut results = vec![];
while let Some(batch) = stream.next().await {
let batch = batch?;
results.push(batch);
}
let results = df.collect().await?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is certainly nicer 👍

pretty::print_batches(&results)?;

Ok(())
Expand Down
13 changes: 1 addition & 12 deletions ballista-examples/src/bin/ballista-sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,7 @@ async fn main() -> Result<()> {
GROUP BY c1",
)?;

// execute the query - note that calling collect on the DataFrame
// trait will execute the query with DataFusion so we have to call
// collect on the BallistaContext instead and pass it the DataFusion
// logical plan
let mut stream = ctx.collect(&df.to_logical_plan()).await?;

// print the results
let mut results = vec![];
while let Some(batch) = stream.next().await {
let batch = batch?;
results.push(batch);
}
let results = df.collect().await?;
pretty::print_batches(&results)?;

Ok(())
Expand Down
19 changes: 1 addition & 18 deletions ballista/rust/client/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,16 @@
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::{Arc, Mutex};

use ballista_core::config::BallistaConfig;
use ballista_core::{datasource::DfTableAdapter, utils::create_datafusion_context};

use ballista_core::execution_plans::DistributedQueryExec;
use datafusion::catalog::TableReference;
use datafusion::dataframe::DataFrame;
use datafusion::error::Result;
use datafusion::logical_plan::LogicalPlan;
use datafusion::physical_plan::csv::CsvReadOptions;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::{dataframe::DataFrame, physical_plan::RecordBatchStream};

struct BallistaContextState {
/// Ballista configuration
Expand Down Expand Up @@ -208,20 +205,6 @@ impl BallistaContext {
}
ctx.sql(sql)
}

pub async fn collect(
&self,
plan: &LogicalPlan,
) -> Result<Pin<Box<dyn RecordBatchStream + Send + Sync>>> {
let distributed_query = {
let state = self.state.lock().unwrap();
let scheduler_url =
format!("http://{}:{}", state.scheduler_host, state.scheduler_port);
DistributedQueryExec::new(scheduler_url, state.config.clone(), plan.clone())
};

distributed_query.execute(0).await
}
}

#[cfg(test)]
Expand Down
15 changes: 10 additions & 5 deletions ballista/rust/scheduler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ use self::state::{ConfigBackendClient, SchedulerState};
use ballista_core::config::BallistaConfig;
use ballista_core::execution_plans::ShuffleWriterExec;
use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto;
use ballista_core::utils::create_datafusion_context;
use datafusion::physical_plan::parquet::ParquetExec;
use datafusion::prelude::{ExecutionConfig, ExecutionContext};
use std::time::{Instant, SystemTime, UNIX_EPOCH};

#[derive(Clone)]
Expand Down Expand Up @@ -341,8 +341,7 @@ impl SchedulerGrpc for SchedulerServer {
Query::Sql(sql) => {
//TODO we can't just create a new context because we need a context that has
// tables registered from previous SQL statements that have been executed
//TODO scheduler host and port
let mut ctx = create_datafusion_context("", 50050, &config);
let mut ctx = create_datafusion_context2(&config);
let df = ctx.sql(&sql).map_err(|e| {
let msg = format!("Error parsing SQL: {}", e);
error!("{}", msg);
Expand Down Expand Up @@ -378,8 +377,7 @@ impl SchedulerGrpc for SchedulerServer {
let job_id_spawn = job_id.clone();
tokio::spawn(async move {
// create physical plan using DataFusion
//TODO scheduler url
let datafusion_ctx = create_datafusion_context("", 50050, &config);
let datafusion_ctx = create_datafusion_context2(&config);
macro_rules! fail_job {
($code :expr) => {{
match $code {
Expand Down Expand Up @@ -513,6 +511,13 @@ impl SchedulerGrpc for SchedulerServer {
}
}

/// Create a DataFusion context that is compatible with Ballista
pub fn create_datafusion_context2(config: &BallistaConfig) -> ExecutionContext {
let config =
ExecutionConfig::new().with_concurrency(config.default_shuffle_partitions());
ExecutionContext::with_config(config)
}

#[cfg(all(test, feature = "sled"))]
mod test {
use std::{
Expand Down
11 changes: 2 additions & 9 deletions benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ use std::{
time::Instant,
};

use futures::StreamExt;

use ballista::context::BallistaContext;
use ballista::prelude::{BallistaConfig, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS};

Expand Down Expand Up @@ -312,15 +310,10 @@ async fn benchmark_ballista(opt: BallistaBenchmarkOpt) -> Result<()> {
let df = ctx
.sql(&sql)
.map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?;
let mut batches = vec![];
let mut stream = ctx
.collect(&df.to_logical_plan())
let batches = df
.collect()
.await
.map_err(|e| DataFusionError::Plan(format!("{:?}", e)))?;
while let Some(result) = stream.next().await {
let batch = result?;
batches.push(batch);
}
let elapsed = start.elapsed().as_secs_f64() * 1000.0;
millis.push(elapsed as f64);
println!("Query {} iteration {} took {:.1} ms", opt.query, i, elapsed);
Expand Down