Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate LivyOperatorAsync #1454

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions astronomer/providers/apache/livy/hooks/livy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains the Apache Livy hook async."""
import asyncio
import re
import warnings
from typing import Any, Dict, List, Optional, Sequence, Union

import aiohttp
Expand All @@ -15,16 +16,8 @@

class LivyHookAsync(HttpHookAsync, LoggingMixin):
"""
Hook for Apache Livy through the REST API using LivyHookAsync

:param livy_conn_id: reference to a pre-defined Livy Connection.
:param extra_options: Additional option can be passed when creating a request.
For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)``
:param extra_headers: A dictionary of headers passed to the HTTP request to livy.

.. seealso::
For more details refer to the Apache Livy API reference:
`Apache Livy API reference <https://livy.apache.org/docs/latest/rest-api.html>`_
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.apache.livy.hooks.livy.LivyHook` instead.
"""

TERMINAL_STATES = {
Expand All @@ -47,6 +40,10 @@ def __init__(
extra_options: Optional[Dict[str, Any]] = None,
extra_headers: Optional[Dict[str, Any]] = None,
) -> None:
warnings.warn(
"This class is deprecated and will be removed in 2.0.0."
"Use `airflow.providers.apache.livy.hooks.livy.LivyHook` instead."
)
super().__init__(http_conn_id=livy_conn_id)
self.extra_headers = extra_headers or {}
self.extra_options = extra_options or {}
Expand Down
98 changes: 13 additions & 85 deletions astronomer/providers/apache/livy/operators/livy.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,23 @@
"""This module contains the Apache Livy operator async."""
from typing import Any, Dict
from __future__ import annotations

from airflow.exceptions import AirflowException
from airflow.providers.apache.livy.operators.livy import BatchState, LivyOperator
import warnings
from typing import Any

from astronomer.providers.apache.livy.triggers.livy import LivyTrigger
from astronomer.providers.utils.typing_compat import Context
from airflow.providers.apache.livy.operators.livy import LivyOperator


class LivyOperatorAsync(LivyOperator):
"""
This operator wraps the Apache Livy batch REST API, allowing to submit a Spark
application to the underlying cluster asynchronously.

:param file: path of the file containing the application to execute (required).
:param class_name: name of the application Java/Spark main class.
:param args: application command line arguments.
:param jars: jars to be used in this sessions.
:param py_files: python files to be used in this session.
:param files: files to be used in this session.
:param driver_memory: amount of memory to use for the driver process.
:param driver_cores: number of cores to use for the driver process.
:param executor_memory: amount of memory to use per executor process.
:param executor_cores: number of cores to use for each executor.
:param num_executors: number of executors to launch for this session.
:param archives: archives to be used in this session.
:param queue: name of the YARN queue to which the application is submitted.
:param name: name of this session.
:param conf: Spark configuration properties.
:param proxy_user: user to impersonate when running the job.
:param livy_conn_id: reference to a pre-defined Livy Connection.
:param polling_interval: time in seconds between polling for job completion. If poll_interval=0, in that case
return the batch_id and if polling_interval > 0, poll the livy job for termination in the polling interval
defined.
:param extra_options: Additional option can be passed when creating a request.
For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)``
:param extra_headers: A dictionary of headers passed to the HTTP request to livy.
:param retry_args: Arguments which define the retry behaviour.
See Tenacity documentation at https://github.com/jd/tenacity
This class is deprecated.
Use :class: `~airflow.providers.apache.livy.operators.livy.LivyOperator` instead
and set `deferrable` param to `True` instead.
"""

def execute(self, context: Context) -> Any:
"""
Airflow runs this method on the worker and defers using the trigger.
Submit the job and get the job_id using which we defer and poll in trigger
"""
self._batch_id = self.get_hook().post_batch(**self.spark_params)
self.log.info("Generated batch-id is %s", self._batch_id)

hook = self.get_hook()
state = hook.get_batch_state(self._batch_id, retry_args=self.retry_args)
self.log.debug("Batch with id %s is in state: %s", self._batch_id, state.value)
if state not in hook.TERMINAL_STATES:
self.defer(
timeout=self.execution_timeout,
trigger=LivyTrigger(
batch_id=self._batch_id,
spark_params=self.spark_params,
livy_conn_id=self._livy_conn_id,
polling_interval=self._polling_interval,
extra_options=self._extra_options,
extra_headers=self._extra_headers,
),
method_name="execute_complete",
)
else:
self.log.info("Batch with id %s terminated with state: %s", self._batch_id, state.value)
hook.dump_batch_logs(self._batch_id)
if state != BatchState.SUCCESS:
raise AirflowException(f"Batch {self._batch_id} did not succeed")

context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"])
return self._batch_id

def execute_complete(self, context: Context, event: Dict[str, Any]) -> Any:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
# dump the logs from livy to worker through triggerer.
if event.get("log_lines", None) is not None:
for log_line in event["log_lines"]:
self.log.info(log_line)

if event["status"] == "error":
raise AirflowException(event["response"])
self.log.info(
"%s completed with response %s",
self.task_id,
event["response"],
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
"This class is deprecated. "
"Use `airflow.providers.apache.livy.operators.livy.LivyOperator` "
"and set `deferrable` param to `True` instead.",
)
context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(event["batch_id"])["appId"])
return event["batch_id"]
super().__init__(*args, deferrable=True, **kwargs)
22 changes: 7 additions & 15 deletions astronomer/providers/apache/livy/triggers/livy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module contains the Apache Livy Trigger."""
import asyncio
import warnings
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union

from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand All @@ -9,21 +10,8 @@

class LivyTrigger(BaseTrigger):
"""
Check for the state of a previously submitted job with batch_id

:param batch_id: Batch job id
:param spark_params: Spark parameters; for example,
spark_params = {"file": "test/pi.py", "class_name": "org.apache.spark.examples.SparkPi",
"args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"],"jars": "command-runner.jar",
"driver_cores": 1, "executor_cores": 4,"num_executors": 1}
:param livy_conn_id: reference to a pre-defined Livy Connection.
:param polling_interval: time in seconds between polling for job completion. If poll_interval=0, in that case
return the batch_id and if polling_interval > 0, poll the livy job for termination in the polling interval
defined.
:param extra_options: A dictionary of options, where key is string and value
depends on the option that's being modified.
:param extra_headers: A dictionary of headers passed to the HTTP request to livy.
:param livy_hook_async: LivyHookAsync object
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.apache.livy.triggers.livy.LivyTrigger` instead.
"""

def __init__(
Expand All @@ -36,6 +24,10 @@ def __init__(
extra_headers: Optional[Dict[str, Any]] = None,
livy_hook_async: Optional[LivyHookAsync] = None,
):
warnings.warn(
"This class is deprecated. "
"Use `airflow.providers.apache.livy.triggers.livy.LivyTrigger` instead.",
)
super().__init__()
self._batch_id = batch_id
self.spark_params = spark_params
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ apache.hive =
apache-airflow-providers-apache-hive>=6.1.5
impyla
apache.livy =
apache-airflow-providers-apache-livy
apache-airflow-providers-apache-livy>=3.7.1
paramiko
cncf.kubernetes =
apache-airflow-providers-cncf-kubernetes>=4
Expand Down Expand Up @@ -120,7 +120,7 @@ all =
aiobotocore>=2.1.1
apache-airflow-providers-amazon>=8.16.0
apache-airflow-providers-apache-hive>=6.1.5
apache-airflow-providers-apache-livy
apache-airflow-providers-apache-livy>=3.7.1
apache-airflow-providers-cncf-kubernetes>=4
apache-airflow-providers-databricks>=2.2.0
apache-airflow-providers-google>=8.1.0
Expand Down
166 changes: 5 additions & 161 deletions tests/apache/livy/operators/test_livy.py
Original file line number Diff line number Diff line change
@@ -1,174 +1,18 @@
from unittest.mock import MagicMock, patch
from __future__ import annotations

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.apache.livy.hooks.livy import BatchState
from airflow.utils import timezone
from airflow.providers.apache.livy.operators.livy import LivyOperator

from astronomer.providers.apache.livy.operators.livy import LivyOperatorAsync
from astronomer.providers.apache.livy.triggers.livy import LivyTrigger

DEFAULT_DATE = timezone.datetime(2017, 1, 1)
mock_livy_client = MagicMock()

BATCH_ID = 100
LOG_RESPONSE = {"total": 3, "log": ["first_line", "second_line", "third_line"]}


class TestLivyOperatorAsync:
@pytest.fixture()
@patch(
"astronomer.providers.apache.livy.hooks.livy.LivyHookAsync.dump_batch_logs",
return_value=None,
)
@patch("astronomer.providers.apache.livy.hooks.livy.LivyHookAsync.get_batch_state")
async def test_poll_for_termination(self, mock_livy, mock_dump_logs, dag):
state_list = 2 * [BatchState.RUNNING] + [BatchState.SUCCESS]

def side_effect(_, retry_args):
if state_list:
return state_list.pop(0)
# fail if does not stop right before
raise AssertionError()

mock_livy.side_effect = side_effect

task = LivyOperatorAsync(file="sparkapp", polling_interval=1, dag=dag, task_id="livy_example")
task._livy_hook = task.get_hook()
task.poll_for_termination(BATCH_ID)

mock_livy.assert_called_with(BATCH_ID, retry_args=None)
mock_dump_logs.assert_called_with(BATCH_ID)
assert mock_livy.call_count == 3

@pytest.mark.parametrize(
"mock_state",
(
BatchState.NOT_STARTED,
BatchState.STARTING,
BatchState.RUNNING,
BatchState.IDLE,
BatchState.SHUTTING_DOWN,
),
)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state")
def test_livy_operator_async(self, mock_get_batch_state, mock_post, mock_state, dag):
mock_get_batch_state.retun_value = mock_state
task = LivyOperatorAsync(
livy_conn_id="livyunittest",
file="sparkapp",
polling_interval=1,
dag=dag,
task_id="livy_example",
)

with pytest.raises(TaskDeferred) as exc:
task.execute({})

assert isinstance(exc.value.trigger, LivyTrigger), "Trigger is not a LivyTrigger"

@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs",
return_value=None,
)
@patch("astronomer.providers.apache.livy.operators.livy.LivyOperatorAsync.defer")
@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value={"appId": BATCH_ID}
)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state",
return_value=BatchState.SUCCESS,
)
def test_livy_operator_async_finish_before_deferred_success(
self, mock_get_batch_state, mock_post, mock_get, mock_defer, mock_dump_logs, dag
):
def test_init(self):
task = LivyOperatorAsync(
livy_conn_id="livyunittest",
file="sparkapp",
polling_interval=1,
dag=dag,
task_id="livy_example",
)
assert task.execute(context={"ti": MagicMock()}) == BATCH_ID
assert not mock_defer.called

@pytest.mark.parametrize(
"mock_state",
(
BatchState.ERROR,
BatchState.DEAD,
BatchState.KILLED,
),
)
@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.dump_batch_logs",
return_value=None,
)
@patch("astronomer.providers.apache.livy.operators.livy.LivyOperatorAsync.defer")
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state")
def test_livy_operator_async_finish_before_deferred_not_success(
self, mock_get_batch_state, mock_post, mock_defer, mock_dump_logs, mock_state, dag
):
mock_get_batch_state.return_value = mock_state

task = LivyOperatorAsync(
livy_conn_id="livyunittest",
file="sparkapp",
polling_interval=1,
dag=dag,
task_id="livy_example",
)
with pytest.raises(AirflowException):
task.execute({})
assert not mock_defer.called

@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value={"appId": BATCH_ID}
)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
def test_livy_operator_async_execute_complete_success(self, mock_post, mock_get, dag):
"""Asserts that a task is completed with success status."""
task = LivyOperatorAsync(
livy_conn_id="livyunittest",
file="sparkapp",
polling_interval=1,
dag=dag,
task_id="livy_example",
)
assert (
task.execute_complete(
context={"ti": MagicMock()},
event={
"status": "success",
"log_lines": None,
"batch_id": BATCH_ID,
"response": "mock success",
},
)
is BATCH_ID
)

@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
def test_livy_operator_async_execute_complete_error(self, mock_post, dag):
"""Asserts that a task is completed with success status."""

task = LivyOperatorAsync(
livy_conn_id="livyunittest",
file="sparkapp",
polling_interval=1,
dag=dag,
task_id="livy_example",
)
with pytest.raises(AirflowException):
task.execute_complete(
context={},
event={
"status": "error",
"log_lines": ["mock log"],
"batch_id": BATCH_ID,
"response": "mock error",
},
)
assert isinstance(task, LivyOperator)
assert task.deferrable is True
Loading