diff --git a/airflow/api_fastapi/common/parameters.py b/airflow/api_fastapi/common/parameters.py index 6bfbfadf4180c..7554ee88450bb 100644 --- a/airflow/api_fastapi/common/parameters.py +++ b/airflow/api_fastapi/common/parameters.py @@ -33,7 +33,7 @@ overload, ) -from fastapi import Depends, HTTPException, Query +from fastapi import Depends, HTTPException, Query, status from pendulum.parsing.exceptions import ParserError from pydantic import AfterValidator, BaseModel, NonNegativeInt from sqlalchemy import Column, case, or_ @@ -337,9 +337,15 @@ def to_orm(self, select: Select) -> Select: @staticmethod def _convert_dag_run_states(states: Iterable[str] | None) -> list[DagRunState | None] | None: - if not states: - return None - return [None if s in ("none", None) else DagRunState(s) for s in states] + try: + if not states: + return None + return [None if s in ("none", None) else DagRunState(s) for s in states] + except ValueError: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Invalid value for state. Valid values are {', '.join(DagRunState)}", + ) def depends(self, state: list[str] = Query(default_factory=list)) -> DagRunStateFilter: states = self._convert_dag_run_states(state) @@ -360,7 +366,13 @@ def to_orm(self, select: Select) -> Select: return select.where(or_(*conditions)) def depends(self, state: list[str] = Query(default_factory=list)) -> TIStateFilter: - states = _convert_ti_states(state) + try: + states = _convert_ti_states(state) + except ValueError: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Invalid value for state. Valid values are {', '.join(TaskInstanceState)}", + ) return self.set_value(states) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index 2ac22a02e31aa..b3ce267bf5222 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -425,8 +425,11 @@ def test_bad_filters(self, test_client): assert body["detail"] == expected_detail def test_invalid_state(self, test_client): - with pytest.raises(ValueError, match="'invalid' is not a valid DagRunState"): - test_client.get(f"/public/dags/{DAG1_ID}/dagRuns", params={"state": "invalid"}) + response = test_client.get(f"/public/dags/{DAG1_ID}/dagRuns", params={"state": ["invalid"]}) + assert response.status_code == 422 + assert ( + response.json()["detail"] == f"Invalid value for state. Valid values are {', '.join(DagRunState)}" + ) class TestPatchDagRun: diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index f8e75600171b3..b3b4d0ffb1b50 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -956,6 +956,14 @@ def test_not_found(self, test_client): assert response.status_code == 404 assert response.json() == {"detail": "DagRun with run_id: `invalid` was not found"} + def test_bad_state(self, test_client): + response = test_client.get("/public/dags/~/dagRuns/~/taskInstances", params={"state": "invalid"}) + assert response.status_code == 422 + assert ( + response.json()["detail"] + == f"Invalid value for state. Valid values are {', '.join(TaskInstanceState)}" + ) + @pytest.mark.xfail(reason="permissions not implemented yet.") def test_return_TI_only_from_readable_dags(self, test_client, session): task_instances = {