Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update AthenaColumn to parse array types #672

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions dbt/adapters/athena/column.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from dataclasses import dataclass
from typing import ClassVar, Dict

Expand Down Expand Up @@ -28,6 +29,9 @@ def is_binary(self) -> bool:
def is_timestamp(self) -> bool:
return self.dtype.lower() in {"timestamp"}

def is_array(self) -> bool:
nicor88 marked this conversation as resolved.
Show resolved Hide resolved
return self.dtype.lower().startswith("array") # type: ignore
nicor88 marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def string_type(cls, size: int) -> str:
return f"varchar({size})" if size > 0 else "varchar"
Expand All @@ -39,6 +43,23 @@ def binary_type(cls) -> str:
def timestamp_type(self) -> str:
return "timestamp(6)" if self.is_iceberg() else "timestamp"

@classmethod
def array_type(cls, inner_type: str) -> str:
return f"array({inner_type})"

def array_inner_type(self) -> str:
if not self.is_array():
raise DbtRuntimeError("Called array_inner_type() on non-array field!")
nicor88 marked this conversation as resolved.
Show resolved Hide resolved
# Match either `array<inner_type>` or `array(inner_type)`. Don't bother
# parsing nested arrays here, since we will expect the caller to be
# responsible for formatting the inner type, including nested arrays
pattern = r"^array[<(](.*)[>)]$"
match = re.match(pattern, self.dtype)
if match:
return match.group(1)
# If for some reason there's no match, fall back to the original string
return self.dtype # type: ignore

def string_size(self) -> int:
if not self.is_string():
raise DbtRuntimeError("Called string_size() on non-string field!")
Expand All @@ -59,4 +80,18 @@ def data_type(self) -> str:
if self.is_timestamp():
return self.timestamp_type()

if self.is_array():
# Resolve the inner type of the array, using an AthenaColumn
# instance to properly convert the inner type. Note that this will
# cause recursion in cases of nested arrays
inner_type = self.array_inner_type()
inner_type_col = AthenaColumn(
column=self.column,
dtype=inner_type,
char_size=self.char_size,
numeric_precision=self.numeric_precision,
numeric_scale=self.numeric_scale,
)
nicor88 marked this conversation as resolved.
Show resolved Hide resolved
return self.array_type(inner_type_col.data_type)

return self.dtype # type: ignore
108 changes: 108 additions & 0 deletions tests/unit/test_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import pytest
from dbt_common.exceptions import DbtRuntimeError

from dbt.adapters.athena.column import AthenaColumn
from dbt.adapters.athena.relation import TableType


class TestAthenaColumn:
nicor88 marked this conversation as resolved.
Show resolved Hide resolved
def setup_column(self, **kwargs):
base_kwargs = {"column": "foo", "dtype": "varchar"}
return AthenaColumn(**{**base_kwargs, **kwargs})

@pytest.mark.parametrize(
"table_type,expected",
[
pytest.param(TableType.TABLE, False),
pytest.param(TableType.ICEBERG, True),
],
)
def test_is_iceberg(self, table_type, expected):
column = self.setup_column(table_type=table_type)
assert column.is_iceberg() is expected

@pytest.mark.parametrize(
"dtype,expected_type_func",
[
pytest.param("varchar", "is_string"),
pytest.param("string", "is_string"),
pytest.param("binary", "is_binary"),
pytest.param("varbinary", "is_binary"),
pytest.param("timestamp", "is_timestamp"),
pytest.param("array<string>", "is_array"),
pytest.param("array(string)", "is_array"),
],
)
def test_is_type(self, dtype, expected_type_func):
column = self.setup_column(dtype=dtype)
for type_func in ["is_string", "is_binary", "is_timestamp", "is_array"]:
if type_func == expected_type_func:
assert getattr(column, type_func)()
else:
assert not getattr(column, type_func)()

@pytest.mark.parametrize("size,expected", [pytest.param(1, "varchar(1)"), pytest.param(0, "varchar")])
def test_string_type(self, size, expected):
assert AthenaColumn.string_type(size) == expected

@pytest.mark.parametrize(
"table_type,expected",
[pytest.param(TableType.TABLE, "timestamp"), pytest.param(TableType.ICEBERG, "timestamp(6)")],
)
def test_timestamp_type(self, table_type, expected):
column = self.setup_column(table_type=table_type)
assert column.timestamp_type() == expected

def test_array_type(self):
assert AthenaColumn.array_type("varchar") == "array(varchar)"

@pytest.mark.parametrize(
"dtype,expected",
[
pytest.param("array<string>", "string"),
pytest.param("array<varchar(10)>", "varchar(10)"),
pytest.param("array<array<int>>", "array<int>"),
pytest.param("array", "array"),
],
)
def test_array_inner_type(self, dtype, expected):
column = self.setup_column(dtype=dtype)
assert column.array_inner_type() == expected

def test_array_inner_type_raises_for_non_array_type(self):
column = self.setup_column(dtype="varchar")
with pytest.raises(DbtRuntimeError, match=r"Called array_inner_type\(\) on non-array field!"):
column.array_inner_type()

@pytest.mark.parametrize(
"char_size,expected",
[
pytest.param(10, 10),
pytest.param(None, 0),
],
)
def test_string_size(self, char_size, expected):
column = self.setup_column(dtype="varchar", char_size=char_size)
assert column.string_size() == expected

def test_string_size_raises_for_non_string_type(self):
column = self.setup_column(dtype="int")
with pytest.raises(DbtRuntimeError, match=r"Called string_size\(\) on non-string field!"):
column.string_size()

@pytest.mark.parametrize(
"dtype,expected",
[
pytest.param("string", "varchar(10)"),
pytest.param("decimal", "decimal(1,2)"),
pytest.param("binary", "varbinary"),
pytest.param("timestamp", "timestamp(6)"),
pytest.param("array<string>", "array(varchar(10))"),
pytest.param("array<array<string>>", "array(array(varchar(10)))"),
],
)
def test_data_type(self, dtype, expected):
column = self.setup_column(
table_type=TableType.ICEBERG, dtype=dtype, char_size=10, numeric_precision=1, numeric_scale=2
)
assert column.data_type == expected
Loading