Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 22, 2024
1 parent e1456c9 commit 70d6f92
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
import logging
import math
import re

from decimal import Decimal
from typing import Any, Collection, Dict, List, Optional, Set
from typing import Any
from typing import Collection

Check warning on line 9 in projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/ingest_to_oracle.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

projects/vdk-plugins/vdk-oracle/src/vdk/plugin/oracle/ingest_to_oracle.py#L9

Unused Collection imported from typing
from typing import Dict
from typing import List
from typing import Optional
from typing import Set

from vdk.api.plugin.plugin_input import PEP249Connection
from vdk.internal.builtin_plugins.connection.impl.router import ManagedConnectionRouter
Expand All @@ -23,8 +27,10 @@ class InvalidOracleIdentifierError(BaseException):

# Raise error identifier it starts with a number
if re.match("\\d", identifier):
raise InvalidOracleIdentifierError(f"{identifier} is not a valid Oracle identifier. "
f"https://docs.oracle.com/en/error-help/db/ora-00904/")
raise InvalidOracleIdentifierError(
f"{identifier} is not a valid Oracle identifier. "
f"https://docs.oracle.com/en/error-help/db/ora-00904/"
)

# https://docs.oracle.com/en/error-help/db/ora-00904/
# Alphanumeric that doesn't start with a number
Expand All @@ -36,17 +42,17 @@ class InvalidOracleIdentifierError(BaseException):
def _normalize_identifier(identifier: str) -> str:
return identifier.upper() if _is_plain_identifier(identifier) else identifier


def _escape_special_chars(value: str) -> str:
return value if _is_plain_identifier(value) else f'"{value}"'

class TableCache:

class TableCache:
def __init__(self, cursor: ManagedCursor):
self._tables: Dict[str, Dict[str, str]] = {}
self._cursor = cursor

def cache_columns(self, table: str) -> None:

# exit if the table columns have already been cached
if table.upper() in self._tables and self._tables[table.upper()]:
return
Expand Down Expand Up @@ -74,7 +80,8 @@ def update_from_col_defs(self, table: str, col_defs) -> None:

def get_col_type(self, table: str, col: str) -> str:
return self._tables.get(table.upper()).get(
col.upper() if _is_plain_identifier(col) else col)
col.upper() if _is_plain_identifier(col) else col
)

def table_exists(self, table: str) -> bool:
if table.upper() in self._tables:
Expand Down Expand Up @@ -112,7 +119,10 @@ def _get_oracle_type(value: Any) -> str:
return type_mappings.get(type(value), "VARCHAR2(255)")

def _create_table(self, table_name: str, row: Dict[str, Any]) -> None:
column_defs = [f"{_escape_special_chars(col)} {self._get_oracle_type(row[col])}" for col in row.keys()]
column_defs = [
f"{_escape_special_chars(col)} {self._get_oracle_type(row[col])}"
for col in row.keys()
]
create_table_sql = (
f"CREATE TABLE {table_name.upper()} ({', '.join(column_defs)})"
)
Expand All @@ -123,7 +133,9 @@ def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None:
existing_columns = self.table_cache.get_columns(table_name)

# Find unique new columns from all rows in the payload
all_columns = {_normalize_identifier(col) for row in payload for col in row.keys()}
all_columns = {
_normalize_identifier(col) for row in payload for col in row.keys()
}
new_columns = all_columns - existing_columns.keys()
column_defs = []
if new_columns:
Expand All @@ -138,7 +150,10 @@ def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None:
)
column_defs.append((col, column_type))

string_defs = [f"{_escape_special_chars(col_def[0])} {col_def[1]}" for col_def in column_defs]
string_defs = [
f"{_escape_special_chars(col_def[0])} {col_def[1]}"
for col_def in column_defs
]
alter_sql = (
f"ALTER TABLE {table_name.upper()} ADD ({', '.join(string_defs)})"
)
Expand Down Expand Up @@ -209,12 +224,12 @@ def _insert_data(self, table_name: str, payload: List[Dict[str, Any]]) -> None:
self.cursor.executemany(queries[i], batch_data[i])

def ingest_payload(
self,
payload: List[Dict[str, Any]],
destination_table: Optional[str] = None,
target: str = None,
collection_id: Optional[str] = None,
metadata: Optional[IIngesterPlugin.IngestionMetadata] = None,
self,
payload: List[Dict[str, Any]],
destination_table: Optional[str] = None,
target: str = None,
collection_id: Optional[str] = None,
metadata: Optional[IIngesterPlugin.IngestionMetadata] = None,
) -> None:
if not payload:
return None
Expand Down
57 changes: 39 additions & 18 deletions projects/vdk-plugins/vdk-oracle/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@ def test_oracle_ingest_no_table(self):
def test_oracle_ingest_no_table_special_chars(self):
runner = CliEntryBasedTestRunner(oracle_plugin)
result: Result = runner.invoke(
["run", jobs_path_from_caller_directory("oracle-ingest-job-no-table-special-chars")]
[
"run",
jobs_path_from_caller_directory(
"oracle-ingest-job-no-table-special-chars"
),
]
)
cli_assert_equal(0, result)
_verify_ingest_execution_no_table_special_chars(runner)
Expand Down Expand Up @@ -123,6 +128,7 @@ def test_oracle_ingest_different_payloads_no_table(self):
)
cli_assert_equal(0, result)
_verify_ingest_execution_different_payloads_no_table(runner)

def test_oracle_ingest_different_payloads_no_table_special_chars(self):
runner = CliEntryBasedTestRunner(oracle_plugin)
result: Result = runner.invoke(
Expand All @@ -135,6 +141,7 @@ def test_oracle_ingest_different_payloads_no_table_special_chars(self):
)
cli_assert_equal(0, result)
_verify_ingest_execution_different_payloads_no_table_special_chars(runner)

def test_oracle_ingest_blob(self):
runner = CliEntryBasedTestRunner(oracle_plugin)
result: Result = runner.invoke(
Expand Down Expand Up @@ -176,12 +183,14 @@ def _verify_ingest_execution_special_chars(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
)
expected = (' ID @str_data %int_data *float*data* BOOL_DATA '
'TIMESTAMP_DATA DECIMAL_DATA\n'
'---- ----------- ----------- -------------- ----------- '
'------------------- --------------\n'
' 5 string 12 1.2 1 2023-11-21 '
'08:12:53 0.1\n')
expected = (
" ID @str_data %int_data *float*data* BOOL_DATA "
"TIMESTAMP_DATA DECIMAL_DATA\n"
"---- ----------- ----------- -------------- ----------- "
"------------------- --------------\n"
" 5 string 12 1.2 1 2023-11-21 "
"08:12:53 0.1\n"
)
assert check_result.output == expected


Expand Down Expand Up @@ -223,16 +232,18 @@ def _verify_ingest_execution_no_table_special_chars(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
)
expected = (' ID @str_data %int_data *float*data* BOOL_DATA '
'TIMESTAMP_DATA DECIMAL_DATA\n'
'---- ----------- ----------- -------------- ----------- '
'------------------- --------------\n'
' 0 string 12 1.2 1 '
'2023-11-21T08:12:53 1.1\n'
' 1 string 12 1.2 1 '
'2023-11-21T08:12:53 1.1\n'
' 2 string 12 1.2 1 '
'2023-11-21T08:12:53 1.1\n')
expected = (
" ID @str_data %int_data *float*data* BOOL_DATA "
"TIMESTAMP_DATA DECIMAL_DATA\n"
"---- ----------- ----------- -------------- ----------- "
"------------------- --------------\n"
" 0 string 12 1.2 1 "
"2023-11-21T08:12:53 1.1\n"
" 1 string 12 1.2 1 "
"2023-11-21T08:12:53 1.1\n"
" 2 string 12 1.2 1 "
"2023-11-21T08:12:53 1.1\n"
)
assert check_result.output == expected


Expand All @@ -242,13 +253,22 @@ def _verify_ingest_execution_different_payloads_no_table(runner):
)
expected = " COUNT(*)\n----------\n 8\n"
assert check_result.output == expected


def _verify_ingest_execution_different_payloads_no_table_special_chars(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
)

actual_columns = check_result.output.split("\n")[0].split()
expected_columns = ['ID', '&timestamp_data', '^bool_data', '@int_data', '%float_data', '?str_data']
expected_columns = [
"ID",
"&timestamp_data",
"^bool_data",
"@int_data",
"%float_data",
"?str_data",
]
for expected_col in expected_columns:
assert expected_col in actual_columns

Expand All @@ -258,6 +278,7 @@ def _verify_ingest_execution_different_payloads_no_table_special_chars(runner):
)
assert check_result.output == expected_count


def _verify_ingest_execution_different_payloads(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
Expand Down

0 comments on commit 70d6f92

Please sign in to comment.