Skip to content

Commit

Permalink
update modifs
Browse files Browse the repository at this point in the history
  • Loading branch information
sarah-lauzeral committed Dec 19, 2023
1 parent e895a98 commit 9872a18
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 90 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@
export PYTHONPATH="/Users/sarah.lauzeral/Library/CloudStorage/[email protected]/Mon Drive/internal_projects/skaff-rag-accelerator/"
```

```bash
python "/Users/sarah.lauzeral/Library/CloudStorage/[email protected]/Mon Drive/internal_projects/skaff-rag-accelerator/backend/main.py"
```

</div>
107 changes: 34 additions & 73 deletions backend/chatbot.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from pathlib import Path
from typing import List, Optional, Tuple, Union

import pandas as pd
from config_renderer import get_config
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chat_models import AzureChatOpenAI
from langchain.document_loaders import (
CSVLoader,
Docx2txtLoader,
PyPDFLoader,
UnstructuredExcelLoader,
UnstructuredPowerPointLoader,
)
from langchain.document_loaders import DirectoryLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.memory import (
ConversationBufferMemory,
Expand All @@ -25,20 +19,20 @@
HumanMessagePromptTemplate,
PromptTemplate,
)
from langchain.schema.document import Document
from langchain.schema.messages import HumanMessage, SystemMessage
from langchain.text_splitter import (
CharacterTextSplitter,
RecursiveCharacterTextSplitter,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma

from backend.llm import get_model_instance
from backend.embedding import get_embedding_model_instance
from backend.llm import get_llm_model_instance
from backend.vector_store import get_vector_store


def get_response(answer_chain: ConversationalRetrievalChain, query: str) -> str:
"""Processes the given query through the answer chain and returns the formatted response."""
return answer_chain.run(query)


def get_answer_chain(
llm, docsearch: Chroma, memory: ConversationBufferMemory
) -> ConversationalRetrievalChain:
Expand Down Expand Up @@ -111,56 +105,18 @@ def get_embeddings_model(embedding_api_base: str, embedding_api_key: str) -> Ope
)


def get_documents(data: pd.DataFrame) -> List[Document]:
"""Converts a dataframe into a list of Document objects."""
docs = data["answer"].tolist()
metadatas = data[["source", "question"]].to_dict("records")

documents = []
for text, metadata in zip(docs, metadatas):
document = Document(page_content=text, metadata=metadata)
documents.append(document)
return documents


def load_documents(file_extension: str, file_path: str):
"""Loads documents based on the file extension and path provided."""
if file_extension == ".pdf":
loader = PyPDFLoader(file_path)
elif file_extension in [".csv"]:
loader = CSVLoader(file_path, encoding="utf-8-sig", csv_args={"delimiter": "\t"})
elif file_extension in [".xlsx"]:
loader = UnstructuredExcelLoader(file_path, mode="elements")
elif file_extension in [".pptx"]:
loader = UnstructuredPowerPointLoader(file_path)
elif file_extension in [".docx"]:
loader = Docx2txtLoader(file_path)
else:
raise Exception("Unsupported file type!")

return loader.load()


def get_chunks(
_documents: List[str], chunk_size: int, chunk_overlap: int, text_splitter_type: int
) -> List[str]:
"""Splits the documents into chunks."""
if text_splitter_type == "basic":
text_splitter = CharacterTextSplitter(
separator="\n\n",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
elif text_splitter_type == "recursive":
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", " "], chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
return text_splitter.split_documents(_documents)


def get_vector_store(_texts: List[str], _embeddings: OpenAIEmbeddings) -> Chroma:
"""Returns an instance of Chroma based on the provided parameters."""
return Chroma.from_documents(_texts, _embeddings)
def text_chunker(
directory_path: Path,
chunk_size: int,
chunk_overlap: int,
separators: Optional[List[str]] = ["\n\n", "\n", "\t"],
):
loader = DirectoryLoader(directory_path, glob="**/*.txt")
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
separators=separators, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
return text_splitter.split_documents(docs)


def choose_memory_type(
Expand Down Expand Up @@ -198,18 +154,23 @@ def choose_memory_type(
)
return msgs, memory

if __name__ == "__main__":
llm = get_model_instance()

embeddings = get_embeddings_model("https://poc-openai-artefact.openai.azure.com/", "")

documents = load_documents(".csv", str(Path(__file__).parent / "billionaires_csv.csv"))
texts = get_chunks(documents, chunk_size=1500, chunk_overlap=200, text_splitter_type="recursive")
docsearch = get_vector_store(texts, embeddings)
if __name__ == "__main__":
config = get_config()
llm = get_llm_model_instance(config)
embeddings = get_embedding_model_instance(config)
vector_store = get_vector_store(embeddings)
texts = text_chunker(
directory_path=Path(__file__).parent.parent / "data/",
chunk_size=2000,
chunk_overlap=200,
)
vector_store.add_documents(texts)
msgs, memory = choose_memory_type(memory_type="buffer")
answer_chain = get_answer_chain(llm, docsearch, memory)
answer_chain = get_answer_chain(llm, vector_store, memory)

prompt = "Qui sont les 3 personnes les plus riches en france ?"
# prompt = "Qui sont les 3 personnes les plus riches en france ?"
prompt = "Combien y a t'il de milliardaires en France ?"
response = get_response(answer_chain, prompt)
print("Prompt :", prompt)
print("Response: ", response)
23 changes: 23 additions & 0 deletions backend/config_renderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
from pathlib import Path

import yaml
from dotenv import load_dotenv
from jinja2 import Environment, FileSystemLoader


def get_config() -> dict:
load_dotenv()
env = Environment(loader=FileSystemLoader(Path(__file__).parent))
template = env.get_template("models_config.yaml")
config = template.render(os.environ)
return yaml.safe_load(config)


def load_models_config():
with open(Path(__file__).parent / "models_config.yaml", "r") as file:
return yaml.safe_load(file)


if __name__ == "__main__":
print(get_config())
12 changes: 12 additions & 0 deletions backend/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from langchain import embeddings


def get_embedding_model_instance(config):
embedding_spec = getattr(embeddings, config["embedding_model_config"]["model_source"])
all_config_field = {**config["embedding_model_config"], **config["embedding_provider_config"]}
kwargs = {
key: value
for key, value in all_config_field.items()
if key in embedding_spec.__fields__.keys()
}
return embedding_spec(**kwargs)
17 changes: 5 additions & 12 deletions backend/llm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
from langchain import chat_models
from pathlib import Path
import yaml

def get_model_instance():
config = load_models_config()

def get_llm_model_instance(config):
llm_spec = getattr(chat_models, config["llm_model_config"]["model_source"])
all_config_field = {**config["llm_model_config"], **config["llm_provider_config"]}
kwargs = {key: value for key, value in all_config_field.items() if key in llm_spec.__fields__.keys()}
kwargs = {
key: value for key, value in all_config_field.items() if key in llm_spec.__fields__.keys()
}
return llm_spec(**kwargs)


def load_models_config():
with open(Path(__file__).parent / "models_config.yaml", "r") as file:
return yaml.safe_load(file)


15 changes: 10 additions & 5 deletions backend/models_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ llm_provider_config:
openai_api_type: azure
openai_api_base: https://poc-genai-gpt4.openai.azure.com/
openai_api_version: 2023-07-01-preview
openai_api_key:
openai_api_key: {{ OPENAI_API_KEY }}

llm_model_config:
model_source: AzureChatOpenAI
Expand All @@ -12,10 +12,15 @@ llm_model_config:
verbose: true

embedding_provider_config:
openai_api_base: "https://poc-openai-artefact.openai.azure.com/"
openai_api_key:
openai_api_base: https://poc-openai-artefact.openai.azure.com/
openai_api_key: {{ EMBEDDING_API_KEY }}
openai_api_type: azure

embedding_model_config:
model_source: OpenAIEmbeddings
deployment: text-embedding-ada-002
chunk_size: 16
deployment: embeddings
chunk_size: 16

vector_store_provider:
model_source: Chroma
persist_directory: database/
25 changes: 25 additions & 0 deletions backend/vector_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import inspect

from config_renderer import get_config
from langchain import vectorstores


def get_vector_store(embedding_model):
config = get_config()
vector_store_spec = getattr(vectorstores, config["vector_store_provider"]["model_source"])
all_config_field = config["vector_store_provider"]

# the vector store class in langchain doesn't have a uniform interface to pass the embedding model
# we extract the propertiy of the class that matches the 'Embeddings' type
# and instanciate the vector store with our embedding model
signature = inspect.signature(vector_store_spec.__init__)
parameters = signature.parameters
params_dict = dict(parameters)
embedding_param = next(
(param for param in params_dict.values() if "Embeddings" in str(param.annotation)), None
)

kwargs = {key: value for key, value in all_config_field.items() if key in parameters.keys()}
kwargs[embedding_param.name] = embedding_model
vector_store = vector_store_spec(**kwargs)
return vector_store
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pandas
requests
bs4
streamlit
extra_streamlit_components
google-cloud-storage
openai==0.28.1
langchain==0.0.316
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ exceptiongroup==1.2.0
# via
# anyio
# pytest
extra-streamlit-components==0.1.60
# via -r requirements.in
fastapi==0.105.0
# via chromadb
fastjsonschema==2.19.0
Expand Down Expand Up @@ -497,6 +499,7 @@ stevedore==5.1.0
streamlit==1.29.0
# via
# -r requirements.in
# extra-streamlit-components
# streamlit-feedback
# trubrics
streamlit-feedback==0.1.2
Expand Down

0 comments on commit 9872a18

Please sign in to comment.