-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DataFrame owned SessionState #4633
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,6 @@ use std::any::Any; | |
use std::sync::Arc; | ||
|
||
use async_trait::async_trait; | ||
use parking_lot::RwLock; | ||
use parquet::file::properties::WriterProperties; | ||
|
||
use datafusion_common::{Column, DFSchema}; | ||
|
@@ -74,40 +73,22 @@ use crate::prelude::SessionContext; | |
/// ``` | ||
#[derive(Debug, Clone)] | ||
pub struct DataFrame { | ||
session_state: Arc<RwLock<SessionState>>, | ||
session_state: SessionState, | ||
plan: LogicalPlan, | ||
} | ||
|
||
impl DataFrame { | ||
/// Create a new Table based on an existing logical plan | ||
pub fn new(session_state: Arc<RwLock<SessionState>>, plan: LogicalPlan) -> Self { | ||
pub fn new(session_state: SessionState, plan: LogicalPlan) -> Self { | ||
Self { | ||
session_state, | ||
plan, | ||
} | ||
} | ||
|
||
/// Create a physical plan | ||
pub async fn create_physical_plan(self) -> Result<Arc<dyn ExecutionPlan>> { | ||
// this function is copied from SessionContext function of the | ||
// same name | ||
let state_cloned = { | ||
let mut state = self.session_state.write(); | ||
state.execution_props.start_execution(); | ||
|
||
// We need to clone `state` to release the lock that is not `Send`. We could | ||
// make the lock `Send` by using `tokio::sync::Mutex`, but that would require to | ||
// propagate async even to the `LogicalPlan` building methods. | ||
// Cloning `state` here is fine as we then pass it as immutable `&state`, which | ||
// means that we avoid write consistency issues as the cloned version will not | ||
// be written to. As for eventual modifications that would be applied to the | ||
// original state after it has been cloned, they will not be picked up by the | ||
// clone but that is okay, as it is equivalent to postponing the state update | ||
// by keeping the lock until the end of the function scope. | ||
state.clone() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In practice the lock meant that we were still cloning SessionState fairly frequently, better to just be explicit about it and optimise from there |
||
}; | ||
|
||
state_cloned.create_physical_plan(&self.plan).await | ||
pub async fn create_physical_plan(&self) -> Result<Arc<dyn ExecutionPlan>> { | ||
self.session_state.create_physical_plan(&self.plan).await | ||
} | ||
|
||
/// Filter the DataFrame by column. Returns a new DataFrame only containing the | ||
|
@@ -437,8 +418,7 @@ impl DataFrame { | |
} | ||
|
||
fn task_ctx(&self) -> TaskContext { | ||
let lock = self.session_state.read(); | ||
TaskContext::from(&*lock) | ||
TaskContext::from(&self.session_state) | ||
} | ||
|
||
/// Executes this DataFrame and returns a stream over a single partition | ||
|
@@ -527,8 +507,7 @@ impl DataFrame { | |
/// Return the optimized logical plan represented by this DataFrame. | ||
pub fn to_logical_plan(self) -> Result<LogicalPlan> { | ||
// Optimize the plan first for better UX | ||
let state = self.session_state.read().clone(); | ||
state.optimize(&self.plan) | ||
self.session_state.optimize(&self.plan) | ||
} | ||
|
||
/// Return a DataFrame with the explanation of its plan so far. | ||
|
@@ -567,9 +546,8 @@ impl DataFrame { | |
/// # Ok(()) | ||
/// # } | ||
/// ``` | ||
pub fn registry(&self) -> Arc<dyn FunctionRegistry> { | ||
let registry = self.session_state.read().clone(); | ||
Arc::new(registry) | ||
pub fn registry(&self) -> &dyn FunctionRegistry { | ||
&self.session_state | ||
} | ||
|
||
/// Calculate the intersection of two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema | ||
|
@@ -621,9 +599,8 @@ impl DataFrame { | |
|
||
/// Write a `DataFrame` to a CSV file. | ||
pub async fn write_csv(self, path: &str) -> Result<()> { | ||
let state = self.session_state.read().clone(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah all the explicit cloning is definitely a warning sign |
||
let plan = self.create_physical_plan().await?; | ||
plan_to_csv(&state, plan, path).await | ||
plan_to_csv(&self.session_state, plan, path).await | ||
} | ||
|
||
/// Write a `DataFrame` to a Parquet file. | ||
|
@@ -632,16 +609,14 @@ impl DataFrame { | |
path: &str, | ||
writer_properties: Option<WriterProperties>, | ||
) -> Result<()> { | ||
let state = self.session_state.read().clone(); | ||
let plan = self.create_physical_plan().await?; | ||
plan_to_parquet(&state, plan, path, writer_properties).await | ||
plan_to_parquet(&self.session_state, plan, path, writer_properties).await | ||
} | ||
|
||
/// Executes a query and writes the results to a partitioned JSON file. | ||
pub async fn write_json(self, path: impl AsRef<str>) -> Result<()> { | ||
let state = self.session_state.read().clone(); | ||
let plan = self.create_physical_plan().await?; | ||
plan_to_json(&state, plan, path).await | ||
plan_to_json(&self.session_state, plan, path).await | ||
} | ||
|
||
/// Add an additional column to the DataFrame. | ||
|
@@ -747,7 +722,7 @@ impl DataFrame { | |
/// # } | ||
/// ``` | ||
pub async fn cache(self) -> Result<DataFrame> { | ||
let context = SessionContext::with_state(self.session_state.read().clone()); | ||
let context = SessionContext::with_state(self.session_state.clone()); | ||
let mem_table = MemTable::try_new( | ||
SchemaRef::from(self.schema().clone()), | ||
self.collect_partitioned().await?, | ||
|
@@ -1029,9 +1004,8 @@ mod tests { | |
// build query with a UDF using DataFrame API | ||
let df = ctx.table("aggregate_test_100")?; | ||
|
||
let f = df.registry(); | ||
|
||
let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?; | ||
let expr = df.registry().udf("my_fn")?.call(vec![col("c12")]); | ||
let df = df.select(vec![expr])?; | ||
|
||
// build query using SQL | ||
let sql_plan = | ||
|
@@ -1088,7 +1062,7 @@ mod tests { | |
async fn register_table() -> Result<()> { | ||
let df = test_table().await?.select_columns(&["c1", "c12"])?; | ||
let ctx = SessionContext::new(); | ||
let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone()); | ||
let df_impl = DataFrame::new(ctx.state(), df.plan.clone()); | ||
|
||
// register a dataframe as a table | ||
ctx.register_table("test_table", Arc::new(df_impl.clone()))?; | ||
|
@@ -1180,7 +1154,7 @@ mod tests { | |
async fn with_column() -> Result<()> { | ||
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?; | ||
let ctx = SessionContext::new(); | ||
let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone()); | ||
let df_impl = DataFrame::new(ctx.state(), df.plan.clone()); | ||
|
||
let df = df_impl | ||
.filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -273,7 +273,7 @@ impl SessionContext { | |
(false, true, Ok(_)) => { | ||
self.deregister_table(&name)?; | ||
let schema = Arc::new(input.schema().as_ref().into()); | ||
let physical = DataFrame::new(self.state.clone(), input); | ||
let physical = DataFrame::new(self.state(), input); | ||
|
||
let batches: Vec<_> = physical.collect_partitioned().await?; | ||
let table = Arc::new(MemTable::try_new(schema, batches)?); | ||
|
@@ -286,7 +286,7 @@ impl SessionContext { | |
)), | ||
(_, _, Err(_)) => { | ||
let schema = Arc::new(input.schema().as_ref().into()); | ||
let physical = DataFrame::new(self.state.clone(), input); | ||
let physical = DataFrame::new(self.state(), input); | ||
|
||
let batches: Vec<_> = physical.collect_partitioned().await?; | ||
let table = Arc::new(MemTable::try_new(schema, batches)?); | ||
|
@@ -363,7 +363,8 @@ impl SessionContext { | |
LogicalPlan::SetVariable(SetVariable { | ||
variable, value, .. | ||
}) => { | ||
let config_options = &self.state.write().config.config_options; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm actually somewhat surprised this was compiling, there must be some magic going on to extend the lifetime of the temporary lock guard, which is wild |
||
let state = self.state.write(); | ||
let config_options = &state.config.config_options; | ||
|
||
let old_value = | ||
config_options.read().get(&variable).ok_or_else(|| { | ||
|
@@ -410,6 +411,8 @@ impl SessionContext { | |
)) | ||
} | ||
} | ||
drop(state); | ||
|
||
self.return_empty_dataframe() | ||
} | ||
|
||
|
@@ -475,14 +478,14 @@ impl SessionContext { | |
} | ||
} | ||
|
||
plan => Ok(DataFrame::new(self.state.clone(), plan)), | ||
plan => Ok(DataFrame::new(self.state(), plan)), | ||
} | ||
} | ||
|
||
// return an empty dataframe | ||
fn return_empty_dataframe(&self) -> Result<DataFrame> { | ||
let plan = LogicalPlanBuilder::empty(false).build()?; | ||
Ok(DataFrame::new(self.state.clone(), plan)) | ||
Ok(DataFrame::new(self.state(), plan)) | ||
} | ||
|
||
async fn create_external_table( | ||
|
@@ -661,7 +664,7 @@ impl SessionContext { | |
/// Creates an empty DataFrame. | ||
pub fn read_empty(&self) -> Result<DataFrame> { | ||
Ok(DataFrame::new( | ||
self.state.clone(), | ||
self.state(), | ||
LogicalPlanBuilder::empty(true).build()?, | ||
)) | ||
} | ||
|
@@ -716,7 +719,7 @@ impl SessionContext { | |
/// Creates a [`DataFrame`] for reading a custom [`TableProvider`]. | ||
pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<DataFrame> { | ||
Ok(DataFrame::new( | ||
self.state.clone(), | ||
self.state(), | ||
LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)? | ||
.build()?, | ||
)) | ||
|
@@ -726,7 +729,7 @@ impl SessionContext { | |
pub fn read_batch(&self, batch: RecordBatch) -> Result<DataFrame> { | ||
let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?; | ||
Ok(DataFrame::new( | ||
self.state.clone(), | ||
self.state(), | ||
LogicalPlanBuilder::scan( | ||
UNNAMED_TABLE, | ||
provider_as_source(Arc::new(provider)), | ||
|
@@ -946,7 +949,7 @@ impl SessionContext { | |
None, | ||
)? | ||
.build()?; | ||
Ok(DataFrame::new(self.state.clone(), plan)) | ||
Ok(DataFrame::new(self.state(), plan)) | ||
} | ||
|
||
/// Return a [`TabelProvider`] for the specified table. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️