Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataFrame owned SessionState #4633

Merged
merged 4 commits into from
Dec 17, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-examples/examples/custom_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async fn search_accounts(
)?
.build()?;

let mut dataframe = DataFrame::new(ctx.state, logical_plan)
let mut dataframe = DataFrame::new(ctx.state(), logical_plan)
.select_columns(&["id", "bank_account"])?;

if let Some(f) = filter {
Expand Down
58 changes: 16 additions & 42 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use std::any::Any;
use std::sync::Arc;

use async_trait::async_trait;
use parking_lot::RwLock;
use parquet::file::properties::WriterProperties;

use datafusion_common::{Column, DFSchema};
Expand Down Expand Up @@ -74,40 +73,22 @@ use crate::prelude::SessionContext;
/// ```
#[derive(Debug, Clone)]
pub struct DataFrame {
session_state: Arc<RwLock<SessionState>>,
session_state: SessionState,
plan: LogicalPlan,
}

impl DataFrame {
/// Create a new Table based on an existing logical plan
pub fn new(session_state: Arc<RwLock<SessionState>>, plan: LogicalPlan) -> Self {
pub fn new(session_state: SessionState, plan: LogicalPlan) -> Self {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

Self {
session_state,
plan,
}
}

/// Create a physical plan
pub async fn create_physical_plan(self) -> Result<Arc<dyn ExecutionPlan>> {
// this function is copied from SessionContext function of the
// same name
let state_cloned = {
let mut state = self.session_state.write();
state.execution_props.start_execution();

// We need to clone `state` to release the lock that is not `Send`. We could
// make the lock `Send` by using `tokio::sync::Mutex`, but that would require to
// propagate async even to the `LogicalPlan` building methods.
// Cloning `state` here is fine as we then pass it as immutable `&state`, which
// means that we avoid write consistency issues as the cloned version will not
// be written to. As for eventual modifications that would be applied to the
// original state after it has been cloned, they will not be picked up by the
// clone but that is okay, as it is equivalent to postponing the state update
// by keeping the lock until the end of the function scope.
state.clone()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice the lock meant that we were still cloning SessionState fairly frequently, better to just be explicit about it and optimise from there

};

state_cloned.create_physical_plan(&self.plan).await
pub async fn create_physical_plan(&self) -> Result<Arc<dyn ExecutionPlan>> {
self.session_state.create_physical_plan(&self.plan).await
}

/// Filter the DataFrame by column. Returns a new DataFrame only containing the
Expand Down Expand Up @@ -437,8 +418,7 @@ impl DataFrame {
}

fn task_ctx(&self) -> TaskContext {
let lock = self.session_state.read();
TaskContext::from(&*lock)
TaskContext::from(&self.session_state)
}

/// Executes this DataFrame and returns a stream over a single partition
Expand Down Expand Up @@ -527,8 +507,7 @@ impl DataFrame {
/// Return the optimized logical plan represented by this DataFrame.
pub fn to_logical_plan(self) -> Result<LogicalPlan> {
// Optimize the plan first for better UX
let state = self.session_state.read().clone();
state.optimize(&self.plan)
self.session_state.optimize(&self.plan)
}

/// Return a DataFrame with the explanation of its plan so far.
Expand Down Expand Up @@ -567,9 +546,8 @@ impl DataFrame {
/// # Ok(())
/// # }
/// ```
pub fn registry(&self) -> Arc<dyn FunctionRegistry> {
let registry = self.session_state.read().clone();
Arc::new(registry)
pub fn registry(&self) -> &dyn FunctionRegistry {
&self.session_state
}

/// Calculate the intersection of two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema
Expand Down Expand Up @@ -621,9 +599,8 @@ impl DataFrame {

/// Write a `DataFrame` to a CSV file.
pub async fn write_csv(self, path: &str) -> Result<()> {
let state = self.session_state.read().clone();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah all the explicit cloning is definitely a warning sign

let plan = self.create_physical_plan().await?;
plan_to_csv(&state, plan, path).await
plan_to_csv(&self.session_state, plan, path).await
}

/// Write a `DataFrame` to a Parquet file.
Expand All @@ -632,16 +609,14 @@ impl DataFrame {
path: &str,
writer_properties: Option<WriterProperties>,
) -> Result<()> {
let state = self.session_state.read().clone();
let plan = self.create_physical_plan().await?;
plan_to_parquet(&state, plan, path, writer_properties).await
plan_to_parquet(&self.session_state, plan, path, writer_properties).await
}

/// Executes a query and writes the results to a partitioned JSON file.
pub async fn write_json(self, path: impl AsRef<str>) -> Result<()> {
let state = self.session_state.read().clone();
let plan = self.create_physical_plan().await?;
plan_to_json(&state, plan, path).await
plan_to_json(&self.session_state, plan, path).await
}

/// Add an additional column to the DataFrame.
Expand Down Expand Up @@ -747,7 +722,7 @@ impl DataFrame {
/// # }
/// ```
pub async fn cache(self) -> Result<DataFrame> {
let context = SessionContext::with_state(self.session_state.read().clone());
let context = SessionContext::with_state(self.session_state.clone());
let mem_table = MemTable::try_new(
SchemaRef::from(self.schema().clone()),
self.collect_partitioned().await?,
Expand Down Expand Up @@ -1029,9 +1004,8 @@ mod tests {
// build query with a UDF using DataFrame API
let df = ctx.table("aggregate_test_100")?;

let f = df.registry();

let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?;
let expr = df.registry().udf("my_fn")?.call(vec![col("c12")]);
let df = df.select(vec![expr])?;

// build query using SQL
let sql_plan =
Expand Down Expand Up @@ -1088,7 +1062,7 @@ mod tests {
async fn register_table() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c12"])?;
let ctx = SessionContext::new();
let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone());
let df_impl = DataFrame::new(ctx.state(), df.plan.clone());

// register a dataframe as a table
ctx.register_table("test_table", Arc::new(df_impl.clone()))?;
Expand Down Expand Up @@ -1180,7 +1154,7 @@ mod tests {
async fn with_column() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
let ctx = SessionContext::new();
let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone());
let df_impl = DataFrame::new(ctx.state(), df.plan.clone());

let df = df_impl
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
Expand Down
21 changes: 12 additions & 9 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl SessionContext {
(false, true, Ok(_)) => {
self.deregister_table(&name)?;
let schema = Arc::new(input.schema().as_ref().into());
let physical = DataFrame::new(self.state.clone(), input);
let physical = DataFrame::new(self.state(), input);

let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(schema, batches)?);
Expand All @@ -286,7 +286,7 @@ impl SessionContext {
)),
(_, _, Err(_)) => {
let schema = Arc::new(input.schema().as_ref().into());
let physical = DataFrame::new(self.state.clone(), input);
let physical = DataFrame::new(self.state(), input);

let batches: Vec<_> = physical.collect_partitioned().await?;
let table = Arc::new(MemTable::try_new(schema, batches)?);
Expand Down Expand Up @@ -363,7 +363,8 @@ impl SessionContext {
LogicalPlan::SetVariable(SetVariable {
variable, value, ..
}) => {
let config_options = &self.state.write().config.config_options;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually somewhat surprised this was compiling, there must be some magic going on to extend the lifetime of the temporary lock guard, which is wild

let state = self.state.write();
let config_options = &state.config.config_options;

let old_value =
config_options.read().get(&variable).ok_or_else(|| {
Expand Down Expand Up @@ -410,6 +411,8 @@ impl SessionContext {
))
}
}
drop(state);

self.return_empty_dataframe()
}

Expand Down Expand Up @@ -475,14 +478,14 @@ impl SessionContext {
}
}

plan => Ok(DataFrame::new(self.state.clone(), plan)),
plan => Ok(DataFrame::new(self.state(), plan)),
}
}

// return an empty dataframe
fn return_empty_dataframe(&self) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::empty(false).build()?;
Ok(DataFrame::new(self.state.clone(), plan))
Ok(DataFrame::new(self.state(), plan))
}

async fn create_external_table(
Expand Down Expand Up @@ -661,7 +664,7 @@ impl SessionContext {
/// Creates an empty DataFrame.
pub fn read_empty(&self) -> Result<DataFrame> {
Ok(DataFrame::new(
self.state.clone(),
self.state(),
LogicalPlanBuilder::empty(true).build()?,
))
}
Expand Down Expand Up @@ -716,7 +719,7 @@ impl SessionContext {
/// Creates a [`DataFrame`] for reading a custom [`TableProvider`].
pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<DataFrame> {
Ok(DataFrame::new(
self.state.clone(),
self.state(),
LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?
.build()?,
))
Expand All @@ -726,7 +729,7 @@ impl SessionContext {
pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> {
let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?;
Ok(DataFrame::new(
self.state.clone(),
self.state(),
LogicalPlanBuilder::scan(
UNNAMED_TABLE,
provider_as_source(Arc::new(provider)),
Expand Down Expand Up @@ -946,7 +949,7 @@ impl SessionContext {
None,
)?
.build()?;
Ok(DataFrame::new(self.state.clone(), plan))
Ok(DataFrame::new(self.state(), plan))
}

/// Return a [`TabelProvider`] for the specified table.
Expand Down