diff --git a/backend/config.yaml b/backend/config.yaml index 966e906..ab3d4b2 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -23,6 +23,12 @@ vector_store_provider: model_source: Chroma persist_directory: vector_database/ documents_to_retreive: 10 + collection_metadata: + hnsw:space: cosine + search_type: similarity + search_options: + score_threshold: 0.5 + top_k: 20 chat_message_history_config: source: ChatMessageHistory diff --git a/backend/main.py b/backend/main.py index 67a95ed..5b4659c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -33,7 +33,6 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: - """Get the current user by decoding the JWT token.""" credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -55,7 +54,6 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User: @app.post("/user/signup") async def signup(user: User) -> dict: - """Sign up a new user.""" if user_exists(user.email): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered" @@ -67,7 +65,6 @@ async def signup(user: User) -> dict: @app.delete("/user/") async def delete_user(current_user: User = Depends(get_current_user)) -> dict: - """Delete an existing user.""" email = current_user.email try: user = get_user(email) @@ -85,7 +82,6 @@ async def delete_user(current_user: User = Depends(get_current_user)) -> dict: @app.post("/user/login") async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict: - """Log in a user and return an access token.""" user = authenticate_user(form_data.username, form_data.password) if not user: raise HTTPException( diff --git a/backend/rag_components/chain.py b/backend/rag_components/chain.py index f5e7f91..f1299ee 100644 --- a/backend/rag_components/chain.py +++ b/backend/rag_components/chain.py @@ -1,6 +1,7 @@ import asyncio from langchain.chains import ConversationalRetrievalChain, LLMChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain +from langchain.chains.combine_documents.reduce import ReduceDocumentsChain from langchain.chat_models.base import SystemMessage from langchain.prompts import ( ChatPromptTemplate, @@ -36,17 +37,21 @@ def get_answer_chain(config, docsearch: VectorStore, memory) -> ConversationalRe context_with_docs_prompt = PromptTemplate(template=prompts.document_context, input_variables=["page_content", "source"]) - final_qa_chain = StuffDocumentsChain( + stuffed_qa_chain = StuffDocumentsChain( llm_chain=question_answering_chain, document_variable_name="context", document_prompt=context_with_docs_prompt, ) + reduced_qa_chain = ReduceDocumentsChain( + combine_documents_chain=stuffed_qa_chain, + ) + chain = ConversationalRetrievalChain( question_generator=condense_question_chain, - retriever=docsearch.as_retriever(search_kwargs={"k": config["vector_store_provider"]["documents_to_retreive"]}), + retriever=docsearch.as_retriever(search_type=config["vector_store_provider"]["search_type"], search_kwargs=config["vector_store_provider"]["search_options"]), memory=memory, - combine_docs_chain=final_qa_chain, + combine_docs_chain=reduced_qa_chain, ) return chain, callback_handler diff --git a/backend/rag_components/main.py b/backend/rag_components/main.py index 9a6736b..2a9e9ca 100644 --- a/backend/rag_components/main.py +++ b/backend/rag_components/main.py @@ -2,6 +2,7 @@ from typing import List from langchain.docstore.document import Document +from langchain.vectorstores import VectorStore from langchain.vectorstores.utils import filter_complex_metadata from backend.rag_components.chain import get_answer_chain, get_response_stream @@ -22,13 +23,11 @@ def __init__(self, config_file_path: Path = None): self.config = get_config(config_file_path) self.llm = get_llm_model(self.config) self.embeddings = get_embedding_model(self.config) - self.vector_store = get_vector_store(self.embeddings, self.config) + self.vector_store: VectorStore = get_vector_store(self.embeddings, self.config) def generate_response(self, message: Message): - embeddings = get_embedding_model(self.config) - vector_store = get_vector_store(embeddings, self.config) memory = get_conversation_buffer_memory(self.config, message.chat_id) - answer_chain, callback_handler = get_answer_chain(self.config, vector_store, memory) + answer_chain, callback_handler = get_answer_chain(self.config, self.vector_store, memory) response_stream = get_response_stream(answer_chain, callback_handler, message.content) return response_stream @@ -53,7 +52,6 @@ def serve(): if __name__ == "__main__": file_path = Path(__file__).parent.parent.parent / "data" rag = RAG() - for file in file_path.iterdir(): if file.is_file(): rag.load_file(file) diff --git a/bin/install_with_conda.sh b/bin/install_with_conda.sh deleted file mode 100644 index af3fd9c..0000000 --- a/bin/install_with_conda.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -e - -read -p "Want to install conda env named 'skaff-rag-accelerator'? (y/n)" answer -if [ "$answer" = "y" ]; then - echo "Installing conda env..." - conda create -n skaff-rag-accelerator python=3.11 -y - source $(conda info --base)/etc/profile.d/conda.sh - conda activate skaff-rag-accelerator - echo "Installing requirements..." - pip install -r requirements.txt - python3 -m ipykernel install --user --name=skaff-rag-accelerator - conda install -c conda-forge --name skaff-rag-accelerator notebook -y - echo "Installing pre-commit..." - make install_precommit - echo "Installation complete!"; -else - echo "Installation of conda env aborted!"; -fi diff --git a/bin/install_with_venv.sh b/bin/install_with_venv.sh deleted file mode 100644 index f610f2a..0000000 --- a/bin/install_with_venv.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -e - -read -p "Want to install virtual env named 'venv' in this project ? (y/n)" answer -if [ "$answer" = "y" ]; then - echo "Installing virtual env..." - declare VENV_DIR=$(pwd)/venv - if ! [ -d "$VENV_DIR" ]; then - python3 -m venv $VENV_DIR - fi - - source $VENV_DIR/bin/activate - echo "Installing requirements..." - pip install -r requirements.txt - python3 -m ipykernel install --user --name=venv - echo "Installing pre-commit..." - make install_precommit - echo "Installation complete!"; -else - echo "Installation of virtual env aborted!"; -fi diff --git a/database/database.py b/database/database.py index b618196..715db0f 100644 --- a/database/database.py +++ b/database/database.py @@ -5,17 +5,13 @@ class Database: - """A simple wrapper class for SQLite database operations.""" - def __init__(self): - """Initialize the Database object by setting the database path based on the environment.""" db_name = ( "test.sqlite" if os.getenv("TESTING", "false").lower() == "true" else "database.sqlite" ) self.db_path = Path(__file__).parent / db_name def __enter__(self) -> "Database": - """Enter the runtime context related to the database connection.""" self.conn = sqlite3.connect(self.db_path) self.conn.row_factory = sqlite3.Row return self @@ -23,14 +19,12 @@ def __enter__(self) -> "Database": def __exit__( self, exc_type: Optional[type], exc_val: Optional[type], exc_tb: Optional[type] ) -> None: - """Exit the runtime context and close the database connection properly.""" if exc_type is not None: self.conn.rollback() self.conn.commit() self.conn.close() def query(self, query: str, params: Optional[Tuple] = None) -> List[List[sqlite3.Row]]: - """Execute a query against the database.""" cursor = self.conn.cursor() results = [] commands = filter(None, query.split(";")) @@ -40,13 +34,11 @@ def query(self, query: str, params: Optional[Tuple] = None) -> List[List[sqlite3 return results def query_from_file(self, file_path: Path) -> None: - """Execute a query from a SQL file.""" with Path.open(file_path, "r") as file: query = file.read() self.query(query) def delete_db(self) -> None: - """Delete the database file from the filesystem.""" if self.conn: self.conn.close() if self.db_path.exists():