From e75522b0636fb5115d73da43da244d0c3832794f Mon Sep 17 00:00:00 2001 From: Kalyan Date: Wed, 7 Feb 2024 20:46:09 +0530 Subject: [PATCH] fix: PythonVirtualenvOperator crashes if any python_callable function is defined in the same source as DAG (#37165) --------- Signed-off-by: kalyanr --- airflow/models/dagbag.py | 12 ++++++---- airflow/operators/python.py | 23 ++++++++++++------- airflow/utils/file.py | 12 ++++++++++ airflow/utils/python_virtualenv_script.jinja2 | 19 +++++++++++++-- 4 files changed, 51 insertions(+), 15 deletions(-) diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index ca81af8105410..ce9bf5587be80 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import hashlib import importlib import importlib.machinery import importlib.util @@ -48,7 +47,12 @@ from airflow.utils import timezone from airflow.utils.dag_cycle_tester import check_cycle from airflow.utils.docs import get_docs_url -from airflow.utils.file import correct_maybe_zipped, list_py_file_paths, might_contain_dag +from airflow.utils.file import ( + correct_maybe_zipped, + get_unique_dag_module_name, + list_py_file_paths, + might_contain_dag, +) from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries from airflow.utils.session import NEW_SESSION, provide_session @@ -326,9 +330,7 @@ def _load_modules_from_file(self, filepath, safe_mode): return [] self.log.debug("Importing %s", filepath) - path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest() - org_mod_name = Path(filepath).stem - mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}" + mod_name = get_unique_dag_module_name(filepath) if mod_name in sys.modules: del sys.modules[mod_name] diff --git a/airflow/operators/python.py b/airflow/operators/python.py index 1c5c9d3f6953c..1b1453cc5ed50 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -52,6 +52,7 @@ from airflow.operators.branch import BranchMixIn from airflow.utils import hashlib_wrapper from airflow.utils.context import context_copy_partial, context_merge +from airflow.utils.file import get_unique_dag_module_name from airflow.utils.operator_helpers import KeywordParameters from airflow.utils.process_utils import execute_in_subprocess from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script @@ -437,15 +438,21 @@ def _execute_python_callable_in_subprocess(self, python_path: Path): self._write_args(input_path) self._write_string_args(string_args_path) + + jinja_context = { + "op_args": self.op_args, + "op_kwargs": op_kwargs, + "expect_airflow": self.expect_airflow, + "pickling_library": self.pickling_library.__name__, + "python_callable": self.python_callable.__name__, + "python_callable_source": self.get_python_source(), + } + + if inspect.getfile(self.python_callable) == self.dag.fileloc: + jinja_context["modified_dag_module_name"] = get_unique_dag_module_name(self.dag.fileloc) + write_python_script( - jinja_context={ - "op_args": self.op_args, - "op_kwargs": op_kwargs, - "expect_airflow": self.expect_airflow, - "pickling_library": self.pickling_library.__name__, - "python_callable": self.python_callable.__name__, - "python_callable_source": self.get_python_source(), - }, + jinja_context=jinja_context, filename=os.fspath(script_path), render_template_as_native_obj=self.dag.render_template_as_native_obj, ) diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 013d9ea36ab22..7e15eeb2f8d72 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -18,6 +18,7 @@ from __future__ import annotations import ast +import hashlib import logging import os import zipfile @@ -33,6 +34,8 @@ log = logging.getLogger(__name__) +MODIFIED_DAG_MODULE_NAME = "unusual_prefix_{path_hash}_{module_name}" + class _IgnoreRule(Protocol): """Interface for ignore rules for structural subtyping.""" @@ -379,3 +382,12 @@ def iter_airflow_imports(file_path: str) -> Generator[str, None, None]: for m in _find_imported_modules(parsed): if m.startswith("airflow."): yield m + + +def get_unique_dag_module_name(file_path: str) -> str: + """Returns a unique module name in the format unusual_prefix_{sha1 of module's file path}_{original module name}.""" + if isinstance(file_path, str): + path_hash = hashlib.sha1(file_path.encode("utf-8")).hexdigest() + org_mod_name = Path(file_path).stem + return MODIFIED_DAG_MODULE_NAME.format(path_hash=path_hash, module_name=org_mod_name) + raise ValueError("file_path should be a string to generate unique module name") diff --git a/airflow/utils/python_virtualenv_script.jinja2 b/airflow/utils/python_virtualenv_script.jinja2 index 7bbf6a953193b..4199a47130fb9 100644 --- a/airflow/utils/python_virtualenv_script.jinja2 +++ b/airflow/utils/python_virtualenv_script.jinja2 @@ -34,6 +34,22 @@ if sys.version_info >= (3,6): pass {% endif %} +# Script +{{ python_callable_source }} + +# monkey patching for the cases when python_callable is part of the dag module. +{% if modified_dag_module_name is defined %} + +import types + +{{ modified_dag_module_name }} = types.ModuleType("{{ modified_dag_module_name }}") + +{{ modified_dag_module_name }}.{{ python_callable }} = {{ python_callable }} + +sys.modules["{{modified_dag_module_name}}"] = {{modified_dag_module_name}} + +{% endif%} + {% if op_args or op_kwargs %} with open(sys.argv[1], "rb") as file: arg_dict = {{ pickling_library }}.load(file) @@ -47,8 +63,7 @@ with open(sys.argv[3], "r") as file: virtualenv_string_args = list(map(lambda x: x.strip(), list(file))) {% endif %} -# Script -{{ python_callable_source }} + try: res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"]) except Exception as e: