-
Notifications
You must be signed in to change notification settings - Fork 0
/
advanced_fusion_retriever.py
209 lines (152 loc) · 6.01 KB
/
advanced_fusion_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# -*- coding: utf-8 -*-
"""Advanced_Fusion_Retriever.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1b_bn2k34oPBFXCNvl-1V7E7G6z4KKREq
"""
!pip install llama-index-readers-file pymupdf
!pip install llama-index-llms-gemini
!pip install llama-index-retrievers-bm25
import nest_asyncio
nest_asyncio.apply()
!mkdir data
!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"
!pip install llama-index
from pathlib import Path
from llama_index.readers.file import PyMuPDFReader
loader = PyMuPDFReader()
documents = loader.load(file_path="./data/llama2.pdf")
import os
GOOGLE_API_KEY = "AIzaSyAgDr3ptkmYVH1RnX8kwKxdoCbOXyut5uc"
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY
!pip install llama-index-embeddings-huggingface
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings
Settings.embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.llms.gemini import Gemini
llm=Gemini()
embed_model = Settings.embed_model
from llama_index.core import VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
splitter = SentenceSplitter(chunk_size=1024)
index = VectorStoreIndex.from_documents(
documents, transformations=[splitter], embed_model=embed_model
)
from llama_index.core import PromptTemplate
query_str = "How do the models developed in this work compare to open-source chat models based on the benchmarks tested?"
query_gen_prompt_str = (
"You are a helpful assistant that generates multiple search queries based on a "
"single input query. Generate {num_queries} search queries, one on each line, "
"related to the following input query:\n"
"Query: {query}\n"
"Queries:\n"
)
query_gen_prompt = PromptTemplate(query_gen_prompt_str)
def generate_queries(llm, query_str: str, num_queries: int = 4):
fmt_prompt = query_gen_prompt.format(
num_queries=num_queries - 1, query=query_str
)
response = llm.complete(fmt_prompt)
queries = response.text.split("\n")
return queries
queries = generate_queries(llm, query_str, num_queries=4)
print(queries)
"""## Vector Search for each Query"""
from tqdm.asyncio import tqdm
from typing import List
async def run_queries(queries: List[str], retrievers: List) -> dict: # Type hinting added for clarity
"""Run queries against retrievers."""
# Filter out empty queries to avoid the IndexError
filtered_queries = [query for query in queries if query.strip()] # Filter out empty strings or strings with only whitespace
tasks = []
for query in filtered_queries: # Use filtered queries
for retriever in retrievers:
tasks.append(retriever.aretrieve(query))
task_results = await tqdm.gather(*tasks)
results_dict = {}
for i, (query, query_result) in enumerate(zip(filtered_queries, task_results)): # Use filtered queries
results_dict[(query, i)] = query_result
return results_dict
# get retrievers
from llama_index.retrievers.bm25 import BM25Retriever
## vector retriever
vector_retriever = index.as_retriever(similarity_top_k=2)
## bm25 retriever
bm25_retriever = BM25Retriever.from_defaults(
docstore=index.docstore, similarity_top_k=2
)
print(index.docstore)
results_dict = await run_queries(queries, [vector_retriever, bm25_retriever])
"""## Perform Fusion"""
from typing import List
from llama_index.core.schema import NodeWithScore
def fuse_results(results_dict, similarity_top_k: int = 2):
"""Fuse results."""
k = 60.0 # `k` is a parameter used to control the impact of outlier rankings.
fused_scores = {}
text_to_node = {}
# compute reciprocal rank scores
for nodes_with_scores in results_dict.values():
for rank, node_with_score in enumerate(
sorted(
nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True
)
):
text = node_with_score.node.get_content()
text_to_node[text] = node_with_score
if text not in fused_scores:
fused_scores[text] = 0.0
fused_scores[text] += 1.0 / (rank + k)
# sort results
reranked_results = dict(
sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
)
# adjust node scores
reranked_nodes: List[NodeWithScore] = []
for text, score in reranked_results.items():
reranked_nodes.append(text_to_node[text])
reranked_nodes[-1].score = score
return reranked_nodes[:similarity_top_k]
final_results = fuse_results(results_dict)
for n in final_results:
print(n.score, "\n", n.text, "\n********\n")
"""## Plug into RetrieverQueryEngine"""
from typing import List
from llama_index.core import QueryBundle
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import NodeWithScore
import asyncio
class FusionRetriever(BaseRetriever):
"""Ensemble retriever with fusion."""
def __init__(
self,
llm,
retrievers: List[BaseRetriever],
similarity_top_k: int = 2,
) -> None:
"""Init params."""
self._retrievers = retrievers
self._similarity_top_k = similarity_top_k
self._llm = llm
super().__init__()
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve."""
queries = generate_queries(
self._llm, query_bundle.query_str, num_queries=4
)
results = asyncio.run(run_queries(queries, self._retrievers))
final_results = fuse_results(
results, similarity_top_k=self._similarity_top_k
)
return final_results
from llama_index.llms.gemini import Gemini
llm=Gemini()
from llama_index.core.query_engine import RetrieverQueryEngine
fusion_retriever = FusionRetriever(
llm, [vector_retriever, bm25_retriever], similarity_top_k=2
)
query_engine = RetrieverQueryEngine(fusion_retriever)
response = query_engine.query(query_str)