diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index d138bb06cd84..12510157dbb4 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -67,11 +67,7 @@ impl Column { } } - /// Deserialize a fully qualified name string into a column - pub fn from_qualified_name(flat_name: impl Into) -> Self { - let flat_name = flat_name.into(); - let mut idents = parse_identifiers_normalized(&flat_name); - + fn from_idents(idents: &mut Vec) -> Option { let (relation, name) = match idents.len() { 1 => (None, idents.remove(0)), 2 => ( @@ -97,9 +93,33 @@ impl Column { ), // any expression that failed to parse or has more than 4 period delimited // identifiers will be treated as an unqualified column name - _ => (None, flat_name), + _ => return None, }; - Self { relation, name } + Some(Self { relation, name }) + } + + /// Deserialize a fully qualified name string into a column + /// + /// Treats the name as a SQL identifier. For example + /// `foo.BAR` would be parsed to a reference to relation `foo`, column name `bar` (lower case) + /// where `"foo.BAR"` would be parsed to a reference to column named `foo.BAR` + pub fn from_qualified_name(flat_name: impl Into) -> Self { + let flat_name: &str = &flat_name.into(); + Self::from_idents(&mut parse_identifiers_normalized(flat_name, false)) + .unwrap_or_else(|| Self { + relation: None, + name: flat_name.to_owned(), + }) + } + + /// Deserialize a fully qualified name string into a column preserving column text case + pub fn from_qualified_name_ignore_case(flat_name: impl Into) -> Self { + let flat_name: &str = &flat_name.into(); + Self::from_idents(&mut parse_identifiers_normalized(flat_name, true)) + .unwrap_or_else(|| Self { + relation: None, + name: flat_name.to_owned(), + }) } /// Serialize column into a flat name string diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index cd05f8082dab..55681ece1016 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -299,7 +299,7 @@ impl<'a> TableReference<'a> { /// Forms a [`TableReference`] by parsing `s` as a multipart SQL /// identifier. See docs on [`TableReference`] for more details. pub fn parse_str(s: &'a str) -> Self { - let mut parts = parse_identifiers_normalized(s); + let mut parts = parse_identifiers_normalized(s, false); match parts.len() { 1 => Self::Bare { diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index c324e2079dc6..9ab91bcdd8bf 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -246,13 +246,14 @@ pub fn get_arrayref_at_indices( .collect() } -pub(crate) fn parse_identifiers_normalized(s: &str) -> Vec { +pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { parse_identifiers(s) .unwrap_or_default() .into_iter() .map(|id| match id.quote_style { Some(_) => id.value, - None => id.value.to_ascii_lowercase(), + None if ignore_case => id.value, + _ => id.value.to_ascii_lowercase(), }) .collect::>() } diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 5da72f96bd65..e1109b7fc403 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -1014,12 +1014,21 @@ impl DataFrame { /// ``` pub fn with_column_renamed( self, - old_name: impl Into, + old_name: impl Into, new_name: &str, ) -> Result { - let old_name: Column = old_name.into(); + let ident_opts = self + .session_state + .config_options() + .sql_parser + .enable_ident_normalization; + let old_column: Column = if ident_opts { + Column::from_qualified_name(old_name) + } else { + Column::from_qualified_name_ignore_case(old_name) + }; - let field_to_rename = match self.plan.schema().field_from_column(&old_name) { + let field_to_rename = match self.plan.schema().field_from_column(&old_column) { Ok(field) => field, // no-op if field not found Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { .. })) => { @@ -1830,6 +1839,58 @@ mod tests { Ok(()) } + #[tokio::test] + async fn with_column_renamed_case_sensitive() -> Result<()> { + let config = + SessionConfig::from_string_hash_map(std::collections::HashMap::from([( + "datafusion.sql_parser.enable_ident_normalization".to_owned(), + "false".to_owned(), + )]))?; + let mut ctx = SessionContext::with_config(config); + let name = "aggregate_test_100"; + register_aggregate_csv(&mut ctx, name).await?; + let df = ctx.table(name); + + let df = df + .await? + .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? + .limit(0, Some(1))? + .sort(vec![ + // make the test deterministic + col("c1").sort(true, true), + col("c2").sort(true, true), + col("c3").sort(true, true), + ])? + .select_columns(&["c1"])?; + + let df_renamed = df.clone().with_column_renamed("c1", "CoLuMn1")?; + + let res = &df_renamed.clone().collect().await?; + + assert_batches_sorted_eq!( + vec![ + "+---------+", + "| CoLuMn1 |", + "+---------+", + "| a |", + "+---------+", + ], + res + ); + + let df_renamed = df_renamed + .with_column_renamed("CoLuMn1", "c1")? + .collect() + .await?; + + assert_batches_sorted_eq!( + vec!["+----+", "| c1 |", "+----+", "| a |", "+----+",], + &df_renamed + ); + + Ok(()) + } + #[tokio::test] async fn filter_pushdown_dataframe() -> Result<()> { let ctx = SessionContext::new();