From a5d5c45d7fea14bea0da038b0fdada5eafa60e5d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 5 Dec 2022 15:58:40 -0500 Subject: [PATCH 1/6] Add tests for names with period --- datafusion/common/src/dfschema.rs | 2 +- datafusion/common/src/lib.rs | 2 +- datafusion/common/src/table_reference.rs | 89 +++- datafusion/core/src/catalog/listing_schema.rs | 8 +- datafusion/core/src/datasource/view.rs | 6 +- datafusion/core/src/execution/context.rs | 28 +- datafusion/core/tests/sql/errors.rs | 7 +- .../tests/sqllogictests/src/insert/mod.rs | 21 +- .../core/tests/sqllogictests/src/main.rs | 12 +- .../test_files/{insert.slt => ddl.slt} | 31 +- datafusion/expr/src/logical_plan/plan.rs | 14 +- datafusion/proto/proto/datafusion.proto | 29 +- datafusion/proto/src/from_proto.rs | 33 +- datafusion/proto/src/generated/pbjson.rs | 466 +++++++++++++++++- datafusion/proto/src/generated/prost.rs | 48 +- datafusion/proto/src/logical_plan.rs | 27 +- datafusion/proto/src/to_proto.rs | 32 +- datafusion/sql/src/planner.rs | 269 ++++++---- datafusion/sql/src/utils.rs | 8 + 19 files changed, 959 insertions(+), 173 deletions(-) rename datafusion/core/tests/sqllogictests/test_files/{insert.slt => ddl.slt} (69%) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 8ad179c5bd41..764a5743f014 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -206,7 +206,7 @@ impl DFSchema { (Some(qq), None) => { // the original field may now be aliased with a name that matches the // original qualified name - let table_ref: TableReference = field.name().as_str().into(); + let table_ref = TableReference::parse_str(field.name().as_str()); match table_ref { TableReference::Partial { schema, table } => { schema == qq && table == name diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 4e272d3c6662..60d69324913b 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -36,7 +36,7 @@ pub use error::{field_not_found, DataFusionError, Result, SchemaError}; pub use parsers::parse_interval; pub use scalar::{ScalarType, ScalarValue}; pub use stats::{ColumnStatistics, Statistics}; -pub use table_reference::{ResolvedTableReference, TableReference}; +pub use table_reference::{OwnedTableReference, ResolvedTableReference, TableReference}; /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index 27905163dc7e..d18dbbd35a00 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -52,6 +52,73 @@ pub enum TableReference<'a> { }, } +/// Represents a path to a table that may require further resolution +/// that owns the underlying names +#[derive(Debug, Clone)] +pub enum OwnedTableReference { + /// An unqualified table reference, e.g. "table" + Bare { + /// The table name + table: String, + }, + /// A partially resolved table reference, e.g. "schema.table" + Partial { + /// The schema containing the table + schema: String, + /// The table name + table: String, + }, + /// A fully resolved table reference, e.g. "catalog.schema.table" + Full { + /// The catalog (aka database) containing the table + catalog: String, + /// The schema containing the table + schema: String, + /// The table name + table: String, + }, +} + +impl OwnedTableReference { + /// Return a `TableReference` view of this `OwnedTableReference` + pub fn as_table_reference(&self) -> TableReference<'_> { + match self { + Self::Bare { table } => TableReference::Bare { table }, + Self::Partial { schema, table } => TableReference::Partial { schema, table }, + Self::Full { + catalog, + schema, + table, + } => TableReference::Full { + catalog, + schema, + table, + }, + } + } + + /// Return a string suitable for display + pub fn display_string(&self) -> String { + match self { + OwnedTableReference::Bare { table } => table.clone(), + OwnedTableReference::Partial { schema, table } => format!("{schema}.{table}"), + OwnedTableReference::Full { + catalog, + schema, + table, + } => format!("{catalog}.{schema}.{table}"), + } + } +} + +/// Convert `OwnedTableReference` into a `TableReference`. Somewhat +/// akward to use but 'idiomatic': `(&table_ref).into()` +impl<'a> From<&'a OwnedTableReference> for TableReference<'a> { + fn from(r: &'a OwnedTableReference) -> Self { + r.as_table_reference() + } +} + impl<'a> TableReference<'a> { /// Retrieve the actual table name, regardless of qualification pub fn table(&self) -> &str { @@ -90,10 +157,17 @@ impl<'a> TableReference<'a> { }, } } -} -impl<'a> From<&'a str> for TableReference<'a> { - fn from(s: &'a str) -> Self { + /// Split `s` on periods. + /// + /// Note that this function does NOT handle periods or name + /// normalization correctly (e.g. `"foo.bar"` will be parsed as + /// `"foo`.`bar"`. and `Foo` will be parsed as `Foo` (not `foo`). + /// + /// Instead, you should use SQL parser for this + /// + /// The improvement is tracked in TODO FILE DATAFUSION TICKET + pub fn parse_str(s: &'a str) -> Self { let parts: Vec<&str> = s.split('.').collect(); match parts.len() { @@ -112,6 +186,15 @@ impl<'a> From<&'a str> for TableReference<'a> { } } +/// parse any string with "." :( +/// +/// See caveats on parse_str +impl<'a> From<&'a str> for TableReference<'a> { + fn from(s: &'a str) -> Self { + Self::parse_str(s) + } +} + impl<'a> From> for TableReference<'a> { fn from(resolved: ResolvedTableReference<'a>) -> Self { Self::Full { diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index 6e8dcd5b2b1a..5fc2f48adda0 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -20,7 +20,7 @@ use crate::catalog::schema::SchemaProvider; use crate::datasource::datasource::TableProviderFactory; use crate::datasource::TableProvider; use crate::execution::context::SessionState; -use datafusion_common::{DFSchema, DataFusionError}; +use datafusion_common::{DFSchema, DataFusionError, OwnedTableReference}; use datafusion_expr::CreateExternalTable; use futures::TryStreamExt; use itertools::Itertools; @@ -115,16 +115,20 @@ impl ListingSchemaProvider { let table_path = table.to_str().ok_or_else(|| { DataFusionError::Internal("Cannot parse file name!".to_string()) })?; + if !self.table_exist(table_name) { let table_url = format!("{}/{}", self.authority, table_path); + let name = OwnedTableReference::Bare { + table: table_name.to_string(), + }; let provider = self .factory .create( state, &CreateExternalTable { schema: Arc::new(DFSchema::empty()), - name: table_name.to_string(), + name, location: table_url, file_type: self.format.clone(), has_header: self.has_header, diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index 5939c25e6930..ce262f55ddfc 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -502,7 +502,7 @@ mod tests { let actual = format!("{}", plan.display_indent()); let expected = "\ Explain\ - \n CreateView: \"xyz\"\ + \n CreateView: Bare { table: \"xyz\" }\ \n Projection: abc.column1, abc.column2, abc.column3\ \n TableScan: abc projection=[column1, column2, column3]"; assert_eq!(expected, actual); @@ -516,7 +516,7 @@ mod tests { let actual = format!("{}", plan.display_indent()); let expected = "\ Explain\ - \n CreateView: \"xyz\"\ + \n CreateView: Bare { table: \"xyz\" }\ \n Projection: abc.column1, abc.column2, abc.column3\ \n Filter: abc.column2 = Int64(5)\ \n TableScan: abc projection=[column1, column2, column3]"; @@ -531,7 +531,7 @@ mod tests { let actual = format!("{}", plan.display_indent()); let expected = "\ Explain\ - \n CreateView: \"xyz\"\ + \n CreateView: Bare { table: \"xyz\" }\ \n Projection: abc.column1, abc.column2\ \n Filter: abc.column2 = Int64(5)\ \n TableScan: abc projection=[column1, column2]"; diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 0117fee37e0f..f54eac6cbc9a 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -225,7 +225,7 @@ impl SessionContext { batch: RecordBatch, ) -> Result>> { let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - self.register_table(table_name, Arc::new(table)) + self.register_table(TableReference::Bare { table: table_name }, Arc::new(table)) } /// Return the [RuntimeEnv] used to run queries with this [SessionContext] @@ -265,12 +265,12 @@ impl SessionContext { if_not_exists, or_replace, }) => { - let table = self.table(name.as_str()); + let table = self.table(&name); match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), (false, true, Ok(_)) => { - self.deregister_table(name.as_str())?; + self.deregister_table(&name)?; let physical = Arc::new(DataFrame::new(self.state.clone(), &input)); @@ -280,7 +280,7 @@ impl SessionContext { batches, )?); - self.register_table(name.as_str(), table)?; + self.register_table(&name, table)?; self.return_empty_dataframe() } (true, true, Ok(_)) => Err(DataFusionError::Internal( @@ -296,7 +296,7 @@ impl SessionContext { batches, )?); - self.register_table(name.as_str(), table)?; + self.register_table(&name, table)?; self.return_empty_dataframe() } (false, false, Ok(_)) => Err(DataFusionError::Execution(format!( @@ -312,22 +312,22 @@ impl SessionContext { or_replace, definition, }) => { - let view = self.table(name.as_str()); + let view = self.table(&name); match (or_replace, view) { (true, Ok(_)) => { - self.deregister_table(name.as_str())?; + self.deregister_table(&name)?; let table = Arc::new(ViewTable::try_new((*input).clone(), definition)?); - self.register_table(name.as_str(), table)?; + self.register_table(&name, table)?; self.return_empty_dataframe() } (_, Err(_)) => { let table = Arc::new(ViewTable::try_new((*input).clone(), definition)?); - self.register_table(name.as_str(), table)?; + self.register_table(&name, table)?; self.return_empty_dataframe() } (false, Ok(_)) => Err(DataFusionError::Execution(format!( @@ -340,7 +340,7 @@ impl SessionContext { LogicalPlan::DropTable(DropTable { name, if_exists, .. }) => { - let result = self.find_and_deregister(name.as_str(), TableType::Base); + let result = self.find_and_deregister(&name, TableType::Base); match (result, if_exists) { (Ok(true), _) => self.return_empty_dataframe(), (_, true) => self.return_empty_dataframe(), @@ -354,7 +354,7 @@ impl SessionContext { LogicalPlan::DropView(DropView { name, if_exists, .. }) => { - let result = self.find_and_deregister(name.as_str(), TableType::View); + let result = self.find_and_deregister(&name, TableType::View); match (result, if_exists) { (Ok(true), _) => self.return_empty_dataframe(), (_, true) => self.return_empty_dataframe(), @@ -497,11 +497,11 @@ impl SessionContext { let table_provider: Arc = self.create_custom_table(cmd).await?; - let table = self.table(cmd.name.as_str()); + let table = self.table(&cmd.name); match (cmd.if_not_exists, table) { (true, Ok(_)) => self.return_empty_dataframe(), (_, Err(_)) => { - self.register_table(cmd.name.as_str(), table_provider)?; + self.register_table(&cmd.name, table_provider)?; self.return_empty_dataframe() } (false, Ok(_)) => Err(DataFusionError::Execution(format!( @@ -765,7 +765,7 @@ impl SessionContext { .with_listing_options(options) .with_schema(resolved_schema); let table = ListingTable::try_new(config)?.with_definition(sql_definition); - self.register_table(name, Arc::new(table))?; + self.register_table(TableReference::Bare { table: name }, Arc::new(table))?; Ok(()) } diff --git a/datafusion/core/tests/sql/errors.rs b/datafusion/core/tests/sql/errors.rs index 412ad108c4dc..f3761320bf60 100644 --- a/datafusion/core/tests/sql/errors.rs +++ b/datafusion/core/tests/sql/errors.rs @@ -132,7 +132,12 @@ async fn invalid_qualified_table_references() -> Result<()> { "way.too.many.namespaces.as.ident.prefixes.aggregate_test_100", ] { let sql = format!("SELECT COUNT(*) FROM {}", table_ref); - assert!(matches!(ctx.sql(&sql).await, Err(DataFusionError::Plan(_)))); + let result = ctx.sql(&sql).await; + assert!( + matches!(result, Err(DataFusionError::Plan(_))), + "result was: {:?}", + result + ); } Ok(()) } diff --git a/datafusion/core/tests/sqllogictests/src/insert/mod.rs b/datafusion/core/tests/sqllogictests/src/insert/mod.rs index 695b6d26d56b..6caac0a36b69 100644 --- a/datafusion/core/tests/sqllogictests/src/insert/mod.rs +++ b/datafusion/core/tests/sqllogictests/src/insert/mod.rs @@ -24,22 +24,21 @@ use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_common::{DFSchema, DataFusionError}; use datafusion_expr::Expr as DFExpr; -use datafusion_sql::planner::{PlannerContext, SqlToRel}; +use datafusion_sql::planner::{object_name_to_table_reference, PlannerContext, SqlToRel}; use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement}; use std::sync::Arc; -pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result { +pub async fn insert(ctx: &SessionContext, insert_stmt: SQLStatement) -> Result { // First, use sqlparser to get table name and insert values - let table_name; + let table_reference; let insert_values: Vec>; match insert_stmt { SQLStatement::Insert { - table_name: name, - source, - .. + table_name, source, .. } => { + table_reference = object_name_to_table_reference(table_name)?; + // Todo: check columns match table schema - table_name = name.to_string(); match &*source.body { SetExpr::Values(values) => { insert_values = values.0.clone(); @@ -54,9 +53,9 @@ pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result< } // Second, get batches in table and destroy the old table - let mut origin_batches = ctx.table(table_name.as_str())?.collect().await?; - let schema = ctx.table_provider(table_name.as_str())?.schema(); - ctx.deregister_table(table_name.as_str())?; + let mut origin_batches = ctx.table(&table_reference)?.collect().await?; + let schema = ctx.table_provider(&table_reference)?.schema(); + ctx.deregister_table(&table_reference)?; // Third, transfer insert values to `RecordBatch` // Attention: schema info can be ignored. (insert values don't contain schema info) @@ -84,7 +83,7 @@ pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result< // Final, create new memtable with same schema. let new_provider = MemTable::try_new(schema, vec![origin_batches])?; - ctx.register_table(table_name.as_str(), Arc::new(new_provider))?; + ctx.register_table(&table_reference, Arc::new(new_provider))?; Ok("".to_string()) } diff --git a/datafusion/core/tests/sqllogictests/src/main.rs b/datafusion/core/tests/sqllogictests/src/main.rs index ad7cb20e976b..759733cfd013 100644 --- a/datafusion/core/tests/sqllogictests/src/main.rs +++ b/datafusion/core/tests/sqllogictests/src/main.rs @@ -187,15 +187,17 @@ fn format_batches(batches: Vec) -> Result { async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result { let sql = sql.into(); // Check if the sql is `insert` - if let Ok(statements) = DFParser::parse_sql(&sql) { - if let Statement::Statement(statement) = &statements[0] { - if let SQLStatement::Insert { .. } = &**statement { + if let Ok(mut statements) = DFParser::parse_sql(&sql) { + let statement0 = statements.pop_front().expect("at least one SQL statement"); + if let Statement::Statement(statement) = statement0 { + let statement = *statement; + if matches!(&statement, SQLStatement::Insert { .. }) { return insert(ctx, statement).await; } } } - let df = ctx.sql(sql.as_str()).await.unwrap(); - let results: Vec = df.collect().await.unwrap(); + let df = ctx.sql(sql.as_str()).await?; + let results: Vec = df.collect().await?; let formatted_batches = format_batches(results)?; Ok(formatted_batches) } diff --git a/datafusion/core/tests/sqllogictests/test_files/insert.slt b/datafusion/core/tests/sqllogictests/test_files/ddl.slt similarity index 69% rename from datafusion/core/tests/sqllogictests/test_files/insert.slt rename to datafusion/core/tests/sqllogictests/test_files/ddl.slt index 0927b3777ddc..3420e47ee4da 100644 --- a/datafusion/core/tests/sqllogictests/test_files/insert.slt +++ b/datafusion/core/tests/sqllogictests/test_files/ddl.slt @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +## DDL Tests +## + statement ok CREATE TABLE users AS VALUES(1,2),(2,3); @@ -45,6 +48,32 @@ select * from users; 2 4 11 20 -# Test insert into a undefined table + +# Dropping table +statement ok +DROP TABLE users; + +# Table is gone +statement error +select * from users; + +# Can not drop it again +statement error +DROP TABLE user; + +# Can not insert into a undefined table statement error insert into user values(1, 20); + + +# Verify that creating tables with periods in their name works +# (note "foo.bar" is the table name, NOT table "bar" in schema "foo") +statement ok +CREATE TABLE "foo.bar" AS VALUES(1,2),(2,3); + +# Should be able to select from it as well +query II rowsort +select * from "foo.bar"; +---- +1 2 +2 3 diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ed5711a7fb57..d8df322c38cb 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -25,7 +25,9 @@ use crate::utils::{ }; use crate::{Expr, ExprSchemable, TableProviderFilterPushDown, TableSource}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::{plan_err, Column, DFSchema, DFSchemaRef, DataFusionError}; +use datafusion_common::{ + plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, +}; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; @@ -1058,7 +1060,7 @@ pub struct CreateCatalogSchema { #[derive(Clone)] pub struct DropTable { /// The table name - pub name: String, + pub name: OwnedTableReference, /// If the table exists pub if_exists: bool, /// Dummy schema @@ -1069,7 +1071,7 @@ pub struct DropTable { #[derive(Clone)] pub struct DropView { /// The view name - pub name: String, + pub name: OwnedTableReference, /// If the view exists pub if_exists: bool, /// Dummy schema @@ -1309,7 +1311,7 @@ pub struct Union { #[derive(Clone)] pub struct CreateMemoryTable { /// The table name - pub name: String, + pub name: OwnedTableReference, /// The logical plan pub input: Arc, /// Option to not error if table already exists @@ -1322,7 +1324,7 @@ pub struct CreateMemoryTable { #[derive(Clone)] pub struct CreateView { /// The table name - pub name: String, + pub name: OwnedTableReference, /// The logical plan pub input: Arc, /// Option to not error if table already exists @@ -1337,7 +1339,7 @@ pub struct CreateExternalTable { /// The table schema pub schema: DFSchemaRef, /// The table name - pub name: String, + pub name: OwnedTableReference, /// The physical location pub location: String, /// The file type of physical file diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e529626a6c62..d5284ef5956d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -167,7 +167,8 @@ message EmptyRelationNode { } message CreateExternalTableNode { - string name = 1; + reserved 1; // was string name + OwnedTableReference name = 12; string location = 2; string file_type = 3; bool has_header = 4; @@ -193,7 +194,8 @@ message CreateCatalogNode { } message CreateViewNode { - string name = 1; + reserved 1; // was string name + OwnedTableReference name = 5; LogicalPlanNode input = 2; bool or_replace = 3; string definition = 4; @@ -915,6 +917,29 @@ message StringifiedPlan { string plan = 2; } +message BareTableReference { + string table = 1; +} + +message PartialTableReference { + string schema = 1; + string table = 2; +} + +message FullTableReference { + string catalog = 1; + string schema = 2; + string table = 3; +} + +message OwnedTableReference { + oneof table_reference_enum { + BareTableReference bare = 1; + PartialTableReference partial = 2; + FullTableReference full = 3; + } +} + ///////////////////////////////////////////////////////////////////////////////////////////////// // PhysicalPlanNode is a nested type diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 935dd4f44892..98ffbd240a3b 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -29,7 +29,8 @@ use arrow::datatypes::{ }; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue, + Column, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, + ScalarValue, }; use datafusion_expr::expr::{BinaryExpr, Cast}; use datafusion_expr::{ @@ -202,6 +203,36 @@ impl From for WindowFrameUnits { } } +impl TryFrom for OwnedTableReference { + type Error = Error; + + fn try_from(value: protobuf::OwnedTableReference) -> Result { + use protobuf::owned_table_reference::TableReferenceEnum; + let table_reference_enum = value + .table_reference_enum + .ok_or_else(|| Error::required("table_reference_enum"))?; + + match table_reference_enum { + TableReferenceEnum::Bare(protobuf::BareTableReference { table }) => { + Ok(OwnedTableReference::Bare { table }) + } + TableReferenceEnum::Partial(protobuf::PartialTableReference { + schema, + table, + }) => Ok(OwnedTableReference::Partial { schema, table }), + TableReferenceEnum::Full(protobuf::FullTableReference { + catalog, + schema, + table, + }) => Ok(OwnedTableReference::Full { + catalog, + schema, + table, + }), + } + } +} + impl TryFrom<&protobuf::ArrowType> for DataType { type Error = Error; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8c98b76c0852..97b796257c0a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1717,6 +1717,97 @@ impl<'de> serde::Deserialize<'de> for AvroScanExecNode { deserializer.deserialize_struct("datafusion.AvroScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for BareTableReference { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.table.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.BareTableReference", len)?; + if !self.table.is_empty() { + struct_ser.serialize_field("table", &self.table)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for BareTableReference { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "table", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Table, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "table" => Ok(GeneratedField::Table), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = BareTableReference; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.BareTableReference") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut table__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Table => { + if table__.is_some() { + return Err(serde::de::Error::duplicate_field("table")); + } + table__ = Some(map.next_value()?); + } + } + } + Ok(BareTableReference { + table: table__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.BareTableReference", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for BetweenNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -3228,7 +3319,7 @@ impl serde::Serialize for CreateExternalTableNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { + if self.name.is_some() { len += 1; } if !self.location.is_empty() { @@ -3262,8 +3353,8 @@ impl serde::Serialize for CreateExternalTableNode { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.CreateExternalTableNode", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; } if !self.location.is_empty() { struct_ser.serialize_field("location", &self.location)?; @@ -3404,7 +3495,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = Some(map.next_value()?); + name__ = map.next_value()?; } GeneratedField::Location => { if location__.is_some() { @@ -3471,7 +3562,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { } } Ok(CreateExternalTableNode { - name: name__.unwrap_or_default(), + name: name__, location: location__.unwrap_or_default(), file_type: file_type__.unwrap_or_default(), has_header: has_header__.unwrap_or_default(), @@ -3496,7 +3587,7 @@ impl serde::Serialize for CreateViewNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { + if self.name.is_some() { len += 1; } if self.input.is_some() { @@ -3509,8 +3600,8 @@ impl serde::Serialize for CreateViewNode { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.CreateViewNode", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; } if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; @@ -3598,7 +3689,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { if name__.is_some() { return Err(serde::de::Error::duplicate_field("name")); } - name__ = Some(map.next_value()?); + name__ = map.next_value()?; } GeneratedField::Input => { if input__.is_some() { @@ -3621,7 +3712,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { } } Ok(CreateViewNode { - name: name__.unwrap_or_default(), + name: name__, input: input__, or_replace: or_replace__.unwrap_or_default(), definition: definition__.unwrap_or_default(), @@ -7313,6 +7404,131 @@ impl<'de> serde::Deserialize<'de> for FixedSizeList { deserializer.deserialize_struct("datafusion.FixedSizeList", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for FullTableReference { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.catalog.is_empty() { + len += 1; + } + if !self.schema.is_empty() { + len += 1; + } + if !self.table.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FullTableReference", len)?; + if !self.catalog.is_empty() { + struct_ser.serialize_field("catalog", &self.catalog)?; + } + if !self.schema.is_empty() { + struct_ser.serialize_field("schema", &self.schema)?; + } + if !self.table.is_empty() { + struct_ser.serialize_field("table", &self.table)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FullTableReference { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "catalog", + "schema", + "table", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Catalog, + Schema, + Table, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "catalog" => Ok(GeneratedField::Catalog), + "schema" => Ok(GeneratedField::Schema), + "table" => Ok(GeneratedField::Table), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FullTableReference; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FullTableReference") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut catalog__ = None; + let mut schema__ = None; + let mut table__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Catalog => { + if catalog__.is_some() { + return Err(serde::de::Error::duplicate_field("catalog")); + } + catalog__ = Some(map.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = Some(map.next_value()?); + } + GeneratedField::Table => { + if table__.is_some() { + return Err(serde::de::Error::duplicate_field("table")); + } + table__ = Some(map.next_value()?); + } + } + } + Ok(FullTableReference { + catalog: catalog__.unwrap_or_default(), + schema: schema__.unwrap_or_default(), + table: table__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.FullTableReference", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for GetIndexedField { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -12124,6 +12340,128 @@ impl<'de> serde::Deserialize<'de> for OptimizedPhysicalPlanType { deserializer.deserialize_struct("datafusion.OptimizedPhysicalPlanType", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for OwnedTableReference { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.table_reference_enum.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.OwnedTableReference", len)?; + if let Some(v) = self.table_reference_enum.as_ref() { + match v { + owned_table_reference::TableReferenceEnum::Bare(v) => { + struct_ser.serialize_field("bare", v)?; + } + owned_table_reference::TableReferenceEnum::Partial(v) => { + struct_ser.serialize_field("partial", v)?; + } + owned_table_reference::TableReferenceEnum::Full(v) => { + struct_ser.serialize_field("full", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for OwnedTableReference { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "bare", + "partial", + "full", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Bare, + Partial, + Full, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "bare" => Ok(GeneratedField::Bare), + "partial" => Ok(GeneratedField::Partial), + "full" => Ok(GeneratedField::Full), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = OwnedTableReference; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.OwnedTableReference") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut table_reference_enum__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Bare => { + if table_reference_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("bare")); + } + table_reference_enum__ = map.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Bare) +; + } + GeneratedField::Partial => { + if table_reference_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("partial")); + } + table_reference_enum__ = map.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Partial) +; + } + GeneratedField::Full => { + if table_reference_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("full")); + } + table_reference_enum__ = map.next_value::<::std::option::Option<_>>()?.map(owned_table_reference::TableReferenceEnum::Full) +; + } + } + } + Ok(OwnedTableReference { + table_reference_enum: table_reference_enum__, + }) + } + } + deserializer.deserialize_struct("datafusion.OwnedTableReference", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ParquetFormat { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -12326,6 +12664,114 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PartialTableReference { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.schema.is_empty() { + len += 1; + } + if !self.table.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PartialTableReference", len)?; + if !self.schema.is_empty() { + struct_ser.serialize_field("schema", &self.schema)?; + } + if !self.table.is_empty() { + struct_ser.serialize_field("table", &self.table)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PartialTableReference { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "schema", + "table", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Schema, + Table, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "schema" => Ok(GeneratedField::Schema), + "table" => Ok(GeneratedField::Table), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PartialTableReference; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PartialTableReference") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut schema__ = None; + let mut table__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = Some(map.next_value()?); + } + GeneratedField::Table => { + if table__.is_some() { + return Err(serde::de::Error::duplicate_field("table")); + } + table__ = Some(map.next_value()?); + } + } + } + Ok(PartialTableReference { + schema: schema__.unwrap_or_default(), + table: table__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PartitionId { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ea281b717d09..6bfb6b96c32a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -248,8 +248,8 @@ pub struct EmptyRelationNode { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct CreateExternalTableNode { - #[prost(string, tag = "1")] - pub name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "12")] + pub name: ::core::option::Option, #[prost(string, tag = "2")] pub location: ::prost::alloc::string::String, #[prost(string, tag = "3")] @@ -294,8 +294,8 @@ pub struct CreateCatalogNode { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct CreateViewNode { - #[prost(string, tag = "1")] - pub name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "5")] + pub name: ::core::option::Option, #[prost(message, optional, boxed, tag = "2")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(bool, tag = "3")] @@ -1148,6 +1148,46 @@ pub struct StringifiedPlan { #[prost(string, tag = "2")] pub plan: ::prost::alloc::string::String, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct BareTableReference { + #[prost(string, tag = "1")] + pub table: ::prost::alloc::string::String, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PartialTableReference { + #[prost(string, tag = "1")] + pub schema: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub table: ::prost::alloc::string::String, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FullTableReference { + #[prost(string, tag = "1")] + pub catalog: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub schema: ::prost::alloc::string::String, + #[prost(string, tag = "3")] + pub table: ::prost::alloc::string::String, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct OwnedTableReference { + #[prost(oneof = "owned_table_reference::TableReferenceEnum", tags = "1, 2, 3")] + pub table_reference_enum: ::core::option::Option< + owned_table_reference::TableReferenceEnum, + >, +} +/// Nested message and enum types in `OwnedTableReference`. +pub mod owned_table_reference { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum TableReferenceEnum { + #[prost(message, tag = "1")] + Bare(super::BareTableReference), + #[prost(message, tag = "2")] + Partial(super::PartialTableReference), + #[prost(message, tag = "3")] + Full(super::FullTableReference), + } +} /// PhysicalPlanNode is a nested type #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalPlanNode { diff --git a/datafusion/proto/src/logical_plan.rs b/datafusion/proto/src/logical_plan.rs index 99d5b8ebd671..38c18673680d 100644 --- a/datafusion/proto/src/logical_plan.rs +++ b/datafusion/proto/src/logical_plan.rs @@ -38,7 +38,7 @@ use datafusion::{ datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; -use datafusion_common::{context, Column, DataFusionError}; +use datafusion_common::{context, Column, DataFusionError, OwnedTableReference}; use datafusion_expr::logical_plan::builder::{project, subquery_alias_owned}; use datafusion_expr::{ logical_plan::{ @@ -264,6 +264,20 @@ impl From for protobuf::JoinConstraint { } } +fn from_owned_table_reference( + table_ref: Option<&protobuf::OwnedTableReference>, + error_context: &str, +) -> Result { + let table_ref = table_ref.ok_or_else(|| { + DataFusionError::Internal(format!( + "Protobuf deserialization error, {} was missing required field name.", + error_context + )) + })?; + + Ok(table_ref.clone().try_into()?) +} + impl AsLogicalPlan for LogicalPlanNode { fn try_decode(buf: &[u8]) -> Result where @@ -580,7 +594,7 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(LogicalPlan::CreateExternalTable(CreateExternalTable { schema: pb_schema.try_into()?, - name: create_extern_table.name.clone(), + name: from_owned_table_reference(create_extern_table.name.as_ref(), "CreateExternalTable")?, location: create_extern_table.location.clone(), file_type: create_extern_table.file_type.clone(), has_header: create_extern_table.has_header, @@ -609,7 +623,10 @@ impl AsLogicalPlan for LogicalPlanNode { }; Ok(LogicalPlan::CreateView(CreateView { - name: create_view.name.clone(), + name: from_owned_table_reference( + create_view.name.as_ref(), + "CreateView", + )?, input: Arc::new(plan), or_replace: create_view.or_replace, definition, @@ -1215,7 +1232,7 @@ impl AsLogicalPlan for LogicalPlanNode { }) => Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( protobuf::CreateExternalTableNode { - name: name.clone(), + name: Some(name.clone().into()), location: location.clone(), file_type: file_type.clone(), has_header: *has_header, @@ -1237,7 +1254,7 @@ impl AsLogicalPlan for LogicalPlanNode { }) => Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateView(Box::new( protobuf::CreateViewNode { - name: name.clone(), + name: Some(name.clone().into()), input: Some(Box::new(LogicalPlanNode::try_from_logical_plan( input, extension_codec, diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 8dc72a1f2f48..133e2f89d54e 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -33,7 +33,7 @@ use arrow::datatypes::{ DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode, }; -use datafusion_common::{Column, DFField, DFSchemaRef, ScalarValue}; +use datafusion_common::{Column, DFField, DFSchemaRef, OwnedTableReference, ScalarValue}; use datafusion_expr::expr::{ Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, }; @@ -1292,6 +1292,36 @@ impl From<&IntervalUnit> for protobuf::IntervalUnit { } } +impl From for protobuf::OwnedTableReference { + fn from(t: OwnedTableReference) -> Self { + use protobuf::owned_table_reference::TableReferenceEnum; + let table_reference_enum = match t { + OwnedTableReference::Bare { table } => { + TableReferenceEnum::Bare(protobuf::BareTableReference { table }) + } + OwnedTableReference::Partial { schema, table } => { + TableReferenceEnum::Partial(protobuf::PartialTableReference { + schema, + table, + }) + } + OwnedTableReference::Full { + catalog, + schema, + table, + } => TableReferenceEnum::Full(protobuf::FullTableReference { + catalog, + schema, + table, + }), + }; + + protobuf::OwnedTableReference { + table_reference_enum: Some(table_reference_enum), + } + } +} + /// Creates a scalar protobuf value from an optional value (T), and /// encoding None as the appropriate datatype fn create_proto_scalar protobuf::scalar_value::Value>( diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 77fc3a82cfcc..81ba44559ead 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -37,11 +37,11 @@ use sqlparser::ast::{ObjectType, OrderByExpr, Statement}; use sqlparser::parser::ParserError::ParserError; use datafusion_common::parsers::parse_interval; -use datafusion_common::TableReference; use datafusion_common::ToDFSchema; use datafusion_common::{ field_not_found, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; +use datafusion_common::{OwnedTableReference, TableReference}; use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like}; use datafusion_expr::expr_rewriter::normalize_col; use datafusion_expr::expr_rewriter::normalize_col_with_schemas; @@ -72,7 +72,9 @@ use datafusion_expr::{ }; use crate::parser::{CreateExternalTable, DescribeTable, Statement as DFStatement}; -use crate::utils::{make_decimal_type, normalize_ident, resolve_columns}; +use crate::utils::{ + make_decimal_type, normalize_ident, normalize_ident_owned, resolve_columns, +}; use super::{ parser::DFParser, @@ -262,7 +264,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; Ok(LogicalPlan::CreateMemoryTable(CreateMemoryTable { - name: name.to_string(), + name: object_name_to_table_reference(name)?, input: Arc::new(plan), if_not_exists, or_replace, @@ -280,7 +282,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan = Self::apply_expr_alias(plan, &columns)?; Ok(LogicalPlan::CreateView(CreateView { - name: name.to_string(), + name: object_name_to_table_reference(name)?, input: Arc::new(plan), or_replace, definition: sql, @@ -291,7 +293,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .to_string(), )), Statement::ShowCreate { obj_type, obj_name } => match obj_type { - ShowCreateObject::Table => self.show_create_table_to_plan(&obj_name), + ShowCreateObject::Table => self.show_create_table_to_plan(obj_name), _ => Err(DataFusionError::NotImplemented( "Only `SHOW CREATE TABLE ...` statement is supported".to_string(), )), @@ -316,34 +318,40 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Statement::Drop { object_type, if_exists, - names, + mut names, cascade: _, restrict: _, purge: _, + } => { // We don't support cascade and purge for now. // nor do we support multiple object names - } => match object_type { - ObjectType::Table => Ok(LogicalPlan::DropTable(DropTable { - name: names - .get(0) - .ok_or_else(|| ParserError("Missing table name.".to_string()))? - .to_string(), - if_exists, - schema: DFSchemaRef::new(DFSchema::empty()), - })), - ObjectType::View => Ok(LogicalPlan::DropView(DropView { - name: names - .get(0) - .ok_or_else(|| ParserError("Missing table name.".to_string()))? - .to_string(), - if_exists, - schema: DFSchemaRef::new(DFSchema::empty()), - })), - _ => Err(DataFusionError::NotImplemented( - "Only `DROP TABLE/VIEW ...` statement is supported currently" - .to_string(), - )), - }, + + let name = match names.len() { + 0 => Err(ParserError("Missing table name.".to_string()).into()), + 1 => object_name_to_table_reference(names.pop().unwrap()), + _ => { + Err(ParserError("Multiple objects not supported".to_string()) + .into()) + } + }?; + + match object_type { + ObjectType::Table => Ok(LogicalPlan::DropTable(DropTable { + name, + if_exists, + schema: DFSchemaRef::new(DFSchema::empty()), + })), + ObjectType::View => Ok(LogicalPlan::DropView(DropView { + name, + if_exists, + schema: DFSchemaRef::new(DFSchema::empty()), + })), + _ => Err(DataFusionError::NotImplemented( + "Only `DROP TABLE/VIEW ...` statement is supported currently" + .to_string(), + )), + } + } Statement::ShowTables { extended, @@ -357,7 +365,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { full, table_name, filter, - } => self.show_columns_to_plan(extended, full, &table_name, filter.as_ref()), + } => self.show_columns_to_plan(extended, full, table_name, filter), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported SQL statement: {:?}", sql @@ -534,13 +542,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, statement: DescribeTable, ) -> Result { - let table_name = statement.table_name.to_string(); - let table_ref: TableReference = table_name.as_str().into(); + let DescribeTable { table_name } = statement; - // check if table_name exists - let _ = self.schema_provider.get_table_provider(table_ref)?; + let where_clause = object_name_to_qualifier(&table_name); + let table_ref = object_name_to_table_reference(table_name)?; - let where_clause = object_name_to_qualifier(&statement.table_name); + // check if table_name exists + let _ = self + .schema_provider + .get_table_provider((&table_ref).into())?; if self.has_table("information_schema", "tables") { let sql = format!( @@ -592,6 +602,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let schema = self.build_schema(columns)?; + // External tables do not support schemas at the moment, so the name is just a table name + let name = OwnedTableReference::Bare { table: name }; + Ok(LogicalPlan::CreateExternalTable(PlanCreateExternalTable { schema: schema.to_dfschema_ref()?, name, @@ -891,18 +904,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { outer_query_schema: Option<&DFSchema>, ) -> Result { let (plan, alias) = match relation { - TableFactor::Table { - name: ref sql_object_name, - alias, - .. - } => { + TableFactor::Table { name, alias, .. } => { // normalize name and alias - let table_name = normalize_sql_object_name(sql_object_name); - let table_ref: TableReference = table_name.as_str().into(); + let table_ref = object_name_to_table_reference(name)?; + let table_name = table_ref.display_string(); let table_alias = alias.as_ref().map(|a| normalize_ident(&a.name)); let cte = planner_context.ctes.get(&table_name); ( - match (cte, self.schema_provider.get_table_provider(table_ref)) { + match ( + cte, + self.schema_provider.get_table_provider((&table_ref).into()), + ) { (Some(cte_plan), _) => match table_alias { Some(cte_alias) => subquery_alias(cte_plan, &cte_alias), _ => Ok(cte_plan.clone()), @@ -1910,6 +1922,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // interpret names with '.' as if they were // compound identifiers, but this is not a compound // identifier. (e.g. it is "foo.bar" not foo.bar) + Ok(Expr::Column(Column { relation: None, name: normalize_ident(&id), @@ -1934,9 +1947,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::CompoundIdentifier(ids) => { - let mut var_names: Vec<_> = ids.into_iter().map(|s| normalize_ident(&s)).collect(); - - if var_names[0].get(0..1) == Some("@") { + if ids[0].value.starts_with('@') { + let var_names: Vec<_> = ids.into_iter().map(|s| normalize_ident(&s)).collect(); let ty = self .schema_provider .get_variable_type(&var_names) @@ -1948,37 +1960,41 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { })?; Ok(Expr::ScalarVariable(ty, var_names)) } else { - match (var_names.pop(), var_names.pop()) { - (Some(name), Some(relation)) if var_names.is_empty() => { - match schema.field_with_qualified_name(&relation, &name) { - Ok(_) => { - // found an exact match on a qualified name so this is a table.column identifier - Ok(Expr::Column(Column { - relation: Some(relation), - name, - })) - } - Err(_) => { - if let Some(field) = schema.fields().iter().find(|f| f.name().eq(&relation)) { - // Access to a field of a column which is a structure, example: SELECT my_struct.key - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(Expr::Column(field.qualified_column())), - ScalarValue::Utf8(Some(name)), - ))) - } else { - // table.column identifier - Ok(Expr::Column(Column { - relation: Some(relation), - name, - })) - } - } + // only support "schema.table" type identifiers here + let (name, relation) = match idents_to_table_reference(ids)? { + OwnedTableReference::Partial { schema, table } => (table, schema), + r @ OwnedTableReference::Bare { .. } | + r @ OwnedTableReference::Full { .. } => { + return Err(DataFusionError::Plan(format!( + "Unsupported compound identifier '{:?}'", r, + ))) + } + }; + + // Try and find the reference in schema + match schema.field_with_qualified_name(&relation, &name) { + Ok(_) => { + // found an exact match on a qualified name so this is a table.column identifier + Ok(Expr::Column(Column { + relation: Some(relation), + name, + })) + } + Err(_) => { + if let Some(field) = schema.fields().iter().find(|f| f.name().eq(&relation)) { + // Access to a field of a column which is a structure, example: SELECT my_struct.key + Ok(Expr::GetIndexedField(GetIndexedField::new( + Box::new(Expr::Column(field.qualified_column())), + ScalarValue::Utf8(Some(name)), + ))) + } else { + // table.column identifier + Ok(Expr::Column(Column { + relation: Some(relation), + name, + })) } } - _ => Err(DataFusionError::NotImplemented(format!( - "Unsupported compound identifier '{:?}'", - var_names, - ))), } } } @@ -2706,8 +2722,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, extended: bool, full: bool, - sql_table_name: &ObjectName, - filter: Option<&ShowStatementFilter>, + sql_table_name: ObjectName, + filter: Option, ) -> Result { if filter.is_some() { return Err(DataFusionError::Plan( @@ -2721,13 +2737,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .to_string(), )); } - let table_name = normalize_sql_object_name(sql_table_name); - let table_ref: TableReference = table_name.as_str().into(); - - let _ = self.schema_provider.get_table_provider(table_ref)?; - // Figure out the where clause - let where_clause = object_name_to_qualifier(sql_table_name); + let where_clause = object_name_to_qualifier(&sql_table_name); + + // Do a table lookup to verify the table exists + let table_ref = object_name_to_table_reference(sql_table_name)?; + let _ = self + .schema_provider + .get_table_provider((&table_ref).into())?; // treat both FULL and EXTENDED as the same let select_list = if full || extended { @@ -2748,7 +2765,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn show_create_table_to_plan( &self, - sql_table_name: &ObjectName, + sql_table_name: ObjectName, ) -> Result { if !self.has_table("information_schema", "tables") { return Err(DataFusionError::Plan( @@ -2756,13 +2773,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .to_string(), )); } - let table_name = normalize_sql_object_name(sql_table_name); - let table_ref: TableReference = table_name.as_str().into(); - - let _ = self.schema_provider.get_table_provider(table_ref)?; - // Figure out the where clause - let where_clause = object_name_to_qualifier(sql_table_name); + let where_clause = object_name_to_qualifier(&sql_table_name); + + // Do a table lookup to verify the table exists + let table_ref = object_name_to_table_reference(sql_table_name)?; + let _ = self + .schema_provider + .get_table_provider((&table_ref).into())?; let query = format!( "SELECT table_catalog, table_schema, table_name, definition FROM information_schema.views WHERE {}", @@ -2995,14 +3013,61 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } -/// Normalize a SQL object name -fn normalize_sql_object_name(sql_object_name: &ObjectName) -> String { - sql_object_name - .0 - .iter() - .map(normalize_ident) - .collect::>() - .join(".") +/// Create a [`OwnedTableReference`] after normalizing the specified ObjectName +/// +/// For example, TODO make doc example +/// ['foo'] -> +/// ['"foo.bar"]] -> +/// ['foo', 'Bar'] -> +/// ['foo', 'bar'] -> +/// ['foo', '"Bar"'] -> +pub fn object_name_to_table_reference( + object_name: ObjectName, +) -> Result { + // use destructure to make it clear no fields on ObjectName are ignored + let ObjectName(idents) = object_name; + idents_to_table_reference(idents) +} + +/// Create a [`OwnedTableReference`] after normalizing the specified identifier +fn idents_to_table_reference(idents: Vec) -> Result { + struct IdentTaker(Vec); + /// take the next identifier from the back of idents, panic'ing if + /// there are none left + impl IdentTaker { + fn take(&mut self) -> String { + let ident = self.0.pop().expect("no more identifiers"); + normalize_ident_owned(ident) + } + } + + let mut taker = IdentTaker(idents); + + match taker.0.len() { + 1 => { + let table = taker.take(); + Ok(OwnedTableReference::Bare { table }) + } + 2 => { + let table = taker.take(); + let schema = taker.take(); + Ok(OwnedTableReference::Partial { schema, table }) + } + 3 => { + let table = taker.take(); + let schema = taker.take(); + let catalog = taker.take(); + Ok(OwnedTableReference::Full { + catalog, + schema, + table, + }) + } + _ => Err(DataFusionError::Plan(format!( + "Unsupported compound identifier '{:?}'", + taker.0, + ))), + } } /// Construct a WHERE qualifier suitable for e.g. information_schema filtering @@ -4478,21 +4543,21 @@ mod tests { #[test] fn create_external_table_csv() { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv'"; - let expected = "CreateExternalTable: \"t\""; + let expected = "CreateExternalTable: Bare { table: \"t\" }"; quick_test(sql, expected); } #[test] fn create_external_table_custom() { let sql = "CREATE EXTERNAL TABLE dt STORED AS DELTATABLE LOCATION 's3://bucket/schema/table';"; - let expected = r#"CreateExternalTable: "dt""#; + let expected = r#"CreateExternalTable: Bare { table: "dt" }"#; quick_test(sql, expected); } #[test] fn create_external_table_csv_no_schema() { let sql = "CREATE EXTERNAL TABLE t STORED AS CSV LOCATION 'foo.csv'"; - let expected = "CreateExternalTable: \"t\""; + let expected = "CreateExternalTable: Bare { table: \"t\" }"; quick_test(sql, expected); } @@ -4506,7 +4571,7 @@ mod tests { "CREATE EXTERNAL TABLE t(c1 int) STORED AS JSON COMPRESSION TYPE BZIP2 LOCATION 'foo.json.bz2'", ]; for sql in sqls { - let expected = "CreateExternalTable: \"t\""; + let expected = "CreateExternalTable: Bare { table: \"t\" }"; quick_test(sql, expected); } @@ -4540,7 +4605,7 @@ mod tests { #[test] fn create_external_table_parquet_no_schema() { let sql = "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION 'foo.parquet'"; - let expected = "CreateExternalTable: \"t\""; + let expected = "CreateExternalTable: Bare { table: \"t\" }"; quick_test(sql, expected); } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 1cc6c1f00086..6a0dd7c3f581 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -542,3 +542,11 @@ pub(crate) fn normalize_ident(id: &Ident) -> String { None => id.value.to_ascii_lowercase(), } } + +// Normalize an owned identifier to a lowercase string unless the identifier is quoted. +pub(crate) fn normalize_ident_owned(id: Ident) -> String { + match id.quote_style { + Some(_) => id.value, + None => id.value.to_ascii_lowercase(), + } +} From 5169621027b2760dadbb10032ede9ceae707d9fb Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 6 Dec 2022 14:56:17 -0500 Subject: [PATCH 2/6] adjust docstrings --- datafusion/sql/src/planner.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 81ba44559ead..197d7d244252 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -3015,12 +3015,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Create a [`OwnedTableReference`] after normalizing the specified ObjectName /// -/// For example, TODO make doc example -/// ['foo'] -> -/// ['"foo.bar"]] -> -/// ['foo', 'Bar'] -> -/// ['foo', 'bar'] -> -/// ['foo', '"Bar"'] -> +/// Examples +/// ```text +/// ['foo'] -> Bare { table: "foo" } +/// ['"foo.bar"]] -> Bare { table: "foo.bar" } +/// ['foo', 'Bar'] -> Partial { schema: "foo", table: "bar" } <-- note lower case "bar" +/// ['foo', 'bar'] -> Partial { schema: "foo", table: "bar" } +/// ['foo', '"Bar"'] -> Partial { schema: "foo", table: "Bar" } +/// ``` pub fn object_name_to_table_reference( object_name: ObjectName, ) -> Result { From 9c9c313fcabb1a8f4d2d2c3099db6fc007937bd9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 6 Dec 2022 15:11:18 -0500 Subject: [PATCH 3/6] Improve docstrings --- datafusion/common/src/table_reference.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index d18dbbd35a00..9b811e401294 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -158,15 +158,16 @@ impl<'a> TableReference<'a> { } } - /// Split `s` on periods. + /// Forms a [`TableReferece`] by splitting `s` on periods `.`. /// /// Note that this function does NOT handle periods or name /// normalization correctly (e.g. `"foo.bar"` will be parsed as /// `"foo`.`bar"`. and `Foo` will be parsed as `Foo` (not `foo`). /// - /// Instead, you should use SQL parser for this + /// If you need to handle such identifiers correctly, you should + /// use a SQL oarser or form the [`OwnedTableReference`] directly. /// - /// The improvement is tracked in TODO FILE DATAFUSION TICKET + /// See more detail in pub fn parse_str(s: &'a str) -> Self { let parts: Vec<&str> = s.split('.').collect(); @@ -186,9 +187,9 @@ impl<'a> TableReference<'a> { } } -/// parse any string with "." :( +/// Parse a string into a TableReference, by splittig on `.` /// -/// See caveats on parse_str +/// See caveats on [`TableReference::parse_str`] impl<'a> From<&'a str> for TableReference<'a> { fn from(s: &'a str) -> Self { Self::parse_str(s) From ad007bf7d5739839b5612f94c976d40179f052cc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 6 Dec 2022 15:20:01 -0500 Subject: [PATCH 4/6] Add tests coverage --- .../tests/sqllogictests/test_files/ddl.slt | 81 ++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/ddl.slt b/datafusion/core/tests/sqllogictests/test_files/ddl.slt index 3420e47ee4da..b672d22ba3fa 100644 --- a/datafusion/core/tests/sqllogictests/test_files/ddl.slt +++ b/datafusion/core/tests/sqllogictests/test_files/ddl.slt @@ -15,8 +15,9 @@ # specific language governing permissions and limitations # under the License. +########## ## DDL Tests -## +########## statement ok CREATE TABLE users AS VALUES(1,2),(2,3); @@ -65,6 +66,9 @@ DROP TABLE user; statement error insert into user values(1, 20); +########## +# Multipart identifier test (CREATE/DROP TABLE) +########## # Verify that creating tables with periods in their name works # (note "foo.bar" is the table name, NOT table "bar" in schema "foo") @@ -77,3 +81,78 @@ select * from "foo.bar"; ---- 1 2 2 3 + +# Can not select from non existent foo.bar table +statement error +select * from foo.bar; + +# Error if wrong capitalization to select +statement error +select * from "Foo.bar"; + +# Should be able to drop the table +statement ok +drop table "foo.bar"; + +########## +# Multipart identifier test (CREATE/DROP VIEW) +########## + +# Verify that creating views with periods in their name works +# (note "foo.bar" is the table name, NOT table "bar" in schema "foo") +statement ok +CREATE VIEW "foo.bar" AS VALUES(1,2),(2,3); + +# Should be able to select from it as well +query II rowsort +select * from "foo.bar"; +---- +1 2 +2 3 + +# Can not select from non existent foo.bar view +statement error +select * from foo.bar; + +# Error if wrong capitalization to select +statement error +select * from "Foo.bar"; + +# Should be able to drop the view +statement ok +drop view "foo.bar"; + + +########## +# Drop view error tests +########## + +statement ok +CREATE VIEW foo AS VALUES(1,2),(2,3); + +statement ok +CREATE VIEW bar AS VALUES(3,4),(4,5); + +# Should be able to select from it as well +query II rowsort +select * from "foo"; +---- +1 2 +2 3 + +query II rowsort +select * from "bar"; +---- +3 4 +4 5 + +# multiple drops not supported +statement error +DROP VIEW foo, bar; + +# multiple drops not supported +statement ok +DROP VIEW foo; + +statement ok +DROP VIEW bar; From 569979c843897c0a86db7b65f964dfbadcb2f5c3 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 7 Dec 2022 07:00:38 -0500 Subject: [PATCH 5/6] Update datafusion/common/src/table_reference.rs Co-authored-by: Nga Tran --- datafusion/common/src/table_reference.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index 9b811e401294..250aa3129a7f 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -165,7 +165,7 @@ impl<'a> TableReference<'a> { /// `"foo`.`bar"`. and `Foo` will be parsed as `Foo` (not `foo`). /// /// If you need to handle such identifiers correctly, you should - /// use a SQL oarser or form the [`OwnedTableReference`] directly. + /// use a SQL parser or form the [`OwnedTableReference`] directly. /// /// See more detail in pub fn parse_str(s: &'a str) -> Self { From c03cf2b08edb2841fb5d7245f70a8fa263e5ecec Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 7 Dec 2022 07:05:04 -0500 Subject: [PATCH 6/6] Add tests for creating tables with three periods --- .../tests/sqllogictests/test_files/ddl.slt | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/datafusion/core/tests/sqllogictests/test_files/ddl.slt b/datafusion/core/tests/sqllogictests/test_files/ddl.slt index b672d22ba3fa..eda589e5b397 100644 --- a/datafusion/core/tests/sqllogictests/test_files/ddl.slt +++ b/datafusion/core/tests/sqllogictests/test_files/ddl.slt @@ -94,6 +94,20 @@ select * from "Foo.bar"; statement ok drop table "foo.bar"; +# Verify that creating tables with three periods also works +statement ok +CREATE TABLE "foo.bar.baz" AS VALUES(8,9); + +# Should be able to select from it as well +query II rowsort +select * from "foo.bar.baz"; +---- +8 9 + +# And drop +statement ok +drop table "foo.bar.baz" + ########## # Multipart identifier test (CREATE/DROP VIEW) ########## @@ -122,6 +136,20 @@ select * from "Foo.bar"; statement ok drop view "foo.bar"; +# Verify that creating views with three periods also works +statement ok +CREATE VIEW "foo.bar.baz" AS VALUES(8,9); + +# Should be able to select from it as well +query II rowsort +select * from "foo.bar.baz"; +---- +8 9 + +# And drop +statement ok +drop view "foo.bar.baz" + ########## # Drop view error tests