Skip to content

Commit

Permalink
vdk-oracle: support type inference (#2948)
Browse files Browse the repository at this point in the history
## Why?

Inferring the type in case a string is passed in the payload instead of
the corresponding data type improves the user experience. Type
conversion can be relegated to vdk instead of doing pre-processing of
large payloads inside data jobs

## What?

Add column data types to column cache
In case a string is passed, infer the python data type based on the
Oracle column data type.
In case math.nan is passed, write NULL to the Oracle table

## How was this tested?

Ran functional tests locally
CI/CD

## What kind of change is this?

Feature/non-breaking

Signed-off-by: Dilyan Marinov <[email protected]>
  • Loading branch information
DeltaMichael authored Nov 30, 2023
1 parent 7694f22 commit 43f5bdc
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import datetime
import logging
import math
from decimal import Decimal
from typing import Any
from typing import Dict
Expand All @@ -22,7 +23,7 @@ def __init__(self, connections: ManagedConnectionRouter):
self.conn: PEP249Connection = connections.open_connection("ORACLE").connect()
self.cursor: ManagedCursor = self.conn.cursor()
self.table_cache: Set[str] = set() # Cache to store existing tables
self.column_cache: Dict[str, str] = {} # New cache for columns
self.column_cache: Dict[str, Dict[str, str]] = {} # New cache for columns

@staticmethod
def _get_oracle_type(value: Any) -> str:
Expand Down Expand Up @@ -62,10 +63,13 @@ def _create_table(self, table_name: str, row: Dict[str, Any]) -> None:
def _cache_columns(self, table_name: str) -> None:
try:
self.cursor.execute(
f"SELECT column_name FROM user_tab_columns WHERE table_name = '{table_name.upper()}'"
f"SELECT column_name, data_type, data_scale FROM user_tab_columns WHERE table_name = '{table_name.upper()}'"
)
result = self.cursor.fetchall()
self.column_cache[table_name.upper()] = {column[0] for column in result}
self.column_cache[table_name.upper()] = {
col: ("DECIMAL" if data_type == "NUMBER" and data_scale else data_type)
for (col, data_type, data_scale) in result
}
except Exception as e:
# TODO: https://github.com/vmware/versatile-data-kit/issues/2932
log.error(
Expand All @@ -81,10 +85,9 @@ def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None:

# Find unique new columns from all rows in the payload
all_columns = {col.upper() for row in payload for col in row.keys()}
new_columns = all_columns - existing_columns

new_columns = all_columns - existing_columns.keys()
column_defs = []
if new_columns:
column_defs = []
for col in new_columns:
sample_value = next(
(row[col] for row in payload if row.get(col) is not None), None
Expand All @@ -94,19 +97,41 @@ def _add_columns(self, table_name: str, payload: List[Dict[str, Any]]) -> None:
if sample_value is not None
else "VARCHAR2(255)"
)
column_defs.append(f"{col} {column_type}")
column_defs.append((col, column_type))

string_defs = [f"{col_def[0]} {col_def[1]}" for col_def in column_defs]
alter_sql = (
f"ALTER TABLE {table_name.upper()} ADD ({', '.join(column_defs)})"
f"ALTER TABLE {table_name.upper()} ADD ({', '.join(string_defs)})"
)
self.cursor.execute(alter_sql)
self.column_cache[table_name.upper()].update(new_columns)
self.column_cache[table_name.upper()].update(column_defs)

# TODO: https://github.com/vmware/versatile-data-kit/issues/2929
# TODO: https://github.com/vmware/versatile-data-kit/issues/2930
def _cast_to_correct_type(self, value: Any) -> Any:
if type(value) is Decimal:
def _cast_to_correct_type(self, table: str, column: str, value: Any) -> Any:
def cast_string_to_type(db_type: str, payload_value: str) -> Any:
if db_type == "FLOAT" or db_type == "DECIMAL":
return float(payload_value)
if db_type == "NUMBER":
payload_value = payload_value.capitalize()
return (
bool(payload_value)
if payload_value in ["True", "False"]
else int(payload_value)
)
if "TIMESTAMP" in db_type:
return datetime.datetime.strptime(payload_value, "%Y-%m-%dT%H:%M:%S")
if db_type == "BLOB":
return payload_value.encode("utf-8")
return payload_value

if (isinstance(value, float) or isinstance(value, int)) and math.isnan(value):
return None
if isinstance(value, Decimal):
return float(value)
if isinstance(value, str):
col_type = self.column_cache.get(table.upper()).get(column.upper())
return cast_string_to_type(col_type, value)
return value

# TODO: Look into potential optimizations
Expand All @@ -132,7 +157,10 @@ def _insert_data(self, table_name: str, payload: List[Dict[str, Any]]) -> None:
queries.append(insert_sql)
temp_data = []
for row in batch:
temp = [self._cast_to_correct_type(row[col]) for col in columns]
temp = [
self._cast_to_correct_type(table_name, col, row[col])
for col in columns
]
temp_data.append(temp)
batch_data.append(temp_data)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
begin
execute immediate 'drop table test_table';
exception when others then if sqlcode <> -942 then raise; end if;
end;
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
create table test_table (
id number,
str_data varchar2(255),
int_data number,
nan_int_data number,
float_data float,
bool_data number(1),
timestamp_data timestamp,
decimal_data decimal(14,8),
primary key(id))
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2021-2023 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
import datetime
import math
from decimal import Decimal


def run(job_input):
payload = {
"id": "5",
"str_data": "string",
"int_data": "12",
"nan_int_data": math.nan,
"float_data": "1.2",
"bool_data": "False",
"timestamp_data": "2023-11-21T08:12:53",
"decimal_data": "0.1",
}

job_input.send_object_for_ingestion(payload=payload, destination_table="test_table")
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,6 @@


def run(job_input):
# TODO: https://github.com/vmware/versatile-data-kit/issues/2929
# setup different data types (all passed initially as strings) are cast correctly
# payload = {
# "id": "",
# "str_data": "string",
# "int_data": "12",
# "float_data": "1.2",
# "bool_data": "True",
# # TODO: add timestamp
# # TODO: add decimal
# }

# for i in range(5):
# local_payload = payload.copy()
# local_payload["id"] = i
# job_input.send_object_for_ingestion(
# payload=local_payload, destination_table="test_table"
# )

payload_with_types = {
"id": 5,
"str_data": "string",
Expand All @@ -37,12 +18,3 @@ def run(job_input):
job_input.send_object_for_ingestion(
payload=payload_with_types, destination_table="test_table"
)

# TODO: https://github.com/vmware/versatile-data-kit/issues/2930
# this setup:
# a) partial payload (only few columns are included)
# b) includes float data which is NaN
# payload2 = {"id": 6, "float_data": math.nan, "int_data": math.nan}
# job_input.send_object_for_ingestion(
# payload=payload2, destination_table="test_table"
# )
23 changes: 23 additions & 0 deletions projects/vdk-plugins/vdk-oracle/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def test_oracle_ingest_existing_table(self):
cli_assert_equal(0, result)
_verify_ingest_execution(runner)

def test_oracle_ingest_type_inference(self):
runner = CliEntryBasedTestRunner(oracle_plugin)
result: Result = runner.invoke(
["run", jobs_path_from_caller_directory("oracle-ingest-job-type-inference")]
)
cli_assert_equal(0, result)
_verify_ingest_execution_type_inference(runner)

def test_oracle_ingest_no_table(self):
runner = CliEntryBasedTestRunner(oracle_plugin)
result: Result = runner.invoke(
Expand Down Expand Up @@ -117,6 +125,21 @@ def _verify_ingest_execution(runner):
assert check_result.output == expected


def _verify_ingest_execution_type_inference(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
)
expected = (
" ID STR_DATA INT_DATA NAN_INT_DATA FLOAT_DATA BOOL_DATA "
"TIMESTAMP_DATA DECIMAL_DATA\n"
"---- ---------- ---------- -------------- ------------ ----------- "
"------------------- --------------\n"
" 5 string 12 1.2 1 "
"2023-11-21 08:12:53 0.1\n"
)
assert check_result.output == expected


def _verify_ingest_execution_no_table(runner):
check_result = runner.invoke(
["oracle-query", "--query", "SELECT * FROM test_table"]
Expand Down

0 comments on commit 43f5bdc

Please sign in to comment.