Skip to content

Commit

Permalink
index docs to vector store
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah-lauzeral committed Dec 22, 2023
1 parent fa9c948 commit a07a165
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ secrets/*
# Mac OS
.DS_Store
data/
vector_database/

*.sqlite
*.sqlite3
1 change: 1 addition & 0 deletions backend/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ embedding_model_config:
vector_store_provider:
model_source: Chroma
persist_directory: vector_database/
cleanup_mode: full

chat_message_history_config:
source: ChatMessageHistory
Expand Down
2 changes: 1 addition & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def feedback_thumbs_down(
@app.post("/index/documents")
async def index_documents(chunks: List[Doc], bucket: str, storage_backend: StorageBackend) -> None:
"""Index documents in a specified storage backend."""
document_store.store_documents(chunks, bucket, storage_backend)
store_documents(chunks, bucket, storage_backend)


if __name__ == "__main__":
Expand Down
27 changes: 20 additions & 7 deletions backend/rag_components/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
from pathlib import Path
from typing import List

from dotenv import load_dotenv
from langchain.docstore.document import Document
from langchain.indexes import SQLRecordManager, index
from langchain.vectorstores.utils import filter_complex_metadata

from backend.config_renderer import get_config
Expand All @@ -10,6 +13,8 @@
from backend.rag_components.llm import get_llm_model
from backend.rag_components.vector_store import get_vector_store

load_dotenv()


class RAG:
def __init__(self):
Expand All @@ -21,15 +26,23 @@ def __init__(self):
def generate_response():
pass

def load_documents(self, documents: List[Document]):
# TODO améliorer la robustesse du load_document
# TODO agent langchain qui fait le get_best_loader
self.vector_store.add_documents(documents)

def load_file(self, file_path: Path):
def load_documents(self, documents: List[Document], cleanup_mode: str):
record_manager = SQLRecordManager(
namespace="vector_store/my_docs", db_url=os.environ.get("DATABASE_CONNECTION_STRING")
)
record_manager.create_schema()
index(
documents,
record_manager,
self.vector_store,
cleanup=cleanup_mode,
source_id_key="source",
)

def load_file(self, file_path: Path) -> List[Document]:
documents = get_documents(file_path, self.llm)
filtered_documents = filter_complex_metadata(documents)
self.vector_store.add_documents(filtered_documents)
return filtered_documents

# TODO pour chaque fichier -> stocker un hash en base
# TODO avant de loader un fichier dans le vector store si le hash est dans notre db est append le doc dans le vector store que si le hash est inexistant
Expand Down
154 changes: 137 additions & 17 deletions notebooks/docs_loader.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import PyPDFLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n"
"from backend.rag_components.main import RAG\n",
"from langchain.indexes import SQLRecordManager, index\n",
"from langchain.document_loaders.csv_loader import CSVLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from dotenv import load_dotenv\n",
"load_dotenv()"
]
},
{
Expand All @@ -40,16 +44,67 @@
"metadata": {},
"outputs": [],
"source": [
"loader = PyPDFLoader(f\"{parent_directory}/data/Cheat sheet entretien.pdf\", extract_images=True)\n",
"rag = RAG()\n",
"rag.vector_store"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"namespace = f\"chromadb/my_docs\"\n",
"record_manager = SQLRecordManager(\n",
" namespace, db_url=os.environ.get(\"DATABASE_CONNECTION_STRING\")\n",
")\n",
"# pointer le record_manager vers une table dans db sql \n",
"record_manager.create_schema()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loader = CSVLoader(f\"{parent_directory}/data/billionaires_csv.csv\")\n",
"documents = loader.load()\n",
"text_splitter = RecursiveCharacterTextSplitter(\n",
" separators=[\"\\n\\n\", \"\\n\", \".\"], chunk_size=1500, chunk_overlap=100\n",
" separators=[\"\\n\\n\", \"\\n\"], chunk_size=1500, chunk_overlap=100\n",
")\n",
"texts = text_splitter.split_documents(documents)\n",
"\n",
"with open('../data/local_documents.json', 'w') as f:\n",
" for doc in texts:\n",
" f.write(doc.json() + '\\n')"
"texts[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loader = CSVLoader(f\"{parent_directory}/data/billionaires_csv_bis.csv\")\n",
"documents = loader.load()\n",
"text_splitter = RecursiveCharacterTextSplitter(\n",
" separators=[\"\\n\\n\", \"\\n\"], chunk_size=1500, chunk_overlap=100\n",
")\n",
"texts_bis = text_splitter.split_documents(documents)\n",
"texts_bis[:5]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"index(\n",
" [],\n",
" record_manager,\n",
" rag.vector_store,\n",
" cleanup=\"full\", #incremental\n",
" source_id_key=\"source\",\n",
")"
]
},
{
Expand All @@ -58,11 +113,61 @@
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from langchain.docstore.document import Document\n",
"index(\n",
" texts[:100],\n",
" record_manager,\n",
" rag.vector_store,\n",
" cleanup=\"incremental\", #incremental\n",
" source_id_key=\"source\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"index(\n",
" texts_bis[50:100],\n",
" record_manager,\n",
" rag.vector_store,\n",
" cleanup=\"incremental\",\n",
" source_id_key=\"source\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# print(os.environ.get(\"APIFY_API_TOKEN\"))\n",
"\n",
"from langchain.document_loaders.base import Document\n",
"from langchain.utilities import ApifyWrapper\n",
"from dotenv import load_dotenv\n",
"load_dotenv()\n",
"\n",
"apify = ApifyWrapper()\n",
"\n",
"with open('../data/local_documents.json', 'r') as f:\n",
" json_data = [json.loads(line) for line in f]\n"
"loader = apify.call_actor(\n",
" actor_id=\"apify/website-content-crawler\",\n",
" run_input={\"startUrls\": [{\"url\": \"https://python.langchain.com/en/latest/modules/indexes/document_loaders.html\"}]},\n",
" dataset_mapping_function=lambda item: Document(\n",
" page_content=item[\"text\"] or \"\", metadata={\"source\": item[\"url\"]}\n",
" ),\n",
")"
]
},
{
Expand All @@ -71,11 +176,20 @@
"metadata": {},
"outputs": [],
"source": [
"documents = []\n",
"loader #.apify_client()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from apify_client import ApifyClient\n",
"\n",
"apify_client = loader.apify_client\n",
"\n",
"for r in json_data:\n",
" document = Document(page_content=r[\"page_content\"], metadata=r[\"metadata\"])\n",
" documents.append(document)\n"
"len(apify_client.dataset(loader.dataset_id).list_items().items)"
]
},
{
Expand All @@ -84,7 +198,13 @@
"metadata": {},
"outputs": [],
"source": [
"documents"
"index(\n",
" [loader],\n",
" record_manager,\n",
" rag.vector_store,\n",
" cleanup=\"incremental\",\n",
" source_id_key=\"source\",\n",
")"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ pytest
python-jose
trubrics[streamlit]
uvicorn
apify-client
sentence_transformers
Loading

0 comments on commit a07a165

Please sign in to comment.