diff --git a/astronomer/providers/apache/livy/hooks/livy.py b/astronomer/providers/apache/livy/hooks/livy.py index bd41f6c99..cbe2e2a4b 100644 --- a/astronomer/providers/apache/livy/hooks/livy.py +++ b/astronomer/providers/apache/livy/hooks/livy.py @@ -1,6 +1,7 @@ """This module contains the Apache Livy hook async.""" import asyncio import re +import warnings from typing import Any, Dict, List, Optional, Sequence, Union import aiohttp @@ -15,16 +16,8 @@ class LivyHookAsync(HttpHookAsync, LoggingMixin): """ - Hook for Apache Livy through the REST API using LivyHookAsync - - :param livy_conn_id: reference to a pre-defined Livy Connection. - :param extra_options: Additional option can be passed when creating a request. - For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)`` - :param extra_headers: A dictionary of headers passed to the HTTP request to livy. - - .. seealso:: - For more details refer to the Apache Livy API reference: - `Apache Livy API reference `_ + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.apache.livy.hooks.livy.LivyHook` instead. """ TERMINAL_STATES = { @@ -47,6 +40,10 @@ def __init__( extra_options: Optional[Dict[str, Any]] = None, extra_headers: Optional[Dict[str, Any]] = None, ) -> None: + warnings.warn( + "This class is deprecated and will be removed in 2.0.0." + "Use `airflow.providers.apache.livy.hooks.livy.LivyHook` instead." + ) super().__init__(http_conn_id=livy_conn_id) self.extra_headers = extra_headers or {} self.extra_options = extra_options or {} diff --git a/astronomer/providers/apache/livy/operators/livy.py b/astronomer/providers/apache/livy/operators/livy.py index 049ef33a9..94d9c0eea 100644 --- a/astronomer/providers/apache/livy/operators/livy.py +++ b/astronomer/providers/apache/livy/operators/livy.py @@ -1,95 +1,23 @@ """This module contains the Apache Livy operator async.""" -from typing import Any, Dict +from __future__ import annotations -from airflow.exceptions import AirflowException -from airflow.providers.apache.livy.operators.livy import BatchState, LivyOperator +import warnings +from typing import Any -from astronomer.providers.apache.livy.triggers.livy import LivyTrigger -from astronomer.providers.utils.typing_compat import Context +from airflow.providers.apache.livy.operators.livy import LivyOperator class LivyOperatorAsync(LivyOperator): """ - This operator wraps the Apache Livy batch REST API, allowing to submit a Spark - application to the underlying cluster asynchronously. - - :param file: path of the file containing the application to execute (required). - :param class_name: name of the application Java/Spark main class. - :param args: application command line arguments. - :param jars: jars to be used in this sessions. - :param py_files: python files to be used in this session. - :param files: files to be used in this session. - :param driver_memory: amount of memory to use for the driver process. - :param driver_cores: number of cores to use for the driver process. - :param executor_memory: amount of memory to use per executor process. - :param executor_cores: number of cores to use for each executor. - :param num_executors: number of executors to launch for this session. - :param archives: archives to be used in this session. - :param queue: name of the YARN queue to which the application is submitted. - :param name: name of this session. - :param conf: Spark configuration properties. - :param proxy_user: user to impersonate when running the job. - :param livy_conn_id: reference to a pre-defined Livy Connection. - :param polling_interval: time in seconds between polling for job completion. If poll_interval=0, in that case - return the batch_id and if polling_interval > 0, poll the livy job for termination in the polling interval - defined. - :param extra_options: Additional option can be passed when creating a request. - For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)`` - :param extra_headers: A dictionary of headers passed to the HTTP request to livy. - :param retry_args: Arguments which define the retry behaviour. - See Tenacity documentation at https://github.com/jd/tenacity + This class is deprecated. + Use :class: `~airflow.providers.apache.livy.operators.livy.LivyOperator` instead + and set `deferrable` param to `True` instead. """ - def execute(self, context: Context) -> Any: - """ - Airflow runs this method on the worker and defers using the trigger. - Submit the job and get the job_id using which we defer and poll in trigger - """ - self._batch_id = self.get_hook().post_batch(**self.spark_params) - self.log.info("Generated batch-id is %s", self._batch_id) - - hook = self.get_hook() - state = hook.get_batch_state(self._batch_id, retry_args=self.retry_args) - self.log.debug("Batch with id %s is in state: %s", self._batch_id, state.value) - if state not in hook.TERMINAL_STATES: - self.defer( - timeout=self.execution_timeout, - trigger=LivyTrigger( - batch_id=self._batch_id, - spark_params=self.spark_params, - livy_conn_id=self._livy_conn_id, - polling_interval=self._polling_interval, - extra_options=self._extra_options, - extra_headers=self._extra_headers, - ), - method_name="execute_complete", - ) - else: - self.log.info("Batch with id %s terminated with state: %s", self._batch_id, state.value) - hook.dump_batch_logs(self._batch_id) - if state != BatchState.SUCCESS: - raise AirflowException(f"Batch {self._batch_id} did not succeed") - - context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"]) - return self._batch_id - - def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - # dump the logs from livy to worker through triggerer. - if event.get("log_lines", None) is not None: - for log_line in event["log_lines"]: - self.log.info(log_line) - - if event["status"] == "error": - raise AirflowException(event["response"]) - self.log.info( - "%s completed with response %s", - self.task_id, - event["response"], + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + "This class is deprecated. " + "Use `airflow.providers.apache.livy.operators.livy.LivyOperator` " + "and set `deferrable` param to `True` instead.", ) - context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(event["batch_id"])["appId"]) - return event["batch_id"] + super().__init__(*args, deferrable=True, **kwargs) diff --git a/astronomer/providers/apache/livy/triggers/livy.py b/astronomer/providers/apache/livy/triggers/livy.py index 0d965fc27..7d48cc6aa 100644 --- a/astronomer/providers/apache/livy/triggers/livy.py +++ b/astronomer/providers/apache/livy/triggers/livy.py @@ -1,5 +1,6 @@ """This module contains the Apache Livy Trigger.""" import asyncio +import warnings from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -9,21 +10,8 @@ class LivyTrigger(BaseTrigger): """ - Check for the state of a previously submitted job with batch_id - - :param batch_id: Batch job id - :param spark_params: Spark parameters; for example, - spark_params = {"file": "test/pi.py", "class_name": "org.apache.spark.examples.SparkPi", - "args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"],"jars": "command-runner.jar", - "driver_cores": 1, "executor_cores": 4,"num_executors": 1} - :param livy_conn_id: reference to a pre-defined Livy Connection. - :param polling_interval: time in seconds between polling for job completion. If poll_interval=0, in that case - return the batch_id and if polling_interval > 0, poll the livy job for termination in the polling interval - defined. - :param extra_options: A dictionary of options, where key is string and value - depends on the option that's being modified. - :param extra_headers: A dictionary of headers passed to the HTTP request to livy. - :param livy_hook_async: LivyHookAsync object + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.apache.livy.triggers.livy.LivyTrigger` instead. """ def __init__( @@ -36,6 +24,10 @@ def __init__( extra_headers: Optional[Dict[str, Any]] = None, livy_hook_async: Optional[LivyHookAsync] = None, ): + warnings.warn( + "This class is deprecated. " + "Use `airflow.providers.apache.livy.triggers.livy.LivyTrigger` instead.", + ) super().__init__() self._batch_id = batch_id self.spark_params = spark_params diff --git a/setup.cfg b/setup.cfg index ab161f7d5..1e68e119c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,7 @@ apache.hive = apache-airflow-providers-apache-hive>=6.1.5 impyla apache.livy = - apache-airflow-providers-apache-livy + apache-airflow-providers-apache-livy>=3.7.1 paramiko cncf.kubernetes = apache-airflow-providers-cncf-kubernetes>=4 @@ -120,7 +120,7 @@ all = aiobotocore>=2.1.1 apache-airflow-providers-amazon>=8.16.0 apache-airflow-providers-apache-hive>=6.1.5 - apache-airflow-providers-apache-livy + apache-airflow-providers-apache-livy>=3.7.1 apache-airflow-providers-cncf-kubernetes>=4 apache-airflow-providers-databricks>=2.2.0 apache-airflow-providers-google>=8.1.0 diff --git a/tests/apache/livy/operators/test_livy.py b/tests/apache/livy/operators/test_livy.py index 25c808276..d596cfe2d 100644 --- a/tests/apache/livy/operators/test_livy.py +++ b/tests/apache/livy/operators/test_livy.py @@ -1,174 +1,18 @@ -from unittest.mock import MagicMock, patch +from __future__ import annotations -import pytest -from airflow.exceptions import AirflowException, TaskDeferred -from airflow.providers.apache.livy.hooks.livy import BatchState -from airflow.utils import timezone +from airflow.providers.apache.livy.operators.livy import LivyOperator from astronomer.providers.apache.livy.operators.livy import LivyOperatorAsync -from astronomer.providers.apache.livy.triggers.livy import LivyTrigger - -DEFAULT_DATE = timezone.datetime(2017, 1, 1) -mock_livy_client = MagicMock() - -BATCH_ID = 100 -LOG_RESPONSE = {"total": 3, "log": ["first_line", "second_line", "third_line"]} class TestLivyOperatorAsync: - @pytest.fixture() - @patch( - "astronomer.providers.apache.livy.hooks.livy.LivyHookAsync.dump_batch_logs", - return_value=None, - ) - @patch("astronomer.providers.apache.livy.hooks.livy.LivyHookAsync.get_batch_state") - async def test_poll_for_termination(self, mock_livy, mock_dump_logs, dag): - state_list = 2 * [BatchState.RUNNING] + [BatchState.SUCCESS] - - def side_effect(_, retry_args): - if state_list: - return state_list.pop(0) - # fail if does not stop right before - raise AssertionError() - - mock_livy.side_effect = side_effect - - task = LivyOperatorAsync(file="sparkapp", polling_interval=1, dag=dag, task_id="livy_example") - task._livy_hook = task.get_hook() - task.poll_for_termination(BATCH_ID) - - mock_livy.assert_called_with(BATCH_ID, retry_args=None) - mock_dump_logs.assert_called_with(BATCH_ID) - assert mock_livy.call_count == 3 - - @pytest.mark.parametrize( - "mock_state", - ( - BatchState.NOT_STARTED, - BatchState.STARTING, - BatchState.RUNNING, - BatchState.IDLE, - BatchState.SHUTTING_DOWN, - ), - ) - @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) - @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state") - def test_livy_operator_async(self, mock_get_batch_state, mock_post, mock_state, dag): - mock_get_batch_state.retun_value = mock_state - task = LivyOperatorAsync( - livy_conn_id="livyunittest", - file="sparkapp", - polling_interval=1, - dag=dag, - task_id="livy_example", - ) - - with pytest.raises(TaskDeferred) as exc: - task.execute({}) - - assert isinstance(exc.value.trigger, LivyTrigger), "Trigger is not a LivyTrigger" - - @patch( - "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", - return_value=None, - ) - @patch("astronomer.providers.apache.livy.operators.livy.LivyOperatorAsync.defer") - @patch( - "airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value={"appId": BATCH_ID} - ) - @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) - @patch( - "airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state", - return_value=BatchState.SUCCESS, - ) - def test_livy_operator_async_finish_before_deferred_success( - self, mock_get_batch_state, mock_post, mock_get, mock_defer, mock_dump_logs, dag - ): + def test_init(self): task = LivyOperatorAsync( livy_conn_id="livyunittest", file="sparkapp", polling_interval=1, - dag=dag, task_id="livy_example", ) - assert task.execute(context={"ti": MagicMock()}) == BATCH_ID - assert not mock_defer.called - - @pytest.mark.parametrize( - "mock_state", - ( - BatchState.ERROR, - BatchState.DEAD, - BatchState.KILLED, - ), - ) - @patch( - "airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs", - return_value=None, - ) - @patch("astronomer.providers.apache.livy.operators.livy.LivyOperatorAsync.defer") - @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) - @patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state") - def test_livy_operator_async_finish_before_deferred_not_success( - self, mock_get_batch_state, mock_post, mock_defer, mock_dump_logs, mock_state, dag - ): - mock_get_batch_state.return_value = mock_state - task = LivyOperatorAsync( - livy_conn_id="livyunittest", - file="sparkapp", - polling_interval=1, - dag=dag, - task_id="livy_example", - ) - with pytest.raises(AirflowException): - task.execute({}) - assert not mock_defer.called - - @patch( - "airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value={"appId": BATCH_ID} - ) - @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) - def test_livy_operator_async_execute_complete_success(self, mock_post, mock_get, dag): - """Asserts that a task is completed with success status.""" - task = LivyOperatorAsync( - livy_conn_id="livyunittest", - file="sparkapp", - polling_interval=1, - dag=dag, - task_id="livy_example", - ) - assert ( - task.execute_complete( - context={"ti": MagicMock()}, - event={ - "status": "success", - "log_lines": None, - "batch_id": BATCH_ID, - "response": "mock success", - }, - ) - is BATCH_ID - ) - - @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID) - def test_livy_operator_async_execute_complete_error(self, mock_post, dag): - """Asserts that a task is completed with success status.""" - - task = LivyOperatorAsync( - livy_conn_id="livyunittest", - file="sparkapp", - polling_interval=1, - dag=dag, - task_id="livy_example", - ) - with pytest.raises(AirflowException): - task.execute_complete( - context={}, - event={ - "status": "error", - "log_lines": ["mock log"], - "batch_id": BATCH_ID, - "response": "mock error", - }, - ) + assert isinstance(task, LivyOperator) + assert task.deferrable is True