forked from NirDiamant/RAG_Techniques
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hierarchical_indices.py
124 lines (100 loc) · 5.45 KB
/
hierarchical_indices.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import asyncio
import os
import sys
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain.chains.summarize.chain import load_summarize_chain
from langchain.docstore.document import Document
from helper_functions import encode_pdf, encode_from_string
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..'))) # Add the parent directory to the path
from helper_functions import *
from evaluation.evalute_rag import *
# Load environment variables from a .env file
load_dotenv()
# Set the OpenAI API key environment variable
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
# Function to encode to both summary and chunk levels, sharing the page metadata
async def encode_pdf_hierarchical(path, chunk_size=1000, chunk_overlap=200, is_string=False):
"""
Asynchronously encodes a PDF book into a hierarchical vector store using OpenAI embeddings.
Includes rate limit handling with exponential backoff.
"""
if not is_string:
loader = PyPDFLoader(path)
documents = await asyncio.to_thread(loader.load)
else:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, is_separator_regex=False
)
documents = text_splitter.create_documents([path])
summary_llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini", max_tokens=4000)
summary_chain = load_summarize_chain(summary_llm, chain_type="map_reduce")
async def summarize_doc(doc):
summary_output = await retry_with_exponential_backoff(summary_chain.ainvoke([doc]))
summary = summary_output['output_text']
return Document(page_content=summary, metadata={"source": path, "page": doc.metadata["page"], "summary": True})
summaries = []
batch_size = 5
for i in range(0, len(documents), batch_size):
batch = documents[i:i + batch_size]
batch_summaries = await asyncio.gather(*[summarize_doc(doc) for doc in batch])
summaries.extend(batch_summaries)
await asyncio.sleep(1)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len)
detailed_chunks = await asyncio.to_thread(text_splitter.split_documents, documents)
for i, chunk in enumerate(detailed_chunks):
chunk.metadata.update({"chunk_id": i, "summary": False, "page": int(chunk.metadata.get("page", 0))})
embeddings = OpenAIEmbeddings()
async def create_vectorstore(docs):
return await retry_with_exponential_backoff(asyncio.to_thread(FAISS.from_documents, docs, embeddings))
summary_vectorstore, detailed_vectorstore = await asyncio.gather(
create_vectorstore(summaries),
create_vectorstore(detailed_chunks)
)
return summary_vectorstore, detailed_vectorstore
def retrieve_hierarchical(query, summary_vectorstore, detailed_vectorstore, k_summaries=3, k_chunks=5):
"""
Performs a hierarchical retrieval using the query.
"""
top_summaries = summary_vectorstore.similarity_search(query, k=k_summaries)
relevant_chunks = []
for summary in top_summaries:
page_number = summary.metadata["page"]
page_filter = lambda metadata: metadata["page"] == page_number
page_chunks = detailed_vectorstore.similarity_search(query, k=k_chunks, filter=page_filter)
relevant_chunks.extend(page_chunks)
return relevant_chunks
class HierarchicalRAG:
def __init__(self, pdf_path, chunk_size=1000, chunk_overlap=200):
self.pdf_path = pdf_path
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.summary_store = None
self.detailed_store = None
async def run(self, query):
if os.path.exists("../vector_stores/summary_store") and os.path.exists("../vector_stores/detailed_store"):
embeddings = OpenAIEmbeddings()
self.summary_store = FAISS.load_local("../vector_stores/summary_store", embeddings, allow_dangerous_deserialization=True)
self.detailed_store = FAISS.load_local("../vector_stores/detailed_store", embeddings, allow_dangerous_deserialization=True)
else:
self.summary_store, self.detailed_store = await encode_pdf_hierarchical(self.pdf_path, self.chunk_size, self.chunk_overlap)
self.summary_store.save_local("../vector_stores/summary_store")
self.detailed_store.save_local("../vector_stores/detailed_store")
results = retrieve_hierarchical(query, self.summary_store, self.detailed_store)
for chunk in results:
print(f"Page: {chunk.metadata['page']}")
print(f"Content: {chunk.page_content}...")
print("---")
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="Run Hierarchical RAG on a given PDF.")
parser.add_argument("--pdf_path", type=str, default="../data/Understanding_Climate_Change.pdf", help="Path to the PDF document.")
parser.add_argument("--chunk_size", type=int, default=1000, help="Size of each text chunk.")
parser.add_argument("--chunk_overlap", type=int, default=200, help="Overlap between consecutive chunks.")
parser.add_argument("--query", type=str, default='What is the greenhouse effect',
help="Query to search in the document.")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
rag = HierarchicalRAG(args.pdf_path, args.chunk_size, args.chunk_overlap)
asyncio.run(rag.run(args.query))