Skip to content

Commit

Permalink
Migrate public endpoint Get Tasks to FastAPI (apache#43980)
Browse files Browse the repository at this point in the history
* Migrate public endpoint Get Tasks to FastAPI

* Re-run static checks

* Add migration marker

* Remove 401 and 403, which are now added by default

* Re-run static checks
  • Loading branch information
omkar-foss authored Nov 15, 2024
1 parent 6e59137 commit f66459b
Show file tree
Hide file tree
Showing 13 changed files with 597 additions and 44 deletions.
1 change: 1 addition & 0 deletions airflow/api_connexion/endpoints/task_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_task(*, dag_id: str, task_id: str) -> APIResponse:
return task_schema.dump(task)


@mark_fastapi_migration_done
@security.requires_access_dag("GET", DagAccessEntity.TASK)
def get_tasks(*, dag_id: str, order_by: str = "task_id") -> APIResponse:
"""Get tasks for DAG."""
Expand Down
28 changes: 0 additions & 28 deletions airflow/api_fastapi/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@
# under the License.
from __future__ import annotations

import inspect
from datetime import timedelta
from typing import Annotated

from pydantic import AfterValidator, AliasGenerator, AwareDatetime, BaseModel, BeforeValidator, ConfigDict

from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils import timezone

UtcDateTime = Annotated[AwareDatetime, AfterValidator(lambda d: d.astimezone(timezone.utc))]
Expand Down Expand Up @@ -59,28 +56,3 @@ class TimeDelta(BaseModel):


TimeDeltaWithValidation = Annotated[TimeDelta, BeforeValidator(_validate_timedelta_field)]


def get_class_ref(obj) -> dict[str, str | None]:
"""Return the class_ref dict for obj."""
is_mapped_or_serialized = isinstance(obj, (MappedOperator, SerializedBaseOperator))

module_path = None
if is_mapped_or_serialized:
module_path = obj._task_module
else:
module_type = inspect.getmodule(obj)
module_path = module_type.__name__ if module_type else None

class_name = None
if is_mapped_or_serialized:
class_name = obj._task_type
elif obj.__class__ is type:
class_name = obj.__name__
else:
class_name = type(obj).__name__

return {
"module_path": module_path,
"class_name": class_name,
}
47 changes: 45 additions & 2 deletions airflow/api_fastapi/core_api/datamodels/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,44 @@

from __future__ import annotations

import inspect
from collections import abc
from datetime import datetime
from typing import Any

from pydantic import BaseModel, computed_field, field_validator
from pydantic import BaseModel, computed_field, field_validator, model_validator

from airflow.api_fastapi.common.types import TimeDeltaWithValidation
from airflow.serialization.serialized_objects import encode_priority_weight_strategy
from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator, encode_priority_weight_strategy
from airflow.task.priority_strategy import PriorityWeightStrategy


def _get_class_ref(obj) -> dict[str, str | None]:
"""Return the class_ref dict for obj."""
is_mapped_or_serialized = isinstance(obj, (MappedOperator, SerializedBaseOperator))

module_path = None
if is_mapped_or_serialized:
module_path = obj._task_module
else:
module_type = inspect.getmodule(obj)
module_path = module_type.__name__ if module_type else None

class_name = None
if is_mapped_or_serialized:
class_name = obj._task_type
elif obj.__class__ is type:
class_name = obj.__name__
else:
class_name = type(obj).__name__

return {
"module_path": module_path,
"class_name": class_name,
}


class TaskResponse(BaseModel):
"""Task serializer for responses."""

Expand Down Expand Up @@ -57,6 +85,14 @@ class TaskResponse(BaseModel):
class_ref: dict | None
is_mapped: bool | None

@model_validator(mode="before")
@classmethod
def validate_model(cls, task: Any) -> Any:
task.__dict__.update(
{"class_ref": _get_class_ref(task), "is_mapped": isinstance(task, MappedOperator)}
)
return task

@field_validator("weight_rule", mode="before")
@classmethod
def validate_weight_rule(cls, wr: str | PriorityWeightStrategy | None) -> str | None:
Expand All @@ -81,3 +117,10 @@ def get_params(cls, params: abc.MutableMapping | None) -> dict | None:
def extra_links(self) -> list[str]:
"""Extract and return extra_links."""
return getattr(self, "operator_extra_links", [])


class TaskCollectionResponse(BaseModel):
"""Task collection serializer for responses."""

tasks: list[TaskResponse]
total_entries: int
74 changes: 74 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3204,6 +3204,64 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/tasks/:
get:
tags:
- Task
summary: Get Tasks
description: Get tasks for DAG.
operationId: get_tasks
parameters:
- name: dag_id
in: path
required: true
schema:
type: string
title: Dag Id
- name: order_by
in: query
required: false
schema:
type: string
default: task_id
title: Order By
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/TaskCollectionResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/tasks/{task_id}:
get:
tags:
Expand Down Expand Up @@ -5546,6 +5604,22 @@ components:
- latest_scheduler_heartbeat
title: SchedulerInfoSchema
description: Schema for Scheduler info.
TaskCollectionResponse:
properties:
tasks:
items:
$ref: '#/components/schemas/TaskResponse'
type: array
title: Tasks
total_entries:
type: integer
title: Total Entries
type: object
required:
- tasks
- total_entries
title: TaskCollectionResponse
description: Task collection serializer for responses.
TaskDependencyCollectionResponse:
properties:
dependencies:
Expand Down
39 changes: 31 additions & 8 deletions airflow/api_fastapi/core_api/routes/public/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,52 @@

from __future__ import annotations

from operator import attrgetter

from fastapi import HTTPException, Request, status

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.common.types import get_class_ref
from airflow.api_fastapi.core_api.datamodels.tasks import TaskResponse
from airflow.api_fastapi.core_api.datamodels.tasks import TaskCollectionResponse, TaskResponse
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.exceptions import TaskNotFound
from airflow.models import DAG
from airflow.models.mappedoperator import MappedOperator

tasks_router = AirflowRouter(tags=["Task"], prefix="/dags/{dag_id}/tasks")


@tasks_router.get(
"/",
responses=create_openapi_http_exception_doc(
[
status.HTTP_400_BAD_REQUEST,
status.HTTP_404_NOT_FOUND,
]
),
)
def get_tasks(
dag_id: str,
request: Request,
order_by: str = "task_id",
) -> TaskCollectionResponse:
"""Get tasks for DAG."""
dag: DAG = request.app.state.dag_bag.get_dag(dag_id)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")
try:
tasks = sorted(dag.tasks, key=attrgetter(order_by.lstrip("-")), reverse=(order_by[0:1] == "-"))
except AttributeError as err:
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(err))
return TaskCollectionResponse(
tasks=[TaskResponse.model_validate(task, from_attributes=True) for task in tasks],
total_entries=(len(tasks)),
)


@tasks_router.get(
"/{task_id}",
responses=create_openapi_http_exception_doc(
[
status.HTTP_400_BAD_REQUEST,
status.HTTP_401_UNAUTHORIZED,
status.HTTP_403_FORBIDDEN,
status.HTTP_404_NOT_FOUND,
]
),
Expand All @@ -48,9 +74,6 @@ def get_task(dag_id: str, task_id, request: Request) -> TaskResponse:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")
try:
task = dag.get_task(task_id=task_id)
task.__dict__.update(
{"class_ref": get_class_ref(task), "is_mapped": isinstance(task, MappedOperator)}
)
except TaskNotFound:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task with id {task_id} was not found")
return TaskResponse.model_validate(task, from_attributes=True)
18 changes: 18 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,24 @@ export const UseTaskInstanceServiceGetTaskInstancesKeyFn = (
},
]),
];
export type TaskServiceGetTasksDefaultResponse = Awaited<
ReturnType<typeof TaskService.getTasks>
>;
export type TaskServiceGetTasksQueryResult<
TData = TaskServiceGetTasksDefaultResponse,
TError = unknown,
> = UseQueryResult<TData, TError>;
export const useTaskServiceGetTasksKey = "TaskServiceGetTasks";
export const UseTaskServiceGetTasksKeyFn = (
{
dagId,
orderBy,
}: {
dagId: string;
orderBy?: string;
},
queryKey?: Array<unknown>,
) => [useTaskServiceGetTasksKey, ...(queryKey ?? [{ dagId, orderBy }])];
export type TaskServiceGetTaskDefaultResponse = Awaited<
ReturnType<typeof TaskService.getTask>
>;
Expand Down
23 changes: 23 additions & 0 deletions airflow/ui/openapi-gen/queries/prefetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,29 @@ export const prefetchUseTaskInstanceServiceGetTaskInstances = (
updatedAtLte,
}),
});
/**
* Get Tasks
* Get tasks for DAG.
* @param data The data for the request.
* @param data.dagId
* @param data.orderBy
* @returns TaskCollectionResponse Successful Response
* @throws ApiError
*/
export const prefetchUseTaskServiceGetTasks = (
queryClient: QueryClient,
{
dagId,
orderBy,
}: {
dagId: string;
orderBy?: string;
},
) =>
queryClient.prefetchQuery({
queryKey: Common.UseTaskServiceGetTasksKeyFn({ dagId, orderBy }),
queryFn: () => TaskService.getTasks({ dagId, orderBy }),
});
/**
* Get Task
* Get simplified representation of a task.
Expand Down
29 changes: 29 additions & 0 deletions airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,35 @@ export const useTaskInstanceServiceGetTaskInstances = <
}) as TData,
...options,
});
/**
* Get Tasks
* Get tasks for DAG.
* @param data The data for the request.
* @param data.dagId
* @param data.orderBy
* @returns TaskCollectionResponse Successful Response
* @throws ApiError
*/
export const useTaskServiceGetTasks = <
TData = Common.TaskServiceGetTasksDefaultResponse,
TError = unknown,
TQueryKey extends Array<unknown> = unknown[],
>(
{
dagId,
orderBy,
}: {
dagId: string;
orderBy?: string;
},
queryKey?: TQueryKey,
options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">,
) =>
useQuery<TData, TError>({
queryKey: Common.UseTaskServiceGetTasksKeyFn({ dagId, orderBy }, queryKey),
queryFn: () => TaskService.getTasks({ dagId, orderBy }) as TData,
...options,
});
/**
* Get Task
* Get simplified representation of a task.
Expand Down
Loading

0 comments on commit f66459b

Please sign in to comment.