Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(typespec,agents-api): Update metadata_filter to have object type #586

Merged
merged 9 commits into from
Oct 5, 2024
54 changes: 54 additions & 0 deletions agents-api/agents_api/dependencies/query_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Callable

from fastapi import Request


def convert_value(value: str) -> Any:
"""
Attempts to convert a string value to an int or float. Returns the original string if conversion fails.
"""
for convert in (int, float):
try:
return convert(value)
except ValueError:
continue
return value


def create_filter_extractor(
prefix: str = "filter",
) -> Callable[[Request], dict[str, Any]]:
"""
Creates a dependency function to extract filter parameters with a given prefix.

Args:
prefix (str): The prefix to identify filter parameters.

Returns:
Callable[[Request], dict[str, Any]]: The dependency function.
"""

# Add a dot to the prefix to allow for nested filters
prefix += "."

def extract_filters(request: Request) -> dict[str, Any]:
"""
Extracts query parameters that start with the specified prefix and returns them as a dictionary.

Args:
request (Request): The incoming HTTP request.

Returns:
dict[str, Any]: A dictionary containing the filter parameters.
"""

filters: dict[str, Any] = {}

for key, value in request.query_params.items():
if key.startswith(prefix):
filter_key = key[len(prefix) :]
filters[filter_key] = convert_value(value)

return filters

return extract_filters
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def list_docs(
created_at,
metadata,
}},
snippets[id, snippet_data]
snippets[id, snippet_data],
{metadata_filter_str}

:limit $limit
:offset $offset
Expand All @@ -112,6 +113,5 @@ def list_docs(
"owner_type": owner_type,
"limit": limit,
"offset": offset,
"metadata_filter": metadata_filter_str,
},
)
22 changes: 9 additions & 13 deletions agents-api/agents_api/routers/agents/list_agents.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,36 @@
import json
from json import JSONDecodeError
from typing import Annotated, Literal
from uuid import UUID

from fastapi import Depends, HTTPException, status
from fastapi import Depends

from ...autogen.openapi_model import Agent, ListResponse
from ...dependencies.developer_id import get_developer_id
from ...dependencies.query_filter import create_filter_extractor
from ...models.agent.list_agents import list_agents as list_agents_query
from .router import router


@router.get("/agents", tags=["agents"])
async def list_agents(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
# Expects the dot notation of object in query params
# Example:
# > ?metadata_filter.name=John&metadata_filter.age=30
metadata_filter: Annotated[
dict, Depends(create_filter_extractor("metadata_filter"))
],
limit: int = 100,
offset: int = 0,
sort_by: Literal["created_at", "updated_at"] = "created_at",
direction: Literal["asc", "desc"] = "desc",
metadata_filter: str = "{}",
) -> ListResponse[Agent]:
try:
metadata_filter = json.loads(metadata_filter)
except JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="metadata_filter is not a valid JSON",
)

agents = list_agents_query(
developer_id=x_developer_id,
limit=limit,
offset=offset,
sort_by=sort_by,
direction=direction,
metadata_filter=metadata_filter,
metadata_filter=metadata_filter or {},
)

return ListResponse[Agent](items=agents)
33 changes: 10 additions & 23 deletions agents-api/agents_api/routers/docs/list_docs.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,27 @@
import json
from json import JSONDecodeError
from typing import Annotated, Literal
from uuid import UUID

from fastapi import Depends, HTTPException, status
from fastapi import Depends

from ...autogen.openapi_model import Doc, ListResponse
from ...dependencies.developer_id import get_developer_id
from ...dependencies.query_filter import create_filter_extractor
from ...models.docs.list_docs import list_docs as list_docs_query
from .router import router


@router.get("/users/{user_id}/docs", tags=["docs"])
async def list_user_docs(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
metadata_filter: Annotated[
dict, Depends(create_filter_extractor("metadata_filter"))
],
user_id: UUID,
limit: int = 100,
offset: int = 0,
sort_by: Literal["created_at", "updated_at"] = "created_at",
direction: Literal["asc", "desc"] = "desc",
metadata_filter: str = "{}",
) -> ListResponse[Doc]:
try:
metadata_filter = json.loads(metadata_filter)
except JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="metadata_filter is not a valid JSON",
)

docs = list_docs_query(
developer_id=x_developer_id,
owner_type="user",
Expand All @@ -37,7 +30,7 @@ async def list_user_docs(
offset=offset,
sort_by=sort_by,
direction=direction,
metadata_filter=metadata_filter,
metadata_filter=metadata_filter or {},
)

return ListResponse[Doc](items=docs)
Expand All @@ -46,21 +39,15 @@ async def list_user_docs(
@router.get("/agents/{agent_id}/docs", tags=["docs"])
async def list_agent_docs(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
metadata_filter: Annotated[
dict, Depends(create_filter_extractor("metadata_filter"))
],
agent_id: UUID,
limit: int = 100,
offset: int = 0,
sort_by: Literal["created_at", "updated_at"] = "created_at",
direction: Literal["asc", "desc"] = "desc",
metadata_filter: str = "{}",
) -> ListResponse[Doc]:
try:
metadata_filter = json.loads(metadata_filter)
except JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="metadata_filter is not a valid JSON",
)

docs = list_docs_query(
developer_id=x_developer_id,
owner_type="agent",
Expand All @@ -69,7 +56,7 @@ async def list_agent_docs(
offset=offset,
sort_by=sort_by,
direction=direction,
metadata_filter=metadata_filter,
metadata_filter=metadata_filter or {},
)

return ListResponse[Doc](items=docs)
19 changes: 6 additions & 13 deletions agents-api/agents_api/routers/sessions/list_sessions.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,33 @@
import json
from json import JSONDecodeError
from typing import Annotated, Literal
from uuid import UUID

from fastapi import Depends, HTTPException, status
from fastapi import Depends

from ...autogen.openapi_model import ListResponse, Session
from ...dependencies.developer_id import get_developer_id
from ...dependencies.query_filter import create_filter_extractor
from ...models.session.list_sessions import list_sessions as list_sessions_query
from .router import router


@router.get("/sessions", tags=["sessions"])
async def list_sessions(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
metadata_filter: Annotated[
dict, Depends(create_filter_extractor("metadata_filter"))
] = {},
limit: int = 100,
offset: int = 0,
sort_by: Literal["created_at", "updated_at"] = "created_at",
direction: Literal["asc", "desc"] = "desc",
metadata_filter: str = "{}",
) -> ListResponse[Session]:
try:
metadata_filter = json.loads(metadata_filter)
except JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="metadata_filter is not a valid JSON",
)

sessions = list_sessions_query(
developer_id=x_developer_id,
limit=limit,
offset=offset,
sort_by=sort_by,
direction=direction,
metadata_filter=metadata_filter,
metadata_filter=metadata_filter or {},
)

return ListResponse[Session](items=sessions)
2 changes: 2 additions & 0 deletions typespec/common/scalars.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ namespace Common;
@format("uuid")
scalar uuid extends string;

alias concreteType = numeric | string | boolean | null;

/**
* For Unicode character safety
* See: https://unicode.org/reports/tr31/
Expand Down
5 changes: 3 additions & 2 deletions typespec/common/types.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace Common;
//

alias Metadata = Record<unknown>;
alias MetadataFilter = Record<concreteType>;

model ResourceCreatedResponse {
@doc("ID of created resource")
Expand Down Expand Up @@ -48,6 +49,6 @@ model PaginationOptions {
/** Sort direction */
@query direction: sortDirection = "asc",

/** JSON string of object that should be used to filter objects by metadata */
@query metadata_filter: string = "{}",
/** Object to filter results by metadata */
@query metadata_filter: MetadataFilter,
}
11 changes: 8 additions & 3 deletions typespec/tsp-output/@typespec/openapi3/openapi-0.4.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1317,10 +1317,15 @@ components:
name: metadata_filter
in: query
required: true
description: JSON string of object that should be used to filter objects by metadata
description: Object to filter results by metadata
schema:
type: string
default: '{}'
type: object
additionalProperties:
anyOf:
- type: number
- type: string
- type: boolean
nullable: true
explode: false
Common.PaginationOptions.offset:
name: offset
Expand Down
11 changes: 8 additions & 3 deletions typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1317,10 +1317,15 @@ components:
name: metadata_filter
in: query
required: true
description: JSON string of object that should be used to filter objects by metadata
description: Object to filter results by metadata
schema:
type: string
default: '{}'
type: object
additionalProperties:
anyOf:
- type: number
- type: string
- type: boolean
nullable: true
explode: false
Common.PaginationOptions.offset:
name: offset
Expand Down