diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index a82f712f4e191..9e4ea49e08890 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -56,7 +56,7 @@ ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.exceptions import TaskNotFound -from airflow.models import Base +from airflow.models import Base, DagRun from airflow.models.taskinstance import TaskInstance as TI from airflow.models.taskinstancehistory import TaskInstanceHistory as TIH from airflow.ti_deps.dep_context import DepContext @@ -303,8 +303,18 @@ def get_task_instances( base_query = select(TI).join(TI.dag_run) if dag_id != "~": + dag = request.app.state.dag_bag.get_dag(dag_id) + if not dag: + raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id: `{dag_id}` was not found") base_query = base_query.where(TI.dag_id == dag_id) + if dag_run_id != "~": + dag_run = session.scalar(select(DagRun).filter_by(run_id=dag_run_id)) + if not dag_run: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + f"DagRun with run_id: `{dag_run_id}` was not found", + ) base_query = base_query.where(TI.run_id == dag_run_id) task_instance_select, total_entries = paginated_select( diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index d5e01042215e3..56e2ee5e0e178 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -947,6 +947,15 @@ def test_should_respond_200( assert response.json()["total_entries"] == expected_ti assert len(response.json()["task_instances"]) == expected_ti + def test_not_found(self, test_client): + response = test_client.get("/public/dags/invalid/dagRuns/~/taskInstances") + assert response.status_code == 404 + assert response.json() == {"detail": "DAG with dag_id: `invalid` was not found"} + + response = test_client.get("/public/dags/~/dagRuns/invalid/taskInstances") + assert response.status_code == 404 + assert response.json() == {"detail": "DagRun with run_id: `invalid` was not found"} + @pytest.mark.xfail(reason="permissions not implemented yet.") def test_return_TI_only_from_readable_dags(self, test_client, session): task_instances = {