Skip to content

Commit

Permalink
Fix case if SparkKubernetesOperator.application_file is templated f…
Browse files Browse the repository at this point in the history
…ile (apache#38035)

* Fix case if `SparkKubernetesOperator.application_file` is templated file

* Simplify non-templated file test

* Change check sequential order

* Unique dag names
  • Loading branch information
Taragolis authored Mar 12, 2024
1 parent 58bffa6 commit 1d3010c
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 1 deletion.
12 changes: 11 additions & 1 deletion airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import re
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any

from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
Expand Down Expand Up @@ -124,7 +125,16 @@ def _render_nested_template_fields(

def manage_template_specs(self):
if self.application_file:
template_body = _load_body_to_dict(open(self.application_file))
try:
filepath = Path(self.application_file.rstrip()).resolve(strict=True)
except (FileNotFoundError, OSError, RuntimeError, ValueError):
application_file_body = self.application_file
else:
application_file_body = filepath.read_text()
template_body = _load_body_to_dict(application_file_body)
if not isinstance(template_body, dict):
msg = f"application_file body can't transformed into the dictionary:\n{application_file_body}"
raise TypeError(msg)
elif self.template_spec:
template_body = self.template_spec
else:
Expand Down
105 changes: 105 additions & 0 deletions tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import copy
import json
from datetime import date
from os.path import join
from pathlib import Path
from unittest import mock
Expand All @@ -32,6 +33,7 @@
from airflow import DAG
from airflow.models import Connection, DagRun, TaskInstance
from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator
from airflow.template.templater import LiteralValue
from airflow.utils import db, timezone
from airflow.utils.types import DagRunType

Expand Down Expand Up @@ -504,3 +506,106 @@ def test_template_body_templating(create_task_instance_of_operator):
ti.render_templates()
task: SparkKubernetesOperator = ti.task
assert task.template_body == {"spark": {"foo": "2024-02-01", "bar": "test_template_body_templating_dag"}}


@pytest.mark.db_test
def test_resolve_application_file_template_file(dag_maker, tmp_path):
execution_date = timezone.datetime(2024, 2, 1, tzinfo=timezone.utc)
filename = "test-application-file.yml"
(tmp_path / filename).write_text("foo: {{ ds }}\nbar: {{ dag_run.dag_id }}\nspam: egg")

with dag_maker(
dag_id="test_resolve_application_file_template_file", template_searchpath=tmp_path.as_posix()
):
SparkKubernetesOperator(
application_file=filename,
kubernetes_conn_id="kubernetes_default_kube_config",
task_id="test_template_body_templating_task",
)

ti = dag_maker.create_dagrun(execution_date=execution_date).task_instances[0]
ti.render_templates()
task: SparkKubernetesOperator = ti.task
assert task.template_body == {
"spark": {
"foo": date(2024, 2, 1),
"bar": "test_resolve_application_file_template_file",
"spam": "egg",
}
}


@pytest.mark.db_test
@pytest.mark.parametrize(
"body",
[
pytest.param(["a", "b"], id="list"),
pytest.param(42, id="int"),
pytest.param("{{ ds }}", id="jinja"),
pytest.param(None, id="none"),
],
)
def test_resolve_application_file_template_non_dictionary(dag_maker, tmp_path, body):
execution_date = timezone.datetime(2024, 2, 1, tzinfo=timezone.utc)
filename = "test-application-file.yml"
with open((tmp_path / filename), "w") as fp:
yaml.safe_dump(body, fp)

with dag_maker(
dag_id="test_resolve_application_file_template_nondictionary", template_searchpath=tmp_path.as_posix()
):
SparkKubernetesOperator(
application_file=filename,
kubernetes_conn_id="kubernetes_default_kube_config",
task_id="test_template_body_templating_task",
)

ti = dag_maker.create_dagrun(execution_date=execution_date).task_instances[0]
ti.render_templates()
task: SparkKubernetesOperator = ti.task
with pytest.raises(TypeError, match="application_file body can't transformed into the dictionary"):
_ = task.template_body


@pytest.mark.db_test
@pytest.mark.parametrize(
"use_literal_value", [pytest.param(True, id="literal-value"), pytest.param(False, id="whitespace-compat")]
)
def test_resolve_application_file_real_file(create_task_instance_of_operator, tmp_path, use_literal_value):
application_file = tmp_path / "test-application-file.yml"
application_file.write_text("foo: bar\nspam: egg")

application_file = application_file.resolve().as_posix()
if use_literal_value:
application_file = LiteralValue(application_file)
else:
# Prior Airflow 2.8 workaround was adding whitespace at the end of the filepath
application_file = f"{application_file} "

ti = create_task_instance_of_operator(
SparkKubernetesOperator,
application_file=application_file,
kubernetes_conn_id="kubernetes_default_kube_config",
dag_id="test_resolve_application_file_real_file",
task_id="test_template_body_templating_task",
)
ti.render_templates()
task: SparkKubernetesOperator = ti.task

assert task.template_body == {"spark": {"foo": "bar", "spam": "egg"}}


@pytest.mark.db_test
def test_resolve_application_file_real_file_not_exists(create_task_instance_of_operator, tmp_path):
application_file = (tmp_path / "test-application-file.yml").resolve().as_posix()
ti = create_task_instance_of_operator(
SparkKubernetesOperator,
application_file=LiteralValue(application_file),
kubernetes_conn_id="kubernetes_default_kube_config",
dag_id="test_resolve_application_file_real_file_not_exists",
task_id="test_template_body_templating_task",
)
ti.render_templates()
task: SparkKubernetesOperator = ti.task
with pytest.raises(TypeError, match="application_file body can't transformed into the dictionary"):
_ = task.template_body

0 comments on commit 1d3010c

Please sign in to comment.