From 4a714617464d4d3cbbe7a34e7da0557556d588bc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jan 2024 08:49:14 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../src/vdk/plugin/oracle/ingest_to_oracle.py | 45 ++++++++++----- .../vdk-oracle/tests/test_plugin.py | 57 +++++++++++++------ 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/ingest_to_oracle.py b/projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/ingest_to_oracle.py index 12dd89550e..101b882c5b 100644 --- a/projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/ingest_to_oracle.py +++ b/projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/ingest_to_oracle.py @@ -4,9 +4,13 @@ import logging import math import re - from decimal import Decimal -from typing import Any, Collection, Dict, List, Optional, Set +from typing import Any +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Set from vdk.api.plugin.plugin_input import PEP249Connection from vdk.internal.builtin_plugins.connection.impl.router import ManagedConnectionRouter @@ -28,17 +32,17 @@ def _is_plain_identifier(identifier: str) -> bool: def _normalize_identifier(identifier: str) -> str: return identifier.upper() if _is_plain_identifier(identifier) else identifier + def _escape_special_chars(value: str) -> str: return value if _is_plain_identifier(value) else f'"{value}"' -class TableCache: +class TableCache: def __init__(self, cursor: ManagedCursor): self._tables: Dict[str, Dict[str, str]] = {} self._cursor = cursor def cache_columns(self, table: str) -> None: - # exit if the table columns have already been cached if table.upper() in self._tables and self._tables[table.upper()]: return @@ -53,7 +57,9 @@ def cache_columns(self, table: str) -> None: } except Exception as e: # TODO: https://github.com/vmware/versatile-data-kit/issues/2932 - log.exception("An error occurred while trying to cache columns. Ignoring for now.", e) + log.exception( + "An error occurred while trying to cache columns. Ignoring for now.", e + ) def get_columns(self, table: str) -> Dict[str, str]: return self._tables[table.upper()] @@ -63,7 +69,8 @@ def update_from_col_defs(self, table: str, col_defs) -> None: def get_col_type(self, table: str, col: str) -> str: return self._tables.get(table.upper()).get( - col.upper() if _is_plain_identifier(col) else col) + col.upper() if _is_plain_identifier(col) else col + ) def table_exists(self, table: str) -> bool: if table.upper() in self._tables: @@ -101,7 +108,10 @@ def _get_oracle_type(value: Any) -> str: return type_mappings.get(type(value), "VARCHAR2(255)") def _create_table(self, table_name: str, row: Dict[str, Any]) -> None: - column_defs = [f"{_escape_special_chars(col)} {self._get_oracle_type(row[col])}" for col in row.keys()] + column_defs = [ + f"{_escape_special_chars(col)} {self._get_oracle_type(row[col])}" + for col in row.keys() + ] create_table_sql = ( f"CREATE TABLE {table_name.upper()} ({', '.join(column_defs)})" ) @@ -112,7 +122,9 @@ def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None: existing_columns = self.table_cache.get_columns(table_name) # Find unique new columns from all rows in the payload - all_columns = {_normalize_identifier(col) for row in payload for col in row.keys()} + all_columns = { + _normalize_identifier(col) for row in payload for col in row.keys() + } new_columns = all_columns - existing_columns.keys() column_defs = [] if new_columns: @@ -127,7 +139,10 @@ def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None: ) column_defs.append((col, column_type)) - string_defs = [f"{_escape_special_chars(col_def[0])} {col_def[1]}" for col_def in column_defs] + string_defs = [ + f"{_escape_special_chars(col_def[0])} {col_def[1]}" + for col_def in column_defs + ] alter_sql = ( f"ALTER TABLE {table_name.upper()} ADD ({', '.join(string_defs)})" ) @@ -198,12 +213,12 @@ def _insert_data(self, table_name: str, payload: List[Dict[str, Any]]) -> None: self.cursor.executemany(queries[i], batch_data[i]) def ingest_payload( - self, - payload: List[Dict[str, Any]], - destination_table: Optional[str] = None, - target: str = None, - collection_id: Optional[str] = None, - metadata: Optional[IIngesterPlugin.IngestionMetadata] = None, + self, + payload: List[Dict[str, Any]], + destination_table: Optional[str] = None, + target: str = None, + collection_id: Optional[str] = None, + metadata: Optional[IIngesterPlugin.IngestionMetadata] = None, ) -> None: if not payload: return None diff --git a/projects/vdk-plugins/vdk-oracle/tests/test_plugin.py b/projects/vdk-plugins/vdk-oracle/tests/test_plugin.py index 5bc729e337..dd1d1aff3c 100644 --- a/projects/vdk-plugins/vdk-oracle/tests/test_plugin.py +++ b/projects/vdk-plugins/vdk-oracle/tests/test_plugin.py @@ -95,7 +95,12 @@ def test_oracle_ingest_no_table(self): def test_oracle_ingest_no_table_special_chars(self): runner = CliEntryBasedTestRunner(oracle_plugin) result: Result = runner.invoke( - ["run", jobs_path_from_caller_directory("oracle-ingest-job-no-table-special-chars")] + [ + "run", + jobs_path_from_caller_directory( + "oracle-ingest-job-no-table-special-chars" + ), + ] ) cli_assert_equal(0, result) _verify_ingest_execution_no_table_special_chars(runner) @@ -123,6 +128,7 @@ def test_oracle_ingest_different_payloads_no_table(self): ) cli_assert_equal(0, result) _verify_ingest_execution_different_payloads_no_table(runner) + def test_oracle_ingest_different_payloads_no_table_special_chars(self): runner = CliEntryBasedTestRunner(oracle_plugin) result: Result = runner.invoke( @@ -135,6 +141,7 @@ def test_oracle_ingest_different_payloads_no_table_special_chars(self): ) cli_assert_equal(0, result) _verify_ingest_execution_different_payloads_no_table_special_chars(runner) + def test_oracle_ingest_blob(self): runner = CliEntryBasedTestRunner(oracle_plugin) result: Result = runner.invoke( @@ -176,12 +183,14 @@ def _verify_ingest_execution_special_chars(runner): check_result = runner.invoke( ["oracle-query", "--query", "SELECT * FROM test_table"] ) - expected = (' ID @str_data %int_data *float*data* BOOL_DATA ' - 'TIMESTAMP_DATA DECIMAL_DATA\n' - '---- ----------- ----------- -------------- ----------- ' - '------------------- --------------\n' - ' 5 string 12 1.2 1 2023-11-21 ' - '08:12:53 0.1\n') + expected = ( + " ID @str_data %int_data *float*data* BOOL_DATA " + "TIMESTAMP_DATA DECIMAL_DATA\n" + "---- ----------- ----------- -------------- ----------- " + "------------------- --------------\n" + " 5 string 12 1.2 1 2023-11-21 " + "08:12:53 0.1\n" + ) assert check_result.output == expected @@ -223,16 +232,18 @@ def _verify_ingest_execution_no_table_special_chars(runner): check_result = runner.invoke( ["oracle-query", "--query", "SELECT * FROM test_table"] ) - expected = (' ID @str_data %int_data *float*data* BOOL_DATA ' - 'TIMESTAMP_DATA DECIMAL_DATA\n' - '---- ----------- ----------- -------------- ----------- ' - '------------------- --------------\n' - ' 0 string 12 1.2 1 ' - '2023-11-21T08:12:53 1.1\n' - ' 1 string 12 1.2 1 ' - '2023-11-21T08:12:53 1.1\n' - ' 2 string 12 1.2 1 ' - '2023-11-21T08:12:53 1.1\n') + expected = ( + " ID @str_data %int_data *float*data* BOOL_DATA " + "TIMESTAMP_DATA DECIMAL_DATA\n" + "---- ----------- ----------- -------------- ----------- " + "------------------- --------------\n" + " 0 string 12 1.2 1 " + "2023-11-21T08:12:53 1.1\n" + " 1 string 12 1.2 1 " + "2023-11-21T08:12:53 1.1\n" + " 2 string 12 1.2 1 " + "2023-11-21T08:12:53 1.1\n" + ) assert check_result.output == expected @@ -242,13 +253,22 @@ def _verify_ingest_execution_different_payloads_no_table(runner): ) expected = " COUNT(*)\n----------\n 8\n" assert check_result.output == expected + + def _verify_ingest_execution_different_payloads_no_table_special_chars(runner): check_result = runner.invoke( ["oracle-query", "--query", "SELECT * FROM test_table"] ) actual_columns = check_result.output.split("\n")[0].split() - expected_columns = ['ID', '×tamp_data', '^bool_data', '@int_data', '%float_data', '?str_data'] + expected_columns = [ + "ID", + "×tamp_data", + "^bool_data", + "@int_data", + "%float_data", + "?str_data", + ] for expected_col in expected_columns: assert expected_col in actual_columns @@ -258,6 +278,7 @@ def _verify_ingest_execution_different_payloads_no_table_special_chars(runner): ) assert check_result.output == expected_count + def _verify_ingest_execution_different_payloads(runner): check_result = runner.invoke( ["oracle-query", "--query", "SELECT * FROM test_table"]