diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index df7c1d8..b7b65ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,6 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - id: check-toml - - id: check-yaml - id: check-json - id: check-added-large-files - repo: local @@ -17,7 +16,7 @@ repos: language: system - id: ruff name: Linting (ruff) - entry: ruff + entry: ruff --fix types: [python] language: system - id: nbstripout diff --git a/backend/config.py b/backend/config.py index 62e4de1..bef925d 100644 --- a/backend/config.py +++ b/backend/config.py @@ -71,7 +71,7 @@ class RagConfig: def from_yaml(cls, yaml_path: Path, env: dict = None): if env is None: env = os.environ - with open(yaml_path, "r") as file: + with Path.open(yaml_path, "r") as file: template = Template(file.read()) config_data = yaml.safe_load(template.render(env))["RagConfig"] diff --git a/backend/database.py b/backend/database.py index d8b0ad9..dad53b5 100644 --- a/backend/database.py +++ b/backend/database.py @@ -5,7 +5,7 @@ import sqlglot from dbutils.pooled_db import PooledDB -from dotenv import load_dotenv +from dotenv import load_dotenv # noqa from sqlalchemy.engine.url import make_url from backend.logger import get_logger @@ -75,8 +75,8 @@ def fetchall(self, query: str, params: Optional[tuple] = None) -> list: def initialize_schema(self): try: self.logger.debug("Initializing database schema") - sql_script = Path(__file__).parent.joinpath('db_init.sql').read_text() - transpiled_sql = sqlglot.transpile(sql_script, read='sqlite', write=self.url.drivername.replace("postgresql", "postgres")) + sql_script = Path(__file__).parent.joinpath("db_init.sql").read_text() + transpiled_sql = sqlglot.transpile(sql_script, read="sqlite", write=self.url.drivername.replace("postgresql", "postgres")) for statement in transpiled_sql: self.execute(statement) self.logger.info(f"Database schema initialized successfully for {self.url.drivername}") @@ -88,18 +88,36 @@ def _create_pool(self) -> PooledDB: if self.connection_string.startswith("sqlite:///"): import sqlite3 Path(self.connection_string.replace("sqlite:///", "")).parent.mkdir(parents=True, exist_ok=True) - return PooledDB(creator=sqlite3, database=self.connection_string.replace("sqlite:///", ""), maxconnections=5) + return PooledDB( + creator=sqlite3, + database=self.connection_string.replace("sqlite:///", ""), + maxconnections=5 + ) elif self.connection_string.startswith("postgresql://"): import psycopg2 - return PooledDB(creator=psycopg2, dsn=self.connection_string, maxconnections=5) - elif self.connection_string.startswith("mysql://"): + return PooledDB( + creator=psycopg2, + dsn=self.connection_string, + maxconnections=5 + ) + elif self.connection_string.startswith("mysql://") or \ + self.connection_string.startswith("mysql+pymysql://"): import mysql.connector - return PooledDB(creator=mysql.connector, user=self.url.username, password=self.url.password, host=self.url.host, port=self.url.port, database=self.url.database, maxconnections=5) - elif self.connection_string.startswith("mysql+pymysql://"): - import mysql.connector - return PooledDB(creator=mysql.connector, user=self.url.username, password=self.url.password, host=self.url.host, port=self.url.port, database=self.url.database, maxconnections=5) + return PooledDB( + creator=mysql.connector, + user=self.url.username, + password=self.url.password, + host=self.url.host, + port=self.url.port, + database=self.url.database, + maxconnections=5 + ) elif self.connection_string.startswith("sqlserver://"): import pyodbc - return PooledDB(creator=pyodbc, dsn=self.connection_string.replace("sqlserver://", ""), maxconnections=5) + return PooledDB( + creator=pyodbc, + dsn=self.connection_string.replace("sqlserver://", ""), + maxconnections=5 + ) else: raise ValueError(f"Unsupported database type: {self.url.drivername}") diff --git a/backend/main.py b/backend/main.py index f2ddcaa..30e2af9 100644 --- a/backend/main.py +++ b/backend/main.py @@ -158,7 +158,9 @@ async def stream_response(rag: RAG, chat_id: str, question, response): await asyncio.sleep(0) except Exception as e: logger.error(f"Error generating response for chat {chat_id}: {e}", exc_info=True) - full_response = f"Sorry, there was an error generating a response. Please contact an administrator and provide them with the following error code: {response_id} \n\n {traceback.format_exc()}" + full_response = f"Sorry, there was an error generating a response. \ + Please contact an administrator and provide them with the following error code: \ + {response_id} \n\n {traceback.format_exc()}" yield full_response.encode("utf-8") finally: await log_response_to_db(chat_id, full_response) diff --git a/backend/rag_components/chain.py b/backend/rag_components/chain.py index 6acd6cf..2dc6344 100644 --- a/backend/rag_components/chain.py +++ b/backend/rag_components/chain.py @@ -14,11 +14,19 @@ DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") -def get_answer_chain(config: RagConfig, vector_store: VectorStore, memory: ConversationBufferWindowMemory, logging_callback_handler: LoggingCallbackHandler = None): +def get_answer_chain( + config: RagConfig, + vector_store: VectorStore, + memory: ConversationBufferWindowMemory, + logging_callback_handler: LoggingCallbackHandler = None + ): llm_callbacks = [logging_callback_handler] if logging_callback_handler is not None else [] llm = get_llm_model(config, callbacks=llm_callbacks) - retriever = vector_store.as_retriever(search_type=config.vector_store.retriever_search_type, search_kwargs=config.vector_store.retriever_config) + retriever = vector_store.as_retriever( + search_type=config.vector_store.retriever_search_type, + search_kwargs=config.vector_store.retriever_config + ) condense_question_prompt = PromptTemplate.from_template(prompts.condense_history) question_answering_prompt = ChatPromptTemplate.from_template(prompts.respond_to_question) diff --git a/backend/rag_components/chat_message_history.py b/backend/rag_components/chat_message_history.py index 907964d..cbacae3 100644 --- a/backend/rag_components/chat_message_history.py +++ b/backend/rag_components/chat_message_history.py @@ -1,4 +1,3 @@ -import os from langchain.memory import ConversationBufferWindowMemory from langchain_community.chat_message_histories import SQLChatMessageHistory diff --git a/backend/rag_components/document_loader.py b/backend/rag_components/document_loader.py index 16917cb..87bbfdf 100644 --- a/backend/rag_components/document_loader.py +++ b/backend/rag_components/document_loader.py @@ -32,7 +32,8 @@ def get_best_loader(file_extension: str, llm: BaseChatModel): prompt = PromptTemplate( input_variables=["file_extension", "loaders"], template=""" - Among the following loaders, which is the best to load a "{file_extension}" file? Only give me one the class name without any other special characters. If no relevant loader is found, respond "None". + Among the following loaders, which is the best to load a "{file_extension}" file? \ + Only give me one the class name without any other special characters. If no relevant loader is found, respond "None". Loaders: {loaders} """, diff --git a/backend/rag_components/prompts.py b/backend/rag_components/prompts.py index d64f548..bbb7775 100644 --- a/backend/rag_components/prompts.py +++ b/backend/rag_components/prompts.py @@ -1,5 +1,6 @@ condense_history = """ -Given the conversation history and the following question, can you rephrase the user's question in its original language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns. +Given the conversation history and the following question, can you rephrase the user's question in its original \ + language so that it is self-sufficient. Make sure to avoid the use of unclear pronouns. Chat history : {chat_history} @@ -9,13 +10,17 @@ """ rag_system_prompt = """ -As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. -It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail. +As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on \ + the documents provided as input. +It is essential to respond in the same language in which the question was asked. Responses must be written in \ + a professional style and must demonstrate great attention to detail. """ respond_to_question = """ -As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on the documents provided as input. -It is essential to respond in the same language in which the question was asked. Responses must be written in a professional style and must demonstrate great attention to detail. +As a chatbot assistant, your mission is to respond to user inquiries in a precise and concise manner based on \ + the documents provided as input. +It is essential to respond in the same language in which the question was asked. Responses must be written in \ + a professional style and must demonstrate great attention to detail. Respond to the question taking into account the following context. diff --git a/backend/rag_components/rag.py b/backend/rag_components/rag.py index e58d3ba..fd92e31 100644 --- a/backend/rag_components/rag.py +++ b/backend/rag_components/rag.py @@ -1,4 +1,3 @@ -import asyncio from logging import Logger from pathlib import Path from typing import List, Union diff --git a/backend/rag_components/streaming_callback_handler.py b/backend/rag_components/streaming_callback_handler.py index 6cc3478..98e661b 100644 --- a/backend/rag_components/streaming_callback_handler.py +++ b/backend/rag_components/streaming_callback_handler.py @@ -7,7 +7,6 @@ class StreamingCallbackHandler(BaseCallbackHandler): queue = Queue() - def on_llm_new_token(self, token: str, **kwargs: AnyStr) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" + def on_llm_new_token(self, token: str, **kwargs: AnyStr) -> None: # type: ignore if token is not None and token != "": self.queue.put_nowait(token) diff --git a/frontend/lib/chat.py b/frontend/lib/chat.py index 486151c..ad50a53 100644 --- a/frontend/lib/chat.py +++ b/frontend/lib/chat.py @@ -42,7 +42,7 @@ def chat(): response = send_prompt(user_message) with st.chat_message("assistant"): placeholder = st.empty() - full_response = '' + full_response = "" for item in response: full_response += item placeholder.write(full_response) diff --git a/frontend/lib/sidebar.py b/frontend/lib/sidebar.py index 9737305..89dc3f9 100644 --- a/frontend/lib/sidebar.py +++ b/frontend/lib/sidebar.py @@ -3,7 +3,7 @@ import humanize import streamlit as st -from frontend.lib.chat import Message, new_chat +from frontend.lib.chat import Message def sidebar(): diff --git a/pyproject.toml b/pyproject.toml index 7857214..f9644a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,11 +26,7 @@ select = [ "F", "I", "N", - "D", - "ANN", "Q", - "RET", - "ARG", "PTH", "PD", ] # See: https://beta.ruff.rs/docs/rules/ diff --git a/tests/test_feedback.py b/tests/test_feedback.py deleted file mode 100644 index 12ae887..0000000 --- a/tests/test_feedback.py +++ /dev/null @@ -1,56 +0,0 @@ -import os -from pathlib import Path -from typing import Generator, Tuple - -import pytest -from fastapi.testclient import TestClient -from lib.main import app - -from backend.database import Database - -os.environ["TESTING"] = "True" -client = TestClient(app) - - -@pytest.fixture(scope="module") -def context() -> Generator[Tuple[dict, Database], None, None]: - """Set up the database context and provides a client with an authorized header.""" - db = Database() - with db: - db.query_from_file(Path(__file__).parents[1] / "database" / "database_init.sql") - - user_data = {"email": "test@example.com", "password": "testpassword"} - - response = client.post("/user/signup", json=user_data) - assert response.status_code == 200 - response = client.post( - "/user/login", data={"username": user_data["email"], "password": user_data["password"]} - ) - assert response.status_code == 200 - token = response.json()["access_token"] - client.headers = {**client.headers, "Authorization": f"Bearer {token}"} - - yield client.headers, db - db.delete_db() - - -def test_feedback_thumbs_up(context: Tuple[dict, Database]) -> None: - """Test to ensure that giving a thumbs up feedback works correctly.""" - headers, db = context[0], context[1] - message_id = "test_message_id_1" - response = client.post(f"/feedback/{message_id}/thumbs_up", headers=headers) - assert response.status_code == 200 - with db: - result = db.query("SELECT 1 FROM feedback WHERE message_id = ?", (message_id,))[0] - assert len(result) == 1 - - -def test_feedback_thumbs_down(context: Tuple[dict, Database]) -> None: - """Test to ensure that giving a thumbs down feedback works correctly.""" - headers, db = context[0], context[1] - message_id = "test_message_id_2" - response = client.post(f"/feedback/{message_id}/thumbs_down", headers=headers) - assert response.status_code == 200 - with db: - result = db.query("SELECT 1 FROM feedback WHERE message_id = ?", (message_id,))[0] - assert len(result) == 1 diff --git a/tests/test_users.py b/tests/test_users.py deleted file mode 100644 index 420a771..0000000 --- a/tests/test_users.py +++ /dev/null @@ -1,71 +0,0 @@ -import os -from pathlib import Path - -import pytest -from fastapi.testclient import TestClient -from lib.main import app - -from backend.database import Database - -os.environ["TESTING"] = "True" -client = TestClient(app) - - -@pytest.fixture() -def initialize_database() -> Database: - """Initialize the test database by applying the initialization SQL script.""" - db = Database() - with db: - db.query_from_file(Path(__file__).parents[1] / "database" / "database_init.sql") - yield db - db.delete_db() - - -def test_signup(initialize_database: Database) -> None: - """Test the user signup process.""" - response = client.post( - "/user/signup", json={"email": "test@example.com", "password": "testpassword"} - ) - assert response.status_code == 200 - assert response.json()["email"] == "test@example.com" - - response = client.post( - "/user/signup", json={"email": "test@example.com", "password": "testpassword"} - ) - assert response.status_code == 400 - assert "detail" in response.json() - assert response.json()["detail"] == "User test@example.com already registered" - - -def test_login(initialize_database: Database) -> None: - """Test the user login process.""" - response = client.post( - "/user/signup", json={"email": "test@example.com", "password": "testpassword"} - ) - assert response.status_code == 200 - assert response.json()["email"] == "test@example.com" - response = client.post( - "/user/login", data={"username": "test@example.com", "password": "testpassword"} - ) - assert response.status_code == 200 - assert "access_token" in response.json() - - -def test_user_me(initialize_database: Database) -> None: - """Test the retrieval of user profile information.""" - response = client.post( - "/user/signup", json={"email": "test@example.com", "password": "testpassword"} - ) - assert response.status_code == 200 - assert response.json()["email"] == "test@example.com" - - response = client.post( - "/user/login", data={"username": "test@example.com", "password": "testpassword"} - ) - assert response.status_code == 200 - assert "access_token" in response.json() - - token = response.json()["access_token"] - response = client.get("/user/me", headers={"Authorization": f"Bearer {token}"}) - assert response.status_code == 200 - assert response.json()["email"] == "test@example.com"