Skip to content

Commit

Permalink
AIP-84 Get Batch Task Instances (apache#44051)
Browse files Browse the repository at this point in the history
  • Loading branch information
pierrejeambrun authored Nov 15, 2024
1 parent eed59f1 commit aa72f0f
Show file tree
Hide file tree
Showing 11 changed files with 981 additions and 26 deletions.
1 change: 1 addition & 0 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def get_task_instances(
)


@mark_fastapi_migration_done
@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE)
@provide_session
def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
Expand Down
82 changes: 59 additions & 23 deletions airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def depends(self, *args: Any, **kwargs: Any) -> Self:
pass


class _LimitFilter(BaseParam[int]):
class LimitFilter(BaseParam[int]):
"""Filter on the limit."""

def to_orm(self, select: Select) -> Select:
Expand All @@ -75,19 +75,19 @@ def to_orm(self, select: Select) -> Select:

return select.limit(self.value)

def depends(self, limit: int = 100) -> _LimitFilter:
def depends(self, limit: int = 100) -> LimitFilter:
return self.set_value(limit)


class _OffsetFilter(BaseParam[int]):
class OffsetFilter(BaseParam[int]):
"""Filter on offset."""

def to_orm(self, select: Select) -> Select:
if self.value is None and self.skip_none:
return select
return select.offset(self.value)

def depends(self, offset: int = 0) -> _OffsetFilter:
def depends(self, offset: int = 0) -> OffsetFilter:
return self.set_value(offset)


Expand Down Expand Up @@ -115,18 +115,54 @@ def depends(self, only_active: bool = True) -> _OnlyActiveFilter:
return self.set_value(only_active)


class _DagIdsFilter(BaseParam[list[str]]):
"""Filter on multi-valued dag_ids param for DagRun."""
class DagIdsFilter(BaseParam[list[str]]):
"""Filter on dag ids."""

def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None:
super().__init__(value, skip_none)
self.model = model

def to_orm(self, select: Select) -> Select:
if self.value and self.skip_none:
return select.where(DagRun.dag_id.in_(self.value))
return select.where(self.model.dag_id.in_(self.value))
return select

def depends(self, dag_ids: list[str] = Query(None)) -> _DagIdsFilter:
def depends(self, dag_ids: list[str] = Query(None)) -> DagIdsFilter:
return self.set_value(dag_ids)


class DagRunIdsFilter(BaseParam[list[str]]):
"""Filter on dag run ids."""

def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None:
super().__init__(value, skip_none)
self.model = model

def to_orm(self, select: Select) -> Select:
if self.value and self.skip_none:
return select.where(self.model.run_id.in_(self.value))
return select

def depends(self, dag_run_ids: list[str] = Query(None)) -> DagRunIdsFilter:
return self.set_value(dag_run_ids)


class TaskIdsFilter(BaseParam[list[str]]):
"""Filter on task ids."""

def __init__(self, model: Base, value: list[str] | None = None, skip_none: bool = True) -> None:
super().__init__(value, skip_none)
self.model = model

def to_orm(self, select: Select) -> Select:
if self.value and self.skip_none:
return select.where(self.model.task_id.in_(self.value))
return select

def depends(self, task_ids: list[str] = Query(None)) -> TaskIdsFilter:
return self.set_value(task_ids)


class _SearchParam(BaseParam[str]):
"""Search on attribute."""

Expand Down Expand Up @@ -273,7 +309,7 @@ def depends(self, owners: list[str] = Query(default_factory=list)) -> _OwnersFil
return self.set_value(owners)


class _TIStateFilter(BaseParam[List[Optional[TaskInstanceState]]]):
class TIStateFilter(BaseParam[List[Optional[TaskInstanceState]]]):
"""Filter on task instance state."""

def to_orm(self, select: Select) -> Select:
Expand All @@ -286,12 +322,12 @@ def to_orm(self, select: Select) -> Select:
conditions = [TaskInstance.state == state for state in self.value]
return select.where(or_(*conditions))

def depends(self, state: list[str] = Query(default_factory=list)) -> _TIStateFilter:
def depends(self, state: list[str] = Query(default_factory=list)) -> TIStateFilter:
states = _convert_ti_states(state)
return self.set_value(states)


class _TIPoolFilter(BaseParam[List[str]]):
class TIPoolFilter(BaseParam[List[str]]):
"""Filter on task instance pool."""

def to_orm(self, select: Select) -> Select:
Expand All @@ -304,11 +340,11 @@ def to_orm(self, select: Select) -> Select:
conditions = [TaskInstance.pool == pool for pool in self.value]
return select.where(or_(*conditions))

def depends(self, pool: list[str] = Query(default_factory=list)) -> _TIPoolFilter:
def depends(self, pool: list[str] = Query(default_factory=list)) -> TIPoolFilter:
return self.set_value(pool)


class _TIQueueFilter(BaseParam[List[str]]):
class TIQueueFilter(BaseParam[List[str]]):
"""Filter on task instance queue."""

def to_orm(self, select: Select) -> Select:
Expand All @@ -321,11 +357,11 @@ def to_orm(self, select: Select) -> Select:
conditions = [TaskInstance.queue == queue for queue in self.value]
return select.where(or_(*conditions))

def depends(self, queue: list[str] = Query(default_factory=list)) -> _TIQueueFilter:
def depends(self, queue: list[str] = Query(default_factory=list)) -> TIQueueFilter:
return self.set_value(queue)


class _TIExecutorFilter(BaseParam[List[str]]):
class TIExecutorFilter(BaseParam[List[str]]):
"""Filter on task instance executor."""

def to_orm(self, select: Select) -> Select:
Expand All @@ -338,7 +374,7 @@ def to_orm(self, select: Select) -> Select:
conditions = [TaskInstance.executor == executor for executor in self.value]
return select.where(or_(*conditions))

def depends(self, executor: list[str] = Query(default_factory=list)) -> _TIExecutorFilter:
def depends(self, executor: list[str] = Query(default_factory=list)) -> TIExecutorFilter:
return self.set_value(executor)


Expand Down Expand Up @@ -581,8 +617,8 @@ def depends_float(
DateTimeQuery = Annotated[str, AfterValidator(_safe_parse_datetime)]

# DAG
QueryLimit = Annotated[_LimitFilter, Depends(_LimitFilter().depends)]
QueryOffset = Annotated[_OffsetFilter, Depends(_OffsetFilter().depends)]
QueryLimit = Annotated[LimitFilter, Depends(LimitFilter().depends)]
QueryOffset = Annotated[OffsetFilter, Depends(OffsetFilter().depends)]
QueryPausedFilter = Annotated[_PausedFilter, Depends(_PausedFilter().depends)]
QueryOnlyActiveFilter = Annotated[_OnlyActiveFilter, Depends(_OnlyActiveFilter().depends)]
QueryDagIdPatternSearch = Annotated[_DagIdPatternSearch, Depends(_DagIdPatternSearch().depends)]
Expand All @@ -597,7 +633,7 @@ def depends_float(

# DagRun
QueryLastDagRunStateFilter = Annotated[_LastDagRunStateFilter, Depends(_LastDagRunStateFilter().depends)]
QueryDagIdsFilter = Annotated[_DagIdsFilter, Depends(_DagIdsFilter().depends)]
QueryDagIdsFilter = Annotated[DagIdsFilter, Depends(DagIdsFilter(DagRun).depends)]

# DAGWarning
QueryDagIdInDagWarningFilter = Annotated[_DagIdFilter, Depends(_DagIdFilter(DagWarning.dag_id).depends)]
Expand All @@ -607,10 +643,10 @@ def depends_float(
QueryDagTagPatternSearch = Annotated[_DagTagNamePatternSearch, Depends(_DagTagNamePatternSearch().depends)]

# TI
QueryTIStateFilter = Annotated[_TIStateFilter, Depends(_TIStateFilter().depends)]
QueryTIPoolFilter = Annotated[_TIPoolFilter, Depends(_TIPoolFilter().depends)]
QueryTIQueueFilter = Annotated[_TIQueueFilter, Depends(_TIQueueFilter().depends)]
QueryTIExecutorFilter = Annotated[_TIExecutorFilter, Depends(_TIExecutorFilter().depends)]
QueryTIStateFilter = Annotated[TIStateFilter, Depends(TIStateFilter().depends)]
QueryTIPoolFilter = Annotated[TIPoolFilter, Depends(TIPoolFilter().depends)]
QueryTIQueueFilter = Annotated[TIQueueFilter, Depends(TIQueueFilter().depends)]
QueryTIExecutorFilter = Annotated[TIExecutorFilter, Depends(TIExecutorFilter().depends)]

# Assets
QueryUriPatternSearch = Annotated[_UriPatternSearch, Depends(_UriPatternSearch().depends)]
Expand Down
33 changes: 32 additions & 1 deletion airflow/api_fastapi/core_api/datamodels/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@
from datetime import datetime
from typing import Annotated

from pydantic import AliasPath, BaseModel, BeforeValidator, ConfigDict, Field
from pydantic import (
AliasPath,
AwareDatetime,
BaseModel,
BeforeValidator,
ConfigDict,
Field,
NonNegativeInt,
)

from airflow.api_fastapi.core_api.datamodels.job import JobResponse
from airflow.api_fastapi.core_api.datamodels.trigger import TriggerResponse
Expand Down Expand Up @@ -83,3 +91,26 @@ class TaskDependencyCollectionResponse(BaseModel):
"""Task scheduling dependencies collection serializer for responses."""

dependencies: list[TaskDependencyResponse]


class TaskInstancesBatchBody(BaseModel):
"""Task Instance body for get batch."""

dag_ids: list[str] | None = None
dag_run_ids: list[str] | None = None
task_ids: list[str] | None = None
state: list[TaskInstanceState | None] | None = None
logical_date_gte: AwareDatetime | None = None
logical_date_lte: AwareDatetime | None = None
start_date_gte: AwareDatetime | None = None
start_date_lte: AwareDatetime | None = None
end_date_gte: AwareDatetime | None = None
end_date_lte: AwareDatetime | None = None
duration_gte: float | None = None
duration_lte: float | None = None
pool: list[str] | None = None
queue: list[str] | None = None
executor: list[str] | None = None
page_offset: NonNegativeInt = 0
page_limit: NonNegativeInt = 100
order_by: str | None = None
Loading

0 comments on commit aa72f0f

Please sign in to comment.