diff --git a/backend/main.py b/backend/main.py new file mode 100644 index 0000000..e6187ed --- /dev/null +++ b/backend/main.py @@ -0,0 +1,205 @@ +from datetime import datetime, timedelta +from typing import List +from uuid import uuid4 + +from fastapi import Depends, FastAPI, HTTPException, status +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from jose import JWTError, jwt + +import backend.document_store as document_store +from database.database import Database +from backend.document_store import StorageBackend +from backend.model import Doc, Message +from backend.user_management import ( + ALGORITHM, + SECRET_KEY, + User, + authenticate_user, + create_access_token, + create_user, + get_user, + user_exists, +) + +app = FastAPI() + + +############################################ +### Authentication ### +############################################ + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") + + +async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: + """Get the current user by decoding the JWT token.""" + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + email: str = payload.get("email") # 'sub' is commonly used to store user identity + if email is None: + raise credentials_exception + # Here you should fetch the user from the database by user_id + user = get_user(email) + if user is None: + raise credentials_exception + return user + except JWTError: + raise credentials_exception + + +@app.post("/user/signup") +async def signup(user: User) -> dict: + """Sign up a new user.""" + if user_exists(user.email): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered" + ) + + create_user(user) + return {"email": user.email} + + +@app.delete("/user/") +async def delete_user(current_user: User = Depends(get_current_user)) -> dict: + """Delete an existing user.""" + email = current_user.email + try: + user = get_user(email) + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"User {email} not found" + ) + delete_user(email) + return {"detail": f"User {email} deleted"} + except Exception: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error" + ) + + +@app.post("/user/login") +async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict: + """Log in a user and return an access token.""" + user = authenticate_user(form_data.username, form_data.password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token_expires = timedelta(minutes=60) + access_token = create_access_token(data=user.model_dump(), expires_delta=access_token_expires) + return {"access_token": access_token, "token_type": "bearer"} + + +@app.get("/user/me") +async def user_me(current_user: User = Depends(get_current_user)) -> User: + """Get the current user's profile.""" + return current_user + + +############################################ +### Chat ### +############################################ + +@app.post("/chat/new") +async def chat_new(current_user: User = Depends(get_current_user)) -> dict: + chat_id = str(uuid4()) + timestamp = datetime.now().isoformat() + user_id = current_user.email + with Database() as connection: + connection.query( + "INSERT INTO chat (id, timestamp, user_id) VALUES (?, ?, ?)", + (chat_id, timestamp, user_id), + ) + return {"chat_id": chat_id} + +@app.post("/chat/{chat_id}/user_message") +async def chat_prompt(message: Message, current_user: User = Depends(get_current_user)) -> dict: + with Database() as connection: + connection.query( + "INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)", + (message.id, message.timestamp, message.chat_id, message.sender, message.content), + ) + + model_response = Message( + id=str(uuid4()), + timestamp=datetime.now().isoformat(), + chat_id=message.chat_id, + sender="assistant", + content=f"Unique response: {uuid4()}", + ) + + with Database() as connection: + connection.query( + "INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)", + (model_response.id, model_response.timestamp, model_response.chat_id, model_response.sender, model_response.content), + ) + return {"message": model_response} + + +@app.post("/chat/regenerate") +async def chat_regenerate(current_user: User = Depends(get_current_user)) -> dict: + """Regenerate a chat session for the current user.""" + pass + + +@app.get("/chat/list") +async def chat_list(current_user: User = Depends(get_current_user)) -> List[dict]: + """Get a list of chat sessions for the current user.""" + pass + + +@app.get("/chat/{chat_id}") +async def chat(chat_id: str, current_user: User = Depends(get_current_user)) -> dict: + """Get details of a specific chat session.""" + pass + + +############################################ +### Feedback ### +############################################ + + +@app.post("/feedback/{message_id}/thumbs_up") +async def feedback_thumbs_up( + message_id: str, current_user: User = Depends(get_current_user) +) -> None: + with Database() as connection: + connection.query( + "INSERT INTO feedback (id, message_id, feedback) VALUES (?, ?, ?)", + (str(uuid4()), message_id, "thumbs_up"), + ) + + +@app.post("/feedback/{message_id}/thumbs_down") +async def feedback_thumbs_down( + message_id: str, current_user: User = Depends(get_current_user) +) -> None: + with Database() as connection: + connection.query( + "INSERT INTO feedback (id, message_id, feedback) VALUES (?, ?, ?)", + (str(uuid4()), message_id, "thumbs_down"), + ) + + +############################################ +### Other ### +############################################ + + +@app.post("/index/documents") +async def index_documents(chunks: List[Doc], bucket: str, storage_backend: StorageBackend) -> None: + """Index documents in a specified storage backend.""" + document_store.store_documents(chunks, bucket, storage_backend) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/backend/model.py b/backend/model.py new file mode 100644 index 0000000..6c49ec1 --- /dev/null +++ b/backend/model.py @@ -0,0 +1,21 @@ +from datetime import datetime +from uuid import uuid4 +from langchain.docstore.document import Document +from pydantic import BaseModel + +class Message(BaseModel): + id: str + timestamp: str + chat_id: str + sender: str + content: str + +class Doc(BaseModel): + """Represents a document with content and associated metadata.""" + + content: str + metadata: dict + + def to_langchain_document(self) -> Document: + """Converts the current Doc instance into a langchain Document.""" + return Document(page_content=self.content, metadata=self.metadata) diff --git a/database/database.sqlite b/database/database.sqlite new file mode 100644 index 0000000..7c5a9c5 Binary files /dev/null and b/database/database.sqlite differ diff --git a/database/database_init.sql b/database/database_init.sql new file mode 100644 index 0000000..d3f17c9 --- /dev/null +++ b/database/database_init.sql @@ -0,0 +1,34 @@ +-- Go to https://dbdiagram.io/d/RAGAAS-63dbdcc6296d97641d7e07c8 +-- Make your changes +-- Export > Export to PostgresSQL (or other) +-- Translate to SQLite (works with a cmd+k in Cursor, or https://www.rebasedata.com/convert-postgresql-to-sqlite-online) +-- Paste here +-- Replace "CREATE TABLE" with "CREATE TABLE IF NOT EXISTS" + +CREATE TABLE IF NOT EXISTS "user" ( + "email" TEXT PRIMARY KEY, + "password" TEXT +); + +CREATE TABLE IF NOT EXISTS "chat" ( + "id" TEXT PRIMARY KEY, + "timestamp" TEXT, + "user_id" TEXT, + FOREIGN KEY ("user_id") REFERENCES "user" ("email") +); + +CREATE TABLE IF NOT EXISTS "message" ( + "id" TEXT PRIMARY KEY, + "timestamp" TEXT, + "chat_id" TEXT, + "sender" TEXT, + "content" TEXT, + FOREIGN KEY ("chat_id") REFERENCES "chat" ("id") +); + +CREATE TABLE IF NOT EXISTS "feedback" ( + "id" TEXT PRIMARY KEY, + "message_id" TEXT, + "feedback" TEXT, + FOREIGN KEY ("message_id") REFERENCES "message" ("id") +); diff --git a/frontend/lib/chat.py b/frontend/lib/chat.py new file mode 100644 index 0000000..47d1c73 --- /dev/null +++ b/frontend/lib/chat.py @@ -0,0 +1,60 @@ +from uuid import uuid4 +from datetime import datetime + +import streamlit as st + +from dataclasses import dataclass, asdict +from streamlit_feedback import streamlit_feedback + +@dataclass +class Message: + sender: str + content: str + chat_id: str + id: str = None + timestamp: str = None + + def __post_init__(self): + self.id = str(uuid4()) if self.id is None else self.id + self.timestamp = datetime.now().isoformat() if self.timestamp is None else self.timestamp + +def chat(): + prompt = st.chat_input("Say something") + + if prompt: + if len(st.session_state.get("messages", [])) == 0: + chat_id = new_chat() + else: + chat_id = st.session_state.get("chat_id") + + st.session_state.get("messages").append(Message("user", prompt, chat_id)) + response = send_prompt(st.session_state.get("messages")[-1]) + st.session_state.get("messages").append(Message(**response)) + + with st.container(border=True): + for message in st.session_state.get("messages", []): + with st.chat_message(message.sender): + st.write(message.content) + if len(st.session_state.get("messages", [])) > 0 and len(st.session_state.get("messages")) % 2 == 0: + streamlit_feedback(key=str(len(st.session_state.get("messages"))), feedback_type="thumbs", on_submit=lambda feedback: send_feedback(st.session_state.get("messages")[-1].id, feedback)) + + +def new_chat(): + session = st.session_state.get("session") + response = session.post("/chat/new") + st.session_state["chat_id"] = response.json()["chat_id"] + st.session_state["messages"] = [] + return response.json()["chat_id"] + +def send_prompt(message: Message): + session = st.session_state.get("session") + response = session.post(f"/chat/{message.chat_id}/user_message", json=asdict(message)) + print(response.headers) + print(response.text) + return response.json()["message"] + +def send_feedback(message_id: str, feedback: str): + feedback = "thumbs_up" if feedback["score"] == "👍" else "thumbs_down" + session = st.session_state.get("session") + response = session.post(f"/feedback/{message_id}/{feedback}") + return response.text \ No newline at end of file