diff --git a/astronomer/providers/amazon/aws/hooks/batch_client.py b/astronomer/providers/amazon/aws/hooks/batch_client.py
index 8e7e37b31..a2ae392b5 100644
--- a/astronomer/providers/amazon/aws/hooks/batch_client.py
+++ b/astronomer/providers/amazon/aws/hooks/batch_client.py
@@ -1,6 +1,9 @@
+from __future__ import annotations
+
import asyncio
+import warnings
from random import sample
-from typing import Any, Dict, List, Optional, Union
+from typing import Any
import botocore.exceptions
from airflow.exceptions import AirflowException
@@ -13,6 +16,9 @@ class BatchClientHookAsync(BatchClientHook, AwsBaseHookAsync):
"""
Async client for AWS Batch services.
+ This class is deprecated and will be removed in 2.0.0.
+ Use :class: `~airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook` instead
+
:param max_retries: exponential back-off retries, 4200 = 48 hours;
polling is only used when waiters is None
:param status_retries: number of HTTP retries to get job status, 10;
@@ -41,12 +47,20 @@ class BatchClientHookAsync(BatchClientHook, AwsBaseHookAsync):
- `Exponential Backoff And Jitter `_
"""
- def __init__(self, job_id: Optional[str], waiters: Any = None, *args: Any, **kwargs: Any) -> None:
+ def __init__(self, job_id: str | None, waiters: Any = None, *args: Any, **kwargs: Any) -> None:
+ warnings.warn(
+ (
+ "This module is deprecated and will be removed in 2.0.0."
+ "Please use `airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook`"
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__(*args, **kwargs)
self.job_id = job_id
self.waiters = waiters
- async def monitor_job(self) -> Union[Dict[str, str], None]:
+ async def monitor_job(self) -> dict[str, str] | None:
"""
Monitor an AWS Batch job
monitor_job can raise an exception or an AirflowTaskTimeout can be raised if execution_timeout
@@ -92,7 +106,7 @@ async def check_job_success(self, job_id: str) -> bool: # type: ignore[override
raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}")
@staticmethod
- async def delay(delay: Union[int, float, None] = None) -> None: # type: ignore[override]
+ async def delay(delay: int | (float | None) = None) -> None: # type: ignore[override]
"""
Pause execution for ``delay`` seconds.
@@ -116,7 +130,7 @@ async def delay(delay: Union[int, float, None] = None) -> None: # type: ignore[
delay = BatchClientHookAsync.add_jitter(delay)
await asyncio.sleep(delay)
- async def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None: # type: ignore[override]
+ async def wait_for_job(self, job_id: str, delay: int | (float | None) = None) -> None: # type: ignore[override]
"""
Wait for Batch job to complete
@@ -131,7 +145,7 @@ async def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None)
self.log.info("AWS Batch job (%s) has completed", job_id)
async def poll_for_job_complete( # type: ignore[override]
- self, job_id: str, delay: Union[int, float, None] = None
+ self, job_id: str, delay: int | (float | None) = None
) -> None:
"""
Poll for job completion. The status that indicates job completion
@@ -150,7 +164,7 @@ async def poll_for_job_complete( # type: ignore[override]
await self.poll_job_status(job_id, complete_status)
async def poll_for_job_running( # type: ignore[override]
- self, job_id: str, delay: Union[int, float, None] = None
+ self, job_id: str, delay: int | (float | None) = None
) -> None:
"""
Poll for job running. The status that indicates a job is running or
@@ -172,7 +186,7 @@ async def poll_for_job_running( # type: ignore[override]
running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, self.FAILURE_STATE]
await self.poll_job_status(job_id, running_status)
- async def get_job_description(self, job_id: str) -> Dict[str, str]: # type: ignore[override]
+ async def get_job_description(self, job_id: str) -> dict[str, str]: # type: ignore[override]
"""
Get job description (using status_retries).
@@ -210,7 +224,7 @@ async def get_job_description(self, job_id: str) -> Dict[str, str]: # type: ign
)
await self.delay(pause)
- async def poll_job_status(self, job_id: str, match_status: List[str]) -> bool: # type: ignore[override]
+ async def poll_job_status(self, job_id: str, match_status: list[str]) -> bool: # type: ignore[override]
"""
Poll for job status using an exponential back-off strategy (with max_retries).
The Batch job status polled are:
diff --git a/astronomer/providers/amazon/aws/hooks/redshift_data.py b/astronomer/providers/amazon/aws/hooks/redshift_data.py
index d7b880c1f..0ae56bba1 100644
--- a/astronomer/providers/amazon/aws/hooks/redshift_data.py
+++ b/astronomer/providers/amazon/aws/hooks/redshift_data.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
+import warnings
from typing import Any, Iterable
import botocore.exceptions
@@ -18,6 +19,9 @@ class RedshiftDataHook(AwsBaseHook):
RedshiftDataHook inherits from AwsBaseHook to connect with AWS redshift
by using boto3 client_type as redshift-data we can interact with redshift cluster database and execute the query
+ This class is deprecated and will be removed in 2.0.0.
+ Use :class: `~airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook` instead
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
@@ -34,6 +38,15 @@ class RedshiftDataHook(AwsBaseHook):
"""
def __init__(self, *args: Any, poll_interval: int = 0, **kwargs: Any) -> None:
+ warnings.warn(
+ (
+ "This module is deprecated and will be removed in 2.0.0."
+ "Please use `airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`"
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
aws_connection_type: str = "redshift-data"
try:
# for apache-airflow-providers-amazon>=3.0.0
diff --git a/astronomer/providers/amazon/aws/hooks/s3.py b/astronomer/providers/amazon/aws/hooks/s3.py
index 1af72b360..57adb22aa 100644
--- a/astronomer/providers/amazon/aws/hooks/s3.py
+++ b/astronomer/providers/amazon/aws/hooks/s3.py
@@ -4,6 +4,7 @@
import fnmatch
import os
import re
+import warnings
from datetime import datetime
from functools import wraps
from inspect import signature
@@ -43,12 +44,24 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
class S3HookAsync(AwsBaseHookAsync):
- """Interact with AWS S3, using the aiobotocore library."""
+ """Interact with AWS S3, using the aiobotocore library.
+
+ This class is deprecated and will be removed in 2.0.0.
+ Use :class: `~airflow.providers.amazon.aws.hooks.s3.S3Hook` instead
+ """
conn_type = "s3"
hook_name = "S3"
def __init__(self, *args: Any, **kwargs: Any) -> None:
+ warnings.warn(
+ (
+ "This module is deprecated and will be removed in 2.0.0."
+ "Please use `airflow.providers.amazon.aws.hooks.s3.S3Hook`"
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
kwargs["client_type"] = "s3"
kwargs["resource_type"] = "s3"
super().__init__(*args, **kwargs)
diff --git a/astronomer/providers/amazon/aws/hooks/sagemaker.py b/astronomer/providers/amazon/aws/hooks/sagemaker.py
index cb518b9c9..1396d250a 100644
--- a/astronomer/providers/amazon/aws/hooks/sagemaker.py
+++ b/astronomer/providers/amazon/aws/hooks/sagemaker.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
import time
-from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
+import warnings
+from typing import Any, AsyncGenerator
from airflow.providers.amazon.aws.hooks.sagemaker import (
LogState,
@@ -21,56 +24,67 @@ class SageMakerHookAsync(AwsBaseHookAsync):
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHookAsync.
+
+ This class is deprecated and will be removed in 2.0.0.
+ Use :class: `~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook` instead
"""
NON_TERMINAL_STATES = ("InProgress", "Stopping", "Stopped")
def __init__(self, *args: Any, **kwargs: Any):
+ warnings.warn(
+ (
+ "This module is deprecated and will be removed in 2.0.0."
+ "Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook`"
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
kwargs["client_type"] = "sagemaker"
super().__init__(*args, **kwargs)
self.s3_hook = S3HookAsync(aws_conn_id=self.aws_conn_id)
self.logs_hook_async = AwsLogsHookAsync(aws_conn_id=self.aws_conn_id)
- async def describe_transform_job_async(self, job_name: str) -> Dict[str, Any]:
+ async def describe_transform_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the transform job info associated with the name
:param job_name: the name of the transform job
"""
async with await self.get_client_async() as client:
- response: Dict[str, Any] = await client.describe_transform_job(TransformJobName=job_name)
+ response: dict[str, Any] = await client.describe_transform_job(TransformJobName=job_name)
return response
- async def describe_processing_job_async(self, job_name: str) -> Dict[str, Any]:
+ async def describe_processing_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the processing job info associated with the name
:param job_name: the name of the processing job
"""
async with await self.get_client_async() as client:
- response: Dict[str, Any] = await client.describe_processing_job(ProcessingJobName=job_name)
+ response: dict[str, Any] = await client.describe_processing_job(ProcessingJobName=job_name)
return response
- async def describe_training_job_async(self, job_name: str) -> Dict[str, Any]:
+ async def describe_training_job_async(self, job_name: str) -> dict[str, Any]:
"""
Return the training job info associated with the name
:param job_name: the name of the training job
"""
async with await self.get_client_async() as client:
- response: Dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
+ response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
return response
async def describe_training_job_with_log(
self,
job_name: str,
- positions: Dict[str, Any],
- stream_names: List[str],
+ positions: dict[str, Any],
+ stream_names: list[str],
instance_count: int,
state: int,
- last_description: Dict[str, Any],
+ last_description: dict[str, Any],
last_describe_job_call: float,
- ) -> Tuple[int, Dict[str, Any], float]:
+ ) -> tuple[int, dict[str, Any], float]:
"""
Return the training job info associated with job_name and print CloudWatch logs
@@ -127,8 +141,8 @@ async def describe_training_job_with_log(
return state, last_description, last_describe_job_call
async def get_multi_stream(
- self, log_group: str, streams: List[str], positions: Dict[str, Any]
- ) -> AsyncGenerator[Any, Tuple[int, Optional[Any]]]:
+ self, log_group: str, streams: list[str], positions: dict[str, Any]
+ ) -> AsyncGenerator[Any, tuple[int, Any | None]]:
"""
Iterate over the available events coming from a set of log streams in a single log group
interleaving the events from each stream so they're yielded in timestamp order.
@@ -140,7 +154,7 @@ async def get_multi_stream(
read from each stream.
"""
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
- events: list[Optional[Any]] = []
+ events: list[Any | None] = []
event_iters = [
self.logs_hook_async.get_log_events_async(log_group, s, positions[s].timestamp, positions[s].skip)
diff --git a/astronomer/providers/amazon/aws/operators/batch.py b/astronomer/providers/amazon/aws/operators/batch.py
index 2ec421998..5370ab75e 100644
--- a/astronomer/providers/amazon/aws/operators/batch.py
+++ b/astronomer/providers/amazon/aws/operators/batch.py
@@ -7,95 +7,29 @@
- `Batch `_
- `Welcome `_
"""
-from typing import Any, Dict
+from __future__ import annotations
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.operators.batch import BatchOperator
+import warnings
+from typing import Any
-from astronomer.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
-from astronomer.providers.utils.typing_compat import Context
+from airflow.providers.amazon.aws.operators.batch import BatchOperator
class BatchOperatorAsync(BatchOperator):
"""
- Execute a job asynchronously on AWS Batch
-
- .. see also::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:BatchOperator`
-
- :param job_name: the name for the job that will run on AWS Batch (templated)
- :param job_definition: the job definition name on AWS Batch
- :param job_queue: the queue name on AWS Batch
- :param overrides: Removed in apache-airflow-providers-amazon release 8.0.0, use container_overrides instead with the
- same value.
- :param container_overrides: the `containerOverrides` parameter for boto3 (templated)
- :param array_properties: the `arrayProperties` parameter for boto3
- :param parameters: the `parameters` for boto3 (templated)
- :param job_id: the job ID, usually unknown (None) until the
- submit_job operation gets the jobId defined by AWS Batch
- :param waiters: an :py:class:`.BatchWaiters` object (see note below);
- if None, polling is used with max_retries and status_retries.
- :param max_retries: exponential back-off retries, 4200 = 48 hours;
- polling is only used when waiters is None
- :param status_retries: number of HTTP retries to get job status, 10;
- polling is only used when waiters is None
- :param aws_conn_id: connection id of AWS credentials / region name. If None,
- credential boto3 strategy will be used.
- :param region_name: region name to use in AWS Hook.
- Override the region_name in connection (if provided)
- :param tags: collection of tags to apply to the AWS Batch job submission
- if None, no tags are submitted
-
- .. note::
- Any custom waiters must return a waiter for these calls:
-
- | ``waiter = waiters.get_waiter("JobExists")``
- | ``waiter = waiters.get_waiter("JobRunning")``
- | ``waiter = waiters.get_waiter("JobComplete")``
+ This class is deprecated.
+ Please use :class: `~airflow.providers.amazon.aws.operators.batch.BatchOperator`
+ and set `deferrable` param to `True` instead.
"""
- def execute(self, context: Context) -> None:
- """
- Airflow runs this method on the worker and defers using the trigger.
- Submit the job and get the job_id using which we defer and poll in trigger
- """
- self.submit_job(context)
- if not self.job_id:
- raise AirflowException("AWS Batch job - job_id was not found")
- job = self.hook.get_job_description(self.job_id)
- job_status = job.get("status")
-
- if job_status == self.hook.SUCCESS_STATE:
- self.log.info(f"{self.job_id} was completed successfully")
- return
-
- if job_status == self.hook.FAILURE_STATE:
- raise AirflowException(f"{self.job_id} failed")
-
- if job_status in self.hook.INTERMEDIATE_STATES:
- self.defer(
- timeout=self.execution_timeout,
- trigger=BatchOperatorTrigger(
- job_id=self.job_id,
- waiters=self.waiters,
- max_retries=self.hook.max_retries,
- aws_conn_id=self.hook.aws_conn_id,
- region_name=self.hook.region_name,
- ),
- method_name="execute_complete",
- )
-
- raise AirflowException(f"Unexpected status: {job_status}")
-
- # Ignoring the override type check because the parent class specifies "context: Any" but specifying it as
- # "context: Context" is accurate as it's more specific.
- def execute_complete(self, context: Context, event: Dict[str, Any]) -> None: # type: ignore[override]
- """
- Callback for when the trigger fires - returns immediately.
- Relies on trigger to throw an exception, otherwise it assumes execution was
- successful.
- """
- if "status" in event and event["status"] == "error":
- raise AirflowException(event["message"])
- self.log.info(event["message"])
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ warnings.warn(
+ (
+ "This module is deprecated."
+ "Please use `airflow.providers.amazon.aws.operators.batch.BatchOperator`"
+ "and set `deferrable` param to `True` instead."
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ super().__init__(deferrable=True, **kwargs)
diff --git a/astronomer/providers/amazon/aws/operators/redshift_data.py b/astronomer/providers/amazon/aws/operators/redshift_data.py
index 6dd05d2fa..9b6ebcb5d 100644
--- a/astronomer/providers/amazon/aws/operators/redshift_data.py
+++ b/astronomer/providers/amazon/aws/operators/redshift_data.py
@@ -1,75 +1,27 @@
+import warnings
from typing import Any
-from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator
-from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
-from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger
-from astronomer.providers.utils.typing_compat import Context
-
class RedshiftDataOperatorAsync(RedshiftDataOperator):
"""
- Executes SQL Statements against an Amazon Redshift cluster.
- If there are multiple queries as part of the SQL, and one of them fails to reach a successful completion state,
- the operator returns the relevant error for the failed query.
-
- :param sql: the SQL code to be executed as a single string, or
- a list of str (sql statements), or a reference to a template file.
- Template references are recognized by str ending in '.sql'
- :param aws_conn_id: AWS connection ID
- :param parameters: (optional) the parameters to render the SQL query with.
- :param autocommit: if True, each command is automatically committed.
- (default value: False)
+ This class is deprecated.
+ Please use :class: `~airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator`
+ and set `deferrable` param to `True` instead.
"""
def __init__(
self,
- *,
- poll_interval: int = 5,
**kwargs: Any,
) -> None:
- self.poll_interval = poll_interval
- super().__init__(**kwargs)
-
- def execute(self, context: Context) -> None:
- """
- Makes a sync call to RedshiftDataHook, executes the query and gets back the list of query_ids and
- defers trigger to poll for the status for the queries executed.
- """
- redshift_data_hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id)
- query_ids, response = redshift_data_hook.execute_query(sql=self.sql, params=self.params)
- self.log.info("Query IDs %s", query_ids)
- if response.get("status") == "error":
- self.execute_complete(context, event=response)
- context["ti"].xcom_push(key="return_value", value=query_ids)
-
- if redshift_data_hook.queries_are_completed(query_ids, context):
- self.log.info("%s completed successfully.", self.task_id)
- return
-
- self.defer(
- timeout=self.execution_timeout,
- trigger=RedshiftDataTrigger(
- task_id=self.task_id,
- poll_interval=self.poll_interval,
- aws_conn_id=self.aws_conn_id,
- query_ids=query_ids,
+ warnings.warn(
+ (
+ "This module is deprecated."
+ "Please use `airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator`"
+ "and set `deferrable` param to `True` instead."
),
- method_name="execute_complete",
+ DeprecationWarning,
+ stacklevel=2,
)
-
- def execute_complete(self, context: Context, event: Any = None) -> None:
- """
- Callback for when the trigger fires - returns immediately.
- Relies on trigger to throw an exception, otherwise it assumes execution was
- successful.
- """
- if event:
- if "status" in event and event["status"] == "error":
- msg = "context: {}, error message: {}".format(context, event["message"])
- raise AirflowException(msg)
- elif "status" in event and event["status"] == "success":
- self.log.info("%s completed successfully.", self.task_id)
- else:
- raise AirflowException("Did not receive valid event from the trigerrer")
+ super().__init__(deferrable=True, **kwargs)
diff --git a/astronomer/providers/amazon/aws/operators/sagemaker.py b/astronomer/providers/amazon/aws/operators/sagemaker.py
index 038afd8fa..ed8a7edf4 100644
--- a/astronomer/providers/amazon/aws/operators/sagemaker.py
+++ b/astronomer/providers/amazon/aws/operators/sagemaker.py
@@ -2,6 +2,7 @@
import json
import time
+import warnings
from typing import Any
from airflow.exceptions import AirflowException
@@ -17,7 +18,6 @@
from airflow.utils.json import AirflowJsonEncoder
from astronomer.providers.amazon.aws.triggers.sagemaker import (
- SagemakerProcessingTrigger,
SagemakerTrainingWithLogTrigger,
SagemakerTrigger,
)
@@ -31,216 +31,42 @@ def serialize(result: dict[str, Any]) -> str:
class SageMakerProcessingOperatorAsync(SageMakerProcessingOperator):
"""
- SageMakerProcessingOperatorAsync is used to analyze data and evaluate machine learning
- models on Amazon SageMaker. With SageMakerProcessingOperatorAsync, you can use a simplified, managed
- experience on SageMaker to run your data processing workloads, such as feature
- engineering, data validation, model evaluation, and model interpretation.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:SageMakerProcessingOperator`
-
- :param config: The configuration necessary to start a processing job (templated).
- For details of the configuration parameter see
- :ref:``SageMaker.Client.create_processing_job``
- :param aws_conn_id: The AWS connection ID to use.
- :param wait_for_completion: Even if wait is set to False, in async we will defer and
- the operation waits to check the status of the processing job.
- :param print_log: if the operator should print the cloudwatch log during processing
- :param check_interval: if wait is set to be true, this is the time interval
- in seconds which the operator will check the status of the processing job
- :param max_ingestion_time: The operation fails if the processing job
- doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
- the operation does not timeout.
- :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
- (default) and "fail".
+ This class is deprecated.
+ Please use :class: `~airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator`
+ and set `deferrable` param to `True` instead.
"""
- def execute(self, context: Context) -> dict[str, str] | None: # type: ignore[override]
- """
- Creates processing job via sync hook `create_processing_job` and pass the
- control to trigger and polls for the status of the processing job in async
- """
- self.preprocess_config()
- processing_job_name = self.config["ProcessingJobName"]
- try:
- if self.hook.count_processing_jobs_by_name(processing_job_name):
- raise AirflowException(
- f"A SageMaker processing job with name {processing_job_name} already exists."
- )
- except AttributeError: # pragma: no cover
- # For apache-airflow-providers-amazon<8.0.0
- if self.hook.find_processing_job_by_name(processing_job_name):
- raise AirflowException(
- f"A SageMaker processing job with name {processing_job_name} already exists."
- )
- self.log.info("Creating SageMaker processing job %s.", self.config["ProcessingJobName"])
- response = self.hook.create_processing_job(
- self.config,
- # we do not wait for completion here but we create the processing job
- # and poll for it in trigger
- wait_for_completion=False,
- )
- if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
- raise AirflowException(f"Sagemaker Processing Job creation failed: {response}")
-
- response = self.hook.describe_processing_job(processing_job_name)
- status = response["ProcessingJobStatus"]
- if status in self.hook.failed_states:
- raise AirflowException(f"SageMaker job failed because {response['FailureReason']}")
- elif status == "Completed":
- self.log.info(f"{self.task_id} completed successfully.")
- return {"Processing": serialize(response)}
-
- end_time: float | None = None
- if self.max_ingestion_time is not None:
- end_time = time.time() + self.max_ingestion_time
- self.defer(
- timeout=self.execution_timeout,
- trigger=SagemakerProcessingTrigger(
- poll_interval=self.check_interval,
- aws_conn_id=self.aws_conn_id,
- job_name=self.config["ProcessingJobName"],
- end_time=end_time,
+ def __init__(self, **kwargs: Any) -> None:
+ warnings.warn(
+ (
+ "This module is deprecated."
+ "Please use `airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator`"
+ "and set `deferrable` param to `True` instead."
),
- method_name="execute_complete",
+ DeprecationWarning,
+ stacklevel=2,
)
-
- # for bypassing mypy missing return error
- return None # pragma: no cover
-
- def execute_complete(self, context: Context, event: Any = None) -> dict[str, Any]:
- """
- Callback for when the trigger fires - returns immediately.
- Relies on trigger to throw an exception, otherwise it assumes execution was
- successful.
- """
- if event and event["status"] == "success":
- self.log.info("%s completed successfully.", self.task_id)
- return {"Processing": serialize(event["message"])}
- if event and event["status"] == "error":
- raise AirflowException(event["message"])
- raise AirflowException("No event received in trigger callback")
+ super().__init__(deferrable=True, **kwargs)
class SageMakerTransformOperatorAsync(SageMakerTransformOperator):
"""
- SageMakerTransformOperatorAsync starts a transform job and polls for the status asynchronously.
- A transform job uses a trained model to get inferences on a dataset and saves these results to an Amazon
- S3 location that you specify.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:``howto/operator:SageMakerTransformOperator``
-
- :param config: The configuration necessary to start a transform job (templated).
-
- If you need to create a SageMaker transform job based on an existed SageMaker model::
-
- config = transform_config
-
- If you need to create both SageMaker model and SageMaker Transform job::
-
- config = {
- 'Model': model_config,
- 'Transform': transform_config
- }
-
- For details of the configuration parameter of transform_config see
- :ref:``SageMaker.Client.create_transform_job``
-
- For details of the configuration parameter of model_config, See:
- :ref:``SageMaker.Client.create_model``
-
- :param aws_conn_id: The AWS connection ID to use.
- :param check_interval: If wait is set to True, the time interval, in seconds,
- that this operation waits to check the status of the transform job.
- :param max_ingestion_time: The operation fails if the transform job doesn't finish
- within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout.
- :param check_if_job_exists: If set to true, then the operator will check whether a transform job
- already exists for the name in the config.
- :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
- (default) and "fail".
- This is only relevant if check_if_job_exists is True.
+ This class is deprecated.
+ Please use :class: `~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator`
+ and set `deferrable` param to `True` instead.
"""
- def execute(self, context: Context) -> dict[str, Any] | None: # type: ignore[override]
- """
- Creates transform job via sync hook `create_transform_job` and pass the
- control to trigger and polls for the status of the transform job in async
- """
- self.preprocess_config()
- model_config = self.config.get("Model")
- transform_config = self.config.get("Transform", self.config)
- if self.check_if_job_exists: # pragma: no cover
- try:
- # for apache-airflow-providers-amazon<=7.2.1
- self._check_if_transform_job_exists() # type: ignore[attr-defined]
- except AttributeError:
- # for apache-airflow-providers-amazon>=7.3.0
- transform_config["TransformJobName"] = self._get_unique_job_name(
- transform_config["TransformJobName"],
- self.action_if_job_exists == "fail",
- self.hook.describe_transform_job,
- )
- if model_config:
- self.log.info("Creating SageMaker Model %s for transform job", model_config["ModelName"])
- self.hook.create_model(model_config)
- self.log.info("Creating SageMaker transform Job %s.", transform_config["TransformJobName"])
- response = self.hook.create_transform_job(
- transform_config,
- wait_for_completion=False,
- )
- if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
- raise AirflowException(f"Sagemaker transform Job creation failed: {response}")
-
- response = self.hook.describe_transform_job(transform_config["TransformJobName"])
- status = response["TransformJobStatus"]
- if status in self.hook.failed_states:
- raise AirflowException(f"SageMaker job failed because {response['FailureReason']}")
-
- if status == "Completed":
- self.log.info(f"{self.task_id} completed successfully.")
- return {
- "Model": serialize(self.hook.describe_model(transform_config["ModelName"])),
- "Transform": serialize(response),
- }
-
- end_time: float | None = None
- if self.max_ingestion_time is not None:
- end_time = time.time() + self.max_ingestion_time
- self.defer(
- timeout=self.execution_timeout,
- trigger=SagemakerTrigger(
- poke_interval=self.check_interval,
- end_time=end_time,
- aws_conn_id=self.aws_conn_id,
- job_name=transform_config["TransformJobName"],
- job_type="Transform",
- response_key="TransformJobStatus",
+ def __init__(self, **kwargs: Any) -> None:
+ warnings.warn(
+ (
+ "This module is deprecated."
+ "Please use `airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator`"
+ "and set `deferrable` param to `True` instead."
),
- method_name="execute_complete",
+ DeprecationWarning,
+ stacklevel=2,
)
-
- # for bypassing mypy missing return error
- return None # pragma: no cover
-
- def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]: # type: ignore[override]
- """
- Callback for when the trigger fires - returns immediately.
- Relies on trigger to throw an exception, otherwise it assumes execution was
- successful.
- """
- if event and event["status"] == "success":
- self.log.info("%s completed successfully.", self.task_id)
- transform_config = self.config.get("Transform", self.config)
- return {
- "Model": serialize(self.hook.describe_model(transform_config["ModelName"])),
- "Transform": serialize(event["message"]),
- }
- if event and event["status"] == "error":
- raise AirflowException(event["message"])
- raise AirflowException("No event received in trigger callback")
+ super().__init__(deferrable=True, **kwargs)
class SageMakerTrainingOperatorAsync(SageMakerTrainingOperator):
diff --git a/astronomer/providers/amazon/aws/sensors/redshift_cluster.py b/astronomer/providers/amazon/aws/sensors/redshift_cluster.py
index 48f138b71..6f919bf22 100644
--- a/astronomer/providers/amazon/aws/sensors/redshift_cluster.py
+++ b/astronomer/providers/amazon/aws/sensors/redshift_cluster.py
@@ -1,71 +1,36 @@
import warnings
-from datetime import timedelta
-from typing import Any, Dict, Optional
+from typing import Any
from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor
-from astronomer.providers.amazon.aws.triggers.redshift_cluster import (
- RedshiftClusterSensorTrigger,
-)
-from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
-from astronomer.providers.utils.typing_compat import Context
-
class RedshiftClusterSensorAsync(RedshiftClusterSensor):
"""
- Waits for a Redshift cluster to reach a specific status.
-
- :param cluster_identifier: The identifier for the cluster being pinged.\
- :param target_status: The cluster status desired.
+ This class is deprecated.
+ Please use :class: `~airflow.providers.amazon.aws.sensors.redshift_cluster.RedshiftClusterSensor`
+ and set `deferrable` param to `True` instead.
"""
def __init__(
self,
- *,
- poll_interval: float = 5,
**kwargs: Any,
):
# TODO: Remove once deprecated
- if poll_interval:
- self.poke_interval = poll_interval
+ if kwargs.get("poll_interval"):
+ kwargs["poke_interval"] = kwargs["poll_interval"]
warnings.warn(
"Argument `poll_interval` is deprecated and will be removed "
"in a future release. Please use `poke_interval` instead.",
DeprecationWarning,
stacklevel=2,
)
- super().__init__(**kwargs)
-
- def execute(self, context: Context) -> None:
- """Check for the target_status and defers using the trigger"""
- if not poke(self, context):
- self.defer(
- timeout=timedelta(seconds=self.timeout),
- trigger=RedshiftClusterSensorTrigger(
- task_id=self.task_id,
- aws_conn_id=self.aws_conn_id,
- cluster_identifier=self.cluster_identifier,
- target_status=self.target_status,
- poke_interval=self.poke_interval,
- ),
- method_name="execute_complete",
- )
-
- def execute_complete(self, context: Context, event: Optional[Dict[Any, Any]] = None) -> None:
- """
- Callback for when the trigger fires - returns immediately.
- Relies on trigger to throw an exception, otherwise it assumes execution was
- successful.
- """
- if event:
- if "status" in event and event["status"] == "error":
- msg = "{}: {}".format(event["status"], event["message"])
- raise_error_or_skip_exception(self.soft_fail, msg)
- if "status" in event and event["status"] == "success":
- self.log.info("%s completed successfully.", self.task_id)
- self.log.info(
- "Cluster Identifier %s is in %s state", self.cluster_identifier, self.target_status
- )
- return None
- self.log.info("%s completed successfully.", self.task_id)
- return None
+ warnings.warn(
+ (
+ "This module is deprecated."
+ "Please use `airflow.providers.amazon.aws.sensors.redshift_cluster.RedshiftClusterSensor`"
+ "and set `deferrable` param to `True` instead."
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ super().__init__(deferrable=True, **kwargs)
diff --git a/astronomer/providers/amazon/aws/sensors/s3.py b/astronomer/providers/amazon/aws/sensors/s3.py
index a4e902060..ab0858c7e 100644
--- a/astronomer/providers/amazon/aws/sensors/s3.py
+++ b/astronomer/providers/amazon/aws/sensors/s3.py
@@ -1,131 +1,30 @@
from __future__ import annotations
import warnings
-from datetime import timedelta
-from typing import Any, Callable, Sequence, cast
+from typing import Any, Callable
-from airflow.exceptions import AirflowSkipException
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeysUnchangedSensor
from airflow.sensors.base import BaseSensorOperator
-from astronomer.providers.amazon.aws.triggers.s3 import (
- S3KeyTrigger,
-)
-from astronomer.providers.utils.sensor_util import raise_error_or_skip_exception
-from astronomer.providers.utils.typing_compat import Context
-
class S3KeySensorAsync(S3KeySensor):
"""
- Waits for one or multiple keys (a file-like instance on S3) to be present in a S3 bucket.
- S3 being a key/value it does not support folders. The path is just a key
- a resource.
-
- :param bucket_key: The key(s) being waited on. Supports full s3:// style url
- or relative path from root level. When it's specified as a full s3://
- url, please leave bucket_name as `None`
- :param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key``
- is not provided as a full s3:// url. When specified, all the keys passed to ``bucket_key``
- refers to this bucket
- :param wildcard_match: whether the bucket_key should be interpreted as a
- Unix wildcard pattern
- :param use_regex: whether to use regex to check bucket
- :param check_fn: Function that receives the list of the S3 objects,
- and returns a boolean:
- - ``True``: the criteria is met
- - ``False``: the criteria isn't met
- **Example**: Wait for any S3 object size more than 1 megabyte ::
-
- def check_fn(files: List) -> bool:
- return any(f.get('Size', 0) > 1048576 for f in files)
- :param aws_conn_id: a reference to the s3 connection
- :param verify: Whether or not to verify SSL certificates for S3 connection.
- By default SSL certificates are verified.
- You can provide the following values:
-
- - ``False``: do not validate SSL certificates. SSL will still be used
- (unless use_ssl is False), but SSL certificates will not be
- verified.
- - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
- You can specify this argument if you want to use a different
- CA cert bundle than the one used by botocore.
-
- .. seealso::
- `For more information on how to use this sensor, take a look at the guide:
- :ref:`howto/sensor:S3KeySensor`.
+ This class is deprecated.
+ Please use :class: `~airflow.providers.amazon.aws.sensors.s3.S3KeySensor`
+ and set `deferrable` param to `True` instead.
"""
- template_fields: Sequence[str] = ("bucket_key", "bucket_name")
-
- def __init__(
- self,
- *,
- bucket_key: str | list[str],
- bucket_name: str | None = None,
- wildcard_match: bool = False,
- use_regex: bool = False,
- check_fn: Callable[..., bool] | None = None,
- aws_conn_id: str = "aws_default",
- verify: str | bool | None = None,
- **kwargs: Any,
- ):
- self.bucket_key: list[str] = [bucket_key] if isinstance(bucket_key, str) else bucket_key
- self.use_regex = use_regex
- super().__init__(
- bucket_name=bucket_name,
- bucket_key=self.bucket_key,
- wildcard_match=wildcard_match,
- check_fn=check_fn,
- aws_conn_id=aws_conn_id,
- verify=verify,
- **kwargs,
- )
- self.check_fn = check_fn
- self.should_check_fn = True if check_fn else False
-
- def execute(self, context: Context) -> None:
- """Check for a keys in s3 and defers using the trigger"""
- try:
- poke = self.poke(context)
- except Exception as e:
- if self.soft_fail:
- raise AirflowSkipException(str(e))
- else:
- raise e
- if not poke:
- self._defer()
-
- def _defer(self) -> None:
- self.defer(
- timeout=timedelta(seconds=self.timeout),
- trigger=S3KeyTrigger(
- bucket_name=cast(str, self.bucket_name),
- bucket_key=self.bucket_key,
- wildcard_match=self.wildcard_match,
- use_regex=self.use_regex,
- aws_conn_id=self.aws_conn_id,
- verify=self.verify,
- poke_interval=self.poke_interval,
- soft_fail=self.soft_fail,
- should_check_fn=self.should_check_fn,
+ def __init__(self, **kwargs: Any) -> None:
+ warnings.warn(
+ (
+ "This module is deprecated."
+ "Please use `airflow.providers.amazon.aws.sensors.s3.S3KeySensor`"
+ "and set `deferrable` param to `True` instead."
),
- method_name="execute_complete",
+ DeprecationWarning,
+ stacklevel=2,
)
-
- def execute_complete(self, context: Context, event: Any = None) -> bool | None:
- """
- Callback for when the trigger fires - returns immediately.
- Relies on trigger to throw an exception, otherwise it assumes execution was
- successful.
- """
- if event["status"] == "running":
- if self.check_fn(event["files"]): # type: ignore[misc]
- return None
- else:
- self._defer()
- if event["status"] == "error":
- raise_error_or_skip_exception(self.soft_fail, event["message"])
- return None
+ super().__init__(deferrable=True, **kwargs)
class S3KeySizeSensorAsync(S3KeySensorAsync):
diff --git a/astronomer/providers/amazon/aws/triggers/batch.py b/astronomer/providers/amazon/aws/triggers/batch.py
index 01b1be301..9c55e7620 100644
--- a/astronomer/providers/amazon/aws/triggers/batch.py
+++ b/astronomer/providers/amazon/aws/triggers/batch.py
@@ -1,4 +1,7 @@
-from typing import Any, AsyncIterator, Dict, Optional, Tuple
+from __future__ import annotations
+
+import warnings
+from typing import Any, AsyncIterator
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -10,6 +13,10 @@ class BatchOperatorTrigger(BaseTrigger):
Checks for the state of a previously submitted job to AWS Batch.
BatchOperatorTrigger is fired as deferred class with params to poll the job state in Triggerer
+ This class is deprecated and will be removed in 2.0.0.
+ Use :class: `~airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger` instead
+
+
:param job_id: the job ID, usually unknown (None) until the
submit_job operation gets the jobId defined by AWS Batch
:param waiters: a :class:`.BatchWaiters` object (see note below);
@@ -24,12 +31,20 @@ class BatchOperatorTrigger(BaseTrigger):
def __init__(
self,
- job_id: Optional[str],
+ job_id: str | None,
waiters: Any,
max_retries: int,
- region_name: Optional[str],
- aws_conn_id: Optional[str] = "aws_default",
+ region_name: str | None,
+ aws_conn_id: str | None = "aws_default",
):
+ warnings.warn(
+ (
+ "This module is deprecated and will be removed in 2.0.0."
+ "Please use `airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger`"
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__()
self.job_id = job_id
self.waiters = waiters
@@ -37,7 +52,7 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.region_name = region_name
- def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BatchOperatorTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.batch.BatchOperatorTrigger",
@@ -50,7 +65,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)
- async def run(self) -> AsyncIterator["TriggerEvent"]:
+ async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Make async connection using aiobotocore library to AWS Batch,
periodically poll for the job status on the Triggerer
diff --git a/astronomer/providers/amazon/aws/triggers/redshift_cluster.py b/astronomer/providers/amazon/aws/triggers/redshift_cluster.py
index 9b0f3e75c..f5a032774 100644
--- a/astronomer/providers/amazon/aws/triggers/redshift_cluster.py
+++ b/astronomer/providers/amazon/aws/triggers/redshift_cluster.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
import asyncio
import warnings
-from typing import Any, AsyncIterator, Dict, Optional, Tuple
+from typing import Any, AsyncIterator
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -29,7 +31,7 @@ def __init__(
operation_type: str,
polling_period_seconds: float = 5.0,
skip_final_cluster_snapshot: bool = True,
- final_cluster_snapshot_identifier: Optional[str] = None,
+ final_cluster_snapshot_identifier: str | None = None,
):
warnings.warn(
(
@@ -48,7 +50,7 @@ def __init__(
self.skip_final_cluster_snapshot = skip_final_cluster_snapshot
self.final_cluster_snapshot_identifier = final_cluster_snapshot_identifier
- def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes RedshiftClusterTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger",
@@ -63,7 +65,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)
- async def run(self) -> AsyncIterator["TriggerEvent"]:
+ async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Make async connection to redshift, based on the operation type call
the RedshiftHookAsync functions
@@ -112,6 +114,9 @@ class RedshiftClusterSensorTrigger(BaseTrigger):
"""
RedshiftClusterSensorTrigger is fired as deferred class with params to run the task in trigger worker
+ This class is deprecated and will be removed in 2.0.0.
+ Use :class: `~airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger` instead
+
:param task_id: Reference to task id of the Dag
:param aws_conn_id: Reference to AWS connection id for redshift
:param cluster_identifier: unique identifier of a cluster
@@ -127,6 +132,14 @@ def __init__(
target_status: str,
poke_interval: float,
):
+ warnings.warn(
+ (
+ "This module is deprecated and will be removed in 2.0.0."
+ "Please use `airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger`"
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__()
self.task_id = task_id
self.aws_conn_id = aws_conn_id
@@ -134,7 +147,7 @@ def __init__(
self.target_status = target_status
self.poke_interval = poke_interval
- def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes RedshiftClusterSensorTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterSensorTrigger",
@@ -147,7 +160,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)
- async def run(self) -> AsyncIterator["TriggerEvent"]:
+ async def run(self) -> AsyncIterator[TriggerEvent]:
"""Simple async function run until the cluster status match the target status."""
try:
hook = RedshiftHookAsync(aws_conn_id=self.aws_conn_id)
diff --git a/astronomer/providers/amazon/aws/triggers/redshift_data.py b/astronomer/providers/amazon/aws/triggers/redshift_data.py
index c3e6790c9..540135fd7 100644
--- a/astronomer/providers/amazon/aws/triggers/redshift_data.py
+++ b/astronomer/providers/amazon/aws/triggers/redshift_data.py
@@ -1,4 +1,7 @@
-from typing import Any, AsyncIterator, Dict, List, Tuple
+from __future__ import annotations
+
+import warnings
+from typing import Any, AsyncIterator
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -9,6 +12,9 @@ class RedshiftDataTrigger(BaseTrigger):
"""
RedshiftDataTrigger is fired as deferred class with params to run the task in triggerer.
+ This class is deprecated and will be removed in 2.0.0.
+ Use :class: `~airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger` instead
+
:param task_id: task ID of the Dag
:param poll_interval: polling period in seconds to check for the status
:param aws_conn_id: AWS connection ID for redshift
@@ -19,16 +25,25 @@ def __init__(
self,
task_id: str,
poll_interval: int,
- query_ids: List[str],
+ query_ids: list[str],
aws_conn_id: str = "aws_default",
):
+ warnings.warn(
+ (
+ "This module is deprecated and will be removed in 2.0.0."
+ "Please use `airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger`"
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
super().__init__()
self.task_id = task_id
self.poll_interval = poll_interval
self.aws_conn_id = aws_conn_id
self.query_ids = query_ids
- def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes RedshiftDataTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger",
@@ -40,7 +55,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)
- async def run(self) -> AsyncIterator["TriggerEvent"]:
+ async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Makes async connection and gets status for a list of queries submitted by the operator.
Even if one of the queries has a non-successful state, the hook returns a failure event and the error
diff --git a/astronomer/providers/amazon/aws/triggers/s3.py b/astronomer/providers/amazon/aws/triggers/s3.py
index 20f8157c0..32b77031b 100644
--- a/astronomer/providers/amazon/aws/triggers/s3.py
+++ b/astronomer/providers/amazon/aws/triggers/s3.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
+import warnings
from typing import Any, AsyncIterator
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -12,6 +13,9 @@ class S3KeyTrigger(BaseTrigger):
"""
S3KeyTrigger is fired as deferred class with params to run the task in trigger worker
+ This class is deprecated and will be removed in 2.0.0.
+ Use :class: `~airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger` instead
+
:param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key``
is not provided as a full s3:// url.
:param bucket_key: The key being waited on. Supports full s3:// style url
@@ -37,6 +41,14 @@ def __init__(
should_check_fn: bool = False,
**hook_params: Any,
):
+ warnings.warn(
+ (
+ "This module is deprecated and will be removed in 2.0.0."
+ "Please use `airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger`"
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__()
self.bucket_name = bucket_name
self.bucket_key = bucket_key
diff --git a/astronomer/providers/amazon/aws/triggers/sagemaker.py b/astronomer/providers/amazon/aws/triggers/sagemaker.py
index f3aa41e06..f311d3a42 100644
--- a/astronomer/providers/amazon/aws/triggers/sagemaker.py
+++ b/astronomer/providers/amazon/aws/triggers/sagemaker.py
@@ -1,6 +1,9 @@
+from __future__ import annotations
+
import asyncio
import time
-from typing import Any, AsyncIterator, Dict, List, Optional, Tuple
+import warnings
+from typing import Any, AsyncIterator
from airflow.providers.amazon.aws.hooks.sagemaker import LogState
from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -12,6 +15,9 @@ class SagemakerProcessingTrigger(BaseTrigger):
"""
SagemakerProcessingTrigger is fired as deferred class with params to run the task in triggerer.
+ This class is deprecated and will be removed in 2.0.0.
+ Use :class: `~airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger` instead
+
:param job_name: name of the job to check status
:param poll_interval: polling period in seconds to check for the status
:param aws_conn_id: AWS connection ID for sagemaker
@@ -26,16 +32,24 @@ def __init__(
self,
job_name: str,
poll_interval: float,
- end_time: Optional[float],
+ end_time: float | None,
aws_conn_id: str = "aws_default",
):
+ warnings.warn(
+ (
+ "This module is deprecated and will be removed in 2.0.0."
+ "Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerTrigger`"
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__()
self.job_name = job_name
self.poll_interval = poll_interval
self.aws_conn_id = aws_conn_id
self.end_time = end_time
- def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes SagemakerProcessingTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerProcessingTrigger",
@@ -47,7 +61,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)
- async def run(self) -> AsyncIterator["TriggerEvent"]:
+ async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator.
Trigger returns a failure event if any error and success in state return the success event.
@@ -80,6 +94,9 @@ class SagemakerTrigger(BaseTrigger):
SagemakerTrigger is common trigger for both transform and training sagemaker job and it is
fired as deferred class with params to run the task in triggerer.
+ This class is deprecated and will be removed in 2.0.0.
+ Use :class: `~airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger` instead
+
:param job_name: name of the job to check status
:param job_type: Type of the sagemaker job whether it is Transform or Training
:param response_key: The key which needs to be look in the response.
@@ -97,9 +114,17 @@ def __init__(
job_type: str,
response_key: str,
poke_interval: float,
- end_time: Optional[float] = None,
+ end_time: float | None = None,
aws_conn_id: str = "aws_default",
):
+ warnings.warn(
+ (
+ "This module is deprecated and will be removed in 2.0.0."
+ "Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerTrigger`"
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
super().__init__()
self.job_name = job_name
self.job_type = job_type
@@ -108,7 +133,7 @@ def __init__(
self.end_time = end_time
self.aws_conn_id = aws_conn_id
- def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes SagemakerTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerTrigger",
@@ -122,7 +147,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)
- async def run(self) -> AsyncIterator["TriggerEvent"]:
+ async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator.
Trigger returns a failure event if any error and success in state return the success event.
@@ -148,7 +173,7 @@ def _get_async_hook(self) -> SageMakerHookAsync:
return SageMakerHookAsync(aws_conn_id=self.aws_conn_id)
@staticmethod
- async def get_job_status(hook: SageMakerHookAsync, job_name: str, job_type: str) -> Dict[str, Any]:
+ async def get_job_status(hook: SageMakerHookAsync, job_name: str, job_type: str) -> dict[str, Any]:
"""
Based on the job type the SageMakerHookAsync connect to sagemaker related function
and get the response of the job and return it
@@ -181,7 +206,7 @@ def __init__(
instance_count: int,
status: str,
poke_interval: float,
- end_time: Optional[float] = None,
+ end_time: float | None = None,
aws_conn_id: str = "aws_default",
):
super().__init__()
@@ -192,7 +217,7 @@ def __init__(
self.end_time = end_time
self.aws_conn_id = aws_conn_id
- def serialize(self) -> Tuple[str, Dict[str, Any]]:
+ def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes SagemakerTrainingWithLogTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerTrainingWithLogTrigger",
@@ -206,7 +231,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)
- async def run(self) -> AsyncIterator["TriggerEvent"]:
+ async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator.
Trigger returns a failure event if any error and success in state return the success event.
@@ -214,8 +239,8 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
hook = self._get_async_hook()
last_description = await hook.describe_training_job_async(self.job_name)
- stream_names: List[str] = [] # The list of log streams
- positions: Dict[str, Any] = {} # The current position in each stream, map of stream name -> position
+ stream_names: list[str] = [] # The list of log streams
+ positions: dict[str, Any] = {} # The current position in each stream, map of stream name -> position
job_already_completed = self.status not in self.NON_TERMINAL_STATES
diff --git a/setup.cfg b/setup.cfg
index e291bf37d..f9ba982d5 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -44,7 +44,7 @@ zip_safe = false
[options.extras_require]
amazon =
- apache-airflow-providers-amazon>=8.16.0
+ apache-airflow-providers-amazon>=8.17.0
aiobotocore>=2.1.1
apache.hive =
apache-airflow-providers-apache-hive>=6.1.5
@@ -118,7 +118,7 @@ mypy =
# All extras from above except 'mypy', 'docs' and 'tests'
all =
aiobotocore>=2.1.1
- apache-airflow-providers-amazon>=8.16.0
+ apache-airflow-providers-amazon>=8.17.0
apache-airflow-providers-apache-hive>=6.1.5
apache-airflow-providers-apache-livy>=3.7.1
apache-airflow-providers-cncf-kubernetes>=4
diff --git a/tests/amazon/aws/operators/test_batch.py b/tests/amazon/aws/operators/test_batch.py
index e8e798cc7..5c3370716 100644
--- a/tests/amazon/aws/operators/test_batch.py
+++ b/tests/amazon/aws/operators/test_batch.py
@@ -1,161 +1,17 @@
-from unittest import mock
-
-import pytest
-from airflow.exceptions import AirflowException, TaskDeferred
-from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.operators.batch import BatchOperator
from astronomer.providers.amazon.aws.operators.batch import BatchOperatorAsync
-from astronomer.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
-from tests.utils.airflow_util import create_context
class TestBatchOperatorAsync:
- JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3"
- JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
- MAX_RETRIES = 2
- STATUS_RETRIES = 3
- RESPONSE_WITHOUT_FAILURES = {
- "jobName": JOB_NAME,
- "jobId": JOB_ID,
- }
-
- @mock.patch("astronomer.providers.amazon.aws.operators.batch.BatchOperatorAsync.defer")
- @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description")
- @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
- def test_batch_op_async_succeeded_before_defer(self, get_client_type_mock, get_job_description, defer):
- get_client_type_mock.return_value.submit_job.return_value = self.RESPONSE_WITHOUT_FAILURES
- get_job_description.return_value = {"status": BatchClientHook.SUCCESS_STATE}
- task = BatchOperatorAsync(
- task_id="task",
- job_name=self.JOB_NAME,
- job_queue="queue",
- job_definition="hello-world",
- max_retries=self.MAX_RETRIES,
- status_retries=self.STATUS_RETRIES,
- parameters=None,
- overrides={},
- array_properties=None,
- aws_conn_id="aws_default",
- region_name="eu-west-1",
- tags={},
- )
- context = create_context(task)
- task.execute(context)
- assert not defer.called
-
- @pytest.mark.parametrize("status", (BatchClientHook.FAILURE_STATE, "Unexpected status"))
- @mock.patch("astronomer.providers.amazon.aws.operators.batch.BatchOperatorAsync.defer")
- @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description")
- @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
- def test_batch_op_async_failed_before_defer(
- self, get_client_type_mock, get_job_description, defer, status
- ):
- get_client_type_mock.return_value.submit_job.return_value = self.RESPONSE_WITHOUT_FAILURES
- get_job_description.return_value = {"status": status}
- task = BatchOperatorAsync(
- task_id="task",
- job_name=self.JOB_NAME,
- job_queue="queue",
- job_definition="hello-world",
- max_retries=self.MAX_RETRIES,
- status_retries=self.STATUS_RETRIES,
- parameters=None,
- overrides={},
- array_properties=None,
- aws_conn_id="aws_default",
- region_name="eu-west-1",
- tags={},
- )
- context = create_context(task)
- with pytest.raises(AirflowException):
- task.execute(context)
- assert not defer.called
-
- @pytest.mark.parametrize("status", BatchClientHook.INTERMEDIATE_STATES)
- @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description")
- @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
- def test_batch_op_async(self, get_client_type_mock, get_job_description, status):
- get_client_type_mock.return_value.submit_job.return_value = self.RESPONSE_WITHOUT_FAILURES
- get_job_description.return_value = {"status": status}
- task = BatchOperatorAsync(
- task_id="task",
- job_name=self.JOB_NAME,
- job_queue="queue",
- job_definition="hello-world",
- max_retries=self.MAX_RETRIES,
- status_retries=self.STATUS_RETRIES,
- parameters=None,
- overrides={},
- array_properties=None,
- aws_conn_id="aws_default",
- region_name="eu-west-1",
- tags={},
- )
- context = create_context(task)
- with pytest.raises(TaskDeferred) as exc:
- task.execute(context)
- assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger is not a BatchOperatorTrigger"
-
- def test_batch_op_async_execute_failure(self, context):
- """Tests that an AirflowException is raised in case of error event"""
-
- task = BatchOperatorAsync(
- task_id="task",
- job_name=self.JOB_NAME,
- job_queue="queue",
- job_definition="hello-world",
- max_retries=self.MAX_RETRIES,
- status_retries=self.STATUS_RETRIES,
- parameters=None,
- overrides={},
- array_properties=None,
- aws_conn_id="aws_default",
- region_name="eu-west-1",
- tags={},
- )
- with pytest.raises(AirflowException) as exc_info:
- task.execute_complete(context=None, event={"status": "error", "message": "test failure message"})
-
- assert str(exc_info.value) == "test failure message"
-
- @pytest.mark.parametrize(
- "event",
- [{"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}],
- )
- def test_batch_op_async_execute_complete(self, caplog, event):
- """Tests that execute_complete method returns None and that it prints expected log"""
- task = BatchOperatorAsync(
- task_id="task",
- job_name=self.JOB_NAME,
- job_queue="queue",
- job_definition="hello-world",
- max_retries=self.MAX_RETRIES,
- status_retries=self.STATUS_RETRIES,
- parameters=None,
- overrides={},
- array_properties=None,
- aws_conn_id="aws_default",
- region_name="eu-west-1",
- tags={},
- )
- with mock.patch.object(task.log, "info") as mock_log_info:
- assert task.execute_complete(context=None, event=event) is None
-
- mock_log_info.assert_called_with(f"AWS Batch job ({self.JOB_ID}) succeeded")
-
- @mock.patch("astronomer.providers.amazon.aws.operators.batch.BatchOperatorAsync.submit_job")
- def test_batch_op_raises_exception_before_deferral_if_job_id_unset(self, mock_submit_job):
- """
- Test that an AirflowException is raised if job_id is not set before deferral by mocking the submit_job
- method which sets the job_id attribute of the instance.
- """
+ def test_init(self):
task = BatchOperatorAsync(
task_id="task",
- job_name=self.JOB_NAME,
+ job_name="51455483-c62c-48ac-9b88-53a6a725baa3",
job_queue="queue",
job_definition="hello-world",
- max_retries=self.MAX_RETRIES,
- status_retries=self.STATUS_RETRIES,
+ max_retries=2,
+ status_retries=3,
parameters=None,
overrides={},
array_properties=None,
@@ -163,7 +19,5 @@ def test_batch_op_raises_exception_before_deferral_if_job_id_unset(self, mock_su
region_name="eu-west-1",
tags={},
)
- context = create_context(task)
- with pytest.raises(AirflowException) as exc:
- task.execute(context)
- assert "AWS Batch job - job_id was not found" in str(exc.value)
+ assert isinstance(task, BatchOperator)
+ assert task.deferrable is True
diff --git a/tests/amazon/aws/operators/test_redshift_data.py b/tests/amazon/aws/operators/test_redshift_data.py
index 5cff1b4f8..8c2860695 100644
--- a/tests/amazon/aws/operators/test_redshift_data.py
+++ b/tests/amazon/aws/operators/test_redshift_data.py
@@ -1,101 +1,16 @@
-from unittest import mock
-
-import pytest
-from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator
from astronomer.providers.amazon.aws.operators.redshift_data import (
RedshiftDataOperatorAsync,
)
-from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger
-from tests.utils.airflow_util import create_context
class TestRedshiftDataOperatorAsync:
- DATABASE_NAME = "TEST_DATABASE"
- TASK_ID = "fetch_data"
- SQL_QUERY = "select * from any"
- TASK = RedshiftDataOperatorAsync(
- task_id=TASK_ID,
- sql=SQL_QUERY,
- database=DATABASE_NAME,
- )
-
- @mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer")
- @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
- @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
- def test_redshift_data_op_async_finished_before_deferred(self, mock_execute, mock_conn, mock_defer):
- mock_execute.return_value = ["test_query_id"], {}
- mock_conn.describe_statement.return_value = {
- "Status": "FINISHED",
- }
- self.TASK.execute(create_context(self.TASK))
- assert not mock_defer.called
-
- @mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer")
- @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
- @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
- def test_redshift_data_op_async_aborted_before_deferred(self, mock_execute, mock_conn, mock_defer):
- mock_execute.return_value = ["test_query_id"], {}
- mock_conn.describe_statement.return_value = {"Status": "ABORTED"}
-
- with pytest.raises(AirflowException):
- self.TASK.execute(create_context(self.TASK))
-
- assert not mock_defer.called
-
- @mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer")
- @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
- @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
- def test_redshift_data_op_async_failed_before_deferred(self, mock_execute, mock_conn, mock_defer):
- mock_execute.return_value = ["test_query_id"], {}
- mock_conn.describe_statement.return_value = {
- "Status": "FAILED",
- "QueryString": "test query",
- "Error": "test error",
- }
-
- with pytest.raises(AirflowException):
- self.TASK.execute(create_context(self.TASK))
-
- assert not mock_defer.called
-
- @pytest.mark.parametrize("status", ("SUBMITTED", "PICKED", "STARTED"))
- @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
- @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
- def test_redshift_data_op_async(self, mock_execute, mock_conn, status):
- mock_execute.return_value = ["test_query_id"], {}
- mock_conn.describe_statement.return_value = {"Status": status}
-
- with pytest.raises(TaskDeferred) as exc:
- self.TASK.execute(create_context(self.TASK))
- assert isinstance(exc.value.trigger, RedshiftDataTrigger), "Trigger is not a RedshiftDataTrigger"
-
- @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
- def test_redshift_data_op_async_execute_query_error(self, mock_execute, context):
- mock_execute.return_value = [], {"status": "error", "message": "Test exception"}
- with pytest.raises(AirflowException):
- self.TASK.execute(context)
-
- def test_redshift_data_op_async_execute_failure(self, context):
- """Tests that an AirflowException is raised in case of error event"""
-
- with pytest.raises(AirflowException):
- self.TASK.execute_complete(
- context=None, event={"status": "error", "message": "test failure message"}
- )
-
- @pytest.mark.parametrize(
- "event",
- [None, {"status": "success", "message": "Job completed"}],
- )
- def test_redshift_data_op_async_execute_complete(self, event):
- """Asserts that logging occurs as expected"""
-
- if not event:
- with pytest.raises(AirflowException) as exception_info:
- self.TASK.execute_complete(context=None, event=None)
- assert exception_info.value.args[0] == "Did not receive valid event from the trigerrer"
- else:
- with mock.patch.object(self.TASK.log, "info") as mock_log_info:
- self.TASK.execute_complete(context=None, event=event)
- mock_log_info.assert_called_with("%s completed successfully.", self.TASK_ID)
+ def test_init(self):
+ task = RedshiftDataOperatorAsync(
+ task_id="fetch_data",
+ sql="select * from any",
+ database="TEST_DATABASE",
+ )
+ assert isinstance(task, RedshiftDataOperator)
+ assert task.deferrable is True
diff --git a/tests/amazon/aws/operators/test_sagemaker.py b/tests/amazon/aws/operators/test_sagemaker.py
index 8b83019ed..8f25c50e4 100644
--- a/tests/amazon/aws/operators/test_sagemaker.py
+++ b/tests/amazon/aws/operators/test_sagemaker.py
@@ -4,6 +4,10 @@
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
+from airflow.providers.amazon.aws.operators.sagemaker import (
+ SageMakerProcessingOperator,
+ SageMakerTransformOperator,
+)
from airflow.utils.timezone import datetime
from astronomer.providers.amazon.aws.operators.sagemaker import (
@@ -12,7 +16,6 @@
SageMakerTransformOperatorAsync,
)
from astronomer.providers.amazon.aws.triggers.sagemaker import (
- SagemakerProcessingTrigger,
SagemakerTrainingWithLogTrigger,
SagemakerTrigger,
)
@@ -136,158 +139,15 @@ class TestSagemakerProcessingOperatorAsync:
CHECK_INTERVAL = 5
MAX_INGESTION_TIME = 60 * 60 * 24 * 7
- @mock.patch("astronomer.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperatorAsync.defer")
- @mock.patch.object(
- SageMakerHook, "describe_processing_job", return_value={"ProcessingJobStatus": "Completed"}
- )
- @mock.patch.object(
- SageMakerHook,
- "create_processing_job",
- return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}},
- )
- @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=False)
- def test_sagemakerprocessing_op_async_complete_before_defer(
- self, mock_hook, mock_processing, mock_describe, mock_defer, context
- ):
- task = SageMakerProcessingOperatorAsync(
- config=CREATE_PROCESSING_PARAMS,
- task_id=self.TASK_ID,
- check_interval=self.CHECK_INTERVAL,
- max_ingestion_time=self.MAX_INGESTION_TIME,
- )
- task.execute(context)
-
- assert not mock_defer.called
-
- @mock.patch("astronomer.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperatorAsync.defer")
- @mock.patch.object(SageMakerHook, "describe_processing_job", return_value=Exception)
- @mock.patch.object(
- SageMakerHook,
- "create_processing_job",
- return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}},
- )
- @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=False)
- def test_sagemakerprocessing_op_async_encounter_exception_before_defer(
- self, mock_hook, mock_processing, mock_describe, mock_defer, context
- ):
- task = SageMakerProcessingOperatorAsync(
- config=CREATE_PROCESSING_PARAMS,
- task_id=self.TASK_ID,
- check_interval=self.CHECK_INTERVAL,
- max_ingestion_time=self.MAX_INGESTION_TIME,
- )
-
- with pytest.raises(Exception):
- task.execute(context)
-
- assert not mock_defer.called
-
- @mock.patch("astronomer.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperatorAsync.defer")
- @mock.patch.object(
- SageMakerHook,
- "describe_processing_job",
- return_value={"ProcessingJobStatus": "Failed", "FailureReason": "it just failed"},
- )
- @mock.patch.object(
- SageMakerHook,
- "create_processing_job",
- return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}},
- )
- @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=False)
- def test_sagemakerprocessing_op_async_failed_before_defer(
- self, mock_hook, mock_processing, mock_describe, mock_defer, context
- ):
- task = SageMakerProcessingOperatorAsync(
- config=CREATE_PROCESSING_PARAMS,
- task_id=self.TASK_ID,
- check_interval=self.CHECK_INTERVAL,
- max_ingestion_time=self.MAX_INGESTION_TIME,
- )
- with pytest.raises(AirflowException):
- task.execute(context)
-
- assert not mock_defer.called
-
- @mock.patch.object(
- SageMakerHook, "describe_processing_job", return_value={"ProcessingJobStatus": "InProgress"}
- )
- @mock.patch.object(
- SageMakerHook,
- "create_processing_job",
- return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}},
- )
- @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=False)
- def test_sagemakerprocessing_op_async(self, mock_hook, mock_processing, mock_describe, context):
- """Assert SageMakerProcessingOperatorAsync deferred properly"""
+ def test_init(self):
task = SageMakerProcessingOperatorAsync(
config=CREATE_PROCESSING_PARAMS,
task_id=self.TASK_ID,
check_interval=self.CHECK_INTERVAL,
max_ingestion_time=self.MAX_INGESTION_TIME,
)
- with pytest.raises(TaskDeferred) as exc:
- task.execute(context)
- assert isinstance(
- exc.value.trigger, SagemakerProcessingTrigger
- ), "Trigger is not a SagemakerProcessingTrigger"
-
- @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=True)
- def test_sagemakerprocessing_op_async_duplicate_failure(self, mock_hook, context):
- """Tests that an AirflowException is raised in case of error event from find_processing_job_name"""
- task = SageMakerProcessingOperatorAsync(
- config=CREATE_PROCESSING_PARAMS,
- task_id=self.TASK_ID,
- check_interval=self.CHECK_INTERVAL,
- )
- with pytest.raises(AirflowException):
- task.execute(context)
-
- @mock.patch.object(
- SageMakerHook,
- "create_processing_job",
- return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}},
- )
- @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=False)
- def test_sagemakerprocessing_op_async_failure(self, mock_hook, mock_processing_job, context):
- """Tests that an AirflowException is raised in case of error event from create_processing_job"""
- task = SageMakerProcessingOperatorAsync(
- config=CREATE_PROCESSING_PARAMS,
- task_id=self.TASK_ID,
- check_interval=self.CHECK_INTERVAL,
- )
- with pytest.raises(AirflowException):
- task.execute(context)
-
- @pytest.mark.parametrize(
- "event",
- [{"status": "error", "message": "test failure message"}, None],
- )
- def test_sagemakerprocessing_op_async_execute_failure(self, event):
- """Tests that an AirflowException is raised in case of error event"""
- task = SageMakerProcessingOperatorAsync(
- config=CREATE_PROCESSING_PARAMS,
- task_id=self.TASK_ID,
- check_interval=self.CHECK_INTERVAL,
- max_ingestion_time=self.MAX_INGESTION_TIME,
- )
- with pytest.raises(AirflowException):
- task.execute_complete(context=None, event=event)
-
- @pytest.mark.parametrize(
- "event",
- [{"status": "success", "message": "Job completed"}],
- )
- def test_sagemakerprocessing_op_async_execute_complete(self, event):
- """Asserts that logging occurs as expected"""
- task = SageMakerProcessingOperatorAsync(
- config=CREATE_PROCESSING_PARAMS,
- task_id=self.TASK_ID,
- check_interval=self.CHECK_INTERVAL,
- max_ingestion_time=self.MAX_INGESTION_TIME,
- )
- with mock.patch.object(task.log, "info") as mock_log_info:
- task.execute_complete(context=None, event=event)
- mock_log_info.assert_called_with("%s completed successfully.", "test_sagemaker_processing_operator")
+ assert isinstance(task, SageMakerProcessingOperator)
+ assert task.deferrable is True
class TestSagemakerTransformOperatorAsync:
@@ -295,63 +155,7 @@ class TestSagemakerTransformOperatorAsync:
CHECK_INTERVAL = 5
MAX_INGESTION_TIME = 60 * 60 * 24 * 7
- @mock.patch("astronomer.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperatorAsync.defer")
- @mock.patch.object(SageMakerHook, "describe_model", return_value={})
- @mock.patch.object(
- SageMakerHook,
- "describe_transform_job",
- return_value={"TransformJobStatus": "Failed", "FailureReason": "it just failed"},
- )
- @mock.patch.object(
- SageMakerHook,
- "create_transform_job",
- return_value={"TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}},
- )
- @mock.patch.object(SageMakerHook, "list_transform_jobs", return_value=[])
- @mock.patch.object(SageMakerHook, "create_model", return_value=None)
- def test_sagemaker_transform_op_async_failed_before_defer(
- self,
- mock_create_model,
- mock_list_transform_jobs,
- mock_transform_job,
- mock_describe_transform_job,
- mock_describe_model,
- mock_defer,
- context,
- ):
- task = SageMakerTransformOperatorAsync(
- config=CONFIG,
- task_id=self.TASK_ID,
- check_if_job_exists=False,
- check_interval=self.CHECK_INTERVAL,
- max_ingestion_time=self.MAX_INGESTION_TIME,
- )
- with pytest.raises(AirflowException):
- task.execute(context)
- assert not mock_defer.called
-
- @mock.patch("astronomer.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperatorAsync.defer")
- @mock.patch.object(SageMakerHook, "describe_model", return_value={})
- @mock.patch.object(
- SageMakerHook, "describe_transform_job", return_value={"TransformJobStatus": "Completed"}
- )
- @mock.patch.object(
- SageMakerHook,
- "create_transform_job",
- return_value={"TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}},
- )
- @mock.patch.object(SageMakerHook, "list_transform_jobs", return_value=[])
- @mock.patch.object(SageMakerHook, "create_model", return_value=None)
- def test_sagemaker_transform_op_async_complete_before_defer(
- self,
- mock_create_model,
- mock_list_transform_jobs,
- mock_transform_job,
- mock_describe_transform_job,
- mock_describe_model,
- mock_defer,
- context,
- ):
+ def test_init(self):
task = SageMakerTransformOperatorAsync(
config=CONFIG,
task_id=self.TASK_ID,
@@ -359,87 +163,8 @@ def test_sagemaker_transform_op_async_complete_before_defer(
check_interval=self.CHECK_INTERVAL,
max_ingestion_time=self.MAX_INGESTION_TIME,
)
- task.execute(context)
- assert not mock_defer.called
-
- @mock.patch.object(
- SageMakerHook, "describe_transform_job", return_value={"TransformJobStatus": "InProgress"}
- )
- @mock.patch.object(
- SageMakerHook,
- "create_transform_job",
- return_value={"TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}},
- )
- @mock.patch.object(SageMakerHook, "list_transform_jobs", return_value=[])
- @mock.patch.object(SageMakerHook, "create_model", return_value=None)
- def test_sagemaker_transform_op_async(
- self,
- mock_create_model,
- mock_list_transform_jobs,
- mock_transform_job,
- mock_describe_transform_job,
- context,
- ):
- """Assert SageMakerTransformOperatorAsync deferred properly"""
- task = SageMakerTransformOperatorAsync(
- config=CONFIG,
- task_id=self.TASK_ID,
- check_if_job_exists=False,
- check_interval=self.CHECK_INTERVAL,
- max_ingestion_time=self.MAX_INGESTION_TIME,
- )
- with pytest.raises(TaskDeferred) as exc:
- task.execute(context)
- assert isinstance(exc.value.trigger, SagemakerTrigger), "Trigger is not a SagemakerTrigger"
-
- @mock.patch.object(
- SageMakerHook,
- "create_transform_job",
- return_value={"TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}},
- )
- @mock.patch.object(SageMakerHook, "list_transform_jobs", return_value=[])
- @mock.patch.object(SageMakerHook, "create_model", return_value=None)
- def test_sagemaker_transform_op_async_execute_failure(self, mock_hook, mock_transform_job, context):
- """Tests that an AirflowException is raised in case of error event from create_transform_job"""
- task = SageMakerTransformOperatorAsync(
- config=CONFIG,
- task_id=self.TASK_ID,
- check_if_job_exists=False,
- check_interval=self.CHECK_INTERVAL,
- )
- with pytest.raises(AirflowException):
- task.execute(context)
-
- @pytest.mark.parametrize(
- "mock_event",
- [{"status": "error", "message": "test failure message"}, None],
- )
- def test_sagemaker_transform_op_async_execute_complete_failure(self, mock_event):
- """Tests that an AirflowException is raised in case of error event"""
- task = SageMakerTransformOperatorAsync(
- config=CONFIG,
- task_id=self.TASK_ID,
- check_interval=self.CHECK_INTERVAL,
- )
- with pytest.raises(AirflowException):
- task.execute_complete(context=None, event=mock_event)
-
- @pytest.mark.parametrize(
- "mock_event",
- [{"status": "success", "message": "Job completed"}],
- )
- @mock.patch.object(SageMakerHook, "describe_model")
- def test_sagemaker_transform_op_async_execute_complete(self, mock_model_output, mock_event):
- """Asserts that logging occurs as expected"""
- task = SageMakerTransformOperatorAsync(
- config=CONFIG,
- task_id=self.TASK_ID,
- check_interval=self.CHECK_INTERVAL,
- )
- mock_model_output.return_value = {"test": "test"}
- with mock.patch.object(task.log, "info") as mock_log_info:
- task.execute_complete(context=None, event=mock_event)
- mock_log_info.assert_called_with("%s completed successfully.", "test_sagemaker_transform_operator")
+ assert isinstance(task, SageMakerTransformOperator)
+ assert task.deferrable is True
class TestSagemakerTrainingOperatorAsync:
diff --git a/tests/amazon/aws/sensors/test_redshift_sensor.py b/tests/amazon/aws/sensors/test_redshift_sensor.py
index f3023ae5d..09b0eb238 100644
--- a/tests/amazon/aws/sensors/test_redshift_sensor.py
+++ b/tests/amazon/aws/sensors/test_redshift_sensor.py
@@ -1,72 +1,20 @@
-from unittest import mock
-
-import pytest
-from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.providers.amazon.aws.sensors.redshift_cluster import (
+ RedshiftClusterSensor,
+)
from astronomer.providers.amazon.aws.sensors.redshift_cluster import (
RedshiftClusterSensorAsync,
)
-from astronomer.providers.amazon.aws.triggers.redshift_cluster import (
- RedshiftClusterSensorTrigger,
-)
TASK_ID = "redshift_sensor_check"
-POLLING_PERIOD_SECONDS = 1.0
-
-MODULE = "astronomer.providers.amazon.aws.sensors.redshift_cluster"
class TestRedshiftClusterSensorAsync:
- TASK = RedshiftClusterSensorAsync(
- task_id=TASK_ID,
- cluster_identifier="astro-redshift-cluster-1",
- target_status="available",
- )
-
- @mock.patch(f"{MODULE}.RedshiftClusterSensorAsync.defer")
- @mock.patch(f"{MODULE}.RedshiftClusterSensorAsync.poke", return_value=True)
- def test_redshift_cluster_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
- """Assert task is not deferred when it receives a finish status before deferring"""
- self.TASK.execute(context)
- assert not mock_defer.called
-
- @mock.patch(f"{MODULE}.RedshiftClusterSensorAsync.poke", return_value=False)
- def test_redshift_cluster_sensor_async(self, context):
- """Test RedshiftClusterSensorAsync that a task with wildcard=True
- is deferred and an RedshiftClusterSensorTrigger will be fired when executed method is called"""
-
- with pytest.raises(TaskDeferred) as exc:
- self.TASK.execute(context)
- assert isinstance(
- exc.value.trigger, RedshiftClusterSensorTrigger
- ), "Trigger is not a RedshiftClusterSensorTrigger"
-
- def test_redshift_sensor_async_execute_failure(self, context):
- """Test RedshiftClusterSensorAsync with an AirflowException is raised in case of error event"""
-
- with pytest.raises(AirflowException):
- self.TASK.execute_complete(
- context=None, event={"status": "error", "message": "test failure message"}
- )
-
- def test_redshift_sensor_async_execute_complete(self):
- """Asserts that logging occurs as expected"""
-
- with mock.patch.object(self.TASK.log, "info") as mock_log_info:
- self.TASK.execute_complete(
- context=None, event={"status": "success", "cluster_state": "available"}
- )
- mock_log_info.assert_called_with(
- "Cluster Identifier %s is in %s state", "astro-redshift-cluster-1", "available"
+ def test_init(self):
+ task = RedshiftClusterSensorAsync(
+ task_id=TASK_ID,
+ cluster_identifier="astro-redshift-cluster-1",
+ target_status="available",
)
-
- def test_poll_interval_deprecation_warning(self):
- """Test DeprecationWarning for RedshiftClusterSensorAsync by setting param poll_interval"""
- # TODO: Remove once deprecated
- with pytest.warns(expected_warning=DeprecationWarning):
- RedshiftClusterSensorAsync(
- task_id=TASK_ID,
- cluster_identifier="astro-redshift-cluster-1",
- target_status="available",
- poll_interval=5.0,
- )
+ assert isinstance(task, RedshiftClusterSensor)
+ assert task.deferrable is True
diff --git a/tests/amazon/aws/sensors/test_s3_sensors.py b/tests/amazon/aws/sensors/test_s3_sensors.py
index 4dfae42d6..bff9e7b2c 100644
--- a/tests/amazon/aws/sensors/test_s3_sensors.py
+++ b/tests/amazon/aws/sensors/test_s3_sensors.py
@@ -1,15 +1,7 @@
import unittest
-from datetime import timedelta
-from typing import Any, List
-from unittest import mock
import pytest
-from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred
-from airflow.models import DAG, DagRun, TaskInstance
-from airflow.models.variable import Variable
-from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor
-from airflow.utils import timezone
-from parameterized import parameterized
+from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeysUnchangedSensor
from astronomer.providers.amazon.aws.sensors.s3 import (
S3KeySensorAsync,
@@ -17,279 +9,15 @@
S3KeysUnchangedSensorAsync,
S3PrefixSensorAsync,
)
-from astronomer.providers.amazon.aws.triggers.s3 import (
- S3KeyTrigger,
-)
MODULE = "astronomer.providers.amazon.aws.sensors.s3"
class TestS3KeySensorAsync:
- @mock.patch(f"{MODULE}.S3KeySensorAsync.defer")
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=True)
- def test_finish_before_deferred(self, mock_poke, mock_defer, context):
- """Assert task is not deferred when it receives a finish status before deferring"""
- sensor = S3KeySensorAsync(task_id="s3_key_sensor", bucket_key="file_in_bucket")
- sensor.execute(context)
- assert not mock_defer.called
-
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False)
- def test_bucket_name_none_and_bucket_key_as_relative_path(self, mock_poke, context):
- """
- Test if exception is raised when bucket_name is None
- and bucket_key is provided with one of the two keys as relative path rather than s3:// url.
- """
- sensor = S3KeySensorAsync(task_id="s3_key_sensor", bucket_key="file_in_bucket")
- with pytest.raises(TaskDeferred):
- sensor.execute(context)
-
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False)
- @mock.patch("astronomer.providers.amazon.aws.hooks.s3.S3HookAsync.get_head_object")
- def test_bucket_name_none_and_bucket_key_is_list_and_contain_relative_path(
- self, mock_head_object, mock_poke, context
- ):
- """
- Test if exception is raised when bucket_name is None
- and bucket_key is provided with one of the two keys as relative path rather than s3:// url.
- :return:
- """
- mock_head_object.return_value = {"ContentLength": 0}
- sensor = S3KeySensorAsync(
- task_id="s3_key_sensor", bucket_key=["s3://test_bucket/file", "file_in_bucket"]
- )
- with pytest.raises(TaskDeferred):
- sensor.execute(context)
-
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False)
- def test_bucket_name_provided_and_bucket_key_is_s3_url(self, mock_poke, context):
- """
- Test if exception is raised when bucket_name is provided
- while bucket_key is provided as a full s3:// url.
- :return:
- """
- op = S3KeySensorAsync(
- task_id="s3_key_sensor", bucket_key="s3://test_bucket/file", bucket_name="test_bucket"
- )
- with pytest.raises(TaskDeferred):
- op.execute(context)
-
- @parameterized.expand(
- [
- ["s3://bucket/key", None],
- ["key", "bucket"],
- ]
- )
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False)
- @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
- def test_s3_key_sensor_async(self, key, bucket, mock_hook, mock_poke):
- """
- Asserts that a task is deferred and an S3KeyTrigger will be fired
- when the S3KeySensorAsync is executed.
- """
- mock_hook.check_for_key.return_value = False
-
- sensor = S3KeySensorAsync(
- task_id="s3_key_sensor_async",
- bucket_key=key,
- bucket_name=bucket,
- )
-
- with pytest.raises(TaskDeferred) as exc:
- sensor.execute(context=None)
-
- assert isinstance(exc.value.trigger, S3KeyTrigger), "Trigger is not a S3KeyTrigger"
-
- @parameterized.expand(
- [
- ["s3://bucket/key", None],
- ["key", "bucket"],
- ]
- )
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False)
- @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
- def test_s3_key_sensor_execute_complete_success(self, key, bucket, mock_poke, mock_hook):
- """
- Asserts that a task is completed with success status.
- """
- mock_hook.check_for_key.return_value = False
-
- sensor = S3KeySensorAsync(
- task_id="s3_key_sensor_async",
- bucket_key=key,
- bucket_name=bucket,
- )
- assert sensor.execute_complete(context={}, event={"status": "success"}) is None
-
- @parameterized.expand(
- [
- ["key", "bucket"],
- ]
- )
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False)
- def test_s3_key_sensor_execute_complete_success_with_keys(self, key, bucket, mock_poke):
- """
- Asserts that a task is completed with success status and check function
- """
-
- def check_fn(files: List[Any]) -> bool:
- return all(f.get("Size", 0) > 0 for f in files)
-
- sensor = S3KeySensorAsync(
- task_id="s3_key_sensor_async",
- bucket_key=key,
- bucket_name=bucket,
- check_fn=check_fn,
- )
- assert (
- sensor.execute_complete(context={}, event={"status": "running", "files": [{"Size": 10}]}) is None
- )
-
- @mock.patch(f"{MODULE}.S3KeySensorAsync._defer")
- def test_s3_key_sensor_re_defer(self, mock_defer):
- def check_fn(files: List[Any]) -> bool:
- return False
-
- sensor = S3KeySensorAsync(
- task_id="s3_key_sensor_async",
- bucket_key="key",
- bucket_name="bucket",
- check_fn=check_fn,
- )
- sensor.execute_complete(context={}, event={"status": "running", "files": [{"Size": 10}]})
-
- mock_defer.assert_called_once()
-
- @parameterized.expand(
- [
- ["s3://bucket/key", None],
- ["key", "bucket"],
- ]
- )
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False)
- @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
- def test_s3_key_sensor_execute_complete_error(self, key, bucket, mock_hook, mock_poke):
- """
- Asserts that a task is completed with error status.
- """
- mock_hook.check_for_key.return_value = False
-
- sensor = S3KeySensorAsync(
- task_id="s3_key_sensor_async",
- bucket_key=key,
- bucket_name=bucket,
- )
- with pytest.raises(AirflowException):
- sensor.execute_complete(
- context={}, event={"status": "error", "message": "mocked error", "soft_fail": False}
- )
-
- @parameterized.expand(
- [
- ["s3://bucket/key", None],
- ["key", "bucket"],
- ]
- )
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False)
- @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
- @mock.patch.object(S3KeySensorAsync, "defer")
- @mock.patch("astronomer.providers.amazon.aws.sensors.s3.S3KeyTrigger")
- def test_s3_key_sensor_async_with_mock_defer(
- self, key, bucket, mock_trigger, mock_defer, mock_hook, mock_poke
- ):
- """
- Asserts that a task is deferred and an S3KeyTrigger will be fired
- when the S3KeySensorAsync is executed.
- """
- mock_hook.check_for_key.return_value = False
-
- sensor = S3KeySensorAsync(
- task_id="s3_key_sensor_async",
- bucket_key=key,
- bucket_name=bucket,
- )
-
- sensor.execute(context=None)
-
- mock_defer.assert_called()
- mock_defer.assert_called_once_with(
- timeout=timedelta(days=7), trigger=mock_trigger.return_value, method_name="execute_complete"
- )
-
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False)
- @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key")
- def test_parse_bucket_key_from_jinja(self, mock_check, mock_poke):
- mock_check.return_value = False
-
- Variable.set("test_bucket_key", "s3://bucket/key")
-
- execution_date = timezone.datetime(2020, 1, 1)
-
- dag = DAG("test_s3_key", start_date=execution_date)
- op = S3KeySensorAsync(
- task_id="s3_key_sensor",
- bucket_key="s3://bucket/key",
- bucket_name=None,
- dag=dag,
- )
-
- dag_run = DagRun(dag_id=dag.dag_id, execution_date=execution_date, run_id="test")
- ti = TaskInstance(task=op)
- ti.dag_run = dag_run
- context = ti.get_template_context()
- ti.render_templates(context)
-
- assert op.bucket_key == ["s3://bucket/key"]
- assert op.bucket_name is None
-
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False)
- @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
- def test_s3_key_sensor_with_wildcard_async(self, mock_hook, mock_poke, context):
- """
- Asserts that a task with wildcard=True is deferred and an S3KeyTrigger will be fired
- when the S3KeySensorAsync is executed.
- """
- mock_hook.check_for_key.return_value = False
-
- sensor = S3KeySensorAsync(
- task_id="s3_key_sensor_async", bucket_key="s3://test_bucket/file", wildcard_match=True
- )
-
- with pytest.raises(TaskDeferred) as exc:
- sensor.execute(context)
-
- assert isinstance(exc.value.trigger, S3KeyTrigger), "Trigger is not a S3KeyTrigger"
-
- def test_soft_fail(self):
- """Raise AirflowSkipException in case soft_fail is true"""
- sensor = S3KeySensorAsync(
- task_id="s3_key_sensor_async", bucket_key="key", bucket_name="bucket", soft_fail=True
- )
- with pytest.raises(AirflowSkipException):
- sensor.execute_complete(
- context={}, event={"status": "error", "message": "mocked error", "soft_fail": True}
- )
-
- @pytest.mark.parametrize(
- "soft_fail,exception",
- [
- (True, AirflowSkipException),
- (False, Exception),
- ],
- )
- @mock.patch(f"{MODULE}.S3KeySensorAsync.poke")
- def test_execute_handle_exception(self, mock_poke, soft_fail, exception):
- mock_poke.side_effect = Exception()
- sensor = S3KeySensorAsync(
- task_id="s3_key_sensor_async", bucket_key="key", bucket_name="bucket", soft_fail=soft_fail
- )
- with pytest.raises(exception):
- sensor.execute(context={})
-
- def test_soft_fail_enable(self, context):
- """Sensor should raise AirflowSkipException if soft_fail is True and error occur"""
- sensor = S3KeySensorAsync(task_id="s3_key_sensor", bucket_key="file_in_bucket", soft_fail=True)
- with pytest.raises(AirflowSkipException):
- sensor.execute(context)
+ def test_init(self):
+ task = S3KeySensorAsync(task_id="s3_key_sensor", bucket_key="file_in_bucket")
+ assert isinstance(task, S3KeySensor)
+ assert task.deferrable is True
class TestS3KeysUnchangedSensorAsync: