Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add early job_id xcom_push for google provider Beam Pipeline operators #121

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions providers/src/airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,20 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.beam_hook: BeamHook
self.dataflow_hook: DataflowHook | None = None
self.dataflow_job_id: str | None = None
self._dataflow_job_id: str | None = None
self._execute_context: Context | None = None

@property
def dataflow_job_id(self):
return self._dataflow_job_id

@dataflow_job_id.setter
def dataflow_job_id(self, new_value):
if all([new_value, not self._dataflow_job_id, self._execute_context]):
# push job_id as soon as it's ready, to let Sensors work before the job finished
# and job_id pushed as returned value item.
self.xcom_push(context=self._execute_context, key="dataflow_job_id", value=new_value)
self._dataflow_job_id = new_value

def _cast_dataflow_config(self):
if isinstance(self.dataflow_config, dict):
Expand Down Expand Up @@ -346,6 +359,7 @@ def __init__(

def execute(self, context: Context):
"""Execute the Apache Beam Python Pipeline."""
self._execute_context = context
self._cast_dataflow_config()
self.pipeline_options.setdefault("labels", {}).update(
{"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
Expand Down Expand Up @@ -540,6 +554,7 @@ def __init__(

def execute(self, context: Context):
"""Execute the Apache Beam Python Pipeline."""
self._execute_context = context
self._cast_dataflow_config()
(
self.is_dataflow,
Expand Down Expand Up @@ -738,7 +753,7 @@ def execute(self, context: Context):
"""Execute the Apache Beam Pipeline."""
if not exactly_one(self.go_file, self.launcher_binary):
raise ValueError("Exactly one of `go_file` and `launcher_binary` must be set")

self._execute_context = context
self._cast_dataflow_config()
if self.dataflow_config.impersonation_chain:
self.log.warning(
Expand Down
21 changes: 21 additions & 0 deletions providers/tests/apache/beam/operators/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,27 @@ def test_async_execute_logging_should_execute_successfully(self, caplog):
)
assert f"{TASK_ID} completed with response Pipeline has finished SUCCESSFULLY" in caplog.text

def test_early_dataflow_id_xcom_push(self, default_options, pipeline_options):
with mock.patch.object(BeamBasePipelineOperator, "xcom_push") as mock_xcom_push:
op = BeamBasePipelineOperator(
**self.default_op_kwargs,
default_pipeline_options=copy.deepcopy(default_options),
pipeline_options=copy.deepcopy(pipeline_options),
dataflow_config={},
)
sample_df_job_id = "sample_df_job_id_value"
op._execute_context = MagicMock()

assert op.dataflow_job_id is None

op.dataflow_job_id = sample_df_job_id
mock_xcom_push.assert_called_once_with(
context=op._execute_context, key="dataflow_job_id", value=sample_df_job_id
)
mock_xcom_push.reset_mock()
op.dataflow_job_id = "sample_df_job_same_value_id"
mock_xcom_push.assert_not_called()


class TestBeamRunPythonPipelineOperator:
@pytest.fixture(autouse=True)
Expand Down
Loading