diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py index fb3e2552c9130..a91f16c4aa537 100644 --- a/airflow/kubernetes/pod_generator.py +++ b/airflow/kubernetes/pod_generator.py @@ -134,6 +134,8 @@ class PodGenerator: :type pod_template_file: Optional[str] :param extract_xcom: Whether to bring up a container for xcom :type extract_xcom: bool + :param priority_class_name: priority class name for the launched Pod + :type priority_class_name: str """ def __init__( # pylint: disable=too-many-arguments,too-many-locals self, @@ -165,6 +167,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals pod: Optional[k8s.V1Pod] = None, pod_template_file: Optional[str] = None, extract_xcom: bool = False, + priority_class_name: Optional[str] = None, ): self.validate_pod_generator_args(locals()) @@ -228,6 +231,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self.spec.volumes = volumes or [] self.spec.node_selector = node_selectors self.spec.restart_policy = restart_policy + self.spec.priority_class_name = priority_class_name self.spec.image_pull_secrets = [] diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py index 1b64412b14f1b..b54979ab55937 100644 --- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -136,6 +136,8 @@ class KubernetesPodOperator(BaseOperator): # pylint: disable=too-many-instance- :type do_xcom_push: bool :param pod_template_file: path to pod template file :type pod_template_file: str + :param priority_class_name: priority class name for the launched Pod + :type priority_class_name: str """ template_fields = ('cmds', 'arguments', 'env_vars', 'config_file', 'pod_template_file') @@ -177,6 +179,7 @@ def __init__(self, # pylint: disable=too-many-arguments,too-many-locals log_events_on_failure: bool = False, do_xcom_push: bool = False, pod_template_file: Optional[str] = None, + priority_class_name: Optional[str] = None, *args, **kwargs): if kwargs.get('xcom_push') is not None: @@ -218,6 +221,7 @@ def __init__(self, # pylint: disable=too-many-arguments,too-many-locals self.full_pod_spec = full_pod_spec self.init_containers = init_containers or [] self.log_events_on_failure = log_events_on_failure + self.priority_class_name = priority_class_name self.pod_template_file = pod_template_file self.name = self._set_name(name) @@ -263,8 +267,9 @@ def execute(self, context): schedulername=self.schedulername, init_containers=self.init_containers, restart_policy='Never', + priority_class_name=self.priority_class_name, pod_template_file=self.pod_template_file, - pod=self.full_pod_spec + pod=self.full_pod_spec, ).gen_pod() pod = append_to_pod( diff --git a/docs/howto/operator/kubernetes.rst b/docs/howto/operator/kubernetes.rst index 35f3cae388ad7..c821acaa954eb 100644 --- a/docs/howto/operator/kubernetes.rst +++ b/docs/howto/operator/kubernetes.rst @@ -161,4 +161,5 @@ The :class:`airflow.providers.cncf.kubernetes.operators.kubernetes_pod.Kubernete tolerations=tolerations, configmaps=configmaps, init_containers=[init_container], + priority_class_name="medium", ) diff --git a/tests/runtime/kubernetes/test_kubernetes_pod_operator.py b/tests/runtime/kubernetes/test_kubernetes_pod_operator.py index 21566de99a103..5045f4e53c250 100644 --- a/tests/runtime/kubernetes/test_kubernetes_pod_operator.py +++ b/tests/runtime/kubernetes/test_kubernetes_pod_operator.py @@ -787,6 +787,34 @@ def test_pod_template_file(self, mock_client, launcher_mock): } }, actual_pod) + @mock.patch("airflow.kubernetes.pod_launcher.PodLauncher.run_pod") + @mock.patch("airflow.kubernetes.kube_client.get_kube_client") + def test_pod_priority_class_name(self, mock_client, launcher_mock): + """Test ability to assign priorityClassName to pod + + """ + from airflow.utils.state import State + + priority_class_name = "medium-test" + k = KubernetesPodOperator( + namespace='default', + image="ubuntu:16.04", + cmds=["bash", "-cx"], + arguments=["echo 10"], + labels={"foo": "bar"}, + name="test", + task_id="task", + in_cluster=False, + do_xcom_push=False, + priority_class_name=priority_class_name, + ) + + launcher_mock.return_value = (State.SUCCESS, None) + k.execute(None) + actual_pod = self.api_client.sanitize_for_serialization(k.pod) + self.expected_pod['spec']['priorityClassName'] = priority_class_name + self.assertEqual(self.expected_pod, actual_pod) + # pylint: enable=unused-argument if __name__ == '__main__':