Skip to content

Commit

Permalink
Deprecate RedshiftDataOperatorAsync, RedshiftClusterSensorAsync, S3Ke…
Browse files Browse the repository at this point in the history
…ySensorAsync, BatchOperatorAsync, SageMakerProcessingOperatorAsync, SageMakerTransformOperatorAsync (#1455)

* feat(providers/amazon): deprecate BatchOperatorAsync, SageMakerProcessingOperatorAsync, SageMakerTransformOperatorAsync, S3KeySensorAsync, RedshiftClusterSensorAsync, RedshiftDataOperatorAsync
* build(setup.cfg): upgrade min amazon provider version to 8.17.0

---------

Co-authored-by: Pankaj Koti <[email protected]>
  • Loading branch information
Lee-W and pankajkoti authored Jan 29, 2024
1 parent 1e94306 commit 52063a4
Show file tree
Hide file tree
Showing 20 changed files with 314 additions and 1,434 deletions.
32 changes: 23 additions & 9 deletions astronomer/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import asyncio
import warnings
from random import sample
from typing import Any, Dict, List, Optional, Union
from typing import Any

import botocore.exceptions
from airflow.exceptions import AirflowException
Expand All @@ -13,6 +16,9 @@ class BatchClientHookAsync(BatchClientHook, AwsBaseHookAsync):
"""
Async client for AWS Batch services.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook` instead
:param max_retries: exponential back-off retries, 4200 = 48 hours;
polling is only used when waiters is None
:param status_retries: number of HTTP retries to get job status, 10;
Expand Down Expand Up @@ -41,12 +47,20 @@ class BatchClientHookAsync(BatchClientHook, AwsBaseHookAsync):
- `Exponential Backoff And Jitter <https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/>`_
"""

def __init__(self, job_id: Optional[str], waiters: Any = None, *args: Any, **kwargs: Any) -> None:
def __init__(self, job_id: str | None, waiters: Any = None, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook`"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
self.job_id = job_id
self.waiters = waiters

async def monitor_job(self) -> Union[Dict[str, str], None]:
async def monitor_job(self) -> dict[str, str] | None:
"""
Monitor an AWS Batch job
monitor_job can raise an exception or an AirflowTaskTimeout can be raised if execution_timeout
Expand Down Expand Up @@ -92,7 +106,7 @@ async def check_job_success(self, job_id: str) -> bool: # type: ignore[override
raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}")

@staticmethod
async def delay(delay: Union[int, float, None] = None) -> None: # type: ignore[override]
async def delay(delay: int | (float | None) = None) -> None: # type: ignore[override]
"""
Pause execution for ``delay`` seconds.
Expand All @@ -116,7 +130,7 @@ async def delay(delay: Union[int, float, None] = None) -> None: # type: ignore[
delay = BatchClientHookAsync.add_jitter(delay)
await asyncio.sleep(delay)

async def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None: # type: ignore[override]
async def wait_for_job(self, job_id: str, delay: int | (float | None) = None) -> None: # type: ignore[override]
"""
Wait for Batch job to complete
Expand All @@ -131,7 +145,7 @@ async def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None)
self.log.info("AWS Batch job (%s) has completed", job_id)

async def poll_for_job_complete( # type: ignore[override]
self, job_id: str, delay: Union[int, float, None] = None
self, job_id: str, delay: int | (float | None) = None
) -> None:
"""
Poll for job completion. The status that indicates job completion
Expand All @@ -150,7 +164,7 @@ async def poll_for_job_complete( # type: ignore[override]
await self.poll_job_status(job_id, complete_status)

async def poll_for_job_running( # type: ignore[override]
self, job_id: str, delay: Union[int, float, None] = None
self, job_id: str, delay: int | (float | None) = None
) -> None:
"""
Poll for job running. The status that indicates a job is running or
Expand All @@ -172,7 +186,7 @@ async def poll_for_job_running( # type: ignore[override]
running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, self.FAILURE_STATE]
await self.poll_job_status(job_id, running_status)

async def get_job_description(self, job_id: str) -> Dict[str, str]: # type: ignore[override]
async def get_job_description(self, job_id: str) -> dict[str, str]: # type: ignore[override]
"""
Get job description (using status_retries).
Expand Down Expand Up @@ -210,7 +224,7 @@ async def get_job_description(self, job_id: str) -> Dict[str, str]: # type: ign
)
await self.delay(pause)

async def poll_job_status(self, job_id: str, match_status: List[str]) -> bool: # type: ignore[override]
async def poll_job_status(self, job_id: str, match_status: list[str]) -> bool: # type: ignore[override]
"""
Poll for job status using an exponential back-off strategy (with max_retries).
The Batch job status polled are:
Expand Down
13 changes: 13 additions & 0 deletions astronomer/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import warnings
from typing import Any, Iterable

import botocore.exceptions
Expand All @@ -18,6 +19,9 @@ class RedshiftDataHook(AwsBaseHook):
RedshiftDataHook inherits from AwsBaseHook to connect with AWS redshift
by using boto3 client_type as redshift-data we can interact with redshift cluster database and execute the query
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook` instead
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
Expand All @@ -34,6 +38,15 @@ class RedshiftDataHook(AwsBaseHook):
"""

def __init__(self, *args: Any, poll_interval: int = 0, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`"
),
DeprecationWarning,
stacklevel=2,
)

aws_connection_type: str = "redshift-data"
try:
# for apache-airflow-providers-amazon>=3.0.0
Expand Down
15 changes: 14 additions & 1 deletion astronomer/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import fnmatch
import os
import re
import warnings
from datetime import datetime
from functools import wraps
from inspect import signature
Expand Down Expand Up @@ -43,12 +44,24 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:


class S3HookAsync(AwsBaseHookAsync):
"""Interact with AWS S3, using the aiobotocore library."""
"""Interact with AWS S3, using the aiobotocore library.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.hooks.s3.S3Hook` instead
"""

conn_type = "s3"
hook_name = "S3"

def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.hooks.s3.S3Hook`"
),
DeprecationWarning,
stacklevel=2,
)
kwargs["client_type"] = "s3"
kwargs["resource_type"] = "s3"
super().__init__(*args, **kwargs)
Expand Down
42 changes: 28 additions & 14 deletions astronomer/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
import warnings
from typing import Any, AsyncGenerator

from airflow.providers.amazon.aws.hooks.sagemaker import (
LogState,
Expand All @@ -21,56 +24,67 @@ class SageMakerHookAsync(AwsBaseHookAsync):
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHookAsync.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook` instead
"""

NON_TERMINAL_STATES = ("InProgress", "Stopping", "Stopped")

def __init__(self, *args: Any, **kwargs: Any):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook`"
),
DeprecationWarning,
stacklevel=2,
)
kwargs["client_type"] = "sagemaker"
super().__init__(*args, **kwargs)
self.s3_hook = S3HookAsync(aws_conn_id=self.aws_conn_id)
self.logs_hook_async = AwsLogsHookAsync(aws_conn_id=self.aws_conn_id)

async def describe_transform_job_async(self, job_name: str) -> Dict[str, Any]:
async def describe_transform_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the transform job info associated with the name
:param job_name: the name of the transform job
"""
async with await self.get_client_async() as client:
response: Dict[str, Any] = await client.describe_transform_job(TransformJobName=job_name)
response: dict[str, Any] = await client.describe_transform_job(TransformJobName=job_name)
return response

async def describe_processing_job_async(self, job_name: str) -> Dict[str, Any]:
async def describe_processing_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the processing job info associated with the name
:param job_name: the name of the processing job
"""
async with await self.get_client_async() as client:
response: Dict[str, Any] = await client.describe_processing_job(ProcessingJobName=job_name)
response: dict[str, Any] = await client.describe_processing_job(ProcessingJobName=job_name)
return response

async def describe_training_job_async(self, job_name: str) -> Dict[str, Any]:
async def describe_training_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the training job info associated with the name
:param job_name: the name of the training job
"""
async with await self.get_client_async() as client:
response: Dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
return response

async def describe_training_job_with_log(
self,
job_name: str,
positions: Dict[str, Any],
stream_names: List[str],
positions: dict[str, Any],
stream_names: list[str],
instance_count: int,
state: int,
last_description: Dict[str, Any],
last_description: dict[str, Any],
last_describe_job_call: float,
) -> Tuple[int, Dict[str, Any], float]:
) -> tuple[int, dict[str, Any], float]:
"""
Return the training job info associated with job_name and print CloudWatch logs
Expand Down Expand Up @@ -127,8 +141,8 @@ async def describe_training_job_with_log(
return state, last_description, last_describe_job_call

async def get_multi_stream(
self, log_group: str, streams: List[str], positions: Dict[str, Any]
) -> AsyncGenerator[Any, Tuple[int, Optional[Any]]]:
self, log_group: str, streams: list[str], positions: dict[str, Any]
) -> AsyncGenerator[Any, tuple[int, Any | None]]:
"""
Iterate over the available events coming from a set of log streams in a single log group
interleaving the events from each stream so they're yielded in timestamp order.
Expand All @@ -140,7 +154,7 @@ async def get_multi_stream(
read from each stream.
"""
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
events: list[Optional[Any]] = []
events: list[Any | None] = []

event_iters = [
self.logs_hook_async.get_log_events_async(log_group, s, positions[s].timestamp, positions[s].skip)
Expand Down
102 changes: 18 additions & 84 deletions astronomer/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,95 +7,29 @@
- `Batch <http://boto3.readthedocs.io/en/latest/reference/services/batch.html>`_
- `Welcome <https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html>`_
"""
from typing import Any, Dict
from __future__ import annotations

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.batch import BatchOperator
import warnings
from typing import Any

from astronomer.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from astronomer.providers.utils.typing_compat import Context
from airflow.providers.amazon.aws.operators.batch import BatchOperator


class BatchOperatorAsync(BatchOperator):
"""
Execute a job asynchronously on AWS Batch
.. see also::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BatchOperator`
:param job_name: the name for the job that will run on AWS Batch (templated)
:param job_definition: the job definition name on AWS Batch
:param job_queue: the queue name on AWS Batch
:param overrides: Removed in apache-airflow-providers-amazon release 8.0.0, use container_overrides instead with the
same value.
:param container_overrides: the `containerOverrides` parameter for boto3 (templated)
:param array_properties: the `arrayProperties` parameter for boto3
:param parameters: the `parameters` for boto3 (templated)
:param job_id: the job ID, usually unknown (None) until the
submit_job operation gets the jobId defined by AWS Batch
:param waiters: an :py:class:`.BatchWaiters` object (see note below);
if None, polling is used with max_retries and status_retries.
:param max_retries: exponential back-off retries, 4200 = 48 hours;
polling is only used when waiters is None
:param status_retries: number of HTTP retries to get job status, 10;
polling is only used when waiters is None
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used.
:param region_name: region name to use in AWS Hook.
Override the region_name in connection (if provided)
:param tags: collection of tags to apply to the AWS Batch job submission
if None, no tags are submitted
.. note::
Any custom waiters must return a waiter for these calls:
| ``waiter = waiters.get_waiter("JobExists")``
| ``waiter = waiters.get_waiter("JobRunning")``
| ``waiter = waiters.get_waiter("JobComplete")``
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.operators.batch.BatchOperator`
and set `deferrable` param to `True` instead.
"""

def execute(self, context: Context) -> None:
"""
Airflow runs this method on the worker and defers using the trigger.
Submit the job and get the job_id using which we defer and poll in trigger
"""
self.submit_job(context)
if not self.job_id:
raise AirflowException("AWS Batch job - job_id was not found")
job = self.hook.get_job_description(self.job_id)
job_status = job.get("status")

if job_status == self.hook.SUCCESS_STATE:
self.log.info(f"{self.job_id} was completed successfully")
return

if job_status == self.hook.FAILURE_STATE:
raise AirflowException(f"{self.job_id} failed")

if job_status in self.hook.INTERMEDIATE_STATES:
self.defer(
timeout=self.execution_timeout,
trigger=BatchOperatorTrigger(
job_id=self.job_id,
waiters=self.waiters,
max_retries=self.hook.max_retries,
aws_conn_id=self.hook.aws_conn_id,
region_name=self.hook.region_name,
),
method_name="execute_complete",
)

raise AirflowException(f"Unexpected status: {job_status}")

# Ignoring the override type check because the parent class specifies "context: Any" but specifying it as
# "context: Context" is accurate as it's more specific.
def execute_complete(self, context: Context, event: Dict[str, Any]) -> None: # type: ignore[override]
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if "status" in event and event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(event["message"])
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated."
"Please use `airflow.providers.amazon.aws.operators.batch.BatchOperator`"
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__(deferrable=True, **kwargs)
Loading

0 comments on commit 52063a4

Please sign in to comment.