Skip to content

Commit

Permalink
fix: memory not working
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Jan 9, 2024
1 parent 2bcfaec commit e434278
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
17 changes: 15 additions & 2 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from jose import JWTError, jwt
from langchain_core.messages.ai import AIMessage, AIMessageChunk

from backend.config import RagConfig
from backend.database import Database
from backend.logger import get_logger
from backend.model import Message
from backend.rag_components.chat_message_history import get_conversation_buffer_memory
from backend.rag_components.rag import RAG
from backend.user_management import (
ALGORITHM,
Expand Down Expand Up @@ -89,7 +91,13 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current
}
rag = RAG(config=Path(__file__).parent / "config.yaml", logger=logger, context=context)
response = rag.generate_response(message)
return StreamingResponse(stream_response(message.chat_id, response), media_type="text/event-stream")
response_stream = stream_response(
rag=rag,
chat_id=message.chat_id,
question=message.content,
response=response
)
return StreamingResponse(response_stream, media_type="text/event-stream")


@app.post("/chat/regenerate")
Expand Down Expand Up @@ -130,7 +138,7 @@ async def chat(chat_id: str, current_user: User = Depends(get_current_user)) ->
return {"chat_id": chat_id, "messages": [message.model_dump() for message in messages]}


async def stream_response(chat_id: str, response):
async def stream_response(rag: RAG, chat_id: str, question, response):
full_response = ""
response_id = str(uuid4())
try:
Expand All @@ -154,6 +162,7 @@ async def stream_response(chat_id: str, response):
yield full_response.encode("utf-8")
finally:
await log_response_to_db(chat_id, full_response)
await memorize_response(rag.config, chat_id, question, full_response)

async def log_response_to_db(chat_id: str, full_response: str):
response_id = str(uuid4())
Expand All @@ -163,6 +172,10 @@ async def log_response_to_db(chat_id: str, full_response: str):
(response_id, datetime.now().isoformat(), chat_id, "assistant", full_response),
)

async def memorize_response(rag_config: RagConfig, chat_id: str, question: str, answer: str):
memory = get_conversation_buffer_memory(rag_config, chat_id)
memory.save_context({"question": question}, {"answer": answer})


############################################
### Feedback ###
Expand Down
2 changes: 1 addition & 1 deletion docs/recipe_vector_stores_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ As we need a backend SQL database to store conversation history and other info,
[See the recipes for database configs here](recipe_databases_configs.md)

```shell
pip install pgvector
pip install psycopg2-binary pgvector
```

```yaml
Expand Down

0 comments on commit e434278

Please sign in to comment.