diff --git a/providers/src/airflow/providers/apache/beam/hooks/beam.py b/providers/src/airflow/providers/apache/beam/hooks/beam.py index 9b5d216c3dc0f..20b83df4631b7 100644 --- a/providers/src/airflow/providers/apache/beam/hooks/beam.py +++ b/providers/src/airflow/providers/apache/beam/hooks/beam.py @@ -451,6 +451,7 @@ async def start_python_pipeline_async( py_interpreter: str = "python3", py_requirements: list[str] | None = None, py_system_site_packages: bool = False, + process_line_callback: Callable[[str], None] | None = None, ): """ Start Apache Beam python pipeline. @@ -470,6 +471,8 @@ async def start_python_pipeline_async( :param py_system_site_packages: Whether to include system_site_packages in your virtualenv. See virtualenv documentation for more information. This option is only relevant if the ``py_requirements`` parameter is not None. + :param process_line_callback: Optional callback which can be used to process + stdout and stderr to detect job id """ py_options = py_options or [] if "labels" in variables: @@ -518,16 +521,25 @@ async def start_python_pipeline_async( return_code = await self.start_pipeline_async( variables=variables, command_prefix=command_prefix, + process_line_callback=process_line_callback, ) return return_code - async def start_java_pipeline_async(self, variables: dict, jar: str, job_class: str | None = None): + async def start_java_pipeline_async( + self, + variables: dict, + jar: str, + job_class: str | None = None, + process_line_callback: Callable[[str], None] | None = None, + ): """ Start Apache Beam Java pipeline. :param variables: Variables passed to the job. :param jar: Name of the jar for the pipeline. :param job_class: Name of the java class for the pipeline. + :param process_line_callback: Optional callback which can be used to process + stdout and stderr to detect job id :return: Beam command execution return code. """ if "labels" in variables: @@ -537,6 +549,7 @@ async def start_java_pipeline_async(self, variables: dict, jar: str, job_class: return_code = await self.start_pipeline_async( variables=variables, command_prefix=command_prefix, + process_line_callback=process_line_callback, ) return return_code @@ -545,6 +558,7 @@ async def start_pipeline_async( variables: dict, command_prefix: list[str], working_directory: str | None = None, + process_line_callback: Callable[[str], None] | None = None, ) -> int: cmd = [*command_prefix, f"--runner={self.runner}"] if variables: @@ -553,6 +567,7 @@ async def start_pipeline_async( cmd=cmd, working_directory=working_directory, log=self.log, + process_line_callback=process_line_callback, ) async def run_beam_command_async( @@ -560,13 +575,16 @@ async def run_beam_command_async( cmd: list[str], log: logging.Logger, working_directory: str | None = None, + process_line_callback: Callable[[str], None] | None = None, ) -> int: """ Run pipeline command in subprocess. :param cmd: Parts of the command to be run in subprocess :param working_directory: Working directory - :param log: logger. + :param log: logger + :param process_line_callback: Optional callback which can be used to process + stdout and stderr to detect job id """ cmd_str_representation = " ".join(shlex.quote(c) for c in cmd) log.info("Running command: %s", cmd_str_representation) @@ -584,8 +602,8 @@ async def run_beam_command_async( log.info("Start waiting for Apache Beam process to complete.") # Creating separate threads for stdout and stderr - stdout_task = asyncio.create_task(self.read_logs(process.stdout)) - stderr_task = asyncio.create_task(self.read_logs(process.stderr)) + stdout_task = asyncio.create_task(self.read_logs(process.stdout, process_line_callback)) + stderr_task = asyncio.create_task(self.read_logs(process.stderr, process_line_callback)) # Waiting for the both tasks to complete await asyncio.gather(stdout_task, stderr_task) @@ -598,10 +616,16 @@ async def run_beam_command_async( raise AirflowException(f"Apache Beam process failed with return code {return_code}") return return_code - async def read_logs(self, stream_reader): + async def read_logs( + self, + stream_reader, + process_line_callback: Callable[[str], None] | None = None, + ): while True: line = await stream_reader.readline() if not line: break decoded_line = line.decode().strip() + if process_line_callback: + process_line_callback(decoded_line) self.log.info(decoded_line) diff --git a/providers/src/airflow/providers/apache/beam/operators/beam.py b/providers/src/airflow/providers/apache/beam/operators/beam.py index 07987fe1a9bc9..812c74cd95476 100644 --- a/providers/src/airflow/providers/apache/beam/operators/beam.py +++ b/providers/src/airflow/providers/apache/beam/operators/beam.py @@ -274,6 +274,17 @@ def execute_complete(self, context: Context, event: dict[str, Any]): self.task_id, event["message"], ) + self.dataflow_job_id = event["dataflow_job_id"] + self.project_id = event["project_id"] + self.location = event["location"] + + DataflowJobLink.persist( + self, + context, + self.project_id, + self.location, + self.dataflow_job_id, + ) return {"dataflow_job_id": self.dataflow_job_id} @@ -425,13 +436,6 @@ def execute_sync(self, context: Context): def execute_async(self, context: Context): if self.is_dataflow and self.dataflow_hook: - DataflowJobLink.persist( - self, - context, - self.dataflow_config.project_id, - self.dataflow_config.location, - self.dataflow_job_id, - ) with self.dataflow_hook.provide_authorized_gcloud(): self.defer( trigger=BeamPythonPipelineTrigger( @@ -443,6 +447,8 @@ def execute_async(self, context: Context): py_system_site_packages=self.py_system_site_packages, runner=self.runner, gcp_conn_id=self.gcp_conn_id, + project_id=self.dataflow_config.project_id, + location=self.dataflow_config.location, ), method_name="execute_complete", ) @@ -613,13 +619,6 @@ def execute_sync(self, context: Context): def execute_async(self, context: Context): if self.is_dataflow and self.dataflow_hook: - DataflowJobLink.persist( - self, - context, - self.dataflow_config.project_id, - self.dataflow_config.location, - self.dataflow_job_id, - ) with self.dataflow_hook.provide_authorized_gcloud(): self.pipeline_options["jobName"] = self.dataflow_job_name self.defer( diff --git a/providers/src/airflow/providers/apache/beam/triggers/beam.py b/providers/src/airflow/providers/apache/beam/triggers/beam.py index 3eb3611b6139e..1c77089b9f2bc 100644 --- a/providers/src/airflow/providers/apache/beam/triggers/beam.py +++ b/providers/src/airflow/providers/apache/beam/triggers/beam.py @@ -19,12 +19,15 @@ import asyncio import contextlib from collections.abc import AsyncIterator, Sequence -from typing import IO, Any +from typing import IO, Any, Callable from google.cloud.dataflow_v1beta3 import ListJobsRequest -from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook -from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook +from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook, BeamRunnerType +from airflow.providers.google.cloud.hooks.dataflow import ( + AsyncDataflowHook, + process_line_and_extract_dataflow_job_id_callback, +) from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -40,6 +43,14 @@ def _get_async_hook(*args, **kwargs) -> BeamAsyncHook: def _get_sync_dataflow_hook(**kwargs) -> AsyncDataflowHook: return AsyncDataflowHook(**kwargs) + def _get_dataflow_process_callback(self) -> Callable[[str], None]: + def set_current_dataflow_job_id(job_id): + self.dataflow_job_id = job_id + + return process_line_and_extract_dataflow_job_id_callback( + on_new_job_id_callback=set_current_dataflow_job_id + ) + class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger): """ @@ -59,6 +70,8 @@ class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger): :param py_system_site_packages: Whether to include system_site_packages in your virtualenv. See virtualenv documentation for more information. This option is only relevant if the ``py_requirements`` parameter is not None. + :param project_id: Optional, the Google Cloud project ID in which to start a job. + :param location: Optional, Job location. :param runner: Runner on which pipeline will be run. By default, "DirectRunner" is being used. Other possible options: DataflowRunner, SparkRunner, FlinkRunner, PortableRunner. See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType` @@ -74,6 +87,8 @@ def __init__( py_interpreter: str = "python3", py_requirements: list[str] | None = None, py_system_site_packages: bool = False, + project_id: str | None = None, + location: str | None = None, runner: str = "DirectRunner", gcp_conn_id: str = "google_cloud_default", ): @@ -84,6 +99,9 @@ def __init__( self.py_interpreter = py_interpreter self.py_requirements = py_requirements self.py_system_site_packages = py_system_site_packages + self.dataflow_job_id: str | None = None + self.project_id = project_id + self.location = location self.runner = runner self.gcp_conn_id = gcp_conn_id @@ -98,6 +116,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "py_interpreter": self.py_interpreter, "py_requirements": self.py_requirements, "py_system_site_packages": self.py_system_site_packages, + "project_id": self.project_id, + "location": self.location, "runner": self.runner, "gcp_conn_id": self.gcp_conn_id, }, @@ -106,6 +126,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current pipeline status and yields a TriggerEvent.""" hook = self._get_async_hook(runner=self.runner) + is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower() + try: # Get the current running event loop to manage I/O operations asynchronously loop = asyncio.get_running_loop() @@ -130,6 +152,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] py_interpreter=self.py_interpreter, py_requirements=self.py_requirements, py_system_site_packages=self.py_system_site_packages, + process_line_callback=self._get_dataflow_process_callback() if is_dataflow else None, ) except Exception as e: self.log.exception("Exception occurred while checking for pipeline state") @@ -140,6 +163,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] { "status": "success", "message": "Pipeline has finished SUCCESSFULLY", + "dataflow_job_id": self.dataflow_job_id, + "project_id": self.project_id, + "location": self.location, } ) else: @@ -205,6 +231,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.poll_sleep = poll_sleep self.cancel_timeout = cancel_timeout + self.dataflow_job_id: str | None = None def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize BeamJavaPipelineTrigger arguments and classpath.""" @@ -229,6 +256,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current Java pipeline status and yields a TriggerEvent.""" hook = self._get_async_hook(runner=self.runner) + is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower() return_code = 0 if self.check_if_running: @@ -271,7 +299,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.jar = tmp_gcs_file.name return_code = await hook.start_java_pipeline_async( - variables=self.variables, jar=self.jar, job_class=self.job_class + variables=self.variables, + jar=self.jar, + job_class=self.job_class, + process_line_callback=self._get_dataflow_process_callback() if is_dataflow else None, ) except Exception as e: self.log.exception("Exception occurred while starting the Java pipeline") @@ -282,6 +313,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] { "status": "success", "message": "Pipeline has finished SUCCESSFULLY", + "dataflow_job_id": self.dataflow_job_id, + "project_id": self.project_id, + "location": self.location, } ) else: diff --git a/providers/tests/apache/beam/hooks/test_beam.py b/providers/tests/apache/beam/hooks/test_beam.py index 9489e6ca8a155..61d2d2e4735ff 100644 --- a/providers/tests/apache/beam/hooks/test_beam.py +++ b/providers/tests/apache/beam/hooks/test_beam.py @@ -486,7 +486,10 @@ async def test_start_pipline_async(self, mock_runner): ) mock_runner.assert_called_once_with( - cmd=expected_cmd, working_directory=WORKING_DIRECTORY, log=hook.log + cmd=expected_cmd, + working_directory=WORKING_DIRECTORY, + log=hook.log, + process_line_callback=None, ) @pytest.mark.asyncio @@ -516,6 +519,7 @@ async def test_start_python_pipeline(self, mock_create_dir, mock_runner, mocked_ cmd=expected_cmd, working_directory=None, log=ANY, + process_line_callback=None, ) @pytest.mark.asyncio @@ -580,6 +584,7 @@ async def test_start_python_pipeline_with_custom_interpreter( cmd=expected_cmd, working_directory=None, log=ANY, + process_line_callback=None, ) @pytest.mark.asyncio @@ -630,6 +635,7 @@ async def test_start_python_pipeline_with_non_empty_py_requirements_and_without_ cmd=expected_cmd, working_directory=None, log=ANY, + process_line_callback=None, ) mock_virtualenv.assert_called_once_with( venv_directory=mock.ANY, @@ -671,5 +677,7 @@ async def test_start_java_pipeline_async(self, mock_start_pipeline, job_class, c await hook.start_java_pipeline_async(variables=variables, jar=JAR_FILE, job_class=job_class) mock_start_pipeline.assert_called_once_with( - variables=BEAM_VARIABLES_JAVA_STRING_LABELS, command_prefix=command_prefix + variables=BEAM_VARIABLES_JAVA_STRING_LABELS, + command_prefix=command_prefix, + process_line_callback=None, ) diff --git a/providers/tests/apache/beam/operators/test_beam.py b/providers/tests/apache/beam/operators/test_beam.py index a6794ade40dc0..a74e9002784fa 100644 --- a/providers/tests/apache/beam/operators/test_beam.py +++ b/providers/tests/apache/beam/operators/test_beam.py @@ -106,7 +106,13 @@ def test_async_execute_logging_should_execute_successfully(self, caplog): op = BeamBasePipelineOperator(**self.default_op_kwargs) op.execute_complete( context=mock.MagicMock(), - event={"status": "success", "message": "Pipeline has finished SUCCESSFULLY"}, + event={ + "status": "success", + "message": "Pipeline has finished SUCCESSFULLY", + "dataflow_job_id": "test_dataflow_job_id", + "project_id": "test_project_id", + "location": "test_location", + }, ) assert f"{TASK_ID} completed with response Pipeline has finished SUCCESSFULLY" in caplog.text @@ -952,10 +958,9 @@ def test_async_execute_direct_runner(self, beam_hook_mock): op.execute(context=mock.MagicMock()) beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER) - @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) - def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock, persist_link_mock): + def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock): """ Test DataflowHook is created and the right args are passed to start_python_dataflow when executing Dataflow runner. @@ -971,7 +976,6 @@ def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock, persist_ with pytest.raises(TaskDeferred): op.execute(context=magic_mock) - job_name = dataflow_hook_mock.build_dataflow_job_name.return_value dataflow_hook_mock.assert_called_once_with( gcp_conn_id=dataflow_config.gcp_conn_id, poll_sleep=dataflow_config.poll_sleep, @@ -980,22 +984,6 @@ def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock, persist_ cancel_timeout=dataflow_config.cancel_timeout, wait_until_finished=dataflow_config.wait_until_finished, ) - expected_options = { - "project": dataflow_hook_mock.return_value.project_id, - "job_name": job_name, - "staging_location": "gs://test/staging", - "output": "gs://test/output", - "labels": {"foo": "bar", "airflow-version": TEST_VERSION}, - "region": "us-central1", - "impersonate_service_account": TEST_IMPERSONATION_ACCOUNT, - } - persist_link_mock.assert_called_once_with( - op, - magic_mock, - expected_options["project"], - expected_options["region"], - op.dataflow_job_id, - ) beam_hook_mock.return_value.start_python_pipeline.assert_not_called() dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() @@ -1076,10 +1064,9 @@ def test_async_execute_direct_runner(self, beam_hook_mock): op.execute(context=mock.MagicMock()) beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER) - @mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist")) @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) @mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook")) - def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock, persist_link_mock): + def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock): """ Test DataflowHook is created and the right args are passed to start_java_pipeline when executing Dataflow runner. @@ -1092,7 +1079,6 @@ def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock, persist_ with pytest.raises(TaskDeferred): op.execute(context=magic_mock) - job_name = dataflow_hook_mock.build_dataflow_job_name.return_value dataflow_hook_mock.assert_called_once_with( gcp_conn_id=dataflow_config.gcp_conn_id, poll_sleep=dataflow_config.poll_sleep, @@ -1101,22 +1087,6 @@ def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock, persist_ cancel_timeout=dataflow_config.cancel_timeout, wait_until_finished=dataflow_config.wait_until_finished, ) - expected_options = { - "project": dataflow_hook_mock.return_value.project_id, - "job_name": job_name, - "staging_location": "gs://test/staging", - "output": "gs://test/output", - "labels": {"foo": "bar"}, - "region": "us-central1", - "impersonate_service_account": TEST_IMPERSONATION_ACCOUNT, - } - persist_link_mock.assert_called_once_with( - op, - magic_mock, - expected_options["project"], - expected_options["region"], - op.dataflow_job_id, - ) beam_hook_mock.return_value.start_python_pipeline.assert_not_called() dataflow_hook_mock.return_value.provide_authorized_gcloud.assert_called_once_with() diff --git a/providers/tests/apache/beam/triggers/test_beam.py b/providers/tests/apache/beam/triggers/test_beam.py index ca9981244d8e9..c1e1f65ce6b0f 100644 --- a/providers/tests/apache/beam/triggers/test_beam.py +++ b/providers/tests/apache/beam/triggers/test_beam.py @@ -61,6 +61,8 @@ def python_trigger(): py_interpreter=TEST_PY_INTERPRETER, py_requirements=TEST_PY_REQUIREMENTS, py_system_site_packages=TEST_PY_PACKAGES, + project_id=PROJECT_ID, + location=LOCATION, runner=TEST_RUNNER, gcp_conn_id=TEST_GCP_CONN_ID, ) @@ -99,6 +101,8 @@ def test_beam_trigger_serialization_should_execute_successfully(self, python_tri "py_interpreter": TEST_PY_INTERPRETER, "py_requirements": TEST_PY_REQUIREMENTS, "py_system_site_packages": TEST_PY_PACKAGES, + "project_id": PROJECT_ID, + "location": LOCATION, "runner": TEST_RUNNER, "gcp_conn_id": TEST_GCP_CONN_ID, } @@ -114,7 +118,18 @@ async def test_beam_trigger_on_success_should_execute_successfully( mock_pipeline_status.return_value = 0 generator = python_trigger.run() actual = await generator.asend(None) - assert TriggerEvent({"status": "success", "message": "Pipeline has finished SUCCESSFULLY"}) == actual + assert ( + TriggerEvent( + { + "status": "success", + "message": "Pipeline has finished SUCCESSFULLY", + "dataflow_job_id": None, + "project_id": PROJECT_ID, + "location": LOCATION, + } + ) + == actual + ) @pytest.mark.asyncio @mock.patch(HOOK_STATUS_STR_PYTHON) @@ -189,7 +204,18 @@ async def test_beam_trigger_on_success_should_execute_successfully( mock_pipeline_status.return_value = 0 generator = java_trigger.run() actual = await generator.asend(None) - assert TriggerEvent({"status": "success", "message": "Pipeline has finished SUCCESSFULLY"}) == actual + assert ( + TriggerEvent( + { + "status": "success", + "message": "Pipeline has finished SUCCESSFULLY", + "dataflow_job_id": None, + "project_id": PROJECT_ID, + "location": LOCATION, + } + ) + == actual + ) @pytest.mark.asyncio @mock.patch(HOOK_STATUS_STR_JAVA)