diff --git a/.all-contributorsrc b/.all-contributorsrc index 9d8013ee..fe7897c4 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -132,6 +132,26 @@ "bug" ] }, + { + "login": "sanromeo", + "name": "Roman Korsun", + "avatar_url": "https://avatars.githubusercontent.com/u/44975602?v=4", + "profile": "https://github.com/sanromeo", + "contributions": [ + "code", + "bug" + ] + }, + { + "login": "Danya-Fpnk", + "name": "DanyaF", + "avatar_url": "https://avatars.githubusercontent.com/u/122433975?v=4", + "profile": "https://github.com/Danya-Fpnk", + "contributions": [ + "code", + "bug" + ] + }, { "login": "octiva", "name": "Spencer", diff --git a/README.md b/README.md index 972bb4df..98d80406 100644 --- a/README.md +++ b/README.md @@ -690,6 +690,8 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
Tuple[str, str]: + """Formats a value based on its column type for inclusion in a SQL query.""" + comp_func = "=" # Default comparison function + if value is None: + return "null", " is " + elif column_type == "integer": + return str(value), comp_func + elif column_type == "string": + # Properly escape single quotes in the string value + escaped_value = str(value).replace("'", "''") + return f"'{escaped_value}'", comp_func + elif column_type == "date": + return f"DATE'{value}'", comp_func + elif column_type == "timestamp": + return f"TIMESTAMP'{value}'", comp_func + else: + # Raise an error for unsupported column types + raise ValueError(f"Unsupported column type: {column_type}") diff --git a/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql b/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql index 3f64cc59..c1bf6505 100644 --- a/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql +++ b/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql @@ -1,9 +1,11 @@ {% macro get_partition_batches(sql, as_subquery=True) -%} + {# Retrieve partition configuration and set default partition limit #} {%- set partitioned_by = config.get('partitioned_by') -%} {%- set athena_partitions_limit = config.get('partitions_limit', 100) | int -%} {%- set partitioned_keys = adapter.format_partition_keys(partitioned_by) -%} {% do log('PARTITIONED KEYS: ' ~ partitioned_keys) %} + {# Retrieve distinct partitions from the given SQL #} {% call statement('get_partitions', fetch_result=True) %} {%- if as_subquery -%} select distinct {{ partitioned_keys }} from ({{ sql }}) order by {{ partitioned_keys }}; @@ -12,48 +14,73 @@ {%- endif -%} {% endcall %} + {# Initialize variables to store partition info #} {%- set table = load_result('get_partitions').table -%} {%- set rows = table.rows -%} - {%- set partitions = {} -%} - {% do log('TOTAL PARTITIONS TO PROCESS: ' ~ rows | length) %} - {%- set partitions_batches = [] -%} + {%- set ns = namespace(partitions = [], bucket_conditions = {}, bucket_numbers = [], bucket_column = None, is_bucketed = false) -%} + {# Process each partition row #} {%- for row in rows -%} {%- set single_partition = [] -%} - {%- for col in row -%} - - - {%- set column_type = adapter.convert_type(table, loop.index0) -%} - {%- set comp_func = '=' -%} - {%- if col is none -%} - {%- set value = 'null' -%} - {%- set comp_func = ' is ' -%} - {%- elif column_type == 'integer' or column_type is none -%} - {%- set value = col | string -%} - {%- elif column_type == 'string' -%} - {%- set value = "'" + col + "'" -%} - {%- elif column_type == 'date' -%} - {%- set value = "DATE'" + col | string + "'" -%} - {%- elif column_type == 'timestamp' -%} - {%- set value = "TIMESTAMP'" + col | string + "'" -%} - {%- else -%} - {%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%} + {# Use Namespace to hold the counter for loop index #} + {%- set counter = namespace(value=0) -%} + {# Loop through each column in the row #} + {%- for col, partition_key in zip(row, partitioned_by) -%} + {# Process bucketed columns using the new macro with the index #} + {%- do process_bucket_column(col, partition_key, table, ns, counter.value) -%} + + {# Logic for non-bucketed columns #} + {%- set bucket_match = modules.re.search('bucket\((.+?),\s*(\d+)\)', partition_key) -%} + {%- if not bucket_match -%} + {# For non-bucketed columns, format partition key and value #} + {%- set column_type = adapter.convert_type(table, counter.value) -%} + {%- set value, comp_func = adapter.format_value_for_partition(col, column_type) -%} + {%- set partition_key_formatted = adapter.format_one_partition_key(partitioned_by[counter.value]) -%} + {%- do single_partition.append(partition_key_formatted + comp_func + value) -%} {%- endif -%} - {%- set partition_key = adapter.format_one_partition_key(partitioned_by[loop.index0]) -%} - {%- do single_partition.append(partition_key + comp_func + value) -%} + {# Increment the counter #} + {%- set counter.value = counter.value + 1 -%} {%- endfor -%} + {# Concatenate conditions for a single partition #} {%- set single_partition_expression = single_partition | join(' and ') -%} + {%- if single_partition_expression not in ns.partitions %} + {%- do ns.partitions.append(single_partition_expression) -%} + {%- endif -%} + {%- endfor -%} - {%- set batch_number = (loop.index0 / athena_partitions_limit) | int -%} - {% if not batch_number in partitions %} - {% do partitions.update({batch_number: []}) %} - {% endif %} + {# Calculate total batches based on bucketing and partitioning #} + {%- if ns.is_bucketed -%} + {%- set total_batches = ns.partitions | length * ns.bucket_numbers | length -%} + {%- else -%} + {%- set total_batches = ns.partitions | length -%} + {%- endif -%} + {% do log('TOTAL PARTITIONS TO PROCESS: ' ~ total_batches) %} - {%- do partitions[batch_number].append('(' + single_partition_expression + ')') -%} - {%- if partitions[batch_number] | length == athena_partitions_limit or loop.last -%} - {%- do partitions_batches.append(partitions[batch_number] | join(' or ')) -%} + {# Determine the number of batches per partition limit #} + {%- set batches_per_partition_limit = (total_batches // athena_partitions_limit) + (total_batches % athena_partitions_limit > 0) -%} + + {# Create conditions for each batch #} + {%- set partitions_batches = [] -%} + {%- for i in range(batches_per_partition_limit) -%} + {%- set batch_conditions = [] -%} + {%- if ns.is_bucketed -%} + {# Combine partition and bucket conditions for each batch #} + {%- for partition_expression in ns.partitions -%} + {%- for bucket_num in ns.bucket_numbers -%} + {%- set bucket_condition = ns.bucket_column + " IN (" + ns.bucket_conditions[bucket_num] | join(", ") + ")" -%} + {%- set combined_condition = "(" + partition_expression + ' and ' + bucket_condition + ")" -%} + {%- do batch_conditions.append(combined_condition) -%} + {%- endfor -%} + {%- endfor -%} + {%- else -%} + {# Extend batch conditions with partitions for non-bucketed columns #} + {%- do batch_conditions.extend(ns.partitions) -%} {%- endif -%} + {# Calculate batch start and end index and append batch conditions #} + {%- set start_index = i * athena_partitions_limit -%} + {%- set end_index = start_index + athena_partitions_limit -%} + {%- do partitions_batches.append(batch_conditions[start_index:end_index] | join(' or ')) -%} {%- endfor -%} {{ return(partitions_batches) }} diff --git a/dbt/include/athena/macros/materializations/models/helpers/process_bucket_column.sql b/dbt/include/athena/macros/materializations/models/helpers/process_bucket_column.sql new file mode 100644 index 00000000..3790fbba --- /dev/null +++ b/dbt/include/athena/macros/materializations/models/helpers/process_bucket_column.sql @@ -0,0 +1,20 @@ +{% macro process_bucket_column(col, partition_key, table, ns, col_index) %} + {# Extract bucket information from the partition key #} + {%- set bucket_match = modules.re.search('bucket\((.+?),\s*(\d+)\)', partition_key) -%} + + {%- if bucket_match -%} + {# For bucketed columns, compute bucket numbers and conditions #} + {%- set column_type = adapter.convert_type(table, col_index) -%} + {%- set ns.is_bucketed = true -%} + {%- set ns.bucket_column = bucket_match[1] -%} + {%- set bucket_num = adapter.murmur3_hash(col, bucket_match[2] | int) -%} + {%- set formatted_value, comp_func = adapter.format_value_for_partition(col, column_type) -%} + + {%- if bucket_num not in ns.bucket_numbers %} + {%- do ns.bucket_numbers.append(bucket_num) %} + {%- do ns.bucket_conditions.update({bucket_num: [formatted_value]}) -%} + {%- elif formatted_value not in ns.bucket_conditions[bucket_num] %} + {%- do ns.bucket_conditions[bucket_num].append(formatted_value) -%} + {%- endif -%} + {%- endif -%} +{% endmacro %} diff --git a/setup.py b/setup.py index a2509409..1d4e90ec 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ def _get_package_version() -> str: "boto3~=1.26", "boto3-stubs[athena,glue,lakeformation,sts]~=1.26", "dbt-core~=1.7.0", + "mmh3~=4.0.1", "pyathena>=2.25,<4.0", "pydantic>=1.10,<3.0", "tenacity~=8.2", diff --git a/tests/functional/adapter/test_partitions.py b/tests/functional/adapter/test_partitions.py index f5f1e6d3..da2e5955 100644 --- a/tests/functional/adapter/test_partitions.py +++ b/tests/functional/adapter/test_partitions.py @@ -78,6 +78,28 @@ NULL as date_column """ +test_bucket_partitions_sql = """ +with non_random_strings as ( + select + chr(cast(65 + (row_number() over () % 26) as bigint)) || + chr(cast(65 + ((row_number() over () + 1) % 26) as bigint)) || + chr(cast(65 + ((row_number() over () + 4) % 26) as bigint)) as non_random_str + from + (select 1 union all select 2 union all select 3) as temp_table +) +select + cast(date_column as date) as date_column, + doy(date_column) as doy, + nrnd.non_random_str +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-24'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +join non_random_strings nrnd on true +""" + class TestHiveTablePartitions: @pytest.fixture(scope="class") @@ -264,3 +286,67 @@ def test__check_run_with_partitions(self, project): records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] assert records_count_first_run == 202 + + +class TestIcebergTablePartitionsBuckets: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "table", + "+partitioned_by": ["DAY(date_column)", "doy", "bucket(non_random_str, 5)"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_bucket_partitions.sql": test_bucket_partitions_sql, + } + + def test__check_run_with_bucket_and_partitions(self, project): + relation_name = "test_bucket_partitions" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 615 + + +class TestIcebergTableBuckets: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "table", + "+partitioned_by": ["bucket(non_random_str, 5)"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_bucket_partitions.sql": test_bucket_partitions_sql, + } + + def test__check_run_with_bucket_in_partitions(self, project): + relation_name = "test_bucket_partitions" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 615 diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 168fa3cc..1c8c1422 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1,3 +1,4 @@ +import datetime import decimal from unittest import mock from unittest.mock import patch @@ -1442,6 +1443,83 @@ def test__get_relation_type_iceberg(self, dbt_debug_caplog, mock_aws_service): def test__is_current_column(self, column, expected): assert self.adapter._is_current_column(column) == expected + @pytest.mark.parametrize( + "partition_keys, expected_result", + [ + ( + ["year(date_col)", "bucket(col_name, 10)", "default_partition_key"], + "date_trunc('year', date_col), col_name, default_partition_key", + ), + ], + ) + def test_format_partition_keys(self, partition_keys, expected_result): + assert self.adapter.format_partition_keys(partition_keys) == expected_result + + @pytest.mark.parametrize( + "partition_key, expected_result", + [ + ("month(hidden)", "date_trunc('month', hidden)"), + ("bucket(bucket_col, 10)", "bucket_col"), + ("regular_col", "regular_col"), + ], + ) + def test_format_one_partition_key(self, partition_key, expected_result): + assert self.adapter.format_one_partition_key(partition_key) == expected_result + + def test_murmur3_hash_with_int(self): + bucket_number = self.adapter.murmur3_hash(123, 100) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + assert bucket_number == 54 + + def test_murmur3_hash_with_date(self): + d = datetime.date.today() + bucket_number = self.adapter.murmur3_hash(d, 100) + assert isinstance(d, datetime.date) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + + def test_murmur3_hash_with_datetime(self): + dt = datetime.datetime.now() + bucket_number = self.adapter.murmur3_hash(dt, 100) + assert isinstance(dt, datetime.datetime) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + + def test_murmur3_hash_with_str(self): + bucket_number = self.adapter.murmur3_hash("test_string", 100) + assert isinstance(bucket_number, int) + assert 0 <= bucket_number < 100 + assert bucket_number == 88 + + def test_murmur3_hash_uniqueness(self): + # Ensuring different inputs produce different hashes + hash1 = self.adapter.murmur3_hash("string1", 100) + hash2 = self.adapter.murmur3_hash("string2", 100) + assert hash1 != hash2 + + def test_murmur3_hash_with_unsupported_type(self): + with pytest.raises(TypeError): + self.adapter.murmur3_hash([1, 2, 3], 100) + + @pytest.mark.parametrize( + "value, column_type, expected_result", + [ + (None, "integer", ("null", " is ")), + (42, "integer", ("42", "=")), + ("O'Reilly", "string", ("'O''Reilly'", "=")), + ("test", "string", ("'test'", "=")), + ("2021-01-01", "date", ("DATE'2021-01-01'", "=")), + ("2021-01-01 12:00:00", "timestamp", ("TIMESTAMP'2021-01-01 12:00:00'", "=")), + ], + ) + def test_format_value_for_partition(self, value, column_type, expected_result): + assert self.adapter.format_value_for_partition(value, column_type) == expected_result + + def test_format_unsupported_type(self): + with pytest.raises(ValueError): + self.adapter.format_value_for_partition("test", "unsupported_type") + class TestAthenaFilterCatalog: def test__catalog_filter_table(self):