Skip to content

Commit

Permalink
feat: Add doc IDs to the session chat response
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed May 10, 2024
1 parent cbae3a3 commit 3d2231b
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 42 deletions.
56 changes: 31 additions & 25 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2024-04-30T17:38:56+00:00
# timestamp: 2024-05-09T11:41:12+00:00

from __future__ import annotations

Expand Down Expand Up @@ -362,30 +362,6 @@ class Response(BaseModel):
items: ChatMLMessage | None = None


class ChatResponse(BaseModel):
"""
Represents a chat completion response returned by model, based on the provided input.
"""

id: UUID
"""
A unique identifier for the chat completion.
"""
finish_reason: FinishReason
"""
The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.
"""
response: List[List[ChatMLMessage] | Response]
"""
A list of chat completion messages produced as a response.
"""
usage: CompletionUsage
jobs: Set[UUID] | None = None
"""
IDs (if any) of jobs created as part of this request
"""


class Memory(BaseModel):
agent_id: UUID
"""
Expand Down Expand Up @@ -840,6 +816,11 @@ class PartialFunctionDef(BaseModel):
"""


class DocIds(BaseModel):
agent_doc_ids: List[str]
user_doc_ids: List[str]


class Agent(BaseModel):
name: str
"""
Expand Down Expand Up @@ -987,6 +968,31 @@ class ChatInputData(BaseModel):
"""


class ChatResponse(BaseModel):
"""
Represents a chat completion response returned by model, based on the provided input.
"""

id: UUID
"""
A unique identifier for the chat completion.
"""
finish_reason: FinishReason
"""
The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.
"""
response: List[List[ChatMLMessage] | Response]
"""
A list of chat completion messages produced as a response.
"""
usage: CompletionUsage
jobs: Set[UUID] | None = None
"""
IDs (if any) of jobs created as part of this request
"""
doc_ids: DocIds


class ChatInput(ChatInputData, ChatSettings, MemoryAccessOptions):
pass

Expand Down
26 changes: 16 additions & 10 deletions agents-api/agents_api/models/entry/proc_mem_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ def proc_mem_context_query(
# Collect docs
# Search for agent docs
?[role, name, content, token_count, created_at, index] :=
?[role, name, content, token_count, created_at, index, agent_doc_id] :=
*_input{{agent_id, doc_query}},
*agent_docs {{
agent_id,
doc_id,
doc_id: agent_doc_id,
created_at,
}},
~information_snippets:embedding_space {{
Expand All @@ -201,11 +201,11 @@ def proc_mem_context_query(
index = 5 + (snippet_idx * 0.01)
# Search for user docs
?[role, name, content, token_count, created_at, index] :=
?[role, name, content, token_count, created_at, index, user_doc_id] :=
*_input{{user_id, doc_query}},
*user_docs {{
user_id,
doc_id,
doc_id: user_doc_id,
created_at,
}},
~information_snippets:embedding_space {{
Expand Down Expand Up @@ -234,6 +234,8 @@ def proc_mem_context_query(
token_count: Int,
created_at: Float,
index: Float,
agent_doc_id: Uuid default null,
user_doc_id: Uuid default null,
}}
}} {{
# Collect all entries related to the session.
Expand Down Expand Up @@ -269,13 +271,14 @@ def proc_mem_context_query(
}} {{
# Combine all collected data into a structured format.
# Combine all
?[role, name, content, token_count, created_at, index] :=
?[role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id] :=
*_preamble{{
role, name, content, token_count, created_at, index,
}},
agent_doc_id = null, user_doc_id = null,
# Now let's get instructions
?[role, name, content, token_count, created_at, index] :=
?[role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id] :=
*_input{{agent_id}},
*agents{{
agent_id,
Expand All @@ -288,21 +291,24 @@ def proc_mem_context_query(
content = instruction,
token_count = round(length(instruction) / 3.5),
instruction in instructions,
agent_doc_id = null, user_doc_id = null,
?[role, name, content, token_count, created_at, index] :=
?[role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id] :=
*_tools{{
role, name, content, token_count, created_at, index
}},
agent_doc_id = null, user_doc_id = null,
?[role, name, content, token_count, created_at, index] :=
?[role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id] :=
*_docs {{
role, name, content, token_count, created_at, index
role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id
}},
?[role, name, content, token_count, created_at, index] :=
?[role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id] :=
*_entries{{
role, name, content, token_count, created_at, index
}},
agent_doc_id = null, user_doc_id = null,
:sort index, created_at
}}
Expand Down
5 changes: 4 additions & 1 deletion agents-api/agents_api/routers/sessions/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ async def session_chat(
min_p=request.min_p,
preset=request.preset,
)
response, new_entry, bg_task = await session.run(request.messages, settings)
response, new_entry, bg_task, doc_ids = await session.run(
request.messages, settings
)

jobs = None
if bg_task:
Expand All @@ -308,4 +310,5 @@ async def session_chat(
response=[resp],
usage=CompletionUsage(**response.usage.model_dump()),
jobs=jobs,
doc_ids=doc_ids,
)
22 changes: 16 additions & 6 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import litellm
from litellm import acompletion

from ...autogen.openapi_model import InputChatMLMessage, Tool
from ...autogen.openapi_model import InputChatMLMessage, Tool, DocIds
from ...clients.embed import embed
from ...clients.temporal import run_summarization_task
from ...clients.worker.types import ChatML
Expand Down Expand Up @@ -125,7 +125,7 @@ def rm_user_assistant(m):

async def run(
self, new_input, settings: Settings
) -> tuple[ChatCompletion, Entry, Callable | None]:
) -> tuple[ChatCompletion, Entry, Callable | None, DocIds]:
# TODO: implement locking at some point

# Get session data
Expand All @@ -134,7 +134,7 @@ async def run(
raise SessionNotFoundError(self.developer_id, self.session_id)

# Assemble context
init_context, final_settings = await self.forward(
init_context, final_settings, doc_ids = await self.forward(
session_data, new_input, settings
)

Expand Down Expand Up @@ -180,14 +180,14 @@ async def run(
new_input, total_tokens, new_entry, final_settings
)

return response, new_entry, backward_pass
return response, new_entry, backward_pass, doc_ids

async def forward(
self,
session_data: SessionData | None,
new_input: list[Entry],
settings: Settings,
) -> tuple[list[ChatML], Settings]:
) -> tuple[list[ChatML], Settings, DocIds]:
# role, name, content, token_count, created_at
string_to_embed = "\n".join(
[f"{msg.name or msg.role}: {msg.content}" for msg in new_input]
Expand All @@ -214,12 +214,22 @@ async def forward(
first_instruction_idx = -1
first_instruction_created_at = 0
tools = []
doc_ids = DocIds(agent_doc_ids=[], user_doc_ids=[])

for idx, row in proc_mem_context_query(
session_id=self.session_id,
tool_query_embedding=tool_query_embedding,
doc_query_embedding=doc_query_embedding,
).iterrows():
agent_doc_id = row.get("agent_doc_id")
user_doc_id = row.get("user_doc_id")

if agent_doc_id is not None:
doc_ids.agent_doc_ids.append(agent_doc_id)

if user_doc_id is not None:
doc_ids.user_doc_ids.append(user_doc_id)

# If a `functions` message is encountered, extract into tools list
if row["name"] == "functions":
# FIXME: This might also break if {role: system, name: functions, content} but content not valid json object
Expand Down Expand Up @@ -321,7 +331,7 @@ async def forward(
settings.tools = settings.tools or []
settings.tools.extend(tools)

return messages, settings
return messages, settings, doc_ids

async def generate(
self, init_context: list[ChatML], settings: Settings
Expand Down
17 changes: 17 additions & 0 deletions openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1699,11 +1699,14 @@ components:
format: uuid
uniqueItems: true
description: IDs (if any) of jobs created as part of this request
doc_ids:
$ref: '#/components/schemas/DocIds'
required:
- usage
- response
- finish_reason
- id
- doc_ids
Memory:
$schema: http://json-schema.org/draft-04/schema#
type: object
Expand Down Expand Up @@ -2326,6 +2329,20 @@ components:
parameters:
$ref: '#/components/schemas/FunctionParameters'
description: Parameters accepeted by this function
DocIds:
type: object
properties:
agent_doc_ids:
type: array
items:
type: string
user_doc_ids:
type: array
items:
type: string
required:
- agent_doc_ids
- user_doc_ids
securitySchemes:
api-key:
type: apiKey
Expand Down

0 comments on commit 3d2231b

Please sign in to comment.