Skip to content

Commit

Permalink
Fix escaping of special characters or reserved words as column names …
Browse files Browse the repository at this point in the history
…in dialects of common sql provider (apache#45640)

* refactor: Make sure reserved words in column names are all escaped in the generate_replace_sql method of MSSQLDialect

---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
dabla and davidblain-infrabel authored Jan 26, 2025
1 parent fbf0f0e commit 17d3a60
Show file tree
Hide file tree
Showing 15 changed files with 288 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class Dialect(LoggingMixin):
"""Generic dialect implementation."""

pattern = re.compile(r'"([a-zA-Z0-9_]+)"')
pattern = re.compile(r"[^\w]")

def __init__(self, hook, **kwargs) -> None:
super().__init__(**kwargs)
Expand All @@ -45,12 +45,6 @@ def __init__(self, hook, **kwargs) -> None:

self.hook: DbApiHook = hook

@classmethod
def remove_quotes(cls, value: str | None) -> str | None:
if value:
return cls.pattern.sub(r"\1", value)
return value

@property
def placeholder(self) -> str:
return self.hook.placeholder
Expand All @@ -60,16 +54,56 @@ def inspector(self) -> Inspector:
return self.hook.inspector

@property
def _insert_statement_format(self) -> str:
return self.hook._insert_statement_format # type: ignore
def insert_statement_format(self) -> str:
return self.hook.insert_statement_format

@property
def replace_statement_format(self) -> str:
return self.hook.replace_statement_format

@property
def _replace_statement_format(self) -> str:
return self.hook._replace_statement_format # type: ignore
def escape_word_format(self) -> str:
return self.hook.escape_word_format

@property
def _escape_column_name_format(self) -> str:
return self.hook._escape_column_name_format # type: ignore
def escape_column_names(self) -> bool:
return self.hook.escape_column_names

def escape_word(self, word: str) -> str:
"""
Escape the word if necessary.
If the word is a reserved word or contains special characters or if the ``escape_column_names``
property is set to True in connection extra field, then the given word will be escaped.
:param word: Name of the column
:return: The escaped word
"""
if word != self.escape_word_format.format(self.unescape_word(word)) and (
self.escape_column_names or word.casefold() in self.reserved_words or self.pattern.search(word)
):
return self.escape_word_format.format(word)
return word

def unescape_word(self, word: str | None) -> str | None:
"""
Remove escape characters from each part of a dotted identifier (e.g., schema.table).
:param word: Escaped schema, table, or column name, potentially with multiple segments.
:return: The word without escaped characters.
"""
if not word:
return word

escape_char_start = self.escape_word_format[0]
escape_char_end = self.escape_word_format[-1]

def unescape_part(part: str) -> str:
if part.startswith(escape_char_start) and part.endswith(escape_char_end):
return part[1:-1]
return part

return ".".join(map(unescape_part, word.split(".")))

@classmethod
def extract_schema_from_table(cls, table: str) -> tuple[str, str | None]:
Expand All @@ -87,8 +121,8 @@ def get_column_names(
for column in filter(
predicate,
self.inspector.get_columns(
table_name=self.remove_quotes(table),
schema=self.remove_quotes(schema) if schema else None,
table_name=self.unescape_word(table),
schema=self.unescape_word(schema) if schema else None,
),
)
)
Expand All @@ -110,8 +144,8 @@ def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] |
if schema is None:
table, schema = self.extract_schema_from_table(table)
primary_keys = self.inspector.get_pk_constraint(
table_name=self.remove_quotes(table),
schema=self.remove_quotes(schema) if schema else None,
table_name=self.unescape_word(table),
schema=self.unescape_word(schema) if schema else None,
).get("constrained_columns", [])
self.log.debug("Primary keys for table '%s': %s", table, primary_keys)
return primary_keys
Expand All @@ -138,20 +172,6 @@ def get_records(
def reserved_words(self) -> set[str]:
return self.hook.reserved_words

def escape_column_name(self, column_name: str) -> str:
"""
Escape the column name if it's a reserved word.
:param column_name: Name of the column
:return: The escaped column name if needed
"""
if (
column_name != self._escape_column_name_format.format(column_name)
and column_name.casefold() in self.reserved_words
):
return self._escape_column_name_format.format(column_name)
return column_name

def _joined_placeholders(self, values) -> str:
placeholders = [
self.placeholder,
Expand All @@ -160,7 +180,7 @@ def _joined_placeholders(self, values) -> str:

def _joined_target_fields(self, target_fields) -> str:
if target_fields:
target_fields = ", ".join(map(self.escape_column_name, target_fields))
target_fields = ", ".join(map(self.escape_word, target_fields))
return f"({target_fields})"
return ""

Expand All @@ -173,7 +193,7 @@ def generate_insert_sql(self, table, values, target_fields, **kwargs) -> str:
:param target_fields: The names of the columns to fill in the table
:return: The generated INSERT SQL statement
"""
return self._insert_statement_format.format(
return self.insert_statement_format.format(
table, self._joined_target_fields(target_fields), self._joined_placeholders(values)
)

Expand All @@ -186,6 +206,6 @@ def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str:
:param target_fields: The names of the columns to fill in the table
:return: The generated REPLACE SQL statement
"""
return self._replace_statement_format.format(
return self.replace_statement_format.format(
table, self._joined_target_fields(target_fields), self._joined_placeholders(values)
)
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,17 @@ T = TypeVar("T")
class Dialect(LoggingMixin):
hook: Incomplete
def __init__(self, hook, **kwargs) -> None: ...
@classmethod
def remove_quotes(cls, value: str | None) -> str | None: ...
def escape_word(self, column_name: str) -> str: ...
def unescape_word(self, value: str | None) -> str | None: ...
@property
def placeholder(self) -> str: ...
@property
def insert_statement_format(self) -> str: ...
@property
def replace_statement_format(self) -> str: ...
@property
def escape_word_format(self) -> str: ...
@property
def inspector(self) -> Inspector: ...
@classmethod
def extract_schema_from_table(cls, table: str) -> tuple[str, str | None]: ...
Expand All @@ -72,6 +78,5 @@ class Dialect(LoggingMixin):
) -> Any: ...
@property
def reserved_words(self) -> set[str]: ...
def escape_column_name(self, column_name: str) -> str: ...
def generate_insert_sql(self, table, values, target_fields, **kwargs) -> str: ...
def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: ...
58 changes: 41 additions & 17 deletions providers/common/sql/src/airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.dialects.dialect import Dialect
from airflow.providers.common.sql.hooks import handlers
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
Expand All @@ -67,24 +68,18 @@
def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool | None):
warnings.warn(WARNING_MESSAGE.format("return_single_query_results"), DeprecationWarning, stacklevel=2)

from airflow.providers.common.sql.hooks import handlers

return handlers.return_single_query_results(sql, return_last, split_statements)


def fetch_all_handler(cursor) -> list[tuple] | None:
warnings.warn(WARNING_MESSAGE.format("fetch_all_handler"), DeprecationWarning, stacklevel=2)

from airflow.providers.common.sql.hooks import handlers

return handlers.fetch_all_handler(cursor)


def fetch_one_handler(cursor) -> list[tuple] | None:
warnings.warn(WARNING_MESSAGE.format("fetch_one_handler"), DeprecationWarning, stacklevel=2)

from airflow.providers.common.sql.hooks import handlers

return handlers.fetch_one_handler(cursor)


Expand Down Expand Up @@ -184,13 +179,10 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa
self.__schema = schema
self.log_sql = log_sql
self.descriptions: list[Sequence[Sequence] | None] = []
self._insert_statement_format: str = kwargs.get(
"insert_statement_format", "INSERT INTO {} {} VALUES ({})"
)
self._replace_statement_format: str = kwargs.get(
"replace_statement_format", "REPLACE INTO {} {} VALUES ({})"
)
self._escape_column_name_format: str = kwargs.get("escape_column_name_format", '"{}"')
self._insert_statement_format: str | None = kwargs.get("insert_statement_format")
self._replace_statement_format: str | None = kwargs.get("replace_statement_format")
self._escape_word_format: str | None = kwargs.get("escape_word_format")
self._escape_column_names: bool | None = kwargs.get("escape_column_names")
self._connection: Connection | None = kwargs.pop("connection", None)

def get_conn_id(self) -> str:
Expand All @@ -212,6 +204,38 @@ def placeholder(self) -> str:
)
return self._placeholder

@property
def insert_statement_format(self) -> str:
"""Return the insert statement format."""
if not self._insert_statement_format:
self._insert_statement_format = self.connection_extra.get(
"insert_statement_format", "INSERT INTO {} {} VALUES ({})"
)
return self._insert_statement_format

@property
def replace_statement_format(self) -> str:
"""Return the replacement statement format."""
if not self._replace_statement_format:
self._replace_statement_format = self.connection_extra.get(
"replace_statement_format", "REPLACE INTO {} {} VALUES ({})"
)
return self._replace_statement_format

@property
def escape_word_format(self) -> str:
"""Return the escape word format."""
if not self._escape_word_format:
self._escape_word_format = self.connection_extra.get("escape_word_format", '"{}"')
return self._escape_word_format

@property
def escape_column_names(self) -> bool:
"""Return the escape column names flag."""
if not self._escape_column_names:
self._escape_column_names = self.connection_extra.get("escape_column_names", False)
return self._escape_column_names

@property
def connection(self) -> Connection:
if self._connection is None:
Expand Down Expand Up @@ -413,7 +437,7 @@ def get_records(
:param sql: the sql statement to be executed (str) or a list of sql statements to execute
:param parameters: The parameters to render the SQL query with.
"""
return self.run(sql=sql, parameters=parameters, handler=fetch_all_handler)
return self.run(sql=sql, parameters=parameters, handler=handlers.fetch_all_handler)

def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None) -> Any:
"""
Expand All @@ -422,7 +446,7 @@ def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, An
:param sql: the sql statement to be executed (str) or a list of sql statements to execute
:param parameters: The parameters to render the SQL query with.
"""
return self.run(sql=sql, parameters=parameters, handler=fetch_one_handler)
return self.run(sql=sql, parameters=parameters, handler=handlers.fetch_one_handler)

@staticmethod
def strip_sql_string(sql: str) -> str:
Expand Down Expand Up @@ -557,7 +581,7 @@ def run(

if handler is not None:
result = self._make_common_data_structure(handler(cur))
if return_single_query_results(sql, return_last, split_statements):
if handlers.return_single_query_results(sql, return_last, split_statements):
_last_result = result
_last_description = cur.description
else:
Expand All @@ -572,7 +596,7 @@ def run(

if handler is None:
return None
if return_single_query_results(sql, return_last, split_statements):
if handlers.return_single_query_results(sql, return_last, split_statements):
self.descriptions = [_last_description]
return _last_result
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ class DbApiHook(BaseHook):
@cached_property
def placeholder(self) -> str: ...
@property
def insert_statement_format(self) -> str: ...
@property
def replace_statement_format(self) -> str: ...
@property
def escape_word_format(self) -> str: ...
@property
def escape_column_names(self) -> bool: ...
@property
def connection(self) -> Connection: ...
@connection.setter
def connection(self, value: Any) -> None: ...
Expand Down
Loading

0 comments on commit 17d3a60

Please sign in to comment.