Skip to content

Commit

Permalink
fix(ingest/snowflake): explicit set schema if public schema is absent (
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate authored Dec 28, 2023
1 parent 2cd38a4 commit e343b69
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

logger = logging.getLogger(__name__)

PUBLIC_SCHEMA = "PUBLIC"


class SnowflakeProfiler(GenericProfiler, SnowflakeCommonMixin):
def __init__(
Expand All @@ -36,6 +38,7 @@ def __init__(
self.config: SnowflakeV2Config = config
self.report: SnowflakeV2Report = report
self.logger = logger
self.database_default_schema: Dict[str, str] = dict()

def get_workunits(
self, database: SnowflakeDatabase, db_tables: Dict[str, List[SnowflakeTable]]
Expand All @@ -47,6 +50,10 @@ def get_workunits(
"max_overflow", self.config.profiling.max_workers
)

if PUBLIC_SCHEMA not in db_tables:
# If PUBLIC schema is absent, we use any one of schemas as default schema
self.database_default_schema[database.name] = list(db_tables.keys())[0]

profile_requests = []
for schema in database.schemas:
for table in db_tables[schema.name]:
Expand Down Expand Up @@ -136,9 +143,16 @@ def get_profiler_instance(
)

def callable_for_db_connection(self, db_name: str) -> Callable:
schema_name = self.database_default_schema.get(db_name)

def get_db_connection():
conn = self.config.get_connection()
conn.cursor().execute(SnowflakeQuery.use_database(db_name))

# As mentioned here - https://docs.snowflake.com/en/sql-reference/sql/use-database#usage-notes
# no schema is selected if PUBLIC schema is absent. We need to explicitly call `USE SCHEMA <schema>`
if schema_name:
conn.cursor().execute(SnowflakeQuery.use_schema(schema_name))
return conn

return get_db_connection
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def show_tags() -> str:
def use_database(db_name: str) -> str:
return f'use database "{db_name}"'

@staticmethod
def use_schema(schema_name: str) -> str:
return f'use schema "{schema_name}"'

@staticmethod
def get_databases(db_name: Optional[str]) -> str:
db_clause = f'"{db_name}".' if db_name is not None else ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class StatefulIngestionConfigBase(GenericModel, Generic[CustomConfig]):
)


class StatefulLineageConfigMixin:
class StatefulLineageConfigMixin(ConfigModel):
enable_stateful_lineage_ingestion: bool = Field(
default=True,
description="Enable stateful lineage ingestion."
Expand Down

0 comments on commit e343b69

Please sign in to comment.