diff --git a/README.md b/README.md index e68d0c7..1b43b74 100644 --- a/README.md +++ b/README.md @@ -2,75 +2,4 @@ # skaff-rag-accelerator -[![CI status](https://github.com/artefactory/skaff-rag-accelerator/actions/workflows/ci.yaml/badge.svg)](https://github.com/artefactory/skaff-rag-accelerator/actions/workflows/ci.yaml?query=branch%3Amain) -[![Python Version](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10-blue.svg)]() - -[![Linting , formatting, imports sorting: ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) -[![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit) -[![Pre-commit](https://img.shields.io/badge/pre--commit-enabled-informational?logo=pre-commit&logoColor=white)](https://github.com/artefactory/skaff-rag-accelerator/blob/main/.pre-commit-config.yaml) - - -TODO: if not done already, check out the [Skaff documentation](https://artefact.roadie.so/catalog/default/component/repo-builder-ds/docs/) for more information about the generated repository. - -Deploy RAGs quickly - -## Table of Contents - -- [skaff-rag-accelerator](#skaff-rag-accelerator) - - [Table of Contents](#table-of-contents) - - [Installation](#installation) - - [Usage](#usage) - - [Documentation](#documentation) - - [Repository Structure](#repository-structure) - -## Installation - -To install the required packages in a virtual environment, run the following command: - -```bash -make install -``` - -TODO: Choose between conda and venv if necessary or let the Makefile as is and copy/paste the [MORE INFO installation section](MORE_INFO.md#eased-installation) to explain how to choose between conda and venv. - -A complete list of available commands can be found using the following command: - -```bash -make help -``` - -## Usage - -TODO: Add usage instructions here - -## Documentation - -TODO: Github pages is not enabled by default, you need to enable it in the repository settings: Settings > Pages > Source: "Deploy from a branch" / Branch: "gh-pages" / Folder: "/(root)" - -A detailed documentation of this project is available [here](https://artefactory.github.io/skaff-rag-accelerator/) - -To serve the documentation locally, run the following command: - -```bash -mkdocs serve -``` - -To build it and deploy it to GitHub pages, run the following command: - -```bash -make deploy_docs -``` - -## Repository Structure - -``` -. -├── .github <- GitHub Actions workflows and PR template -├── bin <- Bash files -├── config <- Configuration files -├── docs <- Documentation files (mkdocs) -├── lib <- Python modules -├── notebooks <- Jupyter notebooks -├── secrets <- Secret files (ignored by git) -└── tests <- Unit tests -``` + \ No newline at end of file diff --git a/authentication.py b/authentication.py new file mode 100644 index 0000000..e796a22 --- /dev/null +++ b/authentication.py @@ -0,0 +1,41 @@ +from datetime import timedelta, datetime +import os +from pydantic import BaseModel +from jose import jwt + + +from database.database import DatabaseConnection + +SECRET_KEY = os.environ.get("SECRET_KEY", "default_unsecure_key") +ALGORITHM = "HS256" + +class User(BaseModel): + email: str = None + password: str = None + +def create_user(user: User): + with DatabaseConnection() as connection: + connection.query("INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password)) + +def get_user(email: str): + with DatabaseConnection() as connection: + user_row = connection.query("SELECT * FROM user WHERE email = ?", (email,))[0] + for row in user_row: + return User(**row) + raise Exception("User not found") + +def authenticate_user(username: str, password: str): + user = get_user(username) + if not user or not password == user.password: + return False + return user + +def create_access_token(*, data: dict, expires_delta: timedelta = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt diff --git a/client/main.py b/client/main.py new file mode 100644 index 0000000..e69de29 diff --git a/database/database.py b/database/database.py new file mode 100644 index 0000000..0390c9a --- /dev/null +++ b/database/database.py @@ -0,0 +1,30 @@ +from pathlib import Path +import sqlite3 +from typing import List + +class DatabaseConnection: + def __enter__(self): + self.conn = sqlite3.connect(Path(__file__).parent / "database.sqlite") + self.conn.row_factory = sqlite3.Row + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.conn.commit() + self.conn.close() + + def query(self, query, params=None) -> List[List[sqlite3.Row]]: + cursor = self.conn.cursor() + results = [] + commands = filter(None, query.split(";")) + for command in commands: + cursor.execute(command, params or ()) + results.append(cursor.fetchall()) + return results + + def query_from_file(self, file_path): + with open(file_path, 'r') as file: + query = file.read() + self.query(query) + +with DatabaseConnection() as connection: + connection.query_from_file(Path(__file__).parent / "database_init.sql") \ No newline at end of file diff --git a/database/database_init.sql b/database/database_init.sql new file mode 100644 index 0000000..801befb --- /dev/null +++ b/database/database_init.sql @@ -0,0 +1,33 @@ +-- Go to https://dbdiagram.io/d/RAGAAS-63dbdcc6296d97641d7e07c8 +-- Make your changes +-- Export > Export to PostgresSQL (or other) +-- Translate to SQLite (works with a cmd+k in Cursor, or https://www.rebasedata.com/convert-postgresql-to-sqlite-online) +-- Paste here +-- Replace "CREATE TABLE" with "CREATE TABLE IF NOT EXISTS" + +CREATE TABLE IF NOT EXISTS "user" ( + "email" TEXT PRIMARY KEY, + "password" TEXT +); + +CREATE TABLE IF NOT EXISTS "chat" ( + "id" TEXT PRIMARY KEY, + "user_id" TEXT, + FOREIGN KEY ("user_id") REFERENCES "user" ("email") +); + +CREATE TABLE IF NOT EXISTS "message" ( + "id" TEXT PRIMARY KEY, + "timestamp" TEXT, + "chat_id" TEXT, + "sender" TEXT, + "content" TEXT, + FOREIGN KEY ("chat_id") REFERENCES "chat" ("id") +); + +CREATE TABLE IF NOT EXISTS "feedback" ( + "id" TEXT PRIMARY KEY, + "message_id" TEXT, + "feedback" TEXT, + FOREIGN KEY ("message_id") REFERENCES "message" ("id") +); diff --git a/document_store.py b/document_store.py index 327cb1c..035bda1 100644 --- a/document_store.py +++ b/document_store.py @@ -24,11 +24,11 @@ def persist_to_bucket(bucket_path: str, store: Chroma): def store_documents(docs: List[Document], bucket_path: str, storage_backend: StorageBackend): - lagnchain_documents = [doc.to_langchain_document() for doc in docs] + langchain_documents = [doc.to_langchain_document() for doc in docs] embeddings_model = OpenAIEmbeddings() persistent_client = chromadb.PersistentClient() collection = persistent_client.get_or_create_collection(get_storage_root_path(bucket_path, storage_backend)) - collection.add(documents=lagnchain_documents) + collection.add(documents=langchain_documents) langchain_chroma = Chroma( client=persistent_client, collection_name=bucket_path, diff --git a/main.py b/main.py index f857d95..2eabb6b 100644 --- a/main.py +++ b/main.py @@ -1,21 +1,123 @@ -from fastapi import FastAPI, HTTPException, status, Body +from datetime import timedelta from typing import List -from langchain.docstore.document import Document -from document_store import StorageBackend + +from fastapi import FastAPI, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from jose import jwt, JWTError + import document_store -from model import ChatMessage +from authentication import (authenticate_user, create_access_token, create_user, + get_user, User, SECRET_KEY, ALGORITHM) +from document_store import StorageBackend from model import Doc + app = FastAPI() + +############################################ +### Authentication ### +############################################ + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") +async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + email: str = payload.get("email") # 'sub' is commonly used to store user identity + if email is None: + raise credentials_exception + # Here you should fetch the user from the database by user_id + user = get_user(email) + if user is None: + raise credentials_exception + return user + except JWTError: + raise credentials_exception + +@app.post("/user/signup") +async def signup(user: User): + try: + user = get_user(user.email) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"User {user.email} already registered" + ) + except Exception as e: + create_user(user) + return {"email": user.email} + +@app.post("/user/login") +async def login(form_data: OAuth2PasswordRequestForm = Depends()): + user = authenticate_user(form_data.username, form_data.password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token_expires = timedelta(minutes=60) + access_token = create_access_token( + data=user.model_dump(), expires_delta=access_token_expires + ) + return {"access_token": access_token, "token_type": "bearer"} + +@app.get("/user/me") +async def user_me(current_user: User = Depends(get_current_user)): + return current_user + + +############################################ +### Chat ### +############################################ +# P1 +@app.post("/chat/new") +async def chat_new(current_user: User = Depends(get_current_user)): + pass + +# P1 +@app.post("/chat/user_message") +async def chat_prompt(current_user: User = Depends(get_current_user)): + pass + +@app.get("/chat/list") +async def chat_list(current_user: User = Depends(get_current_user)): + pass + +@app.get("/chat/{chat_id}") +async def chat(chat_id: str, current_user: User = Depends(get_current_user)): + pass + + +############################################ +### Feedback ### +############################################ + +@app.post("/feedback/thumbs_up") +async def feedback_thumbs_up(current_user: User = Depends(get_current_user)): + pass + +@app.post("/feedback/thumbs_down") +async def feedback_thumbs_down(current_user: User = Depends(get_current_user)): + pass + +@app.post("/feedback/regenerate") +async def feedback_regenerate(current_user: User = Depends(get_current_user)): + pass + + +############################################ +### Other ### +############################################ + @app.post("/index/documents") async def index_documents(chunks: List[Doc], bucket: str, storage_backend: StorageBackend): document_store.store_documents(chunks, bucket, storage_backend) -@app.post("/chat") -async def chat(chat_message: ChatMessage): - pass - if __name__ == "__main__": import uvicorn diff --git a/model.py b/model.py index c1c3740..dadba97 100644 --- a/model.py +++ b/model.py @@ -3,6 +3,7 @@ class ChatMessage(BaseModel): message: str + message_id: str session_id: str class Doc(BaseModel): diff --git a/requirements.txt b/requirements.txt index 746b00d..ec94ae1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,6 @@ universal_pathlib chromadb langchain langchainhub -gpt4all \ No newline at end of file +gpt4all +python-multipart +httpx \ No newline at end of file diff --git a/sandbox_alexis/main.py b/sandbox_alexis/main.py index f959c9b..2fd2166 100644 --- a/sandbox_alexis/main.py +++ b/sandbox_alexis/main.py @@ -14,5 +14,5 @@ split_documents = load_and_split_document(text=data) root_path = get_storage_root_path("dbt-server-alexis3-36fe-rag", StorageBackend.GCS) -vector_store = Chroma(persist_directory=str(root_path / "chromadb"), embedding_function=GPT4AllEmbeddings()) +vector_store = Chroma(persist_directory=root_path / "chromadb", embedding_function=GPT4AllEmbeddings()) db = vector_store.add_documents(split_documents) \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..ca663d0 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,21 @@ +from fastapi.testclient import TestClient +from main import app + +client = TestClient(app) + +def test_signup(): + response = client.post("/user/signup", json={"email": "test@example.com", "password": "testpassword"}) + assert response.status_code == 200 + assert response.json()["email"] == "test@example.com" + +def test_login(): + response = client.post("/user/login", data={"username": "test@example.com", "password": "testpassword"}) + assert response.status_code == 200 + assert "access_token" in response.json() + +def test_user_me(): + login_response = client.post("/user/login", data={"username": "test@example.com", "password": "testpassword"}) + token = login_response.json()["access_token"] + response = client.get("/user/me", headers={"Authorization": f"Bearer {token}"}) + assert response.status_code == 200 + assert response.json()["email"] == "test@example.com"