diff --git a/airflow/providers/docker/decorators/docker.py b/airflow/providers/docker/decorators/docker.py index 30827bb6fefd6..153b55d4ab1a9 100644 --- a/airflow/providers/docker/decorators/docker.py +++ b/airflow/providers/docker/decorators/docker.py @@ -89,7 +89,7 @@ def __init__( command = "placeholder command" self.python_command = python_command self.expect_airflow = expect_airflow - self.pickling_library = dill if use_dill else pickle + self.use_dill = use_dill super().__init__( command=command, retrieve_output=True, retrieve_output_path="/tmp/script.out", **kwargs ) @@ -143,6 +143,12 @@ def get_python_source(self): res = remove_task_decorator(res, self.custom_operator_name) return res + @property + def pickling_library(self): + if self.use_dill: + return dill + return pickle + def docker_task( python_callable: Callable | None = None, diff --git a/tests/providers/docker/decorators/test_docker.py b/tests/providers/docker/decorators/test_docker.py index 26d4cb5b6c2d9..c70fd5a37960e 100644 --- a/tests/providers/docker/decorators/test_docker.py +++ b/tests/providers/docker/decorators/test_docker.py @@ -190,3 +190,46 @@ def f(): teardown_task = dag.task_group.children["f"] assert teardown_task.is_teardown assert teardown_task.on_failure_fail_dagrun is on_failure_fail_dagrun + + @pytest.mark.parametrize("use_dill", [True, False]) + def test_deepcopy_with_python_operator(self, dag_maker, use_dill): + import copy + + from airflow.providers.docker.decorators.docker import _DockerDecoratedOperator + + @task.docker(image="python:3.9-slim", auto_remove="force", use_dill=use_dill) + def f(): + import logging + + logger = logging.getLogger("airflow.task") + logger.info("info log in docker") + + @task.python() + def g(): + import logging + + logger = logging.getLogger("airflow.task") + logger.info("info log in python") + + with dag_maker() as dag: + docker_task = f() + python_task = g() + _ = python_task >> docker_task + + docker_operator = getattr(docker_task, "operator", None) + assert isinstance(docker_operator, _DockerDecoratedOperator) + task_id = docker_operator.task_id + + assert isinstance(dag, DAG) + assert hasattr(dag, "task_dict") + assert isinstance(dag.task_dict, dict) + assert task_id in dag.task_dict + + some_task = dag.task_dict[task_id] + clone_of_docker_operator = copy.deepcopy(docker_operator) + assert isinstance(some_task, _DockerDecoratedOperator) + assert isinstance(clone_of_docker_operator, _DockerDecoratedOperator) + assert some_task.command == clone_of_docker_operator.command + assert some_task.expect_airflow == clone_of_docker_operator.expect_airflow + assert some_task.use_dill == clone_of_docker_operator.use_dill + assert some_task.pickling_library is clone_of_docker_operator.pickling_library