forked from NirDiamant/RAG_Techniques
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathhelper_functions.py
324 lines (239 loc) · 10.1 KB
/
helper_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain import PromptTemplate
import fitz
from typing import List
from rank_bm25 import BM25Okapi
import asyncio
import random
import textwrap
import numpy as np
def replace_t_with_space(list_of_documents):
"""
Replaces all tab characters ('\t') with spaces in the page content of each document.
Args:
list_of_documents: A list of document objects, each with a 'page_content' attribute.
Returns:
The modified list of documents with tab characters replaced by spaces.
"""
for doc in list_of_documents:
doc.page_content = doc.page_content.replace('\t', ' ') # Replace tabs with spaces
return list_of_documents
def text_wrap(text, width=120):
"""
Wraps the input text to the specified width.
Args:
text (str): The input text to wrap.
width (int): The width at which to wrap the text.
Returns:
str: The wrapped text.
"""
return textwrap.fill(text, width=width)
def encode_pdf(path, chunk_size=1000, chunk_overlap=200):
"""
Encodes a PDF book into a vector store using OpenAI embeddings.
Args:
path: The path to the PDF file.
chunk_size: The desired size of each text chunk.
chunk_overlap: The amount of overlap between consecutive chunks.
Returns:
A FAISS vector store containing the encoded book content.
"""
# Load PDF documents
loader = PyPDFLoader(path)
documents = loader.load()
# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
)
texts = text_splitter.split_documents(documents)
cleaned_texts = replace_t_with_space(texts)
# Create embeddings and vector store
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_documents(cleaned_texts, embeddings)
return vectorstore
def encode_from_string(content, chunk_size=1000, chunk_overlap=200):
"""
Encodes a string into a vector store using OpenAI embeddings.
Args:
content (str): The text content to be encoded.
chunk_size (int): The size of each chunk of text.
chunk_overlap (int): The overlap between chunks.
Returns:
FAISS: A vector store containing the encoded content.
Raises:
ValueError: If the input content is not valid.
RuntimeError: If there is an error during the encoding process.
"""
if not isinstance(content, str) or not content.strip():
raise ValueError("Content must be a non-empty string.")
if not isinstance(chunk_size, int) or chunk_size <= 0:
raise ValueError("chunk_size must be a positive integer.")
if not isinstance(chunk_overlap, int) or chunk_overlap < 0:
raise ValueError("chunk_overlap must be a non-negative integer.")
try:
# Split the content into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False,
)
chunks = text_splitter.create_documents([content])
# Assign metadata to each chunk
for chunk in chunks:
chunk.metadata['relevance_score'] = 1.0
# Generate embeddings and create the vector store
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_documents(chunks, embeddings)
except Exception as e:
raise RuntimeError(f"An error occurred during the encoding process: {str(e)}")
return vectorstore
def retrieve_context_per_question(question, chunks_query_retriever):
"""
Retrieves relevant context and unique URLs for a given question using the chunks query retriever.
Args:
question: The question for which to retrieve context and URLs.
Returns:
A tuple containing:
- A string with the concatenated content of relevant documents.
- A list of unique URLs from the metadata of the relevant documents.
"""
# Retrieve relevant documents for the given question
docs = chunks_query_retriever.get_relevant_documents(question)
# Concatenate document content
# context = " ".join(doc.page_content for doc in docs)
context = [doc.page_content for doc in docs]
return context
class QuestionAnswerFromContext(BaseModel):
"""
Model to generate an answer to a query based on a given context.
Attributes:
answer_based_on_content (str): The generated answer based on the context.
"""
answer_based_on_content: str = Field(description="Generates an answer to a query based on a given context.")
def create_question_answer_from_context_chain(llm):
# Initialize the ChatOpenAI model with specific parameters
question_answer_from_context_llm = llm
# Define the prompt template for chain-of-thought reasoning
question_answer_prompt_template = """
For the question below, provide a concise but suffice answer based ONLY on the provided context:
{context}
Question
{question}
"""
# Create a PromptTemplate object with the specified template and input variables
question_answer_from_context_prompt = PromptTemplate(
template=question_answer_prompt_template,
input_variables=["context", "question"],
)
# Create a chain by combining the prompt template and the language model
question_answer_from_context_cot_chain = question_answer_from_context_prompt | question_answer_from_context_llm.with_structured_output(QuestionAnswerFromContext)
return question_answer_from_context_cot_chain
def answer_question_from_context(question, context, question_answer_from_context_chain):
"""
Answer a question using the given context by invoking a chain of reasoning.
Args:
question: The question to be answered.
context: The context to be used for answering the question.
Returns:
A dictionary containing the answer, context, and question.
"""
input_data = {
"question": question,
"context": context
}
print("Answering the question from the retrieved context...")
output = question_answer_from_context_chain.invoke(input_data)
answer = output.answer_based_on_content
return {"answer": answer, "context": context, "question": question}
def show_context(context):
"""
Display the contents of the provided context list.
Args:
context (list): A list of context items to be displayed.
Prints each context item in the list with a heading indicating its position.
"""
for i, c in enumerate(context):
print(f"Context {i+1}:")
print(c)
print("\n")
def read_pdf_to_string(path):
"""
Read a PDF document from the specified path and return its content as a string.
Args:
path (str): The file path to the PDF document.
Returns:
str: The concatenated text content of all pages in the PDF document.
The function uses the 'fitz' library (PyMuPDF) to open the PDF document, iterate over each page,
extract the text content from each page, and append it to a single string.
"""
# Open the PDF document located at the specified path
doc = fitz.open(path)
content = ""
# Iterate over each page in the document
for page_num in range(len(doc)):
# Get the current page
page = doc[page_num]
# Extract the text content from the current page and append it to the content string
content += page.get_text()
return content
def bm25_retrieval(bm25: BM25Okapi, cleaned_texts: List[str], query: str, k: int = 5) -> List[str]:
"""
Perform BM25 retrieval and return the top k cleaned text chunks.
Args:
bm25 (BM25Okapi): Pre-computed BM25 index.
cleaned_texts (List[str]): List of cleaned text chunks corresponding to the BM25 index.
query (str): The query string.
k (int): The number of text chunks to retrieve.
Returns:
List[str]: The top k cleaned text chunks based on BM25 scores.
"""
# Tokenize the query
query_tokens = query.split()
# Get BM25 scores for the query
bm25_scores = bm25.get_scores(query_tokens)
# Get the indices of the top k scores
top_k_indices = np.argsort(bm25_scores)[::-1][:k]
# Retrieve the top k cleaned text chunks
top_k_texts = [cleaned_texts[i] for i in top_k_indices]
return top_k_texts
async def exponential_backoff(attempt):
"""
Implements exponential backoff with a jitter.
Args:
attempt: The current retry attempt number.
Waits for a period of time before retrying the operation.
The wait time is calculated as (2^attempt) + a random fraction of a second.
"""
# Calculate the wait time with exponential backoff and jitter
wait_time = (2 ** attempt) + random.uniform(0, 1)
print(f"Rate limit hit. Retrying in {wait_time:.2f} seconds...")
# Asynchronously sleep for the calculated wait time
await asyncio.sleep(wait_time)
async def retry_with_exponential_backoff(coroutine, max_retries=5):
"""
Retries a coroutine using exponential backoff upon encountering a RateLimitError.
Args:
coroutine: The coroutine to be executed.
max_retries: The maximum number of retry attempts.
Returns:
The result of the coroutine if successful.
Raises:
The last encountered exception if all retry attempts fail.
"""
for attempt in range(max_retries):
try:
# Attempt to execute the coroutine
return await coroutine
except RateLimitError as e:
# If the last attempt also fails, raise the exception
if attempt == max_retries - 1:
raise e
# Wait for an exponential backoff period before retrying
await exponential_backoff(attempt)
# If max retries are reached without success, raise an exception
raise Exception("Max retries reached")