diff --git a/airflow/auth/managers/simple/simple_auth_manager.py b/airflow/auth/managers/simple/simple_auth_manager.py index 451068733667c..4a9639a998c46 100644 --- a/airflow/auth/managers/simple/simple_auth_manager.py +++ b/airflow/auth/managers/simple/simple_auth_manager.py @@ -221,7 +221,12 @@ def _is_authorized( user = self.get_user() if not user: return False - role_str = user.get_role().upper() + + user_role = user.get_role() + if not user_role: + return False + + role_str = user_role.upper() role = SimpleAuthManagerRole[role_str] if role == SimpleAuthManagerRole.ADMIN: return True diff --git a/airflow/auth/managers/simple/user.py b/airflow/auth/managers/simple/user.py index fa032f596ee44..f4591b0b1c751 100644 --- a/airflow/auth/managers/simple/user.py +++ b/airflow/auth/managers/simple/user.py @@ -24,10 +24,10 @@ class SimpleAuthManagerUser(BaseUser): User model for users managed by the simple auth manager. :param username: The username - :param role: The role associated to the user + :param role: The role associated to the user. If not provided, the user has no permission """ - def __init__(self, *, username: str, role: str) -> None: + def __init__(self, *, username: str, role: str | None) -> None: self.username = username self.role = role @@ -37,5 +37,5 @@ def get_id(self) -> str: def get_name(self) -> str: return self.username - def get_role(self): + def get_role(self) -> str | None: return self.role diff --git a/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py b/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py new file mode 100644 index 0000000000000..321a1e2bbafa8 --- /dev/null +++ b/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Update dag_run_note.user_id and task_instance_note.user_id columns to String. + +Revision ID: 44eabb1904b4 +Revises: 16cbcb1c8c36 +Create Date: 2024-09-27 09:57:29.830521 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "44eabb1904b4" +down_revision = "16cbcb1c8c36" +branch_labels = None +depends_on = None +airflow_version = "3.0.0" + + +def upgrade(): + with op.batch_alter_table("dag_run_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.String(length=128)) + with op.batch_alter_table("task_instance_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.String(length=128)) + + +def downgrade(): + with op.batch_alter_table("dag_run_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.Integer(), postgresql_using="user_id::integer") + with op.batch_alter_table("task_instance_note") as batch_op: + batch_op.alter_column("user_id", type_=sa.Integer(), postgresql_using="user_id::integer") diff --git a/airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py b/airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py similarity index 98% rename from airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py rename to airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py index 5c8aec69e9be9..6016dd9658908 100644 --- a/airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py +++ b/airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py @@ -30,7 +30,7 @@ also rename the one on DatasetAliasModel here for consistency. Revision ID: 0d9e73a75ee4 -Revises: 16cbcb1c8c36 +Revises: 44eabb1904b4 Create Date: 2024-08-13 09:45:32.213222 """ @@ -42,7 +42,7 @@ # revision identifiers, used by Alembic. revision = "0d9e73a75ee4" -down_revision = "16cbcb1c8c36" +down_revision = "44eabb1904b4" branch_labels = None depends_on = None airflow_version = "3.0.0" diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 5d53e51763dff..4928c7fcbd8f7 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1687,7 +1687,7 @@ class DagRunNote(Base): __tablename__ = "dag_run_note" - user_id = Column(Integer, nullable=True) + user_id = Column(String(128), nullable=True) dag_run_id = Column(Integer, primary_key=True, nullable=False) content = Column(String(1000).with_variant(Text(1000), "mysql")) created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index b19e65486307d..333a4cad91cbe 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -4002,7 +4002,7 @@ class TaskInstanceNote(TaskInstanceDependencies): __tablename__ = "task_instance_note" - user_id = Column(Integer, nullable=True) + user_id = Column(String(128), nullable=True) task_id = Column(StringID(), primary_key=True, nullable=False) dag_id = Column(StringID(), primary_key=True, nullable=False) run_id = Column(StringID(), primary_key=True, nullable=False) diff --git a/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py b/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py index 3a0328fe9962c..7b50338733453 100644 --- a/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py +++ b/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py @@ -62,9 +62,7 @@ def requires_authentication(function: T): @wraps(function) def decorated(*args, **kwargs): - if auth_current_user() is not None or current_app.appbuilder.get_app.config.get( - "AUTH_ROLE_PUBLIC", None - ): + if auth_current_user() is not None or current_app.config.get("AUTH_ROLE_PUBLIC", None): return function(*args, **kwargs) else: return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"}) diff --git a/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py b/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py index d8d5a95ee676b..f2038b27597c1 100644 --- a/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py +++ b/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py @@ -124,7 +124,7 @@ def requires_authentication(function: T, find_user: Callable[[str], BaseUser] | @wraps(function) def decorated(*args, **kwargs): - if current_app.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None): + if current_app.config.get("AUTH_ROLE_PUBLIC", None): response = function(*args, **kwargs) return make_response(response) diff --git a/airflow/providers/fab/auth_manager/models/anonymous_user.py b/airflow/providers/fab/auth_manager/models/anonymous_user.py index 2f294fd9e5d0e..9afb2cdff635f 100644 --- a/airflow/providers/fab/auth_manager/models/anonymous_user.py +++ b/airflow/providers/fab/auth_manager/models/anonymous_user.py @@ -35,7 +35,7 @@ class AnonymousUser(AnonymousUserMixin, BaseUser): @property def roles(self): if not self._roles: - public_role = current_app.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None) + public_role = current_app.config.get("AUTH_ROLE_PUBLIC", None) self._roles = {current_app.appbuilder.sm.find_role(public_role)} if public_role else set() return self._roles diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256 index e4a952da1b9fd..bca068fde6749 100644 --- a/docs/apache-airflow/img/airflow_erd.sha256 +++ b/docs/apache-airflow/img/airflow_erd.sha256 @@ -1 +1 @@ -c33e9a583a5b29eb748ebd50e117643e11bcb2a9b61ec017efd690621e22769b \ No newline at end of file +64dfad12dfd49f033c4723c2f3bb3bac58dd956136fb24a87a2e5a6ae176ec1a \ No newline at end of file diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg index 76fbd8f841f25..4eb6c2ee70917 100644 --- a/docs/apache-airflow/img/airflow_erd.svg +++ b/docs/apache-airflow/img/airflow_erd.svg @@ -1394,7 +1394,7 @@ user_id - [INTEGER] + [VARCHAR(100)] @@ -1813,7 +1813,7 @@ user_id - [INTEGER] + [VARCHAR(100)] diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst index a547d03d75be6..e4fb2dfa332eb 100644 --- a/docs/apache-airflow/migrations-ref.rst +++ b/docs/apache-airflow/migrations-ref.rst @@ -39,7 +39,10 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``0d9e73a75ee4`` (head) | ``16cbcb1c8c36`` | ``3.0.0`` | Add name and group fields to DatasetModel. | +| ``0d9e73a75ee4`` (head) | ``44eabb1904b4`` | ``3.0.0`` | Add name and group fields to DatasetModel. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``44eabb1904b4`` | ``16cbcb1c8c36`` | ``3.0.0`` | Update dag_run_note.user_id and task_instance_note.user_id | +| | | | columns to String. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``16cbcb1c8c36`` | ``522625f6d606`` | ``3.0.0`` | Remove redundant index. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py index 38e7b58cb5981..6a23b2cf11d93 100644 --- a/tests/api_connexion/conftest.py +++ b/tests/api_connexion/conftest.py @@ -36,9 +36,16 @@ def minimal_app_for_api(): ] ) def factory(): - with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}): + with conf_vars( + { + ("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend", + ( + "core", + "auth_manager", + ): "airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager", + } + ): _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore - _app.config["AUTH_ROLE_PUBLIC"] = None return _app return factory() diff --git a/tests/api_connexion/endpoints/test_backfill_endpoint.py b/tests/api_connexion/endpoints/test_backfill_endpoint.py index 51a4faf40055c..07b2a3fd56c2d 100644 --- a/tests/api_connexion/endpoints/test_backfill_endpoint.py +++ b/tests/api_connexion/endpoints/test_backfill_endpoint.py @@ -29,7 +29,6 @@ from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import create_user, delete_user @@ -50,25 +49,11 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), - ], - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user(app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) with DAG( DAG_ID, @@ -93,9 +78,8 @@ def configured_app(minimal_app_for_api): yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBackfillEndpoint: @@ -178,7 +162,6 @@ def test_should_respond_200(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -240,7 +223,6 @@ def test_no_exist(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -268,7 +250,6 @@ class TestCreateBackfill(TestBackfillEndpoint): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -347,7 +328,6 @@ def test_should_respond_200(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), @@ -409,7 +389,6 @@ def test_should_respond_200(self, session): @pytest.mark.parametrize( "user, expected", [ - ("test_granular_permissions", 200), ("test_no_permissions", 403), ("test", 200), (None, 401), diff --git a/tests/api_connexion/endpoints/test_config_endpoint.py b/tests/api_connexion/endpoints/test_config_endpoint.py index 475753a4a902e..bd88c491c952b 100644 --- a/tests/api_connexion/endpoints/test_config_endpoint.py +++ b/tests/api_connexion/endpoints/test_config_endpoint.py @@ -21,7 +21,6 @@ import pytest -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -54,18 +53,17 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) with conf_vars({("webserver", "expose_config"): "True"}): yield minimal_app_for_api - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestGetConfig: diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py b/tests/api_connexion/endpoints/test_connection_endpoint.py index a19b046aa2747..a140046656e31 100644 --- a/tests/api_connexion/endpoints/test_connection_endpoint.py +++ b/tests/api_connexion/endpoints/test_connection_endpoint.py @@ -24,7 +24,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Connection from airflow.secrets.environment_variables import CONN_ENV_PREFIX -from airflow.security import permissions from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -38,22 +37,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestConnectionEndpoint: diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 9905b4e27ab2c..6d4ffc2d06d2c 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -28,7 +28,6 @@ from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils.session import provide_session from airflow.utils.state import TaskInstanceState from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user @@ -56,33 +55,11 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), - ], - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user(app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) with DAG( DAG_ID, @@ -107,9 +84,8 @@ def configured_app(minimal_app_for_api): yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagEndpoint: @@ -258,13 +234,6 @@ def test_should_respond_200_with_schedule_none(self, session): "pickle_id": None, } == response.json - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(1) - response = self.client.get( - "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - def test_should_respond_404(self): response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={"REMOTE_USER": "test"}) assert response.status_code == 404 @@ -282,13 +251,6 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 - def test_should_respond_403_with_granular_access_for_different_dag(self): - self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 403 - @pytest.mark.parametrize( "fields", [ @@ -961,15 +923,6 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids): assert expected_dag_ids == dag_ids - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" - @pytest.mark.parametrize( "url, expected_dag_ids", [ @@ -1252,18 +1205,6 @@ def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, sessio session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None, expected_extra=payload ) - def test_should_respond_200_on_patch_with_granular_dag_access(self, session): - self._create_dag_models(1) - response = self.client.patch( - "/api/v1/dags/TEST_DAG_1", - json={ - "is_paused": False, - }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None) - def test_should_respond_400_on_invalid_request(self): patch_body = { "is_paused": True, @@ -1279,24 +1220,6 @@ def test_should_respond_400_on_invalid_request(self): "type": EXCEPTIONS_LINK_MAP[400], } - def test_validation_error_raises_400(self): - patch_body = { - "ispaused": True, - } - dag_model = self._create_dag_model() - response = self.client.patch( - f"/api/v1/dags/{dag_model.dag_id}", - json=patch_body, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 400 - assert response.json == { - "detail": "{'ispaused': ['Unknown field.']}", - "status": 400, - "title": "Bad Request", - "type": EXCEPTIONS_LINK_MAP[400], - } - def test_non_existing_dag_raises_not_found(self): patch_body = { "is_paused": True, @@ -1820,19 +1743,6 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids): assert expected_dag_ids == dag_ids - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(3) - response = self.client.patch( - "api/v1/dags?dag_id_pattern=~", - json={ - "is_paused": False, - }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" - @pytest.mark.parametrize( "url, expected_dag_ids", [ diff --git a/tests/api_connexion/endpoints/test_dag_parsing.py b/tests/api_connexion/endpoints/test_dag_parsing.py index 521d8d9e8cd99..ae42a565dd052 100644 --- a/tests/api_connexion/endpoints/test_dag_parsing.py +++ b/tests/api_connexion/endpoints/test_dag_parsing.py @@ -24,7 +24,6 @@ from airflow.models import DagBag from airflow.models.dagbag import DagPriorityParsingRequest -from airflow.security import permissions from tests.test_utils.api_connexion_utils import create_user, delete_user from tests.test_utils.db import clear_db_dag_parsing_requests @@ -45,21 +44,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)], # type: ignore + role_name="admin", ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - TEST_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_EDIT]}, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagParsingRequest: diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index f3921da7b9c29..73c75b98a43b1 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -30,12 +30,11 @@ from airflow.models.dagrun import DagRun from airflow.models.param import Param from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags @@ -52,79 +51,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, # type: ignore - username="test_no_dag_run_create_permission", - role_name="TestNoDagRunCreatePermission", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, # type: ignore - username="test_dag_view_only", - role_name="TestViewDags", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, # type: ignore - username="test_view_dags", - role_name="TestViewDags", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - ], + role_name="admin", ) - create_user( - app, # type: ignore - username="test_granular_permissions", - role_name="TestGranularDag", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)], - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_ID", - access_control={ - "TestGranularDag": {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}, - "TestNoDagRunCreatePermission": {permissions.RESOURCE_DAG_RUN: {permissions.ACTION_CAN_CREATE}}, - }, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_dag_view_only") # type: ignore - delete_user(app, username="test_view_dags") # type: ignore - delete_user(app, username="test_granular_permissions") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_no_dag_run_create_permission") # type: ignore - delete_roles(app) + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagRunEndpoint: @@ -499,16 +435,6 @@ def test_should_return_all_with_tilde_as_dag_id_and_all_dag_permissions(self): dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] assert dag_run_ids == expected_dag_run_ids - def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): - self._create_test_dag_run(extra_dag=True) - expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"] - response = self.client.get( - "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] - assert dag_run_ids == expected_dag_run_ids - def test_should_raises_401_unauthenticated(self): self._create_test_dag_run() @@ -907,57 +833,6 @@ def test_order_by_raises_for_invalid_attr(self): msg = "Ordering with 'dag_ru' is disallowed or the attribute does not exist on the model" assert response.json["detail"] == msg - def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): - self._create_test_dag_run(extra_dag=True) - expected_response_json_1 = { - "dag_id": "TEST_DAG_ID", - "dag_run_id": "TEST_DAG_RUN_ID_1", - "end_date": None, - "state": "running", - "execution_date": self.default_time, - "logical_date": self.default_time, - "external_trigger": True, - "start_date": self.default_time, - "conf": {}, - "data_interval_end": None, - "data_interval_start": None, - "last_scheduling_decision": None, - "run_type": "manual", - "note": None, - } - expected_response_json_1.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) - expected_response_json_2 = { - "dag_id": "TEST_DAG_ID", - "dag_run_id": "TEST_DAG_RUN_ID_2", - "end_date": None, - "state": "running", - "execution_date": self.default_time_2, - "logical_date": self.default_time_2, - "external_trigger": True, - "start_date": self.default_time, - "conf": {}, - "data_interval_end": None, - "data_interval_start": None, - "last_scheduling_decision": None, - "run_type": "manual", - "note": None, - } - expected_response_json_2.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) - - response = self.client.post( - "api/v1/dags/~/dagRuns/list", - json={"dag_ids": []}, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - assert response.json == { - "dag_runs": [ - expected_response_json_1, - expected_response_json_2, - ], - "total_entries": 2, - } - @pytest.mark.parametrize( "payload, error", [ @@ -1328,15 +1203,6 @@ def test_raises_validation_error_for_invalid_params(self): assert response.status_code == 400 assert "Invalid input for param" in response.json["detail"] - def test_dagrun_trigger_with_dag_level_permissions(self): - self._create_dag("TEST_DAG_ID") - response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns", - json={"conf": {"validated_number": 1}}, - environ_overrides={"REMOTE_USER": "test_no_dag_run_create_permission"}, - ) - assert response.status_code == 200 - @mock.patch("airflow.api_connexion.endpoints.dag_run_endpoint.get_airflow_app") def test_dagrun_creation_exception_is_handled(self, mock_get_app, session): self._create_dag("TEST_DAG_ID") @@ -1627,11 +1493,7 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) - @pytest.mark.parametrize( - "username", - ["test_dag_view_only", "test_view_dags", "test_granular_permissions", "test_no_permissions"], - ) - def test_should_raises_403_unauthorized(self, username): + def test_should_raises_403_unauthorized(self): self._create_dag("TEST_DAG_ID") response = self.client.post( "api/v1/dags/TEST_DAG_ID/dagRuns", @@ -1639,7 +1501,7 @@ def test_should_raises_403_unauthorized(self, username): "dag_run_id": "TEST_DAG_RUN_ID_1", "execution_date": self.default_time, }, - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index a8d1224e034c3..f4df56ba629ae 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -23,7 +23,6 @@ import pytest from airflow.models import DagBag -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags @@ -44,29 +43,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], # type: ignore + role_name="admin", ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - TEST_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - EXAMPLE_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - TEST_MULTIPLE_DAGS_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestGetSource: @@ -123,18 +109,6 @@ def test_should_respond_200_json(self, url_safe_serializer): assert dag_docstring in response.json["content"] assert "application/json" == response.headers["Content-Type"] - def test_should_respond_406(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - test_dag: DAG = dagbag.dags[TEST_DAG_ID] - - url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" - response = self.client.get( - url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"} - ) - - assert 406 == response.status_code - def test_should_respond_404(self): wrong_fileloc = "abcd1234" url = f"/api/v1/dagSources/{wrong_fileloc}" @@ -167,38 +141,3 @@ def test_should_raise_403_forbidden(self, url_safe_serializer): environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 - - def test_should_respond_403_not_readable(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID] - - response = self.client.get( - f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, - ) - read_dag = self.client.get( - f"/api/v1/dags/{NOT_READABLE_DAG_ID}", - environ_overrides={"REMOTE_USER": "test"}, - ) - assert response.status_code == 403 - assert read_dag.status_code == 403 - - def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): - dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) - dagbag.sync_to_db() - dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID] - - response = self.client.get( - f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, - ) - - read_dag = self.client.get( - f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}", - environ_overrides={"REMOTE_USER": "test"}, - ) - assert response.status_code == 403 - assert read_dag.status_code == 200 diff --git a/tests/api_connexion/endpoints/test_dag_stats_endpoint.py b/tests/api_connexion/endpoints/test_dag_stats_endpoint.py index 36fc54d3a5b17..9ab5b49765931 100644 --- a/tests/api_connexion/endpoints/test_dag_stats_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_stats_endpoint.py @@ -22,7 +22,6 @@ from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import DagRunState @@ -38,21 +37,17 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDagStatsEndpoint: diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py index 3e7c805173b39..f156d8921c0e6 100644 --- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py @@ -22,7 +22,6 @@ from airflow.models.dag import DagModel from airflow.models.dagwarning import DagWarning -from airflow.security import permissions from airflow.utils.session import create_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_dag_warnings, clear_db_dags @@ -34,30 +33,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - ], # type: ignore - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user( - app, # type:ignore - username="test_with_dag2_read", - role_name="TestWithDag2Read", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"), - ], # type: ignore + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) yield minimal_app_for_api - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_with_dag2_read") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBaseDagWarning: @@ -162,11 +147,3 @@ def test_should_raise_403_forbidden(self): "/api/v1/dagWarnings", environ_overrides={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 - - def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): - response = self.client.get( - "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, - query_string={"dag_id": "dag1"}, - ) - assert response.status_code == 403 diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py index 5caec0ac2a131..76c164654c9d8 100644 --- a/tests/api_connexion/endpoints/test_dataset_endpoint.py +++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py @@ -33,7 +33,6 @@ TaskOutletAssetReference, ) from airflow.models.dagrun import DagRun -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session from airflow.utils.types import DagRunType @@ -50,31 +49,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ASSET), - ], - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user( - app, # type: ignore - username="test_queued_event", - role_name="TestQueuedEvent", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET), - ], + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test_queued_event") # type: ignore - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestDatasetEndpoint: @@ -768,43 +752,6 @@ def _create_dataset_dag_run_queues(self, dag_id, dataset_id, session): class TestGetDagDatasetQueuedEvent(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - dataset_uri = "s3://bucket/key" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - - def test_should_respond_404(self): - dag_id = "not_exists" - dataset_uri = "not_exists" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self, session): dag_id = "dummy" dataset_uri = "dummy" @@ -826,47 +773,6 @@ def test_should_raise_403_forbidden(self, session): class TestDeleteDagDatasetQueuedEvent(TestDatasetEndpoint): - def test_delete_should_respond_204(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_uri = "s3://bucket/key" - dataset_id = self._create_dataset(session).id - - adrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) - session.add(adrq) - session.commit() - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 1 - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 204 - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 0 - _check_last_log( - session, dag_id=dag_id, event="api.delete_dag_dataset_queued_event", execution_date=None - ) - - def test_should_respond_404(self): - dag_id = "not_exists" - dataset_uri = "not_exists" - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self, session): dag_id = "dummy" dataset_uri = "dummy" @@ -884,46 +790,6 @@ def test_should_raise_403_forbidden(self, session): class TestGetDagDatasetQueuedEvents(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "queued_events": [ - { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - ], - "total_entries": 1, - } - - def test_should_respond_404(self): - dag_id = "not_exists" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dag_id = "dummy" @@ -943,22 +809,6 @@ def test_should_raise_403_forbidden(self): class TestDeleteDagDatasetQueuedEvents(TestDatasetEndpoint): - def test_should_respond_404(self): - dag_id = "not_exists" - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/datasets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with dag_id: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dag_id = "dummy" @@ -978,47 +828,6 @@ def test_should_raise_403_forbidden(self): class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - dataset_uri = "s3://bucket/key" - - response = self.client.get( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "queued_events": [ - { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - ], - "total_entries": 1, - } - - def test_should_respond_404(self): - dataset_uri = "not_exists" - - response = self.client.get( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dataset_uri = "not_exists" @@ -1038,39 +847,6 @@ def test_should_raise_403_forbidden(self): class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint): - def test_delete_should_respond_204(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - dataset_id = self._create_dataset(session).id - self._create_dataset_dag_run_queues(dag_id, dataset_id, session) - dataset_uri = "s3://bucket/key" - - response = self.client.delete( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 204 - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 0 - _check_last_log(session, dag_id=None, event="api.delete_dataset_queued_events", execution_date=None) - - def test_should_respond_404(self): - dataset_uri = "not_exists" - - response = self.client.delete( - f"/api/v1/datasets/queuedEvent/{dataset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert { - "detail": "Queue event with asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } == response.json - def test_should_raises_401_unauthenticated(self): dataset_uri = "not_exists" diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py index 0fdef1a3af2b6..e5ca3d301765a 100644 --- a/tests/api_connexion/endpoints/test_event_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py @@ -20,7 +20,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Log -from airflow.security import permissions from airflow.utils import timezone from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -33,32 +32,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore + role_name="admin", ) - create_user( - app, # type:ignore - username="test_granular", - role_name="TestGranular", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_ID_1", - access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "TEST_DAG_ID_2", - access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_granular") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") @pytest.fixture @@ -274,33 +257,6 @@ def test_should_raises_401_unauthenticated(self, log_model): assert_401(response) - def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session): - eventlog1 = create_log_model( - event="TEST_EVENT_1", - dag_id="TEST_DAG_ID_1", - task_id="TEST_TASK_ID_1", - owner="TEST_OWNER_1", - when=self.default_time, - ) - eventlog2 = create_log_model( - event="TEST_EVENT_2", - dag_id="TEST_DAG_ID_2", - task_id="TEST_TASK_ID_2", - owner="TEST_OWNER_2", - when=self.default_time_2, - ) - session.add_all([eventlog1, eventlog2]) - session.commit() - for attr in ["dag_id", "task_id", "owner", "event"]: - attr_value = f"TEST_{attr}_1".upper() - response = self.client.get( - f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} - ) - assert response.status_code == 200 - assert response.json["total_entries"] == 1 - assert len(response.json["event_logs"]) == 1 - assert response.json["event_logs"][0][attr] == attr_value - def test_should_filter_eventlogs_by_when(self, create_log_model, session): eventlog1 = create_log_model(event="TEST_EVENT_1", when=self.default_time) eventlog2 = create_log_model(event="TEST_EVENT_2", when=self.default_time_2) @@ -339,32 +295,6 @@ def test_should_filter_eventlogs_by_run_id(self, create_log_model, session): assert {eventlog["event"] for eventlog in response.json["event_logs"]} == expected_eventlogs assert all({eventlog["run_id"] == run_id for eventlog in response.json["event_logs"]}) - def test_should_filter_eventlogs_by_included_events(self, create_log_model): - for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: - create_log_model(event=event, when=self.default_time) - response = self.client.get( - "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, - ) - assert response.status_code == 200 - response_data = response.json - assert len(response_data["event_logs"]) == 2 - assert response_data["total_entries"] == 2 - assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]} - - def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): - for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: - create_log_model(event=event, when=self.default_time) - response = self.client.get( - "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, - ) - assert response.status_code == 200 - response_data = response.json - assert len(response_data["event_logs"]) == 1 - assert response_data["total_entries"] == 1 - assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} - class TestGetEventLogPagination(TestEventLogEndpoint): @pytest.mark.parametrize( diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py index 1e9226ede9847..2c3eacdc91dc0 100644 --- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py +++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py @@ -26,7 +26,6 @@ from airflow.models.dagbag import DagBag from airflow.models.xcom import XCom from airflow.plugins_manager import AirflowPlugin -from airflow.security import permissions from airflow.timetables.base import DataInterval from airflow.utils import timezone from airflow.utils.state import DagRunState @@ -48,21 +47,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestGetExtraLinks: @@ -78,8 +72,8 @@ def setup_attrs(self, configured_app, session) -> None: self.dag = self._create_dag() self.app.dag_bag = DagBag(os.devnull, include_examples=False) - self.app.dag_bag.dags = {self.dag.dag_id: self.dag} # type: ignore - self.app.dag_bag.sync_to_db() # type: ignore + self.app.dag_bag.dags = {self.dag.dag_id: self.dag} + self.app.dag_bag.sync_to_db() triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} self.dag.create_dagrun( diff --git a/tests/api_connexion/endpoints/test_import_error_endpoint.py b/tests/api_connexion/endpoints/test_import_error_endpoint.py index 635e159bb292c..af2b83ebb1eed 100644 --- a/tests/api_connexion/endpoints/test_import_error_endpoint.py +++ b/tests/api_connexion/endpoints/test_import_error_endpoint.py @@ -21,15 +21,12 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.models.dag import DagModel -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.compat import ParseImportError from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_dags, clear_db_import_errors -from tests.test_utils.permissions import _resource_name pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -40,42 +37,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), - ], # type: ignore - ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore - create_user( - app, # type:ignore - username="test_single_dag", - role_name="TestSingleDAG", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore - ) - # For some reason, DAG level permissions are not synced when in the above list of perms, - # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( - [ - { - "role": "TestSingleDAG", - "perms": [ - ( - permissions.ACTION_CAN_READ, - _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG), - ) - ], - } - ] + role_name="admin", ) + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_single_dag") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBaseImportError: @@ -152,72 +123,6 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 - def test_should_raise_403_forbidden_without_dag_read(self, session): - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 403 - - def test_should_return_200_with_single_dag_read(self, session): - dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py") - session.add(dag_model) - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - response_data["import_error_id"] = 1 - assert { - "filename": "Lorem_ipsum.py", - "import_error_id": 1, - "stack_trace": "Lorem ipsum", - "timestamp": "2020-06-10T12:00:00+00:00", - } == response_data - - def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session): - for dag_id in TEST_DAG_IDS: - dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py") - session.add(dag_model) - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - response_data["import_error_id"] = 1 - assert { - "filename": "Lorem_ipsum.py", - "import_error_id": 1, - "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", - "timestamp": "2020-06-10T12:00:00+00:00", - } == response_data - class TestGetImportErrorsEndpoint(TestBaseImportError): def test_get_import_errors(self, session): @@ -328,71 +233,6 @@ def test_should_raises_401_unauthenticated(self, session): assert_401(response) - def test_get_import_errors_single_dag(self, session): - for dag_id in TEST_DAG_IDS: - fake_filename = f"/tmp/{dag_id}.py" - dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) - session.add(dag_model) - importerror = ParseImportError( - filename=fake_filename, - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(importerror) - session.commit() - - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - self._normalize_import_errors(response_data["import_errors"]) - assert { - "import_errors": [ - { - "filename": "/tmp/test_dag.py", - "import_error_id": 1, - "stack_trace": "Lorem ipsum", - "timestamp": "2020-06-10T12:00:00+00:00", - }, - ], - "total_entries": 1, - } == response_data - - def test_get_import_errors_single_dag_in_dagfile(self, session): - for dag_id in TEST_DAG_IDS: - fake_filename = "/tmp/all_in_one.py" - dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) - session.add(dag_model) - - importerror = ParseImportError( - filename="/tmp/all_in_one.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(importerror) - session.commit() - - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - self._normalize_import_errors(response_data["import_errors"]) - assert { - "import_errors": [ - { - "filename": "/tmp/all_in_one.py", - "import_error_id": 1, - "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", - "timestamp": "2020-06-10T12:00:00+00:00", - }, - ], - "total_entries": 1, - } == response_data - class TestGetImportErrorsEndpointPagination(TestBaseImportError): @pytest.mark.parametrize( diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py index 420d2dd65f89c..2b112e3221843 100644 --- a/tests/api_connexion/endpoints/test_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_log_endpoint.py @@ -30,7 +30,6 @@ from airflow.decorators import task from airflow.models.dag import DAG from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.types import DagRunType from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user @@ -46,13 +45,9 @@ def configured_app(minimal_app_for_api): create_user( app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") + create_user(app, username="test_no_permissions", role_name=None) yield app diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py index 72cdccdee68df..fc53b8952f4aa 100644 --- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py @@ -28,12 +28,11 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.dagbag import DagBag from airflow.models.taskmap import TaskMap -from airflow.security import permissions from airflow.utils.platform import getuser from airflow.utils.session import provide_session from airflow.utils.state import State, TaskInstanceState from airflow.utils.timezone import datetime -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields from tests.test_utils.mock_operators import MockOperator @@ -50,24 +49,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_roles(app) + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestMappedTaskInstanceEndpoint: @@ -133,8 +124,8 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None): session.add(ti) self.app.dag_bag = DagBag(os.devnull, include_examples=False) - self.app.dag_bag.dags = {dag_id: dag_maker.dag} # type: ignore - self.app.dag_bag.sync_to_db() # type: ignore + self.app.dag_bag.dags = {dag_id: dag_maker.dag} + self.app.dag_bag.sync_to_db() session.flush() mapped.expand_mapped_task(dr.run_id, session=session) diff --git a/tests/api_connexion/endpoints/test_plugin_endpoint.py b/tests/api_connexion/endpoints/test_plugin_endpoint.py index edf925cf0fa73..0cd630375a282 100644 --- a/tests/api_connexion/endpoints/test_plugin_endpoint.py +++ b/tests/api_connexion/endpoints/test_plugin_endpoint.py @@ -24,7 +24,6 @@ from airflow.hooks.base import BaseHook from airflow.plugins_manager import AirflowPlugin -from airflow.security import permissions from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.timetables.base import Timetable from airflow.utils.module_loading import qualname @@ -105,17 +104,16 @@ class MockPlugin(AirflowPlugin): def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestPluginsEndpoint: diff --git a/tests/api_connexion/endpoints/test_pool_endpoint.py b/tests/api_connexion/endpoints/test_pool_endpoint.py index 87439a5811945..2cc095d077aa9 100644 --- a/tests/api_connexion/endpoints/test_pool_endpoint.py +++ b/tests/api_connexion/endpoints/test_pool_endpoint.py @@ -20,7 +20,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models.pool import Pool -from airflow.security import permissions from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars @@ -35,22 +34,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_POOL), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBasePoolEndpoints: diff --git a/tests/api_connexion/endpoints/test_provider_endpoint.py b/tests/api_connexion/endpoints/test_provider_endpoint.py index 16e5989cc56db..b4cf8f10a92ae 100644 --- a/tests/api_connexion/endpoints/test_provider_endpoint.py +++ b/tests/api_connexion/endpoints/test_provider_endpoint.py @@ -21,7 +21,6 @@ import pytest from airflow.providers_manager import ProviderInfo -from airflow.security import permissions from tests.test_utils.api_connexion_utils import create_user, delete_user pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -54,17 +53,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestBaseProviderEndpoint: diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index d0a4fb903c8b8..b2e068bd507fe 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -27,7 +27,6 @@ from airflow.models.expandinput import EXPAND_INPUT_EMPTY from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags @@ -38,21 +37,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestTaskEndpoint: diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index 25ded6c814b72..b5b3163e988d0 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -25,19 +25,17 @@ from sqlalchemy import select from sqlalchemy.orm import contains_eager -from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import TriggererJobRunner from airflow.models import DagRun, SlaMiss, TaskInstance, Trigger from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF from airflow.models.taskinstancehistory import TaskInstanceHistory -from airflow.security import permissions from airflow.utils.platform import getuser from airflow.utils.session import provide_session from airflow.utils.state import State from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user +from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields from tests.test_utils.www import _check_last_log @@ -55,69 +53,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, # type: ignore - username="test_dag_read_only", - role_name="TestDagReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, # type: ignore - username="test_task_read_only", - role_name="TestTaskReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, # type: ignore - username="test_read_only_one_dag", - role_name="TestReadOnlyOneDag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms, - # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( - [ - { - "role": "TestReadOnlyOneDag", - "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")], - } - ] + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_dag_read_only") # type: ignore - delete_user(app, username="test_task_read_only") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore - delete_user(app, username="test_read_only_one_dag") # type: ignore - delete_roles(app) + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestTaskInstanceEndpoint: @@ -219,9 +164,8 @@ def setup_method(self): def teardown_method(self): clear_db_runs() - @pytest.mark.parametrize("username", ["test", "test_dag_read_only", "test_task_read_only"]) @provide_session - def test_should_respond_200(self, username, session): + def test_should_respond_200(self, session): self.create_task_instances(session) # Update ti and set operator to None to # test that operator field is nullable. @@ -232,7 +176,7 @@ def test_should_respond_200(self, username, session): session.commit() response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert response.json == { @@ -723,36 +667,11 @@ def test_should_respond_200(self, task_instances, update_extras, url, expected_t assert response.json["total_entries"] == expected_ti assert len(response.json["task_instances"]) == expected_ti - @pytest.mark.parametrize( - "task_instances, user, expected_ti", - [ - pytest.param( - { - "example_python_operator": 2, - "example_skip_dag": 1, - }, - "test_read_only_one_dag", - 2, - ), - pytest.param( - { - "example_python_operator": 1, - "example_skip_dag": 2, - }, - "test_read_only_one_dag", - 1, - ), - pytest.param( - { - "example_python_operator": 1, - "example_skip_dag": 2, - }, - "test", - 3, - ), - ], - ) - def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session): + def test_return_TI_only_from_readable_dags(self, session): + task_instances = { + "example_python_operator": 1, + "example_skip_dag": 2, + } for dag_id in task_instances: self.create_task_instances( session, @@ -763,11 +682,11 @@ def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ dag_id=dag_id, ) response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} + "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": "test"} ) assert response.status_code == 200 - assert response.json["total_entries"] == expected_ti - assert len(response.json["task_instances"]) == expected_ti + assert response.json["total_entries"] == 3 + assert len(response.json["task_instances"]) == 3 def test_should_respond_200_for_dag_id_filter(self, session): self.create_task_instances(session) @@ -898,44 +817,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): "test", id="test executor filter", ), - pytest.param( - [ - {"pool": "test_pool_1"}, - {"pool": "test_pool_2"}, - {"pool": "test_pool_3"}, - ], - True, - {"pool": ["test_pool_1", "test_pool_2"]}, - 2, - "test_dag_read_only", - id="test pool filter", - ), - pytest.param( - [ - {"state": State.RUNNING}, - {"state": State.QUEUED}, - {"state": State.SUCCESS}, - {"state": State.NONE}, - ], - False, - {"state": ["running", "queued", "none"]}, - 3, - "test_task_read_only", - id="test state filter", - ), - pytest.param( - [ - {"state": State.NONE}, - {"state": State.NONE}, - {"state": State.NONE}, - {"state": State.NONE}, - ], - False, - {}, - 4, - "test_task_read_only", - id="test dag with null states", - ), pytest.param( [ {"duration": 100}, @@ -948,36 +829,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): "test", id="test duration filter", ), - pytest.param( - [ - {"end_date": DEFAULT_DATETIME_1}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - { - "end_date_gte": DEFAULT_DATETIME_STR_1, - "end_date_lte": DEFAULT_DATETIME_STR_2, - }, - 2, - "test_task_read_only", - id="test end date filter", - ), - pytest.param( - [ - {"start_date": DEFAULT_DATETIME_1}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - { - "start_date_gte": DEFAULT_DATETIME_STR_1, - "start_date_lte": DEFAULT_DATETIME_STR_2, - }, - 2, - "test_dag_read_only", - id="test start date filter", - ), pytest.param( [ {"execution_date": DEFAULT_DATETIME_1}, @@ -1162,24 +1013,6 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 - def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session): - self.create_task_instances(session=session) - self.create_task_instances(session=session, dag_id="example_skip_dag") - payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]} - - response = self.client.post( - "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test_read_only_one_dag"}, - json=payload, - ) - assert response.status_code == 403 - assert response.json == { - "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']", - "status": 403, - "title": "Forbidden", - "type": EXCEPTIONS_LINK_MAP[403], - } - def test_should_raise_400_for_no_json(self): response = self.client.post( "/api/v1/dags/~/dagRuns/~/taskInstances/list", @@ -1794,11 +1627,10 @@ def test_should_raises_401_unauthenticated(self): ) assert_401(response) - @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username: str): + def test_should_raise_403_forbidden(self): response = self.client.post( "/api/v1/dags/example_python_operator/clearTaskInstances", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, json={ "dry_run": False, "reset_dag_runs": True, @@ -2043,11 +1875,10 @@ def test_should_raises_401_unauthenticated(self): ) assert_401(response) - @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username): + def test_should_raise_403_forbidden(self): response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, json={ "dry_run": True, "task_id": "print_the_context", @@ -2386,11 +2217,10 @@ def test_should_raises_401_unauthenticated(self): ) assert_401(response) - @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username): + def test_should_raise_403_forbidden(self): response = self.client.patch( self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, json={ "dry_run": True, "new_state": "failed", @@ -2748,14 +2578,13 @@ def setup_method(self): def teardown_method(self): clear_db_runs() - @pytest.mark.parametrize("username", ["test", "test_dag_read_only", "test_task_read_only"]) @provide_session - def test_should_respond_200(self, username, session): + def test_should_respond_200(self, session): self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True) response = self.client.get( "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1", - environ_overrides={"REMOTE_USER": username}, + environ_overrides={"REMOTE_USER": "test"}, ) assert response.status_code == 200 assert response.json == { diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py b/tests/api_connexion/endpoints/test_variable_endpoint.py index 81405df08b045..aa5f7c99674f8 100644 --- a/tests/api_connexion/endpoints/test_variable_endpoint.py +++ b/tests/api_connexion/endpoints/test_variable_endpoint.py @@ -22,7 +22,6 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Variable -from airflow.security import permissions from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_variables @@ -36,40 +35,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), - ], - ) - create_user( - app, # type: ignore - username="test_read_only", - role_name="TestReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), - ], - ) - create_user( - app, # type: ignore - username="test_delete_only", - role_name="TestDeleteOnly", - permissions=[ - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), - ], + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_read_only") # type: ignore - delete_user(app, username="test_delete_only") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestVariableEndpoint: @@ -131,8 +106,6 @@ class TestGetVariable(TestVariableEndpoint): "user, expected_status_code", [ ("test", 200), - ("test_read_only", 200), - ("test_delete_only", 403), ("test_no_permissions", 403), ], ) diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index 7a51714c5b299..809e537f9f88d 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -26,7 +26,6 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend from airflow.operators.empty import EmptyOperator -from airflow.security import permissions from airflow.utils.dates import parse_execution_date from airflow.utils.session import create_session from airflow.utils.timezone import utcnow @@ -52,32 +51,16 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type: ignore + app, username="test", - role_name="Test", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], - ) - create_user( - app, # type: ignore - username="test_granular_permissions", - role_name="TestGranularDag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], - ) - app.appbuilder.sm.sync_perm_for_dag( # type: ignore - "test-dag-id-1", - access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, + role_name="admin", ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name=None) yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") def _compare_xcom_collections(collection1: dict, collection_2: dict): @@ -435,53 +418,6 @@ def test_should_respond_200_with_tilde_and_access_to_all_dags(self): }, ) - def test_should_respond_200_with_tilde_and_granular_dag_access(self): - dag_id_1 = "test-dag-id-1" - task_id_1 = "test-task-id-1" - execution_date = "2005-04-02T00:00:00+00:00" - execution_date_parsed = parse_execution_date(execution_date) - dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) - self._create_xcom_entries(dag_id_1, dag_run_id_1, execution_date_parsed, task_id_1) - - dag_id_2 = "test-dag-id-2" - task_id_2 = "test-task-id-2" - run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) - self._create_xcom_entries(dag_id_2, run_id_2, execution_date_parsed, task_id_2) - self._create_invalid_xcom_entries(execution_date_parsed) - response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - - assert 200 == response.status_code - response_data = response.json - for xcom_entry in response_data["xcom_entries"]: - xcom_entry["timestamp"] = "TIMESTAMP" - _compare_xcom_collections( - response_data, - { - "xcom_entries": [ - { - "dag_id": dag_id_1, - "execution_date": execution_date, - "key": "test-xcom-key-1", - "task_id": task_id_1, - "timestamp": "TIMESTAMP", - "map_index": -1, - }, - { - "dag_id": dag_id_1, - "execution_date": execution_date, - "key": "test-xcom-key-2", - "task_id": task_id_1, - "timestamp": "TIMESTAMP", - "map_index": -1, - }, - ], - "total_entries": 2, - }, - ) - def test_should_respond_200_with_map_index(self): dag_id = "test-dag-id" task_id = "test-task-id" diff --git a/tests/api_connexion/test_auth.py b/tests/api_connexion/test_auth.py index 7d1dcc088273c..54e5632ad84d1 100644 --- a/tests/api_connexion/test_auth.py +++ b/tests/api_connexion/test_auth.py @@ -16,15 +16,15 @@ # under the License. from __future__ import annotations -from base64 import b64encode +from unittest.mock import patch import pytest -from flask_login import current_user +from airflow.auth.managers.simple.simple_auth_manager import SimpleAuthManager +from airflow.auth.managers.simple.user import SimpleAuthManagerUser from tests.test_utils.api_connexion_utils import assert_401 from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_pools -from tests.test_utils.www import client_with_login pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -34,101 +34,6 @@ class BaseTestAuth: def set_attrs(self, minimal_app_for_api): self.app = minimal_app_for_api - sm = self.app.appbuilder.sm - tester = sm.find_user(username="test") - if not tester: - role_admin = sm.find_role("Admin") - sm.add_user( - username="test", - first_name="test", - last_name="test", - email="test@fab.org", - role=role_admin, - password="test", - ) - - -class TestBasicAuth(BaseTestAuth): - @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): - from airflow.www.extensions.init_security import init_api_auth - - old_auth = getattr(minimal_app_for_api, "api_auth") - - try: - with conf_vars( - {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} - ): - init_api_auth(minimal_app_for_api) - yield - finally: - setattr(minimal_app_for_api, "api_auth", old_auth) - - def test_success(self): - token = "Basic " + b64encode(b"test:test").decode() - clear_db_pools() - - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert current_user.email == "test@fab.org" - - assert response.status_code == 200 - assert response.json == { - "pools": [ - { - "name": "default_pool", - "slots": 128, - "occupied_slots": 0, - "running_slots": 0, - "queued_slots": 0, - "scheduled_slots": 0, - "deferred_slots": 0, - "open_slots": 128, - "description": "Default pool", - "include_deferred": False, - }, - ], - "total_entries": 1, - } - - @pytest.mark.parametrize( - "token", - [ - "basic", - "basic ", - "bearer", - "test:test", - b64encode(b"test:test").decode(), - "bearer ", - "basic: ", - "basic 123", - ], - ) - def test_malformed_headers(self, token): - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 401 - assert response.headers["Content-Type"] == "application/problem+json" - assert response.headers["WWW-Authenticate"] == "Basic" - assert_401(response) - - @pytest.mark.parametrize( - "token", - [ - "basic " + b64encode(b"test").decode(), - "basic " + b64encode(b"test:").decode(), - "basic " + b64encode(b"test:123").decode(), - "basic " + b64encode(b"test test").decode(), - ], - ) - def test_invalid_auth_header(self, token): - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 401 - assert response.headers["Content-Type"] == "application/problem+json" - assert response.headers["WWW-Authenticate"] == "Basic" - assert_401(response) - class TestSessionAuth(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") @@ -144,74 +49,37 @@ def with_session_backend(self, minimal_app_for_api): finally: setattr(minimal_app_for_api, "api_auth", old_auth) - def test_success(self): + @patch.object(SimpleAuthManager, "is_logged_in", return_value=True) + @patch.object( + SimpleAuthManager, "get_user", return_value=SimpleAuthManagerUser(username="test", role="admin") + ) + def test_success(self, *args): clear_db_pools() - admin_user = client_with_login(self.app, username="test", password="test") - response = admin_user.get("/api/v1/pools") - assert response.status_code == 200 - assert response.json == { - "pools": [ - { - "name": "default_pool", - "slots": 128, - "occupied_slots": 0, - "running_slots": 0, - "queued_slots": 0, - "scheduled_slots": 0, - "deferred_slots": 0, - "open_slots": 128, - "description": "Default pool", - "include_deferred": False, - }, - ], - "total_entries": 1, - } - - def test_failure(self): with self.app.test_client() as test_client: response = test_client.get("/api/v1/pools") - assert response.status_code == 401 - assert response.headers["Content-Type"] == "application/problem+json" - assert_401(response) - - -class TestSessionWithBasicAuthFallback(BaseTestAuth): - @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): - from airflow.www.extensions.init_security import init_api_auth - - old_auth = getattr(minimal_app_for_api, "api_auth") - - try: - with conf_vars( - { - ( - "api", - "auth_backends", - ): "airflow.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth" - } - ): - init_api_auth(minimal_app_for_api) - yield - finally: - setattr(minimal_app_for_api, "api_auth", old_auth) - - def test_basic_auth_fallback(self): - token = "Basic " + b64encode(b"test:test").decode() - clear_db_pools() - - # request uses session - admin_user = client_with_login(self.app, username="test", password="test") - response = admin_user.get("/api/v1/pools") - assert response.status_code == 200 - - # request uses basic auth - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) assert response.status_code == 200 + assert response.json == { + "pools": [ + { + "name": "default_pool", + "slots": 128, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "deferred_slots": 0, + "open_slots": 128, + "description": "Default pool", + "include_deferred": False, + }, + ], + "total_entries": 1, + } - # request without session or basic auth header + def test_failure(self): with self.app.test_client() as test_client: response = test_client.get("/api/v1/pools") assert response.status_code == 401 + assert response.headers["Content-Type"] == "application/problem+json" + assert_401(response) diff --git a/tests/api_connexion/test_security.py b/tests/api_connexion/test_security.py index 13a5dd4e25af1..c6a112b1a1bb9 100644 --- a/tests/api_connexion/test_security.py +++ b/tests/api_connexion/test_security.py @@ -18,7 +18,6 @@ import pytest -from airflow.security import permissions from tests.test_utils.api_connexion_utils import create_user, delete_user pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] @@ -28,15 +27,14 @@ def configured_app(minimal_app_for_api): app = minimal_app_for_api create_user( - app, # type:ignore + app, username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore + role_name="admin", ) yield minimal_app_for_api - delete_user(app, username="test") # type: ignore + delete_user(app, username="test") class TestSession: diff --git a/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py b/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py new file mode 100644 index 0000000000000..61d923d5ff125 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from contextlib import contextmanager + +from tests.test_utils.compat import ignore_provider_compatibility_error + +with ignore_provider_compatibility_error("2.9.0+", __file__): + from airflow.providers.fab.auth_manager.security_manager.override import EXISTING_ROLES + + +@contextmanager +def create_test_client(app, user_name, role_name, permissions): + """ + Helper function to create a client with a temporary user which will be deleted once done + """ + client = app.test_client() + with create_user_scope(app, username=user_name, role_name=role_name, permissions=permissions) as _: + resp = client.post("/login/", data={"username": user_name, "password": user_name}) + assert resp.status_code == 302 + yield client + + +@contextmanager +def create_user_scope(app, username, **kwargs): + """ + Helper function designed to be used with pytest fixture mainly. + It will create a user and provide it for the fixture via YIELD (generator) + then will tidy up once test is complete + """ + test_user = create_user(app, username, **kwargs) + + try: + yield test_user + finally: + delete_user(app, username) + + +def create_user(app, username, role_name=None, email=None, permissions=None): + appbuilder = app.appbuilder + + # Removes user and role so each test has isolated test data. + delete_user(app, username) + role = None + if role_name: + delete_role(app, role_name) + role = create_role(app, role_name, permissions) + else: + role = [] + + return appbuilder.sm.add_user( + username=username, + first_name=username, + last_name=username, + email=email or f"{username}@example.org", + role=role, + password=username, + ) + + +def create_role(app, name, permissions=None): + appbuilder = app.appbuilder + role = appbuilder.sm.find_role(name) + if not role: + role = appbuilder.sm.add_role(name) + if not permissions: + permissions = [] + for permission in permissions: + perm_object = appbuilder.sm.get_permission(*permission) + appbuilder.sm.add_permission_to_role(role, perm_object) + return role + + +def set_user_single_role(app, user, role_name): + role = create_role(app, role_name) + if role not in user.roles: + user.roles = [role] + app.appbuilder.sm.update_user(user) + user._perms = None + + +def delete_role(app, name): + if name not in EXISTING_ROLES: + if app.appbuilder.sm.find_role(name): + app.appbuilder.sm.delete_role(name) + + +def delete_roles(app): + for role in app.appbuilder.sm.get_all_roles(): + delete_role(app, role.name) + + +def delete_user(app, username): + appbuilder = app.appbuilder + for user in appbuilder.sm.get_all_users(): + if user.username == username: + _ = [ + delete_role(app, role.name) for role in user.roles if role and role.name not in EXISTING_ROLES + ] + appbuilder.sm.del_register_user(user) + break diff --git a/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py b/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py new file mode 100644 index 0000000000000..b7714e5192e6a --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default authentication backend - everything is allowed""" + +from __future__ import annotations + +import logging +from functools import wraps +from typing import TYPE_CHECKING, Callable, TypeVar, cast + +from flask import Response, request +from flask_login import login_user + +from airflow.utils.airflow_flask_app import get_airflow_app + +if TYPE_CHECKING: + from requests.auth import AuthBase + +log = logging.getLogger(__name__) + +CLIENT_AUTH: tuple[str, str] | AuthBase | None = None + + +def init_app(_): + """Initializes authentication backend""" + + +T = TypeVar("T", bound=Callable) + + +def _lookup_user(user_email_or_username: str): + security_manager = get_airflow_app().appbuilder.sm + user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user( + username=user_email_or_username + ) + if not user: + return None + + if not user.is_active: + return None + + return user + + +def requires_authentication(function: T): + """Decorator for functions that require authentication""" + + @wraps(function) + def decorated(*args, **kwargs): + user_id = request.remote_user + if not user_id: + log.debug("Missing REMOTE_USER.") + return Response("Forbidden", 403) + + log.debug("Looking for user: %s", user_id) + + user = _lookup_user(user_id) + if not user: + return Response("Forbidden", 403) + + log.debug("Found user: %s", user) + + login_user(user, remember=False) + return function(*args, **kwargs) + + return cast(T, decorated) diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_auth.py b/tests/providers/fab/auth_manager/api_endpoints/test_auth.py new file mode 100644 index 0000000000000..d3012e2f1b43e --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_auth.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from base64 import b64encode + +import pytest +from flask_login import current_user + +from tests.test_utils.api_connexion_utils import assert_401 +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_pools +from tests.test_utils.www import client_with_login + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +class BaseTestAuth: + @pytest.fixture(autouse=True) + def set_attrs(self, minimal_app_for_auth_api): + self.app = minimal_app_for_auth_api + + sm = self.app.appbuilder.sm + tester = sm.find_user(username="test") + if not tester: + role_admin = sm.find_role("Admin") + sm.add_user( + username="test", + first_name="test", + last_name="test", + email="test@fab.org", + role=role_admin, + password="test", + ) + + +class TestBasicAuth(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_auth_api): + from airflow.www.extensions.init_security import init_api_auth + + old_auth = getattr(minimal_app_for_auth_api, "api_auth") + + try: + with conf_vars( + {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} + ): + init_api_auth(minimal_app_for_auth_api) + yield + finally: + setattr(minimal_app_for_auth_api, "api_auth", old_auth) + + def test_success(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert current_user.email == "test@fab.org" + + assert response.status_code == 200 + assert response.json == { + "pools": [ + { + "name": "default_pool", + "slots": 128, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "deferred_slots": 0, + "open_slots": 128, + "description": "Default pool", + "include_deferred": False, + }, + ], + "total_entries": 1, + } + + @pytest.mark.parametrize( + "token", + [ + "basic", + "basic ", + "bearer", + "test:test", + b64encode(b"test:test").decode(), + "bearer ", + "basic: ", + "basic 123", + ], + ) + def test_malformed_headers(self, token): + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 401 + assert response.headers["Content-Type"] == "application/problem+json" + assert response.headers["WWW-Authenticate"] == "Basic" + assert_401(response) + + @pytest.mark.parametrize( + "token", + [ + "basic " + b64encode(b"test").decode(), + "basic " + b64encode(b"test:").decode(), + "basic " + b64encode(b"test:123").decode(), + "basic " + b64encode(b"test test").decode(), + ], + ) + def test_invalid_auth_header(self, token): + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 401 + assert response.headers["Content-Type"] == "application/problem+json" + assert response.headers["WWW-Authenticate"] == "Basic" + assert_401(response) + + +class TestSessionWithBasicAuthFallback(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_auth_api): + from airflow.www.extensions.init_security import init_api_auth + + old_auth = getattr(minimal_app_for_auth_api, "api_auth") + + try: + with conf_vars( + { + ( + "api", + "auth_backends", + ): "airflow.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth" + } + ): + init_api_auth(minimal_app_for_auth_api) + yield + finally: + setattr(minimal_app_for_auth_api, "api_auth", old_auth) + + def test_basic_auth_fallback(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + # request uses session + admin_user = client_with_login(self.app, username="test", password="test") + response = admin_user.get("/api/v1/pools") + assert response.status_code == 200 + + # request uses basic auth + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 200 + + # request without session or basic auth header + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools") + assert response.status_code == 401 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py new file mode 100644 index 0000000000000..56f135d457e9c --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime +from unittest import mock +from urllib.parse import urlencode + +import pendulum +import pytest + +from airflow.models import DagBag, DagModel +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS + +try: + from airflow.models.backfill import Backfill +except ImportError: + if AIRFLOW_V_3_0_PLUS: + raise + else: + pass +from airflow.models.dag import DAG +from airflow.models.serialized_dag import SerializedDagModel +from airflow.operators.empty import EmptyOperator +from airflow.security import permissions +from airflow.utils import timezone +from airflow.utils.session import provide_session +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs, clear_db_serialized_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +DAG_ID = "test_dag" +TASK_ID = "op1" +DAG2_ID = "test_dag2" +DAG3_ID = "test_dag3" +UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user(app, username="test_granular_permissions", role_name="TestGranularDag") + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_1", + access_control={ + "TestGranularDag": { + permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} + }, + }, + ) + + with DAG( + DAG_ID, + schedule=None, + start_date=datetime(2020, 6, 15), + doc_md="details", + params={"foo": 1}, + tags=["example"], + ) as dag: + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG2_ID, schedule=None, start_date=datetime(2020, 6, 15)) as dag2: # no doc_md + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG3_ID, schedule=None) as dag3: # DAG start_date set to None + EmptyOperator(task_id=TASK_ID, start_date=datetime(2019, 6, 12)) + + dag_bag = DagBag(os.devnull, include_examples=False) + dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3} + + app.dag_bag = dag_bag + + yield app + + delete_user(app, username="test_granular_permissions") + + +class TestBackfillEndpoint: + @staticmethod + def clean_db(): + clear_db_backfills() + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.clean_db() + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.dag_id = DAG_ID + self.dag2_id = DAG2_ID + self.dag3_id = DAG3_ID + + def teardown_method(self) -> None: + self.clean_db() + + @provide_session + def _create_dag_models(self, *, count=1, dag_id_prefix="TEST_DAG", is_paused=False, session=None): + dags = [] + for num in range(1, count + 1): + dag_model = DagModel( + dag_id=f"{dag_id_prefix}_{num}", + fileloc=f"/tmp/dag_{num}.py", + is_active=True, + timetable_summary="0 0 * * *", + is_paused=is_paused, + ) + session.add(dag_model) + dags.append(dag_model) + return dags + + @provide_session + def _create_deactivated_dag(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_DELETED_1", + fileloc="/tmp/dag_del_1.py", + schedule_interval="2 2 * * *", + is_active=False, + ) + session.add(dag_model) + + +class TestListBackfills(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + b = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + + session.add(b) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.get("/api/v1/backfills?dag_id=TEST_DAG_1", **kwargs) + assert response.status_code == 200 + + +class TestGetBackfill(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + backfill = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + session.add(backfill) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.get(f"/api/v1/backfills/{backfill.id}", **kwargs) + assert response.status_code == 200 + + +class TestCreateBackfill(TestBackfillEndpoint): + def test_create_backfill(self, session, dag_maker): + with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * * *") as dag: + EmptyOperator(task_id="mytask") + session.add(SerializedDagModel(dag)) + session.commit() + session.query(DagModel).all() + from_date = pendulum.parse("2024-01-01") + from_date_iso = from_date.isoformat() + to_date = pendulum.parse("2024-02-01") + to_date_iso = to_date.isoformat() + max_active_runs = 5 + query = urlencode( + query={ + "dag_id": dag.dag_id, + "from_date": f"{from_date_iso}", + "to_date": f"{to_date_iso}", + "max_active_runs": max_active_runs, + "reverse": False, + } + ) + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + + response = self.client.post( + f"/api/v1/backfills?{query}", + **kwargs, + ) + assert response.status_code == 200 + assert response.json == { + "completed_at": mock.ANY, + "created_at": mock.ANY, + "dag_id": "TEST_DAG_1", + "dag_run_conf": None, + "from_date": from_date_iso, + "id": mock.ANY, + "is_paused": False, + "max_active_runs": 5, + "to_date": to_date_iso, + "updated_at": mock.ANY, + } + + +class TestPauseBackfill(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + backfill = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + session.add(backfill) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.post(f"/api/v1/backfills/{backfill.id}/pause", **kwargs) + assert response.status_code == 200 + + +class TestCancelBackfill(TestBackfillEndpoint): + def test_should_respond_200_with_granular_dag_access(self, session): + (dag,) = self._create_dag_models() + from_date = timezone.utcnow() + to_date = timezone.utcnow() + backfill = Backfill( + dag_id=dag.dag_id, + from_date=from_date, + to_date=to_date, + ) + session.add(backfill) + session.commit() + kwargs = {} + kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"}) + response = self.client.post(f"/api/v1/backfills/{backfill.id}/cancel", **kwargs) + assert response.status_code == 200 + # now it is marked as completed + assert pendulum.parse(response.json["completed_at"]) + + # get conflict when canceling already-canceled backfill + response = self.client.post(f"/api/v1/backfills/{backfill.id}/cancel", **kwargs) + assert response.status_code == 409 diff --git a/tests/api_connexion/test_cors.py b/tests/providers/fab/auth_manager/api_endpoints/test_cors.py similarity index 81% rename from tests/api_connexion/test_cors.py rename to tests/providers/fab/auth_manager/api_endpoints/test_cors.py index a2b7f0ebca743..b44eab8820ec6 100644 --- a/tests/api_connexion/test_cors.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_cors.py @@ -20,16 +20,21 @@ import pytest +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_pools -pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] class BaseTestAuth: @pytest.fixture(autouse=True) - def set_attrs(self, minimal_app_for_api): - self.app = minimal_app_for_api + def set_attrs(self, minimal_app_for_auth_api): + self.app = minimal_app_for_auth_api sm = self.app.appbuilder.sm tester = sm.find_user(username="test") @@ -47,19 +52,19 @@ def set_attrs(self, minimal_app_for_api): class TestEmptyCors(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): + def with_basic_auth_backend(self, minimal_app_for_auth_api): from airflow.www.extensions.init_security import init_api_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + old_auth = getattr(minimal_app_for_auth_api, "api_auth") try: with conf_vars( {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} ): - init_api_auth(minimal_app_for_api) + init_api_auth(minimal_app_for_auth_api) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(minimal_app_for_auth_api, "api_auth", old_auth) def test_empty_cors_headers(self): token = "Basic " + b64encode(b"test:test").decode() @@ -75,10 +80,10 @@ def test_empty_cors_headers(self): class TestCorsOrigin(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): + def with_basic_auth_backend(self, minimal_app_for_auth_api): from airflow.www.extensions.init_security import init_api_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + old_auth = getattr(minimal_app_for_auth_api, "api_auth") try: with conf_vars( @@ -90,10 +95,10 @@ def with_basic_auth_backend(self, minimal_app_for_api): ("api", "access_control_allow_origins"): "http://apache.org http://example.com", } ): - init_api_auth(minimal_app_for_api) + init_api_auth(minimal_app_for_auth_api) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(minimal_app_for_auth_api, "api_auth", old_auth) def test_cors_origin_reflection(self): token = "Basic " + b64encode(b"test:test").decode() @@ -119,10 +124,10 @@ def test_cors_origin_reflection(self): class TestCorsWildcard(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_api): + def with_basic_auth_backend(self, minimal_app_for_auth_api): from airflow.www.extensions.init_security import init_api_auth - old_auth = getattr(minimal_app_for_api, "api_auth") + old_auth = getattr(minimal_app_for_auth_api, "api_auth") try: with conf_vars( @@ -134,10 +139,10 @@ def with_basic_auth_backend(self, minimal_app_for_api): ("api", "access_control_allow_origins"): "*", } ): - init_api_auth(minimal_app_for_api) + init_api_auth(minimal_app_for_auth_api) yield finally: - setattr(minimal_app_for_api, "api_auth", old_auth) + setattr(minimal_app_for_auth_api, "api_auth", old_auth) def test_cors_origin_reflection(self): token = "Basic " + b64encode(b"test:test").decode() diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py new file mode 100644 index 0000000000000..b78ac58e442e0 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py @@ -0,0 +1,252 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +import pendulum +import pytest + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.models import DagBag, DagModel +from airflow.models.dag import DAG +from airflow.operators.empty import EmptyOperator +from airflow.security import permissions +from airflow.utils.session import provide_session +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags +from tests.test_utils.www import _check_last_log + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture +def current_file_token(url_safe_serializer) -> str: + return url_safe_serializer.dumps(__file__) + + +DAG_ID = "test_dag" +TASK_ID = "op1" +DAG2_ID = "test_dag2" +DAG3_ID = "test_dag3" +UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user(app, username="test_granular_permissions", role_name="TestGranularDag") + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_1", + access_control={ + "TestGranularDag": { + permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} + }, + }, + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_1", + access_control={ + "TestGranularDag": { + permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} + }, + }, + ) + + with DAG( + DAG_ID, + schedule=None, + start_date=datetime(2020, 6, 15), + doc_md="details", + params={"foo": 1}, + tags=["example"], + ) as dag: + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG2_ID, schedule=None, start_date=datetime(2020, 6, 15)) as dag2: # no doc_md + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG3_ID, schedule=None) as dag3: # DAG start_date set to None + EmptyOperator(task_id=TASK_ID, start_date=datetime(2019, 6, 12)) + + dag_bag = DagBag(os.devnull, include_examples=False) + dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3} + + app.dag_bag = dag_bag + + yield app + + delete_user(app, username="test_granular_permissions") + + +class TestDagEndpoint: + @staticmethod + def clean_db(): + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.clean_db() + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.dag_id = DAG_ID + self.dag2_id = DAG2_ID + self.dag3_id = DAG3_ID + + def teardown_method(self) -> None: + self.clean_db() + + @provide_session + def _create_dag_models(self, count, dag_id_prefix="TEST_DAG", is_paused=False, session=None): + for num in range(1, count + 1): + dag_model = DagModel( + dag_id=f"{dag_id_prefix}_{num}", + fileloc=f"/tmp/dag_{num}.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=is_paused, + ) + session.add(dag_model) + + @provide_session + def _create_dag_model_for_details_endpoint(self, dag_id, session=None): + dag_model = DagModel( + dag_id=dag_id, + fileloc="/tmp/dag.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=False, + ) + session.add(dag_model) + + @provide_session + def _create_dag_model_for_details_endpoint_with_dataset_expression(self, dag_id, session=None): + dag_model = DagModel( + dag_id=dag_id, + fileloc="/tmp/dag.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=False, + dataset_expression={ + "any": [ + "s3://dag1/output_1.txt", + {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]}, + ] + }, + ) + session.add(dag_model) + + @provide_session + def _create_deactivated_dag(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_DELETED_1", + fileloc="/tmp/dag_del_1.py", + timetable_summary="2 2 * * *", + is_active=False, + ) + session.add(dag_model) + + +class TestGetDag(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(1) + response = self.client.get( + "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 200 + + def test_should_respond_403_with_granular_access_for_different_dag(self): + self._create_dag_models(3) + response = self.client.get( + "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 403 + + +class TestGetDags(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(3) + response = self.client.get( + "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 200 + assert len(response.json["dags"]) == 1 + assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" + + +class TestPatchDag(TestDagEndpoint): + @provide_session + def _create_dag_model(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_1", fileloc="/tmp/dag_1.py", timetable_summary="2 2 * * *", is_paused=True + ) + session.add(dag_model) + return dag_model + + def test_should_respond_200_on_patch_with_granular_dag_access(self, session): + self._create_dag_models(1) + response = self.client.patch( + "/api/v1/dags/TEST_DAG_1", + json={ + "is_paused": False, + }, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 200 + _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None) + + def test_validation_error_raises_400(self): + patch_body = { + "ispaused": True, + } + dag_model = self._create_dag_model() + response = self.client.patch( + f"/api/v1/dags/{dag_model.dag_id}", + json=patch_body, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 400 + assert response.json == { + "detail": "{'ispaused': ['Unknown field.']}", + "status": 400, + "title": "Bad Request", + "type": EXCEPTIONS_LINK_MAP[400], + } + + +class TestPatchDags(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(3) + response = self.client.patch( + "api/v1/dags?dag_id_pattern=~", + json={ + "is_paused": False, + }, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 200 + assert len(response.json["dags"]) == 1 + assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py new file mode 100644 index 0000000000000..a58ea08ff31cf --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import timedelta + +import pytest + +from airflow.models.dag import DAG, DagModel +from airflow.models.dagrun import DagRun +from airflow.models.param import Param +from airflow.security import permissions +from airflow.utils import timezone +from airflow.utils.session import create_session +from airflow.utils.state import DagRunState +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS + +try: + from airflow.utils.types import DagRunTriggeredByType, DagRunType +except ImportError: + if AIRFLOW_V_3_0_PLUS: + raise + else: + pass +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user( + app, + username="test_no_dag_run_create_permission", + role_name="TestNoDagRunCreatePermission", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), + ], + ) + create_user( + app, + username="test_dag_view_only", + role_name="TestViewDags", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), + ], + ) + create_user( + app, + username="test_view_dags", + role_name="TestViewDags", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), + ], + ) + create_user( + app, + username="test_granular_permissions", + role_name="TestGranularDag", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)], + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_ID", + access_control={ + "TestGranularDag": {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}, + "TestNoDagRunCreatePermission": {permissions.RESOURCE_DAG_RUN: {permissions.ACTION_CAN_CREATE}}, + }, + ) + + yield app + + delete_user(app, username="test_dag_view_only") + delete_user(app, username="test_view_dags") + delete_user(app, username="test_granular_permissions") + delete_user(app, username="test_no_dag_run_create_permission") + delete_roles(app) + + +class TestDagRunEndpoint: + default_time = "2020-06-11T18:00:00+00:00" + default_time_2 = "2020-06-12T18:00:00+00:00" + default_time_3 = "2020-06-13T18:00:00+00:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_runs() + clear_db_serialized_dags() + clear_db_dags() + + def teardown_method(self) -> None: + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + def _create_dag(self, dag_id): + dag_instance = DagModel(dag_id=dag_id) + dag_instance.is_active = True + with create_session() as session: + session.add(dag_instance) + dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)}) + self.app.dag_bag.bag_dag(dag) + return dag_instance + + def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): + dag_runs = [] + dags = [] + triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {} + + for i in range(idx_start, idx_start + 2): + if i == 1: + dags.append(DagModel(dag_id="TEST_DAG_ID", is_active=True)) + dagrun_model = DagRun( + dag_id="TEST_DAG_ID", + run_id=f"TEST_DAG_RUN_ID_{i}", + run_type=DagRunType.MANUAL, + execution_date=timezone.parse(self.default_time) + timedelta(days=i - 1), + start_date=timezone.parse(self.default_time), + external_trigger=True, + state=state, + **triggered_by_kwargs, + ) + dagrun_model.updated_at = timezone.parse(self.default_time) + dag_runs.append(dagrun_model) + + if extra_dag: + for i in range(idx_start + 2, idx_start + 4): + dags.append(DagModel(dag_id=f"TEST_DAG_ID_{i}")) + dag_runs.append( + DagRun( + dag_id=f"TEST_DAG_ID_{i}", + run_id=f"TEST_DAG_RUN_ID_{i}", + run_type=DagRunType.MANUAL, + execution_date=timezone.parse(self.default_time_2), + start_date=timezone.parse(self.default_time), + external_trigger=True, + state=state, + ) + ) + if commit: + with create_session() as session: + session.add_all(dag_runs) + session.add_all(dags) + return dag_runs + + +class TestGetDagRuns(TestDagRunEndpoint): + def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): + self._create_test_dag_run(extra_dag=True) + expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"] + response = self.client.get( + "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 200 + dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] + assert dag_run_ids == expected_dag_run_ids + + +class TestGetDagRunBatch(TestDagRunEndpoint): + def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): + self._create_test_dag_run(extra_dag=True) + expected_response_json_1 = { + "dag_id": "TEST_DAG_ID", + "dag_run_id": "TEST_DAG_RUN_ID_1", + "end_date": None, + "state": "running", + "execution_date": self.default_time, + "logical_date": self.default_time, + "external_trigger": True, + "start_date": self.default_time, + "conf": {}, + "data_interval_end": None, + "data_interval_start": None, + "last_scheduling_decision": None, + "run_type": "manual", + "note": None, + } + expected_response_json_1.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) + expected_response_json_2 = { + "dag_id": "TEST_DAG_ID", + "dag_run_id": "TEST_DAG_RUN_ID_2", + "end_date": None, + "state": "running", + "execution_date": self.default_time_2, + "logical_date": self.default_time_2, + "external_trigger": True, + "start_date": self.default_time, + "conf": {}, + "data_interval_end": None, + "data_interval_start": None, + "last_scheduling_decision": None, + "run_type": "manual", + "note": None, + } + expected_response_json_2.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) + + response = self.client.post( + "api/v1/dags/~/dagRuns/list", + json={"dag_ids": []}, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 200 + assert response.json == { + "dag_runs": [ + expected_response_json_1, + expected_response_json_2, + ], + "total_entries": 2, + } + + +class TestPostDagRun(TestDagRunEndpoint): + def test_dagrun_trigger_with_dag_level_permissions(self): + self._create_dag("TEST_DAG_ID") + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns", + json={"conf": {"validated_number": 1}}, + environ_overrides={"REMOTE_USER": "test_no_dag_run_create_permission"}, + ) + assert response.status_code == 200 + + @pytest.mark.parametrize( + "username", + ["test_dag_view_only", "test_view_dags", "test_granular_permissions"], + ) + def test_should_raises_403_unauthorized(self, username): + self._create_dag("TEST_DAG_ID") + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns", + json={ + "dag_run_id": "TEST_DAG_RUN_ID_1", + "execution_date": self.default_time, + }, + environ_overrides={"REMOTE_USER": username}, + ) + assert response.status_code == 403 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py new file mode 100644 index 0000000000000..f0d9b0da298c6 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import os +from typing import TYPE_CHECKING + +import pytest + +from airflow.models import DagBag +from airflow.security import permissions +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +if TYPE_CHECKING: + from airflow.models.dag import DAG + +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) +EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") +EXAMPLE_DAG_ID = "example_bash_operator" +TEST_DAG_ID = "latest_only" +NOT_READABLE_DAG_ID = "latest_only_with_trigger" +TEST_MULTIPLE_DAGS_ID = "asset_produces_1" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test", + role_name="Test", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], + ) + app.appbuilder.sm.sync_perm_for_dag( + TEST_DAG_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + EXAMPLE_DAG_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + TEST_MULTIPLE_DAGS_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test") + + +class TestGetSource: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.clear_db() + + def teardown_method(self) -> None: + self.clear_db() + + @staticmethod + def clear_db(): + clear_db_dags() + clear_db_serialized_dags() + clear_db_dag_code() + + @staticmethod + def _get_dag_file_docstring(fileloc: str) -> str | None: + with open(fileloc) as f: + file_contents = f.read() + module = ast.parse(file_contents) + docstring = ast.get_docstring(module) + return docstring + + def test_should_respond_406(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + test_dag: DAG = dagbag.dags[TEST_DAG_ID] + + url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" + response = self.client.get( + url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"} + ) + + assert 406 == response.status_code + + def test_should_respond_403_not_readable(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID] + + response = self.client.get( + f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", + headers={"Accept": "text/plain"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + read_dag = self.client.get( + f"/api/v1/dags/{NOT_READABLE_DAG_ID}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 403 + assert read_dag.status_code == 403 + + def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID] + + response = self.client.get( + f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", + headers={"Accept": "text/plain"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + + read_dag = self.client.get( + f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 403 + assert read_dag.status_code == 200 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py new file mode 100644 index 0000000000000..adfde1cc5b3eb --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models.dag import DagModel +from airflow.models.dagwarning import DagWarning +from airflow.security import permissions +from airflow.utils.session import create_session +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dag_warnings, clear_db_dags + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, # type:ignore + username="test_with_dag2_read", + role_name="TestWithDag2Read", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), + (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"), + ], + ) + + yield app + + delete_user(app, username="test_with_dag2_read") + + +class TestBaseDagWarning: + timestamp = "2020-06-10T12:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + + def teardown_method(self) -> None: + clear_db_dag_warnings() + clear_db_dags() + + +class TestGetDagWarningEndpoint(TestBaseDagWarning): + def setup_class(self): + clear_db_dag_warnings() + clear_db_dags() + + def setup_method(self): + with create_session() as session: + session.add(DagModel(dag_id="dag1")) + session.add(DagWarning("dag1", "non-existent pool", "test message")) + session.commit() + + def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): + response = self.client.get( + "/api/v1/dagWarnings", + environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, + query_string={"dag_id": "dag1"}, + ) + assert response.status_code == 403 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py new file mode 100644 index 0000000000000..4d302722223d8 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Generator + +import pytest +import time_machine + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS + +try: + from airflow.models.asset import AssetDagRunQueue, AssetModel +except ImportError: + if AIRFLOW_V_3_0_PLUS: + raise + else: + pass +from airflow.security import permissions +from airflow.utils import timezone +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.db import clear_db_assets, clear_db_runs +from tests.test_utils.www import _check_last_log + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_queued_event", + role_name="TestQueuedEvent", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET), + ], + ) + + yield app + + delete_user(app, username="test_queued_event") + + +class TestAssetEndpoint: + default_time = "2020-06-11T18:00:00+00:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() + clear_db_assets() + clear_db_runs() + + def teardown_method(self) -> None: + clear_db_assets() + clear_db_runs() + + def _create_asset(self, session): + asset_model = AssetModel( + id=1, + uri="s3://bucket/key", + extra={"foo": "bar"}, + created_at=timezone.parse(self.default_time), + updated_at=timezone.parse(self.default_time), + ) + session.add(asset_model) + session.commit() + return asset_model + + +class TestQueuedEventEndpoint(TestAssetEndpoint): + @pytest.fixture + def time_freezer(self) -> Generator: + freezer = time_machine.travel(self.default_time, tick=False) + freezer.start() + + yield + + freezer.stop() + + def _create_asset_dag_run_queues(self, dag_id, dataset_id, session): + ddrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) + session.add(ddrq) + session.commit() + return ddrq + + +class TestGetDagDatasetQueuedEvent(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + dataset_uri = "s3://bucket/key" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + + def test_should_respond_404(self): + dag_id = "not_exists" + dataset_uri = "not_exists" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDagDatasetQueuedEvent(TestAssetEndpoint): + def test_delete_should_respond_204(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_uri = "s3://bucket/key" + dataset_id = self._create_asset(session).id + + ddrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) + session.add(ddrq) + session.commit() + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 1 + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 204 + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 0 + _check_last_log( + session, dag_id=dag_id, event="api.delete_dag_dataset_queued_event", execution_date=None + ) + + def test_should_respond_404(self): + dag_id = "not_exists" + dataset_uri = "not_exists" + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestGetDagDatasetQueuedEvents(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "queued_events": [ + { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + ], + "total_entries": 1, + } + + def test_should_respond_404(self): + dag_id = "not_exists" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDagDatasetQueuedEvents(TestAssetEndpoint): + def test_should_respond_404(self): + dag_id = "not_exists" + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + dataset_uri = "s3://bucket/key" + + response = self.client.get( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "queued_events": [ + { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + ], + "total_entries": 1, + } + + def test_should_respond_404(self): + dataset_uri = "not_exists" + + response = self.client.get( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint): + def test_delete_should_respond_204(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, dataset_id, session) + dataset_uri = "s3://bucket/key" + + response = self.client.delete( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 204 + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 0 + _check_last_log(session, dag_id=None, event="api.delete_dataset_queued_events", execution_date=None) + + def test_should_respond_404(self): + dataset_uri = "not_exists" + + response = self.client.delete( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py new file mode 100644 index 0000000000000..acf3ca62684a1 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models import Log +from airflow.security import permissions +from airflow.utils import timezone +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_logs + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_granular", + role_name="TestGranular", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_ID_1", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_ID_2", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test_granular") + + +@pytest.fixture +def task_instance(session, create_task_instance, request): + return create_task_instance( + session=session, + dag_id="TEST_DAG_ID", + task_id="TEST_TASK_ID", + run_id="TEST_RUN_ID", + execution_date=request.instance.default_time, + ) + + +@pytest.fixture +def create_log_model(create_task_instance, task_instance, session, request): + def maker(event, when, **kwargs): + log_model = Log( + event=event, + task_instance=task_instance, + **kwargs, + ) + log_model.dttm = when + + session.add(log_model) + session.flush() + return log_model + + return maker + + +class TestEventLogEndpoint: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_logs() + self.default_time = timezone.parse("2020-06-10T20:00:00+00:00") + self.default_time_2 = timezone.parse("2020-06-11T07:00:00+00:00") + + def teardown_method(self) -> None: + clear_db_logs() + + +class TestGetEventLogs(TestEventLogEndpoint): + def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session): + eventlog1 = create_log_model( + event="TEST_EVENT_1", + dag_id="TEST_DAG_ID_1", + task_id="TEST_TASK_ID_1", + owner="TEST_OWNER_1", + when=self.default_time, + ) + eventlog2 = create_log_model( + event="TEST_EVENT_2", + dag_id="TEST_DAG_ID_2", + task_id="TEST_TASK_ID_2", + owner="TEST_OWNER_2", + when=self.default_time_2, + ) + session.add_all([eventlog1, eventlog2]) + session.commit() + for attr in ["dag_id", "task_id", "owner", "event"]: + attr_value = f"TEST_{attr}_1".upper() + response = self.client.get( + f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} + ) + assert response.status_code == 200 + assert response.json["total_entries"] == 1 + assert len(response.json["event_logs"]) == 1 + assert response.json["event_logs"][0][attr] == attr_value + + def test_should_filter_eventlogs_by_included_events(self, create_log_model): + for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: + create_log_model(event=event, when=self.default_time) + response = self.client.get( + "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2", + environ_overrides={"REMOTE_USER": "test_granular"}, + ) + assert response.status_code == 200 + response_data = response.json + assert len(response_data["event_logs"]) == 2 + assert response_data["total_entries"] == 2 + assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]} + + def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): + for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: + create_log_model(event=event, when=self.default_time) + response = self.client.get( + "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2", + environ_overrides={"REMOTE_USER": "test_granular"}, + ) + assert response.status_code == 200 + response_data = response.json + assert len(response_data["event_logs"]) == 1 + assert response_data["total_entries"] == 1 + assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py new file mode 100644 index 0000000000000..a2fa1d028a3f2 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models.dag import DagModel +from airflow.security import permissions +from airflow.utils import timezone +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS, ParseImportError +from tests.test_utils.db import clear_db_dags, clear_db_import_errors +from tests.test_utils.permissions import _resource_name + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +TEST_DAG_IDS = ["test_dag", "test_dag2"] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_single_dag", + role_name="TestSingleDAG", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], + ) + # For some reason, DAG level permissions are not synced when in the above list of perms, + # so do it manually here: + app.appbuilder.sm.bulk_sync_roles( + [ + { + "role": "TestSingleDAG", + "perms": [ + ( + permissions.ACTION_CAN_READ, + _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG), + ) + ], + } + ] + ) + + yield app + + delete_user(app, username="test_single_dag") + + +class TestBaseImportError: + timestamp = "2020-06-10T12:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + + clear_db_import_errors() + clear_db_dags() + + def teardown_method(self) -> None: + clear_db_import_errors() + clear_db_dags() + + @staticmethod + def _normalize_import_errors(import_errors): + for i, import_error in enumerate(import_errors, 1): + import_error["import_error_id"] = i + + +class TestGetImportErrorEndpoint(TestBaseImportError): + def test_should_raise_403_forbidden_without_dag_read(self, session): + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 403 + + def test_should_return_200_with_single_dag_read(self, session): + dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py") + session.add(dag_model) + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + response_data["import_error_id"] = 1 + assert { + "filename": "Lorem_ipsum.py", + "import_error_id": 1, + "stack_trace": "Lorem ipsum", + "timestamp": "2020-06-10T12:00:00+00:00", + } == response_data + + def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session): + for dag_id in TEST_DAG_IDS: + dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py") + session.add(dag_model) + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + response_data["import_error_id"] = 1 + assert { + "filename": "Lorem_ipsum.py", + "import_error_id": 1, + "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", + "timestamp": "2020-06-10T12:00:00+00:00", + } == response_data + + +class TestGetImportErrorsEndpoint(TestBaseImportError): + def test_get_import_errors_single_dag(self, session): + for dag_id in TEST_DAG_IDS: + fake_filename = f"/tmp/{dag_id}.py" + dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) + session.add(dag_model) + importerror = ParseImportError( + filename=fake_filename, + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(importerror) + session.commit() + + response = self.client.get( + "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + self._normalize_import_errors(response_data["import_errors"]) + assert { + "import_errors": [ + { + "filename": "/tmp/test_dag.py", + "import_error_id": 1, + "stack_trace": "Lorem ipsum", + "timestamp": "2020-06-10T12:00:00+00:00", + }, + ], + "total_entries": 1, + } == response_data + + def test_get_import_errors_single_dag_in_dagfile(self, session): + for dag_id in TEST_DAG_IDS: + fake_filename = "/tmp/all_in_one.py" + dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) + session.add(dag_model) + + importerror = ParseImportError( + filename="/tmp/all_in_one.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(importerror) + session.commit() + + response = self.client.get( + "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + self._normalize_import_errors(response_data["import_errors"]) + assert { + "import_errors": [ + { + "filename": "/tmp/all_in_one.py", + "import_error_id": 1, + "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", + "timestamp": "2020-06-10T12:00:00+00:00", + }, + ], + "total_entries": 1, + } == response_data diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py index 30cfaeb227903..413a49a9d86a1 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py @@ -19,6 +19,13 @@ import pytest from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_role, + create_user, + delete_role, + delete_user, +) +from tests.test_utils.api_connexion_utils import assert_401 from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): @@ -27,13 +34,6 @@ from airflow.security import permissions -from tests.test_utils.api_connexion_utils import ( - assert_401, - create_role, - create_user, - delete_role, - delete_user, -) pytestmark = pytest.mark.db_test @@ -42,7 +42,7 @@ def configured_app(minimal_app_for_auth_api): app = minimal_app_for_auth_api create_user( - app, # type: ignore + app, username="test", role_name="Test", permissions=[ @@ -53,11 +53,11 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name="TestNoPermissions") yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestRoleEndpoint: diff --git a/tests/api_connexion/schemas/test_role_and_permission_schema.py b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py similarity index 85% rename from tests/api_connexion/schemas/test_role_and_permission_schema.py rename to tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py index f2967d519794c..4a2f0068e5e4a 100644 --- a/tests/api_connexion/schemas/test_role_and_permission_schema.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py @@ -31,19 +31,19 @@ class TestRoleCollectionItemSchema: @pytest.fixture(scope="class") - def role(self, minimal_app_for_api): + def role(self, minimal_app_for_auth_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_auth_api, # type: ignore name="Test", permissions=[ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), ], ) - delete_role(minimal_app_for_api, "Test") + delete_role(minimal_app_for_auth_api, "Test") @pytest.fixture(autouse=True) - def _set_attrs(self, minimal_app_for_api, role): - self.app = minimal_app_for_api + def _set_attrs(self, minimal_app_for_auth_api, role): + self.app = minimal_app_for_auth_api self.role = role def test_serialize(self): @@ -67,26 +67,26 @@ def test_deserialize(self): class TestRoleCollectionSchema: @pytest.fixture(scope="class") - def role1(self, minimal_app_for_api): + def role1(self, minimal_app_for_auth_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_auth_api, # type: ignore name="Test1", permissions=[ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), ], ) - delete_role(minimal_app_for_api, "Test1") + delete_role(minimal_app_for_auth_api, "Test1") @pytest.fixture(scope="class") - def role2(self, minimal_app_for_api): + def role2(self, minimal_app_for_auth_api): yield create_role( - minimal_app_for_api, # type: ignore + minimal_app_for_auth_api, # type: ignore name="Test2", permissions=[ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), ], ) - delete_role(minimal_app_for_api, "Test2") + delete_role(minimal_app_for_auth_api, "Test2") def test_serialize(self, role1, role2): instance = RoleCollection([role1, role2], total_entries=2) diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py new file mode 100644 index 0000000000000..69b3c221eae93 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py @@ -0,0 +1,427 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime as dt +import urllib + +import pytest + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.models import DagRun, TaskInstance +from airflow.security import permissions +from airflow.utils.session import provide_session +from airflow.utils.state import State +from airflow.utils.timezone import datetime +from airflow.utils.types import DagRunType +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +DEFAULT_DATETIME_1 = datetime(2020, 1, 1) +DEFAULT_DATETIME_STR_1 = "2020-01-01T00:00:00+00:00" +DEFAULT_DATETIME_STR_2 = "2020-01-02T00:00:00+00:00" + +QUOTED_DEFAULT_DATETIME_STR_1 = urllib.parse.quote(DEFAULT_DATETIME_STR_1) +QUOTED_DEFAULT_DATETIME_STR_2 = urllib.parse.quote(DEFAULT_DATETIME_STR_2) + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_dag_read_only", + role_name="TestDagReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + create_user( + app, + username="test_task_read_only", + role_name="TestTaskReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + create_user( + app, + username="test_read_only_one_dag", + role_name="TestReadOnlyOneDag", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms, + # so do it manually here: + app.appbuilder.sm.bulk_sync_roles( + [ + { + "role": "TestReadOnlyOneDag", + "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")], + } + ] + ) + + yield app + + delete_user(app, username="test_dag_read_only") + delete_user(app, username="test_task_read_only") + delete_user(app, username="test_read_only_one_dag") + delete_roles(app) + + +class TestTaskInstanceEndpoint: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app, dagbag) -> None: + self.default_time = DEFAULT_DATETIME_1 + self.ti_init = { + "execution_date": self.default_time, + "state": State.RUNNING, + } + self.ti_extras = { + "start_date": self.default_time + dt.timedelta(days=1), + "end_date": self.default_time + dt.timedelta(days=2), + "pid": 100, + "duration": 10000, + "pool": "default_pool", + "queue": "default_queue", + "job_id": 0, + } + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_runs() + clear_db_sla_miss() + clear_rendered_ti_fields() + self.dagbag = dagbag + + def create_task_instances( + self, + session, + dag_id: str = "example_python_operator", + update_extras: bool = True, + task_instances=None, + dag_run_state=State.RUNNING, + with_ti_history=False, + ): + """Method to create task instances using kwargs and default arguments""" + + dag = self.dagbag.get_dag(dag_id) + tasks = dag.tasks + counter = len(tasks) + if task_instances is not None: + counter = min(len(task_instances), counter) + + run_id = "TEST_DAG_RUN_ID" + execution_date = self.ti_init.pop("execution_date", self.default_time) + dr = None + + tis = [] + for i in range(counter): + if task_instances is None: + pass + elif update_extras: + self.ti_extras.update(task_instances[i]) + else: + self.ti_init.update(task_instances[i]) + + if "execution_date" in self.ti_init: + run_id = f"TEST_DAG_RUN_ID_{i}" + execution_date = self.ti_init.pop("execution_date") + dr = None + + if not dr: + dr = DagRun( + run_id=run_id, + dag_id=dag_id, + execution_date=execution_date, + run_type=DagRunType.MANUAL, + state=dag_run_state, + ) + session.add(dr) + ti = TaskInstance(task=tasks[i], **self.ti_init) + session.add(ti) + ti.dag_run = dr + ti.note = "placeholder-note" + + for key, value in self.ti_extras.items(): + setattr(ti, key, value) + tis.append(ti) + + session.commit() + if with_ti_history: + for ti in tis: + ti.try_number = 1 + session.merge(ti) + session.commit() + dag.clear() + for ti in tis: + ti.try_number = 2 + ti.queue = "default_queue" + session.merge(ti) + session.commit() + return tis + + +class TestGetTaskInstance(TestTaskInstanceEndpoint): + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + @provide_session + def test_should_respond_200(self, username, session): + self.create_task_instances(session) + # Update ti and set operator to None to + # test that operator field is nullable. + # This prevents issue when users upgrade to 2.0+ + # from 1.10.x + # https://github.com/apache/airflow/issues/14421 + session.query(TaskInstance).update({TaskInstance.operator: None}, synchronize_session="fetch") + session.commit() + response = self.client.get( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + environ_overrides={"REMOTE_USER": username}, + ) + assert response.status_code == 200 + + +class TestGetTaskInstances(TestTaskInstanceEndpoint): + @pytest.mark.parametrize( + "task_instances, user, expected_ti", + [ + pytest.param( + { + "example_python_operator": 2, + "example_skip_dag": 1, + }, + "test_read_only_one_dag", + 2, + ), + pytest.param( + { + "example_python_operator": 1, + "example_skip_dag": 2, + }, + "test_read_only_one_dag", + 1, + ), + ], + ) + def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session): + for dag_id in task_instances: + self.create_task_instances( + session, + task_instances=[ + {"execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=i)} + for i in range(task_instances[dag_id]) + ], + dag_id=dag_id, + ) + response = self.client.get( + "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} + ) + assert response.status_code == 200 + assert response.json["total_entries"] == expected_ti + assert len(response.json["task_instances"]) == expected_ti + + +class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): + @pytest.mark.parametrize( + "task_instances, update_extras, payload, expected_ti_count, username", + [ + pytest.param( + [ + {"pool": "test_pool_1"}, + {"pool": "test_pool_2"}, + {"pool": "test_pool_3"}, + ], + True, + {"pool": ["test_pool_1", "test_pool_2"]}, + 2, + "test_dag_read_only", + id="test pool filter", + ), + pytest.param( + [ + {"state": State.RUNNING}, + {"state": State.QUEUED}, + {"state": State.SUCCESS}, + {"state": State.NONE}, + ], + False, + {"state": ["running", "queued", "none"]}, + 3, + "test_task_read_only", + id="test state filter", + ), + pytest.param( + [ + {"state": State.NONE}, + {"state": State.NONE}, + {"state": State.NONE}, + {"state": State.NONE}, + ], + False, + {}, + 4, + "test_task_read_only", + id="test dag with null states", + ), + pytest.param( + [ + {"end_date": DEFAULT_DATETIME_1}, + {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, + {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, + ], + True, + { + "end_date_gte": DEFAULT_DATETIME_STR_1, + "end_date_lte": DEFAULT_DATETIME_STR_2, + }, + 2, + "test_task_read_only", + id="test end date filter", + ), + pytest.param( + [ + {"start_date": DEFAULT_DATETIME_1}, + {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, + {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, + ], + True, + { + "start_date_gte": DEFAULT_DATETIME_STR_1, + "start_date_lte": DEFAULT_DATETIME_STR_2, + }, + 2, + "test_dag_read_only", + id="test start date filter", + ), + ], + ) + def test_should_respond_200( + self, task_instances, update_extras, payload, expected_ti_count, username, session + ): + self.create_task_instances( + session, + update_extras=update_extras, + task_instances=task_instances, + ) + response = self.client.post( + "/api/v1/dags/~/dagRuns/~/taskInstances/list", + environ_overrides={"REMOTE_USER": username}, + json=payload, + ) + assert response.status_code == 200, response.json + assert expected_ti_count == response.json["total_entries"] + assert expected_ti_count == len(response.json["task_instances"]) + + def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session): + self.create_task_instances(session=session) + self.create_task_instances(session=session, dag_id="example_skip_dag") + payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]} + + response = self.client.post( + "/api/v1/dags/~/dagRuns/~/taskInstances/list", + environ_overrides={"REMOTE_USER": "test_read_only_one_dag"}, + json=payload, + ) + assert response.status_code == 403 + assert response.json == { + "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']", + "status": 403, + "title": "Forbidden", + "type": EXCEPTIONS_LINK_MAP[403], + } + + +class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint): + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + def test_should_raise_403_forbidden(self, username): + response = self.client.post( + "/api/v1/dags/example_python_operator/updateTaskInstancesState", + environ_overrides={"REMOTE_USER": username}, + json={ + "dry_run": True, + "task_id": "print_the_context", + "execution_date": DEFAULT_DATETIME_1.isoformat(), + "include_upstream": True, + "include_downstream": True, + "include_future": True, + "include_past": True, + "new_state": "failed", + }, + ) + assert response.status_code == 403 + + +class TestPatchTaskInstance(TestTaskInstanceEndpoint): + ENDPOINT_URL = ( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" + ) + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + def test_should_raise_403_forbidden(self, username): + response = self.client.patch( + self.ENDPOINT_URL, + environ_overrides={"REMOTE_USER": username}, + json={ + "dry_run": True, + "new_state": "failed", + }, + ) + assert response.status_code == 403 + + +class TestGetTaskInstanceTry(TestTaskInstanceEndpoint): + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + @provide_session + def test_should_respond_200(self, username, session): + self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True) + + response = self.client.get( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1", + environ_overrides={"REMOTE_USER": username}, + ) + assert response.status_code == 200 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py index bc400c8a43fad..7f2c885bab52c 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py @@ -30,7 +30,12 @@ with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.models import User -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_role, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_role, + delete_user, +) +from tests.test_utils.api_connexion_utils import assert_401 from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -43,7 +48,7 @@ def configured_app(minimal_app_for_auth_api): app = minimal_app_for_auth_api create_user( - app, # type: ignore + app, username="test", role_name="Test", permissions=[ @@ -53,12 +58,12 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name="TestNoPermissions") yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") delete_role(app, name="TestNoPermissions") diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py index 265407622e269..f3399de6a9775 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py @@ -18,6 +18,7 @@ import pytest +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_role, delete_role from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): @@ -30,8 +31,6 @@ DEFAULT_TIME = "2021-01-09T13:59:56.336000+00:00" -from tests.test_utils.api_connexion_utils import create_role, delete_role # noqa: E402 - pytestmark = pytest.mark.db_test diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py new file mode 100644 index 0000000000000..a8e71e1a82466 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models import Variable +from airflow.security import permissions +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_variables + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user( + app, + username="test_read_only", + role_name="TestReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), + ], + ) + create_user( + app, + username="test_delete_only", + role_name="TestDeleteOnly", + permissions=[ + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), + ], + ) + + yield app + + delete_user(app, username="test_read_only") + delete_user(app, username="test_delete_only") + + +class TestVariableEndpoint: + @pytest.fixture(autouse=True) + def setup_method(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_variables() + + def teardown_method(self) -> None: + clear_db_variables() + + +class TestGetVariable(TestVariableEndpoint): + @pytest.mark.parametrize( + "user, expected_status_code", + [ + ("test_read_only", 200), + ("test_delete_only", 403), + ], + ) + def test_read_variable(self, user, expected_status_code): + expected_value = '{"foo": 1}' + Variable.set("TEST_VARIABLE_KEY", expected_value) + response = self.client.get( + "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": user} + ) + assert response.status_code == expected_status_code + if expected_status_code == 200: + assert response.json == {"key": "TEST_VARIABLE_KEY", "value": expected_value, "description": None} diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py new file mode 100644 index 0000000000000..01336f9957c6d --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import timedelta + +import pytest + +from airflow.models.dag import DagModel +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance +from airflow.models.xcom import BaseXCom, XCom +from airflow.operators.empty import EmptyOperator +from airflow.security import permissions +from airflow.utils.dates import parse_execution_date +from airflow.utils.session import create_session +from airflow.utils.types import DagRunType +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +class CustomXCom(BaseXCom): + @classmethod + def deserialize_value(cls, xcom: XCom): + return f"real deserialized {super().deserialize_value(xcom)}" + + def orm_deserialize_value(self): + return f"orm deserialized {super().orm_deserialize_value()}" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user( + app, + username="test_granular_permissions", + role_name="TestGranularDag", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), + ], + ) + app.appbuilder.sm.sync_perm_for_dag( + "test-dag-id-1", + access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test_granular_permissions") + + +def _compare_xcom_collections(collection1: dict, collection_2: dict): + assert collection1.get("total_entries") == collection_2.get("total_entries") + + def sort_key(record): + return ( + record.get("dag_id"), + record.get("task_id"), + record.get("execution_date"), + record.get("map_index"), + record.get("key"), + ) + + assert sorted(collection1.get("xcom_entries", []), key=sort_key) == sorted( + collection_2.get("xcom_entries", []), key=sort_key + ) + + +class TestXComEndpoint: + @staticmethod + def clean_db(): + clear_db_dags() + clear_db_runs() + clear_db_xcom() + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + """ + Setup For XCom endpoint TC + """ + self.app = configured_app + self.client = self.app.test_client() # type:ignore + # clear existing xcoms + self.clean_db() + + def teardown_method(self) -> None: + """ + Clear Hanging XComs + """ + self.clean_db() + + +class TestGetXComEntries(TestXComEndpoint): + def test_should_respond_200_with_tilde_and_granular_dag_access(self): + dag_id_1 = "test-dag-id-1" + task_id_1 = "test-task-id-1" + execution_date = "2005-04-02T00:00:00+00:00" + execution_date_parsed = parse_execution_date(execution_date) + dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + self._create_xcom_entries(dag_id_1, dag_run_id_1, execution_date_parsed, task_id_1) + + dag_id_2 = "test-dag-id-2" + task_id_2 = "test-task-id-2" + run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + self._create_xcom_entries(dag_id_2, run_id_2, execution_date_parsed, task_id_2) + self._create_invalid_xcom_entries(execution_date_parsed) + response = self.client.get( + "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + + assert 200 == response.status_code + response_data = response.json + for xcom_entry in response_data["xcom_entries"]: + xcom_entry["timestamp"] = "TIMESTAMP" + _compare_xcom_collections( + response_data, + { + "xcom_entries": [ + { + "dag_id": dag_id_1, + "execution_date": execution_date, + "key": "test-xcom-key-1", + "task_id": task_id_1, + "timestamp": "TIMESTAMP", + "map_index": -1, + }, + { + "dag_id": dag_id_1, + "execution_date": execution_date, + "key": "test-xcom-key-2", + "task_id": task_id_1, + "timestamp": "TIMESTAMP", + "map_index": -1, + }, + ], + "total_entries": 2, + }, + ) + + def _create_xcom_entries(self, dag_id, run_id, execution_date, task_id, mapped_ti=False): + with create_session() as session: + dag = DagModel(dag_id=dag_id) + session.add(dag) + dagrun = DagRun( + dag_id=dag_id, + run_id=run_id, + execution_date=execution_date, + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun) + if mapped_ti: + for i in [0, 1]: + ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, map_index=i) + ti.dag_id = dag_id + session.add(ti) + else: + ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id) + ti.dag_id = dag_id + session.add(ti) + + for i in [1, 2]: + if mapped_ti: + key = "test-xcom-key" + map_index = i - 1 + else: + key = f"test-xcom-key-{i}" + map_index = -1 + + XCom.set( + key=key, value="TEST", run_id=run_id, task_id=task_id, dag_id=dag_id, map_index=map_index + ) + + def _create_invalid_xcom_entries(self, execution_date): + """ + Invalid XCom entries to test join query + """ + with create_session() as session: + dag = DagModel(dag_id="invalid_dag") + session.add(dag) + dagrun = DagRun( + dag_id="invalid_dag", + run_id="invalid_run_id", + execution_date=execution_date + timedelta(days=1), + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun) + dagrun1 = DagRun( + dag_id="invalid_dag", + run_id="not_this_run_id", + execution_date=execution_date, + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun1) + ti = TaskInstance(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id") + ti.dag_id = "invalid_dag" + session.add(ti) + for i in [1, 2]: + XCom.set( + key=f"invalid-xcom-key-{i}", + value="TEST", + run_id="not_this_run_id", + task_id="invalid_task", + dag_id="invalid_dag", + ) diff --git a/tests/providers/fab/auth_manager/conftest.py b/tests/providers/fab/auth_manager/conftest.py index 22c29dd229fa1..a8fbe5fbdaaae 100644 --- a/tests/providers/fab/auth_manager/conftest.py +++ b/tests/providers/fab/auth_manager/conftest.py @@ -30,7 +30,10 @@ def minimal_app_for_auth_api(): "init_appbuilder", "init_api_auth", "init_api_auth_provider", + "init_api_connexion", "init_api_error_handlers", + "init_airflow_session_interface", + "init_appbuilder_views", ] ) def factory(): @@ -39,7 +42,11 @@ def factory(): ( "api", "auth_backends", - ): "tests.test_utils.remote_user_api_auth_backend,airflow.api.auth.backend.session" + ): "tests.providers.fab.auth_manager.api_endpoints.remote_user_api_auth_backend,airflow.api.auth.backend.session", + ( + "core", + "auth_manager", + ): "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", } ): _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore @@ -58,3 +65,11 @@ def set_auth_role_public(request): yield app.config["AUTH_ROLE_PUBLIC"] = auto_role_public + + +@pytest.fixture(scope="module") +def dagbag(): + from airflow.models import DagBag + + DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() + return DagBag(include_examples=True, read_dags_from_db=True) diff --git a/tests/providers/fab/auth_manager/test_security.py b/tests/providers/fab/auth_manager/test_security.py index 156b5cf626271..bebb52c256fc8 100644 --- a/tests/providers/fab/auth_manager/test_security.py +++ b/tests/providers/fab/auth_manager/test_security.py @@ -49,7 +49,7 @@ from airflow.www.auth import get_access_denied_message from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.utils import CustomSQLAInterface -from tests.test_utils.api_connexion_utils import ( +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( create_user, create_user_scope, delete_role, diff --git a/tests/providers/fab/auth_manager/views/test_permissions.py b/tests/providers/fab/auth_manager/views/test_permissions.py index 0b1073df287fa..f24d9b738343b 100644 --- a/tests/providers/fab/auth_manager/views/test_permissions.py +++ b/tests/providers/fab/auth_manager/views/test_permissions.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_roles_list.py b/tests/providers/fab/auth_manager/views/test_roles_list.py index 156f07df41209..8de63ad5ba88a 100644 --- a/tests/providers/fab/auth_manager/views/test_roles_list.py +++ b/tests/providers/fab/auth_manager/views/test_roles_list.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_user.py b/tests/providers/fab/auth_manager/views/test_user.py index 6660ab926d886..62b03a99e7c2c 100644 --- a/tests/providers/fab/auth_manager/views/test_user.py +++ b/tests/providers/fab/auth_manager/views/test_user.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_user_edit.py b/tests/providers/fab/auth_manager/views/test_user_edit.py index 65937b6f83d33..8099f67948183 100644 --- a/tests/providers/fab/auth_manager/views/test_user_edit.py +++ b/tests/providers/fab/auth_manager/views/test_user_edit.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/providers/fab/auth_manager/views/test_user_stats.py b/tests/providers/fab/auth_manager/views/test_user_stats.py index 8cb260fcf1ec4..ae09cf92252c6 100644 --- a/tests/providers/fab/auth_manager/views/test_user_stats.py +++ b/tests/providers/fab/auth_manager/views/test_user_stats.py @@ -21,7 +21,7 @@ from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS from tests.test_utils.www import client_with_login diff --git a/tests/test_utils/api_connexion_utils.py b/tests/test_utils/api_connexion_utils.py index af746b2d55468..48869ee48078d 100644 --- a/tests/test_utils/api_connexion_utils.py +++ b/tests/test_utils/api_connexion_utils.py @@ -17,6 +17,7 @@ from __future__ import annotations from contextlib import contextmanager +from typing import TYPE_CHECKING from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from tests.test_utils.compat import ignore_provider_compatibility_error @@ -24,6 +25,9 @@ with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.security_manager.override import EXISTING_ROLES +if TYPE_CHECKING: + from flask import Flask + @contextmanager def create_test_client(app, user_name, role_name, permissions): @@ -44,7 +48,11 @@ def create_user_scope(app, username, **kwargs): It will create a user and provide it for the fixture via YIELD (generator) then will tidy up once test is complete """ - test_user = create_user(app, username, **kwargs) + from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user as create_user_fab, + ) + + test_user = create_user_fab(app, username, **kwargs) try: yield test_user @@ -52,27 +60,20 @@ def create_user_scope(app, username, **kwargs): delete_user(app, username) -def create_user(app, username, role_name=None, email=None, permissions=None): - appbuilder = app.appbuilder - +def create_user(app: Flask, username: str, role_name: str | None): # Removes user and role so each test has isolated test data. delete_user(app, username) - role = None - if role_name: - delete_role(app, role_name) - role = create_role(app, role_name, permissions) - else: - role = [] - - return appbuilder.sm.add_user( - username=username, - first_name=username, - last_name=username, - email=email or f"{username}@example.org", - role=role, - password=username, + + users = app.config.get("SIMPLE_AUTH_MANAGER_USERS", []) + users.append( + { + "username": username, + "role": role_name, + } ) + app.config["SIMPLE_AUTH_MANAGER_USERS"] = users + def create_role(app, name, permissions=None): appbuilder = app.appbuilder @@ -87,14 +88,6 @@ def create_role(app, name, permissions=None): return role -def set_user_single_role(app, user, role_name): - role = create_role(app, role_name) - if role not in user.roles: - user.roles = [role] - app.appbuilder.sm.update_user(user) - user._perms = None - - def delete_role(app, name): if name not in EXISTING_ROLES: if app.appbuilder.sm.find_role(name): @@ -106,20 +99,11 @@ def delete_roles(app): delete_role(app, role.name) -def delete_user(app, username): - appbuilder = app.appbuilder - for user in appbuilder.sm.get_all_users(): - if user.username == username: - _ = [ - delete_role(app, role.name) for role in user.roles if role and role.name not in EXISTING_ROLES - ] - appbuilder.sm.del_register_user(user) - break - - -def delete_users(app): - for user in app.appbuilder.sm.get_all_users(): - delete_user(app, user.username) +def delete_user(app: Flask, username): + users = app.config.get("SIMPLE_AUTH_MANAGER_USERS", []) + users = [user for user in users if user["username"] != username] + + app.config["SIMPLE_AUTH_MANAGER_USERS"] = users def assert_401(response): diff --git a/tests/test_utils/remote_user_api_auth_backend.py b/tests/test_utils/remote_user_api_auth_backend.py index b7714e5192e6a..59df201e530e4 100644 --- a/tests/test_utils/remote_user_api_auth_backend.py +++ b/tests/test_utils/remote_user_api_auth_backend.py @@ -15,17 +15,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Default authentication backend - everything is allowed""" - from __future__ import annotations import logging from functools import wraps from typing import TYPE_CHECKING, Callable, TypeVar, cast -from flask import Response, request -from flask_login import login_user +from flask import Response, request, session +from airflow.auth.managers.simple.user import SimpleAuthManagerUser from airflow.utils.airflow_flask_app import get_airflow_app if TYPE_CHECKING: @@ -36,25 +34,15 @@ CLIENT_AUTH: tuple[str, str] | AuthBase | None = None -def init_app(_): - """Initializes authentication backend""" +def init_app(_): ... T = TypeVar("T", bound=Callable) -def _lookup_user(user_email_or_username: str): - security_manager = get_airflow_app().appbuilder.sm - user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user( - username=user_email_or_username - ) - if not user: - return None - - if not user.is_active: - return None - - return user +def _lookup_user(username: str): + users = get_airflow_app().config.get("SIMPLE_AUTH_MANAGER_USERS", []) + return next((user for user in users if user["username"] == username), None) def requires_authentication(function: T): @@ -69,13 +57,13 @@ def decorated(*args, **kwargs): log.debug("Looking for user: %s", user_id) - user = _lookup_user(user_id) - if not user: + user_dict = _lookup_user(user_id) + if not user_dict: return Response("Forbidden", 403) - log.debug("Found user: %s", user) + log.debug("Found user: %s", user_dict) + session["user"] = SimpleAuthManagerUser(username=user_dict["username"], role=user_dict["role"]) - login_user(user, remember=False) return function(*args, **kwargs) return cast(T, decorated) diff --git a/tests/www/views/test_views_custom_user_views.py b/tests/www/views/test_views_custom_user_views.py index ae6d0132827c2..84947a8e5f36f 100644 --- a/tests/www/views/test_views_custom_user_views.py +++ b/tests/www/views/test_views_custom_user_views.py @@ -27,7 +27,10 @@ from airflow import settings from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_role +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user as create_user, + delete_role, +) from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login pytestmark = pytest.mark.db_test diff --git a/tests/www/views/test_views_dagrun.py b/tests/www/views/test_views_dagrun.py index 39c17d086f379..d95955246ac78 100644 --- a/tests/www/views/test_views_dagrun.py +++ b/tests/www/views/test_views_dagrun.py @@ -24,7 +24,11 @@ from airflow.utils import timezone from airflow.utils.session import create_session from airflow.www.views import DagRunModelView -from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login from tests.www.views.test_views_tasks import _get_appbuilder_pk_string diff --git a/tests/www/views/test_views_home.py b/tests/www/views/test_views_home.py index 5393115041392..ddec0c0bcfed3 100644 --- a/tests/www/views/test_views_home.py +++ b/tests/www/views/test_views_home.py @@ -27,7 +27,7 @@ from airflow.utils.state import State from airflow.www.utils import UIAlert from airflow.www.views import FILTER_LASTRUN_COOKIE, FILTER_STATUS_COOKIE, FILTER_TAGS_COOKIE -from tests.test_utils.api_connexion_utils import create_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user from tests.test_utils.db import clear_db_dags, clear_db_import_errors, clear_db_serialized_dags from tests.test_utils.permissions import _resource_name from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index f5cc011fb6f0e..7b65051724c27 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -44,7 +44,11 @@ from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunType from airflow.www.views import TaskInstanceModelView, _safe_parse_datetime -from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS from tests.test_utils.config import conf_vars from tests.test_utils.db import clear_db_runs, clear_db_xcom diff --git a/tests/www/views/test_views_variable.py b/tests/www/views/test_views_variable.py index a91a12ddc470b..b7fa8b37c52c8 100644 --- a/tests/www/views/test_views_variable.py +++ b/tests/www/views/test_views_variable.py @@ -25,7 +25,7 @@ from airflow.models import Variable from airflow.security import permissions from airflow.utils.session import create_session -from tests.test_utils.api_connexion_utils import create_user +from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user from tests.test_utils.www import ( _check_last_log, check_content_in_response,