diff --git a/.gitignore b/.gitignore index 44d5196..0dee274 100644 --- a/.gitignore +++ b/.gitignore @@ -137,6 +137,7 @@ secrets/* # Mac OS .DS_Store data/ +vector_database/ *.sqlite *.sqlite3 \ No newline at end of file diff --git a/backend/config.yaml b/backend/config.yaml index 9ca450d..2833610 100644 --- a/backend/config.yaml +++ b/backend/config.yaml @@ -24,6 +24,7 @@ embedding_model_config: vector_store_provider: model_source: Chroma persist_directory: vector_database/ + cleanup_mode: full chat_message_history_config: source: ChatMessageHistory diff --git a/backend/main.py b/backend/main.py index 10d71e1..58cea02 100644 --- a/backend/main.py +++ b/backend/main.py @@ -208,7 +208,7 @@ async def feedback_thumbs_down( @app.post("/index/documents") async def index_documents(chunks: List[Doc], bucket: str, storage_backend: StorageBackend) -> None: """Index documents in a specified storage backend.""" - document_store.store_documents(chunks, bucket, storage_backend) + store_documents(chunks, bucket, storage_backend) if __name__ == "__main__": diff --git a/backend/rag_components/main.py b/backend/rag_components/main.py index db93076..c7ba067 100644 --- a/backend/rag_components/main.py +++ b/backend/rag_components/main.py @@ -1,7 +1,10 @@ +import os from pathlib import Path from typing import List +from dotenv import load_dotenv from langchain.docstore.document import Document +from langchain.indexes import SQLRecordManager, index from langchain.vectorstores.utils import filter_complex_metadata from backend.config_renderer import get_config @@ -10,6 +13,8 @@ from backend.rag_components.llm import get_llm_model from backend.rag_components.vector_store import get_vector_store +load_dotenv() + class RAG: def __init__(self): @@ -21,15 +26,23 @@ def __init__(self): def generate_response(): pass - def load_documents(self, documents: List[Document]): - # TODO améliorer la robustesse du load_document - # TODO agent langchain qui fait le get_best_loader - self.vector_store.add_documents(documents) - - def load_file(self, file_path: Path): + def load_documents(self, documents: List[Document], cleanup_mode: str): + record_manager = SQLRecordManager( + namespace="vector_store/my_docs", db_url=os.environ.get("DATABASE_CONNECTION_STRING") + ) + record_manager.create_schema() + index( + documents, + record_manager, + self.vector_store, + cleanup=cleanup_mode, + source_id_key="source", + ) + + def load_file(self, file_path: Path) -> List[Document]: documents = get_documents(file_path, self.llm) filtered_documents = filter_complex_metadata(documents) - self.vector_store.add_documents(filtered_documents) + return filtered_documents # TODO pour chaque fichier -> stocker un hash en base # TODO avant de loader un fichier dans le vector store si le hash est dans notre db est append le doc dans le vector store que si le hash est inexistant diff --git a/notebooks/docs_loader.ipynb b/notebooks/docs_loader.ipynb index ffe377f..7c08f4e 100644 --- a/notebooks/docs_loader.ipynb +++ b/notebooks/docs_loader.ipynb @@ -30,8 +30,12 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.document_loaders import PyPDFLoader\n", - "from langchain.text_splitter import RecursiveCharacterTextSplitter\n" + "from backend.rag_components.main import RAG\n", + "from langchain.indexes import SQLRecordManager, index\n", + "from langchain.document_loaders.csv_loader import CSVLoader\n", + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", + "from dotenv import load_dotenv\n", + "load_dotenv()" ] }, { @@ -40,16 +44,67 @@ "metadata": {}, "outputs": [], "source": [ - "loader = PyPDFLoader(f\"{parent_directory}/data/Cheat sheet entretien.pdf\", extract_images=True)\n", + "rag = RAG()\n", + "rag.vector_store" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "namespace = f\"chromadb/my_docs\"\n", + "record_manager = SQLRecordManager(\n", + " namespace, db_url=os.environ.get(\"DATABASE_CONNECTION_STRING\")\n", + ")\n", + "# pointer le record_manager vers une table dans db sql \n", + "record_manager.create_schema()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loader = CSVLoader(f\"{parent_directory}/data/billionaires_csv.csv\")\n", "documents = loader.load()\n", "text_splitter = RecursiveCharacterTextSplitter(\n", - " separators=[\"\\n\\n\", \"\\n\", \".\"], chunk_size=1500, chunk_overlap=100\n", + " separators=[\"\\n\\n\", \"\\n\"], chunk_size=1500, chunk_overlap=100\n", ")\n", "texts = text_splitter.split_documents(documents)\n", - "\n", - "with open('../data/local_documents.json', 'w') as f:\n", - " for doc in texts:\n", - " f.write(doc.json() + '\\n')" + "texts[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loader = CSVLoader(f\"{parent_directory}/data/billionaires_csv_bis.csv\")\n", + "documents = loader.load()\n", + "text_splitter = RecursiveCharacterTextSplitter(\n", + " separators=[\"\\n\\n\", \"\\n\"], chunk_size=1500, chunk_overlap=100\n", + ")\n", + "texts_bis = text_splitter.split_documents(documents)\n", + "texts_bis[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "index(\n", + " [],\n", + " record_manager,\n", + " rag.vector_store,\n", + " cleanup=\"full\", #incremental\n", + " source_id_key=\"source\",\n", + ")" ] }, { @@ -58,11 +113,61 @@ "metadata": {}, "outputs": [], "source": [ - "import json\n", - "from langchain.docstore.document import Document\n", + "index(\n", + " texts[:100],\n", + " record_manager,\n", + " rag.vector_store,\n", + " cleanup=\"incremental\", #incremental\n", + " source_id_key=\"source\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "index(\n", + " texts_bis[50:100],\n", + " record_manager,\n", + " rag.vector_store,\n", + " cleanup=\"incremental\",\n", + " source_id_key=\"source\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# print(os.environ.get(\"APIFY_API_TOKEN\"))\n", + "\n", + "from langchain.document_loaders.base import Document\n", + "from langchain.utilities import ApifyWrapper\n", + "from dotenv import load_dotenv\n", + "load_dotenv()\n", + "\n", + "apify = ApifyWrapper()\n", "\n", - "with open('../data/local_documents.json', 'r') as f:\n", - " json_data = [json.loads(line) for line in f]\n" + "loader = apify.call_actor(\n", + " actor_id=\"apify/website-content-crawler\",\n", + " run_input={\"startUrls\": [{\"url\": \"https://python.langchain.com/en/latest/modules/indexes/document_loaders.html\"}]},\n", + " dataset_mapping_function=lambda item: Document(\n", + " page_content=item[\"text\"] or \"\", metadata={\"source\": item[\"url\"]}\n", + " ),\n", + ")" ] }, { @@ -71,11 +176,20 @@ "metadata": {}, "outputs": [], "source": [ - "documents = []\n", + "loader #.apify_client()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from apify_client import ApifyClient\n", + "\n", + "apify_client = loader.apify_client\n", "\n", - "for r in json_data:\n", - " document = Document(page_content=r[\"page_content\"], metadata=r[\"metadata\"])\n", - " documents.append(document)\n" + "len(apify_client.dataset(loader.dataset_id).list_items().items)" ] }, { @@ -84,7 +198,13 @@ "metadata": {}, "outputs": [], "source": [ - "documents" + "index(\n", + " [loader],\n", + " record_manager,\n", + " rag.vector_store,\n", + " cleanup=\"incremental\",\n", + " source_id_key=\"source\",\n", + ")" ] }, { diff --git a/requirements.in b/requirements.in index 9ac0e15..d8f28c7 100644 --- a/requirements.in +++ b/requirements.in @@ -33,3 +33,5 @@ pytest python-jose trubrics[streamlit] uvicorn +apify-client +sentence_transformers diff --git a/requirements.txt b/requirements.txt index 60c3dbf..2d53979 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,6 +32,10 @@ anyio==3.7.1 # langchain # starlette # watchfiles +apify-client==1.6.1 + # via -r requirements.in +apify-shared==1.1.0 + # via apify-client async-timeout==4.0.3 # via # aiohttp @@ -140,6 +144,8 @@ fastjsonschema==2.19.0 filelock==3.13.1 # via # huggingface-hub + # torch + # transformers # virtualenv filetype==1.2.0 # via unstructured @@ -155,6 +161,7 @@ fsspec==2023.12.1 # gcsfs # huggingface-hub # s3fs + # torch # universal-pathlib gcsfs==2023.12.1 # via -r requirements.in @@ -204,9 +211,14 @@ httpcore==1.0.2 httptools==0.6.1 # via uvicorn httpx==0.25.2 - # via -r requirements.in + # via + # -r requirements.in + # apify-client huggingface-hub==0.19.4 - # via tokenizers + # via + # sentence-transformers + # tokenizers + # transformers humanfriendly==10.0 # via coloredlogs identify==2.5.33 @@ -231,10 +243,13 @@ jinja2==3.1.2 # via # altair # pydeck + # torch jmespath==1.0.1 # via botocore joblib==1.3.2 - # via nltk + # via + # nltk + # scikit-learn jsonpatch==1.33 # via langchain jsonpointer==2.4 @@ -293,9 +308,13 @@ nbformat==5.9.2 nbstripout==0.6.1 # via -r requirements.in networkx==3.2.1 - # via -r requirements.in + # via + # -r requirements.in + # torch nltk==3.8.1 - # via unstructured + # via + # sentence-transformers + # unstructured nodeenv==1.8.0 # via pre-commit numpy==1.26.2 @@ -309,7 +328,12 @@ numpy==1.26.2 # pandas # pyarrow # pydeck + # scikit-learn + # scipy + # sentence-transformers # streamlit + # torchvision + # transformers # unstructured oauthlib==3.2.2 # via requests-oauthlib @@ -331,6 +355,7 @@ packaging==23.2 # onnxruntime # pytest # streamlit + # transformers pandas==2.1.4 # via # -r requirements.in @@ -344,6 +369,7 @@ pillow==10.1.0 # via # python-pptx # streamlit + # torchvision platformdirs==4.1.0 # via # black @@ -423,6 +449,7 @@ pyyaml==6.0.1 # huggingface-hub # langchain # pre-commit + # transformers # uvicorn rapidfuzz==3.5.2 # via unstructured @@ -434,6 +461,7 @@ regex==2023.10.3 # via # nltk # tiktoken + # transformers requests==2.31.0 # via # -r requirements.in @@ -454,6 +482,8 @@ requests==2.31.0 # requests-oauthlib # streamlit # tiktoken + # torchvision + # transformers # trubrics # unstructured requests-oauthlib==1.3.1 @@ -474,6 +504,18 @@ ruff==0.1.7 # via -r requirements.in s3fs==2023.12.1 # via -r requirements.in +safetensors==0.4.1 + # via transformers +scikit-learn==1.3.2 + # via sentence-transformers +scipy==1.11.4 + # via + # scikit-learn + # sentence-transformers +sentence-transformers==2.2.2 + # via -r requirements.in +sentencepiece==0.1.99 + # via sentence-transformers six==1.16.0 # via # azure-core @@ -505,17 +547,23 @@ streamlit==1.29.0 streamlit-feedback==0.1.2 # via trubrics sympy==1.12 - # via onnxruntime + # via + # onnxruntime + # torch tabulate==0.9.0 # via unstructured tenacity==8.2.3 # via # langchain # streamlit +threadpoolctl==3.2.0 + # via scikit-learn tiktoken==0.5.2 # via -r requirements.in tokenizers==0.15.0 - # via chromadb + # via + # chromadb + # transformers toml==0.10.2 # via streamlit tomli==2.0.1 @@ -524,6 +572,12 @@ tomli==2.0.1 # pytest toolz==0.12.0 # via altair +torch==2.1.2 + # via + # sentence-transformers + # torchvision +torchvision==0.16.2 + # via sentence-transformers tornado==6.4 # via streamlit tqdm==4.66.1 @@ -533,10 +587,14 @@ tqdm==4.66.1 # huggingface-hub # nltk # openai + # sentence-transformers + # transformers traitlets==5.14.0 # via # jupyter-core # nbformat +transformers==4.36.2 + # via sentence-transformers trubrics[streamlit]==1.6.2 # via -r requirements.in typer==0.9.0 @@ -558,6 +616,7 @@ typing-extensions==4.9.0 # pydantic-core # sqlalchemy # streamlit + # torch # typer # typing-inspect # unstructured