-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllm_model.py
120 lines (90 loc) · 3.56 KB
/
llm_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from datetime import datetime
import streamlit as st
from langchain.prompts import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import TextSplitter, RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
CURRENT_DATE = datetime.now().strftime("%d %B %Y")
os.environ["GOOGLE_API_KEY"] = st.secrets["API_KEY"]
DOCS = "Documents/FAQ.txt"
PERSIST_DIRECTORY = "Documents/embedding_db"
EMBEDDINGS_MODEL = "models/text-embedding-004"
MODEL_CONFIG = {
"model": "gemini-2.0-flash-exp",
"temperature": 0,
"top_p": 0.95,
"top_k": 64,
"max_output_tokens": 1000
}
class QASplitter(TextSplitter):
def split_text(self, text):
return [
"Q: " + split.replace("\n", "").strip()
for split in text.split("Q:")
if split.strip()
]
def load_and_split_document(file_path):
loader = TextLoader(file_path)
docs = loader.load()
qa_splitter = QASplitter()
qa_splits = qa_splitter.split_documents(docs)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2500, chunk_overlap=200)
final_splits = []
for qa_split in qa_splits:
if len(qa_split.page_content) > 2500:
final_splits.extend(text_splitter.split_documents([qa_split]))
else:
final_splits.append(qa_split)
return final_splits
def get_vectorstore():
embedding = GoogleGenerativeAIEmbeddings(model=EMBEDDINGS_MODEL)
if os.path.exists(PERSIST_DIRECTORY):
print("Loading existing vectorstore...")
return Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embedding)
print("Creating new vectorstore...")
documents = load_and_split_document(DOCS)
return Chroma.from_documents(
documents=documents, embedding=embedding, persist_directory=PERSIST_DIRECTORY
)
template = """
You are 2PotGPT, a humorous expert on the Two-Pot System.
The Two-Pot Retirement system divides contributions into a "Savings Pot" (one-third, accessible before retirement)
and a "Retirement Pot" (two-thirds, preserved for retirement income), promoting financial security and long-term savings effective on 01 September 2024 in South Africa.
Highlight the required and important parts in a markdown (i.e. bold, points, tables...).
**Use Context, Chat History and Current Date as your knowledge base** and if the user question is out of context ask for clarification.
**Carefully analyze the user's question step by step, and provide a thoughtful, accurate, and well-reasoned response**
<Current Date>
{current_date}
</Current Date>
<Chat History>
{chat_history}
</Chat History>
<Context (FAQ)>
{context}
</Context (FAQ)>
<User Question>
{question}
</User Question>
**Concise Helpful Answer: **
""".replace("{current_date}", CURRENT_DATE)
prompt_template = PromptTemplate.from_template(template)
def initialize_conversation_chain():
llm = ChatGoogleGenerativeAI(**MODEL_CONFIG)
memory = ConversationBufferMemory(
memory_key="chat_history",
input_key="question",
output_key="answer",
return_messages=True,
)
vectorstore = get_vectorstore()
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
memory=memory,
combine_docs_chain_kwargs={"prompt": prompt_template},
)
return conversation_chain