Skip to content

Commit

Permalink
Refactor DataprocCreateBatchOperator and Dataproc system tests
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseenkov committed Aug 16, 2024
1 parent 393978d commit 0f25e9a
Show file tree
Hide file tree
Showing 24 changed files with 148 additions and 105 deletions.
154 changes: 68 additions & 86 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from deprecated import deprecated
Expand Down Expand Up @@ -2985,10 +2986,10 @@ class DataprocCreateBatchOperator(GoogleCloudBaseOperator):
def __init__(
self,
*,
region: str | None = None,
region: str,
project_id: str = PROVIDE_PROJECT_ID,
batch: dict | Batch,
batch_id: str,
batch_id: str | None = None,
request_id: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
Expand Down Expand Up @@ -3021,20 +3022,20 @@ def __init__(
self.polling_interval_seconds = polling_interval_seconds

def execute(self, context: Context):
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
# batch_id might not be set and will be generated
if self.batch_id:
link = DATAPROC_BATCH_LINK.format(
region=self.region, project_id=self.project_id, batch_id=self.batch_id
if self.asynchronous and self.deferrable:
raise AirflowException(
"Both asynchronous and deferrable parameters were passed. Please, provide only one."
)
self.log.info("Creating batch %s", self.batch_id)
self.log.info("Once started, the batch job will be available at %s", link)

batch_id: str = ""
if self.batch_id:
batch_id = self.batch_id
self.log.info("Starting batch %s", batch_id)
else:
self.log.info("Starting batch job. The batch ID will be generated since it was not provided.")
if self.region is None:
raise AirflowException("Region should be set here")
self.log.info("Starting batch. The batch ID will be generated since it was not provided.")

try:
self.operation = hook.create_batch(
self.operation = self.hook.create_batch(
region=self.region,
project_id=self.project_id,
batch=self.batch,
Expand All @@ -3044,85 +3045,62 @@ def execute(self, context: Context):
timeout=self.timeout,
metadata=self.metadata,
)
if self.operation is None:
raise RuntimeError("The operation should be set here!")

if not self.deferrable:
if not self.asynchronous:
result = hook.wait_for_operation(
timeout=self.timeout, result_retry=self.result_retry, operation=self.operation
)
self.log.info("Batch %s created", self.batch_id)

else:
DataprocBatchLink.persist(
context=context,
operator=self,
project_id=self.project_id,
region=self.region,
batch_id=self.batch_id,
)
return self.operation.operation.name

else:
# processing ends in execute_complete
self.defer(
trigger=DataprocBatchTrigger(
batch_id=self.batch_id,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)

except AlreadyExists:
self.log.info("Batch with given id already exists")
# This is only likely to happen if batch_id was provided
# Could be running if Airflow was restarted after task started
# poll until a final state is reached

self.log.info("Attaching to the job %s if it is still running.", self.batch_id)
self.log.info("Batch with given id already exists.")
self.log.info("Attaching to the job %s if it is still running.", batch_id)
else:
batch_id = self.operation.metadata.batch.split("/")[-1]
self.log.info("The batch %s was created.", batch_id)

# deferrable handling of a batch_id that already exists - processing ends in execute_complete
if self.deferrable:
self.defer(
trigger=DataprocBatchTrigger(
batch_id=self.batch_id,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
DataprocBatchLink.persist(
context=context,
operator=self,
project_id=self.project_id,
region=self.region,
batch_id=batch_id,
)

# non-deferrable handling of a batch_id that already exists
result = hook.wait_for_batch(
batch_id=self.batch_id,
if self.asynchronous:
batch = self.hook.get_batch(
batch_id=batch_id,
region=self.region,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
wait_check_interval=self.polling_interval_seconds,
)
batch_id = self.batch_id or result.name.split("/")[-1]
self.log.info("The batch %s was created asynchronously. Exiting.", batch_id)
return Batch.to_dict(batch)

self.handle_batch_status(context, result.state, batch_id)
project_id = self.project_id or hook.project_id
if project_id:
DataprocBatchLink.persist(
context=context,
operator=self,
project_id=project_id,
region=self.region,
batch_id=batch_id,
if self.deferrable:
self.defer(
trigger=DataprocBatchTrigger(
batch_id=batch_id,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
return Batch.to_dict(result)

self.log.info("Waiting for the completion of batch job %s", batch_id)
batch = self.hook.wait_for_batch(
batch_id=batch_id,
region=self.region,
project_id=self.project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

self.handle_batch_status(context, batch.state, batch_id, batch.state_message)
return Batch.to_dict(batch)

@cached_property
def hook(self) -> DataprocHook:
return DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)

def execute_complete(self, context, event=None) -> None:
"""
Expand All @@ -3135,23 +3113,27 @@ def execute_complete(self, context, event=None) -> None:
raise AirflowException("Batch failed.")
state = event["batch_state"]
batch_id = event["batch_id"]
self.handle_batch_status(context, state, batch_id)
self.handle_batch_status(context, state, batch_id, state_message=event["batch_state_message"])

def on_kill(self):
if self.operation:
self.operation.cancel()

def handle_batch_status(self, context: Context, state: Batch.State, batch_id: str) -> None:
def handle_batch_status(
self, context: Context, state: Batch.State, batch_id: str, state_message: str | None = None
) -> None:
# The existing batch may be a number of states other than 'SUCCEEDED'\
# wait_for_operation doesn't fail if the job is cancelled, so we will check for it here which also
# finds a cancelling|canceled|unspecified job from wait_for_batch or the deferred trigger
link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, batch_id=batch_id)
if state == Batch.State.FAILED:
raise AirflowException("Batch job %s failed. Driver Logs: %s", batch_id, link)
raise AirflowException(
f"Batch job {batch_id} failed with error: {state_message}\nDriver Logs: {link}"
)
if state in (Batch.State.CANCELLED, Batch.State.CANCELLING):
raise AirflowException("Batch job %s was cancelled. Driver logs: %s", batch_id, link)
raise AirflowException(f"Batch job {batch_id} was cancelled. Driver logs: {link}")
if state == Batch.State.STATE_UNSPECIFIED:
raise AirflowException("Batch job %s unspecified. Driver logs: %s", batch_id, link)
raise AirflowException(f"Batch job {batch_id} unspecified. Driver logs: {link}")
self.log.info("Batch job %s completed. Driver logs: %s", batch_id, link)


Expand Down
5 changes: 4 additions & 1 deletion airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ async def run(self):
self.log.info("Current state is %s", state)
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"batch_id": self.batch_id, "batch_state": state})

yield TriggerEvent(
{"batch_id": self.batch_id, "batch_state": state, "batch_state_message": batch.state_message}
)


class DataprocDeleteClusterTrigger(DataprocBaseTrigger):
Expand Down
3 changes: 1 addition & 2 deletions scripts/ci/pre_commit/check_system_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
errors: list[str] = []

WATCHER_APPEND_INSTRUCTION = "list(dag.tasks) >> watcher()"
WATCHER_APPEND_INSTRUCTION_SHORT = " >> watcher()"

PYTEST_FUNCTION = """
from tests.system.utils import get_test_run # noqa: E402
Expand All @@ -53,7 +52,7 @@
def _check_file(file: Path):
content = file.read_text()
if "from tests.system.utils.watcher import watcher" in content:
index = content.find(WATCHER_APPEND_INSTRUCTION_SHORT)
index = content.find(WATCHER_APPEND_INSTRUCTION)
if index == -1:
errors.append(
f"[red]The example {file} imports tests.system.utils.watcher "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
region=REGION,
batch=BATCH_CONFIG,
batch_id=BATCH_ID,
result_retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0),
)

create_batch_2 = DataprocCreateBatchOperator(
Expand All @@ -87,6 +88,7 @@
batch=BATCH_CONFIG,
batch_id=BATCH_ID_3,
asynchronous=True,
result_retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0),
)
# [END how_to_cloud_dataproc_create_batch_operator]

Expand Down Expand Up @@ -128,18 +130,10 @@
task_id="cancel_operation",
project_id=PROJECT_ID,
region=REGION,
operation_name="{{ task_instance.xcom_pull('create_batch_4') }}",
operation_name="{{ task_instance.xcom_pull('create_batch_4')['operation'] }}",
)
# [END how_to_cloud_dataproc_cancel_operation_operator]

batch_cancelled_sensor = DataprocBatchSensor(
task_id="batch_cancelled_sensor",
region=REGION,
project_id=PROJECT_ID,
batch_id=BATCH_ID_4,
poke_interval=10,
)

# [START how_to_cloud_dataproc_delete_batch_operator]
delete_batch = DataprocDeleteBatchOperator(
task_id="delete_batch", project_id=PROJECT_ID, region=REGION, batch_id=BATCH_ID
Expand All @@ -161,27 +155,27 @@

(
# TEST SETUP
[create_batch, create_batch_2, create_batch_3]
create_batch
>> create_batch_2
>> create_batch_3
# TEST BODY
>> batch_async_sensor
>> get_batch
>> list_batches
>> create_batch_4
>> cancel_operation
# TEST TEARDOWN
>> [delete_batch, delete_batch_2, delete_batch_3]
>> batch_cancelled_sensor
>> delete_batch
>> delete_batch_2
>> delete_batch_3
>> delete_batch_4
)

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "teardown" task with trigger rule is part of the DAG

# Excluding sensor because we expect it to fail due to cancelled operation
[task for task in dag.tasks if task.task_id != "batch_cancelled_sensor"] >> watcher()

list(dag.tasks) >> watcher()

from tests.system.utils import get_test_run # noqa: E402

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import os
from datetime import datetime

from google.api_core.retry import Retry

from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.dataproc import (
DataprocCreateBatchOperator,
Expand Down Expand Up @@ -62,6 +64,7 @@
batch=BATCH_CONFIG,
batch_id=BATCH_ID,
deferrable=True,
result_retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0),
)
# [END how_to_cloud_dataproc_create_batch_operator_async]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import os
from datetime import datetime

from google.api_core.retry import Retry

from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.dataproc import (
ClusterGenerator,
Expand Down Expand Up @@ -89,6 +91,7 @@
cluster_config=CLUSTER_GENERATOR_CONFIG_FOR_PHS,
region=REGION,
cluster_name=CLUSTER_NAME,
result_retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0),
)
# [END how_to_cloud_dataproc_create_cluster_for_persistent_history_server]

Expand All @@ -99,6 +102,7 @@
region=REGION,
batch=BATCH_CONFIG_WITH_PHS,
batch_id=BATCH_ID,
result_retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0),
)
# [END how_to_cloud_dataproc_create_batch_operator_with_persistent_history_server]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import os
from datetime import datetime

from google.api_core.retry import Retry

from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.dataproc import (
DataprocCreateClusterOperator,
Expand Down Expand Up @@ -69,13 +71,15 @@
region=REGION,
cluster_name=CLUSTER_NAME,
use_if_exists=True,
result_retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0),
)

start_cluster = DataprocStartClusterOperator(
task_id="start_cluster",
project_id=PROJECT_ID,
region=REGION,
cluster_name=CLUSTER_NAME,
result_retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0),
)

stop_cluster = DataprocStopClusterOperator(
Expand Down
Loading

0 comments on commit 0f25e9a

Please sign in to comment.