diff --git a/dbt/adapters/athena/column.py b/dbt/adapters/athena/column.py index b198cbfa..a220bf3b 100644 --- a/dbt/adapters/athena/column.py +++ b/dbt/adapters/athena/column.py @@ -1,3 +1,4 @@ +import re from dataclasses import dataclass from typing import ClassVar, Dict @@ -28,6 +29,9 @@ def is_binary(self) -> bool: def is_timestamp(self) -> bool: return self.dtype.lower() in {"timestamp"} + def is_array(self) -> bool: + return self.dtype.lower().startswith("array") # type: ignore + @classmethod def string_type(cls, size: int) -> str: return f"varchar({size})" if size > 0 else "varchar" @@ -39,6 +43,23 @@ def binary_type(cls) -> str: def timestamp_type(self) -> str: return "timestamp(6)" if self.is_iceberg() else "timestamp" + @classmethod + def array_type(cls, inner_type: str) -> str: + return f"array({inner_type})" + + def array_inner_type(self) -> str: + if not self.is_array(): + raise DbtRuntimeError("Called array_inner_type() on non-array field!") + # Match either `array` or `array(inner_type)`. Don't bother + # parsing nested arrays here, since we will expect the caller to be + # responsible for formatting the inner type, including nested arrays + pattern = r"^array[<(](.*)[>)]$" + match = re.match(pattern, self.dtype) + if match: + return match.group(1) + # If for some reason there's no match, fall back to the original string + return self.dtype # type: ignore + def string_size(self) -> int: if not self.is_string(): raise DbtRuntimeError("Called string_size() on non-string field!") @@ -59,4 +80,18 @@ def data_type(self) -> str: if self.is_timestamp(): return self.timestamp_type() + if self.is_array(): + # Resolve the inner type of the array, using an AthenaColumn + # instance to properly convert the inner type. Note that this will + # cause recursion in cases of nested arrays + inner_type = self.array_inner_type() + inner_type_col = AthenaColumn( + column=self.column, + dtype=inner_type, + char_size=self.char_size, + numeric_precision=self.numeric_precision, + numeric_scale=self.numeric_scale, + ) + return self.array_type(inner_type_col.data_type) + return self.dtype # type: ignore diff --git a/tests/unit/test_column.py b/tests/unit/test_column.py new file mode 100644 index 00000000..a2b15bf9 --- /dev/null +++ b/tests/unit/test_column.py @@ -0,0 +1,108 @@ +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.athena.column import AthenaColumn +from dbt.adapters.athena.relation import TableType + + +class TestAthenaColumn: + def setup_column(self, **kwargs): + base_kwargs = {"column": "foo", "dtype": "varchar"} + return AthenaColumn(**{**base_kwargs, **kwargs}) + + @pytest.mark.parametrize( + "table_type,expected", + [ + pytest.param(TableType.TABLE, False), + pytest.param(TableType.ICEBERG, True), + ], + ) + def test_is_iceberg(self, table_type, expected): + column = self.setup_column(table_type=table_type) + assert column.is_iceberg() is expected + + @pytest.mark.parametrize( + "dtype,expected_type_func", + [ + pytest.param("varchar", "is_string"), + pytest.param("string", "is_string"), + pytest.param("binary", "is_binary"), + pytest.param("varbinary", "is_binary"), + pytest.param("timestamp", "is_timestamp"), + pytest.param("array", "is_array"), + pytest.param("array(string)", "is_array"), + ], + ) + def test_is_type(self, dtype, expected_type_func): + column = self.setup_column(dtype=dtype) + for type_func in ["is_string", "is_binary", "is_timestamp", "is_array"]: + if type_func == expected_type_func: + assert getattr(column, type_func)() + else: + assert not getattr(column, type_func)() + + @pytest.mark.parametrize("size,expected", [pytest.param(1, "varchar(1)"), pytest.param(0, "varchar")]) + def test_string_type(self, size, expected): + assert AthenaColumn.string_type(size) == expected + + @pytest.mark.parametrize( + "table_type,expected", + [pytest.param(TableType.TABLE, "timestamp"), pytest.param(TableType.ICEBERG, "timestamp(6)")], + ) + def test_timestamp_type(self, table_type, expected): + column = self.setup_column(table_type=table_type) + assert column.timestamp_type() == expected + + def test_array_type(self): + assert AthenaColumn.array_type("varchar") == "array(varchar)" + + @pytest.mark.parametrize( + "dtype,expected", + [ + pytest.param("array", "string"), + pytest.param("array", "varchar(10)"), + pytest.param("array>", "array"), + pytest.param("array", "array"), + ], + ) + def test_array_inner_type(self, dtype, expected): + column = self.setup_column(dtype=dtype) + assert column.array_inner_type() == expected + + def test_array_inner_type_raises_for_non_array_type(self): + column = self.setup_column(dtype="varchar") + with pytest.raises(DbtRuntimeError, match=r"Called array_inner_type\(\) on non-array field!"): + column.array_inner_type() + + @pytest.mark.parametrize( + "char_size,expected", + [ + pytest.param(10, 10), + pytest.param(None, 0), + ], + ) + def test_string_size(self, char_size, expected): + column = self.setup_column(dtype="varchar", char_size=char_size) + assert column.string_size() == expected + + def test_string_size_raises_for_non_string_type(self): + column = self.setup_column(dtype="int") + with pytest.raises(DbtRuntimeError, match=r"Called string_size\(\) on non-string field!"): + column.string_size() + + @pytest.mark.parametrize( + "dtype,expected", + [ + pytest.param("string", "varchar(10)"), + pytest.param("decimal", "decimal(1,2)"), + pytest.param("binary", "varbinary"), + pytest.param("timestamp", "timestamp(6)"), + pytest.param("array", "array(varchar(10))"), + pytest.param("array>", "array(array(varchar(10)))"), + ], + ) + def test_data_type(self, dtype, expected): + column = self.setup_column( + table_type=TableType.ICEBERG, dtype=dtype, char_size=10, numeric_precision=1, numeric_scale=2 + ) + assert column.data_type == expected