Skip to content

Commit

Permalink
Merge branch 'main' into temp-schema-for-materialized-tables
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrebzl authored Jul 4, 2024
2 parents 5bc3ce3 + 5b95fcc commit b05d1df
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion dbt/adapters/athena/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.8.2"
version = "1.8.3"
17 changes: 10 additions & 7 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from dbt_common.clients.agate_helper import table_from_rows
from dbt_common.contracts.constraints import ConstraintType
from dbt_common.exceptions import DbtRuntimeError
from mypy_boto3_athena import AthenaClient
from mypy_boto3_athena.type_defs import DataCatalogTypeDef, GetWorkGroupOutputTypeDef
from mypy_boto3_glue.type_defs import (
ColumnTypeDef,
Expand Down Expand Up @@ -219,14 +218,11 @@ def apply_lf_grants(self, relation: AthenaRelation, lf_grants_config: Dict[str,
lf_permissions.process_permissions(lf_config)

@lru_cache()
def _get_work_group(self, client: AthenaClient, work_group: str) -> GetWorkGroupOutputTypeDef:
def _get_work_group(self, work_group: str) -> GetWorkGroupOutputTypeDef:
"""
helper function to cache the result of the get_work_group to avoid APIs throttling
"""
return client.get_work_group(WorkGroup=work_group)

@available
def is_work_group_output_location_enforced(self) -> bool:
LOGGER.debug("get_work_group for %s", work_group)
conn = self.connections.get_thread_connection()
creds = conn.credentials
client = conn.handle
Expand All @@ -238,8 +234,15 @@ def is_work_group_output_location_enforced(self) -> bool:
config=get_boto3_config(num_retries=creds.effective_num_retries),
)

return athena_client.get_work_group(WorkGroup=work_group)

@available
def is_work_group_output_location_enforced(self) -> bool:
conn = self.connections.get_thread_connection()
creds = conn.credentials

if creds.work_group:
work_group = self._get_work_group(athena_client, creds.work_group)
work_group = self._get_work_group(creds.work_group)
output_location = (
work_group.get("WorkGroup", {})
.get("Configuration", {})
Expand Down
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dbt-tests-adapter~=1.9.1
flake8~=7.1
Flake8-pyproject~=1.2
isort~=5.13
moto~=5.0.9
moto~=5.0.10
pre-commit~=3.5
pyparsing~=3.1.2
pytest~=8.2
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/test_retries_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class TestIcebergRetriesEnabled:
def dbt_profile_target(self):
profile = copy.deepcopy(base_dbt_profile)
# we set the iceberg retries to the same number of parallelism to make sure that the retries are working
profile["num_iceberg_retries"] = PARALLELISM
profile["num_iceberg_retries"] = PARALLELISM * 2
return profile

@pytest.fixture(scope="class")
Expand Down

0 comments on commit b05d1df

Please sign in to comment.