Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate SnowflakeSqlApiOperatorAsync #1447

Merged
merged 20 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 15 additions & 183 deletions astronomer/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@

import logging
import typing
import warnings
from contextlib import closing
from datetime import timedelta
from typing import Any, Callable, List

import requests
from airflow.exceptions import AirflowException

from snowflake.connector import SnowflakeConnection
from snowflake.connector.constants import QueryStatus

try:
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator, SnowflakeSqlApiOperator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'll need to move this SnowflakeSqlApiOperator import outside of this block

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove SnowflakeOperator here. We've update our snowflake provider version

except ImportError: # pragma: no cover
# For apache-airflow-providers-snowflake > 3.3.0
# currently added type: ignore[no-redef, attr-defined] and pragma: no cover because this import
Expand All @@ -26,11 +25,7 @@
SnowflakeHookAsync,
fetch_all_snowflake_handler,
)
from astronomer.providers.snowflake.hooks.snowflake_sql_api import (
SnowflakeSqlApiHookAsync,
)
from astronomer.providers.snowflake.triggers.snowflake_trigger import (
SnowflakeSqlApiTrigger,
SnowflakeTrigger,
get_db_hook,
)
Expand Down Expand Up @@ -224,183 +219,20 @@ def execute_complete(self, context: Context, event: dict[str, str | list[str]] |
raise AirflowException("Did not receive valid event from the trigerrer")


class SnowflakeSqlApiOperatorAsync(SnowflakeOperator):
class SnowflakeSqlApiOperatorAsync(SnowflakeSqlApiOperator):
"""
This class is deprecated and will be removed in 2.0.0.
vatsrahul1001 marked this conversation as resolved.
Show resolved Hide resolved
Use :class: `~airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator` instead.
vatsrahul1001 marked this conversation as resolved.
Show resolved Hide resolved
"""
Implemented Async Snowflake SQL API Operator to support multiple SQL statements sequentially,
which is the behavior of the SnowflakeOperator, the Snowflake SQL API allows submitting
multiple SQL statements in a single request. In combination with aiohttp, make post request to submit SQL
statements for execution, poll to check the status of the execution of a statement. Fetch query results
concurrently.
This Operator currently uses key pair authentication, so you need tp provide private key raw content or
private key file path in the snowflake connection along with other details

.. seealso::

`Snowflake SQL API key pair Authentication <https://docs.snowflake.com/en/developer-guide/sql-api/authenticating.html#label-sql-api-authenticating-key-pair>`_

Where can this operator fit in?
- To execute multiple SQL statements in a single request
- To execute the SQL statement asynchronously and to execute standard queries and most DDL and DML statements
- To develop custom applications and integrations that perform queries
- To create provision users and roles, create table, etc.

The following commands are not supported:
- The PUT command (in Snowflake SQL)
- The GET command (in Snowflake SQL)
- The CALL command with stored procedures that return a table(stored procedures with the RETURNS TABLE clause).

.. seealso::

- `Snowflake SQL API <https://docs.snowflake.com/en/developer-guide/sql-api/intro.html#introduction-to-the-sql-api>`_
- `API Reference <https://docs.snowflake.com/en/developer-guide/sql-api/reference.html#snowflake-sql-api-reference>`_
- `Limitation on snowflake SQL API <https://docs.snowflake.com/en/developer-guide/sql-api/intro.html#limitations-of-the-sql-api>`_

:param snowflake_conn_id: Reference to Snowflake connection id
:param sql: the sql code to be executed. (templated)
:param autocommit: if True, each command is automatically committed.
(default value: True)
:param parameters: (optional) the parameters to render the SQL query with.
:param warehouse: name of warehouse (will overwrite any warehouse
defined in the connection's extra JSON)
:param database: name of database (will overwrite database defined
in connection)
:param schema: name of schema (will overwrite schema defined in
connection)
:param role: name of role (will overwrite any role defined in
connection's extra JSON)
:param authenticator: authenticator for Snowflake.
'snowflake' (default) to use the internal Snowflake authenticator
'externalbrowser' to authenticate using your web browser and
Okta, ADFS or any other SAML 2.0-compliant identify provider
(IdP) that has been defined for your account
'https://<your_okta_account_name>.okta.com' to authenticate
through native Okta.
:param session_parameters: You can set session-level parameters at
the time you connect to Snowflake
:param poll_interval: the interval in seconds to poll the query
:param statement_count: Number of SQL statement to be executed
:param token_life_time: lifetime of the JWT Token
:param token_renewal_delta: Renewal time of the JWT Token
:param bindings: (Optional) Values of bind variables in the SQL statement.
When executing the statement, Snowflake replaces placeholders (? and :name) in
the statement with these specified values.
""" # noqa

LIFETIME = timedelta(minutes=59) # The tokens will have a 59 minutes lifetime
RENEWAL_DELTA = timedelta(minutes=54) # Tokens will be renewed after 54 minutes

def __init__(
self,
*,
snowflake_conn_id: str = "snowflake_default",
warehouse: str | None = None,
database: str | None = None,
role: str | None = None,
schema: str | None = None,
authenticator: str | None = None,
session_parameters: dict[str, Any] | None = None,
poll_interval: int = 5,
statement_count: int = 0,
token_life_time: timedelta = LIFETIME,
token_renewal_delta: timedelta = RENEWAL_DELTA,
bindings: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.warehouse = warehouse
self.database = database
self.role = role
self.schema = schema
self.authenticator = authenticator
self.session_parameters = session_parameters
self.snowflake_conn_id = snowflake_conn_id
self.poll_interval = poll_interval
self.statement_count = statement_count
self.token_life_time = token_life_time
self.token_renewal_delta = token_renewal_delta
self.bindings = bindings
self.execute_async = False
if self.__class__.__base__.__name__ != "SnowflakeOperator": # type: ignore[union-attr]
# It's better to do str check of the parent class name because currently SnowflakeOperator
# is deprecated and in future OSS SnowflakeOperator may be removed
if any(
[warehouse, database, role, schema, authenticator, session_parameters]
): # pragma: no cover
hook_params = kwargs.pop("hook_params", {}) # pragma: no cover
kwargs["hook_params"] = {
"warehouse": warehouse,
"database": database,
"role": role,
"schema": schema,
"authenticator": authenticator,
"session_parameters": session_parameters,
**hook_params,
}
super().__init__(conn_id=snowflake_conn_id, **kwargs) # pragma: no cover
else:
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
"""
Make a POST API request to snowflake by using SnowflakeSQL and execute the query to get the ids.
By deferring the SnowflakeSqlApiTrigger class passed along with query ids.
"""
self.log.info("Executing: %s", self.sql)
hook = SnowflakeSqlApiHookAsync(
snowflake_conn_id=self.snowflake_conn_id,
token_life_time=self.token_life_time,
token_renewal_delta=self.token_renewal_delta,
)
hook.execute_query(self.sql, statement_count=self.statement_count, bindings=self.bindings)
self.query_ids = hook.query_ids
self.log.info("List of query ids %s", self.query_ids)

if self.do_xcom_push:
context["ti"].xcom_push(key="query_ids", value=self.query_ids)

succeeded_query_ids = []
for query_id in self.query_ids:
self.log.info("Retrieving status for query id %s", query_id)
header, params, url = hook.get_request_url_header_params(query_id)
with requests.session() as session:
session.headers = header
with session.get(url, params=params) as resp:
event = hook.process_query_status_response(resp.json(), resp.status_code)
if resp.status_code == 202:
break
elif resp.status_code == 200:
succeeded_query_ids.append(query_id)
else:
raise AirflowException(f"{event['status']}: {event['message']}")

if len(self.query_ids) == len(succeeded_query_ids):
self.log.info("%s completed successfully.", self.task_id)
return

self.defer(
timeout=self.execution_timeout,
trigger=SnowflakeSqlApiTrigger(
poll_interval=self.poll_interval,
query_ids=self.query_ids,
snowflake_conn_id=self.snowflake_conn_id,
token_life_time=self.token_life_time,
token_renewal_delta=self.token_renewal_delta,
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.snowflake.operators.snowflake.SnowflakeSqlApiOperator` "
"and set `deferrable` param to `True` instead."
),
method_name="execute_complete",
DeprecationWarning,
stacklevel=2,
)

def execute_complete(self, context: Context, event: dict[str, str | list[str]] | None = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
raise AirflowException(f"{event['status']}: {event['message']}")
elif "status" in event and event["status"] == "success":
hook = SnowflakeSqlApiHookAsync(snowflake_conn_id=self.snowflake_conn_id)
query_ids = typing.cast(List[str], event["statement_query_ids"])
hook.check_query_output(query_ids)
self.log.info("%s completed successfully.", self.task_id)
else:
self.log.info("%s completed successfully.", self.task_id)
super().__init__(*args, deferrable=True, **kwargs)
19 changes: 11 additions & 8 deletions astronomer/providers/snowflake/triggers/snowflake_trigger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import warnings
from datetime import timedelta
from typing import Any, AsyncIterator

Expand Down Expand Up @@ -81,14 +82,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]:

class SnowflakeSqlApiTrigger(BaseTrigger):
"""
SnowflakeSqlApi Trigger inherits from the BaseTrigger,it is fired as
deferred class with params to run the task in trigger worker and
fetch the status for the query ids passed

:param task_id: Reference to task id of the Dag
:param poll_interval: polling period in seconds to check for the status
:param query_ids: List of Query ids to run and poll for the status
:param snowflake_conn_id: Reference to Snowflake connection id
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger` instead.
"""

def __init__(
Expand All @@ -99,6 +94,14 @@ def __init__(
token_life_time: timedelta,
token_renewal_delta: timedelta,
):
warnings.warn(
(
"This class is deprecated and will be removed in 2.0.0."
"Use :class: `~airflow.providers.snowflake.triggers.snowflake_trigger.SnowflakeSqlApiTrigger` instead"
vatsrahul1001 marked this conversation as resolved.
Show resolved Hide resolved
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.poll_interval = poll_interval
self.query_ids = query_ids
Expand Down
Loading