From c5e2594e99b01c12d4f6903cb998a62a5479455c Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Mon, 9 Jan 2023 17:27:41 +0100 Subject: [PATCH] Add DataFrame::into_view instead of implementing TableProvider (#2659) (#4778) --- datafusion/core/src/dataframe.rs | 41 ++++++++++++--------- datafusion/core/src/datasource/view.rs | 4 +- datafusion/expr/src/logical_plan/builder.rs | 10 +++++ 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index fe417593a234..e9773dbdf372 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -528,6 +528,15 @@ impl DataFrame { self.session_state.optimize(&self.plan) } + /// Converts this [`DataFrame`] into a [`TableProvider`] that can be registered + /// as a table view using [`SessionContext::register_table`]. + /// + /// Note: This discards the [`SessionState`] associated with this + /// [`DataFrame`] in favour of the one passed to [`TableProvider::scan`] + pub fn into_view(self) -> Arc { + Arc::new(DataFrameTableProvider { plan: self.plan }) + } + /// Return the optimized logical plan represented by this DataFrame. /// /// Note: This method should not be used outside testing, as it loses the snapshot @@ -766,9 +775,12 @@ impl DataFrame { } } -// TODO: This will introduce a ref cycle (#2659) +struct DataFrameTableProvider { + plan: LogicalPlan, +} + #[async_trait] -impl TableProvider for DataFrame { +impl TableProvider for DataFrameTableProvider { fn as_any(&self) -> &dyn Any { self } @@ -796,20 +808,14 @@ impl TableProvider for DataFrame { async fn scan( &self, - _state: &SessionState, + state: &SessionState, projection: Option<&Vec>, filters: &[Expr], limit: Option, ) -> Result> { - let mut expr = self.clone(); + let mut expr = LogicalPlanBuilder::from(self.plan.clone()); if let Some(p) = projection { - let schema = TableProvider::schema(&expr).project(p)?; - let names = schema - .fields() - .iter() - .map(|field| field.name().as_str()) - .collect::>(); - expr = expr.select_columns(names.as_slice())?; + expr = expr.select(p.iter().copied())? } // Add filter when given @@ -817,13 +823,12 @@ impl TableProvider for DataFrame { if let Some(filter) = filter { expr = expr.filter(filter)? } + // add a limit if given if let Some(l) = limit { expr = expr.limit(0, Some(l))? } - // add a limit if given - Self::new(self.session_state.clone(), expr.plan) - .create_physical_plan() - .await + let plan = expr.build()?; + state.create_physical_plan(&plan).await } } @@ -1098,7 +1103,7 @@ mod tests { let df_impl = DataFrame::new(ctx.state(), df.plan.clone()); // register a dataframe as a table - ctx.register_table("test_table", Arc::new(df_impl.clone()))?; + ctx.register_table("test_table", df_impl.clone().into_view())?; // pull the table out let table = ctx.table("test_table").await?; @@ -1297,7 +1302,7 @@ mod tests { let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?; let ctx = SessionContext::new(); - let table = Arc::new(df); + let table = df.into_view(); ctx.register_table("t1", table.clone())?; ctx.register_table("t2", table)?; let df = ctx @@ -1386,7 +1391,7 @@ mod tests { ) .await?; - ctx.register_table("t1", Arc::new(ctx.table("test").await?))?; + ctx.register_table("t1", ctx.table("test").await?.into_view())?; let df = ctx .table("t1") diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index 2d2f33dc2051..524ad9f5c2ad 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -428,7 +428,7 @@ mod tests { ) .await?; - ctx.register_table("t1", Arc::new(ctx.table("test").await?))?; + ctx.register_table("t1", ctx.table("test").await?.into_view())?; ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?; @@ -458,7 +458,7 @@ mod tests { ) .await?; - ctx.register_table("t1", Arc::new(ctx.table("test").await?))?; + ctx.register_table("t1", ctx.table("test").await?.into_view())?; ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 63783a110667..428a31f14997 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -288,6 +288,16 @@ impl LogicalPlanBuilder { Ok(Self::from(project(self.plan, expr)?)) } + /// Select the given column indices + pub fn select(self, indices: impl IntoIterator) -> Result { + let fields = self.plan.schema().fields(); + let exprs: Vec<_> = indices + .into_iter() + .map(|x| Expr::Column(fields[x].qualified_column())) + .collect(); + self.project(exprs) + } + /// Apply a filter pub fn filter(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?;