diff --git a/astronomer/providers/amazon/aws/hooks/batch_client.py b/astronomer/providers/amazon/aws/hooks/batch_client.py index 8e7e37b31..a2ae392b5 100644 --- a/astronomer/providers/amazon/aws/hooks/batch_client.py +++ b/astronomer/providers/amazon/aws/hooks/batch_client.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import asyncio +import warnings from random import sample -from typing import Any, Dict, List, Optional, Union +from typing import Any import botocore.exceptions from airflow.exceptions import AirflowException @@ -13,6 +16,9 @@ class BatchClientHookAsync(BatchClientHook, AwsBaseHookAsync): """ Async client for AWS Batch services. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook` instead + :param max_retries: exponential back-off retries, 4200 = 48 hours; polling is only used when waiters is None :param status_retries: number of HTTP retries to get job status, 10; @@ -41,12 +47,20 @@ class BatchClientHookAsync(BatchClientHook, AwsBaseHookAsync): - `Exponential Backoff And Jitter `_ """ - def __init__(self, job_id: Optional[str], waiters: Any = None, *args: Any, **kwargs: Any) -> None: + def __init__(self, job_id: str | None, waiters: Any = None, *args: Any, **kwargs: Any) -> None: + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook`" + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__(*args, **kwargs) self.job_id = job_id self.waiters = waiters - async def monitor_job(self) -> Union[Dict[str, str], None]: + async def monitor_job(self) -> dict[str, str] | None: """ Monitor an AWS Batch job monitor_job can raise an exception or an AirflowTaskTimeout can be raised if execution_timeout @@ -92,7 +106,7 @@ async def check_job_success(self, job_id: str) -> bool: # type: ignore[override raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}") @staticmethod - async def delay(delay: Union[int, float, None] = None) -> None: # type: ignore[override] + async def delay(delay: int | (float | None) = None) -> None: # type: ignore[override] """ Pause execution for ``delay`` seconds. @@ -116,7 +130,7 @@ async def delay(delay: Union[int, float, None] = None) -> None: # type: ignore[ delay = BatchClientHookAsync.add_jitter(delay) await asyncio.sleep(delay) - async def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None: # type: ignore[override] + async def wait_for_job(self, job_id: str, delay: int | (float | None) = None) -> None: # type: ignore[override] """ Wait for Batch job to complete @@ -131,7 +145,7 @@ async def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) self.log.info("AWS Batch job (%s) has completed", job_id) async def poll_for_job_complete( # type: ignore[override] - self, job_id: str, delay: Union[int, float, None] = None + self, job_id: str, delay: int | (float | None) = None ) -> None: """ Poll for job completion. The status that indicates job completion @@ -150,7 +164,7 @@ async def poll_for_job_complete( # type: ignore[override] await self.poll_job_status(job_id, complete_status) async def poll_for_job_running( # type: ignore[override] - self, job_id: str, delay: Union[int, float, None] = None + self, job_id: str, delay: int | (float | None) = None ) -> None: """ Poll for job running. The status that indicates a job is running or @@ -172,7 +186,7 @@ async def poll_for_job_running( # type: ignore[override] running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, self.FAILURE_STATE] await self.poll_job_status(job_id, running_status) - async def get_job_description(self, job_id: str) -> Dict[str, str]: # type: ignore[override] + async def get_job_description(self, job_id: str) -> dict[str, str]: # type: ignore[override] """ Get job description (using status_retries). @@ -210,7 +224,7 @@ async def get_job_description(self, job_id: str) -> Dict[str, str]: # type: ign ) await self.delay(pause) - async def poll_job_status(self, job_id: str, match_status: List[str]) -> bool: # type: ignore[override] + async def poll_job_status(self, job_id: str, match_status: list[str]) -> bool: # type: ignore[override] """ Poll for job status using an exponential back-off strategy (with max_retries). The Batch job status polled are: diff --git a/astronomer/providers/amazon/aws/hooks/redshift_data.py b/astronomer/providers/amazon/aws/hooks/redshift_data.py index d7b880c1f..0ae56bba1 100644 --- a/astronomer/providers/amazon/aws/hooks/redshift_data.py +++ b/astronomer/providers/amazon/aws/hooks/redshift_data.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import warnings from typing import Any, Iterable import botocore.exceptions @@ -18,6 +19,9 @@ class RedshiftDataHook(AwsBaseHook): RedshiftDataHook inherits from AwsBaseHook to connect with AWS redshift by using boto3 client_type as redshift-data we can interact with redshift cluster database and execute the query + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook` instead + :param aws_conn_id: The Airflow connection used for AWS credentials. If this is None or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or @@ -34,6 +38,15 @@ class RedshiftDataHook(AwsBaseHook): """ def __init__(self, *args: Any, poll_interval: int = 0, **kwargs: Any) -> None: + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`" + ), + DeprecationWarning, + stacklevel=2, + ) + aws_connection_type: str = "redshift-data" try: # for apache-airflow-providers-amazon>=3.0.0 diff --git a/astronomer/providers/amazon/aws/hooks/s3.py b/astronomer/providers/amazon/aws/hooks/s3.py index 1af72b360..57adb22aa 100644 --- a/astronomer/providers/amazon/aws/hooks/s3.py +++ b/astronomer/providers/amazon/aws/hooks/s3.py @@ -4,6 +4,7 @@ import fnmatch import os import re +import warnings from datetime import datetime from functools import wraps from inspect import signature @@ -43,12 +44,24 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: class S3HookAsync(AwsBaseHookAsync): - """Interact with AWS S3, using the aiobotocore library.""" + """Interact with AWS S3, using the aiobotocore library. + + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.hooks.s3.S3Hook` instead + """ conn_type = "s3" hook_name = "S3" def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.hooks.s3.S3Hook`" + ), + DeprecationWarning, + stacklevel=2, + ) kwargs["client_type"] = "s3" kwargs["resource_type"] = "s3" super().__init__(*args, **kwargs) diff --git a/astronomer/providers/amazon/aws/hooks/sagemaker.py b/astronomer/providers/amazon/aws/hooks/sagemaker.py index cb518b9c9..1396d250a 100644 --- a/astronomer/providers/amazon/aws/hooks/sagemaker.py +++ b/astronomer/providers/amazon/aws/hooks/sagemaker.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import time -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +import warnings +from typing import Any, AsyncGenerator from airflow.providers.amazon.aws.hooks.sagemaker import ( LogState, @@ -21,56 +24,67 @@ class SageMakerHookAsync(AwsBaseHookAsync): Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHookAsync. + + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook` instead """ NON_TERMINAL_STATES = ("InProgress", "Stopping", "Stopped") def __init__(self, *args: Any, **kwargs: Any): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook`" + ), + DeprecationWarning, + stacklevel=2, + ) kwargs["client_type"] = "sagemaker" super().__init__(*args, **kwargs) self.s3_hook = S3HookAsync(aws_conn_id=self.aws_conn_id) self.logs_hook_async = AwsLogsHookAsync(aws_conn_id=self.aws_conn_id) - async def describe_transform_job_async(self, job_name: str) -> Dict[str, Any]: + async def describe_transform_job_async(self, job_name: str) -> dict[str, Any]: """ Return the transform job info associated with the name :param job_name: the name of the transform job """ async with await self.get_client_async() as client: - response: Dict[str, Any] = await client.describe_transform_job(TransformJobName=job_name) + response: dict[str, Any] = await client.describe_transform_job(TransformJobName=job_name) return response - async def describe_processing_job_async(self, job_name: str) -> Dict[str, Any]: + async def describe_processing_job_async(self, job_name: str) -> dict[str, Any]: """ Return the processing job info associated with the name :param job_name: the name of the processing job """ async with await self.get_client_async() as client: - response: Dict[str, Any] = await client.describe_processing_job(ProcessingJobName=job_name) + response: dict[str, Any] = await client.describe_processing_job(ProcessingJobName=job_name) return response - async def describe_training_job_async(self, job_name: str) -> Dict[str, Any]: + async def describe_training_job_async(self, job_name: str) -> dict[str, Any]: """ Return the training job info associated with the name :param job_name: the name of the training job """ async with await self.get_client_async() as client: - response: Dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name) + response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name) return response async def describe_training_job_with_log( self, job_name: str, - positions: Dict[str, Any], - stream_names: List[str], + positions: dict[str, Any], + stream_names: list[str], instance_count: int, state: int, - last_description: Dict[str, Any], + last_description: dict[str, Any], last_describe_job_call: float, - ) -> Tuple[int, Dict[str, Any], float]: + ) -> tuple[int, dict[str, Any], float]: """ Return the training job info associated with job_name and print CloudWatch logs @@ -127,8 +141,8 @@ async def describe_training_job_with_log( return state, last_description, last_describe_job_call async def get_multi_stream( - self, log_group: str, streams: List[str], positions: Dict[str, Any] - ) -> AsyncGenerator[Any, Tuple[int, Optional[Any]]]: + self, log_group: str, streams: list[str], positions: dict[str, Any] + ) -> AsyncGenerator[Any, tuple[int, Any | None]]: """ Iterate over the available events coming from a set of log streams in a single log group interleaving the events from each stream so they're yielded in timestamp order. @@ -140,7 +154,7 @@ async def get_multi_stream( read from each stream. """ positions = positions or {s: Position(timestamp=0, skip=0) for s in streams} - events: list[Optional[Any]] = [] + events: list[Any | None] = [] event_iters = [ self.logs_hook_async.get_log_events_async(log_group, s, positions[s].timestamp, positions[s].skip) diff --git a/astronomer/providers/amazon/aws/operators/batch.py b/astronomer/providers/amazon/aws/operators/batch.py index 2ec421998..5370ab75e 100644 --- a/astronomer/providers/amazon/aws/operators/batch.py +++ b/astronomer/providers/amazon/aws/operators/batch.py @@ -7,95 +7,29 @@ - `Batch `_ - `Welcome `_ """ -from typing import Any, Dict +from __future__ import annotations -from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.operators.batch import BatchOperator +import warnings +from typing import Any -from astronomer.providers.amazon.aws.triggers.batch import BatchOperatorTrigger -from astronomer.providers.utils.typing_compat import Context +from airflow.providers.amazon.aws.operators.batch import BatchOperator class BatchOperatorAsync(BatchOperator): """ - Execute a job asynchronously on AWS Batch - - .. see also:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:BatchOperator` - - :param job_name: the name for the job that will run on AWS Batch (templated) - :param job_definition: the job definition name on AWS Batch - :param job_queue: the queue name on AWS Batch - :param overrides: Removed in apache-airflow-providers-amazon release 8.0.0, use container_overrides instead with the - same value. - :param container_overrides: the `containerOverrides` parameter for boto3 (templated) - :param array_properties: the `arrayProperties` parameter for boto3 - :param parameters: the `parameters` for boto3 (templated) - :param job_id: the job ID, usually unknown (None) until the - submit_job operation gets the jobId defined by AWS Batch - :param waiters: an :py:class:`.BatchWaiters` object (see note below); - if None, polling is used with max_retries and status_retries. - :param max_retries: exponential back-off retries, 4200 = 48 hours; - polling is only used when waiters is None - :param status_retries: number of HTTP retries to get job status, 10; - polling is only used when waiters is None - :param aws_conn_id: connection id of AWS credentials / region name. If None, - credential boto3 strategy will be used. - :param region_name: region name to use in AWS Hook. - Override the region_name in connection (if provided) - :param tags: collection of tags to apply to the AWS Batch job submission - if None, no tags are submitted - - .. note:: - Any custom waiters must return a waiter for these calls: - - | ``waiter = waiters.get_waiter("JobExists")`` - | ``waiter = waiters.get_waiter("JobRunning")`` - | ``waiter = waiters.get_waiter("JobComplete")`` + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.operators.batch.BatchOperator` + and set `deferrable` param to `True` instead. """ - def execute(self, context: Context) -> None: - """ - Airflow runs this method on the worker and defers using the trigger. - Submit the job and get the job_id using which we defer and poll in trigger - """ - self.submit_job(context) - if not self.job_id: - raise AirflowException("AWS Batch job - job_id was not found") - job = self.hook.get_job_description(self.job_id) - job_status = job.get("status") - - if job_status == self.hook.SUCCESS_STATE: - self.log.info(f"{self.job_id} was completed successfully") - return - - if job_status == self.hook.FAILURE_STATE: - raise AirflowException(f"{self.job_id} failed") - - if job_status in self.hook.INTERMEDIATE_STATES: - self.defer( - timeout=self.execution_timeout, - trigger=BatchOperatorTrigger( - job_id=self.job_id, - waiters=self.waiters, - max_retries=self.hook.max_retries, - aws_conn_id=self.hook.aws_conn_id, - region_name=self.hook.region_name, - ), - method_name="execute_complete", - ) - - raise AirflowException(f"Unexpected status: {job_status}") - - # Ignoring the override type check because the parent class specifies "context: Any" but specifying it as - # "context: Context" is accurate as it's more specific. - def execute_complete(self, context: Context, event: Dict[str, Any]) -> None: # type: ignore[override] - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if "status" in event and event["status"] == "error": - raise AirflowException(event["message"]) - self.log.info(event["message"]) + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + ( + "This module is deprecated." + "Please use `airflow.providers.amazon.aws.operators.batch.BatchOperator`" + "and set `deferrable` param to `True` instead." + ), + DeprecationWarning, + stacklevel=2, + ) + super().__init__(deferrable=True, **kwargs) diff --git a/astronomer/providers/amazon/aws/operators/redshift_data.py b/astronomer/providers/amazon/aws/operators/redshift_data.py index 6dd05d2fa..9b6ebcb5d 100644 --- a/astronomer/providers/amazon/aws/operators/redshift_data.py +++ b/astronomer/providers/amazon/aws/operators/redshift_data.py @@ -1,75 +1,27 @@ +import warnings from typing import Any -from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator -from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook -from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger -from astronomer.providers.utils.typing_compat import Context - class RedshiftDataOperatorAsync(RedshiftDataOperator): """ - Executes SQL Statements against an Amazon Redshift cluster. - If there are multiple queries as part of the SQL, and one of them fails to reach a successful completion state, - the operator returns the relevant error for the failed query. - - :param sql: the SQL code to be executed as a single string, or - a list of str (sql statements), or a reference to a template file. - Template references are recognized by str ending in '.sql' - :param aws_conn_id: AWS connection ID - :param parameters: (optional) the parameters to render the SQL query with. - :param autocommit: if True, each command is automatically committed. - (default value: False) + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator` + and set `deferrable` param to `True` instead. """ def __init__( self, - *, - poll_interval: int = 5, **kwargs: Any, ) -> None: - self.poll_interval = poll_interval - super().__init__(**kwargs) - - def execute(self, context: Context) -> None: - """ - Makes a sync call to RedshiftDataHook, executes the query and gets back the list of query_ids and - defers trigger to poll for the status for the queries executed. - """ - redshift_data_hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id) - query_ids, response = redshift_data_hook.execute_query(sql=self.sql, params=self.params) - self.log.info("Query IDs %s", query_ids) - if response.get("status") == "error": - self.execute_complete(context, event=response) - context["ti"].xcom_push(key="return_value", value=query_ids) - - if redshift_data_hook.queries_are_completed(query_ids, context): - self.log.info("%s completed successfully.", self.task_id) - return - - self.defer( - timeout=self.execution_timeout, - trigger=RedshiftDataTrigger( - task_id=self.task_id, - poll_interval=self.poll_interval, - aws_conn_id=self.aws_conn_id, - query_ids=query_ids, + warnings.warn( + ( + "This module is deprecated." + "Please use `airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator`" + "and set `deferrable` param to `True` instead." ), - method_name="execute_complete", + DeprecationWarning, + stacklevel=2, ) - - def execute_complete(self, context: Context, event: Any = None) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if "status" in event and event["status"] == "error": - msg = "context: {}, error message: {}".format(context, event["message"]) - raise AirflowException(msg) - elif "status" in event and event["status"] == "success": - self.log.info("%s completed successfully.", self.task_id) - else: - raise AirflowException("Did not receive valid event from the trigerrer") + super().__init__(deferrable=True, **kwargs) diff --git a/astronomer/providers/amazon/aws/operators/sagemaker.py b/astronomer/providers/amazon/aws/operators/sagemaker.py index 038afd8fa..ed8a7edf4 100644 --- a/astronomer/providers/amazon/aws/operators/sagemaker.py +++ b/astronomer/providers/amazon/aws/operators/sagemaker.py @@ -2,6 +2,7 @@ import json import time +import warnings from typing import Any from airflow.exceptions import AirflowException @@ -17,7 +18,6 @@ from airflow.utils.json import AirflowJsonEncoder from astronomer.providers.amazon.aws.triggers.sagemaker import ( - SagemakerProcessingTrigger, SagemakerTrainingWithLogTrigger, SagemakerTrigger, ) @@ -31,216 +31,42 @@ def serialize(result: dict[str, Any]) -> str: class SageMakerProcessingOperatorAsync(SageMakerProcessingOperator): """ - SageMakerProcessingOperatorAsync is used to analyze data and evaluate machine learning - models on Amazon SageMaker. With SageMakerProcessingOperatorAsync, you can use a simplified, managed - experience on SageMaker to run your data processing workloads, such as feature - engineering, data validation, model evaluation, and model interpretation. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:SageMakerProcessingOperator` - - :param config: The configuration necessary to start a processing job (templated). - For details of the configuration parameter see - :ref:``SageMaker.Client.create_processing_job`` - :param aws_conn_id: The AWS connection ID to use. - :param wait_for_completion: Even if wait is set to False, in async we will defer and - the operation waits to check the status of the processing job. - :param print_log: if the operator should print the cloudwatch log during processing - :param check_interval: if wait is set to be true, this is the time interval - in seconds which the operator will check the status of the processing job - :param max_ingestion_time: The operation fails if the processing job - doesn't finish within max_ingestion_time seconds. If you set this parameter to None, - the operation does not timeout. - :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" - (default) and "fail". + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator` + and set `deferrable` param to `True` instead. """ - def execute(self, context: Context) -> dict[str, str] | None: # type: ignore[override] - """ - Creates processing job via sync hook `create_processing_job` and pass the - control to trigger and polls for the status of the processing job in async - """ - self.preprocess_config() - processing_job_name = self.config["ProcessingJobName"] - try: - if self.hook.count_processing_jobs_by_name(processing_job_name): - raise AirflowException( - f"A SageMaker processing job with name {processing_job_name} already exists." - ) - except AttributeError: # pragma: no cover - # For apache-airflow-providers-amazon<8.0.0 - if self.hook.find_processing_job_by_name(processing_job_name): - raise AirflowException( - f"A SageMaker processing job with name {processing_job_name} already exists." - ) - self.log.info("Creating SageMaker processing job %s.", self.config["ProcessingJobName"]) - response = self.hook.create_processing_job( - self.config, - # we do not wait for completion here but we create the processing job - # and poll for it in trigger - wait_for_completion=False, - ) - if response["ResponseMetadata"]["HTTPStatusCode"] != 200: - raise AirflowException(f"Sagemaker Processing Job creation failed: {response}") - - response = self.hook.describe_processing_job(processing_job_name) - status = response["ProcessingJobStatus"] - if status in self.hook.failed_states: - raise AirflowException(f"SageMaker job failed because {response['FailureReason']}") - elif status == "Completed": - self.log.info(f"{self.task_id} completed successfully.") - return {"Processing": serialize(response)} - - end_time: float | None = None - if self.max_ingestion_time is not None: - end_time = time.time() + self.max_ingestion_time - self.defer( - timeout=self.execution_timeout, - trigger=SagemakerProcessingTrigger( - poll_interval=self.check_interval, - aws_conn_id=self.aws_conn_id, - job_name=self.config["ProcessingJobName"], - end_time=end_time, + def __init__(self, **kwargs: Any) -> None: + warnings.warn( + ( + "This module is deprecated." + "Please use `airflow.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperator`" + "and set `deferrable` param to `True` instead." ), - method_name="execute_complete", + DeprecationWarning, + stacklevel=2, ) - - # for bypassing mypy missing return error - return None # pragma: no cover - - def execute_complete(self, context: Context, event: Any = None) -> dict[str, Any]: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event and event["status"] == "success": - self.log.info("%s completed successfully.", self.task_id) - return {"Processing": serialize(event["message"])} - if event and event["status"] == "error": - raise AirflowException(event["message"]) - raise AirflowException("No event received in trigger callback") + super().__init__(deferrable=True, **kwargs) class SageMakerTransformOperatorAsync(SageMakerTransformOperator): """ - SageMakerTransformOperatorAsync starts a transform job and polls for the status asynchronously. - A transform job uses a trained model to get inferences on a dataset and saves these results to an Amazon - S3 location that you specify. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:``howto/operator:SageMakerTransformOperator`` - - :param config: The configuration necessary to start a transform job (templated). - - If you need to create a SageMaker transform job based on an existed SageMaker model:: - - config = transform_config - - If you need to create both SageMaker model and SageMaker Transform job:: - - config = { - 'Model': model_config, - 'Transform': transform_config - } - - For details of the configuration parameter of transform_config see - :ref:``SageMaker.Client.create_transform_job`` - - For details of the configuration parameter of model_config, See: - :ref:``SageMaker.Client.create_model`` - - :param aws_conn_id: The AWS connection ID to use. - :param check_interval: If wait is set to True, the time interval, in seconds, - that this operation waits to check the status of the transform job. - :param max_ingestion_time: The operation fails if the transform job doesn't finish - within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. - :param check_if_job_exists: If set to true, then the operator will check whether a transform job - already exists for the name in the config. - :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" - (default) and "fail". - This is only relevant if check_if_job_exists is True. + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator` + and set `deferrable` param to `True` instead. """ - def execute(self, context: Context) -> dict[str, Any] | None: # type: ignore[override] - """ - Creates transform job via sync hook `create_transform_job` and pass the - control to trigger and polls for the status of the transform job in async - """ - self.preprocess_config() - model_config = self.config.get("Model") - transform_config = self.config.get("Transform", self.config) - if self.check_if_job_exists: # pragma: no cover - try: - # for apache-airflow-providers-amazon<=7.2.1 - self._check_if_transform_job_exists() # type: ignore[attr-defined] - except AttributeError: - # for apache-airflow-providers-amazon>=7.3.0 - transform_config["TransformJobName"] = self._get_unique_job_name( - transform_config["TransformJobName"], - self.action_if_job_exists == "fail", - self.hook.describe_transform_job, - ) - if model_config: - self.log.info("Creating SageMaker Model %s for transform job", model_config["ModelName"]) - self.hook.create_model(model_config) - self.log.info("Creating SageMaker transform Job %s.", transform_config["TransformJobName"]) - response = self.hook.create_transform_job( - transform_config, - wait_for_completion=False, - ) - if response["ResponseMetadata"]["HTTPStatusCode"] != 200: - raise AirflowException(f"Sagemaker transform Job creation failed: {response}") - - response = self.hook.describe_transform_job(transform_config["TransformJobName"]) - status = response["TransformJobStatus"] - if status in self.hook.failed_states: - raise AirflowException(f"SageMaker job failed because {response['FailureReason']}") - - if status == "Completed": - self.log.info(f"{self.task_id} completed successfully.") - return { - "Model": serialize(self.hook.describe_model(transform_config["ModelName"])), - "Transform": serialize(response), - } - - end_time: float | None = None - if self.max_ingestion_time is not None: - end_time = time.time() + self.max_ingestion_time - self.defer( - timeout=self.execution_timeout, - trigger=SagemakerTrigger( - poke_interval=self.check_interval, - end_time=end_time, - aws_conn_id=self.aws_conn_id, - job_name=transform_config["TransformJobName"], - job_type="Transform", - response_key="TransformJobStatus", + def __init__(self, **kwargs: Any) -> None: + warnings.warn( + ( + "This module is deprecated." + "Please use `airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator`" + "and set `deferrable` param to `True` instead." ), - method_name="execute_complete", + DeprecationWarning, + stacklevel=2, ) - - # for bypassing mypy missing return error - return None # pragma: no cover - - def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any]: # type: ignore[override] - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event and event["status"] == "success": - self.log.info("%s completed successfully.", self.task_id) - transform_config = self.config.get("Transform", self.config) - return { - "Model": serialize(self.hook.describe_model(transform_config["ModelName"])), - "Transform": serialize(event["message"]), - } - if event and event["status"] == "error": - raise AirflowException(event["message"]) - raise AirflowException("No event received in trigger callback") + super().__init__(deferrable=True, **kwargs) class SageMakerTrainingOperatorAsync(SageMakerTrainingOperator): diff --git a/astronomer/providers/amazon/aws/sensors/redshift_cluster.py b/astronomer/providers/amazon/aws/sensors/redshift_cluster.py index 48f138b71..6f919bf22 100644 --- a/astronomer/providers/amazon/aws/sensors/redshift_cluster.py +++ b/astronomer/providers/amazon/aws/sensors/redshift_cluster.py @@ -1,71 +1,36 @@ import warnings -from datetime import timedelta -from typing import Any, Dict, Optional +from typing import Any from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor -from astronomer.providers.amazon.aws.triggers.redshift_cluster import ( - RedshiftClusterSensorTrigger, -) -from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception -from astronomer.providers.utils.typing_compat import Context - class RedshiftClusterSensorAsync(RedshiftClusterSensor): """ - Waits for a Redshift cluster to reach a specific status. - - :param cluster_identifier: The identifier for the cluster being pinged.\ - :param target_status: The cluster status desired. + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.sensors.redshift_cluster.RedshiftClusterSensor` + and set `deferrable` param to `True` instead. """ def __init__( self, - *, - poll_interval: float = 5, **kwargs: Any, ): # TODO: Remove once deprecated - if poll_interval: - self.poke_interval = poll_interval + if kwargs.get("poll_interval"): + kwargs["poke_interval"] = kwargs["poll_interval"] warnings.warn( "Argument `poll_interval` is deprecated and will be removed " "in a future release. Please use `poke_interval` instead.", DeprecationWarning, stacklevel=2, ) - super().__init__(**kwargs) - - def execute(self, context: Context) -> None: - """Check for the target_status and defers using the trigger""" - if not poke(self, context): - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=RedshiftClusterSensorTrigger( - task_id=self.task_id, - aws_conn_id=self.aws_conn_id, - cluster_identifier=self.cluster_identifier, - target_status=self.target_status, - poke_interval=self.poke_interval, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Context, event: Optional[Dict[Any, Any]] = None) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if "status" in event and event["status"] == "error": - msg = "{}: {}".format(event["status"], event["message"]) - raise_error_or_skip_exception(self.soft_fail, msg) - if "status" in event and event["status"] == "success": - self.log.info("%s completed successfully.", self.task_id) - self.log.info( - "Cluster Identifier %s is in %s state", self.cluster_identifier, self.target_status - ) - return None - self.log.info("%s completed successfully.", self.task_id) - return None + warnings.warn( + ( + "This module is deprecated." + "Please use `airflow.providers.amazon.aws.sensors.redshift_cluster.RedshiftClusterSensor`" + "and set `deferrable` param to `True` instead." + ), + DeprecationWarning, + stacklevel=2, + ) + super().__init__(deferrable=True, **kwargs) diff --git a/astronomer/providers/amazon/aws/sensors/s3.py b/astronomer/providers/amazon/aws/sensors/s3.py index a4e902060..ab0858c7e 100644 --- a/astronomer/providers/amazon/aws/sensors/s3.py +++ b/astronomer/providers/amazon/aws/sensors/s3.py @@ -1,131 +1,30 @@ from __future__ import annotations import warnings -from datetime import timedelta -from typing import Any, Callable, Sequence, cast +from typing import Any, Callable -from airflow.exceptions import AirflowSkipException from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeysUnchangedSensor from airflow.sensors.base import BaseSensorOperator -from astronomer.providers.amazon.aws.triggers.s3 import ( - S3KeyTrigger, -) -from astronomer.providers.utils.sensor_util import raise_error_or_skip_exception -from astronomer.providers.utils.typing_compat import Context - class S3KeySensorAsync(S3KeySensor): """ - Waits for one or multiple keys (a file-like instance on S3) to be present in a S3 bucket. - S3 being a key/value it does not support folders. The path is just a key - a resource. - - :param bucket_key: The key(s) being waited on. Supports full s3:// style url - or relative path from root level. When it's specified as a full s3:// - url, please leave bucket_name as `None` - :param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key`` - is not provided as a full s3:// url. When specified, all the keys passed to ``bucket_key`` - refers to this bucket - :param wildcard_match: whether the bucket_key should be interpreted as a - Unix wildcard pattern - :param use_regex: whether to use regex to check bucket - :param check_fn: Function that receives the list of the S3 objects, - and returns a boolean: - - ``True``: the criteria is met - - ``False``: the criteria isn't met - **Example**: Wait for any S3 object size more than 1 megabyte :: - - def check_fn(files: List) -> bool: - return any(f.get('Size', 0) > 1048576 for f in files) - :param aws_conn_id: a reference to the s3 connection - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. - You can provide the following values: - - - ``False``: do not validate SSL certificates. SSL will still be used - (unless use_ssl is False), but SSL certificates will not be - verified. - - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. - You can specify this argument if you want to use a different - CA cert bundle than the one used by botocore. - - .. seealso:: - `For more information on how to use this sensor, take a look at the guide: - :ref:`howto/sensor:S3KeySensor`. + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.sensors.s3.S3KeySensor` + and set `deferrable` param to `True` instead. """ - template_fields: Sequence[str] = ("bucket_key", "bucket_name") - - def __init__( - self, - *, - bucket_key: str | list[str], - bucket_name: str | None = None, - wildcard_match: bool = False, - use_regex: bool = False, - check_fn: Callable[..., bool] | None = None, - aws_conn_id: str = "aws_default", - verify: str | bool | None = None, - **kwargs: Any, - ): - self.bucket_key: list[str] = [bucket_key] if isinstance(bucket_key, str) else bucket_key - self.use_regex = use_regex - super().__init__( - bucket_name=bucket_name, - bucket_key=self.bucket_key, - wildcard_match=wildcard_match, - check_fn=check_fn, - aws_conn_id=aws_conn_id, - verify=verify, - **kwargs, - ) - self.check_fn = check_fn - self.should_check_fn = True if check_fn else False - - def execute(self, context: Context) -> None: - """Check for a keys in s3 and defers using the trigger""" - try: - poke = self.poke(context) - except Exception as e: - if self.soft_fail: - raise AirflowSkipException(str(e)) - else: - raise e - if not poke: - self._defer() - - def _defer(self) -> None: - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=S3KeyTrigger( - bucket_name=cast(str, self.bucket_name), - bucket_key=self.bucket_key, - wildcard_match=self.wildcard_match, - use_regex=self.use_regex, - aws_conn_id=self.aws_conn_id, - verify=self.verify, - poke_interval=self.poke_interval, - soft_fail=self.soft_fail, - should_check_fn=self.should_check_fn, + def __init__(self, **kwargs: Any) -> None: + warnings.warn( + ( + "This module is deprecated." + "Please use `airflow.providers.amazon.aws.sensors.s3.S3KeySensor`" + "and set `deferrable` param to `True` instead." ), - method_name="execute_complete", + DeprecationWarning, + stacklevel=2, ) - - def execute_complete(self, context: Context, event: Any = None) -> bool | None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] == "running": - if self.check_fn(event["files"]): # type: ignore[misc] - return None - else: - self._defer() - if event["status"] == "error": - raise_error_or_skip_exception(self.soft_fail, event["message"]) - return None + super().__init__(deferrable=True, **kwargs) class S3KeySizeSensorAsync(S3KeySensorAsync): diff --git a/astronomer/providers/amazon/aws/triggers/batch.py b/astronomer/providers/amazon/aws/triggers/batch.py index 01b1be301..9c55e7620 100644 --- a/astronomer/providers/amazon/aws/triggers/batch.py +++ b/astronomer/providers/amazon/aws/triggers/batch.py @@ -1,4 +1,7 @@ -from typing import Any, AsyncIterator, Dict, Optional, Tuple +from __future__ import annotations + +import warnings +from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -10,6 +13,10 @@ class BatchOperatorTrigger(BaseTrigger): Checks for the state of a previously submitted job to AWS Batch. BatchOperatorTrigger is fired as deferred class with params to poll the job state in Triggerer + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger` instead + + :param job_id: the job ID, usually unknown (None) until the submit_job operation gets the jobId defined by AWS Batch :param waiters: a :class:`.BatchWaiters` object (see note below); @@ -24,12 +31,20 @@ class BatchOperatorTrigger(BaseTrigger): def __init__( self, - job_id: Optional[str], + job_id: str | None, waiters: Any, max_retries: int, - region_name: Optional[str], - aws_conn_id: Optional[str] = "aws_default", + region_name: str | None, + aws_conn_id: str | None = "aws_default", ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger`" + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.job_id = job_id self.waiters = waiters @@ -37,7 +52,7 @@ def __init__( self.aws_conn_id = aws_conn_id self.region_name = region_name - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes BatchOperatorTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.batch.BatchOperatorTrigger", @@ -50,7 +65,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """ Make async connection using aiobotocore library to AWS Batch, periodically poll for the job status on the Triggerer diff --git a/astronomer/providers/amazon/aws/triggers/redshift_cluster.py b/astronomer/providers/amazon/aws/triggers/redshift_cluster.py index 9b0f3e75c..f5a032774 100644 --- a/astronomer/providers/amazon/aws/triggers/redshift_cluster.py +++ b/astronomer/providers/amazon/aws/triggers/redshift_cluster.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import asyncio import warnings -from typing import Any, AsyncIterator, Dict, Optional, Tuple +from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -29,7 +31,7 @@ def __init__( operation_type: str, polling_period_seconds: float = 5.0, skip_final_cluster_snapshot: bool = True, - final_cluster_snapshot_identifier: Optional[str] = None, + final_cluster_snapshot_identifier: str | None = None, ): warnings.warn( ( @@ -48,7 +50,7 @@ def __init__( self.skip_final_cluster_snapshot = skip_final_cluster_snapshot self.final_cluster_snapshot_identifier = final_cluster_snapshot_identifier - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes RedshiftClusterTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger", @@ -63,7 +65,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """ Make async connection to redshift, based on the operation type call the RedshiftHookAsync functions @@ -112,6 +114,9 @@ class RedshiftClusterSensorTrigger(BaseTrigger): """ RedshiftClusterSensorTrigger is fired as deferred class with params to run the task in trigger worker + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger` instead + :param task_id: Reference to task id of the Dag :param aws_conn_id: Reference to AWS connection id for redshift :param cluster_identifier: unique identifier of a cluster @@ -127,6 +132,14 @@ def __init__( target_status: str, poke_interval: float, ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger`" + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.task_id = task_id self.aws_conn_id = aws_conn_id @@ -134,7 +147,7 @@ def __init__( self.target_status = target_status self.poke_interval = poke_interval - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes RedshiftClusterSensorTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterSensorTrigger", @@ -147,7 +160,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """Simple async function run until the cluster status match the target status.""" try: hook = RedshiftHookAsync(aws_conn_id=self.aws_conn_id) diff --git a/astronomer/providers/amazon/aws/triggers/redshift_data.py b/astronomer/providers/amazon/aws/triggers/redshift_data.py index c3e6790c9..540135fd7 100644 --- a/astronomer/providers/amazon/aws/triggers/redshift_data.py +++ b/astronomer/providers/amazon/aws/triggers/redshift_data.py @@ -1,4 +1,7 @@ -from typing import Any, AsyncIterator, Dict, List, Tuple +from __future__ import annotations + +import warnings +from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -9,6 +12,9 @@ class RedshiftDataTrigger(BaseTrigger): """ RedshiftDataTrigger is fired as deferred class with params to run the task in triggerer. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger` instead + :param task_id: task ID of the Dag :param poll_interval: polling period in seconds to check for the status :param aws_conn_id: AWS connection ID for redshift @@ -19,16 +25,25 @@ def __init__( self, task_id: str, poll_interval: int, - query_ids: List[str], + query_ids: list[str], aws_conn_id: str = "aws_default", ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger`" + ), + DeprecationWarning, + stacklevel=2, + ) + super().__init__() self.task_id = task_id self.poll_interval = poll_interval self.aws_conn_id = aws_conn_id self.query_ids = query_ids - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes RedshiftDataTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger", @@ -40,7 +55,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """ Makes async connection and gets status for a list of queries submitted by the operator. Even if one of the queries has a non-successful state, the hook returns a failure event and the error diff --git a/astronomer/providers/amazon/aws/triggers/s3.py b/astronomer/providers/amazon/aws/triggers/s3.py index 20f8157c0..32b77031b 100644 --- a/astronomer/providers/amazon/aws/triggers/s3.py +++ b/astronomer/providers/amazon/aws/triggers/s3.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import warnings from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -12,6 +13,9 @@ class S3KeyTrigger(BaseTrigger): """ S3KeyTrigger is fired as deferred class with params to run the task in trigger worker + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger` instead + :param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key`` is not provided as a full s3:// url. :param bucket_key: The key being waited on. Supports full s3:// style url @@ -37,6 +41,14 @@ def __init__( should_check_fn: bool = False, **hook_params: Any, ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger`" + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.bucket_name = bucket_name self.bucket_key = bucket_key diff --git a/astronomer/providers/amazon/aws/triggers/sagemaker.py b/astronomer/providers/amazon/aws/triggers/sagemaker.py index f3aa41e06..f311d3a42 100644 --- a/astronomer/providers/amazon/aws/triggers/sagemaker.py +++ b/astronomer/providers/amazon/aws/triggers/sagemaker.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import asyncio import time -from typing import Any, AsyncIterator, Dict, List, Optional, Tuple +import warnings +from typing import Any, AsyncIterator from airflow.providers.amazon.aws.hooks.sagemaker import LogState from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -12,6 +15,9 @@ class SagemakerProcessingTrigger(BaseTrigger): """ SagemakerProcessingTrigger is fired as deferred class with params to run the task in triggerer. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger` instead + :param job_name: name of the job to check status :param poll_interval: polling period in seconds to check for the status :param aws_conn_id: AWS connection ID for sagemaker @@ -26,16 +32,24 @@ def __init__( self, job_name: str, poll_interval: float, - end_time: Optional[float], + end_time: float | None, aws_conn_id: str = "aws_default", ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerTrigger`" + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.job_name = job_name self.poll_interval = poll_interval self.aws_conn_id = aws_conn_id self.end_time = end_time - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes SagemakerProcessingTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerProcessingTrigger", @@ -47,7 +61,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """ Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator. Trigger returns a failure event if any error and success in state return the success event. @@ -80,6 +94,9 @@ class SagemakerTrigger(BaseTrigger): SagemakerTrigger is common trigger for both transform and training sagemaker job and it is fired as deferred class with params to run the task in triggerer. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger` instead + :param job_name: name of the job to check status :param job_type: Type of the sagemaker job whether it is Transform or Training :param response_key: The key which needs to be look in the response. @@ -97,9 +114,17 @@ def __init__( job_type: str, response_key: str, poke_interval: float, - end_time: Optional[float] = None, + end_time: float | None = None, aws_conn_id: str = "aws_default", ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.hooks.sagemaker.SageMakerTrigger`" + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.job_name = job_name self.job_type = job_type @@ -108,7 +133,7 @@ def __init__( self.end_time = end_time self.aws_conn_id = aws_conn_id - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes SagemakerTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerTrigger", @@ -122,7 +147,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """ Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator. Trigger returns a failure event if any error and success in state return the success event. @@ -148,7 +173,7 @@ def _get_async_hook(self) -> SageMakerHookAsync: return SageMakerHookAsync(aws_conn_id=self.aws_conn_id) @staticmethod - async def get_job_status(hook: SageMakerHookAsync, job_name: str, job_type: str) -> Dict[str, Any]: + async def get_job_status(hook: SageMakerHookAsync, job_name: str, job_type: str) -> dict[str, Any]: """ Based on the job type the SageMakerHookAsync connect to sagemaker related function and get the response of the job and return it @@ -181,7 +206,7 @@ def __init__( instance_count: int, status: str, poke_interval: float, - end_time: Optional[float] = None, + end_time: float | None = None, aws_conn_id: str = "aws_default", ): super().__init__() @@ -192,7 +217,7 @@ def __init__( self.end_time = end_time self.aws_conn_id = aws_conn_id - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes SagemakerTrainingWithLogTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.sagemaker.SagemakerTrainingWithLogTrigger", @@ -206,7 +231,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """ Makes async connection to sagemaker async hook and gets job status for a job submitted by the operator. Trigger returns a failure event if any error and success in state return the success event. @@ -214,8 +239,8 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: hook = self._get_async_hook() last_description = await hook.describe_training_job_async(self.job_name) - stream_names: List[str] = [] # The list of log streams - positions: Dict[str, Any] = {} # The current position in each stream, map of stream name -> position + stream_names: list[str] = [] # The list of log streams + positions: dict[str, Any] = {} # The current position in each stream, map of stream name -> position job_already_completed = self.status not in self.NON_TERMINAL_STATES diff --git a/setup.cfg b/setup.cfg index 1e68e119c..493e2e03d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,7 @@ zip_safe = false [options.extras_require] amazon = - apache-airflow-providers-amazon>=8.16.0 + apache-airflow-providers-amazon>=8.17.0 aiobotocore>=2.1.1 apache.hive = apache-airflow-providers-apache-hive>=6.1.5 @@ -118,7 +118,7 @@ mypy = # All extras from above except 'mypy', 'docs' and 'tests' all = aiobotocore>=2.1.1 - apache-airflow-providers-amazon>=8.16.0 + apache-airflow-providers-amazon>=8.17.0 apache-airflow-providers-apache-hive>=6.1.5 apache-airflow-providers-apache-livy>=3.7.1 apache-airflow-providers-cncf-kubernetes>=4 diff --git a/tests/amazon/aws/operators/test_batch.py b/tests/amazon/aws/operators/test_batch.py index e8e798cc7..5c3370716 100644 --- a/tests/amazon/aws/operators/test_batch.py +++ b/tests/amazon/aws/operators/test_batch.py @@ -1,161 +1,17 @@ -from unittest import mock - -import pytest -from airflow.exceptions import AirflowException, TaskDeferred -from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook +from airflow.providers.amazon.aws.operators.batch import BatchOperator from astronomer.providers.amazon.aws.operators.batch import BatchOperatorAsync -from astronomer.providers.amazon.aws.triggers.batch import BatchOperatorTrigger -from tests.utils.airflow_util import create_context class TestBatchOperatorAsync: - JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3" - JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19" - MAX_RETRIES = 2 - STATUS_RETRIES = 3 - RESPONSE_WITHOUT_FAILURES = { - "jobName": JOB_NAME, - "jobId": JOB_ID, - } - - @mock.patch("astronomer.providers.amazon.aws.operators.batch.BatchOperatorAsync.defer") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") - def test_batch_op_async_succeeded_before_defer(self, get_client_type_mock, get_job_description, defer): - get_client_type_mock.return_value.submit_job.return_value = self.RESPONSE_WITHOUT_FAILURES - get_job_description.return_value = {"status": BatchClientHook.SUCCESS_STATE} - task = BatchOperatorAsync( - task_id="task", - job_name=self.JOB_NAME, - job_queue="queue", - job_definition="hello-world", - max_retries=self.MAX_RETRIES, - status_retries=self.STATUS_RETRIES, - parameters=None, - overrides={}, - array_properties=None, - aws_conn_id="aws_default", - region_name="eu-west-1", - tags={}, - ) - context = create_context(task) - task.execute(context) - assert not defer.called - - @pytest.mark.parametrize("status", (BatchClientHook.FAILURE_STATE, "Unexpected status")) - @mock.patch("astronomer.providers.amazon.aws.operators.batch.BatchOperatorAsync.defer") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") - def test_batch_op_async_failed_before_defer( - self, get_client_type_mock, get_job_description, defer, status - ): - get_client_type_mock.return_value.submit_job.return_value = self.RESPONSE_WITHOUT_FAILURES - get_job_description.return_value = {"status": status} - task = BatchOperatorAsync( - task_id="task", - job_name=self.JOB_NAME, - job_queue="queue", - job_definition="hello-world", - max_retries=self.MAX_RETRIES, - status_retries=self.STATUS_RETRIES, - parameters=None, - overrides={}, - array_properties=None, - aws_conn_id="aws_default", - region_name="eu-west-1", - tags={}, - ) - context = create_context(task) - with pytest.raises(AirflowException): - task.execute(context) - assert not defer.called - - @pytest.mark.parametrize("status", BatchClientHook.INTERMEDIATE_STATES) - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientHook.get_job_description") - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") - def test_batch_op_async(self, get_client_type_mock, get_job_description, status): - get_client_type_mock.return_value.submit_job.return_value = self.RESPONSE_WITHOUT_FAILURES - get_job_description.return_value = {"status": status} - task = BatchOperatorAsync( - task_id="task", - job_name=self.JOB_NAME, - job_queue="queue", - job_definition="hello-world", - max_retries=self.MAX_RETRIES, - status_retries=self.STATUS_RETRIES, - parameters=None, - overrides={}, - array_properties=None, - aws_conn_id="aws_default", - region_name="eu-west-1", - tags={}, - ) - context = create_context(task) - with pytest.raises(TaskDeferred) as exc: - task.execute(context) - assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger is not a BatchOperatorTrigger" - - def test_batch_op_async_execute_failure(self, context): - """Tests that an AirflowException is raised in case of error event""" - - task = BatchOperatorAsync( - task_id="task", - job_name=self.JOB_NAME, - job_queue="queue", - job_definition="hello-world", - max_retries=self.MAX_RETRIES, - status_retries=self.STATUS_RETRIES, - parameters=None, - overrides={}, - array_properties=None, - aws_conn_id="aws_default", - region_name="eu-west-1", - tags={}, - ) - with pytest.raises(AirflowException) as exc_info: - task.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) - - assert str(exc_info.value) == "test failure message" - - @pytest.mark.parametrize( - "event", - [{"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}], - ) - def test_batch_op_async_execute_complete(self, caplog, event): - """Tests that execute_complete method returns None and that it prints expected log""" - task = BatchOperatorAsync( - task_id="task", - job_name=self.JOB_NAME, - job_queue="queue", - job_definition="hello-world", - max_retries=self.MAX_RETRIES, - status_retries=self.STATUS_RETRIES, - parameters=None, - overrides={}, - array_properties=None, - aws_conn_id="aws_default", - region_name="eu-west-1", - tags={}, - ) - with mock.patch.object(task.log, "info") as mock_log_info: - assert task.execute_complete(context=None, event=event) is None - - mock_log_info.assert_called_with(f"AWS Batch job ({self.JOB_ID}) succeeded") - - @mock.patch("astronomer.providers.amazon.aws.operators.batch.BatchOperatorAsync.submit_job") - def test_batch_op_raises_exception_before_deferral_if_job_id_unset(self, mock_submit_job): - """ - Test that an AirflowException is raised if job_id is not set before deferral by mocking the submit_job - method which sets the job_id attribute of the instance. - """ + def test_init(self): task = BatchOperatorAsync( task_id="task", - job_name=self.JOB_NAME, + job_name="51455483-c62c-48ac-9b88-53a6a725baa3", job_queue="queue", job_definition="hello-world", - max_retries=self.MAX_RETRIES, - status_retries=self.STATUS_RETRIES, + max_retries=2, + status_retries=3, parameters=None, overrides={}, array_properties=None, @@ -163,7 +19,5 @@ def test_batch_op_raises_exception_before_deferral_if_job_id_unset(self, mock_su region_name="eu-west-1", tags={}, ) - context = create_context(task) - with pytest.raises(AirflowException) as exc: - task.execute(context) - assert "AWS Batch job - job_id was not found" in str(exc.value) + assert isinstance(task, BatchOperator) + assert task.deferrable is True diff --git a/tests/amazon/aws/operators/test_redshift_data.py b/tests/amazon/aws/operators/test_redshift_data.py index 5cff1b4f8..8c2860695 100644 --- a/tests/amazon/aws/operators/test_redshift_data.py +++ b/tests/amazon/aws/operators/test_redshift_data.py @@ -1,101 +1,16 @@ -from unittest import mock - -import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator from astronomer.providers.amazon.aws.operators.redshift_data import ( RedshiftDataOperatorAsync, ) -from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger -from tests.utils.airflow_util import create_context class TestRedshiftDataOperatorAsync: - DATABASE_NAME = "TEST_DATABASE" - TASK_ID = "fetch_data" - SQL_QUERY = "select * from any" - TASK = RedshiftDataOperatorAsync( - task_id=TASK_ID, - sql=SQL_QUERY, - database=DATABASE_NAME, - ) - - @mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_redshift_data_op_async_finished_before_deferred(self, mock_execute, mock_conn, mock_defer): - mock_execute.return_value = ["test_query_id"], {} - mock_conn.describe_statement.return_value = { - "Status": "FINISHED", - } - self.TASK.execute(create_context(self.TASK)) - assert not mock_defer.called - - @mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_redshift_data_op_async_aborted_before_deferred(self, mock_execute, mock_conn, mock_defer): - mock_execute.return_value = ["test_query_id"], {} - mock_conn.describe_statement.return_value = {"Status": "ABORTED"} - - with pytest.raises(AirflowException): - self.TASK.execute(create_context(self.TASK)) - - assert not mock_defer.called - - @mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_redshift_data_op_async_failed_before_deferred(self, mock_execute, mock_conn, mock_defer): - mock_execute.return_value = ["test_query_id"], {} - mock_conn.describe_statement.return_value = { - "Status": "FAILED", - "QueryString": "test query", - "Error": "test error", - } - - with pytest.raises(AirflowException): - self.TASK.execute(create_context(self.TASK)) - - assert not mock_defer.called - - @pytest.mark.parametrize("status", ("SUBMITTED", "PICKED", "STARTED")) - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_redshift_data_op_async(self, mock_execute, mock_conn, status): - mock_execute.return_value = ["test_query_id"], {} - mock_conn.describe_statement.return_value = {"Status": status} - - with pytest.raises(TaskDeferred) as exc: - self.TASK.execute(create_context(self.TASK)) - assert isinstance(exc.value.trigger, RedshiftDataTrigger), "Trigger is not a RedshiftDataTrigger" - - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_redshift_data_op_async_execute_query_error(self, mock_execute, context): - mock_execute.return_value = [], {"status": "error", "message": "Test exception"} - with pytest.raises(AirflowException): - self.TASK.execute(context) - - def test_redshift_data_op_async_execute_failure(self, context): - """Tests that an AirflowException is raised in case of error event""" - - with pytest.raises(AirflowException): - self.TASK.execute_complete( - context=None, event={"status": "error", "message": "test failure message"} - ) - - @pytest.mark.parametrize( - "event", - [None, {"status": "success", "message": "Job completed"}], - ) - def test_redshift_data_op_async_execute_complete(self, event): - """Asserts that logging occurs as expected""" - - if not event: - with pytest.raises(AirflowException) as exception_info: - self.TASK.execute_complete(context=None, event=None) - assert exception_info.value.args[0] == "Did not receive valid event from the trigerrer" - else: - with mock.patch.object(self.TASK.log, "info") as mock_log_info: - self.TASK.execute_complete(context=None, event=event) - mock_log_info.assert_called_with("%s completed successfully.", self.TASK_ID) + def test_init(self): + task = RedshiftDataOperatorAsync( + task_id="fetch_data", + sql="select * from any", + database="TEST_DATABASE", + ) + assert isinstance(task, RedshiftDataOperator) + assert task.deferrable is True diff --git a/tests/amazon/aws/operators/test_sagemaker.py b/tests/amazon/aws/operators/test_sagemaker.py index 8b83019ed..8f25c50e4 100644 --- a/tests/amazon/aws/operators/test_sagemaker.py +++ b/tests/amazon/aws/operators/test_sagemaker.py @@ -4,6 +4,10 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.operators import sagemaker +from airflow.providers.amazon.aws.operators.sagemaker import ( + SageMakerProcessingOperator, + SageMakerTransformOperator, +) from airflow.utils.timezone import datetime from astronomer.providers.amazon.aws.operators.sagemaker import ( @@ -12,7 +16,6 @@ SageMakerTransformOperatorAsync, ) from astronomer.providers.amazon.aws.triggers.sagemaker import ( - SagemakerProcessingTrigger, SagemakerTrainingWithLogTrigger, SagemakerTrigger, ) @@ -136,158 +139,15 @@ class TestSagemakerProcessingOperatorAsync: CHECK_INTERVAL = 5 MAX_INGESTION_TIME = 60 * 60 * 24 * 7 - @mock.patch("astronomer.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperatorAsync.defer") - @mock.patch.object( - SageMakerHook, "describe_processing_job", return_value={"ProcessingJobStatus": "Completed"} - ) - @mock.patch.object( - SageMakerHook, - "create_processing_job", - return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, - ) - @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=False) - def test_sagemakerprocessing_op_async_complete_before_defer( - self, mock_hook, mock_processing, mock_describe, mock_defer, context - ): - task = SageMakerProcessingOperatorAsync( - config=CREATE_PROCESSING_PARAMS, - task_id=self.TASK_ID, - check_interval=self.CHECK_INTERVAL, - max_ingestion_time=self.MAX_INGESTION_TIME, - ) - task.execute(context) - - assert not mock_defer.called - - @mock.patch("astronomer.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperatorAsync.defer") - @mock.patch.object(SageMakerHook, "describe_processing_job", return_value=Exception) - @mock.patch.object( - SageMakerHook, - "create_processing_job", - return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, - ) - @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=False) - def test_sagemakerprocessing_op_async_encounter_exception_before_defer( - self, mock_hook, mock_processing, mock_describe, mock_defer, context - ): - task = SageMakerProcessingOperatorAsync( - config=CREATE_PROCESSING_PARAMS, - task_id=self.TASK_ID, - check_interval=self.CHECK_INTERVAL, - max_ingestion_time=self.MAX_INGESTION_TIME, - ) - - with pytest.raises(Exception): - task.execute(context) - - assert not mock_defer.called - - @mock.patch("astronomer.providers.amazon.aws.operators.sagemaker.SageMakerProcessingOperatorAsync.defer") - @mock.patch.object( - SageMakerHook, - "describe_processing_job", - return_value={"ProcessingJobStatus": "Failed", "FailureReason": "it just failed"}, - ) - @mock.patch.object( - SageMakerHook, - "create_processing_job", - return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, - ) - @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=False) - def test_sagemakerprocessing_op_async_failed_before_defer( - self, mock_hook, mock_processing, mock_describe, mock_defer, context - ): - task = SageMakerProcessingOperatorAsync( - config=CREATE_PROCESSING_PARAMS, - task_id=self.TASK_ID, - check_interval=self.CHECK_INTERVAL, - max_ingestion_time=self.MAX_INGESTION_TIME, - ) - with pytest.raises(AirflowException): - task.execute(context) - - assert not mock_defer.called - - @mock.patch.object( - SageMakerHook, "describe_processing_job", return_value={"ProcessingJobStatus": "InProgress"} - ) - @mock.patch.object( - SageMakerHook, - "create_processing_job", - return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, - ) - @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=False) - def test_sagemakerprocessing_op_async(self, mock_hook, mock_processing, mock_describe, context): - """Assert SageMakerProcessingOperatorAsync deferred properly""" + def test_init(self): task = SageMakerProcessingOperatorAsync( config=CREATE_PROCESSING_PARAMS, task_id=self.TASK_ID, check_interval=self.CHECK_INTERVAL, max_ingestion_time=self.MAX_INGESTION_TIME, ) - with pytest.raises(TaskDeferred) as exc: - task.execute(context) - assert isinstance( - exc.value.trigger, SagemakerProcessingTrigger - ), "Trigger is not a SagemakerProcessingTrigger" - - @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=True) - def test_sagemakerprocessing_op_async_duplicate_failure(self, mock_hook, context): - """Tests that an AirflowException is raised in case of error event from find_processing_job_name""" - task = SageMakerProcessingOperatorAsync( - config=CREATE_PROCESSING_PARAMS, - task_id=self.TASK_ID, - check_interval=self.CHECK_INTERVAL, - ) - with pytest.raises(AirflowException): - task.execute(context) - - @mock.patch.object( - SageMakerHook, - "create_processing_job", - return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}}, - ) - @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=False) - def test_sagemakerprocessing_op_async_failure(self, mock_hook, mock_processing_job, context): - """Tests that an AirflowException is raised in case of error event from create_processing_job""" - task = SageMakerProcessingOperatorAsync( - config=CREATE_PROCESSING_PARAMS, - task_id=self.TASK_ID, - check_interval=self.CHECK_INTERVAL, - ) - with pytest.raises(AirflowException): - task.execute(context) - - @pytest.mark.parametrize( - "event", - [{"status": "error", "message": "test failure message"}, None], - ) - def test_sagemakerprocessing_op_async_execute_failure(self, event): - """Tests that an AirflowException is raised in case of error event""" - task = SageMakerProcessingOperatorAsync( - config=CREATE_PROCESSING_PARAMS, - task_id=self.TASK_ID, - check_interval=self.CHECK_INTERVAL, - max_ingestion_time=self.MAX_INGESTION_TIME, - ) - with pytest.raises(AirflowException): - task.execute_complete(context=None, event=event) - - @pytest.mark.parametrize( - "event", - [{"status": "success", "message": "Job completed"}], - ) - def test_sagemakerprocessing_op_async_execute_complete(self, event): - """Asserts that logging occurs as expected""" - task = SageMakerProcessingOperatorAsync( - config=CREATE_PROCESSING_PARAMS, - task_id=self.TASK_ID, - check_interval=self.CHECK_INTERVAL, - max_ingestion_time=self.MAX_INGESTION_TIME, - ) - with mock.patch.object(task.log, "info") as mock_log_info: - task.execute_complete(context=None, event=event) - mock_log_info.assert_called_with("%s completed successfully.", "test_sagemaker_processing_operator") + assert isinstance(task, SageMakerProcessingOperator) + assert task.deferrable is True class TestSagemakerTransformOperatorAsync: @@ -295,63 +155,7 @@ class TestSagemakerTransformOperatorAsync: CHECK_INTERVAL = 5 MAX_INGESTION_TIME = 60 * 60 * 24 * 7 - @mock.patch("astronomer.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperatorAsync.defer") - @mock.patch.object(SageMakerHook, "describe_model", return_value={}) - @mock.patch.object( - SageMakerHook, - "describe_transform_job", - return_value={"TransformJobStatus": "Failed", "FailureReason": "it just failed"}, - ) - @mock.patch.object( - SageMakerHook, - "create_transform_job", - return_value={"TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, - ) - @mock.patch.object(SageMakerHook, "list_transform_jobs", return_value=[]) - @mock.patch.object(SageMakerHook, "create_model", return_value=None) - def test_sagemaker_transform_op_async_failed_before_defer( - self, - mock_create_model, - mock_list_transform_jobs, - mock_transform_job, - mock_describe_transform_job, - mock_describe_model, - mock_defer, - context, - ): - task = SageMakerTransformOperatorAsync( - config=CONFIG, - task_id=self.TASK_ID, - check_if_job_exists=False, - check_interval=self.CHECK_INTERVAL, - max_ingestion_time=self.MAX_INGESTION_TIME, - ) - with pytest.raises(AirflowException): - task.execute(context) - assert not mock_defer.called - - @mock.patch("astronomer.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperatorAsync.defer") - @mock.patch.object(SageMakerHook, "describe_model", return_value={}) - @mock.patch.object( - SageMakerHook, "describe_transform_job", return_value={"TransformJobStatus": "Completed"} - ) - @mock.patch.object( - SageMakerHook, - "create_transform_job", - return_value={"TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, - ) - @mock.patch.object(SageMakerHook, "list_transform_jobs", return_value=[]) - @mock.patch.object(SageMakerHook, "create_model", return_value=None) - def test_sagemaker_transform_op_async_complete_before_defer( - self, - mock_create_model, - mock_list_transform_jobs, - mock_transform_job, - mock_describe_transform_job, - mock_describe_model, - mock_defer, - context, - ): + def test_init(self): task = SageMakerTransformOperatorAsync( config=CONFIG, task_id=self.TASK_ID, @@ -359,87 +163,8 @@ def test_sagemaker_transform_op_async_complete_before_defer( check_interval=self.CHECK_INTERVAL, max_ingestion_time=self.MAX_INGESTION_TIME, ) - task.execute(context) - assert not mock_defer.called - - @mock.patch.object( - SageMakerHook, "describe_transform_job", return_value={"TransformJobStatus": "InProgress"} - ) - @mock.patch.object( - SageMakerHook, - "create_transform_job", - return_value={"TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, - ) - @mock.patch.object(SageMakerHook, "list_transform_jobs", return_value=[]) - @mock.patch.object(SageMakerHook, "create_model", return_value=None) - def test_sagemaker_transform_op_async( - self, - mock_create_model, - mock_list_transform_jobs, - mock_transform_job, - mock_describe_transform_job, - context, - ): - """Assert SageMakerTransformOperatorAsync deferred properly""" - task = SageMakerTransformOperatorAsync( - config=CONFIG, - task_id=self.TASK_ID, - check_if_job_exists=False, - check_interval=self.CHECK_INTERVAL, - max_ingestion_time=self.MAX_INGESTION_TIME, - ) - with pytest.raises(TaskDeferred) as exc: - task.execute(context) - assert isinstance(exc.value.trigger, SagemakerTrigger), "Trigger is not a SagemakerTrigger" - - @mock.patch.object( - SageMakerHook, - "create_transform_job", - return_value={"TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}}, - ) - @mock.patch.object(SageMakerHook, "list_transform_jobs", return_value=[]) - @mock.patch.object(SageMakerHook, "create_model", return_value=None) - def test_sagemaker_transform_op_async_execute_failure(self, mock_hook, mock_transform_job, context): - """Tests that an AirflowException is raised in case of error event from create_transform_job""" - task = SageMakerTransformOperatorAsync( - config=CONFIG, - task_id=self.TASK_ID, - check_if_job_exists=False, - check_interval=self.CHECK_INTERVAL, - ) - with pytest.raises(AirflowException): - task.execute(context) - - @pytest.mark.parametrize( - "mock_event", - [{"status": "error", "message": "test failure message"}, None], - ) - def test_sagemaker_transform_op_async_execute_complete_failure(self, mock_event): - """Tests that an AirflowException is raised in case of error event""" - task = SageMakerTransformOperatorAsync( - config=CONFIG, - task_id=self.TASK_ID, - check_interval=self.CHECK_INTERVAL, - ) - with pytest.raises(AirflowException): - task.execute_complete(context=None, event=mock_event) - - @pytest.mark.parametrize( - "mock_event", - [{"status": "success", "message": "Job completed"}], - ) - @mock.patch.object(SageMakerHook, "describe_model") - def test_sagemaker_transform_op_async_execute_complete(self, mock_model_output, mock_event): - """Asserts that logging occurs as expected""" - task = SageMakerTransformOperatorAsync( - config=CONFIG, - task_id=self.TASK_ID, - check_interval=self.CHECK_INTERVAL, - ) - mock_model_output.return_value = {"test": "test"} - with mock.patch.object(task.log, "info") as mock_log_info: - task.execute_complete(context=None, event=mock_event) - mock_log_info.assert_called_with("%s completed successfully.", "test_sagemaker_transform_operator") + assert isinstance(task, SageMakerTransformOperator) + assert task.deferrable is True class TestSagemakerTrainingOperatorAsync: diff --git a/tests/amazon/aws/sensors/test_redshift_sensor.py b/tests/amazon/aws/sensors/test_redshift_sensor.py index f3023ae5d..09b0eb238 100644 --- a/tests/amazon/aws/sensors/test_redshift_sensor.py +++ b/tests/amazon/aws/sensors/test_redshift_sensor.py @@ -1,72 +1,20 @@ -from unittest import mock - -import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.providers.amazon.aws.sensors.redshift_cluster import ( + RedshiftClusterSensor, +) from astronomer.providers.amazon.aws.sensors.redshift_cluster import ( RedshiftClusterSensorAsync, ) -from astronomer.providers.amazon.aws.triggers.redshift_cluster import ( - RedshiftClusterSensorTrigger, -) TASK_ID = "redshift_sensor_check" -POLLING_PERIOD_SECONDS = 1.0 - -MODULE = "astronomer.providers.amazon.aws.sensors.redshift_cluster" class TestRedshiftClusterSensorAsync: - TASK = RedshiftClusterSensorAsync( - task_id=TASK_ID, - cluster_identifier="astro-redshift-cluster-1", - target_status="available", - ) - - @mock.patch(f"{MODULE}.RedshiftClusterSensorAsync.defer") - @mock.patch(f"{MODULE}.RedshiftClusterSensorAsync.poke", return_value=True) - def test_redshift_cluster_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context): - """Assert task is not deferred when it receives a finish status before deferring""" - self.TASK.execute(context) - assert not mock_defer.called - - @mock.patch(f"{MODULE}.RedshiftClusterSensorAsync.poke", return_value=False) - def test_redshift_cluster_sensor_async(self, context): - """Test RedshiftClusterSensorAsync that a task with wildcard=True - is deferred and an RedshiftClusterSensorTrigger will be fired when executed method is called""" - - with pytest.raises(TaskDeferred) as exc: - self.TASK.execute(context) - assert isinstance( - exc.value.trigger, RedshiftClusterSensorTrigger - ), "Trigger is not a RedshiftClusterSensorTrigger" - - def test_redshift_sensor_async_execute_failure(self, context): - """Test RedshiftClusterSensorAsync with an AirflowException is raised in case of error event""" - - with pytest.raises(AirflowException): - self.TASK.execute_complete( - context=None, event={"status": "error", "message": "test failure message"} - ) - - def test_redshift_sensor_async_execute_complete(self): - """Asserts that logging occurs as expected""" - - with mock.patch.object(self.TASK.log, "info") as mock_log_info: - self.TASK.execute_complete( - context=None, event={"status": "success", "cluster_state": "available"} - ) - mock_log_info.assert_called_with( - "Cluster Identifier %s is in %s state", "astro-redshift-cluster-1", "available" + def test_init(self): + task = RedshiftClusterSensorAsync( + task_id=TASK_ID, + cluster_identifier="astro-redshift-cluster-1", + target_status="available", ) - - def test_poll_interval_deprecation_warning(self): - """Test DeprecationWarning for RedshiftClusterSensorAsync by setting param poll_interval""" - # TODO: Remove once deprecated - with pytest.warns(expected_warning=DeprecationWarning): - RedshiftClusterSensorAsync( - task_id=TASK_ID, - cluster_identifier="astro-redshift-cluster-1", - target_status="available", - poll_interval=5.0, - ) + assert isinstance(task, RedshiftClusterSensor) + assert task.deferrable is True diff --git a/tests/amazon/aws/sensors/test_s3_sensors.py b/tests/amazon/aws/sensors/test_s3_sensors.py index 4dfae42d6..bff9e7b2c 100644 --- a/tests/amazon/aws/sensors/test_s3_sensors.py +++ b/tests/amazon/aws/sensors/test_s3_sensors.py @@ -1,15 +1,7 @@ import unittest -from datetime import timedelta -from typing import Any, List -from unittest import mock import pytest -from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred -from airflow.models import DAG, DagRun, TaskInstance -from airflow.models.variable import Variable -from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor -from airflow.utils import timezone -from parameterized import parameterized +from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeysUnchangedSensor from astronomer.providers.amazon.aws.sensors.s3 import ( S3KeySensorAsync, @@ -17,279 +9,15 @@ S3KeysUnchangedSensorAsync, S3PrefixSensorAsync, ) -from astronomer.providers.amazon.aws.triggers.s3 import ( - S3KeyTrigger, -) MODULE = "astronomer.providers.amazon.aws.sensors.s3" class TestS3KeySensorAsync: - @mock.patch(f"{MODULE}.S3KeySensorAsync.defer") - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=True) - def test_finish_before_deferred(self, mock_poke, mock_defer, context): - """Assert task is not deferred when it receives a finish status before deferring""" - sensor = S3KeySensorAsync(task_id="s3_key_sensor", bucket_key="file_in_bucket") - sensor.execute(context) - assert not mock_defer.called - - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False) - def test_bucket_name_none_and_bucket_key_as_relative_path(self, mock_poke, context): - """ - Test if exception is raised when bucket_name is None - and bucket_key is provided with one of the two keys as relative path rather than s3:// url. - """ - sensor = S3KeySensorAsync(task_id="s3_key_sensor", bucket_key="file_in_bucket") - with pytest.raises(TaskDeferred): - sensor.execute(context) - - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False) - @mock.patch("astronomer.providers.amazon.aws.hooks.s3.S3HookAsync.get_head_object") - def test_bucket_name_none_and_bucket_key_is_list_and_contain_relative_path( - self, mock_head_object, mock_poke, context - ): - """ - Test if exception is raised when bucket_name is None - and bucket_key is provided with one of the two keys as relative path rather than s3:// url. - :return: - """ - mock_head_object.return_value = {"ContentLength": 0} - sensor = S3KeySensorAsync( - task_id="s3_key_sensor", bucket_key=["s3://test_bucket/file", "file_in_bucket"] - ) - with pytest.raises(TaskDeferred): - sensor.execute(context) - - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False) - def test_bucket_name_provided_and_bucket_key_is_s3_url(self, mock_poke, context): - """ - Test if exception is raised when bucket_name is provided - while bucket_key is provided as a full s3:// url. - :return: - """ - op = S3KeySensorAsync( - task_id="s3_key_sensor", bucket_key="s3://test_bucket/file", bucket_name="test_bucket" - ) - with pytest.raises(TaskDeferred): - op.execute(context) - - @parameterized.expand( - [ - ["s3://bucket/key", None], - ["key", "bucket"], - ] - ) - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False) - @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") - def test_s3_key_sensor_async(self, key, bucket, mock_hook, mock_poke): - """ - Asserts that a task is deferred and an S3KeyTrigger will be fired - when the S3KeySensorAsync is executed. - """ - mock_hook.check_for_key.return_value = False - - sensor = S3KeySensorAsync( - task_id="s3_key_sensor_async", - bucket_key=key, - bucket_name=bucket, - ) - - with pytest.raises(TaskDeferred) as exc: - sensor.execute(context=None) - - assert isinstance(exc.value.trigger, S3KeyTrigger), "Trigger is not a S3KeyTrigger" - - @parameterized.expand( - [ - ["s3://bucket/key", None], - ["key", "bucket"], - ] - ) - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False) - @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") - def test_s3_key_sensor_execute_complete_success(self, key, bucket, mock_poke, mock_hook): - """ - Asserts that a task is completed with success status. - """ - mock_hook.check_for_key.return_value = False - - sensor = S3KeySensorAsync( - task_id="s3_key_sensor_async", - bucket_key=key, - bucket_name=bucket, - ) - assert sensor.execute_complete(context={}, event={"status": "success"}) is None - - @parameterized.expand( - [ - ["key", "bucket"], - ] - ) - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False) - def test_s3_key_sensor_execute_complete_success_with_keys(self, key, bucket, mock_poke): - """ - Asserts that a task is completed with success status and check function - """ - - def check_fn(files: List[Any]) -> bool: - return all(f.get("Size", 0) > 0 for f in files) - - sensor = S3KeySensorAsync( - task_id="s3_key_sensor_async", - bucket_key=key, - bucket_name=bucket, - check_fn=check_fn, - ) - assert ( - sensor.execute_complete(context={}, event={"status": "running", "files": [{"Size": 10}]}) is None - ) - - @mock.patch(f"{MODULE}.S3KeySensorAsync._defer") - def test_s3_key_sensor_re_defer(self, mock_defer): - def check_fn(files: List[Any]) -> bool: - return False - - sensor = S3KeySensorAsync( - task_id="s3_key_sensor_async", - bucket_key="key", - bucket_name="bucket", - check_fn=check_fn, - ) - sensor.execute_complete(context={}, event={"status": "running", "files": [{"Size": 10}]}) - - mock_defer.assert_called_once() - - @parameterized.expand( - [ - ["s3://bucket/key", None], - ["key", "bucket"], - ] - ) - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False) - @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") - def test_s3_key_sensor_execute_complete_error(self, key, bucket, mock_hook, mock_poke): - """ - Asserts that a task is completed with error status. - """ - mock_hook.check_for_key.return_value = False - - sensor = S3KeySensorAsync( - task_id="s3_key_sensor_async", - bucket_key=key, - bucket_name=bucket, - ) - with pytest.raises(AirflowException): - sensor.execute_complete( - context={}, event={"status": "error", "message": "mocked error", "soft_fail": False} - ) - - @parameterized.expand( - [ - ["s3://bucket/key", None], - ["key", "bucket"], - ] - ) - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False) - @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") - @mock.patch.object(S3KeySensorAsync, "defer") - @mock.patch("astronomer.providers.amazon.aws.sensors.s3.S3KeyTrigger") - def test_s3_key_sensor_async_with_mock_defer( - self, key, bucket, mock_trigger, mock_defer, mock_hook, mock_poke - ): - """ - Asserts that a task is deferred and an S3KeyTrigger will be fired - when the S3KeySensorAsync is executed. - """ - mock_hook.check_for_key.return_value = False - - sensor = S3KeySensorAsync( - task_id="s3_key_sensor_async", - bucket_key=key, - bucket_name=bucket, - ) - - sensor.execute(context=None) - - mock_defer.assert_called() - mock_defer.assert_called_once_with( - timeout=timedelta(days=7), trigger=mock_trigger.return_value, method_name="execute_complete" - ) - - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False) - @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook.check_for_key") - def test_parse_bucket_key_from_jinja(self, mock_check, mock_poke): - mock_check.return_value = False - - Variable.set("test_bucket_key", "s3://bucket/key") - - execution_date = timezone.datetime(2020, 1, 1) - - dag = DAG("test_s3_key", start_date=execution_date) - op = S3KeySensorAsync( - task_id="s3_key_sensor", - bucket_key="s3://bucket/key", - bucket_name=None, - dag=dag, - ) - - dag_run = DagRun(dag_id=dag.dag_id, execution_date=execution_date, run_id="test") - ti = TaskInstance(task=op) - ti.dag_run = dag_run - context = ti.get_template_context() - ti.render_templates(context) - - assert op.bucket_key == ["s3://bucket/key"] - assert op.bucket_name is None - - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke", return_value=False) - @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") - def test_s3_key_sensor_with_wildcard_async(self, mock_hook, mock_poke, context): - """ - Asserts that a task with wildcard=True is deferred and an S3KeyTrigger will be fired - when the S3KeySensorAsync is executed. - """ - mock_hook.check_for_key.return_value = False - - sensor = S3KeySensorAsync( - task_id="s3_key_sensor_async", bucket_key="s3://test_bucket/file", wildcard_match=True - ) - - with pytest.raises(TaskDeferred) as exc: - sensor.execute(context) - - assert isinstance(exc.value.trigger, S3KeyTrigger), "Trigger is not a S3KeyTrigger" - - def test_soft_fail(self): - """Raise AirflowSkipException in case soft_fail is true""" - sensor = S3KeySensorAsync( - task_id="s3_key_sensor_async", bucket_key="key", bucket_name="bucket", soft_fail=True - ) - with pytest.raises(AirflowSkipException): - sensor.execute_complete( - context={}, event={"status": "error", "message": "mocked error", "soft_fail": True} - ) - - @pytest.mark.parametrize( - "soft_fail,exception", - [ - (True, AirflowSkipException), - (False, Exception), - ], - ) - @mock.patch(f"{MODULE}.S3KeySensorAsync.poke") - def test_execute_handle_exception(self, mock_poke, soft_fail, exception): - mock_poke.side_effect = Exception() - sensor = S3KeySensorAsync( - task_id="s3_key_sensor_async", bucket_key="key", bucket_name="bucket", soft_fail=soft_fail - ) - with pytest.raises(exception): - sensor.execute(context={}) - - def test_soft_fail_enable(self, context): - """Sensor should raise AirflowSkipException if soft_fail is True and error occur""" - sensor = S3KeySensorAsync(task_id="s3_key_sensor", bucket_key="file_in_bucket", soft_fail=True) - with pytest.raises(AirflowSkipException): - sensor.execute(context) + def test_init(self): + task = S3KeySensorAsync(task_id="s3_key_sensor", bucket_key="file_in_bucket") + assert isinstance(task, S3KeySensor) + assert task.deferrable is True class TestS3KeysUnchangedSensorAsync: