Skip to content

Commit

Permalink
Merge branch 'main' into all-contributors/add-octiva
Browse files Browse the repository at this point in the history
  • Loading branch information
svdimchenko committed Jan 4, 2024
2 parents b5955a7 + ebb4b16 commit 12f474d
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 32 deletions.
20 changes: 20 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,26 @@
"bug"
]
},
{
"login": "sanromeo",
"name": "Roman Korsun",
"avatar_url": "https://avatars.githubusercontent.com/u/44975602?v=4",
"profile": "https://github.com/sanromeo",
"contributions": [
"code",
"bug"
]
},
{
"login": "Danya-Fpnk",
"name": "DanyaF",
"avatar_url": "https://avatars.githubusercontent.com/u/122433975?v=4",
"profile": "https://github.com/Danya-Fpnk",
"contributions": [
"code",
"bug"
]
},
{
"login": "octiva",
"name": "Spencer",
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,8 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<td align="center" valign="top" width="14.28%"><a href="https://github.com/henriblancke"><img src="https://avatars.githubusercontent.com/u/1708162?v=4?s=100" width="100px;" alt="Henri Blancke"/><br /><sub><b>Henri Blancke</b></sub></a><br /><a href="https://github.com/dbt-athena/dbt-athena/commits?author=henriblancke" title="Code">💻</a> <a href="https://github.com/dbt-athena/dbt-athena/issues?q=author%3Ahenriblancke" title="Bug reports">🐛</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/svdimchenko"><img src="https://avatars.githubusercontent.com/u/39801237?v=4?s=100" width="100px;" alt="Serhii Dimchenko"/><br /><sub><b>Serhii Dimchenko</b></sub></a><br /><a href="https://github.com/dbt-athena/dbt-athena/commits?author=svdimchenko" title="Code">💻</a> <a href="https://github.com/dbt-athena/dbt-athena/issues?q=author%3Asvdimchenko" title="Bug reports">🐛</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/chrischin478"><img src="https://avatars.githubusercontent.com/u/47199426?v=4?s=100" width="100px;" alt="chrischin478"/><br /><sub><b>chrischin478</b></sub></a><br /><a href="https://github.com/dbt-athena/dbt-athena/commits?author=chrischin478" title="Code">💻</a> <a href="https://github.com/dbt-athena/dbt-athena/issues?q=author%3Achrischin478" title="Bug reports">🐛</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/sanromeo"><img src="https://avatars.githubusercontent.com/u/44975602?v=4?s=100" width="100px;" alt="Roman Korsun"/><br /><sub><b>Roman Korsun</b></sub></a><br /><a href="https://github.com/dbt-athena/dbt-athena/commits?author=sanromeo" title="Code">💻</a> <a href="https://github.com/dbt-athena/dbt-athena/issues?q=author%3Asanromeo" title="Bug reports">🐛</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/Danya-Fpnk"><img src="https://avatars.githubusercontent.com/u/122433975?v=4?s=100" width="100px;" alt="DanyaF"/><br /><sub><b>DanyaF</b></sub></a><br /><a href="https://github.com/dbt-athena/dbt-athena/commits?author=Danya-Fpnk" title="Code">💻</a> <a href="https://github.com/dbt-athena/dbt-athena/issues?q=author%3ADanya-Fpnk" title="Bug reports">🐛</a></td>
<td align="center" valign="top" width="14.28%"><a href="https://github.com/octiva"><img src="https://avatars.githubusercontent.com/u/53303191?v=4?s=100" width="100px;" alt="Spencer"/><br /><sub><b>Spencer</b></sub></a><br /><a href="https://github.com/dbt-athena/dbt-athena/commits?author=octiva" title="Code">💻</a></td>
</tr>
</tbody>
Expand Down
53 changes: 51 additions & 2 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import os
import posixpath as path
import re
import struct
import tempfile
from dataclasses import dataclass
from datetime import date, datetime
from itertools import chain
from textwrap import dedent
from threading import Lock
Expand All @@ -12,6 +14,7 @@
from uuid import uuid4

import agate
import mmh3
from botocore.exceptions import ClientError
from mypy_boto3_athena.type_defs import DataCatalogTypeDef
from mypy_boto3_glue.type_defs import (
Expand Down Expand Up @@ -118,6 +121,7 @@ class AthenaConfig(AdapterConfig):
class AthenaAdapter(SQLAdapter):
BATCH_CREATE_PARTITION_API_LIMIT = 100
BATCH_DELETE_PARTITION_API_LIMIT = 25
INTEGER_MAX_VALUE_32_BIT_SIGNED = 0x7FFFFFFF

ConnectionManager = AthenaConnectionManager
Relation = AthenaRelation
Expand Down Expand Up @@ -1262,6 +1266,51 @@ def format_partition_keys(self, partition_keys: List[str]) -> str:

@available
def format_one_partition_key(self, partition_key: str) -> str:
"""Check if partition key uses Iceberg hidden partitioning"""
"""Check if partition key uses Iceberg hidden partitioning or bucket partitioning"""
hidden = re.search(r"^(hour|day|month|year)\((.+)\)", partition_key.lower())
return f"date_trunc('{hidden.group(1)}', {hidden.group(2)})" if hidden else partition_key.lower()
bucket = re.search(r"bucket\((.+),", partition_key.lower())
if hidden:
return f"date_trunc('{hidden.group(1)}', {hidden.group(2)})"
elif bucket:
return bucket.group(1)
else:
return partition_key.lower()

@available
def murmur3_hash(self, value: Any, num_buckets: int) -> int:
"""
Computes a hash for the given value using the MurmurHash3 algorithm and returns a bucket number.
This method was adopted from https://github.com/apache/iceberg-python/blob/main/pyiceberg/transforms.py#L240
"""
if isinstance(value, int): # int, long
hash_value = mmh3.hash(struct.pack("<q", value))
elif isinstance(value, (datetime, date)): # date, time, timestamp, timestampz
timestamp = int(value.timestamp()) if isinstance(value, datetime) else int(value.strftime("%s"))
hash_value = mmh3.hash(struct.pack("<q", timestamp))
elif isinstance(value, (str, bytes)): # string
hash_value = mmh3.hash(value)
else:
raise TypeError(f"Need to add support data type for hashing: {type(value)}")

return int((hash_value & self.INTEGER_MAX_VALUE_32_BIT_SIGNED) % num_buckets)

@available
def format_value_for_partition(self, value: Any, column_type: str) -> Tuple[str, str]:
"""Formats a value based on its column type for inclusion in a SQL query."""
comp_func = "=" # Default comparison function
if value is None:
return "null", " is "
elif column_type == "integer":
return str(value), comp_func
elif column_type == "string":
# Properly escape single quotes in the string value
escaped_value = str(value).replace("'", "''")
return f"'{escaped_value}'", comp_func
elif column_type == "date":
return f"DATE'{value}'", comp_func
elif column_type == "timestamp":
return f"TIMESTAMP'{value}'", comp_func
else:
# Raise an error for unsupported column types
raise ValueError(f"Unsupported column type: {column_type}")
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
{% macro get_partition_batches(sql, as_subquery=True) -%}
{# Retrieve partition configuration and set default partition limit #}
{%- set partitioned_by = config.get('partitioned_by') -%}
{%- set athena_partitions_limit = config.get('partitions_limit', 100) | int -%}
{%- set partitioned_keys = adapter.format_partition_keys(partitioned_by) -%}
{% do log('PARTITIONED KEYS: ' ~ partitioned_keys) %}

{# Retrieve distinct partitions from the given SQL #}
{% call statement('get_partitions', fetch_result=True) %}
{%- if as_subquery -%}
select distinct {{ partitioned_keys }} from ({{ sql }}) order by {{ partitioned_keys }};
Expand All @@ -12,48 +14,73 @@
{%- endif -%}
{% endcall %}

{# Initialize variables to store partition info #}
{%- set table = load_result('get_partitions').table -%}
{%- set rows = table.rows -%}
{%- set partitions = {} -%}
{% do log('TOTAL PARTITIONS TO PROCESS: ' ~ rows | length) %}
{%- set partitions_batches = [] -%}
{%- set ns = namespace(partitions = [], bucket_conditions = {}, bucket_numbers = [], bucket_column = None, is_bucketed = false) -%}

{# Process each partition row #}
{%- for row in rows -%}
{%- set single_partition = [] -%}
{%- for col in row -%}


{%- set column_type = adapter.convert_type(table, loop.index0) -%}
{%- set comp_func = '=' -%}
{%- if col is none -%}
{%- set value = 'null' -%}
{%- set comp_func = ' is ' -%}
{%- elif column_type == 'integer' or column_type is none -%}
{%- set value = col | string -%}
{%- elif column_type == 'string' -%}
{%- set value = "'" + col + "'" -%}
{%- elif column_type == 'date' -%}
{%- set value = "DATE'" + col | string + "'" -%}
{%- elif column_type == 'timestamp' -%}
{%- set value = "TIMESTAMP'" + col | string + "'" -%}
{%- else -%}
{%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%}
{# Use Namespace to hold the counter for loop index #}
{%- set counter = namespace(value=0) -%}
{# Loop through each column in the row #}
{%- for col, partition_key in zip(row, partitioned_by) -%}
{# Process bucketed columns using the new macro with the index #}
{%- do process_bucket_column(col, partition_key, table, ns, counter.value) -%}

{# Logic for non-bucketed columns #}
{%- set bucket_match = modules.re.search('bucket\((.+?),\s*(\d+)\)', partition_key) -%}
{%- if not bucket_match -%}
{# For non-bucketed columns, format partition key and value #}
{%- set column_type = adapter.convert_type(table, counter.value) -%}
{%- set value, comp_func = adapter.format_value_for_partition(col, column_type) -%}
{%- set partition_key_formatted = adapter.format_one_partition_key(partitioned_by[counter.value]) -%}
{%- do single_partition.append(partition_key_formatted + comp_func + value) -%}
{%- endif -%}
{%- set partition_key = adapter.format_one_partition_key(partitioned_by[loop.index0]) -%}
{%- do single_partition.append(partition_key + comp_func + value) -%}
{# Increment the counter #}
{%- set counter.value = counter.value + 1 -%}
{%- endfor -%}

{# Concatenate conditions for a single partition #}
{%- set single_partition_expression = single_partition | join(' and ') -%}
{%- if single_partition_expression not in ns.partitions %}
{%- do ns.partitions.append(single_partition_expression) -%}
{%- endif -%}
{%- endfor -%}

{%- set batch_number = (loop.index0 / athena_partitions_limit) | int -%}
{% if not batch_number in partitions %}
{% do partitions.update({batch_number: []}) %}
{% endif %}
{# Calculate total batches based on bucketing and partitioning #}
{%- if ns.is_bucketed -%}
{%- set total_batches = ns.partitions | length * ns.bucket_numbers | length -%}
{%- else -%}
{%- set total_batches = ns.partitions | length -%}
{%- endif -%}
{% do log('TOTAL PARTITIONS TO PROCESS: ' ~ total_batches) %}

{%- do partitions[batch_number].append('(' + single_partition_expression + ')') -%}
{%- if partitions[batch_number] | length == athena_partitions_limit or loop.last -%}
{%- do partitions_batches.append(partitions[batch_number] | join(' or ')) -%}
{# Determine the number of batches per partition limit #}
{%- set batches_per_partition_limit = (total_batches // athena_partitions_limit) + (total_batches % athena_partitions_limit > 0) -%}

{# Create conditions for each batch #}
{%- set partitions_batches = [] -%}
{%- for i in range(batches_per_partition_limit) -%}
{%- set batch_conditions = [] -%}
{%- if ns.is_bucketed -%}
{# Combine partition and bucket conditions for each batch #}
{%- for partition_expression in ns.partitions -%}
{%- for bucket_num in ns.bucket_numbers -%}
{%- set bucket_condition = ns.bucket_column + " IN (" + ns.bucket_conditions[bucket_num] | join(", ") + ")" -%}
{%- set combined_condition = "(" + partition_expression + ' and ' + bucket_condition + ")" -%}
{%- do batch_conditions.append(combined_condition) -%}
{%- endfor -%}
{%- endfor -%}
{%- else -%}
{# Extend batch conditions with partitions for non-bucketed columns #}
{%- do batch_conditions.extend(ns.partitions) -%}
{%- endif -%}
{# Calculate batch start and end index and append batch conditions #}
{%- set start_index = i * athena_partitions_limit -%}
{%- set end_index = start_index + athena_partitions_limit -%}
{%- do partitions_batches.append(batch_conditions[start_index:end_index] | join(' or ')) -%}
{%- endfor -%}

{{ return(partitions_batches) }}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{% macro process_bucket_column(col, partition_key, table, ns, col_index) %}
{# Extract bucket information from the partition key #}
{%- set bucket_match = modules.re.search('bucket\((.+?),\s*(\d+)\)', partition_key) -%}

{%- if bucket_match -%}
{# For bucketed columns, compute bucket numbers and conditions #}
{%- set column_type = adapter.convert_type(table, col_index) -%}
{%- set ns.is_bucketed = true -%}
{%- set ns.bucket_column = bucket_match[1] -%}
{%- set bucket_num = adapter.murmur3_hash(col, bucket_match[2] | int) -%}
{%- set formatted_value, comp_func = adapter.format_value_for_partition(col, column_type) -%}

{%- if bucket_num not in ns.bucket_numbers %}
{%- do ns.bucket_numbers.append(bucket_num) %}
{%- do ns.bucket_conditions.update({bucket_num: [formatted_value]}) -%}
{%- elif formatted_value not in ns.bucket_conditions[bucket_num] %}
{%- do ns.bucket_conditions[bucket_num].append(formatted_value) -%}
{%- endif -%}
{%- endif -%}
{% endmacro %}
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def _get_package_version() -> str:
"boto3~=1.26",
"boto3-stubs[athena,glue,lakeformation,sts]~=1.26",
"dbt-core~=1.7.0",
"mmh3~=4.0.1",
"pyathena>=2.25,<4.0",
"pydantic>=1.10,<3.0",
"tenacity~=8.2",
Expand Down
86 changes: 86 additions & 0 deletions tests/functional/adapter/test_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,28 @@
NULL as date_column
"""

test_bucket_partitions_sql = """
with non_random_strings as (
select
chr(cast(65 + (row_number() over () % 26) as bigint)) ||
chr(cast(65 + ((row_number() over () + 1) % 26) as bigint)) ||
chr(cast(65 + ((row_number() over () + 4) % 26) as bigint)) as non_random_str
from
(select 1 union all select 2 union all select 3) as temp_table
)
select
cast(date_column as date) as date_column,
doy(date_column) as doy,
nrnd.non_random_str
from (
values (
sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-24'), interval '1' day)
)
) as t1(date_array)
cross join unnest(date_array) as t2(date_column)
join non_random_strings nrnd on true
"""


class TestHiveTablePartitions:
@pytest.fixture(scope="class")
Expand Down Expand Up @@ -264,3 +286,67 @@ def test__check_run_with_partitions(self, project):
records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]

assert records_count_first_run == 202


class TestIcebergTablePartitionsBuckets:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"+table_type": "iceberg",
"+materialized": "table",
"+partitioned_by": ["DAY(date_column)", "doy", "bucket(non_random_str, 5)"],
}
}

@pytest.fixture(scope="class")
def models(self):
return {
"test_bucket_partitions.sql": test_bucket_partitions_sql,
}

def test__check_run_with_bucket_and_partitions(self, project):
relation_name = "test_bucket_partitions"
model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"

first_model_run = run_dbt(["run", "--select", relation_name])
first_model_run_result = first_model_run.results[0]

# check that the model run successfully
assert first_model_run_result.status == RunStatus.Success

records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]

assert records_count_first_run == 615


class TestIcebergTableBuckets:
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"+table_type": "iceberg",
"+materialized": "table",
"+partitioned_by": ["bucket(non_random_str, 5)"],
}
}

@pytest.fixture(scope="class")
def models(self):
return {
"test_bucket_partitions.sql": test_bucket_partitions_sql,
}

def test__check_run_with_bucket_in_partitions(self, project):
relation_name = "test_bucket_partitions"
model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"

first_model_run = run_dbt(["run", "--select", relation_name])
first_model_run_result = first_model_run.results[0]

# check that the model run successfully
assert first_model_run_result.status == RunStatus.Success

records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]

assert records_count_first_run == 615
Loading

0 comments on commit 12f474d

Please sign in to comment.