Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 23, 2024
1 parent 2a09a24 commit 4a71461
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
import logging
import math
import re

from decimal import Decimal
from typing import Any, Collection, Dict, List, Optional, Set
from typing import Any
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Set

from vdk.api.plugin.plugin_input import PEP249Connection
from vdk.internal.builtin_plugins.connection.impl.router import ManagedConnectionRouter
Expand All @@ -28,17 +32,17 @@ def _is_plain_identifier(identifier: str) -> bool:
def _normalize_identifier(identifier: str) -> str:
return identifier.upper() if _is_plain_identifier(identifier) else identifier


def _escape_special_chars(value: str) -> str:
return value if _is_plain_identifier(value) else f'"{value}"'

class TableCache:

class TableCache:
def __init__(self, cursor: ManagedCursor):
self._tables: Dict[str, Dict[str, str]] = {}
self._cursor = cursor

def cache_columns(self, table: str) -> None:

# exit if the table columns have already been cached
if table.upper() in self._tables and self._tables[table.upper()]:
return
Expand All @@ -53,7 +57,9 @@ def cache_columns(self, table: str) -> None:
}
except Exception as e:
# TODO: https://github.com/vmware/versatile-data-kit/issues/2932
log.exception("An error occurred while trying to cache columns. Ignoring for now.", e)
log.exception(
"An error occurred while trying to cache columns. Ignoring for now.", e
)

def get_columns(self, table: str) -> Dict[str, str]:
return self._tables[table.upper()]
Expand All @@ -63,7 +69,8 @@ def update_from_col_defs(self, table: str, col_defs) -> None:

def get_col_type(self, table: str, col: str) -> str:
return self._tables.get(table.upper()).get(
col.upper() if _is_plain_identifier(col) else col)
col.upper() if _is_plain_identifier(col) else col
)

def table_exists(self, table: str) -> bool:
if table.upper() in self._tables:
Expand Down Expand Up @@ -101,7 +108,10 @@ def _get_oracle_type(value: Any) -> str:
return type_mappings.get(type(value), "VARCHAR2(255)")

def _create_table(self, table_name: str, row: Dict[str, Any]) -> None:
column_defs = [f"{_escape_special_chars(col)} {self._get_oracle_type(row[col])}" for col in row.keys()]
column_defs = [
f"{_escape_special_chars(col)} {self._get_oracle_type(row[col])}"
for col in row.keys()
]
create_table_sql = (
f"CREATE TABLE {table_name.upper()} ({', '.join(column_defs)})"
)
Expand All @@ -112,7 +122,9 @@ def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None:
existing_columns = self.table_cache.get_columns(table_name)

# Find unique new columns from all rows in the payload
all_columns = {_normalize_identifier(col) for row in payload for col in row.keys()}
all_columns = {
_normalize_identifier(col) for row in payload for col in row.keys()
}
new_columns = all_columns - existing_columns.keys()
column_defs = []
if new_columns:
Expand All @@ -127,7 +139,10 @@ def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None:
)
column_defs.append((col, column_type))

string_defs = [f"{_escape_special_chars(col_def[0])} {col_def[1]}" for col_def in column_defs]
string_defs = [
f"{_escape_special_chars(col_def[0])} {col_def[1]}"
for col_def in column_defs
]
alter_sql = (
f"ALTER TABLE {table_name.upper()} ADD ({', '.join(string_defs)})"
)
Expand Down Expand Up @@ -198,12 +213,12 @@ def _insert_data(self, table_name: str, payload: List[Dict[str, Any]]) -> None:
self.cursor.executemany(queries[i], batch_data[i])

def ingest_payload(
self,
payload: List[Dict[str, Any]],
destination_table: Optional[str] = None,
target: str = None,
collection_id: Optional[str] = None,
metadata: Optional[IIngesterPlugin.IngestionMetadata] = None,
self,
payload: List[Dict[str, Any]],
destination_table: Optional[str] = None,
target: str = None,
collection_id: Optional[str] = None,
metadata: Optional[IIngesterPlugin.IngestionMetadata] = None,
) -> None:
if not payload:
return None
Expand Down
57 changes: 39 additions & 18 deletions projects/vdk-plugins/vdk-oracle/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ def test_oracle_ingest_no_table(self):
def test_oracle_ingest_no_table_special_chars(self):
runner = CliEntryBasedTestRunner(oracle_plugin)
result: Result = runner.invoke(
["run", jobs_path_from_caller_directory("oracle-ingest-job-no-table-special-chars")]
[
"run",
jobs_path_from_caller_directory(
"oracle-ingest-job-no-table-special-chars"
),
]
)
cli_assert_equal(0, result)
_verify_ingest_execution_no_table_special_chars(runner)
Expand Down Expand Up @@ -123,6 +128,7 @@ def test_oracle_ingest_different_payloads_no_table(self):
)
cli_assert_equal(0, result)
_verify_ingest_execution_different_payloads_no_table(runner)

def test_oracle_ingest_different_payloads_no_table_special_chars(self):
runner = CliEntryBasedTestRunner(oracle_plugin)
result: Result = runner.invoke(
Expand All @@ -135,6 +141,7 @@ def test_oracle_ingest_different_payloads_no_table_special_chars(self):
)
cli_assert_equal(0, result)
_verify_ingest_execution_different_payloads_no_table_special_chars(runner)

def test_oracle_ingest_blob(self):
runner = CliEntryBasedTestRunner(oracle_plugin)
result: Result = runner.invoke(
Expand Down Expand Up @@ -176,12 +183,14 @@ def _verify_ingest_execution_special_chars(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
)
expected = (' ID @str_data %int_data *float*data* BOOL_DATA '
'TIMESTAMP_DATA DECIMAL_DATA\n'
'---- ----------- ----------- -------------- ----------- '
'------------------- --------------\n'
' 5 string 12 1.2 1 2023-11-21 '
'08:12:53 0.1\n')
expected = (
" ID @str_data %int_data *float*data* BOOL_DATA "
"TIMESTAMP_DATA DECIMAL_DATA\n"
"---- ----------- ----------- -------------- ----------- "
"------------------- --------------\n"
" 5 string 12 1.2 1 2023-11-21 "
"08:12:53 0.1\n"
)
assert check_result.output == expected


Expand Down Expand Up @@ -223,16 +232,18 @@ def _verify_ingest_execution_no_table_special_chars(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
)
expected = (' ID @str_data %int_data *float*data* BOOL_DATA '
'TIMESTAMP_DATA DECIMAL_DATA\n'
'---- ----------- ----------- -------------- ----------- '
'------------------- --------------\n'
' 0 string 12 1.2 1 '
'2023-11-21T08:12:53 1.1\n'
' 1 string 12 1.2 1 '
'2023-11-21T08:12:53 1.1\n'
' 2 string 12 1.2 1 '
'2023-11-21T08:12:53 1.1\n')
expected = (
" ID @str_data %int_data *float*data* BOOL_DATA "
"TIMESTAMP_DATA DECIMAL_DATA\n"
"---- ----------- ----------- -------------- ----------- "
"------------------- --------------\n"
" 0 string 12 1.2 1 "
"2023-11-21T08:12:53 1.1\n"
" 1 string 12 1.2 1 "
"2023-11-21T08:12:53 1.1\n"
" 2 string 12 1.2 1 "
"2023-11-21T08:12:53 1.1\n"
)
assert check_result.output == expected


Expand All @@ -242,13 +253,22 @@ def _verify_ingest_execution_different_payloads_no_table(runner):
)
expected = " COUNT(*)\n----------\n 8\n"
assert check_result.output == expected


def _verify_ingest_execution_different_payloads_no_table_special_chars(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
)

actual_columns = check_result.output.split("\n")[0].split()
expected_columns = ['ID', '&timestamp_data', '^bool_data', '@int_data', '%float_data', '?str_data']
expected_columns = [
"ID",
"&timestamp_data",
"^bool_data",
"@int_data",
"%float_data",
"?str_data",
]
for expected_col in expected_columns:
assert expected_col in actual_columns

Expand All @@ -258,6 +278,7 @@ def _verify_ingest_execution_different_payloads_no_table_special_chars(runner):
)
assert check_result.output == expected_count


def _verify_ingest_execution_different_payloads(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
Expand Down

0 comments on commit 4a71461

Please sign in to comment.