Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate RedshiftDataOperatorAsync, RedshiftClusterSensorAsync, S3KeySensorAsync, BatchOperatorAsync, SageMakerProcessingOperatorAsync, SageMakerTransformOperatorAsync #1455

Merged
merged 21 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
cd2a71f
feat(providers/amazon): deprecate BatchOperatorAsync
Lee-W Jan 23, 2024
1cc9a7c
feat(providers/amazon): deprecate SageMakerProcessingOperatorAsync
Lee-W Jan 23, 2024
31a9449
feat(providers/amazon): deprecate SageMakerTransformOperatorAsync
Lee-W Jan 23, 2024
534e172
feat(providers/amazon): deprecate S3KeySensorAsync
Lee-W Jan 23, 2024
a086057
feat(providers/amazon): deprecate RedshiftClusterSensorAsync
Lee-W Jan 23, 2024
a15b96e
build(setup.cfg): upgrade min amazon provider version to 8.17.0rc1
Lee-W Jan 23, 2024
bb61993
test: fix sagamaker test cases
Lee-W Jan 24, 2024
19576cd
feat(providers/amazon): deprecate RedshiftDataOperatorAsync
Lee-W Jan 24, 2024
3c9313c
fix(amazon): remove unnecessay poll_interval in RedshiftDataOperatorA…
Lee-W Jan 25, 2024
ae03f35
feat(amazon): change how poll_interval is loaded in redshift cluster …
Lee-W Jan 25, 2024
4520e8c
build: pin amazon provider version to >= 8.17.0
Lee-W Jan 29, 2024
807cd55
Update astronomer/providers/amazon/aws/hooks/sagemaker.py
Lee-W Jan 29, 2024
13ea59b
Update astronomer/providers/amazon/aws/hooks/sagemaker.py
Lee-W Jan 29, 2024
af7a2eb
Update astronomer/providers/amazon/aws/operators/sagemaker.py
Lee-W Jan 29, 2024
7b1e18c
Update astronomer/providers/amazon/aws/sensors/redshift_cluster.py
Lee-W Jan 29, 2024
8fe4e0f
Update astronomer/providers/amazon/aws/operators/sagemaker.py
Lee-W Jan 29, 2024
996fdef
Update astronomer/providers/amazon/aws/operators/batch.py
Lee-W Jan 29, 2024
0c4590b
Update astronomer/providers/amazon/aws/operators/redshift_data.py
Lee-W Jan 29, 2024
772bb36
Update astronomer/providers/amazon/aws/operators/sagemaker.py
Lee-W Jan 29, 2024
52b1caf
Update astronomer/providers/amazon/aws/sensors/s3.py
Lee-W Jan 29, 2024
0fe7821
Update astronomer/providers/amazon/aws/operators/sagemaker.py
Lee-W Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions astronomer/providers/amazon/aws/hooks/batch_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import asyncio
import warnings
from random import sample
from typing import Any, Dict, List, Optional, Union
from typing import Any

import botocore.exceptions
from airflow.exceptions import AirflowException
Expand All @@ -13,6 +16,9 @@ class BatchClientHookAsync(BatchClientHook, AwsBaseHookAsync):
"""
Async client for AWS Batch services.

This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook` instead

:param max_retries: exponential back-off retries, 4200 = 48 hours;
polling is only used when waiters is None
:param status_retries: number of HTTP retries to get job status, 10;
Expand Down Expand Up @@ -41,12 +47,20 @@ class BatchClientHookAsync(BatchClientHook, AwsBaseHookAsync):
- `Exponential Backoff And Jitter <https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/>`_
"""

def __init__(self, job_id: Optional[str], waiters: Any = None, *args: Any, **kwargs: Any) -> None:
def __init__(self, job_id: str | None, waiters: Any = None, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook`"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
self.job_id = job_id
self.waiters = waiters

async def monitor_job(self) -> Union[Dict[str, str], None]:
async def monitor_job(self) -> dict[str, str] | None:
"""
Monitor an AWS Batch job
monitor_job can raise an exception or an AirflowTaskTimeout can be raised if execution_timeout
Expand Down Expand Up @@ -92,7 +106,7 @@ async def check_job_success(self, job_id: str) -> bool: # type: ignore[override
raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}")

@staticmethod
async def delay(delay: Union[int, float, None] = None) -> None: # type: ignore[override]
async def delay(delay: int | (float | None) = None) -> None: # type: ignore[override]
"""
Pause execution for ``delay`` seconds.

Expand All @@ -116,7 +130,7 @@ async def delay(delay: Union[int, float, None] = None) -> None: # type: ignore[
delay = BatchClientHookAsync.add_jitter(delay)
await asyncio.sleep(delay)

async def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None: # type: ignore[override]
async def wait_for_job(self, job_id: str, delay: int | (float | None) = None) -> None: # type: ignore[override]
"""
Wait for Batch job to complete

Expand All @@ -131,7 +145,7 @@ async def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None)
self.log.info("AWS Batch job (%s) has completed", job_id)

async def poll_for_job_complete( # type: ignore[override]
self, job_id: str, delay: Union[int, float, None] = None
self, job_id: str, delay: int | (float | None) = None
) -> None:
"""
Poll for job completion. The status that indicates job completion
Expand All @@ -150,7 +164,7 @@ async def poll_for_job_complete( # type: ignore[override]
await self.poll_job_status(job_id, complete_status)

async def poll_for_job_running( # type: ignore[override]
self, job_id: str, delay: Union[int, float, None] = None
self, job_id: str, delay: int | (float | None) = None
) -> None:
"""
Poll for job running. The status that indicates a job is running or
Expand All @@ -172,7 +186,7 @@ async def poll_for_job_running( # type: ignore[override]
running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, self.FAILURE_STATE]
await self.poll_job_status(job_id, running_status)

async def get_job_description(self, job_id: str) -> Dict[str, str]: # type: ignore[override]
async def get_job_description(self, job_id: str) -> dict[str, str]: # type: ignore[override]
"""
Get job description (using status_retries).

Expand Down Expand Up @@ -210,7 +224,7 @@ async def get_job_description(self, job_id: str) -> Dict[str, str]: # type: ign
)
await self.delay(pause)

async def poll_job_status(self, job_id: str, match_status: List[str]) -> bool: # type: ignore[override]
async def poll_job_status(self, job_id: str, match_status: list[str]) -> bool: # type: ignore[override]
"""
Poll for job status using an exponential back-off strategy (with max_retries).
The Batch job status polled are:
Expand Down
13 changes: 13 additions & 0 deletions astronomer/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import warnings
from typing import Any, Iterable

import botocore.exceptions
Expand All @@ -18,6 +19,9 @@ class RedshiftDataHook(AwsBaseHook):
RedshiftDataHook inherits from AwsBaseHook to connect with AWS redshift
by using boto3 client_type as redshift-data we can interact with redshift cluster database and execute the query

This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook` instead

:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
Expand All @@ -34,6 +38,15 @@ class RedshiftDataHook(AwsBaseHook):
"""

def __init__(self, *args: Any, poll_interval: int = 0, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`"
),
DeprecationWarning,
stacklevel=2,
)

aws_connection_type: str = "redshift-data"
try:
# for apache-airflow-providers-amazon>=3.0.0
Expand Down
15 changes: 14 additions & 1 deletion astronomer/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import fnmatch
import os
import re
import warnings
from datetime import datetime
from functools import wraps
from inspect import signature
Expand Down Expand Up @@ -43,12 +44,24 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:


class S3HookAsync(AwsBaseHookAsync):
"""Interact with AWS S3, using the aiobotocore library."""
"""Interact with AWS S3, using the aiobotocore library.

This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.hooks.s3.S3Hook` instead
"""

conn_type = "s3"
hook_name = "S3"

def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.hooks.s3.S3Hook`"
),
DeprecationWarning,
stacklevel=2,
)
kwargs["client_type"] = "s3"
kwargs["resource_type"] = "s3"
super().__init__(*args, **kwargs)
Expand Down
42 changes: 28 additions & 14 deletions astronomer/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
import warnings
from typing import Any, AsyncGenerator

from airflow.providers.amazon.aws.hooks.sagemaker import (
LogState,
Expand All @@ -21,56 +24,67 @@ class SageMakerHookAsync(AwsBaseHookAsync):

Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHookAsync.

This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook` instead
"""

NON_TERMINAL_STATES = ("InProgress", "Stopping", "Stopped")

def __init__(self, *args: Any, **kwargs: Any):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook`"
),
DeprecationWarning,
stacklevel=2,
)
kwargs["client_type"] = "sagemaker"
super().__init__(*args, **kwargs)
self.s3_hook = S3HookAsync(aws_conn_id=self.aws_conn_id)
self.logs_hook_async = AwsLogsHookAsync(aws_conn_id=self.aws_conn_id)

async def describe_transform_job_async(self, job_name: str) -> Dict[str, Any]:
async def describe_transform_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the transform job info associated with the name

:param job_name: the name of the transform job
"""
async with await self.get_client_async() as client:
response: Dict[str, Any] = await client.describe_transform_job(TransformJobName=job_name)
response: dict[str, Any] = await client.describe_transform_job(TransformJobName=job_name)
return response

async def describe_processing_job_async(self, job_name: str) -> Dict[str, Any]:
async def describe_processing_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the processing job info associated with the name

:param job_name: the name of the processing job
"""
async with await self.get_client_async() as client:
response: Dict[str, Any] = await client.describe_processing_job(ProcessingJobName=job_name)
response: dict[str, Any] = await client.describe_processing_job(ProcessingJobName=job_name)
return response

async def describe_training_job_async(self, job_name: str) -> Dict[str, Any]:
async def describe_training_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the training job info associated with the name

:param job_name: the name of the training job
"""
async with await self.get_client_async() as client:
response: Dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
return response

async def describe_training_job_with_log(
self,
job_name: str,
positions: Dict[str, Any],
stream_names: List[str],
positions: dict[str, Any],
stream_names: list[str],
instance_count: int,
state: int,
last_description: Dict[str, Any],
last_description: dict[str, Any],
last_describe_job_call: float,
) -> Tuple[int, Dict[str, Any], float]:
) -> tuple[int, dict[str, Any], float]:
"""
Return the training job info associated with job_name and print CloudWatch logs

Expand Down Expand Up @@ -127,8 +141,8 @@ async def describe_training_job_with_log(
return state, last_description, last_describe_job_call

async def get_multi_stream(
self, log_group: str, streams: List[str], positions: Dict[str, Any]
) -> AsyncGenerator[Any, Tuple[int, Optional[Any]]]:
self, log_group: str, streams: list[str], positions: dict[str, Any]
) -> AsyncGenerator[Any, tuple[int, Any | None]]:
"""
Iterate over the available events coming from a set of log streams in a single log group
interleaving the events from each stream so they're yielded in timestamp order.
Expand All @@ -140,7 +154,7 @@ async def get_multi_stream(
read from each stream.
"""
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
events: list[Optional[Any]] = []
events: list[Any | None] = []

event_iters = [
self.logs_hook_async.get_log_events_async(log_group, s, positions[s].timestamp, positions[s].skip)
Expand Down
102 changes: 18 additions & 84 deletions astronomer/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,95 +7,29 @@
- `Batch <http://boto3.readthedocs.io/en/latest/reference/services/batch.html>`_
- `Welcome <https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html>`_
"""
from typing import Any, Dict
from __future__ import annotations

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.batch import BatchOperator
import warnings
from typing import Any

from astronomer.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from astronomer.providers.utils.typing_compat import Context
from airflow.providers.amazon.aws.operators.batch import BatchOperator


class BatchOperatorAsync(BatchOperator):
"""
Execute a job asynchronously on AWS Batch

.. see also::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BatchOperator`

:param job_name: the name for the job that will run on AWS Batch (templated)
:param job_definition: the job definition name on AWS Batch
:param job_queue: the queue name on AWS Batch
:param overrides: Removed in apache-airflow-providers-amazon release 8.0.0, use container_overrides instead with the
same value.
:param container_overrides: the `containerOverrides` parameter for boto3 (templated)
:param array_properties: the `arrayProperties` parameter for boto3
:param parameters: the `parameters` for boto3 (templated)
:param job_id: the job ID, usually unknown (None) until the
submit_job operation gets the jobId defined by AWS Batch
:param waiters: an :py:class:`.BatchWaiters` object (see note below);
if None, polling is used with max_retries and status_retries.
:param max_retries: exponential back-off retries, 4200 = 48 hours;
polling is only used when waiters is None
:param status_retries: number of HTTP retries to get job status, 10;
polling is only used when waiters is None
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used.
:param region_name: region name to use in AWS Hook.
Override the region_name in connection (if provided)
:param tags: collection of tags to apply to the AWS Batch job submission
if None, no tags are submitted

.. note::
Any custom waiters must return a waiter for these calls:

| ``waiter = waiters.get_waiter("JobExists")``
| ``waiter = waiters.get_waiter("JobRunning")``
| ``waiter = waiters.get_waiter("JobComplete")``
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.operators.batch.BatchOperator`
and set `deferrable` param to `True` instead.
"""

def execute(self, context: Context) -> None:
"""
Airflow runs this method on the worker and defers using the trigger.
Submit the job and get the job_id using which we defer and poll in trigger
"""
self.submit_job(context)
if not self.job_id:
raise AirflowException("AWS Batch job - job_id was not found")
job = self.hook.get_job_description(self.job_id)
job_status = job.get("status")

if job_status == self.hook.SUCCESS_STATE:
self.log.info(f"{self.job_id} was completed successfully")
return

if job_status == self.hook.FAILURE_STATE:
raise AirflowException(f"{self.job_id} failed")

if job_status in self.hook.INTERMEDIATE_STATES:
self.defer(
timeout=self.execution_timeout,
trigger=BatchOperatorTrigger(
job_id=self.job_id,
waiters=self.waiters,
max_retries=self.hook.max_retries,
aws_conn_id=self.hook.aws_conn_id,
region_name=self.hook.region_name,
),
method_name="execute_complete",
)

raise AirflowException(f"Unexpected status: {job_status}")

# Ignoring the override type check because the parent class specifies "context: Any" but specifying it as
# "context: Context" is accurate as it's more specific.
def execute_complete(self, context: Context, event: Dict[str, Any]) -> None: # type: ignore[override]
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if "status" in event and event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(event["message"])
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated."
"Please use `airflow.providers.amazon.aws.operators.batch.BatchOperator`"
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__(deferrable=True, **kwargs)
Loading
Loading