From 9c4e333f5b7cc6f950f6791500ecd4bad41ba2f9 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Mon, 25 Mar 2024 11:50:53 +0100 Subject: [PATCH] fix: disabled_for_operators now stops whole event emission (#38033) Signed-off-by: Kacper Muda --- .../providers/dbt/cloud/utils/openlineage.py | 6 +- airflow/providers/openlineage/conf.py | 98 +++++ .../providers/openlineage/extractors/base.py | 30 -- .../providers/openlineage/extractors/bash.py | 8 +- .../openlineage/extractors/manager.py | 30 +- .../openlineage/extractors/python.py | 8 +- .../providers/openlineage/plugins/adapter.py | 28 +- .../providers/openlineage/plugins/listener.py | 23 +- .../providers/openlineage/plugins/macros.py | 5 +- .../openlineage/plugins/openlineage.py | 19 +- airflow/providers/openlineage/utils/utils.py | 22 +- .../openlineage/extractors/test_base.py | 33 +- .../openlineage/extractors/test_bash.py | 103 +---- .../openlineage/extractors/test_python.py | 117 +----- .../openlineage/plugins/test_listener.py | 69 +++- .../openlineage/plugins/test_macros.py | 4 +- .../openlineage/plugins/test_openlineage.py | 65 ++- .../plugins/test_openlineage_adapter.py | 181 +++++++- .../openlineage/plugins/test_utils.py | 30 ++ tests/providers/openlineage/test_conf.py | 389 ++++++++++++++++++ 20 files changed, 913 insertions(+), 355 deletions(-) create mode 100644 airflow/providers/openlineage/conf.py create mode 100644 tests/providers/openlineage/test_conf.py diff --git a/airflow/providers/dbt/cloud/utils/openlineage.py b/airflow/providers/dbt/cloud/utils/openlineage.py index f86c77a6897e7..358382c39bbe1 100644 --- a/airflow/providers/dbt/cloud/utils/openlineage.py +++ b/airflow/providers/dbt/cloud/utils/openlineage.py @@ -47,9 +47,9 @@ def generate_openlineage_events_from_dbt_cloud_run( """ from openlineage.common.provider.dbt import DbtCloudArtifactProcessor, ParentRunMetadata + from airflow.providers.openlineage.conf import namespace from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.plugins.adapter import ( - _DAG_NAMESPACE, _PRODUCER, OpenLineageAdapter, ) @@ -110,7 +110,7 @@ async def get_artifacts_for_steps(steps, artifacts): processor = DbtCloudArtifactProcessor( producer=_PRODUCER, - job_namespace=_DAG_NAMESPACE, + job_namespace=namespace(), skip_errors=False, logger=operator.log, manifest=manifest, @@ -130,7 +130,7 @@ async def get_artifacts_for_steps(steps, artifacts): parent_job = ParentRunMetadata( run_id=parent_run_id, job_name=f"{task_instance.dag_id}.{task_instance.task_id}", - job_namespace=_DAG_NAMESPACE, + job_namespace=namespace(), ) processor.dbt_run_metadata = parent_job diff --git a/airflow/providers/openlineage/conf.py b/airflow/providers/openlineage/conf.py new file mode 100644 index 0000000000000..ba8ce913c7191 --- /dev/null +++ b/airflow/providers/openlineage/conf.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import os +from typing import Any + +from airflow.compat.functools import cache +from airflow.configuration import conf + +_CONFIG_SECTION = "openlineage" + + +@cache +def config_path(check_legacy_env_var: bool = True) -> str: + """[openlineage] config_path.""" + option = conf.get(_CONFIG_SECTION, "config_path", fallback="") + if check_legacy_env_var and not option: + option = os.getenv("OPENLINEAGE_CONFIG", "") + return option + + +@cache +def is_source_enabled() -> bool: + """[openlineage] disable_source_code.""" + option = conf.get(_CONFIG_SECTION, "disable_source_code", fallback="") + if not option: + option = os.getenv("OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE", "") + return option.lower() not in ("true", "1", "t") + + +@cache +def disabled_operators() -> set[str]: + """[openlineage] disabled_for_operators.""" + option = conf.get(_CONFIG_SECTION, "disabled_for_operators", fallback="") + return set(operator.strip() for operator in option.split(";") if operator.strip()) + + +@cache +def custom_extractors() -> set[str]: + """[openlineage] extractors.""" + option = conf.get(_CONFIG_SECTION, "extractors", fallback="") + if not option: + option = os.getenv("OPENLINEAGE_EXTRACTORS", "") + return set(extractor.strip() for extractor in option.split(";") if extractor.strip()) + + +@cache +def namespace() -> str: + """[openlineage] namespace.""" + option = conf.get(_CONFIG_SECTION, "namespace", fallback="") + if not option: + option = os.getenv("OPENLINEAGE_NAMESPACE", "default") + return option + + +@cache +def transport() -> dict[str, Any]: + """[openlineage] transport.""" + option = conf.getjson(_CONFIG_SECTION, "transport", fallback={}) + if not isinstance(option, dict): + raise ValueError(f"OpenLineage transport `{option}` is not a dict") + return option + + +@cache +def is_disabled() -> bool: + """[openlineage] disabled + some extra checks.""" + + def _is_true(val): + return str(val).lower().strip() in ("true", "1", "t") + + option = conf.get(_CONFIG_SECTION, "disabled", fallback="") + if _is_true(option): + return True + + option = os.getenv("OPENLINEAGE_DISABLED", "") + if _is_true(option): + return True + + # Check if both 'transport' and 'config_path' are not present and also + # if legacy 'OPENLINEAGE_URL' environment variables is not set + return transport() == {} and config_path(True) == "" and os.getenv("OPENLINEAGE_URL", "") == "" diff --git a/airflow/providers/openlineage/extractors/base.py b/airflow/providers/openlineage/extractors/base.py index f5cb27027db74..2f5f957b5bc32 100644 --- a/airflow/providers/openlineage/extractors/base.py +++ b/airflow/providers/openlineage/extractors/base.py @@ -18,12 +18,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from functools import cached_property from typing import TYPE_CHECKING from attrs import Factory, define -from airflow.configuration import conf from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import TaskInstanceState @@ -64,31 +62,10 @@ def get_operator_classnames(cls) -> list[str]: """ raise NotImplementedError() - @cached_property - def disabled_operators(self) -> set[str]: - return set( - operator.strip() - for operator in conf.get("openlineage", "disabled_for_operators", fallback="").split(";") - ) - - @cached_property - def _is_operator_disabled(self) -> bool: - fully_qualified_class_name = ( - self.operator.__class__.__module__ + "." + self.operator.__class__.__name__ - ) - return fully_qualified_class_name in self.disabled_operators - @abstractmethod def _execute_extraction(self) -> OperatorLineage | None: ... def extract(self) -> OperatorLineage | None: - if self._is_operator_disabled: - self.log.debug( - "Skipping extraction for operator %s " - "due to its presence in [openlineage] openlineage_disabled_for_operators.", - self.operator.task_type, - ) - return None return self._execute_extraction() def extract_on_complete(self, task_instance) -> OperatorLineage | None: @@ -125,13 +102,6 @@ def _execute_extraction(self) -> OperatorLineage | None: return None def extract_on_complete(self, task_instance) -> OperatorLineage | None: - if self._is_operator_disabled: - self.log.debug( - "Skipping extraction for operator %s " - "due to its presence in [openlineage] openlineage_disabled_for_operators.", - self.operator.task_type, - ) - return None if task_instance.state == TaskInstanceState.FAILED: on_failed = getattr(self.operator, "get_openlineage_facets_on_failure", None) if on_failed and callable(on_failed): diff --git a/airflow/providers/openlineage/extractors/bash.py b/airflow/providers/openlineage/extractors/bash.py index d5fec894c658f..6d06ab6387350 100644 --- a/airflow/providers/openlineage/extractors/bash.py +++ b/airflow/providers/openlineage/extractors/bash.py @@ -19,15 +19,13 @@ from openlineage.client.facet import SourceCodeJobFacet +from airflow.providers.openlineage import conf from airflow.providers.openlineage.extractors.base import BaseExtractor, OperatorLineage from airflow.providers.openlineage.plugins.facets import ( UnknownOperatorAttributeRunFacet, UnknownOperatorInstance, ) -from airflow.providers.openlineage.utils.utils import ( - get_filtered_unknown_operator_keys, - is_source_enabled, -) +from airflow.providers.openlineage.utils.utils import get_filtered_unknown_operator_keys """ :meta private: @@ -51,7 +49,7 @@ def get_operator_classnames(cls) -> list[str]: def _execute_extraction(self) -> OperatorLineage | None: job_facets: dict = {} - if is_source_enabled(): + if conf.is_source_enabled(): job_facets = { "sourceCode": SourceCodeJobFacet( language="bash", diff --git a/airflow/providers/openlineage/extractors/manager.py b/airflow/providers/openlineage/extractors/manager.py index 405db3d8e5a4c..d8de9419fbf05 100644 --- a/airflow/providers/openlineage/extractors/manager.py +++ b/airflow/providers/openlineage/extractors/manager.py @@ -16,11 +16,10 @@ # under the License. from __future__ import annotations -import os from contextlib import suppress from typing import TYPE_CHECKING, Iterator -from airflow.configuration import conf +from airflow.providers.openlineage import conf from airflow.providers.openlineage.extractors import BaseExtractor, OperatorLineage from airflow.providers.openlineage.extractors.base import DefaultExtractor from airflow.providers.openlineage.extractors.bash import BashExtractor @@ -65,22 +64,17 @@ def __init__(self): for operator_class in extractor.get_operator_classnames(): self.extractors[operator_class] = extractor - # Semicolon-separated extractors in Airflow configuration or OPENLINEAGE_EXTRACTORS variable. - # Extractors should implement BaseExtractor - env_extractors = conf.get("openlineage", "extractors", fallback=os.getenv("OPENLINEAGE_EXTRACTORS")) - # skip either when it's empty string or None - if env_extractors: - for extractor in env_extractors.split(";"): - extractor: type[BaseExtractor] = try_import_from_string(extractor.strip()) - for operator_class in extractor.get_operator_classnames(): - if operator_class in self.extractors: - self.log.debug( - "Duplicate extractor found for `%s`. `%s` will be used instead of `%s`", - operator_class, - extractor, - self.extractors[operator_class], - ) - self.extractors[operator_class] = extractor + for extractor_path in conf.custom_extractors(): + extractor: type[BaseExtractor] = try_import_from_string(extractor_path) + for operator_class in extractor.get_operator_classnames(): + if operator_class in self.extractors: + self.log.debug( + "Duplicate extractor found for `%s`. `%s` will be used instead of `%s`", + operator_class, + extractor_path, + self.extractors[operator_class], + ) + self.extractors[operator_class] = extractor def add_extractor(self, operator_class: str, extractor: type[BaseExtractor]): self.extractors[operator_class] = extractor diff --git a/airflow/providers/openlineage/extractors/python.py b/airflow/providers/openlineage/extractors/python.py index 56e4e19f1f25d..da85453efdc94 100644 --- a/airflow/providers/openlineage/extractors/python.py +++ b/airflow/providers/openlineage/extractors/python.py @@ -22,15 +22,13 @@ from openlineage.client.facet import SourceCodeJobFacet +from airflow.providers.openlineage import conf from airflow.providers.openlineage.extractors.base import BaseExtractor, OperatorLineage from airflow.providers.openlineage.plugins.facets import ( UnknownOperatorAttributeRunFacet, UnknownOperatorInstance, ) -from airflow.providers.openlineage.utils.utils import ( - get_filtered_unknown_operator_keys, - is_source_enabled, -) +from airflow.providers.openlineage.utils.utils import get_filtered_unknown_operator_keys """ :meta private: @@ -55,7 +53,7 @@ def get_operator_classnames(cls) -> list[str]: def _execute_extraction(self) -> OperatorLineage | None: source_code = self.get_source_code(self.operator.python_callable) job_facet: dict = {} - if is_source_enabled() and source_code: + if conf.is_source_enabled() and source_code: job_facet = { "sourceCode": SourceCodeJobFacet( language="python", diff --git a/airflow/providers/openlineage/plugins/adapter.py b/airflow/providers/openlineage/plugins/adapter.py index f6e3c257e6d1c..7ee82b58932f6 100644 --- a/airflow/providers/openlineage/plugins/adapter.py +++ b/airflow/providers/openlineage/plugins/adapter.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import os import uuid from contextlib import ExitStack from typing import TYPE_CHECKING @@ -37,8 +36,7 @@ ) from openlineage.client.run import Job, Run, RunEvent, RunState -from airflow.configuration import conf -from airflow.providers.openlineage import __version__ as OPENLINEAGE_PROVIDER_VERSION +from airflow.providers.openlineage import __version__ as OPENLINEAGE_PROVIDER_VERSION, conf from airflow.providers.openlineage.utils.utils import OpenLineageRedactor from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin @@ -48,12 +46,6 @@ from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils.log.secrets_masker import SecretsMasker -_DAG_DEFAULT_NAMESPACE = "default" - -_DAG_NAMESPACE = conf.get( - "openlineage", "namespace", fallback=os.getenv("OPENLINEAGE_NAMESPACE", _DAG_DEFAULT_NAMESPACE) -) - _PRODUCER = f"https://github.com/apache/airflow/tree/providers-openlineage/{OPENLINEAGE_PROVIDER_VERSION}" set_producer(_PRODUCER) @@ -88,18 +80,16 @@ def get_or_create_openlineage_client(self) -> OpenLineageClient: def get_openlineage_config(self) -> dict | None: # First, try to read from YAML file - openlineage_config_path = conf.get("openlineage", "config_path", fallback="") + openlineage_config_path = conf.config_path(check_legacy_env_var=False) if openlineage_config_path: config = self._read_yaml_config(openlineage_config_path) if config: return config.get("transport", None) # Second, try to get transport config - transport = conf.getjson("openlineage", "transport", fallback="") - if not transport: + transport_config = conf.transport() + if not transport_config: return None - elif not isinstance(transport, dict): - raise ValueError(f"{transport} is not a dict") - return transport + return transport_config def _read_yaml_config(self, path: str) -> dict | None: with open(path) as config_file: @@ -107,14 +97,14 @@ def _read_yaml_config(self, path: str) -> dict | None: @staticmethod def build_dag_run_id(dag_id, dag_run_id): - return str(uuid.uuid3(uuid.NAMESPACE_URL, f"{_DAG_NAMESPACE}.{dag_id}.{dag_run_id}")) + return str(uuid.uuid3(uuid.NAMESPACE_URL, f"{conf.namespace()}.{dag_id}.{dag_run_id}")) @staticmethod def build_task_instance_run_id(dag_id, task_id, execution_date, try_number): return str( uuid.uuid3( uuid.NAMESPACE_URL, - f"{_DAG_NAMESPACE}.{dag_id}.{task_id}.{execution_date}.{try_number}", + f"{conf.namespace()}.{dag_id}.{task_id}.{execution_date}.{try_number}", ) ) @@ -353,7 +343,7 @@ def _build_run( if parent_run_id: parent_run_facet = ParentRunFacet.create( runId=parent_run_id, - namespace=_DAG_NAMESPACE, + namespace=conf.namespace(), name=parent_job_name or job_name, ) facets.update( @@ -396,4 +386,4 @@ def _build_job( facets.update({"jobType": job_type}) - return Job(_DAG_NAMESPACE, job_name, facets) + return Job(conf.namespace(), job_name, facets) diff --git a/airflow/providers/openlineage/plugins/listener.py b/airflow/providers/openlineage/plugins/listener.py index b7f767af30ecb..0d6b487f22a98 100644 --- a/airflow/providers/openlineage/plugins/listener.py +++ b/airflow/providers/openlineage/plugins/listener.py @@ -30,6 +30,7 @@ get_airflow_run_facet, get_custom_facets, get_job_name, + is_operator_disabled, print_warning, ) from airflow.stats import Stats @@ -74,6 +75,13 @@ def on_task_instance_running( if TYPE_CHECKING: assert task dag = task.dag + if is_operator_disabled(task): + self.log.debug( + "Skipping OpenLineage event emission for operator %s " + "due to its presence in [openlineage] disabled_for_operators.", + task.task_type, + ) + return None @print_warning(self.log) def on_running(): @@ -134,6 +142,13 @@ def on_task_instance_success(self, previous_state, task_instance: TaskInstance, if TYPE_CHECKING: assert task dag = task.dag + if is_operator_disabled(task): + self.log.debug( + "Skipping OpenLineage event emission for operator %s " + "due to its presence in [openlineage] disabled_for_operators.", + task.task_type, + ) + return None @print_warning(self.log) def on_success(): @@ -179,6 +194,13 @@ def on_task_instance_failed(self, previous_state, task_instance: TaskInstance, s if TYPE_CHECKING: assert task dag = task.dag + if is_operator_disabled(task): + self.log.debug( + "Skipping OpenLineage event emission for operator %s " + "due to its presence in [openlineage] disabled_for_operators.", + task.task_type, + ) + return None @print_warning(self.log) def on_failure(): @@ -228,7 +250,6 @@ def on_starting(self, component): @hookimpl def before_stopping(self, component): self.log.debug("before_stopping: %s", component.__class__.__name__) - # TODO: configure this with Airflow config with timeout(30): self.executor.shutdown(wait=True) diff --git a/airflow/providers/openlineage/plugins/macros.py b/airflow/providers/openlineage/plugins/macros.py index e3cd8cade13fa..391b29495fe37 100644 --- a/airflow/providers/openlineage/plugins/macros.py +++ b/airflow/providers/openlineage/plugins/macros.py @@ -18,7 +18,8 @@ from typing import TYPE_CHECKING -from airflow.providers.openlineage.plugins.adapter import _DAG_NAMESPACE, OpenLineageAdapter +from airflow.providers.openlineage import conf +from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter from airflow.providers.openlineage.utils.utils import get_job_name if TYPE_CHECKING: @@ -60,4 +61,4 @@ def lineage_parent_id(task_instance: TaskInstance): """ job_name = get_job_name(task_instance.task) run_id = lineage_run_id(task_instance) - return f"{_DAG_NAMESPACE}/{job_name}/{run_id}" + return f"{conf.namespace()}/{job_name}/{run_id}" diff --git a/airflow/providers/openlineage/plugins/openlineage.py b/airflow/providers/openlineage/plugins/openlineage.py index 0511618ac8c7f..a0be47a499168 100644 --- a/airflow/providers/openlineage/plugins/openlineage.py +++ b/airflow/providers/openlineage/plugins/openlineage.py @@ -16,27 +16,12 @@ # under the License. from __future__ import annotations -import os - -from airflow.configuration import conf from airflow.plugins_manager import AirflowPlugin +from airflow.providers.openlineage import conf from airflow.providers.openlineage.plugins.listener import get_openlineage_listener from airflow.providers.openlineage.plugins.macros import lineage_parent_id, lineage_run_id -def _is_disabled() -> bool: - return ( - conf.getboolean("openlineage", "disabled", fallback=False) - or os.getenv("OPENLINEAGE_DISABLED", "false").lower() == "true" - or ( - conf.get("openlineage", "transport", fallback="") == "" - and conf.get("openlineage", "config_path", fallback="") == "" - and os.getenv("OPENLINEAGE_URL", "") == "" - and os.getenv("OPENLINEAGE_CONFIG", "") == "" - ) - ) - - class OpenLineageProviderPlugin(AirflowPlugin): """ Listener that emits numerous Events. @@ -46,6 +31,6 @@ class OpenLineageProviderPlugin(AirflowPlugin): """ name = "OpenLineageProviderPlugin" - if not _is_disabled(): + if not conf.is_disabled(): macros = [lineage_run_id, lineage_parent_id] listeners = [get_openlineage_listener()] diff --git a/airflow/providers/openlineage/utils/utils.py b/airflow/providers/openlineage/utils/utils.py index 4f8cfbff71afb..ef18933a75991 100644 --- a/airflow/providers/openlineage/utils/utils.py +++ b/airflow/providers/openlineage/utils/utils.py @@ -20,7 +20,6 @@ import datetime import json import logging -import os from contextlib import suppress from functools import wraps from typing import TYPE_CHECKING, Any, Iterable @@ -31,8 +30,7 @@ # TODO: move this maybe to Airflow's logic? from openlineage.client.utils import RedactMixin -from airflow.compat.functools import cache -from airflow.configuration import conf +from airflow.providers.openlineage import conf from airflow.providers.openlineage.plugins.facets import ( AirflowMappedTaskRunFacet, AirflowRunFacet, @@ -41,7 +39,7 @@ from airflow.utils.log.secrets_masker import Redactable, Redacted, SecretsMasker, should_hide_value_for_key if TYPE_CHECKING: - from airflow.models import DAG, BaseOperator, DagRun, TaskInstance + from airflow.models import DAG, BaseOperator, DagRun, MappedOperator, TaskInstance log = logging.getLogger(__name__) @@ -67,6 +65,14 @@ def get_custom_facets(task_instance: TaskInstance | None = None) -> dict[str, An return custom_facets +def get_fully_qualified_class_name(operator: BaseOperator | MappedOperator) -> str: + return operator.__class__.__module__ + "." + operator.__class__.__name__ + + +def is_operator_disabled(operator: BaseOperator | MappedOperator) -> bool: + return get_fully_qualified_class_name(operator) in conf.disabled_operators() + + class InfoJsonEncodable(dict): """ Airflow objects might not be json-encodable overall. @@ -329,14 +335,6 @@ def wrapper(*args, **kwargs): return decorator -@cache -def is_source_enabled() -> bool: - source_var = conf.get( - "openlineage", "disable_source_code", fallback=os.getenv("OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE") - ) - return isinstance(source_var, str) and source_var.lower() not in ("true", "1", "t") - - def get_filtered_unknown_operator_keys(operator: BaseOperator) -> dict: not_required_keys = {"dag", "task_group"} return {attr: value for attr, value in operator.__dict__.items() if attr not in not_required_keys} diff --git a/tests/providers/openlineage/extractors/test_base.py b/tests/providers/openlineage/extractors/test_base.py index 35d51ee2937af..d812106051674 100644 --- a/tests/providers/openlineage/extractors/test_base.py +++ b/tests/providers/openlineage/extractors/test_base.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import os from typing import Any from unittest import mock @@ -34,7 +33,6 @@ ) from airflow.providers.openlineage.extractors.manager import ExtractorManager from airflow.providers.openlineage.extractors.python import PythonExtractor -from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -230,20 +228,9 @@ def test_extraction_without_on_start(): ) -@mock.patch.dict( - os.environ, - {"OPENLINEAGE_EXTRACTORS": "tests.providers.openlineage.extractors.test_base.ExampleExtractor"}, -) -def test_extractors_env_var(): - extractor = ExtractorManager().get_extractor_class(ExampleOperator(task_id="example")) - assert extractor is ExampleExtractor - - -@mock.patch.dict(os.environ, {"OPENLINEAGE_EXTRACTORS": "no.such.extractor"}) -@conf_vars( - {("openlineage", "extractors"): "tests.providers.openlineage.extractors.test_base.ExampleExtractor"} -) -def test_config_has_precedence_over_env_var(): +@mock.patch("airflow.providers.openlineage.conf.custom_extractors") +def test_extractors_env_var(custom_extractors): + custom_extractors.return_value = {"tests.providers.openlineage.extractors.test_base.ExampleExtractor"} extractor = ExtractorManager().get_extractor_class(ExampleOperator(task_id="example")) assert extractor is ExampleExtractor @@ -285,17 +272,3 @@ def test_default_extractor_uses_wrong_operatorlineage_class(): assert ( ExtractorManager().extract_metadata(mock.MagicMock(), operator, complete=False) == OperatorLineage() ) - - -@mock.patch.dict( - os.environ, - { - "AIRFLOW__OPENLINEAGE__DISABLED_FOR_OPERATORS": "tests.providers.openlineage.extractors.test_base.ExampleOperator" - }, -) -def test_default_extraction_disabled_operator(): - extractor = DefaultExtractor(ExampleOperator(task_id="test")) - metadata = extractor.extract() - assert metadata is None - metadata = extractor.extract_on_complete(None) - assert metadata is None diff --git a/tests/providers/openlineage/extractors/test_bash.py b/tests/providers/openlineage/extractors/test_bash.py index 01140c23492c4..c96468ac86b9f 100644 --- a/tests/providers/openlineage/extractors/test_bash.py +++ b/tests/providers/openlineage/extractors/test_bash.py @@ -17,7 +17,6 @@ from __future__ import annotations -import os from datetime import datetime from unittest.mock import patch @@ -27,8 +26,7 @@ from airflow import DAG from airflow.operators.bash import BashOperator from airflow.providers.openlineage.extractors.bash import BashExtractor -from airflow.providers.openlineage.utils.utils import is_source_enabled -from tests.test_utils.config import conf_vars +from airflow.providers.openlineage.plugins.facets import UnknownOperatorAttributeRunFacet pytestmark = pytest.mark.db_test @@ -43,90 +41,23 @@ bash_task = BashOperator(task_id="bash-task", bash_command="ls -halt && exit 0", dag=dag) -@pytest.fixture(autouse=True) -def clear_cache(): - is_source_enabled.cache_clear() - try: - yield - finally: - is_source_enabled.cache_clear() - - -def test_extract_operator_bash_command_disables_without_env(): - operator = BashOperator(task_id="taskid", bash_command="exit 0") - extractor = BashExtractor(operator) - assert "sourceCode" not in extractor.extract().job_facets - - -@patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "False"}) -def test_extract_operator_bash_command_enables_on_true_env(): +@patch("airflow.providers.openlineage.conf.is_source_enabled") +def test_extract_operator_bash_command_disabled(mocked_source_enabled): + mocked_source_enabled.return_value = False operator = BashOperator(task_id="taskid", bash_command="exit 0") - extractor = BashExtractor(operator) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet("bash", "exit 0") + result = BashExtractor(operator).extract() + assert "sourceCode" not in result.job_facets + assert "unknownSourceAttribute" in result.run_facets -@conf_vars({("openlineage", "disable_source_code"): "False"}) -def test_extract_operator_bash_command_enables_on_true_conf(): +@patch("airflow.providers.openlineage.conf.is_source_enabled") +def test_extract_operator_bash_command_enabled(mocked_source_enabled): + mocked_source_enabled.return_value = True operator = BashOperator(task_id="taskid", bash_command="exit 0") - extractor = BashExtractor(operator) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet("bash", "exit 0") - - -@patch.dict( - os.environ, - {k: v for k, v in os.environ.items() if k != "OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE"}, - clear=True, -) -def test_extract_dag_bash_command_disabled_without_env(): - extractor = BashExtractor(bash_task) - assert "sourceCode" not in extractor.extract().job_facets - - -@patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "False"}) -def test_extract_dag_bash_command_enables_on_true_env(): - extractor = BashExtractor(bash_task) - print(extractor.extract().job_facets) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet("bash", "ls -halt && exit 0") - - -@conf_vars({("openlineage", "disable_source_code"): "False"}) -def test_extract_dag_bash_command_enables_on_true_conf(): - extractor = BashExtractor(bash_task) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet("bash", "ls -halt && exit 0") - - -@patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "True"}) -def test_extract_dag_bash_command_env_disables_on_true(): - extractor = BashExtractor(bash_task) - assert "sourceCode" not in extractor.extract().job_facets - - -@conf_vars({("openlineage", "disable_source_code"): "true"}) -def test_extract_dag_bash_command_conf_disables_on_true(): - extractor = BashExtractor(bash_task) - assert "sourceCode" not in extractor.extract().job_facets - - -@patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "asdftgeragdsfgawef"}) -def test_extract_dag_bash_command_env_does_not_disable_on_random_string(): - extractor = BashExtractor(bash_task) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet("bash", "ls -halt && exit 0") - - -@conf_vars({("openlineage", "disable_source_code"): "asdftgeragdsfgawef"}) -def test_extract_dag_bash_command_conf_does_not_disable_on_random_string(): - extractor = BashExtractor(bash_task) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet("bash", "ls -halt && exit 0") - - -@patch.dict( - os.environ, - {"AIRFLOW__OPENLINEAGE__DISABLED_FOR_OPERATORS": "airflow.operators.bash.BashOperator"}, -) -def test_bash_extraction_disabled_operator(): - operator = BashOperator(task_id="taskid", bash_command="echo 1;") - extractor = BashExtractor(operator) - metadata = extractor.extract() - assert metadata is None - metadata = extractor.extract_on_complete(None) - assert metadata is None + result = BashExtractor(operator).extract() + assert result.job_facets["sourceCode"] == SourceCodeJobFacet("bash", "exit 0") + assert "unknownSourceAttribute" in result.run_facets + unknown_operator_facet = result.run_facets["unknownSourceAttribute"] + assert isinstance(unknown_operator_facet, UnknownOperatorAttributeRunFacet) + assert len(unknown_operator_facet.unknownItems) == 1 + assert unknown_operator_facet.unknownItems[0].name == "BashOperator" diff --git a/tests/providers/openlineage/extractors/test_python.py b/tests/providers/openlineage/extractors/test_python.py index 0e0e23b3eda82..328c9f24d29ca 100644 --- a/tests/providers/openlineage/extractors/test_python.py +++ b/tests/providers/openlineage/extractors/test_python.py @@ -29,8 +29,7 @@ from airflow.operators.bash import BashOperator from airflow.operators.python import PythonOperator from airflow.providers.openlineage.extractors.python import PythonExtractor -from airflow.providers.openlineage.utils.utils import is_source_enabled -from tests.test_utils.config import conf_vars +from airflow.providers.openlineage.plugins.facets import UnknownOperatorAttributeRunFacet pytestmark = pytest.mark.db_test @@ -58,110 +57,28 @@ def callable(): CODE = "def callable():\n print(10)\n" -@pytest.fixture(autouse=True) -def clear_cache(): - is_source_enabled.cache_clear() - try: - yield - finally: - is_source_enabled.cache_clear() - - def test_extract_source_code(): code = inspect.getsource(callable) assert code == CODE -def test_extract_operator_code_disables_on_no_env(): +@patch("airflow.providers.openlineage.conf.is_source_enabled") +def test_extract_operator_code_disabled(mocked_source_enabled): + mocked_source_enabled.return_value = False operator = PythonOperator(task_id="taskid", python_callable=callable) - extractor = PythonExtractor(operator) - assert "sourceCode" not in extractor.extract().job_facets + result = PythonExtractor(operator).extract() + assert "sourceCode" not in result.job_facets + assert "unknownSourceAttribute" in result.run_facets -@patch.dict( - os.environ, - {"AIRFLOW__OPENLINEAGE__DISABLED_FOR_OPERATORS": "airflow.operators.python.PythonOperator"}, -) -def test_python_extraction_disabled_operator(): +@patch("airflow.providers.openlineage.conf.is_source_enabled") +def test_extract_operator_code_enabled(mocked_source_enabled): + mocked_source_enabled.return_value = True operator = PythonOperator(task_id="taskid", python_callable=callable) - extractor = PythonExtractor(operator) - metadata = extractor.extract() - assert metadata is None - metadata = extractor.extract_on_complete(None) - assert metadata is None - - -@patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "False"}) -def test_extract_operator_code_enables_on_false_env(): - operator = PythonOperator(task_id="taskid", python_callable=callable) - extractor = PythonExtractor(operator) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet("python", CODE) - - -@conf_vars({("openlineage", "disable_source_code"): "False"}) -def test_extract_operator_code_enables_on_false_conf(): - operator = PythonOperator(task_id="taskid", python_callable=callable) - extractor = PythonExtractor(operator) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet("python", CODE) - - -def test_extract_dag_code_disables_on_no_env(): - extractor = PythonExtractor(python_task_getcwd) - assert "sourceCode" not in extractor.extract().job_facets - - -@patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "False"}) -def test_extract_dag_code_enables_on_true_env(): - extractor = PythonExtractor(python_task_getcwd) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet( - "python", "" - ) - - -@conf_vars({("openlineage", "disable_source_code"): "False"}) -def test_extract_dag_code_enables_on_true_conf(): - extractor = PythonExtractor(python_task_getcwd) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet( - "python", "" - ) - - -@conf_vars({("openlineage", "disable_source_code"): "False"}) -@patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "False"}) -def test_extract_dag_code_conf_precedence(): - extractor = PythonExtractor(python_task_getcwd) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet( - "python", "" - ) - - -@patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "True"}) -def test_extract_dag_code_env_disables_on_true(): - extractor = PythonExtractor(python_task_getcwd) - metadata = extractor.extract() - assert metadata is not None - assert "sourceCode" not in metadata.job_facets - - -@conf_vars({("openlineage", "disable_source_code"): "True"}) -def test_extract_dag_code_conf_disables_on_true(): - extractor = PythonExtractor(python_task_getcwd) - metadata = extractor.extract() - assert metadata is not None - assert "sourceCode" not in metadata.job_facets - - -@patch.dict(os.environ, {"OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE": "asdftgeragdsfgawef"}) -def test_extract_dag_code_env_does_not_disable_on_random_string(): - extractor = PythonExtractor(python_task_getcwd) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet( - "python", "" - ) - - -@conf_vars({("openlineage", "disable_source_code"): "asdftgeragdsfgawef"}) -def test_extract_dag_code_conf_does_not_disable_on_random_string(): - extractor = PythonExtractor(python_task_getcwd) - assert extractor.extract().job_facets["sourceCode"] == SourceCodeJobFacet( - "python", "" - ) + result = PythonExtractor(operator).extract() + assert result.job_facets["sourceCode"] == SourceCodeJobFacet("python", CODE) + assert "unknownSourceAttribute" in result.run_facets + unknown_operator_facet = result.run_facets["unknownSourceAttribute"] + assert isinstance(unknown_operator_facet, UnknownOperatorAttributeRunFacet) + assert len(unknown_operator_facet.unknownItems) == 1 + assert unknown_operator_facet.unknownItems[0].name == "PythonOperator" diff --git a/tests/providers/openlineage/plugins/test_listener.py b/tests/providers/openlineage/plugins/test_listener.py index 827c17c9f750c..69bfdabe91dce 100644 --- a/tests/providers/openlineage/plugins/test_listener.py +++ b/tests/providers/openlineage/plugins/test_listener.py @@ -194,11 +194,12 @@ def mock_task_id(dag_id, task_id, execution_date, try_number): return listener, task_instance +@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") @mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") @mock.patch("airflow.providers.openlineage.plugins.listener.get_custom_facets") @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") def test_adapter_start_task_is_called_with_proper_arguments( - mock_get_job_name, mock_get_custom_facets, mock_get_airflow_run_facet + mock_get_job_name, mock_get_custom_facets, mock_get_airflow_run_facet, mock_disabled ): """Tests that the 'start_task' method of the OpenLineageAdapter is invoked with the correct arguments. @@ -212,6 +213,7 @@ def test_adapter_start_task_is_called_with_proper_arguments( mock_get_job_name.return_value = "job_name" mock_get_custom_facets.return_value = {"custom_facet": 2} mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3} + mock_disabled.return_value = False listener.on_task_instance_running(None, task_instance, None) listener.adapter.start_task.assert_called_once_with( @@ -233,9 +235,10 @@ def test_adapter_start_task_is_called_with_proper_arguments( ) +@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter") @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") -def test_adapter_fail_task_is_called_with_proper_arguments(mock_get_job_name, mocked_adapter): +def test_adapter_fail_task_is_called_with_proper_arguments(mock_get_job_name, mocked_adapter, mock_disabled): """Tests that the 'fail_task' method of the OpenLineageAdapter is invoked with the correct arguments. This test ensures that the job name is accurately retrieved and included, along with the generated @@ -251,6 +254,7 @@ def mock_task_id(dag_id, task_id, execution_date, try_number): mock_get_job_name.return_value = "job_name" mocked_adapter.build_task_instance_run_id.side_effect = mock_task_id mocked_adapter.build_dag_run_id.side_effect = lambda x, y: f"{x}.{y}" + mock_disabled.return_value = False listener.on_task_instance_failed(None, task_instance, None) listener.adapter.fail_task.assert_called_once_with( @@ -263,9 +267,12 @@ def mock_task_id(dag_id, task_id, execution_date, try_number): ) +@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") @mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter") @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") -def test_adapter_complete_task_is_called_with_proper_arguments(mock_get_job_name, mocked_adapter): +def test_adapter_complete_task_is_called_with_proper_arguments( + mock_get_job_name, mocked_adapter, mock_disabled +): """Tests that the 'complete_task' method of the OpenLineageAdapter is called with the correct arguments. It checks that the job name is correctly retrieved and passed, @@ -282,6 +289,7 @@ def mock_task_id(dag_id, task_id, execution_date, try_number): mock_get_job_name.return_value = "job_name" mocked_adapter.build_task_instance_run_id.side_effect = mock_task_id mocked_adapter.build_dag_run_id.side_effect = lambda x, y: f"{x}.{y}" + mock_disabled.return_value = False listener.on_task_instance_success(None, task_instance, None) # This run_id will be different as we did NOT simulate increase of the try_number attribute, @@ -455,3 +463,58 @@ def success_callable(**kwargs): # try_number after task has been executed assert task_instance.try_number == 2 + + +@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") +@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet") +@mock.patch("airflow.providers.openlineage.plugins.listener.get_custom_facets") +@mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") +def test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_operator( + mock_get_job_name, mock_get_custom_facets, mock_get_airflow_run_facet, mock_disabled +): + listener, task_instance = _create_listener_and_task_instance() + mock_get_job_name.return_value = "job_name" + mock_get_custom_facets.return_value = {"custom_facet": 2} + mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3} + mock_disabled.return_value = True + + listener.on_task_instance_running(None, task_instance, None) + mock_disabled.assert_called_once_with(task_instance.task) + listener.adapter.build_dag_run_id.assert_not_called() + listener.adapter.build_task_instance_run_id.assert_not_called() + listener.extractor_manager.extract_metadata.assert_not_called() + listener.adapter.start_task.assert_not_called() + + +@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") +@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter") +@mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") +def test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_operator( + mock_get_job_name, mocked_adapter, mock_disabled +): + listener, task_instance = _create_listener_and_task_instance() + mock_disabled.return_value = True + + listener.on_task_instance_failed(None, task_instance, None) + mock_disabled.assert_called_once_with(task_instance.task) + mocked_adapter.build_dag_run_id.assert_not_called() + mocked_adapter.build_task_instance_run_id.assert_not_called() + listener.extractor_manager.extract_metadata.assert_not_called() + listener.adapter.fail_task.assert_not_called() + + +@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled") +@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter") +@mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name") +def test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_operator( + mock_get_job_name, mocked_adapter, mock_disabled +): + listener, task_instance = _create_listener_and_task_instance() + mock_disabled.return_value = True + + listener.on_task_instance_success(None, task_instance, None) + mock_disabled.assert_called_once_with(task_instance.task) + mocked_adapter.build_dag_run_id.assert_not_called() + mocked_adapter.build_task_instance_run_id.assert_not_called() + listener.extractor_manager.extract_metadata.assert_not_called() + listener.adapter.complete_task.assert_not_called() diff --git a/tests/providers/openlineage/plugins/test_macros.py b/tests/providers/openlineage/plugins/test_macros.py index 415cea36e4f63..9e2160aa196f9 100644 --- a/tests/providers/openlineage/plugins/test_macros.py +++ b/tests/providers/openlineage/plugins/test_macros.py @@ -19,9 +19,11 @@ import uuid from unittest import mock -from airflow.providers.openlineage.plugins.adapter import _DAG_NAMESPACE +from airflow.providers.openlineage.conf import namespace from airflow.providers.openlineage.plugins.macros import lineage_parent_id, lineage_run_id +_DAG_NAMESPACE = namespace() + def test_lineage_run_id(): task = mock.MagicMock( diff --git a/tests/providers/openlineage/plugins/test_openlineage.py b/tests/providers/openlineage/plugins/test_openlineage.py index fa41bc6aa7635..409c8461e8e02 100644 --- a/tests/providers/openlineage/plugins/test_openlineage.py +++ b/tests/providers/openlineage/plugins/test_openlineage.py @@ -23,14 +23,21 @@ import pytest +from airflow.providers.openlineage.conf import config_path, is_disabled, transport from tests.test_utils.config import conf_vars class TestOpenLineageProviderPlugin: def setup_method(self): + is_disabled.cache_clear() + transport.cache_clear() + config_path.cache_clear() self.old_modules = dict(sys.modules) def teardown_method(self): + is_disabled.cache_clear() + transport.cache_clear() + config_path.cache_clear() # Remove any new modules imported during the test run. This lets us # import the same source files for more than one test. for mod in [m for m in sys.modules if m not in self.old_modules]: @@ -39,37 +46,81 @@ def teardown_method(self): @pytest.mark.parametrize( "mocks, expected", [ + # 0: not disabled but no configuration found + ([], 0), + # 0: env_var disabled = true ([patch.dict(os.environ, {"OPENLINEAGE_DISABLED": "true"})], 0), + # 0: env_var disabled = false but no configuration found + ([patch.dict(os.environ, {"OPENLINEAGE_DISABLED": "false"})], 0), + # 0: conf disabled = true + ([conf_vars({("openlineage", "disabled"): "True"})], 0), + # 0: conf disabled = false but no configuration found + ([conf_vars({("openlineage", "disabled"): "False"})], 0), + # 0: env_var disabled = true and conf disabled = false + ( + [ + conf_vars({("openlineage", "disabled"): "F"}), + patch.dict(os.environ, {"OPENLINEAGE_DISABLED": "True"}), + ], + 0, + ), + # 0: env_var disabled = false and conf disabled = true + ( + [ + conf_vars({("openlineage", "disabled"): "T"}), + patch.dict(os.environ, {"OPENLINEAGE_DISABLED": "False"}), + ], + 0, + ), + # 0: env_var disabled = true and some config present ( [ conf_vars( {("openlineage", "transport"): '{"type": "http", "url": "http://localhost:5000"}'} ), - patch.dict(os.environ, {"OPENLINEAGE_DISABLED": "true"}), + patch.dict(os.environ, {"OPENLINEAGE_DISABLED": "1"}), ], 0, ), - ([patch.dict(os.environ, {"OPENLINEAGE_DISABLED": "false"})], 0), + # 0: conf disabled = true and some config present ( [ conf_vars( { - ("openlineage", "disabled"): "False", + ("openlineage", "transport"): '{"type": "http", "url": "http://localhost:5000"}', + ("openlineage", "disabled"): "true", + } + ) + ], + 0, + ), + # 1: conf disabled = false and some config present + ( + [ + conf_vars( + { + ("openlineage", "disabled"): "0", ("openlineage", "transport"): '{"type": "http", "url": "http://localhost:5000"}', } ) ], 1, ), + # 1: env_var disabled = false and some config present ( [ - conf_vars({("openlineage", "disabled"): "False"}), - patch.dict(os.environ, {"OPENLINEAGE_DISABLED": "true"}), + conf_vars( + { + ("openlineage", "transport"): '{"type": "http", "url": "http://localhost:5000"}', + } + ), + patch.dict(os.environ, {"OPENLINEAGE_DISABLED": "false"}), ], - 0, + 1, ), - ([], 0), + # 1: not explicitly disabled and url present ([patch.dict(os.environ, {"OPENLINEAGE_URL": "http://localhost:8080"})], 1), + # 1: not explicitly disabled and transport present ( [ conf_vars( diff --git a/tests/providers/openlineage/plugins/test_openlineage_adapter.py b/tests/providers/openlineage/plugins/test_openlineage_adapter.py index ec1dfc6eb5741..8169cedd5f25b 100644 --- a/tests/providers/openlineage/plugins/test_openlineage_adapter.py +++ b/tests/providers/openlineage/plugins/test_openlineage_adapter.py @@ -38,13 +38,43 @@ ) from openlineage.client.run import Dataset, Job, Run, RunEvent, RunState +from airflow.providers.openlineage.conf import ( + config_path, + custom_extractors, + disabled_operators, + is_disabled, + is_source_enabled, + namespace, + transport, +) from airflow.providers.openlineage.extractors import OperatorLineage -from airflow.providers.openlineage.plugins.adapter import _DAG_NAMESPACE, _PRODUCER, OpenLineageAdapter +from airflow.providers.openlineage.plugins.adapter import _PRODUCER, OpenLineageAdapter from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test +@pytest.fixture(autouse=True) +def clear_cache(): + config_path.cache_clear() + is_source_enabled.cache_clear() + disabled_operators.cache_clear() + custom_extractors.cache_clear() + namespace.cache_clear() + transport.cache_clear() + is_disabled.cache_clear() + try: + yield + finally: + config_path.cache_clear() + is_source_enabled.cache_clear() + disabled_operators.cache_clear() + custom_extractors.cache_clear() + namespace.cache_clear() + transport.cache_clear() + is_disabled.cache_clear() + + @patch.dict( os.environ, {"OPENLINEAGE_URL": "http://ol-api:5000", "OPENLINEAGE_API_KEY": "api-key"}, @@ -115,6 +145,9 @@ def test_create_client_overrides_env_vars(): assert client.transport.kind == "http" assert client.transport.url == "http://localhost:5050" + transport.cache_clear() + config_path.cache_clear() + with conf_vars({("openlineage", "transport"): '{"type": "console"}'}): client = OpenLineageAdapter().get_or_create_openlineage_client() @@ -162,7 +195,7 @@ def test_emit_start_event(mock_stats_incr, mock_stats_timer): }, ), job=Job( - namespace=_DAG_NAMESPACE, + namespace=namespace(), name="job", facets={ "documentation": DocumentationJobFacet(description="description"), @@ -228,18 +261,18 @@ def test_emit_start_event_with_additional_information(mock_stats_incr, mock_stat ), "parent": ParentRunFacet( run={"runId": "parent_run_id"}, - job={"namespace": _DAG_NAMESPACE, "name": "parent_job_name"}, + job={"namespace": namespace(), "name": "parent_job_name"}, ), "parentRun": ParentRunFacet( run={"runId": "parent_run_id"}, - job={"namespace": _DAG_NAMESPACE, "name": "parent_job_name"}, + job={"namespace": namespace(), "name": "parent_job_name"}, ), "externalQuery1": ExternalQueryRunFacet(externalQueryId="123", source="source"), "externalQuery2": ExternalQueryRunFacet(externalQueryId="999", source="source"), }, ), job=Job( - namespace=_DAG_NAMESPACE, + namespace=namespace(), name="job", facets={ "documentation": DocumentationJobFacet(description="description"), @@ -294,7 +327,7 @@ def test_emit_complete_event(mock_stats_incr, mock_stats_timer): eventTime=event_time, run=Run(runId=run_id, facets={}), job=Job( - namespace=_DAG_NAMESPACE, + namespace=namespace(), name="job", facets={ "jobType": JobTypeJobFacet( @@ -346,11 +379,11 @@ def test_emit_complete_event_with_additional_information(mock_stats_incr, mock_s facets={ "parent": ParentRunFacet( run={"runId": "parent_run_id"}, - job={"namespace": _DAG_NAMESPACE, "name": "parent_job_name"}, + job={"namespace": namespace(), "name": "parent_job_name"}, ), "parentRun": ParentRunFacet( run={"runId": "parent_run_id"}, - job={"namespace": _DAG_NAMESPACE, "name": "parent_job_name"}, + job={"namespace": namespace(), "name": "parent_job_name"}, ), "externalQuery": ExternalQueryRunFacet(externalQueryId="123", source="source"), }, @@ -404,7 +437,7 @@ def test_emit_failed_event(mock_stats_incr, mock_stats_timer): eventTime=event_time, run=Run(runId=run_id, facets={}), job=Job( - namespace=_DAG_NAMESPACE, + namespace=namespace(), name="job", facets={ "jobType": JobTypeJobFacet( @@ -456,11 +489,11 @@ def test_emit_failed_event_with_additional_information(mock_stats_incr, mock_sta facets={ "parent": ParentRunFacet( run={"runId": "parent_run_id"}, - job={"namespace": _DAG_NAMESPACE, "name": "parent_job_name"}, + job={"namespace": namespace(), "name": "parent_job_name"}, ), "parentRun": ParentRunFacet( run={"runId": "parent_run_id"}, - job={"namespace": _DAG_NAMESPACE, "name": "parent_job_name"}, + job={"namespace": namespace(), "name": "parent_job_name"}, ), "externalQuery": ExternalQueryRunFacet(externalQueryId="123", source="source"), }, @@ -529,7 +562,7 @@ def test_emit_dag_started_event(mock_stats_incr, mock_stats_timer, uuid): }, ), job=Job( - namespace=_DAG_NAMESPACE, + namespace=namespace(), name="dag_id", facets={ "jobType": JobTypeJobFacet( @@ -579,7 +612,7 @@ def test_emit_dag_complete_event(mock_stats_incr, mock_stats_timer, uuid): eventTime=event_time.isoformat(), run=Run(runId=random_uuid, facets={}), job=Job( - namespace=_DAG_NAMESPACE, + namespace=namespace(), name="dag_id", facets={ "jobType": JobTypeJobFacet( @@ -636,7 +669,7 @@ def test_emit_dag_failed_event(mock_stats_incr, mock_stats_timer, uuid): }, ), job=Job( - namespace=_DAG_NAMESPACE, + namespace=namespace(), name="dag_id", facets={ "jobType": JobTypeJobFacet( @@ -688,7 +721,7 @@ def test_build_dag_run_id_different_inputs_give_different_results(): def test_build_dag_run_id_uses_correct_methods_underneath(): dag_id = "test_dag" dag_run_id = "run_1" - expected = str(uuid.uuid3(uuid.NAMESPACE_URL, f"{_DAG_NAMESPACE}.{dag_id}.{dag_run_id}")) + expected = str(uuid.uuid3(uuid.NAMESPACE_URL, f"{namespace()}.{dag_id}.{dag_run_id}")) actual = OpenLineageAdapter.build_dag_run_id(dag_id, dag_run_id) assert actual == expected @@ -710,7 +743,123 @@ def test_build_task_instance_run_id_uses_correct_methods_underneath(): execution_date = "2023-01-01" try_number = 1 expected = str( - uuid.uuid3(uuid.NAMESPACE_URL, f"{_DAG_NAMESPACE}.{dag_id}.{task_id}.{execution_date}.{try_number}") + uuid.uuid3(uuid.NAMESPACE_URL, f"{namespace()}.{dag_id}.{task_id}.{execution_date}.{try_number}") ) actual = OpenLineageAdapter.build_task_instance_run_id(dag_id, task_id, execution_date, try_number) assert actual == expected + + +def test_configuration_precedence_when_creating_ol_client(): + _section_name = "openlineage" + current_folder = pathlib.Path(__file__).parent.resolve() + yaml_config = str((current_folder / "openlineage_configs" / "http.yaml").resolve()) + + # First, check config_path in Airflow configuration (airflow.cfg or env variable) + with patch.dict( + os.environ, + { + "OPENLINEAGE_URL": "http://wrong.com", + "OPENLINEAGE_API_KEY": "wrong_api_key", + "OPENLINEAGE_CONFIG": "some/config.yml", + }, + clear=True, + ): + with conf_vars( + { + (_section_name, "transport"): '{"type": "kafka", "topic": "test", "config": {"acks": "all"}}', + (_section_name, "config_path"): yaml_config, + }, + ): + client = OpenLineageAdapter().get_or_create_openlineage_client() + assert client.transport.kind == "http" + assert client.transport.config.url == "http://localhost:5050" + assert client.transport.config.endpoint == "api/v1/lineage" + assert client.transport.config.auth.api_key == "random_token" + + config_path.cache_clear() + transport.cache_clear() + + # Second, check transport in Airflow configuration (airflow.cfg or env variable) + with patch.dict( + os.environ, + { + "OPENLINEAGE_URL": "http://wrong.com", + "OPENLINEAGE_API_KEY": "wrong_api_key", + "OPENLINEAGE_CONFIG": "some/config.yml", + }, + clear=True, + ): + with conf_vars( + { + (_section_name, "transport"): '{"type": "kafka", "topic": "test", "config": {"acks": "all"}}', + (_section_name, "config_path"): "", + }, + ): + client = OpenLineageAdapter().get_or_create_openlineage_client() + assert client.transport.kind == "kafka" + assert client.transport.kafka_config.topic == "test" + assert client.transport.kafka_config.config == {"acks": "all"} + + config_path.cache_clear() + transport.cache_clear() + + # Third, check legacy OPENLINEAGE_CONFIG env variable + with patch.dict( + os.environ, + { + "OPENLINEAGE_URL": "http://wrong.com", + "OPENLINEAGE_API_KEY": "wrong_api_key", + "OPENLINEAGE_CONFIG": yaml_config, + }, + clear=True, + ): + with conf_vars( + { + (_section_name, "transport"): "", + (_section_name, "config_path"): "", + }, + ): + client = OpenLineageAdapter().get_or_create_openlineage_client() + assert client.transport.kind == "http" + assert client.transport.config.url == "http://localhost:5050" + assert client.transport.config.endpoint == "api/v1/lineage" + assert client.transport.config.auth.api_key == "random_token" + + config_path.cache_clear() + transport.cache_clear() + + # Fourth, check legacy OPENLINEAGE_URL env variable + with patch.dict( + os.environ, + { + "OPENLINEAGE_URL": "http://test.com", + "OPENLINEAGE_API_KEY": "test_api_key", + "OPENLINEAGE_CONFIG": "", + }, + clear=True, + ): + with conf_vars( + { + (_section_name, "transport"): "", + (_section_name, "config_path"): "", + }, + ): + client = OpenLineageAdapter().get_or_create_openlineage_client() + assert client.transport.kind == "http" + assert client.transport.config.url == "http://test.com" + assert client.transport.config.endpoint == "api/v1/lineage" + assert client.transport.config.auth.api_key == "test_api_key" + + config_path.cache_clear() + transport.cache_clear() + + # If all else fails, use console transport + with patch.dict(os.environ, {}, clear=True): + with conf_vars( + { + (_section_name, "transport"): "", + (_section_name, "config_path"): "", + }, + ): + client = OpenLineageAdapter().get_or_create_openlineage_client() + assert client.transport.kind == "console" diff --git a/tests/providers/openlineage/plugins/test_utils.py b/tests/providers/openlineage/plugins/test_utils.py index 9984007083f0d..4a9b681733a90 100644 --- a/tests/providers/openlineage/plugins/test_utils.py +++ b/tests/providers/openlineage/plugins/test_utils.py @@ -21,6 +21,7 @@ import uuid from json import JSONEncoder from typing import Any +from unittest.mock import patch import pytest from attrs import define @@ -28,10 +29,13 @@ from pkg_resources import parse_version from airflow.models import DAG as AIRFLOW_DAG, DagModel +from airflow.operators.bash import BashOperator from airflow.providers.openlineage.utils.utils import ( InfoJsonEncodable, OpenLineageRedactor, _is_name_redactable, + get_fully_qualified_class_name, + is_operator_disabled, ) from airflow.utils import timezone from airflow.utils.log.secrets_masker import _secrets_masker @@ -170,3 +174,29 @@ class NestedMixined(RedactMixin): assert redactor.redact({"password": "passwd"}) == {"password": "***"} redacted_nested = redactor.redact(NestedMixined("passwd", NestedMixined("passwd", None))) assert redacted_nested == NestedMixined("***", NestedMixined("passwd", None)) + + +def test_get_fully_qualified_class_name(): + from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter + + result = get_fully_qualified_class_name(BashOperator(task_id="test", bash_command="exit 0;")) + assert result == "airflow.operators.bash.BashOperator" + + result = get_fully_qualified_class_name(OpenLineageAdapter()) + assert result == "airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter" + + +@patch("airflow.providers.openlineage.conf.disabled_operators") +def test_is_operator_disabled(mock_disabled_operators): + mock_disabled_operators.return_value = {} + op = BashOperator(task_id="test", bash_command="exit 0;") + assert is_operator_disabled(op) is False + + mock_disabled_operators.return_value = {"random_string"} + assert is_operator_disabled(op) is False + + mock_disabled_operators.return_value = { + "airflow.operators.bash.BashOperator", + "airflow.operators.python.PythonOperator", + } + assert is_operator_disabled(op) is True diff --git a/tests/providers/openlineage/test_conf.py b/tests/providers/openlineage/test_conf.py new file mode 100644 index 0000000000000..95a7cde3e5b0a --- /dev/null +++ b/tests/providers/openlineage/test_conf.py @@ -0,0 +1,389 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from unittest import mock + +import pytest + +from airflow.providers.openlineage.conf import ( + config_path, + custom_extractors, + disabled_operators, + is_disabled, + is_source_enabled, + namespace, + transport, +) +from tests.test_utils.config import conf_vars, env_vars + +_CONFIG_SECTION = "openlineage" +_VAR_CONFIG_PATH = "OPENLINEAGE_CONFIG" +_CONFIG_OPTION_CONFIG_PATH = "config_path" +_VAR_DISABLE_SOURCE_CODE = "OPENLINEAGE_AIRFLOW_DISABLE_SOURCE_CODE" +_CONFIG_OPTION_DISABLE_SOURCE_CODE = "disable_source_code" +_CONFIG_OPTION_DISABLED_FOR_OPERATORS = "disabled_for_operators" +_VAR_EXTRACTORS = "OPENLINEAGE_EXTRACTORS" +_CONFIG_OPTION_EXTRACTORS = "extractors" +_VAR_NAMESPACE = "OPENLINEAGE_NAMESPACE" +_CONFIG_OPTION_NAMESPACE = "namespace" +_CONFIG_OPTION_TRANSPORT = "transport" +_VAR_DISABLED = "OPENLINEAGE_DISABLED" +_CONFIG_OPTION_DISABLED = "disabled" +_VAR_URL = "OPENLINEAGE_URL" + + +@pytest.fixture(autouse=True) +def clear_cache(): + config_path.cache_clear() + is_source_enabled.cache_clear() + disabled_operators.cache_clear() + custom_extractors.cache_clear() + namespace.cache_clear() + transport.cache_clear() + is_disabled.cache_clear() + try: + yield + finally: + config_path.cache_clear() + is_source_enabled.cache_clear() + disabled_operators.cache_clear() + custom_extractors.cache_clear() + namespace.cache_clear() + transport.cache_clear() + is_disabled.cache_clear() + + +@env_vars({_VAR_CONFIG_PATH: "env_var_path"}) +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): None}) +def test_config_path_legacy_env_var_is_used_when_no_conf_option_set(): + assert config_path() == "env_var_path" + + +@env_vars({_VAR_CONFIG_PATH: "env_var_path"}) +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "config_path"}) +def test_config_path_conf_option_has_precedence_over_legacy_env_var(): + assert config_path() == "config_path" + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): ""}) +def test_config_path_empty_conf_option(): + assert config_path() == "" + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): None}) +def test_config_path_do_not_fail_if_conf_option_missing(): + assert config_path() == "" + + +@env_vars({_VAR_DISABLE_SOURCE_CODE: "true"}) +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLE_SOURCE_CODE): None}) +def test_disable_source_code_legacy_env_var_is_used_when_no_conf_option_set(): + assert is_source_enabled() is False + + +@env_vars({_VAR_DISABLE_SOURCE_CODE: "false"}) +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLE_SOURCE_CODE): "true"}) +def test_disable_source_code_conf_option_has_precedence_over_legacy_env_var(): + assert is_source_enabled() is False + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLE_SOURCE_CODE): "asdadawlaksnd"}) +def test_disable_source_code_conf_option_not_working_for_random_string(): + assert is_source_enabled() is True + + +@env_vars({_VAR_DISABLE_SOURCE_CODE: "asdadawlaksnd"}) +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLE_SOURCE_CODE): None}) +def test_disable_source_code_legacy_env_var_not_working_for_random_string(): + assert is_source_enabled() is True + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLE_SOURCE_CODE): ""}) +def test_disable_source_code_empty_conf_option(): + assert is_source_enabled() is True + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLE_SOURCE_CODE): None}) +def test_disable_source_code_do_not_fail_if_conf_option_missing(): + assert is_source_enabled() is True + + +@pytest.mark.parametrize( + ("var_string", "expected"), + ( + (" ", {}), + (" ; ", {}), + (";", {}), + ("path.to.Operator ;", {"path.to.Operator"}), + (" ; path.to.Operator ;", {"path.to.Operator"}), + ("path.to.Operator", {"path.to.Operator"}), + ("path.to.Operator ; path.to.second.Operator ; ", {"path.to.Operator", "path.to.second.Operator"}), + ("path.to.Operator;path.to.second.Operator", {"path.to.Operator", "path.to.second.Operator"}), + ), +) +def test_disabled_for_operators(var_string, expected): + with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLED_FOR_OPERATORS): var_string}): + result = disabled_operators() + assert isinstance(result, set) + assert sorted(result) == sorted(expected) + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLED_FOR_OPERATORS): ""}) +def test_disabled_for_operators_empty_conf_option(): + assert disabled_operators() == set() + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLED_FOR_OPERATORS): None}) +def test_disabled_for_operators_do_not_fail_if_conf_option_missing(): + assert disabled_operators() == set() + + +@env_vars({_VAR_EXTRACTORS: "path.Extractor"}) +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_EXTRACTORS): None}) +def test_extractors_legacy_legacy_env_var_is_used_when_no_conf_option_set(): + assert os.getenv(_VAR_EXTRACTORS) == "path.Extractor" + assert custom_extractors() == {"path.Extractor"} + + +@env_vars({_VAR_EXTRACTORS: "env.Extractor"}) +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_EXTRACTORS): "conf.Extractor"}) +def test_extractors_conf_option_has_precedence_over_legacy_env_var(): + assert os.getenv(_VAR_EXTRACTORS) == "env.Extractor" + assert custom_extractors() == {"conf.Extractor"} + + +@pytest.mark.parametrize( + ("var_string", "expected"), + ( + ("path.to.Extractor ;", {"path.to.Extractor"}), + (" ; path.to.Extractor ;", {"path.to.Extractor"}), + ("path.to.Extractor", {"path.to.Extractor"}), + ( + "path.to.Extractor ; path.to.second.Extractor ; ", + {"path.to.Extractor", "path.to.second.Extractor"}, + ), + ("path.to.Extractor;path.to.second.Extractor", {"path.to.Extractor", "path.to.second.Extractor"}), + ), +) +def test_extractors(var_string, expected): + with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_EXTRACTORS): var_string}): + result = custom_extractors() + assert isinstance(result, set) + assert sorted(result) == sorted(expected) + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_EXTRACTORS): ""}) +def test_extractors_empty_conf_option(): + assert custom_extractors() == set() + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_EXTRACTORS): None}) +def test_extractors_do_not_fail_if_conf_option_missing(): + assert custom_extractors() == set() + + +@env_vars({_VAR_NAMESPACE: "my_custom_namespace"}) +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_NAMESPACE): None}) +def test_namespace_legacy_env_var_is_used_when_no_conf_option_set(): + assert os.getenv(_VAR_NAMESPACE) == "my_custom_namespace" + assert namespace() == "my_custom_namespace" + + +@env_vars({_VAR_NAMESPACE: "env_namespace"}) +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_NAMESPACE): "my_custom_namespace"}) +def test_namespace_conf_option_has_precedence_over_legacy_env_var(): + assert os.getenv(_VAR_NAMESPACE) == "env_namespace" + assert namespace() == "my_custom_namespace" + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_NAMESPACE): ""}) +def test_namespace_empty_conf_option(): + assert namespace() == "default" + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_NAMESPACE): None}) +def test_namespace_do_not_fail_if_conf_option_missing(): + assert namespace() == "default" + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): '{"valid": "json"}'}) +def test_transport_valid(): + assert transport() == {"valid": "json"} + + +@pytest.mark.parametrize("transport_value", ('["a", "b"]', "[]", '[{"a": "b"}]')) +def test_transport_not_valid(transport_value): + with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): transport_value}): + with pytest.raises(ValueError): + transport() + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): ""}) +def test_transport_empty_conf_option(): + assert transport() == {} + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): None}) +def test_transport_do_not_fail_if_conf_option_missing(): + assert transport() == {} + + +@pytest.mark.parametrize("disabled", ("1", "t", "T", "true", "TRUE", "True")) +@mock.patch.dict(os.environ, {_VAR_URL: ""}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "", + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): "", + } +) +def test_is_disabled_possible_values_for_disabling(disabled): + with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): disabled}): + assert is_disabled() is True + + +@mock.patch.dict(os.environ, {_VAR_URL: "https://test.com"}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "", + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): "", + (_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "asdadawlaksnd", + } +) +def test_is_disabled_is_not_disabled_by_random_string(): + assert is_disabled() is False + + +@mock.patch.dict(os.environ, {_VAR_URL: "https://test.com"}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "", + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): "", + (_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "", + } +) +def test_is_disabled_is_false_when_not_explicitly_disabled_and_url_set(): + assert is_disabled() is False + + +@mock.patch.dict(os.environ, {_VAR_URL: ""}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "", + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): '{"valid": "transport"}', + (_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "", + } +) +def test_is_disabled_is_false_when_not_explicitly_disabled_and_transport_set(): + assert is_disabled() is False + + +@mock.patch.dict(os.environ, {_VAR_URL: ""}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "some/path.yml", + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): "", + (_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "", + } +) +def test_is_disabled_is_false_when_not_explicitly_disabled_and_config_path_set(): + assert is_disabled() is False + + +@mock.patch.dict(os.environ, {_VAR_URL: "https://test.com"}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "some/path.yml", + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): '{"valid": "transport"}', + (_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "true", + } +) +def test_is_disabled_conf_option_is_enough_to_disable(): + assert is_disabled() is True + + +@mock.patch.dict(os.environ, {_VAR_URL: "https://test.com", _VAR_DISABLED: "true"}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "some/path.yml", + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): '{"valid": "transport"}', + (_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "", + } +) +def test_is_disabled_legacy_env_var_is_enough_to_disable(): + assert is_disabled() is True + + +@mock.patch.dict(os.environ, {_VAR_URL: "", _VAR_DISABLED: "true"}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): None, + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): None, + (_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): None, + } +) +def test_is_disabled_legacy_env_var_is_used_when_no_config(): + assert is_disabled() is True + + +@mock.patch.dict(os.environ, {_VAR_URL: "", _VAR_DISABLED: "false"}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "some/path.yml", + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): "", + (_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "true", + } +) +def test_is_disabled_conf_true_has_precedence_over_env_var_false(): + assert is_disabled() is True + + +@mock.patch.dict(os.environ, {_VAR_URL: "", _VAR_DISABLED: "true"}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "some/path.yml", + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): "", + (_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "false", + } +) +def test_is_disabled_env_var_true_has_precedence_over_conf_false(): + assert is_disabled() is True + + +@mock.patch.dict(os.environ, {_VAR_URL: ""}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "", + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): "", + (_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): "", + } +) +def test_is_disabled_empty_conf_option(): + assert is_disabled() is True + + +@mock.patch.dict(os.environ, {_VAR_URL: ""}, clear=True) +@conf_vars( + { + (_CONFIG_SECTION, _CONFIG_OPTION_CONFIG_PATH): "", + (_CONFIG_SECTION, _CONFIG_OPTION_TRANSPORT): "", + (_CONFIG_SECTION, _CONFIG_OPTION_DISABLED): None, + } +) +def test_is_disabled_do_not_fail_if_conf_option_missing(): + assert is_disabled() is True