From ab3cd1c0aa463193bc9b6fc9c53cce77f97674b2 Mon Sep 17 00:00:00 2001 From: Dilyan Marinov Date: Fri, 19 Jan 2024 15:50:14 +0200 Subject: [PATCH] vdk-oracle: escape special chars in column names Why? Special characters in column names currently cause errors. They should be supported in case the plugin is used for tables with unorthodox column names What? Support special characters in oracle column names Refactor column caching logic Special chars are escaped as advised here https://docs.oracle.com/en/error-help/db/ora-00904/ How was this tested? Ran functional tests locally CI/CD What kind of change is this? Feature/non-breaking Signed-off-by: Dilyan Marinov Export identifier normalization to separate function --- .../src/vdk/plugin/oracle/ingest_to_oracle.py | 143 +++++++++++------- .../00_drop_table.sql | 4 + .../10_ingest.py | 58 +++++++ .../00_drop_table.sql | 4 + .../10_ingest.py | 48 ++++++ .../00_drop_table.sql | 4 + .../10_create_table.sql | 9 ++ .../20_ingest.py | 20 +++ .../config.ini | 2 + .../vdk-oracle/tests/test_plugin.py | 73 ++++++++- 10 files changed, 307 insertions(+), 58 deletions(-) create mode 100644 projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-different-payloads-no-table-special-chars/00_drop_table.sql create mode 100644 projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-different-payloads-no-table-special-chars/10_ingest.py create mode 100644 projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-no-table-special-chars/00_drop_table.sql create mode 100644 projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-no-table-special-chars/10_ingest.py create mode 100644 projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/00_drop_table.sql create mode 100644 projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/10_create_table.sql create mode 100644 projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/20_ingest.py create mode 100644 projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/config.ini diff --git a/projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/ingest_to_oracle.py b/projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/ingest_to_oracle.py index e2931aed89..12dd89550e 100644 --- a/projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/ingest_to_oracle.py +++ b/projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/ingest_to_oracle.py @@ -3,12 +3,10 @@ import datetime import logging import math +import re + from decimal import Decimal -from typing import Any -from typing import Dict -from typing import List -from typing import Optional -from typing import Set +from typing import Any, Collection, Dict, List, Optional, Set from vdk.api.plugin.plugin_input import PEP249Connection from vdk.internal.builtin_plugins.connection.impl.router import ManagedConnectionRouter @@ -18,12 +16,76 @@ log = logging.getLogger(__name__) +# Functions for escaping special characters +def _is_plain_identifier(identifier: str) -> bool: + # https://docs.oracle.com/en/error-help/db/ora-00904/ + # Alphanumeric that doesn't start with a number + # Can contain and start with $, # and _ + regex = "^[A-Za-z\\$#_][0-9A-Za-z\\$#_]*$" + return bool(re.fullmatch(regex, identifier)) + + +def _normalize_identifier(identifier: str) -> str: + return identifier.upper() if _is_plain_identifier(identifier) else identifier + +def _escape_special_chars(value: str) -> str: + return value if _is_plain_identifier(value) else f'"{value}"' + +class TableCache: + + def __init__(self, cursor: ManagedCursor): + self._tables: Dict[str, Dict[str, str]] = {} + self._cursor = cursor + + def cache_columns(self, table: str) -> None: + + # exit if the table columns have already been cached + if table.upper() in self._tables and self._tables[table.upper()]: + return + try: + self._cursor.execute( + f"SELECT column_name, data_type, data_scale FROM user_tab_columns WHERE table_name = '{table.upper()}'" + ) + result = self._cursor.fetchall() + self._tables[table.upper()] = { + col: ("DECIMAL" if data_type == "NUMBER" and data_scale else data_type) + for (col, data_type, data_scale) in result + } + except Exception as e: + # TODO: https://github.com/vmware/versatile-data-kit/issues/2932 + log.exception("An error occurred while trying to cache columns. Ignoring for now.", e) + + def get_columns(self, table: str) -> Dict[str, str]: + return self._tables[table.upper()] + + def update_from_col_defs(self, table: str, col_defs) -> None: + self._tables[table.upper()].update(col_defs) + + def get_col_type(self, table: str, col: str) -> str: + return self._tables.get(table.upper()).get( + col.upper() if _is_plain_identifier(col) else col) + + def table_exists(self, table: str) -> bool: + if table.upper() in self._tables: + return True + + self._cursor.execute( + f"SELECT COUNT(*) FROM user_tables WHERE table_name = :1", + [table.upper()], + ) + exists = bool(self._cursor.fetchone()[0]) + + if exists: + self._tables[table.upper()] = {} + + return exists + + class IngestToOracle(IIngesterPlugin): def __init__(self, connections: ManagedConnectionRouter): self.conn: PEP249Connection = connections.open_connection("ORACLE").connect() self.cursor: ManagedCursor = self.conn.cursor() - self.table_cache: Set[str] = set() # Cache to store existing tables - self.column_cache: Dict[str, Dict[str, str]] = {} # New cache for columns + self.table_cache: TableCache = TableCache(self.cursor) # New cache for columns @staticmethod def _get_oracle_type(value: Any) -> str: @@ -38,53 +100,19 @@ def _get_oracle_type(value: Any) -> str: } return type_mappings.get(type(value), "VARCHAR2(255)") - def _table_exists(self, table_name: str) -> bool: - if table_name.upper() in self.table_cache: - return True - - self.cursor.execute( - f"SELECT COUNT(*) FROM user_tables WHERE table_name = :1", - [table_name.upper()], - ) - exists = bool(self.cursor.fetchone()[0]) - - if exists: - self.table_cache.add(table_name.upper()) - - return exists - def _create_table(self, table_name: str, row: Dict[str, Any]) -> None: - column_defs = [f"{col} {self._get_oracle_type(row[col])}" for col in row.keys()] + column_defs = [f"{_escape_special_chars(col)} {self._get_oracle_type(row[col])}" for col in row.keys()] create_table_sql = ( f"CREATE TABLE {table_name.upper()} ({', '.join(column_defs)})" ) self.cursor.execute(create_table_sql) - def _cache_columns(self, table_name: str) -> None: - try: - self.cursor.execute( - f"SELECT column_name, data_type, data_scale FROM user_tab_columns WHERE table_name = '{table_name.upper()}'" - ) - result = self.cursor.fetchall() - self.column_cache[table_name.upper()] = { - col: ("DECIMAL" if data_type == "NUMBER" and data_scale else data_type) - for (col, data_type, data_scale) in result - } - except Exception as e: - # TODO: https://github.com/vmware/versatile-data-kit/issues/2932 - log.error( - "An exception occurred while trying to cache columns. Ignoring for now." - ) - log.exception(e) - def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None: - if table_name.upper() not in self.column_cache: - self._cache_columns(table_name) - - existing_columns = self.column_cache[table_name.upper()] + self.table_cache.cache_columns(table_name) + existing_columns = self.table_cache.get_columns(table_name) # Find unique new columns from all rows in the payload - all_columns = {col.upper() for row in payload for col in row.keys()} + all_columns = {_normalize_identifier(col) for row in payload for col in row.keys()} new_columns = all_columns - existing_columns.keys() column_defs = [] if new_columns: @@ -99,12 +127,12 @@ def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None: ) column_defs.append((col, column_type)) - string_defs = [f"{col_def[0]} {col_def[1]}" for col_def in column_defs] + string_defs = [f"{_escape_special_chars(col_def[0])} {col_def[1]}" for col_def in column_defs] alter_sql = ( f"ALTER TABLE {table_name.upper()} ADD ({', '.join(string_defs)})" ) self.cursor.execute(alter_sql) - self.column_cache[table_name.upper()].update(column_defs) + self.table_cache.update_from_col_defs(table_name, column_defs) # TODO: https://github.com/vmware/versatile-data-kit/issues/2929 # TODO: https://github.com/vmware/versatile-data-kit/issues/2930 @@ -130,7 +158,7 @@ def cast_string_to_type(db_type: str, payload_value: str) -> Any: if isinstance(value, Decimal): return float(value) if isinstance(value, str): - col_type = self.column_cache.get(table.upper()).get(column.upper()) + col_type = self.table_cache.get_col_type(table, column) return cast_string_to_type(col_type, value) return value @@ -153,7 +181,8 @@ def _insert_data(self, table_name: str, payload: List[Dict[str, Any]]) -> None: batch_data = [] for column_names, batch in batches.items(): columns = list(column_names) - insert_sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({', '.join([':' + str(i + 1) for i in range(len(columns))])})" + query_columns = [_escape_special_chars(col) for col in columns] + insert_sql = f"INSERT INTO {table_name} ({', '.join(query_columns)}) VALUES ({', '.join([':' + str(i + 1) for i in range(len(query_columns))])})" queries.append(insert_sql) temp_data = [] for row in batch: @@ -169,21 +198,21 @@ def _insert_data(self, table_name: str, payload: List[Dict[str, Any]]) -> None: self.cursor.executemany(queries[i], batch_data[i]) def ingest_payload( - self, - payload: List[Dict[str, Any]], - destination_table: Optional[str] = None, - target: str = None, - collection_id: Optional[str] = None, - metadata: Optional[IIngesterPlugin.IngestionMetadata] = None, + self, + payload: List[Dict[str, Any]], + destination_table: Optional[str] = None, + target: str = None, + collection_id: Optional[str] = None, + metadata: Optional[IIngesterPlugin.IngestionMetadata] = None, ) -> None: if not payload: return None if not destination_table: raise ValueError("Destination table must be specified if not in payload.") - if not self._table_exists(destination_table): + if not self.table_cache.table_exists(destination_table): self._create_table(destination_table, payload[0]) - self._cache_columns(destination_table) + self.table_cache.cache_columns(destination_table) self._add_columns(destination_table, payload) self._insert_data(destination_table, payload) diff --git a/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-different-payloads-no-table-special-chars/00_drop_table.sql b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-different-payloads-no-table-special-chars/00_drop_table.sql new file mode 100644 index 0000000000..865c2f0050 --- /dev/null +++ b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-different-payloads-no-table-special-chars/00_drop_table.sql @@ -0,0 +1,4 @@ +begin + execute immediate 'drop table test_table'; + exception when others then if sqlcode <> -942 then raise; end if; +end; diff --git a/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-different-payloads-no-table-special-chars/10_ingest.py b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-different-payloads-no-table-special-chars/10_ingest.py new file mode 100644 index 0000000000..d4745e724a --- /dev/null +++ b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-different-payloads-no-table-special-chars/10_ingest.py @@ -0,0 +1,58 @@ +# Copyright 2021-2024 VMware, Inc. +# SPDX-License-Identifier: Apache-2.0 +import datetime + + +def run(job_input): + payloads = [ + { + "id": 0, + }, + { + "id": 1, + "?str_data": "string", + }, + { + "id": 2, + "?str_data": "string", + "@int_data": 12, + }, + { + "id": 3, + "?str_data": "string", + "@int_data": 12, + "%float_data": 1.2, + }, + { + "id": 4, + "?str_data": "string", + "@int_data": 12, + "%float_data": 1.2, + "^bool_data": True, + }, + { + "id": 5, + "?str_data": "string", + "@int_data": 12, + "%float_data": 1.2, + "^bool_data": True, + "×tamp_data": datetime.datetime.utcfromtimestamp(1700554373), + }, + { + "id": 6, + "?str_data": "string", + "@int_data": 12, + "%float_data": 1.2, + }, + { + "id": 7, + "?str_data": "string", + "@int_data": 12, + "%float_data": 1.2, + "^bool_data": True, + }, + ] + for payload in payloads: + job_input.send_object_for_ingestion( + payload=payload, destination_table="test_table" + ) diff --git a/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-no-table-special-chars/00_drop_table.sql b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-no-table-special-chars/00_drop_table.sql new file mode 100644 index 0000000000..865c2f0050 --- /dev/null +++ b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-no-table-special-chars/00_drop_table.sql @@ -0,0 +1,4 @@ +begin + execute immediate 'drop table test_table'; + exception when others then if sqlcode <> -942 then raise; end if; +end; diff --git a/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-no-table-special-chars/10_ingest.py b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-no-table-special-chars/10_ingest.py new file mode 100644 index 0000000000..4aca29b04d --- /dev/null +++ b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-no-table-special-chars/10_ingest.py @@ -0,0 +1,48 @@ +# Copyright 2021-2024 VMware, Inc. +# SPDX-License-Identifier: Apache-2.0 +import datetime +from decimal import Decimal + + +def run(job_input): + col_names = [ + "id", + "@str_data", + "%int_data", + "*float*data*", + "bool_data", + "timestamp_data", + "decimal_data", + ] + row_data = [ + [ + 0, + "string", + 12, + 1.2, + True, + datetime.datetime.utcfromtimestamp(1700554373), + Decimal(1.1), + ], + [ + 1, + "string", + 12, + 1.2, + True, + datetime.datetime.utcfromtimestamp(1700554373), + Decimal(1.1), + ], + [ + 2, + "string", + 12, + 1.2, + True, + datetime.datetime.utcfromtimestamp(1700554373), + Decimal(1.1), + ], + ] + job_input.send_tabular_data_for_ingestion( + rows=row_data, column_names=col_names, destination_table="test_table" + ) diff --git a/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/00_drop_table.sql b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/00_drop_table.sql new file mode 100644 index 0000000000..865c2f0050 --- /dev/null +++ b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/00_drop_table.sql @@ -0,0 +1,4 @@ +begin + execute immediate 'drop table test_table'; + exception when others then if sqlcode <> -942 then raise; end if; +end; diff --git a/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/10_create_table.sql b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/10_create_table.sql new file mode 100644 index 0000000000..8316ae0ac2 --- /dev/null +++ b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/10_create_table.sql @@ -0,0 +1,9 @@ +create table test_table ( + id number, + "@str_data" varchar2(255), + "%int_data" number, + "*float*data*" float, + bool_data number(1), + timestamp_data timestamp, + decimal_data decimal(14,8), + primary key(id)) diff --git a/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/20_ingest.py b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/20_ingest.py new file mode 100644 index 0000000000..10aac3eee2 --- /dev/null +++ b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/20_ingest.py @@ -0,0 +1,20 @@ +# Copyright 2021-2024 VMware, Inc. +# SPDX-License-Identifier: Apache-2.0 +import datetime +from decimal import Decimal + + +def run(job_input): + payload_with_types = { + "id": 5, + "@str_data": "string", + "%int_data": 12, + "*float*data*": 1.2, + "bool_data": True, + "timestamp_data": datetime.datetime.utcfromtimestamp(1700554373), + "decimal_data": Decimal(0.1), + } + + job_input.send_object_for_ingestion( + payload=payload_with_types, destination_table="test_table" + ) diff --git a/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/config.ini b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/config.ini new file mode 100644 index 0000000000..ab363d528b --- /dev/null +++ b/projects/vdk-plugins/vdk-oracle/tests/jobs/oracle-ingest-job-special-chars/config.ini @@ -0,0 +1,2 @@ +[owner] +team = test-team diff --git a/projects/vdk-plugins/vdk-oracle/tests/test_plugin.py b/projects/vdk-plugins/vdk-oracle/tests/test_plugin.py index 6b6766f1ed..5bc729e337 100644 --- a/projects/vdk-plugins/vdk-oracle/tests/test_plugin.py +++ b/projects/vdk-plugins/vdk-oracle/tests/test_plugin.py @@ -68,6 +68,14 @@ def test_oracle_ingest_existing_table(self): cli_assert_equal(0, result) _verify_ingest_execution(runner) + def test_oracle_ingest_existing_table_special_chars(self): + runner = CliEntryBasedTestRunner(oracle_plugin) + result: Result = runner.invoke( + ["run", jobs_path_from_caller_directory("oracle-ingest-job-special-chars")] + ) + cli_assert_equal(0, result) + _verify_ingest_execution_special_chars(runner) + def test_oracle_ingest_type_inference(self): runner = CliEntryBasedTestRunner(oracle_plugin) result: Result = runner.invoke( @@ -84,6 +92,14 @@ def test_oracle_ingest_no_table(self): cli_assert_equal(0, result) _verify_ingest_execution_no_table(runner) + def test_oracle_ingest_no_table_special_chars(self): + runner = CliEntryBasedTestRunner(oracle_plugin) + result: Result = runner.invoke( + ["run", jobs_path_from_caller_directory("oracle-ingest-job-no-table-special-chars")] + ) + cli_assert_equal(0, result) + _verify_ingest_execution_no_table_special_chars(runner) + def test_oracle_ingest_different_payloads(self): runner = CliEntryBasedTestRunner(oracle_plugin) result: Result = runner.invoke( @@ -107,7 +123,18 @@ def test_oracle_ingest_different_payloads_no_table(self): ) cli_assert_equal(0, result) _verify_ingest_execution_different_payloads_no_table(runner) - + def test_oracle_ingest_different_payloads_no_table_special_chars(self): + runner = CliEntryBasedTestRunner(oracle_plugin) + result: Result = runner.invoke( + [ + "run", + jobs_path_from_caller_directory( + "oracle-ingest-job-different-payloads-no-table-special-chars" + ), + ] + ) + cli_assert_equal(0, result) + _verify_ingest_execution_different_payloads_no_table_special_chars(runner) def test_oracle_ingest_blob(self): runner = CliEntryBasedTestRunner(oracle_plugin) result: Result = runner.invoke( @@ -145,6 +172,19 @@ def _verify_ingest_execution(runner): assert check_result.output == expected +def _verify_ingest_execution_special_chars(runner): + check_result = runner.invoke( + ["oracle-query", "--query", "SELECT * FROM test_table"] + ) + expected = (' ID @str_data %int_data *float*data* BOOL_DATA ' + 'TIMESTAMP_DATA DECIMAL_DATA\n' + '---- ----------- ----------- -------------- ----------- ' + '------------------- --------------\n' + ' 5 string 12 1.2 1 2023-11-21 ' + '08:12:53 0.1\n') + assert check_result.output == expected + + def _verify_ingest_execution_type_inference(runner): check_result = runner.invoke( ["oracle-query", "--query", "SELECT * FROM test_table"] @@ -179,13 +219,44 @@ def _verify_ingest_execution_no_table(runner): assert check_result.output == expected +def _verify_ingest_execution_no_table_special_chars(runner): + check_result = runner.invoke( + ["oracle-query", "--query", "SELECT * FROM test_table"] + ) + expected = (' ID @str_data %int_data *float*data* BOOL_DATA ' + 'TIMESTAMP_DATA DECIMAL_DATA\n' + '---- ----------- ----------- -------------- ----------- ' + '------------------- --------------\n' + ' 0 string 12 1.2 1 ' + '2023-11-21T08:12:53 1.1\n' + ' 1 string 12 1.2 1 ' + '2023-11-21T08:12:53 1.1\n' + ' 2 string 12 1.2 1 ' + '2023-11-21T08:12:53 1.1\n') + assert check_result.output == expected + + def _verify_ingest_execution_different_payloads_no_table(runner): check_result = runner.invoke( ["oracle-query", "--query", "SELECT count(*) FROM test_table"] ) expected = " COUNT(*)\n----------\n 8\n" assert check_result.output == expected +def _verify_ingest_execution_different_payloads_no_table_special_chars(runner): + check_result = runner.invoke( + ["oracle-query", "--query", "SELECT * FROM test_table"] + ) + actual_columns = check_result.output.split("\n")[0].split() + expected_columns = ['ID', '×tamp_data', '^bool_data', '@int_data', '%float_data', '?str_data'] + for expected_col in expected_columns: + assert expected_col in actual_columns + + expected_count = " COUNT(*)\n----------\n 8\n" + check_result = runner.invoke( + ["oracle-query", "--query", "SELECT count(*) FROM test_table"] + ) + assert check_result.output == expected_count def _verify_ingest_execution_different_payloads(runner): check_result = runner.invoke(