diff --git a/airflow/auth/managers/simple/simple_auth_manager.py b/airflow/auth/managers/simple/simple_auth_manager.py
index 451068733667c..4a9639a998c46 100644
--- a/airflow/auth/managers/simple/simple_auth_manager.py
+++ b/airflow/auth/managers/simple/simple_auth_manager.py
@@ -221,7 +221,12 @@ def _is_authorized(
user = self.get_user()
if not user:
return False
- role_str = user.get_role().upper()
+
+ user_role = user.get_role()
+ if not user_role:
+ return False
+
+ role_str = user_role.upper()
role = SimpleAuthManagerRole[role_str]
if role == SimpleAuthManagerRole.ADMIN:
return True
diff --git a/airflow/auth/managers/simple/user.py b/airflow/auth/managers/simple/user.py
index fa032f596ee44..f4591b0b1c751 100644
--- a/airflow/auth/managers/simple/user.py
+++ b/airflow/auth/managers/simple/user.py
@@ -24,10 +24,10 @@ class SimpleAuthManagerUser(BaseUser):
User model for users managed by the simple auth manager.
:param username: The username
- :param role: The role associated to the user
+ :param role: The role associated to the user. If not provided, the user has no permission
"""
- def __init__(self, *, username: str, role: str) -> None:
+ def __init__(self, *, username: str, role: str | None) -> None:
self.username = username
self.role = role
@@ -37,5 +37,5 @@ def get_id(self) -> str:
def get_name(self) -> str:
return self.username
- def get_role(self):
+ def get_role(self) -> str | None:
return self.role
diff --git a/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py b/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py
new file mode 100644
index 0000000000000..321a1e2bbafa8
--- /dev/null
+++ b/airflow/migrations/versions/0034_3_0_0_update_user_id_type.py
@@ -0,0 +1,52 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Update dag_run_note.user_id and task_instance_note.user_id columns to String.
+
+Revision ID: 44eabb1904b4
+Revises: 16cbcb1c8c36
+Create Date: 2024-09-27 09:57:29.830521
+
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "44eabb1904b4"
+down_revision = "16cbcb1c8c36"
+branch_labels = None
+depends_on = None
+airflow_version = "3.0.0"
+
+
+def upgrade():
+ with op.batch_alter_table("dag_run_note") as batch_op:
+ batch_op.alter_column("user_id", type_=sa.String(length=128))
+ with op.batch_alter_table("task_instance_note") as batch_op:
+ batch_op.alter_column("user_id", type_=sa.String(length=128))
+
+
+def downgrade():
+ with op.batch_alter_table("dag_run_note") as batch_op:
+ batch_op.alter_column("user_id", type_=sa.Integer(), postgresql_using="user_id::integer")
+ with op.batch_alter_table("task_instance_note") as batch_op:
+ batch_op.alter_column("user_id", type_=sa.Integer(), postgresql_using="user_id::integer")
diff --git a/airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py b/airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py
similarity index 98%
rename from airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py
rename to airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py
index 5c8aec69e9be9..6016dd9658908 100644
--- a/airflow/migrations/versions/0034_3_0_0_add_name_field_to_dataset_model.py
+++ b/airflow/migrations/versions/0035_3_0_0_add_name_field_to_dataset_model.py
@@ -30,7 +30,7 @@
also rename the one on DatasetAliasModel here for consistency.
Revision ID: 0d9e73a75ee4
-Revises: 16cbcb1c8c36
+Revises: 44eabb1904b4
Create Date: 2024-08-13 09:45:32.213222
"""
@@ -42,7 +42,7 @@
# revision identifiers, used by Alembic.
revision = "0d9e73a75ee4"
-down_revision = "16cbcb1c8c36"
+down_revision = "44eabb1904b4"
branch_labels = None
depends_on = None
airflow_version = "3.0.0"
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 5d53e51763dff..4928c7fcbd8f7 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -1687,7 +1687,7 @@ class DagRunNote(Base):
__tablename__ = "dag_run_note"
- user_id = Column(Integer, nullable=True)
+ user_id = Column(String(128), nullable=True)
dag_run_id = Column(Integer, primary_key=True, nullable=False)
content = Column(String(1000).with_variant(Text(1000), "mysql"))
created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index b19e65486307d..333a4cad91cbe 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -4002,7 +4002,7 @@ class TaskInstanceNote(TaskInstanceDependencies):
__tablename__ = "task_instance_note"
- user_id = Column(Integer, nullable=True)
+ user_id = Column(String(128), nullable=True)
task_id = Column(StringID(), primary_key=True, nullable=False)
dag_id = Column(StringID(), primary_key=True, nullable=False)
run_id = Column(StringID(), primary_key=True, nullable=False)
diff --git a/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py b/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py
index 3a0328fe9962c..7b50338733453 100644
--- a/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py
+++ b/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py
@@ -62,9 +62,7 @@ def requires_authentication(function: T):
@wraps(function)
def decorated(*args, **kwargs):
- if auth_current_user() is not None or current_app.appbuilder.get_app.config.get(
- "AUTH_ROLE_PUBLIC", None
- ):
+ if auth_current_user() is not None or current_app.config.get("AUTH_ROLE_PUBLIC", None):
return function(*args, **kwargs)
else:
return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"})
diff --git a/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py b/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py
index d8d5a95ee676b..f2038b27597c1 100644
--- a/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py
+++ b/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py
@@ -124,7 +124,7 @@ def requires_authentication(function: T, find_user: Callable[[str], BaseUser] |
@wraps(function)
def decorated(*args, **kwargs):
- if current_app.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None):
+ if current_app.config.get("AUTH_ROLE_PUBLIC", None):
response = function(*args, **kwargs)
return make_response(response)
diff --git a/airflow/providers/fab/auth_manager/models/anonymous_user.py b/airflow/providers/fab/auth_manager/models/anonymous_user.py
index 2f294fd9e5d0e..9afb2cdff635f 100644
--- a/airflow/providers/fab/auth_manager/models/anonymous_user.py
+++ b/airflow/providers/fab/auth_manager/models/anonymous_user.py
@@ -35,7 +35,7 @@ class AnonymousUser(AnonymousUserMixin, BaseUser):
@property
def roles(self):
if not self._roles:
- public_role = current_app.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None)
+ public_role = current_app.config.get("AUTH_ROLE_PUBLIC", None)
self._roles = {current_app.appbuilder.sm.find_role(public_role)} if public_role else set()
return self._roles
diff --git a/docs/apache-airflow/img/airflow_erd.sha256 b/docs/apache-airflow/img/airflow_erd.sha256
index e4a952da1b9fd..bca068fde6749 100644
--- a/docs/apache-airflow/img/airflow_erd.sha256
+++ b/docs/apache-airflow/img/airflow_erd.sha256
@@ -1 +1 @@
-c33e9a583a5b29eb748ebd50e117643e11bcb2a9b61ec017efd690621e22769b
\ No newline at end of file
+64dfad12dfd49f033c4723c2f3bb3bac58dd956136fb24a87a2e5a6ae176ec1a
\ No newline at end of file
diff --git a/docs/apache-airflow/img/airflow_erd.svg b/docs/apache-airflow/img/airflow_erd.svg
index 76fbd8f841f25..4eb6c2ee70917 100644
--- a/docs/apache-airflow/img/airflow_erd.svg
+++ b/docs/apache-airflow/img/airflow_erd.svg
@@ -1394,7 +1394,7 @@
user_id
- [INTEGER]
+ [VARCHAR(100)]
@@ -1813,7 +1813,7 @@
user_id
- [INTEGER]
+ [VARCHAR(100)]
diff --git a/docs/apache-airflow/migrations-ref.rst b/docs/apache-airflow/migrations-ref.rst
index a547d03d75be6..e4fb2dfa332eb 100644
--- a/docs/apache-airflow/migrations-ref.rst
+++ b/docs/apache-airflow/migrations-ref.rst
@@ -39,7 +39,10 @@ Here's the list of all the Database Migrations that are executed via when you ru
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+=========================+==================+===================+==============================================================+
-| ``0d9e73a75ee4`` (head) | ``16cbcb1c8c36`` | ``3.0.0`` | Add name and group fields to DatasetModel. |
+| ``0d9e73a75ee4`` (head) | ``44eabb1904b4`` | ``3.0.0`` | Add name and group fields to DatasetModel. |
++-------------------------+------------------+-------------------+--------------------------------------------------------------+
+| ``44eabb1904b4`` | ``16cbcb1c8c36`` | ``3.0.0`` | Update dag_run_note.user_id and task_instance_note.user_id |
+| | | | columns to String. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``16cbcb1c8c36`` | ``522625f6d606`` | ``3.0.0`` | Remove redundant index. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py
index 38e7b58cb5981..6a23b2cf11d93 100644
--- a/tests/api_connexion/conftest.py
+++ b/tests/api_connexion/conftest.py
@@ -36,9 +36,16 @@ def minimal_app_for_api():
]
)
def factory():
- with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}):
+ with conf_vars(
+ {
+ ("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend",
+ (
+ "core",
+ "auth_manager",
+ ): "airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager",
+ }
+ ):
_app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore
- _app.config["AUTH_ROLE_PUBLIC"] = None
return _app
return factory()
diff --git a/tests/api_connexion/endpoints/test_backfill_endpoint.py b/tests/api_connexion/endpoints/test_backfill_endpoint.py
index 51a4faf40055c..07b2a3fd56c2d 100644
--- a/tests/api_connexion/endpoints/test_backfill_endpoint.py
+++ b/tests/api_connexion/endpoints/test_backfill_endpoint.py
@@ -29,7 +29,6 @@
from airflow.models.dag import DAG
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
-from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.session import provide_session
from tests.test_utils.api_connexion_utils import create_user, delete_user
@@ -50,25 +49,11 @@ def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG),
- ],
- )
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
- create_user(app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore
- app.appbuilder.sm.sync_perm_for_dag( # type: ignore
- "TEST_DAG_1",
- access_control={
- "TestGranularDag": {
- permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}
- },
- },
+ role_name="admin",
)
+ create_user(app, username="test_no_permissions", role_name=None)
with DAG(
DAG_ID,
@@ -93,9 +78,8 @@ def configured_app(minimal_app_for_api):
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
- delete_user(app, username="test_granular_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestBackfillEndpoint:
@@ -178,7 +162,6 @@ def test_should_respond_200(self, session):
@pytest.mark.parametrize(
"user, expected",
[
- ("test_granular_permissions", 200),
("test_no_permissions", 403),
("test", 200),
(None, 401),
@@ -240,7 +223,6 @@ def test_no_exist(self, session):
@pytest.mark.parametrize(
"user, expected",
[
- ("test_granular_permissions", 200),
("test_no_permissions", 403),
("test", 200),
(None, 401),
@@ -268,7 +250,6 @@ class TestCreateBackfill(TestBackfillEndpoint):
@pytest.mark.parametrize(
"user, expected",
[
- ("test_granular_permissions", 200),
("test_no_permissions", 403),
("test", 200),
(None, 401),
@@ -347,7 +328,6 @@ def test_should_respond_200(self, session):
@pytest.mark.parametrize(
"user, expected",
[
- ("test_granular_permissions", 200),
("test_no_permissions", 403),
("test", 200),
(None, 401),
@@ -409,7 +389,6 @@ def test_should_respond_200(self, session):
@pytest.mark.parametrize(
"user, expected",
[
- ("test_granular_permissions", 200),
("test_no_permissions", 403),
("test", 200),
(None, 401),
diff --git a/tests/api_connexion/endpoints/test_config_endpoint.py b/tests/api_connexion/endpoints/test_config_endpoint.py
index 475753a4a902e..bd88c491c952b 100644
--- a/tests/api_connexion/endpoints/test_config_endpoint.py
+++ b/tests/api_connexion/endpoints/test_config_endpoint.py
@@ -21,7 +21,6 @@
import pytest
-from airflow.security import permissions
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.config import conf_vars
@@ -54,18 +53,17 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type:ignore
+ app,
username="test",
- role_name="Test",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
with conf_vars({("webserver", "expose_config"): "True"}):
yield minimal_app_for_api
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestGetConfig:
diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py b/tests/api_connexion/endpoints/test_connection_endpoint.py
index a19b046aa2747..a140046656e31 100644
--- a/tests/api_connexion/endpoints/test_connection_endpoint.py
+++ b/tests/api_connexion/endpoints/test_connection_endpoint.py
@@ -24,7 +24,6 @@
from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.models import Connection
from airflow.secrets.environment_variables import CONN_ENV_PREFIX
-from airflow.security import permissions
from airflow.utils.session import provide_session
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.config import conf_vars
@@ -38,22 +37,16 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION),
- ],
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestConnectionEndpoint:
diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py
index 9905b4e27ab2c..6d4ffc2d06d2c 100644
--- a/tests/api_connexion/endpoints/test_dag_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_endpoint.py
@@ -28,7 +28,6 @@
from airflow.models.dag import DAG
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
-from airflow.security import permissions
from airflow.utils.session import provide_session
from airflow.utils.state import TaskInstanceState
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
@@ -56,33 +55,11 @@ def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG),
- ],
- )
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
- create_user(app, username="test_granular_permissions", role_name="TestGranularDag") # type: ignore
- app.appbuilder.sm.sync_perm_for_dag( # type: ignore
- "TEST_DAG_1",
- access_control={
- "TestGranularDag": {
- permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}
- },
- },
- )
- app.appbuilder.sm.sync_perm_for_dag( # type: ignore
- "TEST_DAG_1",
- access_control={
- "TestGranularDag": {
- permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}
- },
- },
+ role_name="admin",
)
+ create_user(app, username="test_no_permissions", role_name=None)
with DAG(
DAG_ID,
@@ -107,9 +84,8 @@ def configured_app(minimal_app_for_api):
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
- delete_user(app, username="test_granular_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestDagEndpoint:
@@ -258,13 +234,6 @@ def test_should_respond_200_with_schedule_none(self, session):
"pickle_id": None,
} == response.json
- def test_should_respond_200_with_granular_dag_access(self):
- self._create_dag_models(1)
- response = self.client.get(
- "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
- )
- assert response.status_code == 200
-
def test_should_respond_404(self):
response = self.client.get("/api/v1/dags/INVALID_DAG", environ_overrides={"REMOTE_USER": "test"})
assert response.status_code == 404
@@ -282,13 +251,6 @@ def test_should_raise_403_forbidden(self):
)
assert response.status_code == 403
- def test_should_respond_403_with_granular_access_for_different_dag(self):
- self._create_dag_models(3)
- response = self.client.get(
- "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
- )
- assert response.status_code == 403
-
@pytest.mark.parametrize(
"fields",
[
@@ -961,15 +923,6 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids):
assert expected_dag_ids == dag_ids
- def test_should_respond_200_with_granular_dag_access(self):
- self._create_dag_models(3)
- response = self.client.get(
- "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
- )
- assert response.status_code == 200
- assert len(response.json["dags"]) == 1
- assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1"
-
@pytest.mark.parametrize(
"url, expected_dag_ids",
[
@@ -1252,18 +1205,6 @@ def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, sessio
session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None, expected_extra=payload
)
- def test_should_respond_200_on_patch_with_granular_dag_access(self, session):
- self._create_dag_models(1)
- response = self.client.patch(
- "/api/v1/dags/TEST_DAG_1",
- json={
- "is_paused": False,
- },
- environ_overrides={"REMOTE_USER": "test_granular_permissions"},
- )
- assert response.status_code == 200
- _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None)
-
def test_should_respond_400_on_invalid_request(self):
patch_body = {
"is_paused": True,
@@ -1279,24 +1220,6 @@ def test_should_respond_400_on_invalid_request(self):
"type": EXCEPTIONS_LINK_MAP[400],
}
- def test_validation_error_raises_400(self):
- patch_body = {
- "ispaused": True,
- }
- dag_model = self._create_dag_model()
- response = self.client.patch(
- f"/api/v1/dags/{dag_model.dag_id}",
- json=patch_body,
- environ_overrides={"REMOTE_USER": "test_granular_permissions"},
- )
- assert response.status_code == 400
- assert response.json == {
- "detail": "{'ispaused': ['Unknown field.']}",
- "status": 400,
- "title": "Bad Request",
- "type": EXCEPTIONS_LINK_MAP[400],
- }
-
def test_non_existing_dag_raises_not_found(self):
patch_body = {
"is_paused": True,
@@ -1820,19 +1743,6 @@ def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids):
assert expected_dag_ids == dag_ids
- def test_should_respond_200_with_granular_dag_access(self):
- self._create_dag_models(3)
- response = self.client.patch(
- "api/v1/dags?dag_id_pattern=~",
- json={
- "is_paused": False,
- },
- environ_overrides={"REMOTE_USER": "test_granular_permissions"},
- )
- assert response.status_code == 200
- assert len(response.json["dags"]) == 1
- assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1"
-
@pytest.mark.parametrize(
"url, expected_dag_ids",
[
diff --git a/tests/api_connexion/endpoints/test_dag_parsing.py b/tests/api_connexion/endpoints/test_dag_parsing.py
index 521d8d9e8cd99..ae42a565dd052 100644
--- a/tests/api_connexion/endpoints/test_dag_parsing.py
+++ b/tests/api_connexion/endpoints/test_dag_parsing.py
@@ -24,7 +24,6 @@
from airflow.models import DagBag
from airflow.models.dagbag import DagPriorityParsingRequest
-from airflow.security import permissions
from tests.test_utils.api_connexion_utils import create_user, delete_user
from tests.test_utils.db import clear_db_dag_parsing_requests
@@ -45,21 +44,16 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type:ignore
+ app,
username="test",
- role_name="Test",
- permissions=[(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)], # type: ignore
+ role_name="admin",
)
- app.appbuilder.sm.sync_perm_for_dag( # type: ignore
- TEST_DAG_ID,
- access_control={"Test": [permissions.ACTION_CAN_EDIT]},
- )
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestDagParsingRequest:
diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
index f3921da7b9c29..73c75b98a43b1 100644
--- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py
@@ -30,12 +30,11 @@
from airflow.models.dagrun import DagRun
from airflow.models.param import Param
from airflow.operators.empty import EmptyOperator
-from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import DagRunState, State
from airflow.utils.types import DagRunType
-from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user
+from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
@@ -52,79 +51,16 @@ def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN),
- ],
- )
- create_user(
- app, # type: ignore
- username="test_no_dag_run_create_permission",
- role_name="TestNoDagRunCreatePermission",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN),
- ],
- )
- create_user(
- app, # type: ignore
- username="test_dag_view_only",
- role_name="TestViewDags",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN),
- ],
- )
- create_user(
- app, # type: ignore
- username="test_view_dags",
- role_name="TestViewDags",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN),
- ],
+ role_name="admin",
)
- create_user(
- app, # type: ignore
- username="test_granular_permissions",
- role_name="TestGranularDag",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)],
- )
- app.appbuilder.sm.sync_perm_for_dag( # type: ignore
- "TEST_DAG_ID",
- access_control={
- "TestGranularDag": {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ},
- "TestNoDagRunCreatePermission": {permissions.RESOURCE_DAG_RUN: {permissions.ACTION_CAN_CREATE}},
- },
- )
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_dag_view_only") # type: ignore
- delete_user(app, username="test_view_dags") # type: ignore
- delete_user(app, username="test_granular_permissions") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
- delete_user(app, username="test_no_dag_run_create_permission") # type: ignore
- delete_roles(app)
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestDagRunEndpoint:
@@ -499,16 +435,6 @@ def test_should_return_all_with_tilde_as_dag_id_and_all_dag_permissions(self):
dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]]
assert dag_run_ids == expected_dag_run_ids
- def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self):
- self._create_test_dag_run(extra_dag=True)
- expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"]
- response = self.client.get(
- "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
- )
- assert response.status_code == 200
- dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]]
- assert dag_run_ids == expected_dag_run_ids
-
def test_should_raises_401_unauthenticated(self):
self._create_test_dag_run()
@@ -907,57 +833,6 @@ def test_order_by_raises_for_invalid_attr(self):
msg = "Ordering with 'dag_ru' is disallowed or the attribute does not exist on the model"
assert response.json["detail"] == msg
- def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self):
- self._create_test_dag_run(extra_dag=True)
- expected_response_json_1 = {
- "dag_id": "TEST_DAG_ID",
- "dag_run_id": "TEST_DAG_RUN_ID_1",
- "end_date": None,
- "state": "running",
- "execution_date": self.default_time,
- "logical_date": self.default_time,
- "external_trigger": True,
- "start_date": self.default_time,
- "conf": {},
- "data_interval_end": None,
- "data_interval_start": None,
- "last_scheduling_decision": None,
- "run_type": "manual",
- "note": None,
- }
- expected_response_json_1.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {})
- expected_response_json_2 = {
- "dag_id": "TEST_DAG_ID",
- "dag_run_id": "TEST_DAG_RUN_ID_2",
- "end_date": None,
- "state": "running",
- "execution_date": self.default_time_2,
- "logical_date": self.default_time_2,
- "external_trigger": True,
- "start_date": self.default_time,
- "conf": {},
- "data_interval_end": None,
- "data_interval_start": None,
- "last_scheduling_decision": None,
- "run_type": "manual",
- "note": None,
- }
- expected_response_json_2.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {})
-
- response = self.client.post(
- "api/v1/dags/~/dagRuns/list",
- json={"dag_ids": []},
- environ_overrides={"REMOTE_USER": "test_granular_permissions"},
- )
- assert response.status_code == 200
- assert response.json == {
- "dag_runs": [
- expected_response_json_1,
- expected_response_json_2,
- ],
- "total_entries": 2,
- }
-
@pytest.mark.parametrize(
"payload, error",
[
@@ -1328,15 +1203,6 @@ def test_raises_validation_error_for_invalid_params(self):
assert response.status_code == 400
assert "Invalid input for param" in response.json["detail"]
- def test_dagrun_trigger_with_dag_level_permissions(self):
- self._create_dag("TEST_DAG_ID")
- response = self.client.post(
- "api/v1/dags/TEST_DAG_ID/dagRuns",
- json={"conf": {"validated_number": 1}},
- environ_overrides={"REMOTE_USER": "test_no_dag_run_create_permission"},
- )
- assert response.status_code == 200
-
@mock.patch("airflow.api_connexion.endpoints.dag_run_endpoint.get_airflow_app")
def test_dagrun_creation_exception_is_handled(self, mock_get_app, session):
self._create_dag("TEST_DAG_ID")
@@ -1627,11 +1493,7 @@ def test_should_raises_401_unauthenticated(self):
assert_401(response)
- @pytest.mark.parametrize(
- "username",
- ["test_dag_view_only", "test_view_dags", "test_granular_permissions", "test_no_permissions"],
- )
- def test_should_raises_403_unauthorized(self, username):
+ def test_should_raises_403_unauthorized(self):
self._create_dag("TEST_DAG_ID")
response = self.client.post(
"api/v1/dags/TEST_DAG_ID/dagRuns",
@@ -1639,7 +1501,7 @@ def test_should_raises_403_unauthorized(self, username):
"dag_run_id": "TEST_DAG_RUN_ID_1",
"execution_date": self.default_time,
},
- environ_overrides={"REMOTE_USER": username},
+ environ_overrides={"REMOTE_USER": "test_no_permissions"},
)
assert response.status_code == 403
diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py
index a8d1224e034c3..f4df56ba629ae 100644
--- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py
@@ -23,7 +23,6 @@
import pytest
from airflow.models import DagBag
-from airflow.security import permissions
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags
@@ -44,29 +43,16 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type:ignore
+ app,
username="test",
- role_name="Test",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], # type: ignore
+ role_name="admin",
)
- app.appbuilder.sm.sync_perm_for_dag( # type: ignore
- TEST_DAG_ID,
- access_control={"Test": [permissions.ACTION_CAN_READ]},
- )
- app.appbuilder.sm.sync_perm_for_dag( # type: ignore
- EXAMPLE_DAG_ID,
- access_control={"Test": [permissions.ACTION_CAN_READ]},
- )
- app.appbuilder.sm.sync_perm_for_dag( # type: ignore
- TEST_MULTIPLE_DAGS_ID,
- access_control={"Test": [permissions.ACTION_CAN_READ]},
- )
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestGetSource:
@@ -123,18 +109,6 @@ def test_should_respond_200_json(self, url_safe_serializer):
assert dag_docstring in response.json["content"]
assert "application/json" == response.headers["Content-Type"]
- def test_should_respond_406(self, url_safe_serializer):
- dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
- dagbag.sync_to_db()
- test_dag: DAG = dagbag.dags[TEST_DAG_ID]
-
- url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}"
- response = self.client.get(
- url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"}
- )
-
- assert 406 == response.status_code
-
def test_should_respond_404(self):
wrong_fileloc = "abcd1234"
url = f"/api/v1/dagSources/{wrong_fileloc}"
@@ -167,38 +141,3 @@ def test_should_raise_403_forbidden(self, url_safe_serializer):
environ_overrides={"REMOTE_USER": "test_no_permissions"},
)
assert response.status_code == 403
-
- def test_should_respond_403_not_readable(self, url_safe_serializer):
- dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
- dagbag.sync_to_db()
- dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID]
-
- response = self.client.get(
- f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}",
- headers={"Accept": "text/plain"},
- environ_overrides={"REMOTE_USER": "test"},
- )
- read_dag = self.client.get(
- f"/api/v1/dags/{NOT_READABLE_DAG_ID}",
- environ_overrides={"REMOTE_USER": "test"},
- )
- assert response.status_code == 403
- assert read_dag.status_code == 403
-
- def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer):
- dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
- dagbag.sync_to_db()
- dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID]
-
- response = self.client.get(
- f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}",
- headers={"Accept": "text/plain"},
- environ_overrides={"REMOTE_USER": "test"},
- )
-
- read_dag = self.client.get(
- f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}",
- environ_overrides={"REMOTE_USER": "test"},
- )
- assert response.status_code == 403
- assert read_dag.status_code == 200
diff --git a/tests/api_connexion/endpoints/test_dag_stats_endpoint.py b/tests/api_connexion/endpoints/test_dag_stats_endpoint.py
index 36fc54d3a5b17..9ab5b49765931 100644
--- a/tests/api_connexion/endpoints/test_dag_stats_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_stats_endpoint.py
@@ -22,7 +22,6 @@
from airflow.models.dag import DAG, DagModel
from airflow.models.dagrun import DagRun
-from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import DagRunState
@@ -38,21 +37,17 @@ def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- ],
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestDagStatsEndpoint:
diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py
index 3e7c805173b39..f156d8921c0e6 100644
--- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py
@@ -22,7 +22,6 @@
from airflow.models.dag import DagModel
from airflow.models.dagwarning import DagWarning
-from airflow.security import permissions
from airflow.utils.session import create_session
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.db import clear_db_dag_warnings, clear_db_dags
@@ -34,30 +33,16 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type:ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- ], # type: ignore
- )
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
- create_user(
- app, # type:ignore
- username="test_with_dag2_read",
- role_name="TestWithDag2Read",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING),
- (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"),
- ], # type: ignore
+ role_name="admin",
)
+ create_user(app, username="test_no_permissions", role_name=None)
yield minimal_app_for_api
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
- delete_user(app, username="test_with_dag2_read") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestBaseDagWarning:
@@ -162,11 +147,3 @@ def test_should_raise_403_forbidden(self):
"/api/v1/dagWarnings", environ_overrides={"REMOTE_USER": "test_no_permissions"}
)
assert response.status_code == 403
-
- def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self):
- response = self.client.get(
- "/api/v1/dagWarnings",
- environ_overrides={"REMOTE_USER": "test_with_dag2_read"},
- query_string={"dag_id": "dag1"},
- )
- assert response.status_code == 403
diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py
index 5caec0ac2a131..76c164654c9d8 100644
--- a/tests/api_connexion/endpoints/test_dataset_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py
@@ -33,7 +33,6 @@
TaskOutletAssetReference,
)
from airflow.models.dagrun import DagRun
-from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.session import provide_session
from airflow.utils.types import DagRunType
@@ -50,31 +49,16 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET),
- (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ASSET),
- ],
- )
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
- create_user(
- app, # type: ignore
- username="test_queued_event",
- role_name="TestQueuedEvent",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET),
- ],
+ role_name="admin",
)
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test_queued_event") # type: ignore
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestDatasetEndpoint:
@@ -768,43 +752,6 @@ def _create_dataset_dag_run_queues(self, dag_id, dataset_id, session):
class TestGetDagDatasetQueuedEvent(TestQueuedEventEndpoint):
- @pytest.mark.usefixtures("time_freezer")
- def test_should_respond_200(self, session, create_dummy_dag):
- dag, _ = create_dummy_dag()
- dag_id = dag.dag_id
- dataset_id = self._create_dataset(session).id
- self._create_dataset_dag_run_queues(dag_id, dataset_id, session)
- dataset_uri = "s3://bucket/key"
-
- response = self.client.get(
- f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 200
- assert response.json == {
- "created_at": self.default_time,
- "uri": "s3://bucket/key",
- "dag_id": "dag",
- }
-
- def test_should_respond_404(self):
- dag_id = "not_exists"
- dataset_uri = "not_exists"
-
- response = self.client.get(
- f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert {
- "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- } == response.json
-
def test_should_raises_401_unauthenticated(self, session):
dag_id = "dummy"
dataset_uri = "dummy"
@@ -826,47 +773,6 @@ def test_should_raise_403_forbidden(self, session):
class TestDeleteDagDatasetQueuedEvent(TestDatasetEndpoint):
- def test_delete_should_respond_204(self, session, create_dummy_dag):
- dag, _ = create_dummy_dag()
- dag_id = dag.dag_id
- dataset_uri = "s3://bucket/key"
- dataset_id = self._create_dataset(session).id
-
- adrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id)
- session.add(adrq)
- session.commit()
- conn = session.query(AssetDagRunQueue).all()
- assert len(conn) == 1
-
- response = self.client.delete(
- f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 204
- conn = session.query(AssetDagRunQueue).all()
- assert len(conn) == 0
- _check_last_log(
- session, dag_id=dag_id, event="api.delete_dag_dataset_queued_event", execution_date=None
- )
-
- def test_should_respond_404(self):
- dag_id = "not_exists"
- dataset_uri = "not_exists"
-
- response = self.client.delete(
- f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert {
- "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- } == response.json
-
def test_should_raises_401_unauthenticated(self, session):
dag_id = "dummy"
dataset_uri = "dummy"
@@ -884,46 +790,6 @@ def test_should_raise_403_forbidden(self, session):
class TestGetDagDatasetQueuedEvents(TestQueuedEventEndpoint):
- @pytest.mark.usefixtures("time_freezer")
- def test_should_respond_200(self, session, create_dummy_dag):
- dag, _ = create_dummy_dag()
- dag_id = dag.dag_id
- dataset_id = self._create_dataset(session).id
- self._create_dataset_dag_run_queues(dag_id, dataset_id, session)
-
- response = self.client.get(
- f"/api/v1/dags/{dag_id}/datasets/queuedEvent",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 200
- assert response.json == {
- "queued_events": [
- {
- "created_at": self.default_time,
- "uri": "s3://bucket/key",
- "dag_id": "dag",
- }
- ],
- "total_entries": 1,
- }
-
- def test_should_respond_404(self):
- dag_id = "not_exists"
-
- response = self.client.get(
- f"/api/v1/dags/{dag_id}/datasets/queuedEvent",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert {
- "detail": "Queue event with dag_id: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- } == response.json
-
def test_should_raises_401_unauthenticated(self):
dag_id = "dummy"
@@ -943,22 +809,6 @@ def test_should_raise_403_forbidden(self):
class TestDeleteDagDatasetQueuedEvents(TestDatasetEndpoint):
- def test_should_respond_404(self):
- dag_id = "not_exists"
-
- response = self.client.delete(
- f"/api/v1/dags/{dag_id}/datasets/queuedEvent",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert {
- "detail": "Queue event with dag_id: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- } == response.json
-
def test_should_raises_401_unauthenticated(self):
dag_id = "dummy"
@@ -978,47 +828,6 @@ def test_should_raise_403_forbidden(self):
class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint):
- @pytest.mark.usefixtures("time_freezer")
- def test_should_respond_200(self, session, create_dummy_dag):
- dag, _ = create_dummy_dag()
- dag_id = dag.dag_id
- dataset_id = self._create_dataset(session).id
- self._create_dataset_dag_run_queues(dag_id, dataset_id, session)
- dataset_uri = "s3://bucket/key"
-
- response = self.client.get(
- f"/api/v1/datasets/queuedEvent/{dataset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 200
- assert response.json == {
- "queued_events": [
- {
- "created_at": self.default_time,
- "uri": "s3://bucket/key",
- "dag_id": "dag",
- }
- ],
- "total_entries": 1,
- }
-
- def test_should_respond_404(self):
- dataset_uri = "not_exists"
-
- response = self.client.get(
- f"/api/v1/datasets/queuedEvent/{dataset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert {
- "detail": "Queue event with asset uri: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- } == response.json
-
def test_should_raises_401_unauthenticated(self):
dataset_uri = "not_exists"
@@ -1038,39 +847,6 @@ def test_should_raise_403_forbidden(self):
class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint):
- def test_delete_should_respond_204(self, session, create_dummy_dag):
- dag, _ = create_dummy_dag()
- dag_id = dag.dag_id
- dataset_id = self._create_dataset(session).id
- self._create_dataset_dag_run_queues(dag_id, dataset_id, session)
- dataset_uri = "s3://bucket/key"
-
- response = self.client.delete(
- f"/api/v1/datasets/queuedEvent/{dataset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 204
- conn = session.query(AssetDagRunQueue).all()
- assert len(conn) == 0
- _check_last_log(session, dag_id=None, event="api.delete_dataset_queued_events", execution_date=None)
-
- def test_should_respond_404(self):
- dataset_uri = "not_exists"
-
- response = self.client.delete(
- f"/api/v1/datasets/queuedEvent/{dataset_uri}",
- environ_overrides={"REMOTE_USER": "test_queued_event"},
- )
-
- assert response.status_code == 404
- assert {
- "detail": "Queue event with asset uri: `not_exists` was not found",
- "status": 404,
- "title": "Queue event not found",
- "type": EXCEPTIONS_LINK_MAP[404],
- } == response.json
-
def test_should_raises_401_unauthenticated(self):
dataset_uri = "not_exists"
diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py
index 0fdef1a3af2b6..e5ca3d301765a 100644
--- a/tests/api_connexion/endpoints/test_event_log_endpoint.py
+++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py
@@ -20,7 +20,6 @@
from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.models import Log
-from airflow.security import permissions
from airflow.utils import timezone
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.config import conf_vars
@@ -33,32 +32,16 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type:ignore
+ app,
username="test",
- role_name="Test",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore
+ role_name="admin",
)
- create_user(
- app, # type:ignore
- username="test_granular",
- role_name="TestGranular",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], # type: ignore
- )
- app.appbuilder.sm.sync_perm_for_dag( # type: ignore
- "TEST_DAG_ID_1",
- access_control={"TestGranular": [permissions.ACTION_CAN_READ]},
- )
- app.appbuilder.sm.sync_perm_for_dag( # type: ignore
- "TEST_DAG_ID_2",
- access_control={"TestGranular": [permissions.ACTION_CAN_READ]},
- )
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_granular") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
@pytest.fixture
@@ -274,33 +257,6 @@ def test_should_raises_401_unauthenticated(self, log_model):
assert_401(response)
- def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session):
- eventlog1 = create_log_model(
- event="TEST_EVENT_1",
- dag_id="TEST_DAG_ID_1",
- task_id="TEST_TASK_ID_1",
- owner="TEST_OWNER_1",
- when=self.default_time,
- )
- eventlog2 = create_log_model(
- event="TEST_EVENT_2",
- dag_id="TEST_DAG_ID_2",
- task_id="TEST_TASK_ID_2",
- owner="TEST_OWNER_2",
- when=self.default_time_2,
- )
- session.add_all([eventlog1, eventlog2])
- session.commit()
- for attr in ["dag_id", "task_id", "owner", "event"]:
- attr_value = f"TEST_{attr}_1".upper()
- response = self.client.get(
- f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"}
- )
- assert response.status_code == 200
- assert response.json["total_entries"] == 1
- assert len(response.json["event_logs"]) == 1
- assert response.json["event_logs"][0][attr] == attr_value
-
def test_should_filter_eventlogs_by_when(self, create_log_model, session):
eventlog1 = create_log_model(event="TEST_EVENT_1", when=self.default_time)
eventlog2 = create_log_model(event="TEST_EVENT_2", when=self.default_time_2)
@@ -339,32 +295,6 @@ def test_should_filter_eventlogs_by_run_id(self, create_log_model, session):
assert {eventlog["event"] for eventlog in response.json["event_logs"]} == expected_eventlogs
assert all({eventlog["run_id"] == run_id for eventlog in response.json["event_logs"]})
- def test_should_filter_eventlogs_by_included_events(self, create_log_model):
- for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]:
- create_log_model(event=event, when=self.default_time)
- response = self.client.get(
- "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2",
- environ_overrides={"REMOTE_USER": "test_granular"},
- )
- assert response.status_code == 200
- response_data = response.json
- assert len(response_data["event_logs"]) == 2
- assert response_data["total_entries"] == 2
- assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]}
-
- def test_should_filter_eventlogs_by_excluded_events(self, create_log_model):
- for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]:
- create_log_model(event=event, when=self.default_time)
- response = self.client.get(
- "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2",
- environ_overrides={"REMOTE_USER": "test_granular"},
- )
- assert response.status_code == 200
- response_data = response.json
- assert len(response_data["event_logs"]) == 1
- assert response_data["total_entries"] == 1
- assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]}
-
class TestGetEventLogPagination(TestEventLogEndpoint):
@pytest.mark.parametrize(
diff --git a/tests/api_connexion/endpoints/test_extra_link_endpoint.py b/tests/api_connexion/endpoints/test_extra_link_endpoint.py
index 1e9226ede9847..2c3eacdc91dc0 100644
--- a/tests/api_connexion/endpoints/test_extra_link_endpoint.py
+++ b/tests/api_connexion/endpoints/test_extra_link_endpoint.py
@@ -26,7 +26,6 @@
from airflow.models.dagbag import DagBag
from airflow.models.xcom import XCom
from airflow.plugins_manager import AirflowPlugin
-from airflow.security import permissions
from airflow.timetables.base import DataInterval
from airflow.utils import timezone
from airflow.utils.state import DagRunState
@@ -48,21 +47,16 @@ def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ],
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestGetExtraLinks:
@@ -78,8 +72,8 @@ def setup_attrs(self, configured_app, session) -> None:
self.dag = self._create_dag()
self.app.dag_bag = DagBag(os.devnull, include_examples=False)
- self.app.dag_bag.dags = {self.dag.dag_id: self.dag} # type: ignore
- self.app.dag_bag.sync_to_db() # type: ignore
+ self.app.dag_bag.dags = {self.dag.dag_id: self.dag}
+ self.app.dag_bag.sync_to_db()
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
self.dag.create_dagrun(
diff --git a/tests/api_connexion/endpoints/test_import_error_endpoint.py b/tests/api_connexion/endpoints/test_import_error_endpoint.py
index 635e159bb292c..af2b83ebb1eed 100644
--- a/tests/api_connexion/endpoints/test_import_error_endpoint.py
+++ b/tests/api_connexion/endpoints/test_import_error_endpoint.py
@@ -21,15 +21,12 @@
import pytest
from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
-from airflow.models.dag import DagModel
-from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.session import provide_session
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.compat import ParseImportError
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_import_errors
-from tests.test_utils.permissions import _resource_name
pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]
@@ -40,42 +37,16 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type:ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR),
- ], # type: ignore
- )
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
- create_user(
- app, # type:ignore
- username="test_single_dag",
- role_name="TestSingleDAG",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore
- )
- # For some reason, DAG level permissions are not synced when in the above list of perms,
- # so do it manually here:
- app.appbuilder.sm.bulk_sync_roles(
- [
- {
- "role": "TestSingleDAG",
- "perms": [
- (
- permissions.ACTION_CAN_READ,
- _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG),
- )
- ],
- }
- ]
+ role_name="admin",
)
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
- delete_user(app, username="test_single_dag") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestBaseImportError:
@@ -152,72 +123,6 @@ def test_should_raise_403_forbidden(self):
)
assert response.status_code == 403
- def test_should_raise_403_forbidden_without_dag_read(self, session):
- import_error = ParseImportError(
- filename="Lorem_ipsum.py",
- stacktrace="Lorem ipsum",
- timestamp=timezone.parse(self.timestamp, timezone="UTC"),
- )
- session.add(import_error)
- session.commit()
-
- response = self.client.get(
- f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
- )
-
- assert response.status_code == 403
-
- def test_should_return_200_with_single_dag_read(self, session):
- dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py")
- session.add(dag_model)
- import_error = ParseImportError(
- filename="Lorem_ipsum.py",
- stacktrace="Lorem ipsum",
- timestamp=timezone.parse(self.timestamp, timezone="UTC"),
- )
- session.add(import_error)
- session.commit()
-
- response = self.client.get(
- f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
- )
-
- assert response.status_code == 200
- response_data = response.json
- response_data["import_error_id"] = 1
- assert {
- "filename": "Lorem_ipsum.py",
- "import_error_id": 1,
- "stack_trace": "Lorem ipsum",
- "timestamp": "2020-06-10T12:00:00+00:00",
- } == response_data
-
- def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session):
- for dag_id in TEST_DAG_IDS:
- dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py")
- session.add(dag_model)
- import_error = ParseImportError(
- filename="Lorem_ipsum.py",
- stacktrace="Lorem ipsum",
- timestamp=timezone.parse(self.timestamp, timezone="UTC"),
- )
- session.add(import_error)
- session.commit()
-
- response = self.client.get(
- f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
- )
-
- assert response.status_code == 200
- response_data = response.json
- response_data["import_error_id"] = 1
- assert {
- "filename": "Lorem_ipsum.py",
- "import_error_id": 1,
- "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file",
- "timestamp": "2020-06-10T12:00:00+00:00",
- } == response_data
-
class TestGetImportErrorsEndpoint(TestBaseImportError):
def test_get_import_errors(self, session):
@@ -328,71 +233,6 @@ def test_should_raises_401_unauthenticated(self, session):
assert_401(response)
- def test_get_import_errors_single_dag(self, session):
- for dag_id in TEST_DAG_IDS:
- fake_filename = f"/tmp/{dag_id}.py"
- dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
- session.add(dag_model)
- importerror = ParseImportError(
- filename=fake_filename,
- stacktrace="Lorem ipsum",
- timestamp=timezone.parse(self.timestamp, timezone="UTC"),
- )
- session.add(importerror)
- session.commit()
-
- response = self.client.get(
- "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"}
- )
-
- assert response.status_code == 200
- response_data = response.json
- self._normalize_import_errors(response_data["import_errors"])
- assert {
- "import_errors": [
- {
- "filename": "/tmp/test_dag.py",
- "import_error_id": 1,
- "stack_trace": "Lorem ipsum",
- "timestamp": "2020-06-10T12:00:00+00:00",
- },
- ],
- "total_entries": 1,
- } == response_data
-
- def test_get_import_errors_single_dag_in_dagfile(self, session):
- for dag_id in TEST_DAG_IDS:
- fake_filename = "/tmp/all_in_one.py"
- dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
- session.add(dag_model)
-
- importerror = ParseImportError(
- filename="/tmp/all_in_one.py",
- stacktrace="Lorem ipsum",
- timestamp=timezone.parse(self.timestamp, timezone="UTC"),
- )
- session.add(importerror)
- session.commit()
-
- response = self.client.get(
- "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"}
- )
-
- assert response.status_code == 200
- response_data = response.json
- self._normalize_import_errors(response_data["import_errors"])
- assert {
- "import_errors": [
- {
- "filename": "/tmp/all_in_one.py",
- "import_error_id": 1,
- "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file",
- "timestamp": "2020-06-10T12:00:00+00:00",
- },
- ],
- "total_entries": 1,
- } == response_data
-
class TestGetImportErrorsEndpointPagination(TestBaseImportError):
@pytest.mark.parametrize(
diff --git a/tests/api_connexion/endpoints/test_log_endpoint.py b/tests/api_connexion/endpoints/test_log_endpoint.py
index 420d2dd65f89c..2b112e3221843 100644
--- a/tests/api_connexion/endpoints/test_log_endpoint.py
+++ b/tests/api_connexion/endpoints/test_log_endpoint.py
@@ -30,7 +30,6 @@
from airflow.decorators import task
from airflow.models.dag import DAG
from airflow.operators.empty import EmptyOperator
-from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.types import DagRunType
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
@@ -46,13 +45,9 @@ def configured_app(minimal_app_for_api):
create_user(
app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG),
- ],
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions")
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py
index 72cdccdee68df..fc53b8952f4aa 100644
--- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py
@@ -28,12 +28,11 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagbag import DagBag
from airflow.models.taskmap import TaskMap
-from airflow.security import permissions
from airflow.utils.platform import getuser
from airflow.utils.session import provide_session
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.timezone import datetime
-from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user
+from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields
from tests.test_utils.mock_operators import MockOperator
@@ -50,24 +49,16 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
- ],
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
- delete_roles(app)
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestMappedTaskInstanceEndpoint:
@@ -133,8 +124,8 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None):
session.add(ti)
self.app.dag_bag = DagBag(os.devnull, include_examples=False)
- self.app.dag_bag.dags = {dag_id: dag_maker.dag} # type: ignore
- self.app.dag_bag.sync_to_db() # type: ignore
+ self.app.dag_bag.dags = {dag_id: dag_maker.dag}
+ self.app.dag_bag.sync_to_db()
session.flush()
mapped.expand_mapped_task(dr.run_id, session=session)
diff --git a/tests/api_connexion/endpoints/test_plugin_endpoint.py b/tests/api_connexion/endpoints/test_plugin_endpoint.py
index edf925cf0fa73..0cd630375a282 100644
--- a/tests/api_connexion/endpoints/test_plugin_endpoint.py
+++ b/tests/api_connexion/endpoints/test_plugin_endpoint.py
@@ -24,7 +24,6 @@
from airflow.hooks.base import BaseHook
from airflow.plugins_manager import AirflowPlugin
-from airflow.security import permissions
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.timetables.base import Timetable
from airflow.utils.module_loading import qualname
@@ -105,17 +104,16 @@ class MockPlugin(AirflowPlugin):
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)],
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestPluginsEndpoint:
diff --git a/tests/api_connexion/endpoints/test_pool_endpoint.py b/tests/api_connexion/endpoints/test_pool_endpoint.py
index 87439a5811945..2cc095d077aa9 100644
--- a/tests/api_connexion/endpoints/test_pool_endpoint.py
+++ b/tests/api_connexion/endpoints/test_pool_endpoint.py
@@ -20,7 +20,6 @@
from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.models.pool import Pool
-from airflow.security import permissions
from airflow.utils.session import provide_session
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.config import conf_vars
@@ -35,22 +34,16 @@ def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_POOL),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL),
- ],
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestBasePoolEndpoints:
diff --git a/tests/api_connexion/endpoints/test_provider_endpoint.py b/tests/api_connexion/endpoints/test_provider_endpoint.py
index 16e5989cc56db..b4cf8f10a92ae 100644
--- a/tests/api_connexion/endpoints/test_provider_endpoint.py
+++ b/tests/api_connexion/endpoints/test_provider_endpoint.py
@@ -21,7 +21,6 @@
import pytest
from airflow.providers_manager import ProviderInfo
-from airflow.security import permissions
from tests.test_utils.api_connexion_utils import create_user, delete_user
pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]
@@ -54,17 +53,16 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)],
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestBaseProviderEndpoint:
diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py
index d0a4fb903c8b8..b2e068bd507fe 100644
--- a/tests/api_connexion/endpoints/test_task_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_endpoint.py
@@ -27,7 +27,6 @@
from airflow.models.expandinput import EXPAND_INPUT_EMPTY
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
-from airflow.security import permissions
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
@@ -38,21 +37,16 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ],
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestTaskEndpoint:
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 25ded6c814b72..b5b3163e988d0 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -25,19 +25,17 @@
from sqlalchemy import select
from sqlalchemy.orm import contains_eager
-from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import DagRun, SlaMiss, TaskInstance, Trigger
from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
from airflow.models.taskinstancehistory import TaskInstanceHistory
-from airflow.security import permissions
from airflow.utils.platform import getuser
from airflow.utils.session import provide_session
from airflow.utils.state import State
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
-from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_roles, delete_user
+from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields
from tests.test_utils.www import _check_last_log
@@ -55,69 +53,16 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
- ],
- )
- create_user(
- app, # type: ignore
- username="test_dag_read_only",
- role_name="TestDagReadOnly",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
- ],
- )
- create_user(
- app, # type: ignore
- username="test_task_read_only",
- role_name="TestTaskReadOnly",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ],
- )
- create_user(
- app, # type: ignore
- username="test_read_only_one_dag",
- role_name="TestReadOnlyOneDag",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
- ],
- )
- # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms,
- # so do it manually here:
- app.appbuilder.sm.bulk_sync_roles(
- [
- {
- "role": "TestReadOnlyOneDag",
- "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")],
- }
- ]
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_dag_read_only") # type: ignore
- delete_user(app, username="test_task_read_only") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
- delete_user(app, username="test_read_only_one_dag") # type: ignore
- delete_roles(app)
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestTaskInstanceEndpoint:
@@ -219,9 +164,8 @@ def setup_method(self):
def teardown_method(self):
clear_db_runs()
- @pytest.mark.parametrize("username", ["test", "test_dag_read_only", "test_task_read_only"])
@provide_session
- def test_should_respond_200(self, username, session):
+ def test_should_respond_200(self, session):
self.create_task_instances(session)
# Update ti and set operator to None to
# test that operator field is nullable.
@@ -232,7 +176,7 @@ def test_should_respond_200(self, username, session):
session.commit()
response = self.client.get(
"/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context",
- environ_overrides={"REMOTE_USER": username},
+ environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 200
assert response.json == {
@@ -723,36 +667,11 @@ def test_should_respond_200(self, task_instances, update_extras, url, expected_t
assert response.json["total_entries"] == expected_ti
assert len(response.json["task_instances"]) == expected_ti
- @pytest.mark.parametrize(
- "task_instances, user, expected_ti",
- [
- pytest.param(
- {
- "example_python_operator": 2,
- "example_skip_dag": 1,
- },
- "test_read_only_one_dag",
- 2,
- ),
- pytest.param(
- {
- "example_python_operator": 1,
- "example_skip_dag": 2,
- },
- "test_read_only_one_dag",
- 1,
- ),
- pytest.param(
- {
- "example_python_operator": 1,
- "example_skip_dag": 2,
- },
- "test",
- 3,
- ),
- ],
- )
- def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session):
+ def test_return_TI_only_from_readable_dags(self, session):
+ task_instances = {
+ "example_python_operator": 1,
+ "example_skip_dag": 2,
+ }
for dag_id in task_instances:
self.create_task_instances(
session,
@@ -763,11 +682,11 @@ def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_
dag_id=dag_id,
)
response = self.client.get(
- "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user}
+ "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": "test"}
)
assert response.status_code == 200
- assert response.json["total_entries"] == expected_ti
- assert len(response.json["task_instances"]) == expected_ti
+ assert response.json["total_entries"] == 3
+ assert len(response.json["task_instances"]) == 3
def test_should_respond_200_for_dag_id_filter(self, session):
self.create_task_instances(session)
@@ -898,44 +817,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
"test",
id="test executor filter",
),
- pytest.param(
- [
- {"pool": "test_pool_1"},
- {"pool": "test_pool_2"},
- {"pool": "test_pool_3"},
- ],
- True,
- {"pool": ["test_pool_1", "test_pool_2"]},
- 2,
- "test_dag_read_only",
- id="test pool filter",
- ),
- pytest.param(
- [
- {"state": State.RUNNING},
- {"state": State.QUEUED},
- {"state": State.SUCCESS},
- {"state": State.NONE},
- ],
- False,
- {"state": ["running", "queued", "none"]},
- 3,
- "test_task_read_only",
- id="test state filter",
- ),
- pytest.param(
- [
- {"state": State.NONE},
- {"state": State.NONE},
- {"state": State.NONE},
- {"state": State.NONE},
- ],
- False,
- {},
- 4,
- "test_task_read_only",
- id="test dag with null states",
- ),
pytest.param(
[
{"duration": 100},
@@ -948,36 +829,6 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
"test",
id="test duration filter",
),
- pytest.param(
- [
- {"end_date": DEFAULT_DATETIME_1},
- {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)},
- {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)},
- ],
- True,
- {
- "end_date_gte": DEFAULT_DATETIME_STR_1,
- "end_date_lte": DEFAULT_DATETIME_STR_2,
- },
- 2,
- "test_task_read_only",
- id="test end date filter",
- ),
- pytest.param(
- [
- {"start_date": DEFAULT_DATETIME_1},
- {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)},
- {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)},
- ],
- True,
- {
- "start_date_gte": DEFAULT_DATETIME_STR_1,
- "start_date_lte": DEFAULT_DATETIME_STR_2,
- },
- 2,
- "test_dag_read_only",
- id="test start date filter",
- ),
pytest.param(
[
{"execution_date": DEFAULT_DATETIME_1},
@@ -1162,24 +1013,6 @@ def test_should_raise_403_forbidden(self):
)
assert response.status_code == 403
- def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session):
- self.create_task_instances(session=session)
- self.create_task_instances(session=session, dag_id="example_skip_dag")
- payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]}
-
- response = self.client.post(
- "/api/v1/dags/~/dagRuns/~/taskInstances/list",
- environ_overrides={"REMOTE_USER": "test_read_only_one_dag"},
- json=payload,
- )
- assert response.status_code == 403
- assert response.json == {
- "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']",
- "status": 403,
- "title": "Forbidden",
- "type": EXCEPTIONS_LINK_MAP[403],
- }
-
def test_should_raise_400_for_no_json(self):
response = self.client.post(
"/api/v1/dags/~/dagRuns/~/taskInstances/list",
@@ -1794,11 +1627,10 @@ def test_should_raises_401_unauthenticated(self):
)
assert_401(response)
- @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"])
- def test_should_raise_403_forbidden(self, username: str):
+ def test_should_raise_403_forbidden(self):
response = self.client.post(
"/api/v1/dags/example_python_operator/clearTaskInstances",
- environ_overrides={"REMOTE_USER": username},
+ environ_overrides={"REMOTE_USER": "test_no_permissions"},
json={
"dry_run": False,
"reset_dag_runs": True,
@@ -2043,11 +1875,10 @@ def test_should_raises_401_unauthenticated(self):
)
assert_401(response)
- @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"])
- def test_should_raise_403_forbidden(self, username):
+ def test_should_raise_403_forbidden(self):
response = self.client.post(
"/api/v1/dags/example_python_operator/updateTaskInstancesState",
- environ_overrides={"REMOTE_USER": username},
+ environ_overrides={"REMOTE_USER": "test_no_permissions"},
json={
"dry_run": True,
"task_id": "print_the_context",
@@ -2386,11 +2217,10 @@ def test_should_raises_401_unauthenticated(self):
)
assert_401(response)
- @pytest.mark.parametrize("username", ["test_no_permissions", "test_dag_read_only", "test_task_read_only"])
- def test_should_raise_403_forbidden(self, username):
+ def test_should_raise_403_forbidden(self):
response = self.client.patch(
self.ENDPOINT_URL,
- environ_overrides={"REMOTE_USER": username},
+ environ_overrides={"REMOTE_USER": "test_no_permissions"},
json={
"dry_run": True,
"new_state": "failed",
@@ -2748,14 +2578,13 @@ def setup_method(self):
def teardown_method(self):
clear_db_runs()
- @pytest.mark.parametrize("username", ["test", "test_dag_read_only", "test_task_read_only"])
@provide_session
- def test_should_respond_200(self, username, session):
+ def test_should_respond_200(self, session):
self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True)
response = self.client.get(
"/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1",
- environ_overrides={"REMOTE_USER": username},
+ environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 200
assert response.json == {
diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py b/tests/api_connexion/endpoints/test_variable_endpoint.py
index 81405df08b045..aa5f7c99674f8 100644
--- a/tests/api_connexion/endpoints/test_variable_endpoint.py
+++ b/tests/api_connexion/endpoints/test_variable_endpoint.py
@@ -22,7 +22,6 @@
from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.models import Variable
-from airflow.security import permissions
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_variables
@@ -36,40 +35,16 @@ def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE),
- (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE),
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE),
- ],
- )
- create_user(
- app, # type: ignore
- username="test_read_only",
- role_name="TestReadOnly",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE),
- ],
- )
- create_user(
- app, # type: ignore
- username="test_delete_only",
- role_name="TestDeleteOnly",
- permissions=[
- (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE),
- ],
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_read_only") # type: ignore
- delete_user(app, username="test_delete_only") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestVariableEndpoint:
@@ -131,8 +106,6 @@ class TestGetVariable(TestVariableEndpoint):
"user, expected_status_code",
[
("test", 200),
- ("test_read_only", 200),
- ("test_delete_only", 403),
("test_no_permissions", 403),
],
)
diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py
index 7a51714c5b299..809e537f9f88d 100644
--- a/tests/api_connexion/endpoints/test_xcom_endpoint.py
+++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py
@@ -26,7 +26,6 @@
from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom import BaseXCom, XCom, resolve_xcom_backend
from airflow.operators.empty import EmptyOperator
-from airflow.security import permissions
from airflow.utils.dates import parse_execution_date
from airflow.utils.session import create_session
from airflow.utils.timezone import utcnow
@@ -52,32 +51,16 @@ def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type: ignore
+ app,
username="test",
- role_name="Test",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM),
- ],
- )
- create_user(
- app, # type: ignore
- username="test_granular_permissions",
- role_name="TestGranularDag",
- permissions=[
- (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM),
- ],
- )
- app.appbuilder.sm.sync_perm_for_dag( # type: ignore
- "test-dag-id-1",
- access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]},
+ role_name="admin",
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name=None)
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
def _compare_xcom_collections(collection1: dict, collection_2: dict):
@@ -435,53 +418,6 @@ def test_should_respond_200_with_tilde_and_access_to_all_dags(self):
},
)
- def test_should_respond_200_with_tilde_and_granular_dag_access(self):
- dag_id_1 = "test-dag-id-1"
- task_id_1 = "test-task-id-1"
- execution_date = "2005-04-02T00:00:00+00:00"
- execution_date_parsed = parse_execution_date(execution_date)
- dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
- self._create_xcom_entries(dag_id_1, dag_run_id_1, execution_date_parsed, task_id_1)
-
- dag_id_2 = "test-dag-id-2"
- task_id_2 = "test-task-id-2"
- run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
- self._create_xcom_entries(dag_id_2, run_id_2, execution_date_parsed, task_id_2)
- self._create_invalid_xcom_entries(execution_date_parsed)
- response = self.client.get(
- "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries",
- environ_overrides={"REMOTE_USER": "test_granular_permissions"},
- )
-
- assert 200 == response.status_code
- response_data = response.json
- for xcom_entry in response_data["xcom_entries"]:
- xcom_entry["timestamp"] = "TIMESTAMP"
- _compare_xcom_collections(
- response_data,
- {
- "xcom_entries": [
- {
- "dag_id": dag_id_1,
- "execution_date": execution_date,
- "key": "test-xcom-key-1",
- "task_id": task_id_1,
- "timestamp": "TIMESTAMP",
- "map_index": -1,
- },
- {
- "dag_id": dag_id_1,
- "execution_date": execution_date,
- "key": "test-xcom-key-2",
- "task_id": task_id_1,
- "timestamp": "TIMESTAMP",
- "map_index": -1,
- },
- ],
- "total_entries": 2,
- },
- )
-
def test_should_respond_200_with_map_index(self):
dag_id = "test-dag-id"
task_id = "test-task-id"
diff --git a/tests/api_connexion/test_auth.py b/tests/api_connexion/test_auth.py
index 7d1dcc088273c..54e5632ad84d1 100644
--- a/tests/api_connexion/test_auth.py
+++ b/tests/api_connexion/test_auth.py
@@ -16,15 +16,15 @@
# under the License.
from __future__ import annotations
-from base64 import b64encode
+from unittest.mock import patch
import pytest
-from flask_login import current_user
+from airflow.auth.managers.simple.simple_auth_manager import SimpleAuthManager
+from airflow.auth.managers.simple.user import SimpleAuthManagerUser
from tests.test_utils.api_connexion_utils import assert_401
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_pools
-from tests.test_utils.www import client_with_login
pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]
@@ -34,101 +34,6 @@ class BaseTestAuth:
def set_attrs(self, minimal_app_for_api):
self.app = minimal_app_for_api
- sm = self.app.appbuilder.sm
- tester = sm.find_user(username="test")
- if not tester:
- role_admin = sm.find_role("Admin")
- sm.add_user(
- username="test",
- first_name="test",
- last_name="test",
- email="test@fab.org",
- role=role_admin,
- password="test",
- )
-
-
-class TestBasicAuth(BaseTestAuth):
- @pytest.fixture(autouse=True, scope="class")
- def with_basic_auth_backend(self, minimal_app_for_api):
- from airflow.www.extensions.init_security import init_api_auth
-
- old_auth = getattr(minimal_app_for_api, "api_auth")
-
- try:
- with conf_vars(
- {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"}
- ):
- init_api_auth(minimal_app_for_api)
- yield
- finally:
- setattr(minimal_app_for_api, "api_auth", old_auth)
-
- def test_success(self):
- token = "Basic " + b64encode(b"test:test").decode()
- clear_db_pools()
-
- with self.app.test_client() as test_client:
- response = test_client.get("/api/v1/pools", headers={"Authorization": token})
- assert current_user.email == "test@fab.org"
-
- assert response.status_code == 200
- assert response.json == {
- "pools": [
- {
- "name": "default_pool",
- "slots": 128,
- "occupied_slots": 0,
- "running_slots": 0,
- "queued_slots": 0,
- "scheduled_slots": 0,
- "deferred_slots": 0,
- "open_slots": 128,
- "description": "Default pool",
- "include_deferred": False,
- },
- ],
- "total_entries": 1,
- }
-
- @pytest.mark.parametrize(
- "token",
- [
- "basic",
- "basic ",
- "bearer",
- "test:test",
- b64encode(b"test:test").decode(),
- "bearer ",
- "basic: ",
- "basic 123",
- ],
- )
- def test_malformed_headers(self, token):
- with self.app.test_client() as test_client:
- response = test_client.get("/api/v1/pools", headers={"Authorization": token})
- assert response.status_code == 401
- assert response.headers["Content-Type"] == "application/problem+json"
- assert response.headers["WWW-Authenticate"] == "Basic"
- assert_401(response)
-
- @pytest.mark.parametrize(
- "token",
- [
- "basic " + b64encode(b"test").decode(),
- "basic " + b64encode(b"test:").decode(),
- "basic " + b64encode(b"test:123").decode(),
- "basic " + b64encode(b"test test").decode(),
- ],
- )
- def test_invalid_auth_header(self, token):
- with self.app.test_client() as test_client:
- response = test_client.get("/api/v1/pools", headers={"Authorization": token})
- assert response.status_code == 401
- assert response.headers["Content-Type"] == "application/problem+json"
- assert response.headers["WWW-Authenticate"] == "Basic"
- assert_401(response)
-
class TestSessionAuth(BaseTestAuth):
@pytest.fixture(autouse=True, scope="class")
@@ -144,74 +49,37 @@ def with_session_backend(self, minimal_app_for_api):
finally:
setattr(minimal_app_for_api, "api_auth", old_auth)
- def test_success(self):
+ @patch.object(SimpleAuthManager, "is_logged_in", return_value=True)
+ @patch.object(
+ SimpleAuthManager, "get_user", return_value=SimpleAuthManagerUser(username="test", role="admin")
+ )
+ def test_success(self, *args):
clear_db_pools()
- admin_user = client_with_login(self.app, username="test", password="test")
- response = admin_user.get("/api/v1/pools")
- assert response.status_code == 200
- assert response.json == {
- "pools": [
- {
- "name": "default_pool",
- "slots": 128,
- "occupied_slots": 0,
- "running_slots": 0,
- "queued_slots": 0,
- "scheduled_slots": 0,
- "deferred_slots": 0,
- "open_slots": 128,
- "description": "Default pool",
- "include_deferred": False,
- },
- ],
- "total_entries": 1,
- }
-
- def test_failure(self):
with self.app.test_client() as test_client:
response = test_client.get("/api/v1/pools")
- assert response.status_code == 401
- assert response.headers["Content-Type"] == "application/problem+json"
- assert_401(response)
-
-
-class TestSessionWithBasicAuthFallback(BaseTestAuth):
- @pytest.fixture(autouse=True, scope="class")
- def with_basic_auth_backend(self, minimal_app_for_api):
- from airflow.www.extensions.init_security import init_api_auth
-
- old_auth = getattr(minimal_app_for_api, "api_auth")
-
- try:
- with conf_vars(
- {
- (
- "api",
- "auth_backends",
- ): "airflow.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"
- }
- ):
- init_api_auth(minimal_app_for_api)
- yield
- finally:
- setattr(minimal_app_for_api, "api_auth", old_auth)
-
- def test_basic_auth_fallback(self):
- token = "Basic " + b64encode(b"test:test").decode()
- clear_db_pools()
-
- # request uses session
- admin_user = client_with_login(self.app, username="test", password="test")
- response = admin_user.get("/api/v1/pools")
- assert response.status_code == 200
-
- # request uses basic auth
- with self.app.test_client() as test_client:
- response = test_client.get("/api/v1/pools", headers={"Authorization": token})
assert response.status_code == 200
+ assert response.json == {
+ "pools": [
+ {
+ "name": "default_pool",
+ "slots": 128,
+ "occupied_slots": 0,
+ "running_slots": 0,
+ "queued_slots": 0,
+ "scheduled_slots": 0,
+ "deferred_slots": 0,
+ "open_slots": 128,
+ "description": "Default pool",
+ "include_deferred": False,
+ },
+ ],
+ "total_entries": 1,
+ }
- # request without session or basic auth header
+ def test_failure(self):
with self.app.test_client() as test_client:
response = test_client.get("/api/v1/pools")
assert response.status_code == 401
+ assert response.headers["Content-Type"] == "application/problem+json"
+ assert_401(response)
diff --git a/tests/api_connexion/test_security.py b/tests/api_connexion/test_security.py
index 13a5dd4e25af1..c6a112b1a1bb9 100644
--- a/tests/api_connexion/test_security.py
+++ b/tests/api_connexion/test_security.py
@@ -18,7 +18,6 @@
import pytest
-from airflow.security import permissions
from tests.test_utils.api_connexion_utils import create_user, delete_user
pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]
@@ -28,15 +27,14 @@
def configured_app(minimal_app_for_api):
app = minimal_app_for_api
create_user(
- app, # type:ignore
+ app,
username="test",
- role_name="Test",
- permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)], # type: ignore
+ role_name="admin",
)
yield minimal_app_for_api
- delete_user(app, username="test") # type: ignore
+ delete_user(app, username="test")
class TestSession:
diff --git a/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py b/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py
new file mode 100644
index 0000000000000..61d923d5ff125
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from contextlib import contextmanager
+
+from tests.test_utils.compat import ignore_provider_compatibility_error
+
+with ignore_provider_compatibility_error("2.9.0+", __file__):
+ from airflow.providers.fab.auth_manager.security_manager.override import EXISTING_ROLES
+
+
+@contextmanager
+def create_test_client(app, user_name, role_name, permissions):
+ """
+ Helper function to create a client with a temporary user which will be deleted once done
+ """
+ client = app.test_client()
+ with create_user_scope(app, username=user_name, role_name=role_name, permissions=permissions) as _:
+ resp = client.post("/login/", data={"username": user_name, "password": user_name})
+ assert resp.status_code == 302
+ yield client
+
+
+@contextmanager
+def create_user_scope(app, username, **kwargs):
+ """
+ Helper function designed to be used with pytest fixture mainly.
+ It will create a user and provide it for the fixture via YIELD (generator)
+ then will tidy up once test is complete
+ """
+ test_user = create_user(app, username, **kwargs)
+
+ try:
+ yield test_user
+ finally:
+ delete_user(app, username)
+
+
+def create_user(app, username, role_name=None, email=None, permissions=None):
+ appbuilder = app.appbuilder
+
+ # Removes user and role so each test has isolated test data.
+ delete_user(app, username)
+ role = None
+ if role_name:
+ delete_role(app, role_name)
+ role = create_role(app, role_name, permissions)
+ else:
+ role = []
+
+ return appbuilder.sm.add_user(
+ username=username,
+ first_name=username,
+ last_name=username,
+ email=email or f"{username}@example.org",
+ role=role,
+ password=username,
+ )
+
+
+def create_role(app, name, permissions=None):
+ appbuilder = app.appbuilder
+ role = appbuilder.sm.find_role(name)
+ if not role:
+ role = appbuilder.sm.add_role(name)
+ if not permissions:
+ permissions = []
+ for permission in permissions:
+ perm_object = appbuilder.sm.get_permission(*permission)
+ appbuilder.sm.add_permission_to_role(role, perm_object)
+ return role
+
+
+def set_user_single_role(app, user, role_name):
+ role = create_role(app, role_name)
+ if role not in user.roles:
+ user.roles = [role]
+ app.appbuilder.sm.update_user(user)
+ user._perms = None
+
+
+def delete_role(app, name):
+ if name not in EXISTING_ROLES:
+ if app.appbuilder.sm.find_role(name):
+ app.appbuilder.sm.delete_role(name)
+
+
+def delete_roles(app):
+ for role in app.appbuilder.sm.get_all_roles():
+ delete_role(app, role.name)
+
+
+def delete_user(app, username):
+ appbuilder = app.appbuilder
+ for user in appbuilder.sm.get_all_users():
+ if user.username == username:
+ _ = [
+ delete_role(app, role.name) for role in user.roles if role and role.name not in EXISTING_ROLES
+ ]
+ appbuilder.sm.del_register_user(user)
+ break
diff --git a/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py b/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py
new file mode 100644
index 0000000000000..b7714e5192e6a
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py
@@ -0,0 +1,81 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Default authentication backend - everything is allowed"""
+
+from __future__ import annotations
+
+import logging
+from functools import wraps
+from typing import TYPE_CHECKING, Callable, TypeVar, cast
+
+from flask import Response, request
+from flask_login import login_user
+
+from airflow.utils.airflow_flask_app import get_airflow_app
+
+if TYPE_CHECKING:
+ from requests.auth import AuthBase
+
+log = logging.getLogger(__name__)
+
+CLIENT_AUTH: tuple[str, str] | AuthBase | None = None
+
+
+def init_app(_):
+ """Initializes authentication backend"""
+
+
+T = TypeVar("T", bound=Callable)
+
+
+def _lookup_user(user_email_or_username: str):
+ security_manager = get_airflow_app().appbuilder.sm
+ user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user(
+ username=user_email_or_username
+ )
+ if not user:
+ return None
+
+ if not user.is_active:
+ return None
+
+ return user
+
+
+def requires_authentication(function: T):
+ """Decorator for functions that require authentication"""
+
+ @wraps(function)
+ def decorated(*args, **kwargs):
+ user_id = request.remote_user
+ if not user_id:
+ log.debug("Missing REMOTE_USER.")
+ return Response("Forbidden", 403)
+
+ log.debug("Looking for user: %s", user_id)
+
+ user = _lookup_user(user_id)
+ if not user:
+ return Response("Forbidden", 403)
+
+ log.debug("Found user: %s", user)
+
+ login_user(user, remember=False)
+ return function(*args, **kwargs)
+
+ return cast(T, decorated)
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_auth.py b/tests/providers/fab/auth_manager/api_endpoints/test_auth.py
new file mode 100644
index 0000000000000..d3012e2f1b43e
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_auth.py
@@ -0,0 +1,176 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from base64 import b64encode
+
+import pytest
+from flask_login import current_user
+
+from tests.test_utils.api_connexion_utils import assert_401
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+from tests.test_utils.config import conf_vars
+from tests.test_utils.db import clear_db_pools
+from tests.test_utils.www import client_with_login
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+
+class BaseTestAuth:
+ @pytest.fixture(autouse=True)
+ def set_attrs(self, minimal_app_for_auth_api):
+ self.app = minimal_app_for_auth_api
+
+ sm = self.app.appbuilder.sm
+ tester = sm.find_user(username="test")
+ if not tester:
+ role_admin = sm.find_role("Admin")
+ sm.add_user(
+ username="test",
+ first_name="test",
+ last_name="test",
+ email="test@fab.org",
+ role=role_admin,
+ password="test",
+ )
+
+
+class TestBasicAuth(BaseTestAuth):
+ @pytest.fixture(autouse=True, scope="class")
+ def with_basic_auth_backend(self, minimal_app_for_auth_api):
+ from airflow.www.extensions.init_security import init_api_auth
+
+ old_auth = getattr(minimal_app_for_auth_api, "api_auth")
+
+ try:
+ with conf_vars(
+ {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"}
+ ):
+ init_api_auth(minimal_app_for_auth_api)
+ yield
+ finally:
+ setattr(minimal_app_for_auth_api, "api_auth", old_auth)
+
+ def test_success(self):
+ token = "Basic " + b64encode(b"test:test").decode()
+ clear_db_pools()
+
+ with self.app.test_client() as test_client:
+ response = test_client.get("/api/v1/pools", headers={"Authorization": token})
+ assert current_user.email == "test@fab.org"
+
+ assert response.status_code == 200
+ assert response.json == {
+ "pools": [
+ {
+ "name": "default_pool",
+ "slots": 128,
+ "occupied_slots": 0,
+ "running_slots": 0,
+ "queued_slots": 0,
+ "scheduled_slots": 0,
+ "deferred_slots": 0,
+ "open_slots": 128,
+ "description": "Default pool",
+ "include_deferred": False,
+ },
+ ],
+ "total_entries": 1,
+ }
+
+ @pytest.mark.parametrize(
+ "token",
+ [
+ "basic",
+ "basic ",
+ "bearer",
+ "test:test",
+ b64encode(b"test:test").decode(),
+ "bearer ",
+ "basic: ",
+ "basic 123",
+ ],
+ )
+ def test_malformed_headers(self, token):
+ with self.app.test_client() as test_client:
+ response = test_client.get("/api/v1/pools", headers={"Authorization": token})
+ assert response.status_code == 401
+ assert response.headers["Content-Type"] == "application/problem+json"
+ assert response.headers["WWW-Authenticate"] == "Basic"
+ assert_401(response)
+
+ @pytest.mark.parametrize(
+ "token",
+ [
+ "basic " + b64encode(b"test").decode(),
+ "basic " + b64encode(b"test:").decode(),
+ "basic " + b64encode(b"test:123").decode(),
+ "basic " + b64encode(b"test test").decode(),
+ ],
+ )
+ def test_invalid_auth_header(self, token):
+ with self.app.test_client() as test_client:
+ response = test_client.get("/api/v1/pools", headers={"Authorization": token})
+ assert response.status_code == 401
+ assert response.headers["Content-Type"] == "application/problem+json"
+ assert response.headers["WWW-Authenticate"] == "Basic"
+ assert_401(response)
+
+
+class TestSessionWithBasicAuthFallback(BaseTestAuth):
+ @pytest.fixture(autouse=True, scope="class")
+ def with_basic_auth_backend(self, minimal_app_for_auth_api):
+ from airflow.www.extensions.init_security import init_api_auth
+
+ old_auth = getattr(minimal_app_for_auth_api, "api_auth")
+
+ try:
+ with conf_vars(
+ {
+ (
+ "api",
+ "auth_backends",
+ ): "airflow.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"
+ }
+ ):
+ init_api_auth(minimal_app_for_auth_api)
+ yield
+ finally:
+ setattr(minimal_app_for_auth_api, "api_auth", old_auth)
+
+ def test_basic_auth_fallback(self):
+ token = "Basic " + b64encode(b"test:test").decode()
+ clear_db_pools()
+
+ # request uses session
+ admin_user = client_with_login(self.app, username="test", password="test")
+ response = admin_user.get("/api/v1/pools")
+ assert response.status_code == 200
+
+ # request uses basic auth
+ with self.app.test_client() as test_client:
+ response = test_client.get("/api/v1/pools", headers={"Authorization": token})
+ assert response.status_code == 200
+
+ # request without session or basic auth header
+ with self.app.test_client() as test_client:
+ response = test_client.get("/api/v1/pools")
+ assert response.status_code == 401
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py
new file mode 100644
index 0000000000000..56f135d457e9c
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_backfill_endpoint.py
@@ -0,0 +1,264 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+from datetime import datetime
+from unittest import mock
+from urllib.parse import urlencode
+
+import pendulum
+import pytest
+
+from airflow.models import DagBag, DagModel
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+
+try:
+ from airflow.models.backfill import Backfill
+except ImportError:
+ if AIRFLOW_V_3_0_PLUS:
+ raise
+ else:
+ pass
+from airflow.models.dag import DAG
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.operators.empty import EmptyOperator
+from airflow.security import permissions
+from airflow.utils import timezone
+from airflow.utils.session import provide_session
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
+from tests.test_utils.db import clear_db_backfills, clear_db_dags, clear_db_runs, clear_db_serialized_dags
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+
+DAG_ID = "test_dag"
+TASK_ID = "op1"
+DAG2_ID = "test_dag2"
+DAG3_ID = "test_dag3"
+UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')"
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_auth_api):
+ app = minimal_app_for_auth_api
+
+ create_user(app, username="test_granular_permissions", role_name="TestGranularDag")
+ app.appbuilder.sm.sync_perm_for_dag(
+ "TEST_DAG_1",
+ access_control={
+ "TestGranularDag": {
+ permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}
+ },
+ },
+ )
+
+ with DAG(
+ DAG_ID,
+ schedule=None,
+ start_date=datetime(2020, 6, 15),
+ doc_md="details",
+ params={"foo": 1},
+ tags=["example"],
+ ) as dag:
+ EmptyOperator(task_id=TASK_ID)
+
+ with DAG(DAG2_ID, schedule=None, start_date=datetime(2020, 6, 15)) as dag2: # no doc_md
+ EmptyOperator(task_id=TASK_ID)
+
+ with DAG(DAG3_ID, schedule=None) as dag3: # DAG start_date set to None
+ EmptyOperator(task_id=TASK_ID, start_date=datetime(2019, 6, 12))
+
+ dag_bag = DagBag(os.devnull, include_examples=False)
+ dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3}
+
+ app.dag_bag = dag_bag
+
+ yield app
+
+ delete_user(app, username="test_granular_permissions")
+
+
+class TestBackfillEndpoint:
+ @staticmethod
+ def clean_db():
+ clear_db_backfills()
+ clear_db_runs()
+ clear_db_dags()
+ clear_db_serialized_dags()
+
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app) -> None:
+ self.clean_db()
+ self.app = configured_app
+ self.client = self.app.test_client() # type:ignore
+ self.dag_id = DAG_ID
+ self.dag2_id = DAG2_ID
+ self.dag3_id = DAG3_ID
+
+ def teardown_method(self) -> None:
+ self.clean_db()
+
+ @provide_session
+ def _create_dag_models(self, *, count=1, dag_id_prefix="TEST_DAG", is_paused=False, session=None):
+ dags = []
+ for num in range(1, count + 1):
+ dag_model = DagModel(
+ dag_id=f"{dag_id_prefix}_{num}",
+ fileloc=f"/tmp/dag_{num}.py",
+ is_active=True,
+ timetable_summary="0 0 * * *",
+ is_paused=is_paused,
+ )
+ session.add(dag_model)
+ dags.append(dag_model)
+ return dags
+
+ @provide_session
+ def _create_deactivated_dag(self, session=None):
+ dag_model = DagModel(
+ dag_id="TEST_DAG_DELETED_1",
+ fileloc="/tmp/dag_del_1.py",
+ schedule_interval="2 2 * * *",
+ is_active=False,
+ )
+ session.add(dag_model)
+
+
+class TestListBackfills(TestBackfillEndpoint):
+ def test_should_respond_200_with_granular_dag_access(self, session):
+ (dag,) = self._create_dag_models()
+ from_date = timezone.utcnow()
+ to_date = timezone.utcnow()
+ b = Backfill(
+ dag_id=dag.dag_id,
+ from_date=from_date,
+ to_date=to_date,
+ )
+
+ session.add(b)
+ session.commit()
+ kwargs = {}
+ kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"})
+ response = self.client.get("/api/v1/backfills?dag_id=TEST_DAG_1", **kwargs)
+ assert response.status_code == 200
+
+
+class TestGetBackfill(TestBackfillEndpoint):
+ def test_should_respond_200_with_granular_dag_access(self, session):
+ (dag,) = self._create_dag_models()
+ from_date = timezone.utcnow()
+ to_date = timezone.utcnow()
+ backfill = Backfill(
+ dag_id=dag.dag_id,
+ from_date=from_date,
+ to_date=to_date,
+ )
+ session.add(backfill)
+ session.commit()
+ kwargs = {}
+ kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"})
+ response = self.client.get(f"/api/v1/backfills/{backfill.id}", **kwargs)
+ assert response.status_code == 200
+
+
+class TestCreateBackfill(TestBackfillEndpoint):
+ def test_create_backfill(self, session, dag_maker):
+ with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * * *") as dag:
+ EmptyOperator(task_id="mytask")
+ session.add(SerializedDagModel(dag))
+ session.commit()
+ session.query(DagModel).all()
+ from_date = pendulum.parse("2024-01-01")
+ from_date_iso = from_date.isoformat()
+ to_date = pendulum.parse("2024-02-01")
+ to_date_iso = to_date.isoformat()
+ max_active_runs = 5
+ query = urlencode(
+ query={
+ "dag_id": dag.dag_id,
+ "from_date": f"{from_date_iso}",
+ "to_date": f"{to_date_iso}",
+ "max_active_runs": max_active_runs,
+ "reverse": False,
+ }
+ )
+ kwargs = {}
+ kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"})
+
+ response = self.client.post(
+ f"/api/v1/backfills?{query}",
+ **kwargs,
+ )
+ assert response.status_code == 200
+ assert response.json == {
+ "completed_at": mock.ANY,
+ "created_at": mock.ANY,
+ "dag_id": "TEST_DAG_1",
+ "dag_run_conf": None,
+ "from_date": from_date_iso,
+ "id": mock.ANY,
+ "is_paused": False,
+ "max_active_runs": 5,
+ "to_date": to_date_iso,
+ "updated_at": mock.ANY,
+ }
+
+
+class TestPauseBackfill(TestBackfillEndpoint):
+ def test_should_respond_200_with_granular_dag_access(self, session):
+ (dag,) = self._create_dag_models()
+ from_date = timezone.utcnow()
+ to_date = timezone.utcnow()
+ backfill = Backfill(
+ dag_id=dag.dag_id,
+ from_date=from_date,
+ to_date=to_date,
+ )
+ session.add(backfill)
+ session.commit()
+ kwargs = {}
+ kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"})
+ response = self.client.post(f"/api/v1/backfills/{backfill.id}/pause", **kwargs)
+ assert response.status_code == 200
+
+
+class TestCancelBackfill(TestBackfillEndpoint):
+ def test_should_respond_200_with_granular_dag_access(self, session):
+ (dag,) = self._create_dag_models()
+ from_date = timezone.utcnow()
+ to_date = timezone.utcnow()
+ backfill = Backfill(
+ dag_id=dag.dag_id,
+ from_date=from_date,
+ to_date=to_date,
+ )
+ session.add(backfill)
+ session.commit()
+ kwargs = {}
+ kwargs.update(environ_overrides={"REMOTE_USER": "test_granular_permissions"})
+ response = self.client.post(f"/api/v1/backfills/{backfill.id}/cancel", **kwargs)
+ assert response.status_code == 200
+ # now it is marked as completed
+ assert pendulum.parse(response.json["completed_at"])
+
+ # get conflict when canceling already-canceled backfill
+ response = self.client.post(f"/api/v1/backfills/{backfill.id}/cancel", **kwargs)
+ assert response.status_code == 409
diff --git a/tests/api_connexion/test_cors.py b/tests/providers/fab/auth_manager/api_endpoints/test_cors.py
similarity index 81%
rename from tests/api_connexion/test_cors.py
rename to tests/providers/fab/auth_manager/api_endpoints/test_cors.py
index a2b7f0ebca743..b44eab8820ec6 100644
--- a/tests/api_connexion/test_cors.py
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_cors.py
@@ -20,16 +20,21 @@
import pytest
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_pools
-pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode]
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
class BaseTestAuth:
@pytest.fixture(autouse=True)
- def set_attrs(self, minimal_app_for_api):
- self.app = minimal_app_for_api
+ def set_attrs(self, minimal_app_for_auth_api):
+ self.app = minimal_app_for_auth_api
sm = self.app.appbuilder.sm
tester = sm.find_user(username="test")
@@ -47,19 +52,19 @@ def set_attrs(self, minimal_app_for_api):
class TestEmptyCors(BaseTestAuth):
@pytest.fixture(autouse=True, scope="class")
- def with_basic_auth_backend(self, minimal_app_for_api):
+ def with_basic_auth_backend(self, minimal_app_for_auth_api):
from airflow.www.extensions.init_security import init_api_auth
- old_auth = getattr(minimal_app_for_api, "api_auth")
+ old_auth = getattr(minimal_app_for_auth_api, "api_auth")
try:
with conf_vars(
{("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"}
):
- init_api_auth(minimal_app_for_api)
+ init_api_auth(minimal_app_for_auth_api)
yield
finally:
- setattr(minimal_app_for_api, "api_auth", old_auth)
+ setattr(minimal_app_for_auth_api, "api_auth", old_auth)
def test_empty_cors_headers(self):
token = "Basic " + b64encode(b"test:test").decode()
@@ -75,10 +80,10 @@ def test_empty_cors_headers(self):
class TestCorsOrigin(BaseTestAuth):
@pytest.fixture(autouse=True, scope="class")
- def with_basic_auth_backend(self, minimal_app_for_api):
+ def with_basic_auth_backend(self, minimal_app_for_auth_api):
from airflow.www.extensions.init_security import init_api_auth
- old_auth = getattr(minimal_app_for_api, "api_auth")
+ old_auth = getattr(minimal_app_for_auth_api, "api_auth")
try:
with conf_vars(
@@ -90,10 +95,10 @@ def with_basic_auth_backend(self, minimal_app_for_api):
("api", "access_control_allow_origins"): "http://apache.org http://example.com",
}
):
- init_api_auth(minimal_app_for_api)
+ init_api_auth(minimal_app_for_auth_api)
yield
finally:
- setattr(minimal_app_for_api, "api_auth", old_auth)
+ setattr(minimal_app_for_auth_api, "api_auth", old_auth)
def test_cors_origin_reflection(self):
token = "Basic " + b64encode(b"test:test").decode()
@@ -119,10 +124,10 @@ def test_cors_origin_reflection(self):
class TestCorsWildcard(BaseTestAuth):
@pytest.fixture(autouse=True, scope="class")
- def with_basic_auth_backend(self, minimal_app_for_api):
+ def with_basic_auth_backend(self, minimal_app_for_auth_api):
from airflow.www.extensions.init_security import init_api_auth
- old_auth = getattr(minimal_app_for_api, "api_auth")
+ old_auth = getattr(minimal_app_for_auth_api, "api_auth")
try:
with conf_vars(
@@ -134,10 +139,10 @@ def with_basic_auth_backend(self, minimal_app_for_api):
("api", "access_control_allow_origins"): "*",
}
):
- init_api_auth(minimal_app_for_api)
+ init_api_auth(minimal_app_for_auth_api)
yield
finally:
- setattr(minimal_app_for_api, "api_auth", old_auth)
+ setattr(minimal_app_for_auth_api, "api_auth", old_auth)
def test_cors_origin_reflection(self):
token = "Basic " + b64encode(b"test:test").decode()
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py
new file mode 100644
index 0000000000000..b78ac58e442e0
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py
@@ -0,0 +1,252 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+import pendulum
+import pytest
+
+from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
+from airflow.models import DagBag, DagModel
+from airflow.models.dag import DAG
+from airflow.operators.empty import EmptyOperator
+from airflow.security import permissions
+from airflow.utils.session import provide_session
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
+from tests.test_utils.www import _check_last_log
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+
+@pytest.fixture
+def current_file_token(url_safe_serializer) -> str:
+ return url_safe_serializer.dumps(__file__)
+
+
+DAG_ID = "test_dag"
+TASK_ID = "op1"
+DAG2_ID = "test_dag2"
+DAG3_ID = "test_dag3"
+UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')"
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_auth_api):
+ app = minimal_app_for_auth_api
+
+ create_user(app, username="test_granular_permissions", role_name="TestGranularDag")
+ app.appbuilder.sm.sync_perm_for_dag(
+ "TEST_DAG_1",
+ access_control={
+ "TestGranularDag": {
+ permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}
+ },
+ },
+ )
+ app.appbuilder.sm.sync_perm_for_dag(
+ "TEST_DAG_1",
+ access_control={
+ "TestGranularDag": {
+ permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}
+ },
+ },
+ )
+
+ with DAG(
+ DAG_ID,
+ schedule=None,
+ start_date=datetime(2020, 6, 15),
+ doc_md="details",
+ params={"foo": 1},
+ tags=["example"],
+ ) as dag:
+ EmptyOperator(task_id=TASK_ID)
+
+ with DAG(DAG2_ID, schedule=None, start_date=datetime(2020, 6, 15)) as dag2: # no doc_md
+ EmptyOperator(task_id=TASK_ID)
+
+ with DAG(DAG3_ID, schedule=None) as dag3: # DAG start_date set to None
+ EmptyOperator(task_id=TASK_ID, start_date=datetime(2019, 6, 12))
+
+ dag_bag = DagBag(os.devnull, include_examples=False)
+ dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3}
+
+ app.dag_bag = dag_bag
+
+ yield app
+
+ delete_user(app, username="test_granular_permissions")
+
+
+class TestDagEndpoint:
+ @staticmethod
+ def clean_db():
+ clear_db_runs()
+ clear_db_dags()
+ clear_db_serialized_dags()
+
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app) -> None:
+ self.clean_db()
+ self.app = configured_app
+ self.client = self.app.test_client() # type:ignore
+ self.dag_id = DAG_ID
+ self.dag2_id = DAG2_ID
+ self.dag3_id = DAG3_ID
+
+ def teardown_method(self) -> None:
+ self.clean_db()
+
+ @provide_session
+ def _create_dag_models(self, count, dag_id_prefix="TEST_DAG", is_paused=False, session=None):
+ for num in range(1, count + 1):
+ dag_model = DagModel(
+ dag_id=f"{dag_id_prefix}_{num}",
+ fileloc=f"/tmp/dag_{num}.py",
+ timetable_summary="2 2 * * *",
+ is_active=True,
+ is_paused=is_paused,
+ )
+ session.add(dag_model)
+
+ @provide_session
+ def _create_dag_model_for_details_endpoint(self, dag_id, session=None):
+ dag_model = DagModel(
+ dag_id=dag_id,
+ fileloc="/tmp/dag.py",
+ timetable_summary="2 2 * * *",
+ is_active=True,
+ is_paused=False,
+ )
+ session.add(dag_model)
+
+ @provide_session
+ def _create_dag_model_for_details_endpoint_with_dataset_expression(self, dag_id, session=None):
+ dag_model = DagModel(
+ dag_id=dag_id,
+ fileloc="/tmp/dag.py",
+ timetable_summary="2 2 * * *",
+ is_active=True,
+ is_paused=False,
+ dataset_expression={
+ "any": [
+ "s3://dag1/output_1.txt",
+ {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]},
+ ]
+ },
+ )
+ session.add(dag_model)
+
+ @provide_session
+ def _create_deactivated_dag(self, session=None):
+ dag_model = DagModel(
+ dag_id="TEST_DAG_DELETED_1",
+ fileloc="/tmp/dag_del_1.py",
+ timetable_summary="2 2 * * *",
+ is_active=False,
+ )
+ session.add(dag_model)
+
+
+class TestGetDag(TestDagEndpoint):
+ def test_should_respond_200_with_granular_dag_access(self):
+ self._create_dag_models(1)
+ response = self.client.get(
+ "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
+ )
+ assert response.status_code == 200
+
+ def test_should_respond_403_with_granular_access_for_different_dag(self):
+ self._create_dag_models(3)
+ response = self.client.get(
+ "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
+ )
+ assert response.status_code == 403
+
+
+class TestGetDags(TestDagEndpoint):
+ def test_should_respond_200_with_granular_dag_access(self):
+ self._create_dag_models(3)
+ response = self.client.get(
+ "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
+ )
+ assert response.status_code == 200
+ assert len(response.json["dags"]) == 1
+ assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1"
+
+
+class TestPatchDag(TestDagEndpoint):
+ @provide_session
+ def _create_dag_model(self, session=None):
+ dag_model = DagModel(
+ dag_id="TEST_DAG_1", fileloc="/tmp/dag_1.py", timetable_summary="2 2 * * *", is_paused=True
+ )
+ session.add(dag_model)
+ return dag_model
+
+ def test_should_respond_200_on_patch_with_granular_dag_access(self, session):
+ self._create_dag_models(1)
+ response = self.client.patch(
+ "/api/v1/dags/TEST_DAG_1",
+ json={
+ "is_paused": False,
+ },
+ environ_overrides={"REMOTE_USER": "test_granular_permissions"},
+ )
+ assert response.status_code == 200
+ _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None)
+
+ def test_validation_error_raises_400(self):
+ patch_body = {
+ "ispaused": True,
+ }
+ dag_model = self._create_dag_model()
+ response = self.client.patch(
+ f"/api/v1/dags/{dag_model.dag_id}",
+ json=patch_body,
+ environ_overrides={"REMOTE_USER": "test_granular_permissions"},
+ )
+ assert response.status_code == 400
+ assert response.json == {
+ "detail": "{'ispaused': ['Unknown field.']}",
+ "status": 400,
+ "title": "Bad Request",
+ "type": EXCEPTIONS_LINK_MAP[400],
+ }
+
+
+class TestPatchDags(TestDagEndpoint):
+ def test_should_respond_200_with_granular_dag_access(self):
+ self._create_dag_models(3)
+ response = self.client.patch(
+ "api/v1/dags?dag_id_pattern=~",
+ json={
+ "is_paused": False,
+ },
+ environ_overrides={"REMOTE_USER": "test_granular_permissions"},
+ )
+ assert response.status_code == 200
+ assert len(response.json["dags"]) == 1
+ assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1"
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py
new file mode 100644
index 0000000000000..a58ea08ff31cf
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py
@@ -0,0 +1,273 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import timedelta
+
+import pytest
+
+from airflow.models.dag import DAG, DagModel
+from airflow.models.dagrun import DagRun
+from airflow.models.param import Param
+from airflow.security import permissions
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.utils.state import DagRunState
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+
+try:
+ from airflow.utils.types import DagRunTriggeredByType, DagRunType
+except ImportError:
+ if AIRFLOW_V_3_0_PLUS:
+ raise
+ else:
+ pass
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import (
+ create_user,
+ delete_roles,
+ delete_user,
+)
+from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_auth_api):
+ app = minimal_app_for_auth_api
+
+ create_user(
+ app,
+ username="test_no_dag_run_create_permission",
+ role_name="TestNoDagRunCreatePermission",
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN),
+ ],
+ )
+ create_user(
+ app,
+ username="test_dag_view_only",
+ role_name="TestViewDags",
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN),
+ ],
+ )
+ create_user(
+ app,
+ username="test_view_dags",
+ role_name="TestViewDags",
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN),
+ ],
+ )
+ create_user(
+ app,
+ username="test_granular_permissions",
+ role_name="TestGranularDag",
+ permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)],
+ )
+ app.appbuilder.sm.sync_perm_for_dag(
+ "TEST_DAG_ID",
+ access_control={
+ "TestGranularDag": {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ},
+ "TestNoDagRunCreatePermission": {permissions.RESOURCE_DAG_RUN: {permissions.ACTION_CAN_CREATE}},
+ },
+ )
+
+ yield app
+
+ delete_user(app, username="test_dag_view_only")
+ delete_user(app, username="test_view_dags")
+ delete_user(app, username="test_granular_permissions")
+ delete_user(app, username="test_no_dag_run_create_permission")
+ delete_roles(app)
+
+
+class TestDagRunEndpoint:
+ default_time = "2020-06-11T18:00:00+00:00"
+ default_time_2 = "2020-06-12T18:00:00+00:00"
+ default_time_3 = "2020-06-13T18:00:00+00:00"
+
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app) -> None:
+ self.app = configured_app
+ self.client = self.app.test_client() # type:ignore
+ clear_db_runs()
+ clear_db_serialized_dags()
+ clear_db_dags()
+
+ def teardown_method(self) -> None:
+ clear_db_runs()
+ clear_db_dags()
+ clear_db_serialized_dags()
+
+ def _create_dag(self, dag_id):
+ dag_instance = DagModel(dag_id=dag_id)
+ dag_instance.is_active = True
+ with create_session() as session:
+ session.add(dag_instance)
+ dag = DAG(dag_id=dag_id, schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)})
+ self.app.dag_bag.bag_dag(dag)
+ return dag_instance
+
+ def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1):
+ dag_runs = []
+ dags = []
+ triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
+
+ for i in range(idx_start, idx_start + 2):
+ if i == 1:
+ dags.append(DagModel(dag_id="TEST_DAG_ID", is_active=True))
+ dagrun_model = DagRun(
+ dag_id="TEST_DAG_ID",
+ run_id=f"TEST_DAG_RUN_ID_{i}",
+ run_type=DagRunType.MANUAL,
+ execution_date=timezone.parse(self.default_time) + timedelta(days=i - 1),
+ start_date=timezone.parse(self.default_time),
+ external_trigger=True,
+ state=state,
+ **triggered_by_kwargs,
+ )
+ dagrun_model.updated_at = timezone.parse(self.default_time)
+ dag_runs.append(dagrun_model)
+
+ if extra_dag:
+ for i in range(idx_start + 2, idx_start + 4):
+ dags.append(DagModel(dag_id=f"TEST_DAG_ID_{i}"))
+ dag_runs.append(
+ DagRun(
+ dag_id=f"TEST_DAG_ID_{i}",
+ run_id=f"TEST_DAG_RUN_ID_{i}",
+ run_type=DagRunType.MANUAL,
+ execution_date=timezone.parse(self.default_time_2),
+ start_date=timezone.parse(self.default_time),
+ external_trigger=True,
+ state=state,
+ )
+ )
+ if commit:
+ with create_session() as session:
+ session.add_all(dag_runs)
+ session.add_all(dags)
+ return dag_runs
+
+
+class TestGetDagRuns(TestDagRunEndpoint):
+ def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self):
+ self._create_test_dag_run(extra_dag=True)
+ expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"]
+ response = self.client.get(
+ "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"}
+ )
+ assert response.status_code == 200
+ dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]]
+ assert dag_run_ids == expected_dag_run_ids
+
+
+class TestGetDagRunBatch(TestDagRunEndpoint):
+ def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self):
+ self._create_test_dag_run(extra_dag=True)
+ expected_response_json_1 = {
+ "dag_id": "TEST_DAG_ID",
+ "dag_run_id": "TEST_DAG_RUN_ID_1",
+ "end_date": None,
+ "state": "running",
+ "execution_date": self.default_time,
+ "logical_date": self.default_time,
+ "external_trigger": True,
+ "start_date": self.default_time,
+ "conf": {},
+ "data_interval_end": None,
+ "data_interval_start": None,
+ "last_scheduling_decision": None,
+ "run_type": "manual",
+ "note": None,
+ }
+ expected_response_json_1.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {})
+ expected_response_json_2 = {
+ "dag_id": "TEST_DAG_ID",
+ "dag_run_id": "TEST_DAG_RUN_ID_2",
+ "end_date": None,
+ "state": "running",
+ "execution_date": self.default_time_2,
+ "logical_date": self.default_time_2,
+ "external_trigger": True,
+ "start_date": self.default_time,
+ "conf": {},
+ "data_interval_end": None,
+ "data_interval_start": None,
+ "last_scheduling_decision": None,
+ "run_type": "manual",
+ "note": None,
+ }
+ expected_response_json_2.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {})
+
+ response = self.client.post(
+ "api/v1/dags/~/dagRuns/list",
+ json={"dag_ids": []},
+ environ_overrides={"REMOTE_USER": "test_granular_permissions"},
+ )
+ assert response.status_code == 200
+ assert response.json == {
+ "dag_runs": [
+ expected_response_json_1,
+ expected_response_json_2,
+ ],
+ "total_entries": 2,
+ }
+
+
+class TestPostDagRun(TestDagRunEndpoint):
+ def test_dagrun_trigger_with_dag_level_permissions(self):
+ self._create_dag("TEST_DAG_ID")
+ response = self.client.post(
+ "api/v1/dags/TEST_DAG_ID/dagRuns",
+ json={"conf": {"validated_number": 1}},
+ environ_overrides={"REMOTE_USER": "test_no_dag_run_create_permission"},
+ )
+ assert response.status_code == 200
+
+ @pytest.mark.parametrize(
+ "username",
+ ["test_dag_view_only", "test_view_dags", "test_granular_permissions"],
+ )
+ def test_should_raises_403_unauthorized(self, username):
+ self._create_dag("TEST_DAG_ID")
+ response = self.client.post(
+ "api/v1/dags/TEST_DAG_ID/dagRuns",
+ json={
+ "dag_run_id": "TEST_DAG_RUN_ID_1",
+ "execution_date": self.default_time,
+ },
+ environ_overrides={"REMOTE_USER": username},
+ )
+ assert response.status_code == 403
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py
new file mode 100644
index 0000000000000..f0d9b0da298c6
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py
@@ -0,0 +1,144 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import ast
+import os
+from typing import TYPE_CHECKING
+
+import pytest
+
+from airflow.models import DagBag
+from airflow.security import permissions
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+from tests.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+if TYPE_CHECKING:
+ from airflow.models.dag import DAG
+
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
+EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py")
+EXAMPLE_DAG_ID = "example_bash_operator"
+TEST_DAG_ID = "latest_only"
+NOT_READABLE_DAG_ID = "latest_only_with_trigger"
+TEST_MULTIPLE_DAGS_ID = "asset_produces_1"
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_auth_api):
+ app = minimal_app_for_auth_api
+ create_user(
+ app,
+ username="test",
+ role_name="Test",
+ permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)],
+ )
+ app.appbuilder.sm.sync_perm_for_dag(
+ TEST_DAG_ID,
+ access_control={"Test": [permissions.ACTION_CAN_READ]},
+ )
+ app.appbuilder.sm.sync_perm_for_dag(
+ EXAMPLE_DAG_ID,
+ access_control={"Test": [permissions.ACTION_CAN_READ]},
+ )
+ app.appbuilder.sm.sync_perm_for_dag(
+ TEST_MULTIPLE_DAGS_ID,
+ access_control={"Test": [permissions.ACTION_CAN_READ]},
+ )
+
+ yield app
+
+ delete_user(app, username="test")
+
+
+class TestGetSource:
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app) -> None:
+ self.app = configured_app
+ self.client = self.app.test_client() # type:ignore
+ self.clear_db()
+
+ def teardown_method(self) -> None:
+ self.clear_db()
+
+ @staticmethod
+ def clear_db():
+ clear_db_dags()
+ clear_db_serialized_dags()
+ clear_db_dag_code()
+
+ @staticmethod
+ def _get_dag_file_docstring(fileloc: str) -> str | None:
+ with open(fileloc) as f:
+ file_contents = f.read()
+ module = ast.parse(file_contents)
+ docstring = ast.get_docstring(module)
+ return docstring
+
+ def test_should_respond_406(self, url_safe_serializer):
+ dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
+ dagbag.sync_to_db()
+ test_dag: DAG = dagbag.dags[TEST_DAG_ID]
+
+ url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}"
+ response = self.client.get(
+ url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"}
+ )
+
+ assert 406 == response.status_code
+
+ def test_should_respond_403_not_readable(self, url_safe_serializer):
+ dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
+ dagbag.sync_to_db()
+ dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID]
+
+ response = self.client.get(
+ f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}",
+ headers={"Accept": "text/plain"},
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+ read_dag = self.client.get(
+ f"/api/v1/dags/{NOT_READABLE_DAG_ID}",
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+ assert response.status_code == 403
+ assert read_dag.status_code == 403
+
+ def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer):
+ dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE)
+ dagbag.sync_to_db()
+ dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID]
+
+ response = self.client.get(
+ f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}",
+ headers={"Accept": "text/plain"},
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+
+ read_dag = self.client.get(
+ f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}",
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+ assert response.status_code == 403
+ assert read_dag.status_code == 200
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py
new file mode 100644
index 0000000000000..adfde1cc5b3eb
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.models.dag import DagModel
+from airflow.models.dagwarning import DagWarning
+from airflow.security import permissions
+from airflow.utils.session import create_session
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+from tests.test_utils.db import clear_db_dag_warnings, clear_db_dags
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_auth_api):
+ app = minimal_app_for_auth_api
+ create_user(
+ app, # type:ignore
+ username="test_with_dag2_read",
+ role_name="TestWithDag2Read",
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING),
+ (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"),
+ ],
+ )
+
+ yield app
+
+ delete_user(app, username="test_with_dag2_read")
+
+
+class TestBaseDagWarning:
+ timestamp = "2020-06-10T12:00"
+
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app) -> None:
+ self.app = configured_app
+ self.client = self.app.test_client() # type:ignore
+
+ def teardown_method(self) -> None:
+ clear_db_dag_warnings()
+ clear_db_dags()
+
+
+class TestGetDagWarningEndpoint(TestBaseDagWarning):
+ def setup_class(self):
+ clear_db_dag_warnings()
+ clear_db_dags()
+
+ def setup_method(self):
+ with create_session() as session:
+ session.add(DagModel(dag_id="dag1"))
+ session.add(DagWarning("dag1", "non-existent pool", "test message"))
+ session.commit()
+
+ def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self):
+ response = self.client.get(
+ "/api/v1/dagWarnings",
+ environ_overrides={"REMOTE_USER": "test_with_dag2_read"},
+ query_string={"dag_id": "dag1"},
+ )
+ assert response.status_code == 403
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py
new file mode 100644
index 0000000000000..4d302722223d8
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_dataset_endpoint.py
@@ -0,0 +1,327 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Generator
+
+import pytest
+import time_machine
+
+from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+
+try:
+ from airflow.models.asset import AssetDagRunQueue, AssetModel
+except ImportError:
+ if AIRFLOW_V_3_0_PLUS:
+ raise
+ else:
+ pass
+from airflow.security import permissions
+from airflow.utils import timezone
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
+from tests.test_utils.db import clear_db_assets, clear_db_runs
+from tests.test_utils.www import _check_last_log
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_auth_api):
+ app = minimal_app_for_auth_api
+ create_user(
+ app,
+ username="test_queued_event",
+ role_name="TestQueuedEvent",
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET),
+ (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET),
+ ],
+ )
+
+ yield app
+
+ delete_user(app, username="test_queued_event")
+
+
+class TestAssetEndpoint:
+ default_time = "2020-06-11T18:00:00+00:00"
+
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app) -> None:
+ self.app = configured_app
+ self.client = self.app.test_client()
+ clear_db_assets()
+ clear_db_runs()
+
+ def teardown_method(self) -> None:
+ clear_db_assets()
+ clear_db_runs()
+
+ def _create_asset(self, session):
+ asset_model = AssetModel(
+ id=1,
+ uri="s3://bucket/key",
+ extra={"foo": "bar"},
+ created_at=timezone.parse(self.default_time),
+ updated_at=timezone.parse(self.default_time),
+ )
+ session.add(asset_model)
+ session.commit()
+ return asset_model
+
+
+class TestQueuedEventEndpoint(TestAssetEndpoint):
+ @pytest.fixture
+ def time_freezer(self) -> Generator:
+ freezer = time_machine.travel(self.default_time, tick=False)
+ freezer.start()
+
+ yield
+
+ freezer.stop()
+
+ def _create_asset_dag_run_queues(self, dag_id, dataset_id, session):
+ ddrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id)
+ session.add(ddrq)
+ session.commit()
+ return ddrq
+
+
+class TestGetDagDatasetQueuedEvent(TestQueuedEventEndpoint):
+ @pytest.mark.usefixtures("time_freezer")
+ def test_should_respond_200(self, session, create_dummy_dag):
+ dag, _ = create_dummy_dag()
+ dag_id = dag.dag_id
+ dataset_id = self._create_asset(session).id
+ self._create_asset_dag_run_queues(dag_id, dataset_id, session)
+ dataset_uri = "s3://bucket/key"
+
+ response = self.client.get(
+ f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}",
+ environ_overrides={"REMOTE_USER": "test_queued_event"},
+ )
+
+ assert response.status_code == 200
+ assert response.json == {
+ "created_at": self.default_time,
+ "uri": "s3://bucket/key",
+ "dag_id": "dag",
+ }
+
+ def test_should_respond_404(self):
+ dag_id = "not_exists"
+ dataset_uri = "not_exists"
+
+ response = self.client.get(
+ f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}",
+ environ_overrides={"REMOTE_USER": "test_queued_event"},
+ )
+
+ assert response.status_code == 404
+ assert {
+ "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found",
+ "status": 404,
+ "title": "Queue event not found",
+ "type": EXCEPTIONS_LINK_MAP[404],
+ } == response.json
+
+
+class TestDeleteDagDatasetQueuedEvent(TestAssetEndpoint):
+ def test_delete_should_respond_204(self, session, create_dummy_dag):
+ dag, _ = create_dummy_dag()
+ dag_id = dag.dag_id
+ dataset_uri = "s3://bucket/key"
+ dataset_id = self._create_asset(session).id
+
+ ddrq = AssetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id)
+ session.add(ddrq)
+ session.commit()
+ conn = session.query(AssetDagRunQueue).all()
+ assert len(conn) == 1
+
+ response = self.client.delete(
+ f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}",
+ environ_overrides={"REMOTE_USER": "test_queued_event"},
+ )
+
+ assert response.status_code == 204
+ conn = session.query(AssetDagRunQueue).all()
+ assert len(conn) == 0
+ _check_last_log(
+ session, dag_id=dag_id, event="api.delete_dag_dataset_queued_event", execution_date=None
+ )
+
+ def test_should_respond_404(self):
+ dag_id = "not_exists"
+ dataset_uri = "not_exists"
+
+ response = self.client.delete(
+ f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}",
+ environ_overrides={"REMOTE_USER": "test_queued_event"},
+ )
+
+ assert response.status_code == 404
+ assert {
+ "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found",
+ "status": 404,
+ "title": "Queue event not found",
+ "type": EXCEPTIONS_LINK_MAP[404],
+ } == response.json
+
+
+class TestGetDagDatasetQueuedEvents(TestQueuedEventEndpoint):
+ @pytest.mark.usefixtures("time_freezer")
+ def test_should_respond_200(self, session, create_dummy_dag):
+ dag, _ = create_dummy_dag()
+ dag_id = dag.dag_id
+ dataset_id = self._create_asset(session).id
+ self._create_asset_dag_run_queues(dag_id, dataset_id, session)
+
+ response = self.client.get(
+ f"/api/v1/dags/{dag_id}/datasets/queuedEvent",
+ environ_overrides={"REMOTE_USER": "test_queued_event"},
+ )
+
+ assert response.status_code == 200
+ assert response.json == {
+ "queued_events": [
+ {
+ "created_at": self.default_time,
+ "uri": "s3://bucket/key",
+ "dag_id": "dag",
+ }
+ ],
+ "total_entries": 1,
+ }
+
+ def test_should_respond_404(self):
+ dag_id = "not_exists"
+
+ response = self.client.get(
+ f"/api/v1/dags/{dag_id}/datasets/queuedEvent",
+ environ_overrides={"REMOTE_USER": "test_queued_event"},
+ )
+
+ assert response.status_code == 404
+ assert {
+ "detail": "Queue event with dag_id: `not_exists` was not found",
+ "status": 404,
+ "title": "Queue event not found",
+ "type": EXCEPTIONS_LINK_MAP[404],
+ } == response.json
+
+
+class TestDeleteDagDatasetQueuedEvents(TestAssetEndpoint):
+ def test_should_respond_404(self):
+ dag_id = "not_exists"
+
+ response = self.client.delete(
+ f"/api/v1/dags/{dag_id}/datasets/queuedEvent",
+ environ_overrides={"REMOTE_USER": "test_queued_event"},
+ )
+
+ assert response.status_code == 404
+ assert {
+ "detail": "Queue event with dag_id: `not_exists` was not found",
+ "status": 404,
+ "title": "Queue event not found",
+ "type": EXCEPTIONS_LINK_MAP[404],
+ } == response.json
+
+
+class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint):
+ @pytest.mark.usefixtures("time_freezer")
+ def test_should_respond_200(self, session, create_dummy_dag):
+ dag, _ = create_dummy_dag()
+ dag_id = dag.dag_id
+ dataset_id = self._create_asset(session).id
+ self._create_asset_dag_run_queues(dag_id, dataset_id, session)
+ dataset_uri = "s3://bucket/key"
+
+ response = self.client.get(
+ f"/api/v1/datasets/queuedEvent/{dataset_uri}",
+ environ_overrides={"REMOTE_USER": "test_queued_event"},
+ )
+
+ assert response.status_code == 200
+ assert response.json == {
+ "queued_events": [
+ {
+ "created_at": self.default_time,
+ "uri": "s3://bucket/key",
+ "dag_id": "dag",
+ }
+ ],
+ "total_entries": 1,
+ }
+
+ def test_should_respond_404(self):
+ dataset_uri = "not_exists"
+
+ response = self.client.get(
+ f"/api/v1/datasets/queuedEvent/{dataset_uri}",
+ environ_overrides={"REMOTE_USER": "test_queued_event"},
+ )
+
+ assert response.status_code == 404
+ assert {
+ "detail": "Queue event with asset uri: `not_exists` was not found",
+ "status": 404,
+ "title": "Queue event not found",
+ "type": EXCEPTIONS_LINK_MAP[404],
+ } == response.json
+
+
+class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint):
+ def test_delete_should_respond_204(self, session, create_dummy_dag):
+ dag, _ = create_dummy_dag()
+ dag_id = dag.dag_id
+ dataset_id = self._create_asset(session).id
+ self._create_asset_dag_run_queues(dag_id, dataset_id, session)
+ dataset_uri = "s3://bucket/key"
+
+ response = self.client.delete(
+ f"/api/v1/datasets/queuedEvent/{dataset_uri}",
+ environ_overrides={"REMOTE_USER": "test_queued_event"},
+ )
+
+ assert response.status_code == 204
+ conn = session.query(AssetDagRunQueue).all()
+ assert len(conn) == 0
+ _check_last_log(session, dag_id=None, event="api.delete_dataset_queued_events", execution_date=None)
+
+ def test_should_respond_404(self):
+ dataset_uri = "not_exists"
+
+ response = self.client.delete(
+ f"/api/v1/datasets/queuedEvent/{dataset_uri}",
+ environ_overrides={"REMOTE_USER": "test_queued_event"},
+ )
+
+ assert response.status_code == 404
+ assert {
+ "detail": "Queue event with asset uri: `not_exists` was not found",
+ "status": 404,
+ "title": "Queue event not found",
+ "type": EXCEPTIONS_LINK_MAP[404],
+ } == response.json
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py
new file mode 100644
index 0000000000000..acf3ca62684a1
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.models import Log
+from airflow.security import permissions
+from airflow.utils import timezone
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+from tests.test_utils.db import clear_db_logs
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_auth_api):
+ app = minimal_app_for_auth_api
+ create_user(
+ app,
+ username="test_granular",
+ role_name="TestGranular",
+ permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)],
+ )
+ app.appbuilder.sm.sync_perm_for_dag(
+ "TEST_DAG_ID_1",
+ access_control={"TestGranular": [permissions.ACTION_CAN_READ]},
+ )
+ app.appbuilder.sm.sync_perm_for_dag(
+ "TEST_DAG_ID_2",
+ access_control={"TestGranular": [permissions.ACTION_CAN_READ]},
+ )
+
+ yield app
+
+ delete_user(app, username="test_granular")
+
+
+@pytest.fixture
+def task_instance(session, create_task_instance, request):
+ return create_task_instance(
+ session=session,
+ dag_id="TEST_DAG_ID",
+ task_id="TEST_TASK_ID",
+ run_id="TEST_RUN_ID",
+ execution_date=request.instance.default_time,
+ )
+
+
+@pytest.fixture
+def create_log_model(create_task_instance, task_instance, session, request):
+ def maker(event, when, **kwargs):
+ log_model = Log(
+ event=event,
+ task_instance=task_instance,
+ **kwargs,
+ )
+ log_model.dttm = when
+
+ session.add(log_model)
+ session.flush()
+ return log_model
+
+ return maker
+
+
+class TestEventLogEndpoint:
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app) -> None:
+ self.app = configured_app
+ self.client = self.app.test_client() # type:ignore
+ clear_db_logs()
+ self.default_time = timezone.parse("2020-06-10T20:00:00+00:00")
+ self.default_time_2 = timezone.parse("2020-06-11T07:00:00+00:00")
+
+ def teardown_method(self) -> None:
+ clear_db_logs()
+
+
+class TestGetEventLogs(TestEventLogEndpoint):
+ def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session):
+ eventlog1 = create_log_model(
+ event="TEST_EVENT_1",
+ dag_id="TEST_DAG_ID_1",
+ task_id="TEST_TASK_ID_1",
+ owner="TEST_OWNER_1",
+ when=self.default_time,
+ )
+ eventlog2 = create_log_model(
+ event="TEST_EVENT_2",
+ dag_id="TEST_DAG_ID_2",
+ task_id="TEST_TASK_ID_2",
+ owner="TEST_OWNER_2",
+ when=self.default_time_2,
+ )
+ session.add_all([eventlog1, eventlog2])
+ session.commit()
+ for attr in ["dag_id", "task_id", "owner", "event"]:
+ attr_value = f"TEST_{attr}_1".upper()
+ response = self.client.get(
+ f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"}
+ )
+ assert response.status_code == 200
+ assert response.json["total_entries"] == 1
+ assert len(response.json["event_logs"]) == 1
+ assert response.json["event_logs"][0][attr] == attr_value
+
+ def test_should_filter_eventlogs_by_included_events(self, create_log_model):
+ for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]:
+ create_log_model(event=event, when=self.default_time)
+ response = self.client.get(
+ "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2",
+ environ_overrides={"REMOTE_USER": "test_granular"},
+ )
+ assert response.status_code == 200
+ response_data = response.json
+ assert len(response_data["event_logs"]) == 2
+ assert response_data["total_entries"] == 2
+ assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]}
+
+ def test_should_filter_eventlogs_by_excluded_events(self, create_log_model):
+ for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]:
+ create_log_model(event=event, when=self.default_time)
+ response = self.client.get(
+ "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2",
+ environ_overrides={"REMOTE_USER": "test_granular"},
+ )
+ assert response.status_code == 200
+ response_data = response.json
+ assert len(response_data["event_logs"]) == 1
+ assert response_data["total_entries"] == 1
+ assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]}
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py
new file mode 100644
index 0000000000000..a2fa1d028a3f2
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.models.dag import DagModel
+from airflow.security import permissions
+from airflow.utils import timezone
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS, ParseImportError
+from tests.test_utils.db import clear_db_dags, clear_db_import_errors
+from tests.test_utils.permissions import _resource_name
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+TEST_DAG_IDS = ["test_dag", "test_dag2"]
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_auth_api):
+ app = minimal_app_for_auth_api
+ create_user(
+ app,
+ username="test_single_dag",
+ role_name="TestSingleDAG",
+ permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)],
+ )
+ # For some reason, DAG level permissions are not synced when in the above list of perms,
+ # so do it manually here:
+ app.appbuilder.sm.bulk_sync_roles(
+ [
+ {
+ "role": "TestSingleDAG",
+ "perms": [
+ (
+ permissions.ACTION_CAN_READ,
+ _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG),
+ )
+ ],
+ }
+ ]
+ )
+
+ yield app
+
+ delete_user(app, username="test_single_dag")
+
+
+class TestBaseImportError:
+ timestamp = "2020-06-10T12:00"
+
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app) -> None:
+ self.app = configured_app
+ self.client = self.app.test_client() # type:ignore
+
+ clear_db_import_errors()
+ clear_db_dags()
+
+ def teardown_method(self) -> None:
+ clear_db_import_errors()
+ clear_db_dags()
+
+ @staticmethod
+ def _normalize_import_errors(import_errors):
+ for i, import_error in enumerate(import_errors, 1):
+ import_error["import_error_id"] = i
+
+
+class TestGetImportErrorEndpoint(TestBaseImportError):
+ def test_should_raise_403_forbidden_without_dag_read(self, session):
+ import_error = ParseImportError(
+ filename="Lorem_ipsum.py",
+ stacktrace="Lorem ipsum",
+ timestamp=timezone.parse(self.timestamp, timezone="UTC"),
+ )
+ session.add(import_error)
+ session.commit()
+
+ response = self.client.get(
+ f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
+ )
+
+ assert response.status_code == 403
+
+ def test_should_return_200_with_single_dag_read(self, session):
+ dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py")
+ session.add(dag_model)
+ import_error = ParseImportError(
+ filename="Lorem_ipsum.py",
+ stacktrace="Lorem ipsum",
+ timestamp=timezone.parse(self.timestamp, timezone="UTC"),
+ )
+ session.add(import_error)
+ session.commit()
+
+ response = self.client.get(
+ f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
+ )
+
+ assert response.status_code == 200
+ response_data = response.json
+ response_data["import_error_id"] = 1
+ assert {
+ "filename": "Lorem_ipsum.py",
+ "import_error_id": 1,
+ "stack_trace": "Lorem ipsum",
+ "timestamp": "2020-06-10T12:00:00+00:00",
+ } == response_data
+
+ def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session):
+ for dag_id in TEST_DAG_IDS:
+ dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py")
+ session.add(dag_model)
+ import_error = ParseImportError(
+ filename="Lorem_ipsum.py",
+ stacktrace="Lorem ipsum",
+ timestamp=timezone.parse(self.timestamp, timezone="UTC"),
+ )
+ session.add(import_error)
+ session.commit()
+
+ response = self.client.get(
+ f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
+ )
+
+ assert response.status_code == 200
+ response_data = response.json
+ response_data["import_error_id"] = 1
+ assert {
+ "filename": "Lorem_ipsum.py",
+ "import_error_id": 1,
+ "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file",
+ "timestamp": "2020-06-10T12:00:00+00:00",
+ } == response_data
+
+
+class TestGetImportErrorsEndpoint(TestBaseImportError):
+ def test_get_import_errors_single_dag(self, session):
+ for dag_id in TEST_DAG_IDS:
+ fake_filename = f"/tmp/{dag_id}.py"
+ dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
+ session.add(dag_model)
+ importerror = ParseImportError(
+ filename=fake_filename,
+ stacktrace="Lorem ipsum",
+ timestamp=timezone.parse(self.timestamp, timezone="UTC"),
+ )
+ session.add(importerror)
+ session.commit()
+
+ response = self.client.get(
+ "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"}
+ )
+
+ assert response.status_code == 200
+ response_data = response.json
+ self._normalize_import_errors(response_data["import_errors"])
+ assert {
+ "import_errors": [
+ {
+ "filename": "/tmp/test_dag.py",
+ "import_error_id": 1,
+ "stack_trace": "Lorem ipsum",
+ "timestamp": "2020-06-10T12:00:00+00:00",
+ },
+ ],
+ "total_entries": 1,
+ } == response_data
+
+ def test_get_import_errors_single_dag_in_dagfile(self, session):
+ for dag_id in TEST_DAG_IDS:
+ fake_filename = "/tmp/all_in_one.py"
+ dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
+ session.add(dag_model)
+
+ importerror = ParseImportError(
+ filename="/tmp/all_in_one.py",
+ stacktrace="Lorem ipsum",
+ timestamp=timezone.parse(self.timestamp, timezone="UTC"),
+ )
+ session.add(importerror)
+ session.commit()
+
+ response = self.client.get(
+ "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"}
+ )
+
+ assert response.status_code == 200
+ response_data = response.json
+ self._normalize_import_errors(response_data["import_errors"])
+ assert {
+ "import_errors": [
+ {
+ "filename": "/tmp/all_in_one.py",
+ "import_error_id": 1,
+ "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file",
+ "timestamp": "2020-06-10T12:00:00+00:00",
+ },
+ ],
+ "total_entries": 1,
+ } == response_data
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py
index 30cfaeb227903..413a49a9d86a1 100644
--- a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py
@@ -19,6 +19,13 @@
import pytest
from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import (
+ create_role,
+ create_user,
+ delete_role,
+ delete_user,
+)
+from tests.test_utils.api_connexion_utils import assert_401
from tests.test_utils.compat import ignore_provider_compatibility_error
with ignore_provider_compatibility_error("2.9.0+", __file__):
@@ -27,13 +34,6 @@
from airflow.security import permissions
-from tests.test_utils.api_connexion_utils import (
- assert_401,
- create_role,
- create_user,
- delete_role,
- delete_user,
-)
pytestmark = pytest.mark.db_test
@@ -42,7 +42,7 @@
def configured_app(minimal_app_for_auth_api):
app = minimal_app_for_auth_api
create_user(
- app, # type: ignore
+ app,
username="test",
role_name="Test",
permissions=[
@@ -53,11 +53,11 @@ def configured_app(minimal_app_for_auth_api):
(permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION),
],
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name="TestNoPermissions")
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
class TestRoleEndpoint:
diff --git a/tests/api_connexion/schemas/test_role_and_permission_schema.py b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py
similarity index 85%
rename from tests/api_connexion/schemas/test_role_and_permission_schema.py
rename to tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py
index f2967d519794c..4a2f0068e5e4a 100644
--- a/tests/api_connexion/schemas/test_role_and_permission_schema.py
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_schema.py
@@ -31,19 +31,19 @@
class TestRoleCollectionItemSchema:
@pytest.fixture(scope="class")
- def role(self, minimal_app_for_api):
+ def role(self, minimal_app_for_auth_api):
yield create_role(
- minimal_app_for_api, # type: ignore
+ minimal_app_for_auth_api, # type: ignore
name="Test",
permissions=[
(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION),
],
)
- delete_role(minimal_app_for_api, "Test")
+ delete_role(minimal_app_for_auth_api, "Test")
@pytest.fixture(autouse=True)
- def _set_attrs(self, minimal_app_for_api, role):
- self.app = minimal_app_for_api
+ def _set_attrs(self, minimal_app_for_auth_api, role):
+ self.app = minimal_app_for_auth_api
self.role = role
def test_serialize(self):
@@ -67,26 +67,26 @@ def test_deserialize(self):
class TestRoleCollectionSchema:
@pytest.fixture(scope="class")
- def role1(self, minimal_app_for_api):
+ def role1(self, minimal_app_for_auth_api):
yield create_role(
- minimal_app_for_api, # type: ignore
+ minimal_app_for_auth_api, # type: ignore
name="Test1",
permissions=[
(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION),
],
)
- delete_role(minimal_app_for_api, "Test1")
+ delete_role(minimal_app_for_auth_api, "Test1")
@pytest.fixture(scope="class")
- def role2(self, minimal_app_for_api):
+ def role2(self, minimal_app_for_auth_api):
yield create_role(
- minimal_app_for_api, # type: ignore
+ minimal_app_for_auth_api, # type: ignore
name="Test2",
permissions=[
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
],
)
- delete_role(minimal_app_for_api, "Test2")
+ delete_role(minimal_app_for_auth_api, "Test2")
def test_serialize(self, role1, role2):
instance = RoleCollection([role1, role2], total_entries=2)
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py
new file mode 100644
index 0000000000000..69b3c221eae93
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py
@@ -0,0 +1,427 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import datetime as dt
+import urllib
+
+import pytest
+
+from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
+from airflow.models import DagRun, TaskInstance
+from airflow.security import permissions
+from airflow.utils.session import provide_session
+from airflow.utils.state import State
+from airflow.utils.timezone import datetime
+from airflow.utils.types import DagRunType
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import (
+ create_user,
+ delete_roles,
+ delete_user,
+)
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+from tests.test_utils.db import clear_db_runs, clear_db_sla_miss, clear_rendered_ti_fields
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+DEFAULT_DATETIME_1 = datetime(2020, 1, 1)
+DEFAULT_DATETIME_STR_1 = "2020-01-01T00:00:00+00:00"
+DEFAULT_DATETIME_STR_2 = "2020-01-02T00:00:00+00:00"
+
+QUOTED_DEFAULT_DATETIME_STR_1 = urllib.parse.quote(DEFAULT_DATETIME_STR_1)
+QUOTED_DEFAULT_DATETIME_STR_2 = urllib.parse.quote(DEFAULT_DATETIME_STR_2)
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_auth_api):
+ app = minimal_app_for_auth_api
+ create_user(
+ app,
+ username="test_dag_read_only",
+ role_name="TestDagReadOnly",
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
+ ],
+ )
+ create_user(
+ app,
+ username="test_task_read_only",
+ role_name="TestTaskReadOnly",
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ],
+ )
+ create_user(
+ app,
+ username="test_read_only_one_dag",
+ role_name="TestReadOnlyOneDag",
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+ ],
+ )
+ # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms,
+ # so do it manually here:
+ app.appbuilder.sm.bulk_sync_roles(
+ [
+ {
+ "role": "TestReadOnlyOneDag",
+ "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")],
+ }
+ ]
+ )
+
+ yield app
+
+ delete_user(app, username="test_dag_read_only")
+ delete_user(app, username="test_task_read_only")
+ delete_user(app, username="test_read_only_one_dag")
+ delete_roles(app)
+
+
+class TestTaskInstanceEndpoint:
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app, dagbag) -> None:
+ self.default_time = DEFAULT_DATETIME_1
+ self.ti_init = {
+ "execution_date": self.default_time,
+ "state": State.RUNNING,
+ }
+ self.ti_extras = {
+ "start_date": self.default_time + dt.timedelta(days=1),
+ "end_date": self.default_time + dt.timedelta(days=2),
+ "pid": 100,
+ "duration": 10000,
+ "pool": "default_pool",
+ "queue": "default_queue",
+ "job_id": 0,
+ }
+ self.app = configured_app
+ self.client = self.app.test_client() # type:ignore
+ clear_db_runs()
+ clear_db_sla_miss()
+ clear_rendered_ti_fields()
+ self.dagbag = dagbag
+
+ def create_task_instances(
+ self,
+ session,
+ dag_id: str = "example_python_operator",
+ update_extras: bool = True,
+ task_instances=None,
+ dag_run_state=State.RUNNING,
+ with_ti_history=False,
+ ):
+ """Method to create task instances using kwargs and default arguments"""
+
+ dag = self.dagbag.get_dag(dag_id)
+ tasks = dag.tasks
+ counter = len(tasks)
+ if task_instances is not None:
+ counter = min(len(task_instances), counter)
+
+ run_id = "TEST_DAG_RUN_ID"
+ execution_date = self.ti_init.pop("execution_date", self.default_time)
+ dr = None
+
+ tis = []
+ for i in range(counter):
+ if task_instances is None:
+ pass
+ elif update_extras:
+ self.ti_extras.update(task_instances[i])
+ else:
+ self.ti_init.update(task_instances[i])
+
+ if "execution_date" in self.ti_init:
+ run_id = f"TEST_DAG_RUN_ID_{i}"
+ execution_date = self.ti_init.pop("execution_date")
+ dr = None
+
+ if not dr:
+ dr = DagRun(
+ run_id=run_id,
+ dag_id=dag_id,
+ execution_date=execution_date,
+ run_type=DagRunType.MANUAL,
+ state=dag_run_state,
+ )
+ session.add(dr)
+ ti = TaskInstance(task=tasks[i], **self.ti_init)
+ session.add(ti)
+ ti.dag_run = dr
+ ti.note = "placeholder-note"
+
+ for key, value in self.ti_extras.items():
+ setattr(ti, key, value)
+ tis.append(ti)
+
+ session.commit()
+ if with_ti_history:
+ for ti in tis:
+ ti.try_number = 1
+ session.merge(ti)
+ session.commit()
+ dag.clear()
+ for ti in tis:
+ ti.try_number = 2
+ ti.queue = "default_queue"
+ session.merge(ti)
+ session.commit()
+ return tis
+
+
+class TestGetTaskInstance(TestTaskInstanceEndpoint):
+ def setup_method(self):
+ clear_db_runs()
+
+ def teardown_method(self):
+ clear_db_runs()
+
+ @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"])
+ @provide_session
+ def test_should_respond_200(self, username, session):
+ self.create_task_instances(session)
+ # Update ti and set operator to None to
+ # test that operator field is nullable.
+ # This prevents issue when users upgrade to 2.0+
+ # from 1.10.x
+ # https://github.com/apache/airflow/issues/14421
+ session.query(TaskInstance).update({TaskInstance.operator: None}, synchronize_session="fetch")
+ session.commit()
+ response = self.client.get(
+ "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context",
+ environ_overrides={"REMOTE_USER": username},
+ )
+ assert response.status_code == 200
+
+
+class TestGetTaskInstances(TestTaskInstanceEndpoint):
+ @pytest.mark.parametrize(
+ "task_instances, user, expected_ti",
+ [
+ pytest.param(
+ {
+ "example_python_operator": 2,
+ "example_skip_dag": 1,
+ },
+ "test_read_only_one_dag",
+ 2,
+ ),
+ pytest.param(
+ {
+ "example_python_operator": 1,
+ "example_skip_dag": 2,
+ },
+ "test_read_only_one_dag",
+ 1,
+ ),
+ ],
+ )
+ def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session):
+ for dag_id in task_instances:
+ self.create_task_instances(
+ session,
+ task_instances=[
+ {"execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=i)}
+ for i in range(task_instances[dag_id])
+ ],
+ dag_id=dag_id,
+ )
+ response = self.client.get(
+ "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user}
+ )
+ assert response.status_code == 200
+ assert response.json["total_entries"] == expected_ti
+ assert len(response.json["task_instances"]) == expected_ti
+
+
+class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
+ @pytest.mark.parametrize(
+ "task_instances, update_extras, payload, expected_ti_count, username",
+ [
+ pytest.param(
+ [
+ {"pool": "test_pool_1"},
+ {"pool": "test_pool_2"},
+ {"pool": "test_pool_3"},
+ ],
+ True,
+ {"pool": ["test_pool_1", "test_pool_2"]},
+ 2,
+ "test_dag_read_only",
+ id="test pool filter",
+ ),
+ pytest.param(
+ [
+ {"state": State.RUNNING},
+ {"state": State.QUEUED},
+ {"state": State.SUCCESS},
+ {"state": State.NONE},
+ ],
+ False,
+ {"state": ["running", "queued", "none"]},
+ 3,
+ "test_task_read_only",
+ id="test state filter",
+ ),
+ pytest.param(
+ [
+ {"state": State.NONE},
+ {"state": State.NONE},
+ {"state": State.NONE},
+ {"state": State.NONE},
+ ],
+ False,
+ {},
+ 4,
+ "test_task_read_only",
+ id="test dag with null states",
+ ),
+ pytest.param(
+ [
+ {"end_date": DEFAULT_DATETIME_1},
+ {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)},
+ {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)},
+ ],
+ True,
+ {
+ "end_date_gte": DEFAULT_DATETIME_STR_1,
+ "end_date_lte": DEFAULT_DATETIME_STR_2,
+ },
+ 2,
+ "test_task_read_only",
+ id="test end date filter",
+ ),
+ pytest.param(
+ [
+ {"start_date": DEFAULT_DATETIME_1},
+ {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)},
+ {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)},
+ ],
+ True,
+ {
+ "start_date_gte": DEFAULT_DATETIME_STR_1,
+ "start_date_lte": DEFAULT_DATETIME_STR_2,
+ },
+ 2,
+ "test_dag_read_only",
+ id="test start date filter",
+ ),
+ ],
+ )
+ def test_should_respond_200(
+ self, task_instances, update_extras, payload, expected_ti_count, username, session
+ ):
+ self.create_task_instances(
+ session,
+ update_extras=update_extras,
+ task_instances=task_instances,
+ )
+ response = self.client.post(
+ "/api/v1/dags/~/dagRuns/~/taskInstances/list",
+ environ_overrides={"REMOTE_USER": username},
+ json=payload,
+ )
+ assert response.status_code == 200, response.json
+ assert expected_ti_count == response.json["total_entries"]
+ assert expected_ti_count == len(response.json["task_instances"])
+
+ def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session):
+ self.create_task_instances(session=session)
+ self.create_task_instances(session=session, dag_id="example_skip_dag")
+ payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]}
+
+ response = self.client.post(
+ "/api/v1/dags/~/dagRuns/~/taskInstances/list",
+ environ_overrides={"REMOTE_USER": "test_read_only_one_dag"},
+ json=payload,
+ )
+ assert response.status_code == 403
+ assert response.json == {
+ "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']",
+ "status": 403,
+ "title": "Forbidden",
+ "type": EXCEPTIONS_LINK_MAP[403],
+ }
+
+
+class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint):
+ @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"])
+ def test_should_raise_403_forbidden(self, username):
+ response = self.client.post(
+ "/api/v1/dags/example_python_operator/updateTaskInstancesState",
+ environ_overrides={"REMOTE_USER": username},
+ json={
+ "dry_run": True,
+ "task_id": "print_the_context",
+ "execution_date": DEFAULT_DATETIME_1.isoformat(),
+ "include_upstream": True,
+ "include_downstream": True,
+ "include_future": True,
+ "include_past": True,
+ "new_state": "failed",
+ },
+ )
+ assert response.status_code == 403
+
+
+class TestPatchTaskInstance(TestTaskInstanceEndpoint):
+ ENDPOINT_URL = (
+ "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context"
+ )
+
+ @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"])
+ def test_should_raise_403_forbidden(self, username):
+ response = self.client.patch(
+ self.ENDPOINT_URL,
+ environ_overrides={"REMOTE_USER": username},
+ json={
+ "dry_run": True,
+ "new_state": "failed",
+ },
+ )
+ assert response.status_code == 403
+
+
+class TestGetTaskInstanceTry(TestTaskInstanceEndpoint):
+ def setup_method(self):
+ clear_db_runs()
+
+ def teardown_method(self):
+ clear_db_runs()
+
+ @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"])
+ @provide_session
+ def test_should_respond_200(self, username, session):
+ self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True)
+
+ response = self.client.get(
+ "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1",
+ environ_overrides={"REMOTE_USER": username},
+ )
+ assert response.status_code == 200
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py
index bc400c8a43fad..7f2c885bab52c 100644
--- a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py
@@ -30,7 +30,12 @@
with ignore_provider_compatibility_error("2.9.0+", __file__):
from airflow.providers.fab.auth_manager.models import User
-from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_role, delete_user
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import (
+ create_user,
+ delete_role,
+ delete_user,
+)
+from tests.test_utils.api_connexion_utils import assert_401
from tests.test_utils.config import conf_vars
pytestmark = pytest.mark.db_test
@@ -43,7 +48,7 @@
def configured_app(minimal_app_for_auth_api):
app = minimal_app_for_auth_api
create_user(
- app, # type: ignore
+ app,
username="test",
role_name="Test",
permissions=[
@@ -53,12 +58,12 @@ def configured_app(minimal_app_for_auth_api):
(permissions.ACTION_CAN_READ, permissions.RESOURCE_USER),
],
)
- create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
+ create_user(app, username="test_no_permissions", role_name="TestNoPermissions")
yield app
- delete_user(app, username="test") # type: ignore
- delete_user(app, username="test_no_permissions") # type: ignore
+ delete_user(app, username="test")
+ delete_user(app, username="test_no_permissions")
delete_role(app, name="TestNoPermissions")
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py
index 265407622e269..f3399de6a9775 100644
--- a/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_schema.py
@@ -18,6 +18,7 @@
import pytest
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_role, delete_role
from tests.test_utils.compat import ignore_provider_compatibility_error
with ignore_provider_compatibility_error("2.9.0+", __file__):
@@ -30,8 +31,6 @@
DEFAULT_TIME = "2021-01-09T13:59:56.336000+00:00"
-from tests.test_utils.api_connexion_utils import create_role, delete_role # noqa: E402
-
pytestmark = pytest.mark.db_test
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py
new file mode 100644
index 0000000000000..a8e71e1a82466
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.models import Variable
+from airflow.security import permissions
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+from tests.test_utils.db import clear_db_variables
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_auth_api):
+ app = minimal_app_for_auth_api
+
+ create_user(
+ app,
+ username="test_read_only",
+ role_name="TestReadOnly",
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE),
+ ],
+ )
+ create_user(
+ app,
+ username="test_delete_only",
+ role_name="TestDeleteOnly",
+ permissions=[
+ (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE),
+ ],
+ )
+
+ yield app
+
+ delete_user(app, username="test_read_only")
+ delete_user(app, username="test_delete_only")
+
+
+class TestVariableEndpoint:
+ @pytest.fixture(autouse=True)
+ def setup_method(self, configured_app) -> None:
+ self.app = configured_app
+ self.client = self.app.test_client() # type:ignore
+ clear_db_variables()
+
+ def teardown_method(self) -> None:
+ clear_db_variables()
+
+
+class TestGetVariable(TestVariableEndpoint):
+ @pytest.mark.parametrize(
+ "user, expected_status_code",
+ [
+ ("test_read_only", 200),
+ ("test_delete_only", 403),
+ ],
+ )
+ def test_read_variable(self, user, expected_status_code):
+ expected_value = '{"foo": 1}'
+ Variable.set("TEST_VARIABLE_KEY", expected_value)
+ response = self.client.get(
+ "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": user}
+ )
+ assert response.status_code == expected_status_code
+ if expected_status_code == 200:
+ assert response.json == {"key": "TEST_VARIABLE_KEY", "value": expected_value, "description": None}
diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py
new file mode 100644
index 0000000000000..01336f9957c6d
--- /dev/null
+++ b/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py
@@ -0,0 +1,230 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import timedelta
+
+import pytest
+
+from airflow.models.dag import DagModel
+from airflow.models.dagrun import DagRun
+from airflow.models.taskinstance import TaskInstance
+from airflow.models.xcom import BaseXCom, XCom
+from airflow.operators.empty import EmptyOperator
+from airflow.security import permissions
+from airflow.utils.dates import parse_execution_date
+from airflow.utils.session import create_session
+from airflow.utils.types import DagRunType
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom
+
+pytestmark = [
+ pytest.mark.db_test,
+ pytest.mark.skip_if_database_isolation_mode,
+ pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"),
+]
+
+
+class CustomXCom(BaseXCom):
+ @classmethod
+ def deserialize_value(cls, xcom: XCom):
+ return f"real deserialized {super().deserialize_value(xcom)}"
+
+ def orm_deserialize_value(self):
+ return f"orm deserialized {super().orm_deserialize_value()}"
+
+
+@pytest.fixture(scope="module")
+def configured_app(minimal_app_for_auth_api):
+ app = minimal_app_for_auth_api
+
+ create_user(
+ app,
+ username="test_granular_permissions",
+ role_name="TestGranularDag",
+ permissions=[
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM),
+ ],
+ )
+ app.appbuilder.sm.sync_perm_for_dag(
+ "test-dag-id-1",
+ access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]},
+ )
+
+ yield app
+
+ delete_user(app, username="test_granular_permissions")
+
+
+def _compare_xcom_collections(collection1: dict, collection_2: dict):
+ assert collection1.get("total_entries") == collection_2.get("total_entries")
+
+ def sort_key(record):
+ return (
+ record.get("dag_id"),
+ record.get("task_id"),
+ record.get("execution_date"),
+ record.get("map_index"),
+ record.get("key"),
+ )
+
+ assert sorted(collection1.get("xcom_entries", []), key=sort_key) == sorted(
+ collection_2.get("xcom_entries", []), key=sort_key
+ )
+
+
+class TestXComEndpoint:
+ @staticmethod
+ def clean_db():
+ clear_db_dags()
+ clear_db_runs()
+ clear_db_xcom()
+
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app) -> None:
+ """
+ Setup For XCom endpoint TC
+ """
+ self.app = configured_app
+ self.client = self.app.test_client() # type:ignore
+ # clear existing xcoms
+ self.clean_db()
+
+ def teardown_method(self) -> None:
+ """
+ Clear Hanging XComs
+ """
+ self.clean_db()
+
+
+class TestGetXComEntries(TestXComEndpoint):
+ def test_should_respond_200_with_tilde_and_granular_dag_access(self):
+ dag_id_1 = "test-dag-id-1"
+ task_id_1 = "test-task-id-1"
+ execution_date = "2005-04-02T00:00:00+00:00"
+ execution_date_parsed = parse_execution_date(execution_date)
+ dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
+ self._create_xcom_entries(dag_id_1, dag_run_id_1, execution_date_parsed, task_id_1)
+
+ dag_id_2 = "test-dag-id-2"
+ task_id_2 = "test-task-id-2"
+ run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed)
+ self._create_xcom_entries(dag_id_2, run_id_2, execution_date_parsed, task_id_2)
+ self._create_invalid_xcom_entries(execution_date_parsed)
+ response = self.client.get(
+ "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries",
+ environ_overrides={"REMOTE_USER": "test_granular_permissions"},
+ )
+
+ assert 200 == response.status_code
+ response_data = response.json
+ for xcom_entry in response_data["xcom_entries"]:
+ xcom_entry["timestamp"] = "TIMESTAMP"
+ _compare_xcom_collections(
+ response_data,
+ {
+ "xcom_entries": [
+ {
+ "dag_id": dag_id_1,
+ "execution_date": execution_date,
+ "key": "test-xcom-key-1",
+ "task_id": task_id_1,
+ "timestamp": "TIMESTAMP",
+ "map_index": -1,
+ },
+ {
+ "dag_id": dag_id_1,
+ "execution_date": execution_date,
+ "key": "test-xcom-key-2",
+ "task_id": task_id_1,
+ "timestamp": "TIMESTAMP",
+ "map_index": -1,
+ },
+ ],
+ "total_entries": 2,
+ },
+ )
+
+ def _create_xcom_entries(self, dag_id, run_id, execution_date, task_id, mapped_ti=False):
+ with create_session() as session:
+ dag = DagModel(dag_id=dag_id)
+ session.add(dag)
+ dagrun = DagRun(
+ dag_id=dag_id,
+ run_id=run_id,
+ execution_date=execution_date,
+ start_date=execution_date,
+ run_type=DagRunType.MANUAL,
+ )
+ session.add(dagrun)
+ if mapped_ti:
+ for i in [0, 1]:
+ ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, map_index=i)
+ ti.dag_id = dag_id
+ session.add(ti)
+ else:
+ ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
+ ti.dag_id = dag_id
+ session.add(ti)
+
+ for i in [1, 2]:
+ if mapped_ti:
+ key = "test-xcom-key"
+ map_index = i - 1
+ else:
+ key = f"test-xcom-key-{i}"
+ map_index = -1
+
+ XCom.set(
+ key=key, value="TEST", run_id=run_id, task_id=task_id, dag_id=dag_id, map_index=map_index
+ )
+
+ def _create_invalid_xcom_entries(self, execution_date):
+ """
+ Invalid XCom entries to test join query
+ """
+ with create_session() as session:
+ dag = DagModel(dag_id="invalid_dag")
+ session.add(dag)
+ dagrun = DagRun(
+ dag_id="invalid_dag",
+ run_id="invalid_run_id",
+ execution_date=execution_date + timedelta(days=1),
+ start_date=execution_date,
+ run_type=DagRunType.MANUAL,
+ )
+ session.add(dagrun)
+ dagrun1 = DagRun(
+ dag_id="invalid_dag",
+ run_id="not_this_run_id",
+ execution_date=execution_date,
+ start_date=execution_date,
+ run_type=DagRunType.MANUAL,
+ )
+ session.add(dagrun1)
+ ti = TaskInstance(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id")
+ ti.dag_id = "invalid_dag"
+ session.add(ti)
+ for i in [1, 2]:
+ XCom.set(
+ key=f"invalid-xcom-key-{i}",
+ value="TEST",
+ run_id="not_this_run_id",
+ task_id="invalid_task",
+ dag_id="invalid_dag",
+ )
diff --git a/tests/providers/fab/auth_manager/conftest.py b/tests/providers/fab/auth_manager/conftest.py
index 22c29dd229fa1..a8fbe5fbdaaae 100644
--- a/tests/providers/fab/auth_manager/conftest.py
+++ b/tests/providers/fab/auth_manager/conftest.py
@@ -30,7 +30,10 @@ def minimal_app_for_auth_api():
"init_appbuilder",
"init_api_auth",
"init_api_auth_provider",
+ "init_api_connexion",
"init_api_error_handlers",
+ "init_airflow_session_interface",
+ "init_appbuilder_views",
]
)
def factory():
@@ -39,7 +42,11 @@ def factory():
(
"api",
"auth_backends",
- ): "tests.test_utils.remote_user_api_auth_backend,airflow.api.auth.backend.session"
+ ): "tests.providers.fab.auth_manager.api_endpoints.remote_user_api_auth_backend,airflow.api.auth.backend.session",
+ (
+ "core",
+ "auth_manager",
+ ): "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager",
}
):
_app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore
@@ -58,3 +65,11 @@ def set_auth_role_public(request):
yield
app.config["AUTH_ROLE_PUBLIC"] = auto_role_public
+
+
+@pytest.fixture(scope="module")
+def dagbag():
+ from airflow.models import DagBag
+
+ DagBag(include_examples=True, read_dags_from_db=False).sync_to_db()
+ return DagBag(include_examples=True, read_dags_from_db=True)
diff --git a/tests/providers/fab/auth_manager/test_security.py b/tests/providers/fab/auth_manager/test_security.py
index 156b5cf626271..bebb52c256fc8 100644
--- a/tests/providers/fab/auth_manager/test_security.py
+++ b/tests/providers/fab/auth_manager/test_security.py
@@ -49,7 +49,7 @@
from airflow.www.auth import get_access_denied_message
from airflow.www.extensions.init_auth_manager import get_auth_manager
from airflow.www.utils import CustomSQLAInterface
-from tests.test_utils.api_connexion_utils import (
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import (
create_user,
create_user_scope,
delete_role,
diff --git a/tests/providers/fab/auth_manager/views/test_permissions.py b/tests/providers/fab/auth_manager/views/test_permissions.py
index 0b1073df287fa..f24d9b738343b 100644
--- a/tests/providers/fab/auth_manager/views/test_permissions.py
+++ b/tests/providers/fab/auth_manager/views/test_permissions.py
@@ -21,7 +21,7 @@
from airflow.security import permissions
from airflow.www import app as application
-from tests.test_utils.api_connexion_utils import create_user, delete_user
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS
from tests.test_utils.www import client_with_login
diff --git a/tests/providers/fab/auth_manager/views/test_roles_list.py b/tests/providers/fab/auth_manager/views/test_roles_list.py
index 156f07df41209..8de63ad5ba88a 100644
--- a/tests/providers/fab/auth_manager/views/test_roles_list.py
+++ b/tests/providers/fab/auth_manager/views/test_roles_list.py
@@ -21,7 +21,7 @@
from airflow.security import permissions
from airflow.www import app as application
-from tests.test_utils.api_connexion_utils import create_user, delete_user
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS
from tests.test_utils.www import client_with_login
diff --git a/tests/providers/fab/auth_manager/views/test_user.py b/tests/providers/fab/auth_manager/views/test_user.py
index 6660ab926d886..62b03a99e7c2c 100644
--- a/tests/providers/fab/auth_manager/views/test_user.py
+++ b/tests/providers/fab/auth_manager/views/test_user.py
@@ -21,7 +21,7 @@
from airflow.security import permissions
from airflow.www import app as application
-from tests.test_utils.api_connexion_utils import create_user, delete_user
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS
from tests.test_utils.www import client_with_login
diff --git a/tests/providers/fab/auth_manager/views/test_user_edit.py b/tests/providers/fab/auth_manager/views/test_user_edit.py
index 65937b6f83d33..8099f67948183 100644
--- a/tests/providers/fab/auth_manager/views/test_user_edit.py
+++ b/tests/providers/fab/auth_manager/views/test_user_edit.py
@@ -21,7 +21,7 @@
from airflow.security import permissions
from airflow.www import app as application
-from tests.test_utils.api_connexion_utils import create_user, delete_user
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS
from tests.test_utils.www import client_with_login
diff --git a/tests/providers/fab/auth_manager/views/test_user_stats.py b/tests/providers/fab/auth_manager/views/test_user_stats.py
index 8cb260fcf1ec4..ae09cf92252c6 100644
--- a/tests/providers/fab/auth_manager/views/test_user_stats.py
+++ b/tests/providers/fab/auth_manager/views/test_user_stats.py
@@ -21,7 +21,7 @@
from airflow.security import permissions
from airflow.www import app as application
-from tests.test_utils.api_connexion_utils import create_user, delete_user
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user
from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS
from tests.test_utils.www import client_with_login
diff --git a/tests/test_utils/api_connexion_utils.py b/tests/test_utils/api_connexion_utils.py
index af746b2d55468..48869ee48078d 100644
--- a/tests/test_utils/api_connexion_utils.py
+++ b/tests/test_utils/api_connexion_utils.py
@@ -17,6 +17,7 @@
from __future__ import annotations
from contextlib import contextmanager
+from typing import TYPE_CHECKING
from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from tests.test_utils.compat import ignore_provider_compatibility_error
@@ -24,6 +25,9 @@
with ignore_provider_compatibility_error("2.9.0+", __file__):
from airflow.providers.fab.auth_manager.security_manager.override import EXISTING_ROLES
+if TYPE_CHECKING:
+ from flask import Flask
+
@contextmanager
def create_test_client(app, user_name, role_name, permissions):
@@ -44,7 +48,11 @@ def create_user_scope(app, username, **kwargs):
It will create a user and provide it for the fixture via YIELD (generator)
then will tidy up once test is complete
"""
- test_user = create_user(app, username, **kwargs)
+ from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import (
+ create_user as create_user_fab,
+ )
+
+ test_user = create_user_fab(app, username, **kwargs)
try:
yield test_user
@@ -52,27 +60,20 @@ def create_user_scope(app, username, **kwargs):
delete_user(app, username)
-def create_user(app, username, role_name=None, email=None, permissions=None):
- appbuilder = app.appbuilder
-
+def create_user(app: Flask, username: str, role_name: str | None):
# Removes user and role so each test has isolated test data.
delete_user(app, username)
- role = None
- if role_name:
- delete_role(app, role_name)
- role = create_role(app, role_name, permissions)
- else:
- role = []
-
- return appbuilder.sm.add_user(
- username=username,
- first_name=username,
- last_name=username,
- email=email or f"{username}@example.org",
- role=role,
- password=username,
+
+ users = app.config.get("SIMPLE_AUTH_MANAGER_USERS", [])
+ users.append(
+ {
+ "username": username,
+ "role": role_name,
+ }
)
+ app.config["SIMPLE_AUTH_MANAGER_USERS"] = users
+
def create_role(app, name, permissions=None):
appbuilder = app.appbuilder
@@ -87,14 +88,6 @@ def create_role(app, name, permissions=None):
return role
-def set_user_single_role(app, user, role_name):
- role = create_role(app, role_name)
- if role not in user.roles:
- user.roles = [role]
- app.appbuilder.sm.update_user(user)
- user._perms = None
-
-
def delete_role(app, name):
if name not in EXISTING_ROLES:
if app.appbuilder.sm.find_role(name):
@@ -106,20 +99,11 @@ def delete_roles(app):
delete_role(app, role.name)
-def delete_user(app, username):
- appbuilder = app.appbuilder
- for user in appbuilder.sm.get_all_users():
- if user.username == username:
- _ = [
- delete_role(app, role.name) for role in user.roles if role and role.name not in EXISTING_ROLES
- ]
- appbuilder.sm.del_register_user(user)
- break
-
-
-def delete_users(app):
- for user in app.appbuilder.sm.get_all_users():
- delete_user(app, user.username)
+def delete_user(app: Flask, username):
+ users = app.config.get("SIMPLE_AUTH_MANAGER_USERS", [])
+ users = [user for user in users if user["username"] != username]
+
+ app.config["SIMPLE_AUTH_MANAGER_USERS"] = users
def assert_401(response):
diff --git a/tests/test_utils/remote_user_api_auth_backend.py b/tests/test_utils/remote_user_api_auth_backend.py
index b7714e5192e6a..59df201e530e4 100644
--- a/tests/test_utils/remote_user_api_auth_backend.py
+++ b/tests/test_utils/remote_user_api_auth_backend.py
@@ -15,17 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Default authentication backend - everything is allowed"""
-
from __future__ import annotations
import logging
from functools import wraps
from typing import TYPE_CHECKING, Callable, TypeVar, cast
-from flask import Response, request
-from flask_login import login_user
+from flask import Response, request, session
+from airflow.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.utils.airflow_flask_app import get_airflow_app
if TYPE_CHECKING:
@@ -36,25 +34,15 @@
CLIENT_AUTH: tuple[str, str] | AuthBase | None = None
-def init_app(_):
- """Initializes authentication backend"""
+def init_app(_): ...
T = TypeVar("T", bound=Callable)
-def _lookup_user(user_email_or_username: str):
- security_manager = get_airflow_app().appbuilder.sm
- user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user(
- username=user_email_or_username
- )
- if not user:
- return None
-
- if not user.is_active:
- return None
-
- return user
+def _lookup_user(username: str):
+ users = get_airflow_app().config.get("SIMPLE_AUTH_MANAGER_USERS", [])
+ return next((user for user in users if user["username"] == username), None)
def requires_authentication(function: T):
@@ -69,13 +57,13 @@ def decorated(*args, **kwargs):
log.debug("Looking for user: %s", user_id)
- user = _lookup_user(user_id)
- if not user:
+ user_dict = _lookup_user(user_id)
+ if not user_dict:
return Response("Forbidden", 403)
- log.debug("Found user: %s", user)
+ log.debug("Found user: %s", user_dict)
+ session["user"] = SimpleAuthManagerUser(username=user_dict["username"], role=user_dict["role"])
- login_user(user, remember=False)
return function(*args, **kwargs)
return cast(T, decorated)
diff --git a/tests/www/views/test_views_custom_user_views.py b/tests/www/views/test_views_custom_user_views.py
index ae6d0132827c2..84947a8e5f36f 100644
--- a/tests/www/views/test_views_custom_user_views.py
+++ b/tests/www/views/test_views_custom_user_views.py
@@ -27,7 +27,10 @@
from airflow import settings
from airflow.security import permissions
from airflow.www import app as application
-from tests.test_utils.api_connexion_utils import create_user, delete_role
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import (
+ create_user as create_user,
+ delete_role,
+)
from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login
pytestmark = pytest.mark.db_test
diff --git a/tests/www/views/test_views_dagrun.py b/tests/www/views/test_views_dagrun.py
index 39c17d086f379..d95955246ac78 100644
--- a/tests/www/views/test_views_dagrun.py
+++ b/tests/www/views/test_views_dagrun.py
@@ -24,7 +24,11 @@
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.www.views import DagRunModelView
-from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import (
+ create_user,
+ delete_roles,
+ delete_user,
+)
from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login
from tests.www.views.test_views_tasks import _get_appbuilder_pk_string
diff --git a/tests/www/views/test_views_home.py b/tests/www/views/test_views_home.py
index 5393115041392..ddec0c0bcfed3 100644
--- a/tests/www/views/test_views_home.py
+++ b/tests/www/views/test_views_home.py
@@ -27,7 +27,7 @@
from airflow.utils.state import State
from airflow.www.utils import UIAlert
from airflow.www.views import FILTER_LASTRUN_COOKIE, FILTER_STATUS_COOKIE, FILTER_TAGS_COOKIE
-from tests.test_utils.api_connexion_utils import create_user
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user
from tests.test_utils.db import clear_db_dags, clear_db_import_errors, clear_db_serialized_dags
from tests.test_utils.permissions import _resource_name
from tests.test_utils.www import check_content_in_response, check_content_not_in_response, client_with_login
diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py
index f5cc011fb6f0e..7b65051724c27 100644
--- a/tests/www/views/test_views_tasks.py
+++ b/tests/www/views/test_views_tasks.py
@@ -44,7 +44,11 @@
from airflow.utils.state import DagRunState, State
from airflow.utils.types import DagRunType
from airflow.www.views import TaskInstanceModelView, _safe_parse_datetime
-from tests.test_utils.api_connexion_utils import create_user, delete_roles, delete_user
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import (
+ create_user,
+ delete_roles,
+ delete_user,
+)
from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_runs, clear_db_xcom
diff --git a/tests/www/views/test_views_variable.py b/tests/www/views/test_views_variable.py
index a91a12ddc470b..b7fa8b37c52c8 100644
--- a/tests/www/views/test_views_variable.py
+++ b/tests/www/views/test_views_variable.py
@@ -25,7 +25,7 @@
from airflow.models import Variable
from airflow.security import permissions
from airflow.utils.session import create_session
-from tests.test_utils.api_connexion_utils import create_user
+from tests.providers.fab.auth_manager.api_endpoints.api_connexion_utils import create_user
from tests.test_utils.www import (
_check_last_log,
check_content_in_response,