diff --git a/README.md b/README.md index b1e34927..b7b23289 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,10 @@ athena: - For incremental models using insert overwrite strategy on hive table - Replace the __dbt_tmp suffix used as temporary table name suffix by a unique uuid - Useful if you are looking to run multiple dbt build inserting in the same table in parallel +- `temp_schema` (`default=none`) + - For incremental models, it allows to define a schema to hold temporary create statements + used in incremental model runs + - Schema will be created in the model target database if does not exist - `lf_tags_config` (`default=none`) - [AWS Lake Formation](#aws-lake-formation-integration) tags to associate with the table and columns - `enabled` (`default=False`) whether LF tags management is enabled for a model diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index 95ee0fc2..a5108b39 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -99,6 +99,7 @@ class AthenaConfig(AdapterConfig): partitions_limit: Maximum numbers of partitions when batching. force_batch: Skip creating the table as ctas and run the operation directly in batch insert mode. unique_tmp_table_suffix: Enforce the use of a unique id as tmp table suffix instead of __dbt_tmp. + temp_schema: Define in which schema to create temporary tables used in incremental runs. """ work_group: Optional[str] = None @@ -120,6 +121,7 @@ class AthenaConfig(AdapterConfig): partitions_limit: Optional[int] = None force_batch: bool = False unique_tmp_table_suffix: bool = False + temp_schema: Optional[str] = None class AthenaAdapter(SQLAdapter): diff --git a/dbt/include/athena/macros/adapters/relation.sql b/dbt/include/athena/macros/adapters/relation.sql index 26cd5e75..611ffc59 100644 --- a/dbt/include/athena/macros/adapters/relation.sql +++ b/dbt/include/athena/macros/adapters/relation.sql @@ -36,6 +36,21 @@ {%- endcall %} {%- endmacro %} +{% macro make_temp_relation(base_relation, suffix='__dbt_tmp', temp_schema=none) %} + {%- set temp_identifier = base_relation.identifier ~ suffix -%} + {%- set temp_relation = base_relation.incorporate(path={"identifier": temp_identifier}) -%} + + {%- if temp_schema is not none -%} + {%- set temp_relation = temp_relation.incorporate(path={ + "identifier": temp_identifier, + "schema": temp_schema + }) -%} + {%- do create_schema(temp_relation) -%} + {% endif %} + + {{ return(temp_relation) }} +{% endmacro %} + {% macro athena__rename_relation(from_relation, to_relation) %} {% call statement('rename_relation') -%} alter table {{ from_relation.render_hive() }} rename to `{{ to_relation.schema }}`.`{{ to_relation.identifier }}` diff --git a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql index 887b450e..b9f5994e 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql @@ -10,6 +10,7 @@ {% set partitioned_by = config.get('partitioned_by') %} {% set force_batch = config.get('force_batch', False) | as_bool -%} {% set unique_tmp_table_suffix = config.get('unique_tmp_table_suffix', False) | as_bool -%} + {% set temp_schema = config.get('temp_schema') %} {% set target_relation = this.incorporate(type='table') %} {% set existing_relation = load_relation(this) %} -- If using insert_overwrite on Hive table, allow to set a unique tmp table suffix @@ -22,7 +23,7 @@ {% set old_tmp_relation = adapter.get_relation(identifier=target_relation.identifier ~ tmp_table_suffix, schema=schema, database=database) %} - {% set tmp_relation = make_temp_relation(target_relation, suffix=tmp_table_suffix) %} + {% set tmp_relation = make_temp_relation(target_relation, suffix=tmp_table_suffix, temp_schema=temp_schema) %} -- If no partitions are used with insert_overwrite, we fall back to append mode. {% if partitioned_by is none and strategy == 'insert_overwrite' %} diff --git a/tests/functional/adapter/test_incremental_tmp_schema.py b/tests/functional/adapter/test_incremental_tmp_schema.py new file mode 100644 index 00000000..d06e95f2 --- /dev/null +++ b/tests/functional/adapter/test_incremental_tmp_schema.py @@ -0,0 +1,108 @@ +import pytest +import yaml +from tests.functional.adapter.utils.parse_dbt_run_output import ( + extract_create_statement_table_names, + extract_running_create_statements, +) + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__schema_tmp_sql = """ +{{ config( + materialized='incremental', + incremental_strategy='insert_overwrite', + partitioned_by=['date_column'], + temp_schema=var('temp_schema_name') + ) +}} +select + random() as rnd, + cast(from_iso8601_date('{{ var('logical_date') }}') as date) as date_column +""" + + +class TestIncrementalTmpSchema: + @pytest.fixture(scope="class") + def models(self): + return {"schema_tmp.sql": models__schema_tmp_sql} + + def test__schema_tmp(self, project, capsys): + relation_name = "schema_tmp" + temp_schema_name = f"{project.test_schema}_tmp" + drop_temp_schema = f"drop schema if exists `{temp_schema_name}` cascade" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + vars_dict = { + "temp_schema_name": temp_schema_name, + "logical_date": "2024-01-01", + } + + first_model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + yaml.safe_dump(vars_dict), + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + first_model_run_result = first_model_run.results[0] + + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 1 + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + assert len(athena_running_create_statements) == 1 + + incremental_model_run_result_table_name = extract_create_statement_table_names( + athena_running_create_statements[0] + )[0] + + assert temp_schema_name not in incremental_model_run_result_table_name + + vars_dict["logical_date"] = "2024-01-02" + incremental_model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + yaml.safe_dump(vars_dict), + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + incremental_model_run_result = incremental_model_run.results[0] + + assert incremental_model_run_result.status == RunStatus.Success + + records_count_incremental_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_incremental_run == 2 + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_create_statements(out, relation_name) + + assert len(athena_running_create_statements) == 1 + + incremental_model_run_result_table_name = extract_create_statement_table_names( + athena_running_create_statements[0] + )[0] + + assert temp_schema_name == incremental_model_run_result_table_name.split(".")[1].strip('"') + + project.run_sql(drop_temp_schema) diff --git a/tests/functional/adapter/test_unique_tmp_table_suffix.py b/tests/functional/adapter/test_unique_tmp_table_suffix.py index 1f6dcec3..563e5dcb 100644 --- a/tests/functional/adapter/test_unique_tmp_table_suffix.py +++ b/tests/functional/adapter/test_unique_tmp_table_suffix.py @@ -1,8 +1,10 @@ -import json import re -from typing import List import pytest +from tests.functional.adapter.utils.parse_dbt_run_output import ( + extract_create_statement_table_names, + extract_running_create_statements, +) from dbt.contracts.results import RunStatus from dbt.tests.util import run_dbt @@ -21,39 +23,6 @@ """ -def extract_running_create_statements(dbt_run_capsys_output: str) -> List[str]: - sql_create_statements = [] - # Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..." - for events_msg in dbt_run_capsys_output.split("\n")[1:]: - base_msg_data = None - # Best effort solution to avoid invalid records and blank lines - try: - base_msg_data = json.loads(events_msg).get("data") - except json.JSONDecodeError: - pass - """First run will not produce data.sql object in the execution logs, only data.base_msg - containing the "Running Athena query:" initial create statement. - Subsequent incremental runs will only contain the insert from the tmp table into the model - table destination. - Since we want to compare both run create statements, we need to handle both cases""" - if base_msg_data: - base_msg = base_msg_data.get("base_msg") - if "Running Athena query:" in str(base_msg): - if "create table" in base_msg: - sql_create_statements.append(base_msg) - - if base_msg_data.get("conn_name") == "model.test.unique_tmp_table_suffix" and "sql" in base_msg_data: - if "create table" in base_msg_data.get("sql"): - sql_create_statements.append(base_msg_data.get("sql")) - - return sql_create_statements - - -def extract_create_statement_table_names(sql_create_statement: str) -> List[str]: - table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement) - return [table_name.rstrip() for table_name in table_names] - - class TestUniqueTmpTableSuffix: @pytest.fixture(scope="class") def models(self): @@ -86,7 +55,7 @@ def test__unique_tmp_table_suffix(self, project, capsys): assert first_model_run_result.status == RunStatus.Success out, _ = capsys.readouterr() - athena_running_create_statements = extract_running_create_statements(out) + athena_running_create_statements = extract_running_create_statements(out, relation_name) assert len(athena_running_create_statements) == 1 @@ -118,7 +87,7 @@ def test__unique_tmp_table_suffix(self, project, capsys): assert incremental_model_run_result.status == RunStatus.Success out, _ = capsys.readouterr() - athena_running_create_statements = extract_running_create_statements(out) + athena_running_create_statements = extract_running_create_statements(out, relation_name) assert len(athena_running_create_statements) == 1 @@ -150,7 +119,7 @@ def test__unique_tmp_table_suffix(self, project, capsys): assert incremental_model_run_result.status == RunStatus.Success out, _ = capsys.readouterr() - athena_running_create_statements = extract_running_create_statements(out) + athena_running_create_statements = extract_running_create_statements(out, relation_name) incremental_model_run_result_table_name_2 = extract_create_statement_table_names( athena_running_create_statements[0] diff --git a/tests/functional/adapter/utils/parse_dbt_run_output.py b/tests/functional/adapter/utils/parse_dbt_run_output.py new file mode 100644 index 00000000..4f448420 --- /dev/null +++ b/tests/functional/adapter/utils/parse_dbt_run_output.py @@ -0,0 +1,36 @@ +import json +import re +from typing import List + + +def extract_running_create_statements(dbt_run_capsys_output: str, relation_name: str) -> List[str]: + sql_create_statements = [] + # Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..." + for events_msg in dbt_run_capsys_output.split("\n")[1:]: + base_msg_data = None + # Best effort solution to avoid invalid records and blank lines + try: + base_msg_data = json.loads(events_msg).get("data") + except json.JSONDecodeError: + pass + """First run will not produce data.sql object in the execution logs, only data.base_msg + containing the "Running Athena query:" initial create statement. + Subsequent incremental runs will only contain the insert from the tmp table into the model + table destination. + Since we want to compare both run create statements, we need to handle both cases""" + if base_msg_data: + base_msg = base_msg_data.get("base_msg") + if "Running Athena query:" in str(base_msg): + if "create table" in base_msg: + sql_create_statements.append(base_msg) + + if base_msg_data.get("conn_name") == f"model.test.{relation_name}" and "sql" in base_msg_data: + if "create table" in base_msg_data.get("sql"): + sql_create_statements.append(base_msg_data.get("sql")) + + return sql_create_statements + + +def extract_create_statement_table_names(sql_create_statement: str) -> List[str]: + table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement) + return [table_name.rstrip() for table_name in table_names]