diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py index 934985befbc91..f185ccdfbe33b 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py @@ -52,6 +52,11 @@ class BatchState(Enum): SUCCESS = "success" +def sanitize_endpoint_prefix(endpoint_prefix: str | None) -> str: + """Ensure that the endpoint prefix is prefixed with a slash.""" + return f"/{endpoint_prefix.strip('/')}" if endpoint_prefix else "" + + class LivyHook(HttpHook): """ Hook for Apache Livy through the REST API. @@ -86,12 +91,14 @@ def __init__( extra_options: dict[str, Any] | None = None, extra_headers: dict[str, Any] | None = None, auth_type: Any | None = None, + endpoint_prefix: str | None = None, ) -> None: super().__init__() self.method = "POST" self.http_conn_id = livy_conn_id self.extra_headers = extra_headers or {} self.extra_options = extra_options or {} + self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix) if auth_type: self.auth_type = auth_type @@ -163,7 +170,10 @@ def post_batch(self, *args: Any, **kwargs: Any) -> int: self.log.info("Submitting job %s to %s", batch_submit_body, self.base_url) response = self.run_method( - method="POST", endpoint="/batches", data=batch_submit_body, headers=self.extra_headers + method="POST", + endpoint=f"{self.endpoint_prefix}/batches", + data=batch_submit_body, + headers=self.extra_headers, ) self.log.debug("Got response: %s", response.text) @@ -192,7 +202,9 @@ def get_batch(self, session_id: int | str) -> dict: self._validate_session_id(session_id) self.log.debug("Fetching info for batch session %s", session_id) - response = self.run_method(endpoint=f"/batches/{session_id}", headers=self.extra_headers) + response = self.run_method( + endpoint=f"{self.endpoint_prefix}/batches/{session_id}", headers=self.extra_headers + ) try: response.raise_for_status() @@ -217,7 +229,9 @@ def get_batch_state(self, session_id: int | str, retry_args: dict[str, Any] | No self.log.debug("Fetching info for batch session %s", session_id) response = self.run_method( - endpoint=f"/batches/{session_id}/state", retry_args=retry_args, headers=self.extra_headers + endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state", + retry_args=retry_args, + headers=self.extra_headers, ) try: @@ -244,7 +258,9 @@ def delete_batch(self, session_id: int | str) -> dict: self.log.info("Deleting batch session %s", session_id) response = self.run_method( - method="DELETE", endpoint=f"/batches/{session_id}", headers=self.extra_headers + method="DELETE", + endpoint=f"{self.endpoint_prefix}/batches/{session_id}", + headers=self.extra_headers, ) try: @@ -270,7 +286,9 @@ def get_batch_logs(self, session_id: int | str, log_start_position, log_batch_si self._validate_session_id(session_id) log_params = {"from": log_start_position, "size": log_batch_size} response = self.run_method( - endpoint=f"/batches/{session_id}/log", data=log_params, headers=self.extra_headers + endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log", + data=log_params, + headers=self.extra_headers, ) try: response.raise_for_status() @@ -490,12 +508,14 @@ def __init__( livy_conn_id: str = default_conn_name, extra_options: dict[str, Any] | None = None, extra_headers: dict[str, Any] | None = None, + endpoint_prefix: str | None = None, ) -> None: super().__init__() self.method = "POST" self.http_conn_id = livy_conn_id self.extra_headers = extra_headers or {} self.extra_options = extra_options or {} + self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix) async def _do_api_call_async( self, @@ -624,7 +644,7 @@ async def get_batch_state(self, session_id: int | str) -> Any: """ self._validate_session_id(session_id) self.log.info("Fetching info for batch session %s", session_id) - result = await self.run_method(endpoint=f"/batches/{session_id}/state") + result = await self.run_method(endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state") if result["status"] == "error": self.log.info(result) return {"batch_state": "error", "response": result, "status": "error"} @@ -659,7 +679,9 @@ async def get_batch_logs( """ self._validate_session_id(session_id) log_params = {"from": log_start_position, "size": log_batch_size} - result = await self.run_method(endpoint=f"/batches/{session_id}/log", data=log_params) + result = await self.run_method( + endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log", data=log_params + ) if result["status"] == "error": self.log.info(result) return {"response": result["response"], "status": "error"} diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py index b74e52d5e6109..746ea55cceeaf 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py @@ -88,6 +88,7 @@ def __init__( proxy_user: str | None = None, livy_conn_id: str = "livy_default", livy_conn_auth_type: Any | None = None, + livy_endpoint_prefix: str | None = None, polling_interval: int = 0, extra_options: dict[str, Any] | None = None, extra_headers: dict[str, Any] | None = None, @@ -119,6 +120,7 @@ def __init__( self.spark_params = spark_params self._livy_conn_id = livy_conn_id self._livy_conn_auth_type = livy_conn_auth_type + self._livy_endpoint_prefix = livy_endpoint_prefix self._polling_interval = polling_interval self._extra_options = extra_options or {} self._extra_headers = extra_headers or {} @@ -139,6 +141,7 @@ def hook(self) -> LivyHook: extra_headers=self._extra_headers, extra_options=self._extra_options, auth_type=self._livy_conn_auth_type, + endpoint_prefix=self._livy_endpoint_prefix, ) def execute(self, context: Context) -> Any: diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py index 3c1a50255ad7a..d0b011e251827 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py @@ -46,6 +46,7 @@ def __init__( livy_conn_id: str = "livy_default", livy_conn_auth_type: Any | None = None, extra_options: dict[str, Any] | None = None, + endpoint_prefix: str | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -54,6 +55,7 @@ def __init__( self._livy_conn_auth_type = livy_conn_auth_type self._livy_hook: LivyHook | None = None self._extra_options = extra_options or {} + self._endpoint_prefix = endpoint_prefix def get_hook(self) -> LivyHook: """ @@ -66,6 +68,7 @@ def get_hook(self) -> LivyHook: livy_conn_id=self._livy_conn_id, extra_options=self._extra_options, auth_type=self._livy_conn_auth_type, + endpoint_prefix=self._endpoint_prefix, ) return self._livy_hook diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py index 2e40b26113ed1..9c47706b73718 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py @@ -57,6 +57,7 @@ def __init__( extra_headers: dict[str, Any] | None = None, livy_hook_async: LivyAsyncHook | None = None, execution_timeout: timedelta | None = None, + endpoint_prefix: str | None = None, ): super().__init__() self._batch_id = batch_id @@ -67,6 +68,7 @@ def __init__( self._extra_headers = extra_headers self._livy_hook_async = livy_hook_async self._execution_timeout = execution_timeout + self._endpoint_prefix = endpoint_prefix def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize LivyTrigger arguments and classpath.""" @@ -170,5 +172,6 @@ def _get_async_hook(self) -> LivyAsyncHook: livy_conn_id=self._livy_conn_id, extra_headers=self._extra_headers, extra_options=self._extra_options, + endpoint_prefix=self._endpoint_prefix, ) return self._livy_hook_async diff --git a/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py b/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py index 6cb6dd911600c..2a1101ab44e68 100644 --- a/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py +++ b/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py @@ -428,6 +428,77 @@ def test_alternate_auth_type(self): auth_type.assert_called_once_with("login", "secret") + @patch("airflow.providers.apache.livy.hooks.livy.LivyHook.run_method") + def test_post_batch_with_endpoint_prefix(self, mock_request): + mock_request.return_value.status_code = 201 + mock_request.return_value.json.return_value = { + "id": BATCH_ID, + "state": BatchState.STARTING.value, + "log": [], + } + + resp = LivyHook(endpoint_prefix="/livy").post_batch(file="sparkapp") + + mock_request.assert_called_once_with( + method="POST", endpoint="/livy/batches", data=json.dumps({"file": "sparkapp"}), headers={} + ) + + request_args = mock_request.call_args.kwargs + assert "data" in request_args + assert isinstance(request_args["data"], str) + + assert isinstance(resp, int) + assert resp == BATCH_ID + + def test_get_batch_with_endpoint_prefix(self, requests_mock): + requests_mock.register_uri( + "GET", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"id": BATCH_ID}, status_code=200 + ) + resp = LivyHook(endpoint_prefix="/livy").get_batch(BATCH_ID) + assert isinstance(resp, dict) + assert "id" in resp + + def test_get_batch_state_with_endpoint_prefix(self, requests_mock): + running = BatchState.RUNNING + + requests_mock.register_uri( + "GET", + f"{MATCH_URL}/livy/batches/{BATCH_ID}/state", + json={"id": BATCH_ID, "state": running.value}, + status_code=200, + ) + + state = LivyHook(endpoint_prefix="/livy").get_batch_state(BATCH_ID) + assert isinstance(state, BatchState) + assert state == running + + def test_delete_batch_with_endpoint_prefix(self, requests_mock): + requests_mock.register_uri( + "DELETE", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"msg": "deleted"}, status_code=200 + ) + assert LivyHook(endpoint_prefix="/livy").delete_batch(BATCH_ID) == {"msg": "deleted"} + + @pytest.mark.parametrize( + "prefix", + ["/livy/", "livy", "/livy", "livy/"], + ids=["leading_and_trailing_slashes", "no_slashes", "leading_slash", "trailing_slash"], + ) + def test_endpoint_prefix_is_sanitized_simple(self, requests_mock, prefix): + requests_mock.register_uri( + "GET", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"id": BATCH_ID}, status_code=200 + ) + resp = LivyHook(endpoint_prefix=prefix).get_batch(BATCH_ID) + assert isinstance(resp, dict) + assert "id" in resp + + def test_endpoint_prefix_is_sanitized_multiple_path_elements(self, requests_mock): + requests_mock.register_uri( + "GET", f"{MATCH_URL}/livy/foo/bar/batches/{BATCH_ID}", json={"id": BATCH_ID}, status_code=200 + ) + resp = LivyHook(endpoint_prefix="/livy/foo/bar/").get_batch(BATCH_ID) + assert isinstance(resp, dict) + assert "id" in resp + class TestLivyAsyncHook: @pytest.mark.asyncio @@ -815,3 +886,31 @@ def test_check_session_id_success(self, conn_id): def test_check_session_id_failure(self, conn_id): with pytest.raises(TypeError): LivyAsyncHook._validate_session_id(None) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method") + async def test_get_batch_state_with_endpoint_prefix(self, mock_run_method): + mock_run_method.return_value = {"status": "success", "response": {"state": BatchState.RUNNING}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID, endpoint_prefix="/livy") + state = await hook.get_batch_state(BATCH_ID) + assert state == { + "batch_state": BatchState.RUNNING, + "response": "successfully fetched the batch state.", + "status": "success", + } + mock_run_method.assert_called_once_with( + endpoint=f"/livy/batches/{BATCH_ID}/state", + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method") + async def test_get_batch_logs_with_endpoint_prefix(self, mock_run_method): + mock_run_method.return_value = {"status": "success", "response": {}} + hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID, endpoint_prefix="/livy") + state = await hook.get_batch_logs(BATCH_ID, 0, 100) + assert state["status"] == "success" + + mock_run_method.assert_called_once_with( + endpoint=f"/livy/batches/{BATCH_ID}/log", + data={"from": 0, "size": 100}, + )