From f691adf7105b687b6ba2885c8977607065856fd3 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Fri, 9 Feb 2024 12:17:33 +0400 Subject: [PATCH] Fix rendering `SparkKubernetesOperator.template_body` (#37271) --- .../kubernetes/operators/spark_kubernetes.py | 6 +++++- .../operators/test_spark_kubernetes.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index 83d6f484b39dc..0c177510eb584 100644 --- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -107,7 +107,6 @@ def __init__( self.get_logs = get_logs self.log_events_on_failure = log_events_on_failure self.success_run_history_limit = success_run_history_limit - self.template_body = self.manage_template_specs() def _render_nested_template_fields( self, @@ -193,6 +192,11 @@ def pod_manager(self) -> PodManager: def _try_numbers_match(context, pod) -> bool: return pod.metadata.labels["try_number"] == context["ti"].try_number + @property + def template_body(self): + """Templated body for CustomObjectLauncher.""" + return self.manage_template_specs() + def find_spark_job(self, context): labels = self.create_labels_for_pod(context, include_try_number=False) label_selector = self._get_pod_identifying_label_string(labels) + ",spark-role=driver" diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py index d900ef78b43dc..5d164069262be 100644 --- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -25,6 +25,7 @@ from unittest.mock import patch import pendulum +import pytest import yaml from kubernetes.client import models as k8s @@ -488,3 +489,18 @@ def test_toleration( assert op.launcher.body["spec"]["driver"]["tolerations"] == [toleration] assert op.launcher.body["spec"]["executor"]["tolerations"] == [toleration] + + +@pytest.mark.db_test +def test_template_body_templating(create_task_instance_of_operator): + ti = create_task_instance_of_operator( + SparkKubernetesOperator, + template_spec={"foo": "{{ ds }}", "bar": "{{ dag_run.dag_id }}"}, + kubernetes_conn_id="kubernetes_default_kube_config", + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: SparkKubernetesOperator = ti.task + assert task.template_body == {"spark": {"foo": "2024-02-01", "bar": "test_template_body_templating_dag"}}