diff --git a/airflow/providers/docker/operators/docker_swarm.py b/airflow/providers/docker/operators/docker_swarm.py index b9fc6c89a77f2..a05bfdc897864 100644 --- a/airflow/providers/docker/operators/docker_swarm.py +++ b/airflow/providers/docker/operators/docker_swarm.py @@ -19,6 +19,7 @@ from __future__ import annotations import re +import shlex from datetime import datetime from time import sleep from typing import TYPE_CHECKING @@ -58,6 +59,7 @@ class DockerSwarmOperator(DockerOperator): container's process exits. The default is False. :param command: Command to be run in the container. (templated) + :param args: Arguments to the command. :param docker_url: URL of the host running the docker daemon. Default is the value of the ``DOCKER_HOST`` environment variable or unix://var/run/docker.sock if it is unset. @@ -106,6 +108,7 @@ def __init__( self, *, image: str, + args: str | list[str] | None = None, enable_logging: bool = True, configs: list[types.ConfigReference] | None = None, secrets: list[types.SecretReference] | None = None, @@ -116,6 +119,7 @@ def __init__( **kwargs, ) -> None: super().__init__(image=image, **kwargs) + self.args = args self.enable_logging = enable_logging self.service = None self.configs = configs @@ -136,6 +140,7 @@ def _run_service(self) -> None: container_spec=types.ContainerSpec( image=self.image, command=self.format_command(self.command), + args=self.format_args(self.args), mounts=self.mounts, env=self.environment, user=self.user, @@ -225,6 +230,20 @@ def stream_new_logs(last_line_logged, since=0): sleep(2) last_line_logged, last_timestamp = stream_new_logs(last_line_logged, since=last_timestamp) + @staticmethod + def format_args(args: list[str] | str | None) -> list[str] | None: + """Retrieve args. + + The args string is parsed to a list. + + :param args: args to the docker service + + :return: the args as list + """ + if isinstance(args, str): + return shlex.split(args) + return args + def on_kill(self) -> None: if self.hook.client_created and self.service is not None: self.log.info("Removing docker service: %s", self.service["ID"]) diff --git a/tests/providers/docker/operators/test_docker_swarm.py b/tests/providers/docker/operators/test_docker_swarm.py index 5576eec0837c9..29661123d518b 100644 --- a/tests/providers/docker/operators/test_docker_swarm.py +++ b/tests/providers/docker/operators/test_docker_swarm.py @@ -84,6 +84,7 @@ def _client_service_logs_effect(): types_mock.ContainerSpec.assert_called_once_with( image="ubuntu:latest", command="env", + args=None, user="unittest", mounts=[types.Mount(source="/host/path", target="/container/path", type="bind")], tty=True, @@ -254,3 +255,79 @@ def test_container_resources(self, types_mock, docker_api_client_patcher): placement=None, ) types_mock.Resources.assert_not_called() + + @mock.patch("airflow.providers.docker.operators.docker_swarm.types") + def test_service_args_str(self, types_mock, docker_api_client_patcher): + mock_obj = mock.Mock() + + client_mock = mock.Mock(spec=APIClient) + client_mock.create_service.return_value = {"ID": "some_id"} + client_mock.images.return_value = [] + client_mock.pull.return_value = [b'{"status":"pull log"}'] + client_mock.tasks.return_value = [{"Status": {"State": "complete"}}] + types_mock.TaskTemplate.return_value = mock_obj + types_mock.ContainerSpec.return_value = mock_obj + types_mock.RestartPolicy.return_value = mock_obj + types_mock.Resources.return_value = mock_obj + + docker_api_client_patcher.return_value = client_mock + + operator = DockerSwarmOperator( + image="ubuntu:latest", + command="env", + args="--show", + task_id="unittest", + auto_remove="success", + enable_logging=False, + ) + operator.execute(None) + + types_mock.ContainerSpec.assert_called_once_with( + image="ubuntu:latest", + command="env", + args=["--show"], + user=None, + mounts=[], + tty=False, + env={"AIRFLOW_TMP_DIR": "/tmp/airflow"}, + configs=None, + secrets=None, + ) + + @mock.patch("airflow.providers.docker.operators.docker_swarm.types") + def test_service_args_list(self, types_mock, docker_api_client_patcher): + mock_obj = mock.Mock() + + client_mock = mock.Mock(spec=APIClient) + client_mock.create_service.return_value = {"ID": "some_id"} + client_mock.images.return_value = [] + client_mock.pull.return_value = [b'{"status":"pull log"}'] + client_mock.tasks.return_value = [{"Status": {"State": "complete"}}] + types_mock.TaskTemplate.return_value = mock_obj + types_mock.ContainerSpec.return_value = mock_obj + types_mock.RestartPolicy.return_value = mock_obj + types_mock.Resources.return_value = mock_obj + + docker_api_client_patcher.return_value = client_mock + + operator = DockerSwarmOperator( + image="ubuntu:latest", + command="env", + args=["--show"], + task_id="unittest", + auto_remove="success", + enable_logging=False, + ) + operator.execute(None) + + types_mock.ContainerSpec.assert_called_once_with( + image="ubuntu:latest", + command="env", + args=["--show"], + user=None, + mounts=[], + tty=False, + env={"AIRFLOW_TMP_DIR": "/tmp/airflow"}, + configs=None, + secrets=None, + )