From 1489cf7a0372898ab5f905fa7b56f3b1327d2cfe Mon Sep 17 00:00:00 2001 From: Maksim Date: Tue, 14 May 2024 07:53:13 -0700 Subject: [PATCH] Fix deferrable mode for BeamRunJavaPipelineOperator (#39371) --- .../providers/apache/beam/operators/beam.py | 21 +++--------------- .../providers/apache/beam/triggers/beam.py | 22 +++++++++++++++++-- .../apache/beam/operators/test_beam.py | 10 ++------- .../apache/beam/triggers/test_beam.py | 13 +++++++++++ 4 files changed, 38 insertions(+), 28 deletions(-) diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py index 62f650f19a4b1..af338cdc6d390 100644 --- a/airflow/providers/apache/beam/operators/beam.py +++ b/airflow/providers/apache/beam/operators/beam.py @@ -546,7 +546,7 @@ def execute(self, context: Context): if not self.beam_hook: raise AirflowException("Beam hook is not defined.") if self.deferrable: - asyncio.run(self.execute_async(context)) + self.execute_async(context) else: return self.execute_sync(context) @@ -605,23 +605,7 @@ def execute_sync(self, context: Context): process_line_callback=self.process_line_callback, ) - async def execute_async(self, context: Context): - # Creating a new event loop to manage I/O operations asynchronously - loop = asyncio.get_event_loop() - if self.jar.lower().startswith("gs://"): - gcs_hook = GCSHook(self.gcp_conn_id) - # Running synchronous `enter_context()` method in a separate - # thread using the default executor `None`. The `run_in_executor()` function returns the - # file object, which is created using gcs function `provide_file()`, asynchronously. - # This means we can perform asynchronous operations with this file. - create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar) - tmp_gcs_file: IO[str] = await loop.run_in_executor( - None, - contextlib.ExitStack().enter_context, # type: ignore[arg-type] - create_tmp_file_call, - ) - self.jar = tmp_gcs_file.name - + def execute_async(self, context: Context): if self.is_dataflow and self.dataflow_hook: DataflowJobLink.persist( self, @@ -657,6 +641,7 @@ async def execute_async(self, context: Context): job_class=self.job_class, runner=self.runner, check_if_running=self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun, + gcp_conn_id=self.gcp_conn_id, ), method_name="execute_complete", ) diff --git a/airflow/providers/apache/beam/triggers/beam.py b/airflow/providers/apache/beam/triggers/beam.py index 5b1f7a99d5a0a..b160218f737e0 100644 --- a/airflow/providers/apache/beam/triggers/beam.py +++ b/airflow/providers/apache/beam/triggers/beam.py @@ -17,7 +17,8 @@ from __future__ import annotations import asyncio -from typing import Any, AsyncIterator, Sequence +import contextlib +from typing import IO, Any, AsyncIterator, Sequence from deprecated import deprecated from google.cloud.dataflow_v1beta3 import ListJobsRequest @@ -25,6 +26,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook +from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -166,7 +168,7 @@ def __init__( project_id: str | None = None, location: str | None = None, job_name: str | None = None, - gcp_conn_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, poll_sleep: int = 10, cancel_timeout: int | None = None, @@ -233,6 +235,22 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] if is_running: await asyncio.sleep(self.poll_sleep) try: + # Get the current running event loop to manage I/O operations asynchronously + loop = asyncio.get_running_loop() + if self.jar.lower().startswith("gs://"): + gcs_hook = GCSHook(self.gcp_conn_id) + # Running synchronous `enter_context()` method in a separate + # thread using the default executor `None`. The `run_in_executor()` function returns the + # file object, which is created using gcs function `provide_file()`, asynchronously. + # This means we can perform asynchronous operations with this file. + create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar) + tmp_gcs_file: IO[str] = await loop.run_in_executor( + None, + contextlib.ExitStack().enter_context, # type: ignore[arg-type] + create_tmp_file_call, + ) + self.jar = tmp_gcs_file.name + return_code = await hook.start_java_pipeline_async( variables=self.variables, jar=self.jar, job_class=self.job_class ) diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py index a6a4c31c77a5c..15d5c9778af93 100644 --- a/tests/providers/apache/beam/operators/test_beam.py +++ b/tests/providers/apache/beam/operators/test_beam.py @@ -1013,24 +1013,20 @@ def test_async_execute_should_execute_successfully(self, gcs_hook, beam_hook_moc ), "Trigger is not a BeamPJavaPipelineTrigger" @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) - @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) - def test_async_execute_direct_runner(self, gcs_hook, beam_hook_mock): + def test_async_execute_direct_runner(self, beam_hook_mock): """ Test BeamHook is created and the right args are passed to start_java_pipeline when executing direct runner. """ - gcs_provide_file = gcs_hook.return_value.provide_file op = BeamRunJavaPipelineOperator(**self.default_op_kwargs) with pytest.raises(TaskDeferred): op.execute(context=mock.MagicMock()) beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER) - gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) - @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) - def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock): + def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock, persist_link_mock): """ Test DataflowHook is created and the right args are passed to start_java_pipeline when executing Dataflow runner. @@ -1039,7 +1035,6 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock op = BeamRunJavaPipelineOperator( runner="DataflowRunner", dataflow_config=dataflow_config, **self.default_op_kwargs ) - gcs_provide_file = gcs_hook.return_value.provide_file magic_mock = mock.MagicMock() with pytest.raises(TaskDeferred): op.execute(context=magic_mock) @@ -1062,7 +1057,6 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock "region": "us-central1", "impersonate_service_account": TEST_IMPERSONATION_ACCOUNT, } - gcs_provide_file.assert_called_once_with(object_url=JAR_FILE) persist_link_mock.assert_called_once_with( op, magic_mock, diff --git a/tests/providers/apache/beam/triggers/test_beam.py b/tests/providers/apache/beam/triggers/test_beam.py index 6bd1b4bc6647e..972e90161a22d 100644 --- a/tests/providers/apache/beam/triggers/test_beam.py +++ b/tests/providers/apache/beam/triggers/test_beam.py @@ -43,6 +43,7 @@ TEST_PY_PACKAGES = False TEST_RUNNER = "DirectRunner" TEST_JAR_FILE = "example.jar" +TEST_GCS_JAR_FILE = "gs://my-bucket/example/test.jar" TEST_JOB_CLASS = "TestClass" TEST_CHECK_IF_RUNNING = False TEST_JOB_NAME = "test_job_name" @@ -215,3 +216,15 @@ async def test_beam_trigger_exception_list_jobs_should_execute_successfully( generator = java_trigger.run() actual = await generator.asend(None) assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual + + @pytest.mark.asyncio + @mock.patch("airflow.providers.apache.beam.triggers.beam.GCSHook") + async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, gcs_hook, java_trigger): + """ + Test that BeamJavaPipelineTrigger downloads GCS provide file correct. + """ + gcs_provide_file = gcs_hook.return_value.provide_file + java_trigger.jar = TEST_GCS_JAR_FILE + generator = java_trigger.run() + await generator.asend(None) + gcs_provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE)