Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(athena-adapter): support custom retry configurations for boto3 calls #494

Merged
merged 2 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add num_boto3_retries to athena credentials and enable passing it to …
…get_boto3_config. add unit test to cover implementation
  • Loading branch information
lukealexmiller committed Nov 9, 2023
commit f98064428c174cd92db02a881ff76f496ad31d48
7 changes: 5 additions & 2 deletions dbt/adapters/athena/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@


@lru_cache()
def get_boto3_config() -> config.Config:
return config.Config(user_agent_extra="dbt-athena-community/" + importlib.metadata.version("dbt-athena-community"))
def get_boto3_config(num_retries: int) -> config.Config:
return config.Config(
user_agent_extra="dbt-athena-community/" + importlib.metadata.version("dbt-athena-community"),
retries={"max_attempts": num_retries, "mode": "standard"},
)
7 changes: 6 additions & 1 deletion dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class AthenaCredentials(Credentials):
debug_query_state: bool = False
_ALIASES = {"catalog": "database"}
num_retries: int = 5
num_boto3_retries: Optional[int] = None
s3_data_dir: Optional[str] = None
s3_data_naming: Optional[str] = "schema_table_unique"
s3_tmp_table_dir: Optional[str] = None
Expand All @@ -74,6 +75,10 @@ def type(self) -> str:
def unique_field(self) -> str:
return f"athena-{hashlib.md5(self.s3_staging_dir.encode()).hexdigest()}"

@property
def get_effective_num_retries(self) -> int:
return self.num_boto3_retries if self.num_boto3_retries is not None else self.num_retries

def _connection_keys(self) -> Tuple[str, ...]:
return (
"s3_staging_dir",
Expand Down Expand Up @@ -235,7 +240,7 @@ def open(cls, connection: Connection) -> Connection:
attempt=creds.num_retries + 1,
exceptions=("ThrottlingException", "TooManyRequestsException", "InternalServerException"),
),
config=get_boto3_config(),
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

connection.state = ConnectionState.OPEN
Expand Down
142 changes: 121 additions & 21 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,16 @@ def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
@available
def add_lf_tags_to_database(self, relation: AthenaRelation) -> None:
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle
if lf_tags := conn.credentials.lf_tags_database:
config = LfTagsConfig(enabled=True, tags=lf_tags)
with boto3_client_lock:
lf_client = client.session.client("lakeformation", client.region_name, config=get_boto3_config())
lf_client = client.session.client(
"lakeformation",
client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)
manager = LfTagsManager(lf_client, relation, config)
manager.process_lf_tags_database()
else:
Expand All @@ -166,9 +171,14 @@ def add_lf_tags(self, relation: AthenaRelation, lf_tags_config: Dict[str, Any])
config = LfTagsConfig(**lf_tags_config)
if config.enabled:
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle
with boto3_client_lock:
lf_client = client.session.client("lakeformation", client.region_name, config=get_boto3_config())
lf_client = client.session.client(
"lakeformation",
client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)
manager = LfTagsManager(lf_client, relation, config)
manager.process_lf_tags()
return
Expand All @@ -179,9 +189,14 @@ def apply_lf_grants(self, relation: AthenaRelation, lf_grants_config: Dict[str,
lf_config = LfGrantsConfig(**lf_grants_config)
if lf_config.data_cell_filters.enabled:
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle
with boto3_client_lock:
lf = client.session.client("lakeformation", region_name=client.region_name, config=get_boto3_config())
lf = client.session.client(
"lakeformation",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)
catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(catalog)
lf_permissions = LfPermissions(catalog_id, relation, lf) # type: ignore
Expand All @@ -195,7 +210,11 @@ def is_work_group_output_location_enforced(self) -> bool:
client = conn.handle

with boto3_client_lock:
athena_client = client.session.client("athena", region_name=client.region_name, config=get_boto3_config())
athena_client = client.session.client(
"athena",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

if creds.work_group:
work_group = athena_client.get_work_group(WorkGroup=creds.work_group)
Expand Down Expand Up @@ -284,13 +303,18 @@ def get_glue_table(self, relation: AthenaRelation) -> Optional[GetTableResponseT
Helper function to get a relation via Glue
"""
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
glue_client = client.session.client(
"glue",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

try:
table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.identifier)
Expand Down Expand Up @@ -339,13 +363,18 @@ def get_glue_table_location(self, relation: AthenaRelation) -> Optional[str]:
@available
def clean_up_partitions(self, relation: AthenaRelation, where_condition: str) -> None:
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
glue_client = client.session.client(
"glue",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)
paginator = glue_client.get_paginator("get_partitions")
partition_params = {
"CatalogId": catalog_id,
Expand Down Expand Up @@ -381,6 +410,7 @@ def upload_seed_to_s3(
seed_s3_upload_args: Optional[Dict[str, Any]] = None,
) -> str:
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle

# TODO: consider using the workgroup default location when configured
Expand All @@ -393,7 +423,11 @@ def upload_seed_to_s3(
object_name = path.join(prefix, file_name)

with boto3_client_lock:
s3_client = client.session.client("s3", region_name=client.region_name, config=get_boto3_config())
s3_client = client.session.client(
"s3",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)
# This ensures cross-platform support, tempfile.NamedTemporaryFile does not
tmpfile = os.path.join(tempfile.gettempdir(), os.urandom(24).hex())
table.to_csv(tmpfile, quoting=csv.QUOTE_NONNUMERIC)
Expand All @@ -410,10 +444,15 @@ def delete_from_s3(self, s3_path: str) -> None:
a DbtRuntimeError in case it included errors.
"""
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle
bucket_name, prefix = self._parse_s3_path(s3_path)
if self._s3_path_exists(bucket_name, prefix):
s3_resource = client.session.resource("s3", region_name=client.region_name, config=get_boto3_config())
s3_resource = client.session.resource(
"s3",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)
s3_bucket = s3_resource.Bucket(bucket_name)
LOGGER.debug(f"Deleting table data: path='{s3_path}', bucket='{bucket_name}', prefix='{prefix}'")
response = s3_bucket.objects.filter(Prefix=prefix).delete()
Expand Down Expand Up @@ -449,9 +488,14 @@ def _parse_s3_path(s3_path: str) -> Tuple[str, str]:
def _s3_path_exists(self, s3_bucket: str, s3_prefix: str) -> bool:
"""Checks whether a given s3 path exists."""
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle
with boto3_client_lock:
s3_client = client.session.client("s3", region_name=client.region_name, config=get_boto3_config())
s3_client = client.session.client(
"s3",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)
response = s3_client.list_objects_v2(Bucket=s3_bucket, Prefix=s3_prefix)
return True if "Contents" in response else False

Expand Down Expand Up @@ -535,10 +579,15 @@ def _get_one_catalog(
data_catalog_type = get_catalog_type(data_catalog)

conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle
if data_catalog_type == AthenaCatalogType.GLUE:
with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
glue_client = client.session.client(
"glue",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

catalog = []
paginator = glue_client.get_paginator("get_tables")
Expand All @@ -561,7 +610,9 @@ def _get_one_catalog(
else:
with boto3_client_lock:
athena_client = client.session.client(
"athena", region_name=client.region_name, config=get_boto3_config()
"athena",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

catalog = []
Expand Down Expand Up @@ -602,14 +653,23 @@ def _get_catalog_schemas(self, manifest: Manifest) -> AthenaSchemaSearchMap:
def _get_data_catalog(self, database: str) -> Optional[DataCatalogTypeDef]:
if database:
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle
if database.lower() == "awsdatacatalog":
with boto3_client_lock:
sts = client.session.client("sts", region_name=client.region_name, config=get_boto3_config())
sts = client.session.client(
"sts",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)
catalog_id = sts.get_caller_identity()["Account"]
return {"Name": database, "Type": "GLUE", "Parameters": {"catalog-id": catalog_id}}
with boto3_client_lock:
athena = client.session.client("athena", region_name=client.region_name, config=get_boto3_config())
athena = client.session.client(
"athena",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)
return athena.get_data_catalog(Name=database)["DataCatalog"]
return None

Expand All @@ -621,9 +681,14 @@ def list_relations_without_caching(self, schema_relation: AthenaRelation) -> Lis
return super().list_relations_without_caching(schema_relation) # type: ignore

conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle
with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
glue_client = client.session.client(
"glue",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)
paginator = glue_client.get_paginator("get_tables")

kwargs = {
Expand Down Expand Up @@ -694,13 +759,18 @@ def _get_one_catalog_by_relations(
@available
def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelation) -> None:
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle

data_catalog = self._get_data_catalog(src_relation.database)
src_catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
glue_client = client.session.client(
"glue",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

src_table = glue_client.get_table(
CatalogId=src_catalog_id, DatabaseName=src_relation.schema, Name=src_relation.identifier
Expand Down Expand Up @@ -777,10 +847,15 @@ def _get_glue_table_versions_to_expire(self, relation: AthenaRelation, to_keep:
Given a table and the amount of its version to keep, it returns the versions to delete
"""
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
glue_client = client.session.client(
"glue",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

paginator = glue_client.get_paginator("get_table_versions")
response_iterator = paginator.paginate(
Expand All @@ -799,13 +874,18 @@ def expire_glue_table_versions(
self, relation: AthenaRelation, to_keep: int, delete_s3: bool
) -> List[TableVersionTypeDef]:
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
glue_client = client.session.client(
"glue",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

versions_to_delete = self._get_glue_table_versions_to_expire(relation, to_keep)
LOGGER.debug(f"Versions to delete: {[v['VersionId'] for v in versions_to_delete]}")
Expand Down Expand Up @@ -855,13 +935,18 @@ def persist_docs_to_glue(
Every dbt run should create not more than one table version.
"""
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
glue_client = client.session.client(
"glue",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

# By default, there is no need to update Glue Table
need_to_update_table = False
Expand Down Expand Up @@ -941,10 +1026,15 @@ def persist_docs_to_glue(
@available
def list_schemas(self, database: str) -> List[str]:
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
glue_client = client.session.client(
"glue",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

paginator = glue_client.get_paginator("get_databases")
result = []
Expand All @@ -964,13 +1054,18 @@ def _is_current_column(col: ColumnTypeDef) -> bool:
@available
def get_columns_in_relation(self, relation: AthenaRelation) -> List[AthenaColumn]:
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
glue_client = client.session.client(
"glue",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

try:
table = glue_client.get_table(CatalogId=catalog_id, DatabaseName=relation.schema, Name=relation.identifier)[
Expand Down Expand Up @@ -1001,13 +1096,18 @@ def delete_from_glue_catalog(self, relation: AthenaRelation) -> None:
table_name = relation.identifier

conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle

data_catalog = self._get_data_catalog(relation.database)
catalog_id = get_catalog_id(data_catalog)

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
glue_client = client.session.client(
"glue",
region_name=client.region_name,
config=get_boto3_config(num_retries=creds.get_effective_num_retries),
)

try:
glue_client.delete_table(CatalogId=catalog_id, DatabaseName=schema_name, Name=table_name)
Expand Down
Loading