From cddc1c8794e9e0b2f12faf383ddc5ddfff9238c0 Mon Sep 17 00:00:00 2001 From: nicor88 <6278547+nicor88@users.noreply.github.com> Date: Mon, 6 Feb 2023 14:37:04 +0100 Subject: [PATCH] chore: get boto3 client in function and cover some edge cases (#143) --- dbt/adapters/athena/impl.py | 19 +++++++++++-------- tests/unit/test_adapter.py | 10 ++++++++++ tests/unit/utils.py | 2 +- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index fc099742..3f469e3c 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -1,7 +1,7 @@ import posixpath as path from itertools import chain from threading import Lock -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple +from typing import Dict, Iterator, List, Optional, Set, Tuple from urllib.parse import urlparse from uuid import uuid4 @@ -144,7 +144,9 @@ def clean_up_partitions(self, database_name: str, table_name: str, where_conditi def clean_up_table(self, database_name: str, table_name: str): table_location = self.get_table_location(database_name, table_name) - if table_location is not None: + # this check avoid issues for when the table location is an empty string + # or when the table do not exist and table location is None + if table_location: self.delete_from_s3(table_location) @available @@ -161,7 +163,7 @@ def delete_from_s3(self, s3_path: str): conn = self.connections.get_thread_connection() client = conn.handle bucket_name, prefix = self._parse_s3_path(s3_path) - if self._s3_path_exists(client, bucket_name, prefix): + if self._s3_path_exists(bucket_name, prefix): s3_resource = client.session.resource("s3", region_name=client.region_name, config=get_boto3_config()) s3_bucket = s3_resource.Bucket(bucket_name) logger.debug(f"Deleting table data: path='{s3_path}', bucket='{bucket_name}', prefix='{prefix}'") @@ -195,12 +197,13 @@ def _parse_s3_path(s3_path: str) -> Tuple[str, str]: prefix = o.path.lstrip("/").rstrip("/") + "/" return bucket_name, prefix - @staticmethod - def _s3_path_exists(client: Any, s3_bucket: str, s3_prefix: str) -> bool: + def _s3_path_exists(self, s3_bucket: str, s3_prefix: str) -> bool: """Checks whether a given s3 path exists.""" - response = client.session.client( - "s3", region_name=client.region_name, config=get_boto3_config() - ).list_objects_v2(Bucket=s3_bucket, Prefix=s3_prefix) + conn = self.connections.get_thread_connection() + client = conn.handle + with boto3_client_lock: + s3_client = client.session.client("s3", region_name=client.region_name, config=get_boto3_config()) + response = s3_client.list_objects_v2(Bucket=s3_bucket, Prefix=s3_prefix) return True if "Contents" in response else False def _join_catalog_table_owners(self, table: agate.Table, manifest: Manifest) -> agate.Table: diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index a1c00ef7..58faa1ac 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -322,6 +322,16 @@ def test_clean_up_table_table_does_not_exist(self, dbt_debug_caplog, aws_credent assert result is None assert "Table 'table' does not exists - Ignoring" in dbt_debug_caplog.getvalue() + @mock_glue + @mock_athena + def test_clean_up_table_view(self, dbt_debug_caplog, aws_credentials): + self.mock_aws_service.create_data_catalog() + self.mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + self.mock_aws_service.create_view("test_view") + result = self.adapter.clean_up_table(DATABASE_NAME, "test_view") + assert result is None + @mock_glue @mock_s3 @mock_athena diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 6f5c916a..9a44f62e 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -161,7 +161,7 @@ def create_view(self, view_name: str): "Type": "date", }, ], - "Location": f"s3://{BUCKET}/tables/{view_name}", + "Location": "", }, "TableType": "VIRTUAL_VIEW", },