diff --git a/agent/chain.py b/agent/chain.py index 8d535f9..1ce7f1c 100644 --- a/agent/chain.py +++ b/agent/chain.py @@ -58,8 +58,8 @@ def think(cls, cache: Conversation, input: str): chain = thought_prompt | cls.llm def save_new_messages(ai_response): - cache.add_message("response", HumanMessage(content=input)) - cache.add_message("response", AIMessage(content=ai_response)) + cache.add_message("thought", HumanMessage(content=input)) + cache.add_message("thought", AIMessage(content=ai_response)) return Streamable(chain.astream({}, {"tags": ["thought"], "metadata": {"conversation_id": cache.conversation_id, "user_id": cache.user_id}}), save_new_messages)