diff --git a/.gitignore b/.gitignore
index 44d5196..082b2ec 100644
--- a/.gitignore
+++ b/.gitignore
@@ -137,6 +137,8 @@ secrets/*
# Mac OS
.DS_Store
data/
+vector_database/
*.sqlite
-*.sqlite3
\ No newline at end of file
+*.sqlite3
+vector_database/
\ No newline at end of file
diff --git a/README.md b/README.md
index ac80bd0..2279537 100644
--- a/README.md
+++ b/README.md
@@ -1,20 +1,228 @@
-
+
-# skaff-rag-accelerator
+# skaff-rag-accelerator
-```bash
-export PYTHONPATH="/Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah.lauzeral@artefact.com/Mon Drive/internal_projects/skaff-rag-accelerator/"
+
+
+This is a starter kit to deploy a modularizable RAG locally or on the cloud (or across multiple clouds)
+
+## Features
+
+- A configurable RAG setup based around Langchain
+- `RAG` and `RagConfig` python classes to help you set things up
+- A REST API based on FastAPI to provide easy access to the RAG as a web backend
+- A demo Streamlit to serve as a basic working frontend (not production grade)
+- A document loader for the RAG
+- User authentication (unsecure for now, but usable for conversation history)
+- User feedback collection
+- Streamed responses
+
+## Quickstart
+
+In a fresh env:
+```shell
+pip install -r requirements.txt
```
-```bash
-python "/Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah.lauzeral@artefact.com/Mon Drive/internal_projects/skaff-rag-accelerator/backend/main.py"
+You will need to set some env vars, either in a .env file at the project root, or just by exporting them like so:
+```shell
+export OPENAI_API_KEY="xxx" # API key used to query the LLM
+export EMBEDDING_API_KEY="xxx" # API key used to query the embedding model
+export DATABASE_URL="sqlite:///$(pwd)/database/db.sqlite3" # For local developement only. You will need a real, cloud-based SQL database URL for prod.
```
-- comment mettre des docs dans le chatbot
-- comment lancer l'API
-- gestion de la config
-- écrire des helpers de co, pour envoyer des messages...
-- tester différents modèles
-- écrire des snippets de code pour éxpliquer comment charger les docs dans le RAG
+Start the backend server locally
+```shell
+python backend/main.py
+```
-
+Start the frontend demo
+```shell
+streamlit run frontend/app.py
+```
+
+You should than be able to login and chat to the bot:
+![](docs/login_and_chat.gif)
+
+
+## Loading documents
+
+The easiest but least flexible way to load documents to your RAG is to use the `RAG.load_file` method. It will semi-intellignetly try to pick the best Langchain loader and parameters for your file.
+
+Create `backend/load_my_docs.py`:
+```python
+from pathlib import Path
+
+from backend.rag_components.rag import RAG
+
+
+data_directory = Path("data")
+
+config_directory = Path("backend/config.yaml")
+rag = RAG(config_directory)
+
+for file in data_directory.iterdir():
+ if file.is_file():
+ rag.load_file(file)
+```
+
+If you want more flexibility, you can use the `rag.load_documents` method which expects a list of `langchain.docstore.document` objects.
+
+**TODO: example**
+
+#### Document indexing
+
+The document loader maintains an index of the loaded documents. You can change it in the configuration of your RAG at `vector_store.insertion_mode` to `None`, `incremental`, or `full`.
+
+[Details of what that means here.](https://python.langchain.com/docs/modules/data_connection/indexing)
+
+## Configuring the RAG
+
+### The `RAG` object
+
+It provides a unique interface to the RAG's functionalities.
+
+Out of the box, A RAG object is created from your configuration and used by the `/chat/{chat_id}/user_message` endpoint in [`backend/main.py`](backend/main.py)
+
+The RAG class initializes key components (language model, embeddings, vector store), and generates responses to user messages using an answer chain.
+
+It also manages document loading and indexing based on configuration settings.
+
+
+Using the `RAG` class directly:
+```python
+from pathlib import Path
+from backend.rag_components.rag import RAG
+from backend.model import Message
+
+config_directory = Path("backend/config.yaml")
+rag = RAG(config_directory)
+
+message = Message(
+ id="123",
+ timestamp="2021-06-01T12:00:00",
+ chat_id="123",
+ sender="user",
+ content="Hello, how are you?",
+)
+response = rag.generate_response(message)
+print(response)
+```
+
+[Go to the code.](backend/rag_components/rag.py)
+
+### Managing the configuration (`RAGConfig`)
+
+The overall config management works like this:
+![](docs/config_architecture.png)
+
+This means the best way to configure your RAG deployment is to modify the config.yaml file.
+
+This file is then loaded to instanciate a `RAGConfig` object which is used by the `RAG` class.
+
+In the default configuration template ([`backend/config.yaml`](backend/config.yaml)) you will find this:
+```yaml
+# This is the LLM configuration (&LLMConfig is a yaml anchor to reference this block further down in the conf file)
+LLMConfig: &LLMConfig
+
+ # By default we're using a GPT model deployed on Azure. You should be able to change this to any langchain BaseChatModel here: https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/__init__.py
+ source: "AzureChatOpenAI"
+
+ # This is a key-value map of the parameters that will be passed to the langchain chat model object when it's created. Looking at the AzureChatOpenAI source code (https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/azure_openai.py), we input the following params:
+ source_config:
+ openai_api_type: "azure"
+ openai_api_key: {{ OPENAI_API_KEY }}
+ openai_api_base: "https://poc-genai-gpt4.openai.azure.com/"
+ openai_api_version: "2023-07-01-preview"
+ deployment_name: "gpt4v"
+
+ # While the params in source_config are specific to each model, temperature is implemented by all BaseChatModel classes in langchain.
+ temperature: 0.1
+
+# ... Rest of the config ...
+```
+
+Let's say we want to use a Vertex LLM instead. [Looking at the source code of this model in langchain](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/chat_models/vertexai.py#L206C7-L206C19), we find this:
+
+```python
+class ChatVertexAI(_VertexAICommon, BaseChatModel):
+ """`Vertex AI` Chat large language models API."""
+
+ model_name: str = "chat-bison"
+ "Underlying model name."
+ examples: Optional[List[BaseMessage]] = None
+```
+
+Updated `config.yaml` could look like this:
+```yaml
+LLMConfig: &LLMConfig
+ source: "ChatVertexAI"
+ source_config:
+ model_name: gemini-pro
+ temperature: 0.1
+```
+
+
+## Architecture
+
+### The `frontend`, the `backend`, and the `database`
+
+The whole goal of this repo is to decouple the "computing and LLM querying" part from the "rendering a user interface" part. We do this with a typical 3-tier architecture.
+
+![](docs/3t_architecture.png)
+
+- The [frontend](frontend) is the end user facing part. It reches out to the backend **ONLY** through the REST API. We provide a frontend demo here for convenience, but ultimately it could live in a completely different repo, and be written in a completely different language.
+- The [backend](backend) provides a REST API to abstract RAG functionalities. It handles calls to LLMs, tracks conversations and users, handles the state management using a db, and much more. To get the gist of the backend, look at the of the API: http://0.0.0.0:8000/docs
+- The [database](database) is only accessed by the backend and persists the state of the RAG application. [Explore the data model here.](https://dbdiagram.io/d/RAGAAS-63dbdcc6296d97641d7e07c8)
+
+
+## Going further
+
+### Extending the configuration
+
+As you tune this starter kit to your needs, you may need to add specific configuration that your RAG will use.
+
+For example, let's say you want to add the `foo` configuration parameter to your vector store configuration.
+
+First, add it to `config.py` in the part relavant to the vector store:
+
+```python
+# ...
+
+@dataclass
+class VectorStoreConfig:
+ # ... rest of the VectorStoreConfig ...
+
+ foo: str = "bar" # We add foo param, of type str, with the default value "bar"
+
+# ...
+```
+
+This parameter will now be available in your `RAG` object configuration.
+
+```python
+from pathlib import Path
+from backend.rag_components.rag import RAG
+
+config_directory = Path("backend/config.yaml")
+rag = RAG(config_directory)
+
+print(rag.config.vector_store.foo)
+# > bar
+```
+
+if you want to override its default value. You can do that in your `config.yaml`:
+```yaml
+VectorStoreConfig: &VectorStoreConfig
+ # ... rest of the VectorStoreConfig ...
+ foo: baz
+```
+
+```python
+print(rag.config.vector_store.foo)
+# > baz
+```
+
+### Using `RagConfig` directly
+
+TODO: Add usage example here
\ No newline at end of file
diff --git a/backend/authentication.py b/backend/authentication.py
index 0db1baa..79f467b 100644
--- a/backend/authentication.py
+++ b/backend/authentication.py
@@ -5,38 +5,33 @@
from jose import jwt
from pydantic import BaseModel
-from database.database import Database
+from backend.database import Database
SECRET_KEY = os.environ.get("SECRET_KEY", "default_unsecure_key")
ALGORITHM = "HS256"
class User(BaseModel):
- """Represents a user with an email and password."""
-
email: str = None
password: str = None
def create_user(user: User) -> None:
- """Create a new user in the database."""
with Database() as connection:
- connection.query(
+ connection.execute(
"INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password)
)
def get_user(email: str) -> User:
- """Retrieve a user from the database by email."""
with Database() as connection:
- user_row = connection.query("SELECT * FROM user WHERE email = ?", (email,))[0]
+ user_row = connection.execute("SELECT * FROM user WHERE email = ?", (email,))
for row in user_row:
return User(**row)
raise Exception("User not found")
def authenticate_user(username: str, password: str) -> Optional[User]:
- """Authenticate a user by their username and password."""
user = get_user(username)
if not user or not password == user.password:
return False
@@ -44,7 +39,6 @@ def authenticate_user(username: str, password: str) -> Optional[User]:
def create_access_token(*, data: dict, expires_delta: Optional[timedelta] = None) -> str:
- """Create a JWT access token with optional expiry."""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
diff --git a/backend/chatbot.py b/backend/chatbot.py
deleted file mode 100644
index 937a68b..0000000
--- a/backend/chatbot.py
+++ /dev/null
@@ -1,87 +0,0 @@
-from langchain.chains import ConversationalRetrievalChain, LLMChain
-from langchain.chains.combine_documents.stuff import StuffDocumentsChain
-from langchain.chat_models.base import HumanMessage, SystemMessage
-from langchain.prompts import (
- ChatPromptTemplate,
- HumanMessagePromptTemplate,
- PromptTemplate,
-)
-from langchain.vectorstores import VectorStore
-
-from backend.config_renderer import get_config
-from backend.rag_components.chat_message_history import get_conversation_buffer_memory
-from backend.rag_components.embedding import get_embedding_model
-from backend.rag_components.llm import get_llm_model
-from backend.rag_components.vector_store import get_vector_store
-
-
-def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str:
- """Processes the given query through the answer chain and returns the formatted response."""
- return answer_chain.run(query)
-
-
-def get_answer_chain(llm, docsearch: VectorStore, memory) -> ConversationalRetrievalChain:
- """Returns an instance of ConversationalRetrievalChain based on the provided parameters."""
- template = """Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns.
-
-Chat history :
-{chat_history}
-Question : {question}
-
-Rephrased question :
-"""
- condense_question_prompt = PromptTemplate.from_template(template)
- condense_question_chain = LLMChain(
- llm=llm,
- prompt=condense_question_prompt,
- )
-
- messages = [
- SystemMessage(
- content=(
- """As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail."""
- )
- ),
- HumanMessage(content="Respond to the question taking into account the following context."),
- HumanMessagePromptTemplate.from_template("{context}"),
- HumanMessagePromptTemplate.from_template("Question: {question}"),
- ]
- system_prompt = ChatPromptTemplate(messages=messages)
- qa_chain = LLMChain(
- llm=llm,
- prompt=system_prompt,
- )
-
- doc_prompt = PromptTemplate(
- template="Content: {page_content}\nSource: {source}",
- input_variables=["page_content", "source"],
- )
-
- final_qa_chain = StuffDocumentsChain(
- llm_chain=qa_chain,
- document_variable_name="context",
- document_prompt=doc_prompt,
- )
-
- return ConversationalRetrievalChain(
- question_generator=condense_question_chain,
- retriever=docsearch.as_retriever(search_kwargs={"k": 10}),
- memory=memory,
- combine_docs_chain=final_qa_chain,
- verbose=True,
- )
-
-
-if __name__ == "__main__":
- chat_id = "test"
- config = get_config()
- llm = get_llm_model(config)
- embeddings = get_embedding_model(config)
- vector_store = get_vector_store(embeddings, config)
- memory = get_conversation_buffer_memory(config, chat_id)
- answer_chain = get_answer_chain(llm, vector_store, memory)
-
- prompt = "Give me the top 5 bilionnaires in france based on their worth in order of decreasing net worth"
- response = get_response(answer_chain, prompt)
- print("Prompt: ", prompt)
- print("Response: ", response)
diff --git a/backend/config.py b/backend/config.py
new file mode 100644
index 0000000..7f6208e
--- /dev/null
+++ b/backend/config.py
@@ -0,0 +1,81 @@
+from dataclasses import dataclass, field, is_dataclass
+import os
+from pathlib import Path
+from dotenv import load_dotenv
+from jinja2 import Template
+
+from langchain.chat_models.base import BaseChatModel
+from langchain.vectorstores import VectorStore
+from langchain.schema.embeddings import Embeddings
+import yaml
+
+load_dotenv()
+
+@dataclass
+class LLMConfig:
+ source: BaseChatModel | str = "AzureChatOpenAI"
+ source_config: dict = field(default_factory=lambda: {
+ "openai_api_type": "azure",
+ "openai_api_base": "https://poc-genai-gpt4.openai.azure.com/",
+ "openai_api_version": "2023-07-01-preview",
+ "openai_api_key": os.environ.get("OPENAI_API_KEY"),
+ "deployment_name": "gpt4v",
+ })
+
+ temperature: float = 0.1
+
+@dataclass
+class VectorStoreConfig:
+ source: VectorStore | str = "Chroma"
+ source_config: dict = field(default_factory=lambda: {
+ "persist_directory": "vector_database/",
+ "collection_metadata": {
+ "hnsw:space": "cosine"
+ }
+ })
+
+ retreiver_search_type: str = "similarity"
+ retreiver_config: dict = field(default_factory=lambda: {
+ "top_k": 20,
+ "score_threshold": 0.5
+ })
+
+ insertion_mode: str = "full" # "None", "full", "incremental"
+
+@dataclass
+class EmbeddingModelConfig:
+ source: Embeddings | str = "OpenAIEmbeddings"
+ source_config: dict = field(default_factory=lambda: {
+ "openai_api_type": "azure",
+ "openai_api_base": "https://poc-openai-artefact.openai.azure.com/",
+ "openai_api_key": os.environ.get("EMBEDDING_API_KEY"),
+ "deployment": "embeddings",
+ "chunk_size": 500,
+ })
+
+@dataclass
+class DatabaseConfig:
+ database_url: str = os.environ.get("DATABASE_URL")
+
+@dataclass
+class RagConfig:
+ llm: LLMConfig = field(default_factory=LLMConfig)
+ vector_store: VectorStoreConfig = field(default_factory=VectorStoreConfig)
+ embedding_model: EmbeddingModelConfig = field(default_factory=EmbeddingModelConfig)
+ database: DatabaseConfig = field(default_factory=DatabaseConfig)
+ chat_history_window_size: int = 5
+ max_tokens_limit: int = 3000
+
+ @classmethod
+ def from_yaml(cls, yaml_path: Path, env: dict = None):
+ if env is None:
+ env = os.environ
+ with open(yaml_path, "r") as file:
+ template = Template(file.read())
+ config_data = yaml.safe_load(template.render(env))["RagConfig"]
+
+ for field_name, field_type in cls.__annotations__.items():
+ if field_name in config_data and is_dataclass(field_type):
+ config_data[field_name] = field_type(**config_data[field_name])
+
+ return cls(**config_data)
diff --git a/backend/config.yaml b/backend/config.yaml
index 9ca450d..0f2c74c 100644
--- a/backend/config.yaml
+++ b/backend/config.yaml
@@ -1,30 +1,46 @@
-llm_provider_config:
- openai_api_type: azure
- openai_api_base: https://poc-genai-gpt4.openai.azure.com/
- openai_api_version: 2023-07-01-preview
- openai_api_key: {{ OPENAI_API_KEY }}
+# LLMConfig: &LLMConfig
+# source: ChatVertexAI
+# source_config:
+# model_name: gemini-pro
+# temperature: 0.1
-llm_model_config:
- model_source: AzureChatOpenAI
- deployment_name: gpt4v
- temperature: 0.1
- streaming: true
- verbose: true
+LLMConfig: &LLMConfig
+ source: AzureChatOpenAI
+ source_config:
+ openai_api_type: azure
+ openai_api_key: {{ OPENAI_API_KEY }}
+ openai_api_base: https://genai-ds.openai.azure.com/
+ openai_api_version: 2023-07-01-preview
+ deployment_name: gpt4
-embedding_provider_config:
- openai_api_type: azure
- openai_api_base: https://poc-openai-artefact.openai.azure.com/
- openai_api_key: {{ EMBEDDING_API_KEY }}
+VectorStoreConfig: &VectorStoreConfig
+ source: Chroma
+ source_config:
+ persist_directory: vector_database/
+ collection_metadata:
+ hnsw:space: cosine
+ retreiver_search_type: similarity
+ retreiver_config:
+ top_k: 20
+ score_threshold: 0.5
+ insertion_mode: full
-embedding_model_config:
- model_source: OpenAIEmbeddings
- deployment: embeddings
- chunk_size: 500
+EmbeddingModelConfig: &EmbeddingModelConfig
+ source: OpenAIEmbeddings
+ source_config:
+ openai_api_type: azure
+ openai_api_key: {{ EMBEDDING_API_KEY }}
+ openai_api_base: https://poc-openai-artefact.openai.azure.com/
+ deployment: embeddings
+ chunk_size: 500
-vector_store_provider:
- model_source: Chroma
- persist_directory: vector_database/
+DatabaseConfig: &DatabaseConfig
+ database_url: {{ DATABASE_URL }}
-chat_message_history_config:
- source: ChatMessageHistory
- window_size: 5
+RagConfig:
+ llm: *LLMConfig
+ vector_store: *VectorStoreConfig
+ embedding_model: *EmbeddingModelConfig
+ database: *DatabaseConfig
+ chat_history_window_size: 5
+ max_tokens_limit: 3000
\ No newline at end of file
diff --git a/backend/config_renderer.py b/backend/config_renderer.py
deleted file mode 100644
index 8a0e568..0000000
--- a/backend/config_renderer.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import os
-from pathlib import Path
-
-import yaml
-from dotenv import load_dotenv
-from jinja2 import Environment, FileSystemLoader
-
-
-def get_config() -> dict:
- load_dotenv()
- env = Environment(loader=FileSystemLoader(Path(__file__).parent))
- template = env.get_template("config.yaml")
- config = template.render(os.environ)
- return yaml.safe_load(config)
-
-
-def load_models_config():
- with open(Path(__file__).parent / "config.yaml", "r") as file:
- return yaml.safe_load(file)
-
-
-if __name__ == "__main__":
- print(get_config())
diff --git a/backend/database.py b/backend/database.py
new file mode 100644
index 0000000..04db5a1
--- /dev/null
+++ b/backend/database.py
@@ -0,0 +1,97 @@
+from logging import Logger
+import os
+from typing import Optional, Any
+from pathlib import Path
+from dotenv import load_dotenv
+
+import sqlglot
+from dbutils.pooled_db import PooledDB
+from logging import Logger
+
+from backend.logger import get_logger
+
+class Database:
+ def __init__(self, connection_string: str = None, logger: Logger = None):
+ self.connection_string = connection_string or os.getenv("DATABASE_URL")
+ self.logger = logger or get_logger()
+
+ self.logger.debug("Creating connection pool")
+ self.pool = self._create_pool()
+ self.conn = None
+
+ def __enter__(self) -> "Database":
+ self.logger.debug("Getting connection from pool")
+ self.conn = self.pool.connection()
+ return self
+
+ def __exit__(self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any]) -> None:
+ if self.conn:
+ if exc_type:
+ self.logger.error("Transaction failed", exc_info=(exc_type, exc_value, traceback))
+ self.conn.rollback()
+ else:
+ self.conn.commit()
+ self.logger.debug("Returning connection to pool")
+ self.conn.close()
+ self.conn = None
+
+ def execute(self, query: str, params: Optional[tuple] = None) -> Any:
+ cursor = self.conn.cursor()
+ try:
+ self.logger.debug(f"Executing query: {query}")
+ cursor.execute(query, params or ())
+ return cursor
+ except Exception as e:
+ cursor.close()
+ self.logger.exception("Query execution failed", exc_info=e)
+ raise
+
+ def fetchone(self, query: str, params: Optional[tuple] = None) -> Optional[tuple]:
+ cursor = self.execute(query, params)
+ try:
+ return cursor.fetchone()
+ finally:
+ cursor.close()
+
+ def fetchall(self, query: str, params: Optional[tuple] = None) -> list:
+ cursor = self.execute(query, params)
+ try:
+ return cursor.fetchall()
+ finally:
+ cursor.close()
+
+ def initialize_schema(self):
+ try:
+ self.logger.debug("Initializing database schema")
+ sql_script = Path(__file__).parent.joinpath('db_init.sql').read_text()
+ transpiled_sql = sqlglot.transpile(sql_script, read='sqlite', write=self.connection_string.split(":")[0])
+ for statement in transpiled_sql:
+ self.execute(statement)
+ self.logger.debug(f"Database schema initialized successfully for {self.connection_string.split(':')[0]}")
+ except Exception as e:
+ self.logger.exception("Schema initialization failed", exc_info=e)
+ raise
+
+ def _create_pool(self) -> PooledDB:
+ if self.connection_string.startswith("sqlite:///"):
+ import sqlite3
+ return PooledDB(creator=sqlite3, database=self.connection_string.replace("sqlite:///", ""), maxconnections=5)
+ elif self.connection_string.startswith("postgres://"):
+ import psycopg2
+ return PooledDB(creator=psycopg2, dsn=self.connection_string.replace("postgres://", ""), maxconnections=5)
+ elif self.connection_string.startswith("mysql://"):
+ import mysql.connector
+ return PooledDB(creator=mysql.connector, database=self.connection_string.replace("mysql://", ""), maxconnections=5)
+ elif self.connection_string.startswith("sqlserver://"):
+ import pyodbc
+ return PooledDB(creator=pyodbc, dsn=self.connection_string.replace("sqlserver://", ""), maxconnections=5)
+ else:
+ raise ValueError("Unsupported database type")
+
+
+
+if __name__ == "__main__":
+ load_dotenv()
+ with Database(os.getenv("DATABASE_URL")) as db:
+ db.execute("DELETE FROM user WHERE email IN ('alexis')")
+ db.execute("DELETE FROM chat WHERE user_id IN ('alexis')")
diff --git a/database/database_init.sql b/backend/db_init.sql
similarity index 85%
rename from database/database_init.sql
rename to backend/db_init.sql
index eb9f971..8c8e472 100644
--- a/database/database_init.sql
+++ b/backend/db_init.sql
@@ -1,7 +1,6 @@
-- Go to https://dbdiagram.io/d/RAGAAS-63dbdcc6296d97641d7e07c8
-- Make your changes
-- Export > Export to PostgresSQL (or other)
--- Translate to SQLite (works with a cmd+k in Cursor, or https://www.rebasedata.com/convert-postgresql-to-sqlite-online)
-- Paste here
-- Replace "CREATE TABLE" with "CREATE TABLE IF NOT EXISTS"
@@ -12,7 +11,7 @@ CREATE TABLE IF NOT EXISTS "user" (
CREATE TABLE IF NOT EXISTS "chat" (
"id" TEXT PRIMARY KEY,
- "timestamp" TEXT,
+ "timestamp" DATETIME,
"user_id" TEXT,
FOREIGN KEY ("user_id") REFERENCES "user" ("email")
);
diff --git a/backend/document_store.py b/backend/document_store.py
deleted file mode 100644
index 106ec8e..0000000
--- a/backend/document_store.py
+++ /dev/null
@@ -1,48 +0,0 @@
-from enum import Enum
-from pathlib import Path
-from typing import List
-
-import chromadb
-from langchain.docstore.document import Document
-from langchain.embeddings import OpenAIEmbeddings
-from langchain.vectorstores import Chroma
-
-
-class StorageBackend(Enum):
- """Enumeration of supported storage backends."""
-
- LOCAL = "local"
- MEMORY = "memory"
- GCS = "gcs"
- S3 = "s3"
- AZURE = "az"
-
-
-def get_storage_root_path(bucket_name: str, storage_backend: StorageBackend) -> Path:
- """Constructs the root path for the storage based on the bucket name and storage backend."""
- return Path(f"{storage_backend.value}://{bucket_name}")
-
-
-def persist_to_bucket(bucket_path: str, store: Chroma) -> None:
- """Persists the data in the given Chroma store to a bucket."""
- store.persist("./db/chroma")
- # TODO: Uplaod persisted file on disk to bucket_path gcs
-
-
-def store_documents(
- docs: List[Document], bucket_path: str, storage_backend: StorageBackend
-) -> None:
- """Stores a list of documents in a specified bucket using a given storage backend."""
- langchain_documents = [doc.to_langchain_document() for doc in docs]
- embeddings_model = OpenAIEmbeddings()
- persistent_client = chromadb.PersistentClient()
- collection = persistent_client.get_or_create_collection(
- get_storage_root_path(bucket_path, storage_backend)
- )
- collection.add(documents=langchain_documents)
- langchain_chroma = Chroma(
- client=persistent_client,
- collection_name=bucket_path,
- embedding_function=embeddings_model.embed_documents,
- )
- print("There are", langchain_chroma._collection.count(), "in the collection")
diff --git a/backend/logger.py b/backend/logger.py
new file mode 100644
index 0000000..e5c11c7
--- /dev/null
+++ b/backend/logger.py
@@ -0,0 +1,16 @@
+import logging
+from logging import Logger
+
+
+# Implement your custom logging logic here. Eg. send logs to a cloud's logging tool.
+_logger_instance = None
+
+def get_logger() -> Logger:
+ global _logger_instance
+ if _logger_instance is None:
+ _logger_instance = logging.getLogger(__name__)
+ _logger_instance.setLevel(logging.INFO)
+ console_handler = logging.StreamHandler()
+ console_handler.setLevel(logging.INFO)
+ _logger_instance.addHandler(console_handler)
+ return _logger_instance
diff --git a/backend/main.py b/backend/main.py
index 10d71e1..d821136 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -3,14 +3,16 @@
from typing import List
from uuid import uuid4
+
from fastapi import Depends, FastAPI, HTTPException, status
+from fastapi.responses import StreamingResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
+from backend.logger import get_logger
+
-from backend.config_renderer import get_config
-from backend.document_store import StorageBackend, store_documents
-from backend.model import Doc, Message
-from backend.rag_components.document_loader import generate_response
+from backend.model import Message
+from backend.rag_components.rag import RAG
from backend.user_management import (
ALGORITHM,
SECRET_KEY,
@@ -21,20 +23,19 @@
get_user,
user_exists,
)
-from database.database import Database
+from backend.database import Database
+
app = FastAPI()
+logger = get_logger()
-
-############################################
-### Authentication ###
-############################################
+with Database() as connection:
+ connection.initialize_schema()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
- """Get the current user by decoding the JWT token."""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
@@ -42,10 +43,10 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
- email: str = payload.get("email") # 'sub' is commonly used to store user identity
+ email: str = payload.get("email")
if email is None:
raise credentials_exception
- # Here you should fetch the user from the database by user_id
+
user = get_user(email)
if user is None:
raise credentials_exception
@@ -54,95 +55,103 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
raise credentials_exception
-@app.post("/user/signup")
-async def signup(user: User) -> dict:
- """Sign up a new user."""
- if user_exists(user.email):
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered"
- )
-
- create_user(user)
- return {"email": user.email}
-
-
-@app.delete("/user/")
-async def delete_user(current_user: User = Depends(get_current_user)) -> dict:
- """Delete an existing user."""
- email = current_user.email
- try:
- user = get_user(email)
- if user is None:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail=f"User {email} not found"
- )
- delete_user(email)
- return {"detail": f"User {email} deleted"}
- except Exception:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error"
- )
-
-
-@app.post("/user/login")
-async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
- """Log in a user and return an access token."""
- user = authenticate_user(form_data.username, form_data.password)
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Incorrect username or password",
- headers={"WWW-Authenticate": "Bearer"},
- )
- access_token_expires = timedelta(minutes=60)
- access_token = create_access_token(data=user.model_dump(), expires_delta=access_token_expires)
- return {"access_token": access_token, "token_type": "bearer"}
-
-
-@app.get("/user/me")
-async def user_me(current_user: User = Depends(get_current_user)) -> User:
- """Get the current user's profile."""
- return current_user
-
-
############################################
### Chat ###
############################################
-
@app.post("/chat/new")
async def chat_new(current_user: User = Depends(get_current_user)) -> dict:
chat_id = str(uuid4())
timestamp = datetime.now().isoformat()
user_id = current_user.email
with Database() as connection:
- connection.query(
+ connection.execute(
"INSERT INTO chat (id, timestamp, user_id) VALUES (?, ?, ?)",
(chat_id, timestamp, user_id),
)
return {"chat_id": chat_id}
-
+
@app.post("/chat/{chat_id}/user_message")
async def chat_prompt(message: Message, current_user: User = Depends(get_current_user)) -> dict:
with Database() as connection:
- connection.query(
+ connection.execute(
"INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)",
(message.id, message.timestamp, message.chat_id, message.sender, message.content),
)
+
+ context = {
+ "user": current_user.email,
+ "chat_id": message.chat_id,
+ "message_id": message.id,
+ "timestamp": message.timestamp,
+ }
+ rag = RAG(config=Path(__file__).parent / "config.yaml", logger=logger, context=context)
+ response = rag.async_generate_response(message)
+
+ return StreamingResponse(async_llm_response(message.chat_id, response), media_type="text/event-stream")
+
+
+@app.post("/chat/regenerate")
+async def chat_regenerate(current_user: User = Depends(get_current_user)) -> dict:
+ """Regenerate a chat session for the current user."""
+ pass
+
+
+@app.get("/chat/list")
+async def chat_list(current_user: User = Depends(get_current_user)) -> List[dict]:
+ chats = []
+ with Database() as connection:
+ result = connection.execute(
+ "SELECT id, timestamp FROM chat WHERE user_id = ? ORDER BY timestamp DESC",
+ (current_user.email,),
+ )
+ chats = [{"id": row[0], "timestamp": row[1]} for row in result]
+ return chats
+
+
+@app.get("/chat/{chat_id}")
+async def chat(chat_id: str, current_user: User = Depends(get_current_user)) -> dict:
+ messages: List[Message] = []
+ with Database() as connection:
+ result = connection.execute(
+ "SELECT id, timestamp, chat_id, sender, content FROM message WHERE chat_id = ? ORDER BY timestamp ASC",
+ (chat_id,),
+ )
+ for row in result:
+ message = Message(
+ id=row[0],
+ timestamp=row[1],
+ chat_id=row[2],
+ sender=row[3],
+ content=row[4]
+ )
+ messages.append(message)
+ return {"chat_id": chat_id, "messages": [message.model_dump() for message in messages]}
+
- config = get_config()
+async def async_llm_response(chat_id, answer_chain):
+ full_response = ""
+ response_id = str(uuid4())
+ try:
+ async for data in answer_chain:
+ full_response += data
+ yield data.encode("utf-8")
+ except Exception as e:
+ logger.error(f"Error generating response for chat {chat_id}: {e}")
+ full_response = f"Sorry, there was an error generating a response. Please contact an administrator and tell them the following error code: {response_id}, and message: {str(e)}"
+ yield full_response.encode("utf-8")
model_response = Message(
- id=str(uuid4()),
+ id=response_id,
timestamp=datetime.now().isoformat(),
- chat_id=message.chat_id,
+ chat_id=chat_id,
sender="assistant",
- content=response,
+ content=full_response,
)
with Database() as connection:
- connection.query(
+ connection.execute(
"INSERT INTO message (id, timestamp, chat_id, sender, content) VALUES (?, ?, ?, ?, ?)",
(
model_response.id,
@@ -152,38 +161,18 @@ async def chat_prompt(message: Message, current_user: User = Depends(get_current
model_response.content,
),
)
- return {"message": model_response}
-
-
-@app.post("/chat/regenerate")
-async def chat_regenerate(current_user: User = Depends(get_current_user)) -> dict:
- """Regenerate a chat session for the current user."""
- pass
-
-
-@app.get("/chat/list")
-async def chat_list(current_user: User = Depends(get_current_user)) -> List[dict]:
- """Get a list of chat sessions for the current user."""
- pass
-
-
-@app.get("/chat/{chat_id}")
-async def chat(chat_id: str, current_user: User = Depends(get_current_user)) -> dict:
- """Get details of a specific chat session."""
- pass
############################################
### Feedback ###
############################################
-
@app.post("/feedback/{message_id}/thumbs_up")
async def feedback_thumbs_up(
message_id: str, current_user: User = Depends(get_current_user)
) -> None:
with Database() as connection:
- connection.query(
+ connection.execute(
"INSERT INTO feedback (id, message_id, feedback) VALUES (?, ?, ?)",
(str(uuid4()), message_id, "thumbs_up"),
)
@@ -194,21 +183,62 @@ async def feedback_thumbs_down(
message_id: str, current_user: User = Depends(get_current_user)
) -> None:
with Database() as connection:
- connection.query(
+ connection.execute(
"INSERT INTO feedback (id, message_id, feedback) VALUES (?, ?, ?)",
(str(uuid4()), message_id, "thumbs_down"),
)
############################################
-### Other ###
+### Authentication ###
############################################
-@app.post("/index/documents")
-async def index_documents(chunks: List[Doc], bucket: str, storage_backend: StorageBackend) -> None:
- """Index documents in a specified storage backend."""
- document_store.store_documents(chunks, bucket, storage_backend)
+@app.post("/user/signup")
+async def signup(user: User) -> dict:
+ if user_exists(user.email):
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered"
+ )
+
+ create_user(user)
+ return {"email": user.email}
+
+
+@app.delete("/user/")
+async def delete_user(current_user: User = Depends(get_current_user)) -> dict:
+ email = current_user.email
+ try:
+ user = get_user(email)
+ if user is None:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"User {email} not found"
+ )
+ delete_user(email)
+ return {"detail": f"User {email} deleted"}
+ except Exception:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error"
+ )
+
+
+@app.post("/user/login")
+async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
+ user = authenticate_user(form_data.username, form_data.password)
+ if not user:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Incorrect username or password",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ access_token_expires = timedelta(minutes=60)
+ access_token = create_access_token(data=user.model_dump(), expires_delta=access_token_expires)
+ return {"access_token": access_token, "token_type": "bearer"}
+
+
+@app.get("/user/me")
+async def user_me(current_user: User = Depends(get_current_user)) -> User:
+ return current_user
if __name__ == "__main__":
diff --git a/backend/model.py b/backend/model.py
index 5e24a4f..3890f75 100644
--- a/backend/model.py
+++ b/backend/model.py
@@ -1,7 +1,3 @@
-from datetime import datetime
-from uuid import uuid4
-
-from langchain.docstore.document import Document
from pydantic import BaseModel
@@ -11,14 +7,3 @@ class Message(BaseModel):
chat_id: str
sender: str
content: str
-
-
-class Doc(BaseModel):
- """Represents a document with content and associated metadata."""
-
- content: str
- metadata: dict
-
- def to_langchain_document(self) -> Document:
- """Converts the current Doc instance into a langchain Document."""
- return Document(page_content=self.content, metadata=self.metadata)
diff --git a/backend/rag_components/chain.py b/backend/rag_components/chain.py
new file mode 100644
index 0000000..f959f8d
--- /dev/null
+++ b/backend/rag_components/chain.py
@@ -0,0 +1,74 @@
+import asyncio
+from threading import Thread
+from time import sleep
+
+from langchain.chains import ConversationalRetrievalChain, LLMChain
+from langchain.chains.combine_documents.stuff import StuffDocumentsChain
+from langchain.prompts import (
+ ChatPromptTemplate,
+ HumanMessagePromptTemplate,
+ PromptTemplate,
+)
+from langchain.vectorstores import VectorStore
+
+from backend.config import RagConfig
+from backend.rag_components.llm import get_llm_model
+from backend.rag_components import prompts
+from backend.rag_components.logging_callback_handler import LoggingCallbackHandler
+
+
+
+async def async_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str:
+ run = asyncio.create_task(chain.arun({"question": query}))
+
+ async for token in streaming_callback_handler.aiter():
+ yield token
+
+ await run
+
+
+def stream_get_response(chain: ConversationalRetrievalChain, query: str, streaming_callback_handler) -> str:
+ thread = Thread(target=lambda chain, query: chain.run({"question": query}), args=(chain, query))
+ thread.start()
+
+ while thread.is_alive() or not streaming_callback_handler.queue.empty():
+ if not streaming_callback_handler.queue.empty():
+ yield streaming_callback_handler.queue.get()
+ else:
+ sleep(0.1)
+
+ thread.join()
+
+def get_answer_chain(config: RagConfig, docsearch: VectorStore, memory, streaming_callback_handler = None, logging_callback_handler: LoggingCallbackHandler = None) -> ConversationalRetrievalChain:
+ callbacks = [logging_callback_handler] if logging_callback_handler is not None else []
+ streaming_callback = [streaming_callback_handler] if streaming_callback_handler is not None else []
+
+ condense_question_prompt = PromptTemplate.from_template(prompts.condense_history)
+ condense_question_chain = LLMChain(llm=get_llm_model(config), prompt=condense_question_prompt, callbacks=callbacks)
+
+ messages = [
+ HumanMessagePromptTemplate.from_template(prompts.respond_to_question),
+ ]
+ question_answering_prompt = ChatPromptTemplate(messages=messages)
+ streaming_llm = get_llm_model(config, callbacks=streaming_callback + callbacks)
+ question_answering_chain = LLMChain(llm=streaming_llm, prompt=question_answering_prompt, callbacks=callbacks)
+
+ context_with_docs_prompt = PromptTemplate(template=prompts.document_context, input_variables=["page_content", "source"])
+
+ stuffed_qa_chain = StuffDocumentsChain(
+ llm_chain=question_answering_chain,
+ document_variable_name="context",
+ document_prompt=context_with_docs_prompt,
+ callbacks=callbacks
+ )
+
+ chain = ConversationalRetrievalChain(
+ question_generator=condense_question_chain,
+ retriever=docsearch.as_retriever(search_type=config.vector_store.retreiver_search_type, search_kwargs=config.vector_store.retreiver_config),
+ memory=memory,
+ max_tokens_limit=config.max_tokens_limit,
+ combine_docs_chain=stuffed_qa_chain,
+ callbacks=callbacks
+ )
+
+ return chain
diff --git a/backend/rag_components/chat_message_history.py b/backend/rag_components/chat_message_history.py
index 20558b1..7d57135 100644
--- a/backend/rag_components/chat_message_history.py
+++ b/backend/rag_components/chat_message_history.py
@@ -3,21 +3,22 @@
from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import SQLChatMessageHistory
+from backend.config import RagConfig
+
TABLE_NAME = "message_history"
-def get_conversation_buffer_memory(config, chat_id):
+def get_conversation_buffer_memory(config: RagConfig, chat_id):
return ConversationBufferWindowMemory(
memory_key="chat_history",
- chat_memory=get_chat_message_history(chat_id),
+ chat_memory=get_chat_message_history(config, chat_id),
return_messages=True,
- k=config["chat_message_history_config"]["window_size"],
+ k=config.chat_history_window_size,
)
-
-def get_chat_message_history(chat_id):
+def get_chat_message_history(config: RagConfig, chat_id):
return SQLChatMessageHistory(
session_id=chat_id,
- connection_string=os.environ.get("DATABASE_CONNECTION_STRING"),
+ connection_string=config.database.database_url,
table_name=TABLE_NAME,
)
diff --git a/backend/rag_components/document_loader.py b/backend/rag_components/document_loader.py
index da4ae23..16917cb 100644
--- a/backend/rag_components/document_loader.py
+++ b/backend/rag_components/document_loader.py
@@ -50,18 +50,3 @@ def get_loaders() -> List[str]:
if inspect.isclass(obj):
loaders.append(obj.__name__)
return loaders
-
-
-if __name__ == "__main__":
- from pathlib import Path
-
- from backend.config_renderer import get_config
- from frontend.lib.chat import Message
-
- config = get_config()
- data_to_store = Path(f"{Path(__file__).parent.parent.parent}/data/billionaires_csv.csv")
- prompt = "Quelles sont les 5 plus grandes fortunes de France ?"
- chat_id = "test"
- input_query = Message("user", prompt, chat_id)
- response = generate_response(data_to_store, config, input_query)
- print(response)
diff --git a/backend/rag_components/embedding.py b/backend/rag_components/embedding.py
index dab7715..1ce8b8e 100644
--- a/backend/rag_components/embedding.py
+++ b/backend/rag_components/embedding.py
@@ -1,10 +1,10 @@
from langchain import embeddings
+from backend.config import RagConfig
-def get_embedding_model(config):
- spec = getattr(embeddings, config["embedding_model_config"]["model_source"])
- all_config_field = {**config["embedding_model_config"], **config["embedding_provider_config"]}
+def get_embedding_model(config: RagConfig):
+ spec = getattr(embeddings, config.embedding_model.source)
kwargs = {
- key: value for key, value in all_config_field.items() if key in spec.__fields__.keys()
+ key: value for key, value in config.embedding_model.source_config.items() if key in spec.__fields__.keys()
}
return spec(**kwargs)
diff --git a/backend/rag_components/llm.py b/backend/rag_components/llm.py
index a4a5690..f767b11 100644
--- a/backend/rag_components/llm.py
+++ b/backend/rag_components/llm.py
@@ -1,10 +1,15 @@
+from typing import List
from langchain import chat_models
+from langchain.callbacks.base import BaseCallbackHandler
+from backend.config import RagConfig
-def get_llm_model(config):
- llm_spec = getattr(chat_models, config["llm_model_config"]["model_source"])
- all_config_field = {**config["llm_model_config"], **config["llm_provider_config"]}
+
+def get_llm_model(config: RagConfig, callbacks: List[BaseCallbackHandler] = []):
+ llm_spec = getattr(chat_models, config.llm.source)
kwargs = {
- key: value for key, value in all_config_field.items() if key in llm_spec.__fields__.keys()
+ key: value for key, value in config.llm.source_config.items() if key in llm_spec.__fields__.keys()
}
+ kwargs["streaming"] = True
+ kwargs["callbacks"] = callbacks
return llm_spec(**kwargs)
diff --git a/backend/rag_components/logging_callback_handler.py b/backend/rag_components/logging_callback_handler.py
new file mode 100644
index 0000000..551469d
--- /dev/null
+++ b/backend/rag_components/logging_callback_handler.py
@@ -0,0 +1,72 @@
+import json
+from langchain.callbacks.base import BaseCallbackHandler
+
+class CustomJSONEncoder(json.JSONEncoder):
+ def default(self, obj):
+ try:
+ return super().default(obj)
+ except TypeError:
+ return str(obj)
+
+class LoggingCallbackHandler(BaseCallbackHandler):
+ def __init__(self, logger, context: dict = None):
+ self.logger = logger
+ self.context = context or {}
+
+ def _log_event(self, event_name, level, *args, **kwargs):
+ if self.logger is None:
+ return
+
+ log_data = {
+ "event": event_name,
+ "context": self.context,
+ "args": args,
+ **kwargs,
+ }
+ log_message = json.dumps(log_data, cls=CustomJSONEncoder)
+ if level == "info":
+ self.logger.info(log_message)
+ elif level == "error":
+ self.logger.error(log_message)
+
+ def on_llm_start(self, *args, **kwargs):
+ self._log_event("llm_start", "info", *args, **kwargs)
+
+ def on_llm_end(self, *args, **kwargs):
+ self._log_event("llm_end", "info", *args, **kwargs)
+
+ def on_retriever_error(self, *args, **kwargs):
+ self._log_event("retriever_error", "error", *args, **kwargs)
+
+ def on_retriever_end(self, *args, **kwargs):
+ self._log_event("retriever_end", "info", *args, **kwargs)
+
+ def on_llm_error(self, *args, **kwargs):
+ self._log_event("llm_error", "error", *args, **kwargs)
+
+ def on_chain_end(self, *args, **kwargs):
+ self._log_event("chain_end", "info", *args, **kwargs)
+
+ def on_chain_error(self, *args, **kwargs):
+ self._log_event("chain_error", "error", *args, **kwargs)
+
+ def on_agent_action(self, *args, **kwargs):
+ self._log_event("agent_action", "info", *args, **kwargs)
+
+ def on_agent_finish(self, *args, **kwargs):
+ self._log_event("agent_finish", "info", *args, **kwargs)
+
+ def on_tool_start(self, *args, **kwargs):
+ self._log_event("tool_start", "info", *args, **kwargs)
+
+ def on_tool_end(self, *args, **kwargs):
+ self._log_event("tool_end", "info", *args, **kwargs)
+
+ def on_tool_error(self, *args, **kwargs):
+ self._log_event("tool_error", "error", *args, **kwargs)
+
+ def on_text(self, *args, **kwargs):
+ self._log_event("text", "info", *args, **kwargs)
+
+ def on_retry(self, *args, **kwargs):
+ self._log_event("retry", "info", *args, **kwargs)
\ No newline at end of file
diff --git a/backend/rag_components/main.py b/backend/rag_components/main.py
deleted file mode 100644
index db93076..0000000
--- a/backend/rag_components/main.py
+++ /dev/null
@@ -1,48 +0,0 @@
-from pathlib import Path
-from typing import List
-
-from langchain.docstore.document import Document
-from langchain.vectorstores.utils import filter_complex_metadata
-
-from backend.config_renderer import get_config
-from backend.rag_components.document_loader import get_documents
-from backend.rag_components.embedding import get_embedding_model
-from backend.rag_components.llm import get_llm_model
-from backend.rag_components.vector_store import get_vector_store
-
-
-class RAG:
- def __init__(self):
- self.config = get_config()
- self.llm = get_llm_model(self.config)
- self.embeddings = get_embedding_model(self.config)
- self.vector_store = get_vector_store(self.embeddings, self.config)
-
- def generate_response():
- pass
-
- def load_documents(self, documents: List[Document]):
- # TODO améliorer la robustesse du load_document
- # TODO agent langchain qui fait le get_best_loader
- self.vector_store.add_documents(documents)
-
- def load_file(self, file_path: Path):
- documents = get_documents(file_path, self.llm)
- filtered_documents = filter_complex_metadata(documents)
- self.vector_store.add_documents(filtered_documents)
-
- # TODO pour chaque fichier -> stocker un hash en base
- # TODO avant de loader un fichier dans le vector store si le hash est dans notre db est append le doc dans le vector store que si le hash est inexistant
- # TODO éviter de dupliquer les embeddings
-
- def serve():
- pass
-
-
-if __name__ == "__main__":
- file_path = Path(__file__).parent.parent.parent / "data"
- rag = RAG()
-
- for file in file_path.iterdir():
- if file.is_file():
- rag.load_file(file)
diff --git a/backend/rag_components/prompts.py b/backend/rag_components/prompts.py
new file mode 100644
index 0000000..1d14cd4
--- /dev/null
+++ b/backend/rag_components/prompts.py
@@ -0,0 +1,32 @@
+condense_history = """
+Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns.
+
+Chat history :
+{chat_history}
+Question : {question}
+
+Rephrased question :
+"""
+
+rag_system_prompt = """
+As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input.
+It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail.
+"""
+
+respond_to_question = """
+As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input.
+It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail.
+
+
+Respond to the question taking into account the following context.
+
+{context}
+
+Question: {question}
+"""
+
+document_context = """
+Content: {page_content}
+
+Source: {source}
+"""
\ No newline at end of file
diff --git a/backend/rag_components/rag.py b/backend/rag_components/rag.py
new file mode 100644
index 0000000..26d1809
--- /dev/null
+++ b/backend/rag_components/rag.py
@@ -0,0 +1,89 @@
+import asyncio
+from logging import Logger
+from pathlib import Path
+from typing import AsyncIterator, List, Union
+
+
+from langchain.docstore.document import Document
+from langchain.vectorstores.utils import filter_complex_metadata
+from langchain.callbacks import AsyncIteratorCallbackHandler
+from backend.rag_components.chain import get_answer_chain, async_get_response, stream_get_response
+from langchain.indexes import SQLRecordManager, index
+from langchain.chat_models.base import BaseChatModel
+from langchain.vectorstores import VectorStore
+from langchain.schema.embeddings import Embeddings
+
+
+from backend.config import RagConfig
+from backend.model import Message
+from backend.rag_components.chat_message_history import get_conversation_buffer_memory
+from backend.rag_components.document_loader import get_documents
+from backend.rag_components.embedding import get_embedding_model
+from backend.rag_components.llm import get_llm_model
+from backend.rag_components.logging_callback_handler import LoggingCallbackHandler
+from backend.rag_components.streaming_callback_handler import StreamingCallbackHandler
+from backend.rag_components.vector_store import get_vector_store
+
+
+class RAG:
+ def __init__(self, config: Union[Path, RagConfig], logger: Logger = None, context: dict = {}):
+ if isinstance(config, RagConfig):
+ self.config = config
+ else:
+ self.config = RagConfig.from_yaml(config)
+
+ self.logger = logger or Logger("RAG")
+ self.context = context
+
+ self.llm: BaseChatModel = get_llm_model(self.config)
+ self.embeddings: Embeddings = get_embedding_model(self.config)
+ self.vector_store: VectorStore = get_vector_store(self.embeddings, self.config)
+
+ def generate_response(self, message: Message) -> str:
+ loop = asyncio.get_event_loop()
+ response_stream = self.async_generate_response(message)
+ responses = loop.run_until_complete(self._collect_responses(response_stream))
+ return "".join([str(response) for response in responses])
+
+ def stream_generate_response(self, message: Message) -> AsyncIterator[str]:
+ memory = get_conversation_buffer_memory(self.config, message.chat_id)
+ streaming_callback_handler = StreamingCallbackHandler()
+ logging_callback_handler = LoggingCallbackHandler(self.logger, context=self.context)
+ answer_chain = get_answer_chain(self.config, self.vector_store, memory, streaming_callback_handler=streaming_callback_handler, logging_callback_handler=logging_callback_handler)
+ response_stream = stream_get_response(answer_chain, message.content, streaming_callback_handler)
+ return response_stream
+
+ def async_generate_response(self, message: Message) -> AsyncIterator[str]:
+ memory = get_conversation_buffer_memory(self.config, message.chat_id)
+ streaming_callback_handler = AsyncIteratorCallbackHandler()
+ logging_callback_handler = LoggingCallbackHandler(self.logger, context=self.context)
+ answer_chain = get_answer_chain(self.config, self.vector_store, memory, streaming_callback_handler=streaming_callback_handler, logging_callback_handler=logging_callback_handler)
+ response_stream = async_get_response(answer_chain, message.content, streaming_callback_handler)
+ return response_stream
+
+ def load_file(self, file_path: Path) -> List[Document]:
+ documents = get_documents(file_path, self.llm)
+ filtered_documents = filter_complex_metadata(documents)
+ return self.load_documents(filtered_documents)
+
+ def load_documents(self, documents: List[Document], insertion_mode: str = None):
+ insertion_mode = insertion_mode or self.config.vector_store.insertion_mode
+
+ record_manager = SQLRecordManager(
+ namespace="vector_store/my_docs", db_url=self.config.database.database_url
+ )
+ record_manager.create_schema()
+ indexing_output = index(
+ documents,
+ record_manager,
+ self.vector_store,
+ cleanup=insertion_mode,
+ source_id_key="source",
+ )
+ self.logger.info({"event": "load_documents", **indexing_output})
+
+ async def _collect_responses(self, response_stream):
+ responses = []
+ async for response in response_stream:
+ responses.append(response)
+ return responses
diff --git a/backend/rag_components/streaming_callback_handler.py b/backend/rag_components/streaming_callback_handler.py
new file mode 100644
index 0000000..69b615a
--- /dev/null
+++ b/backend/rag_components/streaming_callback_handler.py
@@ -0,0 +1,11 @@
+from multiprocessing import Queue
+from typing import AnyStr
+from langchain_core.callbacks.base import BaseCallbackHandler
+
+class StreamingCallbackHandler(BaseCallbackHandler):
+ queue = Queue()
+
+ def on_llm_new_token(self, token: str, **kwargs: AnyStr) -> None:
+ """Run on new LLM token. Only available when streaming is enabled."""
+ if token is not None and token != "":
+ self.queue.put_nowait(token)
diff --git a/backend/rag_components/vector_store.py b/backend/rag_components/vector_store.py
index 63db85a..a837c20 100644
--- a/backend/rag_components/vector_store.py
+++ b/backend/rag_components/vector_store.py
@@ -2,10 +2,11 @@
from langchain import vectorstores
+from backend.config import RagConfig
-def get_vector_store(embedding_model, config):
- vector_store_spec = getattr(vectorstores, config["vector_store_provider"]["model_source"])
- all_config_field = config["vector_store_provider"]
+
+def get_vector_store(embedding_model, config: RagConfig):
+ vector_store_spec = getattr(vectorstores, config.vector_store.source)
# the vector store class in langchain doesn't have a uniform interface to pass the embedding model
# we extract the propertiy of the class that matches the 'Embeddings' type
@@ -17,7 +18,7 @@ def get_vector_store(embedding_model, config):
(param for param in params_dict.values() if "Embeddings" in str(param.annotation)), None
)
- kwargs = {key: value for key, value in all_config_field.items() if key in parameters.keys()}
+ kwargs = {key: value for key, value in config.vector_store.source_config.items() if key in parameters.keys()}
kwargs[embedding_param.name] = embedding_model
vector_store = vector_store_spec(**kwargs)
return vector_store
diff --git a/backend/user_management.py b/backend/user_management.py
index c6f47cd..2cca3fd 100644
--- a/backend/user_management.py
+++ b/backend/user_management.py
@@ -5,51 +5,44 @@
from jose import jwt
from pydantic import BaseModel
-from database.database import Database
+from backend.database import Database
SECRET_KEY = os.environ.get("SECRET_KEY", "default_unsecure_key")
ALGORITHM = "HS256"
class User(BaseModel):
- """User model representing the user's email and password."""
-
email: str = None
password: str = None
def create_user(user: User) -> None:
- """Create a new user in the database."""
with Database() as connection:
- connection.query(
+ connection.execute(
"INSERT INTO user (email, password) VALUES (?, ?)", (user.email, user.password)
)
def user_exists(email: str) -> bool:
- """Check if a user exists in the database by email."""
with Database() as connection:
- result = connection.query("SELECT 1 FROM user WHERE email = ?", (email,))[0]
+ result = connection.fetchone("SELECT 1 FROM user WHERE email = ?", (email,))
return bool(result)
def get_user(email: str) -> Optional[User]:
- """Retrieve a user from the database by email."""
with Database() as connection:
- user_row = connection.query("SELECT * FROM user WHERE email = ?", (email,))[0]
- for row in user_row:
- return User(**row)
- raise Exception("User not found")
+ user_row = connection.fetchone("SELECT * FROM user WHERE email = ?", (email,))
+ if user_row:
+ return User(email=user_row[0], password=user_row[1])
+ return None
def delete_user(email: str) -> None:
- """Delete a user from the database by email."""
with Database() as connection:
- connection.query("DELETE FROM user WHERE email = ?", (email,))
+ connection.execute("DELETE FROM user WHERE email = ?", (email,))
def authenticate_user(username: str, password: str) -> Optional[User]:
- """Authenticate a user by their username and password."""
user = get_user(username)
if not user or not password == user.password:
return False
@@ -57,7 +50,6 @@ def authenticate_user(username: str, password: str) -> Optional[User]:
def create_access_token(*, data: dict, expires_delta: Optional[timedelta] = None) -> str:
- """Create a JWT access token for a user."""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
diff --git a/bin/install_with_conda.sh b/bin/install_with_conda.sh
deleted file mode 100644
index af3fd9c..0000000
--- a/bin/install_with_conda.sh
+++ /dev/null
@@ -1,18 +0,0 @@
-#!/bin/bash -e
-
-read -p "Want to install conda env named 'skaff-rag-accelerator'? (y/n)" answer
-if [ "$answer" = "y" ]; then
- echo "Installing conda env..."
- conda create -n skaff-rag-accelerator python=3.11 -y
- source $(conda info --base)/etc/profile.d/conda.sh
- conda activate skaff-rag-accelerator
- echo "Installing requirements..."
- pip install -r requirements.txt
- python3 -m ipykernel install --user --name=skaff-rag-accelerator
- conda install -c conda-forge --name skaff-rag-accelerator notebook -y
- echo "Installing pre-commit..."
- make install_precommit
- echo "Installation complete!";
-else
- echo "Installation of conda env aborted!";
-fi
diff --git a/bin/install_with_venv.sh b/bin/install_with_venv.sh
deleted file mode 100644
index f610f2a..0000000
--- a/bin/install_with_venv.sh
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/bin/bash -e
-
-read -p "Want to install virtual env named 'venv' in this project ? (y/n)" answer
-if [ "$answer" = "y" ]; then
- echo "Installing virtual env..."
- declare VENV_DIR=$(pwd)/venv
- if ! [ -d "$VENV_DIR" ]; then
- python3 -m venv $VENV_DIR
- fi
-
- source $VENV_DIR/bin/activate
- echo "Installing requirements..."
- pip install -r requirements.txt
- python3 -m ipykernel install --user --name=venv
- echo "Installing pre-commit..."
- make install_precommit
- echo "Installation complete!";
-else
- echo "Installation of virtual env aborted!";
-fi
diff --git a/database/database.py b/database/database.py
deleted file mode 100644
index b618196..0000000
--- a/database/database.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import os
-import sqlite3
-from pathlib import Path
-from typing import List, Optional, Tuple
-
-
-class Database:
- """A simple wrapper class for SQLite database operations."""
-
- def __init__(self):
- """Initialize the Database object by setting the database path based on the environment."""
- db_name = (
- "test.sqlite" if os.getenv("TESTING", "false").lower() == "true" else "database.sqlite"
- )
- self.db_path = Path(__file__).parent / db_name
-
- def __enter__(self) -> "Database":
- """Enter the runtime context related to the database connection."""
- self.conn = sqlite3.connect(self.db_path)
- self.conn.row_factory = sqlite3.Row
- return self
-
- def __exit__(
- self, exc_type: Optional[type], exc_val: Optional[type], exc_tb: Optional[type]
- ) -> None:
- """Exit the runtime context and close the database connection properly."""
- if exc_type is not None:
- self.conn.rollback()
- self.conn.commit()
- self.conn.close()
-
- def query(self, query: str, params: Optional[Tuple] = None) -> List[List[sqlite3.Row]]:
- """Execute a query against the database."""
- cursor = self.conn.cursor()
- results = []
- commands = filter(None, query.split(";"))
- for command in commands:
- cursor.execute(command, params or ())
- results.append(cursor.fetchall())
- return results
-
- def query_from_file(self, file_path: Path) -> None:
- """Execute a query from a SQL file."""
- with Path.open(file_path, "r") as file:
- query = file.read()
- self.query(query)
-
- def delete_db(self) -> None:
- """Delete the database file from the filesystem."""
- if self.conn:
- self.conn.close()
- if self.db_path.exists():
- self.db_path.unlink(missing_ok=True)
-
-
-with Database() as connection:
- connection.query_from_file(Path(__file__).parent / "database_init.sql")
diff --git a/docs/3t_architecture.png b/docs/3t_architecture.png
new file mode 100644
index 0000000..60bbdd1
Binary files /dev/null and b/docs/3t_architecture.png differ
diff --git a/docs/config_architecture.png b/docs/config_architecture.png
new file mode 100644
index 0000000..5c5c150
Binary files /dev/null and b/docs/config_architecture.png differ
diff --git a/docs/login_and_chat.gif b/docs/login_and_chat.gif
new file mode 100644
index 0000000..b5e42f4
Binary files /dev/null and b/docs/login_and_chat.gif differ
diff --git a/frontend/app.py b/frontend/app.py
index bad3bda..6a41ba1 100644
--- a/frontend/app.py
+++ b/frontend/app.py
@@ -7,10 +7,9 @@
from frontend.lib.auth import auth
from frontend.lib.chat import chat
+from frontend.lib.sidebar import sidebar
load_dotenv()
-embedding_api_base = os.getenv("EMBEDDING_OPENAI_API_BASE")
-embedding_api_key = os.getenv("EMBEDDING_API_KEY")
FASTAPI_URL = os.getenv("FASTAPI_URL", "localhost:8000")
assets = Path(__file__).parent / "assets"
@@ -36,4 +35,5 @@
Le backend du ChatBot est APéïsé ce qui permet une meilleure scalabilité et robustesse."
)
+ sidebar()
chat()
diff --git a/frontend/lib/auth.py b/frontend/lib/auth.py
index 8880bcc..f5dc2bb 100644
--- a/frontend/lib/auth.py
+++ b/frontend/lib/auth.py
@@ -40,8 +40,9 @@ def login_form() -> tuple[bool, Optional[str]]:
session = create_session()
session = authenticate_session(session, token)
else:
- st.error("Wrong authent")
+ st.error("Failed authentication")
st.session_state["session"] = session
+ st.session_state["email"] = username
return session
@@ -59,8 +60,9 @@ def signup_form() -> tuple[bool, Optional[str]]:
session = create_session()
auth_session = authenticate_session(session, token)
else:
- st.error("Failed to signing up")
+ st.error("Failed signing up")
st.session_state["session"] = auth_session
+ st.session_state["email"] = username
return auth_session
diff --git a/frontend/lib/chat.py b/frontend/lib/chat.py
index 3d3d1d6..486151c 100644
--- a/frontend/lib/chat.py
+++ b/frontend/lib/chat.py
@@ -22,20 +22,35 @@ def __post_init__(self):
def chat():
prompt = st.chat_input("Say something")
- if prompt:
- if len(st.session_state.get("messages", [])) == 0:
- chat_id = new_chat()
- else:
- chat_id = st.session_state.get("chat_id")
-
- st.session_state.get("messages").append(Message("user", prompt, chat_id))
- response = send_prompt(st.session_state.get("messages")[-1])
- st.session_state.get("messages").append(Message(**response))
-
with st.container(border=True):
for message in st.session_state.get("messages", []):
with st.chat_message(message.sender):
st.write(message.content)
+
+ if prompt:
+ if len(st.session_state.get("messages", [])) == 0:
+ chat_id = new_chat()
+ else:
+ chat_id = st.session_state.get("chat_id")
+
+ with st.chat_message("user"):
+ st.write(prompt)
+
+ user_message = Message("user", prompt, chat_id)
+ st.session_state["messages"].append(user_message)
+
+ response = send_prompt(user_message)
+ with st.chat_message("assistant"):
+ placeholder = st.empty()
+ full_response = ''
+ for item in response:
+ full_response += item
+ placeholder.write(full_response)
+ placeholder.write(full_response)
+
+ bot_message = Message("assistant", full_response, chat_id)
+ st.session_state["messages"].append(bot_message)
+
if (
len(st.session_state.get("messages", [])) > 0
and len(st.session_state.get("messages")) % 2 == 0
@@ -43,9 +58,7 @@ def chat():
streamlit_feedback(
key=str(len(st.session_state.get("messages"))),
feedback_type="thumbs",
- on_submit=lambda feedback: send_feedback(
- st.session_state.get("messages")[-1].id, feedback
- ),
+ on_submit=lambda feedback: send_feedback(st.session_state.get("messages")[-1].id, feedback),
)
@@ -59,10 +72,10 @@ def new_chat():
def send_prompt(message: Message):
session = st.session_state.get("session")
- response = session.post(f"/chat/{message.chat_id}/user_message", json=asdict(message))
- print(response.headers)
- print(response.text)
- return response.json()["message"]
+ response = session.post(f"/chat/{message.chat_id}/user_message", stream=True, json=asdict(message))
+
+ for line in response.iter_content(chunk_size=16, decode_unicode=True):
+ yield line
def send_feedback(message_id: str, feedback: str):
diff --git a/frontend/lib/sidebar.py b/frontend/lib/sidebar.py
new file mode 100644
index 0000000..663aa34
--- /dev/null
+++ b/frontend/lib/sidebar.py
@@ -0,0 +1,43 @@
+import streamlit as st
+from datetime import datetime
+import humanize
+
+from frontend.lib.chat import Message, new_chat
+
+def sidebar():
+ with st.sidebar:
+ st.sidebar.title("RAG Indus Kit", anchor="top")
+ st.sidebar.markdown(f"Logged in as {st.session_state['email']}
", unsafe_allow_html=True)
+
+ if st.sidebar.button("New Chat", use_container_width=True, key="new_chat_button"):
+ st.session_state["messages"] = []
+
+ with st.empty():
+ chat_list = list_chats()
+ chats_by_time_ago = {}
+ for chat in chat_list:
+ chat_id, timestamp = chat["id"], chat["timestamp"]
+ time_ago = humanize.naturaltime(datetime.now() - datetime.fromisoformat(timestamp))
+ if time_ago not in chats_by_time_ago:
+ chats_by_time_ago[time_ago] = []
+ chats_by_time_ago[time_ago].append(chat)
+
+ for time_ago, chats in chats_by_time_ago.items():
+ st.sidebar.markdown(time_ago)
+ for chat in chats:
+ chat_id = chat["id"]
+ if st.sidebar.button(chat_id, key=chat_id, use_container_width=True):
+ st.session_state["chat_id"] = chat_id
+ messages = [Message(**message) for message in get_chat(chat_id)["messages"]]
+ st.session_state["messages"] = messages
+
+
+def list_chats():
+ session = st.session_state.get("session")
+ response = session.get("/chat/list")
+ return response.json()
+
+def get_chat(chat_id: str):
+ session = st.session_state.get("session")
+ response = session.get(f"/chat/{chat_id}")
+ return response.json()
diff --git a/notebooks/docs_loader.ipynb b/notebooks/docs_loader.ipynb
deleted file mode 100644
index ffe377f..0000000
--- a/notebooks/docs_loader.ipynb
+++ /dev/null
@@ -1,119 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext autoreload\n",
- "%autoreload 2"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "import sys\n",
- "\n",
- "current_directory = os.getcwd()\n",
- "parent_directory = os.path.dirname(current_directory)\n",
- "sys.path.append(parent_directory)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from langchain.document_loaders import PyPDFLoader\n",
- "from langchain.text_splitter import RecursiveCharacterTextSplitter\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "loader = PyPDFLoader(f\"{parent_directory}/data/Cheat sheet entretien.pdf\", extract_images=True)\n",
- "documents = loader.load()\n",
- "text_splitter = RecursiveCharacterTextSplitter(\n",
- " separators=[\"\\n\\n\", \"\\n\", \".\"], chunk_size=1500, chunk_overlap=100\n",
- ")\n",
- "texts = text_splitter.split_documents(documents)\n",
- "\n",
- "with open('../data/local_documents.json', 'w') as f:\n",
- " for doc in texts:\n",
- " f.write(doc.json() + '\\n')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import json\n",
- "from langchain.docstore.document import Document\n",
- "\n",
- "with open('../data/local_documents.json', 'r') as f:\n",
- " json_data = [json.loads(line) for line in f]\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "documents = []\n",
- "\n",
- "for r in json_data:\n",
- " document = Document(page_content=r[\"page_content\"], metadata=r[\"metadata\"])\n",
- " documents.append(document)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "documents"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "skaff-rag-accelerator",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.5"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/notebooks/memory_tests.ipynb b/notebooks/memory_tests.ipynb
deleted file mode 100644
index 520cf4d..0000000
--- a/notebooks/memory_tests.ipynb
+++ /dev/null
@@ -1,252 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext autoreload\n",
- "%autoreload 2"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "import sys\n",
- "\n",
- "current_directory = os.getcwd()\n",
- "parent_directory = os.path.dirname(current_directory)\n",
- "sys.path.append(parent_directory)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from langchain.chains import ConversationalRetrievalChain, LLMChain\n",
- "from langchain.chains.combine_documents.stuff import StuffDocumentsChain\n",
- "from langchain.chat_models.base import SystemMessage, HumanMessage\n",
- "from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate\n",
- "from langchain.vectorstores import VectorStore\n",
- "from langchain.memory.chat_message_histories import SQLChatMessageHistory\n",
- "\n",
- "from backend.config_renderer import get_config\n",
- "from backend.rag_components.embedding import get_embedding_model\n",
- "from backend.rag_components.llm import get_llm_model\n",
- "from backend.rag_components.vector_store import get_vector_store\n",
- "import frontend.lib.auth as auth"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def get_answer_chain(llm, docsearch: VectorStore, memory) -> ConversationalRetrievalChain:\n",
- " \"\"\"Returns an instance of ConversationalRetrievalChain based on the provided parameters.\"\"\"\n",
- " template = \"\"\"Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns.\n",
- "\n",
- "Chat history :\n",
- "{chat_history}\n",
- "Question : {question}\n",
- "\n",
- "Rephrased question :\n",
- "\"\"\"\n",
- " condense_question_prompt = PromptTemplate.from_template(template)\n",
- " condense_question_chain = LLMChain(\n",
- " llm=llm,\n",
- " prompt=condense_question_prompt,\n",
- " )\n",
- "\n",
- " messages = [\n",
- " SystemMessage(\n",
- " content=(\n",
- " \"\"\"As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail.\"\"\"\n",
- " )\n",
- " ),\n",
- " HumanMessage(content=\"Respond to the question taking into account the following context.\"),\n",
- " HumanMessagePromptTemplate.from_template(\"{context}\"),\n",
- " HumanMessagePromptTemplate.from_template(\"Question: {question}\"),\n",
- " ]\n",
- " system_prompt = ChatPromptTemplate(messages=messages)\n",
- " qa_chain = LLMChain(\n",
- " llm=llm,\n",
- " prompt=system_prompt,\n",
- " )\n",
- "\n",
- " doc_prompt = PromptTemplate(\n",
- " template=\"Content: {page_content}\\nSource: {source}\",\n",
- " input_variables=[\"page_content\", \"source\"],\n",
- " )\n",
- "\n",
- " final_qa_chain = StuffDocumentsChain(\n",
- " llm_chain=qa_chain,\n",
- " document_variable_name=\"context\",\n",
- " document_prompt=doc_prompt,\n",
- " )\n",
- "\n",
- " return ConversationalRetrievalChain(\n",
- " question_generator=condense_question_chain,\n",
- " retriever=docsearch.as_retriever(search_kwargs={\"k\": 10}),\n",
- " memory=memory,\n",
- " combine_docs_chain=final_qa_chain,\n",
- " verbose=True,\n",
- " )\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "config = get_config()\n",
- "llm = get_llm_model(config)\n",
- "embeddings = get_embedding_model(config)\n",
- "vector_store = get_vector_store(embeddings, config)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "username = \"slauzeral\"\n",
- "password = \"test\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "success = auth.sign_up(username, password)\n",
- "if success:\n",
- " token = auth.get_token(username, password)\n",
- " session = auth.create_session()\n",
- " auth_session = auth.authenticate_session(session, token)\n",
- "\n",
- "response = auth_session.post(\"/chat/new\")\n",
- "chat_id = response.json()[\"chat_id\"]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from datetime import datetime\n",
- "from uuid import uuid4\n",
- "from typing import Any\n",
- "\n",
- "from langchain.memory.chat_message_histories.sql import DefaultMessageConverter\n",
- "from langchain.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage\n",
- "from sqlalchemy import Column, DateTime, Integer, Text\n",
- "from sqlalchemy.orm import declarative_base\n",
- "from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict\n",
- "import json\n",
- "\n",
- "Base = declarative_base()\n",
- "\n",
- "class CustomMessage(Base):\n",
- " __tablename__ = \"message_test\"\n",
- "\n",
- " id = Column(Text, primary_key=True, default=lambda: str(uuid4())) # default=lambda: str(uuid4())\n",
- " timestamp = Column(DateTime)\n",
- " chat_id = Column(Text)\n",
- " sender = Column(Text)\n",
- " content = Column(Text)\n",
- " message = Column(Text)\n",
- "\n",
- "\n",
- "class CustomMessageConverter(DefaultMessageConverter):\n",
- "\n",
- " def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:\n",
- " sub_message = json.loads(message.content)\n",
- " return CustomMessage(\n",
- " id = sub_message[\"id\"],\n",
- " timestamp = datetime.strptime(sub_message[\"timestamp\"], \"%Y-%m-%d %H:%M:%S.%f\"),\n",
- " chat_id = session_id,\n",
- " sender = message.type,\n",
- " content = sub_message[\"content\"],\n",
- " message = json.dumps(_message_to_dict(message)),\n",
- " )\n",
- "\n",
- " def get_sql_model_class(self) -> Any:\n",
- " return CustomMessage\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "chat_message_history = SQLChatMessageHistory(\n",
- " session_id=chat_id,\n",
- " connection_string=\"sqlite:////Users/sarah.lauzeral/Library/CloudStorage/GoogleDrive-sarah.lauzeral@artefact.com/Mon Drive/internal_projects/skaff-rag-accelerator/database/database.sqlite\",\n",
- " table_name=\"message_test\",\n",
- " session_id_field_name=\"chat_id\",\n",
- " custom_message_converter=CustomMessageConverter(table_name=\"message_test\"),\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "chat_message_history.add_ai_message(json.dumps({\"content\":\"Hi\", \"timestamp\":f\"{datetime.utcnow()}\", \"id\":\"764528762\"}))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "chat_message_history.messages"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "skaff-rag-accelerator",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.5"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/notebooks/test_auth.ipynb b/notebooks/test_auth.ipynb
deleted file mode 100644
index 4980182..0000000
--- a/notebooks/test_auth.ipynb
+++ /dev/null
@@ -1,125 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext autoreload\n",
- "%autoreload 2"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "import sys\n",
- "\n",
- "current_directory = os.getcwd()\n",
- "parent_directory = os.path.dirname(current_directory)\n",
- "sys.path.append(parent_directory)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from typing import NoReturn\n",
- "from fastapi.testclient import TestClient\n",
- "from lib.main import app\n",
- "\n",
- "import streamlit as st\n",
- "import requests\n",
- "\n",
- "client = TestClient(app)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def log_in(username: str, password: str) -> Optional[str]:\n",
- " response = client.post(\n",
- " \"/user/login\", data={\"username\": username, \"password\": password}\n",
- " )\n",
- " if response.status_code == 200 and \"access_token\" in response.json():\n",
- " return response.json()[\"access_token\"]\n",
- " else:\n",
- " return None\n",
- "\n",
- "def sign_up(username: str, password: str) -> str:\n",
- " response = client.post(\n",
- " \"/user/signup\", json={\"username\": username, \"password\": password}\n",
- " )\n",
- " if response.status_code == 200 and \"email\" in response.json():\n",
- " return f\"User {username} registered successfully.\"\n",
- " else:\n",
- " return \"Registration failed.\"\n",
- "\n",
- "def reset_pwd(username: str) -> str:\n",
- " # Assuming there's an endpoint to request a password reset\n",
- " response = client.post(\n",
- " \"/user/reset-password\", json={\"username\": username}\n",
- " )\n",
- " if response.status_code == 200:\n",
- " return \"Password reset link sent.\"\n",
- " else:\n",
- " return \"Failed to send password reset link.\"\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "sign_up(\"sarah.lauzeral@artefact.com\", \"test_pwd\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "sign_up(\"test@example.com\", \"test_pwd\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "skaff-rag-accelerator",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.5"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/requirements.in b/requirements.in
index 9ac0e15..d8f28c7 100644
--- a/requirements.in
+++ b/requirements.in
@@ -33,3 +33,5 @@ pytest
python-jose
trubrics[streamlit]
uvicorn
+apify-client
+sentence_transformers
diff --git a/requirements.txt b/requirements.txt
index 60c3dbf..2d53979 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -32,6 +32,10 @@ anyio==3.7.1
# langchain
# starlette
# watchfiles
+apify-client==1.6.1
+ # via -r requirements.in
+apify-shared==1.1.0
+ # via apify-client
async-timeout==4.0.3
# via
# aiohttp
@@ -140,6 +144,8 @@ fastjsonschema==2.19.0
filelock==3.13.1
# via
# huggingface-hub
+ # torch
+ # transformers
# virtualenv
filetype==1.2.0
# via unstructured
@@ -155,6 +161,7 @@ fsspec==2023.12.1
# gcsfs
# huggingface-hub
# s3fs
+ # torch
# universal-pathlib
gcsfs==2023.12.1
# via -r requirements.in
@@ -204,9 +211,14 @@ httpcore==1.0.2
httptools==0.6.1
# via uvicorn
httpx==0.25.2
- # via -r requirements.in
+ # via
+ # -r requirements.in
+ # apify-client
huggingface-hub==0.19.4
- # via tokenizers
+ # via
+ # sentence-transformers
+ # tokenizers
+ # transformers
humanfriendly==10.0
# via coloredlogs
identify==2.5.33
@@ -231,10 +243,13 @@ jinja2==3.1.2
# via
# altair
# pydeck
+ # torch
jmespath==1.0.1
# via botocore
joblib==1.3.2
- # via nltk
+ # via
+ # nltk
+ # scikit-learn
jsonpatch==1.33
# via langchain
jsonpointer==2.4
@@ -293,9 +308,13 @@ nbformat==5.9.2
nbstripout==0.6.1
# via -r requirements.in
networkx==3.2.1
- # via -r requirements.in
+ # via
+ # -r requirements.in
+ # torch
nltk==3.8.1
- # via unstructured
+ # via
+ # sentence-transformers
+ # unstructured
nodeenv==1.8.0
# via pre-commit
numpy==1.26.2
@@ -309,7 +328,12 @@ numpy==1.26.2
# pandas
# pyarrow
# pydeck
+ # scikit-learn
+ # scipy
+ # sentence-transformers
# streamlit
+ # torchvision
+ # transformers
# unstructured
oauthlib==3.2.2
# via requests-oauthlib
@@ -331,6 +355,7 @@ packaging==23.2
# onnxruntime
# pytest
# streamlit
+ # transformers
pandas==2.1.4
# via
# -r requirements.in
@@ -344,6 +369,7 @@ pillow==10.1.0
# via
# python-pptx
# streamlit
+ # torchvision
platformdirs==4.1.0
# via
# black
@@ -423,6 +449,7 @@ pyyaml==6.0.1
# huggingface-hub
# langchain
# pre-commit
+ # transformers
# uvicorn
rapidfuzz==3.5.2
# via unstructured
@@ -434,6 +461,7 @@ regex==2023.10.3
# via
# nltk
# tiktoken
+ # transformers
requests==2.31.0
# via
# -r requirements.in
@@ -454,6 +482,8 @@ requests==2.31.0
# requests-oauthlib
# streamlit
# tiktoken
+ # torchvision
+ # transformers
# trubrics
# unstructured
requests-oauthlib==1.3.1
@@ -474,6 +504,18 @@ ruff==0.1.7
# via -r requirements.in
s3fs==2023.12.1
# via -r requirements.in
+safetensors==0.4.1
+ # via transformers
+scikit-learn==1.3.2
+ # via sentence-transformers
+scipy==1.11.4
+ # via
+ # scikit-learn
+ # sentence-transformers
+sentence-transformers==2.2.2
+ # via -r requirements.in
+sentencepiece==0.1.99
+ # via sentence-transformers
six==1.16.0
# via
# azure-core
@@ -505,17 +547,23 @@ streamlit==1.29.0
streamlit-feedback==0.1.2
# via trubrics
sympy==1.12
- # via onnxruntime
+ # via
+ # onnxruntime
+ # torch
tabulate==0.9.0
# via unstructured
tenacity==8.2.3
# via
# langchain
# streamlit
+threadpoolctl==3.2.0
+ # via scikit-learn
tiktoken==0.5.2
# via -r requirements.in
tokenizers==0.15.0
- # via chromadb
+ # via
+ # chromadb
+ # transformers
toml==0.10.2
# via streamlit
tomli==2.0.1
@@ -524,6 +572,12 @@ tomli==2.0.1
# pytest
toolz==0.12.0
# via altair
+torch==2.1.2
+ # via
+ # sentence-transformers
+ # torchvision
+torchvision==0.16.2
+ # via sentence-transformers
tornado==6.4
# via streamlit
tqdm==4.66.1
@@ -533,10 +587,14 @@ tqdm==4.66.1
# huggingface-hub
# nltk
# openai
+ # sentence-transformers
+ # transformers
traitlets==5.14.0
# via
# jupyter-core
# nbformat
+transformers==4.36.2
+ # via sentence-transformers
trubrics[streamlit]==1.6.2
# via -r requirements.in
typer==0.9.0
@@ -558,6 +616,7 @@ typing-extensions==4.9.0
# pydantic-core
# sqlalchemy
# streamlit
+ # torch
# typer
# typing-inspect
# unstructured
diff --git a/tests/test_feedback.py b/tests/test_feedback.py
index 9993026..12ae887 100644
--- a/tests/test_feedback.py
+++ b/tests/test_feedback.py
@@ -6,7 +6,7 @@
from fastapi.testclient import TestClient
from lib.main import app
-from database.database import Database
+from backend.database import Database
os.environ["TESTING"] = "True"
client = TestClient(app)
diff --git a/tests/test_users.py b/tests/test_users.py
index f39e802..420a771 100644
--- a/tests/test_users.py
+++ b/tests/test_users.py
@@ -5,7 +5,7 @@
from fastapi.testclient import TestClient
from lib.main import app
-from database.database import Database
+from backend.database import Database
os.environ["TESTING"] = "True"
client = TestClient(app)