From 469beef0a05a59564d4066e3d36f958329476671 Mon Sep 17 00:00:00 2001 From: nevcohen <73249829+nevcohen@users.noreply.github.com> Date: Mon, 15 Jul 2024 00:31:32 +0300 Subject: [PATCH] Add `kubernetes_application_id` to `SparkSubmitHook` (#40753) --- .../providers/apache/spark/hooks/spark_submit.py | 14 ++++++++++---- .../apache/spark/hooks/test_spark_submit.py | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py index e8f53f89568fa..a12636c86a16d 100644 --- a/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/airflow/providers/apache/spark/hooks/spark_submit.py @@ -199,6 +199,7 @@ def __init__( self._submit_sp: Any | None = None self._yarn_application_id: str | None = None self._kubernetes_driver_pod: str | None = None + self._kubernetes_application_id: str | None = None self.spark_binary = spark_binary self._properties_file = properties_file self._yarn_queue = yarn_queue @@ -546,16 +547,21 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: match = re.search("application[0-9_]+", line) if match: self._yarn_application_id = match.group(0) - self.log.info("Identified spark driver id: %s", self._yarn_application_id) + self.log.info("Identified spark application id: %s", self._yarn_application_id) # If we run Kubernetes cluster mode, we want to extract the driver pod id # from the logs so we can kill the application when we stop it unexpectedly elif self._is_kubernetes: - match = re.search(r"\s*pod name: ((.+?)-([a-z0-9]+)-driver$)", line) - if match: - self._kubernetes_driver_pod = match.group(1) + match_driver_pod = re.search(r"\s*pod name: ((.+?)-([a-z0-9]+)-driver$)", line) + if match_driver_pod: + self._kubernetes_driver_pod = match_driver_pod.group(1) self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod) + match_application_id = re.search(r"\s*spark-app-selector -> (spark-([a-z0-9]+)), ", line) + if match_application_id: + self._kubernetes_application_id = match_application_id.group(1) + self.log.info("Identified spark application id: %s", self._kubernetes_application_id) + # Store the Spark Exit code match_exit_code = re.search(r"\s*[eE]xit code: (\d+)", line) if match_exit_code: diff --git a/tests/providers/apache/spark/hooks/test_spark_submit.py b/tests/providers/apache/spark/hooks/test_spark_submit.py index 0a4c6032e10d4..b526e25b6f952 100644 --- a/tests/providers/apache/spark/hooks/test_spark_submit.py +++ b/tests/providers/apache/spark/hooks/test_spark_submit.py @@ -652,6 +652,7 @@ def test_process_spark_submit_log_k8s(self, pod_name): # Then assert hook._kubernetes_driver_pod == pod_name + assert hook._kubernetes_application_id == "spark-465b868ada474bda82ccb84ab2747fcd" assert hook._spark_exit_code == 999 def test_process_spark_submit_log_k8s_spark_3(self):