-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
126 lines (105 loc) · 4.93 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import sys
sys.modules['sqlite3'] = __import__('pysqlite3')
import streamlit as st
from langchain.document_loaders import PyPDFLoader, PythonLoader, UnstructuredWordDocumentLoader#, Docx2txtLoader # from langchain.document_loaders import CSVLoader, PDFMinerLoader, TextLoader, UnstructuredExcelLoader, Docx2txtLoader, UnstructuredFileLoader, UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter # CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatOpenAI
from langchain.vectorstores.chroma import Chroma
import chromadb
import os
# # Local Testing
# from local_secrets.key import OPENAI_API_KEY
# os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
os.environ['OPENAI_API_KEY'] = st.secrets["OPENAI_API_KEY"]
from dirs import SOURCE_DIRECTORY, PERSIST_DIRECTORY
def load_chunk_persist_pdf() -> Chroma:
folder_path = SOURCE_DIRECTORY # "C:\\Users\\me\\Projects\\llm\\langchain-chromadb-pdf-2\\data"
documents = []
# for file in os.listdir(folder_path):
# data_files = [f for f in listdir(folder_path) if isfile(join(folder_path, f))]
# for file in data_files:
for root, dirs, files in os.walk(folder_path):
for file in files:
if file.endswith('.pdf'):
pdf_path = os.path.join(folder_path, file)
loader = PyPDFLoader(pdf_path)
documents.extend(loader.load())
if file.endswith('.py'):
pdf_path = os.path.join(root, file)
loader = PythonLoader(pdf_path)
documents.extend(loader.load())
if file.endswith('.doc'):
pdf_path = os.path.join(folder_path, file)
loader = UnstructuredWordDocumentLoader(pdf_path)
documents.extend(loader.load())
if file.endswith('.docx'):
pdf_path = os.path.join(folder_path, file)
loader = UnstructuredWordDocumentLoader(pdf_path) #Docx2txtLoader
documents.extend(loader.load())
# if file.endswith('.txt'):
# pdf_path = os.path.join(folder_path, file)
# loader = TextLoader(pdf_path)
# documents.extend(loader.load())
# if file.endswith('.md'):
# pdf_path = os.path.join(folder_path, file)
# loader = UnstructuredMarkdownLoader(pdf_path)
# documents.extend(loader.load())
# ".txt": TextLoader,
# ".md": UnstructuredMarkdownLoader,
# ".py": TextLoader,
# # ".pdf": PDFMinerLoader,
# ".pdf": UnstructuredFileLoader,
# ".csv": CSVLoader,
# ".xls": UnstructuredExcelLoader,
# ".xlsx": UnstructuredExcelLoader,
# ".docx": Docx2txtLoader,
# ".doc": Docx2txtLoader,
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=10)
chunked_documents = text_splitter.split_documents(documents)
client = chromadb.Client()
if client.list_collections():
consent_collection = client.create_collection("dsgpt_collection_a5")
else:
print("Collection already exists")
vectordb = Chroma.from_documents(
documents=chunked_documents,
embedding=OpenAIEmbeddings(),
persist_directory=PERSIST_DIRECTORY # "C:\\Users\\me\\Projects\\llm\\langchain-chromadb-pdf-2\\db" # r"./db" #
)
vectordb.persist()
return vectordb
def create_agent_chain():
model_name = "gpt-3.5-turbo"
llm = ChatOpenAI(model_name=model_name)
chain = load_qa_chain(llm, chain_type="stuff")
return chain
def get_llm_response(query):
if "vectordb" not in st.session_state:
vectordb = load_chunk_persist_pdf()
st.session_state['vectordb'] = vectordb
if "chain" not in st.session_state:
chain = create_agent_chain()
st.session_state['chain '] = chain
matching_docs = vectordb.similarity_search(query)
answer = chain.run(input_documents=matching_docs, question=query)
return answer
########################################################################################
# TODO: Implement session state to avoid reloading the db and chain on every submit
# Load Tools
#if "vectordb" not in st.session_state:
vectordb = load_chunk_persist_pdf()
st.session_state['vectordb'] = vectordb
#if "chain" not in st.session_state:
chain = create_agent_chain()
st.session_state['chain'] = chain
# User Interface
st.set_page_config(page_title="DSGPT", page_icon=":robot:")
st.header("DSGPT: Ask python libraries how they work")
form_input = st.text_input('Enter Question and Click Enter with your Mouse')
submit = st.button("Generate")
if submit:
# print(st.session_state.vectordb)
# print(vectordb) # NameError: name 'vectordb' is not defined
st.write(get_llm_response(form_input))