Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Make chat route tests pass #454

Merged
merged 3 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions agents-api/agents_api/autogen/Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .Common import LogitBias
from .Docs import DocReference
from .Entries import ChatMLMessage, InputChatMLMessage
from .Entries import InputChatMLMessage
from .Tools import FunctionTool, NamedToolChoice


Expand All @@ -23,7 +23,7 @@ class BaseChatOutput(BaseModel):
"""
The reason the model stopped generating tokens
"""
logprobs: Annotated[LogProbResponse | None, Field(...)]
logprobs: LogProbResponse | None = None
"""
The log probabilities of tokens
"""
Expand All @@ -33,7 +33,7 @@ class BaseChatResponse(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
usage: Annotated[CompetionUsage | None, Field(...)]
usage: CompetionUsage | None = None
"""
Usage statistics for the completion request
"""
Expand Down Expand Up @@ -61,7 +61,7 @@ class BaseTokenLogProb(BaseModel):
"""
The log probability of the token
"""
bytes: Annotated[list[int] | None, Field(...)]
bytes: list[int] | None = None


class ChatInputData(BaseModel):
Expand Down Expand Up @@ -90,7 +90,7 @@ class ChatOutputChunk(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
delta: ChatMLMessage
delta: InputChatMLMessage
"""
The message generated by the model
"""
Expand Down Expand Up @@ -166,7 +166,7 @@ class MultipleChatOutput(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
messages: list[ChatMLMessage]
messages: list[InputChatMLMessage]


class OpenAISettings(BaseModel):
Expand Down Expand Up @@ -199,7 +199,7 @@ class SingleChatOutput(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
message: ChatMLMessage
message: InputChatMLMessage


class TokenLogProb(BaseTokenLogProb):
Expand Down
41 changes: 2 additions & 39 deletions agents-api/agents_api/autogen/Entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class BaseEntry(BaseModel):
)
role: Literal[
"user",
"agent",
"assistant",
"system",
"function",
"function_response",
Expand Down Expand Up @@ -67,43 +67,6 @@ class ChatMLImageContentPart(BaseModel):
"""


class ChatMLMessage(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
role: Literal[
"user",
"agent",
"system",
"function",
"function_response",
"function_call",
"auto",
]
"""
The role of the message
"""
content: str | list[str] | list[ChatMLTextContentPart | ChatMLImageContentPart]
"""
The content parts of the message
"""
name: str | None = None
"""
Name
"""
tool_calls: Annotated[
list[ChosenToolCall], Field([], json_schema_extra={"readOnly": True})
]
"""
Tool calls generated by the model.
"""
created_at: Annotated[AwareDatetime, Field(json_schema_extra={"readOnly": True})]
"""
When this resource was created as UTC date-time
"""
id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})]


class ChatMLTextContentPart(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
Expand Down Expand Up @@ -159,7 +122,7 @@ class InputChatMLMessage(BaseModel):
)
role: Literal[
"user",
"agent",
"assistant",
"system",
"function",
"function_response",
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_active_agent(self) -> Agent:
"""
Get the active agent from the session data.
"""
requested_agent: UUID | None = self.settings.agent
requested_agent: UUID | None = self.settings and self.settings.agent

if requested_agent:
assert requested_agent in [agent.id for agent in self.agents], (
Expand All @@ -67,15 +67,15 @@ def get_active_agent(self) -> Agent:
return self.agents[0]

def merge_settings(self, chat_input: ChatInput) -> ChatSettings:
request_settings = ChatSettings.model_validate(chat_input)
request_settings = chat_input.model_dump(exclude_unset=True)
active_agent = self.get_active_agent()
default_settings = active_agent.default_settings

self.settings = settings = ChatSettings(
**{
"model": active_agent.model,
**default_settings.model_dump(),
**request_settings.model_dump(exclude_unset=True),
**request_settings,
}
)

Expand Down
39 changes: 34 additions & 5 deletions agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@ async def render_template_string(
return rendered


async def render_template_chatml(
messages: list[dict], variables: dict, check: bool = False
) -> list[dict]:
# Parse template
# FIXME: should template_strings contain a list of ChatMLTextContentPart? Should we handle it somehow?
templates = [jinja_env.from_string(msg["content"]) for msg in messages]

# If check is required, get required vars from template and validate variables
if check:
for template in templates:
schema = to_json_schema(infer(template))
validate(instance=variables, schema=schema)

# Render
rendered = [
({**msg, "content": await template.render_async(**variables)})
for template, msg in zip(templates, messages)
]

return rendered


async def render_template_parts(
template_strings: list[dict], variables: dict, check: bool = False
) -> list[dict]:
Expand Down Expand Up @@ -73,7 +95,7 @@ async def render_template_parts(


async def render_template(
template_string: str | list[dict],
input: str | list[dict],
variables: dict,
check: bool = False,
skip_vars: list[str] | None = None,
Expand All @@ -83,8 +105,15 @@ async def render_template(
for name, val in variables.items()
if not (skip_vars is not None and isinstance(name, str) and name in skip_vars)
}
if isinstance(template_string, str):
return await render_template_string(template_string, variables, check)

elif isinstance(template_string, list):
return await render_template_parts(template_string, variables, check)
match input:
case str():
future = render_template_string(input, variables, check)

case [{"content": str()}, *_]:
future = render_template_chatml(input, variables, check)

case _:
future = render_template_parts(input, variables, check)

return await future
3 changes: 3 additions & 0 deletions agents-api/agents_api/models/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def dbsf_normalize(scores: list[float]) -> list[float]:
Scores scaled using minmax scaler with our custom feature range
(extremes indicated as 3 standard deviations from the mean)
"""
if len(scores) < 2:
return scores

sd = stdev(scores)
if sd == 0:
return scores
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/routers/docs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa: F401
from .create_doc import create_agent_doc, create_user_doc
from .delete_doc import delete_agent_doc, delete_user_doc
from .embed import embed
from .get_doc import get_doc
from .list_docs import list_agent_docs, list_user_docs
from .router import router
Expand Down
28 changes: 28 additions & 0 deletions agents-api/agents_api/routers/docs/embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Annotated

from fastapi import Depends
from pydantic import UUID4

import agents_api.clients.embed as embedder

from ...autogen.openapi_model import (
EmbedQueryRequest,
EmbedQueryResponse,
)
from ...dependencies.developer_id import get_developer_id
from .router import router


@router.post("/embed", tags=["docs"])
async def embed(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
data: EmbedQueryRequest,
) -> EmbedQueryResponse:
text_to_embed: str | list[str] = data.text
text_to_embed: list[str] = (
[text_to_embed] if isinstance(text_to_embed, str) else text_to_embed
)

vectors = await embedder.embed(inputs=text_to_embed)

return EmbedQueryResponse(vectors=vectors)
Loading
Loading