Skip to content

Commit

Permalink
Support session reuse in RedshiftDataOperator (apache#42218)
Browse files Browse the repository at this point in the history
  • Loading branch information
borismo authored Sep 24, 2024
1 parent 9ec8753 commit 8580e6d
Show file tree
Hide file tree
Showing 12 changed files with 468 additions and 66 deletions.
27 changes: 27 additions & 0 deletions airflow/providers/amazon/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,33 @@
Changelog
---------

Main
......

Breaking changes
~~~~~~~~~~~~~~~~

.. warning::
In order to support session reuse in RedshiftData operators, the following breaking changes were introduced:

The ``database`` argument is now optional and as a result was moved after the ``sql`` argument which is a positional
one. Update your DAGs accordingly if they rely on argument order. Applies to:
* ``RedshiftDataHook``'s ``execute_query`` method
* ``RedshiftDataOperator``

``RedshiftDataHook``'s ``execute_query`` method now returns a ``QueryExecutionOutput`` object instead of just the
statement ID as a string.

``RedshiftDataHook``'s ``parse_statement_resposne`` method was renamed to ``parse_statement_response``.

``S3ToRedshiftOperator``'s ``schema`` argument is now optional and was moved after the ``s3_key`` positional argument.
Update your DAGs accordingly if they rely on argument order.

Features
~~~~~~~~

* ``Support session reuse in RedshiftDataOperator, RedshiftToS3Operator and S3ToRedshiftOperator (#42218)``

8.29.0
......

Expand Down
66 changes: 52 additions & 14 deletions airflow/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
from __future__ import annotations

import time
from dataclasses import dataclass
from pprint import pformat
from typing import TYPE_CHECKING, Any, Iterable
from uuid import UUID

from pendulum import duration

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils import trim_none_values
Expand All @@ -35,6 +39,14 @@
RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}


@dataclass
class QueryExecutionOutput:
"""Describes the output of a query execution."""

statement_id: str
session_id: str | None


class RedshiftDataQueryFailedError(ValueError):
"""Raise an error that redshift data query failed."""

Expand Down Expand Up @@ -65,8 +77,8 @@ def __init__(self, *args, **kwargs) -> None:

def execute_query(
self,
database: str,
sql: str | list[str],
database: str | None = None,
cluster_identifier: str | None = None,
db_user: str | None = None,
parameters: Iterable | None = None,
Expand All @@ -76,23 +88,28 @@ def execute_query(
wait_for_completion: bool = True,
poll_interval: int = 10,
workgroup_name: str | None = None,
) -> str:
session_id: str | None = None,
session_keep_alive_seconds: int | None = None,
) -> QueryExecutionOutput:
"""
Execute a statement against Amazon Redshift.
:param database: the name of the database
:param sql: the SQL statement or list of SQL statement to run
:param database: the name of the database
:param cluster_identifier: unique identifier of a cluster
:param db_user: the database username
:param parameters: the parameters for the SQL statement
:param secret_arn: the name or ARN of the secret that enables db access
:param statement_name: the name of the SQL statement
:param with_event: indicates whether to send an event to EventBridge
:param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait
:param with_event: whether to send an event to EventBridge
:param wait_for_completion: whether to wait for a result
:param poll_interval: how often in seconds to check the query status
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
:param session_id: the session identifier of the query
:param session_keep_alive_seconds: duration in seconds to keep the session alive after the query
finishes. The maximum time a session can keep alive is 24 hours
:returns statement_id: str, the UUID of the statement
"""
Expand All @@ -105,7 +122,28 @@ def execute_query(
"SecretArn": secret_arn,
"StatementName": statement_name,
"WorkgroupName": workgroup_name,
"SessionId": session_id,
"SessionKeepAliveSeconds": session_keep_alive_seconds,
}

if sum(x is not None for x in (cluster_identifier, workgroup_name, session_id)) != 1:
raise ValueError(
"Exactly one of cluster_identifier, workgroup_name, or session_id must be provided"
)

if session_id is not None:
msg = "session_id must be a valid UUID4"
try:
if UUID(session_id).version != 4:
raise ValueError(msg)
except ValueError:
raise ValueError(msg)

if session_keep_alive_seconds is not None and (
session_keep_alive_seconds < 0 or duration(seconds=session_keep_alive_seconds).hours > 24
):
raise ValueError("Session keep alive duration must be between 0 and 86400 seconds.")

if isinstance(sql, list):
kwargs["Sqls"] = sql
resp = self.conn.batch_execute_statement(**trim_none_values(kwargs))
Expand All @@ -115,13 +153,10 @@ def execute_query(

statement_id = resp["Id"]

if bool(cluster_identifier) is bool(workgroup_name):
raise ValueError("Either 'cluster_identifier' or 'workgroup_name' must be specified.")

if wait_for_completion:
self.wait_for_results(statement_id, poll_interval=poll_interval)

return statement_id
return QueryExecutionOutput(statement_id=statement_id, session_id=resp.get("SessionId"))

def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
while True:
Expand All @@ -135,9 +170,9 @@ def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
def check_query_is_finished(self, statement_id: str) -> bool:
"""Check whether query finished, raise exception is failed."""
resp = self.conn.describe_statement(Id=statement_id)
return self.parse_statement_resposne(resp)
return self.parse_statement_response(resp)

def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef) -> bool:
def parse_statement_response(self, resp: DescribeStatementResponseTypeDef) -> bool:
"""Parse the response of describe_statement."""
status = resp["Status"]
if status == FINISHED_STATE:
Expand Down Expand Up @@ -179,8 +214,10 @@ def get_table_primary_key(
:param table: Name of the target table
:param database: the name of the database
:param schema: Name of the target schema, public by default
:param sql: the SQL statement or list of SQL statement to run
:param cluster_identifier: unique identifier of a cluster
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
:param db_user: the database username
:param secret_arn: the name or ARN of the secret that enables db access
:param statement_name: the name of the SQL statement
Expand Down Expand Up @@ -212,7 +249,8 @@ def get_table_primary_key(
with_event=with_event,
wait_for_completion=wait_for_completion,
poll_interval=poll_interval,
)
).statement_id

pk_columns = []
token = ""
while True:
Expand Down Expand Up @@ -251,4 +289,4 @@ async def check_query_is_finished_async(self, statement_id: str) -> bool:
"""
async with self.async_conn as client:
resp = await client.describe_statement(Id=statement_id)
return self.parse_statement_resposne(resp)
return self.parse_statement_response(resp)
21 changes: 18 additions & 3 deletions airflow/providers/amazon/aws/operators/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
:param session_id: the session identifier of the query
:param session_keep_alive_seconds: duration in seconds to keep the session alive after the query
finishes. The maximum time a session can keep alive is 24 hours
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
:param verify: Whether to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
Expand All @@ -77,15 +80,16 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
"parameters",
"statement_name",
"workgroup_name",
"session_id",
)
template_ext = (".sql",)
template_fields_renderers = {"sql": "sql"}
statement_id: str | None

def __init__(
self,
database: str,
sql: str | list,
database: str | None = None,
cluster_identifier: str | None = None,
db_user: str | None = None,
parameters: list | None = None,
Expand All @@ -97,6 +101,8 @@ def __init__(
return_sql_result: bool = False,
workgroup_name: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
session_id: str | None = None,
session_keep_alive_seconds: int | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -120,6 +126,8 @@ def __init__(
self.return_sql_result = return_sql_result
self.statement_id: str | None = None
self.deferrable = deferrable
self.session_id = session_id
self.session_keep_alive_seconds = session_keep_alive_seconds

def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
"""Execute a statement against Amazon Redshift."""
Expand All @@ -130,7 +138,7 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
if self.deferrable:
wait_for_completion = False

self.statement_id = self.hook.execute_query(
query_execution_output = self.hook.execute_query(
database=self.database,
sql=self.sql,
cluster_identifier=self.cluster_identifier,
Expand All @@ -142,8 +150,15 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
with_event=self.with_event,
wait_for_completion=wait_for_completion,
poll_interval=self.poll_interval,
session_id=self.session_id,
session_keep_alive_seconds=self.session_keep_alive_seconds,
)

self.statement_id = query_execution_output.statement_id

if query_execution_output.session_id:
self.xcom_push(context, key="session_id", value=query_execution_output.session_id)

if self.deferrable and self.wait_for_completion:
is_finished = self.hook.check_query_is_finished(self.statement_id)
if not is_finished:
Expand Down
25 changes: 15 additions & 10 deletions airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class RedshiftToS3Operator(BaseOperator):
:param s3_key: reference to a specific S3 key. If ``table_as_file_name`` is set
to False, this param must include the desired file name
:param schema: reference to a specific schema in redshift database,
used when ``table`` param provided and ``select_query`` param not provided
used when ``table`` param provided and ``select_query`` param not provided.
Do not provide when unloading a temporary table
:param table: reference to a specific table in redshift database,
used when ``schema`` param provided and ``select_query`` param not provided
:param select_query: custom select query to fetch data from redshift database,
Expand All @@ -55,8 +56,8 @@ class RedshiftToS3Operator(BaseOperator):
If the AWS connection contains 'aws_iam_role' in ``extras``
the operator will use AWS STS credentials with a token
https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
:param verify: Whether to verify SSL certificates for S3 connection.
By default, SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
Expand All @@ -67,7 +68,7 @@ class RedshiftToS3Operator(BaseOperator):
CA cert bundle than the one used by botocore.
:param unload_options: reference to a list of UNLOAD options
:param autocommit: If set to True it will automatically commit the UNLOAD statement.
Otherwise it will be committed right before the redshift connection gets closed.
Otherwise, it will be committed right before the redshift connection gets closed.
:param include_header: If set to True the s3 file contains the header columns.
:param parameters: (optional) the parameters to render the SQL query with.
:param table_as_file_name: If set to True, the s3 file will be named as the table.
Expand Down Expand Up @@ -141,9 +142,15 @@ def _build_unload_query(

@property
def default_select_query(self) -> str | None:
if self.schema and self.table:
return f"SELECT * FROM {self.schema}.{self.table}"
return None
if not self.table:
return None

if self.schema:
table = f"{self.schema}.{self.table}"
else:
# Relevant when unloading a temporary table
table = self.table
return f"SELECT * FROM {table}"

def execute(self, context: Context) -> None:
if self.table and self.table_as_file_name:
Expand All @@ -152,9 +159,7 @@ def execute(self, context: Context) -> None:
self.select_query = self.select_query or self.default_select_query

if self.select_query is None:
raise ValueError(
"Please provide both `schema` and `table` params or `select_query` to fetch the data."
)
raise ValueError("Please specify either a table or `select_query` to fetch the data.")

if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
self.unload_options = [*self.unload_options, "HEADER"]
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/transfers/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
if TYPE_CHECKING:
from airflow.utils.context import Context


AVAILABLE_METHODS = ["APPEND", "REPLACE", "UPSERT"]


Expand All @@ -40,17 +39,18 @@ class S3ToRedshiftOperator(BaseOperator):
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:S3ToRedshiftOperator`
:param schema: reference to a specific schema in redshift database
:param table: reference to a specific table in redshift database
:param s3_bucket: reference to a specific S3 bucket
:param s3_key: key prefix that selects single or multiple objects from S3
:param schema: reference to a specific schema in redshift database.
Do not provide when copying into a temporary table
:param redshift_conn_id: reference to a specific redshift database OR a redshift data-api connection
:param aws_conn_id: reference to a specific S3 connection
If the AWS connection contains 'aws_iam_role' in ``extras``
the operator will use AWS STS credentials with a token
https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
:param verify: Whether to verify SSL certificates for S3 connection.
By default, SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
Expand Down Expand Up @@ -87,10 +87,10 @@ class S3ToRedshiftOperator(BaseOperator):
def __init__(
self,
*,
schema: str,
table: str,
s3_bucket: str,
s3_key: str,
schema: str | None = None,
redshift_conn_id: str = "redshift_default",
aws_conn_id: str | None = "aws_default",
verify: bool | str | None = None,
Expand Down Expand Up @@ -160,7 +160,7 @@ def execute(self, context: Context) -> None:
credentials_block = build_credentials_block(credentials)

copy_options = "\n\t\t\t".join(self.copy_options)
destination = f"{self.schema}.{self.table}"
destination = f"{self.schema}.{self.table}" if self.schema else self.table
copy_destination = f"#{self.table}" if self.method == "UPSERT" else destination

copy_statement = self._build_copy_query(
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/utils/openlineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def get_facets_from_redshift_table(
]
)
else:
statement_id = redshift_hook.execute_query(sql=sql, poll_interval=1, **redshift_data_api_kwargs)
statement_id = redshift_hook.execute_query(
sql=sql, poll_interval=1, **redshift_data_api_kwargs
).statement_id
response = redshift_hook.conn.get_statement_result(Id=statement_id)

table_schema = SchemaDatasetFacet(
Expand Down
Loading

0 comments on commit 8580e6d

Please sign in to comment.