Skip to content

Commit

Permalink
Add endpoint_prefix to LivyHook (apache#45811)
Browse files Browse the repository at this point in the history
* revert manual changes to the Changelog file

* adding tests for livy db hook

* adding tests for async livy hook

* reformat comment, and fix endpoint_prefix type for livy trigger

* formatting changes

* remove extra default args from expectation

* adding mock patch for mock request

---------

Co-authored-by: Giridhar Pathak <[email protected]>
  • Loading branch information
gpathak128 and Giridhar Pathak authored Feb 1, 2025
1 parent 53e1723 commit 8c9e0b2
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class BatchState(Enum):
SUCCESS = "success"


def sanitize_endpoint_prefix(endpoint_prefix: str | None) -> str:
"""Ensure that the endpoint prefix is prefixed with a slash."""
return f"/{endpoint_prefix.strip('/')}" if endpoint_prefix else ""


class LivyHook(HttpHook):
"""
Hook for Apache Livy through the REST API.
Expand Down Expand Up @@ -86,12 +91,14 @@ def __init__(
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
auth_type: Any | None = None,
endpoint_prefix: str | None = None,
) -> None:
super().__init__()
self.method = "POST"
self.http_conn_id = livy_conn_id
self.extra_headers = extra_headers or {}
self.extra_options = extra_options or {}
self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)
if auth_type:
self.auth_type = auth_type

Expand Down Expand Up @@ -163,7 +170,10 @@ def post_batch(self, *args: Any, **kwargs: Any) -> int:
self.log.info("Submitting job %s to %s", batch_submit_body, self.base_url)

response = self.run_method(
method="POST", endpoint="/batches", data=batch_submit_body, headers=self.extra_headers
method="POST",
endpoint=f"{self.endpoint_prefix}/batches",
data=batch_submit_body,
headers=self.extra_headers,
)
self.log.debug("Got response: %s", response.text)

Expand Down Expand Up @@ -192,7 +202,9 @@ def get_batch(self, session_id: int | str) -> dict:
self._validate_session_id(session_id)

self.log.debug("Fetching info for batch session %s", session_id)
response = self.run_method(endpoint=f"/batches/{session_id}", headers=self.extra_headers)
response = self.run_method(
endpoint=f"{self.endpoint_prefix}/batches/{session_id}", headers=self.extra_headers
)

try:
response.raise_for_status()
Expand All @@ -217,7 +229,9 @@ def get_batch_state(self, session_id: int | str, retry_args: dict[str, Any] | No

self.log.debug("Fetching info for batch session %s", session_id)
response = self.run_method(
endpoint=f"/batches/{session_id}/state", retry_args=retry_args, headers=self.extra_headers
endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state",
retry_args=retry_args,
headers=self.extra_headers,
)

try:
Expand All @@ -244,7 +258,9 @@ def delete_batch(self, session_id: int | str) -> dict:

self.log.info("Deleting batch session %s", session_id)
response = self.run_method(
method="DELETE", endpoint=f"/batches/{session_id}", headers=self.extra_headers
method="DELETE",
endpoint=f"{self.endpoint_prefix}/batches/{session_id}",
headers=self.extra_headers,
)

try:
Expand All @@ -270,7 +286,9 @@ def get_batch_logs(self, session_id: int | str, log_start_position, log_batch_si
self._validate_session_id(session_id)
log_params = {"from": log_start_position, "size": log_batch_size}
response = self.run_method(
endpoint=f"/batches/{session_id}/log", data=log_params, headers=self.extra_headers
endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log",
data=log_params,
headers=self.extra_headers,
)
try:
response.raise_for_status()
Expand Down Expand Up @@ -490,12 +508,14 @@ def __init__(
livy_conn_id: str = default_conn_name,
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
endpoint_prefix: str | None = None,
) -> None:
super().__init__()
self.method = "POST"
self.http_conn_id = livy_conn_id
self.extra_headers = extra_headers or {}
self.extra_options = extra_options or {}
self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)

async def _do_api_call_async(
self,
Expand Down Expand Up @@ -624,7 +644,7 @@ async def get_batch_state(self, session_id: int | str) -> Any:
"""
self._validate_session_id(session_id)
self.log.info("Fetching info for batch session %s", session_id)
result = await self.run_method(endpoint=f"/batches/{session_id}/state")
result = await self.run_method(endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state")
if result["status"] == "error":
self.log.info(result)
return {"batch_state": "error", "response": result, "status": "error"}
Expand Down Expand Up @@ -659,7 +679,9 @@ async def get_batch_logs(
"""
self._validate_session_id(session_id)
log_params = {"from": log_start_position, "size": log_batch_size}
result = await self.run_method(endpoint=f"/batches/{session_id}/log", data=log_params)
result = await self.run_method(
endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log", data=log_params
)
if result["status"] == "error":
self.log.info(result)
return {"response": result["response"], "status": "error"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
proxy_user: str | None = None,
livy_conn_id: str = "livy_default",
livy_conn_auth_type: Any | None = None,
livy_endpoint_prefix: str | None = None,
polling_interval: int = 0,
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
Expand Down Expand Up @@ -119,6 +120,7 @@ def __init__(
self.spark_params = spark_params
self._livy_conn_id = livy_conn_id
self._livy_conn_auth_type = livy_conn_auth_type
self._livy_endpoint_prefix = livy_endpoint_prefix
self._polling_interval = polling_interval
self._extra_options = extra_options or {}
self._extra_headers = extra_headers or {}
Expand All @@ -139,6 +141,7 @@ def hook(self) -> LivyHook:
extra_headers=self._extra_headers,
extra_options=self._extra_options,
auth_type=self._livy_conn_auth_type,
endpoint_prefix=self._livy_endpoint_prefix,
)

def execute(self, context: Context) -> Any:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
livy_conn_id: str = "livy_default",
livy_conn_auth_type: Any | None = None,
extra_options: dict[str, Any] | None = None,
endpoint_prefix: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -54,6 +55,7 @@ def __init__(
self._livy_conn_auth_type = livy_conn_auth_type
self._livy_hook: LivyHook | None = None
self._extra_options = extra_options or {}
self._endpoint_prefix = endpoint_prefix

def get_hook(self) -> LivyHook:
"""
Expand All @@ -66,6 +68,7 @@ def get_hook(self) -> LivyHook:
livy_conn_id=self._livy_conn_id,
extra_options=self._extra_options,
auth_type=self._livy_conn_auth_type,
endpoint_prefix=self._endpoint_prefix,
)
return self._livy_hook

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
extra_headers: dict[str, Any] | None = None,
livy_hook_async: LivyAsyncHook | None = None,
execution_timeout: timedelta | None = None,
endpoint_prefix: str | None = None,
):
super().__init__()
self._batch_id = batch_id
Expand All @@ -67,6 +68,7 @@ def __init__(
self._extra_headers = extra_headers
self._livy_hook_async = livy_hook_async
self._execution_timeout = execution_timeout
self._endpoint_prefix = endpoint_prefix

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize LivyTrigger arguments and classpath."""
Expand Down Expand Up @@ -170,5 +172,6 @@ def _get_async_hook(self) -> LivyAsyncHook:
livy_conn_id=self._livy_conn_id,
extra_headers=self._extra_headers,
extra_options=self._extra_options,
endpoint_prefix=self._endpoint_prefix,
)
return self._livy_hook_async
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,77 @@ def test_alternate_auth_type(self):

auth_type.assert_called_once_with("login", "secret")

@patch("airflow.providers.apache.livy.hooks.livy.LivyHook.run_method")
def test_post_batch_with_endpoint_prefix(self, mock_request):
mock_request.return_value.status_code = 201
mock_request.return_value.json.return_value = {
"id": BATCH_ID,
"state": BatchState.STARTING.value,
"log": [],
}

resp = LivyHook(endpoint_prefix="/livy").post_batch(file="sparkapp")

mock_request.assert_called_once_with(
method="POST", endpoint="/livy/batches", data=json.dumps({"file": "sparkapp"}), headers={}
)

request_args = mock_request.call_args.kwargs
assert "data" in request_args
assert isinstance(request_args["data"], str)

assert isinstance(resp, int)
assert resp == BATCH_ID

def test_get_batch_with_endpoint_prefix(self, requests_mock):
requests_mock.register_uri(
"GET", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"id": BATCH_ID}, status_code=200
)
resp = LivyHook(endpoint_prefix="/livy").get_batch(BATCH_ID)
assert isinstance(resp, dict)
assert "id" in resp

def test_get_batch_state_with_endpoint_prefix(self, requests_mock):
running = BatchState.RUNNING

requests_mock.register_uri(
"GET",
f"{MATCH_URL}/livy/batches/{BATCH_ID}/state",
json={"id": BATCH_ID, "state": running.value},
status_code=200,
)

state = LivyHook(endpoint_prefix="/livy").get_batch_state(BATCH_ID)
assert isinstance(state, BatchState)
assert state == running

def test_delete_batch_with_endpoint_prefix(self, requests_mock):
requests_mock.register_uri(
"DELETE", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"msg": "deleted"}, status_code=200
)
assert LivyHook(endpoint_prefix="/livy").delete_batch(BATCH_ID) == {"msg": "deleted"}

@pytest.mark.parametrize(
"prefix",
["/livy/", "livy", "/livy", "livy/"],
ids=["leading_and_trailing_slashes", "no_slashes", "leading_slash", "trailing_slash"],
)
def test_endpoint_prefix_is_sanitized_simple(self, requests_mock, prefix):
requests_mock.register_uri(
"GET", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"id": BATCH_ID}, status_code=200
)
resp = LivyHook(endpoint_prefix=prefix).get_batch(BATCH_ID)
assert isinstance(resp, dict)
assert "id" in resp

def test_endpoint_prefix_is_sanitized_multiple_path_elements(self, requests_mock):
requests_mock.register_uri(
"GET", f"{MATCH_URL}/livy/foo/bar/batches/{BATCH_ID}", json={"id": BATCH_ID}, status_code=200
)
resp = LivyHook(endpoint_prefix="/livy/foo/bar/").get_batch(BATCH_ID)
assert isinstance(resp, dict)
assert "id" in resp


class TestLivyAsyncHook:
@pytest.mark.asyncio
Expand Down Expand Up @@ -815,3 +886,31 @@ def test_check_session_id_success(self, conn_id):
def test_check_session_id_failure(self, conn_id):
with pytest.raises(TypeError):
LivyAsyncHook._validate_session_id(None)

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method")
async def test_get_batch_state_with_endpoint_prefix(self, mock_run_method):
mock_run_method.return_value = {"status": "success", "response": {"state": BatchState.RUNNING}}
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID, endpoint_prefix="/livy")
state = await hook.get_batch_state(BATCH_ID)
assert state == {
"batch_state": BatchState.RUNNING,
"response": "successfully fetched the batch state.",
"status": "success",
}
mock_run_method.assert_called_once_with(
endpoint=f"/livy/batches/{BATCH_ID}/state",
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method")
async def test_get_batch_logs_with_endpoint_prefix(self, mock_run_method):
mock_run_method.return_value = {"status": "success", "response": {}}
hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID, endpoint_prefix="/livy")
state = await hook.get_batch_logs(BATCH_ID, 0, 100)
assert state["status"] == "success"

mock_run_method.assert_called_once_with(
endpoint=f"/livy/batches/{BATCH_ID}/log",
data={"from": 0, "size": 100},
)

0 comments on commit 8c9e0b2

Please sign in to comment.