Skip to content

Commit

Permalink
add: db storage
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Dec 19, 2023
1 parent 54cabbe commit 72d8f88
Show file tree
Hide file tree
Showing 5 changed files with 320 additions and 0 deletions.
205 changes: 205 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from datetime import datetime, timedelta
from typing import List
from uuid import uuid4

from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt

import backend.document_store as document_store
from database.database import Database
from backend.document_store import StorageBackend
from backend.model import Doc, Message
from backend.user_management import (
ALGORITHM,
SECRET_KEY,
User,
authenticate_user,
create_access_token,
create_user,
get_user,
user_exists,
)

app = FastAPI()


############################################
### Authentication ###
############################################

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")


async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
"""Get the current user by decoding the JWT token."""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
email: str = payload.get("email") # 'sub' is commonly used to store user identity
if email is None:
raise credentials_exception
# Here you should fetch the user from the database by user_id
user = get_user(email)
if user is None:
raise credentials_exception
return user
except JWTError:
raise credentials_exception


@app.post("/user/signup")
async def signup(user: User) -> dict:
"""Sign up a new user."""
if user_exists(user.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered"
)

create_user(user)
return {"email": user.email}


@app.delete("/user/")
async def delete_user(current_user: User = Depends(get_current_user)) -> dict:
"""Delete an existing user."""
email = current_user.email
try:
user = get_user(email)
if user is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"User {email} not found"
)
delete_user(email)
return {"detail": f"User {email} deleted"}
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error"
)


@app.post("/user/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
"""Log in a user and return an access token."""
user = authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=60)
access_token = create_access_token(data=user.model_dump(), expires_delta=access_token_expires)
return {"access_token": access_token, "token_type": "bearer"}


@app.get("/user/me")
async def user_me(current_user: User = Depends(get_current_user)) -> User:
"""Get the current user's profile."""
return current_user


############################################
### Chat ###
############################################

@app.post("/chat/new")
async def chat_new(current_user: User = Depends(get_current_user)) -> dict:
chat_id = str(uuid4())
timestamp = datetime.now().isoformat()
user_id = current_user.email
with Database() as connection:
connection.query(
"INSERT INTO chat (id, timestamp, user_id) VALUES (?, ?, ?)",
(chat_id, timestamp, user_id),
)
return {"chat_id": chat_id}

@app.post("/chat/{chat_id}/user_message")
async def chat_prompt(message: Message, current_user: User = Depends(get_current_user)) -> dict:
with Database() as connection:
connection.query(
"INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)",
(message.id, message.timestamp, message.chat_id, message.sender, message.content),
)

model_response = Message(
id=str(uuid4()),
timestamp=datetime.now().isoformat(),
chat_id=message.chat_id,
sender="assistant",
content=f"Unique response: {uuid4()}",
)

with Database() as connection:
connection.query(
"INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)",
(model_response.id, model_response.timestamp, model_response.chat_id, model_response.sender, model_response.content),
)
return {"message": model_response}


@app.post("/chat/regenerate")
async def chat_regenerate(current_user: User = Depends(get_current_user)) -> dict:
"""Regenerate a chat session for the current user."""
pass


@app.get("/chat/list")
async def chat_list(current_user: User = Depends(get_current_user)) -> List[dict]:
"""Get a list of chat sessions for the current user."""
pass


@app.get("/chat/{chat_id}")
async def chat(chat_id: str, current_user: User = Depends(get_current_user)) -> dict:
"""Get details of a specific chat session."""
pass


############################################
### Feedback ###
############################################


@app.post("/feedback/{message_id}/thumbs_up")
async def feedback_thumbs_up(
message_id: str, current_user: User = Depends(get_current_user)
) -> None:
with Database() as connection:
connection.query(
"INSERT INTO feedback (id, message_id, feedback) VALUES (?, ?, ?)",
(str(uuid4()), message_id, "thumbs_up"),
)


@app.post("/feedback/{message_id}/thumbs_down")
async def feedback_thumbs_down(
message_id: str, current_user: User = Depends(get_current_user)
) -> None:
with Database() as connection:
connection.query(
"INSERT INTO feedback (id, message_id, feedback) VALUES (?, ?, ?)",
(str(uuid4()), message_id, "thumbs_down"),
)


############################################
### Other ###
############################################


@app.post("/index/documents")
async def index_documents(chunks: List[Doc], bucket: str, storage_backend: StorageBackend) -> None:
"""Index documents in a specified storage backend."""
document_store.store_documents(chunks, bucket, storage_backend)


if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
21 changes: 21 additions & 0 deletions backend/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from datetime import datetime
from uuid import uuid4
from langchain.docstore.document import Document
from pydantic import BaseModel

class Message(BaseModel):
id: str
timestamp: str
chat_id: str
sender: str
content: str

class Doc(BaseModel):
"""Represents a document with content and associated metadata."""

content: str
metadata: dict

def to_langchain_document(self) -> Document:
"""Converts the current Doc instance into a langchain Document."""
return Document(page_content=self.content, metadata=self.metadata)
Binary file added database/database.sqlite
Binary file not shown.
34 changes: 34 additions & 0 deletions database/database_init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
-- Go to https://dbdiagram.io/d/RAGAAS-63dbdcc6296d97641d7e07c8
-- Make your changes
-- Export > Export to PostgresSQL (or other)
-- Translate to SQLite (works with a cmd+k in Cursor, or https://www.rebasedata.com/convert-postgresql-to-sqlite-online)
-- Paste here
-- Replace "CREATE TABLE" with "CREATE TABLE IF NOT EXISTS"

CREATE TABLE IF NOT EXISTS "user" (
"email" TEXT PRIMARY KEY,
"password" TEXT
);

CREATE TABLE IF NOT EXISTS "chat" (
"id" TEXT PRIMARY KEY,
"timestamp" TEXT,
"user_id" TEXT,
FOREIGN KEY ("user_id") REFERENCES "user" ("email")
);

CREATE TABLE IF NOT EXISTS "message" (
"id" TEXT PRIMARY KEY,
"timestamp" TEXT,
"chat_id" TEXT,
"sender" TEXT,
"content" TEXT,
FOREIGN KEY ("chat_id") REFERENCES "chat" ("id")
);

CREATE TABLE IF NOT EXISTS "feedback" (
"id" TEXT PRIMARY KEY,
"message_id" TEXT,
"feedback" TEXT,
FOREIGN KEY ("message_id") REFERENCES "message" ("id")
);
60 changes: 60 additions & 0 deletions frontend/lib/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from uuid import uuid4
from datetime import datetime

import streamlit as st

from dataclasses import dataclass, asdict
from streamlit_feedback import streamlit_feedback

@dataclass
class Message:
sender: str
content: str
chat_id: str
id: str = None
timestamp: str = None

def __post_init__(self):
self.id = str(uuid4()) if self.id is None else self.id
self.timestamp = datetime.now().isoformat() if self.timestamp is None else self.timestamp

def chat():
prompt = st.chat_input("Say something")

if prompt:
if len(st.session_state.get("messages", [])) == 0:
chat_id = new_chat()
else:
chat_id = st.session_state.get("chat_id")

st.session_state.get("messages").append(Message("user", prompt, chat_id))
response = send_prompt(st.session_state.get("messages")[-1])
st.session_state.get("messages").append(Message(**response))

with st.container(border=True):
for message in st.session_state.get("messages", []):
with st.chat_message(message.sender):
st.write(message.content)
if len(st.session_state.get("messages", [])) > 0 and len(st.session_state.get("messages")) % 2 == 0:
streamlit_feedback(key=str(len(st.session_state.get("messages"))), feedback_type="thumbs", on_submit=lambda feedback: send_feedback(st.session_state.get("messages")[-1].id, feedback))


def new_chat():
session = st.session_state.get("session")
response = session.post("/chat/new")
st.session_state["chat_id"] = response.json()["chat_id"]
st.session_state["messages"] = []
return response.json()["chat_id"]

def send_prompt(message: Message):
session = st.session_state.get("session")
response = session.post(f"/chat/{message.chat_id}/user_message", json=asdict(message))
print(response.headers)
print(response.text)
return response.json()["message"]

def send_feedback(message_id: str, feedback: str):
feedback = "thumbs_up" if feedback["score"] == "👍" else "thumbs_down"
session = st.session_state.get("session")
response = session.post(f"/feedback/{message_id}/{feedback}")
return response.text

0 comments on commit 72d8f88

Please sign in to comment.