From 3d2231b193773c01023764bfeb8006fef044a81e Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Thu, 9 May 2024 14:43:41 +0300 Subject: [PATCH] feat: Add doc IDs to the session chat response --- .../agents_api/autogen/openapi_model.py | 56 ++++++++++--------- .../models/entry/proc_mem_context.py | 26 +++++---- .../agents_api/routers/sessions/routers.py | 5 +- .../agents_api/routers/sessions/session.py | 22 ++++++-- openapi.yaml | 17 ++++++ 5 files changed, 84 insertions(+), 42 deletions(-) diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index f9aa80737..786aba3d3 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: openapi.yaml -# timestamp: 2024-04-30T17:38:56+00:00 +# timestamp: 2024-05-09T11:41:12+00:00 from __future__ import annotations @@ -362,30 +362,6 @@ class Response(BaseModel): items: ChatMLMessage | None = None -class ChatResponse(BaseModel): - """ - Represents a chat completion response returned by model, based on the provided input. - """ - - id: UUID - """ - A unique identifier for the chat completion. - """ - finish_reason: FinishReason - """ - The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. - """ - response: List[List[ChatMLMessage] | Response] - """ - A list of chat completion messages produced as a response. - """ - usage: CompletionUsage - jobs: Set[UUID] | None = None - """ - IDs (if any) of jobs created as part of this request - """ - - class Memory(BaseModel): agent_id: UUID """ @@ -840,6 +816,11 @@ class PartialFunctionDef(BaseModel): """ +class DocIds(BaseModel): + agent_doc_ids: List[str] + user_doc_ids: List[str] + + class Agent(BaseModel): name: str """ @@ -987,6 +968,31 @@ class ChatInputData(BaseModel): """ +class ChatResponse(BaseModel): + """ + Represents a chat completion response returned by model, based on the provided input. + """ + + id: UUID + """ + A unique identifier for the chat completion. + """ + finish_reason: FinishReason + """ + The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. + """ + response: List[List[ChatMLMessage] | Response] + """ + A list of chat completion messages produced as a response. + """ + usage: CompletionUsage + jobs: Set[UUID] | None = None + """ + IDs (if any) of jobs created as part of this request + """ + doc_ids: DocIds + + class ChatInput(ChatInputData, ChatSettings, MemoryAccessOptions): pass diff --git a/agents-api/agents_api/models/entry/proc_mem_context.py b/agents-api/agents_api/models/entry/proc_mem_context.py index c6712bd51..ab1ecc35a 100644 --- a/agents-api/agents_api/models/entry/proc_mem_context.py +++ b/agents-api/agents_api/models/entry/proc_mem_context.py @@ -175,11 +175,11 @@ def proc_mem_context_query( # Collect docs # Search for agent docs - ?[role, name, content, token_count, created_at, index] := + ?[role, name, content, token_count, created_at, index, agent_doc_id] := *_input{{agent_id, doc_query}}, *agent_docs {{ agent_id, - doc_id, + doc_id: agent_doc_id, created_at, }}, ~information_snippets:embedding_space {{ @@ -201,11 +201,11 @@ def proc_mem_context_query( index = 5 + (snippet_idx * 0.01) # Search for user docs - ?[role, name, content, token_count, created_at, index] := + ?[role, name, content, token_count, created_at, index, user_doc_id] := *_input{{user_id, doc_query}}, *user_docs {{ user_id, - doc_id, + doc_id: user_doc_id, created_at, }}, ~information_snippets:embedding_space {{ @@ -234,6 +234,8 @@ def proc_mem_context_query( token_count: Int, created_at: Float, index: Float, + agent_doc_id: Uuid default null, + user_doc_id: Uuid default null, }} }} {{ # Collect all entries related to the session. @@ -269,13 +271,14 @@ def proc_mem_context_query( }} {{ # Combine all collected data into a structured format. # Combine all - ?[role, name, content, token_count, created_at, index] := + ?[role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id] := *_preamble{{ role, name, content, token_count, created_at, index, }}, + agent_doc_id = null, user_doc_id = null, # Now let's get instructions - ?[role, name, content, token_count, created_at, index] := + ?[role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id] := *_input{{agent_id}}, *agents{{ agent_id, @@ -288,21 +291,24 @@ def proc_mem_context_query( content = instruction, token_count = round(length(instruction) / 3.5), instruction in instructions, + agent_doc_id = null, user_doc_id = null, - ?[role, name, content, token_count, created_at, index] := + ?[role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id] := *_tools{{ role, name, content, token_count, created_at, index }}, + agent_doc_id = null, user_doc_id = null, - ?[role, name, content, token_count, created_at, index] := + ?[role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id] := *_docs {{ - role, name, content, token_count, created_at, index + role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id }}, - ?[role, name, content, token_count, created_at, index] := + ?[role, name, content, token_count, created_at, index, agent_doc_id, user_doc_id] := *_entries{{ role, name, content, token_count, created_at, index }}, + agent_doc_id = null, user_doc_id = null, :sort index, created_at }} diff --git a/agents-api/agents_api/routers/sessions/routers.py b/agents-api/agents_api/routers/sessions/routers.py index 6831d96f5..caa9fcefc 100644 --- a/agents-api/agents_api/routers/sessions/routers.py +++ b/agents-api/agents_api/routers/sessions/routers.py @@ -292,7 +292,9 @@ async def session_chat( min_p=request.min_p, preset=request.preset, ) - response, new_entry, bg_task = await session.run(request.messages, settings) + response, new_entry, bg_task, doc_ids = await session.run( + request.messages, settings + ) jobs = None if bg_task: @@ -308,4 +310,5 @@ async def session_chat( response=[resp], usage=CompletionUsage(**response.usage.model_dump()), jobs=jobs, + doc_ids=doc_ids, ) diff --git a/agents-api/agents_api/routers/sessions/session.py b/agents-api/agents_api/routers/sessions/session.py index e71d53871..b85a8503a 100644 --- a/agents-api/agents_api/routers/sessions/session.py +++ b/agents-api/agents_api/routers/sessions/session.py @@ -11,7 +11,7 @@ import litellm from litellm import acompletion -from ...autogen.openapi_model import InputChatMLMessage, Tool +from ...autogen.openapi_model import InputChatMLMessage, Tool, DocIds from ...clients.embed import embed from ...clients.temporal import run_summarization_task from ...clients.worker.types import ChatML @@ -125,7 +125,7 @@ def rm_user_assistant(m): async def run( self, new_input, settings: Settings - ) -> tuple[ChatCompletion, Entry, Callable | None]: + ) -> tuple[ChatCompletion, Entry, Callable | None, DocIds]: # TODO: implement locking at some point # Get session data @@ -134,7 +134,7 @@ async def run( raise SessionNotFoundError(self.developer_id, self.session_id) # Assemble context - init_context, final_settings = await self.forward( + init_context, final_settings, doc_ids = await self.forward( session_data, new_input, settings ) @@ -180,14 +180,14 @@ async def run( new_input, total_tokens, new_entry, final_settings ) - return response, new_entry, backward_pass + return response, new_entry, backward_pass, doc_ids async def forward( self, session_data: SessionData | None, new_input: list[Entry], settings: Settings, - ) -> tuple[list[ChatML], Settings]: + ) -> tuple[list[ChatML], Settings, DocIds]: # role, name, content, token_count, created_at string_to_embed = "\n".join( [f"{msg.name or msg.role}: {msg.content}" for msg in new_input] @@ -214,12 +214,22 @@ async def forward( first_instruction_idx = -1 first_instruction_created_at = 0 tools = [] + doc_ids = DocIds(agent_doc_ids=[], user_doc_ids=[]) for idx, row in proc_mem_context_query( session_id=self.session_id, tool_query_embedding=tool_query_embedding, doc_query_embedding=doc_query_embedding, ).iterrows(): + agent_doc_id = row.get("agent_doc_id") + user_doc_id = row.get("user_doc_id") + + if agent_doc_id is not None: + doc_ids.agent_doc_ids.append(agent_doc_id) + + if user_doc_id is not None: + doc_ids.user_doc_ids.append(user_doc_id) + # If a `functions` message is encountered, extract into tools list if row["name"] == "functions": # FIXME: This might also break if {role: system, name: functions, content} but content not valid json object @@ -321,7 +331,7 @@ async def forward( settings.tools = settings.tools or [] settings.tools.extend(tools) - return messages, settings + return messages, settings, doc_ids async def generate( self, init_context: list[ChatML], settings: Settings diff --git a/openapi.yaml b/openapi.yaml index 7f368452c..358ed03fa 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -1699,11 +1699,14 @@ components: format: uuid uniqueItems: true description: IDs (if any) of jobs created as part of this request + doc_ids: + $ref: '#/components/schemas/DocIds' required: - usage - response - finish_reason - id + - doc_ids Memory: $schema: http://json-schema.org/draft-04/schema# type: object @@ -2326,6 +2329,20 @@ components: parameters: $ref: '#/components/schemas/FunctionParameters' description: Parameters accepeted by this function + DocIds: + type: object + properties: + agent_doc_ids: + type: array + items: + type: string + user_doc_ids: + type: array + items: + type: string + required: + - agent_doc_ids + - user_doc_ids securitySchemes: api-key: type: apiKey