Skip to content

Commit

Permalink
vdk-oracle: support type inference
Browse files Browse the repository at this point in the history
Why?

Inferring the type in case a string is
passed in the payload instead of the corresponding
data type improves the user experience

Type conversion can be relegated to vdk instead of
doing pre-processing of large payloads inside data jobs

What?

Add column data types to column cache
In case a string is passed, infer the python data type based on the Oracle column data type
In case math.nan is passed, write NULL to the Oracle table

How was this tested?

Ran functional tests locally
CI/CD

What kind of change is this?

Feature/non-breaking

Signed-off-by: Dilyan Marinov <[email protected]>
  • Loading branch information
Dilyan Marinov committed Nov 30, 2023
1 parent 6951f6f commit 566bc6e
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import datetime
import logging
import math
from decimal import Decimal
from typing import Any
from typing import Dict
Expand All @@ -22,7 +23,7 @@ def __init__(self, connections: ManagedConnectionRouter):
self.conn: PEP249Connection = connections.open_connection("ORACLE").connect()
self.cursor: ManagedCursor = self.conn.cursor()
self.table_cache: Set[str] = set() # Cache to store existing tables
self.column_cache: Dict[str, str] = {} # New cache for columns
self.column_cache: Dict[str, Dict[str, str]] = {} # New cache for columns

@staticmethod
def _get_oracle_type(value: Any) -> str:
Expand Down Expand Up @@ -62,10 +63,13 @@ def _create_table(self, table_name: str, row: Dict[str, Any]) -> None:
def _cache_columns(self, table_name: str) -> None:
try:
self.cursor.execute(
f"SELECT column_name FROM user_tab_columns WHERE table_name = '{table_name.upper()}'"
f"SELECT column_name, data_type, data_scale FROM user_tab_columns WHERE table_name = '{table_name.upper()}'"
)
result = self.cursor.fetchall()
self.column_cache[table_name.upper()] = {column[0] for column in result}
self.column_cache[table_name.upper()] = {
col: ("DECIMAL" if data_type == "NUMBER" and data_scale else data_type)
for (col, data_type, data_scale) in result
}
except Exception as e:
# TODO: https://github.com/vmware/versatile-data-kit/issues/2932
log.error(
Expand All @@ -81,10 +85,9 @@ def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None:

# Find unique new columns from all rows in the payload
all_columns = {col.upper() for row in payload for col in row.keys()}
new_columns = all_columns - existing_columns

new_columns = all_columns - existing_columns.keys()
column_defs = []
if new_columns:
column_defs = []
for col in new_columns:
sample_value = next(
(row[col] for row in payload if row.get(col) is not None), None
Expand All @@ -94,19 +97,41 @@ def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None:
if sample_value is not None
else "VARCHAR2(255)"
)
column_defs.append(f"{col} {column_type}")
column_defs.append((col, column_type))

string_defs = [f"{col_def[0]} {col_def[1]}" for col_def in column_defs]
alter_sql = (
f"ALTER TABLE {table_name.upper()} ADD ({', '.join(column_defs)})"
f"ALTER TABLE {table_name.upper()} ADD ({', '.join(string_defs)})"
)
self.cursor.execute(alter_sql)
self.column_cache[table_name.upper()].update(new_columns)
self.column_cache[table_name.upper()].update(column_defs)

# TODO: https://github.com/vmware/versatile-data-kit/issues/2929
# TODO: https://github.com/vmware/versatile-data-kit/issues/2930
def _cast_to_correct_type(self, value: Any) -> Any:
if type(value) is Decimal:
def _cast_to_correct_type(self, table: str, column: str, value: Any) -> Any:
def cast_string_to_type(db_type: str, payload_value: str) -> Any:
if db_type == "FLOAT" or db_type == "DECIMAL":
return float(payload_value)
if db_type == "NUMBER":
payload_value = payload_value.capitalize()
return (
bool(payload_value)
if payload_value in ["True", "False"]
else int(payload_value)
)
if "TIMESTAMP" in db_type:
return datetime.datetime.strptime(payload_value, "%Y-%m-%dT%H:%M:%S")
if db_type == "BLOB":
return payload_value.encode("utf-8")
return payload_value

if (isinstance(value, float) or isinstance(value, int)) and math.isnan(value):
return None
if isinstance(value, Decimal):
return float(value)
if isinstance(value, str):
col_type = self.column_cache.get(table.upper()).get(column.upper())
return cast_string_to_type(col_type, value)
return value

# TODO: Look into potential optimizations
Expand All @@ -132,7 +157,10 @@ def _insert_data(self, table_name: str, payload: List[Dict[str, Any]]) -> None:
queries.append(insert_sql)
temp_data = []
for row in batch:
temp = [self._cast_to_correct_type(row[col]) for col in columns]
temp = [
self._cast_to_correct_type(table_name, col, row[col])
for col in columns
]
temp_data.append(temp)
batch_data.append(temp_data)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
begin
execute immediate 'drop table test_table';
exception when others then if sqlcode <> -942 then raise; end if;
end;
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
create table test_table (
id number,
str_data varchar2(255),
int_data number,
nan_int_data number,
float_data float,
bool_data number(1),
timestamp_data timestamp,
decimal_data decimal(14,8),
primary key(id))
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2021-2023 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
import datetime
import math
from decimal import Decimal


def run(job_input):
payload = {
"id": "5",
"str_data": "string",
"int_data": "12",
"nan_int_data": math.nan,
"float_data": "1.2",
"bool_data": "False",
"timestamp_data": "2023-11-21T08:12:53",
"decimal_data": "0.1",
}

job_input.send_object_for_ingestion(payload=payload, destination_table="test_table")
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,6 @@


def run(job_input):
# TODO: https://github.com/vmware/versatile-data-kit/issues/2929
# setup different data types (all passed initially as strings) are cast correctly
# payload = {
# "id": "",
# "str_data": "string",
# "int_data": "12",
# "float_data": "1.2",
# "bool_data": "True",
# # TODO: add timestamp
# # TODO: add decimal
# }

# for i in range(5):
# local_payload = payload.copy()
# local_payload["id"] = i
# job_input.send_object_for_ingestion(
# payload=local_payload, destination_table="test_table"
# )

payload_with_types = {
"id": 5,
"str_data": "string",
Expand All @@ -37,12 +18,3 @@ def run(job_input):
job_input.send_object_for_ingestion(
payload=payload_with_types, destination_table="test_table"
)

# TODO: https://github.com/vmware/versatile-data-kit/issues/2930
# this setup:
# a) partial payload (only few columns are included)
# b) includes float data which is NaN
# payload2 = {"id": 6, "float_data": math.nan, "int_data": math.nan}
# job_input.send_object_for_ingestion(
# payload=payload2, destination_table="test_table"
# )
23 changes: 23 additions & 0 deletions projects/vdk-plugins/vdk-oracle/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def test_oracle_ingest_existing_table(self):
cli_assert_equal(0, result)
_verify_ingest_execution(runner)

def test_oracle_ingest_type_inference(self):
runner = CliEntryBasedTestRunner(oracle_plugin)
result: Result = runner.invoke(
["run", jobs_path_from_caller_directory("oracle-ingest-job-type-inference")]
)
cli_assert_equal(0, result)
_verify_ingest_execution_type_inference(runner)

def test_oracle_ingest_no_table(self):
runner = CliEntryBasedTestRunner(oracle_plugin)
result: Result = runner.invoke(
Expand Down Expand Up @@ -117,6 +125,21 @@ def _verify_ingest_execution(runner):
assert check_result.output == expected


def _verify_ingest_execution_type_inference(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
)
expected = (
" ID STR_DATA INT_DATA NAN_INT_DATA FLOAT_DATA BOOL_DATA "
"TIMESTAMP_DATA DECIMAL_DATA\n"
"---- ---------- ---------- -------------- ------------ ----------- "
"------------------- --------------\n"
" 5 string 12 1.2 1 "
"2023-11-21 08:12:53 0.1\n"
)
assert check_result.output == expected


def _verify_ingest_execution_no_table(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
Expand Down

0 comments on commit 566bc6e

Please sign in to comment.