Skip to content

Commit

Permalink
upd: minor refactos
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisVLRT committed Dec 24, 2023
1 parent e5cda39 commit 95d3a79
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 58 deletions.
6 changes: 6 additions & 0 deletions backend/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ vector_store_provider:
model_source: Chroma
persist_directory: vector_database/
documents_to_retreive: 10
collection_metadata:
hnsw:space: cosine
search_type: similarity
search_options:
score_threshold: 0.5
top_k: 20

chat_message_history_config:
source: ChatMessageHistory
Expand Down
4 changes: 0 additions & 4 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@


async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
"""Get the current user by decoding the JWT token."""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
Expand All @@ -55,7 +54,6 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:

@app.post("/user/signup")
async def signup(user: User) -> dict:
"""Sign up a new user."""
if user_exists(user.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"User {user.email} already registered"
Expand All @@ -67,7 +65,6 @@ async def signup(user: User) -> dict:

@app.delete("/user/")
async def delete_user(current_user: User = Depends(get_current_user)) -> dict:
"""Delete an existing user."""
email = current_user.email
try:
user = get_user(email)
Expand All @@ -85,7 +82,6 @@ async def delete_user(current_user: User = Depends(get_current_user)) -> dict:

@app.post("/user/login")
async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
"""Log in a user and return an access token."""
user = authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
Expand Down
11 changes: 8 additions & 3 deletions backend/rag_components/chain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chat_models.base import SystemMessage
from langchain.prompts import (
ChatPromptTemplate,
Expand Down Expand Up @@ -36,17 +37,21 @@ def get_answer_chain(config, docsearch: VectorStore, memory) -> ConversationalRe

context_with_docs_prompt = PromptTemplate(template=prompts.document_context, input_variables=["page_content", "source"])

final_qa_chain = StuffDocumentsChain(
stuffed_qa_chain = StuffDocumentsChain(
llm_chain=question_answering_chain,
document_variable_name="context",
document_prompt=context_with_docs_prompt,
)

reduced_qa_chain = ReduceDocumentsChain(
combine_documents_chain=stuffed_qa_chain,
)

chain = ConversationalRetrievalChain(
question_generator=condense_question_chain,
retriever=docsearch.as_retriever(search_kwargs={"k": config["vector_store_provider"]["documents_to_retreive"]}),
retriever=docsearch.as_retriever(search_type=config["vector_store_provider"]["search_type"], search_kwargs=config["vector_store_provider"]["search_options"]),
memory=memory,
combine_docs_chain=final_qa_chain,
combine_docs_chain=reduced_qa_chain,
)

return chain, callback_handler
Expand Down
8 changes: 3 additions & 5 deletions backend/rag_components/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from langchain.docstore.document import Document
from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import filter_complex_metadata
from backend.rag_components.chain import get_answer_chain, get_response_stream

Expand All @@ -22,13 +23,11 @@ def __init__(self, config_file_path: Path = None):
self.config = get_config(config_file_path)
self.llm = get_llm_model(self.config)
self.embeddings = get_embedding_model(self.config)
self.vector_store = get_vector_store(self.embeddings, self.config)
self.vector_store: VectorStore = get_vector_store(self.embeddings, self.config)

def generate_response(self, message: Message):
embeddings = get_embedding_model(self.config)
vector_store = get_vector_store(embeddings, self.config)
memory = get_conversation_buffer_memory(self.config, message.chat_id)
answer_chain, callback_handler = get_answer_chain(self.config, vector_store, memory)
answer_chain, callback_handler = get_answer_chain(self.config, self.vector_store, memory)
response_stream = get_response_stream(answer_chain, callback_handler, message.content)
return response_stream

Expand All @@ -53,7 +52,6 @@ def serve():
if __name__ == "__main__":
file_path = Path(__file__).parent.parent.parent / "data"
rag = RAG()

for file in file_path.iterdir():
if file.is_file():
rag.load_file(file)
18 changes: 0 additions & 18 deletions bin/install_with_conda.sh

This file was deleted.

20 changes: 0 additions & 20 deletions bin/install_with_venv.sh

This file was deleted.

8 changes: 0 additions & 8 deletions database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,26 @@


class Database:
"""A simple wrapper class for SQLite database operations."""

def __init__(self):
"""Initialize the Database object by setting the database path based on the environment."""
db_name = (
"test.sqlite" if os.getenv("TESTING", "false").lower() == "true" else "database.sqlite"
)
self.db_path = Path(__file__).parent / db_name

def __enter__(self) -> "Database":
"""Enter the runtime context related to the database connection."""
self.conn = sqlite3.connect(self.db_path)
self.conn.row_factory = sqlite3.Row
return self

def __exit__(
self, exc_type: Optional[type], exc_val: Optional[type], exc_tb: Optional[type]
) -> None:
"""Exit the runtime context and close the database connection properly."""
if exc_type is not None:
self.conn.rollback()
self.conn.commit()
self.conn.close()

def query(self, query: str, params: Optional[Tuple] = None) -> List[List[sqlite3.Row]]:
"""Execute a query against the database."""
cursor = self.conn.cursor()
results = []
commands = filter(None, query.split(";"))
Expand All @@ -40,13 +34,11 @@ def query(self, query: str, params: Optional[Tuple] = None) -> List[List[sqlite3
return results

def query_from_file(self, file_path: Path) -> None:
"""Execute a query from a SQL file."""
with Path.open(file_path, "r") as file:
query = file.read()
self.query(query)

def delete_db(self) -> None:
"""Delete the database file from the filesystem."""
if self.conn:
self.conn.close()
if self.db_path.exists():
Expand Down

0 comments on commit 95d3a79

Please sign in to comment.