Skip to content

Commit

Permalink
chore: upgrade to dbt 1.8 (#614)
Browse files Browse the repository at this point in the history
Co-authored-by: nicor88 <[email protected]>
  • Loading branch information
Jrmyy and nicor88 authored Apr 15, 2024
1 parent b3f85b9 commit e458b59
Show file tree
Hide file tree
Showing 22 changed files with 182 additions and 392 deletions.
9 changes: 3 additions & 6 deletions dbt/adapters/athena/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import dbt
from dbt.adapters.athena.connections import AthenaConnectionManager, AthenaCredentials
from dbt.adapters.athena.impl import AthenaAdapter
from dbt.adapters.athena.query_headers import _QueryComment
from dbt.adapters.base import AdapterPlugin
from dbt.include import athena

Plugin = AdapterPlugin(adapter=AthenaAdapter, credentials=AthenaCredentials, include_path=athena.PACKAGE_PATH)

# overwrite _QueryComment to add leading "--" to query comment
dbt.adapters.base.query_headers._QueryComment = _QueryComment
Plugin: AdapterPlugin = AdapterPlugin(
adapter=AthenaAdapter, credentials=AthenaCredentials, include_path=athena.PACKAGE_PATH
)

__all__ = [
"AthenaConnectionManager",
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/athena/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.7.2"
version = "1.8.0b1"
3 changes: 2 additions & 1 deletion dbt/adapters/athena/column.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass
from typing import ClassVar, Dict

from dbt_common.exceptions import DbtRuntimeError

from dbt.adapters.athena.relation import TableType
from dbt.adapters.base.column import Column
from dbt.exceptions import DbtRuntimeError


@dataclass
Expand Down
16 changes: 12 additions & 4 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Any, ContextManager, Dict, List, Optional, Tuple

import tenacity
from dbt_common.exceptions import ConnectionError, DbtRuntimeError
from dbt_common.utils import md5
from pyathena.connection import Connection as AthenaConnection
from pyathena.cursor import Cursor
from pyathena.error import OperationalError, ProgrammingError
Expand All @@ -29,12 +31,15 @@

from dbt.adapters.athena.config import get_boto3_config
from dbt.adapters.athena.constants import LOGGER
from dbt.adapters.athena.query_headers import AthenaMacroQueryStringSetter
from dbt.adapters.athena.session import get_boto3_session
from dbt.adapters.base import Credentials
from dbt.adapters.contracts.connection import (
AdapterResponse,
Connection,
ConnectionState,
Credentials,
)
from dbt.adapters.sql import SQLConnectionManager
from dbt.contracts.connection import AdapterResponse, Connection, ConnectionState
from dbt.exceptions import ConnectionError, DbtRuntimeError
from dbt.utils import md5


@dataclass
Expand Down Expand Up @@ -201,6 +206,9 @@ def inner() -> AthenaCursor:
class AthenaConnectionManager(SQLConnectionManager):
TYPE = "athena"

def set_query_header(self, query_header_context: Dict[str, Any]) -> None:
self.query_header = AthenaMacroQueryStringSetter(self.profile, query_header_context)

@classmethod
def data_type_code_to_name(cls, type_code: str) -> str:
"""
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/athena/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dbt.events import AdapterLogger
from dbt.adapters.events.logging import AdapterLogger

DEFAULT_THREAD_COUNT = 4
DEFAULT_RETRY_ATTEMPTS = 3
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/athena/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dbt.exceptions import CompilationError, DbtRuntimeError
from dbt_common.exceptions import CompilationError, DbtRuntimeError


class SnapshotMigrationRequired(CompilationError):
Expand Down
89 changes: 26 additions & 63 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
from dataclasses import dataclass
from datetime import date, datetime
from functools import lru_cache
from itertools import chain
from textwrap import dedent
from threading import Lock
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type
from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type
from urllib.parse import urlparse
from uuid import uuid4

import agate
import mmh3
from botocore.exceptions import ClientError
from dbt_common.clients.agate_helper import table_from_rows
from dbt_common.contracts.constraints import ConstraintType
from dbt_common.exceptions import DbtRuntimeError
from mypy_boto3_athena import AthenaClient
from mypy_boto3_athena.type_defs import DataCatalogTypeDef, GetWorkGroupOutputTypeDef
from mypy_boto3_glue.type_defs import (
Expand Down Expand Up @@ -65,13 +67,9 @@
from dbt.adapters.base import ConstraintSupport, PythonJobHelper, available
from dbt.adapters.base.impl import AdapterConfig
from dbt.adapters.base.relation import BaseRelation, InformationSchema
from dbt.adapters.contracts.connection import AdapterResponse
from dbt.adapters.contracts.relation import RelationConfig
from dbt.adapters.sql import SQLAdapter
from dbt.clients.agate_helper import table_from_rows
from dbt.config.runtime import RuntimeConfig
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import CompiledNode, ConstraintType
from dbt.exceptions import DbtRuntimeError

boto3_client_lock = Lock()

Expand Down Expand Up @@ -538,29 +536,6 @@ def _s3_path_exists(self, s3_bucket: str, s3_prefix: str) -> bool:
response = s3_client.list_objects_v2(Bucket=s3_bucket, Prefix=s3_prefix)
return True if "Contents" in response else False

def _join_catalog_table_owners(self, table: agate.Table, manifest: Manifest) -> agate.Table:
owners = []
# Get the owner for each model from the manifest
for node in manifest.nodes.values():
if node.resource_type == "model":
owners.append(
{
"table_database": node.database,
"table_schema": node.schema,
"table_name": node.alias,
"table_owner": node.config.meta.get("owner"),
}
)
owners_table = agate.Table.from_object(owners)

# Join owners with the results from catalog
join_keys = ["table_database", "table_schema", "table_name"]
return table.join(
right_table=owners_table,
left_key=join_keys,
right_key=join_keys,
)

def _get_one_table_for_catalog(self, table: TableTypeDef, database: str) -> List[Dict[str, Any]]:
table_catalog = {
"table_database": database,
Expand Down Expand Up @@ -608,13 +583,13 @@ def _get_one_table_for_non_glue_catalog(
def _get_one_catalog(
self,
information_schema: InformationSchema,
schemas: Dict[str, Optional[Set[str]]],
manifest: Manifest,
schemas: Set[str],
used_schemas: FrozenSet[Tuple[str, str]],
) -> agate.Table:
"""
This function is invoked by Adapter.get_catalog for each schema.
"""
data_catalog = self._get_data_catalog(information_schema.path.database)
data_catalog = self._get_data_catalog(information_schema.database)
data_catalog_type = get_catalog_type(data_catalog)

conn = self.connections.get_thread_connection()
Expand All @@ -630,7 +605,7 @@ def _get_one_catalog(

catalog = []
paginator = glue_client.get_paginator("get_tables")
for schema, relations in schemas.items():
for schema in schemas:
kwargs = {
"DatabaseName": schema,
"MaxResults": 100,
Expand All @@ -643,8 +618,7 @@ def _get_one_catalog(

for page in paginator.paginate(**kwargs):
for table in page["TableList"]:
if relations and table["Name"] in relations:
catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database))
catalog.extend(self._get_one_table_for_catalog(table, information_schema.database))
table = agate.Table.from_object(catalog)
else:
with boto3_client_lock:
Expand All @@ -656,36 +630,28 @@ def _get_one_catalog(

catalog = []
paginator = athena_client.get_paginator("list_table_metadata")
for schema, relations in schemas.items():
for schema in schemas:
for page in paginator.paginate(
CatalogName=information_schema.path.database,
CatalogName=information_schema.database,
DatabaseName=schema,
MaxResults=50, # Limit supported by this operation
):
for table in page["TableMetadataList"]:
if relations and table["Name"].lower() in relations:
catalog.extend(
self._get_one_table_for_non_glue_catalog(
table, schema, information_schema.path.database
)
)
catalog.extend(
self._get_one_table_for_non_glue_catalog(table, schema, information_schema.database)
)
table = agate.Table.from_object(catalog)

filtered_table = self._catalog_filter_table(table, manifest)
return self._join_catalog_table_owners(filtered_table, manifest)
return self._catalog_filter_table(table, used_schemas)

def _get_catalog_schemas(self, manifest: Manifest) -> AthenaSchemaSearchMap:
def _get_catalog_schemas(self, relation_configs: Iterable[RelationConfig]) -> AthenaSchemaSearchMap:
"""
Get the schemas from the catalog.
It's called by the `get_catalog` method.
"""
info_schema_name_map = AthenaSchemaSearchMap()
nodes: Iterator[CompiledNode] = chain(
[node for node in manifest.nodes.values() if (node.is_relational and not node.is_ephemeral_model)],
manifest.sources.values(),
)
for node in nodes:
relation = self.Relation.create_from(self.config, node)
for relation_config in relation_configs:
relation = self.Relation.create_from(quoting=self.config, relation_config=relation_config)
info_schema_name_map.add(relation)
return info_schema_name_map

Expand Down Expand Up @@ -775,9 +741,9 @@ def list_relations_without_caching(self, schema_relation: AthenaRelation) -> Lis
def _get_one_catalog_by_relations(
self,
information_schema: InformationSchema,
relations: List[BaseRelation],
manifest: Manifest,
) -> agate.Table:
relations: List[AthenaRelation],
used_schemas: FrozenSet[Tuple[str, str]],
) -> "agate.Table":
"""
Overwrite of _get_one_catalog_by_relations for Athena, in order to use glue apis.
This function is invoked by Adapter.get_catalog_by_relations.
Expand All @@ -790,12 +756,11 @@ def _get_one_catalog_by_relations(
_table_definitions.extend(_table_definition)
table = agate.Table.from_object(_table_definitions)
# picked from _catalog_filter_table, force database + schema to be strings
table_casted = table_from_rows(
return table_from_rows(
table.rows,
table.column_names,
text_only_columns=["table_database", "table_schema", "table_name"],
)
return self._join_catalog_table_owners(table_casted, manifest)

@available
def swap_table(self, src_relation: AthenaRelation, target_relation: AthenaRelation) -> None:
Expand Down Expand Up @@ -1012,11 +977,9 @@ def persist_docs_to_glue(
# Add some of dbt model config fields as table meta
meta["unique_id"] = model.get("unique_id")
meta["materialized"] = model.get("config", {}).get("materialized")
# Get dbt runtime config to be able to get dbt project metadata
runtime_config: RuntimeConfig = self.config
# Add dbt project metadata to table meta
meta["dbt_project_name"] = runtime_config.project_name
meta["dbt_project_version"] = runtime_config.version
meta["dbt_project_name"] = self.config.project_name
meta["dbt_project_version"] = self.config.version
# Prepare meta values for table properties and check if update is required
for meta_key, meta_value_raw in meta.items():
if is_valid_table_parameter_key(meta_key):
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/athena/lakeformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Dict, List, Optional, Sequence, Set, Union

from dbt_common.exceptions import DbtRuntimeError
from mypy_boto3_lakeformation import LakeFormationClient
from mypy_boto3_lakeformation.type_defs import (
AddLFTagsToResourceResponseTypeDef,
Expand All @@ -16,8 +17,7 @@
from pydantic import BaseModel

from dbt.adapters.athena.relation import AthenaRelation
from dbt.events import AdapterLogger
from dbt.exceptions import DbtRuntimeError
from dbt.adapters.events.logging import AdapterLogger

logger = AdapterLogger("AthenaLakeFormation")

Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/athena/python_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Any, Dict

import botocore
from dbt_common.exceptions import DbtRuntimeError

from dbt.adapters.athena.config import AthenaSparkSessionConfig
from dbt.adapters.athena.connections import AthenaCredentials
from dbt.adapters.athena.constants import LOGGER
from dbt.adapters.athena.session import AthenaSparkSessionManager
from dbt.adapters.base import PythonJobHelper
from dbt.exceptions import DbtRuntimeError

SUBMISSION_LANGUAGE = "python"

Expand Down
13 changes: 11 additions & 2 deletions dbt/adapters/athena/query_headers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import dbt.adapters.base.query_headers
from typing import Any, Dict

from dbt.adapters.base.query_headers import MacroQueryStringSetter, _QueryComment
from dbt.adapters.contracts.connection import AdapterRequiredConfig

class _QueryComment(dbt.adapters.base.query_headers._QueryComment):

class AthenaMacroQueryStringSetter(MacroQueryStringSetter):
def __init__(self, config: AdapterRequiredConfig, query_header_context: Dict[str, Any]):
super().__init__(config, query_header_context)
self.comment = _AthenaQueryComment(None)


class _AthenaQueryComment(_QueryComment):
"""
Athena DDL does not always respect /* ... */ block quotations.
This function is the same as _QueryComment.add except that
Expand Down
6 changes: 3 additions & 3 deletions dbt/adapters/athena/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

import boto3
import boto3.session
from dbt_common.exceptions import DbtRuntimeError
from dbt_common.invocation import get_invocation_id

from dbt.adapters.athena.config import get_boto3_config
from dbt.adapters.athena.constants import (
DEFAULT_THREAD_COUNT,
LOGGER,
SESSION_IDLE_TIMEOUT_MIN,
)
from dbt.contracts.connection import Connection
from dbt.events.functions import get_invocation_id
from dbt.exceptions import DbtRuntimeError
from dbt.adapters.contracts.connection import Connection

invocation_id = get_invocation_id()
spark_session_list: Dict[UUID, str] = {}
Expand Down
1 change: 1 addition & 0 deletions dbt/include/athena/macros/utils/safe_cast.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-- TODO: make safe_cast supports complex structures
{% macro athena__safe_cast(field, type) -%}
try_cast({{field}} as {{type}})
{%- endmacro %}
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
autoflake~=2.3
black~=24.4
boto3-stubs[s3]~=1.34
dbt-tests-adapter~=1.7.11
dbt-tests-adapter~=1.8.0b1
flake8~=7.0
Flake8-pyproject~=1.2
isort~=5.13
Expand Down
Loading

0 comments on commit e458b59

Please sign in to comment.