Skip to content

Commit

Permalink
Merge pull request #454 from julep-ai/f/chat-tests
Browse files Browse the repository at this point in the history
feat(agents-api): Make chat route tests pass
  • Loading branch information
whiterabbit1983 authored Aug 13, 2024
2 parents 95dfa18 + 6279937 commit 67db1fd
Show file tree
Hide file tree
Showing 43 changed files with 348 additions and 436 deletions.
14 changes: 7 additions & 7 deletions agents-api/agents_api/autogen/Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .Common import LogitBias
from .Docs import DocReference
from .Entries import ChatMLMessage, InputChatMLMessage
from .Entries import InputChatMLMessage
from .Tools import FunctionTool, NamedToolChoice


Expand All @@ -23,7 +23,7 @@ class BaseChatOutput(BaseModel):
"""
The reason the model stopped generating tokens
"""
logprobs: Annotated[LogProbResponse | None, Field(...)]
logprobs: LogProbResponse | None = None
"""
The log probabilities of tokens
"""
Expand All @@ -33,7 +33,7 @@ class BaseChatResponse(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
usage: Annotated[CompetionUsage | None, Field(...)]
usage: CompetionUsage | None = None
"""
Usage statistics for the completion request
"""
Expand Down Expand Up @@ -61,7 +61,7 @@ class BaseTokenLogProb(BaseModel):
"""
The log probability of the token
"""
bytes: Annotated[list[int] | None, Field(...)]
bytes: list[int] | None = None


class ChatInputData(BaseModel):
Expand Down Expand Up @@ -90,7 +90,7 @@ class ChatOutputChunk(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
delta: ChatMLMessage
delta: InputChatMLMessage
"""
The message generated by the model
"""
Expand Down Expand Up @@ -166,7 +166,7 @@ class MultipleChatOutput(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
messages: list[ChatMLMessage]
messages: list[InputChatMLMessage]


class OpenAISettings(BaseModel):
Expand Down Expand Up @@ -199,7 +199,7 @@ class SingleChatOutput(BaseChatOutput):
model_config = ConfigDict(
populate_by_name=True,
)
message: ChatMLMessage
message: InputChatMLMessage


class TokenLogProb(BaseTokenLogProb):
Expand Down
41 changes: 2 additions & 39 deletions agents-api/agents_api/autogen/Entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class BaseEntry(BaseModel):
)
role: Literal[
"user",
"agent",
"assistant",
"system",
"function",
"function_response",
Expand Down Expand Up @@ -67,43 +67,6 @@ class ChatMLImageContentPart(BaseModel):
"""


class ChatMLMessage(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
role: Literal[
"user",
"agent",
"system",
"function",
"function_response",
"function_call",
"auto",
]
"""
The role of the message
"""
content: str | list[str] | list[ChatMLTextContentPart | ChatMLImageContentPart]
"""
The content parts of the message
"""
name: str | None = None
"""
Name
"""
tool_calls: Annotated[
list[ChosenToolCall], Field([], json_schema_extra={"readOnly": True})
]
"""
Tool calls generated by the model.
"""
created_at: Annotated[AwareDatetime, Field(json_schema_extra={"readOnly": True})]
"""
When this resource was created as UTC date-time
"""
id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})]


class ChatMLTextContentPart(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
Expand Down Expand Up @@ -159,7 +122,7 @@ class InputChatMLMessage(BaseModel):
)
role: Literal[
"user",
"agent",
"assistant",
"system",
"function",
"function_response",
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_active_agent(self) -> Agent:
"""
Get the active agent from the session data.
"""
requested_agent: UUID | None = self.settings.agent
requested_agent: UUID | None = self.settings and self.settings.agent

if requested_agent:
assert requested_agent in [agent.id for agent in self.agents], (
Expand All @@ -67,15 +67,15 @@ def get_active_agent(self) -> Agent:
return self.agents[0]

def merge_settings(self, chat_input: ChatInput) -> ChatSettings:
request_settings = ChatSettings.model_validate(chat_input)
request_settings = chat_input.model_dump(exclude_unset=True)
active_agent = self.get_active_agent()
default_settings = active_agent.default_settings

self.settings = settings = ChatSettings(
**{
"model": active_agent.model,
**default_settings.model_dump(),
**request_settings.model_dump(exclude_unset=True),
**request_settings,
}
)

Expand Down
39 changes: 34 additions & 5 deletions agents-api/agents_api/common/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@ async def render_template_string(
return rendered


async def render_template_chatml(
messages: list[dict], variables: dict, check: bool = False
) -> list[dict]:
# Parse template
# FIXME: should template_strings contain a list of ChatMLTextContentPart? Should we handle it somehow?
templates = [jinja_env.from_string(msg["content"]) for msg in messages]

# If check is required, get required vars from template and validate variables
if check:
for template in templates:
schema = to_json_schema(infer(template))
validate(instance=variables, schema=schema)

# Render
rendered = [
({**msg, "content": await template.render_async(**variables)})
for template, msg in zip(templates, messages)
]

return rendered


async def render_template_parts(
template_strings: list[dict], variables: dict, check: bool = False
) -> list[dict]:
Expand Down Expand Up @@ -73,7 +95,7 @@ async def render_template_parts(


async def render_template(
template_string: str | list[dict],
input: str | list[dict],
variables: dict,
check: bool = False,
skip_vars: list[str] | None = None,
Expand All @@ -83,8 +105,15 @@ async def render_template(
for name, val in variables.items()
if not (skip_vars is not None and isinstance(name, str) and name in skip_vars)
}
if isinstance(template_string, str):
return await render_template_string(template_string, variables, check)

elif isinstance(template_string, list):
return await render_template_parts(template_string, variables, check)
match input:
case str():
future = render_template_string(input, variables, check)

case [{"content": str()}, *_]:
future = render_template_chatml(input, variables, check)

case _:
future = render_template_parts(input, variables, check)

return await future
3 changes: 3 additions & 0 deletions agents-api/agents_api/models/docs/search_docs_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def dbsf_normalize(scores: list[float]) -> list[float]:
Scores scaled using minmax scaler with our custom feature range
(extremes indicated as 3 standard deviations from the mean)
"""
if len(scores) < 2:
return scores

sd = stdev(scores)
if sd == 0:
return scores
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/routers/docs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa: F401
from .create_doc import create_agent_doc, create_user_doc
from .delete_doc import delete_agent_doc, delete_user_doc
from .embed import embed
from .get_doc import get_doc
from .list_docs import list_agent_docs, list_user_docs
from .router import router
Expand Down
28 changes: 28 additions & 0 deletions agents-api/agents_api/routers/docs/embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Annotated

from fastapi import Depends
from pydantic import UUID4

import agents_api.clients.embed as embedder

from ...autogen.openapi_model import (
EmbedQueryRequest,
EmbedQueryResponse,
)
from ...dependencies.developer_id import get_developer_id
from .router import router


@router.post("/embed", tags=["docs"])
async def embed(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
data: EmbedQueryRequest,
) -> EmbedQueryResponse:
text_to_embed: str | list[str] = data.text
text_to_embed: list[str] = (
[text_to_embed] if isinstance(text_to_embed, str) else text_to_embed
)

vectors = await embedder.embed(inputs=text_to_embed)

return EmbedQueryResponse(vectors=vectors)
Loading

0 comments on commit 67db1fd

Please sign in to comment.